mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-11 20:18:30 +08:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7faca5512a | ||
|
|
ad84272084 | ||
|
|
09e0f594ff | ||
|
|
dd2fbf4424 | ||
|
|
99b12a49c6 | ||
|
|
ea35efe440 | ||
|
|
bf09e740e9 | ||
|
|
60c77cec56 | ||
|
|
0e4a1dddb5 | ||
|
|
1cf18b6e12 | ||
|
|
f9a8be898a | ||
|
|
1521ce5a96 | ||
|
|
f2e62dd197 | ||
|
|
d378630b38 | ||
|
|
d9e6346911 | ||
|
|
238788e0e9 | ||
|
|
68ff828505 |
23
.github/workflows/docker-publish.yml
vendored
23
.github/workflows/docker-publish.yml
vendored
@@ -146,10 +146,33 @@ jobs:
|
|||||||
type=semver,pattern={{major}}.{{minor}}
|
type=semver,pattern={{major}}.{{minor}}
|
||||||
type=sha,prefix=
|
type=sha,prefix=
|
||||||
|
|
||||||
|
- name: Extract version from tag
|
||||||
|
id: version
|
||||||
|
run: |
|
||||||
|
# 从 tag 提取版本号,如 v0.2.5 -> 0.2.5
|
||||||
|
VERSION="${GITHUB_REF#refs/tags/v}"
|
||||||
|
if [ "$VERSION" = "$GITHUB_REF" ]; then
|
||||||
|
# 不是 tag 触发,使用 git describe
|
||||||
|
VERSION=$(git describe --tags --always | sed 's/^v//')
|
||||||
|
fi
|
||||||
|
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||||
|
echo "Extracted version: $VERSION"
|
||||||
|
|
||||||
- name: Update Dockerfile.app to use registry base image
|
- name: Update Dockerfile.app to use registry base image
|
||||||
run: |
|
run: |
|
||||||
sed -i "s|FROM aether-base:latest AS builder|FROM ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest AS builder|g" Dockerfile.app
|
sed -i "s|FROM aether-base:latest AS builder|FROM ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest AS builder|g" Dockerfile.app
|
||||||
|
|
||||||
|
- name: Generate version file
|
||||||
|
run: |
|
||||||
|
# 生成 _version.py 文件
|
||||||
|
cat > src/_version.py << EOF
|
||||||
|
# Auto-generated by CI
|
||||||
|
__version__ = '${{ steps.version.outputs.version }}'
|
||||||
|
__version_tuple__ = tuple(int(x) for x in '${{ steps.version.outputs.version }}'.split('.') if x.isdigit())
|
||||||
|
version = __version__
|
||||||
|
version_tuple = __version_tuple__
|
||||||
|
EOF
|
||||||
|
|
||||||
- name: Build and push app image
|
- name: Build and push app image
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -224,3 +224,6 @@ extracted_*.ts
|
|||||||
.deps-hash
|
.deps-hash
|
||||||
.code-hash
|
.code-hash
|
||||||
.migration-hash
|
.migration-hash
|
||||||
|
|
||||||
|
# Version file (auto-generated by hatch-vcs)
|
||||||
|
src/_version.py
|
||||||
|
|||||||
@@ -147,6 +147,10 @@ RUN printf '%s\n' \
|
|||||||
# 创建目录
|
# 创建目录
|
||||||
RUN mkdir -p /var/log/supervisor /app/logs /app/data
|
RUN mkdir -p /var/log/supervisor /app/logs /app/data
|
||||||
|
|
||||||
|
# 入口脚本(启动前执行迁移)
|
||||||
|
COPY entrypoint.sh /entrypoint.sh
|
||||||
|
RUN chmod +x /entrypoint.sh
|
||||||
|
|
||||||
# 环境变量
|
# 环境变量
|
||||||
ENV PYTHONUNBUFFERED=1 \
|
ENV PYTHONUNBUFFERED=1 \
|
||||||
PYTHONDONTWRITEBYTECODE=1 \
|
PYTHONDONTWRITEBYTECODE=1 \
|
||||||
@@ -161,4 +165,5 @@ EXPOSE 80
|
|||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
CMD curl -f http://localhost/health || exit 1
|
CMD curl -f http://localhost/health || exit 1
|
||||||
|
|
||||||
|
ENTRYPOINT ["/entrypoint.sh"]
|
||||||
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
|
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
|
||||||
|
|||||||
@@ -139,6 +139,10 @@ RUN printf '%s\n' \
|
|||||||
# 创建目录
|
# 创建目录
|
||||||
RUN mkdir -p /var/log/supervisor /app/logs /app/data
|
RUN mkdir -p /var/log/supervisor /app/logs /app/data
|
||||||
|
|
||||||
|
# 入口脚本(启动前执行迁移)
|
||||||
|
COPY entrypoint.sh /entrypoint.sh
|
||||||
|
RUN chmod +x /entrypoint.sh
|
||||||
|
|
||||||
# 环境变量
|
# 环境变量
|
||||||
ENV PYTHONUNBUFFERED=1 \
|
ENV PYTHONUNBUFFERED=1 \
|
||||||
PYTHONDONTWRITEBYTECODE=1 \
|
PYTHONDONTWRITEBYTECODE=1 \
|
||||||
@@ -152,4 +156,5 @@ EXPOSE 80
|
|||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
CMD curl -f http://localhost/health || exit 1
|
CMD curl -f http://localhost/health || exit 1
|
||||||
|
|
||||||
|
ENTRYPOINT ["/entrypoint.sh"]
|
||||||
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
|
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
|
||||||
|
|||||||
10
README.md
10
README.md
@@ -57,14 +57,8 @@ cd Aether
|
|||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
|
python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
|
||||||
|
|
||||||
# 3. 部署
|
# 3. 部署 / 更新(自动执行数据库迁移)
|
||||||
docker compose up -d
|
docker compose pull && docker compose up -d
|
||||||
|
|
||||||
# 4. 首次部署时, 初始化数据库
|
|
||||||
./migrate.sh
|
|
||||||
|
|
||||||
# 5. 更新
|
|
||||||
docker compose pull && docker compose up -d && ./migrate.sh
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Docker Compose(本地构建镜像)
|
### Docker Compose(本地构建镜像)
|
||||||
|
|||||||
530
alembic/versions/20260110_2000_consolidated_schema_updates.py
Normal file
530
alembic/versions/20260110_2000_consolidated_schema_updates.py
Normal file
@@ -0,0 +1,530 @@
|
|||||||
|
"""consolidated schema updates
|
||||||
|
|
||||||
|
Revision ID: m4n5o6p7q8r9
|
||||||
|
Revises: 02a45b66b7c4
|
||||||
|
Create Date: 2026-01-10 20:00:00.000000
|
||||||
|
|
||||||
|
This migration consolidates all schema changes from 2026-01-08 to 2026-01-10:
|
||||||
|
|
||||||
|
1. provider_api_keys: Key 直接关联 Provider (provider_id, api_formats)
|
||||||
|
2. provider_api_keys: 添加 rate_multipliers JSON 字段(按格式费率)
|
||||||
|
3. models: global_model_id 改为可空(支持独立 ProviderModel)
|
||||||
|
4. providers: 添加 timeout, max_retries, proxy(从 endpoint 迁移)
|
||||||
|
5. providers: display_name 重命名为 name,删除原 name
|
||||||
|
6. provider_api_keys: max_concurrent -> rpm_limit(并发改 RPM)
|
||||||
|
7. provider_api_keys: 健康度改为按格式存储(health_by_format, circuit_breaker_by_format)
|
||||||
|
8. provider_endpoints: 删除废弃的 rate_limit 列
|
||||||
|
9. usage: 添加 client_response_headers 字段
|
||||||
|
10. provider_api_keys: 删除 endpoint_id(Key 不再与 Endpoint 绑定)
|
||||||
|
11. provider_endpoints: 删除废弃的 max_concurrent 列
|
||||||
|
12. providers: 删除废弃的 rpm_limit, rpm_used, rpm_reset_at 列
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
alembic_logger = logging.getLogger("alembic.runtime.migration")
|
||||||
|
|
||||||
|
revision = "m4n5o6p7q8r9"
|
||||||
|
down_revision = "02a45b66b7c4"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def _column_exists(table_name: str, column_name: str) -> bool:
|
||||||
|
"""Check if a column exists in the table"""
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = inspect(bind)
|
||||||
|
columns = [col["name"] for col in inspector.get_columns(table_name)]
|
||||||
|
return column_name in columns
|
||||||
|
|
||||||
|
|
||||||
|
def _constraint_exists(table_name: str, constraint_name: str) -> bool:
|
||||||
|
"""Check if a constraint exists"""
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = inspect(bind)
|
||||||
|
fks = inspector.get_foreign_keys(table_name)
|
||||||
|
return any(fk.get("name") == constraint_name for fk in fks)
|
||||||
|
|
||||||
|
|
||||||
|
def _index_exists(table_name: str, index_name: str) -> bool:
|
||||||
|
"""Check if an index exists"""
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = inspect(bind)
|
||||||
|
indexes = inspector.get_indexes(table_name)
|
||||||
|
return any(idx.get("name") == index_name for idx in indexes)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Apply all consolidated schema changes"""
|
||||||
|
bind = op.get_bind()
|
||||||
|
|
||||||
|
# ========== 1. provider_api_keys: 添加 provider_id 和 api_formats ==========
|
||||||
|
if not _column_exists("provider_api_keys", "provider_id"):
|
||||||
|
op.add_column("provider_api_keys", sa.Column("provider_id", sa.String(36), nullable=True))
|
||||||
|
|
||||||
|
# 数据迁移:从 endpoint 获取 provider_id
|
||||||
|
op.execute("""
|
||||||
|
UPDATE provider_api_keys k
|
||||||
|
SET provider_id = e.provider_id
|
||||||
|
FROM provider_endpoints e
|
||||||
|
WHERE k.endpoint_id = e.id AND k.provider_id IS NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
# 检查无法关联的孤儿 Key
|
||||||
|
result = bind.execute(sa.text(
|
||||||
|
"SELECT COUNT(*) FROM provider_api_keys WHERE provider_id IS NULL"
|
||||||
|
))
|
||||||
|
orphan_count = result.scalar() or 0
|
||||||
|
if orphan_count > 0:
|
||||||
|
# 使用 logger 记录更明显的告警
|
||||||
|
alembic_logger.warning("=" * 60)
|
||||||
|
alembic_logger.warning(f"[MIGRATION WARNING] 发现 {orphan_count} 个无法关联 Provider 的孤儿 Key")
|
||||||
|
alembic_logger.warning("=" * 60)
|
||||||
|
alembic_logger.info("正在备份孤儿 Key 到 _orphan_api_keys_backup 表...")
|
||||||
|
|
||||||
|
# 先备份孤儿数据到临时表,避免数据丢失
|
||||||
|
op.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS _orphan_api_keys_backup AS
|
||||||
|
SELECT *, NOW() as backup_at
|
||||||
|
FROM provider_api_keys
|
||||||
|
WHERE provider_id IS NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
# 记录备份的 Key ID
|
||||||
|
orphan_ids = bind.execute(sa.text(
|
||||||
|
"SELECT id, name FROM provider_api_keys WHERE provider_id IS NULL"
|
||||||
|
)).fetchall()
|
||||||
|
alembic_logger.info("备份的孤儿 Key 列表:")
|
||||||
|
for key_id, key_name in orphan_ids:
|
||||||
|
alembic_logger.info(f" - Key: {key_name} (ID: {key_id})")
|
||||||
|
|
||||||
|
# 删除孤儿数据
|
||||||
|
op.execute("DELETE FROM provider_api_keys WHERE provider_id IS NULL")
|
||||||
|
alembic_logger.info(f"已备份并删除 {orphan_count} 个孤儿 Key")
|
||||||
|
|
||||||
|
# 提供恢复指南
|
||||||
|
alembic_logger.warning("-" * 60)
|
||||||
|
alembic_logger.warning("[恢复指南] 如需恢复孤儿 Key:")
|
||||||
|
alembic_logger.warning(" 1. 查询备份表: SELECT * FROM _orphan_api_keys_backup;")
|
||||||
|
alembic_logger.warning(" 2. 确定正确的 provider_id")
|
||||||
|
alembic_logger.warning(" 3. 执行恢复:")
|
||||||
|
alembic_logger.warning(" INSERT INTO provider_api_keys (...)")
|
||||||
|
alembic_logger.warning(" SELECT ... FROM _orphan_api_keys_backup WHERE ...;")
|
||||||
|
alembic_logger.warning("-" * 60)
|
||||||
|
|
||||||
|
# 设置 NOT NULL 并创建外键
|
||||||
|
op.alter_column("provider_api_keys", "provider_id", nullable=False)
|
||||||
|
|
||||||
|
if not _constraint_exists("provider_api_keys", "fk_provider_api_keys_provider"):
|
||||||
|
op.create_foreign_key(
|
||||||
|
"fk_provider_api_keys_provider",
|
||||||
|
"provider_api_keys",
|
||||||
|
"providers",
|
||||||
|
["provider_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="CASCADE",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _index_exists("provider_api_keys", "idx_provider_api_keys_provider_id"):
|
||||||
|
op.create_index("idx_provider_api_keys_provider_id", "provider_api_keys", ["provider_id"])
|
||||||
|
|
||||||
|
if not _column_exists("provider_api_keys", "api_formats"):
|
||||||
|
op.add_column("provider_api_keys", sa.Column("api_formats", sa.JSON(), nullable=True))
|
||||||
|
|
||||||
|
# 数据迁移:从 endpoint 获取 api_format
|
||||||
|
op.execute("""
|
||||||
|
UPDATE provider_api_keys k
|
||||||
|
SET api_formats = json_build_array(e.api_format)
|
||||||
|
FROM provider_endpoints e
|
||||||
|
WHERE k.endpoint_id = e.id AND k.api_formats IS NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
op.alter_column("provider_api_keys", "api_formats", nullable=False, server_default="[]")
|
||||||
|
|
||||||
|
# 修改 endpoint_id 为可空,外键改为 SET NULL
|
||||||
|
if _constraint_exists("provider_api_keys", "provider_api_keys_endpoint_id_fkey"):
|
||||||
|
op.drop_constraint("provider_api_keys_endpoint_id_fkey", "provider_api_keys", type_="foreignkey")
|
||||||
|
op.alter_column("provider_api_keys", "endpoint_id", nullable=True)
|
||||||
|
# 不再重建外键,因为后面会删除这个字段
|
||||||
|
|
||||||
|
# ========== 2. provider_api_keys: 添加 rate_multipliers ==========
|
||||||
|
if not _column_exists("provider_api_keys", "rate_multipliers"):
|
||||||
|
op.add_column(
|
||||||
|
"provider_api_keys",
|
||||||
|
sa.Column("rate_multipliers", postgresql.JSON(astext_type=sa.Text()), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 数据迁移:将 rate_multiplier 按 api_formats 转换
|
||||||
|
op.execute("""
|
||||||
|
UPDATE provider_api_keys
|
||||||
|
SET rate_multipliers = (
|
||||||
|
SELECT jsonb_object_agg(elem, rate_multiplier)
|
||||||
|
FROM jsonb_array_elements_text(api_formats::jsonb) AS elem
|
||||||
|
)
|
||||||
|
WHERE api_formats IS NOT NULL
|
||||||
|
AND api_formats::text != '[]'
|
||||||
|
AND api_formats::text != 'null'
|
||||||
|
AND rate_multipliers IS NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
# ========== 3. models: global_model_id 改为可空 ==========
|
||||||
|
op.alter_column("models", "global_model_id", existing_type=sa.String(36), nullable=True)
|
||||||
|
|
||||||
|
# ========== 4. providers: 添加 timeout, max_retries, proxy ==========
|
||||||
|
if not _column_exists("providers", "timeout"):
|
||||||
|
op.add_column(
|
||||||
|
"providers",
|
||||||
|
sa.Column("timeout", sa.Integer(), nullable=True, comment="请求超时(秒)"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _column_exists("providers", "max_retries"):
|
||||||
|
op.add_column(
|
||||||
|
"providers",
|
||||||
|
sa.Column("max_retries", sa.Integer(), nullable=True, comment="最大重试次数"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _column_exists("providers", "proxy"):
|
||||||
|
op.add_column(
|
||||||
|
"providers",
|
||||||
|
sa.Column("proxy", postgresql.JSONB(), nullable=True, comment="代理配置"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从端点迁移数据到 provider
|
||||||
|
op.execute("""
|
||||||
|
UPDATE providers p
|
||||||
|
SET
|
||||||
|
timeout = COALESCE(
|
||||||
|
p.timeout,
|
||||||
|
(SELECT MAX(e.timeout) FROM provider_endpoints e WHERE e.provider_id = p.id AND e.timeout IS NOT NULL),
|
||||||
|
300
|
||||||
|
),
|
||||||
|
max_retries = COALESCE(
|
||||||
|
p.max_retries,
|
||||||
|
(SELECT MAX(e.max_retries) FROM provider_endpoints e WHERE e.provider_id = p.id AND e.max_retries IS NOT NULL),
|
||||||
|
2
|
||||||
|
),
|
||||||
|
proxy = COALESCE(
|
||||||
|
p.proxy,
|
||||||
|
(SELECT e.proxy FROM provider_endpoints e WHERE e.provider_id = p.id AND e.proxy IS NOT NULL ORDER BY e.created_at LIMIT 1)
|
||||||
|
)
|
||||||
|
WHERE p.timeout IS NULL OR p.max_retries IS NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
# ========== 5. providers: display_name -> name ==========
|
||||||
|
# 注意:这里假设 display_name 已经被重命名为 name
|
||||||
|
# 如果 display_name 仍然存在,则需要执行重命名
|
||||||
|
if _column_exists("providers", "display_name"):
|
||||||
|
# 删除旧的 name 索引
|
||||||
|
if _index_exists("providers", "ix_providers_name"):
|
||||||
|
op.drop_index("ix_providers_name", table_name="providers")
|
||||||
|
|
||||||
|
# 如果存在旧的 name 列,先删除
|
||||||
|
if _column_exists("providers", "name"):
|
||||||
|
op.drop_column("providers", "name")
|
||||||
|
|
||||||
|
# 重命名 display_name 为 name
|
||||||
|
op.alter_column("providers", "display_name", new_column_name="name")
|
||||||
|
|
||||||
|
# 创建新索引
|
||||||
|
op.create_index("ix_providers_name", "providers", ["name"], unique=True)
|
||||||
|
|
||||||
|
# ========== 6. provider_api_keys: max_concurrent -> rpm_limit ==========
|
||||||
|
if _column_exists("provider_api_keys", "max_concurrent"):
|
||||||
|
op.alter_column("provider_api_keys", "max_concurrent", new_column_name="rpm_limit")
|
||||||
|
|
||||||
|
if _column_exists("provider_api_keys", "learned_max_concurrent"):
|
||||||
|
op.alter_column("provider_api_keys", "learned_max_concurrent", new_column_name="learned_rpm_limit")
|
||||||
|
|
||||||
|
if _column_exists("provider_api_keys", "last_concurrent_peak"):
|
||||||
|
op.alter_column("provider_api_keys", "last_concurrent_peak", new_column_name="last_rpm_peak")
|
||||||
|
|
||||||
|
# 删除废弃字段
|
||||||
|
for col in ["rate_limit", "daily_limit", "monthly_limit"]:
|
||||||
|
if _column_exists("provider_api_keys", col):
|
||||||
|
op.drop_column("provider_api_keys", col)
|
||||||
|
|
||||||
|
# ========== 7. provider_api_keys: 健康度改为按格式存储 ==========
|
||||||
|
if not _column_exists("provider_api_keys", "health_by_format"):
|
||||||
|
op.add_column(
|
||||||
|
"provider_api_keys",
|
||||||
|
sa.Column(
|
||||||
|
"health_by_format",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=True,
|
||||||
|
comment="按API格式存储的健康度数据",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _column_exists("provider_api_keys", "circuit_breaker_by_format"):
|
||||||
|
op.add_column(
|
||||||
|
"provider_api_keys",
|
||||||
|
sa.Column(
|
||||||
|
"circuit_breaker_by_format",
|
||||||
|
postgresql.JSONB(astext_type=sa.Text()),
|
||||||
|
nullable=True,
|
||||||
|
comment="按API格式存储的熔断器状态",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 数据迁移:如果存在旧字段,迁移数据到新结构
|
||||||
|
if _column_exists("provider_api_keys", "health_score"):
|
||||||
|
op.execute("""
|
||||||
|
UPDATE provider_api_keys
|
||||||
|
SET health_by_format = (
|
||||||
|
SELECT jsonb_object_agg(
|
||||||
|
elem,
|
||||||
|
jsonb_build_object(
|
||||||
|
'health_score', COALESCE(health_score, 1.0),
|
||||||
|
'consecutive_failures', COALESCE(consecutive_failures, 0),
|
||||||
|
'last_failure_at', last_failure_at,
|
||||||
|
'request_results_window', COALESCE(request_results_window::jsonb, '[]'::jsonb)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
FROM jsonb_array_elements_text(api_formats::jsonb) AS elem
|
||||||
|
)
|
||||||
|
WHERE api_formats IS NOT NULL
|
||||||
|
AND api_formats::text != '[]'
|
||||||
|
AND health_by_format IS NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Circuit Breaker 迁移策略:
|
||||||
|
# 不复制旧的 circuit_breaker_open 状态到所有 format,而是全部重置为 closed
|
||||||
|
# 原因:旧的单一 circuit breaker 状态可能因某一个 format 失败而打开,
|
||||||
|
# 如果复制到所有 format,会导致其他正常工作的 format 被错误标记为不可用
|
||||||
|
if _column_exists("provider_api_keys", "circuit_breaker_open"):
|
||||||
|
op.execute("""
|
||||||
|
UPDATE provider_api_keys
|
||||||
|
SET circuit_breaker_by_format = (
|
||||||
|
SELECT jsonb_object_agg(
|
||||||
|
elem,
|
||||||
|
jsonb_build_object(
|
||||||
|
'open', false,
|
||||||
|
'open_at', NULL,
|
||||||
|
'next_probe_at', NULL,
|
||||||
|
'half_open_until', NULL,
|
||||||
|
'half_open_successes', 0,
|
||||||
|
'half_open_failures', 0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
FROM jsonb_array_elements_text(api_formats::jsonb) AS elem
|
||||||
|
)
|
||||||
|
WHERE api_formats IS NOT NULL
|
||||||
|
AND api_formats::text != '[]'
|
||||||
|
AND circuit_breaker_by_format IS NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
# 设置默认空对象
|
||||||
|
op.execute("""
|
||||||
|
UPDATE provider_api_keys
|
||||||
|
SET health_by_format = '{}'::jsonb
|
||||||
|
WHERE health_by_format IS NULL
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
UPDATE provider_api_keys
|
||||||
|
SET circuit_breaker_by_format = '{}'::jsonb
|
||||||
|
WHERE circuit_breaker_by_format IS NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
# 创建 GIN 索引
|
||||||
|
if not _index_exists("provider_api_keys", "ix_provider_api_keys_health_by_format"):
|
||||||
|
op.create_index(
|
||||||
|
"ix_provider_api_keys_health_by_format",
|
||||||
|
"provider_api_keys",
|
||||||
|
["health_by_format"],
|
||||||
|
postgresql_using="gin",
|
||||||
|
)
|
||||||
|
if not _index_exists("provider_api_keys", "ix_provider_api_keys_circuit_breaker_by_format"):
|
||||||
|
op.create_index(
|
||||||
|
"ix_provider_api_keys_circuit_breaker_by_format",
|
||||||
|
"provider_api_keys",
|
||||||
|
["circuit_breaker_by_format"],
|
||||||
|
postgresql_using="gin",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 删除旧字段
|
||||||
|
old_health_columns = [
|
||||||
|
"health_score",
|
||||||
|
"consecutive_failures",
|
||||||
|
"last_failure_at",
|
||||||
|
"request_results_window",
|
||||||
|
"circuit_breaker_open",
|
||||||
|
"circuit_breaker_open_at",
|
||||||
|
"next_probe_at",
|
||||||
|
"half_open_until",
|
||||||
|
"half_open_successes",
|
||||||
|
"half_open_failures",
|
||||||
|
]
|
||||||
|
for col in old_health_columns:
|
||||||
|
if _column_exists("provider_api_keys", col):
|
||||||
|
op.drop_column("provider_api_keys", col)
|
||||||
|
|
||||||
|
# ========== 8. provider_endpoints: 删除废弃的 rate_limit 列 ==========
|
||||||
|
if _column_exists("provider_endpoints", "rate_limit"):
|
||||||
|
op.drop_column("provider_endpoints", "rate_limit")
|
||||||
|
|
||||||
|
# ========== 9. usage: 添加 client_response_headers ==========
|
||||||
|
if not _column_exists("usage", "client_response_headers"):
|
||||||
|
op.add_column(
|
||||||
|
"usage",
|
||||||
|
sa.Column("client_response_headers", sa.JSON(), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ========== 10. provider_api_keys: 删除 endpoint_id ==========
|
||||||
|
# Key 不再与 Endpoint 绑定,通过 provider_id + api_formats 关联
|
||||||
|
if _column_exists("provider_api_keys", "endpoint_id"):
|
||||||
|
# 确保外键已删除(前面可能已经删除)
|
||||||
|
try:
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = inspect(bind)
|
||||||
|
for fk in inspector.get_foreign_keys("provider_api_keys"):
|
||||||
|
constrained = fk.get("constrained_columns") or []
|
||||||
|
if "endpoint_id" in constrained:
|
||||||
|
name = fk.get("name")
|
||||||
|
if name:
|
||||||
|
op.drop_constraint(name, "provider_api_keys", type_="foreignkey")
|
||||||
|
except Exception:
|
||||||
|
pass # 外键可能已经不存在
|
||||||
|
op.drop_column("provider_api_keys", "endpoint_id")
|
||||||
|
|
||||||
|
# ========== 11. provider_endpoints: 删除废弃的 max_concurrent 列 ==========
|
||||||
|
if _column_exists("provider_endpoints", "max_concurrent"):
|
||||||
|
op.drop_column("provider_endpoints", "max_concurrent")
|
||||||
|
|
||||||
|
# ========== 12. providers: 删除废弃的 RPM 相关字段 ==========
|
||||||
|
if _column_exists("providers", "rpm_limit"):
|
||||||
|
op.drop_column("providers", "rpm_limit")
|
||||||
|
if _column_exists("providers", "rpm_used"):
|
||||||
|
op.drop_column("providers", "rpm_used")
|
||||||
|
if _column_exists("providers", "rpm_reset_at"):
|
||||||
|
op.drop_column("providers", "rpm_reset_at")
|
||||||
|
|
||||||
|
alembic_logger.info("[OK] Consolidated migration completed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""
|
||||||
|
Downgrade is complex due to data migrations.
|
||||||
|
For safety, this only removes new columns without restoring old structure.
|
||||||
|
Manual intervention may be required for full rollback.
|
||||||
|
"""
|
||||||
|
bind = op.get_bind()
|
||||||
|
|
||||||
|
# 12. 恢复 providers RPM 相关字段
|
||||||
|
if not _column_exists("providers", "rpm_limit"):
|
||||||
|
op.add_column("providers", sa.Column("rpm_limit", sa.Integer(), nullable=True))
|
||||||
|
if not _column_exists("providers", "rpm_used"):
|
||||||
|
op.add_column(
|
||||||
|
"providers",
|
||||||
|
sa.Column("rpm_used", sa.Integer(), server_default="0", nullable=True),
|
||||||
|
)
|
||||||
|
if not _column_exists("providers", "rpm_reset_at"):
|
||||||
|
op.add_column(
|
||||||
|
"providers",
|
||||||
|
sa.Column("rpm_reset_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 11. 恢复 provider_endpoints.max_concurrent
|
||||||
|
if not _column_exists("provider_endpoints", "max_concurrent"):
|
||||||
|
op.add_column("provider_endpoints", sa.Column("max_concurrent", sa.Integer(), nullable=True))
|
||||||
|
|
||||||
|
# 10. 恢复 endpoint_id
|
||||||
|
if not _column_exists("provider_api_keys", "endpoint_id"):
|
||||||
|
op.add_column("provider_api_keys", sa.Column("endpoint_id", sa.String(36), nullable=True))
|
||||||
|
|
||||||
|
# 9. 删除 client_response_headers
|
||||||
|
if _column_exists("usage", "client_response_headers"):
|
||||||
|
op.drop_column("usage", "client_response_headers")
|
||||||
|
|
||||||
|
# 8. 恢复 provider_endpoints.rate_limit(如果需要)
|
||||||
|
if not _column_exists("provider_endpoints", "rate_limit"):
|
||||||
|
op.add_column("provider_endpoints", sa.Column("rate_limit", sa.Integer(), nullable=True))
|
||||||
|
|
||||||
|
# 7. 删除健康度 JSON 字段
|
||||||
|
bind.execute(sa.text("DROP INDEX IF EXISTS ix_provider_api_keys_health_by_format"))
|
||||||
|
bind.execute(sa.text("DROP INDEX IF EXISTS ix_provider_api_keys_circuit_breaker_by_format"))
|
||||||
|
if _column_exists("provider_api_keys", "health_by_format"):
|
||||||
|
op.drop_column("provider_api_keys", "health_by_format")
|
||||||
|
if _column_exists("provider_api_keys", "circuit_breaker_by_format"):
|
||||||
|
op.drop_column("provider_api_keys", "circuit_breaker_by_format")
|
||||||
|
|
||||||
|
# 6. rpm_limit -> max_concurrent(简化版:仅重命名)
|
||||||
|
if _column_exists("provider_api_keys", "rpm_limit"):
|
||||||
|
op.alter_column("provider_api_keys", "rpm_limit", new_column_name="max_concurrent")
|
||||||
|
if _column_exists("provider_api_keys", "learned_rpm_limit"):
|
||||||
|
op.alter_column("provider_api_keys", "learned_rpm_limit", new_column_name="learned_max_concurrent")
|
||||||
|
if _column_exists("provider_api_keys", "last_rpm_peak"):
|
||||||
|
op.alter_column("provider_api_keys", "last_rpm_peak", new_column_name="last_concurrent_peak")
|
||||||
|
|
||||||
|
# 恢复已删除的字段
|
||||||
|
if not _column_exists("provider_api_keys", "rate_limit"):
|
||||||
|
op.add_column("provider_api_keys", sa.Column("rate_limit", sa.Integer(), nullable=True))
|
||||||
|
if not _column_exists("provider_api_keys", "daily_limit"):
|
||||||
|
op.add_column("provider_api_keys", sa.Column("daily_limit", sa.Integer(), nullable=True))
|
||||||
|
if not _column_exists("provider_api_keys", "monthly_limit"):
|
||||||
|
op.add_column("provider_api_keys", sa.Column("monthly_limit", sa.Integer(), nullable=True))
|
||||||
|
|
||||||
|
# 5. name -> display_name (需要先删除索引)
|
||||||
|
if _index_exists("providers", "ix_providers_name"):
|
||||||
|
op.drop_index("ix_providers_name", table_name="providers")
|
||||||
|
op.alter_column("providers", "name", new_column_name="display_name")
|
||||||
|
# 重新添加原 name 字段
|
||||||
|
op.add_column("providers", sa.Column("name", sa.String(100), nullable=True))
|
||||||
|
op.execute("""
|
||||||
|
UPDATE providers
|
||||||
|
SET name = LOWER(REPLACE(REPLACE(display_name, ' ', '_'), '-', '_'))
|
||||||
|
""")
|
||||||
|
op.alter_column("providers", "name", nullable=False)
|
||||||
|
op.create_index("ix_providers_name", "providers", ["name"], unique=True)
|
||||||
|
|
||||||
|
# 4. 删除 providers 的 timeout, max_retries, proxy
|
||||||
|
if _column_exists("providers", "proxy"):
|
||||||
|
op.drop_column("providers", "proxy")
|
||||||
|
if _column_exists("providers", "max_retries"):
|
||||||
|
op.drop_column("providers", "max_retries")
|
||||||
|
if _column_exists("providers", "timeout"):
|
||||||
|
op.drop_column("providers", "timeout")
|
||||||
|
|
||||||
|
# 3. models: global_model_id 改回 NOT NULL
|
||||||
|
result = bind.execute(sa.text(
|
||||||
|
"SELECT COUNT(*) FROM models WHERE global_model_id IS NULL"
|
||||||
|
))
|
||||||
|
orphan_model_count = result.scalar() or 0
|
||||||
|
if orphan_model_count > 0:
|
||||||
|
alembic_logger.warning(f"[WARN] 发现 {orphan_model_count} 个无 global_model_id 的独立模型,将被删除")
|
||||||
|
op.execute("DELETE FROM models WHERE global_model_id IS NULL")
|
||||||
|
alembic_logger.info(f"已删除 {orphan_model_count} 个独立模型")
|
||||||
|
op.alter_column("models", "global_model_id", nullable=False)
|
||||||
|
|
||||||
|
# 2. 删除 rate_multipliers
|
||||||
|
if _column_exists("provider_api_keys", "rate_multipliers"):
|
||||||
|
op.drop_column("provider_api_keys", "rate_multipliers")
|
||||||
|
|
||||||
|
# 1. 删除 provider_id 和 api_formats
|
||||||
|
if _index_exists("provider_api_keys", "idx_provider_api_keys_provider_id"):
|
||||||
|
op.drop_index("idx_provider_api_keys_provider_id", table_name="provider_api_keys")
|
||||||
|
if _constraint_exists("provider_api_keys", "fk_provider_api_keys_provider"):
|
||||||
|
op.drop_constraint("fk_provider_api_keys_provider", "provider_api_keys", type_="foreignkey")
|
||||||
|
if _column_exists("provider_api_keys", "api_formats"):
|
||||||
|
op.drop_column("provider_api_keys", "api_formats")
|
||||||
|
if _column_exists("provider_api_keys", "provider_id"):
|
||||||
|
op.drop_column("provider_api_keys", "provider_id")
|
||||||
|
|
||||||
|
# 恢复 endpoint_id 外键(简化版:仅创建外键,不强制 NOT NULL)
|
||||||
|
if _column_exists("provider_api_keys", "endpoint_id"):
|
||||||
|
if not _constraint_exists("provider_api_keys", "provider_api_keys_endpoint_id_fkey"):
|
||||||
|
op.create_foreign_key(
|
||||||
|
"provider_api_keys_endpoint_id_fkey",
|
||||||
|
"provider_api_keys",
|
||||||
|
"provider_endpoints",
|
||||||
|
["endpoint_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
alembic_logger.info("[OK] Downgrade completed (simplified version)")
|
||||||
19
deploy.sh
19
deploy.sh
@@ -88,9 +88,28 @@ build_base() {
|
|||||||
save_deps_hash
|
save_deps_hash
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 生成版本文件
|
||||||
|
generate_version_file() {
|
||||||
|
# 从 git 获取版本号
|
||||||
|
local version
|
||||||
|
version=$(git describe --tags --always 2>/dev/null | sed 's/^v//')
|
||||||
|
if [ -z "$version" ]; then
|
||||||
|
version="unknown"
|
||||||
|
fi
|
||||||
|
echo ">>> Generating version file: $version"
|
||||||
|
cat > src/_version.py << EOF
|
||||||
|
# Auto-generated by deploy.sh - do not edit
|
||||||
|
__version__ = '$version'
|
||||||
|
__version_tuple__ = tuple(int(x) for x in '$version'.split('-')[0].split('.') if x.isdigit())
|
||||||
|
version = __version__
|
||||||
|
version_tuple = __version_tuple__
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
# 构建应用镜像
|
# 构建应用镜像
|
||||||
build_app() {
|
build_app() {
|
||||||
echo ">>> Building app image (code only)..."
|
echo ">>> Building app image (code only)..."
|
||||||
|
generate_version_file
|
||||||
docker build -f Dockerfile.app.local -t aether-app:latest .
|
docker build -f Dockerfile.app.local -t aether-app:latest .
|
||||||
save_code_hash
|
save_code_hash
|
||||||
}
|
}
|
||||||
|
|||||||
8
entrypoint.sh
Normal file
8
entrypoint.sh
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo "Running database migrations..."
|
||||||
|
alembic upgrade head
|
||||||
|
|
||||||
|
echo "Starting application..."
|
||||||
|
exec "$@"
|
||||||
@@ -67,7 +67,6 @@ export interface GlobalModelExport {
|
|||||||
|
|
||||||
export interface ProviderExport {
|
export interface ProviderExport {
|
||||||
name: string
|
name: string
|
||||||
display_name: string
|
|
||||||
description?: string | null
|
description?: string | null
|
||||||
website?: string | null
|
website?: string | null
|
||||||
billing_type?: string | null
|
billing_type?: string | null
|
||||||
@@ -76,10 +75,13 @@ export interface ProviderExport {
|
|||||||
rpm_limit?: number | null
|
rpm_limit?: number | null
|
||||||
provider_priority?: number
|
provider_priority?: number
|
||||||
is_active: boolean
|
is_active: boolean
|
||||||
rate_limit?: number | null
|
|
||||||
concurrent_limit?: number | null
|
concurrent_limit?: number | null
|
||||||
|
timeout?: number | null
|
||||||
|
max_retries?: number | null
|
||||||
|
proxy?: any
|
||||||
config?: any
|
config?: any
|
||||||
endpoints: EndpointExport[]
|
endpoints: EndpointExport[]
|
||||||
|
api_keys: ProviderKeyExport[]
|
||||||
models: ModelExport[]
|
models: ModelExport[]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,27 +91,26 @@ export interface EndpointExport {
|
|||||||
headers?: any
|
headers?: any
|
||||||
timeout?: number
|
timeout?: number
|
||||||
max_retries?: number
|
max_retries?: number
|
||||||
max_concurrent?: number | null
|
|
||||||
rate_limit?: number | null
|
|
||||||
is_active: boolean
|
is_active: boolean
|
||||||
custom_path?: string | null
|
custom_path?: string | null
|
||||||
config?: any
|
config?: any
|
||||||
keys: KeyExport[]
|
proxy?: any
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface KeyExport {
|
export interface ProviderKeyExport {
|
||||||
api_key: string
|
api_key: string
|
||||||
name?: string | null
|
name?: string | null
|
||||||
note?: string | null
|
note?: string | null
|
||||||
|
api_formats: string[]
|
||||||
rate_multiplier?: number
|
rate_multiplier?: number
|
||||||
|
rate_multipliers?: Record<string, number> | null
|
||||||
internal_priority?: number
|
internal_priority?: number
|
||||||
global_priority?: number | null
|
global_priority?: number | null
|
||||||
max_concurrent?: number | null
|
rpm_limit?: number | null
|
||||||
rate_limit?: number | null
|
allowed_models?: any
|
||||||
daily_limit?: number | null
|
|
||||||
monthly_limit?: number | null
|
|
||||||
allowed_models?: string[] | null
|
|
||||||
capabilities?: any
|
capabilities?: any
|
||||||
|
cache_ttl_minutes?: number
|
||||||
|
max_probe_interval_minutes?: number
|
||||||
is_active: boolean
|
is_active: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -159,6 +160,15 @@ export interface EmailTemplateResetResponse {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查更新响应
|
||||||
|
export interface CheckUpdateResponse {
|
||||||
|
current_version: string
|
||||||
|
latest_version: string | null
|
||||||
|
has_update: boolean
|
||||||
|
release_url: string | null
|
||||||
|
error: string | null
|
||||||
|
}
|
||||||
|
|
||||||
// LDAP 配置响应
|
// LDAP 配置响应
|
||||||
export interface LdapConfigResponse {
|
export interface LdapConfigResponse {
|
||||||
server_url: string | null
|
server_url: string | null
|
||||||
@@ -526,6 +536,14 @@ export const adminApi = {
|
|||||||
return response.data
|
return response.data
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// 检查系统更新
|
||||||
|
async checkUpdate(): Promise<CheckUpdateResponse> {
|
||||||
|
const response = await apiClient.get<CheckUpdateResponse>(
|
||||||
|
'/api/admin/system/check-update'
|
||||||
|
)
|
||||||
|
return response.data
|
||||||
|
},
|
||||||
|
|
||||||
// LDAP 配置相关
|
// LDAP 配置相关
|
||||||
// 获取 LDAP 配置
|
// 获取 LDAP 配置
|
||||||
async getLdapConfig(): Promise<LdapConfigResponse> {
|
async getLdapConfig(): Promise<LdapConfigResponse> {
|
||||||
|
|||||||
@@ -155,6 +155,7 @@ export interface RequestDetail {
|
|||||||
request_body?: Record<string, any>
|
request_body?: Record<string, any>
|
||||||
provider_request_headers?: Record<string, any>
|
provider_request_headers?: Record<string, any>
|
||||||
response_headers?: Record<string, any>
|
response_headers?: Record<string, any>
|
||||||
|
client_response_headers?: Record<string, any>
|
||||||
response_body?: Record<string, any>
|
response_body?: Record<string, any>
|
||||||
metadata?: Record<string, any>
|
metadata?: Record<string, any>
|
||||||
// 阶梯计费信息
|
// 阶梯计费信息
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ export async function toggleAdaptiveMode(
|
|||||||
message: string
|
message: string
|
||||||
key_id: string
|
key_id: string
|
||||||
is_adaptive: boolean
|
is_adaptive: boolean
|
||||||
max_concurrent: number | null
|
rpm_limit: number | null
|
||||||
effective_limit: number | null
|
effective_limit: number | null
|
||||||
}> {
|
}> {
|
||||||
const response = await client.patch(`/api/admin/adaptive/keys/${keyId}/mode`, data)
|
const response = await client.patch(`/api/admin/adaptive/keys/${keyId}/mode`, data)
|
||||||
@@ -22,16 +22,16 @@ export async function toggleAdaptiveMode(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 设置 Key 的固定并发限制
|
* 设置 Key 的固定 RPM 限制
|
||||||
*/
|
*/
|
||||||
export async function setConcurrentLimit(
|
export async function setRpmLimit(
|
||||||
keyId: string,
|
keyId: string,
|
||||||
limit: number
|
limit: number
|
||||||
): Promise<{
|
): Promise<{
|
||||||
message: string
|
message: string
|
||||||
key_id: string
|
key_id: string
|
||||||
is_adaptive: boolean
|
is_adaptive: boolean
|
||||||
max_concurrent: number
|
rpm_limit: number
|
||||||
previous_mode: string
|
previous_mode: string
|
||||||
}> {
|
}> {
|
||||||
const response = await client.patch(`/api/admin/adaptive/keys/${keyId}/limit`, null, {
|
const response = await client.patch(`/api/admin/adaptive/keys/${keyId}/limit`, null, {
|
||||||
|
|||||||
@@ -27,15 +27,9 @@ export async function createEndpoint(
|
|||||||
api_format: string
|
api_format: string
|
||||||
base_url: string
|
base_url: string
|
||||||
custom_path?: string
|
custom_path?: string
|
||||||
auth_type?: string
|
|
||||||
auth_header?: string
|
|
||||||
headers?: Record<string, string>
|
headers?: Record<string, string>
|
||||||
timeout?: number
|
timeout?: number
|
||||||
max_retries?: number
|
max_retries?: number
|
||||||
priority?: number
|
|
||||||
weight?: number
|
|
||||||
max_concurrent?: number
|
|
||||||
rate_limit?: number
|
|
||||||
is_active?: boolean
|
is_active?: boolean
|
||||||
config?: Record<string, any>
|
config?: Record<string, any>
|
||||||
proxy?: ProxyConfig | null
|
proxy?: ProxyConfig | null
|
||||||
@@ -52,16 +46,10 @@ export async function updateEndpoint(
|
|||||||
endpointId: string,
|
endpointId: string,
|
||||||
data: Partial<{
|
data: Partial<{
|
||||||
base_url: string
|
base_url: string
|
||||||
custom_path: string
|
custom_path: string | null
|
||||||
auth_type: string
|
|
||||||
auth_header: string
|
|
||||||
headers: Record<string, string>
|
headers: Record<string, string>
|
||||||
timeout: number
|
timeout: number
|
||||||
max_retries: number
|
max_retries: number
|
||||||
priority: number
|
|
||||||
weight: number
|
|
||||||
max_concurrent: number
|
|
||||||
rate_limit: number
|
|
||||||
is_active: boolean
|
is_active: boolean
|
||||||
config: Record<string, any>
|
config: Record<string, any>
|
||||||
proxy: ProxyConfig | null
|
proxy: ProxyConfig | null
|
||||||
@@ -74,7 +62,7 @@ export async function updateEndpoint(
|
|||||||
/**
|
/**
|
||||||
* 删除 Endpoint
|
* 删除 Endpoint
|
||||||
*/
|
*/
|
||||||
export async function deleteEndpoint(endpointId: string): Promise<{ message: string; deleted_keys_count: number }> {
|
export async function deleteEndpoint(endpointId: string): Promise<{ message: string; affected_keys_count: number }> {
|
||||||
const response = await client.delete(`/api/admin/endpoints/${endpointId}`)
|
const response = await client.delete(`/api/admin/endpoints/${endpointId}`)
|
||||||
return response.data
|
return response.data
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,16 +32,21 @@ export async function getKeyHealth(keyId: string): Promise<HealthStatus> {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 恢复Key健康状态(一键恢复:重置健康度 + 关闭熔断器 + 取消自动禁用)
|
* 恢复Key健康状态(一键恢复:重置健康度 + 关闭熔断器 + 取消自动禁用)
|
||||||
|
* @param keyId Key ID
|
||||||
|
* @param apiFormat 可选,指定 API 格式(如 CLAUDE、OPENAI),不指定则恢复所有格式
|
||||||
*/
|
*/
|
||||||
export async function recoverKeyHealth(keyId: string): Promise<{
|
export async function recoverKeyHealth(keyId: string, apiFormat?: string): Promise<{
|
||||||
message: string
|
message: string
|
||||||
details: {
|
details: {
|
||||||
|
api_format?: string
|
||||||
health_score: number
|
health_score: number
|
||||||
circuit_breaker_open: boolean
|
circuit_breaker_open: boolean
|
||||||
is_active: boolean
|
is_active: boolean
|
||||||
}
|
}
|
||||||
}> {
|
}> {
|
||||||
const response = await client.patch(`/api/admin/endpoints/health/keys/${keyId}`)
|
const response = await client.patch(`/api/admin/endpoints/health/keys/${keyId}`, null, {
|
||||||
|
params: apiFormat ? { api_format: apiFormat } : undefined
|
||||||
|
})
|
||||||
return response.data
|
return response.data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import client from '../client'
|
import client from '../client'
|
||||||
import type { EndpointAPIKey } from './types'
|
import type { EndpointAPIKey, AllowedModels } from './types'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 能力定义类型
|
* 能力定义类型
|
||||||
@@ -49,67 +49,6 @@ export async function getModelCapabilities(modelName: string): Promise<ModelCapa
|
|||||||
return response.data
|
return response.data
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 获取 Endpoint 的所有 Keys
|
|
||||||
*/
|
|
||||||
export async function getEndpointKeys(endpointId: string): Promise<EndpointAPIKey[]> {
|
|
||||||
const response = await client.get(`/api/admin/endpoints/${endpointId}/keys`)
|
|
||||||
return response.data
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 为 Endpoint 添加 Key
|
|
||||||
*/
|
|
||||||
export async function addEndpointKey(
|
|
||||||
endpointId: string,
|
|
||||||
data: {
|
|
||||||
endpoint_id: string
|
|
||||||
api_key: string
|
|
||||||
name: string // 密钥名称(必填)
|
|
||||||
rate_multiplier?: number // 成本倍率(默认 1.0)
|
|
||||||
internal_priority?: number // Endpoint 内部优先级(数字越小越优先)
|
|
||||||
max_concurrent?: number // 最大并发数(留空=自适应模式)
|
|
||||||
rate_limit?: number
|
|
||||||
daily_limit?: number
|
|
||||||
monthly_limit?: number
|
|
||||||
cache_ttl_minutes?: number // 缓存 TTL(分钟),0=禁用
|
|
||||||
max_probe_interval_minutes?: number // 熔断探测间隔(分钟)
|
|
||||||
allowed_models?: string[] // 允许使用的模型列表
|
|
||||||
capabilities?: Record<string, boolean> // 能力标签配置
|
|
||||||
note?: string // 备注说明(可选)
|
|
||||||
}
|
|
||||||
): Promise<EndpointAPIKey> {
|
|
||||||
const response = await client.post(`/api/admin/endpoints/${endpointId}/keys`, data)
|
|
||||||
return response.data
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 更新 Endpoint Key
|
|
||||||
*/
|
|
||||||
export async function updateEndpointKey(
|
|
||||||
keyId: string,
|
|
||||||
data: Partial<{
|
|
||||||
api_key: string
|
|
||||||
name: string // 密钥名称
|
|
||||||
rate_multiplier: number // 成本倍率
|
|
||||||
internal_priority: number // Endpoint 内部优先级(提供商优先模式,数字越小越优先)
|
|
||||||
global_priority: number // 全局 Key 优先级(全局 Key 优先模式,数字越小越优先)
|
|
||||||
max_concurrent: number // 最大并发数(留空=自适应模式)
|
|
||||||
rate_limit: number
|
|
||||||
daily_limit: number
|
|
||||||
monthly_limit: number
|
|
||||||
cache_ttl_minutes: number // 缓存 TTL(分钟),0=禁用
|
|
||||||
max_probe_interval_minutes: number // 熔断探测间隔(分钟)
|
|
||||||
allowed_models: string[] | null // 允许使用的模型列表,null 表示允许所有
|
|
||||||
capabilities: Record<string, boolean> | null // 能力标签配置
|
|
||||||
is_active: boolean
|
|
||||||
note: string // 备注说明
|
|
||||||
}>
|
|
||||||
): Promise<EndpointAPIKey> {
|
|
||||||
const response = await client.put(`/api/admin/endpoints/keys/${keyId}`, data)
|
|
||||||
return response.data
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取完整的 API Key(用于查看和复制)
|
* 获取完整的 API Key(用于查看和复制)
|
||||||
*/
|
*/
|
||||||
@@ -119,22 +58,71 @@ export async function revealEndpointKey(keyId: string): Promise<{ api_key: strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 删除 Endpoint Key
|
* 删除 Key
|
||||||
*/
|
*/
|
||||||
export async function deleteEndpointKey(keyId: string): Promise<{ message: string }> {
|
export async function deleteEndpointKey(keyId: string): Promise<{ message: string }> {
|
||||||
const response = await client.delete(`/api/admin/endpoints/keys/${keyId}`)
|
const response = await client.delete(`/api/admin/endpoints/keys/${keyId}`)
|
||||||
return response.data
|
return response.data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ========== Provider 级别的 Keys API ==========
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 批量更新 Endpoint Keys 的优先级(用于拖动排序)
|
* 获取 Provider 的所有 Keys
|
||||||
*/
|
*/
|
||||||
export async function batchUpdateKeyPriority(
|
export async function getProviderKeys(providerId: string): Promise<EndpointAPIKey[]> {
|
||||||
endpointId: string,
|
const response = await client.get(`/api/admin/endpoints/providers/${providerId}/keys`)
|
||||||
priorities: Array<{ key_id: string; internal_priority: number }>
|
return response.data
|
||||||
): Promise<{ message: string; updated_count: number }> {
|
}
|
||||||
const response = await client.put(`/api/admin/endpoints/${endpointId}/keys/batch-priority`, {
|
|
||||||
priorities
|
/**
|
||||||
})
|
* 为 Provider 添加 Key
|
||||||
|
*/
|
||||||
|
export async function addProviderKey(
|
||||||
|
providerId: string,
|
||||||
|
data: {
|
||||||
|
api_formats: string[] // 支持的 API 格式列表(必填)
|
||||||
|
api_key: string
|
||||||
|
name: string
|
||||||
|
rate_multiplier?: number // 默认成本倍率
|
||||||
|
rate_multipliers?: Record<string, number> | null // 按 API 格式的成本倍率
|
||||||
|
internal_priority?: number
|
||||||
|
rpm_limit?: number | null // RPM 限制(留空=自适应模式)
|
||||||
|
cache_ttl_minutes?: number
|
||||||
|
max_probe_interval_minutes?: number
|
||||||
|
allowed_models?: AllowedModels
|
||||||
|
capabilities?: Record<string, boolean>
|
||||||
|
note?: string
|
||||||
|
}
|
||||||
|
): Promise<EndpointAPIKey> {
|
||||||
|
const response = await client.post(`/api/admin/endpoints/providers/${providerId}/keys`, data)
|
||||||
|
return response.data
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 更新 Key
|
||||||
|
*/
|
||||||
|
export async function updateProviderKey(
|
||||||
|
keyId: string,
|
||||||
|
data: Partial<{
|
||||||
|
api_formats: string[] // 支持的 API 格式列表
|
||||||
|
api_key: string
|
||||||
|
name: string
|
||||||
|
rate_multiplier: number // 默认成本倍率
|
||||||
|
rate_multipliers: Record<string, number> | null // 按 API 格式的成本倍率
|
||||||
|
internal_priority: number
|
||||||
|
global_priority: number | null
|
||||||
|
rpm_limit: number | null // RPM 限制(留空=自适应模式)
|
||||||
|
cache_ttl_minutes: number
|
||||||
|
max_probe_interval_minutes: number
|
||||||
|
allowed_models: AllowedModels
|
||||||
|
capabilities: Record<string, boolean> | null
|
||||||
|
is_active: boolean
|
||||||
|
note: string
|
||||||
|
}>
|
||||||
|
): Promise<EndpointAPIKey> {
|
||||||
|
const response = await client.put(`/api/admin/endpoints/keys/${keyId}`, data)
|
||||||
return response.data
|
return response.data
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -147,14 +147,26 @@ export async function queryProviderUpstreamModels(
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 从上游提供商导入模型
|
* 从上游提供商导入模型
|
||||||
|
* @param providerId 提供商 ID
|
||||||
|
* @param modelIds 模型 ID 列表
|
||||||
|
* @param options 可选配置
|
||||||
|
* @param options.tiered_pricing 阶梯计费配置
|
||||||
|
* @param options.price_per_request 按次计费价格
|
||||||
*/
|
*/
|
||||||
export async function importModelsFromUpstream(
|
export async function importModelsFromUpstream(
|
||||||
providerId: string,
|
providerId: string,
|
||||||
modelIds: string[]
|
modelIds: string[],
|
||||||
|
options?: {
|
||||||
|
tiered_pricing?: object
|
||||||
|
price_per_request?: number
|
||||||
|
}
|
||||||
): Promise<ImportFromUpstreamResponse> {
|
): Promise<ImportFromUpstreamResponse> {
|
||||||
const response = await client.post(
|
const response = await client.post(
|
||||||
`/api/admin/providers/${providerId}/import-from-upstream`,
|
`/api/admin/providers/${providerId}/import-from-upstream`,
|
||||||
{ model_ids: modelIds }
|
{
|
||||||
|
model_ids: modelIds,
|
||||||
|
...options
|
||||||
|
}
|
||||||
)
|
)
|
||||||
return response.data
|
return response.data
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import client from '../client'
|
import client from '../client'
|
||||||
import type { ProviderWithEndpointsSummary } from './types'
|
import type { ProviderWithEndpointsSummary, ProxyConfig } from './types'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取 Providers 摘要(包含 Endpoints 统计)
|
* 获取 Providers 摘要(包含 Endpoints 统计)
|
||||||
@@ -23,7 +23,7 @@ export async function getProvider(providerId: string): Promise<ProviderWithEndpo
|
|||||||
export async function updateProvider(
|
export async function updateProvider(
|
||||||
providerId: string,
|
providerId: string,
|
||||||
data: Partial<{
|
data: Partial<{
|
||||||
display_name: string
|
name: string
|
||||||
description: string
|
description: string
|
||||||
website: string
|
website: string
|
||||||
provider_priority: number
|
provider_priority: number
|
||||||
@@ -33,6 +33,10 @@ export async function updateProvider(
|
|||||||
quota_last_reset_at: string // 周期开始时间
|
quota_last_reset_at: string // 周期开始时间
|
||||||
quota_expires_at: string
|
quota_expires_at: string
|
||||||
rpm_limit: number | null
|
rpm_limit: number | null
|
||||||
|
// 请求配置(从 Endpoint 迁移)
|
||||||
|
timeout: number
|
||||||
|
max_retries: number
|
||||||
|
proxy: ProxyConfig | null
|
||||||
cache_ttl_minutes: number // 0表示不支持缓存,>0表示支持缓存并设置TTL(分钟)
|
cache_ttl_minutes: number // 0表示不支持缓存,>0表示支持缓存并设置TTL(分钟)
|
||||||
max_probe_interval_minutes: number
|
max_probe_interval_minutes: number
|
||||||
is_active: boolean
|
is_active: boolean
|
||||||
@@ -83,7 +87,6 @@ export interface TestModelResponse {
|
|||||||
provider?: {
|
provider?: {
|
||||||
id: string
|
id: string
|
||||||
name: string
|
name: string
|
||||||
display_name: string
|
|
||||||
}
|
}
|
||||||
model?: string
|
model?: string
|
||||||
}
|
}
|
||||||
@@ -92,4 +95,3 @@ export async function testModel(data: TestModelRequest): Promise<TestModelRespon
|
|||||||
const response = await client.post('/api/admin/provider-query/test-model', data)
|
const response = await client.post('/api/admin/provider-query/test-model', data)
|
||||||
return response.data
|
return response.data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,38 @@ export const API_FORMAT_LABELS: Record<string, string> = {
|
|||||||
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
|
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// API 格式缩写映射(用于空间紧凑的显示场景)
|
||||||
|
export const API_FORMAT_SHORT: Record<string, string> = {
|
||||||
|
[API_FORMATS.OPENAI]: 'O',
|
||||||
|
[API_FORMATS.OPENAI_CLI]: 'OC',
|
||||||
|
[API_FORMATS.CLAUDE]: 'C',
|
||||||
|
[API_FORMATS.CLAUDE_CLI]: 'CC',
|
||||||
|
[API_FORMATS.GEMINI]: 'G',
|
||||||
|
[API_FORMATS.GEMINI_CLI]: 'GC',
|
||||||
|
}
|
||||||
|
|
||||||
|
// API 格式排序顺序(统一的显示顺序)
|
||||||
|
export const API_FORMAT_ORDER: string[] = [
|
||||||
|
API_FORMATS.OPENAI,
|
||||||
|
API_FORMATS.OPENAI_CLI,
|
||||||
|
API_FORMATS.CLAUDE,
|
||||||
|
API_FORMATS.CLAUDE_CLI,
|
||||||
|
API_FORMATS.GEMINI,
|
||||||
|
API_FORMATS.GEMINI_CLI,
|
||||||
|
]
|
||||||
|
|
||||||
|
// 工具函数:按标准顺序排序 API 格式数组
|
||||||
|
export function sortApiFormats(formats: string[]): string[] {
|
||||||
|
return [...formats].sort((a, b) => {
|
||||||
|
const aIdx = API_FORMAT_ORDER.indexOf(a)
|
||||||
|
const bIdx = API_FORMAT_ORDER.indexOf(b)
|
||||||
|
if (aIdx === -1 && bIdx === -1) return 0
|
||||||
|
if (aIdx === -1) return 1
|
||||||
|
if (bIdx === -1) return -1
|
||||||
|
return aIdx - bIdx
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 代理配置类型
|
* 代理配置类型
|
||||||
*/
|
*/
|
||||||
@@ -37,18 +69,9 @@ export interface ProviderEndpoint {
|
|||||||
api_format: string
|
api_format: string
|
||||||
base_url: string
|
base_url: string
|
||||||
custom_path?: string // 自定义请求路径(可选,为空则使用 API 格式默认路径)
|
custom_path?: string // 自定义请求路径(可选,为空则使用 API 格式默认路径)
|
||||||
auth_type: string
|
|
||||||
auth_header?: string
|
|
||||||
headers?: Record<string, string>
|
headers?: Record<string, string>
|
||||||
timeout: number
|
timeout: number
|
||||||
max_retries: number
|
max_retries: number
|
||||||
priority: number
|
|
||||||
weight: number
|
|
||||||
max_concurrent?: number
|
|
||||||
rate_limit?: number
|
|
||||||
health_score: number
|
|
||||||
consecutive_failures: number
|
|
||||||
last_failure_at?: string
|
|
||||||
is_active: boolean
|
is_active: boolean
|
||||||
config?: Record<string, any>
|
config?: Record<string, any>
|
||||||
proxy?: ProxyConfig | null
|
proxy?: ProxyConfig | null
|
||||||
@@ -58,25 +81,55 @@ export interface ProviderEndpoint {
|
|||||||
updated_at: string
|
updated_at: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 模型权限配置类型(支持简单列表和按格式字典两种模式)
|
||||||
|
*
|
||||||
|
* 使用示例:
|
||||||
|
* 1. 不限制(允许所有模型): null
|
||||||
|
* 2. 简单列表模式(所有 API 格式共享同一个白名单): ["gpt-4", "claude-3-opus"]
|
||||||
|
* 3. 按格式字典模式(不同 API 格式使用不同的白名单):
|
||||||
|
* { "OPENAI": ["gpt-4"], "CLAUDE": ["claude-3-opus"] }
|
||||||
|
*/
|
||||||
|
export type AllowedModels = string[] | Record<string, string[]> | null
|
||||||
|
|
||||||
|
// AllowedModels 类型守卫函数
|
||||||
|
export function isAllowedModelsList(value: AllowedModels): value is string[] {
|
||||||
|
return Array.isArray(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isAllowedModelsDict(value: AllowedModels): value is Record<string, string[]> {
|
||||||
|
if (value === null || typeof value !== 'object' || Array.isArray(value)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// 验证所有值都是字符串数组
|
||||||
|
return Object.values(value).every(
|
||||||
|
(v) => Array.isArray(v) && v.every((item) => typeof item === 'string')
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
export interface EndpointAPIKey {
|
export interface EndpointAPIKey {
|
||||||
id: string
|
id: string
|
||||||
endpoint_id: string
|
provider_id: string
|
||||||
|
api_formats: string[] // 支持的 API 格式列表
|
||||||
api_key_masked: string
|
api_key_masked: string
|
||||||
api_key_plain?: string | null
|
api_key_plain?: string | null
|
||||||
name: string // 密钥名称(必填,用于识别)
|
name: string // 密钥名称(必填,用于识别)
|
||||||
rate_multiplier: number // 成本倍率(真实成本 = 表面成本 × 倍率)
|
rate_multiplier: number // 默认成本倍率(真实成本 = 表面成本 × 倍率)
|
||||||
internal_priority: number // Endpoint 内部优先级
|
rate_multipliers?: Record<string, number> | null // 按 API 格式的成本倍率,如 {"CLAUDE": 1.0, "OPENAI": 0.8}
|
||||||
|
internal_priority: number // Key 内部优先级
|
||||||
global_priority?: number | null // 全局 Key 优先级
|
global_priority?: number | null // 全局 Key 优先级
|
||||||
max_concurrent?: number
|
rpm_limit?: number | null // RPM 速率限制 (1-10000),null 表示自适应模式
|
||||||
rate_limit?: number
|
allowed_models?: AllowedModels // 允许使用的模型列表(null=不限制,列表=简单白名单,字典=按格式区分)
|
||||||
daily_limit?: number
|
|
||||||
monthly_limit?: number
|
|
||||||
allowed_models?: string[] | null // 允许使用的模型列表(null = 支持所有模型)
|
|
||||||
capabilities?: Record<string, boolean> | null // 能力标签配置(如 cache_1h, context_1m)
|
capabilities?: Record<string, boolean> | null // 能力标签配置(如 cache_1h, context_1m)
|
||||||
// 缓存与熔断配置
|
// 缓存与熔断配置
|
||||||
cache_ttl_minutes: number // 缓存 TTL(分钟),0=禁用
|
cache_ttl_minutes: number // 缓存 TTL(分钟),0=禁用
|
||||||
max_probe_interval_minutes: number // 熔断探测间隔(分钟)
|
max_probe_interval_minutes: number // 熔断探测间隔(分钟)
|
||||||
|
// 按格式的健康度数据
|
||||||
|
health_by_format?: Record<string, FormatHealthData>
|
||||||
|
circuit_breaker_by_format?: Record<string, FormatCircuitBreakerData>
|
||||||
|
// 聚合字段(从 health_by_format 计算,用于列表显示)
|
||||||
health_score: number
|
health_score: number
|
||||||
|
circuit_breaker_open?: boolean
|
||||||
consecutive_failures: number
|
consecutive_failures: number
|
||||||
last_failure_at?: string
|
last_failure_at?: string
|
||||||
request_count: number
|
request_count: number
|
||||||
@@ -89,10 +142,10 @@ export interface EndpointAPIKey {
|
|||||||
last_used_at?: string
|
last_used_at?: string
|
||||||
created_at: string
|
created_at: string
|
||||||
updated_at: string
|
updated_at: string
|
||||||
// 自适应并发字段
|
// 自适应 RPM 字段
|
||||||
is_adaptive?: boolean // 是否为自适应模式(max_concurrent=NULL)
|
is_adaptive?: boolean // 是否为自适应模式(rpm_limit=NULL)
|
||||||
effective_limit?: number // 当前有效限制(自适应使用学习值,固定使用配置值)
|
effective_limit?: number // 当前有效 RPM 限制(自适应使用学习值,固定使用配置值)
|
||||||
learned_max_concurrent?: number
|
learned_rpm_limit?: number // 学习到的 RPM 限制
|
||||||
// 滑动窗口利用率采样
|
// 滑动窗口利用率采样
|
||||||
utilization_samples?: Array<{ ts: number; util: number }> // 利用率采样窗口
|
utilization_samples?: Array<{ ts: number; util: number }> // 利用率采样窗口
|
||||||
last_probe_increase_at?: string // 上次探测性扩容时间
|
last_probe_increase_at?: string // 上次探测性扩容时间
|
||||||
@@ -100,8 +153,7 @@ export interface EndpointAPIKey {
|
|||||||
rpm_429_count?: number
|
rpm_429_count?: number
|
||||||
last_429_at?: string
|
last_429_at?: string
|
||||||
last_429_type?: string
|
last_429_type?: string
|
||||||
// 熔断器字段(滑动窗口 + 半开模式)
|
// 单格式场景的熔断器字段
|
||||||
circuit_breaker_open?: boolean
|
|
||||||
circuit_breaker_open_at?: string
|
circuit_breaker_open_at?: string
|
||||||
next_probe_at?: string
|
next_probe_at?: string
|
||||||
half_open_until?: string
|
half_open_until?: string
|
||||||
@@ -110,17 +162,36 @@ export interface EndpointAPIKey {
|
|||||||
request_results_window?: Array<{ ts: number; ok: boolean }> // 请求结果滑动窗口
|
request_results_window?: Array<{ ts: number; ok: boolean }> // 请求结果滑动窗口
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 按格式的健康度数据
|
||||||
|
export interface FormatHealthData {
|
||||||
|
health_score: number
|
||||||
|
error_rate: number
|
||||||
|
window_size: number
|
||||||
|
consecutive_failures: number
|
||||||
|
last_failure_at?: string | null
|
||||||
|
circuit_breaker: FormatCircuitBreakerData
|
||||||
|
}
|
||||||
|
|
||||||
|
// 按格式的熔断器数据
|
||||||
|
export interface FormatCircuitBreakerData {
|
||||||
|
open: boolean
|
||||||
|
open_at?: string | null
|
||||||
|
next_probe_at?: string | null
|
||||||
|
half_open_until?: string | null
|
||||||
|
half_open_successes: number
|
||||||
|
half_open_failures: number
|
||||||
|
}
|
||||||
|
|
||||||
export interface EndpointAPIKeyUpdate {
|
export interface EndpointAPIKeyUpdate {
|
||||||
|
api_formats?: string[] // 支持的 API 格式列表
|
||||||
name?: string
|
name?: string
|
||||||
api_key?: string // 仅在需要更新时提供
|
api_key?: string // 仅在需要更新时提供
|
||||||
rate_multiplier?: number
|
rate_multiplier?: number // 默认成本倍率
|
||||||
|
rate_multipliers?: Record<string, number> | null // 按 API 格式的成本倍率
|
||||||
internal_priority?: number
|
internal_priority?: number
|
||||||
global_priority?: number | null
|
global_priority?: number | null
|
||||||
max_concurrent?: number | null // null 表示切换为自适应模式
|
rpm_limit?: number | null // RPM 速率限制 (1-10000),null 表示切换为自适应模式
|
||||||
rate_limit?: number
|
allowed_models?: AllowedModels
|
||||||
daily_limit?: number
|
|
||||||
monthly_limit?: number
|
|
||||||
allowed_models?: string[] | null
|
|
||||||
capabilities?: Record<string, boolean> | null
|
capabilities?: Record<string, boolean> | null
|
||||||
cache_ttl_minutes?: number
|
cache_ttl_minutes?: number
|
||||||
max_probe_interval_minutes?: number
|
max_probe_interval_minutes?: number
|
||||||
@@ -198,7 +269,6 @@ export interface PublicEndpointStatusMonitorResponse {
|
|||||||
export interface ProviderWithEndpointsSummary {
|
export interface ProviderWithEndpointsSummary {
|
||||||
id: string
|
id: string
|
||||||
name: string
|
name: string
|
||||||
display_name: string
|
|
||||||
description?: string
|
description?: string
|
||||||
website?: string
|
website?: string
|
||||||
provider_priority: number
|
provider_priority: number
|
||||||
@@ -208,9 +278,10 @@ export interface ProviderWithEndpointsSummary {
|
|||||||
quota_reset_day?: number
|
quota_reset_day?: number
|
||||||
quota_last_reset_at?: string // 当前周期开始时间
|
quota_last_reset_at?: string // 当前周期开始时间
|
||||||
quota_expires_at?: string
|
quota_expires_at?: string
|
||||||
rpm_limit?: number | null
|
// 请求配置(从 Endpoint 迁移)
|
||||||
rpm_used?: number
|
timeout?: number // 请求超时(秒)
|
||||||
rpm_reset_at?: string
|
max_retries?: number // 最大重试次数
|
||||||
|
proxy?: ProxyConfig | null // 代理配置
|
||||||
is_active: boolean
|
is_active: boolean
|
||||||
total_endpoints: number
|
total_endpoints: number
|
||||||
active_endpoints: number
|
active_endpoints: number
|
||||||
@@ -253,13 +324,10 @@ export interface HealthSummary {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ConcurrencyStatus {
|
export interface KeyRpmStatus {
|
||||||
endpoint_id?: string
|
key_id: string
|
||||||
endpoint_current_concurrency: number
|
current_rpm: number
|
||||||
endpoint_max_concurrent?: number
|
rpm_limit?: number
|
||||||
key_id?: string
|
|
||||||
key_current_concurrency: number
|
|
||||||
key_max_concurrent?: number
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ProviderModelMapping {
|
export interface ProviderModelMapping {
|
||||||
@@ -361,7 +429,6 @@ export interface ModelPriceRange {
|
|||||||
export interface ModelCatalogProviderDetail {
|
export interface ModelCatalogProviderDetail {
|
||||||
provider_id: string
|
provider_id: string
|
||||||
provider_name: string
|
provider_name: string
|
||||||
provider_display_name?: string | null
|
|
||||||
model_id?: string | null
|
model_id?: string | null
|
||||||
target_model: string
|
target_model: string
|
||||||
input_price_per_1m?: number | null
|
input_price_per_1m?: number | null
|
||||||
@@ -534,10 +601,10 @@ export interface UpstreamModel {
|
|||||||
*/
|
*/
|
||||||
export interface ImportFromUpstreamSuccessItem {
|
export interface ImportFromUpstreamSuccessItem {
|
||||||
model_id: string
|
model_id: string
|
||||||
global_model_id: string
|
|
||||||
global_model_name: string
|
|
||||||
provider_model_id: string
|
provider_model_id: string
|
||||||
created_global_model: boolean
|
global_model_id?: string // 可选,未关联时为空字符串
|
||||||
|
global_model_name?: string // 可选,未关联时为空字符串
|
||||||
|
created_global_model: boolean // 始终为 false(不再自动创建 GlobalModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
112
frontend/src/components/common/UpdateDialog.vue
Normal file
112
frontend/src/components/common/UpdateDialog.vue
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
<template>
|
||||||
|
<Dialog
|
||||||
|
v-model="isOpen"
|
||||||
|
size="md"
|
||||||
|
title=""
|
||||||
|
>
|
||||||
|
<div class="flex flex-col items-center text-center py-2">
|
||||||
|
<!-- Logo -->
|
||||||
|
<HeaderLogo
|
||||||
|
size="h-16 w-16"
|
||||||
|
class-name="text-primary"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<!-- Title -->
|
||||||
|
<h2 class="text-xl font-semibold text-foreground mt-4 mb-2">
|
||||||
|
发现新版本
|
||||||
|
</h2>
|
||||||
|
|
||||||
|
<!-- Version Info -->
|
||||||
|
<div class="flex items-center gap-3 mb-4">
|
||||||
|
<span class="px-3 py-1.5 rounded-lg bg-muted text-sm font-mono text-muted-foreground">
|
||||||
|
v{{ currentVersion }}
|
||||||
|
</span>
|
||||||
|
<svg
|
||||||
|
class="h-4 w-4 text-muted-foreground"
|
||||||
|
fill="none"
|
||||||
|
stroke="currentColor"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
stroke-linecap="round"
|
||||||
|
stroke-linejoin="round"
|
||||||
|
stroke-width="2"
|
||||||
|
d="M13 7l5 5m0 0l-5 5m5-5H6"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
<span class="px-3 py-1.5 rounded-lg bg-primary/10 text-sm font-mono font-medium text-primary">
|
||||||
|
v{{ latestVersion }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Description -->
|
||||||
|
<p class="text-sm text-muted-foreground max-w-xs">
|
||||||
|
新版本已发布,建议更新以获得最新功能和安全修复
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<template #footer>
|
||||||
|
<div class="flex w-full gap-3">
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
class="flex-1"
|
||||||
|
@click="handleLater"
|
||||||
|
>
|
||||||
|
稍后提醒
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
class="flex-1"
|
||||||
|
@click="handleViewRelease"
|
||||||
|
>
|
||||||
|
查看更新
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
</Dialog>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { ref, watch } from 'vue'
|
||||||
|
import { Dialog } from '@/components/ui'
|
||||||
|
import Button from '@/components/ui/button.vue'
|
||||||
|
import HeaderLogo from '@/components/HeaderLogo.vue'
|
||||||
|
|
||||||
|
const props = defineProps<{
|
||||||
|
modelValue: boolean
|
||||||
|
currentVersion: string
|
||||||
|
latestVersion: string
|
||||||
|
releaseUrl: string | null
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const emit = defineEmits<{
|
||||||
|
'update:modelValue': [value: boolean]
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const isOpen = ref(props.modelValue)
|
||||||
|
|
||||||
|
watch(() => props.modelValue, (val) => {
|
||||||
|
isOpen.value = val
|
||||||
|
})
|
||||||
|
|
||||||
|
watch(isOpen, (val) => {
|
||||||
|
emit('update:modelValue', val)
|
||||||
|
})
|
||||||
|
|
||||||
|
function handleLater() {
|
||||||
|
// 记录忽略的版本,24小时内不再提醒
|
||||||
|
const ignoreKey = 'aether_update_ignore'
|
||||||
|
const ignoreData = {
|
||||||
|
version: props.latestVersion,
|
||||||
|
until: Date.now() + 24 * 60 * 60 * 1000 // 24小时
|
||||||
|
}
|
||||||
|
localStorage.setItem(ignoreKey, JSON.stringify(ignoreData))
|
||||||
|
isOpen.value = false
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleViewRelease() {
|
||||||
|
if (props.releaseUrl) {
|
||||||
|
window.open(props.releaseUrl, '_blank')
|
||||||
|
}
|
||||||
|
isOpen.value = false
|
||||||
|
}
|
||||||
|
</script>
|
||||||
15
frontend/src/components/ui/collapsible-content.vue
Normal file
15
frontend/src/components/ui/collapsible-content.vue
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
<script setup lang="ts">
|
||||||
|
import { CollapsibleContent, type CollapsibleContentProps } from 'radix-vue'
|
||||||
|
import { cn } from '@/lib/utils'
|
||||||
|
|
||||||
|
const props = defineProps<CollapsibleContentProps & { class?: string }>()
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<template>
|
||||||
|
<CollapsibleContent
|
||||||
|
v-bind="props"
|
||||||
|
:class="cn('overflow-hidden data-[state=closed]:animate-collapsible-up data-[state=open]:animate-collapsible-down', props.class)"
|
||||||
|
>
|
||||||
|
<slot />
|
||||||
|
</CollapsibleContent>
|
||||||
|
</template>
|
||||||
11
frontend/src/components/ui/collapsible-trigger.vue
Normal file
11
frontend/src/components/ui/collapsible-trigger.vue
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
<script setup lang="ts">
|
||||||
|
import { CollapsibleTrigger, type CollapsibleTriggerProps } from 'radix-vue'
|
||||||
|
|
||||||
|
const props = defineProps<CollapsibleTriggerProps>()
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<template>
|
||||||
|
<CollapsibleTrigger v-bind="props" as-child>
|
||||||
|
<slot />
|
||||||
|
</CollapsibleTrigger>
|
||||||
|
</template>
|
||||||
15
frontend/src/components/ui/collapsible.vue
Normal file
15
frontend/src/components/ui/collapsible.vue
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
<script setup lang="ts">
|
||||||
|
import { CollapsibleRoot, type CollapsibleRootEmits, type CollapsibleRootProps } from 'radix-vue'
|
||||||
|
import { useForwardPropsEmits } from 'radix-vue'
|
||||||
|
|
||||||
|
const props = defineProps<CollapsibleRootProps>()
|
||||||
|
const emits = defineEmits<CollapsibleRootEmits>()
|
||||||
|
|
||||||
|
const forwarded = useForwardPropsEmits(props, emits)
|
||||||
|
</script>
|
||||||
|
|
||||||
|
<template>
|
||||||
|
<CollapsibleRoot v-bind="forwarded">
|
||||||
|
<slot />
|
||||||
|
</CollapsibleRoot>
|
||||||
|
</template>
|
||||||
@@ -65,3 +65,8 @@ export { default as RefreshButton } from './refresh-button.vue'
|
|||||||
|
|
||||||
// Tooltip 提示系列
|
// Tooltip 提示系列
|
||||||
export { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from './tooltip'
|
export { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from './tooltip'
|
||||||
|
|
||||||
|
// Collapsible 折叠系列
|
||||||
|
export { default as Collapsible } from './collapsible.vue'
|
||||||
|
export { default as CollapsibleTrigger } from './collapsible-trigger.vue'
|
||||||
|
export { default as CollapsibleContent } from './collapsible-content.vue'
|
||||||
|
|||||||
@@ -186,7 +186,7 @@
|
|||||||
@click.stop
|
@click.stop
|
||||||
@change="toggleSelection('allowed_providers', provider.id)"
|
@change="toggleSelection('allowed_providers', provider.id)"
|
||||||
>
|
>
|
||||||
<span class="text-sm">{{ provider.display_name || provider.name }}</span>
|
<span class="text-sm">{{ provider.name }}</span>
|
||||||
</div>
|
</div>
|
||||||
<div
|
<div
|
||||||
v-if="providers.length === 0"
|
v-if="providers.length === 0"
|
||||||
|
|||||||
@@ -460,13 +460,13 @@
|
|||||||
<TableHead class="h-10 font-semibold">
|
<TableHead class="h-10 font-semibold">
|
||||||
Provider
|
Provider
|
||||||
</TableHead>
|
</TableHead>
|
||||||
<TableHead class="w-[120px] h-10 font-semibold">
|
<TableHead class="w-[100px] h-10 font-semibold">
|
||||||
能力
|
能力
|
||||||
</TableHead>
|
</TableHead>
|
||||||
<TableHead class="w-[180px] h-10 font-semibold">
|
<TableHead class="w-[200px] h-10 font-semibold">
|
||||||
价格 ($/M)
|
价格 ($/M)
|
||||||
</TableHead>
|
</TableHead>
|
||||||
<TableHead class="w-[80px] h-10 font-semibold text-center">
|
<TableHead class="w-[100px] h-10 font-semibold text-center">
|
||||||
操作
|
操作
|
||||||
</TableHead>
|
</TableHead>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
@@ -484,7 +484,7 @@
|
|||||||
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
||||||
:title="provider.is_active ? '活跃' : '停用'"
|
:title="provider.is_active ? '活跃' : '停用'"
|
||||||
/>
|
/>
|
||||||
<span class="font-medium truncate">{{ provider.display_name }}</span>
|
<span class="font-medium truncate">{{ provider.name }}</span>
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell class="py-3">
|
<TableCell class="py-3">
|
||||||
@@ -595,7 +595,7 @@
|
|||||||
class="w-2 h-2 rounded-full shrink-0"
|
class="w-2 h-2 rounded-full shrink-0"
|
||||||
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
||||||
/>
|
/>
|
||||||
<span class="font-medium truncate">{{ provider.display_name }}</span>
|
<span class="font-medium truncate">{{ provider.name }}</span>
|
||||||
</div>
|
</div>
|
||||||
<div class="flex items-center gap-1 shrink-0">
|
<div class="flex items-center gap-1 shrink-0">
|
||||||
<Button
|
<Button
|
||||||
|
|||||||
@@ -531,20 +531,23 @@ watch(() => props.open, async (isOpen) => {
|
|||||||
// 加载数据
|
// 加载数据
|
||||||
async function loadData() {
|
async function loadData() {
|
||||||
await Promise.all([loadGlobalModels(), loadExistingModels()])
|
await Promise.all([loadGlobalModels(), loadExistingModels()])
|
||||||
// 默认折叠全局模型组
|
|
||||||
collapsedGroups.value = new Set(['global'])
|
|
||||||
|
|
||||||
// 检查缓存,如果有缓存数据则直接使用
|
// 检查缓存,如果有缓存数据则直接使用
|
||||||
const cachedModels = getCachedModels(props.providerId)
|
const cachedModels = getCachedModels(props.providerId)
|
||||||
if (cachedModels) {
|
if (cachedModels && cachedModels.length > 0) {
|
||||||
upstreamModels.value = cachedModels
|
upstreamModels.value = cachedModels
|
||||||
upstreamModelsLoaded.value = true
|
upstreamModelsLoaded.value = true
|
||||||
// 折叠所有上游模型组
|
// 有多个分组时全部折叠
|
||||||
|
const allGroups = new Set(['global'])
|
||||||
for (const model of cachedModels) {
|
for (const model of cachedModels) {
|
||||||
if (model.api_format) {
|
if (model.api_format) {
|
||||||
collapsedGroups.value.add(model.api_format)
|
allGroups.add(model.api_format)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
collapsedGroups.value = allGroups
|
||||||
|
} else {
|
||||||
|
// 只有全局模型时展开
|
||||||
|
collapsedGroups.value = new Set()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -585,8 +588,8 @@ async function fetchUpstreamModels(forceRefresh = false) {
|
|||||||
} else {
|
} else {
|
||||||
upstreamModels.value = result.models
|
upstreamModels.value = result.models
|
||||||
upstreamModelsLoaded.value = true
|
upstreamModelsLoaded.value = true
|
||||||
// 折叠所有上游模型组
|
// 有多个分组时全部折叠
|
||||||
const allGroups = new Set(collapsedGroups.value)
|
const allGroups = new Set(['global'])
|
||||||
for (const model of result.models) {
|
for (const model of result.models) {
|
||||||
if (model.api_format) {
|
if (model.api_format) {
|
||||||
allGroups.add(model.api_format)
|
allGroups.add(model.api_format)
|
||||||
|
|||||||
@@ -1,245 +1,200 @@
|
|||||||
<template>
|
<template>
|
||||||
<Dialog
|
<Dialog
|
||||||
:model-value="internalOpen"
|
:model-value="internalOpen"
|
||||||
:title="isEditMode ? '编辑 API 端点' : '添加 API 端点'"
|
title="端点管理"
|
||||||
:description="isEditMode ? `修改 ${provider?.display_name} 的端点配置` : '为提供商添加新的 API 端点'"
|
:description="`管理 ${provider?.name} 的 API 端点`"
|
||||||
:icon="isEditMode ? SquarePen : Link"
|
:icon="Settings"
|
||||||
size="xl"
|
size="2xl"
|
||||||
@update:model-value="handleDialogUpdate"
|
@update:model-value="handleDialogUpdate"
|
||||||
>
|
>
|
||||||
<form
|
<div class="space-y-4">
|
||||||
class="space-y-6"
|
<!-- 已有端点列表 -->
|
||||||
@submit.prevent="handleSubmit()"
|
<div
|
||||||
>
|
v-if="localEndpoints.length > 0"
|
||||||
<!-- API 配置 -->
|
class="space-y-2"
|
||||||
<div class="space-y-4">
|
>
|
||||||
<h3
|
<Label class="text-muted-foreground">已配置的端点</Label>
|
||||||
v-if="isEditMode"
|
|
||||||
class="text-sm font-medium"
|
|
||||||
>
|
|
||||||
API 配置
|
|
||||||
</h3>
|
|
||||||
|
|
||||||
<div class="grid grid-cols-2 gap-4">
|
|
||||||
<!-- API 格式 -->
|
|
||||||
<div class="space-y-2">
|
|
||||||
<Label for="api_format">API 格式 *</Label>
|
|
||||||
<template v-if="isEditMode">
|
|
||||||
<Input
|
|
||||||
id="api_format"
|
|
||||||
v-model="form.api_format"
|
|
||||||
disabled
|
|
||||||
class="bg-muted"
|
|
||||||
/>
|
|
||||||
<p class="text-xs text-muted-foreground">
|
|
||||||
API 格式创建后不可修改
|
|
||||||
</p>
|
|
||||||
</template>
|
|
||||||
<template v-else>
|
|
||||||
<Select
|
|
||||||
v-model="form.api_format"
|
|
||||||
v-model:open="selectOpen"
|
|
||||||
required
|
|
||||||
>
|
|
||||||
<SelectTrigger>
|
|
||||||
<SelectValue placeholder="请选择 API 格式" />
|
|
||||||
</SelectTrigger>
|
|
||||||
<SelectContent>
|
|
||||||
<SelectItem
|
|
||||||
v-for="format in apiFormats"
|
|
||||||
:key="format.value"
|
|
||||||
:value="format.value"
|
|
||||||
>
|
|
||||||
{{ format.label }}
|
|
||||||
</SelectItem>
|
|
||||||
</SelectContent>
|
|
||||||
</Select>
|
|
||||||
</template>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- API URL -->
|
|
||||||
<div class="space-y-2">
|
|
||||||
<Label for="base_url">API URL *</Label>
|
|
||||||
<Input
|
|
||||||
id="base_url"
|
|
||||||
v-model="form.base_url"
|
|
||||||
placeholder="https://api.example.com"
|
|
||||||
required
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 自定义路径 -->
|
|
||||||
<div class="space-y-2">
|
<div class="space-y-2">
|
||||||
<Label for="custom_path">自定义请求路径(可选)</Label>
|
<div
|
||||||
<Input
|
v-for="endpoint in localEndpoints"
|
||||||
id="custom_path"
|
:key="endpoint.id"
|
||||||
v-model="form.custom_path"
|
class="rounded-md border px-3 py-2"
|
||||||
:placeholder="defaultPathPlaceholder"
|
:class="{ 'opacity-50': !endpoint.is_active }"
|
||||||
/>
|
>
|
||||||
</div>
|
<!-- 编辑模式 -->
|
||||||
</div>
|
<template v-if="editingEndpointId === endpoint.id">
|
||||||
|
<div class="space-y-2">
|
||||||
<!-- 请求配置 -->
|
<div class="flex items-center gap-2">
|
||||||
<div class="space-y-4">
|
<span class="text-sm font-medium w-24 shrink-0">{{ API_FORMAT_LABELS[endpoint.api_format] || endpoint.api_format }}</span>
|
||||||
<h3 class="text-sm font-medium">
|
<div class="flex items-center gap-1 ml-auto">
|
||||||
请求配置
|
<Button
|
||||||
</h3>
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
<div class="grid grid-cols-3 gap-4">
|
class="h-7 w-7"
|
||||||
<div class="space-y-2">
|
title="保存"
|
||||||
<Label for="timeout">超时(秒)</Label>
|
:disabled="savingEndpointId === endpoint.id"
|
||||||
<Input
|
@click="saveEndpointUrl(endpoint)"
|
||||||
id="timeout"
|
>
|
||||||
v-model.number="form.timeout"
|
<Check class="w-3.5 h-3.5" />
|
||||||
type="number"
|
</Button>
|
||||||
placeholder="300"
|
<Button
|
||||||
/>
|
variant="ghost"
|
||||||
</div>
|
size="icon"
|
||||||
|
class="h-7 w-7"
|
||||||
<div class="space-y-2">
|
title="取消"
|
||||||
<Label for="max_retries">最大重试</Label>
|
@click="cancelEdit"
|
||||||
<Input
|
>
|
||||||
id="max_retries"
|
<X class="w-3.5 h-3.5" />
|
||||||
v-model.number="form.max_retries"
|
</Button>
|
||||||
type="number"
|
</div>
|
||||||
placeholder="3"
|
</div>
|
||||||
/>
|
<div class="grid grid-cols-2 gap-2">
|
||||||
</div>
|
<div class="space-y-1">
|
||||||
|
<Label class="text-xs text-muted-foreground">Base URL</Label>
|
||||||
<div class="space-y-2">
|
<Input
|
||||||
<Label for="max_concurrent">最大并发</Label>
|
v-model="editingUrl"
|
||||||
<Input
|
class="h-8 text-sm"
|
||||||
id="max_concurrent"
|
placeholder="https://api.example.com"
|
||||||
:model-value="form.max_concurrent ?? ''"
|
@keyup.escape="cancelEdit"
|
||||||
type="number"
|
/>
|
||||||
placeholder="无限制"
|
</div>
|
||||||
@update:model-value="(v) => form.max_concurrent = parseNumberInput(v)"
|
<div class="space-y-1">
|
||||||
/>
|
<Label class="text-xs text-muted-foreground">自定义路径 (可选)</Label>
|
||||||
</div>
|
<Input
|
||||||
</div>
|
v-model="editingPath"
|
||||||
|
class="h-8 text-sm"
|
||||||
<div class="grid grid-cols-2 gap-4">
|
:placeholder="editingDefaultPath || '留空使用默认路径'"
|
||||||
<div class="space-y-2">
|
@keyup.escape="cancelEdit"
|
||||||
<Label for="rate_limit">速率限制(请求/分钟)</Label>
|
/>
|
||||||
<Input
|
</div>
|
||||||
id="rate_limit"
|
</div>
|
||||||
:model-value="form.rate_limit ?? ''"
|
</div>
|
||||||
type="number"
|
</template>
|
||||||
placeholder="无限制"
|
<!-- 查看模式 -->
|
||||||
@update:model-value="(v) => form.rate_limit = parseNumberInput(v)"
|
<template v-else>
|
||||||
/>
|
<div class="flex items-center gap-3">
|
||||||
|
<div class="w-24 shrink-0">
|
||||||
|
<span class="text-sm font-medium">{{ API_FORMAT_LABELS[endpoint.api_format] || endpoint.api_format }}</span>
|
||||||
|
</div>
|
||||||
|
<div class="flex-1 min-w-0">
|
||||||
|
<span class="text-sm text-muted-foreground truncate block">
|
||||||
|
{{ endpoint.base_url }}{{ endpoint.custom_path ? endpoint.custom_path : '' }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<div class="flex items-center gap-1 shrink-0">
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
class="h-7 w-7"
|
||||||
|
title="编辑"
|
||||||
|
@click="startEdit(endpoint)"
|
||||||
|
>
|
||||||
|
<Edit class="w-3.5 h-3.5" />
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
class="h-7 w-7"
|
||||||
|
:title="endpoint.is_active ? '停用' : '启用'"
|
||||||
|
:disabled="togglingEndpointId === endpoint.id"
|
||||||
|
@click="handleToggleEndpoint(endpoint)"
|
||||||
|
>
|
||||||
|
<Power class="w-3.5 h-3.5" />
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
class="h-7 w-7 text-destructive hover:text-destructive"
|
||||||
|
title="删除"
|
||||||
|
:disabled="deletingEndpointId === endpoint.id"
|
||||||
|
@click="handleDeleteEndpoint(endpoint)"
|
||||||
|
>
|
||||||
|
<Trash2 class="w-3.5 h-3.5" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 代理配置 -->
|
<!-- 添加新端点 -->
|
||||||
<div class="space-y-4">
|
<div
|
||||||
<div class="flex items-center justify-between">
|
v-if="availableFormats.length > 0"
|
||||||
<h3 class="text-sm font-medium">
|
class="space-y-3 pt-3 border-t"
|
||||||
代理配置
|
>
|
||||||
</h3>
|
<Label class="text-muted-foreground">添加新端点</Label>
|
||||||
<div class="flex items-center gap-2">
|
<div class="flex items-end gap-3">
|
||||||
<Switch v-model="proxyEnabled" />
|
<div class="w-32 shrink-0 space-y-1.5">
|
||||||
<span class="text-sm text-muted-foreground">启用代理</span>
|
<Label class="text-xs">API 格式</Label>
|
||||||
</div>
|
<Select
|
||||||
</div>
|
v-model="newEndpoint.api_format"
|
||||||
|
v-model:open="formatSelectOpen"
|
||||||
<div
|
|
||||||
v-if="proxyEnabled"
|
|
||||||
class="space-y-4 rounded-lg border p-4"
|
|
||||||
>
|
|
||||||
<div class="space-y-2">
|
|
||||||
<Label for="proxy_url">代理 URL *</Label>
|
|
||||||
<Input
|
|
||||||
id="proxy_url"
|
|
||||||
v-model="form.proxy_url"
|
|
||||||
placeholder="http://host:port 或 socks5://host:port"
|
|
||||||
required
|
|
||||||
:class="proxyUrlError ? 'border-red-500' : ''"
|
|
||||||
/>
|
|
||||||
<p
|
|
||||||
v-if="proxyUrlError"
|
|
||||||
class="text-xs text-red-500"
|
|
||||||
>
|
>
|
||||||
{{ proxyUrlError }}
|
<SelectTrigger class="h-9">
|
||||||
</p>
|
<SelectValue placeholder="选择格式" />
|
||||||
<p
|
</SelectTrigger>
|
||||||
v-else
|
<SelectContent>
|
||||||
class="text-xs text-muted-foreground"
|
<SelectItem
|
||||||
>
|
v-for="format in availableFormats"
|
||||||
支持 HTTP、HTTPS、SOCKS5 代理
|
:key="format.value"
|
||||||
</p>
|
:value="format.value"
|
||||||
|
>
|
||||||
|
{{ format.label }}
|
||||||
|
</SelectItem>
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="flex-1 space-y-1.5">
|
||||||
<div class="grid grid-cols-2 gap-4">
|
<Label class="text-xs">Base URL</Label>
|
||||||
<div class="space-y-2">
|
<Input
|
||||||
<Label for="proxy_user">用户名(可选)</Label>
|
v-model="newEndpoint.base_url"
|
||||||
<Input
|
placeholder="https://api.example.com"
|
||||||
:id="`proxy_user_${formId}`"
|
class="h-9"
|
||||||
v-model="form.proxy_username"
|
/>
|
||||||
:name="`proxy_user_${formId}`"
|
|
||||||
placeholder="代理认证用户名"
|
|
||||||
autocomplete="off"
|
|
||||||
data-form-type="other"
|
|
||||||
data-lpignore="true"
|
|
||||||
data-1p-ignore="true"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="space-y-2">
|
|
||||||
<Label :for="`proxy_pass_${formId}`">密码(可选)</Label>
|
|
||||||
<Input
|
|
||||||
:id="`proxy_pass_${formId}`"
|
|
||||||
v-model="form.proxy_password"
|
|
||||||
:name="`proxy_pass_${formId}`"
|
|
||||||
type="text"
|
|
||||||
:placeholder="passwordPlaceholder"
|
|
||||||
autocomplete="off"
|
|
||||||
data-form-type="other"
|
|
||||||
data-lpignore="true"
|
|
||||||
data-1p-ignore="true"
|
|
||||||
:style="{ '-webkit-text-security': 'disc', 'text-security': 'disc' }"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
|
<div class="w-40 shrink-0 space-y-1.5">
|
||||||
|
<Label class="text-xs">自定义路径</Label>
|
||||||
|
<Input
|
||||||
|
v-model="newEndpoint.custom_path"
|
||||||
|
:placeholder="newEndpointDefaultPath || '可选'"
|
||||||
|
class="h-9"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
class="h-9 shrink-0"
|
||||||
|
:disabled="!newEndpoint.api_format || !newEndpoint.base_url || addingEndpoint"
|
||||||
|
@click="handleAddEndpoint"
|
||||||
|
>
|
||||||
|
{{ addingEndpoint ? '添加中...' : '添加' }}
|
||||||
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</form>
|
|
||||||
|
<!-- 空状态 -->
|
||||||
|
<div
|
||||||
|
v-if="localEndpoints.length === 0 && availableFormats.length === 0"
|
||||||
|
class="text-center py-8 text-muted-foreground"
|
||||||
|
>
|
||||||
|
<p>所有 API 格式都已配置</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<template #footer>
|
<template #footer>
|
||||||
<Button
|
<Button
|
||||||
type="button"
|
|
||||||
variant="outline"
|
variant="outline"
|
||||||
:disabled="loading"
|
@click="handleClose"
|
||||||
@click="handleCancel"
|
|
||||||
>
|
>
|
||||||
取消
|
关闭
|
||||||
</Button>
|
|
||||||
<Button
|
|
||||||
:disabled="loading || !form.base_url || (!isEditMode && !form.api_format)"
|
|
||||||
@click="handleSubmit()"
|
|
||||||
>
|
|
||||||
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : '创建') }}
|
|
||||||
</Button>
|
</Button>
|
||||||
</template>
|
</template>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
|
|
||||||
<!-- 确认清空凭据对话框 -->
|
|
||||||
<AlertDialog
|
|
||||||
v-model="showClearCredentialsDialog"
|
|
||||||
title="清空代理凭据"
|
|
||||||
description="代理 URL 为空,但用户名和密码仍有值。是否清空这些凭据并继续保存?"
|
|
||||||
type="warning"
|
|
||||||
confirm-text="清空并保存"
|
|
||||||
cancel-text="返回编辑"
|
|
||||||
@confirm="confirmClearCredentials"
|
|
||||||
@cancel="showClearCredentialsDialog = false"
|
|
||||||
/>
|
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed, onMounted } from 'vue'
|
import { ref, computed, onMounted, watch } from 'vue'
|
||||||
import {
|
import {
|
||||||
Dialog,
|
Dialog,
|
||||||
Button,
|
Button,
|
||||||
@@ -250,17 +205,15 @@ import {
|
|||||||
SelectValue,
|
SelectValue,
|
||||||
SelectContent,
|
SelectContent,
|
||||||
SelectItem,
|
SelectItem,
|
||||||
Switch,
|
|
||||||
} from '@/components/ui'
|
} from '@/components/ui'
|
||||||
import AlertDialog from '@/components/common/AlertDialog.vue'
|
import { Settings, Edit, Trash2, Check, X, Power } from 'lucide-vue-next'
|
||||||
import { Link, SquarePen } from 'lucide-vue-next'
|
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { useFormDialog } from '@/composables/useFormDialog'
|
|
||||||
import { parseNumberInput } from '@/utils/form'
|
|
||||||
import { log } from '@/utils/logger'
|
import { log } from '@/utils/logger'
|
||||||
import {
|
import {
|
||||||
createEndpoint,
|
createEndpoint,
|
||||||
updateEndpoint,
|
updateEndpoint,
|
||||||
|
deleteEndpoint,
|
||||||
|
API_FORMAT_LABELS,
|
||||||
type ProviderEndpoint,
|
type ProviderEndpoint,
|
||||||
type ProviderWithEndpointsSummary
|
type ProviderWithEndpointsSummary
|
||||||
} from '@/api/endpoints'
|
} from '@/api/endpoints'
|
||||||
@@ -269,7 +222,7 @@ import { adminApi } from '@/api/admin'
|
|||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
modelValue: boolean
|
modelValue: boolean
|
||||||
provider: ProviderWithEndpointsSummary | null
|
provider: ProviderWithEndpointsSummary | null
|
||||||
endpoint?: ProviderEndpoint | null // 编辑模式时传入
|
endpoints?: ProviderEndpoint[]
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const emit = defineEmits<{
|
const emit = defineEmits<{
|
||||||
@@ -279,258 +232,184 @@ const emit = defineEmits<{
|
|||||||
}>()
|
}>()
|
||||||
|
|
||||||
const { success, error: showError } = useToast()
|
const { success, error: showError } = useToast()
|
||||||
const loading = ref(false)
|
|
||||||
const selectOpen = ref(false)
|
|
||||||
const proxyEnabled = ref(false)
|
|
||||||
const showClearCredentialsDialog = ref(false) // 确认清空凭据对话框
|
|
||||||
|
|
||||||
// 生成随机 ID 防止浏览器自动填充
|
// 状态
|
||||||
const formId = Math.random().toString(36).substring(2, 10)
|
const addingEndpoint = ref(false)
|
||||||
|
const editingEndpointId = ref<string | null>(null)
|
||||||
|
const editingUrl = ref('')
|
||||||
|
const editingPath = ref('')
|
||||||
|
const savingEndpointId = ref<string | null>(null)
|
||||||
|
const deletingEndpointId = ref<string | null>(null)
|
||||||
|
const togglingEndpointId = ref<string | null>(null)
|
||||||
|
const formatSelectOpen = ref(false)
|
||||||
|
|
||||||
// 内部状态
|
// 内部状态
|
||||||
const internalOpen = computed(() => props.modelValue)
|
const internalOpen = computed(() => props.modelValue)
|
||||||
|
|
||||||
// 表单数据
|
// 新端点表单
|
||||||
const form = ref({
|
const newEndpoint = ref({
|
||||||
api_format: '',
|
api_format: '',
|
||||||
base_url: '',
|
base_url: '',
|
||||||
custom_path: '',
|
custom_path: '',
|
||||||
timeout: 300,
|
|
||||||
max_retries: 3,
|
|
||||||
max_concurrent: undefined as number | undefined,
|
|
||||||
rate_limit: undefined as number | undefined,
|
|
||||||
is_active: true,
|
|
||||||
// 代理配置
|
|
||||||
proxy_url: '',
|
|
||||||
proxy_username: '',
|
|
||||||
proxy_password: '',
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// API 格式列表
|
// API 格式列表
|
||||||
const apiFormats = ref<Array<{ value: string; label: string; default_path: string; aliases: string[] }>>([])
|
const apiFormats = ref<Array<{ value: string; label: string; default_path: string }>>([])
|
||||||
|
|
||||||
// 加载API格式列表
|
// 本地端点列表
|
||||||
|
const localEndpoints = ref<ProviderEndpoint[]>([])
|
||||||
|
|
||||||
|
// 可用的格式(未添加的)
|
||||||
|
const availableFormats = computed(() => {
|
||||||
|
const existingFormats = localEndpoints.value.map(e => e.api_format)
|
||||||
|
return apiFormats.value.filter(f => !existingFormats.includes(f.value))
|
||||||
|
})
|
||||||
|
|
||||||
|
// 获取指定 API 格式的默认路径
|
||||||
|
function getDefaultPath(apiFormat: string): string {
|
||||||
|
const format = apiFormats.value.find(f => f.value === apiFormat)
|
||||||
|
return format?.default_path || ''
|
||||||
|
}
|
||||||
|
|
||||||
|
// 当前编辑端点的默认路径
|
||||||
|
const editingDefaultPath = computed(() => {
|
||||||
|
const endpoint = localEndpoints.value.find(e => e.id === editingEndpointId.value)
|
||||||
|
return endpoint ? getDefaultPath(endpoint.api_format) : ''
|
||||||
|
})
|
||||||
|
|
||||||
|
// 新端点选择的格式的默认路径
|
||||||
|
const newEndpointDefaultPath = computed(() => {
|
||||||
|
return getDefaultPath(newEndpoint.value.api_format)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 加载 API 格式列表
|
||||||
const loadApiFormats = async () => {
|
const loadApiFormats = async () => {
|
||||||
try {
|
try {
|
||||||
const response = await adminApi.getApiFormats()
|
const response = await adminApi.getApiFormats()
|
||||||
apiFormats.value = response.formats
|
apiFormats.value = response.formats
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
log.error('加载API格式失败:', error)
|
log.error('加载API格式失败:', error)
|
||||||
if (!isEditMode.value) {
|
|
||||||
showError('加载API格式失败', '错误')
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 根据选择的 API 格式计算默认路径
|
|
||||||
const defaultPath = computed(() => {
|
|
||||||
const format = apiFormats.value.find(f => f.value === form.value.api_format)
|
|
||||||
return format?.default_path || '/'
|
|
||||||
})
|
|
||||||
|
|
||||||
// 动态 placeholder
|
|
||||||
const defaultPathPlaceholder = computed(() => {
|
|
||||||
return `留空使用默认路径:${defaultPath.value}`
|
|
||||||
})
|
|
||||||
|
|
||||||
// 检查是否有已保存的密码(后端返回 *** 表示有密码)
|
|
||||||
const hasExistingPassword = computed(() => {
|
|
||||||
if (!props.endpoint?.proxy) return false
|
|
||||||
const proxy = props.endpoint.proxy as { password?: string }
|
|
||||||
return proxy?.password === MASKED_PASSWORD
|
|
||||||
})
|
|
||||||
|
|
||||||
// 密码输入框的 placeholder
|
|
||||||
const passwordPlaceholder = computed(() => {
|
|
||||||
if (hasExistingPassword.value) {
|
|
||||||
return '已保存密码,留空保持不变'
|
|
||||||
}
|
|
||||||
return '代理认证密码'
|
|
||||||
})
|
|
||||||
|
|
||||||
// 代理 URL 验证
|
|
||||||
const proxyUrlError = computed(() => {
|
|
||||||
// 只有启用代理且填写了 URL 时才验证
|
|
||||||
if (!proxyEnabled.value || !form.value.proxy_url) {
|
|
||||||
return ''
|
|
||||||
}
|
|
||||||
const url = form.value.proxy_url.trim()
|
|
||||||
|
|
||||||
// 检查禁止的特殊字符
|
|
||||||
if (/[\n\r]/.test(url)) {
|
|
||||||
return '代理 URL 包含非法字符'
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证协议(不支持 SOCKS4)
|
|
||||||
if (!/^(http|https|socks5):\/\//i.test(url)) {
|
|
||||||
return '代理 URL 必须以 http://, https:// 或 socks5:// 开头'
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
const parsed = new URL(url)
|
|
||||||
if (!parsed.host) {
|
|
||||||
return '代理 URL 必须包含有效的 host'
|
|
||||||
}
|
|
||||||
// 禁止 URL 中内嵌认证信息
|
|
||||||
if (parsed.username || parsed.password) {
|
|
||||||
return '请勿在 URL 中包含用户名和密码,请使用独立的认证字段'
|
|
||||||
}
|
|
||||||
} catch {
|
|
||||||
return '代理 URL 格式无效'
|
|
||||||
}
|
|
||||||
return ''
|
|
||||||
})
|
|
||||||
|
|
||||||
// 组件挂载时加载API格式
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
loadApiFormats()
|
loadApiFormats()
|
||||||
})
|
})
|
||||||
|
|
||||||
// 重置表单
|
// 监听 props 变化
|
||||||
function resetForm() {
|
watch(() => props.modelValue, (open) => {
|
||||||
form.value = {
|
if (open) {
|
||||||
api_format: '',
|
localEndpoints.value = [...(props.endpoints || [])]
|
||||||
base_url: '',
|
// 重置编辑状态
|
||||||
custom_path: '',
|
editingEndpointId.value = null
|
||||||
timeout: 300,
|
editingUrl.value = ''
|
||||||
max_retries: 3,
|
editingPath.value = ''
|
||||||
max_concurrent: undefined,
|
} else {
|
||||||
rate_limit: undefined,
|
// 关闭对话框时完全清空新端点表单
|
||||||
is_active: true,
|
newEndpoint.value = { api_format: '', base_url: '', custom_path: '' }
|
||||||
proxy_url: '',
|
|
||||||
proxy_username: '',
|
|
||||||
proxy_password: '',
|
|
||||||
}
|
}
|
||||||
proxyEnabled.value = false
|
}, { immediate: true })
|
||||||
|
|
||||||
|
watch(() => props.endpoints, (endpoints) => {
|
||||||
|
if (props.modelValue) {
|
||||||
|
localEndpoints.value = [...(endpoints || [])]
|
||||||
|
}
|
||||||
|
}, { deep: true })
|
||||||
|
|
||||||
|
// 开始编辑
|
||||||
|
function startEdit(endpoint: ProviderEndpoint) {
|
||||||
|
editingEndpointId.value = endpoint.id
|
||||||
|
editingUrl.value = endpoint.base_url
|
||||||
|
editingPath.value = endpoint.custom_path || ''
|
||||||
}
|
}
|
||||||
|
|
||||||
// 原始密码占位符(后端返回的脱敏标记)
|
// 取消编辑
|
||||||
const MASKED_PASSWORD = '***'
|
function cancelEdit() {
|
||||||
|
editingEndpointId.value = null
|
||||||
// 加载端点数据(编辑模式)
|
editingUrl.value = ''
|
||||||
function loadEndpointData() {
|
editingPath.value = ''
|
||||||
if (!props.endpoint) return
|
|
||||||
|
|
||||||
const proxy = props.endpoint.proxy as { url?: string; username?: string; password?: string; enabled?: boolean } | null
|
|
||||||
|
|
||||||
form.value = {
|
|
||||||
api_format: props.endpoint.api_format,
|
|
||||||
base_url: props.endpoint.base_url,
|
|
||||||
custom_path: props.endpoint.custom_path || '',
|
|
||||||
timeout: props.endpoint.timeout,
|
|
||||||
max_retries: props.endpoint.max_retries,
|
|
||||||
max_concurrent: props.endpoint.max_concurrent || undefined,
|
|
||||||
rate_limit: props.endpoint.rate_limit || undefined,
|
|
||||||
is_active: props.endpoint.is_active,
|
|
||||||
proxy_url: proxy?.url || '',
|
|
||||||
proxy_username: proxy?.username || '',
|
|
||||||
// 如果密码是脱敏标记,显示为空(让用户知道有密码但看不到)
|
|
||||||
proxy_password: proxy?.password === MASKED_PASSWORD ? '' : (proxy?.password || ''),
|
|
||||||
}
|
|
||||||
|
|
||||||
// 根据 enabled 字段或 url 存在判断是否启用代理
|
|
||||||
proxyEnabled.value = proxy?.enabled ?? !!proxy?.url
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用 useFormDialog 统一处理对话框逻辑
|
// 保存端点
|
||||||
const { isEditMode, handleDialogUpdate, handleCancel } = useFormDialog({
|
async function saveEndpointUrl(endpoint: ProviderEndpoint) {
|
||||||
isOpen: () => props.modelValue,
|
if (!editingUrl.value) return
|
||||||
entity: () => props.endpoint,
|
|
||||||
isLoading: loading,
|
|
||||||
onClose: () => emit('update:modelValue', false),
|
|
||||||
loadData: loadEndpointData,
|
|
||||||
resetForm,
|
|
||||||
})
|
|
||||||
|
|
||||||
// 构建代理配置
|
savingEndpointId.value = endpoint.id
|
||||||
// - 有 URL 时始终保存配置,通过 enabled 字段控制是否启用
|
|
||||||
// - 无 URL 时返回 null
|
|
||||||
function buildProxyConfig(): { url: string; username?: string; password?: string; enabled: boolean } | null {
|
|
||||||
if (!form.value.proxy_url) {
|
|
||||||
// 没填 URL,无代理配置
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
url: form.value.proxy_url,
|
|
||||||
username: form.value.proxy_username || undefined,
|
|
||||||
password: form.value.proxy_password || undefined,
|
|
||||||
enabled: proxyEnabled.value, // 开关状态决定是否启用
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 提交表单
|
|
||||||
const handleSubmit = async (skipCredentialCheck = false) => {
|
|
||||||
if (!props.provider && !props.endpoint) return
|
|
||||||
|
|
||||||
// 只在开关开启且填写了 URL 时验证
|
|
||||||
if (proxyEnabled.value && form.value.proxy_url && proxyUrlError.value) {
|
|
||||||
showError(proxyUrlError.value, '代理配置错误')
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查:开关开启但没有 URL,却有用户名或密码
|
|
||||||
const hasOrphanedCredentials = proxyEnabled.value
|
|
||||||
&& !form.value.proxy_url
|
|
||||||
&& (form.value.proxy_username || form.value.proxy_password)
|
|
||||||
|
|
||||||
if (hasOrphanedCredentials && !skipCredentialCheck) {
|
|
||||||
// 弹出确认对话框
|
|
||||||
showClearCredentialsDialog.value = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
loading.value = true
|
|
||||||
try {
|
try {
|
||||||
const proxyConfig = buildProxyConfig()
|
await updateEndpoint(endpoint.id, {
|
||||||
|
base_url: editingUrl.value,
|
||||||
if (isEditMode.value && props.endpoint) {
|
custom_path: editingPath.value || null, // 空字符串时传 null 清空
|
||||||
// 更新端点
|
})
|
||||||
await updateEndpoint(props.endpoint.id, {
|
success('端点已更新')
|
||||||
base_url: form.value.base_url,
|
emit('endpointUpdated')
|
||||||
custom_path: form.value.custom_path || undefined,
|
cancelEdit()
|
||||||
timeout: form.value.timeout,
|
|
||||||
max_retries: form.value.max_retries,
|
|
||||||
max_concurrent: form.value.max_concurrent,
|
|
||||||
rate_limit: form.value.rate_limit,
|
|
||||||
is_active: form.value.is_active,
|
|
||||||
proxy: proxyConfig,
|
|
||||||
})
|
|
||||||
|
|
||||||
success('端点已更新', '保存成功')
|
|
||||||
emit('endpointUpdated')
|
|
||||||
} else if (props.provider) {
|
|
||||||
// 创建端点
|
|
||||||
await createEndpoint(props.provider.id, {
|
|
||||||
provider_id: props.provider.id,
|
|
||||||
api_format: form.value.api_format,
|
|
||||||
base_url: form.value.base_url,
|
|
||||||
custom_path: form.value.custom_path || undefined,
|
|
||||||
timeout: form.value.timeout,
|
|
||||||
max_retries: form.value.max_retries,
|
|
||||||
max_concurrent: form.value.max_concurrent,
|
|
||||||
rate_limit: form.value.rate_limit,
|
|
||||||
is_active: form.value.is_active,
|
|
||||||
proxy: proxyConfig,
|
|
||||||
})
|
|
||||||
|
|
||||||
success('端点创建成功', '成功')
|
|
||||||
emit('endpointCreated')
|
|
||||||
resetForm()
|
|
||||||
}
|
|
||||||
|
|
||||||
emit('update:modelValue', false)
|
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
const action = isEditMode.value ? '更新' : '创建'
|
showError(error.response?.data?.detail || '更新失败', '错误')
|
||||||
showError(error.response?.data?.detail || `${action}端点失败`, '错误')
|
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
savingEndpointId.value = null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 确认清空凭据并继续保存
|
// 添加端点
|
||||||
const confirmClearCredentials = () => {
|
async function handleAddEndpoint() {
|
||||||
form.value.proxy_username = ''
|
if (!props.provider || !newEndpoint.value.api_format || !newEndpoint.value.base_url) return
|
||||||
form.value.proxy_password = ''
|
|
||||||
showClearCredentialsDialog.value = false
|
addingEndpoint.value = true
|
||||||
handleSubmit(true) // 跳过凭据检查,直接提交
|
try {
|
||||||
|
await createEndpoint(props.provider.id, {
|
||||||
|
provider_id: props.provider.id,
|
||||||
|
api_format: newEndpoint.value.api_format,
|
||||||
|
base_url: newEndpoint.value.base_url,
|
||||||
|
custom_path: newEndpoint.value.custom_path || undefined,
|
||||||
|
is_active: true,
|
||||||
|
})
|
||||||
|
success(`已添加 ${API_FORMAT_LABELS[newEndpoint.value.api_format] || newEndpoint.value.api_format} 端点`)
|
||||||
|
// 重置表单,保留 URL
|
||||||
|
const url = newEndpoint.value.base_url
|
||||||
|
newEndpoint.value = { api_format: '', base_url: url, custom_path: '' }
|
||||||
|
emit('endpointCreated')
|
||||||
|
} catch (error: any) {
|
||||||
|
showError(error.response?.data?.detail || '添加失败', '错误')
|
||||||
|
} finally {
|
||||||
|
addingEndpoint.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 切换端点启用状态
|
||||||
|
async function handleToggleEndpoint(endpoint: ProviderEndpoint) {
|
||||||
|
togglingEndpointId.value = endpoint.id
|
||||||
|
try {
|
||||||
|
const newStatus = !endpoint.is_active
|
||||||
|
await updateEndpoint(endpoint.id, { is_active: newStatus })
|
||||||
|
success(newStatus ? '端点已启用' : '端点已停用')
|
||||||
|
emit('endpointUpdated')
|
||||||
|
} catch (error: any) {
|
||||||
|
showError(error.response?.data?.detail || '操作失败', '错误')
|
||||||
|
} finally {
|
||||||
|
togglingEndpointId.value = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除端点
|
||||||
|
async function handleDeleteEndpoint(endpoint: ProviderEndpoint) {
|
||||||
|
deletingEndpointId.value = endpoint.id
|
||||||
|
try {
|
||||||
|
await deleteEndpoint(endpoint.id)
|
||||||
|
success(`已删除 ${API_FORMAT_LABELS[endpoint.api_format] || endpoint.api_format} 端点`)
|
||||||
|
emit('endpointUpdated')
|
||||||
|
} catch (error: any) {
|
||||||
|
showError(error.response?.data?.detail || '删除失败', '错误')
|
||||||
|
} finally {
|
||||||
|
deletingEndpointId.value = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关闭对话框
|
||||||
|
function handleDialogUpdate(value: boolean) {
|
||||||
|
emit('update:modelValue', value)
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleClose() {
|
||||||
|
emit('update:modelValue', false)
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -1,165 +1,179 @@
|
|||||||
<template>
|
<template>
|
||||||
<Dialog
|
<Dialog
|
||||||
:model-value="isOpen"
|
:model-value="isOpen"
|
||||||
title="配置允许的模型"
|
title="获取上游模型"
|
||||||
description="选择该 API Key 允许访问的模型,留空则允许访问所有模型"
|
:description="`使用密钥 ${props.apiKey?.name || props.apiKey?.api_key_masked || ''} 从上游获取模型列表。导入的模型需要关联全局模型后才能参与路由。`"
|
||||||
:icon="Settings2"
|
:icon="Layers"
|
||||||
size="2xl"
|
size="2xl"
|
||||||
@update:model-value="handleDialogUpdate"
|
@update:model-value="handleDialogUpdate"
|
||||||
>
|
>
|
||||||
<div class="space-y-4 py-2">
|
<div class="space-y-4 py-2">
|
||||||
<!-- 已选模型展示 -->
|
<!-- 操作区域 -->
|
||||||
<div
|
<div class="flex items-center justify-between">
|
||||||
v-if="selectedModels.length > 0"
|
<div class="text-sm text-muted-foreground">
|
||||||
class="space-y-2"
|
<span v-if="!hasQueried">点击获取按钮查询上游可用模型</span>
|
||||||
>
|
<span v-else-if="upstreamModels.length > 0">
|
||||||
<div class="flex items-center justify-between px-1">
|
共 {{ upstreamModels.length }} 个模型,已选 {{ selectedModels.length }} 个
|
||||||
<div class="text-xs font-medium text-muted-foreground">
|
</span>
|
||||||
已选模型 ({{ selectedModels.length }})
|
<span v-else>未找到可用模型</span>
|
||||||
</div>
|
|
||||||
<Button
|
|
||||||
type="button"
|
|
||||||
variant="ghost"
|
|
||||||
size="sm"
|
|
||||||
class="h-6 text-xs hover:text-destructive"
|
|
||||||
@click="clearModels"
|
|
||||||
>
|
|
||||||
清空
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
<div class="flex flex-wrap gap-1.5 p-2 bg-muted/20 rounded-lg border border-border/40 min-h-[40px]">
|
|
||||||
<Badge
|
|
||||||
v-for="modelName in selectedModels"
|
|
||||||
:key="modelName"
|
|
||||||
variant="secondary"
|
|
||||||
class="text-[11px] px-2 py-0.5 bg-background border-border/60 shadow-sm"
|
|
||||||
>
|
|
||||||
{{ getModelLabel(modelName) }}
|
|
||||||
<button
|
|
||||||
class="ml-0.5 hover:text-destructive focus:outline-none"
|
|
||||||
@click.stop="toggleModel(modelName, false)"
|
|
||||||
>
|
|
||||||
×
|
|
||||||
</button>
|
|
||||||
</Badge>
|
|
||||||
</div>
|
</div>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
:disabled="loading"
|
||||||
|
@click="fetchUpstreamModels"
|
||||||
|
>
|
||||||
|
<RefreshCw
|
||||||
|
class="w-3.5 h-3.5 mr-1.5"
|
||||||
|
:class="{ 'animate-spin': loading }"
|
||||||
|
/>
|
||||||
|
{{ hasQueried ? '刷新' : '获取模型' }}
|
||||||
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 模型列表区域 -->
|
<!-- 加载状态 -->
|
||||||
<div class="space-y-2">
|
<div
|
||||||
|
v-if="loading"
|
||||||
|
class="flex flex-col items-center justify-center py-12 space-y-3"
|
||||||
|
>
|
||||||
|
<div class="animate-spin rounded-full h-8 w-8 border-2 border-primary/20 border-t-primary" />
|
||||||
|
<span class="text-xs text-muted-foreground">正在从上游获取模型列表...</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 错误状态 -->
|
||||||
|
<div
|
||||||
|
v-else-if="errorMessage"
|
||||||
|
class="flex flex-col items-center justify-center py-12 text-destructive border border-dashed border-destructive/30 rounded-lg bg-destructive/5"
|
||||||
|
>
|
||||||
|
<AlertCircle class="w-10 h-10 mb-2 opacity-50" />
|
||||||
|
<span class="text-sm text-center px-4">{{ errorMessage }}</span>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
class="mt-3"
|
||||||
|
@click="fetchUpstreamModels"
|
||||||
|
>
|
||||||
|
重试
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 未查询状态 -->
|
||||||
|
<div
|
||||||
|
v-else-if="!hasQueried"
|
||||||
|
class="flex flex-col items-center justify-center py-12 text-muted-foreground border border-dashed rounded-lg bg-muted/10"
|
||||||
|
>
|
||||||
|
<Layers class="w-10 h-10 mb-2 opacity-20" />
|
||||||
|
<span class="text-sm">点击上方按钮获取模型列表</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 无模型 -->
|
||||||
|
<div
|
||||||
|
v-else-if="upstreamModels.length === 0"
|
||||||
|
class="flex flex-col items-center justify-center py-12 text-muted-foreground border border-dashed rounded-lg bg-muted/10"
|
||||||
|
>
|
||||||
|
<Box class="w-10 h-10 mb-2 opacity-20" />
|
||||||
|
<span class="text-sm">上游 API 未返回可用模型</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 模型列表 -->
|
||||||
|
<div v-else class="space-y-2">
|
||||||
|
<!-- 全选/取消 -->
|
||||||
<div class="flex items-center justify-between px-1">
|
<div class="flex items-center justify-between px-1">
|
||||||
<div class="text-xs font-medium text-muted-foreground">
|
<div class="flex items-center gap-2">
|
||||||
可选模型列表
|
<Checkbox
|
||||||
|
:checked="isAllSelected"
|
||||||
|
:indeterminate="isPartiallySelected"
|
||||||
|
@update:checked="toggleSelectAll"
|
||||||
|
/>
|
||||||
|
<span class="text-xs text-muted-foreground">
|
||||||
|
{{ isAllSelected ? '取消全选' : '全选' }}
|
||||||
|
</span>
|
||||||
</div>
|
</div>
|
||||||
<div
|
<div class="text-xs text-muted-foreground">
|
||||||
v-if="!loadingModels && availableModels.length > 0"
|
{{ newModelsCount }} 个新模型(不在本地)
|
||||||
class="text-[10px] text-muted-foreground/60"
|
|
||||||
>
|
|
||||||
共 {{ availableModels.length }} 个模型
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 加载状态 -->
|
<div class="max-h-[320px] overflow-y-auto pr-1 space-y-1 custom-scrollbar">
|
||||||
<div
|
|
||||||
v-if="loadingModels"
|
|
||||||
class="flex flex-col items-center justify-center py-12 space-y-3"
|
|
||||||
>
|
|
||||||
<div class="animate-spin rounded-full h-8 w-8 border-2 border-primary/20 border-t-primary" />
|
|
||||||
<span class="text-xs text-muted-foreground">正在加载模型列表...</span>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 无模型 -->
|
|
||||||
<div
|
|
||||||
v-else-if="availableModels.length === 0"
|
|
||||||
class="flex flex-col items-center justify-center py-12 text-muted-foreground border border-dashed rounded-lg bg-muted/10"
|
|
||||||
>
|
|
||||||
<Box class="w-10 h-10 mb-2 opacity-20" />
|
|
||||||
<span class="text-sm">暂无可选模型</span>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 模型列表 -->
|
|
||||||
<div
|
|
||||||
v-else
|
|
||||||
class="max-h-[320px] overflow-y-auto pr-1 space-y-1.5 custom-scrollbar"
|
|
||||||
>
|
|
||||||
<div
|
<div
|
||||||
v-for="model in availableModels"
|
v-for="model in upstreamModels"
|
||||||
:key="model.global_model_name"
|
:key="`${model.id}:${model.api_format || ''}`"
|
||||||
class="group flex items-center gap-3 px-3 py-2.5 rounded-lg border transition-all duration-200 cursor-pointer select-none"
|
class="group flex items-center gap-3 px-3 py-2.5 rounded-lg border transition-all duration-200 cursor-pointer select-none"
|
||||||
:class="[
|
:class="[
|
||||||
selectedModels.includes(model.global_model_name)
|
selectedModels.includes(model.id)
|
||||||
? 'border-primary/40 bg-primary/5 shadow-sm'
|
? 'border-primary/40 bg-primary/5 shadow-sm'
|
||||||
: 'border-border/40 bg-background hover:border-primary/20 hover:bg-muted/30'
|
: 'border-border/40 bg-background hover:border-primary/20 hover:bg-muted/30'
|
||||||
]"
|
]"
|
||||||
@click="toggleModel(model.global_model_name, !selectedModels.includes(model.global_model_name))"
|
@click="toggleModel(model.id)"
|
||||||
>
|
>
|
||||||
<!-- Checkbox -->
|
|
||||||
<Checkbox
|
<Checkbox
|
||||||
:checked="selectedModels.includes(model.global_model_name)"
|
:checked="selectedModels.includes(model.id)"
|
||||||
class="data-[state=checked]:bg-primary data-[state=checked]:border-primary"
|
class="data-[state=checked]:bg-primary data-[state=checked]:border-primary"
|
||||||
@click.stop
|
@click.stop
|
||||||
@update:checked="checked => toggleModel(model.global_model_name, checked)"
|
@update:checked="checked => toggleModel(model.id, checked)"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<!-- Info -->
|
|
||||||
<div class="flex-1 min-w-0">
|
<div class="flex-1 min-w-0">
|
||||||
<div class="flex items-center justify-between gap-2">
|
<div class="flex items-center gap-2">
|
||||||
<span class="text-sm font-medium truncate text-foreground/90">{{ model.display_name }}</span>
|
<span class="text-sm font-medium truncate text-foreground/90">
|
||||||
<span
|
{{ model.display_name || model.id }}
|
||||||
v-if="hasPricing(model)"
|
|
||||||
class="text-[10px] font-mono text-muted-foreground/80 bg-muted/30 px-1.5 py-0.5 rounded border border-border/30 shrink-0"
|
|
||||||
>
|
|
||||||
{{ formatPricingShort(model) }}
|
|
||||||
</span>
|
</span>
|
||||||
|
<Badge
|
||||||
|
v-if="model.api_format"
|
||||||
|
variant="outline"
|
||||||
|
class="text-[10px] px-1.5 py-0 shrink-0"
|
||||||
|
>
|
||||||
|
{{ API_FORMAT_LABELS[model.api_format] || model.api_format }}
|
||||||
|
</Badge>
|
||||||
|
<Badge
|
||||||
|
v-if="isModelExisting(model.id)"
|
||||||
|
variant="secondary"
|
||||||
|
class="text-[10px] px-1.5 py-0 shrink-0"
|
||||||
|
>
|
||||||
|
已存在
|
||||||
|
</Badge>
|
||||||
</div>
|
</div>
|
||||||
<div class="text-[11px] text-muted-foreground/60 font-mono truncate mt-0.5">
|
<div class="text-[11px] text-muted-foreground/60 font-mono truncate mt-0.5">
|
||||||
{{ model.global_model_name }}
|
{{ model.id }}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<div
|
||||||
<!-- 测试按钮 -->
|
v-if="model.owned_by"
|
||||||
<Button
|
class="text-[10px] text-muted-foreground/50 shrink-0"
|
||||||
variant="ghost"
|
|
||||||
size="icon"
|
|
||||||
class="h-7 w-7 shrink-0"
|
|
||||||
title="测试模型连接"
|
|
||||||
:disabled="testingModelName === model.global_model_name"
|
|
||||||
@click.stop="testModelConnection(model)"
|
|
||||||
>
|
>
|
||||||
<Loader2
|
{{ model.owned_by }}
|
||||||
v-if="testingModelName === model.global_model_name"
|
</div>
|
||||||
class="w-3.5 h-3.5 animate-spin"
|
|
||||||
/>
|
|
||||||
<Play
|
|
||||||
v-else
|
|
||||||
class="w-3.5 h-3.5"
|
|
||||||
/>
|
|
||||||
</Button>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<template #footer>
|
<template #footer>
|
||||||
<div class="flex items-center justify-end gap-2 w-full pt-2">
|
<div class="flex items-center justify-between w-full pt-2">
|
||||||
<Button
|
<div class="text-xs text-muted-foreground">
|
||||||
variant="outline"
|
<span v-if="selectedModels.length > 0 && newSelectedCount > 0">
|
||||||
class="h-9"
|
将导入 {{ newSelectedCount }} 个新模型
|
||||||
@click="handleCancel"
|
</span>
|
||||||
>
|
</div>
|
||||||
取消
|
<div class="flex items-center gap-2">
|
||||||
</Button>
|
<Button
|
||||||
<Button
|
variant="outline"
|
||||||
:disabled="saving"
|
class="h-9"
|
||||||
class="h-9 min-w-[80px]"
|
@click="handleCancel"
|
||||||
@click="handleSave"
|
>
|
||||||
>
|
取消
|
||||||
<Loader2
|
</Button>
|
||||||
v-if="saving"
|
<Button
|
||||||
class="w-3.5 h-3.5 mr-1.5 animate-spin"
|
:disabled="importing || selectedModels.length === 0 || newSelectedCount === 0"
|
||||||
/>
|
class="h-9 min-w-[100px]"
|
||||||
{{ saving ? '保存中' : '保存配置' }}
|
@click="handleImport"
|
||||||
</Button>
|
>
|
||||||
|
<Loader2
|
||||||
|
v-if="importing"
|
||||||
|
class="w-3.5 h-3.5 mr-1.5 animate-spin"
|
||||||
|
/>
|
||||||
|
{{ importing ? '导入中' : `导入 ${newSelectedCount} 个模型` }}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
@@ -167,19 +181,19 @@
|
|||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed, watch } from 'vue'
|
import { ref, computed, watch } from 'vue'
|
||||||
import { Box, Loader2, Settings2, Play } from 'lucide-vue-next'
|
import { Box, Layers, Loader2, RefreshCw, AlertCircle } from 'lucide-vue-next'
|
||||||
import { Dialog } from '@/components/ui'
|
import { Dialog } from '@/components/ui'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
import Badge from '@/components/ui/badge.vue'
|
import Badge from '@/components/ui/badge.vue'
|
||||||
import Checkbox from '@/components/ui/checkbox.vue'
|
import Checkbox from '@/components/ui/checkbox.vue'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { parseApiError, parseTestModelError } from '@/utils/errorParser'
|
import { adminApi } from '@/api/admin'
|
||||||
import {
|
import {
|
||||||
updateEndpointKey,
|
importModelsFromUpstream,
|
||||||
getProviderAvailableSourceModels,
|
getProviderModels,
|
||||||
testModel,
|
|
||||||
type EndpointAPIKey,
|
type EndpointAPIKey,
|
||||||
type ProviderAvailableSourceModel
|
type UpstreamModel,
|
||||||
|
API_FORMAT_LABELS,
|
||||||
} from '@/api/endpoints'
|
} from '@/api/endpoints'
|
||||||
|
|
||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
@@ -196,130 +210,116 @@ const emit = defineEmits<{
|
|||||||
const { success, error: showError } = useToast()
|
const { success, error: showError } = useToast()
|
||||||
|
|
||||||
const isOpen = computed(() => props.open)
|
const isOpen = computed(() => props.open)
|
||||||
const saving = ref(false)
|
const loading = ref(false)
|
||||||
const loadingModels = ref(false)
|
const importing = ref(false)
|
||||||
const availableModels = ref<ProviderAvailableSourceModel[]>([])
|
const hasQueried = ref(false)
|
||||||
|
const errorMessage = ref('')
|
||||||
|
const upstreamModels = ref<UpstreamModel[]>([])
|
||||||
const selectedModels = ref<string[]>([])
|
const selectedModels = ref<string[]>([])
|
||||||
const initialModels = ref<string[]>([])
|
const existingModelIds = ref<Set<string>>(new Set())
|
||||||
const testingModelName = ref<string | null>(null)
|
|
||||||
|
// 计算属性
|
||||||
|
const isAllSelected = computed(() =>
|
||||||
|
upstreamModels.value.length > 0 &&
|
||||||
|
selectedModels.value.length === upstreamModels.value.length
|
||||||
|
)
|
||||||
|
|
||||||
|
const isPartiallySelected = computed(() =>
|
||||||
|
selectedModels.value.length > 0 &&
|
||||||
|
selectedModels.value.length < upstreamModels.value.length
|
||||||
|
)
|
||||||
|
|
||||||
|
const newModelsCount = computed(() =>
|
||||||
|
upstreamModels.value.filter(m => !existingModelIds.value.has(m.id)).length
|
||||||
|
)
|
||||||
|
|
||||||
|
const newSelectedCount = computed(() =>
|
||||||
|
selectedModels.value.filter(id => !existingModelIds.value.has(id)).length
|
||||||
|
)
|
||||||
|
|
||||||
|
// 检查模型是否已存在
|
||||||
|
function isModelExisting(modelId: string): boolean {
|
||||||
|
return existingModelIds.value.has(modelId)
|
||||||
|
}
|
||||||
|
|
||||||
// 监听对话框打开
|
// 监听对话框打开
|
||||||
watch(() => props.open, (open) => {
|
watch(() => props.open, (open) => {
|
||||||
if (open) {
|
if (open) {
|
||||||
loadData()
|
resetState()
|
||||||
|
loadExistingModels()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
async function loadData() {
|
function resetState() {
|
||||||
// 初始化已选模型
|
hasQueried.value = false
|
||||||
if (props.apiKey?.allowed_models) {
|
errorMessage.value = ''
|
||||||
selectedModels.value = [...props.apiKey.allowed_models]
|
upstreamModels.value = []
|
||||||
initialModels.value = [...props.apiKey.allowed_models]
|
|
||||||
} else {
|
|
||||||
selectedModels.value = []
|
|
||||||
initialModels.value = []
|
|
||||||
}
|
|
||||||
|
|
||||||
// 加载可选模型
|
|
||||||
if (props.providerId) {
|
|
||||||
await loadAvailableModels()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function loadAvailableModels() {
|
|
||||||
if (!props.providerId) return
|
|
||||||
try {
|
|
||||||
loadingModels.value = true
|
|
||||||
const response = await getProviderAvailableSourceModels(props.providerId)
|
|
||||||
availableModels.value = response.models
|
|
||||||
} catch (err: any) {
|
|
||||||
const errorMessage = parseApiError(err, '加载模型列表失败')
|
|
||||||
showError(errorMessage, '错误')
|
|
||||||
} finally {
|
|
||||||
loadingModels.value = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const modelLabelMap = computed(() => {
|
|
||||||
const map = new Map<string, string>()
|
|
||||||
availableModels.value.forEach(model => {
|
|
||||||
map.set(model.global_model_name, model.display_name || model.global_model_name)
|
|
||||||
})
|
|
||||||
return map
|
|
||||||
})
|
|
||||||
|
|
||||||
function getModelLabel(modelName: string): string {
|
|
||||||
return modelLabelMap.value.get(modelName) ?? modelName
|
|
||||||
}
|
|
||||||
|
|
||||||
function hasPricing(model: ProviderAvailableSourceModel): boolean {
|
|
||||||
const input = model.price.input_price_per_1m ?? 0
|
|
||||||
const output = model.price.output_price_per_1m ?? 0
|
|
||||||
return input > 0 || output > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
function formatPricingShort(model: ProviderAvailableSourceModel): string {
|
|
||||||
const input = model.price.input_price_per_1m ?? 0
|
|
||||||
const output = model.price.output_price_per_1m ?? 0
|
|
||||||
if (input > 0 || output > 0) {
|
|
||||||
return `$${formatPrice(input)}/$${formatPrice(output)}`
|
|
||||||
}
|
|
||||||
return ''
|
|
||||||
}
|
|
||||||
|
|
||||||
function formatPrice(value?: number | null): string {
|
|
||||||
if (value === undefined || value === null || value === 0) return '0'
|
|
||||||
if (value >= 1) {
|
|
||||||
return value.toFixed(2)
|
|
||||||
}
|
|
||||||
return value.toFixed(2)
|
|
||||||
}
|
|
||||||
|
|
||||||
function toggleModel(modelName: string, checked: boolean) {
|
|
||||||
if (checked) {
|
|
||||||
if (!selectedModels.value.includes(modelName)) {
|
|
||||||
selectedModels.value = [...selectedModels.value, modelName]
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
selectedModels.value = selectedModels.value.filter(name => name !== modelName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function clearModels() {
|
|
||||||
selectedModels.value = []
|
selectedModels.value = []
|
||||||
}
|
}
|
||||||
|
|
||||||
// 测试模型连接
|
// 加载已存在的模型列表
|
||||||
async function testModelConnection(model: ProviderAvailableSourceModel) {
|
async function loadExistingModels() {
|
||||||
if (!props.providerId || !props.apiKey || testingModelName.value) return
|
if (!props.providerId) return
|
||||||
|
|
||||||
testingModelName.value = model.global_model_name
|
|
||||||
try {
|
try {
|
||||||
const result = await testModel({
|
const models = await getProviderModels(props.providerId)
|
||||||
provider_id: props.providerId,
|
existingModelIds.value = new Set(
|
||||||
model_name: model.provider_model_name,
|
models.map((m: { provider_model_name: string }) => m.provider_model_name)
|
||||||
api_key_id: props.apiKey.id,
|
)
|
||||||
message: "hello"
|
} catch {
|
||||||
})
|
existingModelIds.value = new Set()
|
||||||
|
|
||||||
if (result.success) {
|
|
||||||
success(`模型 "${model.display_name}" 测试成功`)
|
|
||||||
} else {
|
|
||||||
showError(`模型测试失败: ${parseTestModelError(result)}`)
|
|
||||||
}
|
|
||||||
} catch (err: any) {
|
|
||||||
const errorMsg = err.response?.data?.detail || err.message || '测试请求失败'
|
|
||||||
showError(`模型测试失败: ${errorMsg}`)
|
|
||||||
} finally {
|
|
||||||
testingModelName.value = null
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function areArraysEqual(a: string[], b: string[]): boolean {
|
// 获取上游模型
|
||||||
if (a.length !== b.length) return false
|
async function fetchUpstreamModels() {
|
||||||
const sortedA = [...a].sort()
|
if (!props.providerId || !props.apiKey) return
|
||||||
const sortedB = [...b].sort()
|
|
||||||
return sortedA.every((value, index) => value === sortedB[index])
|
loading.value = true
|
||||||
|
errorMessage.value = ''
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await adminApi.queryProviderModels(props.providerId, props.apiKey.id)
|
||||||
|
|
||||||
|
if (response.success && response.data?.models) {
|
||||||
|
upstreamModels.value = response.data.models
|
||||||
|
// 默认选中所有新模型
|
||||||
|
selectedModels.value = response.data.models
|
||||||
|
.filter((m: UpstreamModel) => !existingModelIds.value.has(m.id))
|
||||||
|
.map((m: UpstreamModel) => m.id)
|
||||||
|
hasQueried.value = true
|
||||||
|
// 如果有部分失败,显示警告提示
|
||||||
|
if (response.data.error) {
|
||||||
|
showError(`部分格式获取失败: ${response.data.error}`, '警告')
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
errorMessage.value = response.data?.error || '获取上游模型失败'
|
||||||
|
}
|
||||||
|
} catch (err: any) {
|
||||||
|
errorMessage.value = err.response?.data?.detail || '获取上游模型失败'
|
||||||
|
} finally {
|
||||||
|
loading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 切换模型选择
|
||||||
|
function toggleModel(modelId: string, checked?: boolean) {
|
||||||
|
const shouldSelect = checked !== undefined ? checked : !selectedModels.value.includes(modelId)
|
||||||
|
if (shouldSelect) {
|
||||||
|
if (!selectedModels.value.includes(modelId)) {
|
||||||
|
selectedModels.value = [...selectedModels.value, modelId]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
selectedModels.value = selectedModels.value.filter(id => id !== modelId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 全选/取消全选
|
||||||
|
function toggleSelectAll(checked: boolean) {
|
||||||
|
if (checked) {
|
||||||
|
selectedModels.value = upstreamModels.value.map(m => m.id)
|
||||||
|
} else {
|
||||||
|
selectedModels.value = []
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleDialogUpdate(value: boolean) {
|
function handleDialogUpdate(value: boolean) {
|
||||||
@@ -332,30 +332,44 @@ function handleCancel() {
|
|||||||
emit('close')
|
emit('close')
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleSave() {
|
// 导入选中的模型
|
||||||
if (!props.apiKey) return
|
async function handleImport() {
|
||||||
|
if (!props.providerId || selectedModels.value.length === 0) return
|
||||||
|
|
||||||
// 检查是否有变化
|
// 过滤出新模型(不在已存在列表中的)
|
||||||
const hasChanged = !areArraysEqual(selectedModels.value, initialModels.value)
|
const modelsToImport = selectedModels.value.filter(id => !existingModelIds.value.has(id))
|
||||||
if (!hasChanged) {
|
if (modelsToImport.length === 0) {
|
||||||
emit('close')
|
showError('所选模型都已存在', '提示')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
saving.value = true
|
importing.value = true
|
||||||
try {
|
try {
|
||||||
await updateEndpointKey(props.apiKey.id, {
|
const response = await importModelsFromUpstream(props.providerId, modelsToImport)
|
||||||
// 空数组时发送 null,表示允许所有模型
|
|
||||||
allowed_models: selectedModels.value.length > 0 ? [...selectedModels.value] : null
|
const successCount = response.success?.length || 0
|
||||||
})
|
const errorCount = response.errors?.length || 0
|
||||||
success('允许的模型已更新', '成功')
|
|
||||||
emit('saved')
|
if (successCount > 0 && errorCount === 0) {
|
||||||
emit('close')
|
success(`成功导入 ${successCount} 个模型`, '导入成功')
|
||||||
|
emit('saved')
|
||||||
|
emit('close')
|
||||||
|
} else if (successCount > 0 && errorCount > 0) {
|
||||||
|
success(`成功导入 ${successCount} 个模型,${errorCount} 个失败`, '部分成功')
|
||||||
|
emit('saved')
|
||||||
|
// 刷新列表以更新已存在状态
|
||||||
|
await loadExistingModels()
|
||||||
|
// 更新选中列表,移除已成功导入的
|
||||||
|
const successIds = new Set(response.success?.map((s: { model_id: string }) => s.model_id) || [])
|
||||||
|
selectedModels.value = selectedModels.value.filter(id => !successIds.has(id))
|
||||||
|
} else {
|
||||||
|
const errorMsg = response.errors?.[0]?.error || '导入失败'
|
||||||
|
showError(errorMsg, '导入失败')
|
||||||
|
}
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
const errorMessage = parseApiError(err, '保存失败')
|
showError(err.response?.data?.detail || '导入失败', '错误')
|
||||||
showError(errorMessage, '错误')
|
|
||||||
} finally {
|
} finally {
|
||||||
saving.value = false
|
importing.value = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -0,0 +1,696 @@
|
|||||||
|
<template>
|
||||||
|
<Dialog
|
||||||
|
:model-value="isOpen"
|
||||||
|
title="模型权限"
|
||||||
|
:description="`管理密钥 ${props.apiKey?.name || ''} 可访问的模型,清空右侧列表表示允许全部`"
|
||||||
|
:icon="Shield"
|
||||||
|
size="4xl"
|
||||||
|
@update:model-value="handleDialogUpdate"
|
||||||
|
>
|
||||||
|
<template #default>
|
||||||
|
<div class="space-y-4">
|
||||||
|
<!-- 字典模式警告 -->
|
||||||
|
<div
|
||||||
|
v-if="isDictMode"
|
||||||
|
class="rounded-lg border border-amber-500/50 bg-amber-50 dark:bg-amber-950/30 p-3"
|
||||||
|
>
|
||||||
|
<p class="text-sm text-amber-700 dark:text-amber-400">
|
||||||
|
<strong>注意:</strong>此密钥使用按 API 格式区分的模型权限配置。
|
||||||
|
编辑后将转换为统一列表模式,原有的格式区分信息将丢失。
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 密钥信息头部 -->
|
||||||
|
<div class="rounded-lg border bg-muted/30 p-4">
|
||||||
|
<div class="flex items-start justify-between">
|
||||||
|
<div>
|
||||||
|
<p class="font-semibold text-lg">{{ apiKey?.name }}</p>
|
||||||
|
<p class="text-sm text-muted-foreground font-mono">
|
||||||
|
{{ apiKey?.api_key_masked }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<Badge
|
||||||
|
:variant="allowedModels.length === 0 ? 'default' : 'outline'"
|
||||||
|
class="text-xs"
|
||||||
|
>
|
||||||
|
{{ allowedModels.length === 0 ? '允许全部' : `限制 ${allowedModels.length} 个模型` }}
|
||||||
|
</Badge>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 左右对比布局 -->
|
||||||
|
<div class="flex gap-2 items-stretch">
|
||||||
|
<!-- 左侧:可添加的模型 -->
|
||||||
|
<div class="flex-1 space-y-2">
|
||||||
|
<div class="flex items-center justify-between gap-2">
|
||||||
|
<p class="text-sm font-medium shrink-0">可添加</p>
|
||||||
|
<div class="flex-1 relative">
|
||||||
|
<Search class="absolute left-2 top-1/2 -translate-y-1/2 w-3.5 h-3.5 text-muted-foreground" />
|
||||||
|
<Input
|
||||||
|
v-model="searchQuery"
|
||||||
|
placeholder="搜索模型..."
|
||||||
|
class="pl-7 h-7 text-xs"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
v-if="upstreamModelsLoaded"
|
||||||
|
type="button"
|
||||||
|
class="p-1.5 hover:bg-muted rounded-md transition-colors shrink-0"
|
||||||
|
title="刷新上游模型"
|
||||||
|
:disabled="fetchingUpstreamModels"
|
||||||
|
@click="fetchUpstreamModels()"
|
||||||
|
>
|
||||||
|
<RefreshCw
|
||||||
|
class="w-3.5 h-3.5"
|
||||||
|
:class="{ 'animate-spin': fetchingUpstreamModels }"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
v-else-if="!fetchingUpstreamModels"
|
||||||
|
type="button"
|
||||||
|
class="p-1.5 hover:bg-muted rounded-md transition-colors shrink-0"
|
||||||
|
title="从提供商获取模型"
|
||||||
|
@click="fetchUpstreamModels()"
|
||||||
|
>
|
||||||
|
<Zap class="w-3.5 h-3.5" />
|
||||||
|
</button>
|
||||||
|
<Loader2
|
||||||
|
v-else
|
||||||
|
class="w-3.5 h-3.5 animate-spin text-muted-foreground shrink-0"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div class="border rounded-lg h-80 overflow-y-auto">
|
||||||
|
<div
|
||||||
|
v-if="loadingGlobalModels"
|
||||||
|
class="flex items-center justify-center h-full"
|
||||||
|
>
|
||||||
|
<Loader2 class="w-6 h-6 animate-spin text-primary" />
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
v-else-if="totalAvailableCount === 0 && !upstreamModelsLoaded"
|
||||||
|
class="flex flex-col items-center justify-center h-full text-muted-foreground"
|
||||||
|
>
|
||||||
|
<Shield class="w-10 h-10 mb-2 opacity-30" />
|
||||||
|
<p class="text-sm">{{ searchQuery ? '无匹配结果' : '暂无可添加模型' }}</p>
|
||||||
|
</div>
|
||||||
|
<div v-else class="p-2 space-y-2">
|
||||||
|
<!-- 全局模型折叠组 -->
|
||||||
|
<div
|
||||||
|
v-if="availableGlobalModels.length > 0 || !upstreamModelsLoaded"
|
||||||
|
class="border rounded-lg overflow-hidden"
|
||||||
|
>
|
||||||
|
<div class="flex items-center gap-2 px-3 py-2 bg-muted/30">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="flex items-center gap-2 flex-1 hover:bg-muted/50 -mx-1 px-1 rounded transition-colors"
|
||||||
|
@click="toggleGroupCollapse('global')"
|
||||||
|
>
|
||||||
|
<ChevronDown
|
||||||
|
class="w-4 h-4 transition-transform shrink-0"
|
||||||
|
:class="collapsedGroups.has('global') ? '-rotate-90' : ''"
|
||||||
|
/>
|
||||||
|
<span class="text-xs font-medium">全局模型</span>
|
||||||
|
<span class="text-xs text-muted-foreground">
|
||||||
|
({{ availableGlobalModels.length }})
|
||||||
|
</span>
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
v-if="availableGlobalModels.length > 0"
|
||||||
|
type="button"
|
||||||
|
class="text-xs text-primary hover:underline shrink-0"
|
||||||
|
@click.stop="selectAllGlobalModels"
|
||||||
|
>
|
||||||
|
{{ isAllGlobalModelsSelected ? '取消' : '全选' }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
v-show="!collapsedGroups.has('global')"
|
||||||
|
class="p-2 space-y-1 border-t"
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
v-if="availableGlobalModels.length === 0"
|
||||||
|
class="py-4 text-center text-xs text-muted-foreground"
|
||||||
|
>
|
||||||
|
所有全局模型均已添加
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
v-for="model in availableGlobalModels"
|
||||||
|
v-else
|
||||||
|
:key="model.name"
|
||||||
|
class="flex items-center gap-2 p-2 rounded-lg border transition-colors cursor-pointer"
|
||||||
|
:class="selectedLeftIds.includes(model.name)
|
||||||
|
? 'border-primary bg-primary/10'
|
||||||
|
: 'hover:bg-muted/50'"
|
||||||
|
@click="toggleLeftSelection(model.name)"
|
||||||
|
>
|
||||||
|
<Checkbox
|
||||||
|
:checked="selectedLeftIds.includes(model.name)"
|
||||||
|
@update:checked="toggleLeftSelection(model.name)"
|
||||||
|
@click.stop
|
||||||
|
/>
|
||||||
|
<div class="flex-1 min-w-0">
|
||||||
|
<p class="font-medium text-sm truncate">{{ model.display_name }}</p>
|
||||||
|
<p class="text-xs text-muted-foreground truncate font-mono">{{ model.name }}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 从提供商获取的模型折叠组 -->
|
||||||
|
<div
|
||||||
|
v-for="group in upstreamModelGroups"
|
||||||
|
:key="group.api_format"
|
||||||
|
class="border rounded-lg overflow-hidden"
|
||||||
|
>
|
||||||
|
<div class="flex items-center gap-2 px-3 py-2 bg-muted/30">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="flex items-center gap-2 flex-1 hover:bg-muted/50 -mx-1 px-1 rounded transition-colors"
|
||||||
|
@click="toggleGroupCollapse(group.api_format)"
|
||||||
|
>
|
||||||
|
<ChevronDown
|
||||||
|
class="w-4 h-4 transition-transform shrink-0"
|
||||||
|
:class="collapsedGroups.has(group.api_format) ? '-rotate-90' : ''"
|
||||||
|
/>
|
||||||
|
<span class="text-xs font-medium">
|
||||||
|
{{ API_FORMAT_LABELS[group.api_format] || group.api_format }}
|
||||||
|
</span>
|
||||||
|
<span class="text-xs text-muted-foreground">
|
||||||
|
({{ group.models.length }})
|
||||||
|
</span>
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="text-xs text-primary hover:underline shrink-0"
|
||||||
|
@click.stop="selectAllUpstreamModels(group.api_format)"
|
||||||
|
>
|
||||||
|
{{ isUpstreamGroupAllSelected(group.api_format) ? '取消' : '全选' }}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
v-show="!collapsedGroups.has(group.api_format)"
|
||||||
|
class="p-2 space-y-1 border-t"
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
v-for="model in group.models"
|
||||||
|
:key="model.id"
|
||||||
|
class="flex items-center gap-2 p-2 rounded-lg border transition-colors cursor-pointer"
|
||||||
|
:class="selectedLeftIds.includes(model.id)
|
||||||
|
? 'border-primary bg-primary/10'
|
||||||
|
: 'hover:bg-muted/50'"
|
||||||
|
@click="toggleLeftSelection(model.id)"
|
||||||
|
>
|
||||||
|
<Checkbox
|
||||||
|
:checked="selectedLeftIds.includes(model.id)"
|
||||||
|
@update:checked="toggleLeftSelection(model.id)"
|
||||||
|
@click.stop
|
||||||
|
/>
|
||||||
|
<div class="flex-1 min-w-0">
|
||||||
|
<p class="font-medium text-sm truncate">{{ model.id }}</p>
|
||||||
|
<p class="text-xs text-muted-foreground truncate font-mono">
|
||||||
|
{{ model.owned_by || model.id }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 中间:操作按钮 -->
|
||||||
|
<div class="flex flex-col items-center justify-center w-12 shrink-0 gap-2">
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
class="w-9 h-8"
|
||||||
|
:class="selectedLeftIds.length > 0 ? 'border-primary' : ''"
|
||||||
|
:disabled="selectedLeftIds.length === 0"
|
||||||
|
title="添加选中"
|
||||||
|
@click="addSelected"
|
||||||
|
>
|
||||||
|
<ChevronRight
|
||||||
|
class="w-6 h-6 stroke-[3]"
|
||||||
|
:class="selectedLeftIds.length > 0 ? 'text-primary' : ''"
|
||||||
|
/>
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
class="w-9 h-8"
|
||||||
|
:class="selectedRightIds.length > 0 ? 'border-primary' : ''"
|
||||||
|
:disabled="selectedRightIds.length === 0"
|
||||||
|
title="移除选中"
|
||||||
|
@click="removeSelected"
|
||||||
|
>
|
||||||
|
<ChevronLeft
|
||||||
|
class="w-6 h-6 stroke-[3]"
|
||||||
|
:class="selectedRightIds.length > 0 ? 'text-primary' : ''"
|
||||||
|
/>
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 右侧:已添加的允许模型 -->
|
||||||
|
<div class="flex-1 space-y-2">
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<p class="text-sm font-medium">已添加</p>
|
||||||
|
<Button
|
||||||
|
v-if="allowedModels.length > 0"
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
class="h-6 px-2 text-xs"
|
||||||
|
@click="toggleSelectAllRight"
|
||||||
|
>
|
||||||
|
{{ isAllRightSelected ? '取消' : '全选' }}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
<div class="border rounded-lg h-80 overflow-y-auto">
|
||||||
|
<div
|
||||||
|
v-if="allowedModels.length === 0"
|
||||||
|
class="flex flex-col items-center justify-center h-full text-muted-foreground"
|
||||||
|
>
|
||||||
|
<Shield class="w-10 h-10 mb-2 opacity-30" />
|
||||||
|
<p class="text-sm">允许访问全部模型</p>
|
||||||
|
<p class="text-xs mt-1">添加模型以限制访问范围</p>
|
||||||
|
</div>
|
||||||
|
<div v-else class="p-2 space-y-1">
|
||||||
|
<div
|
||||||
|
v-for="modelName in allowedModels"
|
||||||
|
:key="'allowed-' + modelName"
|
||||||
|
class="flex items-center gap-2 p-2 rounded-lg border transition-colors cursor-pointer"
|
||||||
|
:class="selectedRightIds.includes(modelName)
|
||||||
|
? 'border-primary bg-primary/10'
|
||||||
|
: 'hover:bg-muted/50'"
|
||||||
|
@click="toggleRightSelection(modelName)"
|
||||||
|
>
|
||||||
|
<Checkbox
|
||||||
|
:checked="selectedRightIds.includes(modelName)"
|
||||||
|
@update:checked="toggleRightSelection(modelName)"
|
||||||
|
@click.stop
|
||||||
|
/>
|
||||||
|
<div class="flex-1 min-w-0">
|
||||||
|
<p class="font-medium text-sm truncate">
|
||||||
|
{{ getModelDisplayName(modelName) }}
|
||||||
|
</p>
|
||||||
|
<p class="text-xs text-muted-foreground truncate font-mono">
|
||||||
|
{{ modelName }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<template #footer>
|
||||||
|
<div class="flex items-center justify-between w-full">
|
||||||
|
<p class="text-xs text-muted-foreground">
|
||||||
|
{{ hasChanges ? '有未保存的更改' : '' }}
|
||||||
|
</p>
|
||||||
|
<div class="flex items-center gap-2">
|
||||||
|
<Button variant="outline" @click="handleCancel">取消</Button>
|
||||||
|
<Button :disabled="saving || !hasChanges" @click="handleSave">
|
||||||
|
{{ saving ? '保存中...' : '保存' }}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
</Dialog>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { ref, computed, watch, onUnmounted } from 'vue'
|
||||||
|
import {
|
||||||
|
Shield,
|
||||||
|
Search,
|
||||||
|
RefreshCw,
|
||||||
|
Loader2,
|
||||||
|
Zap,
|
||||||
|
ChevronRight,
|
||||||
|
ChevronLeft,
|
||||||
|
ChevronDown
|
||||||
|
} from 'lucide-vue-next'
|
||||||
|
import { Dialog, Button, Input, Checkbox, Badge } from '@/components/ui'
|
||||||
|
import { useToast } from '@/composables/useToast'
|
||||||
|
import { parseApiError } from '@/utils/errorParser'
|
||||||
|
import {
|
||||||
|
updateProviderKey,
|
||||||
|
API_FORMAT_LABELS,
|
||||||
|
type EndpointAPIKey,
|
||||||
|
type AllowedModels,
|
||||||
|
} from '@/api/endpoints'
|
||||||
|
import { getGlobalModels, type GlobalModelResponse } from '@/api/global-models'
|
||||||
|
import { adminApi } from '@/api/admin'
|
||||||
|
import type { UpstreamModel } from '@/api/endpoints/types'
|
||||||
|
|
||||||
|
interface AvailableModel {
|
||||||
|
name: string
|
||||||
|
display_name: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const props = defineProps<{
|
||||||
|
open: boolean
|
||||||
|
apiKey: EndpointAPIKey | null
|
||||||
|
providerId: string
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const emit = defineEmits<{
|
||||||
|
close: []
|
||||||
|
saved: []
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const { success, error: showError } = useToast()
|
||||||
|
|
||||||
|
const isOpen = computed(() => props.open)
|
||||||
|
const saving = ref(false)
|
||||||
|
const loadingGlobalModels = ref(false)
|
||||||
|
const fetchingUpstreamModels = ref(false)
|
||||||
|
const upstreamModelsLoaded = ref(false)
|
||||||
|
|
||||||
|
// 用于取消异步操作的标志
|
||||||
|
let loadingCancelled = false
|
||||||
|
|
||||||
|
// 搜索
|
||||||
|
const searchQuery = ref('')
|
||||||
|
|
||||||
|
// 折叠状态
|
||||||
|
const collapsedGroups = ref<Set<string>>(new Set())
|
||||||
|
|
||||||
|
// 可用模型列表(全局模型)
|
||||||
|
const allGlobalModels = ref<AvailableModel[]>([])
|
||||||
|
// 上游模型列表
|
||||||
|
const upstreamModels = ref<UpstreamModel[]>([])
|
||||||
|
|
||||||
|
// 已添加的允许模型(右侧)
|
||||||
|
const allowedModels = ref<string[]>([])
|
||||||
|
const initialAllowedModels = ref<string[]>([])
|
||||||
|
|
||||||
|
// 选中状态
|
||||||
|
const selectedLeftIds = ref<string[]>([])
|
||||||
|
const selectedRightIds = ref<string[]>([])
|
||||||
|
|
||||||
|
// 是否有更改
|
||||||
|
const hasChanges = computed(() => {
|
||||||
|
if (allowedModels.value.length !== initialAllowedModels.value.length) return true
|
||||||
|
const sorted1 = [...allowedModels.value].sort()
|
||||||
|
const sorted2 = [...initialAllowedModels.value].sort()
|
||||||
|
return sorted1.some((v, i) => v !== sorted2[i])
|
||||||
|
})
|
||||||
|
|
||||||
|
// 计算可添加的全局模型(排除已添加的)
|
||||||
|
const availableGlobalModelsBase = computed(() => {
|
||||||
|
const allowedSet = new Set(allowedModels.value)
|
||||||
|
return allGlobalModels.value.filter(m => !allowedSet.has(m.name))
|
||||||
|
})
|
||||||
|
|
||||||
|
// 搜索过滤后的全局模型
|
||||||
|
const availableGlobalModels = computed(() => {
|
||||||
|
if (!searchQuery.value.trim()) return availableGlobalModelsBase.value
|
||||||
|
const query = searchQuery.value.toLowerCase()
|
||||||
|
return availableGlobalModelsBase.value.filter(m =>
|
||||||
|
m.name.toLowerCase().includes(query) ||
|
||||||
|
m.display_name.toLowerCase().includes(query)
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 计算可添加的上游模型(排除已添加的)
|
||||||
|
const availableUpstreamModelsBase = computed(() => {
|
||||||
|
const allowedSet = new Set(allowedModels.value)
|
||||||
|
return upstreamModels.value.filter(m => !allowedSet.has(m.id))
|
||||||
|
})
|
||||||
|
|
||||||
|
// 搜索过滤后的上游模型
|
||||||
|
const availableUpstreamModels = computed(() => {
|
||||||
|
if (!searchQuery.value.trim()) return availableUpstreamModelsBase.value
|
||||||
|
const query = searchQuery.value.toLowerCase()
|
||||||
|
return availableUpstreamModelsBase.value.filter(m =>
|
||||||
|
m.id.toLowerCase().includes(query) ||
|
||||||
|
(m.owned_by && m.owned_by.toLowerCase().includes(query))
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 按 API 格式分组的上游模型
|
||||||
|
const upstreamModelGroups = computed(() => {
|
||||||
|
const groups: Record<string, UpstreamModel[]> = {}
|
||||||
|
for (const model of availableUpstreamModels.value) {
|
||||||
|
const format = model.api_format || 'unknown'
|
||||||
|
if (!groups[format]) groups[format] = []
|
||||||
|
groups[format].push(model)
|
||||||
|
}
|
||||||
|
const order = Object.keys(API_FORMAT_LABELS)
|
||||||
|
return Object.entries(groups)
|
||||||
|
.map(([api_format, models]) => ({ api_format, models }))
|
||||||
|
.sort((a, b) => {
|
||||||
|
const aIndex = order.indexOf(a.api_format)
|
||||||
|
const bIndex = order.indexOf(b.api_format)
|
||||||
|
if (aIndex === -1 && bIndex === -1) return a.api_format.localeCompare(b.api_format)
|
||||||
|
if (aIndex === -1) return 1
|
||||||
|
if (bIndex === -1) return -1
|
||||||
|
return aIndex - bIndex
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// 总可添加数量
|
||||||
|
const totalAvailableCount = computed(() => {
|
||||||
|
return availableGlobalModels.value.length + availableUpstreamModels.value.length
|
||||||
|
})
|
||||||
|
|
||||||
|
// 右侧全选状态
|
||||||
|
const isAllRightSelected = computed(() =>
|
||||||
|
allowedModels.value.length > 0 &&
|
||||||
|
selectedRightIds.value.length === allowedModels.value.length
|
||||||
|
)
|
||||||
|
|
||||||
|
// 全局模型是否全选
|
||||||
|
const isAllGlobalModelsSelected = computed(() => {
|
||||||
|
if (availableGlobalModels.value.length === 0) return false
|
||||||
|
return availableGlobalModels.value.every(m => selectedLeftIds.value.includes(m.name))
|
||||||
|
})
|
||||||
|
|
||||||
|
// 检查某个上游组是否全选
|
||||||
|
function isUpstreamGroupAllSelected(apiFormat: string): boolean {
|
||||||
|
const group = upstreamModelGroups.value.find(g => g.api_format === apiFormat)
|
||||||
|
if (!group || group.models.length === 0) return false
|
||||||
|
return group.models.every(m => selectedLeftIds.value.includes(m.id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取模型显示名称
|
||||||
|
function getModelDisplayName(name: string): string {
|
||||||
|
const globalModel = allGlobalModels.value.find(m => m.name === name)
|
||||||
|
if (globalModel) return globalModel.display_name
|
||||||
|
const upstreamModel = upstreamModels.value.find(m => m.id === name)
|
||||||
|
if (upstreamModel) return upstreamModel.id
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
// 加载全局模型
|
||||||
|
async function loadGlobalModels() {
|
||||||
|
loadingGlobalModels.value = true
|
||||||
|
try {
|
||||||
|
const response = await getGlobalModels({ limit: 1000 })
|
||||||
|
// 检查是否已取消(dialog 已关闭)
|
||||||
|
if (loadingCancelled) return
|
||||||
|
allGlobalModels.value = response.models.map((m: GlobalModelResponse) => ({
|
||||||
|
name: m.name,
|
||||||
|
display_name: m.display_name
|
||||||
|
}))
|
||||||
|
} catch (err) {
|
||||||
|
if (loadingCancelled) return
|
||||||
|
showError('加载全局模型失败', '错误')
|
||||||
|
} finally {
|
||||||
|
loadingGlobalModels.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从提供商获取模型(使用当前 key)
|
||||||
|
async function fetchUpstreamModels() {
|
||||||
|
if (!props.providerId || !props.apiKey) return
|
||||||
|
try {
|
||||||
|
fetchingUpstreamModels.value = true
|
||||||
|
// 使用当前 key 的 ID 来查询上游模型
|
||||||
|
const response = await adminApi.queryProviderModels(props.providerId, props.apiKey.id)
|
||||||
|
// 检查是否已取消
|
||||||
|
if (loadingCancelled) return
|
||||||
|
if (response.success && response.data?.models) {
|
||||||
|
upstreamModels.value = response.data.models
|
||||||
|
upstreamModelsLoaded.value = true
|
||||||
|
const allGroups = new Set(['global'])
|
||||||
|
for (const model of response.data.models) {
|
||||||
|
if (model.api_format) allGroups.add(model.api_format)
|
||||||
|
}
|
||||||
|
collapsedGroups.value = allGroups
|
||||||
|
} else {
|
||||||
|
showError(response.data?.error || '获取上游模型失败', '错误')
|
||||||
|
}
|
||||||
|
} catch (err: any) {
|
||||||
|
if (loadingCancelled) return
|
||||||
|
showError(err.response?.data?.detail || '获取上游模型失败', '错误')
|
||||||
|
} finally {
|
||||||
|
fetchingUpstreamModels.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 切换折叠状态
|
||||||
|
function toggleGroupCollapse(group: string) {
|
||||||
|
if (collapsedGroups.value.has(group)) {
|
||||||
|
collapsedGroups.value.delete(group)
|
||||||
|
} else {
|
||||||
|
collapsedGroups.value.add(group)
|
||||||
|
}
|
||||||
|
collapsedGroups.value = new Set(collapsedGroups.value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 是否为字典模式(按 API 格式区分)
|
||||||
|
const isDictMode = ref(false)
|
||||||
|
|
||||||
|
// 解析 allowed_models
|
||||||
|
function parseAllowedModels(allowed: AllowedModels): string[] {
|
||||||
|
if (allowed === null || allowed === undefined) {
|
||||||
|
isDictMode.value = false
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
if (Array.isArray(allowed)) {
|
||||||
|
isDictMode.value = false
|
||||||
|
return [...allowed]
|
||||||
|
}
|
||||||
|
// 字典模式:合并所有格式的模型,并设置警告标志
|
||||||
|
isDictMode.value = true
|
||||||
|
const all = new Set<string>()
|
||||||
|
for (const models of Object.values(allowed)) {
|
||||||
|
models.forEach(m => all.add(m))
|
||||||
|
}
|
||||||
|
return Array.from(all)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 左侧选择
|
||||||
|
function toggleLeftSelection(name: string) {
|
||||||
|
const idx = selectedLeftIds.value.indexOf(name)
|
||||||
|
if (idx === -1) {
|
||||||
|
selectedLeftIds.value.push(name)
|
||||||
|
} else {
|
||||||
|
selectedLeftIds.value.splice(idx, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 右侧选择
|
||||||
|
function toggleRightSelection(name: string) {
|
||||||
|
const idx = selectedRightIds.value.indexOf(name)
|
||||||
|
if (idx === -1) {
|
||||||
|
selectedRightIds.value.push(name)
|
||||||
|
} else {
|
||||||
|
selectedRightIds.value.splice(idx, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 右侧全选切换
|
||||||
|
function toggleSelectAllRight() {
|
||||||
|
if (isAllRightSelected.value) {
|
||||||
|
selectedRightIds.value = []
|
||||||
|
} else {
|
||||||
|
selectedRightIds.value = [...allowedModels.value]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 全选全局模型
|
||||||
|
function selectAllGlobalModels() {
|
||||||
|
const allNames = availableGlobalModels.value.map(m => m.name)
|
||||||
|
const allSelected = allNames.every(name => selectedLeftIds.value.includes(name))
|
||||||
|
if (allSelected) {
|
||||||
|
selectedLeftIds.value = selectedLeftIds.value.filter(id => !allNames.includes(id))
|
||||||
|
} else {
|
||||||
|
const newNames = allNames.filter(name => !selectedLeftIds.value.includes(name))
|
||||||
|
selectedLeftIds.value.push(...newNames)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 全选某个 API 格式的上游模型
|
||||||
|
function selectAllUpstreamModels(apiFormat: string) {
|
||||||
|
const group = upstreamModelGroups.value.find(g => g.api_format === apiFormat)
|
||||||
|
if (!group) return
|
||||||
|
const allIds = group.models.map(m => m.id)
|
||||||
|
const allSelected = allIds.every(id => selectedLeftIds.value.includes(id))
|
||||||
|
if (allSelected) {
|
||||||
|
selectedLeftIds.value = selectedLeftIds.value.filter(id => !allIds.includes(id))
|
||||||
|
} else {
|
||||||
|
const newIds = allIds.filter(id => !selectedLeftIds.value.includes(id))
|
||||||
|
selectedLeftIds.value.push(...newIds)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加选中的模型到右侧
|
||||||
|
function addSelected() {
|
||||||
|
for (const name of selectedLeftIds.value) {
|
||||||
|
if (!allowedModels.value.includes(name)) {
|
||||||
|
allowedModels.value.push(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
selectedLeftIds.value = []
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从右侧移除选中的模型
|
||||||
|
function removeSelected() {
|
||||||
|
allowedModels.value = allowedModels.value.filter(
|
||||||
|
name => !selectedRightIds.value.includes(name)
|
||||||
|
)
|
||||||
|
selectedRightIds.value = []
|
||||||
|
}
|
||||||
|
|
||||||
|
// 监听对话框打开
|
||||||
|
watch(() => props.open, async (open) => {
|
||||||
|
if (open && props.apiKey) {
|
||||||
|
// 重置取消标志
|
||||||
|
loadingCancelled = false
|
||||||
|
|
||||||
|
const parsed = parseAllowedModels(props.apiKey.allowed_models ?? null)
|
||||||
|
allowedModels.value = [...parsed]
|
||||||
|
initialAllowedModels.value = [...parsed]
|
||||||
|
selectedLeftIds.value = []
|
||||||
|
selectedRightIds.value = []
|
||||||
|
searchQuery.value = ''
|
||||||
|
upstreamModels.value = []
|
||||||
|
upstreamModelsLoaded.value = false
|
||||||
|
collapsedGroups.value = new Set()
|
||||||
|
|
||||||
|
await loadGlobalModels()
|
||||||
|
} else {
|
||||||
|
// dialog 关闭时设置取消标志
|
||||||
|
loadingCancelled = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 组件卸载时取消所有异步操作
|
||||||
|
onUnmounted(() => {
|
||||||
|
loadingCancelled = true
|
||||||
|
})
|
||||||
|
|
||||||
|
function handleDialogUpdate(value: boolean) {
|
||||||
|
if (!value) emit('close')
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleCancel() {
|
||||||
|
emit('close')
|
||||||
|
}
|
||||||
|
|
||||||
|
async function handleSave() {
|
||||||
|
if (!props.apiKey) return
|
||||||
|
|
||||||
|
saving.value = true
|
||||||
|
try {
|
||||||
|
// 空列表 = null(允许全部)
|
||||||
|
const newAllowed: AllowedModels = allowedModels.value.length > 0
|
||||||
|
? [...allowedModels.value]
|
||||||
|
: null
|
||||||
|
|
||||||
|
await updateProviderKey(props.apiKey.id, { allowed_models: newAllowed })
|
||||||
|
success('模型权限已更新', '成功')
|
||||||
|
emit('saved')
|
||||||
|
emit('close')
|
||||||
|
} catch (err: any) {
|
||||||
|
showError(parseApiError(err, '保存失败'), '错误')
|
||||||
|
} finally {
|
||||||
|
saving.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</script>
|
||||||
@@ -2,57 +2,36 @@
|
|||||||
<Dialog
|
<Dialog
|
||||||
:model-value="isOpen"
|
:model-value="isOpen"
|
||||||
:title="isEditMode ? '编辑密钥' : '添加密钥'"
|
:title="isEditMode ? '编辑密钥' : '添加密钥'"
|
||||||
:description="isEditMode ? '修改 API 密钥配置' : '为端点添加新的 API 密钥'"
|
:description="isEditMode ? '修改 API 密钥配置' : '为提供商添加新的 API 密钥'"
|
||||||
:icon="isEditMode ? SquarePen : Key"
|
:icon="isEditMode ? SquarePen : Key"
|
||||||
size="2xl"
|
size="2xl"
|
||||||
@update:model-value="handleDialogUpdate"
|
@update:model-value="handleDialogUpdate"
|
||||||
>
|
>
|
||||||
<form
|
<form
|
||||||
class="space-y-5"
|
class="space-y-4"
|
||||||
autocomplete="off"
|
autocomplete="off"
|
||||||
@submit.prevent="handleSave"
|
@submit.prevent="handleSave"
|
||||||
>
|
>
|
||||||
<!-- 基本信息 -->
|
<!-- 基本信息 -->
|
||||||
<div class="space-y-3">
|
<div class="grid grid-cols-2 gap-3">
|
||||||
<h3 class="text-sm font-medium border-b pb-2">
|
<div>
|
||||||
基本信息
|
<Label :for="keyNameInputId">密钥名称 *</Label>
|
||||||
</h3>
|
<Input
|
||||||
<div class="grid grid-cols-2 gap-4">
|
:id="keyNameInputId"
|
||||||
<div>
|
v-model="form.name"
|
||||||
<Label :for="keyNameInputId">密钥名称 *</Label>
|
:name="keyNameFieldName"
|
||||||
<Input
|
required
|
||||||
:id="keyNameInputId"
|
placeholder="例如:主 Key、备用 Key 1"
|
||||||
v-model="form.name"
|
maxlength="100"
|
||||||
:name="keyNameFieldName"
|
autocomplete="off"
|
||||||
required
|
autocapitalize="none"
|
||||||
placeholder="例如:主 Key、备用 Key 1"
|
autocorrect="off"
|
||||||
maxlength="100"
|
spellcheck="false"
|
||||||
autocomplete="off"
|
data-form-type="other"
|
||||||
autocapitalize="none"
|
data-lpignore="true"
|
||||||
autocorrect="off"
|
data-1p-ignore="true"
|
||||||
spellcheck="false"
|
/>
|
||||||
data-form-type="other"
|
|
||||||
data-lpignore="true"
|
|
||||||
data-1p-ignore="true"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<Label for="rate_multiplier">成本倍率 *</Label>
|
|
||||||
<Input
|
|
||||||
id="rate_multiplier"
|
|
||||||
v-model.number="form.rate_multiplier"
|
|
||||||
type="number"
|
|
||||||
step="0.01"
|
|
||||||
min="0.01"
|
|
||||||
required
|
|
||||||
placeholder="1.0"
|
|
||||||
/>
|
|
||||||
<p class="text-xs text-muted-foreground mt-1">
|
|
||||||
真实成本 = 表面成本 × 倍率
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div>
|
<div>
|
||||||
<Label :for="apiKeyInputId">API 密钥 {{ editingKey ? '' : '*' }}</Label>
|
<Label :for="apiKeyInputId">API 密钥 {{ editingKey ? '' : '*' }}</Label>
|
||||||
<Input
|
<Input
|
||||||
@@ -83,148 +62,161 @@
|
|||||||
v-else-if="editingKey"
|
v-else-if="editingKey"
|
||||||
class="text-xs text-muted-foreground mt-1"
|
class="text-xs text-muted-foreground mt-1"
|
||||||
>
|
>
|
||||||
留空表示不修改,输入新值则覆盖
|
留空表示不修改
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 备注 -->
|
||||||
|
<div>
|
||||||
|
<Label for="note">备注</Label>
|
||||||
|
<Input
|
||||||
|
id="note"
|
||||||
|
v-model="form.note"
|
||||||
|
placeholder="可选的备注信息"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- API 格式选择 -->
|
||||||
|
<div v-if="sortedApiFormats.length > 0">
|
||||||
|
<Label class="mb-1.5 block">支持的 API 格式 *</Label>
|
||||||
|
<div class="grid grid-cols-2 sm:grid-cols-3 gap-2">
|
||||||
|
<div
|
||||||
|
v-for="format in sortedApiFormats"
|
||||||
|
:key="format"
|
||||||
|
class="flex items-center justify-between rounded-md border px-2 py-1.5 transition-colors cursor-pointer"
|
||||||
|
:class="form.api_formats.includes(format)
|
||||||
|
? 'bg-primary/5 border-primary/30'
|
||||||
|
: 'bg-muted/30 border-border hover:border-muted-foreground/30'"
|
||||||
|
@click="toggleApiFormat(format)"
|
||||||
|
>
|
||||||
|
<div class="flex items-center gap-1.5 min-w-0">
|
||||||
|
<span
|
||||||
|
class="w-4 h-4 rounded border flex items-center justify-center text-xs shrink-0"
|
||||||
|
:class="form.api_formats.includes(format)
|
||||||
|
? 'bg-primary border-primary text-primary-foreground'
|
||||||
|
: 'border-muted-foreground/30'"
|
||||||
|
>
|
||||||
|
<span v-if="form.api_formats.includes(format)">✓</span>
|
||||||
|
</span>
|
||||||
|
<span
|
||||||
|
class="text-sm whitespace-nowrap"
|
||||||
|
:class="form.api_formats.includes(format) ? 'text-primary' : 'text-muted-foreground'"
|
||||||
|
>{{ API_FORMAT_LABELS[format] || format }}</span>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
class="flex items-center shrink-0 ml-2 text-xs text-muted-foreground gap-1"
|
||||||
|
@click.stop
|
||||||
|
>
|
||||||
|
<span>×</span>
|
||||||
|
<input
|
||||||
|
:value="form.rate_multipliers[format] ?? ''"
|
||||||
|
type="number"
|
||||||
|
step="0.01"
|
||||||
|
min="0.01"
|
||||||
|
placeholder="1"
|
||||||
|
class="w-9 bg-transparent text-right outline-none [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none"
|
||||||
|
:class="form.api_formats.includes(format) ? 'text-primary' : 'text-muted-foreground'"
|
||||||
|
title="成本倍率"
|
||||||
|
@input="(e) => updateRateMultiplier(format, (e.target as HTMLInputElement).value)"
|
||||||
|
>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 配置项 -->
|
||||||
|
<div class="grid grid-cols-4 gap-3">
|
||||||
<div>
|
<div>
|
||||||
<Label for="note">备注</Label>
|
<Label
|
||||||
|
for="internal_priority"
|
||||||
|
class="text-xs"
|
||||||
|
>优先级</Label>
|
||||||
<Input
|
<Input
|
||||||
id="note"
|
id="internal_priority"
|
||||||
v-model="form.note"
|
v-model.number="form.internal_priority"
|
||||||
placeholder="可选的备注信息"
|
type="number"
|
||||||
|
min="0"
|
||||||
|
class="h-8"
|
||||||
/>
|
/>
|
||||||
|
<p class="text-xs text-muted-foreground mt-0.5">
|
||||||
|
越小越优先
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<Label
|
||||||
|
for="rpm_limit"
|
||||||
|
class="text-xs"
|
||||||
|
>RPM 限制</Label>
|
||||||
|
<Input
|
||||||
|
id="rpm_limit"
|
||||||
|
:model-value="form.rpm_limit ?? ''"
|
||||||
|
type="number"
|
||||||
|
min="1"
|
||||||
|
max="10000"
|
||||||
|
placeholder="自适应"
|
||||||
|
class="h-8"
|
||||||
|
@update:model-value="(v) => form.rpm_limit = parseNullableNumberInput(v, { min: 1, max: 10000 })"
|
||||||
|
/>
|
||||||
|
<p class="text-xs text-muted-foreground mt-0.5">
|
||||||
|
留空自适应
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<Label
|
||||||
|
for="cache_ttl_minutes"
|
||||||
|
class="text-xs"
|
||||||
|
>缓存 TTL</Label>
|
||||||
|
<Input
|
||||||
|
id="cache_ttl_minutes"
|
||||||
|
:model-value="form.cache_ttl_minutes ?? ''"
|
||||||
|
type="number"
|
||||||
|
min="0"
|
||||||
|
max="60"
|
||||||
|
class="h-8"
|
||||||
|
@update:model-value="(v) => form.cache_ttl_minutes = parseNumberInput(v, { min: 0, max: 60 }) ?? 5"
|
||||||
|
/>
|
||||||
|
<p class="text-xs text-muted-foreground mt-0.5">
|
||||||
|
分钟,0禁用
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<Label
|
||||||
|
for="max_probe_interval_minutes"
|
||||||
|
class="text-xs"
|
||||||
|
>熔断探测</Label>
|
||||||
|
<Input
|
||||||
|
id="max_probe_interval_minutes"
|
||||||
|
:model-value="form.max_probe_interval_minutes ?? ''"
|
||||||
|
type="number"
|
||||||
|
min="2"
|
||||||
|
max="32"
|
||||||
|
placeholder="32"
|
||||||
|
class="h-8"
|
||||||
|
@update:model-value="(v) => form.max_probe_interval_minutes = parseNumberInput(v, { min: 2, max: 32 }) ?? 32"
|
||||||
|
/>
|
||||||
|
<p class="text-xs text-muted-foreground mt-0.5">
|
||||||
|
分钟,2-32
|
||||||
|
</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 调度与限流 -->
|
<!-- 能力标签 -->
|
||||||
<div class="space-y-3">
|
<div v-if="availableCapabilities.length > 0">
|
||||||
<h3 class="text-sm font-medium border-b pb-2">
|
<Label class="text-xs mb-1.5 block">能力标签</Label>
|
||||||
调度与限流
|
<div class="flex flex-wrap gap-1.5">
|
||||||
</h3>
|
<button
|
||||||
<div class="grid grid-cols-2 gap-4">
|
|
||||||
<div>
|
|
||||||
<Label for="internal_priority">内部优先级</Label>
|
|
||||||
<Input
|
|
||||||
id="internal_priority"
|
|
||||||
v-model.number="form.internal_priority"
|
|
||||||
type="number"
|
|
||||||
min="0"
|
|
||||||
/>
|
|
||||||
<p class="text-xs text-muted-foreground mt-1">
|
|
||||||
数字越小越优先
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<Label for="max_concurrent">最大并发</Label>
|
|
||||||
<Input
|
|
||||||
id="max_concurrent"
|
|
||||||
:model-value="form.max_concurrent ?? ''"
|
|
||||||
type="number"
|
|
||||||
min="1"
|
|
||||||
placeholder="留空启用自适应"
|
|
||||||
@update:model-value="(v) => form.max_concurrent = parseNumberInput(v)"
|
|
||||||
/>
|
|
||||||
<p class="text-xs text-muted-foreground mt-1">
|
|
||||||
留空 = 自适应模式
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="grid grid-cols-3 gap-4">
|
|
||||||
<div>
|
|
||||||
<Label for="rate_limit">速率限制(/分钟)</Label>
|
|
||||||
<Input
|
|
||||||
id="rate_limit"
|
|
||||||
:model-value="form.rate_limit ?? ''"
|
|
||||||
type="number"
|
|
||||||
min="1"
|
|
||||||
@update:model-value="(v) => form.rate_limit = parseNumberInput(v)"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<Label for="daily_limit">每日限制</Label>
|
|
||||||
<Input
|
|
||||||
id="daily_limit"
|
|
||||||
:model-value="form.daily_limit ?? ''"
|
|
||||||
type="number"
|
|
||||||
min="1"
|
|
||||||
@update:model-value="(v) => form.daily_limit = parseNumberInput(v)"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<Label for="monthly_limit">每月限制</Label>
|
|
||||||
<Input
|
|
||||||
id="monthly_limit"
|
|
||||||
:model-value="form.monthly_limit ?? ''"
|
|
||||||
type="number"
|
|
||||||
min="1"
|
|
||||||
@update:model-value="(v) => form.monthly_limit = parseNumberInput(v)"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 缓存与熔断 -->
|
|
||||||
<div class="space-y-3">
|
|
||||||
<h3 class="text-sm font-medium border-b pb-2">
|
|
||||||
缓存与熔断
|
|
||||||
</h3>
|
|
||||||
<div class="grid grid-cols-2 gap-4">
|
|
||||||
<div>
|
|
||||||
<Label for="cache_ttl_minutes">缓存 TTL (分钟)</Label>
|
|
||||||
<Input
|
|
||||||
id="cache_ttl_minutes"
|
|
||||||
:model-value="form.cache_ttl_minutes ?? ''"
|
|
||||||
type="number"
|
|
||||||
min="0"
|
|
||||||
max="60"
|
|
||||||
@update:model-value="(v) => form.cache_ttl_minutes = parseNumberInput(v, { min: 0, max: 60 }) ?? 5"
|
|
||||||
/>
|
|
||||||
<p class="text-xs text-muted-foreground mt-1">
|
|
||||||
0 = 禁用缓存亲和性
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<Label for="max_probe_interval_minutes">熔断探测间隔 (分钟)</Label>
|
|
||||||
<Input
|
|
||||||
id="max_probe_interval_minutes"
|
|
||||||
:model-value="form.max_probe_interval_minutes ?? ''"
|
|
||||||
type="number"
|
|
||||||
min="2"
|
|
||||||
max="32"
|
|
||||||
placeholder="32"
|
|
||||||
@update:model-value="(v) => form.max_probe_interval_minutes = parseNumberInput(v, { min: 2, max: 32 }) ?? 32"
|
|
||||||
/>
|
|
||||||
<p class="text-xs text-muted-foreground mt-1">
|
|
||||||
范围 2-32 分钟
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 能力标签配置 -->
|
|
||||||
<div
|
|
||||||
v-if="availableCapabilities.length > 0"
|
|
||||||
class="space-y-3"
|
|
||||||
>
|
|
||||||
<h3 class="text-sm font-medium border-b pb-2">
|
|
||||||
能力标签
|
|
||||||
</h3>
|
|
||||||
<div class="flex flex-wrap gap-2">
|
|
||||||
<label
|
|
||||||
v-for="cap in availableCapabilities"
|
v-for="cap in availableCapabilities"
|
||||||
:key="cap.name"
|
:key="cap.name"
|
||||||
class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm"
|
type="button"
|
||||||
|
class="inline-flex items-center gap-1.5 px-2.5 py-1 rounded-md border text-sm transition-colors"
|
||||||
|
:class="form.capabilities[cap.name]
|
||||||
|
? 'bg-primary/10 border-primary/50 text-primary'
|
||||||
|
: 'bg-card border-border hover:bg-muted/50 text-muted-foreground'"
|
||||||
|
@click="form.capabilities[cap.name] = !form.capabilities[cap.name]"
|
||||||
>
|
>
|
||||||
<input
|
{{ cap.display_name }}
|
||||||
type="checkbox"
|
</button>
|
||||||
:checked="form.capabilities[cap.name] || false"
|
|
||||||
class="rounded"
|
|
||||||
@change="form.capabilities[cap.name] = !form.capabilities[cap.name]"
|
|
||||||
>
|
|
||||||
<span>{{ cap.display_name }}</span>
|
|
||||||
</label>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</form>
|
</form>
|
||||||
@@ -240,25 +232,27 @@
|
|||||||
:disabled="saving"
|
:disabled="saving"
|
||||||
@click="handleSave"
|
@click="handleSave"
|
||||||
>
|
>
|
||||||
{{ saving ? '保存中...' : '保存' }}
|
{{ saving ? (isEditMode ? '保存中...' : '添加中...') : (isEditMode ? '保存' : '添加') }}
|
||||||
</Button>
|
</Button>
|
||||||
</template>
|
</template>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed, onMounted } from 'vue'
|
import { ref, computed, onMounted, watch } from 'vue'
|
||||||
import { Dialog, Button, Input, Label } from '@/components/ui'
|
import { Dialog, Button, Input, Label } from '@/components/ui'
|
||||||
import { Key, SquarePen } from 'lucide-vue-next'
|
import { Key, SquarePen } from 'lucide-vue-next'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { useFormDialog } from '@/composables/useFormDialog'
|
import { useFormDialog } from '@/composables/useFormDialog'
|
||||||
import { parseApiError } from '@/utils/errorParser'
|
import { parseApiError } from '@/utils/errorParser'
|
||||||
import { parseNumberInput } from '@/utils/form'
|
import { parseNumberInput, parseNullableNumberInput } from '@/utils/form'
|
||||||
import { log } from '@/utils/logger'
|
import { log } from '@/utils/logger'
|
||||||
import {
|
import {
|
||||||
addEndpointKey,
|
addProviderKey,
|
||||||
updateEndpointKey,
|
updateProviderKey,
|
||||||
getAllCapabilities,
|
getAllCapabilities,
|
||||||
|
API_FORMAT_LABELS,
|
||||||
|
sortApiFormats,
|
||||||
type EndpointAPIKey,
|
type EndpointAPIKey,
|
||||||
type EndpointAPIKeyUpdate,
|
type EndpointAPIKeyUpdate,
|
||||||
type ProviderEndpoint,
|
type ProviderEndpoint,
|
||||||
@@ -270,6 +264,7 @@ const props = defineProps<{
|
|||||||
endpoint: ProviderEndpoint | null
|
endpoint: ProviderEndpoint | null
|
||||||
editingKey: EndpointAPIKey | null
|
editingKey: EndpointAPIKey | null
|
||||||
providerId: string | null
|
providerId: string | null
|
||||||
|
availableApiFormats: string[] // Provider 支持的所有 API 格式
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const emit = defineEmits<{
|
const emit = defineEmits<{
|
||||||
@@ -279,6 +274,9 @@ const emit = defineEmits<{
|
|||||||
|
|
||||||
const { success, error: showError } = useToast()
|
const { success, error: showError } = useToast()
|
||||||
|
|
||||||
|
// 排序后的可用 API 格式列表
|
||||||
|
const sortedApiFormats = computed(() => sortApiFormats(props.availableApiFormats))
|
||||||
|
|
||||||
const isOpen = computed(() => props.open)
|
const isOpen = computed(() => props.open)
|
||||||
const saving = ref(false)
|
const saving = ref(false)
|
||||||
const formNonce = ref(createFieldNonce())
|
const formNonce = ref(createFieldNonce())
|
||||||
@@ -297,12 +295,10 @@ const availableCapabilities = ref<CapabilityDefinition[]>([])
|
|||||||
const form = ref({
|
const form = ref({
|
||||||
name: '',
|
name: '',
|
||||||
api_key: '',
|
api_key: '',
|
||||||
rate_multiplier: 1.0,
|
api_formats: [] as string[], // 支持的 API 格式列表
|
||||||
internal_priority: 50,
|
rate_multipliers: {} as Record<string, number>, // 按 API 格式的成本倍率
|
||||||
max_concurrent: undefined as number | undefined,
|
internal_priority: 10,
|
||||||
rate_limit: undefined as number | undefined,
|
rpm_limit: undefined as number | null | undefined, // RPM 限制(null=自适应,undefined=保持原值)
|
||||||
daily_limit: undefined as number | undefined,
|
|
||||||
monthly_limit: undefined as number | undefined,
|
|
||||||
cache_ttl_minutes: 5,
|
cache_ttl_minutes: 5,
|
||||||
max_probe_interval_minutes: 32,
|
max_probe_interval_minutes: 32,
|
||||||
note: '',
|
note: '',
|
||||||
@@ -323,6 +319,43 @@ onMounted(() => {
|
|||||||
loadCapabilities()
|
loadCapabilities()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// API 格式切换
|
||||||
|
function toggleApiFormat(format: string) {
|
||||||
|
const index = form.value.api_formats.indexOf(format)
|
||||||
|
if (index === -1) {
|
||||||
|
// 添加格式
|
||||||
|
form.value.api_formats.push(format)
|
||||||
|
} else {
|
||||||
|
// 移除格式前检查:至少保留一个格式
|
||||||
|
if (form.value.api_formats.length <= 1) {
|
||||||
|
showError('至少需要选择一个 API 格式', '验证失败')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 移除格式,但保留倍率配置(用户可能只是临时取消)
|
||||||
|
form.value.api_formats.splice(index, 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新指定格式的成本倍率
|
||||||
|
function updateRateMultiplier(format: string, value: string | number) {
|
||||||
|
// 使用对象替换以确保 Vue 3 响应性
|
||||||
|
const newMultipliers = { ...form.value.rate_multipliers }
|
||||||
|
|
||||||
|
if (value === '' || value === null || value === undefined) {
|
||||||
|
// 清空时删除该格式的配置(使用默认倍率)
|
||||||
|
delete newMultipliers[format]
|
||||||
|
} else {
|
||||||
|
const numValue = typeof value === 'string' ? parseFloat(value) : value
|
||||||
|
// 限制倍率范围:0.01 - 100
|
||||||
|
if (!isNaN(numValue) && numValue >= 0.01 && numValue <= 100) {
|
||||||
|
newMultipliers[format] = numValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 替换整个对象以触发响应式更新
|
||||||
|
form.value.rate_multipliers = newMultipliers
|
||||||
|
}
|
||||||
|
|
||||||
// API 密钥输入框样式计算
|
// API 密钥输入框样式计算
|
||||||
function getApiKeyInputClass(): string {
|
function getApiKeyInputClass(): string {
|
||||||
const classes = []
|
const classes = []
|
||||||
@@ -349,8 +382,8 @@ const apiKeyError = computed(() => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 如果输入了值,检查长度
|
// 如果输入了值,检查长度
|
||||||
if (apiKey.length < 10) {
|
if (apiKey.length < 3) {
|
||||||
return 'API 密钥至少需要 10 个字符'
|
return 'API 密钥至少需要 3 个字符'
|
||||||
}
|
}
|
||||||
|
|
||||||
return ''
|
return ''
|
||||||
@@ -363,12 +396,10 @@ function resetForm() {
|
|||||||
form.value = {
|
form.value = {
|
||||||
name: '',
|
name: '',
|
||||||
api_key: '',
|
api_key: '',
|
||||||
rate_multiplier: 1.0,
|
api_formats: [], // 默认不选中任何格式
|
||||||
internal_priority: 50,
|
rate_multipliers: {},
|
||||||
max_concurrent: undefined,
|
internal_priority: 10,
|
||||||
rate_limit: undefined,
|
rpm_limit: undefined,
|
||||||
daily_limit: undefined,
|
|
||||||
monthly_limit: undefined,
|
|
||||||
cache_ttl_minutes: 5,
|
cache_ttl_minutes: 5,
|
||||||
max_probe_interval_minutes: 32,
|
max_probe_interval_minutes: 32,
|
||||||
note: '',
|
note: '',
|
||||||
@@ -377,6 +408,14 @@ function resetForm() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 添加成功后清除部分字段以便继续添加
|
||||||
|
function clearForNextAdd() {
|
||||||
|
formNonce.value = createFieldNonce()
|
||||||
|
apiKeyFocused.value = false
|
||||||
|
form.value.name = ''
|
||||||
|
form.value.api_key = ''
|
||||||
|
}
|
||||||
|
|
||||||
// 加载密钥数据(编辑模式)
|
// 加载密钥数据(编辑模式)
|
||||||
function loadKeyData() {
|
function loadKeyData() {
|
||||||
if (!props.editingKey) return
|
if (!props.editingKey) return
|
||||||
@@ -385,13 +424,13 @@ function loadKeyData() {
|
|||||||
form.value = {
|
form.value = {
|
||||||
name: props.editingKey.name,
|
name: props.editingKey.name,
|
||||||
api_key: '',
|
api_key: '',
|
||||||
rate_multiplier: props.editingKey.rate_multiplier || 1.0,
|
api_formats: props.editingKey.api_formats?.length > 0
|
||||||
internal_priority: props.editingKey.internal_priority ?? 50,
|
? [...props.editingKey.api_formats]
|
||||||
|
: [], // 编辑模式下保持原有选择,不默认全选
|
||||||
|
rate_multipliers: { ...(props.editingKey.rate_multipliers || {}) },
|
||||||
|
internal_priority: props.editingKey.internal_priority ?? 10,
|
||||||
// 保留原始的 null/undefined 状态,null 表示自适应模式
|
// 保留原始的 null/undefined 状态,null 表示自适应模式
|
||||||
max_concurrent: props.editingKey.max_concurrent ?? undefined,
|
rpm_limit: props.editingKey.rpm_limit ?? undefined,
|
||||||
rate_limit: props.editingKey.rate_limit ?? undefined,
|
|
||||||
daily_limit: props.editingKey.daily_limit ?? undefined,
|
|
||||||
monthly_limit: props.editingKey.monthly_limit ?? undefined,
|
|
||||||
cache_ttl_minutes: props.editingKey.cache_ttl_minutes ?? 5,
|
cache_ttl_minutes: props.editingKey.cache_ttl_minutes ?? 5,
|
||||||
max_probe_interval_minutes: props.editingKey.max_probe_interval_minutes ?? 32,
|
max_probe_interval_minutes: props.editingKey.max_probe_interval_minutes ?? 32,
|
||||||
note: props.editingKey.note || '',
|
note: props.editingKey.note || '',
|
||||||
@@ -415,7 +454,11 @@ function createFieldNonce(): string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function handleSave() {
|
async function handleSave() {
|
||||||
if (!props.endpoint) return
|
// 必须有 providerId
|
||||||
|
if (!props.providerId) {
|
||||||
|
showError('无法保存:缺少提供商信息', '错误')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 提交前验证
|
// 提交前验证
|
||||||
if (apiKeyError.value) {
|
if (apiKeyError.value) {
|
||||||
@@ -429,6 +472,12 @@ async function handleSave() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 验证至少选择一个 API 格式
|
||||||
|
if (form.value.api_formats.length === 0) {
|
||||||
|
showError('请至少选择一个 API 格式', '验证失败')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// 过滤出有效的能力配置(只包含值为 true 的)
|
// 过滤出有效的能力配置(只包含值为 true 的)
|
||||||
const activeCapabilities: Record<string, boolean> = {}
|
const activeCapabilities: Record<string, boolean> = {}
|
||||||
for (const [key, value] of Object.entries(form.value.capabilities)) {
|
for (const [key, value] of Object.entries(form.value.capabilities)) {
|
||||||
@@ -440,21 +489,27 @@ async function handleSave() {
|
|||||||
|
|
||||||
saving.value = true
|
saving.value = true
|
||||||
try {
|
try {
|
||||||
|
// 准备 rate_multipliers 数据:只保留已选中格式的倍率配置
|
||||||
|
const filteredMultipliers: Record<string, number> = {}
|
||||||
|
for (const format of form.value.api_formats) {
|
||||||
|
if (form.value.rate_multipliers[format] !== undefined) {
|
||||||
|
filteredMultipliers[format] = form.value.rate_multipliers[format]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const rateMultipliersData = Object.keys(filteredMultipliers).length > 0
|
||||||
|
? filteredMultipliers
|
||||||
|
: null
|
||||||
|
|
||||||
if (props.editingKey) {
|
if (props.editingKey) {
|
||||||
// 更新模式
|
// 更新模式
|
||||||
// 注意:max_concurrent 需要显式发送 null 来切换到自适应模式
|
// 注意:rpm_limit 使用 null 表示自适应模式
|
||||||
// undefined 会在 JSON 中被忽略,所以用 null 表示"清空/自适应"
|
// undefined 表示"保持原值不变"(会在 JSON 序列化时被忽略)
|
||||||
const updateData: EndpointAPIKeyUpdate = {
|
const updateData: EndpointAPIKeyUpdate = {
|
||||||
|
api_formats: form.value.api_formats,
|
||||||
name: form.value.name,
|
name: form.value.name,
|
||||||
rate_multiplier: form.value.rate_multiplier,
|
rate_multipliers: rateMultipliersData,
|
||||||
internal_priority: form.value.internal_priority,
|
internal_priority: form.value.internal_priority,
|
||||||
// 显式使用 null 表示自适应模式,这样后端能区分"未提供"和"设置为 null"
|
rpm_limit: form.value.rpm_limit,
|
||||||
// 注意:只有 max_concurrent 需要这种处理,因为它有"自适应模式"的概念
|
|
||||||
// 其他限制字段(rate_limit 等)不支持"清空"操作,undefined 会被 JSON 忽略即不更新
|
|
||||||
max_concurrent: form.value.max_concurrent === undefined ? null : form.value.max_concurrent,
|
|
||||||
rate_limit: form.value.rate_limit,
|
|
||||||
daily_limit: form.value.daily_limit,
|
|
||||||
monthly_limit: form.value.monthly_limit,
|
|
||||||
cache_ttl_minutes: form.value.cache_ttl_minutes,
|
cache_ttl_minutes: form.value.cache_ttl_minutes,
|
||||||
max_probe_interval_minutes: form.value.max_probe_interval_minutes,
|
max_probe_interval_minutes: form.value.max_probe_interval_minutes,
|
||||||
note: form.value.note,
|
note: form.value.note,
|
||||||
@@ -466,26 +521,27 @@ async function handleSave() {
|
|||||||
updateData.api_key = form.value.api_key
|
updateData.api_key = form.value.api_key
|
||||||
}
|
}
|
||||||
|
|
||||||
await updateEndpointKey(props.editingKey.id, updateData)
|
await updateProviderKey(props.editingKey.id, updateData)
|
||||||
success('密钥已更新', '成功')
|
success('密钥已更新', '成功')
|
||||||
} else {
|
} else {
|
||||||
// 新增
|
// 新增模式
|
||||||
await addEndpointKey(props.endpoint.id, {
|
await addProviderKey(props.providerId, {
|
||||||
endpoint_id: props.endpoint.id,
|
api_formats: form.value.api_formats,
|
||||||
api_key: form.value.api_key,
|
api_key: form.value.api_key,
|
||||||
name: form.value.name,
|
name: form.value.name,
|
||||||
rate_multiplier: form.value.rate_multiplier,
|
rate_multipliers: rateMultipliersData,
|
||||||
internal_priority: form.value.internal_priority,
|
internal_priority: form.value.internal_priority,
|
||||||
max_concurrent: form.value.max_concurrent,
|
rpm_limit: form.value.rpm_limit,
|
||||||
rate_limit: form.value.rate_limit,
|
|
||||||
daily_limit: form.value.daily_limit,
|
|
||||||
monthly_limit: form.value.monthly_limit,
|
|
||||||
cache_ttl_minutes: form.value.cache_ttl_minutes,
|
cache_ttl_minutes: form.value.cache_ttl_minutes,
|
||||||
max_probe_interval_minutes: form.value.max_probe_interval_minutes,
|
max_probe_interval_minutes: form.value.max_probe_interval_minutes,
|
||||||
note: form.value.note,
|
note: form.value.note,
|
||||||
capabilities: capabilitiesData || undefined
|
capabilities: capabilitiesData || undefined
|
||||||
})
|
})
|
||||||
success('密钥已添加', '成功')
|
success('密钥已添加', '成功')
|
||||||
|
// 添加模式:不关闭对话框,只清除名称和密钥以便继续添加
|
||||||
|
emit('saved')
|
||||||
|
clearForNextAdd()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
emit('saved')
|
emit('saved')
|
||||||
|
|||||||
@@ -95,7 +95,7 @@
|
|||||||
|
|
||||||
<!-- 提供商信息 -->
|
<!-- 提供商信息 -->
|
||||||
<div class="flex-1 min-w-0 flex items-center gap-2">
|
<div class="flex-1 min-w-0 flex items-center gap-2">
|
||||||
<span class="font-medium text-sm truncate">{{ provider.display_name }}</span>
|
<span class="font-medium text-sm truncate">{{ provider.name }}</span>
|
||||||
<Badge
|
<Badge
|
||||||
v-if="!provider.is_active"
|
v-if="!provider.is_active"
|
||||||
variant="secondary"
|
variant="secondary"
|
||||||
@@ -262,17 +262,17 @@
|
|||||||
<div class="shrink-0 flex items-center gap-3">
|
<div class="shrink-0 flex items-center gap-3">
|
||||||
<!-- 健康度 -->
|
<!-- 健康度 -->
|
||||||
<div
|
<div
|
||||||
v-if="key.success_rate !== null"
|
v-if="key.health_score != null"
|
||||||
class="text-xs text-right"
|
class="text-xs text-right"
|
||||||
>
|
>
|
||||||
<div
|
<div
|
||||||
class="font-medium tabular-nums"
|
class="font-medium tabular-nums"
|
||||||
:class="[
|
:class="[
|
||||||
key.success_rate >= 0.95 ? 'text-green-600' :
|
key.health_score >= 0.95 ? 'text-green-600' :
|
||||||
key.success_rate >= 0.8 ? 'text-yellow-600' : 'text-red-500'
|
key.health_score >= 0.5 ? 'text-yellow-600' : 'text-red-500'
|
||||||
]"
|
]"
|
||||||
>
|
>
|
||||||
{{ (key.success_rate * 100).toFixed(0) }}%
|
{{ ((key.health_score || 0) * 100).toFixed(0) }}%
|
||||||
</div>
|
</div>
|
||||||
<div class="text-[10px] text-muted-foreground opacity-70">
|
<div class="text-[10px] text-muted-foreground opacity-70">
|
||||||
{{ key.request_count }} reqs
|
{{ key.request_count }} reqs
|
||||||
@@ -319,19 +319,6 @@
|
|||||||
<div class="flex items-center gap-2 pl-4 border-l border-border">
|
<div class="flex items-center gap-2 pl-4 border-l border-border">
|
||||||
<span class="text-xs text-muted-foreground">调度:</span>
|
<span class="text-xs text-muted-foreground">调度:</span>
|
||||||
<div class="flex gap-0.5 p-0.5 bg-muted/40 rounded-md">
|
<div class="flex gap-0.5 p-0.5 bg-muted/40 rounded-md">
|
||||||
<button
|
|
||||||
type="button"
|
|
||||||
class="px-2 py-1 text-xs font-medium rounded transition-all"
|
|
||||||
:class="[
|
|
||||||
schedulingMode === 'fixed_order'
|
|
||||||
? 'bg-primary text-primary-foreground shadow-sm'
|
|
||||||
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
|
|
||||||
]"
|
|
||||||
title="严格按优先级顺序,不考虑缓存"
|
|
||||||
@click="schedulingMode = 'fixed_order'"
|
|
||||||
>
|
|
||||||
固定顺序
|
|
||||||
</button>
|
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
class="px-2 py-1 text-xs font-medium rounded transition-all"
|
class="px-2 py-1 text-xs font-medium rounded transition-all"
|
||||||
@@ -345,6 +332,32 @@
|
|||||||
>
|
>
|
||||||
缓存亲和
|
缓存亲和
|
||||||
</button>
|
</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="px-2 py-1 text-xs font-medium rounded transition-all"
|
||||||
|
:class="[
|
||||||
|
schedulingMode === 'load_balance'
|
||||||
|
? 'bg-primary text-primary-foreground shadow-sm'
|
||||||
|
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
|
||||||
|
]"
|
||||||
|
title="同优先级内随机轮换,不考虑缓存"
|
||||||
|
@click="schedulingMode = 'load_balance'"
|
||||||
|
>
|
||||||
|
负载均衡
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="px-2 py-1 text-xs font-medium rounded transition-all"
|
||||||
|
:class="[
|
||||||
|
schedulingMode === 'fixed_order'
|
||||||
|
? 'bg-primary text-primary-foreground shadow-sm'
|
||||||
|
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
|
||||||
|
]"
|
||||||
|
title="严格按优先级顺序,不考虑缓存"
|
||||||
|
@click="schedulingMode = 'fixed_order'"
|
||||||
|
>
|
||||||
|
固定顺序
|
||||||
|
</button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -382,7 +395,7 @@ import { Dialog } from '@/components/ui'
|
|||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
import Badge from '@/components/ui/badge.vue'
|
import Badge from '@/components/ui/badge.vue'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { updateProvider, updateEndpointKey } from '@/api/endpoints'
|
import { updateProvider, updateProviderKey } from '@/api/endpoints'
|
||||||
import type { ProviderWithEndpointsSummary } from '@/api/endpoints'
|
import type { ProviderWithEndpointsSummary } from '@/api/endpoints'
|
||||||
import { adminApi } from '@/api/admin'
|
import { adminApi } from '@/api/admin'
|
||||||
|
|
||||||
@@ -400,6 +413,7 @@ interface KeyWithMeta {
|
|||||||
endpoint_base_url: string
|
endpoint_base_url: string
|
||||||
api_format: string
|
api_format: string
|
||||||
capabilities: string[]
|
capabilities: string[]
|
||||||
|
health_score: number | null
|
||||||
success_rate: number | null
|
success_rate: number | null
|
||||||
avg_response_time_ms: number | null
|
avg_response_time_ms: number | null
|
||||||
request_count: number
|
request_count: number
|
||||||
@@ -444,7 +458,7 @@ const saving = ref(false)
|
|||||||
const editingKeyPriority = ref<Record<string, string | null>>({}) // format -> keyId
|
const editingKeyPriority = ref<Record<string, string | null>>({}) // format -> keyId
|
||||||
|
|
||||||
// 调度模式状态
|
// 调度模式状态
|
||||||
const schedulingMode = ref<'fixed_order' | 'cache_affinity'>('cache_affinity')
|
const schedulingMode = ref<'fixed_order' | 'load_balance' | 'cache_affinity'>('cache_affinity')
|
||||||
|
|
||||||
// 可用的 API 格式
|
// 可用的 API 格式
|
||||||
const availableFormats = computed(() => {
|
const availableFormats = computed(() => {
|
||||||
@@ -477,7 +491,11 @@ async function loadCurrentPriorityMode() {
|
|||||||
activeMainTab.value = currentMode === 'global_key' ? 'key' : 'provider'
|
activeMainTab.value = currentMode === 'global_key' ? 'key' : 'provider'
|
||||||
|
|
||||||
const currentSchedulingMode = schedulingResponse.value || 'cache_affinity'
|
const currentSchedulingMode = schedulingResponse.value || 'cache_affinity'
|
||||||
schedulingMode.value = currentSchedulingMode === 'fixed_order' ? 'fixed_order' : 'cache_affinity'
|
if (currentSchedulingMode === 'fixed_order' || currentSchedulingMode === 'load_balance' || currentSchedulingMode === 'cache_affinity') {
|
||||||
|
schedulingMode.value = currentSchedulingMode
|
||||||
|
} else {
|
||||||
|
schedulingMode.value = 'cache_affinity'
|
||||||
|
}
|
||||||
} catch {
|
} catch {
|
||||||
activeMainTab.value = 'provider'
|
activeMainTab.value = 'provider'
|
||||||
schedulingMode.value = 'cache_affinity'
|
schedulingMode.value = 'cache_affinity'
|
||||||
@@ -678,7 +696,7 @@ async function save() {
|
|||||||
const keys = keysByFormat.value[format]
|
const keys = keysByFormat.value[format]
|
||||||
keys.forEach((key) => {
|
keys.forEach((key) => {
|
||||||
// 使用用户设置的 priority 值,相同 priority 会做负载均衡
|
// 使用用户设置的 priority 值,相同 priority 会做负载均衡
|
||||||
keyUpdates.push(updateEndpointKey(key.id, { global_priority: key.priority }))
|
keyUpdates.push(updateProviderKey(key.id, { global_priority: key.priority }))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -4,47 +4,29 @@
|
|||||||
:title="isEditMode ? '编辑提供商' : '添加提供商'"
|
:title="isEditMode ? '编辑提供商' : '添加提供商'"
|
||||||
:description="isEditMode ? '更新提供商配置。API 端点和密钥需在详情页面单独管理。' : '创建新的提供商配置。创建后可以为其添加 API 端点和密钥。'"
|
:description="isEditMode ? '更新提供商配置。API 端点和密钥需在详情页面单独管理。' : '创建新的提供商配置。创建后可以为其添加 API 端点和密钥。'"
|
||||||
:icon="isEditMode ? SquarePen : Server"
|
:icon="isEditMode ? SquarePen : Server"
|
||||||
size="2xl"
|
size="xl"
|
||||||
@update:model-value="handleDialogUpdate"
|
@update:model-value="handleDialogUpdate"
|
||||||
>
|
>
|
||||||
<form
|
<form
|
||||||
class="space-y-6"
|
class="space-y-5"
|
||||||
@submit.prevent="handleSubmit"
|
@submit.prevent="handleSubmit"
|
||||||
>
|
>
|
||||||
<!-- 基本信息 -->
|
<!-- 基本信息 -->
|
||||||
<div class="space-y-4">
|
<div class="space-y-3">
|
||||||
<h3 class="text-sm font-medium border-b pb-2">
|
<h3 class="text-sm font-medium border-b pb-2">
|
||||||
基本信息
|
基本信息
|
||||||
</h3>
|
</h3>
|
||||||
|
|
||||||
<!-- 添加模式显示提供商标识 -->
|
|
||||||
<div
|
|
||||||
v-if="!isEditMode"
|
|
||||||
class="space-y-2"
|
|
||||||
>
|
|
||||||
<Label for="name">提供商标识 *</Label>
|
|
||||||
<Input
|
|
||||||
id="name"
|
|
||||||
v-model="form.name"
|
|
||||||
placeholder="例如: openai-primary"
|
|
||||||
required
|
|
||||||
/>
|
|
||||||
<p class="text-xs text-muted-foreground">
|
|
||||||
唯一ID,创建后不可修改
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="grid grid-cols-2 gap-4">
|
<div class="grid grid-cols-2 gap-4">
|
||||||
<div class="space-y-2">
|
<div class="space-y-1.5">
|
||||||
<Label for="display_name">显示名称 *</Label>
|
<Label for="name">名称 *</Label>
|
||||||
<Input
|
<Input
|
||||||
id="display_name"
|
id="name"
|
||||||
v-model="form.display_name"
|
v-model="form.name"
|
||||||
placeholder="例如: OpenAI 主账号"
|
placeholder="例如: OpenAI 主账号"
|
||||||
required
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div class="space-y-2">
|
<div class="space-y-1.5">
|
||||||
<Label for="website">主站链接</Label>
|
<Label for="website">主站链接</Label>
|
||||||
<Input
|
<Input
|
||||||
id="website"
|
id="website"
|
||||||
@@ -55,24 +37,28 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="space-y-2">
|
<div class="space-y-1.5">
|
||||||
<Label for="description">描述</Label>
|
<Label for="description">描述</Label>
|
||||||
<Textarea
|
<Input
|
||||||
id="description"
|
id="description"
|
||||||
v-model="form.description"
|
v-model="form.description"
|
||||||
placeholder="提供商描述(可选)"
|
placeholder="提供商描述(可选)"
|
||||||
rows="2"
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 计费与限流 -->
|
<!-- 计费与限流 / 请求配置 -->
|
||||||
<div class="space-y-4">
|
<div class="space-y-3">
|
||||||
<h3 class="text-sm font-medium border-b pb-2">
|
|
||||||
计费与限流
|
|
||||||
</h3>
|
|
||||||
<div class="grid grid-cols-2 gap-4">
|
<div class="grid grid-cols-2 gap-4">
|
||||||
<div class="space-y-2">
|
<h3 class="text-sm font-medium border-b pb-2">
|
||||||
|
计费与限流
|
||||||
|
</h3>
|
||||||
|
<h3 class="text-sm font-medium border-b pb-2">
|
||||||
|
请求配置
|
||||||
|
</h3>
|
||||||
|
</div>
|
||||||
|
<div class="grid grid-cols-2 gap-4">
|
||||||
|
<div class="space-y-1.5">
|
||||||
<Label>计费类型</Label>
|
<Label>计费类型</Label>
|
||||||
<Select
|
<Select
|
||||||
v-model="form.billing_type"
|
v-model="form.billing_type"
|
||||||
@@ -82,27 +68,35 @@
|
|||||||
<SelectValue />
|
<SelectValue />
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
<SelectContent>
|
<SelectContent>
|
||||||
<SelectItem value="monthly_quota">
|
<SelectItem value="monthly_quota">月卡额度</SelectItem>
|
||||||
月卡额度
|
<SelectItem value="pay_as_you_go">按量付费</SelectItem>
|
||||||
</SelectItem>
|
<SelectItem value="free_tier">免费套餐</SelectItem>
|
||||||
<SelectItem value="pay_as_you_go">
|
|
||||||
按量付费
|
|
||||||
</SelectItem>
|
|
||||||
<SelectItem value="free_tier">
|
|
||||||
免费套餐
|
|
||||||
</SelectItem>
|
|
||||||
</SelectContent>
|
</SelectContent>
|
||||||
</Select>
|
</Select>
|
||||||
</div>
|
</div>
|
||||||
<div class="space-y-2">
|
<div class="grid grid-cols-2 gap-4">
|
||||||
<Label>RPM 限制</Label>
|
<div class="space-y-1.5">
|
||||||
<Input
|
<Label>超时时间 (秒)</Label>
|
||||||
:model-value="form.rpm_limit ?? ''"
|
<Input
|
||||||
type="number"
|
:model-value="form.timeout ?? ''"
|
||||||
min="0"
|
type="number"
|
||||||
placeholder="不限制请留空"
|
min="1"
|
||||||
@update:model-value="(v) => form.rpm_limit = parseNumberInput(v)"
|
max="600"
|
||||||
/>
|
placeholder="默认 300"
|
||||||
|
@update:model-value="(v) => form.timeout = parseNumberInput(v)"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div class="space-y-1.5">
|
||||||
|
<Label>最大重试次数</Label>
|
||||||
|
<Input
|
||||||
|
:model-value="form.max_retries ?? ''"
|
||||||
|
type="number"
|
||||||
|
min="0"
|
||||||
|
max="10"
|
||||||
|
placeholder="默认 2"
|
||||||
|
@update:model-value="(v) => form.max_retries = parseNumberInput(v)"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -111,52 +105,94 @@
|
|||||||
v-if="form.billing_type === 'monthly_quota'"
|
v-if="form.billing_type === 'monthly_quota'"
|
||||||
class="grid grid-cols-2 gap-4 p-3 border rounded-lg bg-muted/50"
|
class="grid grid-cols-2 gap-4 p-3 border rounded-lg bg-muted/50"
|
||||||
>
|
>
|
||||||
<div class="space-y-2">
|
<div class="space-y-1.5">
|
||||||
<Label class="text-xs">周期额度 (USD)</Label>
|
<Label class="text-xs">周期额度 (USD)</Label>
|
||||||
<Input
|
<Input
|
||||||
:model-value="form.monthly_quota_usd ?? ''"
|
:model-value="form.monthly_quota_usd ?? ''"
|
||||||
type="number"
|
type="number"
|
||||||
step="0.01"
|
step="0.01"
|
||||||
min="0"
|
min="0"
|
||||||
class="h-9"
|
|
||||||
@update:model-value="(v) => form.monthly_quota_usd = parseNumberInput(v, { allowFloat: true })"
|
@update:model-value="(v) => form.monthly_quota_usd = parseNumberInput(v, { allowFloat: true })"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div class="space-y-2">
|
<div class="space-y-1.5">
|
||||||
<Label class="text-xs">重置周期 (天)</Label>
|
<Label class="text-xs">重置周期 (天)</Label>
|
||||||
<Input
|
<Input
|
||||||
:model-value="form.quota_reset_day ?? ''"
|
:model-value="form.quota_reset_day ?? ''"
|
||||||
type="number"
|
type="number"
|
||||||
min="1"
|
min="1"
|
||||||
max="365"
|
max="365"
|
||||||
class="h-9"
|
|
||||||
@update:model-value="(v) => form.quota_reset_day = parseNumberInput(v) ?? 30"
|
@update:model-value="(v) => form.quota_reset_day = parseNumberInput(v) ?? 30"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div class="space-y-2">
|
<div class="space-y-1.5">
|
||||||
<Label class="text-xs">
|
<Label class="text-xs">
|
||||||
周期开始时间
|
周期开始时间 <span class="text-red-500">*</span>
|
||||||
<span class="text-red-500">*</span>
|
|
||||||
</Label>
|
</Label>
|
||||||
<Input
|
<Input
|
||||||
v-model="form.quota_last_reset_at"
|
v-model="form.quota_last_reset_at"
|
||||||
type="datetime-local"
|
type="datetime-local"
|
||||||
class="h-9"
|
|
||||||
/>
|
/>
|
||||||
<p class="text-xs text-muted-foreground">
|
|
||||||
系统会自动统计从该时间点开始的使用量
|
|
||||||
</p>
|
|
||||||
</div>
|
</div>
|
||||||
<div class="space-y-2">
|
<div class="space-y-1.5">
|
||||||
<Label class="text-xs">过期时间</Label>
|
<Label class="text-xs">过期时间</Label>
|
||||||
<Input
|
<Input
|
||||||
v-model="form.quota_expires_at"
|
v-model="form.quota_expires_at"
|
||||||
type="datetime-local"
|
type="datetime-local"
|
||||||
class="h-9"
|
|
||||||
/>
|
/>
|
||||||
<p class="text-xs text-muted-foreground">
|
</div>
|
||||||
留空表示永久有效
|
</div>
|
||||||
</p>
|
</div>
|
||||||
|
|
||||||
|
<!-- 代理配置 -->
|
||||||
|
<div class="space-y-3">
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<h3 class="text-sm font-medium">
|
||||||
|
代理配置
|
||||||
|
</h3>
|
||||||
|
<div class="flex items-center gap-2">
|
||||||
|
<Switch
|
||||||
|
:model-value="form.proxy_enabled"
|
||||||
|
@update:model-value="(v: boolean) => form.proxy_enabled = v"
|
||||||
|
/>
|
||||||
|
<span class="text-sm text-muted-foreground">启用代理</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
v-if="form.proxy_enabled"
|
||||||
|
class="grid grid-cols-2 gap-4 p-3 border rounded-lg bg-muted/50"
|
||||||
|
>
|
||||||
|
<div class="space-y-1.5">
|
||||||
|
<Label class="text-xs">代理地址 *</Label>
|
||||||
|
<Input
|
||||||
|
v-model="form.proxy_url"
|
||||||
|
placeholder="http://proxy:port 或 socks5://proxy:port"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div class="grid grid-cols-2 gap-3">
|
||||||
|
<div class="space-y-1.5">
|
||||||
|
<Label class="text-xs">用户名</Label>
|
||||||
|
<Input
|
||||||
|
v-model="form.proxy_username"
|
||||||
|
placeholder="可选"
|
||||||
|
autocomplete="off"
|
||||||
|
data-form-type="other"
|
||||||
|
data-lpignore="true"
|
||||||
|
data-1p-ignore="true"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div class="space-y-1.5">
|
||||||
|
<Label class="text-xs">密码</Label>
|
||||||
|
<Input
|
||||||
|
v-model="form.proxy_password"
|
||||||
|
type="password"
|
||||||
|
placeholder="可选"
|
||||||
|
autocomplete="new-password"
|
||||||
|
data-form-type="other"
|
||||||
|
data-lpignore="true"
|
||||||
|
data-1p-ignore="true"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -172,7 +208,7 @@
|
|||||||
取消
|
取消
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
:disabled="loading || !form.display_name || (!isEditMode && !form.name)"
|
:disabled="loading || !form.name"
|
||||||
@click="handleSubmit"
|
@click="handleSubmit"
|
||||||
>
|
>
|
||||||
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存' : '创建') }}
|
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存' : '创建') }}
|
||||||
@@ -187,13 +223,13 @@ import {
|
|||||||
Dialog,
|
Dialog,
|
||||||
Button,
|
Button,
|
||||||
Input,
|
Input,
|
||||||
Textarea,
|
|
||||||
Label,
|
Label,
|
||||||
Select,
|
Select,
|
||||||
SelectTrigger,
|
SelectTrigger,
|
||||||
SelectValue,
|
SelectValue,
|
||||||
SelectContent,
|
SelectContent,
|
||||||
SelectItem,
|
SelectItem,
|
||||||
|
Switch,
|
||||||
} from '@/components/ui'
|
} from '@/components/ui'
|
||||||
import { Server, SquarePen } from 'lucide-vue-next'
|
import { Server, SquarePen } from 'lucide-vue-next'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
@@ -223,7 +259,6 @@ const internalOpen = computed(() => props.modelValue)
|
|||||||
// 表单数据
|
// 表单数据
|
||||||
const form = ref({
|
const form = ref({
|
||||||
name: '',
|
name: '',
|
||||||
display_name: '',
|
|
||||||
description: '',
|
description: '',
|
||||||
website: '',
|
website: '',
|
||||||
// 计费配置
|
// 计费配置
|
||||||
@@ -232,19 +267,25 @@ const form = ref({
|
|||||||
quota_reset_day: 30,
|
quota_reset_day: 30,
|
||||||
quota_last_reset_at: '', // 周期开始时间
|
quota_last_reset_at: '', // 周期开始时间
|
||||||
quota_expires_at: '',
|
quota_expires_at: '',
|
||||||
rpm_limit: undefined as string | number | undefined,
|
|
||||||
provider_priority: 999,
|
provider_priority: 999,
|
||||||
// 状态配置
|
// 状态配置
|
||||||
is_active: true,
|
is_active: true,
|
||||||
rate_limit: undefined as number | undefined,
|
rate_limit: undefined as number | undefined,
|
||||||
concurrent_limit: undefined as number | undefined,
|
concurrent_limit: undefined as number | undefined,
|
||||||
|
// 请求配置
|
||||||
|
timeout: undefined as number | undefined,
|
||||||
|
max_retries: undefined as number | undefined,
|
||||||
|
// 代理配置(扁平化便于表单绑定)
|
||||||
|
proxy_enabled: false,
|
||||||
|
proxy_url: '',
|
||||||
|
proxy_username: '',
|
||||||
|
proxy_password: '',
|
||||||
})
|
})
|
||||||
|
|
||||||
// 重置表单
|
// 重置表单
|
||||||
function resetForm() {
|
function resetForm() {
|
||||||
form.value = {
|
form.value = {
|
||||||
name: '',
|
name: '',
|
||||||
display_name: '',
|
|
||||||
description: '',
|
description: '',
|
||||||
website: '',
|
website: '',
|
||||||
billing_type: 'pay_as_you_go',
|
billing_type: 'pay_as_you_go',
|
||||||
@@ -252,11 +293,18 @@ function resetForm() {
|
|||||||
quota_reset_day: 30,
|
quota_reset_day: 30,
|
||||||
quota_last_reset_at: '',
|
quota_last_reset_at: '',
|
||||||
quota_expires_at: '',
|
quota_expires_at: '',
|
||||||
rpm_limit: undefined,
|
|
||||||
provider_priority: 999,
|
provider_priority: 999,
|
||||||
is_active: true,
|
is_active: true,
|
||||||
rate_limit: undefined,
|
rate_limit: undefined,
|
||||||
concurrent_limit: undefined,
|
concurrent_limit: undefined,
|
||||||
|
// 请求配置
|
||||||
|
timeout: undefined,
|
||||||
|
max_retries: undefined,
|
||||||
|
// 代理配置
|
||||||
|
proxy_enabled: false,
|
||||||
|
proxy_url: '',
|
||||||
|
proxy_username: '',
|
||||||
|
proxy_password: '',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -264,9 +312,9 @@ function resetForm() {
|
|||||||
function loadProviderData() {
|
function loadProviderData() {
|
||||||
if (!props.provider) return
|
if (!props.provider) return
|
||||||
|
|
||||||
|
const proxy = props.provider.proxy
|
||||||
form.value = {
|
form.value = {
|
||||||
name: props.provider.name,
|
name: props.provider.name,
|
||||||
display_name: props.provider.display_name,
|
|
||||||
description: props.provider.description || '',
|
description: props.provider.description || '',
|
||||||
website: props.provider.website || '',
|
website: props.provider.website || '',
|
||||||
billing_type: (props.provider.billing_type as 'monthly_quota' | 'pay_as_you_go' | 'free_tier') || 'pay_as_you_go',
|
billing_type: (props.provider.billing_type as 'monthly_quota' | 'pay_as_you_go' | 'free_tier') || 'pay_as_you_go',
|
||||||
@@ -276,11 +324,18 @@ function loadProviderData() {
|
|||||||
new Date(props.provider.quota_last_reset_at).toISOString().slice(0, 16) : '',
|
new Date(props.provider.quota_last_reset_at).toISOString().slice(0, 16) : '',
|
||||||
quota_expires_at: props.provider.quota_expires_at ?
|
quota_expires_at: props.provider.quota_expires_at ?
|
||||||
new Date(props.provider.quota_expires_at).toISOString().slice(0, 16) : '',
|
new Date(props.provider.quota_expires_at).toISOString().slice(0, 16) : '',
|
||||||
rpm_limit: props.provider.rpm_limit ?? undefined,
|
|
||||||
provider_priority: props.provider.provider_priority || 999,
|
provider_priority: props.provider.provider_priority || 999,
|
||||||
is_active: props.provider.is_active,
|
is_active: props.provider.is_active,
|
||||||
rate_limit: undefined,
|
rate_limit: undefined,
|
||||||
concurrent_limit: undefined,
|
concurrent_limit: undefined,
|
||||||
|
// 请求配置
|
||||||
|
timeout: props.provider.timeout ?? undefined,
|
||||||
|
max_retries: props.provider.max_retries ?? undefined,
|
||||||
|
// 代理配置
|
||||||
|
proxy_enabled: proxy?.enabled ?? false,
|
||||||
|
proxy_url: proxy?.url || '',
|
||||||
|
proxy_username: proxy?.username || '',
|
||||||
|
proxy_password: proxy?.password || '',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,17 +357,37 @@ const handleSubmit = async () => {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 启用代理时必须填写代理地址
|
||||||
|
if (form.value.proxy_enabled && !form.value.proxy_url) {
|
||||||
|
showError('启用代理时必须填写代理地址', '验证失败')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
loading.value = true
|
loading.value = true
|
||||||
try {
|
try {
|
||||||
|
// 构建代理配置
|
||||||
|
const proxy = form.value.proxy_enabled ? {
|
||||||
|
url: form.value.proxy_url,
|
||||||
|
username: form.value.proxy_username || undefined,
|
||||||
|
password: form.value.proxy_password || undefined,
|
||||||
|
enabled: true,
|
||||||
|
} : null
|
||||||
|
|
||||||
const payload = {
|
const payload = {
|
||||||
...form.value,
|
name: form.value.name,
|
||||||
rpm_limit:
|
description: form.value.description || undefined,
|
||||||
form.value.rpm_limit === undefined || form.value.rpm_limit === ''
|
website: form.value.website || undefined,
|
||||||
? null
|
billing_type: form.value.billing_type,
|
||||||
: Number(form.value.rpm_limit),
|
monthly_quota_usd: form.value.monthly_quota_usd,
|
||||||
// 空字符串时不发送
|
quota_reset_day: form.value.quota_reset_day,
|
||||||
quota_last_reset_at: form.value.quota_last_reset_at || undefined,
|
quota_last_reset_at: form.value.quota_last_reset_at || undefined,
|
||||||
quota_expires_at: form.value.quota_expires_at || undefined,
|
quota_expires_at: form.value.quota_expires_at || undefined,
|
||||||
|
provider_priority: form.value.provider_priority,
|
||||||
|
is_active: form.value.is_active,
|
||||||
|
// 请求配置
|
||||||
|
timeout: form.value.timeout ?? undefined,
|
||||||
|
max_retries: form.value.max_retries ?? undefined,
|
||||||
|
proxy,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isEditMode.value && props.provider) {
|
if (isEditMode.value && props.provider) {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ export { default as ProviderFormDialog } from './ProviderFormDialog.vue'
|
|||||||
export { default as EndpointFormDialog } from './EndpointFormDialog.vue'
|
export { default as EndpointFormDialog } from './EndpointFormDialog.vue'
|
||||||
export { default as KeyFormDialog } from './KeyFormDialog.vue'
|
export { default as KeyFormDialog } from './KeyFormDialog.vue'
|
||||||
export { default as KeyAllowedModelsDialog } from './KeyAllowedModelsDialog.vue'
|
export { default as KeyAllowedModelsDialog } from './KeyAllowedModelsDialog.vue'
|
||||||
|
export { default as KeyAllowedModelsEditDialog } from './KeyAllowedModelsEditDialog.vue'
|
||||||
export { default as PriorityManagementDialog } from './PriorityManagementDialog.vue'
|
export { default as PriorityManagementDialog } from './PriorityManagementDialog.vue'
|
||||||
export { default as ProviderModelFormDialog } from './ProviderModelFormDialog.vue'
|
export { default as ProviderModelFormDialog } from './ProviderModelFormDialog.vue'
|
||||||
export { default as ProviderDetailDrawer } from './ProviderDetailDrawer.vue'
|
export { default as ProviderDetailDrawer } from './ProviderDetailDrawer.vue'
|
||||||
|
|||||||
@@ -178,7 +178,7 @@
|
|||||||
<Button
|
<Button
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
size="icon"
|
size="icon"
|
||||||
class="h-8 w-8 text-destructive hover:text-destructive"
|
class="h-8 w-8 hover:text-destructive"
|
||||||
title="删除"
|
title="删除"
|
||||||
@click="deleteModel(model)"
|
@click="deleteModel(model)"
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -289,14 +289,14 @@
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 错误信息卡片 -->
|
<!-- 响应客户端错误卡片 -->
|
||||||
<Card
|
<Card
|
||||||
v-if="detail.error_message"
|
v-if="detail.error_message"
|
||||||
class="border-red-200 dark:border-red-800"
|
class="border-red-200 dark:border-red-800"
|
||||||
>
|
>
|
||||||
<div class="p-4">
|
<div class="p-4">
|
||||||
<h4 class="text-sm font-semibold text-red-600 dark:text-red-400 mb-2">
|
<h4 class="text-sm font-semibold text-red-600 dark:text-red-400 mb-2">
|
||||||
错误信息
|
响应客户端错误
|
||||||
</h4>
|
</h4>
|
||||||
<div class="bg-red-50 dark:bg-red-900/20 rounded-lg p-3">
|
<div class="bg-red-50 dark:bg-red-900/20 rounded-lg p-3">
|
||||||
<p class="text-sm text-red-800 dark:text-red-300">
|
<p class="text-sm text-red-800 dark:text-red-300">
|
||||||
@@ -431,7 +431,7 @@
|
|||||||
|
|
||||||
<TabsContent value="response-headers">
|
<TabsContent value="response-headers">
|
||||||
<JsonContent
|
<JsonContent
|
||||||
:data="detail.response_headers"
|
:data="actualResponseHeaders"
|
||||||
:view-mode="viewMode"
|
:view-mode="viewMode"
|
||||||
:expand-depth="currentExpandDepth"
|
:expand-depth="currentExpandDepth"
|
||||||
:is-dark="isDark"
|
:is-dark="isDark"
|
||||||
@@ -614,6 +614,25 @@ const tabs = [
|
|||||||
{ name: 'metadata', label: '元数据' },
|
{ name: 'metadata', label: '元数据' },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
// 判断数据是否有实际内容(非空对象/数组)
|
||||||
|
function hasContent(data: unknown): boolean {
|
||||||
|
if (data === null || data === undefined) return false
|
||||||
|
if (typeof data === 'object') {
|
||||||
|
return Object.keys(data as object).length > 0
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取实际的响应头(优先 client_response_headers,回退到 response_headers)
|
||||||
|
const actualResponseHeaders = computed(() => {
|
||||||
|
if (!detail.value) return null
|
||||||
|
// 优先返回客户端响应头,如果没有则回退到提供商响应头
|
||||||
|
if (hasContent(detail.value.client_response_headers)) {
|
||||||
|
return detail.value.client_response_headers
|
||||||
|
}
|
||||||
|
return detail.value.response_headers
|
||||||
|
})
|
||||||
|
|
||||||
// 根据实际数据决定显示哪些 Tab
|
// 根据实际数据决定显示哪些 Tab
|
||||||
const visibleTabs = computed(() => {
|
const visibleTabs = computed(() => {
|
||||||
if (!detail.value) return []
|
if (!detail.value) return []
|
||||||
@@ -621,15 +640,15 @@ const visibleTabs = computed(() => {
|
|||||||
return tabs.filter(tab => {
|
return tabs.filter(tab => {
|
||||||
switch (tab.name) {
|
switch (tab.name) {
|
||||||
case 'request-headers':
|
case 'request-headers':
|
||||||
return detail.value!.request_headers && Object.keys(detail.value!.request_headers).length > 0
|
return hasContent(detail.value!.request_headers)
|
||||||
case 'request-body':
|
case 'request-body':
|
||||||
return detail.value!.request_body !== null && detail.value!.request_body !== undefined
|
return hasContent(detail.value!.request_body)
|
||||||
case 'response-headers':
|
case 'response-headers':
|
||||||
return detail.value!.response_headers && Object.keys(detail.value!.response_headers).length > 0
|
return hasContent(actualResponseHeaders.value)
|
||||||
case 'response-body':
|
case 'response-body':
|
||||||
return detail.value!.response_body !== null && detail.value!.response_body !== undefined
|
return hasContent(detail.value!.response_body)
|
||||||
case 'metadata':
|
case 'metadata':
|
||||||
return detail.value!.metadata && Object.keys(detail.value!.metadata).length > 0
|
return hasContent(detail.value!.metadata)
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -775,7 +794,7 @@ function copyJsonToClipboard(tabName: string) {
|
|||||||
data = detail.value.request_body
|
data = detail.value.request_body
|
||||||
break
|
break
|
||||||
case 'response-headers':
|
case 'response-headers':
|
||||||
data = detail.value.response_headers
|
data = actualResponseHeaders.value
|
||||||
break
|
break
|
||||||
case 'response-body':
|
case 'response-body':
|
||||||
data = detail.value.response_body
|
data = detail.value.response_body
|
||||||
|
|||||||
@@ -252,7 +252,7 @@
|
|||||||
@click.stop
|
@click.stop
|
||||||
@change="toggleSelection('allowed_providers', provider.id)"
|
@change="toggleSelection('allowed_providers', provider.id)"
|
||||||
>
|
>
|
||||||
<span class="text-sm">{{ provider.display_name || provider.name }}</span>
|
<span class="text-sm">{{ provider.name }}</span>
|
||||||
</div>
|
</div>
|
||||||
<div
|
<div
|
||||||
v-if="providers.length === 0"
|
v-if="providers.length === 0"
|
||||||
|
|||||||
@@ -295,6 +295,15 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<RouterView />
|
<RouterView />
|
||||||
|
|
||||||
|
<!-- 更新提示弹窗 -->
|
||||||
|
<UpdateDialog
|
||||||
|
v-if="updateInfo"
|
||||||
|
v-model="showUpdateDialog"
|
||||||
|
:current-version="updateInfo.current_version"
|
||||||
|
:latest-version="updateInfo.latest_version || ''"
|
||||||
|
:release-url="updateInfo.release_url"
|
||||||
|
/>
|
||||||
</AppShell>
|
</AppShell>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
@@ -304,10 +313,12 @@ import { useRoute, useRouter } from 'vue-router'
|
|||||||
import { useAuthStore } from '@/stores/auth'
|
import { useAuthStore } from '@/stores/auth'
|
||||||
import { useDarkMode } from '@/composables/useDarkMode'
|
import { useDarkMode } from '@/composables/useDarkMode'
|
||||||
import { isDemoMode } from '@/config/demo'
|
import { isDemoMode } from '@/config/demo'
|
||||||
|
import { adminApi, type CheckUpdateResponse } from '@/api/admin'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
import AppShell from '@/components/layout/AppShell.vue'
|
import AppShell from '@/components/layout/AppShell.vue'
|
||||||
import SidebarNav from '@/components/layout/SidebarNav.vue'
|
import SidebarNav from '@/components/layout/SidebarNav.vue'
|
||||||
import HeaderLogo from '@/components/HeaderLogo.vue'
|
import HeaderLogo from '@/components/HeaderLogo.vue'
|
||||||
|
import UpdateDialog from '@/components/common/UpdateDialog.vue'
|
||||||
import {
|
import {
|
||||||
Home,
|
Home,
|
||||||
Users,
|
Users,
|
||||||
@@ -345,17 +356,67 @@ const showAuthError = ref(false)
|
|||||||
const mobileMenuOpen = ref(false)
|
const mobileMenuOpen = ref(false)
|
||||||
let authCheckInterval: number | null = null
|
let authCheckInterval: number | null = null
|
||||||
|
|
||||||
|
// 更新检查相关
|
||||||
|
const showUpdateDialog = ref(false)
|
||||||
|
const updateInfo = ref<CheckUpdateResponse | null>(null)
|
||||||
|
|
||||||
// 路由变化时自动关闭移动端菜单
|
// 路由变化时自动关闭移动端菜单
|
||||||
watch(() => route.path, () => {
|
watch(() => route.path, () => {
|
||||||
mobileMenuOpen.value = false
|
mobileMenuOpen.value = false
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 检查是否应该显示更新提示
|
||||||
|
function shouldShowUpdatePrompt(latestVersion: string): boolean {
|
||||||
|
const ignoreKey = 'aether_update_ignore'
|
||||||
|
const ignoreData = localStorage.getItem(ignoreKey)
|
||||||
|
if (!ignoreData) return true
|
||||||
|
|
||||||
|
try {
|
||||||
|
const { version, until } = JSON.parse(ignoreData)
|
||||||
|
// 如果忽略的是同一版本且未过期,则不显示
|
||||||
|
if (version === latestVersion && Date.now() < until) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// 解析失败,显示提示
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查更新
|
||||||
|
async function checkForUpdate() {
|
||||||
|
// 只有管理员才检查更新
|
||||||
|
if (authStore.user?.role !== 'admin') return
|
||||||
|
|
||||||
|
// 同一会话内只检查一次
|
||||||
|
const sessionKey = 'aether_update_checked'
|
||||||
|
if (sessionStorage.getItem(sessionKey)) return
|
||||||
|
sessionStorage.setItem(sessionKey, '1')
|
||||||
|
|
||||||
|
try {
|
||||||
|
const result = await adminApi.checkUpdate()
|
||||||
|
if (result.has_update && result.latest_version) {
|
||||||
|
if (shouldShowUpdatePrompt(result.latest_version)) {
|
||||||
|
updateInfo.value = result
|
||||||
|
showUpdateDialog.value = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// 静默失败,不影响用户体验
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
authCheckInterval = setInterval(() => {
|
authCheckInterval = setInterval(() => {
|
||||||
if (authStore.user && !authStore.token) {
|
if (authStore.user && !authStore.token) {
|
||||||
showAuthError.value = true
|
showAuthError.value = true
|
||||||
}
|
}
|
||||||
}, 5000)
|
}, 5000)
|
||||||
|
|
||||||
|
// 延迟检查更新,避免影响页面加载
|
||||||
|
setTimeout(() => {
|
||||||
|
checkForUpdate()
|
||||||
|
}, 2000)
|
||||||
})
|
})
|
||||||
|
|
||||||
onUnmounted(() => {
|
onUnmounted(() => {
|
||||||
|
|||||||
@@ -424,8 +424,7 @@ export const MOCK_ADMIN_API_KEYS: AdminApiKeysResponse = {
|
|||||||
export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
|
export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
|
||||||
{
|
{
|
||||||
id: 'provider-001',
|
id: 'provider-001',
|
||||||
name: 'duck_coding_free',
|
name: 'DuckCodingFree',
|
||||||
display_name: 'DuckCodingFree',
|
|
||||||
description: '',
|
description: '',
|
||||||
website: 'https://duckcoding.com',
|
website: 'https://duckcoding.com',
|
||||||
provider_priority: 1,
|
provider_priority: 1,
|
||||||
@@ -451,8 +450,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'provider-002',
|
id: 'provider-002',
|
||||||
name: 'open_claude_code',
|
name: 'OpenClaudeCode',
|
||||||
display_name: 'OpenClaudeCode',
|
|
||||||
description: '',
|
description: '',
|
||||||
website: 'https://www.openclaudecode.cn',
|
website: 'https://www.openclaudecode.cn',
|
||||||
provider_priority: 2,
|
provider_priority: 2,
|
||||||
@@ -477,8 +475,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'provider-003',
|
id: 'provider-003',
|
||||||
name: '88_code',
|
name: '88Code',
|
||||||
display_name: '88Code',
|
|
||||||
description: '',
|
description: '',
|
||||||
website: 'https://www.88code.org/',
|
website: 'https://www.88code.org/',
|
||||||
provider_priority: 3,
|
provider_priority: 3,
|
||||||
@@ -503,8 +500,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'provider-004',
|
id: 'provider-004',
|
||||||
name: 'ikun_code',
|
name: 'IKunCode',
|
||||||
display_name: 'IKunCode',
|
|
||||||
description: '',
|
description: '',
|
||||||
website: 'https://api.ikuncode.cc',
|
website: 'https://api.ikuncode.cc',
|
||||||
provider_priority: 4,
|
provider_priority: 4,
|
||||||
@@ -531,8 +527,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'provider-005',
|
id: 'provider-005',
|
||||||
name: 'duck_coding',
|
name: 'DuckCoding',
|
||||||
display_name: 'DuckCoding',
|
|
||||||
description: '',
|
description: '',
|
||||||
website: 'https://duckcoding.com',
|
website: 'https://duckcoding.com',
|
||||||
provider_priority: 5,
|
provider_priority: 5,
|
||||||
@@ -561,8 +556,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'provider-006',
|
id: 'provider-006',
|
||||||
name: 'privnode',
|
name: 'Privnode',
|
||||||
display_name: 'Privnode',
|
|
||||||
description: '',
|
description: '',
|
||||||
website: 'https://privnode.com',
|
website: 'https://privnode.com',
|
||||||
provider_priority: 6,
|
provider_priority: 6,
|
||||||
@@ -584,8 +578,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'provider-007',
|
id: 'provider-007',
|
||||||
name: 'undying_api',
|
name: 'UndyingAPI',
|
||||||
display_name: 'UndyingAPI',
|
|
||||||
description: '',
|
description: '',
|
||||||
website: 'https://vip.undyingapi.com',
|
website: 'https://vip.undyingapi.com',
|
||||||
provider_priority: 7,
|
provider_priority: 7,
|
||||||
|
|||||||
@@ -418,16 +418,16 @@ const MOCK_ALIASES = [
|
|||||||
|
|
||||||
// Mock Endpoint Keys
|
// Mock Endpoint Keys
|
||||||
const MOCK_ENDPOINT_KEYS = [
|
const MOCK_ENDPOINT_KEYS = [
|
||||||
{ id: 'ekey-001', endpoint_id: 'ep-001', api_key_masked: 'sk-ant...abc1', name: 'Primary Key', rate_multiplier: 1.0, internal_priority: 1, health_score: 98, consecutive_failures: 0, request_count: 5000, success_count: 4950, error_count: 50, success_rate: 99, avg_response_time_ms: 1200, is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
|
{ id: 'ekey-001', provider_id: 'provider-001', api_formats: ['CLAUDE'], api_key_masked: 'sk-ant...abc1', name: 'Primary Key', rate_multiplier: 1.0, internal_priority: 1, health_score: 0.98, consecutive_failures: 0, request_count: 5000, success_count: 4950, error_count: 50, success_rate: 0.99, avg_response_time_ms: 1200, cache_ttl_minutes: 5, max_probe_interval_minutes: 32, is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
|
||||||
{ id: 'ekey-002', endpoint_id: 'ep-001', api_key_masked: 'sk-ant...def2', name: 'Backup Key', rate_multiplier: 1.0, internal_priority: 2, health_score: 95, consecutive_failures: 1, request_count: 2000, success_count: 1950, error_count: 50, success_rate: 97.5, avg_response_time_ms: 1350, is_active: true, created_at: '2024-02-01T00:00:00Z', updated_at: new Date().toISOString() },
|
{ id: 'ekey-002', provider_id: 'provider-001', api_formats: ['CLAUDE'], api_key_masked: 'sk-ant...def2', name: 'Backup Key', rate_multiplier: 1.0, internal_priority: 2, health_score: 0.95, consecutive_failures: 1, request_count: 2000, success_count: 1950, error_count: 50, success_rate: 0.975, avg_response_time_ms: 1350, cache_ttl_minutes: 5, max_probe_interval_minutes: 32, is_active: true, created_at: '2024-02-01T00:00:00Z', updated_at: new Date().toISOString() },
|
||||||
{ id: 'ekey-003', endpoint_id: 'ep-002', api_key_masked: 'sk-oai...ghi3', name: 'OpenAI Main', rate_multiplier: 1.0, internal_priority: 1, health_score: 97, consecutive_failures: 0, request_count: 3500, success_count: 3450, error_count: 50, success_rate: 98.6, avg_response_time_ms: 900, is_active: true, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
|
{ id: 'ekey-003', provider_id: 'provider-002', api_formats: ['OPENAI'], api_key_masked: 'sk-oai...ghi3', name: 'OpenAI Main', rate_multiplier: 1.0, internal_priority: 1, health_score: 0.97, consecutive_failures: 0, request_count: 3500, success_count: 3450, error_count: 50, success_rate: 0.986, avg_response_time_ms: 900, cache_ttl_minutes: 5, max_probe_interval_minutes: 32, is_active: true, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
|
||||||
]
|
]
|
||||||
|
|
||||||
// Mock Endpoints
|
// Mock Endpoints
|
||||||
const MOCK_ENDPOINTS = [
|
const MOCK_ENDPOINTS = [
|
||||||
{ id: 'ep-001', provider_id: 'provider-001', provider_name: 'anthropic', api_format: 'claude', base_url: 'https://api.anthropic.com', auth_type: 'bearer', timeout: 120, max_retries: 3, priority: 100, weight: 100, health_score: 98, consecutive_failures: 0, is_active: true, total_keys: 2, active_keys: 2, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
|
{ id: 'ep-001', provider_id: 'provider-001', provider_name: 'anthropic', api_format: 'CLAUDE', base_url: 'https://api.anthropic.com', timeout: 300, max_retries: 2, is_active: true, total_keys: 2, active_keys: 2, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
|
||||||
{ id: 'ep-002', provider_id: 'provider-002', provider_name: 'openai', api_format: 'openai', base_url: 'https://api.openai.com', auth_type: 'bearer', timeout: 60, max_retries: 3, priority: 90, weight: 100, health_score: 97, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
|
{ id: 'ep-002', provider_id: 'provider-002', provider_name: 'openai', api_format: 'OPENAI', base_url: 'https://api.openai.com', timeout: 60, max_retries: 2, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
|
||||||
{ id: 'ep-003', provider_id: 'provider-003', provider_name: 'google', api_format: 'gemini', base_url: 'https://generativelanguage.googleapis.com', auth_type: 'api_key', timeout: 60, max_retries: 3, priority: 80, weight: 100, health_score: 96, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
|
{ id: 'ep-003', provider_id: 'provider-003', provider_name: 'google', api_format: 'GEMINI', base_url: 'https://generativelanguage.googleapis.com', timeout: 60, max_retries: 2, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
|
||||||
]
|
]
|
||||||
|
|
||||||
// Mock 能力定义
|
// Mock 能力定义
|
||||||
@@ -581,7 +581,6 @@ const mockHandlers: Record<string, (config: AxiosRequestConfig) => Promise<Axios
|
|||||||
return createMockResponse(MOCK_PROVIDERS.map(p => ({
|
return createMockResponse(MOCK_PROVIDERS.map(p => ({
|
||||||
id: p.id,
|
id: p.id,
|
||||||
name: p.name,
|
name: p.name,
|
||||||
display_name: p.display_name,
|
|
||||||
is_active: p.is_active
|
is_active: p.is_active
|
||||||
})))
|
})))
|
||||||
},
|
},
|
||||||
@@ -1222,13 +1221,8 @@ function generateMockEndpointsForProvider(providerId: string) {
|
|||||||
base_url: format.includes('CLAUDE') ? 'https://api.anthropic.com' :
|
base_url: format.includes('CLAUDE') ? 'https://api.anthropic.com' :
|
||||||
format.includes('OPENAI') ? 'https://api.openai.com' :
|
format.includes('OPENAI') ? 'https://api.openai.com' :
|
||||||
'https://generativelanguage.googleapis.com',
|
'https://generativelanguage.googleapis.com',
|
||||||
auth_type: format.includes('GEMINI') ? 'api_key' : 'bearer',
|
timeout: 300,
|
||||||
timeout: 120,
|
max_retries: 2,
|
||||||
max_retries: 3,
|
|
||||||
priority: 100 - index * 10,
|
|
||||||
weight: 100,
|
|
||||||
health_score: healthDetail?.health_score ?? 1.0,
|
|
||||||
consecutive_failures: healthDetail?.health_score && healthDetail.health_score < 0.7 ? 2 : 0,
|
|
||||||
is_active: healthDetail?.is_active ?? true,
|
is_active: healthDetail?.is_active ?? true,
|
||||||
total_keys: Math.ceil(Math.random() * 3) + 1,
|
total_keys: Math.ceil(Math.random() * 3) + 1,
|
||||||
active_keys: Math.ceil(Math.random() * 2) + 1,
|
active_keys: Math.ceil(Math.random() * 2) + 1,
|
||||||
@@ -1238,11 +1232,16 @@ function generateMockEndpointsForProvider(providerId: string) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 为 endpoint 生成 keys
|
// 为 provider 生成 keys(Key 归属 Provider,通过 api_formats 关联)
|
||||||
function generateMockKeysForEndpoint(endpointId: string, count: number = 2) {
|
const PROVIDER_KEYS_CACHE: Record<string, any[]> = {}
|
||||||
|
function generateMockKeysForProvider(providerId: string, count: number = 2) {
|
||||||
|
const provider = MOCK_PROVIDERS.find(p => p.id === providerId)
|
||||||
|
const formats = provider?.api_formats || []
|
||||||
|
|
||||||
return Array.from({ length: count }, (_, i) => ({
|
return Array.from({ length: count }, (_, i) => ({
|
||||||
id: `key-${endpointId}-${i + 1}`,
|
id: `key-${providerId}-${i + 1}`,
|
||||||
endpoint_id: endpointId,
|
provider_id: providerId,
|
||||||
|
api_formats: i === 0 ? formats : formats.slice(0, 1),
|
||||||
api_key_masked: `sk-***...${Math.random().toString(36).substring(2, 6)}`,
|
api_key_masked: `sk-***...${Math.random().toString(36).substring(2, 6)}`,
|
||||||
name: i === 0 ? 'Primary Key' : `Backup Key ${i}`,
|
name: i === 0 ? 'Primary Key' : `Backup Key ${i}`,
|
||||||
rate_multiplier: 1.0,
|
rate_multiplier: 1.0,
|
||||||
@@ -1254,6 +1253,8 @@ function generateMockKeysForEndpoint(endpointId: string, count: number = 2) {
|
|||||||
error_count: Math.floor(Math.random() * 100),
|
error_count: Math.floor(Math.random() * 100),
|
||||||
success_rate: 0.95 + Math.random() * 0.04, // 0.95-0.99
|
success_rate: 0.95 + Math.random() * 0.04, // 0.95-0.99
|
||||||
avg_response_time_ms: 800 + Math.floor(Math.random() * 600),
|
avg_response_time_ms: 800 + Math.floor(Math.random() * 600),
|
||||||
|
cache_ttl_minutes: 5,
|
||||||
|
max_probe_interval_minutes: 32,
|
||||||
is_active: true,
|
is_active: true,
|
||||||
created_at: '2024-01-01T00:00:00Z',
|
created_at: '2024-01-01T00:00:00Z',
|
||||||
updated_at: new Date().toISOString()
|
updated_at: new Date().toISOString()
|
||||||
@@ -1463,29 +1464,63 @@ registerDynamicRoute('PUT', '/api/admin/endpoints/:endpointId', async (config, p
|
|||||||
registerDynamicRoute('DELETE', '/api/admin/endpoints/:endpointId', async (_config, _params) => {
|
registerDynamicRoute('DELETE', '/api/admin/endpoints/:endpointId', async (_config, _params) => {
|
||||||
await delay()
|
await delay()
|
||||||
requireAdmin()
|
requireAdmin()
|
||||||
return createMockResponse({ message: '删除成功(演示模式)' })
|
return createMockResponse({ message: '删除成功(演示模式)', affected_keys_count: 0 })
|
||||||
})
|
})
|
||||||
|
|
||||||
// Endpoint Keys 列表
|
// Provider Keys 列表
|
||||||
registerDynamicRoute('GET', '/api/admin/endpoints/:endpointId/keys', async (_config, params) => {
|
registerDynamicRoute('GET', '/api/admin/endpoints/providers/:providerId/keys', async (_config, params) => {
|
||||||
await delay()
|
await delay()
|
||||||
requireAdmin()
|
requireAdmin()
|
||||||
const keys = generateMockKeysForEndpoint(params.endpointId, 2)
|
if (!PROVIDER_KEYS_CACHE[params.providerId]) {
|
||||||
return createMockResponse(keys)
|
PROVIDER_KEYS_CACHE[params.providerId] = generateMockKeysForProvider(params.providerId, 2)
|
||||||
|
}
|
||||||
|
return createMockResponse(PROVIDER_KEYS_CACHE[params.providerId])
|
||||||
})
|
})
|
||||||
|
|
||||||
// 创建 Key
|
// 为 Provider 创建 Key
|
||||||
registerDynamicRoute('POST', '/api/admin/endpoints/:endpointId/keys', async (config, params) => {
|
registerDynamicRoute('POST', '/api/admin/endpoints/providers/:providerId/keys', async (config, params) => {
|
||||||
await delay()
|
await delay()
|
||||||
requireAdmin()
|
requireAdmin()
|
||||||
const body = JSON.parse(config.data || '{}')
|
const body = JSON.parse(config.data || '{}')
|
||||||
return createMockResponse({
|
const apiKeyPlain = body.api_key || 'sk-demo'
|
||||||
|
const masked = apiKeyPlain.length >= 12
|
||||||
|
? `${apiKeyPlain.slice(0, 8)}***${apiKeyPlain.slice(-4)}`
|
||||||
|
: 'sk-***...demo'
|
||||||
|
|
||||||
|
const newKey = {
|
||||||
id: `key-demo-${Date.now()}`,
|
id: `key-demo-${Date.now()}`,
|
||||||
endpoint_id: params.endpointId,
|
provider_id: params.providerId,
|
||||||
api_key_masked: 'sk-***...demo',
|
api_formats: body.api_formats || [],
|
||||||
...body,
|
api_key_masked: masked,
|
||||||
created_at: new Date().toISOString()
|
api_key_plain: null,
|
||||||
})
|
name: body.name || 'New Key',
|
||||||
|
note: body.note,
|
||||||
|
rate_multiplier: body.rate_multiplier ?? 1.0,
|
||||||
|
rate_multipliers: body.rate_multipliers ?? null,
|
||||||
|
internal_priority: body.internal_priority ?? 50,
|
||||||
|
global_priority: body.global_priority ?? null,
|
||||||
|
rpm_limit: body.rpm_limit ?? null,
|
||||||
|
allowed_models: body.allowed_models ?? null,
|
||||||
|
capabilities: body.capabilities ?? null,
|
||||||
|
cache_ttl_minutes: body.cache_ttl_minutes ?? 5,
|
||||||
|
max_probe_interval_minutes: body.max_probe_interval_minutes ?? 32,
|
||||||
|
health_score: 1.0,
|
||||||
|
consecutive_failures: 0,
|
||||||
|
request_count: 0,
|
||||||
|
success_count: 0,
|
||||||
|
error_count: 0,
|
||||||
|
success_rate: 0.0,
|
||||||
|
avg_response_time_ms: 0.0,
|
||||||
|
is_active: true,
|
||||||
|
created_at: new Date().toISOString(),
|
||||||
|
updated_at: new Date().toISOString(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!PROVIDER_KEYS_CACHE[params.providerId]) {
|
||||||
|
PROVIDER_KEYS_CACHE[params.providerId] = []
|
||||||
|
}
|
||||||
|
PROVIDER_KEYS_CACHE[params.providerId].push(newKey)
|
||||||
|
return createMockResponse(newKey)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Key 更新
|
// Key 更新
|
||||||
@@ -1503,6 +1538,50 @@ registerDynamicRoute('DELETE', '/api/admin/endpoints/keys/:keyId', async (_confi
|
|||||||
return createMockResponse({ message: '删除成功(演示模式)' })
|
return createMockResponse({ message: '删除成功(演示模式)' })
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Key Reveal
|
||||||
|
registerDynamicRoute('GET', '/api/admin/endpoints/keys/:keyId/reveal', async (_config, _params) => {
|
||||||
|
await delay()
|
||||||
|
requireAdmin()
|
||||||
|
return createMockResponse({ api_key: 'sk-demo-reveal' })
|
||||||
|
})
|
||||||
|
|
||||||
|
// Keys grouped by format
|
||||||
|
mockHandlers['GET /api/admin/endpoints/keys/grouped-by-format'] = async () => {
|
||||||
|
await delay()
|
||||||
|
requireAdmin()
|
||||||
|
|
||||||
|
// 确保每个 provider 都有 key 数据
|
||||||
|
for (const provider of MOCK_PROVIDERS) {
|
||||||
|
if (!PROVIDER_KEYS_CACHE[provider.id]) {
|
||||||
|
PROVIDER_KEYS_CACHE[provider.id] = generateMockKeysForProvider(provider.id, 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const grouped: Record<string, any[]> = {}
|
||||||
|
for (const provider of MOCK_PROVIDERS) {
|
||||||
|
const endpoints = generateMockEndpointsForProvider(provider.id)
|
||||||
|
const baseUrlByFormat = Object.fromEntries(endpoints.map(e => [e.api_format, e.base_url]))
|
||||||
|
const keys = PROVIDER_KEYS_CACHE[provider.id] || []
|
||||||
|
for (const key of keys) {
|
||||||
|
const formats: string[] = key.api_formats || []
|
||||||
|
for (const fmt of formats) {
|
||||||
|
if (!grouped[fmt]) grouped[fmt] = []
|
||||||
|
grouped[fmt].push({
|
||||||
|
...key,
|
||||||
|
api_format: fmt,
|
||||||
|
provider_name: provider.name,
|
||||||
|
endpoint_base_url: baseUrlByFormat[fmt],
|
||||||
|
global_priority: key.global_priority ?? null,
|
||||||
|
circuit_breaker_open: false,
|
||||||
|
capabilities: [],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return createMockResponse(grouped)
|
||||||
|
}
|
||||||
|
|
||||||
// Provider Models 列表
|
// Provider Models 列表
|
||||||
registerDynamicRoute('GET', '/api/admin/providers/:providerId/models', async (_config, params) => {
|
registerDynamicRoute('GET', '/api/admin/providers/:providerId/models', async (_config, params) => {
|
||||||
await delay()
|
await delay()
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ interface ValidationError {
|
|||||||
const fieldNameMap: Record<string, string> = {
|
const fieldNameMap: Record<string, string> = {
|
||||||
'api_key': 'API 密钥',
|
'api_key': 'API 密钥',
|
||||||
'priority': '优先级',
|
'priority': '优先级',
|
||||||
'max_concurrent': '最大并发',
|
'rpm_limit': 'RPM 限制',
|
||||||
'rate_limit': '速率限制',
|
'rate_limit': '速率限制',
|
||||||
'daily_limit': '每日限制',
|
'daily_limit': '每日限制',
|
||||||
'monthly_limit': '每月限制',
|
'monthly_limit': '每月限制',
|
||||||
@@ -44,7 +44,6 @@ const fieldNameMap: Record<string, string> = {
|
|||||||
'monthly_quota_usd': '月度配额',
|
'monthly_quota_usd': '月度配额',
|
||||||
'quota_reset_day': '配额重置日',
|
'quota_reset_day': '配额重置日',
|
||||||
'quota_expires_at': '配额过期时间',
|
'quota_expires_at': '配额过期时间',
|
||||||
'rpm_limit': 'RPM 限制',
|
|
||||||
'cache_ttl_minutes': '缓存 TTL',
|
'cache_ttl_minutes': '缓存 TTL',
|
||||||
'max_probe_interval_minutes': '最大探测间隔',
|
'max_probe_interval_minutes': '最大探测间隔',
|
||||||
}
|
}
|
||||||
@@ -54,7 +53,7 @@ const fieldNameMap: Record<string, string> = {
|
|||||||
*/
|
*/
|
||||||
const errorTypeMap: Record<string, (error: ValidationError) => string> = {
|
const errorTypeMap: Record<string, (error: ValidationError) => string> = {
|
||||||
'string_too_short': (error) => {
|
'string_too_short': (error) => {
|
||||||
const minLength = error.ctx?.min_length || 10
|
const minLength = error.ctx?.min_length || 3
|
||||||
return `长度不能少于 ${minLength} 个字符`
|
return `长度不能少于 ${minLength} 个字符`
|
||||||
},
|
},
|
||||||
'string_too_long': (error) => {
|
'string_too_long': (error) => {
|
||||||
@@ -151,11 +150,18 @@ export function parseApiError(err: unknown, defaultMessage: string = '操作失
|
|||||||
return '无法连接到服务器,请检查网络连接'
|
return '无法连接到服务器,请检查网络连接'
|
||||||
}
|
}
|
||||||
|
|
||||||
const detail = err.response?.data?.detail
|
const data = err.response?.data
|
||||||
|
|
||||||
|
// 1. 处理 {error: {type, message}} 格式(ProxyException 返回格式)
|
||||||
|
if (data?.error?.message) {
|
||||||
|
return data.error.message
|
||||||
|
}
|
||||||
|
|
||||||
|
const detail = data?.detail
|
||||||
|
|
||||||
// 如果没有 detail 字段
|
// 如果没有 detail 字段
|
||||||
if (!detail) {
|
if (!detail) {
|
||||||
return err.response?.data?.message || err.message || defaultMessage
|
return data?.message || err.message || defaultMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. 处理 Pydantic 验证错误(数组格式)
|
// 1. 处理 Pydantic 验证错误(数组格式)
|
||||||
|
|||||||
@@ -54,6 +54,57 @@ export function parseNumberInput(
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse number input value for nullable fields (like rpm_limit)
|
||||||
|
* Returns `null` when empty (to signal "use adaptive/default mode")
|
||||||
|
* Returns `undefined` when not provided (to signal "keep original value")
|
||||||
|
*
|
||||||
|
* @param value - Input value (string or number)
|
||||||
|
* @param options - Parse options
|
||||||
|
* @returns Parsed number, null (for empty/adaptive), or undefined
|
||||||
|
*/
|
||||||
|
export function parseNullableNumberInput(
|
||||||
|
value: string | number | null | undefined,
|
||||||
|
options: {
|
||||||
|
allowFloat?: boolean
|
||||||
|
min?: number
|
||||||
|
max?: number
|
||||||
|
} = {}
|
||||||
|
): number | null | undefined {
|
||||||
|
const { allowFloat = false, min, max } = options
|
||||||
|
|
||||||
|
// Empty string means "null" (adaptive mode)
|
||||||
|
if (value === '') {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
// null/undefined means "keep original value"
|
||||||
|
if (value === null || value === undefined) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the value
|
||||||
|
const num = typeof value === 'string'
|
||||||
|
? (allowFloat ? parseFloat(value) : parseInt(value, 10))
|
||||||
|
: value
|
||||||
|
|
||||||
|
// Handle NaN - treat as null (adaptive mode)
|
||||||
|
if (isNaN(num)) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply min/max constraints
|
||||||
|
let result = num
|
||||||
|
if (min !== undefined && result < min) {
|
||||||
|
result = min
|
||||||
|
}
|
||||||
|
if (max !== undefined && result > max) {
|
||||||
|
result = max
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a handler function for number input with specific field
|
* Create a handler function for number input with specific field
|
||||||
* Useful for creating inline handlers in templates
|
* Useful for creating inline handlers in templates
|
||||||
|
|||||||
@@ -530,9 +530,6 @@
|
|||||||
/>
|
/>
|
||||||
<div class="flex-1 min-w-0">
|
<div class="flex-1 min-w-0">
|
||||||
<p class="font-medium text-sm truncate">
|
<p class="font-medium text-sm truncate">
|
||||||
{{ provider.display_name || provider.name }}
|
|
||||||
</p>
|
|
||||||
<p class="text-xs text-muted-foreground truncate">
|
|
||||||
{{ provider.name }}
|
{{ provider.name }}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
@@ -645,10 +642,7 @@
|
|||||||
/>
|
/>
|
||||||
<div class="flex-1 min-w-0">
|
<div class="flex-1 min-w-0">
|
||||||
<p class="font-medium text-sm truncate">
|
<p class="font-medium text-sm truncate">
|
||||||
{{ provider.display_name }}
|
{{ provider.name }}
|
||||||
</p>
|
|
||||||
<p class="text-xs text-muted-foreground truncate">
|
|
||||||
{{ provider.identifier }}
|
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<Badge
|
<Badge
|
||||||
@@ -679,7 +673,7 @@
|
|||||||
<ProviderModelFormDialog
|
<ProviderModelFormDialog
|
||||||
:open="editProviderDialogOpen"
|
:open="editProviderDialogOpen"
|
||||||
:provider-id="editingProvider?.id || ''"
|
:provider-id="editingProvider?.id || ''"
|
||||||
:provider-name="editingProvider?.display_name || ''"
|
:provider-name="editingProvider?.name || ''"
|
||||||
:editing-model="editingProviderModel"
|
:editing-model="editingProviderModel"
|
||||||
@update:open="handleEditProviderDialogUpdate"
|
@update:open="handleEditProviderDialogUpdate"
|
||||||
@saved="handleEditProviderSaved"
|
@saved="handleEditProviderSaved"
|
||||||
@@ -939,7 +933,7 @@ async function batchAddSelectedProviders() {
|
|||||||
const errorMessages = result.errors
|
const errorMessages = result.errors
|
||||||
.map(e => {
|
.map(e => {
|
||||||
const provider = providerOptions.value.find(p => p.id === e.provider_id)
|
const provider = providerOptions.value.find(p => p.id === e.provider_id)
|
||||||
const providerName = provider?.display_name || provider?.name || e.provider_id
|
const providerName = provider?.name || e.provider_id
|
||||||
return `${providerName}: ${e.error}`
|
return `${providerName}: ${e.error}`
|
||||||
})
|
})
|
||||||
.join('\n')
|
.join('\n')
|
||||||
@@ -977,7 +971,7 @@ async function batchRemoveSelectedProviders() {
|
|||||||
await deleteModel(providerId, provider.model_id)
|
await deleteModel(providerId, provider.model_id)
|
||||||
successCount++
|
successCount++
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
errors.push(`${provider.display_name}: ${parseApiError(err, '删除失败')}`)
|
errors.push(`${provider.name}: ${parseApiError(err, '删除失败')}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1088,8 +1082,7 @@ async function loadModelProviders(_globalModelId: string) {
|
|||||||
selectedModelProviders.value = response.providers.map(p => ({
|
selectedModelProviders.value = response.providers.map(p => ({
|
||||||
id: p.provider_id,
|
id: p.provider_id,
|
||||||
model_id: p.model_id,
|
model_id: p.model_id,
|
||||||
display_name: p.provider_display_name || p.provider_name,
|
name: p.provider_name,
|
||||||
identifier: p.provider_name,
|
|
||||||
provider_type: 'API',
|
provider_type: 'API',
|
||||||
target_model: p.target_model,
|
target_model: p.target_model,
|
||||||
is_active: p.is_active,
|
is_active: p.is_active,
|
||||||
@@ -1219,7 +1212,7 @@ async function confirmDeleteProviderImplementation(provider: any) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const confirmed = await confirmDanger(
|
const confirmed = await confirmDanger(
|
||||||
`确定要删除 ${provider.display_name} 的模型关联吗?\n\n模型: ${provider.target_model}\n\n此操作不可恢复!`,
|
`确定要删除 ${provider.name} 的模型关联吗?\n\n模型: ${provider.target_model}\n\n此操作不可恢复!`,
|
||||||
'删除关联提供商'
|
'删除关联提供商'
|
||||||
)
|
)
|
||||||
if (!confirmed) return
|
if (!confirmed) return
|
||||||
@@ -1227,7 +1220,7 @@ async function confirmDeleteProviderImplementation(provider: any) {
|
|||||||
try {
|
try {
|
||||||
const { deleteModel } = await import('@/api/endpoints')
|
const { deleteModel } = await import('@/api/endpoints')
|
||||||
await deleteModel(provider.id, provider.model_id)
|
await deleteModel(provider.id, provider.model_id)
|
||||||
success(`已删除 ${provider.display_name} 的模型实现`)
|
success(`已删除 ${provider.name} 的模型实现`)
|
||||||
// 重新加载 Provider 列表
|
// 重新加载 Provider 列表
|
||||||
if (selectedModel.value) {
|
if (selectedModel.value) {
|
||||||
await loadModelProviders(selectedModel.value.id)
|
await loadModelProviders(selectedModel.value.id)
|
||||||
|
|||||||
@@ -134,10 +134,7 @@
|
|||||||
@click="handleRowClick($event, provider.id)"
|
@click="handleRowClick($event, provider.id)"
|
||||||
>
|
>
|
||||||
<TableCell class="py-3.5">
|
<TableCell class="py-3.5">
|
||||||
<div class="flex flex-col gap-0.5">
|
<span class="text-sm font-medium text-foreground">{{ provider.name }}</span>
|
||||||
<span class="text-sm font-medium text-foreground">{{ provider.display_name }}</span>
|
|
||||||
<span class="text-xs text-muted-foreground/70 font-mono">{{ provider.name }}</span>
|
|
||||||
</div>
|
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell class="py-3.5">
|
<TableCell class="py-3.5">
|
||||||
<Badge
|
<Badge
|
||||||
@@ -219,17 +216,10 @@
|
|||||||
>${{ (provider.monthly_used_usd ?? 0).toFixed(2) }}</span> / <span class="font-medium">${{ (provider.monthly_quota_usd ?? 0).toFixed(2) }}</span>
|
>${{ (provider.monthly_used_usd ?? 0).toFixed(2) }}</span> / <span class="font-medium">${{ (provider.monthly_quota_usd ?? 0).toFixed(2) }}</span>
|
||||||
</div>
|
</div>
|
||||||
<div
|
<div
|
||||||
v-if="rpmUsage(provider)"
|
v-else
|
||||||
class="flex items-center gap-1"
|
|
||||||
>
|
|
||||||
<span class="text-muted-foreground/70">RPM:</span>
|
|
||||||
<span class="font-medium text-foreground/80">{{ rpmUsage(provider) }}</span>
|
|
||||||
</div>
|
|
||||||
<div
|
|
||||||
v-if="provider.billing_type !== 'monthly_quota' && !rpmUsage(provider)"
|
|
||||||
class="text-muted-foreground/50"
|
class="text-muted-foreground/50"
|
||||||
>
|
>
|
||||||
无限制
|
按量付费
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
@@ -304,7 +294,7 @@
|
|||||||
<div class="flex items-start justify-between gap-3">
|
<div class="flex items-start justify-between gap-3">
|
||||||
<div class="flex-1 min-w-0">
|
<div class="flex-1 min-w-0">
|
||||||
<div class="flex items-center gap-2">
|
<div class="flex items-center gap-2">
|
||||||
<span class="font-medium text-foreground truncate">{{ provider.display_name }}</span>
|
<span class="font-medium text-foreground truncate">{{ provider.name }}</span>
|
||||||
<Badge
|
<Badge
|
||||||
:variant="provider.is_active ? 'success' : 'secondary'"
|
:variant="provider.is_active ? 'success' : 'secondary'"
|
||||||
class="text-xs shrink-0"
|
class="text-xs shrink-0"
|
||||||
@@ -312,7 +302,6 @@
|
|||||||
{{ provider.is_active ? '活跃' : '停用' }}
|
{{ provider.is_active ? '活跃' : '停用' }}
|
||||||
</Badge>
|
</Badge>
|
||||||
</div>
|
</div>
|
||||||
<span class="text-xs text-muted-foreground/70 font-mono">{{ provider.name }}</span>
|
|
||||||
</div>
|
</div>
|
||||||
<div
|
<div
|
||||||
class="flex items-center gap-0.5 shrink-0"
|
class="flex items-center gap-0.5 shrink-0"
|
||||||
@@ -383,20 +372,17 @@
|
|||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 第四行:配额/限流 -->
|
<!-- 第四行:配额 -->
|
||||||
<div
|
<div
|
||||||
v-if="provider.billing_type === 'monthly_quota' || rpmUsage(provider)"
|
v-if="provider.billing_type === 'monthly_quota'"
|
||||||
class="flex items-center gap-3 text-xs text-muted-foreground"
|
class="flex items-center gap-3 text-xs text-muted-foreground"
|
||||||
>
|
>
|
||||||
<span v-if="provider.billing_type === 'monthly_quota'">
|
<span>
|
||||||
配额: <span
|
配额: <span
|
||||||
class="font-semibold"
|
class="font-semibold"
|
||||||
:class="getQuotaUsedColorClass(provider)"
|
:class="getQuotaUsedColorClass(provider)"
|
||||||
>${{ (provider.monthly_used_usd ?? 0).toFixed(2) }}</span> / ${{ (provider.monthly_quota_usd ?? 0).toFixed(2) }}
|
>${{ (provider.monthly_used_usd ?? 0).toFixed(2) }}</span> / ${{ (provider.monthly_quota_usd ?? 0).toFixed(2) }}
|
||||||
</span>
|
</span>
|
||||||
<span v-if="rpmUsage(provider)">
|
|
||||||
RPM: {{ rpmUsage(provider) }}
|
|
||||||
</span>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -509,7 +495,7 @@ const filteredProviders = computed(() => {
|
|||||||
if (searchQuery.value.trim()) {
|
if (searchQuery.value.trim()) {
|
||||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||||
result = result.filter(p => {
|
result = result.filter(p => {
|
||||||
const searchableText = `${p.display_name} ${p.name}`.toLowerCase()
|
const searchableText = `${p.name}`.toLowerCase()
|
||||||
return keywords.every(keyword => searchableText.includes(keyword))
|
return keywords.every(keyword => searchableText.includes(keyword))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -525,7 +511,7 @@ const filteredProviders = computed(() => {
|
|||||||
return a.provider_priority - b.provider_priority
|
return a.provider_priority - b.provider_priority
|
||||||
}
|
}
|
||||||
// 3. 按名称排序
|
// 3. 按名称排序
|
||||||
return a.display_name.localeCompare(b.display_name)
|
return a.name.localeCompare(b.name)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -586,7 +572,10 @@ function sortEndpoints(endpoints: any[]) {
|
|||||||
|
|
||||||
// 判断端点是否可用(有 key)
|
// 判断端点是否可用(有 key)
|
||||||
function isEndpointAvailable(endpoint: any, _provider: ProviderWithEndpointsSummary): boolean {
|
function isEndpointAvailable(endpoint: any, _provider: ProviderWithEndpointsSummary): boolean {
|
||||||
// 检查该端点是否有活跃的密钥
|
// 检查端点是否启用,以及是否有活跃的密钥
|
||||||
|
if (endpoint.is_active === false) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
return (endpoint.active_keys ?? 0) > 0
|
return (endpoint.active_keys ?? 0) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -639,21 +628,6 @@ function getQuotaUsedColorClass(provider: ProviderWithEndpointsSummary): string
|
|||||||
return 'text-foreground'
|
return 'text-foreground'
|
||||||
}
|
}
|
||||||
|
|
||||||
function rpmUsage(provider: ProviderWithEndpointsSummary): string | null {
|
|
||||||
const rpmLimit = provider.rpm_limit
|
|
||||||
const rpmUsed = provider.rpm_used ?? 0
|
|
||||||
|
|
||||||
if (rpmLimit === null || rpmLimit === undefined) {
|
|
||||||
return rpmUsed > 0 ? `${rpmUsed}` : null
|
|
||||||
}
|
|
||||||
|
|
||||||
if (rpmLimit === 0) {
|
|
||||||
return '已完全禁止'
|
|
||||||
}
|
|
||||||
|
|
||||||
return `${rpmUsed} / ${rpmLimit}`
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用复用的行点击逻辑
|
// 使用复用的行点击逻辑
|
||||||
const { handleMouseDown, shouldTriggerRowClick } = useRowClick()
|
const { handleMouseDown, shouldTriggerRowClick } = useRowClick()
|
||||||
|
|
||||||
@@ -706,7 +680,7 @@ function handleProviderAdded() {
|
|||||||
async function handleDeleteProvider(provider: ProviderWithEndpointsSummary) {
|
async function handleDeleteProvider(provider: ProviderWithEndpointsSummary) {
|
||||||
const confirmed = await confirmDanger(
|
const confirmed = await confirmDanger(
|
||||||
'删除提供商',
|
'删除提供商',
|
||||||
`确定要删除提供商 "${provider.display_name}" 吗?\n\n这将同时删除其所有端点、密钥和配置。此操作不可恢复!`
|
`确定要删除提供商 "${provider.name}" 吗?\n\n这将同时删除其所有端点、密钥和配置。此操作不可恢复!`
|
||||||
)
|
)
|
||||||
|
|
||||||
if (!confirmed) return
|
if (!confirmed) return
|
||||||
|
|||||||
@@ -511,7 +511,7 @@
|
|||||||
端点: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.endpoints?.length || 0), 0) }} 个
|
端点: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.endpoints?.length || 0), 0) }} 个
|
||||||
</li>
|
</li>
|
||||||
<li>
|
<li>
|
||||||
API Keys: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + p.endpoints?.reduce((s: number, e: any) => s + (e.keys?.length || 0), 0), 0) }} 个
|
API Keys: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.api_keys?.length || 0), 0) }} 个
|
||||||
</li>
|
</li>
|
||||||
</ul>
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
@@ -1144,7 +1144,7 @@ function handleConfigFileSelect(event: Event) {
|
|||||||
const data = JSON.parse(content) as ConfigExportData
|
const data = JSON.parse(content) as ConfigExportData
|
||||||
|
|
||||||
// 验证版本
|
// 验证版本
|
||||||
if (data.version !== '1.0') {
|
if (data.version !== '2.0') {
|
||||||
error(`不支持的配置版本: ${data.version}`)
|
error(`不支持的配置版本: ${data.version}`)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,8 @@
|
|||||||
|
|
||||||
<!-- 主要统计卡片 -->
|
<!-- 主要统计卡片 -->
|
||||||
<div class="grid grid-cols-2 gap-3 sm:gap-4 xl:grid-cols-4">
|
<div class="grid grid-cols-2 gap-3 sm:gap-4 xl:grid-cols-4">
|
||||||
<template v-if="loading && stats.length === 0">
|
<!-- 加载中骨架屏 -->
|
||||||
|
<template v-if="loading">
|
||||||
<Card
|
<Card
|
||||||
v-for="i in 4"
|
v-for="i in 4"
|
||||||
:key="'skeleton-' + i"
|
:key="'skeleton-' + i"
|
||||||
@@ -27,62 +28,98 @@
|
|||||||
<Skeleton class="h-4 w-16" />
|
<Skeleton class="h-4 w-16" />
|
||||||
</Card>
|
</Card>
|
||||||
</template>
|
</template>
|
||||||
<Card
|
<!-- 有数据时显示统计卡片 -->
|
||||||
v-for="(stat, index) in stats"
|
<template v-else-if="stats.length > 0">
|
||||||
v-else
|
<Card
|
||||||
:key="stat.name"
|
v-for="(stat, index) in stats"
|
||||||
class="relative overflow-hidden p-3 sm:p-5"
|
:key="stat.name"
|
||||||
:class="statCardBorders[index % statCardBorders.length]"
|
class="relative overflow-hidden p-3 sm:p-5"
|
||||||
>
|
:class="statCardBorders[index % statCardBorders.length]"
|
||||||
<div
|
|
||||||
class="pointer-events-none absolute -right-4 -top-6 h-28 w-28 rounded-full blur-3xl opacity-40"
|
|
||||||
:class="statCardGlows[index % statCardGlows.length]"
|
|
||||||
/>
|
|
||||||
<!-- 图标固定在右上角 -->
|
|
||||||
<div
|
|
||||||
class="absolute top-3 right-3 sm:top-5 sm:right-5 rounded-xl sm:rounded-2xl border border-border bg-card/50 p-2 sm:p-3 shadow-inner backdrop-blur-sm"
|
|
||||||
:class="getStatIconColor(index)"
|
|
||||||
>
|
>
|
||||||
<component
|
|
||||||
:is="stat.icon"
|
|
||||||
class="h-4 w-4 sm:h-5 sm:w-5"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<!-- 内容区域 -->
|
|
||||||
<div>
|
|
||||||
<p class="text-[9px] sm:text-[11px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.4em] text-muted-foreground pr-10 sm:pr-14">
|
|
||||||
{{ stat.name }}
|
|
||||||
</p>
|
|
||||||
<p class="mt-2 sm:mt-4 text-xl sm:text-3xl font-semibold text-foreground">
|
|
||||||
{{ stat.value }}
|
|
||||||
</p>
|
|
||||||
<p
|
|
||||||
v-if="stat.subValue"
|
|
||||||
class="mt-0.5 sm:mt-1 text-[10px] sm:text-sm text-muted-foreground"
|
|
||||||
>
|
|
||||||
{{ stat.subValue }}
|
|
||||||
</p>
|
|
||||||
<div
|
<div
|
||||||
v-if="stat.change || stat.extraBadge"
|
class="pointer-events-none absolute -right-4 -top-6 h-28 w-28 rounded-full blur-3xl opacity-40"
|
||||||
class="mt-1.5 sm:mt-2 flex items-center gap-1 sm:gap-1.5 flex-wrap"
|
:class="statCardGlows[index % statCardGlows.length]"
|
||||||
|
/>
|
||||||
|
<!-- 图标固定在右上角 -->
|
||||||
|
<div
|
||||||
|
class="absolute top-3 right-3 sm:top-5 sm:right-5 rounded-xl sm:rounded-2xl border border-border bg-card/50 p-2 sm:p-3 shadow-inner backdrop-blur-sm"
|
||||||
|
:class="getStatIconColor(index)"
|
||||||
>
|
>
|
||||||
<Badge
|
<component
|
||||||
v-if="stat.change"
|
:is="stat.icon"
|
||||||
variant="secondary"
|
class="h-4 w-4 sm:h-5 sm:w-5"
|
||||||
class="text-[9px] sm:text-xs"
|
/>
|
||||||
>
|
|
||||||
{{ stat.change }}
|
|
||||||
</Badge>
|
|
||||||
<Badge
|
|
||||||
v-if="stat.extraBadge"
|
|
||||||
variant="secondary"
|
|
||||||
class="text-[9px] sm:text-xs"
|
|
||||||
>
|
|
||||||
{{ stat.extraBadge }}
|
|
||||||
</Badge>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
<!-- 内容区域 -->
|
||||||
</Card>
|
<div>
|
||||||
|
<p class="text-[9px] sm:text-[11px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.4em] text-muted-foreground pr-10 sm:pr-14">
|
||||||
|
{{ stat.name }}
|
||||||
|
</p>
|
||||||
|
<p class="mt-2 sm:mt-4 text-xl sm:text-3xl font-semibold text-foreground">
|
||||||
|
{{ stat.value }}
|
||||||
|
</p>
|
||||||
|
<p
|
||||||
|
v-if="stat.subValue"
|
||||||
|
class="mt-0.5 sm:mt-1 text-[10px] sm:text-sm text-muted-foreground"
|
||||||
|
>
|
||||||
|
{{ stat.subValue }}
|
||||||
|
</p>
|
||||||
|
<div
|
||||||
|
v-if="stat.change || stat.extraBadge"
|
||||||
|
class="mt-1.5 sm:mt-2 flex items-center gap-1 sm:gap-1.5 flex-wrap"
|
||||||
|
>
|
||||||
|
<Badge
|
||||||
|
v-if="stat.change"
|
||||||
|
variant="secondary"
|
||||||
|
class="text-[9px] sm:text-xs"
|
||||||
|
>
|
||||||
|
{{ stat.change }}
|
||||||
|
</Badge>
|
||||||
|
<Badge
|
||||||
|
v-if="stat.extraBadge"
|
||||||
|
variant="secondary"
|
||||||
|
class="text-[9px] sm:text-xs"
|
||||||
|
>
|
||||||
|
{{ stat.extraBadge }}
|
||||||
|
</Badge>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Card>
|
||||||
|
</template>
|
||||||
|
<!-- 无数据时显示占位卡片 -->
|
||||||
|
<template v-else>
|
||||||
|
<Card
|
||||||
|
v-for="(placeholder, index) in emptyStatPlaceholders"
|
||||||
|
:key="'empty-' + index"
|
||||||
|
class="relative overflow-hidden p-3 sm:p-5"
|
||||||
|
:class="statCardBorders[index % statCardBorders.length]"
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
class="pointer-events-none absolute -right-4 -top-6 h-28 w-28 rounded-full blur-3xl opacity-20"
|
||||||
|
:class="statCardGlows[index % statCardGlows.length]"
|
||||||
|
/>
|
||||||
|
<div
|
||||||
|
class="absolute top-3 right-3 sm:top-5 sm:right-5 rounded-xl sm:rounded-2xl border border-border bg-card/50 p-2 sm:p-3 shadow-inner backdrop-blur-sm"
|
||||||
|
:class="getStatIconColor(index)"
|
||||||
|
>
|
||||||
|
<component
|
||||||
|
:is="placeholder.icon"
|
||||||
|
class="h-4 w-4 sm:h-5 sm:w-5"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<p class="text-[9px] sm:text-[11px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.4em] text-muted-foreground pr-10 sm:pr-14">
|
||||||
|
{{ placeholder.name }}
|
||||||
|
</p>
|
||||||
|
<p class="mt-2 sm:mt-4 text-xl sm:text-3xl font-semibold text-muted-foreground/50">
|
||||||
|
--
|
||||||
|
</p>
|
||||||
|
<p class="mt-0.5 sm:mt-1 text-[10px] sm:text-sm text-muted-foreground/50">
|
||||||
|
暂无数据
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</Card>
|
||||||
|
</template>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 管理员:系统健康摘要 -->
|
<!-- 管理员:系统健康摘要 -->
|
||||||
@@ -872,6 +909,24 @@ const iconMap: Record<string, any> = {
|
|||||||
Users, Activity, TrendingUp, DollarSign, Key, Hash, Database
|
Users, Activity, TrendingUp, DollarSign, Key, Hash, Database
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 空状态占位卡片
|
||||||
|
const emptyStatPlaceholders = computed(() => {
|
||||||
|
if (isAdmin.value) {
|
||||||
|
return [
|
||||||
|
{ name: '今日请求', icon: Activity },
|
||||||
|
{ name: '今日 Tokens', icon: Hash },
|
||||||
|
{ name: '活跃用户', icon: Users },
|
||||||
|
{ name: '今日费用', icon: DollarSign }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
return [
|
||||||
|
{ name: '今日请求', icon: Activity },
|
||||||
|
{ name: '今日 Tokens', icon: Hash },
|
||||||
|
{ name: 'API Keys', icon: Key },
|
||||||
|
{ name: '今日费用', icon: DollarSign }
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
const totalStats = computed(() => {
|
const totalStats = computed(() => {
|
||||||
if (dailyStats.value.length === 0) {
|
if (dailyStats.value.length === 0) {
|
||||||
return { requests: 0, tokens: 0, cost: 0, avgResponseTime: 0 }
|
return { requests: 0, tokens: 0, cost: 0, avgResponseTime: 0 }
|
||||||
|
|||||||
@@ -78,6 +78,20 @@ export default {
|
|||||||
md: "calc(var(--radius) - 2px)",
|
md: "calc(var(--radius) - 2px)",
|
||||||
sm: "calc(var(--radius) - 4px)",
|
sm: "calc(var(--radius) - 4px)",
|
||||||
},
|
},
|
||||||
|
keyframes: {
|
||||||
|
"collapsible-down": {
|
||||||
|
from: { height: "0" },
|
||||||
|
to: { height: "var(--radix-collapsible-content-height)" },
|
||||||
|
},
|
||||||
|
"collapsible-up": {
|
||||||
|
from: { height: "var(--radix-collapsible-content-height)" },
|
||||||
|
to: { height: "0" },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
animation: {
|
||||||
|
"collapsible-down": "collapsible-down 0.2s ease-out",
|
||||||
|
"collapsible-up": "collapsible-up 0.2s ease-out",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
plugins: [require("tailwindcss-animate")],
|
plugins: [require("tailwindcss-animate")],
|
||||||
|
|||||||
12
migrate.sh
12
migrate.sh
@@ -1,12 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# 数据库迁移脚本 - 在 Docker 容器内执行 Alembic 迁移
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
CONTAINER_NAME="aether-app"
|
|
||||||
|
|
||||||
echo "Running database migrations in container: $CONTAINER_NAME"
|
|
||||||
|
|
||||||
docker exec $CONTAINER_NAME alembic upgrade head
|
|
||||||
|
|
||||||
echo "Database migration completed successfully"
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
# file generated by setuptools-scm
|
|
||||||
# don't change, don't track in version control
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"__version__",
|
|
||||||
"__version_tuple__",
|
|
||||||
"version",
|
|
||||||
"version_tuple",
|
|
||||||
"__commit_id__",
|
|
||||||
"commit_id",
|
|
||||||
]
|
|
||||||
|
|
||||||
TYPE_CHECKING = False
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from typing import Tuple
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
||||||
COMMIT_ID = Union[str, None]
|
|
||||||
else:
|
|
||||||
VERSION_TUPLE = object
|
|
||||||
COMMIT_ID = object
|
|
||||||
|
|
||||||
version: str
|
|
||||||
__version__: str
|
|
||||||
__version_tuple__: VERSION_TUPLE
|
|
||||||
version_tuple: VERSION_TUPLE
|
|
||||||
commit_id: COMMIT_ID
|
|
||||||
__commit_id__: COMMIT_ID
|
|
||||||
|
|
||||||
__version__ = version = '0.2.5'
|
|
||||||
__version_tuple__ = version_tuple = (0, 2, 5)
|
|
||||||
|
|
||||||
__commit_id__ = commit_id = None
|
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
自适应并发管理 API 端点
|
自适应 RPM 管理 API 端点
|
||||||
|
|
||||||
设计原则:
|
设计原则:
|
||||||
- 自适应模式由 max_concurrent 字段决定:
|
- 自适应模式由 rpm_limit 字段决定:
|
||||||
- max_concurrent = NULL:启用自适应模式,系统自动学习并调整并发限制
|
- rpm_limit = NULL:启用自适应模式,系统自动学习并调整 RPM 限制
|
||||||
- max_concurrent = 数字:固定限制模式,使用用户指定的并发限制
|
- rpm_limit = 数字:固定限制模式,使用用户指定的 RPM 限制
|
||||||
- learned_max_concurrent:自适应模式下学习到的并发限制值
|
- learned_rpm_limit:自适应模式下学习到的 RPM 限制值
|
||||||
- adaptive_mode 是计算字段,基于 max_concurrent 是否为 NULL
|
- adaptive_mode 是计算字段,基于 rpm_limit 是否为 NULL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -18,12 +18,13 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from src.api.base.admin_adapter import AdminApiAdapter
|
from src.api.base.admin_adapter import AdminApiAdapter
|
||||||
from src.api.base.pipeline import ApiRequestPipeline
|
from src.api.base.pipeline import ApiRequestPipeline
|
||||||
|
from src.config.constants import RPMDefaults
|
||||||
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
||||||
from src.database import get_db
|
from src.database import get_db
|
||||||
from src.models.database import ProviderAPIKey
|
from src.models.database import ProviderAPIKey
|
||||||
from src.services.rate_limit.adaptive_concurrency import get_adaptive_manager
|
from src.services.rate_limit.adaptive_rpm import get_adaptive_rpm_manager
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/admin/adaptive", tags=["Adaptive Concurrency"])
|
router = APIRouter(prefix="/api/admin/adaptive", tags=["Adaptive RPM"])
|
||||||
pipeline = ApiRequestPipeline()
|
pipeline = ApiRequestPipeline()
|
||||||
|
|
||||||
|
|
||||||
@@ -35,19 +36,19 @@ class EnableAdaptiveRequest(BaseModel):
|
|||||||
|
|
||||||
enabled: bool = Field(..., description="是否启用自适应模式(true=自适应,false=固定限制)")
|
enabled: bool = Field(..., description="是否启用自适应模式(true=自适应,false=固定限制)")
|
||||||
fixed_limit: Optional[int] = Field(
|
fixed_limit: Optional[int] = Field(
|
||||||
None, ge=1, le=100, description="固定并发限制(仅当 enabled=false 时生效)"
|
None, ge=1, le=100, description="固定 RPM 限制(仅当 enabled=false 时生效,1-100)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveStatsResponse(BaseModel):
|
class AdaptiveStatsResponse(BaseModel):
|
||||||
"""自适应统计响应"""
|
"""自适应统计响应"""
|
||||||
|
|
||||||
adaptive_mode: bool = Field(..., description="是否为自适应模式(max_concurrent=NULL)")
|
adaptive_mode: bool = Field(..., description="是否为自适应模式(rpm_limit=NULL)")
|
||||||
max_concurrent: Optional[int] = Field(None, description="用户配置的固定限制(NULL=自适应)")
|
rpm_limit: Optional[int] = Field(None, description="用户配置的固定限制(NULL=自适应)")
|
||||||
effective_limit: Optional[int] = Field(
|
effective_limit: Optional[int] = Field(
|
||||||
None, description="当前有效限制(自适应使用学习值,固定使用配置值)"
|
None, description="当前有效限制(自适应使用学习值,固定使用配置值)"
|
||||||
)
|
)
|
||||||
learned_limit: Optional[int] = Field(None, description="学习到的并发限制")
|
learned_limit: Optional[int] = Field(None, description="学习到的 RPM 限制")
|
||||||
concurrent_429_count: int
|
concurrent_429_count: int
|
||||||
rpm_429_count: int
|
rpm_429_count: int
|
||||||
last_429_at: Optional[str]
|
last_429_at: Optional[str]
|
||||||
@@ -61,11 +62,12 @@ class KeyListItem(BaseModel):
|
|||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: Optional[str]
|
name: Optional[str]
|
||||||
endpoint_id: str
|
provider_id: str
|
||||||
is_adaptive: bool = Field(..., description="是否为自适应模式(max_concurrent=NULL)")
|
api_formats: List[str] = Field(default_factory=list)
|
||||||
max_concurrent: Optional[int] = Field(None, description="固定并发限制(NULL=自适应)")
|
is_adaptive: bool = Field(..., description="是否为自适应模式(rpm_limit=NULL)")
|
||||||
|
rpm_limit: Optional[int] = Field(None, description="固定 RPM 限制(NULL=自适应)")
|
||||||
effective_limit: Optional[int] = Field(None, description="当前有效限制")
|
effective_limit: Optional[int] = Field(None, description="当前有效限制")
|
||||||
learned_max_concurrent: Optional[int] = Field(None, description="学习到的并发限制")
|
learned_rpm_limit: Optional[int] = Field(None, description="学习到的 RPM 限制")
|
||||||
concurrent_429_count: int
|
concurrent_429_count: int
|
||||||
rpm_429_count: int
|
rpm_429_count: int
|
||||||
|
|
||||||
@@ -80,22 +82,22 @@ class KeyListItem(BaseModel):
|
|||||||
)
|
)
|
||||||
async def list_adaptive_keys(
|
async def list_adaptive_keys(
|
||||||
request: Request,
|
request: Request,
|
||||||
endpoint_id: Optional[str] = Query(None, description="按 Endpoint 过滤"),
|
provider_id: Optional[str] = Query(None, description="按 Provider 过滤"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取所有启用自适应模式的Key列表
|
获取所有启用自适应模式的Key列表
|
||||||
|
|
||||||
可选参数:
|
可选参数:
|
||||||
- endpoint_id: 按 Endpoint 过滤
|
- provider_id: 按 Provider 过滤
|
||||||
"""
|
"""
|
||||||
adapter = ListAdaptiveKeysAdapter(endpoint_id=endpoint_id)
|
adapter = ListAdaptiveKeysAdapter(provider_id=provider_id)
|
||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
@router.patch(
|
||||||
"/keys/{key_id}/mode",
|
"/keys/{key_id}/mode",
|
||||||
summary="Toggle key's concurrency control mode",
|
summary="Toggle key's RPM control mode",
|
||||||
)
|
)
|
||||||
async def toggle_adaptive_mode(
|
async def toggle_adaptive_mode(
|
||||||
key_id: str,
|
key_id: str,
|
||||||
@@ -103,10 +105,10 @@ async def toggle_adaptive_mode(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Toggle the concurrency control mode for a specific key
|
Toggle the RPM control mode for a specific key
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- enabled: true=adaptive mode (max_concurrent=NULL), false=fixed limit mode
|
- enabled: true=adaptive mode (rpm_limit=NULL), false=fixed limit mode
|
||||||
- fixed_limit: fixed limit value (required when enabled=false)
|
- fixed_limit: fixed limit value (required when enabled=false)
|
||||||
"""
|
"""
|
||||||
adapter = ToggleAdaptiveModeAdapter(key_id=key_id)
|
adapter = ToggleAdaptiveModeAdapter(key_id=key_id)
|
||||||
@@ -124,7 +126,7 @@ async def get_adaptive_stats(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取指定Key的自适应并发统计信息
|
获取指定Key的自适应 RPM 统计信息
|
||||||
|
|
||||||
包括:
|
包括:
|
||||||
- 当前配置
|
- 当前配置
|
||||||
@@ -149,12 +151,12 @@ async def reset_adaptive_learning(
|
|||||||
Reset the adaptive learning state for a specific key
|
Reset the adaptive learning state for a specific key
|
||||||
|
|
||||||
Clears:
|
Clears:
|
||||||
- Learned concurrency limit (learned_max_concurrent)
|
- Learned RPM limit (learned_rpm_limit)
|
||||||
- 429 error counts
|
- 429 error counts
|
||||||
- Adjustment history
|
- Adjustment history
|
||||||
|
|
||||||
Does not change:
|
Does not change:
|
||||||
- max_concurrent config (determines adaptive mode)
|
- rpm_limit config (determines adaptive mode)
|
||||||
"""
|
"""
|
||||||
adapter = ResetAdaptiveLearningAdapter(key_id=key_id)
|
adapter = ResetAdaptiveLearningAdapter(key_id=key_id)
|
||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
@@ -162,40 +164,40 @@ async def reset_adaptive_learning(
|
|||||||
|
|
||||||
@router.patch(
|
@router.patch(
|
||||||
"/keys/{key_id}/limit",
|
"/keys/{key_id}/limit",
|
||||||
summary="Set key to fixed concurrency limit mode",
|
summary="Set key to fixed RPM limit mode",
|
||||||
)
|
)
|
||||||
async def set_concurrent_limit(
|
async def set_rpm_limit(
|
||||||
key_id: str,
|
key_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
limit: int = Query(..., ge=1, le=100, description="Concurrency limit value"),
|
limit: int = Query(..., ge=1, le=100, description="RPM limit value (1-100)"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Set key to fixed concurrency limit mode
|
Set key to fixed RPM limit mode
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
- After setting this value, key switches to fixed limit mode and won't auto-adjust
|
- After setting this value, key switches to fixed limit mode and won't auto-adjust
|
||||||
- To restore adaptive mode, use PATCH /keys/{key_id}/mode
|
- To restore adaptive mode, use PATCH /keys/{key_id}/mode
|
||||||
"""
|
"""
|
||||||
adapter = SetConcurrentLimitAdapter(key_id=key_id, limit=limit)
|
adapter = SetRPMLimitAdapter(key_id=key_id, limit=limit)
|
||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/summary",
|
"/summary",
|
||||||
summary="获取自适应并发的全局统计",
|
summary="获取自适应 RPM 的全局统计",
|
||||||
)
|
)
|
||||||
async def get_adaptive_summary(
|
async def get_adaptive_summary(
|
||||||
request: Request,
|
request: Request,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取自适应并发的全局统计摘要
|
获取自适应 RPM 的全局统计摘要
|
||||||
|
|
||||||
包括:
|
包括:
|
||||||
- 启用自适应模式的Key数量
|
- 启用自适应模式的Key数量
|
||||||
- 总429错误数
|
- 总429错误数
|
||||||
- 并发限制调整次数
|
- RPM 限制调整次数
|
||||||
"""
|
"""
|
||||||
adapter = AdaptiveSummaryAdapter()
|
adapter = AdaptiveSummaryAdapter()
|
||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
@@ -206,26 +208,29 @@ async def get_adaptive_summary(
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ListAdaptiveKeysAdapter(AdminApiAdapter):
|
class ListAdaptiveKeysAdapter(AdminApiAdapter):
|
||||||
endpoint_id: Optional[str] = None
|
provider_id: Optional[str] = None
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
# 自适应模式:max_concurrent = NULL
|
# 自适应模式:rpm_limit = NULL
|
||||||
query = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None))
|
query = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.rpm_limit.is_(None))
|
||||||
if self.endpoint_id:
|
if self.provider_id:
|
||||||
query = query.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
|
query = query.filter(ProviderAPIKey.provider_id == self.provider_id)
|
||||||
|
|
||||||
keys = query.all()
|
keys = query.all()
|
||||||
return [
|
return [
|
||||||
KeyListItem(
|
KeyListItem(
|
||||||
id=key.id,
|
id=key.id,
|
||||||
name=key.name,
|
name=key.name,
|
||||||
endpoint_id=key.endpoint_id,
|
provider_id=key.provider_id,
|
||||||
is_adaptive=key.max_concurrent is None,
|
api_formats=key.api_formats or [],
|
||||||
max_concurrent=key.max_concurrent,
|
is_adaptive=key.rpm_limit is None,
|
||||||
|
rpm_limit=key.rpm_limit,
|
||||||
effective_limit=(
|
effective_limit=(
|
||||||
key.learned_max_concurrent if key.max_concurrent is None else key.max_concurrent
|
(key.learned_rpm_limit if key.learned_rpm_limit is not None else RPMDefaults.INITIAL_LIMIT)
|
||||||
|
if key.rpm_limit is None
|
||||||
|
else key.rpm_limit
|
||||||
),
|
),
|
||||||
learned_max_concurrent=key.learned_max_concurrent,
|
learned_rpm_limit=key.learned_rpm_limit,
|
||||||
concurrent_429_count=key.concurrent_429_count or 0,
|
concurrent_429_count=key.concurrent_429_count or 0,
|
||||||
rpm_429_count=key.rpm_429_count or 0,
|
rpm_429_count=key.rpm_429_count or 0,
|
||||||
)
|
)
|
||||||
@@ -252,28 +257,32 @@ class ToggleAdaptiveModeAdapter(AdminApiAdapter):
|
|||||||
raise InvalidRequestException("请求数据验证失败")
|
raise InvalidRequestException("请求数据验证失败")
|
||||||
|
|
||||||
if body.enabled:
|
if body.enabled:
|
||||||
# 启用自适应模式:将 max_concurrent 设为 NULL
|
# 启用自适应模式:将 rpm_limit 设为 NULL
|
||||||
key.max_concurrent = None
|
key.rpm_limit = None
|
||||||
message = "已切换为自适应模式,系统将自动学习并调整并发限制"
|
message = "已切换为自适应模式,系统将自动学习并调整 RPM 限制"
|
||||||
else:
|
else:
|
||||||
# 禁用自适应模式:设置固定限制
|
# 禁用自适应模式:设置固定限制
|
||||||
if body.fixed_limit is None:
|
if body.fixed_limit is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="禁用自适应模式时必须提供 fixed_limit 参数"
|
status_code=400, detail="禁用自适应模式时必须提供 fixed_limit 参数"
|
||||||
)
|
)
|
||||||
key.max_concurrent = body.fixed_limit
|
key.rpm_limit = body.fixed_limit
|
||||||
message = f"已切换为固定限制模式,并发限制设为 {body.fixed_limit}"
|
message = f"已切换为固定限制模式,RPM 限制设为 {body.fixed_limit}"
|
||||||
|
|
||||||
context.db.commit()
|
context.db.commit()
|
||||||
context.db.refresh(key)
|
context.db.refresh(key)
|
||||||
|
|
||||||
is_adaptive = key.max_concurrent is None
|
is_adaptive = key.rpm_limit is None
|
||||||
return {
|
return {
|
||||||
"message": message,
|
"message": message,
|
||||||
"key_id": key.id,
|
"key_id": key.id,
|
||||||
"is_adaptive": is_adaptive,
|
"is_adaptive": is_adaptive,
|
||||||
"max_concurrent": key.max_concurrent,
|
"rpm_limit": key.rpm_limit,
|
||||||
"effective_limit": key.learned_max_concurrent if is_adaptive else key.max_concurrent,
|
"effective_limit": (
|
||||||
|
(key.learned_rpm_limit if key.learned_rpm_limit is not None else RPMDefaults.INITIAL_LIMIT)
|
||||||
|
if is_adaptive
|
||||||
|
else key.rpm_limit
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -286,13 +295,13 @@ class GetAdaptiveStatsAdapter(AdminApiAdapter):
|
|||||||
if not key:
|
if not key:
|
||||||
raise HTTPException(status_code=404, detail="Key not found")
|
raise HTTPException(status_code=404, detail="Key not found")
|
||||||
|
|
||||||
adaptive_manager = get_adaptive_manager()
|
adaptive_manager = get_adaptive_rpm_manager()
|
||||||
stats = adaptive_manager.get_adjustment_stats(key)
|
stats = adaptive_manager.get_adjustment_stats(key)
|
||||||
|
|
||||||
# 转换字段名以匹配响应模型
|
# 转换字段名以匹配响应模型
|
||||||
return AdaptiveStatsResponse(
|
return AdaptiveStatsResponse(
|
||||||
adaptive_mode=stats["adaptive_mode"],
|
adaptive_mode=stats["adaptive_mode"],
|
||||||
max_concurrent=stats["max_concurrent"],
|
rpm_limit=stats["rpm_limit"],
|
||||||
effective_limit=stats["effective_limit"],
|
effective_limit=stats["effective_limit"],
|
||||||
learned_limit=stats["learned_limit"],
|
learned_limit=stats["learned_limit"],
|
||||||
concurrent_429_count=stats["concurrent_429_count"],
|
concurrent_429_count=stats["concurrent_429_count"],
|
||||||
@@ -313,13 +322,13 @@ class ResetAdaptiveLearningAdapter(AdminApiAdapter):
|
|||||||
if not key:
|
if not key:
|
||||||
raise HTTPException(status_code=404, detail="Key not found")
|
raise HTTPException(status_code=404, detail="Key not found")
|
||||||
|
|
||||||
adaptive_manager = get_adaptive_manager()
|
adaptive_manager = get_adaptive_rpm_manager()
|
||||||
adaptive_manager.reset_learning(context.db, key)
|
adaptive_manager.reset_learning(context.db, key)
|
||||||
return {"message": "学习状态已重置", "key_id": key.id}
|
return {"message": "学习状态已重置", "key_id": key.id}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SetConcurrentLimitAdapter(AdminApiAdapter):
|
class SetRPMLimitAdapter(AdminApiAdapter):
|
||||||
key_id: str
|
key_id: str
|
||||||
limit: int
|
limit: int
|
||||||
|
|
||||||
@@ -328,25 +337,25 @@ class SetConcurrentLimitAdapter(AdminApiAdapter):
|
|||||||
if not key:
|
if not key:
|
||||||
raise HTTPException(status_code=404, detail="Key not found")
|
raise HTTPException(status_code=404, detail="Key not found")
|
||||||
|
|
||||||
was_adaptive = key.max_concurrent is None
|
was_adaptive = key.rpm_limit is None
|
||||||
key.max_concurrent = self.limit
|
key.rpm_limit = self.limit
|
||||||
context.db.commit()
|
context.db.commit()
|
||||||
context.db.refresh(key)
|
context.db.refresh(key)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"message": f"已设置为固定限制模式,并发限制为 {self.limit}",
|
"message": f"已设置为固定限制模式,RPM 限制为 {self.limit}",
|
||||||
"key_id": key.id,
|
"key_id": key.id,
|
||||||
"is_adaptive": False,
|
"is_adaptive": False,
|
||||||
"max_concurrent": key.max_concurrent,
|
"rpm_limit": key.rpm_limit,
|
||||||
"previous_mode": "adaptive" if was_adaptive else "fixed",
|
"previous_mode": "adaptive" if was_adaptive else "fixed",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveSummaryAdapter(AdminApiAdapter):
|
class AdaptiveSummaryAdapter(AdminApiAdapter):
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
# 自适应模式:max_concurrent = NULL
|
# 自适应模式:rpm_limit = NULL
|
||||||
adaptive_keys = (
|
adaptive_keys = (
|
||||||
context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None)).all()
|
context.db.query(ProviderAPIKey).filter(ProviderAPIKey.rpm_limit.is_(None)).all()
|
||||||
)
|
)
|
||||||
|
|
||||||
total_keys = len(adaptive_keys)
|
total_keys = len(adaptive_keys)
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Endpoint 并发控制管理 API
|
Key RPM 限制管理 API
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Request
|
from fastapi import APIRouter, Depends, Request
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -12,83 +11,56 @@ from src.api.base.admin_adapter import AdminApiAdapter
|
|||||||
from src.api.base.pipeline import ApiRequestPipeline
|
from src.api.base.pipeline import ApiRequestPipeline
|
||||||
from src.core.exceptions import NotFoundException
|
from src.core.exceptions import NotFoundException
|
||||||
from src.database import get_db
|
from src.database import get_db
|
||||||
from src.models.database import ProviderAPIKey, ProviderEndpoint
|
from src.models.database import ProviderAPIKey
|
||||||
from src.models.endpoint_models import (
|
from src.models.endpoint_models import KeyRpmStatusResponse
|
||||||
ConcurrencyStatusResponse,
|
|
||||||
ResetConcurrencyRequest,
|
|
||||||
)
|
|
||||||
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
|
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
|
||||||
|
|
||||||
router = APIRouter(tags=["Concurrency Control"])
|
router = APIRouter(tags=["RPM Control"])
|
||||||
pipeline = ApiRequestPipeline()
|
pipeline = ApiRequestPipeline()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/concurrency/endpoint/{endpoint_id}", response_model=ConcurrencyStatusResponse)
|
@router.get("/rpm/key/{key_id}", response_model=KeyRpmStatusResponse)
|
||||||
async def get_endpoint_concurrency(
|
async def get_key_rpm(
|
||||||
endpoint_id: str,
|
|
||||||
request: Request,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
) -> ConcurrencyStatusResponse:
|
|
||||||
"""
|
|
||||||
获取 Endpoint 当前并发状态
|
|
||||||
|
|
||||||
查询指定 Endpoint 的实时并发使用情况,包括当前并发数和最大并发限制。
|
|
||||||
|
|
||||||
**路径参数**:
|
|
||||||
- `endpoint_id`: Endpoint ID
|
|
||||||
|
|
||||||
**返回字段**:
|
|
||||||
- `endpoint_id`: Endpoint ID
|
|
||||||
- `endpoint_current_concurrency`: 当前并发数
|
|
||||||
- `endpoint_max_concurrent`: 最大并发限制
|
|
||||||
"""
|
|
||||||
adapter = AdminEndpointConcurrencyAdapter(endpoint_id=endpoint_id)
|
|
||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/concurrency/key/{key_id}", response_model=ConcurrencyStatusResponse)
|
|
||||||
async def get_key_concurrency(
|
|
||||||
key_id: str,
|
key_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> ConcurrencyStatusResponse:
|
) -> KeyRpmStatusResponse:
|
||||||
"""
|
"""
|
||||||
获取 Key 当前并发状态
|
获取 Key 当前 RPM 状态
|
||||||
|
|
||||||
查询指定 API Key 的实时并发使用情况,包括当前并发数和最大并发限制。
|
查询指定 API Key 的实时 RPM 使用情况,包括当前 RPM 计数和最大 RPM 限制。
|
||||||
|
|
||||||
**路径参数**:
|
**路径参数**:
|
||||||
- `key_id`: API Key ID
|
- `key_id`: API Key ID
|
||||||
|
|
||||||
**返回字段**:
|
**返回字段**:
|
||||||
- `key_id`: API Key ID
|
- `key_id`: API Key ID
|
||||||
- `key_current_concurrency`: 当前并发数
|
- `current_rpm`: 当前 RPM 计数
|
||||||
- `key_max_concurrent`: 最大并发限制
|
- `rpm_limit`: RPM 限制
|
||||||
"""
|
"""
|
||||||
adapter = AdminKeyConcurrencyAdapter(key_id=key_id)
|
adapter = AdminKeyRpmAdapter(key_id=key_id)
|
||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/concurrency")
|
@router.delete("/rpm/key/{key_id}")
|
||||||
async def reset_concurrency(
|
async def reset_key_rpm(
|
||||||
request: ResetConcurrencyRequest,
|
key_id: str,
|
||||||
http_request: Request,
|
http_request: Request,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
重置并发计数器
|
重置 Key RPM 计数器
|
||||||
|
|
||||||
重置指定 Endpoint 或 Key 的并发计数器,用于解决计数不准确的问题。
|
重置指定 API Key 的 RPM 计数器,用于解决计数不准确的问题。
|
||||||
管理员功能,请谨慎使用。
|
管理员功能,请谨慎使用。
|
||||||
|
|
||||||
**请求体字段**:
|
**路径参数**:
|
||||||
- `endpoint_id`: Endpoint ID(可选)
|
- `key_id`: API Key ID
|
||||||
- `key_id`: API Key ID(可选)
|
|
||||||
|
|
||||||
**返回字段**:
|
**返回字段**:
|
||||||
- `message`: 操作结果消息
|
- `message`: 操作结果消息
|
||||||
"""
|
"""
|
||||||
adapter = AdminResetConcurrencyAdapter(endpoint_id=request.endpoint_id, key_id=request.key_id)
|
adapter = AdminResetKeyRpmAdapter(key_id=key_id)
|
||||||
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
@@ -96,31 +68,7 @@ async def reset_concurrency(
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdminEndpointConcurrencyAdapter(AdminApiAdapter):
|
class AdminKeyRpmAdapter(AdminApiAdapter):
|
||||||
endpoint_id: str
|
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
|
||||||
db = context.db
|
|
||||||
endpoint = (
|
|
||||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
|
||||||
)
|
|
||||||
if not endpoint:
|
|
||||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
|
||||||
|
|
||||||
concurrency_manager = await get_concurrency_manager()
|
|
||||||
endpoint_count, _ = await concurrency_manager.get_current_concurrency(
|
|
||||||
endpoint_id=self.endpoint_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return ConcurrencyStatusResponse(
|
|
||||||
endpoint_id=self.endpoint_id,
|
|
||||||
endpoint_current_concurrency=endpoint_count,
|
|
||||||
endpoint_max_concurrent=endpoint.max_concurrent,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AdminKeyConcurrencyAdapter(AdminApiAdapter):
|
|
||||||
key_id: str
|
key_id: str
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
@@ -130,23 +78,20 @@ class AdminKeyConcurrencyAdapter(AdminApiAdapter):
|
|||||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||||
|
|
||||||
concurrency_manager = await get_concurrency_manager()
|
concurrency_manager = await get_concurrency_manager()
|
||||||
_, key_count = await concurrency_manager.get_current_concurrency(key_id=self.key_id)
|
key_count = await concurrency_manager.get_key_rpm_count(key_id=self.key_id)
|
||||||
|
|
||||||
return ConcurrencyStatusResponse(
|
return KeyRpmStatusResponse(
|
||||||
key_id=self.key_id,
|
key_id=self.key_id,
|
||||||
key_current_concurrency=key_count,
|
current_rpm=key_count,
|
||||||
key_max_concurrent=key.max_concurrent,
|
rpm_limit=key.rpm_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdminResetConcurrencyAdapter(AdminApiAdapter):
|
class AdminResetKeyRpmAdapter(AdminApiAdapter):
|
||||||
endpoint_id: Optional[str]
|
key_id: str
|
||||||
key_id: Optional[str]
|
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
concurrency_manager = await get_concurrency_manager()
|
concurrency_manager = await get_concurrency_manager()
|
||||||
await concurrency_manager.reset_concurrency(
|
await concurrency_manager.reset_key_rpm(key_id=self.key_id)
|
||||||
endpoint_id=self.endpoint_id, key_id=self.key_id
|
return {"message": "RPM 计数已重置"}
|
||||||
)
|
|
||||||
return {"message": "并发计数已重置"}
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ Endpoint 健康监控 API
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, Request
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
@@ -128,29 +128,32 @@ async def get_api_format_health_monitor(
|
|||||||
async def get_key_health(
|
async def get_key_health(
|
||||||
key_id: str,
|
key_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
api_format: Optional[str] = Query(None, description="API 格式(可选,如 CLAUDE、OPENAI)"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> HealthStatusResponse:
|
) -> HealthStatusResponse:
|
||||||
"""
|
"""
|
||||||
获取 Key 健康状态
|
获取 Key 健康状态
|
||||||
|
|
||||||
获取指定 API Key 的健康状态详情,包括健康分数、连续失败次数、
|
获取指定 API Key 的健康状态详情,包括健康分数、连续失败次数、
|
||||||
熔断器状态等信息。
|
熔断器状态等信息。支持按 API 格式查询。
|
||||||
|
|
||||||
**路径参数**:
|
**路径参数**:
|
||||||
- `key_id`: API Key ID
|
- `key_id`: API Key ID
|
||||||
|
|
||||||
|
**查询参数**:
|
||||||
|
- `api_format`: 可选,指定 API 格式(如 CLAUDE、OPENAI)。
|
||||||
|
- 指定时返回该格式的健康度详情
|
||||||
|
- 不指定时返回所有格式的健康度摘要
|
||||||
|
|
||||||
**返回字段**:
|
**返回字段**:
|
||||||
- `key_id`: API Key ID
|
- `key_id`: API Key ID
|
||||||
- `key_health_score`: 健康分数(0.0-1.0)
|
- `key_health_score`: 健康分数(0.0-1.0)
|
||||||
- `key_consecutive_failures`: 连续失败次数
|
|
||||||
- `key_last_failure_at`: 最后失败时间
|
|
||||||
- `key_is_active`: 是否活跃
|
- `key_is_active`: 是否活跃
|
||||||
- `key_statistics`: 统计信息
|
- `key_statistics`: 统计信息
|
||||||
- `circuit_breaker_open`: 熔断器是否打开
|
- `health_by_format`: 按格式的健康度数据(无 api_format 参数时)
|
||||||
- `circuit_breaker_open_at`: 熔断器打开时间
|
- `circuit_breaker_open`: 熔断器是否打开(有 api_format 参数时)
|
||||||
- `next_probe_at`: 下次探测时间
|
|
||||||
"""
|
"""
|
||||||
adapter = AdminKeyHealthAdapter(key_id=key_id)
|
adapter = AdminKeyHealthAdapter(key_id=key_id, api_format=api_format)
|
||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
@@ -158,17 +161,23 @@ async def get_key_health(
|
|||||||
async def recover_key_health(
|
async def recover_key_health(
|
||||||
key_id: str,
|
key_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
api_format: Optional[str] = Query(None, description="API 格式(可选,不指定则恢复所有格式)"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
恢复 Key 健康状态
|
恢复 Key 健康状态
|
||||||
|
|
||||||
手动恢复指定 Key 的健康状态,将健康分数重置为 1.0,关闭熔断器,
|
手动恢复指定 Key 的健康状态,将健康分数重置为 1.0,关闭熔断器,
|
||||||
取消自动禁用,并重置所有失败计数。
|
取消自动禁用,并重置所有失败计数。支持按 API 格式恢复。
|
||||||
|
|
||||||
**路径参数**:
|
**路径参数**:
|
||||||
- `key_id`: API Key ID
|
- `key_id`: API Key ID
|
||||||
|
|
||||||
|
**查询参数**:
|
||||||
|
- `api_format`: 可选,指定 API 格式(如 CLAUDE、OPENAI)
|
||||||
|
- 指定时仅恢复该格式的健康度
|
||||||
|
- 不指定时恢复所有格式
|
||||||
|
|
||||||
**返回字段**:
|
**返回字段**:
|
||||||
- `message`: 操作结果消息
|
- `message`: 操作结果消息
|
||||||
- `details`: 详细信息
|
- `details`: 详细信息
|
||||||
@@ -176,7 +185,7 @@ async def recover_key_health(
|
|||||||
- `circuit_breaker_open`: 熔断器状态
|
- `circuit_breaker_open`: 熔断器状态
|
||||||
- `is_active`: 是否活跃
|
- `is_active`: 是否活跃
|
||||||
"""
|
"""
|
||||||
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id)
|
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id, api_format=api_format)
|
||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
@@ -276,34 +285,9 @@ class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
|
|||||||
)
|
)
|
||||||
all_formats[api_format] = provider_count
|
all_formats[api_format] = provider_count
|
||||||
|
|
||||||
# 1.1 获取所有活跃的 API 格式及其 API Key 数量
|
# 1.1 建立每个 API 格式对应的 Endpoint ID 列表(用于时间线生成),并收集活跃的 provider+format 组合
|
||||||
active_keys = (
|
|
||||||
db.query(
|
|
||||||
ProviderEndpoint.api_format,
|
|
||||||
func.count(ProviderAPIKey.id).label("key_count"),
|
|
||||||
)
|
|
||||||
.join(ProviderAPIKey, ProviderEndpoint.id == ProviderAPIKey.endpoint_id)
|
|
||||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
|
||||||
.filter(
|
|
||||||
ProviderEndpoint.is_active.is_(True),
|
|
||||||
Provider.is_active.is_(True),
|
|
||||||
ProviderAPIKey.is_active.is_(True),
|
|
||||||
)
|
|
||||||
.group_by(ProviderEndpoint.api_format)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建所有格式的 key_count 映射
|
|
||||||
key_counts: Dict[str, int] = {}
|
|
||||||
for api_format_enum, key_count in active_keys:
|
|
||||||
api_format = (
|
|
||||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
|
||||||
)
|
|
||||||
key_counts[api_format] = key_count
|
|
||||||
|
|
||||||
# 1.2 建立每个 API 格式对应的 Endpoint ID 列表,供 Usage 时间线生成使用
|
|
||||||
endpoint_rows = (
|
endpoint_rows = (
|
||||||
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id)
|
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id, ProviderEndpoint.provider_id)
|
||||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||||
.filter(
|
.filter(
|
||||||
ProviderEndpoint.is_active.is_(True),
|
ProviderEndpoint.is_active.is_(True),
|
||||||
@@ -312,11 +296,32 @@ class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
|
|||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
endpoint_map: Dict[str, List[str]] = defaultdict(list)
|
endpoint_map: Dict[str, List[str]] = defaultdict(list)
|
||||||
for api_format_enum, endpoint_id in endpoint_rows:
|
active_provider_formats: set[tuple[str, str]] = set()
|
||||||
|
for api_format_enum, endpoint_id, provider_id in endpoint_rows:
|
||||||
api_format = (
|
api_format = (
|
||||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||||
)
|
)
|
||||||
endpoint_map[api_format].append(endpoint_id)
|
endpoint_map[api_format].append(endpoint_id)
|
||||||
|
active_provider_formats.add((str(provider_id), api_format))
|
||||||
|
|
||||||
|
# 1.2 统计每个 API 格式可用的活跃 Key 数量(Key 属于 Provider,通过 api_formats 关联格式)
|
||||||
|
key_counts: Dict[str, int] = {}
|
||||||
|
if active_provider_formats:
|
||||||
|
active_provider_keys = (
|
||||||
|
db.query(ProviderAPIKey.provider_id, ProviderAPIKey.api_formats)
|
||||||
|
.join(Provider, ProviderAPIKey.provider_id == Provider.id)
|
||||||
|
.filter(
|
||||||
|
Provider.is_active.is_(True),
|
||||||
|
ProviderAPIKey.is_active.is_(True),
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
for provider_id, api_formats in active_provider_keys:
|
||||||
|
pid = str(provider_id)
|
||||||
|
for fmt in (api_formats or []):
|
||||||
|
if (pid, fmt) not in active_provider_formats:
|
||||||
|
continue
|
||||||
|
key_counts[fmt] = key_counts.get(fmt, 0) + 1
|
||||||
|
|
||||||
# 2. 统计窗口内每个 API 格式的请求状态分布(真实统计)
|
# 2. 统计窗口内每个 API 格式的请求状态分布(真实统计)
|
||||||
# 只统计最终状态:success, failed, skipped
|
# 只统计最终状态:success, failed, skipped
|
||||||
@@ -457,28 +462,45 @@ class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class AdminKeyHealthAdapter(AdminApiAdapter):
|
class AdminKeyHealthAdapter(AdminApiAdapter):
|
||||||
key_id: str
|
key_id: str
|
||||||
|
api_format: Optional[str] = None
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
health_data = health_monitor.get_key_health(context.db, self.key_id)
|
health_data = health_monitor.get_key_health(context.db, self.key_id, self.api_format)
|
||||||
if not health_data:
|
if not health_data:
|
||||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||||
|
|
||||||
return HealthStatusResponse(
|
# 构建响应
|
||||||
key_id=health_data["key_id"],
|
response_data = {
|
||||||
key_health_score=health_data["health_score"],
|
"key_id": health_data["key_id"],
|
||||||
key_consecutive_failures=health_data["consecutive_failures"],
|
"key_is_active": health_data["is_active"],
|
||||||
key_last_failure_at=health_data["last_failure_at"],
|
"key_statistics": health_data.get("statistics"),
|
||||||
key_is_active=health_data["is_active"],
|
"key_health_score": health_data.get("health_score", 1.0),
|
||||||
key_statistics=health_data["statistics"],
|
}
|
||||||
circuit_breaker_open=health_data["circuit_breaker_open"],
|
|
||||||
circuit_breaker_open_at=health_data["circuit_breaker_open_at"],
|
if self.api_format:
|
||||||
next_probe_at=health_data["next_probe_at"],
|
# 单格式查询
|
||||||
)
|
response_data["api_format"] = self.api_format
|
||||||
|
response_data["key_consecutive_failures"] = health_data.get("consecutive_failures")
|
||||||
|
response_data["key_last_failure_at"] = health_data.get("last_failure_at")
|
||||||
|
circuit = health_data.get("circuit_breaker", {})
|
||||||
|
response_data["circuit_breaker_open"] = circuit.get("open", False)
|
||||||
|
response_data["circuit_breaker_open_at"] = circuit.get("open_at")
|
||||||
|
response_data["next_probe_at"] = circuit.get("next_probe_at")
|
||||||
|
response_data["half_open_until"] = circuit.get("half_open_until")
|
||||||
|
response_data["half_open_successes"] = circuit.get("half_open_successes", 0)
|
||||||
|
response_data["half_open_failures"] = circuit.get("half_open_failures", 0)
|
||||||
|
else:
|
||||||
|
# 全格式查询
|
||||||
|
response_data["any_circuit_open"] = health_data.get("any_circuit_open", False)
|
||||||
|
response_data["health_by_format"] = health_data.get("health_by_format")
|
||||||
|
|
||||||
|
return HealthStatusResponse(**response_data)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
|
class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
|
||||||
key_id: str
|
key_id: str
|
||||||
|
api_format: Optional[str] = None
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
db = context.db
|
db = context.db
|
||||||
@@ -486,28 +508,38 @@ class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
|
|||||||
if not key:
|
if not key:
|
||||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||||
|
|
||||||
key.health_score = 1.0
|
# 使用 health_monitor.reset_health 重置健康度
|
||||||
key.consecutive_failures = 0
|
success = health_monitor.reset_health(db, key_id=self.key_id, api_format=self.api_format)
|
||||||
key.last_failure_at = None
|
if not success:
|
||||||
key.circuit_breaker_open = False
|
raise Exception("重置健康度失败")
|
||||||
key.circuit_breaker_open_at = None
|
|
||||||
key.next_probe_at = None
|
# 如果 Key 被禁用,重新启用
|
||||||
if not key.is_active:
|
if not key.is_active:
|
||||||
key.is_active = True
|
key.is_active = True # type: ignore[assignment]
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
admin_name = context.user.username if context.user else "admin"
|
if self.api_format:
|
||||||
logger.info(f"管理员恢复Key健康状态: {self.key_id} (health_score: 1.0, circuit_breaker: closed)")
|
logger.info(f"管理员恢复Key健康状态: {self.key_id}/{self.api_format}")
|
||||||
|
return {
|
||||||
return {
|
"message": f"Key 的 {self.api_format} 格式已恢复",
|
||||||
"message": "Key已完全恢复",
|
"details": {
|
||||||
"details": {
|
"api_format": self.api_format,
|
||||||
"health_score": 1.0,
|
"health_score": 1.0,
|
||||||
"circuit_breaker_open": False,
|
"circuit_breaker_open": False,
|
||||||
"is_active": True,
|
"is_active": True,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
logger.info(f"管理员恢复Key健康状态: {self.key_id} (所有格式)")
|
||||||
|
return {
|
||||||
|
"message": "Key 所有格式已恢复",
|
||||||
|
"details": {
|
||||||
|
"health_score": 1.0,
|
||||||
|
"circuit_breaker_open": False,
|
||||||
|
"is_active": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
|
class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
|
||||||
@@ -516,10 +548,17 @@ class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
|
|||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
db = context.db
|
db = context.db
|
||||||
|
|
||||||
# 查找所有熔断的 Key
|
# 查找所有有熔断格式的 Key(检查 circuit_breaker_by_format JSON 字段)
|
||||||
circuit_open_keys = (
|
all_keys = db.query(ProviderAPIKey).all()
|
||||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.circuit_breaker_open == True).all()
|
|
||||||
)
|
# 筛选出有任何格式熔断的 Key
|
||||||
|
circuit_open_keys = []
|
||||||
|
for key in all_keys:
|
||||||
|
circuit_by_format = key.circuit_breaker_by_format or {}
|
||||||
|
for fmt, circuit_data in circuit_by_format.items():
|
||||||
|
if circuit_data.get("open"):
|
||||||
|
circuit_open_keys.append(key)
|
||||||
|
break
|
||||||
|
|
||||||
if not circuit_open_keys:
|
if not circuit_open_keys:
|
||||||
return {
|
return {
|
||||||
@@ -530,17 +569,15 @@ class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
|
|||||||
|
|
||||||
recovered_keys = []
|
recovered_keys = []
|
||||||
for key in circuit_open_keys:
|
for key in circuit_open_keys:
|
||||||
key.health_score = 1.0
|
# 重置所有格式的健康度
|
||||||
key.consecutive_failures = 0
|
key.health_by_format = {} # type: ignore[assignment]
|
||||||
key.last_failure_at = None
|
key.circuit_breaker_by_format = {} # type: ignore[assignment]
|
||||||
key.circuit_breaker_open = False
|
|
||||||
key.circuit_breaker_open_at = None
|
|
||||||
key.next_probe_at = None
|
|
||||||
recovered_keys.append(
|
recovered_keys.append(
|
||||||
{
|
{
|
||||||
"key_id": key.id,
|
"key_id": key.id,
|
||||||
"key_name": key.name,
|
"key_name": key.name,
|
||||||
"endpoint_id": key.endpoint_id,
|
"provider_id": key.provider_id,
|
||||||
|
"api_formats": key.api_formats,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -552,7 +589,6 @@ class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
|
|||||||
HealthMonitor._open_circuit_keys = 0
|
HealthMonitor._open_circuit_keys = 0
|
||||||
health_open_circuits.set(0)
|
health_open_circuits.set(0)
|
||||||
|
|
||||||
admin_name = context.user.username if context.user else "admin"
|
|
||||||
logger.info(f"管理员批量恢复 {len(recovered_keys)} 个 Key 的健康状态")
|
logger.info(f"管理员批量恢复 {len(recovered_keys)} 个 Key 的健康状态")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Endpoint API Keys 管理
|
Provider API Keys 管理
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
@@ -12,103 +12,23 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from src.api.base.admin_adapter import AdminApiAdapter
|
from src.api.base.admin_adapter import AdminApiAdapter
|
||||||
from src.api.base.pipeline import ApiRequestPipeline
|
from src.api.base.pipeline import ApiRequestPipeline
|
||||||
|
from src.config.constants import RPMDefaults
|
||||||
from src.core.crypto import crypto_service
|
from src.core.crypto import crypto_service
|
||||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||||
from src.core.key_capabilities import get_capability
|
from src.core.key_capabilities import get_capability
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.database import get_db
|
from src.database import get_db
|
||||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||||||
|
from src.services.cache.provider_cache import ProviderCacheService
|
||||||
from src.models.endpoint_models import (
|
from src.models.endpoint_models import (
|
||||||
BatchUpdateKeyPriorityRequest,
|
|
||||||
EndpointAPIKeyCreate,
|
EndpointAPIKeyCreate,
|
||||||
EndpointAPIKeyResponse,
|
EndpointAPIKeyResponse,
|
||||||
EndpointAPIKeyUpdate,
|
EndpointAPIKeyUpdate,
|
||||||
)
|
)
|
||||||
|
|
||||||
router = APIRouter(tags=["Endpoint Keys"])
|
router = APIRouter(tags=["Provider Keys"])
|
||||||
pipeline = ApiRequestPipeline()
|
pipeline = ApiRequestPipeline()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{endpoint_id}/keys", response_model=List[EndpointAPIKeyResponse])
|
|
||||||
async def list_endpoint_keys(
|
|
||||||
endpoint_id: str,
|
|
||||||
request: Request,
|
|
||||||
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
|
||||||
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
) -> List[EndpointAPIKeyResponse]:
|
|
||||||
"""
|
|
||||||
获取 Endpoint 的所有 Keys
|
|
||||||
|
|
||||||
获取指定 Endpoint 下的所有 API Key 列表,包括 Key 的配置、统计信息等。
|
|
||||||
结果按优先级和创建时间排序。
|
|
||||||
|
|
||||||
**路径参数**:
|
|
||||||
- `endpoint_id`: Endpoint ID
|
|
||||||
|
|
||||||
**查询参数**:
|
|
||||||
- `skip`: 跳过的记录数,用于分页(默认 0)
|
|
||||||
- `limit`: 返回的最大记录数(1-1000,默认 100)
|
|
||||||
|
|
||||||
**返回字段**:
|
|
||||||
- `id`: Key ID
|
|
||||||
- `name`: Key 名称
|
|
||||||
- `api_key_masked`: 脱敏后的 API Key
|
|
||||||
- `internal_priority`: 内部优先级
|
|
||||||
- `global_priority`: 全局优先级
|
|
||||||
- `rate_multiplier`: 速率倍数
|
|
||||||
- `max_concurrent`: 最大并发数(null 表示自适应模式)
|
|
||||||
- `is_adaptive`: 是否为自适应并发模式
|
|
||||||
- `effective_limit`: 有效并发限制
|
|
||||||
- `success_rate`: 成功率
|
|
||||||
- `avg_response_time_ms`: 平均响应时间(毫秒)
|
|
||||||
- 其他配置和统计字段
|
|
||||||
"""
|
|
||||||
adapter = AdminListEndpointKeysAdapter(
|
|
||||||
endpoint_id=endpoint_id,
|
|
||||||
skip=skip,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{endpoint_id}/keys", response_model=EndpointAPIKeyResponse)
|
|
||||||
async def add_endpoint_key(
|
|
||||||
endpoint_id: str,
|
|
||||||
key_data: EndpointAPIKeyCreate,
|
|
||||||
request: Request,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
) -> EndpointAPIKeyResponse:
|
|
||||||
"""
|
|
||||||
为 Endpoint 添加 Key
|
|
||||||
|
|
||||||
为指定 Endpoint 添加新的 API Key,支持配置并发限制、速率倍数、
|
|
||||||
优先级、配额限制、能力限制等。
|
|
||||||
|
|
||||||
**路径参数**:
|
|
||||||
- `endpoint_id`: Endpoint ID
|
|
||||||
|
|
||||||
**请求体字段**:
|
|
||||||
- `endpoint_id`: Endpoint ID(必须与路径参数一致)
|
|
||||||
- `api_key`: API Key 原文(将被加密存储)
|
|
||||||
- `name`: Key 名称
|
|
||||||
- `note`: 备注(可选)
|
|
||||||
- `rate_multiplier`: 速率倍数(默认 1.0)
|
|
||||||
- `internal_priority`: 内部优先级(默认 100)
|
|
||||||
- `max_concurrent`: 最大并发数(null 表示自适应模式)
|
|
||||||
- `rate_limit`: 每分钟请求限制(可选)
|
|
||||||
- `daily_limit`: 每日请求限制(可选)
|
|
||||||
- `monthly_limit`: 每月请求限制(可选)
|
|
||||||
- `allowed_models`: 允许的模型列表(可选)
|
|
||||||
- `capabilities`: 能力配置(可选)
|
|
||||||
|
|
||||||
**返回字段**:
|
|
||||||
- 包含完整的 Key 信息,其中 `api_key_plain` 为原文(仅在创建时返回)
|
|
||||||
"""
|
|
||||||
adapter = AdminCreateEndpointKeyAdapter(endpoint_id=endpoint_id, key_data=key_data)
|
|
||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/keys/{key_id}", response_model=EndpointAPIKeyResponse)
|
@router.put("/keys/{key_id}", response_model=EndpointAPIKeyResponse)
|
||||||
async def update_endpoint_key(
|
async def update_endpoint_key(
|
||||||
key_id: str,
|
key_id: str,
|
||||||
@@ -117,7 +37,7 @@ async def update_endpoint_key(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> EndpointAPIKeyResponse:
|
) -> EndpointAPIKeyResponse:
|
||||||
"""
|
"""
|
||||||
更新 Endpoint Key
|
更新 Provider Key
|
||||||
|
|
||||||
更新指定 Key 的配置,支持修改并发限制、速率倍数、优先级、
|
更新指定 Key 的配置,支持修改并发限制、速率倍数、优先级、
|
||||||
配额限制、能力限制等。支持部分更新。
|
配额限制、能力限制等。支持部分更新。
|
||||||
@@ -131,10 +51,7 @@ async def update_endpoint_key(
|
|||||||
- `note`: 备注
|
- `note`: 备注
|
||||||
- `rate_multiplier`: 速率倍数
|
- `rate_multiplier`: 速率倍数
|
||||||
- `internal_priority`: 内部优先级
|
- `internal_priority`: 内部优先级
|
||||||
- `max_concurrent`: 最大并发数(设置为 null 可切换到自适应模式)
|
- `rpm_limit`: RPM 限制(设置为 null 可切换到自适应模式)
|
||||||
- `rate_limit`: 每分钟请求限制
|
|
||||||
- `daily_limit`: 每日请求限制
|
|
||||||
- `monthly_limit`: 每月请求限制
|
|
||||||
- `allowed_models`: 允许的模型列表
|
- `allowed_models`: 允许的模型列表
|
||||||
- `capabilities`: 能力配置
|
- `capabilities`: 能力配置
|
||||||
- `is_active`: 是否活跃
|
- `is_active`: 是否活跃
|
||||||
@@ -209,7 +126,7 @@ async def delete_endpoint_key(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
删除 Endpoint Key
|
删除 Provider Key
|
||||||
|
|
||||||
删除指定的 API Key。此操作不可逆,请谨慎使用。
|
删除指定的 API Key。此操作不可逆,请谨慎使用。
|
||||||
|
|
||||||
@@ -223,163 +140,66 @@ async def delete_endpoint_key(
|
|||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{endpoint_id}/keys/batch-priority")
|
# ========== Provider Keys API ==========
|
||||||
async def batch_update_key_priority(
|
|
||||||
endpoint_id: str,
|
|
||||||
request: Request,
|
|
||||||
priority_data: BatchUpdateKeyPriorityRequest,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
批量更新 Endpoint 下 Keys 的优先级
|
|
||||||
|
|
||||||
批量更新指定 Endpoint 下多个 Key 的内部优先级,用于拖动排序。
|
|
||||||
所有 Key 必须属于指定的 Endpoint。
|
@router.get("/providers/{provider_id}/keys", response_model=List[EndpointAPIKeyResponse])
|
||||||
|
async def list_provider_keys(
|
||||||
|
provider_id: str,
|
||||||
|
request: Request,
|
||||||
|
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
||||||
|
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> List[EndpointAPIKeyResponse]:
|
||||||
|
"""
|
||||||
|
获取 Provider 的所有 Keys
|
||||||
|
|
||||||
|
获取指定 Provider 下的所有 API Key 列表,支持多 API 格式。
|
||||||
|
结果按优先级和创建时间排序。
|
||||||
|
|
||||||
**路径参数**:
|
**路径参数**:
|
||||||
- `endpoint_id`: Endpoint ID
|
- `provider_id`: Provider ID
|
||||||
|
|
||||||
|
**查询参数**:
|
||||||
|
- `skip`: 跳过的记录数,用于分页(默认 0)
|
||||||
|
- `limit`: 返回的最大记录数(1-1000,默认 100)
|
||||||
|
"""
|
||||||
|
adapter = AdminListProviderKeysAdapter(
|
||||||
|
provider_id=provider_id,
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/providers/{provider_id}/keys", response_model=EndpointAPIKeyResponse)
|
||||||
|
async def add_provider_key(
|
||||||
|
provider_id: str,
|
||||||
|
key_data: EndpointAPIKeyCreate,
|
||||||
|
request: Request,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> EndpointAPIKeyResponse:
|
||||||
|
"""
|
||||||
|
为 Provider 添加 Key
|
||||||
|
|
||||||
|
为指定 Provider 添加新的 API Key,支持配置多个 API 格式。
|
||||||
|
|
||||||
|
**路径参数**:
|
||||||
|
- `provider_id`: Provider ID
|
||||||
|
|
||||||
**请求体字段**:
|
**请求体字段**:
|
||||||
- `priorities`: 优先级列表
|
- `api_formats`: 支持的 API 格式列表(必填)
|
||||||
- `key_id`: Key ID
|
- `api_key`: API Key 原文(将被加密存储)
|
||||||
- `internal_priority`: 新的内部优先级
|
- `name`: Key 名称
|
||||||
|
- 其他配置字段同 Key
|
||||||
**返回字段**:
|
|
||||||
- `message`: 操作结果消息
|
|
||||||
- `updated_count`: 实际更新的 Key 数量
|
|
||||||
"""
|
"""
|
||||||
adapter = AdminBatchUpdateKeyPriorityAdapter(endpoint_id=endpoint_id, priority_data=priority_data)
|
adapter = AdminCreateProviderKeyAdapter(provider_id=provider_id, key_data=key_data)
|
||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
# -------- Adapters --------
|
# -------- Adapters --------
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AdminListEndpointKeysAdapter(AdminApiAdapter):
|
|
||||||
endpoint_id: str
|
|
||||||
skip: int
|
|
||||||
limit: int
|
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
|
||||||
db = context.db
|
|
||||||
endpoint = (
|
|
||||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
|
||||||
)
|
|
||||||
if not endpoint:
|
|
||||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
|
||||||
|
|
||||||
keys = (
|
|
||||||
db.query(ProviderAPIKey)
|
|
||||||
.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
|
|
||||||
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
|
|
||||||
.offset(self.skip)
|
|
||||||
.limit(self.limit)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
result: List[EndpointAPIKeyResponse] = []
|
|
||||||
for key in keys:
|
|
||||||
try:
|
|
||||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
|
||||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
|
||||||
except Exception:
|
|
||||||
masked_key = "***ERROR***"
|
|
||||||
|
|
||||||
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
|
|
||||||
avg_response_time_ms = (
|
|
||||||
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
is_adaptive = key.max_concurrent is None
|
|
||||||
key_dict = key.__dict__.copy()
|
|
||||||
key_dict.pop("_sa_instance_state", None)
|
|
||||||
key_dict.update(
|
|
||||||
{
|
|
||||||
"api_key_masked": masked_key,
|
|
||||||
"api_key_plain": None,
|
|
||||||
"success_rate": success_rate,
|
|
||||||
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
|
||||||
"is_adaptive": is_adaptive,
|
|
||||||
"effective_limit": (
|
|
||||||
key.learned_max_concurrent if is_adaptive else key.max_concurrent
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
result.append(EndpointAPIKeyResponse(**key_dict))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AdminCreateEndpointKeyAdapter(AdminApiAdapter):
|
|
||||||
endpoint_id: str
|
|
||||||
key_data: EndpointAPIKeyCreate
|
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
|
||||||
db = context.db
|
|
||||||
endpoint = (
|
|
||||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
|
||||||
)
|
|
||||||
if not endpoint:
|
|
||||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
|
||||||
|
|
||||||
if self.key_data.endpoint_id != self.endpoint_id:
|
|
||||||
raise InvalidRequestException("endpoint_id 不匹配")
|
|
||||||
|
|
||||||
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
# max_concurrent=NULL 表示自适应模式,数字表示固定限制
|
|
||||||
new_key = ProviderAPIKey(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
endpoint_id=self.endpoint_id,
|
|
||||||
api_key=encrypted_key,
|
|
||||||
name=self.key_data.name,
|
|
||||||
note=self.key_data.note,
|
|
||||||
rate_multiplier=self.key_data.rate_multiplier,
|
|
||||||
internal_priority=self.key_data.internal_priority,
|
|
||||||
max_concurrent=self.key_data.max_concurrent, # NULL=自适应模式
|
|
||||||
rate_limit=self.key_data.rate_limit,
|
|
||||||
daily_limit=self.key_data.daily_limit,
|
|
||||||
monthly_limit=self.key_data.monthly_limit,
|
|
||||||
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
|
|
||||||
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
|
|
||||||
request_count=0,
|
|
||||||
success_count=0,
|
|
||||||
error_count=0,
|
|
||||||
total_response_time_ms=0,
|
|
||||||
is_active=True,
|
|
||||||
last_used_at=None,
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(new_key)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(new_key)
|
|
||||||
|
|
||||||
logger.info(f"[OK] 添加 Key: Endpoint={self.endpoint_id}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}")
|
|
||||||
|
|
||||||
masked_key = f"{self.key_data.api_key[:8]}***{self.key_data.api_key[-4:]}"
|
|
||||||
is_adaptive = new_key.max_concurrent is None
|
|
||||||
response_dict = new_key.__dict__.copy()
|
|
||||||
response_dict.pop("_sa_instance_state", None)
|
|
||||||
response_dict.update(
|
|
||||||
{
|
|
||||||
"api_key_masked": masked_key,
|
|
||||||
"api_key_plain": self.key_data.api_key,
|
|
||||||
"success_rate": 0.0,
|
|
||||||
"avg_response_time_ms": 0.0,
|
|
||||||
"is_adaptive": is_adaptive,
|
|
||||||
"effective_limit": (
|
|
||||||
new_key.learned_max_concurrent if is_adaptive else new_key.max_concurrent
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return EndpointAPIKeyResponse(**response_dict)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
||||||
key_id: str
|
key_id: str
|
||||||
@@ -395,14 +215,21 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
|||||||
if "api_key" in update_data:
|
if "api_key" in update_data:
|
||||||
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
|
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
|
||||||
|
|
||||||
# 特殊处理 max_concurrent:需要区分"未提供"和"显式设置为 null"
|
# 特殊处理 rpm_limit:需要区分"未提供"和"显式设置为 null"
|
||||||
# 当 max_concurrent 被显式设置时(在 model_fields_set 中),即使值为 None 也应该更新
|
if "rpm_limit" in self.key_data.model_fields_set:
|
||||||
if "max_concurrent" in self.key_data.model_fields_set:
|
update_data["rpm_limit"] = self.key_data.rpm_limit
|
||||||
update_data["max_concurrent"] = self.key_data.max_concurrent
|
if self.key_data.rpm_limit is None:
|
||||||
# 切换到自适应模式时,清空学习到的并发限制,让系统重新学习
|
update_data["learned_rpm_limit"] = None
|
||||||
if self.key_data.max_concurrent is None:
|
logger.info("Key %s 切换为自适应 RPM 模式", self.key_id)
|
||||||
update_data["learned_max_concurrent"] = None
|
|
||||||
logger.info("Key %s 切换为自适应并发模式", self.key_id)
|
# 统一处理 allowed_models:空列表/空字典 -> None(表示不限制)
|
||||||
|
if "allowed_models" in update_data:
|
||||||
|
am = update_data["allowed_models"]
|
||||||
|
if am is not None and (
|
||||||
|
(isinstance(am, list) and len(am) == 0)
|
||||||
|
or (isinstance(am, dict) and len(am) == 0)
|
||||||
|
):
|
||||||
|
update_data["allowed_models"] = None
|
||||||
|
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(key, field, value)
|
setattr(key, field, value)
|
||||||
@@ -411,35 +238,13 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(key)
|
db.refresh(key)
|
||||||
|
|
||||||
|
# 任何字段更新都清除缓存,确保缓存一致性
|
||||||
|
# 包括 is_active、allowed_models、capabilities 等影响权限和行为的字段
|
||||||
|
await ProviderCacheService.invalidate_provider_api_key_cache(self.key_id)
|
||||||
|
|
||||||
logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys()))
|
logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys()))
|
||||||
|
|
||||||
try:
|
return _build_key_response(key)
|
||||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
|
||||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
|
||||||
except Exception:
|
|
||||||
masked_key = "***ERROR***"
|
|
||||||
|
|
||||||
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
|
|
||||||
avg_response_time_ms = (
|
|
||||||
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
is_adaptive = key.max_concurrent is None
|
|
||||||
response_dict = key.__dict__.copy()
|
|
||||||
response_dict.pop("_sa_instance_state", None)
|
|
||||||
response_dict.update(
|
|
||||||
{
|
|
||||||
"api_key_masked": masked_key,
|
|
||||||
"api_key_plain": None,
|
|
||||||
"success_rate": success_rate,
|
|
||||||
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
|
||||||
"is_adaptive": is_adaptive,
|
|
||||||
"effective_limit": (
|
|
||||||
key.learned_max_concurrent if is_adaptive else key.max_concurrent
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return EndpointAPIKeyResponse(**response_dict)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -476,7 +281,7 @@ class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
|
|||||||
if not key:
|
if not key:
|
||||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||||
|
|
||||||
endpoint_id = key.endpoint_id
|
provider_id = key.provider_id
|
||||||
try:
|
try:
|
||||||
db.delete(key)
|
db.delete(key)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -485,7 +290,7 @@ class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
|
|||||||
logger.error(f"删除 Key 失败: ID={self.key_id}, Error={exc}")
|
logger.error(f"删除 Key 失败: ID={self.key_id}, Error={exc}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Endpoint={endpoint_id}")
|
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Provider={provider_id}")
|
||||||
return {"message": f"Key {self.key_id} 已删除"}
|
return {"message": f"Key {self.key_id} 已删除"}
|
||||||
|
|
||||||
|
|
||||||
@@ -493,31 +298,51 @@ class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
|
|||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
db = context.db
|
db = context.db
|
||||||
|
|
||||||
|
# Key 属于 Provider:按 key.api_formats 分组展示
|
||||||
keys = (
|
keys = (
|
||||||
db.query(ProviderAPIKey, ProviderEndpoint, Provider)
|
db.query(ProviderAPIKey, Provider)
|
||||||
.join(ProviderEndpoint, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
|
.join(Provider, ProviderAPIKey.provider_id == Provider.id)
|
||||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
|
||||||
.filter(
|
.filter(
|
||||||
ProviderAPIKey.is_active.is_(True),
|
ProviderAPIKey.is_active.is_(True),
|
||||||
ProviderEndpoint.is_active.is_(True),
|
|
||||||
Provider.is_active.is_(True),
|
Provider.is_active.is_(True),
|
||||||
)
|
)
|
||||||
.order_by(
|
.order_by(
|
||||||
ProviderAPIKey.global_priority.asc().nullslast(), ProviderAPIKey.internal_priority.asc()
|
ProviderAPIKey.global_priority.asc().nullslast(),
|
||||||
|
ProviderAPIKey.internal_priority.asc(),
|
||||||
)
|
)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
provider_ids = {str(provider.id) for _key, provider in keys}
|
||||||
|
endpoints = (
|
||||||
|
db.query(
|
||||||
|
ProviderEndpoint.provider_id,
|
||||||
|
ProviderEndpoint.api_format,
|
||||||
|
ProviderEndpoint.base_url,
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
ProviderEndpoint.provider_id.in_(provider_ids),
|
||||||
|
ProviderEndpoint.is_active.is_(True),
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
endpoint_base_url_map: Dict[tuple[str, str], str] = {}
|
||||||
|
for provider_id, api_format, base_url in endpoints:
|
||||||
|
fmt = api_format.value if hasattr(api_format, "value") else str(api_format)
|
||||||
|
endpoint_base_url_map[(str(provider_id), fmt)] = base_url
|
||||||
|
|
||||||
grouped: Dict[str, List[dict]] = {}
|
grouped: Dict[str, List[dict]] = {}
|
||||||
for key, endpoint, provider in keys:
|
for key, provider in keys:
|
||||||
api_format = endpoint.api_format
|
api_formats = key.api_formats or []
|
||||||
if api_format not in grouped:
|
|
||||||
grouped[api_format] = []
|
if not api_formats:
|
||||||
|
continue # 跳过没有 API 格式的 Key
|
||||||
|
|
||||||
try:
|
try:
|
||||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
logger.error(f"解密 Key 失败: key_id={key.id}, error={e}")
|
||||||
masked_key = "***ERROR***"
|
masked_key = "***ERROR***"
|
||||||
|
|
||||||
# 计算健康度指标
|
# 计算健康度指标
|
||||||
@@ -536,72 +361,209 @@ class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
|
|||||||
cap_def = get_capability(cap_name)
|
cap_def = get_capability(cap_name)
|
||||||
caps_list.append(cap_def.short_name if cap_def else cap_name)
|
caps_list.append(cap_def.short_name if cap_def else cap_name)
|
||||||
|
|
||||||
grouped[api_format].append(
|
# 构建 Key 信息(基础数据)
|
||||||
{
|
key_info = {
|
||||||
"id": key.id,
|
"id": key.id,
|
||||||
"name": key.name,
|
"name": key.name,
|
||||||
"api_key_masked": masked_key,
|
"api_key_masked": masked_key,
|
||||||
"internal_priority": key.internal_priority,
|
"internal_priority": key.internal_priority,
|
||||||
"global_priority": key.global_priority,
|
"global_priority": key.global_priority,
|
||||||
"rate_multiplier": key.rate_multiplier,
|
"rate_multiplier": key.rate_multiplier,
|
||||||
"is_active": key.is_active,
|
"is_active": key.is_active,
|
||||||
"circuit_breaker_open": key.circuit_breaker_open,
|
"provider_name": provider.name,
|
||||||
"provider_name": provider.display_name or provider.name,
|
"api_formats": api_formats,
|
||||||
"endpoint_base_url": endpoint.base_url,
|
"capabilities": caps_list,
|
||||||
"api_format": api_format,
|
"success_rate": success_rate,
|
||||||
"capabilities": caps_list,
|
"avg_response_time_ms": avg_response_time_ms,
|
||||||
"success_rate": success_rate,
|
"request_count": key.request_count,
|
||||||
"avg_response_time_ms": avg_response_time_ms,
|
}
|
||||||
"request_count": key.request_count,
|
|
||||||
}
|
# 将 Key 添加到每个支持的格式分组中,并附加格式特定的健康度数据
|
||||||
)
|
health_by_format = key.health_by_format or {}
|
||||||
|
circuit_by_format = key.circuit_breaker_by_format or {}
|
||||||
|
provider_id = str(provider.id)
|
||||||
|
for api_format in api_formats:
|
||||||
|
if api_format not in grouped:
|
||||||
|
grouped[api_format] = []
|
||||||
|
# 为每个格式创建副本,设置当前格式
|
||||||
|
format_key_info = key_info.copy()
|
||||||
|
format_key_info["api_format"] = api_format
|
||||||
|
format_key_info["endpoint_base_url"] = endpoint_base_url_map.get(
|
||||||
|
(provider_id, api_format)
|
||||||
|
)
|
||||||
|
# 添加格式特定的健康度数据
|
||||||
|
format_health = health_by_format.get(api_format, {})
|
||||||
|
format_circuit = circuit_by_format.get(api_format, {})
|
||||||
|
format_key_info["health_score"] = float(format_health.get("health_score") or 1.0)
|
||||||
|
format_key_info["circuit_breaker_open"] = bool(format_circuit.get("open", False))
|
||||||
|
grouped[api_format].append(format_key_info)
|
||||||
|
|
||||||
# 直接返回分组对象,供前端使用
|
# 直接返回分组对象,供前端使用
|
||||||
return grouped
|
return grouped
|
||||||
|
|
||||||
|
|
||||||
|
# ========== Adapters ==========
|
||||||
|
|
||||||
|
|
||||||
|
def _build_key_response(
|
||||||
|
key: ProviderAPIKey, api_key_plain: str | None = None
|
||||||
|
) -> EndpointAPIKeyResponse:
|
||||||
|
"""构建 Key 响应对象的辅助函数"""
|
||||||
|
try:
|
||||||
|
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||||
|
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
||||||
|
except Exception:
|
||||||
|
masked_key = "***ERROR***"
|
||||||
|
|
||||||
|
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
|
||||||
|
avg_response_time_ms = (
|
||||||
|
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
is_adaptive = key.rpm_limit is None
|
||||||
|
key_dict = key.__dict__.copy()
|
||||||
|
key_dict.pop("_sa_instance_state", None)
|
||||||
|
|
||||||
|
# 从 health_by_format 计算汇总字段(便于列表展示)
|
||||||
|
health_by_format = key.health_by_format or {}
|
||||||
|
circuit_by_format = key.circuit_breaker_by_format or {}
|
||||||
|
|
||||||
|
# 计算整体健康度(取所有格式中的最低值)
|
||||||
|
if health_by_format:
|
||||||
|
health_scores = [
|
||||||
|
float(h.get("health_score") or 1.0) for h in health_by_format.values()
|
||||||
|
]
|
||||||
|
min_health_score = min(health_scores) if health_scores else 1.0
|
||||||
|
# 取最大的连续失败次数
|
||||||
|
max_consecutive = max(
|
||||||
|
(int(h.get("consecutive_failures") or 0) for h in health_by_format.values()),
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
# 取最近的失败时间
|
||||||
|
failure_times = [
|
||||||
|
h.get("last_failure_at")
|
||||||
|
for h in health_by_format.values()
|
||||||
|
if h.get("last_failure_at")
|
||||||
|
]
|
||||||
|
last_failure = max(failure_times) if failure_times else None
|
||||||
|
else:
|
||||||
|
min_health_score = 1.0
|
||||||
|
max_consecutive = 0
|
||||||
|
last_failure = None
|
||||||
|
|
||||||
|
# 检查是否有任何格式的熔断器打开
|
||||||
|
any_circuit_open = any(c.get("open", False) for c in circuit_by_format.values())
|
||||||
|
|
||||||
|
key_dict.update(
|
||||||
|
{
|
||||||
|
"api_key_masked": masked_key,
|
||||||
|
"api_key_plain": api_key_plain,
|
||||||
|
"success_rate": success_rate,
|
||||||
|
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
||||||
|
"is_adaptive": is_adaptive,
|
||||||
|
"effective_limit": (
|
||||||
|
(key.learned_rpm_limit if key.learned_rpm_limit is not None else RPMDefaults.INITIAL_LIMIT)
|
||||||
|
if is_adaptive
|
||||||
|
else key.rpm_limit
|
||||||
|
),
|
||||||
|
# 汇总字段
|
||||||
|
"health_score": min_health_score,
|
||||||
|
"consecutive_failures": max_consecutive,
|
||||||
|
"last_failure_at": last_failure,
|
||||||
|
"circuit_breaker_open": any_circuit_open,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 防御性:确保 api_formats 存在(历史数据可能为空/缺失)
|
||||||
|
if "api_formats" not in key_dict or key_dict["api_formats"] is None:
|
||||||
|
key_dict["api_formats"] = []
|
||||||
|
|
||||||
|
return EndpointAPIKeyResponse(**key_dict)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdminBatchUpdateKeyPriorityAdapter(AdminApiAdapter):
|
class AdminListProviderKeysAdapter(AdminApiAdapter):
|
||||||
endpoint_id: str
|
"""获取 Provider 的所有 Keys"""
|
||||||
priority_data: BatchUpdateKeyPriorityRequest
|
|
||||||
|
provider_id: str
|
||||||
|
skip: int
|
||||||
|
limit: int
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
db = context.db
|
db = context.db
|
||||||
endpoint = (
|
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
if not provider:
|
||||||
)
|
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
||||||
if not endpoint:
|
|
||||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
|
||||||
|
|
||||||
# 获取所有需要更新的 Key ID
|
|
||||||
key_ids = [item.key_id for item in self.priority_data.priorities]
|
|
||||||
|
|
||||||
# 验证所有 Key 都属于该 Endpoint
|
|
||||||
keys = (
|
keys = (
|
||||||
db.query(ProviderAPIKey)
|
db.query(ProviderAPIKey)
|
||||||
.filter(
|
.filter(ProviderAPIKey.provider_id == self.provider_id)
|
||||||
ProviderAPIKey.id.in_(key_ids),
|
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
|
||||||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
.offset(self.skip)
|
||||||
)
|
.limit(self.limit)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(keys) != len(key_ids):
|
return [_build_key_response(key) for key in keys]
|
||||||
found_ids = {k.id for k in keys}
|
|
||||||
missing_ids = set(key_ids) - found_ids
|
|
||||||
raise InvalidRequestException(f"Keys 不属于该 Endpoint 或不存在: {missing_ids}")
|
|
||||||
|
|
||||||
# 批量更新优先级
|
|
||||||
key_map = {k.id: k for k in keys}
|
|
||||||
updated_count = 0
|
|
||||||
for item in self.priority_data.priorities:
|
|
||||||
key = key_map.get(item.key_id)
|
|
||||||
if key and key.internal_priority != item.internal_priority:
|
|
||||||
key.internal_priority = item.internal_priority
|
|
||||||
key.updated_at = datetime.now(timezone.utc)
|
|
||||||
updated_count += 1
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdminCreateProviderKeyAdapter(AdminApiAdapter):
|
||||||
|
"""为 Provider 添加 Key"""
|
||||||
|
|
||||||
|
provider_id: str
|
||||||
|
key_data: EndpointAPIKeyCreate
|
||||||
|
|
||||||
|
async def handle(self, context): # type: ignore[override]
|
||||||
|
db = context.db
|
||||||
|
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||||
|
if not provider:
|
||||||
|
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
||||||
|
|
||||||
|
# 验证 api_formats 必填
|
||||||
|
if not self.key_data.api_formats:
|
||||||
|
raise InvalidRequestException("api_formats 为必填字段")
|
||||||
|
|
||||||
|
# 允许同一个 API Key 在同一 Provider 下添加多次
|
||||||
|
# 用户可以为不同的 API 格式创建独立的配置记录,便于分开管理
|
||||||
|
|
||||||
|
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
new_key = ProviderAPIKey(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
provider_id=self.provider_id,
|
||||||
|
api_formats=self.key_data.api_formats,
|
||||||
|
api_key=encrypted_key,
|
||||||
|
name=self.key_data.name,
|
||||||
|
note=self.key_data.note,
|
||||||
|
rate_multiplier=self.key_data.rate_multiplier,
|
||||||
|
rate_multipliers=self.key_data.rate_multipliers, # 按 API 格式的成本倍率
|
||||||
|
internal_priority=self.key_data.internal_priority,
|
||||||
|
rpm_limit=self.key_data.rpm_limit,
|
||||||
|
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
|
||||||
|
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
|
||||||
|
cache_ttl_minutes=self.key_data.cache_ttl_minutes,
|
||||||
|
max_probe_interval_minutes=self.key_data.max_probe_interval_minutes,
|
||||||
|
request_count=0,
|
||||||
|
success_count=0,
|
||||||
|
error_count=0,
|
||||||
|
total_response_time_ms=0,
|
||||||
|
health_by_format={}, # 按格式存储健康度
|
||||||
|
circuit_breaker_by_format={}, # 按格式存储熔断器状态
|
||||||
|
is_active=True,
|
||||||
|
last_used_at=None,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(new_key)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
db.refresh(new_key)
|
||||||
|
|
||||||
logger.info(f"[OK] 批量更新 Key 优先级: Endpoint={self.endpoint_id}, Updated={updated_count}/{len(key_ids)}")
|
logger.info(
|
||||||
return {"message": f"已更新 {updated_count} 个 Key 的优先级", "updated_count": updated_count}
|
f"[OK] 添加 Key: Provider={self.provider_id}, "
|
||||||
|
f"Formats={self.key_data.api_formats}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return _build_key_response(new_key, api_key_plain=self.key_data.api_key)
|
||||||
|
|||||||
@@ -67,8 +67,6 @@ async def list_provider_endpoints(
|
|||||||
- `custom_path`: 自定义路径
|
- `custom_path`: 自定义路径
|
||||||
- `timeout`: 超时时间(秒)
|
- `timeout`: 超时时间(秒)
|
||||||
- `max_retries`: 最大重试次数
|
- `max_retries`: 最大重试次数
|
||||||
- `max_concurrent`: 最大并发数
|
|
||||||
- `rate_limit`: 速率限制
|
|
||||||
- `is_active`: 是否活跃
|
- `is_active`: 是否活跃
|
||||||
- `total_keys`: Key 总数
|
- `total_keys`: Key 总数
|
||||||
- `active_keys`: 活跃 Key 数量
|
- `active_keys`: 活跃 Key 数量
|
||||||
@@ -107,8 +105,6 @@ async def create_provider_endpoint(
|
|||||||
- `headers`: 自定义请求头(可选)
|
- `headers`: 自定义请求头(可选)
|
||||||
- `timeout`: 超时时间(秒,默认 300)
|
- `timeout`: 超时时间(秒,默认 300)
|
||||||
- `max_retries`: 最大重试次数(默认 2)
|
- `max_retries`: 最大重试次数(默认 2)
|
||||||
- `max_concurrent`: 最大并发数(可选)
|
|
||||||
- `rate_limit`: 速率限制(可选)
|
|
||||||
- `config`: 额外配置(可选)
|
- `config`: 额外配置(可选)
|
||||||
- `proxy`: 代理配置(可选)
|
- `proxy`: 代理配置(可选)
|
||||||
|
|
||||||
@@ -145,8 +141,6 @@ async def get_endpoint(
|
|||||||
- `custom_path`: 自定义路径
|
- `custom_path`: 自定义路径
|
||||||
- `timeout`: 超时时间(秒)
|
- `timeout`: 超时时间(秒)
|
||||||
- `max_retries`: 最大重试次数
|
- `max_retries`: 最大重试次数
|
||||||
- `max_concurrent`: 最大并发数
|
|
||||||
- `rate_limit`: 速率限制
|
|
||||||
- `is_active`: 是否活跃
|
- `is_active`: 是否活跃
|
||||||
- `total_keys`: Key 总数
|
- `total_keys`: Key 总数
|
||||||
- `active_keys`: 活跃 Key 数量
|
- `active_keys`: 活跃 Key 数量
|
||||||
@@ -178,8 +172,6 @@ async def update_endpoint(
|
|||||||
- `headers`: 自定义请求头
|
- `headers`: 自定义请求头
|
||||||
- `timeout`: 超时时间(秒)
|
- `timeout`: 超时时间(秒)
|
||||||
- `max_retries`: 最大重试次数
|
- `max_retries`: 最大重试次数
|
||||||
- `max_concurrent`: 最大并发数
|
|
||||||
- `rate_limit`: 速率限制
|
|
||||||
- `is_active`: 是否活跃
|
- `is_active`: 是否活跃
|
||||||
- `config`: 额外配置
|
- `config`: 额外配置
|
||||||
- `proxy`: 代理配置(设置为 null 可清除代理)
|
- `proxy`: 代理配置(设置为 null 可清除代理)
|
||||||
@@ -203,15 +195,15 @@ async def delete_endpoint(
|
|||||||
"""
|
"""
|
||||||
删除 Endpoint
|
删除 Endpoint
|
||||||
|
|
||||||
删除指定的 Endpoint,同时级联删除所有关联的 API Keys。
|
删除指定的 Endpoint,会影响该 Provider 在该 API 格式下的路由能力。
|
||||||
此操作不可逆,请谨慎使用。
|
Key 不会被删除,但包含该 API 格式的 Key 将无法被调度使用(直到重新创建该格式的 Endpoint)。
|
||||||
|
|
||||||
**路径参数**:
|
**路径参数**:
|
||||||
- `endpoint_id`: Endpoint ID
|
- `endpoint_id`: Endpoint ID
|
||||||
|
|
||||||
**返回字段**:
|
**返回字段**:
|
||||||
- `message`: 操作结果消息
|
- `message`: 操作结果消息
|
||||||
- `deleted_keys_count`: 同时删除的 Key 数量
|
- `affected_keys_count`: 受影响的 Key 数量(包含该 API 格式)
|
||||||
"""
|
"""
|
||||||
adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id)
|
adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id)
|
||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
@@ -241,39 +233,33 @@ class AdminListProviderEndpointsAdapter(AdminApiAdapter):
|
|||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
endpoint_ids = [ep.id for ep in endpoints]
|
# Key 是 Provider 级别资源:按 key.api_formats 归类到各 Endpoint.api_format 下
|
||||||
total_keys_map = {}
|
keys = (
|
||||||
active_keys_map = {}
|
db.query(ProviderAPIKey.api_formats, ProviderAPIKey.is_active)
|
||||||
if endpoint_ids:
|
.filter(ProviderAPIKey.provider_id == self.provider_id)
|
||||||
total_rows = (
|
.all()
|
||||||
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("total"))
|
)
|
||||||
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
|
total_keys_map: dict[str, int] = {}
|
||||||
.group_by(ProviderAPIKey.endpoint_id)
|
active_keys_map: dict[str, int] = {}
|
||||||
.all()
|
for api_formats, is_active in keys:
|
||||||
)
|
for fmt in (api_formats or []):
|
||||||
total_keys_map = {row.endpoint_id: row.total for row in total_rows}
|
total_keys_map[fmt] = total_keys_map.get(fmt, 0) + 1
|
||||||
|
if is_active:
|
||||||
active_rows = (
|
active_keys_map[fmt] = active_keys_map.get(fmt, 0) + 1
|
||||||
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("active"))
|
|
||||||
.filter(
|
|
||||||
and_(
|
|
||||||
ProviderAPIKey.endpoint_id.in_(endpoint_ids),
|
|
||||||
ProviderAPIKey.is_active.is_(True),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.group_by(ProviderAPIKey.endpoint_id)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
active_keys_map = {row.endpoint_id: row.active for row in active_rows}
|
|
||||||
|
|
||||||
result: List[ProviderEndpointResponse] = []
|
result: List[ProviderEndpointResponse] = []
|
||||||
for endpoint in endpoints:
|
for endpoint in endpoints:
|
||||||
|
endpoint_format = (
|
||||||
|
endpoint.api_format
|
||||||
|
if isinstance(endpoint.api_format, str)
|
||||||
|
else endpoint.api_format.value
|
||||||
|
)
|
||||||
endpoint_dict = {
|
endpoint_dict = {
|
||||||
**endpoint.__dict__,
|
**endpoint.__dict__,
|
||||||
"provider_name": provider.name,
|
"provider_name": provider.name,
|
||||||
"api_format": endpoint.api_format,
|
"api_format": endpoint.api_format,
|
||||||
"total_keys": total_keys_map.get(endpoint.id, 0),
|
"total_keys": total_keys_map.get(endpoint_format, 0),
|
||||||
"active_keys": active_keys_map.get(endpoint.id, 0),
|
"active_keys": active_keys_map.get(endpoint_format, 0),
|
||||||
"proxy": mask_proxy_password(endpoint.proxy),
|
"proxy": mask_proxy_password(endpoint.proxy),
|
||||||
}
|
}
|
||||||
endpoint_dict.pop("_sa_instance_state", None)
|
endpoint_dict.pop("_sa_instance_state", None)
|
||||||
@@ -321,8 +307,6 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
|
|||||||
headers=self.endpoint_data.headers,
|
headers=self.endpoint_data.headers,
|
||||||
timeout=self.endpoint_data.timeout,
|
timeout=self.endpoint_data.timeout,
|
||||||
max_retries=self.endpoint_data.max_retries,
|
max_retries=self.endpoint_data.max_retries,
|
||||||
max_concurrent=self.endpoint_data.max_concurrent,
|
|
||||||
rate_limit=self.endpoint_data.rate_limit,
|
|
||||||
is_active=True,
|
is_active=True,
|
||||||
config=self.endpoint_data.config,
|
config=self.endpoint_data.config,
|
||||||
proxy=self.endpoint_data.proxy.model_dump() if self.endpoint_data.proxy else None,
|
proxy=self.endpoint_data.proxy.model_dump() if self.endpoint_data.proxy else None,
|
||||||
@@ -367,19 +351,23 @@ class AdminGetProviderEndpointAdapter(AdminApiAdapter):
|
|||||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||||
|
|
||||||
endpoint_obj, provider = endpoint
|
endpoint_obj, provider = endpoint
|
||||||
total_keys = (
|
endpoint_format = (
|
||||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
endpoint_obj.api_format
|
||||||
|
if isinstance(endpoint_obj.api_format, str)
|
||||||
|
else endpoint_obj.api_format.value
|
||||||
)
|
)
|
||||||
active_keys = (
|
keys = (
|
||||||
db.query(ProviderAPIKey)
|
db.query(ProviderAPIKey.api_formats, ProviderAPIKey.is_active)
|
||||||
.filter(
|
.filter(ProviderAPIKey.provider_id == endpoint_obj.provider_id)
|
||||||
and_(
|
.all()
|
||||||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
|
||||||
ProviderAPIKey.is_active.is_(True),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.count()
|
|
||||||
)
|
)
|
||||||
|
total_keys = 0
|
||||||
|
active_keys = 0
|
||||||
|
for api_formats, is_active in keys:
|
||||||
|
if endpoint_format in (api_formats or []):
|
||||||
|
total_keys += 1
|
||||||
|
if is_active:
|
||||||
|
active_keys += 1
|
||||||
|
|
||||||
endpoint_dict = {
|
endpoint_dict = {
|
||||||
k: v
|
k: v
|
||||||
@@ -431,19 +419,21 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
|
|||||||
provider = db.query(Provider).filter(Provider.id == endpoint.provider_id).first()
|
provider = db.query(Provider).filter(Provider.id == endpoint.provider_id).first()
|
||||||
logger.info(f"[OK] 更新 Endpoint: ID={self.endpoint_id}, Updates={list(update_data.keys())}")
|
logger.info(f"[OK] 更新 Endpoint: ID={self.endpoint_id}, Updates={list(update_data.keys())}")
|
||||||
|
|
||||||
total_keys = (
|
endpoint_format = (
|
||||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
endpoint.api_format if isinstance(endpoint.api_format, str) else endpoint.api_format.value
|
||||||
)
|
)
|
||||||
active_keys = (
|
keys = (
|
||||||
db.query(ProviderAPIKey)
|
db.query(ProviderAPIKey.api_formats, ProviderAPIKey.is_active)
|
||||||
.filter(
|
.filter(ProviderAPIKey.provider_id == endpoint.provider_id)
|
||||||
and_(
|
.all()
|
||||||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
|
||||||
ProviderAPIKey.is_active.is_(True),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.count()
|
|
||||||
)
|
)
|
||||||
|
total_keys = 0
|
||||||
|
active_keys = 0
|
||||||
|
for api_formats, is_active in keys:
|
||||||
|
if endpoint_format in (api_formats or []):
|
||||||
|
total_keys += 1
|
||||||
|
if is_active:
|
||||||
|
active_keys += 1
|
||||||
|
|
||||||
endpoint_dict = {
|
endpoint_dict = {
|
||||||
k: v
|
k: v
|
||||||
@@ -472,12 +462,26 @@ class AdminDeleteProviderEndpointAdapter(AdminApiAdapter):
|
|||||||
if not endpoint:
|
if not endpoint:
|
||||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||||
|
|
||||||
keys_count = (
|
endpoint_format = (
|
||||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
endpoint.api_format if isinstance(endpoint.api_format, str) else endpoint.api_format.value
|
||||||
|
)
|
||||||
|
keys = (
|
||||||
|
db.query(ProviderAPIKey.api_formats)
|
||||||
|
.filter(ProviderAPIKey.provider_id == endpoint.provider_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
affected_keys_count = sum(
|
||||||
|
1 for (api_formats,) in keys if endpoint_format in (api_formats or [])
|
||||||
)
|
)
|
||||||
db.delete(endpoint)
|
db.delete(endpoint)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
logger.warning(f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, 同时删除了 {keys_count} 个 Keys")
|
logger.warning(
|
||||||
|
f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, Format={endpoint_format}, "
|
||||||
|
f"AffectedKeys={affected_keys_count}"
|
||||||
|
)
|
||||||
|
|
||||||
return {"message": f"Endpoint {self.endpoint_id} 已删除", "deleted_keys_count": keys_count}
|
return {
|
||||||
|
"message": f"Endpoint {self.endpoint_id} 已删除",
|
||||||
|
"affected_keys_count": affected_keys_count,
|
||||||
|
}
|
||||||
|
|||||||
@@ -125,7 +125,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
|||||||
ModelCatalogProviderDetail(
|
ModelCatalogProviderDetail(
|
||||||
provider_id=provider.id,
|
provider_id=provider.id,
|
||||||
provider_name=provider.name,
|
provider_name=provider.name,
|
||||||
provider_display_name=provider.display_name,
|
|
||||||
model_id=model.id,
|
model_id=model.id,
|
||||||
target_model=model.provider_model_name,
|
target_model=model.provider_model_name,
|
||||||
# 显示有效价格
|
# 显示有效价格
|
||||||
|
|||||||
@@ -452,7 +452,6 @@ class AdminGetGlobalModelProvidersAdapter(AdminApiAdapter):
|
|||||||
ModelCatalogProviderDetail(
|
ModelCatalogProviderDetail(
|
||||||
provider_id=provider.id,
|
provider_id=provider.id,
|
||||||
provider_name=provider.name,
|
provider_name=provider.name,
|
||||||
provider_display_name=provider.display_name,
|
|
||||||
model_id=model.id,
|
model_id=model.id,
|
||||||
target_model=model.provider_model_name,
|
target_model=model.provider_model_name,
|
||||||
input_price_per_1m=model.get_effective_input_price(),
|
input_price_per_1m=model.get_effective_input_price(),
|
||||||
|
|||||||
@@ -819,7 +819,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
|
|||||||
"username": user.username if user else None,
|
"username": user.username if user else None,
|
||||||
"email": user.email if user else None,
|
"email": user.email if user else None,
|
||||||
"provider_id": provider_id,
|
"provider_id": provider_id,
|
||||||
"provider_name": provider.display_name if provider else None,
|
"provider_name": provider.name if provider else None,
|
||||||
"endpoint_id": endpoint_id,
|
"endpoint_id": endpoint_id,
|
||||||
"endpoint_api_format": (
|
"endpoint_api_format": (
|
||||||
endpoint.api_format if endpoint and endpoint.api_format else None
|
endpoint.api_format if endpoint and endpoint.api_format else None
|
||||||
@@ -1369,9 +1369,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
|||||||
for model, provider in models:
|
for model, provider in models:
|
||||||
# 检查是否是主模型名称
|
# 检查是否是主模型名称
|
||||||
if model.provider_model_name == mapping_name:
|
if model.provider_model_name == mapping_name:
|
||||||
provider_names.append(
|
provider_names.append(provider.name)
|
||||||
provider.display_name or provider.name
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
# 检查是否在映射列表中
|
# 检查是否在映射列表中
|
||||||
if model.provider_model_mappings:
|
if model.provider_model_mappings:
|
||||||
@@ -1381,9 +1379,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
|||||||
if isinstance(a, dict)
|
if isinstance(a, dict)
|
||||||
]
|
]
|
||||||
if mapping_name in mapping_list:
|
if mapping_name in mapping_list:
|
||||||
provider_names.append(
|
provider_names.append(provider.name)
|
||||||
provider.display_name or provider.name
|
|
||||||
)
|
|
||||||
provider_names = sorted(list(set(provider_names)))
|
provider_names = sorted(list(set(provider_names)))
|
||||||
|
|
||||||
mappings.append({
|
mappings.append({
|
||||||
@@ -1473,7 +1469,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
|||||||
|
|
||||||
provider_model_mappings.append({
|
provider_model_mappings.append({
|
||||||
"provider_id": provider_id,
|
"provider_id": provider_id,
|
||||||
"provider_name": provider.display_name or provider.name,
|
"provider_name": provider.name,
|
||||||
"global_model_id": global_model_id,
|
"global_model_id": global_model_id,
|
||||||
"global_model_name": global_model.name,
|
"global_model_name": global_model.name,
|
||||||
"global_model_display_name": global_model.display_name,
|
"global_model_display_name": global_model.display_name,
|
||||||
|
|||||||
@@ -13,10 +13,11 @@ from sqlalchemy.orm import Session, joinedload
|
|||||||
|
|
||||||
from src.api.handlers.base.chat_adapter_base import get_adapter_class
|
from src.api.handlers.base.chat_adapter_base import get_adapter_class
|
||||||
from src.api.handlers.base.cli_adapter_base import get_cli_adapter_class
|
from src.api.handlers.base.cli_adapter_base import get_cli_adapter_class
|
||||||
|
from src.config.constants import TimeoutDefaults
|
||||||
from src.core.crypto import crypto_service
|
from src.core.crypto import crypto_service
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.database.database import get_db
|
from src.database.database import get_db
|
||||||
from src.models.database import Provider, ProviderEndpoint, User
|
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
|
||||||
from src.utils.auth_utils import get_current_user
|
from src.utils.auth_utils import get_current_user
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
|
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
|
||||||
@@ -81,10 +82,13 @@ async def query_available_models(
|
|||||||
Returns:
|
Returns:
|
||||||
所有端点的模型列表(合并)
|
所有端点的模型列表(合并)
|
||||||
"""
|
"""
|
||||||
# 获取提供商及其端点
|
# 获取提供商及其端点和 API Keys
|
||||||
provider = (
|
provider = (
|
||||||
db.query(Provider)
|
db.query(Provider)
|
||||||
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
|
.options(
|
||||||
|
joinedload(Provider.endpoints),
|
||||||
|
joinedload(Provider.api_keys),
|
||||||
|
)
|
||||||
.filter(Provider.id == request.provider_id)
|
.filter(Provider.id == request.provider_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
@@ -95,49 +99,70 @@ async def query_available_models(
|
|||||||
# 收集所有活跃端点的配置
|
# 收集所有活跃端点的配置
|
||||||
endpoint_configs: list[dict] = []
|
endpoint_configs: list[dict] = []
|
||||||
|
|
||||||
|
# 构建 api_format -> endpoint 映射
|
||||||
|
format_to_endpoint: dict[str, ProviderEndpoint] = {}
|
||||||
|
for endpoint in provider.endpoints:
|
||||||
|
if endpoint.is_active:
|
||||||
|
format_to_endpoint[endpoint.api_format] = endpoint
|
||||||
|
|
||||||
if request.api_key_id:
|
if request.api_key_id:
|
||||||
# 指定了特定的 API Key,只使用该 Key 对应的端点
|
# 指定了特定的 API Key(从 provider.api_keys 查找)
|
||||||
for endpoint in provider.endpoints:
|
api_key = next(
|
||||||
for api_key in endpoint.api_keys:
|
(key for key in provider.api_keys if key.id == request.api_key_id),
|
||||||
if api_key.id == request.api_key_id:
|
None
|
||||||
try:
|
)
|
||||||
api_key_value = crypto_service.decrypt(api_key.api_key)
|
|
||||||
except Exception as e:
|
if not api_key:
|
||||||
logger.error(f"Failed to decrypt API key: {e}")
|
raise HTTPException(status_code=404, detail="API Key not found")
|
||||||
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
|
|
||||||
endpoint_configs.append({
|
try:
|
||||||
"api_key": api_key_value,
|
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||||
"base_url": endpoint.base_url,
|
except Exception as e:
|
||||||
"api_format": endpoint.api_format,
|
logger.error(f"Failed to decrypt API key: {e}")
|
||||||
"extra_headers": endpoint.headers,
|
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
|
||||||
})
|
|
||||||
break
|
# 根据 Key 的 api_formats 找对应的 Endpoint
|
||||||
if endpoint_configs:
|
key_formats = api_key.api_formats or []
|
||||||
break
|
for fmt in key_formats:
|
||||||
|
endpoint = format_to_endpoint.get(fmt)
|
||||||
|
if endpoint:
|
||||||
|
endpoint_configs.append({
|
||||||
|
"api_key": api_key_value,
|
||||||
|
"base_url": endpoint.base_url,
|
||||||
|
"api_format": fmt,
|
||||||
|
"extra_headers": endpoint.headers,
|
||||||
|
})
|
||||||
|
|
||||||
if not endpoint_configs:
|
if not endpoint_configs:
|
||||||
raise HTTPException(status_code=404, detail="API Key not found")
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="No matching endpoint found for this API Key's formats"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 遍历所有活跃端点,每个端点取第一个可用的 Key
|
# 遍历所有活跃端点,为每个端点找一个支持该格式的 Key
|
||||||
for endpoint in provider.endpoints:
|
for endpoint in provider.endpoints:
|
||||||
if not endpoint.is_active or not endpoint.api_keys:
|
if not endpoint.is_active:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 找第一个可用的 Key
|
# 找第一个支持该格式的可用 Key
|
||||||
for api_key in endpoint.api_keys:
|
for api_key in provider.api_keys:
|
||||||
if api_key.is_active:
|
if not api_key.is_active:
|
||||||
try:
|
continue
|
||||||
api_key_value = crypto_service.decrypt(api_key.api_key)
|
key_formats = api_key.api_formats or []
|
||||||
except Exception as e:
|
if endpoint.api_format not in key_formats:
|
||||||
logger.error(f"Failed to decrypt API key: {e}")
|
continue
|
||||||
continue # 尝试下一个 Key
|
try:
|
||||||
endpoint_configs.append({
|
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||||
"api_key": api_key_value,
|
except Exception as e:
|
||||||
"base_url": endpoint.base_url,
|
logger.error(f"Failed to decrypt API key: {e}")
|
||||||
"api_format": endpoint.api_format,
|
continue
|
||||||
"extra_headers": endpoint.headers,
|
endpoint_configs.append({
|
||||||
})
|
"api_key": api_key_value,
|
||||||
break # 只取第一个可用的 Key
|
"base_url": endpoint.base_url,
|
||||||
|
"api_format": endpoint.api_format,
|
||||||
|
"extra_headers": endpoint.headers,
|
||||||
|
})
|
||||||
|
break # 只取第一个可用的 Key
|
||||||
|
|
||||||
if not endpoint_configs:
|
if not endpoint_configs:
|
||||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||||
@@ -214,7 +239,6 @@ async def query_available_models(
|
|||||||
"provider": {
|
"provider": {
|
||||||
"id": provider.id,
|
"id": provider.id,
|
||||||
"name": provider.name,
|
"name": provider.name,
|
||||||
"display_name": provider.display_name,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,17 +253,14 @@ async def test_model(
|
|||||||
测试模型连接性
|
测试模型连接性
|
||||||
|
|
||||||
向指定提供商的指定模型发送测试请求,验证模型是否可用
|
向指定提供商的指定模型发送测试请求,验证模型是否可用
|
||||||
|
|
||||||
Args:
|
|
||||||
request: 测试请求
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
测试结果
|
|
||||||
"""
|
"""
|
||||||
# 获取提供商及其端点
|
# 获取提供商及其端点和 Keys
|
||||||
provider = (
|
provider = (
|
||||||
db.query(Provider)
|
db.query(Provider)
|
||||||
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
|
.options(
|
||||||
|
joinedload(Provider.endpoints),
|
||||||
|
joinedload(Provider.api_keys),
|
||||||
|
)
|
||||||
.filter(Provider.id == request.provider_id)
|
.filter(Provider.id == request.provider_id)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
@@ -247,28 +268,38 @@ async def test_model(
|
|||||||
if not provider:
|
if not provider:
|
||||||
raise HTTPException(status_code=404, detail="Provider not found")
|
raise HTTPException(status_code=404, detail="Provider not found")
|
||||||
|
|
||||||
# 找到合适的端点和API Key
|
# 构建 api_format -> endpoint 映射
|
||||||
endpoint_config = None
|
format_to_endpoint: dict[str, ProviderEndpoint] = {}
|
||||||
|
for ep in provider.endpoints:
|
||||||
|
if ep.is_active:
|
||||||
|
format_to_endpoint[ep.api_format] = ep
|
||||||
|
|
||||||
|
# 找到合适的端点和 API Key
|
||||||
endpoint = None
|
endpoint = None
|
||||||
api_key = None
|
api_key = None
|
||||||
|
|
||||||
if request.api_key_id:
|
if request.api_key_id:
|
||||||
# 使用指定的API Key
|
# 使用指定的 API Key
|
||||||
for ep in provider.endpoints:
|
api_key = next(
|
||||||
for key in ep.api_keys:
|
(key for key in provider.api_keys if key.id == request.api_key_id and key.is_active),
|
||||||
if key.id == request.api_key_id and key.is_active and ep.is_active:
|
None
|
||||||
endpoint = ep
|
)
|
||||||
api_key = key
|
if api_key:
|
||||||
|
# 找到该 Key 支持的第一个活跃 Endpoint
|
||||||
|
for fmt in (api_key.api_formats or []):
|
||||||
|
if fmt in format_to_endpoint:
|
||||||
|
endpoint = format_to_endpoint[fmt]
|
||||||
break
|
break
|
||||||
if endpoint:
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
# 使用第一个可用的端点和密钥
|
# 使用第一个可用的端点和密钥
|
||||||
for ep in provider.endpoints:
|
for ep in provider.endpoints:
|
||||||
if not ep.is_active or not ep.api_keys:
|
if not ep.is_active:
|
||||||
continue
|
continue
|
||||||
for key in ep.api_keys:
|
# 找支持该格式的第一个可用 Key
|
||||||
if key.is_active:
|
for key in provider.api_keys:
|
||||||
|
if not key.is_active:
|
||||||
|
continue
|
||||||
|
if ep.api_format in (key.api_formats or []):
|
||||||
endpoint = ep
|
endpoint = ep
|
||||||
api_key = key
|
api_key = key
|
||||||
break
|
break
|
||||||
@@ -284,14 +315,14 @@ async def test_model(
|
|||||||
logger.error(f"[test-model] Failed to decrypt API key: {e}")
|
logger.error(f"[test-model] Failed to decrypt API key: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
|
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
|
||||||
|
|
||||||
# 构建请求配置
|
# 构建请求配置(timeout 从 Provider 读取)
|
||||||
endpoint_config = {
|
endpoint_config = {
|
||||||
"api_key": api_key_value,
|
"api_key": api_key_value,
|
||||||
"api_key_id": api_key.id, # 添加API Key ID用于用量记录
|
"api_key_id": api_key.id, # 添加API Key ID用于用量记录
|
||||||
"base_url": endpoint.base_url,
|
"base_url": endpoint.base_url,
|
||||||
"api_format": endpoint.api_format,
|
"api_format": endpoint.api_format,
|
||||||
"extra_headers": endpoint.headers,
|
"extra_headers": endpoint.headers,
|
||||||
"timeout": endpoint.timeout or 30.0,
|
"timeout": provider.timeout or TimeoutDefaults.HTTP_REQUEST,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -304,7 +335,6 @@ async def test_model(
|
|||||||
"provider": {
|
"provider": {
|
||||||
"id": provider.id,
|
"id": provider.id,
|
||||||
"name": provider.name,
|
"name": provider.name,
|
||||||
"display_name": provider.display_name,
|
|
||||||
},
|
},
|
||||||
"model": request.model_name,
|
"model": request.model_name,
|
||||||
}
|
}
|
||||||
@@ -325,7 +355,6 @@ async def test_model(
|
|||||||
"provider": {
|
"provider": {
|
||||||
"id": provider.id,
|
"id": provider.id,
|
||||||
"name": provider.name,
|
"name": provider.name,
|
||||||
"display_name": provider.display_name,
|
|
||||||
},
|
},
|
||||||
"model": request.model_name,
|
"model": request.model_name,
|
||||||
}
|
}
|
||||||
@@ -415,7 +444,6 @@ async def test_model(
|
|||||||
"provider": {
|
"provider": {
|
||||||
"id": provider.id,
|
"id": provider.id,
|
||||||
"name": provider.name,
|
"name": provider.name,
|
||||||
"display_name": provider.display_name,
|
|
||||||
},
|
},
|
||||||
"model": request.model_name,
|
"model": request.model_name,
|
||||||
"endpoint": {
|
"endpoint": {
|
||||||
@@ -433,7 +461,6 @@ async def test_model(
|
|||||||
"provider": {
|
"provider": {
|
||||||
"id": provider.id,
|
"id": provider.id,
|
||||||
"name": provider.name,
|
"name": provider.name,
|
||||||
"display_name": provider.display_name,
|
|
||||||
},
|
},
|
||||||
"model": request.model_name,
|
"model": request.model_name,
|
||||||
"endpoint": {
|
"endpoint": {
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ async def get_provider_stats(
|
|||||||
"""
|
"""
|
||||||
获取提供商统计数据
|
获取提供商统计数据
|
||||||
|
|
||||||
获取指定提供商的计费信息、RPM 使用情况和使用统计数据。
|
获取指定提供商的计费信息和使用统计数据。
|
||||||
|
|
||||||
**路径参数**:
|
**路径参数**:
|
||||||
- `provider_id`: 提供商 ID
|
- `provider_id`: 提供商 ID
|
||||||
@@ -96,10 +96,6 @@ async def get_provider_stats(
|
|||||||
- `monthly_used_usd`: 月度已使用
|
- `monthly_used_usd`: 月度已使用
|
||||||
- `quota_remaining_usd`: 剩余配额
|
- `quota_remaining_usd`: 剩余配额
|
||||||
- `quota_expires_at`: 配额过期时间
|
- `quota_expires_at`: 配额过期时间
|
||||||
- `rpm_info`: RPM 信息
|
|
||||||
- `rpm_limit`: RPM 限制
|
|
||||||
- `rpm_used`: 已使用 RPM
|
|
||||||
- `rpm_reset_at`: RPM 重置时间
|
|
||||||
- `usage_stats`: 使用统计
|
- `usage_stats`: 使用统计
|
||||||
- `total_requests`: 总请求数
|
- `total_requests`: 总请求数
|
||||||
- `successful_requests`: 成功请求数
|
- `successful_requests`: 成功请求数
|
||||||
@@ -165,7 +161,6 @@ class AdminProviderBillingAdapter(AdminApiAdapter):
|
|||||||
provider.billing_type = config.billing_type
|
provider.billing_type = config.billing_type
|
||||||
provider.monthly_quota_usd = config.monthly_quota_usd
|
provider.monthly_quota_usd = config.monthly_quota_usd
|
||||||
provider.quota_reset_day = config.quota_reset_day
|
provider.quota_reset_day = config.quota_reset_day
|
||||||
provider.rpm_limit = config.rpm_limit
|
|
||||||
provider.provider_priority = config.provider_priority
|
provider.provider_priority = config.provider_priority
|
||||||
|
|
||||||
from dateutil import parser
|
from dateutil import parser
|
||||||
@@ -262,13 +257,6 @@ class AdminProviderStatsAdapter(AdminApiAdapter):
|
|||||||
provider.quota_expires_at.isoformat() if provider.quota_expires_at else None
|
provider.quota_expires_at.isoformat() if provider.quota_expires_at else None
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
"rpm_info": {
|
|
||||||
"rpm_limit": provider.rpm_limit,
|
|
||||||
"rpm_used": provider.rpm_used,
|
|
||||||
"rpm_reset_at": (
|
|
||||||
provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"usage_stats": {
|
"usage_stats": {
|
||||||
"total_requests": total_requests,
|
"total_requests": total_requests,
|
||||||
"successful_requests": total_success,
|
"successful_requests": total_success,
|
||||||
@@ -296,8 +284,6 @@ class AdminProviderResetQuotaAdapter(AdminApiAdapter):
|
|||||||
|
|
||||||
old_used = provider.monthly_used_usd
|
old_used = provider.monthly_used_usd
|
||||||
provider.monthly_used_usd = 0.0
|
provider.monthly_used_usd = 0.0
|
||||||
provider.rpm_used = 0
|
|
||||||
provider.rpm_reset_at = None
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
logger.info(f"Manually reset quota for provider {provider.name}")
|
logger.info(f"Manually reset quota for provider {provider.name}")
|
||||||
|
|||||||
@@ -338,27 +338,29 @@ async def import_models_from_upstream(
|
|||||||
"""
|
"""
|
||||||
从上游提供商导入模型
|
从上游提供商导入模型
|
||||||
|
|
||||||
从上游提供商导入模型列表。如果全局模型不存在,将自动创建。
|
从上游提供商导入模型列表。导入的模型作为独立的 ProviderModel 存储,
|
||||||
|
不会自动创建 GlobalModel。后续需要手动关联 GlobalModel 才能参与路由。
|
||||||
|
|
||||||
**流程说明**:
|
**流程说明**:
|
||||||
1. 根据 model_ids 检查全局模型是否存在(按 name 匹配)
|
1. 检查模型是否已存在于当前 Provider(按 provider_model_name 匹配)
|
||||||
2. 如不存在,自动创建新的 GlobalModel(使用默认免费配置)
|
2. 创建新的 ProviderModel(global_model_id = NULL)
|
||||||
3. 创建 Model 关联到当前 Provider
|
3. 支持设置价格覆盖(tiered_pricing, price_per_request)
|
||||||
4. 如模型已关联,则记录到成功列表中
|
|
||||||
|
|
||||||
**路径参数**:
|
**路径参数**:
|
||||||
- `provider_id`: 提供商 ID
|
- `provider_id`: 提供商 ID
|
||||||
|
|
||||||
**请求体字段**:
|
**请求体字段**:
|
||||||
- `model_ids`: 模型 ID 数组(必填,每个 ID 长度 1-100 字符)
|
- `model_ids`: 模型 ID 数组(必填,每个 ID 长度 1-100 字符)
|
||||||
|
- `tiered_pricing`: 可选的阶梯计费配置(应用于所有导入的模型)
|
||||||
|
- `price_per_request`: 可选的按次计费价格(应用于所有导入的模型)
|
||||||
|
|
||||||
**返回字段**:
|
**返回字段**:
|
||||||
- `success`: 成功导入的模型数组,每项包含:
|
- `success`: 成功导入的模型数组,每项包含:
|
||||||
- `model_id`: 模型 ID
|
- `model_id`: 模型 ID
|
||||||
- `global_model_id`: 全局模型 ID
|
|
||||||
- `global_model_name`: 全局模型名称
|
|
||||||
- `provider_model_id`: 提供商模型 ID
|
- `provider_model_id`: 提供商模型 ID
|
||||||
- `created_global_model`: 是否新创建了全局模型
|
- `global_model_id`: 全局模型 ID(如果已关联)
|
||||||
|
- `global_model_name`: 全局模型名称(如果已关联)
|
||||||
|
- `created_global_model`: 是否新创建了全局模型(始终为 false)
|
||||||
- `errors`: 失败的模型数组,每项包含:
|
- `errors`: 失败的模型数组,每项包含:
|
||||||
- `model_id`: 模型 ID
|
- `model_id`: 模型 ID
|
||||||
- `error`: 错误信息
|
- `error`: 错误信息
|
||||||
@@ -638,7 +640,7 @@ class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdminImportFromUpstreamAdapter(AdminApiAdapter):
|
class AdminImportFromUpstreamAdapter(AdminApiAdapter):
|
||||||
"""从上游提供商导入模型"""
|
"""从上游提供商导入模型(不创建 GlobalModel,作为独立 ProviderModel)"""
|
||||||
|
|
||||||
provider_id: str
|
provider_id: str
|
||||||
payload: ImportFromUpstreamRequest
|
payload: ImportFromUpstreamRequest
|
||||||
@@ -652,16 +654,13 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
|
|||||||
success: list[ImportFromUpstreamSuccessItem] = []
|
success: list[ImportFromUpstreamSuccessItem] = []
|
||||||
errors: list[ImportFromUpstreamErrorItem] = []
|
errors: list[ImportFromUpstreamErrorItem] = []
|
||||||
|
|
||||||
# 默认阶梯计费配置(免费)
|
# 获取价格覆盖配置
|
||||||
default_tiered_pricing = {
|
tiered_pricing = None
|
||||||
"tiers": [
|
price_per_request = None
|
||||||
{
|
if hasattr(self.payload, 'tiered_pricing') and self.payload.tiered_pricing:
|
||||||
"up_to": None,
|
tiered_pricing = self.payload.tiered_pricing
|
||||||
"input_price_per_1m": 0.0,
|
if hasattr(self.payload, 'price_per_request') and self.payload.price_per_request is not None:
|
||||||
"output_price_per_1m": 0.0,
|
price_per_request = self.payload.price_per_request
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
for model_id in self.payload.model_ids:
|
for model_id in self.payload.model_ids:
|
||||||
# 输入验证:检查 model_id 长度
|
# 输入验证:检查 model_id 长度
|
||||||
@@ -678,56 +677,37 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
|
|||||||
# 使用 savepoint 确保单个模型导入的原子性
|
# 使用 savepoint 确保单个模型导入的原子性
|
||||||
savepoint = db.begin_nested()
|
savepoint = db.begin_nested()
|
||||||
try:
|
try:
|
||||||
# 1. 检查是否已存在同名的 GlobalModel
|
# 1. 检查是否已存在同名的 ProviderModel
|
||||||
global_model = (
|
|
||||||
db.query(GlobalModel).filter(GlobalModel.name == model_id).first()
|
|
||||||
)
|
|
||||||
created_global_model = False
|
|
||||||
|
|
||||||
if not global_model:
|
|
||||||
# 2. 创建新的 GlobalModel
|
|
||||||
global_model = GlobalModel(
|
|
||||||
name=model_id,
|
|
||||||
display_name=model_id,
|
|
||||||
default_tiered_pricing=default_tiered_pricing,
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
db.add(global_model)
|
|
||||||
db.flush()
|
|
||||||
created_global_model = True
|
|
||||||
logger.info(
|
|
||||||
f"Created new GlobalModel: {model_id} during upstream import"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. 检查是否已存在关联
|
|
||||||
existing = (
|
existing = (
|
||||||
db.query(Model)
|
db.query(Model)
|
||||||
.filter(
|
.filter(
|
||||||
Model.provider_id == self.provider_id,
|
Model.provider_id == self.provider_id,
|
||||||
Model.global_model_id == global_model.id,
|
Model.provider_model_name == model_id,
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if existing:
|
if existing:
|
||||||
# 已存在关联,提交 savepoint 并记录成功
|
# 已存在,提交 savepoint 并记录成功
|
||||||
savepoint.commit()
|
savepoint.commit()
|
||||||
success.append(
|
success.append(
|
||||||
ImportFromUpstreamSuccessItem(
|
ImportFromUpstreamSuccessItem(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
global_model_id=global_model.id,
|
global_model_id=existing.global_model_id or "",
|
||||||
global_model_name=global_model.name,
|
global_model_name=existing.global_model.name if existing.global_model else "",
|
||||||
provider_model_id=existing.id,
|
provider_model_id=existing.id,
|
||||||
created_global_model=created_global_model,
|
created_global_model=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 4. 创建新的 Model 记录
|
# 2. 创建新的 Model 记录(不关联 GlobalModel)
|
||||||
new_model = Model(
|
new_model = Model(
|
||||||
provider_id=self.provider_id,
|
provider_id=self.provider_id,
|
||||||
global_model_id=global_model.id,
|
global_model_id=None, # 独立模型,不关联 GlobalModel
|
||||||
provider_model_name=global_model.name,
|
provider_model_name=model_id,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
|
tiered_pricing=tiered_pricing,
|
||||||
|
price_per_request=price_per_request,
|
||||||
)
|
)
|
||||||
db.add(new_model)
|
db.add(new_model)
|
||||||
db.flush()
|
db.flush()
|
||||||
@@ -737,12 +717,15 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
|
|||||||
success.append(
|
success.append(
|
||||||
ImportFromUpstreamSuccessItem(
|
ImportFromUpstreamSuccessItem(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
global_model_id=global_model.id,
|
global_model_id="", # 未关联
|
||||||
global_model_name=global_model.name,
|
global_model_name="", # 未关联
|
||||||
provider_model_id=new_model.id,
|
provider_model_id=new_model.id,
|
||||||
created_global_model=created_global_model,
|
created_global_model=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Created independent ProviderModel: {model_id} for provider {provider.name}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 回滚到 savepoint
|
# 回滚到 savepoint
|
||||||
savepoint.rollback()
|
savepoint.rollback()
|
||||||
@@ -753,11 +736,9 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
|
|||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Imported {len(success)} models from upstream to provider {provider.name} by {context.user.username}"
|
f"Imported {len(success)} independent models to provider {provider.name} by {context.user.username}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 清除 /v1/models 列表缓存
|
# 不需要清除 /v1/models 缓存,因为独立模型不参与路由
|
||||||
if success:
|
|
||||||
await invalidate_models_list_cache()
|
|
||||||
|
|
||||||
return ImportFromUpstreamResponse(success=success, errors=errors)
|
return ImportFromUpstreamResponse(success=success, errors=errors)
|
||||||
|
|||||||
@@ -11,9 +11,11 @@ from src.api.base.admin_adapter import AdminApiAdapter
|
|||||||
from src.api.base.pipeline import ApiRequestPipeline
|
from src.api.base.pipeline import ApiRequestPipeline
|
||||||
from src.core.enums import ProviderBillingType
|
from src.core.enums import ProviderBillingType
|
||||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||||
|
from src.core.logger import logger
|
||||||
from src.database import get_db
|
from src.database import get_db
|
||||||
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
|
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
|
||||||
from src.models.database import Provider
|
from src.models.database import Provider
|
||||||
|
from src.services.cache.provider_cache import ProviderCacheService
|
||||||
|
|
||||||
router = APIRouter(tags=["Provider CRUD"])
|
router = APIRouter(tags=["Provider CRUD"])
|
||||||
pipeline = ApiRequestPipeline()
|
pipeline = ApiRequestPipeline()
|
||||||
@@ -39,8 +41,7 @@ async def list_providers(
|
|||||||
|
|
||||||
**返回字段**:
|
**返回字段**:
|
||||||
- `id`: 提供商 ID
|
- `id`: 提供商 ID
|
||||||
- `name`: 提供商名称(唯一标识)
|
- `name`: 提供商名称(唯一)
|
||||||
- `display_name`: 显示名称
|
|
||||||
- `api_format`: API 格式(如 claude、openai、gemini 等)
|
- `api_format`: API 格式(如 claude、openai、gemini 等)
|
||||||
- `base_url`: API 基础 URL
|
- `base_url`: API 基础 URL
|
||||||
- `api_key`: API 密钥(脱敏显示)
|
- `api_key`: API 密钥(脱敏显示)
|
||||||
@@ -61,8 +62,7 @@ async def create_provider(request: Request, db: Session = Depends(get_db)):
|
|||||||
创建一个新的 AI 模型提供商配置。
|
创建一个新的 AI 模型提供商配置。
|
||||||
|
|
||||||
**请求体字段**:
|
**请求体字段**:
|
||||||
- `name`: 提供商名称(必填,唯一,用于系统标识)
|
- `name`: 提供商名称(必填,唯一)
|
||||||
- `display_name`: 显示名称(必填)
|
|
||||||
- `description`: 描述信息(可选)
|
- `description`: 描述信息(可选)
|
||||||
- `website`: 官网地址(可选)
|
- `website`: 官网地址(可选)
|
||||||
- `billing_type`: 计费类型(可选,pay_as_you_go/subscription/prepaid,默认 pay_as_you_go)
|
- `billing_type`: 计费类型(可选,pay_as_you_go/subscription/prepaid,默认 pay_as_you_go)
|
||||||
@@ -70,16 +70,17 @@ async def create_provider(request: Request, db: Session = Depends(get_db)):
|
|||||||
- `quota_reset_day`: 配额重置日期(1-31)(可选)
|
- `quota_reset_day`: 配额重置日期(1-31)(可选)
|
||||||
- `quota_last_reset_at`: 上次配额重置时间(可选)
|
- `quota_last_reset_at`: 上次配额重置时间(可选)
|
||||||
- `quota_expires_at`: 配额过期时间(可选)
|
- `quota_expires_at`: 配额过期时间(可选)
|
||||||
- `rpm_limit`: 每分钟请求数限制(可选)
|
|
||||||
- `provider_priority`: 提供商优先级(数字越小优先级越高,默认 100)
|
- `provider_priority`: 提供商优先级(数字越小优先级越高,默认 100)
|
||||||
- `is_active`: 是否启用(默认 true)
|
- `is_active`: 是否启用(默认 true)
|
||||||
- `concurrent_limit`: 并发限制(可选)
|
- `concurrent_limit`: 并发限制(可选)
|
||||||
|
- `timeout`: 请求超时(秒,可选)
|
||||||
|
- `max_retries`: 最大重试次数(可选)
|
||||||
|
- `proxy`: 代理配置(可选)
|
||||||
- `config`: 额外配置信息(JSON,可选)
|
- `config`: 额外配置信息(JSON,可选)
|
||||||
|
|
||||||
**返回字段**:
|
**返回字段**:
|
||||||
- `id`: 新创建的提供商 ID
|
- `id`: 新创建的提供商 ID
|
||||||
- `name`: 提供商名称
|
- `name`: 提供商名称
|
||||||
- `display_name`: 显示名称
|
|
||||||
- `message`: 成功提示信息
|
- `message`: 成功提示信息
|
||||||
"""
|
"""
|
||||||
adapter = AdminCreateProviderAdapter()
|
adapter = AdminCreateProviderAdapter()
|
||||||
@@ -98,7 +99,6 @@ async def update_provider(provider_id: str, request: Request, db: Session = Depe
|
|||||||
|
|
||||||
**请求体字段**(所有字段可选):
|
**请求体字段**(所有字段可选):
|
||||||
- `name`: 提供商名称
|
- `name`: 提供商名称
|
||||||
- `display_name`: 显示名称
|
|
||||||
- `description`: 描述信息
|
- `description`: 描述信息
|
||||||
- `website`: 官网地址
|
- `website`: 官网地址
|
||||||
- `billing_type`: 计费类型(pay_as_you_go/subscription/prepaid)
|
- `billing_type`: 计费类型(pay_as_you_go/subscription/prepaid)
|
||||||
@@ -106,10 +106,12 @@ async def update_provider(provider_id: str, request: Request, db: Session = Depe
|
|||||||
- `quota_reset_day`: 配额重置日期(1-31)
|
- `quota_reset_day`: 配额重置日期(1-31)
|
||||||
- `quota_last_reset_at`: 上次配额重置时间
|
- `quota_last_reset_at`: 上次配额重置时间
|
||||||
- `quota_expires_at`: 配额过期时间
|
- `quota_expires_at`: 配额过期时间
|
||||||
- `rpm_limit`: 每分钟请求数限制
|
|
||||||
- `provider_priority`: 提供商优先级
|
- `provider_priority`: 提供商优先级
|
||||||
- `is_active`: 是否启用
|
- `is_active`: 是否启用
|
||||||
- `concurrent_limit`: 并发限制
|
- `concurrent_limit`: 并发限制
|
||||||
|
- `timeout`: 请求超时(秒)
|
||||||
|
- `max_retries`: 最大重试次数
|
||||||
|
- `proxy`: 代理配置
|
||||||
- `config`: 额外配置信息(JSON)
|
- `config`: 额外配置信息(JSON)
|
||||||
|
|
||||||
**返回字段**:
|
**返回字段**:
|
||||||
@@ -163,7 +165,6 @@ class AdminListProvidersAdapter(AdminApiAdapter):
|
|||||||
{
|
{
|
||||||
"id": provider.id,
|
"id": provider.id,
|
||||||
"name": provider.name,
|
"name": provider.name,
|
||||||
"display_name": provider.display_name,
|
|
||||||
"api_format": api_format.value if api_format else None,
|
"api_format": api_format.value if api_format else None,
|
||||||
"base_url": base_url,
|
"base_url": base_url,
|
||||||
"api_key": "***" if api_key else None,
|
"api_key": "***" if api_key else None,
|
||||||
@@ -215,7 +216,6 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
|
|||||||
# 创建 Provider 对象
|
# 创建 Provider 对象
|
||||||
provider = Provider(
|
provider = Provider(
|
||||||
name=validated_data.name,
|
name=validated_data.name,
|
||||||
display_name=validated_data.display_name,
|
|
||||||
description=validated_data.description,
|
description=validated_data.description,
|
||||||
website=validated_data.website,
|
website=validated_data.website,
|
||||||
billing_type=billing_type,
|
billing_type=billing_type,
|
||||||
@@ -223,10 +223,12 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
|
|||||||
quota_reset_day=validated_data.quota_reset_day,
|
quota_reset_day=validated_data.quota_reset_day,
|
||||||
quota_last_reset_at=validated_data.quota_last_reset_at,
|
quota_last_reset_at=validated_data.quota_last_reset_at,
|
||||||
quota_expires_at=validated_data.quota_expires_at,
|
quota_expires_at=validated_data.quota_expires_at,
|
||||||
rpm_limit=validated_data.rpm_limit,
|
|
||||||
provider_priority=validated_data.provider_priority,
|
provider_priority=validated_data.provider_priority,
|
||||||
is_active=validated_data.is_active,
|
is_active=validated_data.is_active,
|
||||||
concurrent_limit=validated_data.concurrent_limit,
|
concurrent_limit=validated_data.concurrent_limit,
|
||||||
|
timeout=validated_data.timeout,
|
||||||
|
max_retries=validated_data.max_retries,
|
||||||
|
proxy=validated_data.proxy.model_dump() if validated_data.proxy else None,
|
||||||
config=validated_data.config,
|
config=validated_data.config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -246,7 +248,6 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
|
|||||||
return {
|
return {
|
||||||
"id": provider.id,
|
"id": provider.id,
|
||||||
"name": provider.name,
|
"name": provider.name,
|
||||||
"display_name": provider.display_name,
|
|
||||||
"message": "提供商创建成功",
|
"message": "提供商创建成功",
|
||||||
}
|
}
|
||||||
except InvalidRequestException:
|
except InvalidRequestException:
|
||||||
@@ -289,6 +290,9 @@ class AdminUpdateProviderAdapter(AdminApiAdapter):
|
|||||||
if field == "billing_type" and value is not None:
|
if field == "billing_type" and value is not None:
|
||||||
# billing_type 需要转换为枚举
|
# billing_type 需要转换为枚举
|
||||||
setattr(provider, field, ProviderBillingType(value))
|
setattr(provider, field, ProviderBillingType(value))
|
||||||
|
elif field == "proxy" and value is not None:
|
||||||
|
# proxy 需要转换为 dict(如果是 Pydantic 模型)
|
||||||
|
setattr(provider, field, value if isinstance(value, dict) else value.model_dump())
|
||||||
else:
|
else:
|
||||||
setattr(provider, field, value)
|
setattr(provider, field, value)
|
||||||
|
|
||||||
@@ -296,6 +300,11 @@ class AdminUpdateProviderAdapter(AdminApiAdapter):
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(provider)
|
db.refresh(provider)
|
||||||
|
|
||||||
|
# 如果更新了 billing_type,清除缓存
|
||||||
|
if "billing_type" in update_data:
|
||||||
|
await ProviderCacheService.invalidate_provider_cache(provider.id)
|
||||||
|
logger.debug(f"已清除 Provider 缓存: {provider.id}")
|
||||||
|
|
||||||
context.add_audit_metadata(
|
context.add_audit_metadata(
|
||||||
action="update_provider",
|
action="update_provider",
|
||||||
provider_id=provider.id,
|
provider_id=provider.id,
|
||||||
|
|||||||
@@ -48,7 +48,6 @@ async def get_providers_summary(
|
|||||||
**返回字段**(数组,每项包含):
|
**返回字段**(数组,每项包含):
|
||||||
- `id`: 提供商 ID
|
- `id`: 提供商 ID
|
||||||
- `name`: 提供商名称
|
- `name`: 提供商名称
|
||||||
- `display_name`: 显示名称
|
|
||||||
- `description`: 描述信息
|
- `description`: 描述信息
|
||||||
- `website`: 官网地址
|
- `website`: 官网地址
|
||||||
- `provider_priority`: 优先级
|
- `provider_priority`: 优先级
|
||||||
@@ -59,9 +58,9 @@ async def get_providers_summary(
|
|||||||
- `quota_reset_day`: 配额重置日期
|
- `quota_reset_day`: 配额重置日期
|
||||||
- `quota_last_reset_at`: 上次配额重置时间
|
- `quota_last_reset_at`: 上次配额重置时间
|
||||||
- `quota_expires_at`: 配额过期时间
|
- `quota_expires_at`: 配额过期时间
|
||||||
- `rpm_limit`: RPM 限制
|
- `timeout`: 默认请求超时(秒)
|
||||||
- `rpm_used`: 已使用 RPM
|
- `max_retries`: 默认最大重试次数
|
||||||
- `rpm_reset_at`: RPM 重置时间
|
- `proxy`: 默认代理配置
|
||||||
- `total_endpoints`: 端点总数
|
- `total_endpoints`: 端点总数
|
||||||
- `active_endpoints`: 活跃端点数
|
- `active_endpoints`: 活跃端点数
|
||||||
- `total_keys`: 密钥总数
|
- `total_keys`: 密钥总数
|
||||||
@@ -96,7 +95,6 @@ async def get_provider_summary(
|
|||||||
**返回字段**:
|
**返回字段**:
|
||||||
- `id`: 提供商 ID
|
- `id`: 提供商 ID
|
||||||
- `name`: 提供商名称
|
- `name`: 提供商名称
|
||||||
- `display_name`: 显示名称
|
|
||||||
- `description`: 描述信息
|
- `description`: 描述信息
|
||||||
- `website`: 官网地址
|
- `website`: 官网地址
|
||||||
- `provider_priority`: 优先级
|
- `provider_priority`: 优先级
|
||||||
@@ -107,9 +105,9 @@ async def get_provider_summary(
|
|||||||
- `quota_reset_day`: 配额重置日期
|
- `quota_reset_day`: 配额重置日期
|
||||||
- `quota_last_reset_at`: 上次配额重置时间
|
- `quota_last_reset_at`: 上次配额重置时间
|
||||||
- `quota_expires_at`: 配额过期时间
|
- `quota_expires_at`: 配额过期时间
|
||||||
- `rpm_limit`: RPM 限制
|
- `timeout`: 默认请求超时(秒)
|
||||||
- `rpm_used`: 已使用 RPM
|
- `max_retries`: 默认最大重试次数
|
||||||
- `rpm_reset_at`: RPM 重置时间
|
- `proxy`: 默认代理配置
|
||||||
- `total_endpoints`: 端点总数
|
- `total_endpoints`: 端点总数
|
||||||
- `active_endpoints`: 活跃端点数
|
- `active_endpoints`: 活跃端点数
|
||||||
- `total_keys`: 密钥总数
|
- `total_keys`: 密钥总数
|
||||||
@@ -185,13 +183,13 @@ async def update_provider_settings(
|
|||||||
"""
|
"""
|
||||||
更新提供商基础配置
|
更新提供商基础配置
|
||||||
|
|
||||||
更新提供商的基础配置信息,如显示名称、描述、优先级等。只需传入需要更新的字段。
|
更新提供商的基础配置信息,如名称、描述、优先级等。只需传入需要更新的字段。
|
||||||
|
|
||||||
**路径参数**:
|
**路径参数**:
|
||||||
- `provider_id`: 提供商 ID
|
- `provider_id`: 提供商 ID
|
||||||
|
|
||||||
**请求体字段**(所有字段可选):
|
**请求体字段**(所有字段可选):
|
||||||
- `display_name`: 显示名称
|
- `name`: 提供商名称
|
||||||
- `description`: 描述信息
|
- `description`: 描述信息
|
||||||
- `website`: 官网地址
|
- `website`: 官网地址
|
||||||
- `provider_priority`: 优先级
|
- `provider_priority`: 优先级
|
||||||
@@ -199,9 +197,10 @@ async def update_provider_settings(
|
|||||||
- `billing_type`: 计费类型
|
- `billing_type`: 计费类型
|
||||||
- `monthly_quota_usd`: 月度配额(美元)
|
- `monthly_quota_usd`: 月度配额(美元)
|
||||||
- `quota_reset_day`: 配额重置日期
|
- `quota_reset_day`: 配额重置日期
|
||||||
- `quota_last_reset_at`: 上次配额重置时间
|
|
||||||
- `quota_expires_at`: 配额过期时间
|
- `quota_expires_at`: 配额过期时间
|
||||||
- `rpm_limit`: RPM 限制
|
- `timeout`: 默认请求超时(秒)
|
||||||
|
- `max_retries`: 默认最大重试次数
|
||||||
|
- `proxy`: 默认代理配置
|
||||||
|
|
||||||
**返回字段**: 返回更新后的提供商摘要信息(与 GET /summary 接口返回格式相同)
|
**返回字段**: 返回更新后的提供商摘要信息(与 GET /summary 接口返回格式相同)
|
||||||
"""
|
"""
|
||||||
@@ -215,18 +214,18 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
|
|||||||
|
|
||||||
total_endpoints = len(endpoints)
|
total_endpoints = len(endpoints)
|
||||||
active_endpoints = sum(1 for e in endpoints if e.is_active)
|
active_endpoints = sum(1 for e in endpoints if e.is_active)
|
||||||
endpoint_ids = [e.id for e in endpoints]
|
|
||||||
|
|
||||||
# Key 统计(合并为单个查询)
|
# Key 统计(合并为单个查询)
|
||||||
total_keys = 0
|
key_stats = (
|
||||||
active_keys = 0
|
db.query(
|
||||||
if endpoint_ids:
|
|
||||||
key_stats = db.query(
|
|
||||||
func.count(ProviderAPIKey.id).label("total"),
|
func.count(ProviderAPIKey.id).label("total"),
|
||||||
func.sum(case((ProviderAPIKey.is_active == True, 1), else_=0)).label("active"),
|
func.sum(case((ProviderAPIKey.is_active == True, 1), else_=0)).label("active"),
|
||||||
).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).first()
|
)
|
||||||
total_keys = key_stats.total or 0
|
.filter(ProviderAPIKey.provider_id == provider.id)
|
||||||
active_keys = int(key_stats.active or 0)
|
.first()
|
||||||
|
)
|
||||||
|
total_keys = key_stats.total or 0
|
||||||
|
active_keys = int(key_stats.active or 0)
|
||||||
|
|
||||||
# Model 统计(合并为单个查询)
|
# Model 统计(合并为单个查询)
|
||||||
model_stats = db.query(
|
model_stats = db.query(
|
||||||
@@ -238,25 +237,34 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
|
|||||||
|
|
||||||
api_formats = [e.api_format for e in endpoints]
|
api_formats = [e.api_format for e in endpoints]
|
||||||
|
|
||||||
# 优化: 一次性加载所有 endpoint 的 keys,避免 N+1 查询
|
# 优化: 一次性加载 Provider 的 keys,避免 N+1 查询
|
||||||
all_keys = []
|
all_keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.provider_id == provider.id).all()
|
||||||
if endpoint_ids:
|
|
||||||
all_keys = (
|
|
||||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).all()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 按 endpoint_id 分组 keys
|
# 按 api_formats 分组 keys(通过 api_formats 关联)
|
||||||
keys_by_endpoint: dict[str, list[ProviderAPIKey]] = {}
|
format_to_endpoint_id: dict[str, str] = {e.api_format: e.id for e in endpoints}
|
||||||
|
keys_by_endpoint: dict[str, list[ProviderAPIKey]] = {e.id: [] for e in endpoints}
|
||||||
for key in all_keys:
|
for key in all_keys:
|
||||||
if key.endpoint_id not in keys_by_endpoint:
|
formats = key.api_formats or []
|
||||||
keys_by_endpoint[key.endpoint_id] = []
|
for fmt in formats:
|
||||||
keys_by_endpoint[key.endpoint_id].append(key)
|
endpoint_id = format_to_endpoint_id.get(fmt)
|
||||||
|
if endpoint_id:
|
||||||
|
keys_by_endpoint[endpoint_id].append(key)
|
||||||
|
|
||||||
endpoint_health_map: dict[str, float] = {}
|
endpoint_health_map: dict[str, float] = {}
|
||||||
for endpoint in endpoints:
|
for endpoint in endpoints:
|
||||||
keys = keys_by_endpoint.get(endpoint.id, [])
|
keys = keys_by_endpoint.get(endpoint.id, [])
|
||||||
if keys:
|
if keys:
|
||||||
health_scores = [k.health_score for k in keys if k.health_score is not None]
|
# 从 health_by_format 获取对应格式的健康度
|
||||||
|
api_fmt = endpoint.api_format
|
||||||
|
health_scores = []
|
||||||
|
for k in keys:
|
||||||
|
health_by_format = k.health_by_format or {}
|
||||||
|
if api_fmt in health_by_format:
|
||||||
|
score = health_by_format[api_fmt].get("health_score")
|
||||||
|
if score is not None:
|
||||||
|
health_scores.append(float(score))
|
||||||
|
else:
|
||||||
|
health_scores.append(1.0) # 默认健康度
|
||||||
avg_health = sum(health_scores) / len(health_scores) if health_scores else 1.0
|
avg_health = sum(health_scores) / len(health_scores) if health_scores else 1.0
|
||||||
endpoint_health_map[endpoint.id] = avg_health
|
endpoint_health_map[endpoint.id] = avg_health
|
||||||
else:
|
else:
|
||||||
@@ -284,7 +292,6 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
|
|||||||
return ProviderWithEndpointsSummary(
|
return ProviderWithEndpointsSummary(
|
||||||
id=provider.id,
|
id=provider.id,
|
||||||
name=provider.name,
|
name=provider.name,
|
||||||
display_name=provider.display_name,
|
|
||||||
description=provider.description,
|
description=provider.description,
|
||||||
website=provider.website,
|
website=provider.website,
|
||||||
provider_priority=provider.provider_priority,
|
provider_priority=provider.provider_priority,
|
||||||
@@ -295,9 +302,9 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
|
|||||||
quota_reset_day=provider.quota_reset_day,
|
quota_reset_day=provider.quota_reset_day,
|
||||||
quota_last_reset_at=provider.quota_last_reset_at,
|
quota_last_reset_at=provider.quota_last_reset_at,
|
||||||
quota_expires_at=provider.quota_expires_at,
|
quota_expires_at=provider.quota_expires_at,
|
||||||
rpm_limit=provider.rpm_limit,
|
timeout=provider.timeout,
|
||||||
rpm_used=provider.rpm_used,
|
max_retries=provider.max_retries,
|
||||||
rpm_reset_at=provider.rpm_reset_at,
|
proxy=provider.proxy,
|
||||||
total_endpoints=total_endpoints,
|
total_endpoints=total_endpoints,
|
||||||
active_endpoints=active_endpoints,
|
active_endpoints=active_endpoints,
|
||||||
total_keys=total_keys,
|
total_keys=total_keys,
|
||||||
@@ -341,7 +348,7 @@ class AdminProviderHealthMonitorAdapter(AdminApiAdapter):
|
|||||||
if not endpoint_ids:
|
if not endpoint_ids:
|
||||||
response = ProviderEndpointHealthMonitorResponse(
|
response = ProviderEndpointHealthMonitorResponse(
|
||||||
provider_id=provider.id,
|
provider_id=provider.id,
|
||||||
provider_name=provider.display_name or provider.name,
|
provider_name=provider.name,
|
||||||
generated_at=now,
|
generated_at=now,
|
||||||
endpoints=[],
|
endpoints=[],
|
||||||
)
|
)
|
||||||
@@ -416,7 +423,7 @@ class AdminProviderHealthMonitorAdapter(AdminApiAdapter):
|
|||||||
|
|
||||||
response = ProviderEndpointHealthMonitorResponse(
|
response = ProviderEndpointHealthMonitorResponse(
|
||||||
provider_id=provider.id,
|
provider_id=provider.id,
|
||||||
provider_name=provider.display_name or provider.name,
|
provider_name=provider.name,
|
||||||
generated_at=now,
|
generated_at=now,
|
||||||
endpoints=endpoint_monitors,
|
endpoints=endpoint_monitors,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -42,6 +42,42 @@ def _get_version_from_git() -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_current_version() -> str:
|
||||||
|
"""获取当前版本号"""
|
||||||
|
version = _get_version_from_git()
|
||||||
|
if version:
|
||||||
|
return version
|
||||||
|
try:
|
||||||
|
from src._version import __version__
|
||||||
|
|
||||||
|
return __version__
|
||||||
|
except ImportError:
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_version(version_str: str) -> tuple:
|
||||||
|
"""解析版本号为可比较的元组,支持 3-4 段版本号
|
||||||
|
|
||||||
|
例如:
|
||||||
|
- '0.2.5' -> (0, 2, 5, 0)
|
||||||
|
- '0.2.5.1' -> (0, 2, 5, 1)
|
||||||
|
- 'v0.2.5-4-g1234567' -> (0, 2, 5, 0)
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
version_str = version_str.lstrip("v")
|
||||||
|
main_version = re.split(r"[-+]", version_str)[0]
|
||||||
|
try:
|
||||||
|
parts = main_version.split(".")
|
||||||
|
# 标准化为 4 段,便于比较
|
||||||
|
int_parts = [int(p) for p in parts]
|
||||||
|
while len(int_parts) < 4:
|
||||||
|
int_parts.append(0)
|
||||||
|
return tuple(int_parts[:4])
|
||||||
|
except ValueError:
|
||||||
|
return (0, 0, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/version")
|
@router.get("/version")
|
||||||
async def get_system_version():
|
async def get_system_version():
|
||||||
"""
|
"""
|
||||||
@@ -52,18 +88,111 @@ async def get_system_version():
|
|||||||
**返回字段**:
|
**返回字段**:
|
||||||
- `version`: 版本号字符串
|
- `version`: 版本号字符串
|
||||||
"""
|
"""
|
||||||
# 优先从 git 获取
|
return {"version": _get_current_version()}
|
||||||
version = _get_version_from_git()
|
|
||||||
if version:
|
|
||||||
return {"version": version}
|
@router.get("/check-update")
|
||||||
|
async def check_update():
|
||||||
|
"""
|
||||||
|
检查系统更新
|
||||||
|
|
||||||
|
从 GitHub Tags 获取最新版本并与当前版本对比。
|
||||||
|
|
||||||
|
**返回字段**:
|
||||||
|
- `current_version`: 当前版本号
|
||||||
|
- `latest_version`: 最新版本号
|
||||||
|
- `has_update`: 是否有更新可用
|
||||||
|
- `release_url`: 最新版本的 GitHub 页面链接
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from src.clients.http_client import HTTPClientPool
|
||||||
|
|
||||||
|
current_version = _get_current_version()
|
||||||
|
github_repo = "Aethersailor/Aether"
|
||||||
|
github_tags_url = f"https://api.github.com/repos/{github_repo}/tags"
|
||||||
|
|
||||||
# 回退到静态版本文件
|
|
||||||
try:
|
try:
|
||||||
from src._version import __version__
|
async with HTTPClientPool.get_temp_client(
|
||||||
|
timeout=httpx.Timeout(connect=5.0, read=10.0, write=5.0, pool=5.0)
|
||||||
|
) as client:
|
||||||
|
response = await client.get(
|
||||||
|
github_tags_url,
|
||||||
|
headers={
|
||||||
|
"Accept": "application/vnd.github.v3+json",
|
||||||
|
"User-Agent": f"Aether/{current_version}",
|
||||||
|
},
|
||||||
|
params={"per_page": 10},
|
||||||
|
)
|
||||||
|
|
||||||
return {"version": __version__}
|
if response.status_code != 200:
|
||||||
except ImportError:
|
return {
|
||||||
return {"version": "unknown"}
|
"current_version": current_version,
|
||||||
|
"latest_version": None,
|
||||||
|
"has_update": False,
|
||||||
|
"release_url": None,
|
||||||
|
"error": f"GitHub API 返回错误: {response.status_code}",
|
||||||
|
}
|
||||||
|
|
||||||
|
tags = response.json()
|
||||||
|
if not tags:
|
||||||
|
return {
|
||||||
|
"current_version": current_version,
|
||||||
|
"latest_version": None,
|
||||||
|
"has_update": False,
|
||||||
|
"release_url": None,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 找到最新的版本 tag(按版本号排序,而非时间)
|
||||||
|
version_tags = []
|
||||||
|
for tag in tags:
|
||||||
|
tag_name = tag.get("name", "")
|
||||||
|
if tag_name.startswith("v") or tag_name[0].isdigit():
|
||||||
|
version_tags.append((tag_name, _parse_version(tag_name)))
|
||||||
|
|
||||||
|
if not version_tags:
|
||||||
|
return {
|
||||||
|
"current_version": current_version,
|
||||||
|
"latest_version": None,
|
||||||
|
"has_update": False,
|
||||||
|
"release_url": None,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 按版本号排序,取最大的
|
||||||
|
version_tags.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
latest_tag = version_tags[0][0]
|
||||||
|
latest_version = latest_tag.lstrip("v")
|
||||||
|
|
||||||
|
current_tuple = _parse_version(current_version)
|
||||||
|
latest_tuple = _parse_version(latest_version)
|
||||||
|
has_update = latest_tuple > current_tuple
|
||||||
|
|
||||||
|
return {
|
||||||
|
"current_version": current_version,
|
||||||
|
"latest_version": latest_version,
|
||||||
|
"has_update": has_update,
|
||||||
|
"release_url": f"https://github.com/{github_repo}/releases/tag/{latest_tag}",
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
return {
|
||||||
|
"current_version": current_version,
|
||||||
|
"latest_version": None,
|
||||||
|
"has_update": False,
|
||||||
|
"release_url": None,
|
||||||
|
"error": "检查更新超时",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"current_version": current_version,
|
||||||
|
"latest_version": None,
|
||||||
|
"has_update": False,
|
||||||
|
"release_url": None,
|
||||||
|
"error": f"检查更新失败: {str(e)}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
pipeline = ApiRequestPipeline()
|
pipeline = ApiRequestPipeline()
|
||||||
@@ -601,36 +730,6 @@ class AdminExportConfigAdapter(AdminApiAdapter):
|
|||||||
)
|
)
|
||||||
endpoints_data = []
|
endpoints_data = []
|
||||||
for ep in endpoints:
|
for ep in endpoints:
|
||||||
# 导出 Endpoint Keys
|
|
||||||
keys = (
|
|
||||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == ep.id).all()
|
|
||||||
)
|
|
||||||
keys_data = []
|
|
||||||
for key in keys:
|
|
||||||
# 解密 API Key
|
|
||||||
try:
|
|
||||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
|
||||||
except Exception:
|
|
||||||
decrypted_key = ""
|
|
||||||
|
|
||||||
keys_data.append(
|
|
||||||
{
|
|
||||||
"api_key": decrypted_key,
|
|
||||||
"name": key.name,
|
|
||||||
"note": key.note,
|
|
||||||
"rate_multiplier": key.rate_multiplier,
|
|
||||||
"internal_priority": key.internal_priority,
|
|
||||||
"global_priority": key.global_priority,
|
|
||||||
"max_concurrent": key.max_concurrent,
|
|
||||||
"rate_limit": key.rate_limit,
|
|
||||||
"daily_limit": key.daily_limit,
|
|
||||||
"monthly_limit": key.monthly_limit,
|
|
||||||
"allowed_models": key.allowed_models,
|
|
||||||
"capabilities": key.capabilities,
|
|
||||||
"is_active": key.is_active,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
endpoints_data.append(
|
endpoints_data.append(
|
||||||
{
|
{
|
||||||
"api_format": ep.api_format,
|
"api_format": ep.api_format,
|
||||||
@@ -638,12 +737,44 @@ class AdminExportConfigAdapter(AdminApiAdapter):
|
|||||||
"headers": ep.headers,
|
"headers": ep.headers,
|
||||||
"timeout": ep.timeout,
|
"timeout": ep.timeout,
|
||||||
"max_retries": ep.max_retries,
|
"max_retries": ep.max_retries,
|
||||||
"max_concurrent": ep.max_concurrent,
|
|
||||||
"rate_limit": ep.rate_limit,
|
|
||||||
"is_active": ep.is_active,
|
"is_active": ep.is_active,
|
||||||
"custom_path": ep.custom_path,
|
"custom_path": ep.custom_path,
|
||||||
"config": ep.config,
|
"config": ep.config,
|
||||||
"keys": keys_data,
|
"proxy": ep.proxy,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导出 Provider Keys(按 provider_id 归属,包含 api_formats)
|
||||||
|
keys = (
|
||||||
|
db.query(ProviderAPIKey)
|
||||||
|
.filter(ProviderAPIKey.provider_id == provider.id)
|
||||||
|
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
keys_data = []
|
||||||
|
for key in keys:
|
||||||
|
# 解密 API Key
|
||||||
|
try:
|
||||||
|
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||||
|
except Exception:
|
||||||
|
decrypted_key = ""
|
||||||
|
|
||||||
|
keys_data.append(
|
||||||
|
{
|
||||||
|
"api_key": decrypted_key,
|
||||||
|
"name": key.name,
|
||||||
|
"note": key.note,
|
||||||
|
"api_formats": key.api_formats or [],
|
||||||
|
"rate_multiplier": key.rate_multiplier,
|
||||||
|
"rate_multipliers": key.rate_multipliers,
|
||||||
|
"internal_priority": key.internal_priority,
|
||||||
|
"global_priority": key.global_priority,
|
||||||
|
"rpm_limit": key.rpm_limit,
|
||||||
|
"allowed_models": key.allowed_models,
|
||||||
|
"capabilities": key.capabilities,
|
||||||
|
"cache_ttl_minutes": key.cache_ttl_minutes,
|
||||||
|
"max_probe_interval_minutes": key.max_probe_interval_minutes,
|
||||||
|
"is_active": key.is_active,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -675,24 +806,26 @@ class AdminExportConfigAdapter(AdminApiAdapter):
|
|||||||
providers_data.append(
|
providers_data.append(
|
||||||
{
|
{
|
||||||
"name": provider.name,
|
"name": provider.name,
|
||||||
"display_name": provider.display_name,
|
|
||||||
"description": provider.description,
|
"description": provider.description,
|
||||||
"website": provider.website,
|
"website": provider.website,
|
||||||
"billing_type": provider.billing_type.value if provider.billing_type else None,
|
"billing_type": provider.billing_type.value if provider.billing_type else None,
|
||||||
"monthly_quota_usd": provider.monthly_quota_usd,
|
"monthly_quota_usd": provider.monthly_quota_usd,
|
||||||
"quota_reset_day": provider.quota_reset_day,
|
"quota_reset_day": provider.quota_reset_day,
|
||||||
"rpm_limit": provider.rpm_limit,
|
|
||||||
"provider_priority": provider.provider_priority,
|
"provider_priority": provider.provider_priority,
|
||||||
"is_active": provider.is_active,
|
"is_active": provider.is_active,
|
||||||
"concurrent_limit": provider.concurrent_limit,
|
"concurrent_limit": provider.concurrent_limit,
|
||||||
|
"timeout": provider.timeout,
|
||||||
|
"max_retries": provider.max_retries,
|
||||||
|
"proxy": provider.proxy,
|
||||||
"config": provider.config,
|
"config": provider.config,
|
||||||
"endpoints": endpoints_data,
|
"endpoints": endpoints_data,
|
||||||
|
"api_keys": keys_data,
|
||||||
"models": models_data,
|
"models": models_data,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"version": "1.0",
|
"version": "2.0",
|
||||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||||
"global_models": global_models_data,
|
"global_models": global_models_data,
|
||||||
"providers": providers_data,
|
"providers": providers_data,
|
||||||
@@ -721,7 +854,7 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
|||||||
|
|
||||||
# 验证配置版本
|
# 验证配置版本
|
||||||
version = payload.get("version")
|
version = payload.get("version")
|
||||||
if version != "1.0":
|
if version != "2.0":
|
||||||
raise InvalidRequestException(f"不支持的配置版本: {version}")
|
raise InvalidRequestException(f"不支持的配置版本: {version}")
|
||||||
|
|
||||||
# 获取导入选项
|
# 获取导入选项
|
||||||
@@ -810,8 +943,8 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
|||||||
)
|
)
|
||||||
elif merge_mode == "overwrite":
|
elif merge_mode == "overwrite":
|
||||||
# 更新现有记录
|
# 更新现有记录
|
||||||
existing_provider.display_name = prov_data.get(
|
existing_provider.name = prov_data.get(
|
||||||
"display_name", existing_provider.display_name
|
"name", existing_provider.name
|
||||||
)
|
)
|
||||||
existing_provider.description = prov_data.get("description")
|
existing_provider.description = prov_data.get("description")
|
||||||
existing_provider.website = prov_data.get("website")
|
existing_provider.website = prov_data.get("website")
|
||||||
@@ -825,7 +958,6 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
|||||||
existing_provider.quota_reset_day = prov_data.get(
|
existing_provider.quota_reset_day = prov_data.get(
|
||||||
"quota_reset_day", 30
|
"quota_reset_day", 30
|
||||||
)
|
)
|
||||||
existing_provider.rpm_limit = prov_data.get("rpm_limit")
|
|
||||||
existing_provider.provider_priority = prov_data.get(
|
existing_provider.provider_priority = prov_data.get(
|
||||||
"provider_priority", 100
|
"provider_priority", 100
|
||||||
)
|
)
|
||||||
@@ -833,6 +965,11 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
|||||||
existing_provider.concurrent_limit = prov_data.get(
|
existing_provider.concurrent_limit = prov_data.get(
|
||||||
"concurrent_limit"
|
"concurrent_limit"
|
||||||
)
|
)
|
||||||
|
existing_provider.timeout = prov_data.get("timeout", existing_provider.timeout)
|
||||||
|
existing_provider.max_retries = prov_data.get(
|
||||||
|
"max_retries", existing_provider.max_retries
|
||||||
|
)
|
||||||
|
existing_provider.proxy = prov_data.get("proxy", existing_provider.proxy)
|
||||||
existing_provider.config = prov_data.get("config")
|
existing_provider.config = prov_data.get("config")
|
||||||
existing_provider.updated_at = datetime.now(timezone.utc)
|
existing_provider.updated_at = datetime.now(timezone.utc)
|
||||||
stats["providers"]["updated"] += 1
|
stats["providers"]["updated"] += 1
|
||||||
@@ -845,16 +982,17 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
|||||||
new_provider = Provider(
|
new_provider = Provider(
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
name=prov_data["name"],
|
name=prov_data["name"],
|
||||||
display_name=prov_data.get("display_name", prov_data["name"]),
|
|
||||||
description=prov_data.get("description"),
|
description=prov_data.get("description"),
|
||||||
website=prov_data.get("website"),
|
website=prov_data.get("website"),
|
||||||
billing_type=billing_type,
|
billing_type=billing_type,
|
||||||
monthly_quota_usd=prov_data.get("monthly_quota_usd"),
|
monthly_quota_usd=prov_data.get("monthly_quota_usd"),
|
||||||
quota_reset_day=prov_data.get("quota_reset_day", 30),
|
quota_reset_day=prov_data.get("quota_reset_day", 30),
|
||||||
rpm_limit=prov_data.get("rpm_limit"),
|
|
||||||
provider_priority=prov_data.get("provider_priority", 100),
|
provider_priority=prov_data.get("provider_priority", 100),
|
||||||
is_active=prov_data.get("is_active", True),
|
is_active=prov_data.get("is_active", True),
|
||||||
concurrent_limit=prov_data.get("concurrent_limit"),
|
concurrent_limit=prov_data.get("concurrent_limit"),
|
||||||
|
timeout=prov_data.get("timeout"),
|
||||||
|
max_retries=prov_data.get("max_retries"),
|
||||||
|
proxy=prov_data.get("proxy"),
|
||||||
config=prov_data.get("config"),
|
config=prov_data.get("config"),
|
||||||
)
|
)
|
||||||
db.add(new_provider)
|
db.add(new_provider)
|
||||||
@@ -874,7 +1012,6 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if existing_ep:
|
if existing_ep:
|
||||||
endpoint_id = existing_ep.id
|
|
||||||
if merge_mode == "skip":
|
if merge_mode == "skip":
|
||||||
stats["endpoints"]["skipped"] += 1
|
stats["endpoints"]["skipped"] += 1
|
||||||
elif merge_mode == "error":
|
elif merge_mode == "error":
|
||||||
@@ -887,12 +1024,11 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
|||||||
)
|
)
|
||||||
existing_ep.headers = ep_data.get("headers")
|
existing_ep.headers = ep_data.get("headers")
|
||||||
existing_ep.timeout = ep_data.get("timeout", 300)
|
existing_ep.timeout = ep_data.get("timeout", 300)
|
||||||
existing_ep.max_retries = ep_data.get("max_retries", 3)
|
existing_ep.max_retries = ep_data.get("max_retries", 2)
|
||||||
existing_ep.max_concurrent = ep_data.get("max_concurrent")
|
|
||||||
existing_ep.rate_limit = ep_data.get("rate_limit")
|
|
||||||
existing_ep.is_active = ep_data.get("is_active", True)
|
existing_ep.is_active = ep_data.get("is_active", True)
|
||||||
existing_ep.custom_path = ep_data.get("custom_path")
|
existing_ep.custom_path = ep_data.get("custom_path")
|
||||||
existing_ep.config = ep_data.get("config")
|
existing_ep.config = ep_data.get("config")
|
||||||
|
existing_ep.proxy = ep_data.get("proxy")
|
||||||
existing_ep.updated_at = datetime.now(timezone.utc)
|
existing_ep.updated_at = datetime.now(timezone.utc)
|
||||||
stats["endpoints"]["updated"] += 1
|
stats["endpoints"]["updated"] += 1
|
||||||
else:
|
else:
|
||||||
@@ -903,69 +1039,107 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
|||||||
base_url=ep_data["base_url"],
|
base_url=ep_data["base_url"],
|
||||||
headers=ep_data.get("headers"),
|
headers=ep_data.get("headers"),
|
||||||
timeout=ep_data.get("timeout", 300),
|
timeout=ep_data.get("timeout", 300),
|
||||||
max_retries=ep_data.get("max_retries", 3),
|
max_retries=ep_data.get("max_retries", 2),
|
||||||
max_concurrent=ep_data.get("max_concurrent"),
|
|
||||||
rate_limit=ep_data.get("rate_limit"),
|
|
||||||
is_active=ep_data.get("is_active", True),
|
is_active=ep_data.get("is_active", True),
|
||||||
custom_path=ep_data.get("custom_path"),
|
custom_path=ep_data.get("custom_path"),
|
||||||
config=ep_data.get("config"),
|
config=ep_data.get("config"),
|
||||||
|
proxy=ep_data.get("proxy"),
|
||||||
)
|
)
|
||||||
db.add(new_ep)
|
db.add(new_ep)
|
||||||
db.flush()
|
db.flush()
|
||||||
endpoint_id = new_ep.id
|
|
||||||
stats["endpoints"]["created"] += 1
|
stats["endpoints"]["created"] += 1
|
||||||
|
|
||||||
# 导入 Keys
|
# 导入 Provider Keys(按 provider_id 归属)
|
||||||
# 获取当前 endpoint 下所有已有的 keys,用于去重
|
endpoint_format_rows = (
|
||||||
existing_keys = (
|
db.query(ProviderEndpoint.api_format)
|
||||||
db.query(ProviderAPIKey)
|
.filter(ProviderEndpoint.provider_id == provider_id)
|
||||||
.filter(ProviderAPIKey.endpoint_id == endpoint_id)
|
.all()
|
||||||
.all()
|
)
|
||||||
)
|
endpoint_formats: set[str] = set()
|
||||||
# 解密已有 keys 用于比对
|
for (api_format,) in endpoint_format_rows:
|
||||||
existing_key_values = set()
|
fmt = api_format.value if hasattr(api_format, "value") else str(api_format)
|
||||||
for ek in existing_keys:
|
endpoint_formats.add(fmt.strip().upper())
|
||||||
try:
|
existing_keys = (
|
||||||
decrypted = crypto_service.decrypt(ek.api_key)
|
db.query(ProviderAPIKey)
|
||||||
existing_key_values.add(decrypted)
|
.filter(ProviderAPIKey.provider_id == provider_id)
|
||||||
except Exception:
|
.all()
|
||||||
pass
|
)
|
||||||
|
existing_key_values = set()
|
||||||
|
for ek in existing_keys:
|
||||||
|
try:
|
||||||
|
decrypted = crypto_service.decrypt(ek.api_key)
|
||||||
|
existing_key_values.add(decrypted)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
for key_data in ep_data.get("keys", []):
|
for key_data in prov_data.get("api_keys", []):
|
||||||
if not key_data.get("api_key"):
|
if not key_data.get("api_key"):
|
||||||
stats["errors"].append(
|
stats["errors"].append(
|
||||||
f"跳过空 API Key (Endpoint: {ep_data['api_format']})"
|
f"跳过空 API Key (Provider: {prov_data['name']})"
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 检查是否已存在相同的 Key(通过明文比对)
|
|
||||||
if key_data["api_key"] in existing_key_values:
|
|
||||||
stats["keys"]["skipped"] += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
encrypted_key = crypto_service.encrypt(key_data["api_key"])
|
|
||||||
|
|
||||||
new_key = ProviderAPIKey(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
endpoint_id=endpoint_id,
|
|
||||||
api_key=encrypted_key,
|
|
||||||
name=key_data.get("name"),
|
|
||||||
note=key_data.get("note"),
|
|
||||||
rate_multiplier=key_data.get("rate_multiplier", 1.0),
|
|
||||||
internal_priority=key_data.get("internal_priority", 100),
|
|
||||||
global_priority=key_data.get("global_priority"),
|
|
||||||
max_concurrent=key_data.get("max_concurrent"),
|
|
||||||
rate_limit=key_data.get("rate_limit"),
|
|
||||||
daily_limit=key_data.get("daily_limit"),
|
|
||||||
monthly_limit=key_data.get("monthly_limit"),
|
|
||||||
allowed_models=key_data.get("allowed_models"),
|
|
||||||
capabilities=key_data.get("capabilities"),
|
|
||||||
is_active=key_data.get("is_active", True),
|
|
||||||
)
|
)
|
||||||
db.add(new_key)
|
continue
|
||||||
# 添加到已有集合,防止同一批导入中重复
|
|
||||||
existing_key_values.add(key_data["api_key"])
|
plaintext_key = key_data["api_key"]
|
||||||
stats["keys"]["created"] += 1
|
if plaintext_key in existing_key_values:
|
||||||
|
stats["keys"]["skipped"] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw_formats = key_data.get("api_formats") or []
|
||||||
|
if not isinstance(raw_formats, list) or len(raw_formats) == 0:
|
||||||
|
stats["errors"].append(
|
||||||
|
f"跳过无 api_formats 的 Key (Provider: {prov_data['name']})"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized_formats: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
missing_formats: list[str] = []
|
||||||
|
for fmt in raw_formats:
|
||||||
|
if not isinstance(fmt, str):
|
||||||
|
continue
|
||||||
|
fmt_upper = fmt.strip().upper()
|
||||||
|
if not fmt_upper or fmt_upper in seen:
|
||||||
|
continue
|
||||||
|
seen.add(fmt_upper)
|
||||||
|
if endpoint_formats and fmt_upper not in endpoint_formats:
|
||||||
|
missing_formats.append(fmt_upper)
|
||||||
|
continue
|
||||||
|
normalized_formats.append(fmt_upper)
|
||||||
|
|
||||||
|
if missing_formats:
|
||||||
|
stats["errors"].append(
|
||||||
|
f"Key (Provider: {prov_data['name']}) 的 api_formats 未配置对应 Endpoint,已跳过: {missing_formats}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(normalized_formats) == 0:
|
||||||
|
stats["keys"]["skipped"] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
encrypted_key = crypto_service.encrypt(plaintext_key)
|
||||||
|
|
||||||
|
new_key = ProviderAPIKey(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
provider_id=provider_id,
|
||||||
|
api_formats=normalized_formats,
|
||||||
|
api_key=encrypted_key,
|
||||||
|
name=key_data.get("name") or "Imported Key",
|
||||||
|
note=key_data.get("note"),
|
||||||
|
rate_multiplier=key_data.get("rate_multiplier", 1.0),
|
||||||
|
rate_multipliers=key_data.get("rate_multipliers"),
|
||||||
|
internal_priority=key_data.get("internal_priority", 50),
|
||||||
|
global_priority=key_data.get("global_priority"),
|
||||||
|
rpm_limit=key_data.get("rpm_limit"),
|
||||||
|
allowed_models=key_data.get("allowed_models"),
|
||||||
|
capabilities=key_data.get("capabilities"),
|
||||||
|
cache_ttl_minutes=key_data.get("cache_ttl_minutes", 5),
|
||||||
|
max_probe_interval_minutes=key_data.get("max_probe_interval_minutes", 32),
|
||||||
|
is_active=key_data.get("is_active", True),
|
||||||
|
health_by_format={},
|
||||||
|
circuit_breaker_by_format={},
|
||||||
|
)
|
||||||
|
db.add(new_key)
|
||||||
|
existing_key_values.add(plaintext_key)
|
||||||
|
stats["keys"]["created"] += 1
|
||||||
|
|
||||||
# 导入 Models
|
# 导入 Models
|
||||||
for model_data in prov_data.get("models", []):
|
for model_data in prov_data.get("models", []):
|
||||||
|
|||||||
@@ -247,7 +247,8 @@ async def get_usage_detail(
|
|||||||
- `request_headers`: 请求头
|
- `request_headers`: 请求头
|
||||||
- `request_body`: 请求体
|
- `request_body`: 请求体
|
||||||
- `provider_request_headers`: 提供商请求头
|
- `provider_request_headers`: 提供商请求头
|
||||||
- `response_headers`: 响应头
|
- `response_headers`: 提供商响应头
|
||||||
|
- `client_response_headers`: 返回给客户端的响应头
|
||||||
- `response_body`: 响应体
|
- `response_body`: 响应体
|
||||||
- `metadata`: 提供商响应元数据
|
- `metadata`: 提供商响应元数据
|
||||||
- `tiered_pricing`: 阶梯计费信息(如适用)
|
- `tiered_pricing`: 阶梯计费信息(如适用)
|
||||||
@@ -916,6 +917,7 @@ class AdminUsageDetailAdapter(AdminApiAdapter):
|
|||||||
"request_body": usage_record.get_request_body(),
|
"request_body": usage_record.get_request_body(),
|
||||||
"provider_request_headers": usage_record.provider_request_headers,
|
"provider_request_headers": usage_record.provider_request_headers,
|
||||||
"response_headers": usage_record.response_headers,
|
"response_headers": usage_record.response_headers,
|
||||||
|
"client_response_headers": usage_record.client_response_headers,
|
||||||
"response_body": usage_record.get_response_body(),
|
"response_body": usage_record.get_response_body(),
|
||||||
"metadata": usage_record.request_metadata,
|
"metadata": usage_record.request_metadata,
|
||||||
"tiered_pricing": tiered_pricing_info,
|
"tiered_pricing": tiered_pricing_info,
|
||||||
|
|||||||
@@ -202,20 +202,59 @@ def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]:
|
|||||||
条件:
|
条件:
|
||||||
- 端点 api_format 匹配
|
- 端点 api_format 匹配
|
||||||
- 端点是活跃的
|
- 端点是活跃的
|
||||||
- 端点下有活跃的 Key
|
- Provider 下有活跃的 Key 且支持该 api_format(Key 直属 Provider,通过 api_formats 过滤)
|
||||||
"""
|
"""
|
||||||
rows = (
|
target_formats = {f.upper() for f in api_formats}
|
||||||
db.query(ProviderEndpoint.provider_id)
|
|
||||||
.join(ProviderAPIKey, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
|
# 1) 先找出有活跃端点的 Provider(记录每个 Provider 支持的格式集合)
|
||||||
|
endpoint_rows = (
|
||||||
|
db.query(ProviderEndpoint.provider_id, ProviderEndpoint.api_format)
|
||||||
.filter(
|
.filter(
|
||||||
ProviderEndpoint.api_format.in_(api_formats),
|
ProviderEndpoint.api_format.in_(list(target_formats)),
|
||||||
ProviderEndpoint.is_active.is_(True),
|
ProviderEndpoint.is_active.is_(True),
|
||||||
ProviderAPIKey.is_active.is_(True),
|
|
||||||
)
|
)
|
||||||
.distinct()
|
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
return {row[0] for row in rows}
|
|
||||||
|
if not endpoint_rows:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
provider_to_formats: dict[str, set[str]] = {}
|
||||||
|
for provider_id, fmt in endpoint_rows:
|
||||||
|
if not provider_id or not fmt:
|
||||||
|
continue
|
||||||
|
provider_to_formats.setdefault(provider_id, set()).add(str(fmt).upper())
|
||||||
|
|
||||||
|
provider_ids_with_endpoints = set(provider_to_formats.keys())
|
||||||
|
if not provider_ids_with_endpoints:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
# 2) 再检查这些 Provider 是否至少有一个活跃 Key 支持对应格式
|
||||||
|
key_rows = (
|
||||||
|
db.query(ProviderAPIKey.provider_id, ProviderAPIKey.api_formats)
|
||||||
|
.filter(
|
||||||
|
ProviderAPIKey.provider_id.in_(provider_ids_with_endpoints),
|
||||||
|
ProviderAPIKey.is_active.is_(True),
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
available_provider_ids: set[str] = set()
|
||||||
|
for provider_id, key_formats in key_rows:
|
||||||
|
if not provider_id:
|
||||||
|
continue
|
||||||
|
endpoint_formats = provider_to_formats.get(provider_id)
|
||||||
|
if not endpoint_formats:
|
||||||
|
continue
|
||||||
|
|
||||||
|
formats_list = key_formats if isinstance(key_formats, list) else []
|
||||||
|
key_formats_upper = {str(f).upper() for f in formats_list}
|
||||||
|
|
||||||
|
# 只有同时满足:请求格式 ∩ Provider 端点格式 ∩ Key 支持格式 非空,才算可用
|
||||||
|
if key_formats_upper & endpoint_formats & target_formats:
|
||||||
|
available_provider_ids.add(provider_id)
|
||||||
|
|
||||||
|
return available_provider_ids
|
||||||
|
|
||||||
|
|
||||||
def _get_available_model_ids_for_format(db: Session, api_formats: list[str]) -> set[str]:
|
def _get_available_model_ids_for_format(db: Session, api_formats: list[str]) -> set[str]:
|
||||||
@@ -228,36 +267,64 @@ def _get_available_model_ids_for_format(db: Session, api_formats: list[str]) ->
|
|||||||
3. **该端点的 Provider 关联了该模型**
|
3. **该端点的 Provider 关联了该模型**
|
||||||
4. Key 的 allowed_models 允许该模型(null = 允许该 Provider 关联的所有模型)
|
4. Key 的 allowed_models 允许该模型(null = 允许该 Provider 关联的所有模型)
|
||||||
"""
|
"""
|
||||||
# 查询所有匹配格式的活跃端点及其活跃 Key,同时获取 endpoint_id
|
target_formats = {f.upper() for f in api_formats}
|
||||||
endpoint_keys = (
|
|
||||||
db.query(
|
# 1) 找出有活跃端点的 Provider(记录每个 Provider 支持的格式集合)
|
||||||
ProviderEndpoint.id.label("endpoint_id"),
|
endpoint_rows = (
|
||||||
ProviderEndpoint.provider_id,
|
db.query(ProviderEndpoint.provider_id, ProviderEndpoint.api_format)
|
||||||
ProviderAPIKey.allowed_models,
|
|
||||||
)
|
|
||||||
.join(ProviderAPIKey, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
|
|
||||||
.filter(
|
.filter(
|
||||||
ProviderEndpoint.api_format.in_(api_formats),
|
ProviderEndpoint.api_format.in_(list(target_formats)),
|
||||||
ProviderEndpoint.is_active.is_(True),
|
ProviderEndpoint.is_active.is_(True),
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not endpoint_rows:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
provider_to_formats: dict[str, set[str]] = {}
|
||||||
|
for provider_id, fmt in endpoint_rows:
|
||||||
|
if not provider_id or not fmt:
|
||||||
|
continue
|
||||||
|
provider_to_formats.setdefault(provider_id, set()).add(str(fmt).upper())
|
||||||
|
|
||||||
|
provider_ids_with_endpoints = set(provider_to_formats.keys())
|
||||||
|
if not provider_ids_with_endpoints:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
# 2) 收集每个 Provider 下「支持对应格式」的活跃 Key 的 allowed_models
|
||||||
|
# Key 直属 Provider,通过 key.api_formats 与 Provider 端点格式交集筛选
|
||||||
|
key_rows = (
|
||||||
|
db.query(ProviderAPIKey.provider_id, ProviderAPIKey.allowed_models, ProviderAPIKey.api_formats)
|
||||||
|
.filter(
|
||||||
|
ProviderAPIKey.provider_id.in_(provider_ids_with_endpoints),
|
||||||
ProviderAPIKey.is_active.is_(True),
|
ProviderAPIKey.is_active.is_(True),
|
||||||
)
|
)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
if not endpoint_keys:
|
# provider_id -> list[(allowed_models, usable_formats)]
|
||||||
|
provider_key_rules: dict[str, list[tuple[object, set[str]]]] = {}
|
||||||
|
for provider_id, allowed_models, key_formats in key_rows:
|
||||||
|
if not provider_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
endpoint_formats = provider_to_formats.get(provider_id)
|
||||||
|
if not endpoint_formats:
|
||||||
|
continue
|
||||||
|
|
||||||
|
formats_list = key_formats if isinstance(key_formats, list) else []
|
||||||
|
key_formats_upper = {str(f).upper() for f in formats_list}
|
||||||
|
usable_formats = key_formats_upper & endpoint_formats & target_formats
|
||||||
|
if not usable_formats:
|
||||||
|
continue
|
||||||
|
|
||||||
|
provider_key_rules.setdefault(provider_id, []).append((allowed_models, usable_formats))
|
||||||
|
|
||||||
|
provider_ids_with_format = set(provider_key_rules.keys())
|
||||||
|
if not provider_ids_with_format:
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
# 收集每个 (provider_id, endpoint_id) 对应的 allowed_models
|
|
||||||
# 使用 provider_id 作为 key,因为模型是关联到 Provider 的
|
|
||||||
provider_allowed_models: dict[str, list[Optional[list[str]]]] = {}
|
|
||||||
provider_ids_with_format: set[str] = set()
|
|
||||||
|
|
||||||
for endpoint_id, provider_id, allowed_models in endpoint_keys:
|
|
||||||
provider_ids_with_format.add(provider_id)
|
|
||||||
if provider_id not in provider_allowed_models:
|
|
||||||
provider_allowed_models[provider_id] = []
|
|
||||||
provider_allowed_models[provider_id].append(allowed_models)
|
|
||||||
|
|
||||||
# 只查询那些有匹配格式端点的 Provider 下的模型
|
# 只查询那些有匹配格式端点的 Provider 下的模型
|
||||||
models = (
|
models = (
|
||||||
db.query(Model)
|
db.query(Model)
|
||||||
@@ -285,22 +352,30 @@ def _get_available_model_ids_for_format(db: Session, api_formats: list[str]) ->
|
|||||||
if model_provider_id not in provider_ids_with_format:
|
if model_provider_id not in provider_ids_with_format:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查该 provider 下是否有 Key 允许这个模型
|
# 检查该 provider 下是否有 Key 允许这个模型(支持 list/dict 两种 allowed_models)
|
||||||
allowed_lists = provider_allowed_models.get(model_provider_id, [])
|
from src.core.model_permissions import check_model_allowed
|
||||||
for allowed_models in allowed_lists:
|
|
||||||
|
rules = provider_key_rules.get(model_provider_id, [])
|
||||||
|
for allowed_models, usable_formats in rules:
|
||||||
|
# None = 不限制
|
||||||
if allowed_models is None:
|
if allowed_models is None:
|
||||||
# null = 允许该 Provider 关联的所有模型(已通过上面的查询限制)
|
|
||||||
available_model_ids.add(model_id)
|
|
||||||
break
|
|
||||||
elif model_id in allowed_models:
|
|
||||||
# 明确在允许列表中
|
|
||||||
available_model_ids.add(model_id)
|
|
||||||
break
|
|
||||||
elif global_model and model.provider_model_name in allowed_models:
|
|
||||||
# 也检查 provider_model_name
|
|
||||||
available_model_ids.add(model_id)
|
available_model_ids.add(model_id)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# 对于支持多个格式的 Key:任意一个可用格式允许即可
|
||||||
|
for fmt in usable_formats:
|
||||||
|
if check_model_allowed(
|
||||||
|
model_name=model_id,
|
||||||
|
allowed_models=allowed_models, # type: ignore[arg-type]
|
||||||
|
api_format=fmt,
|
||||||
|
resolved_model_name=(model.provider_model_name if global_model else None),
|
||||||
|
):
|
||||||
|
available_model_ids.add(model_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
return available_model_ids
|
return available_model_ids
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ async def get_daily_stats(
|
|||||||
class DashboardAdapter(ApiAdapter):
|
class DashboardAdapter(ApiAdapter):
|
||||||
"""需要登录的仪表盘适配器基类。"""
|
"""需要登录的仪表盘适配器基类。"""
|
||||||
|
|
||||||
mode = ApiMode.ADMIN
|
mode = ApiMode.USER # 普通用户也可访问仪表盘
|
||||||
|
|
||||||
def authorize(self, context): # type: ignore[override]
|
def authorize(self, context): # type: ignore[override]
|
||||||
if not context.user:
|
if not context.user:
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ class MessageTelemetry:
|
|||||||
request_headers: Dict[str, Any],
|
request_headers: Dict[str, Any],
|
||||||
response_body: Any,
|
response_body: Any,
|
||||||
response_headers: Dict[str, Any],
|
response_headers: Dict[str, Any],
|
||||||
|
client_response_headers: Optional[Dict[str, Any]] = None,
|
||||||
cache_creation_tokens: int = 0,
|
cache_creation_tokens: int = 0,
|
||||||
cache_read_tokens: int = 0,
|
cache_read_tokens: int = 0,
|
||||||
is_stream: bool = False,
|
is_stream: bool = False,
|
||||||
@@ -143,6 +144,7 @@ class MessageTelemetry:
|
|||||||
request_body=request_body,
|
request_body=request_body,
|
||||||
provider_request_headers=provider_request_headers or {},
|
provider_request_headers=provider_request_headers or {},
|
||||||
response_headers=response_headers,
|
response_headers=response_headers,
|
||||||
|
client_response_headers=client_response_headers,
|
||||||
response_body=response_body,
|
response_body=response_body,
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
# Provider 侧追踪信息(用于记录真实成本)
|
# Provider 侧追踪信息(用于记录真实成本)
|
||||||
@@ -192,6 +194,8 @@ class MessageTelemetry:
|
|||||||
cache_creation_tokens: int = 0,
|
cache_creation_tokens: int = 0,
|
||||||
cache_read_tokens: int = 0,
|
cache_read_tokens: int = 0,
|
||||||
response_body: Optional[Dict[str, Any]] = None,
|
response_body: Optional[Dict[str, Any]] = None,
|
||||||
|
response_headers: Optional[Dict[str, Any]] = None,
|
||||||
|
client_response_headers: Optional[Dict[str, Any]] = None,
|
||||||
# 模型映射信息
|
# 模型映射信息
|
||||||
target_model: Optional[str] = None,
|
target_model: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -207,6 +211,8 @@ class MessageTelemetry:
|
|||||||
cache_creation_tokens: 缓存创建 tokens
|
cache_creation_tokens: 缓存创建 tokens
|
||||||
cache_read_tokens: 缓存读取 tokens
|
cache_read_tokens: 缓存读取 tokens
|
||||||
response_body: 响应体(如果有部分响应)
|
response_body: 响应体(如果有部分响应)
|
||||||
|
response_headers: 响应头(Provider 返回的原始响应头)
|
||||||
|
client_response_headers: 返回给客户端的响应头
|
||||||
target_model: 映射后的目标模型名(如果发生了映射)
|
target_model: 映射后的目标模型名(如果发生了映射)
|
||||||
"""
|
"""
|
||||||
provider_name = provider or "unknown"
|
provider_name = provider or "unknown"
|
||||||
@@ -232,7 +238,8 @@ class MessageTelemetry:
|
|||||||
request_headers=request_headers,
|
request_headers=request_headers,
|
||||||
request_body=request_body,
|
request_body=request_body,
|
||||||
provider_request_headers=provider_request_headers or {},
|
provider_request_headers=provider_request_headers or {},
|
||||||
response_headers={},
|
response_headers=response_headers or {},
|
||||||
|
client_response_headers=client_response_headers,
|
||||||
response_body=response_body or {"error": error_message},
|
response_body=response_body or {"error": error_message},
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
# 模型映射信息
|
# 模型映射信息
|
||||||
|
|||||||
@@ -351,9 +351,9 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
# 确定错误消息
|
# 确定错误消息
|
||||||
if isinstance(e, ProviderAuthException):
|
if isinstance(e, ProviderAuthException):
|
||||||
error_message = (
|
error_message = (
|
||||||
f"提供商认证失败: {str(e)}"
|
"上游服务认证失败"
|
||||||
if result.metadata.provider != "unknown"
|
if result.metadata.provider != "unknown"
|
||||||
else "服务端错误: 无可用提供商"
|
else "服务暂时不可用"
|
||||||
)
|
)
|
||||||
result.error_message = error_message
|
result.error_message = error_message
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ from src.api.handlers.base.stream_processor import StreamProcessor
|
|||||||
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
|
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
|
||||||
from src.api.handlers.base.utils import build_sse_headers
|
from src.api.handlers.base.utils import build_sse_headers
|
||||||
from src.config.settings import config
|
from src.config.settings import config
|
||||||
from src.core.error_utils import extract_error_message
|
from src.core.error_utils import extract_client_error_message
|
||||||
from src.core.exceptions import (
|
from src.core.exceptions import (
|
||||||
EmbeddedErrorException,
|
EmbeddedErrorException,
|
||||||
ProviderAuthException,
|
ProviderAuthException,
|
||||||
@@ -382,10 +382,17 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
http_request.is_disconnected,
|
http_request.is_disconnected,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 透传提供商的响应头给客户端
|
||||||
|
# 同时添加必要的 SSE 头以确保流式传输正常工作
|
||||||
|
client_headers = dict(ctx.response_headers) if ctx.response_headers else {}
|
||||||
|
# 添加/覆盖 SSE 必需的头
|
||||||
|
client_headers.update(build_sse_headers())
|
||||||
|
client_headers["content-type"] = "text/event-stream"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
monitored_stream,
|
monitored_stream,
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers=build_sse_headers(),
|
headers=client_headers,
|
||||||
background=background_tasks,
|
background=background_tasks,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -463,7 +470,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
|
|
||||||
# 配置 HTTP 超时
|
# 配置 HTTP 超时
|
||||||
# 注意:read timeout 用于检测连接断开,不是整体请求超时
|
# 注意:read timeout 用于检测连接断开,不是整体请求超时
|
||||||
# 整体请求超时由 asyncio.wait_for 控制,使用 endpoint.timeout
|
# 整体请求超时由 asyncio.wait_for 控制,使用 provider.timeout
|
||||||
timeout_config = httpx.Timeout(
|
timeout_config = httpx.Timeout(
|
||||||
connect=config.http_connect_timeout,
|
connect=config.http_connect_timeout,
|
||||||
read=config.http_read_timeout, # 使用全局配置,用于检测连接断开
|
read=config.http_read_timeout, # 使用全局配置,用于检测连接断开
|
||||||
@@ -471,14 +478,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
pool=config.http_pool_timeout,
|
pool=config.http_pool_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
# endpoint.timeout 作为整体请求超时(建立连接 + 获取首字节)
|
# provider.timeout 作为整体请求超时(建立连接 + 获取首字节)
|
||||||
request_timeout = float(endpoint.timeout or 300)
|
request_timeout = float(provider.timeout or 300)
|
||||||
|
|
||||||
# 创建 HTTP 客户端(支持代理配置)
|
# 创建 HTTP 客户端(支持代理配置,从 Provider 读取)
|
||||||
from src.clients.http_client import HTTPClientPool
|
from src.clients.http_client import HTTPClientPool
|
||||||
|
|
||||||
http_client = HTTPClientPool.create_client_with_proxy(
|
http_client = HTTPClientPool.create_client_with_proxy(
|
||||||
proxy_config=endpoint.proxy,
|
proxy_config=provider.proxy,
|
||||||
timeout=timeout_config,
|
timeout=timeout_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -514,7 +521,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段
|
# 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段
|
||||||
# endpoint.timeout 控制整体超时,避免上游长时间无响应
|
# provider.timeout 控制整体超时,避免上游长时间无响应
|
||||||
await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout)
|
await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
@@ -590,17 +597,22 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
|
|
||||||
actual_request_body = ctx.provider_request_body or original_request_body
|
actual_request_body = ctx.provider_request_body or original_request_body
|
||||||
|
|
||||||
|
# 失败时返回给客户端的是 JSON 错误响应
|
||||||
|
client_response_headers = {"content-type": "application/json"}
|
||||||
|
|
||||||
await self.telemetry.record_failure(
|
await self.telemetry.record_failure(
|
||||||
provider=ctx.provider_name or "unknown",
|
provider=ctx.provider_name or "unknown",
|
||||||
model=ctx.model,
|
model=ctx.model,
|
||||||
response_time_ms=response_time_ms,
|
response_time_ms=response_time_ms,
|
||||||
status_code=status_code,
|
status_code=status_code,
|
||||||
error_message=extract_error_message(error),
|
error_message=extract_client_error_message(error),
|
||||||
request_headers=original_headers,
|
request_headers=original_headers,
|
||||||
request_body=actual_request_body,
|
request_body=actual_request_body,
|
||||||
is_stream=True,
|
is_stream=True,
|
||||||
api_format=ctx.api_format,
|
api_format=ctx.api_format,
|
||||||
provider_request_headers=ctx.provider_request_headers,
|
provider_request_headers=ctx.provider_request_headers,
|
||||||
|
response_headers=ctx.response_headers,
|
||||||
|
client_response_headers=client_response_headers,
|
||||||
target_model=ctx.mapped_model,
|
target_model=ctx.mapped_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -691,64 +703,98 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}"
|
f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 HTTP 客户端(支持代理配置)
|
# 获取复用的 HTTP 客户端(支持代理配置,从 Provider 读取)
|
||||||
# endpoint.timeout 作为整体请求超时
|
# 注意:使用 get_proxy_client 复用连接池,不再每次创建新客户端
|
||||||
from src.clients.http_client import HTTPClientPool
|
from src.clients.http_client import HTTPClientPool
|
||||||
|
|
||||||
request_timeout = float(endpoint.timeout or 300)
|
request_timeout = float(provider.timeout or 300)
|
||||||
http_client = HTTPClientPool.create_client_with_proxy(
|
http_client = await HTTPClientPool.get_proxy_client(
|
||||||
proxy_config=endpoint.proxy,
|
proxy_config=provider.proxy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注意:不使用 async with,因为复用的客户端不应该被关闭
|
||||||
|
# 超时通过 timeout 参数控制
|
||||||
|
resp = await http_client.post(
|
||||||
|
url,
|
||||||
|
json=provider_payload,
|
||||||
|
headers=provider_hdrs,
|
||||||
timeout=httpx.Timeout(request_timeout),
|
timeout=httpx.Timeout(request_timeout),
|
||||||
)
|
)
|
||||||
async with http_client:
|
|
||||||
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
|
|
||||||
|
|
||||||
status_code = resp.status_code
|
status_code = resp.status_code
|
||||||
response_headers = dict(resp.headers)
|
response_headers = dict(resp.headers)
|
||||||
|
|
||||||
if resp.status_code == 401:
|
if resp.status_code == 401:
|
||||||
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
|
raise ProviderAuthException(str(provider.name))
|
||||||
elif resp.status_code == 429:
|
elif resp.status_code == 429:
|
||||||
raise ProviderRateLimitException(
|
raise ProviderRateLimitException(
|
||||||
f"提供商速率限制: {provider.name}",
|
"请求过于频繁,请稍后重试",
|
||||||
provider_name=str(provider.name),
|
provider_name=str(provider.name),
|
||||||
response_headers=response_headers,
|
response_headers=response_headers,
|
||||||
|
)
|
||||||
|
elif resp.status_code >= 500:
|
||||||
|
# 记录响应体以便调试
|
||||||
|
error_body = ""
|
||||||
|
try:
|
||||||
|
error_body = resp.text[:1000]
|
||||||
|
logger.error(
|
||||||
|
f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}"
|
||||||
)
|
)
|
||||||
elif resp.status_code >= 500:
|
except Exception:
|
||||||
# 记录响应体以便调试
|
pass
|
||||||
error_body = ""
|
raise ProviderNotAvailableException(
|
||||||
try:
|
f"上游服务暂时不可用 (HTTP {resp.status_code})",
|
||||||
error_body = resp.text[:1000]
|
provider_name=str(provider.name),
|
||||||
logger.error(
|
upstream_status=resp.status_code,
|
||||||
f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}"
|
upstream_response=error_body,
|
||||||
)
|
)
|
||||||
except Exception:
|
elif resp.status_code != 200:
|
||||||
pass
|
# 记录非200响应以便调试
|
||||||
raise ProviderNotAvailableException(
|
error_body = ""
|
||||||
f"提供商服务不可用: {provider.name}",
|
try:
|
||||||
provider_name=str(provider.name),
|
error_body = resp.text[:1000]
|
||||||
upstream_status=resp.status_code,
|
logger.warning(
|
||||||
upstream_response=error_body,
|
f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}"
|
||||||
)
|
|
||||||
elif resp.status_code != 200:
|
|
||||||
# 记录非200响应以便调试
|
|
||||||
error_body = ""
|
|
||||||
try:
|
|
||||||
error_body = resp.text[:1000]
|
|
||||||
logger.warning(
|
|
||||||
f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}"
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
raise ProviderNotAvailableException(
|
|
||||||
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
|
|
||||||
provider_name=str(provider.name),
|
|
||||||
upstream_status=resp.status_code,
|
|
||||||
upstream_response=error_body,
|
|
||||||
)
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise ProviderNotAvailableException(
|
||||||
|
f"上游服务返回错误 (HTTP {resp.status_code})",
|
||||||
|
provider_name=str(provider.name),
|
||||||
|
upstream_status=resp.status_code,
|
||||||
|
upstream_response=error_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 安全解析 JSON 响应,处理可能的编码错误
|
||||||
|
try:
|
||||||
response_json = resp.json()
|
response_json = resp.json()
|
||||||
return response_json if isinstance(response_json, dict) else {}
|
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
||||||
|
# 获取原始响应内容用于调试(存入 upstream_response)
|
||||||
|
raw_content = ""
|
||||||
|
try:
|
||||||
|
raw_content = resp.text[:500] if resp.text else "(empty)"
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
raw_content = repr(resp.content[:500]) if resp.content else "(empty)"
|
||||||
|
except Exception:
|
||||||
|
raw_content = "(unable to read)"
|
||||||
|
logger.error(
|
||||||
|
f"[{self.request_id}] 无法解析响应 JSON: {e}, 原始内容: {raw_content}"
|
||||||
|
)
|
||||||
|
# 判断错误类型,生成友好的客户端错误消息(不暴露提供商信息)
|
||||||
|
if raw_content == "(empty)" or not raw_content.strip():
|
||||||
|
client_message = "上游服务返回了空响应"
|
||||||
|
elif raw_content.strip().startswith(("<", "<!doctype", "<!DOCTYPE")):
|
||||||
|
client_message = "上游服务返回了非预期的响应格式"
|
||||||
|
else:
|
||||||
|
client_message = "上游服务返回了无效的响应"
|
||||||
|
raise ProviderNotAvailableException(
|
||||||
|
client_message,
|
||||||
|
provider_name=str(provider.name),
|
||||||
|
upstream_status=resp.status_code,
|
||||||
|
upstream_response=raw_content,
|
||||||
|
)
|
||||||
|
return response_json if isinstance(response_json, dict) else {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 解析能力需求
|
# 解析能力需求
|
||||||
@@ -792,6 +838,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
|
|
||||||
actual_request_body = provider_request_body or original_request_body
|
actual_request_body = provider_request_body or original_request_body
|
||||||
|
|
||||||
|
# 非流式成功时,返回给客户端的是提供商响应头(透传)
|
||||||
|
# JSONResponse 会自动设置 content-type,但我们记录实际返回的完整头
|
||||||
|
client_response_headers = dict(response_headers) if response_headers else {}
|
||||||
|
client_response_headers["content-type"] = "application/json"
|
||||||
|
|
||||||
total_cost = await self.telemetry.record_success(
|
total_cost = await self.telemetry.record_success(
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
model=model,
|
model=model,
|
||||||
@@ -802,6 +853,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
request_headers=original_headers,
|
request_headers=original_headers,
|
||||||
request_body=actual_request_body,
|
request_body=actual_request_body,
|
||||||
response_headers=response_headers,
|
response_headers=response_headers,
|
||||||
|
client_response_headers=client_response_headers,
|
||||||
response_body=response_json,
|
response_body=response_json,
|
||||||
cache_creation_tokens=cache_creation_tokens,
|
cache_creation_tokens=cache_creation_tokens,
|
||||||
cache_read_tokens=cached_tokens,
|
cache_read_tokens=cached_tokens,
|
||||||
@@ -823,7 +875,12 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
f"in:{input_tokens or 0} out:{output_tokens or 0}"
|
f"in:{input_tokens or 0} out:{output_tokens or 0}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return JSONResponse(status_code=status_code, content=response_json)
|
# 透传提供商的响应头
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status_code,
|
||||||
|
content=response_json,
|
||||||
|
headers=response_headers if response_headers else None,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
response_time_ms = self.elapsed_ms()
|
response_time_ms = self.elapsed_ms()
|
||||||
@@ -838,17 +895,27 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
|
|
||||||
actual_request_body = provider_request_body or original_request_body
|
actual_request_body = provider_request_body or original_request_body
|
||||||
|
|
||||||
|
# 尝试从异常中提取响应头
|
||||||
|
error_response_headers: Dict[str, str] = {}
|
||||||
|
if isinstance(e, ProviderRateLimitException) and e.response_headers:
|
||||||
|
error_response_headers = e.response_headers
|
||||||
|
elif isinstance(e, httpx.HTTPStatusError) and hasattr(e, "response"):
|
||||||
|
error_response_headers = dict(e.response.headers)
|
||||||
|
|
||||||
await self.telemetry.record_failure(
|
await self.telemetry.record_failure(
|
||||||
provider=provider_name or "unknown",
|
provider=provider_name or "unknown",
|
||||||
model=model,
|
model=model,
|
||||||
response_time_ms=response_time_ms,
|
response_time_ms=response_time_ms,
|
||||||
status_code=status_code,
|
status_code=status_code,
|
||||||
error_message=extract_error_message(e),
|
error_message=extract_client_error_message(e),
|
||||||
request_headers=original_headers,
|
request_headers=original_headers,
|
||||||
request_body=actual_request_body,
|
request_body=actual_request_body,
|
||||||
is_stream=False,
|
is_stream=False,
|
||||||
api_format=api_format,
|
api_format=api_format,
|
||||||
provider_request_headers=provider_request_headers,
|
provider_request_headers=provider_request_headers,
|
||||||
|
response_headers=error_response_headers,
|
||||||
|
# 非流式失败返回给客户端的是 JSON 错误响应
|
||||||
|
client_response_headers={"content-type": "application/json"},
|
||||||
# 模型映射信息
|
# 模型映射信息
|
||||||
target_model=mapped_model_result,
|
target_model=mapped_model_result,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -306,9 +306,9 @@ class CliAdapterBase(ApiAdapter):
|
|||||||
# 确定错误消息
|
# 确定错误消息
|
||||||
if isinstance(e, ProviderAuthException):
|
if isinstance(e, ProviderAuthException):
|
||||||
error_message = (
|
error_message = (
|
||||||
f"提供商认证失败: {str(e)}"
|
"上游服务认证失败"
|
||||||
if result.metadata.provider != "unknown"
|
if result.metadata.provider != "unknown"
|
||||||
else "服务端错误: 无可用提供商"
|
else "服务暂时不可用"
|
||||||
)
|
)
|
||||||
result.error_message = error_message
|
result.error_message = error_message
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ from src.api.handlers.base.utils import (
|
|||||||
)
|
)
|
||||||
from src.config.constants import StreamDefaults
|
from src.config.constants import StreamDefaults
|
||||||
from src.config.settings import config
|
from src.config.settings import config
|
||||||
from src.core.error_utils import extract_error_message
|
from src.core.error_utils import extract_client_error_message
|
||||||
from src.core.exceptions import (
|
from src.core.exceptions import (
|
||||||
EmbeddedErrorException,
|
EmbeddedErrorException,
|
||||||
ProviderAuthException,
|
ProviderAuthException,
|
||||||
@@ -376,10 +376,18 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
# 创建监控流
|
# 创建监控流
|
||||||
monitored_stream = self._create_monitored_stream(ctx, stream_generator)
|
monitored_stream = self._create_monitored_stream(ctx, stream_generator)
|
||||||
|
|
||||||
|
# 透传提供商的响应头给客户端
|
||||||
|
# 同时添加必要的 SSE 头以确保流式传输正常工作
|
||||||
|
client_headers = dict(ctx.response_headers) if ctx.response_headers else {}
|
||||||
|
# 添加/覆盖 SSE 必需的头
|
||||||
|
client_headers.update(build_sse_headers())
|
||||||
|
client_headers["content-type"] = "text/event-stream"
|
||||||
|
ctx.client_response_headers = client_headers
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
monitored_stream,
|
monitored_stream,
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers=build_sse_headers(),
|
headers=client_headers,
|
||||||
background=background_tasks,
|
background=background_tasks,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -475,8 +483,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
pool=config.http_pool_timeout,
|
pool=config.http_pool_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
# endpoint.timeout 作为整体请求超时(建立连接 + 获取首字节)
|
# provider.timeout 作为整体请求超时(建立连接 + 获取首字节)
|
||||||
request_timeout = float(endpoint.timeout or 300)
|
request_timeout = float(provider.timeout or 300)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f" └─ [{self.request_id}] 发送流式请求: "
|
f" └─ [{self.request_id}] 发送流式请求: "
|
||||||
@@ -486,11 +494,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
f"timeout={request_timeout}s"
|
f"timeout={request_timeout}s"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 HTTP 客户端(支持代理配置)
|
# 创建 HTTP 客户端(支持代理配置,从 Provider 读取)
|
||||||
from src.clients.http_client import HTTPClientPool
|
from src.clients.http_client import HTTPClientPool
|
||||||
|
|
||||||
http_client = HTTPClientPool.create_client_with_proxy(
|
http_client = HTTPClientPool.create_client_with_proxy(
|
||||||
proxy_config=endpoint.proxy,
|
proxy_config=provider.proxy,
|
||||||
timeout=timeout_config,
|
timeout=timeout_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -524,7 +532,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段
|
# 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段
|
||||||
# endpoint.timeout 控制整体超时,避免上游长时间无响应
|
# provider.timeout 控制整体超时,避免上游长时间无响应
|
||||||
await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout)
|
await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
@@ -636,12 +644,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
||||||
elapsed = time.time() - last_data_time
|
elapsed = time.time() - last_data_time
|
||||||
if elapsed > self.DATA_TIMEOUT:
|
if elapsed > self.DATA_TIMEOUT:
|
||||||
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
logger.warning(f"Provider '{ctx.provider_name}' 流超时且无数据")
|
||||||
|
# 设置错误状态用于后续记录
|
||||||
|
ctx.status_code = 504
|
||||||
|
ctx.error_message = "流式响应超时,未收到有效数据"
|
||||||
|
ctx.upstream_response = f"流超时: Provider={ctx.provider_name}, elapsed={elapsed:.1f}s, chunk_count={ctx.chunk_count}, data_count=0"
|
||||||
error_event = {
|
error_event = {
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"error": {
|
"error": {
|
||||||
"type": "empty_stream_timeout",
|
"type": "empty_stream_timeout",
|
||||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
"message": ctx.error_message,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
self._mark_first_output(ctx, output_state)
|
self._mark_first_output(ctx, output_state)
|
||||||
@@ -682,12 +694,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
if ctx.data_count == 0:
|
if ctx.data_count == 0:
|
||||||
# 流已开始,无法抛出异常进行故障转移
|
# 流已开始,无法抛出异常进行故障转移
|
||||||
# 发送错误事件并记录日志
|
# 发送错误事件并记录日志
|
||||||
logger.warning(f"提供商 '{ctx.provider_name}' 返回空流式响应")
|
logger.warning(f"Provider '{ctx.provider_name}' 返回空流式响应")
|
||||||
|
# 设置错误状态用于后续记录
|
||||||
|
ctx.status_code = 503
|
||||||
|
ctx.error_message = "上游服务返回了空的流式响应"
|
||||||
|
ctx.upstream_response = f"空流式响应: Provider={ctx.provider_name}, chunk_count={ctx.chunk_count}, data_count=0"
|
||||||
error_event = {
|
error_event = {
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"error": {
|
"error": {
|
||||||
"type": "empty_response",
|
"type": "empty_response",
|
||||||
"message": f"提供商 '{ctx.provider_name}' 返回了空的流式响应",
|
"message": ctx.error_message,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||||
@@ -699,12 +715,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
except httpx.StreamClosed:
|
except httpx.StreamClosed:
|
||||||
if ctx.data_count == 0:
|
if ctx.data_count == 0:
|
||||||
# 流已开始,发送错误事件而不是抛出异常
|
# 流已开始,发送错误事件而不是抛出异常
|
||||||
logger.warning(f"提供商 '{ctx.provider_name}' 流连接关闭且无数据")
|
logger.warning(f"Provider '{ctx.provider_name}' 流连接关闭且无数据")
|
||||||
|
# 设置错误状态用于后续记录
|
||||||
|
ctx.status_code = 503
|
||||||
|
ctx.error_message = "上游服务连接关闭且未返回数据"
|
||||||
|
ctx.upstream_response = f"流连接关闭: Provider={ctx.provider_name}, chunk_count={ctx.chunk_count}, data_count=0"
|
||||||
error_event = {
|
error_event = {
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"error": {
|
"error": {
|
||||||
"type": "stream_closed",
|
"type": "stream_closed",
|
||||||
"message": f"提供商 '{ctx.provider_name}' 连接关闭且未返回数据",
|
"message": ctx.error_message,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||||
@@ -824,8 +844,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
f"base_url={endpoint.base_url}"
|
f"base_url={endpoint.base_url}"
|
||||||
)
|
)
|
||||||
raise ProviderNotAvailableException(
|
raise ProviderNotAvailableException(
|
||||||
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,"
|
"上游服务返回了非预期的响应格式",
|
||||||
f"请检查 endpoint 的 base_url 配置是否正确"
|
provider_name=str(provider.name),
|
||||||
|
upstream_status=200,
|
||||||
|
upstream_response=normalized_line[:500] if normalized_line else "(empty)",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not normalized_line or normalized_line.startswith(":"):
|
if not normalized_line or normalized_line.startswith(":"):
|
||||||
@@ -1024,12 +1046,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
||||||
elapsed = time.time() - last_data_time
|
elapsed = time.time() - last_data_time
|
||||||
if elapsed > self.DATA_TIMEOUT:
|
if elapsed > self.DATA_TIMEOUT:
|
||||||
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
logger.warning(f"Provider '{ctx.provider_name}' 流超时且无数据")
|
||||||
|
# 设置错误状态用于后续记录
|
||||||
|
ctx.status_code = 504
|
||||||
|
ctx.error_message = "流式响应超时,未收到有效数据"
|
||||||
|
ctx.upstream_response = f"流超时: Provider={ctx.provider_name}, elapsed={elapsed:.1f}s, chunk_count={ctx.chunk_count}, data_count=0"
|
||||||
error_event = {
|
error_event = {
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"error": {
|
"error": {
|
||||||
"type": "empty_stream_timeout",
|
"type": "empty_stream_timeout",
|
||||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
"message": ctx.error_message,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
self._mark_first_output(ctx, output_state)
|
self._mark_first_output(ctx, output_state)
|
||||||
@@ -1071,14 +1097,18 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
if ctx.data_count == 0:
|
if ctx.data_count == 0:
|
||||||
# 空流通常意味着配置错误(如 base_url 指向了网页而非 API)
|
# 空流通常意味着配置错误(如 base_url 指向了网页而非 API)
|
||||||
logger.error(
|
logger.error(
|
||||||
f"提供商 '{ctx.provider_name}' 返回空流式响应 (收到 {ctx.chunk_count} 个非数据行), "
|
f"Provider '{ctx.provider_name}' 返回空流式响应 (收到 {ctx.chunk_count} 个非数据行), "
|
||||||
f"可能是 endpoint base_url 配置错误"
|
f"可能是 endpoint base_url 配置错误"
|
||||||
)
|
)
|
||||||
|
# 设置错误状态用于后续记录
|
||||||
|
ctx.status_code = 503
|
||||||
|
ctx.error_message = "上游服务返回了空的流式响应"
|
||||||
|
ctx.upstream_response = f"空流式响应: Provider={ctx.provider_name}, chunk_count={ctx.chunk_count}, data_count=0, 可能是 base_url 配置错误"
|
||||||
error_event = {
|
error_event = {
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"error": {
|
"error": {
|
||||||
"type": "empty_response",
|
"type": "empty_response",
|
||||||
"message": f"提供商 '{ctx.provider_name}' 返回了空的流式响应 (收到 {ctx.chunk_count} 行非 SSE 数据),请检查 endpoint 的 base_url 配置是否指向了正确的 API 地址",
|
"message": ctx.error_message,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||||
@@ -1089,12 +1119,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
raise
|
raise
|
||||||
except httpx.StreamClosed:
|
except httpx.StreamClosed:
|
||||||
if ctx.data_count == 0:
|
if ctx.data_count == 0:
|
||||||
logger.warning(f"提供商 '{ctx.provider_name}' 流连接关闭且无数据")
|
logger.warning(f"Provider '{ctx.provider_name}' 流连接关闭且无数据")
|
||||||
|
# 设置错误状态用于后续记录
|
||||||
|
ctx.status_code = 503
|
||||||
|
ctx.error_message = "上游服务连接关闭且未返回数据"
|
||||||
|
ctx.upstream_response = f"流连接关闭: Provider={ctx.provider_name}, chunk_count={ctx.chunk_count}, data_count=0"
|
||||||
error_event = {
|
error_event = {
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"error": {
|
"error": {
|
||||||
"type": "stream_closed",
|
"type": "stream_closed",
|
||||||
"message": f"提供商 '{ctx.provider_name}' 连接关闭且未返回数据",
|
"message": ctx.error_message,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||||
@@ -1289,6 +1323,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
if ctx.status_code and ctx.status_code >= 400:
|
if ctx.status_code and ctx.status_code >= 400:
|
||||||
# 记录失败的 Usage,但使用已收到的预估 token 信息(来自 message_start)
|
# 记录失败的 Usage,但使用已收到的预估 token 信息(来自 message_start)
|
||||||
# 这样即使请求中断,也能记录预估成本
|
# 这样即使请求中断,也能记录预估成本
|
||||||
|
# 失败时返回给客户端的是 JSON 错误响应,如果没有设置则使用默认值
|
||||||
|
client_response_headers = ctx.client_response_headers or {"content-type": "application/json"}
|
||||||
await bg_telemetry.record_failure(
|
await bg_telemetry.record_failure(
|
||||||
provider=ctx.provider_name or "unknown",
|
provider=ctx.provider_name or "unknown",
|
||||||
model=ctx.model,
|
model=ctx.model,
|
||||||
@@ -1306,6 +1342,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
cache_creation_tokens=ctx.cache_creation_tokens,
|
cache_creation_tokens=ctx.cache_creation_tokens,
|
||||||
cache_read_tokens=ctx.cached_tokens,
|
cache_read_tokens=ctx.cached_tokens,
|
||||||
response_body=response_body,
|
response_body=response_body,
|
||||||
|
response_headers=ctx.response_headers,
|
||||||
|
client_response_headers=client_response_headers,
|
||||||
# 模型映射信息
|
# 模型映射信息
|
||||||
target_model=ctx.mapped_model,
|
target_model=ctx.mapped_model,
|
||||||
)
|
)
|
||||||
@@ -1319,6 +1357,14 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
# 在记录统计前,允许子类从 parsed_chunks 中提取额外的元数据
|
# 在记录统计前,允许子类从 parsed_chunks 中提取额外的元数据
|
||||||
self._finalize_stream_metadata(ctx)
|
self._finalize_stream_metadata(ctx)
|
||||||
|
|
||||||
|
# 流式成功时,返回给客户端的是提供商响应头 + SSE 必需头
|
||||||
|
client_response_headers = dict(ctx.response_headers) if ctx.response_headers else {}
|
||||||
|
client_response_headers.update({
|
||||||
|
"Cache-Control": "no-cache, no-transform",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"content-type": "text/event-stream",
|
||||||
|
})
|
||||||
|
|
||||||
total_cost = await bg_telemetry.record_success(
|
total_cost = await bg_telemetry.record_success(
|
||||||
provider=ctx.provider_name,
|
provider=ctx.provider_name,
|
||||||
model=ctx.model,
|
model=ctx.model,
|
||||||
@@ -1330,6 +1376,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
request_headers=original_headers,
|
request_headers=original_headers,
|
||||||
request_body=actual_request_body,
|
request_body=actual_request_body,
|
||||||
response_headers=ctx.response_headers,
|
response_headers=ctx.response_headers,
|
||||||
|
client_response_headers=client_response_headers,
|
||||||
response_body=response_body,
|
response_body=response_body,
|
||||||
cache_creation_tokens=ctx.cache_creation_tokens,
|
cache_creation_tokens=ctx.cache_creation_tokens,
|
||||||
cache_read_tokens=ctx.cached_tokens,
|
cache_read_tokens=ctx.cached_tokens,
|
||||||
@@ -1367,13 +1414,15 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
# 499 = 客户端断开连接,应标记为失败
|
# 499 = 客户端断开连接,应标记为失败
|
||||||
# 503 = 服务不可用(如流中断),应标记为失败
|
# 503 = 服务不可用(如流中断),应标记为失败
|
||||||
if ctx.status_code and ctx.status_code >= 400:
|
if ctx.status_code and ctx.status_code >= 400:
|
||||||
|
# 请求链路追踪使用 upstream_response(原始响应),回退到 error_message(友好消息)
|
||||||
|
trace_error_message = ctx.upstream_response or ctx.error_message or f"HTTP {ctx.status_code}"
|
||||||
RequestCandidateService.mark_candidate_failed(
|
RequestCandidateService.mark_candidate_failed(
|
||||||
db=bg_db,
|
db=bg_db,
|
||||||
candidate_id=ctx.attempt_id,
|
candidate_id=ctx.attempt_id,
|
||||||
error_type=(
|
error_type=(
|
||||||
"client_disconnected" if ctx.status_code == 499 else "stream_error"
|
"client_disconnected" if ctx.status_code == 499 else "stream_error"
|
||||||
),
|
),
|
||||||
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
|
error_message=trace_error_message,
|
||||||
status_code=ctx.status_code,
|
status_code=ctx.status_code,
|
||||||
latency_ms=response_time_ms,
|
latency_ms=response_time_ms,
|
||||||
extra_data={
|
extra_data={
|
||||||
@@ -1426,17 +1475,22 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
|
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
|
||||||
actual_request_body = ctx.provider_request_body or original_request_body
|
actual_request_body = ctx.provider_request_body or original_request_body
|
||||||
|
|
||||||
|
# 失败时返回给客户端的是 JSON 错误响应
|
||||||
|
client_response_headers = {"content-type": "application/json"}
|
||||||
|
|
||||||
await self.telemetry.record_failure(
|
await self.telemetry.record_failure(
|
||||||
provider=ctx.provider_name or "unknown",
|
provider=ctx.provider_name or "unknown",
|
||||||
model=ctx.model,
|
model=ctx.model,
|
||||||
response_time_ms=response_time_ms,
|
response_time_ms=response_time_ms,
|
||||||
status_code=status_code,
|
status_code=status_code,
|
||||||
error_message=extract_error_message(error),
|
error_message=extract_client_error_message(error),
|
||||||
request_headers=original_headers,
|
request_headers=original_headers,
|
||||||
request_body=actual_request_body,
|
request_body=actual_request_body,
|
||||||
is_stream=True,
|
is_stream=True,
|
||||||
api_format=ctx.api_format,
|
api_format=ctx.api_format,
|
||||||
provider_request_headers=ctx.provider_request_headers,
|
provider_request_headers=ctx.provider_request_headers,
|
||||||
|
response_headers=ctx.response_headers,
|
||||||
|
client_response_headers=client_response_headers,
|
||||||
# 模型映射信息
|
# 模型映射信息
|
||||||
target_model=ctx.mapped_model,
|
target_model=ctx.mapped_model,
|
||||||
)
|
)
|
||||||
@@ -1534,72 +1588,99 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}"
|
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 HTTP 客户端(支持代理配置)
|
# 获取复用的 HTTP 客户端(支持代理配置,从 Provider 读取)
|
||||||
# endpoint.timeout 作为整体请求超时
|
# 注意:使用 get_proxy_client 复用连接池,不再每次创建新客户端
|
||||||
from src.clients.http_client import HTTPClientPool
|
from src.clients.http_client import HTTPClientPool
|
||||||
|
|
||||||
request_timeout = float(endpoint.timeout or 300)
|
request_timeout = float(provider.timeout or 300)
|
||||||
http_client = HTTPClientPool.create_client_with_proxy(
|
http_client = await HTTPClientPool.get_proxy_client(
|
||||||
proxy_config=endpoint.proxy,
|
proxy_config=provider.proxy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注意:不使用 async with,因为复用的客户端不应该被关闭
|
||||||
|
# 超时通过 timeout 参数控制
|
||||||
|
resp = await http_client.post(
|
||||||
|
url,
|
||||||
|
json=provider_payload,
|
||||||
|
headers=provider_headers,
|
||||||
timeout=httpx.Timeout(request_timeout),
|
timeout=httpx.Timeout(request_timeout),
|
||||||
)
|
)
|
||||||
async with http_client:
|
|
||||||
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
|
|
||||||
|
|
||||||
status_code = resp.status_code
|
status_code = resp.status_code
|
||||||
response_headers = dict(resp.headers)
|
response_headers = dict(resp.headers)
|
||||||
|
|
||||||
if resp.status_code == 401:
|
if resp.status_code == 401:
|
||||||
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
|
raise ProviderAuthException(str(provider.name))
|
||||||
elif resp.status_code == 429:
|
elif resp.status_code == 429:
|
||||||
raise ProviderRateLimitException(
|
raise ProviderRateLimitException(
|
||||||
f"提供商速率限制: {provider.name}",
|
"请求过于频繁,请稍后重试",
|
||||||
provider_name=str(provider.name),
|
provider_name=str(provider.name),
|
||||||
response_headers=response_headers,
|
response_headers=response_headers,
|
||||||
retry_after=int(resp.headers.get("retry-after", 0)) or None,
|
retry_after=int(resp.headers.get("retry-after", 0)) or None,
|
||||||
)
|
)
|
||||||
elif resp.status_code >= 500:
|
elif resp.status_code >= 500:
|
||||||
error_text = resp.text
|
error_text = resp.text
|
||||||
raise ProviderNotAvailableException(
|
raise ProviderNotAvailableException(
|
||||||
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}",
|
f"上游服务暂时不可用 (HTTP {resp.status_code})",
|
||||||
provider_name=str(provider.name),
|
provider_name=str(provider.name),
|
||||||
upstream_status=resp.status_code,
|
upstream_status=resp.status_code,
|
||||||
upstream_response=error_text,
|
upstream_response=error_text,
|
||||||
)
|
)
|
||||||
elif 300 <= resp.status_code < 400:
|
elif 300 <= resp.status_code < 400:
|
||||||
redirect_url = resp.headers.get("location", "unknown")
|
redirect_url = resp.headers.get("location", "unknown")
|
||||||
raise ProviderNotAvailableException(
|
raise ProviderNotAvailableException(
|
||||||
f"提供商配置错误: {provider.name}, 返回重定向 {resp.status_code} -> {redirect_url}"
|
"上游服务返回重定向响应",
|
||||||
)
|
provider_name=str(provider.name),
|
||||||
elif resp.status_code != 200:
|
upstream_status=resp.status_code,
|
||||||
error_text = resp.text
|
upstream_response=f"重定向 {resp.status_code} -> {redirect_url}",
|
||||||
raise ProviderNotAvailableException(
|
)
|
||||||
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
|
elif resp.status_code != 200:
|
||||||
provider_name=str(provider.name),
|
error_text = resp.text
|
||||||
upstream_status=resp.status_code,
|
raise ProviderNotAvailableException(
|
||||||
upstream_response=error_text,
|
f"上游服务返回错误 (HTTP {resp.status_code})",
|
||||||
)
|
provider_name=str(provider.name),
|
||||||
|
upstream_status=resp.status_code,
|
||||||
|
upstream_response=error_text,
|
||||||
|
)
|
||||||
|
|
||||||
# 安全解析 JSON 响应,处理可能的编码错误
|
# 安全解析 JSON 响应,处理可能的编码错误
|
||||||
|
try:
|
||||||
|
response_json = resp.json()
|
||||||
|
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
||||||
|
# 获取原始响应内容用于调试(存入 upstream_response)
|
||||||
|
content_type = resp.headers.get("content-type", "unknown")
|
||||||
|
content_encoding = resp.headers.get("content-encoding", "none")
|
||||||
|
raw_content = ""
|
||||||
try:
|
try:
|
||||||
response_json = resp.json()
|
raw_content = resp.text[:500] if resp.text else "(empty)"
|
||||||
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
except Exception:
|
||||||
# 记录原始响应信息用于调试
|
try:
|
||||||
content_type = resp.headers.get("content-type", "unknown")
|
raw_content = repr(resp.content[:500]) if resp.content else "(empty)"
|
||||||
content_encoding = resp.headers.get("content-encoding", "none")
|
except Exception:
|
||||||
logger.error(
|
raw_content = "(unable to read)"
|
||||||
f"[{self.request_id}] 无法解析响应 JSON: {e}, "
|
logger.error(
|
||||||
f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, "
|
f"[{self.request_id}] 无法解析响应 JSON: {e}, "
|
||||||
f"响应长度: {len(resp.content)} bytes"
|
f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, "
|
||||||
)
|
f"响应长度: {len(resp.content)} bytes, 原始内容: {raw_content}"
|
||||||
raise ProviderNotAvailableException(
|
)
|
||||||
f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}"
|
# 判断错误类型,生成友好的客户端错误消息(不暴露提供商信息)
|
||||||
)
|
if raw_content == "(empty)" or not raw_content.strip():
|
||||||
|
client_message = "上游服务返回了空响应"
|
||||||
|
elif raw_content.strip().startswith(("<", "<!doctype", "<!DOCTYPE")):
|
||||||
|
client_message = "上游服务返回了非预期的响应格式"
|
||||||
|
else:
|
||||||
|
client_message = "上游服务返回了无效的响应"
|
||||||
|
raise ProviderNotAvailableException(
|
||||||
|
client_message,
|
||||||
|
provider_name=str(provider.name),
|
||||||
|
upstream_status=resp.status_code,
|
||||||
|
upstream_response=raw_content,
|
||||||
|
)
|
||||||
|
|
||||||
# 提取 Provider 响应元数据(子类可覆盖)
|
# 提取 Provider 响应元数据(子类可覆盖)
|
||||||
response_metadata_result = self._extract_response_metadata(response_json)
|
response_metadata_result = self._extract_response_metadata(response_json)
|
||||||
|
|
||||||
return response_json if isinstance(response_json, dict) else {}
|
return response_json if isinstance(response_json, dict) else {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 解析能力需求
|
# 解析能力需求
|
||||||
@@ -1663,6 +1744,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
|
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
|
||||||
actual_request_body = provider_request_body or original_request_body
|
actual_request_body = provider_request_body or original_request_body
|
||||||
|
|
||||||
|
# 非流式成功时,返回给客户端的是提供商响应头(透传)
|
||||||
|
client_response_headers = dict(response_headers) if response_headers else {}
|
||||||
|
client_response_headers["content-type"] = "application/json"
|
||||||
|
|
||||||
total_cost = await self.telemetry.record_success(
|
total_cost = await self.telemetry.record_success(
|
||||||
provider=provider_name,
|
provider=provider_name,
|
||||||
model=model,
|
model=model,
|
||||||
@@ -1673,6 +1758,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
request_headers=original_headers,
|
request_headers=original_headers,
|
||||||
request_body=actual_request_body,
|
request_body=actual_request_body,
|
||||||
response_headers=response_headers,
|
response_headers=response_headers,
|
||||||
|
client_response_headers=client_response_headers,
|
||||||
response_body=response_json,
|
response_body=response_json,
|
||||||
cache_creation_tokens=cache_creation_tokens,
|
cache_creation_tokens=cache_creation_tokens,
|
||||||
cache_read_tokens=cached_tokens,
|
cache_read_tokens=cached_tokens,
|
||||||
@@ -1691,7 +1777,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
|
|
||||||
logger.info(f"{self.FORMAT_ID} 非流式响应处理完成")
|
logger.info(f"{self.FORMAT_ID} 非流式响应处理完成")
|
||||||
|
|
||||||
return JSONResponse(status_code=status_code, content=response_json)
|
# 透传提供商的响应头
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status_code,
|
||||||
|
content=response_json,
|
||||||
|
headers=response_headers if response_headers else None,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
response_time_ms = int((time.time() - sync_start_time) * 1000)
|
response_time_ms = int((time.time() - sync_start_time) * 1000)
|
||||||
@@ -1707,17 +1798,27 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
|
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
|
||||||
actual_request_body = provider_request_body or original_request_body
|
actual_request_body = provider_request_body or original_request_body
|
||||||
|
|
||||||
|
# 尝试从异常中提取响应头
|
||||||
|
error_response_headers: Dict[str, str] = {}
|
||||||
|
if isinstance(e, ProviderRateLimitException) and e.response_headers:
|
||||||
|
error_response_headers = e.response_headers
|
||||||
|
elif isinstance(e, httpx.HTTPStatusError) and hasattr(e, "response"):
|
||||||
|
error_response_headers = dict(e.response.headers)
|
||||||
|
|
||||||
await self.telemetry.record_failure(
|
await self.telemetry.record_failure(
|
||||||
provider=provider_name or "unknown",
|
provider=provider_name or "unknown",
|
||||||
model=model,
|
model=model,
|
||||||
response_time_ms=response_time_ms,
|
response_time_ms=response_time_ms,
|
||||||
status_code=status_code,
|
status_code=status_code,
|
||||||
error_message=extract_error_message(e),
|
error_message=extract_client_error_message(e),
|
||||||
request_headers=original_headers,
|
request_headers=original_headers,
|
||||||
request_body=actual_request_body,
|
request_body=actual_request_body,
|
||||||
is_stream=False,
|
is_stream=False,
|
||||||
api_format=api_format,
|
api_format=api_format,
|
||||||
provider_request_headers=provider_request_headers,
|
provider_request_headers=provider_request_headers,
|
||||||
|
response_headers=error_response_headers,
|
||||||
|
# 非流式失败返回给客户端的是 JSON 错误响应
|
||||||
|
client_response_headers={"content-type": "application/json"},
|
||||||
# 模型映射信息
|
# 模型映射信息
|
||||||
target_model=mapped_model_result,
|
target_model=mapped_model_result,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -74,7 +74,6 @@ def build_safe_headers(
|
|||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
# 保持向后兼容的run_endpoint_check函数(使用新架构)
|
|
||||||
async def run_endpoint_check(
|
async def run_endpoint_check(
|
||||||
*,
|
*,
|
||||||
client: httpx.AsyncClient, # 保持兼容性,但内部不使用
|
client: httpx.AsyncClient, # 保持兼容性,但内部不使用
|
||||||
@@ -176,10 +175,16 @@ async def _calculate_and_record_usage(
|
|||||||
logger.warning(f"Provider API Key not found for usage calculation: {api_key_id}")
|
logger.warning(f"Provider API Key not found for usage calculation: {api_key_id}")
|
||||||
return {"error": "Provider API Key not found"}
|
return {"error": "Provider API Key not found"}
|
||||||
|
|
||||||
# 获取Provider Endpoint信息
|
# 获取Provider Endpoint信息(通过 api_format 查找)
|
||||||
provider_endpoint = None
|
provider_endpoint = None
|
||||||
if provider_api_key.endpoint_id:
|
if api_format and provider_api_key.provider_id:
|
||||||
provider_endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == provider_api_key.endpoint_id).first()
|
from src.models.database import Provider
|
||||||
|
provider = db.query(Provider).filter(Provider.id == provider_api_key.provider_id).first()
|
||||||
|
if provider:
|
||||||
|
for ep in provider.endpoints:
|
||||||
|
if ep.api_format == api_format and ep.is_active:
|
||||||
|
provider_endpoint = ep
|
||||||
|
break
|
||||||
|
|
||||||
# 获取用户的API Key(用于记录关联,即使实际使用的是Provider API Key)
|
# 获取用户的API Key(用于记录关联,即使实际使用的是Provider API Key)
|
||||||
user_api_key = None
|
user_api_key = None
|
||||||
|
|||||||
@@ -61,11 +61,13 @@ class StreamContext:
|
|||||||
|
|
||||||
# 响应状态
|
# 响应状态
|
||||||
status_code: int = 200
|
status_code: int = 200
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None # 客户端友好的错误消息
|
||||||
|
upstream_response: Optional[str] = None # 原始 Provider 响应(用于请求链路追踪)
|
||||||
has_completion: bool = False
|
has_completion: bool = False
|
||||||
|
|
||||||
# 请求/响应数据
|
# 请求/响应数据
|
||||||
response_headers: Dict[str, str] = field(default_factory=dict)
|
response_headers: Dict[str, str] = field(default_factory=dict) # 提供商响应头
|
||||||
|
client_response_headers: Dict[str, str] = field(default_factory=dict) # 返回给客户端的响应头
|
||||||
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
||||||
provider_request_body: Optional[Dict[str, Any]] = None
|
provider_request_body: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
@@ -97,9 +99,11 @@ class StreamContext:
|
|||||||
self.cached_tokens = 0
|
self.cached_tokens = 0
|
||||||
self.cache_creation_tokens = 0
|
self.cache_creation_tokens = 0
|
||||||
self.error_message = None
|
self.error_message = None
|
||||||
|
self.upstream_response = None
|
||||||
self.status_code = 200
|
self.status_code = 200
|
||||||
self.first_byte_time_ms = None
|
self.first_byte_time_ms = None
|
||||||
self.response_headers = {}
|
self.response_headers = {}
|
||||||
|
self.client_response_headers = {}
|
||||||
self.provider_request_headers = {}
|
self.provider_request_headers = {}
|
||||||
self.provider_request_body = None
|
self.provider_request_body = None
|
||||||
self.response_id = None
|
self.response_id = None
|
||||||
@@ -174,10 +178,24 @@ class StreamContext:
|
|||||||
):
|
):
|
||||||
self.cache_creation_tokens = cache_creation_tokens
|
self.cache_creation_tokens = cache_creation_tokens
|
||||||
|
|
||||||
def mark_failed(self, status_code: int, error_message: str) -> None:
|
def mark_failed(
|
||||||
"""标记请求失败"""
|
self,
|
||||||
|
status_code: int,
|
||||||
|
error_message: str,
|
||||||
|
upstream_response: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
标记请求失败
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status_code: HTTP 状态码
|
||||||
|
error_message: 客户端友好的错误消息
|
||||||
|
upstream_response: 原始 Provider 响应(用于请求链路追踪)
|
||||||
|
"""
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.error_message = error_message
|
self.error_message = error_message
|
||||||
|
if upstream_response:
|
||||||
|
self.upstream_response = upstream_response
|
||||||
|
|
||||||
def record_first_byte_time(self, start_time: float) -> None:
|
def record_first_byte_time(self, start_time: float) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -251,8 +251,10 @@ class StreamProcessor:
|
|||||||
f"base_url={endpoint.base_url}"
|
f"base_url={endpoint.base_url}"
|
||||||
)
|
)
|
||||||
raise ProviderNotAvailableException(
|
raise ProviderNotAvailableException(
|
||||||
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,"
|
"上游服务返回了非预期的响应格式",
|
||||||
f"请检查 endpoint 的 base_url 配置是否正确"
|
provider_name=str(provider.name),
|
||||||
|
upstream_status=200,
|
||||||
|
upstream_response=line[:500] if line else "(empty)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 跳过空行和注释行
|
# 跳过空行和注释行
|
||||||
|
|||||||
@@ -154,6 +154,14 @@ class StreamTelemetryRecorder:
|
|||||||
response_time_ms: int,
|
response_time_ms: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""记录成功的请求"""
|
"""记录成功的请求"""
|
||||||
|
# 流式成功时,返回给客户端的是提供商响应头 + SSE 必需头
|
||||||
|
client_response_headers = dict(ctx.response_headers) if ctx.response_headers else {}
|
||||||
|
client_response_headers.update({
|
||||||
|
"Cache-Control": "no-cache, no-transform",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"content-type": "text/event-stream",
|
||||||
|
})
|
||||||
|
|
||||||
await telemetry.record_success(
|
await telemetry.record_success(
|
||||||
provider=ctx.provider_name or "unknown",
|
provider=ctx.provider_name or "unknown",
|
||||||
model=ctx.model,
|
model=ctx.model,
|
||||||
@@ -165,6 +173,7 @@ class StreamTelemetryRecorder:
|
|||||||
request_headers=original_headers,
|
request_headers=original_headers,
|
||||||
request_body=actual_request_body,
|
request_body=actual_request_body,
|
||||||
response_headers=ctx.response_headers,
|
response_headers=ctx.response_headers,
|
||||||
|
client_response_headers=client_response_headers,
|
||||||
response_body=response_body,
|
response_body=response_body,
|
||||||
cache_creation_tokens=ctx.cache_creation_tokens,
|
cache_creation_tokens=ctx.cache_creation_tokens,
|
||||||
cache_read_tokens=ctx.cached_tokens,
|
cache_read_tokens=ctx.cached_tokens,
|
||||||
@@ -190,6 +199,9 @@ class StreamTelemetryRecorder:
|
|||||||
response_time_ms: int,
|
response_time_ms: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""记录失败的请求"""
|
"""记录失败的请求"""
|
||||||
|
# 失败时返回给客户端的是 JSON 错误响应,如果没有设置则使用默认值
|
||||||
|
client_response_headers = ctx.client_response_headers or {"content-type": "application/json"}
|
||||||
|
|
||||||
await telemetry.record_failure(
|
await telemetry.record_failure(
|
||||||
provider=ctx.provider_name or "unknown",
|
provider=ctx.provider_name or "unknown",
|
||||||
model=ctx.model,
|
model=ctx.model,
|
||||||
@@ -206,6 +218,8 @@ class StreamTelemetryRecorder:
|
|||||||
cache_creation_tokens=ctx.cache_creation_tokens,
|
cache_creation_tokens=ctx.cache_creation_tokens,
|
||||||
cache_read_tokens=ctx.cached_tokens,
|
cache_read_tokens=ctx.cached_tokens,
|
||||||
response_body=response_body,
|
response_body=response_body,
|
||||||
|
response_headers=ctx.response_headers,
|
||||||
|
client_response_headers=client_response_headers,
|
||||||
target_model=ctx.mapped_model,
|
target_model=ctx.mapped_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -239,11 +253,13 @@ class StreamTelemetryRecorder:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
error_type = "client_disconnected" if ctx.status_code == 499 else "stream_error"
|
error_type = "client_disconnected" if ctx.status_code == 499 else "stream_error"
|
||||||
|
# 请求链路追踪使用 upstream_response(原始响应),回退到 error_message(友好消息)
|
||||||
|
trace_error_message = ctx.upstream_response or ctx.error_message or f"HTTP {ctx.status_code}"
|
||||||
RequestCandidateService.mark_candidate_failed(
|
RequestCandidateService.mark_candidate_failed(
|
||||||
db=db,
|
db=db,
|
||||||
candidate_id=ctx.attempt_id,
|
candidate_id=ctx.attempt_id,
|
||||||
error_type=error_type,
|
error_type=error_type,
|
||||||
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
|
error_message=trace_error_message,
|
||||||
status_code=ctx.status_code,
|
status_code=ctx.status_code,
|
||||||
latency_ms=response_time_ms,
|
latency_ms=response_time_ms,
|
||||||
extra_data={
|
extra_data={
|
||||||
|
|||||||
@@ -26,8 +26,10 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
|||||||
3. **旧格式(优先级第三)**:
|
3. **旧格式(优先级第三)**:
|
||||||
usage.cache_creation_input_tokens
|
usage.cache_creation_input_tokens
|
||||||
|
|
||||||
优先使用嵌套格式,如果嵌套格式字段存在但值为 0,则智能 fallback 到旧格式。
|
说明:
|
||||||
扁平格式和嵌套格式互斥,按顺序检查。
|
- 只要检测到新格式字段(嵌套/扁平),即视为权威来源:哪怕值为 0 也不回退到旧字段。
|
||||||
|
- 仅当新格式字段完全不存在时,才回退到旧字段。
|
||||||
|
- 扁平格式和嵌套格式互斥,按顺序检查。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
usage: API 响应中的 usage 字典
|
usage: API 响应中的 usage 字典
|
||||||
@@ -37,27 +39,20 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
|||||||
"""
|
"""
|
||||||
# 1. 检查嵌套格式(最新格式)
|
# 1. 检查嵌套格式(最新格式)
|
||||||
cache_creation = usage.get("cache_creation")
|
cache_creation = usage.get("cache_creation")
|
||||||
if isinstance(cache_creation, dict):
|
has_nested_format = isinstance(cache_creation, dict) and (
|
||||||
|
"ephemeral_5m_input_tokens" in cache_creation
|
||||||
|
or "ephemeral_1h_input_tokens" in cache_creation
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_nested_format:
|
||||||
cache_5m = int(cache_creation.get("ephemeral_5m_input_tokens", 0))
|
cache_5m = int(cache_creation.get("ephemeral_5m_input_tokens", 0))
|
||||||
cache_1h = int(cache_creation.get("ephemeral_1h_input_tokens", 0))
|
cache_1h = int(cache_creation.get("ephemeral_1h_input_tokens", 0))
|
||||||
total = cache_5m + cache_1h
|
total = cache_5m + cache_1h
|
||||||
|
|
||||||
if total > 0:
|
logger.debug(
|
||||||
logger.debug(
|
f"Using nested cache_creation: 5m={cache_5m}, 1h={cache_1h}, total={total}"
|
||||||
f"Using nested cache_creation: 5m={cache_5m}, 1h={cache_1h}, total={total}"
|
)
|
||||||
)
|
return total
|
||||||
return total
|
|
||||||
|
|
||||||
# 嵌套格式存在但为 0,fallback 到旧格式
|
|
||||||
old_format = int(usage.get("cache_creation_input_tokens", 0))
|
|
||||||
if old_format > 0:
|
|
||||||
logger.debug(
|
|
||||||
f"Nested cache_creation is 0, using old format: {old_format}"
|
|
||||||
)
|
|
||||||
return old_format
|
|
||||||
|
|
||||||
# 都是 0,返回 0
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# 2. 检查扁平新格式
|
# 2. 检查扁平新格式
|
||||||
has_flat_format = (
|
has_flat_format = (
|
||||||
@@ -70,22 +65,10 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
|||||||
cache_1h = int(usage.get("claude_cache_creation_1_h_tokens", 0))
|
cache_1h = int(usage.get("claude_cache_creation_1_h_tokens", 0))
|
||||||
total = cache_5m + cache_1h
|
total = cache_5m + cache_1h
|
||||||
|
|
||||||
if total > 0:
|
logger.debug(
|
||||||
logger.debug(
|
f"Using flat new format: 5m={cache_5m}, 1h={cache_1h}, total={total}"
|
||||||
f"Using flat new format: 5m={cache_5m}, 1h={cache_1h}, total={total}"
|
)
|
||||||
)
|
return total
|
||||||
return total
|
|
||||||
|
|
||||||
# 扁平格式存在但为 0,fallback 到旧格式
|
|
||||||
old_format = int(usage.get("cache_creation_input_tokens", 0))
|
|
||||||
if old_format > 0:
|
|
||||||
logger.debug(
|
|
||||||
f"Flat cache_creation is 0, using old format: {old_format}"
|
|
||||||
)
|
|
||||||
return old_format
|
|
||||||
|
|
||||||
# 都是 0,返回 0
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# 3. 回退到旧格式
|
# 3. 回退到旧格式
|
||||||
old_format = int(usage.get("cache_creation_input_tokens", 0))
|
old_format = int(usage.get("cache_creation_input_tokens", 0))
|
||||||
@@ -173,8 +156,10 @@ def check_prefetched_response_error(
|
|||||||
f"base_url={base_url}"
|
f"base_url={base_url}"
|
||||||
)
|
)
|
||||||
raise ProviderNotAvailableException(
|
raise ProviderNotAvailableException(
|
||||||
f"提供商 '{provider_name}' 返回了 HTML 页面而非 API 响应,"
|
"上游服务返回了非预期的响应格式",
|
||||||
f"请检查 endpoint 的 base_url 配置是否正确"
|
provider_name=provider_name,
|
||||||
|
upstream_status=200,
|
||||||
|
upstream_response=stripped.decode("utf-8", errors="replace")[:500],
|
||||||
)
|
)
|
||||||
|
|
||||||
# 纯 JSON(可能无换行/多行 JSON)
|
# 纯 JSON(可能无换行/多行 JSON)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
Claude CLI Message Handler - 基于通用 CLI Handler 基类的简化实现
|
Claude CLI Message Handler - 基于通用 CLI Handler 基类的简化实现
|
||||||
|
|
||||||
继承 CliMessageHandlerBase,只需覆盖格式特定的配置和事件处理逻辑。
|
继承 CliMessageHandlerBase,只需覆盖格式特定的配置和事件处理逻辑。
|
||||||
验证新架构的有效性:代码量从数百行减少到 ~80 行。
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|||||||
@@ -102,7 +102,6 @@ async def get_public_models(
|
|||||||
- id: 模型唯一标识符
|
- id: 模型唯一标识符
|
||||||
- provider_id: 所属提供商 ID
|
- provider_id: 所属提供商 ID
|
||||||
- provider_name: 提供商名称
|
- provider_name: 提供商名称
|
||||||
- provider_display_name: 提供商显示名称
|
|
||||||
- name: 模型统一名称(优先使用 GlobalModel 名称)
|
- name: 模型统一名称(优先使用 GlobalModel 名称)
|
||||||
- display_name: 模型显示名称
|
- display_name: 模型显示名称
|
||||||
- description: 模型描述信息
|
- description: 模型描述信息
|
||||||
@@ -300,10 +299,20 @@ class PublicProvidersAdapter(PublicApiAdapter):
|
|||||||
providers = query.offset(self.skip).limit(self.limit).all()
|
providers = query.offset(self.skip).limit(self.limit).all()
|
||||||
result = []
|
result = []
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
models_count = db.query(Model).filter(Model.provider_id == provider.id).count()
|
models_count = (
|
||||||
|
db.query(Model)
|
||||||
|
.filter(Model.provider_id == provider.id, Model.global_model_id.isnot(None))
|
||||||
|
.count()
|
||||||
|
)
|
||||||
active_models_count = (
|
active_models_count = (
|
||||||
db.query(Model)
|
db.query(Model)
|
||||||
.filter(and_(Model.provider_id == provider.id, Model.is_active.is_(True)))
|
.filter(
|
||||||
|
and_(
|
||||||
|
Model.provider_id == provider.id,
|
||||||
|
Model.is_active.is_(True),
|
||||||
|
Model.global_model_id.isnot(None),
|
||||||
|
)
|
||||||
|
)
|
||||||
.count()
|
.count()
|
||||||
)
|
)
|
||||||
endpoints_count = len(provider.endpoints) if provider.endpoints else 0
|
endpoints_count = len(provider.endpoints) if provider.endpoints else 0
|
||||||
@@ -313,7 +322,6 @@ class PublicProvidersAdapter(PublicApiAdapter):
|
|||||||
provider_data = PublicProviderResponse(
|
provider_data = PublicProviderResponse(
|
||||||
id=provider.id,
|
id=provider.id,
|
||||||
name=provider.name,
|
name=provider.name,
|
||||||
display_name=provider.display_name,
|
|
||||||
description=provider.description,
|
description=provider.description,
|
||||||
is_active=provider.is_active,
|
is_active=provider.is_active,
|
||||||
provider_priority=provider.provider_priority,
|
provider_priority=provider.provider_priority,
|
||||||
@@ -342,7 +350,13 @@ class PublicModelsAdapter(PublicApiAdapter):
|
|||||||
db.query(Model, Provider)
|
db.query(Model, Provider)
|
||||||
.options(joinedload(Model.global_model))
|
.options(joinedload(Model.global_model))
|
||||||
.join(Provider)
|
.join(Provider)
|
||||||
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
|
.filter(
|
||||||
|
and_(
|
||||||
|
Model.is_active.is_(True),
|
||||||
|
Provider.is_active.is_(True),
|
||||||
|
Model.global_model_id.isnot(None),
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if self.provider_id is not None:
|
if self.provider_id is not None:
|
||||||
query = query.filter(Model.provider_id == self.provider_id)
|
query = query.filter(Model.provider_id == self.provider_id)
|
||||||
@@ -357,7 +371,6 @@ class PublicModelsAdapter(PublicApiAdapter):
|
|||||||
id=model.id,
|
id=model.id,
|
||||||
provider_id=model.provider_id,
|
provider_id=model.provider_id,
|
||||||
provider_name=provider.name,
|
provider_name=provider.name,
|
||||||
provider_display_name=provider.display_name,
|
|
||||||
name=unified_name,
|
name=unified_name,
|
||||||
display_name=display_name,
|
display_name=display_name,
|
||||||
description=global_model.config.get("description") if global_model and global_model.config else None,
|
description=global_model.config.get("description") if global_model and global_model.config else None,
|
||||||
@@ -386,7 +399,13 @@ class PublicStatsAdapter(PublicApiAdapter):
|
|||||||
active_models = (
|
active_models = (
|
||||||
db.query(Model)
|
db.query(Model)
|
||||||
.join(Provider)
|
.join(Provider)
|
||||||
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
|
.filter(
|
||||||
|
and_(
|
||||||
|
Model.is_active.is_(True),
|
||||||
|
Provider.is_active.is_(True),
|
||||||
|
Model.global_model_id.isnot(None),
|
||||||
|
)
|
||||||
|
)
|
||||||
.count()
|
.count()
|
||||||
)
|
)
|
||||||
formats = (
|
formats = (
|
||||||
@@ -418,7 +437,13 @@ class PublicSearchModelsAdapter(PublicApiAdapter):
|
|||||||
.options(joinedload(Model.global_model))
|
.options(joinedload(Model.global_model))
|
||||||
.join(Provider)
|
.join(Provider)
|
||||||
.outerjoin(GlobalModel, Model.global_model_id == GlobalModel.id)
|
.outerjoin(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||||||
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
|
.filter(
|
||||||
|
and_(
|
||||||
|
Model.is_active.is_(True),
|
||||||
|
Provider.is_active.is_(True),
|
||||||
|
Model.global_model_id.isnot(None),
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
search_filter = (
|
search_filter = (
|
||||||
Model.provider_model_name.ilike(f"%{self.query}%")
|
Model.provider_model_name.ilike(f"%{self.query}%")
|
||||||
@@ -439,7 +464,6 @@ class PublicSearchModelsAdapter(PublicApiAdapter):
|
|||||||
id=model.id,
|
id=model.id,
|
||||||
provider_id=model.provider_id,
|
provider_id=model.provider_id,
|
||||||
provider_name=provider.name,
|
provider_name=provider.name,
|
||||||
provider_display_name=provider.display_name,
|
|
||||||
name=unified_name,
|
name=unified_name,
|
||||||
display_name=display_name,
|
display_name=display_name,
|
||||||
description=global_model.config.get("description") if global_model and global_model.config else None,
|
description=global_model.config.get("description") if global_model and global_model.config else None,
|
||||||
|
|||||||
@@ -43,7 +43,6 @@ def _serialize_provider(
|
|||||||
provider_data: Dict[str, Any] = {
|
provider_data: Dict[str, Any] = {
|
||||||
"id": provider.id,
|
"id": provider.id,
|
||||||
"name": provider.name,
|
"name": provider.name,
|
||||||
"display_name": provider.display_name,
|
|
||||||
"is_active": provider.is_active,
|
"is_active": provider.is_active,
|
||||||
"provider_priority": provider.provider_priority,
|
"provider_priority": provider.provider_priority,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1023,7 +1023,6 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
|||||||
{
|
{
|
||||||
"id": provider.id,
|
"id": provider.id,
|
||||||
"name": provider.name,
|
"name": provider.name,
|
||||||
"display_name": provider.display_name,
|
|
||||||
"description": provider.description,
|
"description": provider.description,
|
||||||
"provider_priority": provider.provider_priority,
|
"provider_priority": provider.provider_priority,
|
||||||
"endpoints": endpoints_data,
|
"endpoints": endpoints_data,
|
||||||
|
|||||||
@@ -1,10 +1,18 @@
|
|||||||
"""
|
"""
|
||||||
全局HTTP客户端池管理
|
全局HTTP客户端池管理
|
||||||
避免每次请求都创建新的AsyncClient,提高性能
|
避免每次请求都创建新的AsyncClient,提高性能
|
||||||
|
|
||||||
|
性能优化说明:
|
||||||
|
1. 默认客户端:无代理场景,全局复用单一客户端
|
||||||
|
2. 代理客户端缓存:相同代理配置复用同一客户端,避免重复创建
|
||||||
|
3. 连接池复用:Keep-alive 连接减少 TCP 握手开销
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Tuple
|
||||||
from urllib.parse import quote, urlparse
|
from urllib.parse import quote, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -12,6 +20,32 @@ import httpx
|
|||||||
from src.config import config
|
from src.config import config
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
|
|
||||||
|
# 模块级锁,避免类属性延迟初始化的竞态条件
|
||||||
|
_proxy_clients_lock = asyncio.Lock()
|
||||||
|
_default_client_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_proxy_cache_key(proxy_config: Optional[Dict[str, Any]]) -> str:
|
||||||
|
"""
|
||||||
|
计算代理配置的缓存键
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proxy_config: 代理配置字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存键字符串,无代理时返回 "__no_proxy__"
|
||||||
|
"""
|
||||||
|
if not proxy_config:
|
||||||
|
return "__no_proxy__"
|
||||||
|
|
||||||
|
# 构建代理 URL 作为缓存键的基础
|
||||||
|
proxy_url = build_proxy_url(proxy_config)
|
||||||
|
if not proxy_url:
|
||||||
|
return "__no_proxy__"
|
||||||
|
|
||||||
|
# 使用 MD5 哈希来避免过长的键名
|
||||||
|
return f"proxy:{hashlib.md5(proxy_url.encode()).hexdigest()[:16]}"
|
||||||
|
|
||||||
|
|
||||||
def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]:
|
def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
@@ -61,11 +95,20 @@ class HTTPClientPool:
|
|||||||
全局HTTP客户端池单例
|
全局HTTP客户端池单例
|
||||||
|
|
||||||
管理可重用的httpx.AsyncClient实例,避免频繁创建/销毁连接
|
管理可重用的httpx.AsyncClient实例,避免频繁创建/销毁连接
|
||||||
|
|
||||||
|
性能优化:
|
||||||
|
1. 默认客户端:无代理场景复用
|
||||||
|
2. 代理客户端缓存:相同代理配置复用同一客户端
|
||||||
|
3. LRU 淘汰:代理客户端超过上限时淘汰最久未使用的
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instance: Optional["HTTPClientPool"] = None
|
_instance: Optional["HTTPClientPool"] = None
|
||||||
_default_client: Optional[httpx.AsyncClient] = None
|
_default_client: Optional[httpx.AsyncClient] = None
|
||||||
_clients: Dict[str, httpx.AsyncClient] = {}
|
_clients: Dict[str, httpx.AsyncClient] = {}
|
||||||
|
# 代理客户端缓存:{cache_key: (client, last_used_time)}
|
||||||
|
_proxy_clients: Dict[str, Tuple[httpx.AsyncClient, float]] = {}
|
||||||
|
# 代理客户端缓存上限(避免内存泄漏)
|
||||||
|
_max_proxy_clients: int = 50
|
||||||
|
|
||||||
def __new__(cls):
|
def __new__(cls):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
@@ -73,12 +116,50 @@ class HTTPClientPool:
|
|||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_client(cls) -> httpx.AsyncClient:
|
async def get_default_client_async(cls) -> httpx.AsyncClient:
|
||||||
"""
|
"""
|
||||||
获取默认的HTTP客户端
|
获取默认的HTTP客户端(异步线程安全版本)
|
||||||
|
|
||||||
用于大多数HTTP请求,具有合理的默认配置
|
用于大多数HTTP请求,具有合理的默认配置
|
||||||
"""
|
"""
|
||||||
|
if cls._default_client is not None:
|
||||||
|
return cls._default_client
|
||||||
|
|
||||||
|
async with _default_client_lock:
|
||||||
|
# 双重检查,避免重复创建
|
||||||
|
if cls._default_client is None:
|
||||||
|
cls._default_client = httpx.AsyncClient(
|
||||||
|
http2=False, # 暂时禁用HTTP/2以提高兼容性
|
||||||
|
verify=True, # 启用SSL验证
|
||||||
|
timeout=httpx.Timeout(
|
||||||
|
connect=config.http_connect_timeout,
|
||||||
|
read=config.http_read_timeout,
|
||||||
|
write=config.http_write_timeout,
|
||||||
|
pool=config.http_pool_timeout,
|
||||||
|
),
|
||||||
|
limits=httpx.Limits(
|
||||||
|
max_connections=config.http_max_connections,
|
||||||
|
max_keepalive_connections=config.http_keepalive_connections,
|
||||||
|
keepalive_expiry=config.http_keepalive_expiry,
|
||||||
|
),
|
||||||
|
follow_redirects=True, # 跟随重定向
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"全局HTTP客户端池已初始化: "
|
||||||
|
f"max_connections={config.http_max_connections}, "
|
||||||
|
f"keepalive={config.http_keepalive_connections}, "
|
||||||
|
f"keepalive_expiry={config.http_keepalive_expiry}s"
|
||||||
|
)
|
||||||
|
return cls._default_client
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_client(cls) -> httpx.AsyncClient:
|
||||||
|
"""
|
||||||
|
获取默认的HTTP客户端(同步版本,向后兼容)
|
||||||
|
|
||||||
|
⚠️ 注意:此方法在高并发首次调用时可能存在竞态条件,
|
||||||
|
推荐使用 get_default_client_async() 异步版本。
|
||||||
|
"""
|
||||||
if cls._default_client is None:
|
if cls._default_client is None:
|
||||||
cls._default_client = httpx.AsyncClient(
|
cls._default_client = httpx.AsyncClient(
|
||||||
http2=False, # 暂时禁用HTTP/2以提高兼容性
|
http2=False, # 暂时禁用HTTP/2以提高兼容性
|
||||||
@@ -90,13 +171,18 @@ class HTTPClientPool:
|
|||||||
pool=config.http_pool_timeout,
|
pool=config.http_pool_timeout,
|
||||||
),
|
),
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=100, # 最大连接数
|
max_connections=config.http_max_connections,
|
||||||
max_keepalive_connections=20, # 最大保活连接数
|
max_keepalive_connections=config.http_keepalive_connections,
|
||||||
keepalive_expiry=30.0, # 保活过期时间(秒)
|
keepalive_expiry=config.http_keepalive_expiry,
|
||||||
),
|
),
|
||||||
follow_redirects=True, # 跟随重定向
|
follow_redirects=True, # 跟随重定向
|
||||||
)
|
)
|
||||||
logger.info("全局HTTP客户端池已初始化")
|
logger.info(
|
||||||
|
f"全局HTTP客户端池已初始化: "
|
||||||
|
f"max_connections={config.http_max_connections}, "
|
||||||
|
f"keepalive={config.http_keepalive_connections}, "
|
||||||
|
f"keepalive_expiry={config.http_keepalive_expiry}s"
|
||||||
|
)
|
||||||
return cls._default_client
|
return cls._default_client
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -130,6 +216,101 @@ class HTTPClientPool:
|
|||||||
|
|
||||||
return cls._clients[name]
|
return cls._clients[name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_proxy_clients_lock(cls) -> asyncio.Lock:
|
||||||
|
"""获取代理客户端缓存锁(模块级单例,避免竞态条件)"""
|
||||||
|
return _proxy_clients_lock
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _evict_lru_proxy_client(cls) -> None:
|
||||||
|
"""淘汰最久未使用的代理客户端"""
|
||||||
|
if len(cls._proxy_clients) < cls._max_proxy_clients:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 找到最久未使用的客户端
|
||||||
|
oldest_key = min(cls._proxy_clients.keys(), key=lambda k: cls._proxy_clients[k][1])
|
||||||
|
old_client, _ = cls._proxy_clients.pop(oldest_key)
|
||||||
|
|
||||||
|
# 异步关闭旧客户端
|
||||||
|
try:
|
||||||
|
await old_client.aclose()
|
||||||
|
logger.debug(f"淘汰代理客户端: {oldest_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"关闭代理客户端失败: {e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_proxy_client(
|
||||||
|
cls,
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> httpx.AsyncClient:
|
||||||
|
"""
|
||||||
|
获取代理客户端(带缓存复用)
|
||||||
|
|
||||||
|
相同代理配置会复用同一个客户端,大幅减少连接建立开销。
|
||||||
|
注意:返回的客户端使用默认超时配置,如需自定义超时请在请求时传递 timeout 参数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proxy_config: 代理配置字典,包含 url, username, password
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
可复用的 httpx.AsyncClient 实例
|
||||||
|
"""
|
||||||
|
cache_key = _compute_proxy_cache_key(proxy_config)
|
||||||
|
|
||||||
|
# 无代理时返回默认客户端
|
||||||
|
if cache_key == "__no_proxy__":
|
||||||
|
return await cls.get_default_client_async()
|
||||||
|
|
||||||
|
lock = cls._get_proxy_clients_lock()
|
||||||
|
async with lock:
|
||||||
|
# 检查缓存
|
||||||
|
if cache_key in cls._proxy_clients:
|
||||||
|
client, _ = cls._proxy_clients[cache_key]
|
||||||
|
# 健康检查:如果客户端已关闭,移除并重新创建
|
||||||
|
if client.is_closed:
|
||||||
|
del cls._proxy_clients[cache_key]
|
||||||
|
logger.debug(f"代理客户端已关闭,将重新创建: {cache_key}")
|
||||||
|
else:
|
||||||
|
# 更新最后使用时间
|
||||||
|
cls._proxy_clients[cache_key] = (client, time.time())
|
||||||
|
return client
|
||||||
|
|
||||||
|
# 淘汰旧客户端(如果超过上限)
|
||||||
|
await cls._evict_lru_proxy_client()
|
||||||
|
|
||||||
|
# 创建新客户端(使用默认超时,请求时可覆盖)
|
||||||
|
client_config: Dict[str, Any] = {
|
||||||
|
"http2": False,
|
||||||
|
"verify": True,
|
||||||
|
"follow_redirects": True,
|
||||||
|
"limits": httpx.Limits(
|
||||||
|
max_connections=config.http_max_connections,
|
||||||
|
max_keepalive_connections=config.http_keepalive_connections,
|
||||||
|
keepalive_expiry=config.http_keepalive_expiry,
|
||||||
|
),
|
||||||
|
"timeout": httpx.Timeout(
|
||||||
|
connect=config.http_connect_timeout,
|
||||||
|
read=config.http_read_timeout,
|
||||||
|
write=config.http_write_timeout,
|
||||||
|
pool=config.http_pool_timeout,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加代理配置
|
||||||
|
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
|
||||||
|
if proxy_url:
|
||||||
|
client_config["proxy"] = proxy_url
|
||||||
|
|
||||||
|
client = httpx.AsyncClient(**client_config)
|
||||||
|
cls._proxy_clients[cache_key] = (client, time.time())
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"创建代理客户端(缓存): {proxy_config.get('url', 'unknown') if proxy_config else 'none'}, "
|
||||||
|
f"缓存数量: {len(cls._proxy_clients)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def close_all(cls):
|
async def close_all(cls):
|
||||||
"""关闭所有HTTP客户端"""
|
"""关闭所有HTTP客户端"""
|
||||||
@@ -143,6 +324,16 @@ class HTTPClientPool:
|
|||||||
logger.debug(f"命名HTTP客户端已关闭: {name}")
|
logger.debug(f"命名HTTP客户端已关闭: {name}")
|
||||||
|
|
||||||
cls._clients.clear()
|
cls._clients.clear()
|
||||||
|
|
||||||
|
# 关闭代理客户端缓存
|
||||||
|
for cache_key, (client, _) in cls._proxy_clients.items():
|
||||||
|
try:
|
||||||
|
await client.aclose()
|
||||||
|
logger.debug(f"代理客户端已关闭: {cache_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"关闭代理客户端失败: {e}")
|
||||||
|
|
||||||
|
cls._proxy_clients.clear()
|
||||||
logger.info("所有HTTP客户端已关闭")
|
logger.info("所有HTTP客户端已关闭")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -185,13 +376,15 @@ class HTTPClientPool:
|
|||||||
"""
|
"""
|
||||||
创建带代理配置的HTTP客户端
|
创建带代理配置的HTTP客户端
|
||||||
|
|
||||||
|
⚠️ 性能警告:此方法每次都创建新客户端,推荐使用 get_proxy_client() 复用连接。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
proxy_config: 代理配置字典,包含 url, username, password
|
proxy_config: 代理配置字典,包含 url, username, password
|
||||||
timeout: 超时配置
|
timeout: 超时配置
|
||||||
**kwargs: 其他 httpx.AsyncClient 配置参数
|
**kwargs: 其他 httpx.AsyncClient 配置参数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
配置好的 httpx.AsyncClient 实例
|
配置好的 httpx.AsyncClient 实例(调用者需要负责关闭)
|
||||||
"""
|
"""
|
||||||
client_config: Dict[str, Any] = {
|
client_config: Dict[str, Any] = {
|
||||||
"http2": False,
|
"http2": False,
|
||||||
@@ -213,11 +406,21 @@ class HTTPClientPool:
|
|||||||
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
|
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
|
||||||
if proxy_url:
|
if proxy_url:
|
||||||
client_config["proxy"] = proxy_url
|
client_config["proxy"] = proxy_url
|
||||||
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
|
logger.debug(f"创建带代理的HTTP客户端(一次性): {proxy_config.get('url', 'unknown')}")
|
||||||
|
|
||||||
client_config.update(kwargs)
|
client_config.update(kwargs)
|
||||||
return httpx.AsyncClient(**client_config)
|
return httpx.AsyncClient(**client_config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_pool_stats(cls) -> Dict[str, Any]:
|
||||||
|
"""获取连接池统计信息"""
|
||||||
|
return {
|
||||||
|
"default_client_active": cls._default_client is not None,
|
||||||
|
"named_clients_count": len(cls._clients),
|
||||||
|
"proxy_clients_count": len(cls._proxy_clients),
|
||||||
|
"max_proxy_clients": cls._max_proxy_clients,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# 便捷访问函数
|
# 便捷访问函数
|
||||||
def get_http_client() -> httpx.AsyncClient:
|
def get_http_client() -> httpx.AsyncClient:
|
||||||
|
|||||||
@@ -52,19 +52,33 @@ class StreamDefaults:
|
|||||||
MAX_PREFETCH_BYTES = 64 * 1024 # 64KB
|
MAX_PREFETCH_BYTES = 64 * 1024 # 64KB
|
||||||
|
|
||||||
|
|
||||||
class ConcurrencyDefaults:
|
class RPMDefaults:
|
||||||
"""并发控制默认值
|
"""RPM(每分钟请求数)限制默认值
|
||||||
|
|
||||||
算法说明:边界记忆 + 渐进探测
|
算法说明:边界记忆 + 渐进探测
|
||||||
- 触发 429 时记录边界(last_concurrent_peak),新限制 = 边界 - 1
|
- 触发 429 时记录边界(last_rpm_peak),新限制 = 边界 - 1
|
||||||
- 扩容时不超过边界,除非是探测性扩容(长时间无 429)
|
- 扩容时不超过边界,除非是探测性扩容(长时间无 429)
|
||||||
- 这样可以快速收敛到真实限制附近,避免过度保守
|
- 这样可以快速收敛到真实限制附近,避免过度保守
|
||||||
|
|
||||||
|
初始值 50 RPM:
|
||||||
|
- 系统会根据实际使用自动调整
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 自适应并发初始限制(宽松起步,遇到 429 再降低)
|
# 自适应 RPM 初始限制
|
||||||
INITIAL_LIMIT = 50
|
INITIAL_LIMIT = 50 # 每分钟 50 次请求
|
||||||
|
|
||||||
# 429错误后的冷却时间(分钟)- 在此期间不会增加并发限制
|
# === 内存模式 RPM 计数器配置 ===
|
||||||
|
# 内存模式下的最大条目限制(防止内存泄漏)
|
||||||
|
# 每个条目约占 100 字节,10000 条目 = ~1MB
|
||||||
|
# 计算依据:1000 Key × 5 API 格式 × 2 (buffer) = 10000
|
||||||
|
# 可通过环境变量 RPM_MAX_MEMORY_ENTRIES 覆盖
|
||||||
|
MAX_MEMORY_RPM_ENTRIES = 10000
|
||||||
|
|
||||||
|
# 内存使用告警阈值(达到此比例时记录警告日志)
|
||||||
|
# 可通过环境变量 RPM_MEMORY_WARNING_THRESHOLD 覆盖
|
||||||
|
MEMORY_WARNING_THRESHOLD = 0.6 # 60%
|
||||||
|
|
||||||
|
# 429错误后的冷却时间(分钟)- 在此期间不会增加 RPM 限制
|
||||||
COOLDOWN_AFTER_429_MINUTES = 5
|
COOLDOWN_AFTER_429_MINUTES = 5
|
||||||
|
|
||||||
# 探测间隔上限(分钟)- 用于长期探测策略
|
# 探测间隔上限(分钟)- 用于长期探测策略
|
||||||
@@ -86,30 +100,30 @@ class ConcurrencyDefaults:
|
|||||||
# 最小采样数 - 窗口内至少需要这么多采样才能做出扩容决策
|
# 最小采样数 - 窗口内至少需要这么多采样才能做出扩容决策
|
||||||
MIN_SAMPLES_FOR_DECISION = 5
|
MIN_SAMPLES_FOR_DECISION = 5
|
||||||
|
|
||||||
# 扩容步长 - 每次扩容增加的并发数
|
# 扩容步长 - 每次扩容增加的 RPM
|
||||||
INCREASE_STEP = 2
|
INCREASE_STEP = 5 # 每次增加 5 RPM
|
||||||
|
|
||||||
# 最大并发限制上限
|
# 最大 RPM 限制上限(不设上限,让系统自适应学习)
|
||||||
MAX_CONCURRENT_LIMIT = 200
|
MAX_RPM_LIMIT = 10000
|
||||||
|
|
||||||
# 最小并发限制下限
|
# 最小 RPM 限制下限
|
||||||
# 设置为 3 而不是 1,因为预留机制(10%预留给缓存用户)会导致
|
MIN_RPM_LIMIT = 5
|
||||||
# 当 learned_max_concurrent=1 时新用户实际可用槽位为 0,永远无法命中
|
|
||||||
# 注意:当 limit < 10 时,预留机制实际不生效(预留槽位 = 0),这是可接受的
|
# 缓存用户预留比例(默认 10%,新用户可用 90%)
|
||||||
MIN_CONCURRENT_LIMIT = 3
|
# 已被动态预留机制 (AdaptiveReservationDefaults) 替代,保留用于向后兼容
|
||||||
|
CACHE_RESERVATION_RATIO = 0.1
|
||||||
|
|
||||||
# === 探测性扩容参数 ===
|
# === 探测性扩容参数 ===
|
||||||
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容
|
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容
|
||||||
# 探测性扩容可以突破已知边界,尝试更高的并发
|
# 探测性扩容可以突破已知边界,尝试更高的 RPM
|
||||||
PROBE_INCREASE_INTERVAL_MINUTES = 30
|
PROBE_INCREASE_INTERVAL_MINUTES = 30
|
||||||
|
|
||||||
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
|
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
|
||||||
PROBE_INCREASE_MIN_REQUESTS = 10
|
PROBE_INCREASE_MIN_REQUESTS = 10
|
||||||
|
|
||||||
# === 缓存用户预留比例 ===
|
|
||||||
# 缓存用户槽位预留比例(新用户可用 1 - 此值)
|
# 向后兼容别名
|
||||||
# 0.1 表示缓存用户预留 10%,新用户可用 90%
|
ConcurrencyDefaults = RPMDefaults
|
||||||
CACHE_RESERVATION_RATIO = 0.1
|
|
||||||
|
|
||||||
|
|
||||||
class CircuitBreakerDefaults:
|
class CircuitBreakerDefaults:
|
||||||
@@ -193,10 +207,19 @@ class AdaptiveReservationDefaults:
|
|||||||
|
|
||||||
|
|
||||||
class TimeoutDefaults:
|
class TimeoutDefaults:
|
||||||
"""超时配置默认值(秒)"""
|
"""超时配置默认值(秒)
|
||||||
|
|
||||||
# HTTP 请求默认超时
|
超时配置说明:
|
||||||
HTTP_REQUEST = 300 # 5分钟
|
- 全局默认值和 Provider 默认值统一为 120 秒
|
||||||
|
- 120 秒是 LLM API 的合理默认值:
|
||||||
|
* 大多数请求在 30 秒内完成
|
||||||
|
* 复杂推理(如 Claude extended thinking)可能需要 60-90 秒
|
||||||
|
* 120 秒足够覆盖大部分场景,同时避免线程池被长时间占用
|
||||||
|
- 如需更长超时,可在 Provider 级别单独配置
|
||||||
|
"""
|
||||||
|
|
||||||
|
# HTTP 请求默认超时(与 Provider 默认值保持一致)
|
||||||
|
HTTP_REQUEST = 120 # 2分钟
|
||||||
|
|
||||||
# 数据库连接池获取超时
|
# 数据库连接池获取超时
|
||||||
DB_POOL = 30
|
DB_POOL = 30
|
||||||
|
|||||||
@@ -145,6 +145,24 @@ class Config:
|
|||||||
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
|
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
|
||||||
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))
|
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))
|
||||||
|
|
||||||
|
# HTTP 连接池配置
|
||||||
|
# HTTP_MAX_CONNECTIONS: 最大连接数,影响并发能力
|
||||||
|
# - 每个连接占用一个 socket,过多会耗尽系统资源
|
||||||
|
# - 默认根据 Worker 数量自动计算:单 Worker 200,多 Worker 按比例分配
|
||||||
|
# HTTP_KEEPALIVE_CONNECTIONS: 保活连接数,影响连接复用效率
|
||||||
|
# - 高频请求场景应该增大此值
|
||||||
|
# - 默认为 max_connections 的 30%(长连接场景更高效)
|
||||||
|
# HTTP_KEEPALIVE_EXPIRY: 保活过期时间(秒)
|
||||||
|
# - 过短会频繁重建连接,过长会占用资源
|
||||||
|
# - 默认 30 秒,生图等长连接场景可适当增大
|
||||||
|
self.http_max_connections = int(
|
||||||
|
os.getenv("HTTP_MAX_CONNECTIONS") or self._auto_http_max_connections()
|
||||||
|
)
|
||||||
|
self.http_keepalive_connections = int(
|
||||||
|
os.getenv("HTTP_KEEPALIVE_CONNECTIONS") or self._auto_http_keepalive_connections()
|
||||||
|
)
|
||||||
|
self.http_keepalive_expiry = float(os.getenv("HTTP_KEEPALIVE_EXPIRY", "30.0"))
|
||||||
|
|
||||||
# 流式处理配置
|
# 流式处理配置
|
||||||
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
|
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
|
||||||
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
|
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
|
||||||
@@ -224,6 +242,53 @@ class Config:
|
|||||||
"""智能计算最大溢出连接数 - 与 pool_size 相同"""
|
"""智能计算最大溢出连接数 - 与 pool_size 相同"""
|
||||||
return self.db_pool_size
|
return self.db_pool_size
|
||||||
|
|
||||||
|
def _auto_http_max_connections(self) -> int:
|
||||||
|
"""
|
||||||
|
智能计算 HTTP 最大连接数
|
||||||
|
|
||||||
|
计算依据:
|
||||||
|
1. 系统 socket 资源有限(Linux 默认 ulimit -n 通常为 1024)
|
||||||
|
2. 多 Worker 部署时每个进程独立连接池
|
||||||
|
3. 需要为数据库连接、Redis 连接等预留资源
|
||||||
|
|
||||||
|
公式: base_connections / workers
|
||||||
|
- 单 Worker: 200 连接(适合开发/低负载)
|
||||||
|
- 多 Worker: 按比例分配,确保总数不超过系统限制
|
||||||
|
|
||||||
|
范围: 50 - 500
|
||||||
|
"""
|
||||||
|
# 基础连接数:假设系统可用 socket 约 800 个用于 HTTP
|
||||||
|
# (预留给 DB、Redis、内部服务等)
|
||||||
|
base_connections = 800
|
||||||
|
workers = max(self.worker_processes, 1)
|
||||||
|
|
||||||
|
# 每个 Worker 分配的连接数
|
||||||
|
per_worker = base_connections // workers
|
||||||
|
|
||||||
|
# 限制范围:最小 50(保证基本并发),最大 500(避免资源耗尽)
|
||||||
|
return max(50, min(per_worker, 500))
|
||||||
|
|
||||||
|
def _auto_http_keepalive_connections(self) -> int:
|
||||||
|
"""
|
||||||
|
智能计算 HTTP 保活连接数
|
||||||
|
|
||||||
|
计算依据:
|
||||||
|
1. 保活连接用于复用,减少 TCP 握手开销
|
||||||
|
2. 对于 API 网关场景,上游请求频繁,保活比例应较高
|
||||||
|
3. 生图等长连接场景,连接会被长时间占用
|
||||||
|
|
||||||
|
公式: max_connections * 0.3
|
||||||
|
- 30% 的比例在复用效率和资源占用间取得平衡
|
||||||
|
- 长连接场景建议手动调高到 50-70%
|
||||||
|
|
||||||
|
范围: 10 - max_connections
|
||||||
|
"""
|
||||||
|
# 保活连接数为最大连接数的 30%
|
||||||
|
keepalive = int(self.http_max_connections * 0.3)
|
||||||
|
|
||||||
|
# 最小 10 个保活连接,最大不超过 max_connections
|
||||||
|
return max(10, min(keepalive, self.http_max_connections))
|
||||||
|
|
||||||
def _parse_ttfb_timeout(self) -> float:
|
def _parse_ttfb_timeout(self) -> float:
|
||||||
"""
|
"""
|
||||||
解析 TTFB 超时配置,带错误处理和范围限制
|
解析 TTFB 超时配置,带错误处理和范围限制
|
||||||
|
|||||||
@@ -36,6 +36,12 @@ class CacheService:
|
|||||||
try:
|
try:
|
||||||
return json.loads(value)
|
return json.loads(value)
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
# 非 JSON 值:统一返回字符串,避免上层出现 bytes/str 混用
|
||||||
|
if isinstance(value, (bytes, bytearray)):
|
||||||
|
try:
|
||||||
|
return value.decode("utf-8")
|
||||||
|
except Exception:
|
||||||
|
return value
|
||||||
return value
|
return value
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@@ -98,6 +104,48 @@ class CacheService:
|
|||||||
logger.warning(f"缓存删除失败: {key} - {e}")
|
logger.warning(f"缓存删除失败: {key} - {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def delete_pattern(pattern: str, batch_size: int = 100) -> int:
|
||||||
|
"""
|
||||||
|
删除匹配模式的所有缓存
|
||||||
|
|
||||||
|
使用 SCAN 遍历并分批删除,避免阻塞 Redis
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: 缓存键模式(支持 * 通配符)
|
||||||
|
batch_size: 每批删除的最大键数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
删除的键数量
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
redis = await get_redis_client(require_redis=False)
|
||||||
|
if not redis:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 使用 SCAN 遍历匹配的键
|
||||||
|
deleted_count = 0
|
||||||
|
cursor = 0
|
||||||
|
while True:
|
||||||
|
cursor, keys = await redis.scan(cursor, match=pattern, count=batch_size)
|
||||||
|
if keys:
|
||||||
|
# 分批删除,避免单次删除过多键导致 Redis 阻塞
|
||||||
|
for i in range(0, len(keys), batch_size):
|
||||||
|
batch = keys[i : i + batch_size]
|
||||||
|
await redis.delete(*batch)
|
||||||
|
deleted_count += len(batch)
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if deleted_count > 0:
|
||||||
|
logger.debug(f"缓存模式删除成功: {pattern}, 删除 {deleted_count} 个键")
|
||||||
|
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"缓存模式删除失败: {pattern} - {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def exists(key: str) -> bool:
|
async def exists(key: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -7,16 +7,19 @@ from typing import Optional
|
|||||||
|
|
||||||
def extract_error_message(error: Exception, status_code: Optional[int] = None) -> str:
|
def extract_error_message(error: Exception, status_code: Optional[int] = None) -> str:
|
||||||
"""
|
"""
|
||||||
从异常中提取错误消息,优先使用上游响应内容
|
从异常中提取错误消息,优先使用上游原始响应(用于链路追踪/调试)
|
||||||
|
|
||||||
|
此函数用于 RequestCandidate 表的 error_message 字段,
|
||||||
|
用于请求链路追踪中显示原始 Provider 响应。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
error: 异常对象
|
error: 异常对象
|
||||||
status_code: 可选的 HTTP 状态码,用于构建更详细的错误消息
|
status_code: 可选的 HTTP 状态码,用于构建更详细的错误消息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
错误消息字符串
|
错误消息字符串(原始 Provider 响应)
|
||||||
"""
|
"""
|
||||||
# 优先使用 upstream_response 属性(包含上游 Provider 的原始错误)
|
# 优先使用 upstream_response 属性(包含上游 Provider 的原始错误,用于调试)
|
||||||
upstream_response = getattr(error, "upstream_response", None)
|
upstream_response = getattr(error, "upstream_response", None)
|
||||||
if upstream_response and isinstance(upstream_response, str) and upstream_response.strip():
|
if upstream_response and isinstance(upstream_response, str) and upstream_response.strip():
|
||||||
return str(upstream_response)
|
return str(upstream_response)
|
||||||
@@ -26,3 +29,25 @@ def extract_error_message(error: Exception, status_code: Optional[int] = None) -
|
|||||||
if status_code is not None:
|
if status_code is not None:
|
||||||
return f"HTTP {status_code}: {error_str}"
|
return f"HTTP {status_code}: {error_str}"
|
||||||
return error_str
|
return error_str
|
||||||
|
|
||||||
|
|
||||||
|
def extract_client_error_message(error: Exception) -> str:
|
||||||
|
"""
|
||||||
|
从异常中提取客户端友好的错误消息(用于返回给客户端/Usage 记录)
|
||||||
|
|
||||||
|
此函数用于 Usage 表的 error_message 字段,
|
||||||
|
用于显示给最终用户的友好错误消息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: 异常对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
友好的错误消息字符串
|
||||||
|
"""
|
||||||
|
# 优先使用 message 属性(已经是友好处理过的消息)
|
||||||
|
message = getattr(error, "message", None)
|
||||||
|
if message and isinstance(message, str) and message.strip():
|
||||||
|
return message
|
||||||
|
|
||||||
|
# 回退到异常的字符串表示
|
||||||
|
return str(error) or repr(error)
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ class ProviderTimeoutException(ProviderException):
|
|||||||
|
|
||||||
def __init__(self, provider_name: str, timeout: int, request_metadata: Optional[Any] = None):
|
def __init__(self, provider_name: str, timeout: int, request_metadata: Optional[Any] = None):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
message=f"提供商 '{provider_name}' 请求超时({timeout}秒)",
|
message=f"请求超时({timeout}秒)",
|
||||||
provider_name=provider_name,
|
provider_name=provider_name,
|
||||||
request_metadata=request_metadata,
|
request_metadata=request_metadata,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
@@ -217,7 +217,7 @@ class ProviderAuthException(ProviderException):
|
|||||||
|
|
||||||
def __init__(self, provider_name: str, request_metadata: Optional[Any] = None):
|
def __init__(self, provider_name: str, request_metadata: Optional[Any] = None):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
message=f"提供商 '{provider_name}' 认证失败,请检查API密钥",
|
message="上游服务认证失败",
|
||||||
provider_name=provider_name,
|
provider_name=provider_name,
|
||||||
request_metadata=request_metadata,
|
request_metadata=request_metadata,
|
||||||
)
|
)
|
||||||
@@ -292,9 +292,8 @@ class ModelNotSupportedException(ProxyException):
|
|||||||
"""模型不支持"""
|
"""模型不支持"""
|
||||||
|
|
||||||
def __init__(self, model: str, provider_name: Optional[str] = None):
|
def __init__(self, model: str, provider_name: Optional[str] = None):
|
||||||
|
# 客户端消息不暴露提供商信息
|
||||||
message = f"模型 '{model}' 不受支持"
|
message = f"模型 '{model}' 不受支持"
|
||||||
if provider_name:
|
|
||||||
message = f"提供商 '{provider_name}' 不支持模型 '{model}'"
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
error_type="model_not_supported",
|
error_type="model_not_supported",
|
||||||
@@ -307,10 +306,8 @@ class StreamingNotSupportedException(ProxyException):
|
|||||||
"""流式请求不支持"""
|
"""流式请求不支持"""
|
||||||
|
|
||||||
def __init__(self, model: str, provider_name: Optional[str] = None):
|
def __init__(self, model: str, provider_name: Optional[str] = None):
|
||||||
if provider_name:
|
# 客户端消息不暴露提供商信息
|
||||||
message = f"模型 '{model}' 在提供商 '{provider_name}' 上不支持流式请求"
|
message = f"模型 '{model}' 不支持流式请求"
|
||||||
else:
|
|
||||||
message = f"模型 '{model}' 不支持流式请求"
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
error_type="streaming_not_supported",
|
error_type="streaming_not_supported",
|
||||||
@@ -389,7 +386,7 @@ class JSONParseException(ProviderException):
|
|||||||
details["response_content"] = response_content
|
details["response_content"] = response_content
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
message=f"提供商 '{provider_name}' 返回了无效的JSON响应",
|
message="上游服务返回了无效的响应",
|
||||||
provider_name=provider_name,
|
provider_name=provider_name,
|
||||||
request_metadata=request_metadata,
|
request_metadata=request_metadata,
|
||||||
**details,
|
**details,
|
||||||
@@ -406,7 +403,7 @@ class EmptyStreamException(ProviderException):
|
|||||||
request_metadata: Optional[Any] = None,
|
request_metadata: Optional[Any] = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
message=f"提供商 '{provider_name}' 返回了空的流式响应(status=200 但无数据)",
|
message="上游服务返回了空的流式响应",
|
||||||
provider_name=provider_name,
|
provider_name=provider_name,
|
||||||
request_metadata=request_metadata,
|
request_metadata=request_metadata,
|
||||||
chunk_count=chunk_count,
|
chunk_count=chunk_count,
|
||||||
@@ -428,11 +425,10 @@ class EmbeddedErrorException(ProviderException):
|
|||||||
error_status: Optional[str] = None,
|
error_status: Optional[str] = None,
|
||||||
request_metadata: Optional[Any] = None,
|
request_metadata: Optional[Any] = None,
|
||||||
):
|
):
|
||||||
message = f"提供商 '{provider_name}' 返回了嵌套错误"
|
# 客户端消息不暴露提供商信息
|
||||||
|
message = "上游服务返回了错误"
|
||||||
if error_code:
|
if error_code:
|
||||||
message += f" (code={error_code})"
|
message += f" (code={error_code})"
|
||||||
if error_message:
|
|
||||||
message += f": {error_message}"
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
message=message,
|
message=message,
|
||||||
@@ -549,12 +545,14 @@ class ErrorResponse:
|
|||||||
if isinstance(e, ProxyException):
|
if isinstance(e, ProxyException):
|
||||||
details = e.details.copy() if e.details else {}
|
details = e.details.copy() if e.details else {}
|
||||||
status_code = e.status_code
|
status_code = e.status_code
|
||||||
message = e.message
|
message = e.message # 使用友好的错误消息
|
||||||
# 如果是 ProviderNotAvailableException 且有上游错误,直接透传上游信息
|
# 如果是 ProviderNotAvailableException 且有上游错误信息
|
||||||
if isinstance(e, ProviderNotAvailableException) and e.upstream_response:
|
if isinstance(e, ProviderNotAvailableException):
|
||||||
if e.upstream_status:
|
if e.upstream_status:
|
||||||
status_code = e.upstream_status
|
status_code = e.upstream_status
|
||||||
message = e.upstream_response
|
# upstream_response 存入 details 供请求链路追踪使用,不作为客户端消息
|
||||||
|
if e.upstream_response:
|
||||||
|
details["upstream_response"] = e.upstream_response
|
||||||
return ErrorResponse.create(
|
return ErrorResponse.create(
|
||||||
error_type=e.error_type,
|
error_type=e.error_type,
|
||||||
message=message,
|
message=message,
|
||||||
|
|||||||
286
src/core/model_permissions.py
Normal file
286
src/core/model_permissions.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
"""
|
||||||
|
模型权限工具
|
||||||
|
|
||||||
|
支持两种 allowed_models 格式:
|
||||||
|
1. 简单模式(列表): ["claude-sonnet-4", "gpt-4o"]
|
||||||
|
2. 按格式模式(字典): {"OPENAI": ["gpt-4o"], "CLAUDE": ["claude-sonnet-4"]}
|
||||||
|
|
||||||
|
使用 None/null 表示不限制(允许所有模型)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
|
# 类型别名
|
||||||
|
AllowedModels = Optional[Union[List[str], Dict[str, List[str]]]]
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_allowed_models(
|
||||||
|
allowed_models: AllowedModels,
|
||||||
|
api_format: Optional[str] = None,
|
||||||
|
) -> Optional[Set[str]]:
|
||||||
|
"""
|
||||||
|
将 allowed_models 规范化为模型名称集合
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_models: 允许的模型配置(列表或字典)
|
||||||
|
api_format: 当前请求的 API 格式(用于字典模式)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- None: 不限制(允许所有模型)
|
||||||
|
- Set[str]: 允许的模型名称集合(可能为空集,表示拒绝所有)
|
||||||
|
"""
|
||||||
|
if allowed_models is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 简单模式:直接是列表
|
||||||
|
if isinstance(allowed_models, list):
|
||||||
|
return set(allowed_models)
|
||||||
|
|
||||||
|
# 按格式模式:字典
|
||||||
|
if isinstance(allowed_models, dict):
|
||||||
|
if api_format is None:
|
||||||
|
# 没有指定格式,合并所有格式的模型
|
||||||
|
all_models: Set[str] = set()
|
||||||
|
for models in allowed_models.values():
|
||||||
|
if isinstance(models, list):
|
||||||
|
all_models.update(models)
|
||||||
|
return all_models if all_models else None
|
||||||
|
|
||||||
|
# 查找指定格式的模型列表
|
||||||
|
api_format_upper = api_format.upper()
|
||||||
|
models = allowed_models.get(api_format_upper)
|
||||||
|
if models is None:
|
||||||
|
# 该格式未配置,检查是否有通配符 "*"
|
||||||
|
models = allowed_models.get("*")
|
||||||
|
|
||||||
|
if models is None:
|
||||||
|
# 字典模式下未配置的格式 = 不限制该格式
|
||||||
|
return None
|
||||||
|
|
||||||
|
return set(models) if isinstance(models, list) else None
|
||||||
|
|
||||||
|
# 未知类型,视为不限制
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_allowed(
|
||||||
|
model_name: str,
|
||||||
|
allowed_models: AllowedModels,
|
||||||
|
api_format: Optional[str] = None,
|
||||||
|
resolved_model_name: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
检查模型是否被允许
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: 请求的模型名称
|
||||||
|
allowed_models: 允许的模型配置
|
||||||
|
api_format: 当前请求的 API 格式
|
||||||
|
resolved_model_name: 解析后的 GlobalModel.name(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True: 允许使用该模型
|
||||||
|
False: 不允许使用该模型
|
||||||
|
"""
|
||||||
|
allowed_set = normalize_allowed_models(allowed_models, api_format)
|
||||||
|
|
||||||
|
if allowed_set is None:
|
||||||
|
# 不限制
|
||||||
|
return True
|
||||||
|
|
||||||
|
if len(allowed_set) == 0:
|
||||||
|
# 空集合 = 拒绝所有
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查请求的模型名或解析后的名称是否在白名单中
|
||||||
|
if model_name in allowed_set:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if resolved_model_name and resolved_model_name in allowed_set:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def merge_allowed_models(
|
||||||
|
allowed_models_1: AllowedModels,
|
||||||
|
allowed_models_2: AllowedModels,
|
||||||
|
) -> AllowedModels:
|
||||||
|
"""
|
||||||
|
合并两个 allowed_models 配置,取交集
|
||||||
|
|
||||||
|
规则:
|
||||||
|
- 如果任一为 None,返回另一个
|
||||||
|
- 如果都有值,取交集
|
||||||
|
- 如果都是列表,取列表交集
|
||||||
|
- 如果有字典,按 API 格式分别取交集(保持字典语义,不丢失格式区分信息)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_models_1: 第一个配置
|
||||||
|
allowed_models_2: 第二个配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
合并后的配置
|
||||||
|
"""
|
||||||
|
if allowed_models_1 is None:
|
||||||
|
return allowed_models_2
|
||||||
|
if allowed_models_2 is None:
|
||||||
|
return allowed_models_1
|
||||||
|
|
||||||
|
# 两个都是简单列表:直接取交集(返回确定性顺序)
|
||||||
|
if isinstance(allowed_models_1, list) and isinstance(allowed_models_2, list):
|
||||||
|
intersection = set(allowed_models_1) & set(allowed_models_2)
|
||||||
|
return sorted(intersection) if intersection else []
|
||||||
|
|
||||||
|
# 任一为字典模式:按 API 格式分别取交集,避免把 dict 合并成 list 导致权限过宽
|
||||||
|
from src.core.enums import APIFormat
|
||||||
|
|
||||||
|
def merge_sets(a: Optional[Set[str]], b: Optional[Set[str]]) -> Optional[Set[str]]:
|
||||||
|
# None 表示不限制:交集规则下等价于“只受另一方限制”
|
||||||
|
if a is None:
|
||||||
|
return b
|
||||||
|
if b is None:
|
||||||
|
return a
|
||||||
|
return a & b
|
||||||
|
|
||||||
|
known_formats = [fmt.value for fmt in APIFormat]
|
||||||
|
|
||||||
|
per_format: Dict[str, Optional[Set[str]]] = {}
|
||||||
|
for fmt in known_formats:
|
||||||
|
s1 = normalize_allowed_models(allowed_models_1, api_format=fmt)
|
||||||
|
s2 = normalize_allowed_models(allowed_models_2, api_format=fmt)
|
||||||
|
per_format[fmt] = merge_sets(s1, s2)
|
||||||
|
|
||||||
|
# 计算默认(未知格式)的交集,用 "*" 作为默认值以覆盖未枚举的格式
|
||||||
|
default_s1 = normalize_allowed_models(allowed_models_1, api_format="__DEFAULT__")
|
||||||
|
default_s2 = normalize_allowed_models(allowed_models_2, api_format="__DEFAULT__")
|
||||||
|
default_set = merge_sets(default_s1, default_s2)
|
||||||
|
|
||||||
|
# 如果 default_set 非 None 且不存在“某些格式不限制”的情况,可用 "*" 作为默认规则并按需覆盖
|
||||||
|
can_use_wildcard = default_set is not None and all(v is not None for v in per_format.values())
|
||||||
|
|
||||||
|
merged_dict: Dict[str, List[str]] = {}
|
||||||
|
|
||||||
|
if can_use_wildcard and default_set is not None:
|
||||||
|
merged_dict["*"] = sorted(default_set)
|
||||||
|
for fmt, s in per_format.items():
|
||||||
|
# can_use_wildcard 保证 s 非 None
|
||||||
|
if s is not None and s != default_set:
|
||||||
|
merged_dict[fmt] = sorted(s)
|
||||||
|
else:
|
||||||
|
for fmt, s in per_format.items():
|
||||||
|
if s is None:
|
||||||
|
continue
|
||||||
|
merged_dict[fmt] = sorted(s)
|
||||||
|
|
||||||
|
if not merged_dict:
|
||||||
|
# 全部不限制
|
||||||
|
return None
|
||||||
|
|
||||||
|
return merged_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_allowed_models_preview(
|
||||||
|
allowed_models: AllowedModels,
|
||||||
|
max_items: int = 3,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
获取 allowed_models 的预览字符串(用于日志和错误消息)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_models: 允许的模型配置
|
||||||
|
max_items: 最多显示的模型数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
预览字符串,如 "gpt-4o, claude-sonnet-4, ..."
|
||||||
|
"""
|
||||||
|
if allowed_models is None:
|
||||||
|
return "(不限制)"
|
||||||
|
|
||||||
|
all_models: Set[str] = set()
|
||||||
|
|
||||||
|
if isinstance(allowed_models, list):
|
||||||
|
all_models = set(allowed_models)
|
||||||
|
elif isinstance(allowed_models, dict):
|
||||||
|
for models in allowed_models.values():
|
||||||
|
if isinstance(models, list):
|
||||||
|
all_models.update(models)
|
||||||
|
|
||||||
|
if not all_models:
|
||||||
|
return "(无)"
|
||||||
|
|
||||||
|
sorted_models = sorted(all_models)
|
||||||
|
preview = ", ".join(sorted_models[:max_items])
|
||||||
|
if len(sorted_models) > max_items:
|
||||||
|
preview += f", ...共{len(sorted_models)}个"
|
||||||
|
|
||||||
|
return preview
|
||||||
|
|
||||||
|
|
||||||
|
def is_format_mode(allowed_models: AllowedModels) -> bool:
|
||||||
|
"""
|
||||||
|
判断 allowed_models 是否为按格式模式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_models: 允许的模型配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True: 按格式模式(字典)
|
||||||
|
False: 简单模式(列表或 None)
|
||||||
|
"""
|
||||||
|
return isinstance(allowed_models, dict)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_format_mode(
|
||||||
|
allowed_models: AllowedModels,
|
||||||
|
api_formats: Optional[List[str]] = None,
|
||||||
|
) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
将 allowed_models 转换为按格式模式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_models: 原始配置
|
||||||
|
api_formats: 要应用的 API 格式列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
按格式模式的配置
|
||||||
|
"""
|
||||||
|
if allowed_models is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if isinstance(allowed_models, dict):
|
||||||
|
return allowed_models
|
||||||
|
|
||||||
|
# 简单列表模式 -> 按格式模式
|
||||||
|
if isinstance(allowed_models, list):
|
||||||
|
if not api_formats:
|
||||||
|
return {"*": allowed_models}
|
||||||
|
return {fmt.upper(): list(allowed_models) for fmt in api_formats}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_simple_mode(allowed_models: AllowedModels) -> Optional[List[str]]:
|
||||||
|
"""
|
||||||
|
将 allowed_models 转换为简单列表模式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_models: 原始配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
简单列表或 None
|
||||||
|
"""
|
||||||
|
if allowed_models is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(allowed_models, list):
|
||||||
|
return allowed_models
|
||||||
|
|
||||||
|
if isinstance(allowed_models, dict):
|
||||||
|
all_models: Set[str] = set()
|
||||||
|
for models in allowed_models.values():
|
||||||
|
if isinstance(models, list):
|
||||||
|
all_models.update(models)
|
||||||
|
return sorted(all_models) if all_models else None
|
||||||
|
|
||||||
|
return None
|
||||||
@@ -369,7 +369,6 @@ def init_db():
|
|||||||
_ensure_engine()
|
_ensure_engine()
|
||||||
|
|
||||||
# 数据库表结构由 Alembic 迁移管理
|
# 数据库表结构由 Alembic 迁移管理
|
||||||
# 首次部署或更新后请运行: ./migrate.sh
|
|
||||||
|
|
||||||
db = _SessionLocal()
|
db = _SessionLocal()
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -52,15 +52,31 @@ class ProxyConfig(BaseModel):
|
|||||||
class CreateProviderRequest(BaseModel):
|
class CreateProviderRequest(BaseModel):
|
||||||
"""创建 Provider 请求"""
|
"""创建 Provider 请求"""
|
||||||
|
|
||||||
name: str = Field(
|
name: str = Field(..., min_length=1, max_length=100, description="提供商名称(唯一)")
|
||||||
...,
|
|
||||||
min_length=1,
|
|
||||||
max_length=100,
|
|
||||||
description="Provider 名称(英文字母、数字、下划线、连字符)",
|
|
||||||
)
|
|
||||||
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
|
|
||||||
description: Optional[str] = Field(None, max_length=1000, description="描述")
|
description: Optional[str] = Field(None, max_length=1000, description="描述")
|
||||||
website: Optional[str] = Field(None, max_length=500, description="官网地址")
|
website: Optional[str] = Field(None, max_length=500, description="官网地址")
|
||||||
|
|
||||||
|
@field_validator("name")
|
||||||
|
@classmethod
|
||||||
|
def validate_name(cls, v: str) -> str:
|
||||||
|
"""验证名称格式,防止注入攻击"""
|
||||||
|
v = v.strip()
|
||||||
|
|
||||||
|
# 只允许安全的字符:字母、数字、下划线、连字符、空格、中文
|
||||||
|
if not re.match(r"^[\w\s\u4e00-\u9fff-]+$", v):
|
||||||
|
raise ValueError("名称只能包含字母、数字、下划线、连字符、空格和中文")
|
||||||
|
|
||||||
|
# 检查 SQL 注入关键字(不区分大小写)
|
||||||
|
sql_keywords = [
|
||||||
|
"SELECT", "INSERT", "UPDATE", "DELETE", "DROP", "CREATE",
|
||||||
|
"ALTER", "TRUNCATE", "UNION", "EXEC", "EXECUTE", "--", "/*", "*/"
|
||||||
|
]
|
||||||
|
v_upper = v.upper()
|
||||||
|
for keyword in sql_keywords:
|
||||||
|
if keyword in v_upper:
|
||||||
|
raise ValueError(f"名称包含非法关键字: {keyword}")
|
||||||
|
|
||||||
|
return v
|
||||||
billing_type: Optional[str] = Field(
|
billing_type: Optional[str] = Field(
|
||||||
ProviderBillingType.PAY_AS_YOU_GO.value, description="计费类型"
|
ProviderBillingType.PAY_AS_YOU_GO.value, description="计费类型"
|
||||||
)
|
)
|
||||||
@@ -68,47 +84,16 @@ class CreateProviderRequest(BaseModel):
|
|||||||
quota_reset_day: Optional[int] = Field(30, ge=1, le=365, description="配额重置周期(天数)")
|
quota_reset_day: Optional[int] = Field(30, ge=1, le=365, description="配额重置周期(天数)")
|
||||||
quota_last_reset_at: Optional[datetime] = Field(None, description="当前周期开始时间")
|
quota_last_reset_at: Optional[datetime] = Field(None, description="当前周期开始时间")
|
||||||
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
|
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
|
||||||
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
|
|
||||||
provider_priority: Optional[int] = Field(100, ge=0, le=1000, description="提供商优先级(数字越小越优先)")
|
provider_priority: Optional[int] = Field(100, ge=0, le=1000, description="提供商优先级(数字越小越优先)")
|
||||||
is_active: Optional[bool] = Field(True, description="是否启用")
|
is_active: Optional[bool] = Field(True, description="是否启用")
|
||||||
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
||||||
|
# 请求配置(从 Endpoint 迁移)
|
||||||
|
timeout: Optional[int] = Field(300, ge=1, le=600, description="请求超时(秒)")
|
||||||
|
max_retries: Optional[int] = Field(2, ge=0, le=10, description="最大重试次数")
|
||||||
|
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
|
||||||
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
||||||
|
|
||||||
@field_validator("name")
|
@field_validator("name", "description")
|
||||||
@classmethod
|
|
||||||
def validate_name(cls, v: str) -> str:
|
|
||||||
"""验证名称格式"""
|
|
||||||
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
|
|
||||||
raise ValueError("名称只能包含英文字母、数字、下划线和连字符")
|
|
||||||
|
|
||||||
# SQL 注入防护:检查危险关键字
|
|
||||||
dangerous_keywords = [
|
|
||||||
"SELECT",
|
|
||||||
"INSERT",
|
|
||||||
"UPDATE",
|
|
||||||
"DELETE",
|
|
||||||
"DROP",
|
|
||||||
"CREATE",
|
|
||||||
"ALTER",
|
|
||||||
"EXEC",
|
|
||||||
"UNION",
|
|
||||||
"OR",
|
|
||||||
"AND",
|
|
||||||
"--",
|
|
||||||
";",
|
|
||||||
"'",
|
|
||||||
'"',
|
|
||||||
"<",
|
|
||||||
">",
|
|
||||||
]
|
|
||||||
upper_name = v.upper()
|
|
||||||
for keyword in dangerous_keywords:
|
|
||||||
if keyword in upper_name:
|
|
||||||
raise ValueError(f"名称包含禁止的字符或关键字: {keyword}")
|
|
||||||
|
|
||||||
return v
|
|
||||||
|
|
||||||
@field_validator("display_name", "description")
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sanitize_text(cls, v: Optional[str]) -> Optional[str]:
|
def sanitize_text(cls, v: Optional[str]) -> Optional[str]:
|
||||||
"""清理文本输入,防止 XSS"""
|
"""清理文本输入,防止 XSS"""
|
||||||
@@ -162,7 +147,7 @@ class CreateProviderRequest(BaseModel):
|
|||||||
class UpdateProviderRequest(BaseModel):
|
class UpdateProviderRequest(BaseModel):
|
||||||
"""更新 Provider 请求"""
|
"""更新 Provider 请求"""
|
||||||
|
|
||||||
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
|
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||||
description: Optional[str] = Field(None, max_length=1000)
|
description: Optional[str] = Field(None, max_length=1000)
|
||||||
website: Optional[str] = Field(None, max_length=500)
|
website: Optional[str] = Field(None, max_length=500)
|
||||||
billing_type: Optional[str] = None
|
billing_type: Optional[str] = None
|
||||||
@@ -170,14 +155,17 @@ class UpdateProviderRequest(BaseModel):
|
|||||||
quota_reset_day: Optional[int] = Field(None, ge=1, le=365)
|
quota_reset_day: Optional[int] = Field(None, ge=1, le=365)
|
||||||
quota_last_reset_at: Optional[datetime] = None
|
quota_last_reset_at: Optional[datetime] = None
|
||||||
quota_expires_at: Optional[datetime] = None
|
quota_expires_at: Optional[datetime] = None
|
||||||
rpm_limit: Optional[int] = Field(None, ge=0)
|
|
||||||
provider_priority: Optional[int] = Field(None, ge=0, le=1000)
|
provider_priority: Optional[int] = Field(None, ge=0, le=1000)
|
||||||
is_active: Optional[bool] = None
|
is_active: Optional[bool] = None
|
||||||
concurrent_limit: Optional[int] = Field(None, ge=0)
|
concurrent_limit: Optional[int] = Field(None, ge=0)
|
||||||
|
# 请求配置(从 Endpoint 迁移)
|
||||||
|
timeout: Optional[int] = Field(None, ge=1, le=600, description="请求超时(秒)")
|
||||||
|
max_retries: Optional[int] = Field(None, ge=0, le=10, description="最大重试次数")
|
||||||
|
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
|
||||||
config: Optional[Dict[str, Any]] = None
|
config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
# 复用相同的验证器
|
# 复用相同的验证器
|
||||||
_sanitize_text = field_validator("display_name", "description")(
|
_sanitize_text = field_validator("name", "description")(
|
||||||
CreateProviderRequest.sanitize_text.__func__
|
CreateProviderRequest.sanitize_text.__func__
|
||||||
)
|
)
|
||||||
_validate_website = field_validator("website")(CreateProviderRequest.validate_website.__func__)
|
_validate_website = field_validator("website")(CreateProviderRequest.validate_website.__func__)
|
||||||
@@ -196,7 +184,6 @@ class CreateEndpointRequest(BaseModel):
|
|||||||
custom_path: Optional[str] = Field(None, max_length=200, description="自定义路径")
|
custom_path: Optional[str] = Field(None, max_length=200, description="自定义路径")
|
||||||
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
|
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
|
||||||
is_active: Optional[bool] = Field(True, description="是否启用")
|
is_active: Optional[bool] = Field(True, description="是否启用")
|
||||||
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
|
|
||||||
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
||||||
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
||||||
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
|
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
|
||||||
@@ -252,7 +239,6 @@ class UpdateEndpointRequest(BaseModel):
|
|||||||
custom_path: Optional[str] = Field(None, max_length=200)
|
custom_path: Optional[str] = Field(None, max_length=200)
|
||||||
priority: Optional[int] = Field(None, ge=0, le=1000)
|
priority: Optional[int] = Field(None, ge=0, le=1000)
|
||||||
is_active: Optional[bool] = None
|
is_active: Optional[bool] = None
|
||||||
rpm_limit: Optional[int] = Field(None, ge=0)
|
|
||||||
concurrent_limit: Optional[int] = Field(None, ge=0)
|
concurrent_limit: Optional[int] = Field(None, ge=0)
|
||||||
config: Optional[Dict[str, Any]] = None
|
config: Optional[Dict[str, Any]] = None
|
||||||
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
|
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
|
||||||
@@ -277,8 +263,7 @@ class CreateAPIKeyRequest(BaseModel):
|
|||||||
api_key: str = Field(..., min_length=1, max_length=500, description="API Key")
|
api_key: str = Field(..., min_length=1, max_length=500, description="API Key")
|
||||||
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
|
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
|
||||||
is_active: Optional[bool] = Field(True, description="是否启用")
|
is_active: Optional[bool] = Field(True, description="是否启用")
|
||||||
max_rpm: Optional[int] = Field(None, ge=0, description="最大 RPM")
|
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制(NULL=自适应)")
|
||||||
max_concurrent: Optional[int] = Field(None, ge=0, description="最大并发")
|
|
||||||
notes: Optional[str] = Field(None, max_length=500, description="备注")
|
notes: Optional[str] = Field(None, max_length=500, description="备注")
|
||||||
|
|
||||||
@field_validator("api_key")
|
@field_validator("api_key")
|
||||||
|
|||||||
@@ -376,14 +376,13 @@ class ApiKeyResponse(BaseModel):
|
|||||||
class ProviderCreate(BaseModel):
|
class ProviderCreate(BaseModel):
|
||||||
"""创建提供商请求
|
"""创建提供商请求
|
||||||
|
|
||||||
新架构说明:
|
架构说明:
|
||||||
- Provider 仅包含提供商的元数据和计费配置
|
- Provider 仅包含提供商的元数据和计费配置
|
||||||
- API格式、URL、认证等配置应在 ProviderEndpoint 中设置
|
- API格式、URL、认证等配置应在 ProviderEndpoint 中设置
|
||||||
- API密钥应在 ProviderAPIKey 中设置
|
- API密钥应在 ProviderAPIKey 中设置
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = Field(..., min_length=1, max_length=100, description="提供商唯一标识")
|
name: str = Field(..., min_length=1, max_length=100, description="提供商名称(唯一)")
|
||||||
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
|
|
||||||
description: Optional[str] = Field(None, description="提供商描述")
|
description: Optional[str] = Field(None, description="提供商描述")
|
||||||
website: Optional[str] = Field(None, max_length=500, description="主站网站")
|
website: Optional[str] = Field(None, max_length=500, description="主站网站")
|
||||||
|
|
||||||
@@ -397,7 +396,7 @@ class ProviderCreate(BaseModel):
|
|||||||
class ProviderUpdate(BaseModel):
|
class ProviderUpdate(BaseModel):
|
||||||
"""更新提供商请求"""
|
"""更新提供商请求"""
|
||||||
|
|
||||||
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
|
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
website: Optional[str] = Field(None, max_length=500)
|
website: Optional[str] = Field(None, max_length=500)
|
||||||
api_format: Optional[str] = None
|
api_format: Optional[str] = None
|
||||||
@@ -418,7 +417,6 @@ class ProviderResponse(BaseModel):
|
|||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
display_name: str
|
|
||||||
description: Optional[str]
|
description: Optional[str]
|
||||||
website: Optional[str]
|
website: Optional[str]
|
||||||
api_format: str
|
api_format: str
|
||||||
@@ -609,7 +607,6 @@ class PublicProviderResponse(BaseModel):
|
|||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
display_name: str
|
|
||||||
description: Optional[str]
|
description: Optional[str]
|
||||||
website: Optional[str]
|
website: Optional[str]
|
||||||
is_active: bool
|
is_active: bool
|
||||||
@@ -627,7 +624,6 @@ class PublicModelResponse(BaseModel):
|
|||||||
id: str
|
id: str
|
||||||
provider_id: str
|
provider_id: str
|
||||||
provider_name: str
|
provider_name: str
|
||||||
provider_display_name: str
|
|
||||||
name: str
|
name: str
|
||||||
display_name: str
|
display_name: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
|||||||
@@ -13,9 +13,7 @@ class ProviderAPIKeyBase(BaseModel):
|
|||||||
|
|
||||||
name: Optional[str] = Field(None, description="密钥名称/备注")
|
name: Optional[str] = Field(None, description="密钥名称/备注")
|
||||||
api_key: str = Field(..., description="API密钥")
|
api_key: str = Field(..., description="API密钥")
|
||||||
rate_limit: Optional[int] = Field(None, description="速率限制(每分钟请求数)")
|
rpm_limit: Optional[int] = Field(None, description="RPM限制(每分钟请求数),NULL=自适应模式")
|
||||||
daily_limit: Optional[int] = Field(None, description="每日请求限制")
|
|
||||||
monthly_limit: Optional[int] = Field(None, description="每月请求限制")
|
|
||||||
priority: int = Field(0, description="优先级(越高越优先使用)")
|
priority: int = Field(0, description="优先级(越高越优先使用)")
|
||||||
is_active: bool = Field(True, description="是否启用")
|
is_active: bool = Field(True, description="是否启用")
|
||||||
expires_at: Optional[datetime] = Field(None, description="过期时间")
|
expires_at: Optional[datetime] = Field(None, description="过期时间")
|
||||||
@@ -32,9 +30,7 @@ class ProviderAPIKeyUpdate(BaseModel):
|
|||||||
|
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
rate_limit: Optional[int] = None
|
rpm_limit: Optional[int] = None
|
||||||
daily_limit: Optional[int] = None
|
|
||||||
monthly_limit: Optional[int] = None
|
|
||||||
priority: Optional[int] = None
|
priority: Optional[int] = None
|
||||||
is_active: Optional[bool] = None
|
is_active: Optional[bool] = None
|
||||||
expires_at: Optional[datetime] = None
|
expires_at: Optional[datetime] = None
|
||||||
@@ -67,5 +63,3 @@ class ProviderAPIKeyStats(BaseModel):
|
|||||||
last_used_at: Optional[datetime]
|
last_used_at: Optional[datetime]
|
||||||
is_active: bool
|
is_active: bool
|
||||||
is_expired: bool
|
is_expired: bool
|
||||||
remaining_daily: Optional[int] = Field(None, description="今日剩余请求数")
|
|
||||||
remaining_monthly: Optional[int] = Field(None, description="本月剩余请求数")
|
|
||||||
|
|||||||
@@ -338,7 +338,8 @@ class Usage(Base):
|
|||||||
request_headers = Column(JSON, nullable=True) # 客户端请求头
|
request_headers = Column(JSON, nullable=True) # 客户端请求头
|
||||||
request_body = Column(JSON, nullable=True) # 请求体(7天内未压缩)
|
request_body = Column(JSON, nullable=True) # 请求体(7天内未压缩)
|
||||||
provider_request_headers = Column(JSON, nullable=True) # 向提供商发送的请求头
|
provider_request_headers = Column(JSON, nullable=True) # 向提供商发送的请求头
|
||||||
response_headers = Column(JSON, nullable=True) # 响应头
|
response_headers = Column(JSON, nullable=True) # 提供商响应头
|
||||||
|
client_response_headers = Column(JSON, nullable=True) # 返回给客户端的响应头
|
||||||
response_body = Column(JSON, nullable=True) # 响应体(7天内未压缩)
|
response_body = Column(JSON, nullable=True) # 响应体(7天内未压缩)
|
||||||
|
|
||||||
# 压缩存储字段(7天后自动压缩到这里)
|
# 压缩存储字段(7天后自动压缩到这里)
|
||||||
@@ -513,8 +514,7 @@ class Provider(Base):
|
|||||||
__tablename__ = "providers"
|
__tablename__ = "providers"
|
||||||
|
|
||||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
||||||
name = Column(String(100), unique=True, nullable=False, index=True) # 提供商唯一标识
|
name = Column(String(100), unique=True, nullable=False, index=True) # 提供商名称(唯一)
|
||||||
display_name = Column(String(100), nullable=False) # 显示名称
|
|
||||||
description = Column(Text, nullable=True) # 提供商描述
|
description = Column(Text, nullable=True) # 提供商描述
|
||||||
website = Column(String(500), nullable=True) # 主站网站
|
website = Column(String(500), nullable=True) # 主站网站
|
||||||
|
|
||||||
@@ -537,11 +537,6 @@ class Provider(Base):
|
|||||||
quota_last_reset_at = Column(DateTime(timezone=True), nullable=True) # 上次额度重置时间
|
quota_last_reset_at = Column(DateTime(timezone=True), nullable=True) # 上次额度重置时间
|
||||||
quota_expires_at = Column(DateTime(timezone=True), nullable=True) # 月卡过期时间
|
quota_expires_at = Column(DateTime(timezone=True), nullable=True) # 月卡过期时间
|
||||||
|
|
||||||
# RPM限制:NULL=无限制,0=禁止请求
|
|
||||||
rpm_limit = Column(Integer, nullable=True) # 每分钟请求数限制(NULL=无限制,0=禁止请求)
|
|
||||||
rpm_used = Column(Integer, default=0) # 当前分钟已用请求数
|
|
||||||
rpm_reset_at = Column(DateTime(timezone=True), nullable=True) # RPM重置时间
|
|
||||||
|
|
||||||
# 提供商优先级 (数字越小越优先,用于提供商优先模式下的 Provider 排序)
|
# 提供商优先级 (数字越小越优先,用于提供商优先模式下的 Provider 排序)
|
||||||
# 0-10: 急需消耗(如即将过期的月卡)
|
# 0-10: 急需消耗(如即将过期的月卡)
|
||||||
# 11-50: 优先消耗(月卡)
|
# 11-50: 优先消耗(月卡)
|
||||||
@@ -555,6 +550,15 @@ class Provider(Base):
|
|||||||
# 限制
|
# 限制
|
||||||
concurrent_limit = Column(Integer, nullable=True) # 并发请求限制
|
concurrent_limit = Column(Integer, nullable=True) # 并发请求限制
|
||||||
|
|
||||||
|
# 请求配置(从 Endpoint 迁移,作为全局默认值)
|
||||||
|
# 超时 300 秒对于 LLM API 是合理的默认值:
|
||||||
|
# - 大多数请求在 30 秒内完成
|
||||||
|
# - 复杂推理(如 Claude thinking)可能需要 60-120 秒
|
||||||
|
# - 300 秒足够覆盖极端场景(如超长上下文、复杂工具调用)
|
||||||
|
timeout = Column(Integer, default=300, nullable=True) # 请求超时(秒)
|
||||||
|
max_retries = Column(Integer, default=2, nullable=True) # 最大重试次数
|
||||||
|
proxy = Column(JSONB, nullable=True) # 代理配置: {url, username, password, enabled}
|
||||||
|
|
||||||
# 配置
|
# 配置
|
||||||
config = Column(JSON, nullable=True) # 额外配置(如Azure deployment name等)
|
config = Column(JSON, nullable=True) # 额外配置(如Azure deployment name等)
|
||||||
|
|
||||||
@@ -574,6 +578,9 @@ class Provider(Base):
|
|||||||
endpoints = relationship(
|
endpoints = relationship(
|
||||||
"ProviderEndpoint", back_populates="provider", cascade="all, delete-orphan"
|
"ProviderEndpoint", back_populates="provider", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
api_keys = relationship(
|
||||||
|
"ProviderAPIKey", back_populates="provider", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
api_key_mappings = relationship(
|
api_key_mappings = relationship(
|
||||||
"ApiKeyProviderMapping", back_populates="provider", cascade="all, delete-orphan"
|
"ApiKeyProviderMapping", back_populates="provider", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
@@ -599,12 +606,6 @@ class ProviderEndpoint(Base):
|
|||||||
timeout = Column(Integer, default=300) # 超时(秒)
|
timeout = Column(Integer, default=300) # 超时(秒)
|
||||||
max_retries = Column(Integer, default=2) # 最大重试次数
|
max_retries = Column(Integer, default=2) # 最大重试次数
|
||||||
|
|
||||||
# 限制
|
|
||||||
max_concurrent = Column(
|
|
||||||
Integer, nullable=True, default=None
|
|
||||||
) # 该端点的最大并发数(NULL=不限制)
|
|
||||||
rate_limit = Column(Integer, nullable=True) # 每分钟请求限制
|
|
||||||
|
|
||||||
# 状态
|
# 状态
|
||||||
is_active = Column(Boolean, default=True, nullable=False)
|
is_active = Column(Boolean, default=True, nullable=False)
|
||||||
|
|
||||||
@@ -632,9 +633,6 @@ class ProviderEndpoint(Base):
|
|||||||
|
|
||||||
# 关系
|
# 关系
|
||||||
provider = relationship("Provider", back_populates="endpoints")
|
provider = relationship("Provider", back_populates="endpoints")
|
||||||
api_keys = relationship(
|
|
||||||
"ProviderAPIKey", back_populates="endpoint", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 唯一约束和索引在表定义后
|
# 唯一约束和索引在表定义后
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
@@ -734,9 +732,11 @@ class GlobalModel(Base):
|
|||||||
class Model(Base):
|
class Model(Base):
|
||||||
"""Provider 模型配置表 - Provider 如何使用某个 GlobalModel
|
"""Provider 模型配置表 - Provider 如何使用某个 GlobalModel
|
||||||
|
|
||||||
设计原则 (方案 A):
|
设计原则:
|
||||||
- 每个 Model 必须关联一个 GlobalModel (global_model_id 不可为空)
|
- Model 表示 Provider 对某个模型的具体实现
|
||||||
- Model 表示 Provider 对某个 GlobalModel 的具体实现
|
- global_model_id 可为空:
|
||||||
|
- 为空时:模型尚未关联到 GlobalModel,不参与路由
|
||||||
|
- 不为空时:模型已关联 GlobalModel,参与路由
|
||||||
- provider_model_name 是 Provider 侧的实际模型名称 (可能与 GlobalModel.name 不同)
|
- provider_model_name 是 Provider 侧的实际模型名称 (可能与 GlobalModel.name 不同)
|
||||||
- 价格和能力配置可为空,为空时使用 GlobalModel 的默认值
|
- 价格和能力配置可为空,为空时使用 GlobalModel 的默认值
|
||||||
"""
|
"""
|
||||||
@@ -745,7 +745,8 @@ class Model(Base):
|
|||||||
|
|
||||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
||||||
provider_id = Column(String(36), ForeignKey("providers.id"), nullable=False)
|
provider_id = Column(String(36), ForeignKey("providers.id"), nullable=False)
|
||||||
global_model_id = Column(String(36), ForeignKey("global_models.id"), nullable=False, index=True)
|
# 可为空:NULL 表示未关联,不参与路由;非 NULL 表示已关联,参与路由
|
||||||
|
global_model_id = Column(String(36), ForeignKey("global_models.id"), nullable=True, index=True)
|
||||||
|
|
||||||
# Provider 映射配置
|
# Provider 映射配置
|
||||||
provider_model_name = Column(String(200), nullable=False) # Provider 侧的主模型名称
|
provider_model_name = Column(String(200), nullable=False) # Provider 侧的主模型名称
|
||||||
@@ -983,17 +984,20 @@ class Model(Base):
|
|||||||
|
|
||||||
|
|
||||||
class ProviderAPIKey(Base):
|
class ProviderAPIKey(Base):
|
||||||
"""Provider API密钥表 - 归属于特定 ProviderEndpoint"""
|
"""Provider API密钥表 - 直接归属于 Provider,支持多种 API 格式"""
|
||||||
|
|
||||||
__tablename__ = "provider_api_keys"
|
__tablename__ = "provider_api_keys"
|
||||||
|
|
||||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
||||||
|
|
||||||
# 外键关系
|
# 外键关系 - 直接关联 Provider
|
||||||
endpoint_id = Column(
|
provider_id = Column(
|
||||||
String(36), ForeignKey("provider_endpoints.id", ondelete="CASCADE"), nullable=False
|
String(36), ForeignKey("providers.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# API 格式支持列表(核心字段)
|
||||||
|
api_formats = Column(JSON, nullable=False, default=list) # ["CLAUDE", "CLAUDE_CLI"]
|
||||||
|
|
||||||
# API密钥信息
|
# API密钥信息
|
||||||
api_key = Column(String(500), nullable=False) # API密钥(加密存储)
|
api_key = Column(String(500), nullable=False) # API密钥(加密存储)
|
||||||
name = Column(String(100), nullable=False) # 密钥名称(必填,用于识别)
|
name = Column(String(100), nullable=False) # 密钥名称(必填,用于识别)
|
||||||
@@ -1002,7 +1006,10 @@ class ProviderAPIKey(Base):
|
|||||||
# 成本计算
|
# 成本计算
|
||||||
rate_multiplier = Column(
|
rate_multiplier = Column(
|
||||||
Float, default=1.0, nullable=False
|
Float, default=1.0, nullable=False
|
||||||
) # 成本倍率(真实成本 = 表面成本 × 倍率)
|
) # 默认成本倍率(真实成本 = 表面成本 × 倍率)
|
||||||
|
rate_multipliers = Column(
|
||||||
|
JSON, nullable=True
|
||||||
|
) # 按 API 格式的成本倍率 {"CLAUDE": 1.0, "OPENAI": 0.8}
|
||||||
|
|
||||||
# 优先级配置 (数字越小越优先)
|
# 优先级配置 (数字越小越优先)
|
||||||
internal_priority = Column(
|
internal_priority = Column(
|
||||||
@@ -1012,14 +1019,11 @@ class ProviderAPIKey(Base):
|
|||||||
Integer, nullable=True
|
Integer, nullable=True
|
||||||
) # 全局 Key 优先级(用于全局 Key 优先模式,跨 Provider 的 Key 排序,NULL=未配置使用默认排序)
|
) # 全局 Key 优先级(用于全局 Key 优先模式,跨 Provider 的 Key 排序,NULL=未配置使用默认排序)
|
||||||
|
|
||||||
# 并发限制配置
|
# RPM 限制配置(自适应学习)
|
||||||
# max_concurrent 决定并发控制模式:
|
# rpm_limit 决定 RPM 控制模式:
|
||||||
# - NULL: 自适应模式,系统自动学习并调整(使用 learned_max_concurrent)
|
# - NULL: 自适应模式,系统自动学习并调整(使用 learned_rpm_limit)
|
||||||
# - 数字: 固定限制模式,使用用户指定的值
|
# - 数字: 固定限制模式,使用用户指定的值
|
||||||
max_concurrent = Column(Integer, nullable=True, default=None)
|
rpm_limit = Column(Integer, nullable=True, default=None)
|
||||||
rate_limit = Column(Integer, nullable=True) # 速率限制(每分钟请求数)
|
|
||||||
daily_limit = Column(Integer, nullable=True) # 每日请求限制
|
|
||||||
monthly_limit = Column(Integer, nullable=True) # 每月请求限制
|
|
||||||
|
|
||||||
# 模型权限控制
|
# 模型权限控制
|
||||||
allowed_models = Column(JSON, nullable=True) # 允许使用的模型列表(null = 支持所有模型)
|
allowed_models = Column(JSON, nullable=True) # 允许使用的模型列表(null = 支持所有模型)
|
||||||
@@ -1028,16 +1032,16 @@ class ProviderAPIKey(Base):
|
|||||||
capabilities = Column(JSON, nullable=True) # Key 拥有的能力
|
capabilities = Column(JSON, nullable=True) # Key 拥有的能力
|
||||||
# 示例: {"cache_1h": true, "context_1m": true}
|
# 示例: {"cache_1h": true, "context_1m": true}
|
||||||
|
|
||||||
# 自适应并发调整(仅当 max_concurrent = NULL 时生效)
|
# 自适应 RPM 调整(仅当 rpm_limit = NULL 时生效)
|
||||||
learned_max_concurrent = Column(
|
learned_rpm_limit = Column(
|
||||||
Integer, nullable=True
|
Integer, nullable=True
|
||||||
) # 学习到的并发限制(自适应模式下的有效值)
|
) # 学习到的 RPM 限制(自适应模式下的有效值)
|
||||||
concurrent_429_count = Column(Integer, default=0, nullable=False) # 因并发导致的429次数
|
concurrent_429_count = Column(Integer, default=0, nullable=False) # 因并发导致的429次数
|
||||||
rpm_429_count = Column(Integer, default=0, nullable=False) # 因RPM导致的429次数
|
rpm_429_count = Column(Integer, default=0, nullable=False) # 因RPM导致的429次数
|
||||||
last_429_at = Column(DateTime(timezone=True), nullable=True) # 最后429时间
|
last_429_at = Column(DateTime(timezone=True), nullable=True) # 最后429时间
|
||||||
last_429_type = Column(String(50), nullable=True) # 最后429类型: concurrent/rpm/unknown
|
last_429_type = Column(String(50), nullable=True) # 最后429类型: concurrent/rpm/unknown
|
||||||
last_concurrent_peak = Column(Integer, nullable=True) # 触发429时的并发数
|
last_rpm_peak = Column(Integer, nullable=True) # 触发429时的RPM峰值
|
||||||
adjustment_history = Column(JSON, nullable=True) # 并发调整历史
|
adjustment_history = Column(JSON, nullable=True) # RPM调整历史
|
||||||
# 基于滑动窗口的利用率追踪
|
# 基于滑动窗口的利用率追踪
|
||||||
utilization_samples = Column(
|
utilization_samples = Column(
|
||||||
JSON, nullable=True
|
JSON, nullable=True
|
||||||
@@ -1046,12 +1050,9 @@ class ProviderAPIKey(Base):
|
|||||||
DateTime(timezone=True), nullable=True
|
DateTime(timezone=True), nullable=True
|
||||||
) # 上次探测性扩容时间
|
) # 上次探测性扩容时间
|
||||||
|
|
||||||
# 健康度追踪(基于滑动窗口)
|
# 健康度追踪(按 API 格式存储)
|
||||||
health_score = Column(Float, default=1.0) # 0.0-1.0(保留用于展示,实际熔断基于滑动窗口)
|
# 结构: {"CLAUDE": {"health_score": 1.0, "consecutive_failures": 0, "last_failure_at": null, "request_results_window": []}, ...}
|
||||||
consecutive_failures = Column(Integer, default=0)
|
health_by_format = Column(JSON, nullable=True, default=dict)
|
||||||
last_failure_at = Column(DateTime(timezone=True), nullable=True) # 最后失败时间
|
|
||||||
# 滑动窗口:记录最近 N 次请求的结果 [{"ts": timestamp, "ok": true/false}, ...]
|
|
||||||
request_results_window = Column(JSON, nullable=True)
|
|
||||||
|
|
||||||
# 缓存与熔断配置
|
# 缓存与熔断配置
|
||||||
cache_ttl_minutes = Column(
|
cache_ttl_minutes = Column(
|
||||||
@@ -1061,14 +1062,9 @@ class ProviderAPIKey(Base):
|
|||||||
Integer, default=32, nullable=False
|
Integer, default=32, nullable=False
|
||||||
) # 最大探测间隔(分钟),默认32分钟(硬上限)
|
) # 最大探测间隔(分钟),默认32分钟(硬上限)
|
||||||
|
|
||||||
# 熔断器字段(滑动窗口 + 半开状态模式)
|
# 熔断器状态(按 API 格式存储)
|
||||||
circuit_breaker_open = Column(Boolean, default=False, nullable=False) # 熔断器是否打开
|
# 结构: {"CLAUDE": {"open": false, "open_at": null, "next_probe_at": null, "half_open_until": null, "half_open_successes": 0, "half_open_failures": 0}, ...}
|
||||||
circuit_breaker_open_at = Column(DateTime(timezone=True), nullable=True) # 熔断器打开时间
|
circuit_breaker_by_format = Column(JSON, nullable=True, default=dict)
|
||||||
next_probe_at = Column(DateTime(timezone=True), nullable=True) # 下次探测时间
|
|
||||||
# 半开状态:允许少量请求通过验证服务是否恢复
|
|
||||||
half_open_until = Column(DateTime(timezone=True), nullable=True) # 半开状态结束时间
|
|
||||||
half_open_successes = Column(Integer, default=0) # 半开状态下的成功次数
|
|
||||||
half_open_failures = Column(Integer, default=0) # 半开状态下的失败次数
|
|
||||||
|
|
||||||
# 使用统计
|
# 使用统计
|
||||||
request_count = Column(Integer, default=0) # 请求次数
|
request_count = Column(Integer, default=0) # 请求次数
|
||||||
@@ -1095,7 +1091,7 @@ class ProviderAPIKey(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 关系
|
# 关系
|
||||||
endpoint = relationship("ProviderEndpoint", back_populates="api_keys")
|
provider = relationship("Provider", back_populates="api_keys")
|
||||||
|
|
||||||
|
|
||||||
class UserPreference(Base):
|
class UserPreference(Base):
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ ProviderEndpoint 相关的 API 模型定义
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
@@ -26,10 +26,6 @@ class ProviderEndpointCreate(BaseModel):
|
|||||||
timeout: int = Field(default=300, ge=10, le=600, description="超时时间(秒)")
|
timeout: int = Field(default=300, ge=10, le=600, description="超时时间(秒)")
|
||||||
max_retries: int = Field(default=2, ge=0, le=10, description="最大重试次数")
|
max_retries: int = Field(default=2, ge=0, le=10, description="最大重试次数")
|
||||||
|
|
||||||
# 限制
|
|
||||||
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
|
|
||||||
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制(请求/秒)")
|
|
||||||
|
|
||||||
# 额外配置
|
# 额外配置
|
||||||
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置(JSON)")
|
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置(JSON)")
|
||||||
|
|
||||||
@@ -67,8 +63,6 @@ class ProviderEndpointUpdate(BaseModel):
|
|||||||
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
|
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
|
||||||
timeout: Optional[int] = Field(default=None, ge=10, le=600, description="超时时间(秒)")
|
timeout: Optional[int] = Field(default=None, ge=10, le=600, description="超时时间(秒)")
|
||||||
max_retries: Optional[int] = Field(default=None, ge=0, le=10, description="最大重试次数")
|
max_retries: Optional[int] = Field(default=None, ge=0, le=10, description="最大重试次数")
|
||||||
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
|
|
||||||
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
|
|
||||||
is_active: Optional[bool] = Field(default=None, description="是否启用")
|
is_active: Optional[bool] = Field(default=None, description="是否启用")
|
||||||
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置")
|
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置")
|
||||||
proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置")
|
proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置")
|
||||||
@@ -103,10 +97,6 @@ class ProviderEndpointResponse(BaseModel):
|
|||||||
timeout: int
|
timeout: int
|
||||||
max_retries: int
|
max_retries: int
|
||||||
|
|
||||||
# 限制
|
|
||||||
max_concurrent: Optional[int] = None
|
|
||||||
rate_limit: Optional[int] = None
|
|
||||||
|
|
||||||
# 状态
|
# 状态
|
||||||
is_active: bool
|
is_active: bool
|
||||||
|
|
||||||
@@ -127,32 +117,37 @@ class ProviderEndpointResponse(BaseModel):
|
|||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
# ========== ProviderAPIKey 相关(新架构) ==========
|
# ========== ProviderAPIKey 相关 ==========
|
||||||
|
|
||||||
|
|
||||||
class EndpointAPIKeyCreate(BaseModel):
|
class EndpointAPIKeyCreate(BaseModel):
|
||||||
"""为 Endpoint 添加 API Key"""
|
"""为 Provider 添加 API Key"""
|
||||||
|
|
||||||
endpoint_id: str = Field(..., description="Endpoint ID")
|
provider_id: Optional[str] = Field(default=None, description="Provider ID(从 URL 获取)")
|
||||||
api_key: str = Field(..., min_length=10, max_length=500, description="API Key(将自动加密)")
|
api_formats: Optional[List[str]] = Field(
|
||||||
|
default=None, min_length=1, description="支持的 API 格式列表(必填,路由层校验)"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key: str = Field(..., min_length=3, max_length=500, description="API Key(将自动加密)")
|
||||||
name: str = Field(..., min_length=1, max_length=100, description="密钥名称(必填,用于识别)")
|
name: str = Field(..., min_length=1, max_length=100, description="密钥名称(必填,用于识别)")
|
||||||
|
|
||||||
# 成本计算
|
# 成本计算
|
||||||
rate_multiplier: float = Field(
|
rate_multiplier: float = Field(
|
||||||
default=1.0, ge=0.01, description="成本倍率(真实成本 = 表面成本 × 倍率)"
|
default=1.0, ge=0.01, description="默认成本倍率(真实成本 = 表面成本 × 倍率)"
|
||||||
|
)
|
||||||
|
rate_multipliers: Optional[Dict[str, float]] = Field(
|
||||||
|
default=None, description="按 API 格式的成本倍率,如 {'CLAUDE': 1.0, 'OPENAI': 0.8}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 优先级和限制(数字越小越优先)
|
# 优先级和限制(数字越小越优先)
|
||||||
internal_priority: int = Field(default=50, description="Endpoint 内部优先级(提供商优先模式)")
|
internal_priority: int = Field(default=50, description="Key 内部优先级(提供商优先模式)")
|
||||||
# max_concurrent: NULL=自适应模式(系统自动学习),数字=固定限制模式
|
# rpm_limit: NULL=自适应模式(系统自动学习),数字=固定限制模式
|
||||||
max_concurrent: Optional[int] = Field(
|
rpm_limit: Optional[int] = Field(
|
||||||
default=None, ge=1, description="最大并发数(NULL=自适应模式)"
|
default=None, ge=1, le=10000, description="RPM 限制(NULL=自适应模式)"
|
||||||
)
|
)
|
||||||
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
|
allowed_models: Optional[Union[List[str], Dict[str, List[str]]]] = Field(
|
||||||
daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制")
|
default=None,
|
||||||
monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制")
|
description="允许使用的模型列表(null=不限制,列表=简单白名单,字典=按API格式区分)",
|
||||||
allowed_models: Optional[List[str]] = Field(
|
|
||||||
default=None, description="允许使用的模型列表(null = 支持所有模型)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 能力标签
|
# 能力标签
|
||||||
@@ -171,17 +166,99 @@ class EndpointAPIKeyCreate(BaseModel):
|
|||||||
# 备注
|
# 备注
|
||||||
note: Optional[str] = Field(default=None, max_length=500, description="备注说明(可选)")
|
note: Optional[str] = Field(default=None, max_length=500, description="备注说明(可选)")
|
||||||
|
|
||||||
|
@field_validator("api_formats")
|
||||||
|
@classmethod
|
||||||
|
def validate_api_formats(cls, v: Optional[List[str]]) -> Optional[List[str]]:
|
||||||
|
"""验证 API 格式列表"""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
|
||||||
|
from src.core.enums import APIFormat
|
||||||
|
|
||||||
|
allowed = [fmt.value for fmt in APIFormat]
|
||||||
|
validated = []
|
||||||
|
seen = set()
|
||||||
|
for fmt in v:
|
||||||
|
fmt_upper = fmt.upper()
|
||||||
|
if fmt_upper not in allowed:
|
||||||
|
raise ValueError(f"API 格式必须是 {allowed} 之一,当前值: {fmt}")
|
||||||
|
if fmt_upper in seen:
|
||||||
|
continue # 静默去重
|
||||||
|
seen.add(fmt_upper)
|
||||||
|
validated.append(fmt_upper)
|
||||||
|
return validated
|
||||||
|
|
||||||
|
@field_validator("allowed_models")
|
||||||
|
@classmethod
|
||||||
|
def validate_allowed_models(
|
||||||
|
cls, v: Optional[Union[List[str], Dict[str, List[str]]]]
|
||||||
|
) -> Optional[Union[List[str], Dict[str, List[str]]]]:
|
||||||
|
"""
|
||||||
|
规范化 allowed_models:
|
||||||
|
- 列表模式:去空、去重、保留顺序
|
||||||
|
- 字典模式:key 统一大写(支持 "*"),value 去空、去重、保留顺序
|
||||||
|
"""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
|
||||||
|
if isinstance(v, list):
|
||||||
|
cleaned: List[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for item in v:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
raise ValueError("allowed_models 列表必须为字符串数组")
|
||||||
|
name = item.strip()
|
||||||
|
if not name or name in seen:
|
||||||
|
continue
|
||||||
|
seen.add(name)
|
||||||
|
cleaned.append(name)
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
if isinstance(v, dict):
|
||||||
|
from src.core.enums import APIFormat
|
||||||
|
|
||||||
|
allowed_formats = {fmt.value for fmt in APIFormat}
|
||||||
|
normalized: Dict[str, List[str]] = {}
|
||||||
|
for raw_key, models in v.items():
|
||||||
|
if not isinstance(raw_key, str):
|
||||||
|
raise ValueError("allowed_models 字典的 key 必须为字符串")
|
||||||
|
|
||||||
|
key = raw_key.upper()
|
||||||
|
if key != "*" and key not in allowed_formats:
|
||||||
|
raise ValueError(
|
||||||
|
f"allowed_models 字典的 key 必须是 {sorted(allowed_formats)} 或 '*',当前值: {raw_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if models is None:
|
||||||
|
# null 表示该格式不限制,跳过(不加入字典)
|
||||||
|
continue
|
||||||
|
if not isinstance(models, list):
|
||||||
|
raise ValueError("allowed_models 字典的 value 必须为字符串数组")
|
||||||
|
|
||||||
|
cleaned: List[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for item in models:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
raise ValueError("allowed_models 字典的 value 必须为字符串数组")
|
||||||
|
name = item.strip()
|
||||||
|
if not name or name in seen:
|
||||||
|
continue
|
||||||
|
seen.add(name)
|
||||||
|
cleaned.append(name)
|
||||||
|
|
||||||
|
normalized[key] = cleaned
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
raise ValueError("allowed_models 必须是列表或字典")
|
||||||
|
|
||||||
@field_validator("api_key")
|
@field_validator("api_key")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_api_key(cls, v: str) -> str:
|
def validate_api_key(cls, v: str) -> str:
|
||||||
"""验证 API Key 安全性"""
|
"""验证 API Key 安全性"""
|
||||||
# 移除首尾空白
|
# 移除首尾空白(长度校验由 Field min_length 处理)
|
||||||
v = v.strip()
|
v = v.strip()
|
||||||
|
|
||||||
# 检查最小长度
|
|
||||||
if len(v) < 10:
|
|
||||||
raise ValueError("API Key 长度不能少于 10 个字符")
|
|
||||||
|
|
||||||
# 检查危险字符(SQL 注入防护)
|
# 检查危险字符(SQL 注入防护)
|
||||||
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
|
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
|
||||||
for char in dangerous_chars:
|
for char in dangerous_chars:
|
||||||
@@ -218,26 +295,35 @@ class EndpointAPIKeyCreate(BaseModel):
|
|||||||
class EndpointAPIKeyUpdate(BaseModel):
|
class EndpointAPIKeyUpdate(BaseModel):
|
||||||
"""更新 Endpoint API Key"""
|
"""更新 Endpoint API Key"""
|
||||||
|
|
||||||
|
api_formats: Optional[List[str]] = Field(
|
||||||
|
default=None, min_length=1, description="支持的 API 格式列表"
|
||||||
|
)
|
||||||
|
|
||||||
api_key: Optional[str] = Field(
|
api_key: Optional[str] = Field(
|
||||||
default=None, min_length=10, max_length=500, description="API Key(将自动加密)"
|
default=None, min_length=3, max_length=500, description="API Key(将自动加密)"
|
||||||
)
|
)
|
||||||
name: Optional[str] = Field(default=None, min_length=1, max_length=100, description="密钥名称")
|
name: Optional[str] = Field(default=None, min_length=1, max_length=100, description="密钥名称")
|
||||||
rate_multiplier: Optional[float] = Field(default=None, ge=0.01, description="成本倍率")
|
rate_multiplier: Optional[float] = Field(default=None, ge=0.01, description="默认成本倍率")
|
||||||
|
rate_multipliers: Optional[Dict[str, float]] = Field(
|
||||||
|
default=None, description="按 API 格式的成本倍率,如 {'CLAUDE': 1.0, 'OPENAI': 0.8}"
|
||||||
|
)
|
||||||
internal_priority: Optional[int] = Field(
|
internal_priority: Optional[int] = Field(
|
||||||
default=None, description="Endpoint 内部优先级(提供商优先模式,数字越小越优先)"
|
default=None, description="Key 内部优先级(提供商优先模式,数字越小越优先)"
|
||||||
)
|
)
|
||||||
global_priority: Optional[int] = Field(
|
global_priority: Optional[int] = Field(
|
||||||
default=None, description="全局 Key 优先级(全局 Key 优先模式,数字越小越优先)"
|
default=None, description="全局 Key 优先级(全局 Key 优先模式,数字越小越优先)"
|
||||||
)
|
)
|
||||||
# max_concurrent: 使用特殊标记区分"未提供"和"设置为 null(自适应模式)"
|
# rpm_limit: 使用特殊标记区分"未提供"和"设置为 null(自适应模式)"
|
||||||
# - 不提供字段:不更新
|
# - 不提供字段:不更新
|
||||||
# - 提供 null:切换为自适应模式
|
# - 提供 null:切换为自适应模式
|
||||||
# - 提供数字:设置固定并发限制
|
# - 提供数字:设置固定 RPM 限制
|
||||||
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数(null=自适应模式)")
|
rpm_limit: Optional[int] = Field(
|
||||||
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
|
default=None, ge=1, le=10000, description="RPM 限制(null=自适应模式)"
|
||||||
daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制")
|
)
|
||||||
monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制")
|
allowed_models: Optional[Union[List[str], Dict[str, List[str]]]] = Field(
|
||||||
allowed_models: Optional[List[str]] = Field(default=None, description="允许使用的模型列表")
|
default=None,
|
||||||
|
description="允许使用的模型列表(null=不限制,列表=简单白名单,字典=按API格式区分)",
|
||||||
|
)
|
||||||
capabilities: Optional[Dict[str, bool]] = Field(
|
capabilities: Optional[Dict[str, bool]] = Field(
|
||||||
default=None, description="Key 能力标签,如 {'cache_1h': true, 'context_1m': true}"
|
default=None, description="Key 能力标签,如 {'cache_1h': true, 'context_1m': true}"
|
||||||
)
|
)
|
||||||
@@ -250,6 +336,36 @@ class EndpointAPIKeyUpdate(BaseModel):
|
|||||||
is_active: Optional[bool] = Field(default=None, description="是否启用")
|
is_active: Optional[bool] = Field(default=None, description="是否启用")
|
||||||
note: Optional[str] = Field(default=None, max_length=500, description="备注说明")
|
note: Optional[str] = Field(default=None, max_length=500, description="备注说明")
|
||||||
|
|
||||||
|
@field_validator("api_formats")
|
||||||
|
@classmethod
|
||||||
|
def validate_api_formats(cls, v: Optional[List[str]]) -> Optional[List[str]]:
|
||||||
|
"""验证 API 格式列表"""
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
|
||||||
|
from src.core.enums import APIFormat
|
||||||
|
|
||||||
|
allowed = [fmt.value for fmt in APIFormat]
|
||||||
|
validated = []
|
||||||
|
seen = set()
|
||||||
|
for fmt in v:
|
||||||
|
fmt_upper = fmt.upper()
|
||||||
|
if fmt_upper not in allowed:
|
||||||
|
raise ValueError(f"API 格式必须是 {allowed} 之一,当前值: {fmt}")
|
||||||
|
if fmt_upper in seen:
|
||||||
|
continue # 静默去重
|
||||||
|
seen.add(fmt_upper)
|
||||||
|
validated.append(fmt_upper)
|
||||||
|
return validated
|
||||||
|
|
||||||
|
@field_validator("allowed_models")
|
||||||
|
@classmethod
|
||||||
|
def validate_allowed_models(
|
||||||
|
cls, v: Optional[Union[List[str], Dict[str, List[str]]]]
|
||||||
|
) -> Optional[Union[List[str], Dict[str, List[str]]]]:
|
||||||
|
# 与 EndpointAPIKeyCreate 保持一致
|
||||||
|
return EndpointAPIKeyCreate.validate_allowed_models(v)
|
||||||
|
|
||||||
@field_validator("api_key")
|
@field_validator("api_key")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_api_key(cls, v: Optional[str]) -> Optional[str]:
|
def validate_api_key(cls, v: Optional[str]) -> Optional[str]:
|
||||||
@@ -299,7 +415,9 @@ class EndpointAPIKeyResponse(BaseModel):
|
|||||||
"""Endpoint API Key 响应"""
|
"""Endpoint API Key 响应"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
endpoint_id: str
|
|
||||||
|
provider_id: str = Field(..., description="Provider ID")
|
||||||
|
api_formats: List[str] = Field(default=[], description="支持的 API 格式列表")
|
||||||
|
|
||||||
# Key 信息(脱敏)
|
# Key 信息(脱敏)
|
||||||
api_key_masked: str = Field(..., description="脱敏后的 Key")
|
api_key_masked: str = Field(..., description="脱敏后的 Key")
|
||||||
@@ -307,31 +425,37 @@ class EndpointAPIKeyResponse(BaseModel):
|
|||||||
name: str = Field(..., description="密钥名称")
|
name: str = Field(..., description="密钥名称")
|
||||||
|
|
||||||
# 成本计算
|
# 成本计算
|
||||||
rate_multiplier: float = Field(default=1.0, description="成本倍率")
|
rate_multiplier: float = Field(default=1.0, description="默认成本倍率")
|
||||||
|
rate_multipliers: Optional[Dict[str, float]] = Field(
|
||||||
|
default=None, description="按 API 格式的成本倍率,如 {'CLAUDE': 1.0, 'OPENAI': 0.8}"
|
||||||
|
)
|
||||||
|
|
||||||
# 优先级和限制
|
# 优先级和限制
|
||||||
internal_priority: int = Field(default=50, description="Endpoint 内部优先级")
|
internal_priority: int = Field(default=50, description="Endpoint 内部优先级")
|
||||||
global_priority: Optional[int] = Field(default=None, description="全局 Key 优先级")
|
global_priority: Optional[int] = Field(default=None, description="全局 Key 优先级")
|
||||||
max_concurrent: Optional[int] = None
|
rpm_limit: Optional[int] = None
|
||||||
rate_limit: Optional[int] = None
|
allowed_models: Optional[Union[List[str], Dict[str, List[str]]]] = None
|
||||||
daily_limit: Optional[int] = None
|
capabilities: Optional[Dict[str, bool]] = Field(default=None, description="Key 能力标签")
|
||||||
monthly_limit: Optional[int] = None
|
|
||||||
allowed_models: Optional[List[str]] = None
|
|
||||||
capabilities: Optional[Dict[str, bool]] = Field(
|
|
||||||
default=None, description="Key 能力标签"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 缓存与熔断配置
|
# 缓存与熔断配置
|
||||||
cache_ttl_minutes: int = Field(default=5, description="缓存 TTL(分钟),0=禁用")
|
cache_ttl_minutes: int = Field(default=5, description="缓存 TTL(分钟),0=禁用")
|
||||||
max_probe_interval_minutes: int = Field(default=32, description="熔断探测间隔(分钟)")
|
max_probe_interval_minutes: int = Field(default=32, description="熔断探测间隔(分钟)")
|
||||||
|
|
||||||
# 健康度
|
# 按格式的健康度数据
|
||||||
health_score: float
|
health_by_format: Optional[Dict[str, Any]] = Field(
|
||||||
consecutive_failures: int
|
default=None, description="按 API 格式存储的健康度数据"
|
||||||
|
)
|
||||||
|
circuit_breaker_by_format: Optional[Dict[str, Any]] = Field(
|
||||||
|
default=None, description="按 API 格式存储的熔断器状态"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 聚合字段(从 health_by_format 计算,用于列表显示)
|
||||||
|
health_score: float = Field(default=1.0, description="健康度(所有格式中的最低值)")
|
||||||
|
consecutive_failures: int = Field(default=0, description="连续失败次数")
|
||||||
last_failure_at: Optional[datetime] = None
|
last_failure_at: Optional[datetime] = None
|
||||||
|
|
||||||
# 熔断器状态(滑动窗口 + 半开模式)
|
# 聚合熔断器字段
|
||||||
circuit_breaker_open: bool = Field(default=False, description="熔断器是否打开")
|
circuit_breaker_open: bool = Field(default=False, description="熔断器是否打开(任何格式)")
|
||||||
circuit_breaker_open_at: Optional[datetime] = Field(default=None, description="熔断器打开时间")
|
circuit_breaker_open_at: Optional[datetime] = Field(default=None, description="熔断器打开时间")
|
||||||
next_probe_at: Optional[datetime] = Field(default=None, description="下次进入半开状态时间")
|
next_probe_at: Optional[datetime] = Field(default=None, description="下次进入半开状态时间")
|
||||||
half_open_until: Optional[datetime] = Field(default=None, description="半开状态结束时间")
|
half_open_until: Optional[datetime] = Field(default=None, description="半开状态结束时间")
|
||||||
@@ -349,9 +473,9 @@ class EndpointAPIKeyResponse(BaseModel):
|
|||||||
# 状态
|
# 状态
|
||||||
is_active: bool
|
is_active: bool
|
||||||
|
|
||||||
# 自适应并发信息
|
# 自适应 RPM 信息
|
||||||
is_adaptive: bool = Field(default=False, description="是否为自适应模式(max_concurrent=NULL)")
|
is_adaptive: bool = Field(default=False, description="是否为自适应模式(rpm_limit=NULL)")
|
||||||
learned_max_concurrent: Optional[int] = Field(None, description="学习到的并发限制")
|
learned_rpm_limit: Optional[int] = Field(None, description="学习到的 RPM 限制")
|
||||||
effective_limit: Optional[int] = Field(None, description="当前有效限制")
|
effective_limit: Optional[int] = Field(None, description="当前有效限制")
|
||||||
# 滑动窗口利用率采样
|
# 滑动窗口利用率采样
|
||||||
utilization_samples: Optional[List[dict]] = Field(None, description="利用率采样窗口")
|
utilization_samples: Optional[List[dict]] = Field(None, description="利用率采样窗口")
|
||||||
@@ -375,22 +499,42 @@ class EndpointAPIKeyResponse(BaseModel):
|
|||||||
# ========== 健康监控相关 ==========
|
# ========== 健康监控相关 ==========
|
||||||
|
|
||||||
|
|
||||||
class HealthStatusResponse(BaseModel):
|
class FormatHealthData(BaseModel):
|
||||||
"""健康状态响应(仅 Key 级别)"""
|
"""单个 API 格式的健康度数据"""
|
||||||
|
|
||||||
# Key 健康状态
|
health_score: float = 1.0
|
||||||
|
error_rate: float = 0.0
|
||||||
|
window_size: int = 0
|
||||||
|
consecutive_failures: int = 0
|
||||||
|
last_failure_at: Optional[str] = None
|
||||||
|
circuit_breaker: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class HealthStatusResponse(BaseModel):
|
||||||
|
"""健康状态响应(支持按格式查询)"""
|
||||||
|
|
||||||
|
# 基础信息
|
||||||
key_id: str
|
key_id: str
|
||||||
key_health_score: float
|
|
||||||
key_consecutive_failures: int
|
|
||||||
key_last_failure_at: Optional[datetime] = None
|
|
||||||
key_is_active: bool
|
key_is_active: bool
|
||||||
key_statistics: Optional[Dict[str, Any]] = None
|
key_statistics: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
# 熔断器状态(滑动窗口 + 半开模式)
|
# 整体健康度(取所有格式中的最低值)
|
||||||
|
key_health_score: float = 1.0
|
||||||
|
any_circuit_open: bool = False
|
||||||
|
|
||||||
|
# 按格式的健康度数据
|
||||||
|
health_by_format: Optional[Dict[str, FormatHealthData]] = None
|
||||||
|
|
||||||
|
# 单格式查询时的字段
|
||||||
|
api_format: Optional[str] = None
|
||||||
|
key_consecutive_failures: Optional[int] = None
|
||||||
|
key_last_failure_at: Optional[str] = None
|
||||||
|
|
||||||
|
# 单格式查询时的熔断器状态
|
||||||
circuit_breaker_open: bool = False
|
circuit_breaker_open: bool = False
|
||||||
circuit_breaker_open_at: Optional[datetime] = None
|
circuit_breaker_open_at: Optional[str] = None
|
||||||
next_probe_at: Optional[datetime] = None
|
next_probe_at: Optional[str] = None
|
||||||
half_open_until: Optional[datetime] = None
|
half_open_until: Optional[str] = None
|
||||||
half_open_successes: int = 0
|
half_open_successes: int = 0
|
||||||
half_open_failures: int = 0
|
half_open_failures: int = 0
|
||||||
|
|
||||||
@@ -402,33 +546,22 @@ class HealthSummaryResponse(BaseModel):
|
|||||||
keys: Dict[str, int] = Field(..., description="Key 统计 (total, active, unhealthy)")
|
keys: Dict[str, int] = Field(..., description="Key 统计 (total, active, unhealthy)")
|
||||||
|
|
||||||
|
|
||||||
# ========== 并发控制相关 ==========
|
# ========== RPM 控制相关 ==========
|
||||||
|
|
||||||
|
|
||||||
class ConcurrencyStatusResponse(BaseModel):
|
class KeyRpmStatusResponse(BaseModel):
|
||||||
"""并发状态响应"""
|
"""Key RPM 状态响应"""
|
||||||
|
|
||||||
endpoint_id: Optional[str] = None
|
key_id: str = Field(..., description="Key ID")
|
||||||
endpoint_current_concurrency: int = Field(default=0, description="Endpoint 当前并发数")
|
current_rpm: int = Field(default=0, description="当前 RPM 计数")
|
||||||
endpoint_max_concurrent: Optional[int] = Field(default=None, description="Endpoint 最大并发数")
|
rpm_limit: Optional[int] = Field(default=None, description="RPM 限制")
|
||||||
|
|
||||||
key_id: Optional[str] = None
|
|
||||||
key_current_concurrency: int = Field(default=0, description="Key 当前并发数")
|
|
||||||
key_max_concurrent: Optional[int] = Field(default=None, description="Key 最大并发数")
|
|
||||||
|
|
||||||
|
|
||||||
class ResetConcurrencyRequest(BaseModel):
|
|
||||||
"""重置并发计数请求"""
|
|
||||||
|
|
||||||
endpoint_id: Optional[str] = Field(default=None, description="Endpoint ID(可选)")
|
|
||||||
key_id: Optional[str] = Field(default=None, description="Key ID(可选)")
|
|
||||||
|
|
||||||
|
|
||||||
class KeyPriorityItem(BaseModel):
|
class KeyPriorityItem(BaseModel):
|
||||||
"""单个 Key 优先级项"""
|
"""单个 Key 优先级项"""
|
||||||
|
|
||||||
key_id: str = Field(..., description="Key ID")
|
key_id: str = Field(..., description="Key ID")
|
||||||
internal_priority: int = Field(..., ge=0, description="Endpoint 内部优先级(数字越小越优先)")
|
internal_priority: int = Field(..., ge=0, description="Key 内部优先级(数字越小越优先)")
|
||||||
|
|
||||||
|
|
||||||
class BatchUpdateKeyPriorityRequest(BaseModel):
|
class BatchUpdateKeyPriorityRequest(BaseModel):
|
||||||
@@ -443,11 +576,9 @@ class BatchUpdateKeyPriorityRequest(BaseModel):
|
|||||||
class ProviderUpdateRequest(BaseModel):
|
class ProviderUpdateRequest(BaseModel):
|
||||||
"""Provider 基础配置更新请求"""
|
"""Provider 基础配置更新请求"""
|
||||||
|
|
||||||
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
|
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
website: Optional[str] = Field(None, max_length=500, description="主站网站")
|
website: Optional[str] = Field(None, max_length=500, description="主站网站")
|
||||||
priority: Optional[int] = None
|
|
||||||
weight: Optional[float] = Field(None, gt=0)
|
|
||||||
provider_priority: Optional[int] = Field(None, description="提供商优先级(数字越小越优先)")
|
provider_priority: Optional[int] = Field(None, description="提供商优先级(数字越小越优先)")
|
||||||
is_active: Optional[bool] = None
|
is_active: Optional[bool] = None
|
||||||
billing_type: Optional[str] = Field(
|
billing_type: Optional[str] = Field(
|
||||||
@@ -456,9 +587,10 @@ class ProviderUpdateRequest(BaseModel):
|
|||||||
monthly_quota_usd: Optional[float] = Field(None, ge=0, description="订阅配额(美元)")
|
monthly_quota_usd: Optional[float] = Field(None, ge=0, description="订阅配额(美元)")
|
||||||
quota_reset_day: Optional[int] = Field(None, ge=1, le=31, description="配额重置日(1-31)")
|
quota_reset_day: Optional[int] = Field(None, ge=1, le=31, description="配额重置日(1-31)")
|
||||||
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
|
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
|
||||||
rpm_limit: Optional[int] = Field(
|
# 请求配置(从 Endpoint 迁移)
|
||||||
None, ge=0, description="每分钟请求数限制(NULL=无限制,0=禁止请求)"
|
timeout: Optional[int] = Field(None, ge=1, le=600, description="请求超时(秒)")
|
||||||
)
|
max_retries: Optional[int] = Field(None, ge=0, le=10, description="最大重试次数")
|
||||||
|
proxy: Optional[Dict[str, Any]] = Field(None, description="代理配置")
|
||||||
|
|
||||||
|
|
||||||
class ProviderWithEndpointsSummary(BaseModel):
|
class ProviderWithEndpointsSummary(BaseModel):
|
||||||
@@ -467,7 +599,6 @@ class ProviderWithEndpointsSummary(BaseModel):
|
|||||||
# Provider 基本信息
|
# Provider 基本信息
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
display_name: str
|
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
website: Optional[str] = None
|
website: Optional[str] = None
|
||||||
provider_priority: int = Field(default=100, description="提供商优先级(数字越小越优先)")
|
provider_priority: int = Field(default=100, description="提供商优先级(数字越小越优先)")
|
||||||
@@ -481,12 +612,10 @@ class ProviderWithEndpointsSummary(BaseModel):
|
|||||||
quota_last_reset_at: Optional[datetime] = Field(default=None, description="当前周期开始时间")
|
quota_last_reset_at: Optional[datetime] = Field(default=None, description="当前周期开始时间")
|
||||||
quota_expires_at: Optional[datetime] = Field(default=None, description="配额过期时间")
|
quota_expires_at: Optional[datetime] = Field(default=None, description="配额过期时间")
|
||||||
|
|
||||||
# RPM 限制
|
# 请求配置(从 Endpoint 迁移)
|
||||||
rpm_limit: Optional[int] = Field(
|
timeout: Optional[int] = Field(default=300, description="请求超时(秒)")
|
||||||
default=None, description="每分钟请求数限制(NULL=无限制,0=禁止请求)"
|
max_retries: Optional[int] = Field(default=2, description="最大重试次数")
|
||||||
)
|
proxy: Optional[Dict[str, Any]] = Field(default=None, description="代理配置")
|
||||||
rpm_used: Optional[int] = Field(default=None, description="当前分钟已用请求数")
|
|
||||||
rpm_reset_at: Optional[datetime] = Field(default=None, description="RPM 重置时间")
|
|
||||||
|
|
||||||
# Endpoint 统计
|
# Endpoint 统计
|
||||||
total_endpoints: int = Field(default=0, description="总 Endpoint 数量")
|
total_endpoints: int = Field(default=0, description="总 Endpoint 数量")
|
||||||
@@ -621,12 +750,8 @@ class PublicApiFormatHealthMonitor(BaseModel):
|
|||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="Usage 表生成的健康时间线(healthy/warning/unhealthy/unknown)",
|
description="Usage 表生成的健康时间线(healthy/warning/unhealthy/unknown)",
|
||||||
)
|
)
|
||||||
time_range_start: Optional[datetime] = Field(
|
time_range_start: Optional[datetime] = Field(default=None, description="时间线覆盖区间开始时间")
|
||||||
default=None, description="时间线覆盖区间开始时间"
|
time_range_end: Optional[datetime] = Field(default=None, description="时间线覆盖区间结束时间")
|
||||||
)
|
|
||||||
time_range_end: Optional[datetime] = Field(
|
|
||||||
default=None, description="时间线覆盖区间结束时间"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PublicApiFormatHealthMonitorResponse(BaseModel):
|
class PublicApiFormatHealthMonitorResponse(BaseModel):
|
||||||
|
|||||||
@@ -114,7 +114,6 @@ class ModelCatalogProviderDetail(BaseModel):
|
|||||||
|
|
||||||
provider_id: str
|
provider_id: str
|
||||||
provider_name: str
|
provider_name: str
|
||||||
provider_display_name: Optional[str]
|
|
||||||
model_id: Optional[str]
|
model_id: Optional[str]
|
||||||
target_model: str
|
target_model: str
|
||||||
input_price_per_1m: Optional[float]
|
input_price_per_1m: Optional[float]
|
||||||
@@ -312,16 +311,26 @@ class ImportFromUpstreamRequest(BaseModel):
|
|||||||
"""从上游提供商导入模型请求"""
|
"""从上游提供商导入模型请求"""
|
||||||
|
|
||||||
model_ids: List[str] = Field(..., min_length=1, description="上游模型 ID 列表")
|
model_ids: List[str] = Field(..., min_length=1, description="上游模型 ID 列表")
|
||||||
|
# 价格覆盖配置(应用于所有导入的模型)
|
||||||
|
tiered_pricing: Optional[Dict] = Field(
|
||||||
|
None,
|
||||||
|
description="阶梯计费配置(可选),格式: {tiers: [{up_to, input_price_per_1m, output_price_per_1m, ...}]}"
|
||||||
|
)
|
||||||
|
price_per_request: Optional[float] = Field(
|
||||||
|
None,
|
||||||
|
ge=0,
|
||||||
|
description="按次计费价格(可选,单位:美元)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImportFromUpstreamSuccessItem(BaseModel):
|
class ImportFromUpstreamSuccessItem(BaseModel):
|
||||||
"""导入成功的模型信息"""
|
"""导入成功的模型信息"""
|
||||||
|
|
||||||
model_id: str = Field(..., description="上游模型 ID")
|
model_id: str = Field(..., description="上游模型 ID")
|
||||||
global_model_id: str = Field(..., description="GlobalModel ID")
|
|
||||||
global_model_name: str = Field(..., description="GlobalModel 名称")
|
|
||||||
provider_model_id: str = Field(..., description="Provider Model ID")
|
provider_model_id: str = Field(..., description="Provider Model ID")
|
||||||
created_global_model: bool = Field(..., description="是否新创建了 GlobalModel")
|
global_model_id: Optional[str] = Field("", description="GlobalModel ID(如果已关联)")
|
||||||
|
global_model_name: Optional[str] = Field("", description="GlobalModel 名称(如果已关联)")
|
||||||
|
created_global_model: bool = Field(False, description="是否新创建了 GlobalModel(始终为 false)")
|
||||||
|
|
||||||
|
|
||||||
class ImportFromUpstreamErrorItem(BaseModel):
|
class ImportFromUpstreamErrorItem(BaseModel):
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ import hashlib
|
|||||||
import secrets
|
import secrets
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections import OrderedDict
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from threading import Lock
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
@@ -30,6 +32,44 @@ from src.services.cache.user_cache import UserCacheService
|
|||||||
from src.services.user.apikey import ApiKeyService
|
from src.services.user.apikey import ApiKeyService
|
||||||
|
|
||||||
|
|
||||||
|
# API Key last_used_at 更新节流配置
|
||||||
|
# 同一个 API Key 在此时间间隔内只会更新一次 last_used_at
|
||||||
|
_LAST_USED_UPDATE_INTERVAL = 60 # 秒
|
||||||
|
_LAST_USED_CACHE_MAX_SIZE = 10000 # LRU 缓存最大条目数
|
||||||
|
|
||||||
|
# 进程内缓存:记录每个 API Key 最后一次更新 last_used_at 的时间
|
||||||
|
# 使用 OrderedDict 实现 LRU,避免内存无限增长
|
||||||
|
_api_key_last_update_times: OrderedDict[str, float] = OrderedDict()
|
||||||
|
_last_update_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _should_update_last_used(api_key_id: str) -> bool:
|
||||||
|
"""判断是否应该更新 API Key 的 last_used_at
|
||||||
|
|
||||||
|
使用节流策略,同一个 Key 在指定间隔内只更新一次。
|
||||||
|
线程安全,使用 LRU 策略限制缓存大小。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True 表示应该更新,False 表示跳过
|
||||||
|
"""
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
with _last_update_lock:
|
||||||
|
last_update = _api_key_last_update_times.get(api_key_id, 0)
|
||||||
|
|
||||||
|
if now - last_update >= _LAST_USED_UPDATE_INTERVAL:
|
||||||
|
_api_key_last_update_times[api_key_id] = now
|
||||||
|
# LRU: 移到末尾(最近使用)
|
||||||
|
_api_key_last_update_times.move_to_end(api_key_id)
|
||||||
|
|
||||||
|
# 超过最大容量时,移除最旧的条目
|
||||||
|
while len(_api_key_last_update_times) > _LAST_USED_CACHE_MAX_SIZE:
|
||||||
|
_api_key_last_update_times.popitem(last=False)
|
||||||
|
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# JWT配置从config读取
|
# JWT配置从config读取
|
||||||
if not config.jwt_secret_key:
|
if not config.jwt_secret_key:
|
||||||
# 如果没有配置,生成一个随机密钥并警告
|
# 如果没有配置,生成一个随机密钥并警告
|
||||||
@@ -367,9 +407,10 @@ class AuthService:
|
|||||||
logger.warning(f"API认证失败 - 用户已禁用: {user.email}")
|
logger.warning(f"API认证失败 - 用户已禁用: {user.email}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 更新最后使用时间
|
# 更新最后使用时间(使用节流策略,减少数据库写入)
|
||||||
key_record.last_used_at = datetime.now(timezone.utc)
|
if _should_update_last_used(key_record.id):
|
||||||
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
|
key_record.last_used_at = datetime.now(timezone.utc)
|
||||||
|
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
|
||||||
|
|
||||||
api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12]
|
api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12]
|
||||||
logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp)
|
logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp)
|
||||||
|
|||||||
289
src/services/cache/aware_scheduler.py
vendored
289
src/services/cache/aware_scheduler.py
vendored
@@ -34,7 +34,7 @@ import hashlib
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from sqlalchemy.orm import Session, selectinload
|
from sqlalchemy.orm import Session, selectinload
|
||||||
|
|
||||||
@@ -80,8 +80,6 @@ class ProviderCandidate:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConcurrencySnapshot:
|
class ConcurrencySnapshot:
|
||||||
endpoint_current: int
|
|
||||||
endpoint_limit: Optional[int]
|
|
||||||
key_current: int
|
key_current: int
|
||||||
key_limit: Optional[int]
|
key_limit: Optional[int]
|
||||||
is_cached_user: bool = False
|
is_cached_user: bool = False
|
||||||
@@ -91,11 +89,9 @@ class ConcurrencySnapshot:
|
|||||||
reservation_confidence: float = 0.0
|
reservation_confidence: float = 0.0
|
||||||
|
|
||||||
def describe(self) -> str:
|
def describe(self) -> str:
|
||||||
endpoint_limit_text = str(self.endpoint_limit) if self.endpoint_limit is not None else "inf"
|
|
||||||
key_limit_text = str(self.key_limit) if self.key_limit is not None else "inf"
|
key_limit_text = str(self.key_limit) if self.key_limit is not None else "inf"
|
||||||
reservation_text = f"{self.reservation_ratio:.0%}" if self.reservation_ratio > 0 else "N/A"
|
reservation_text = f"{self.reservation_ratio:.0%}" if self.reservation_ratio > 0 else "N/A"
|
||||||
return (
|
return (
|
||||||
f"endpoint={self.endpoint_current}/{endpoint_limit_text}, "
|
|
||||||
f"key={self.key_current}/{key_limit_text}, "
|
f"key={self.key_current}/{key_limit_text}, "
|
||||||
f"cached={self.is_cached_user}, "
|
f"cached={self.is_cached_user}, "
|
||||||
f"reserve={reservation_text}({self.reservation_phase})"
|
f"reserve={reservation_text}({self.reservation_phase})"
|
||||||
@@ -121,11 +117,13 @@ class CacheAwareScheduler:
|
|||||||
PRIORITY_MODE_GLOBAL_KEY,
|
PRIORITY_MODE_GLOBAL_KEY,
|
||||||
}
|
}
|
||||||
# 调度模式常量
|
# 调度模式常量
|
||||||
SCHEDULING_MODE_FIXED_ORDER = "fixed_order" # 固定顺序模式
|
SCHEDULING_MODE_FIXED_ORDER = "fixed_order" # 固定顺序模式:严格按优先级,忽略缓存
|
||||||
SCHEDULING_MODE_CACHE_AFFINITY = "cache_affinity" # 缓存亲和模式
|
SCHEDULING_MODE_CACHE_AFFINITY = "cache_affinity" # 缓存亲和模式:优先缓存,同优先级哈希分散
|
||||||
|
SCHEDULING_MODE_LOAD_BALANCE = "load_balance" # 负载均衡模式:忽略缓存,同优先级随机轮换
|
||||||
ALLOWED_SCHEDULING_MODES = {
|
ALLOWED_SCHEDULING_MODES = {
|
||||||
SCHEDULING_MODE_FIXED_ORDER,
|
SCHEDULING_MODE_FIXED_ORDER,
|
||||||
SCHEDULING_MODE_CACHE_AFFINITY,
|
SCHEDULING_MODE_CACHE_AFFINITY,
|
||||||
|
SCHEDULING_MODE_LOAD_BALANCE,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -244,9 +242,8 @@ class CacheAwareScheduler:
|
|||||||
|
|
||||||
if not candidates:
|
if not candidates:
|
||||||
if provider_offset == 0:
|
if provider_offset == 0:
|
||||||
# 没有找到任何候选,提供友好的错误提示
|
# 没有找到任何候选,提供友好的错误提示(不暴露内部信息)
|
||||||
error_msg = f"模型 '{model_name}' 不可用"
|
raise ProviderNotAvailableException("请求的模型当前不可用")
|
||||||
raise ProviderNotAvailableException(error_msg)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
self._metrics["total_batches"] += 1
|
self._metrics["total_batches"] += 1
|
||||||
@@ -268,7 +265,6 @@ class CacheAwareScheduler:
|
|||||||
|
|
||||||
is_cached_user = bool(candidate.is_cached)
|
is_cached_user = bool(candidate.is_cached)
|
||||||
can_use, snapshot = await self._check_concurrent_available(
|
can_use, snapshot = await self._check_concurrent_available(
|
||||||
endpoint,
|
|
||||||
key,
|
key,
|
||||||
is_cached_user=is_cached_user,
|
is_cached_user=is_cached_user,
|
||||||
)
|
)
|
||||||
@@ -310,47 +306,51 @@ class CacheAwareScheduler:
|
|||||||
|
|
||||||
provider_offset += provider_batch_size
|
provider_offset += provider_batch_size
|
||||||
|
|
||||||
raise ProviderNotAvailableException(f"所有Provider的资源当前不可用 (model={model_name})")
|
raise ProviderNotAvailableException("服务暂时繁忙,请稍后重试")
|
||||||
|
|
||||||
def _get_effective_concurrent_limit(self, key: ProviderAPIKey) -> Optional[int]:
|
def _get_effective_rpm_limit(self, key: ProviderAPIKey) -> Optional[int]:
|
||||||
"""
|
"""
|
||||||
获取有效的并发限制
|
获取有效的 RPM 限制
|
||||||
|
|
||||||
新逻辑:
|
新逻辑:
|
||||||
- max_concurrent=NULL: 启用自适应,使用 learned_max_concurrent(如无学习记录则为 None)
|
- rpm_limit=NULL: 启用自适应,使用 learned_rpm_limit(如无学习记录则使用默认初始值)
|
||||||
- max_concurrent=数字: 固定限制,直接使用该值
|
- rpm_limit=数字: 固定限制,直接使用该值
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: API Key对象
|
key: API Key对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
有效的并发限制(None 表示不限制)
|
有效的 RPM 限制(None 表示不限制)
|
||||||
"""
|
"""
|
||||||
if key.max_concurrent is None:
|
if key.rpm_limit is None:
|
||||||
# 自适应模式:使用学习到的值
|
# 自适应模式:使用学习到的值
|
||||||
learned = key.learned_max_concurrent
|
learned = key.learned_rpm_limit
|
||||||
return int(learned) if learned is not None else None
|
if learned is not None:
|
||||||
|
return int(learned)
|
||||||
|
|
||||||
|
# 未学习到值时,使用默认初始限制,避免无限制打爆上游
|
||||||
|
from src.config.constants import RPMDefaults
|
||||||
|
|
||||||
|
return int(RPMDefaults.INITIAL_LIMIT)
|
||||||
else:
|
else:
|
||||||
# 固定限制模式
|
# 固定限制模式
|
||||||
return int(key.max_concurrent)
|
return int(key.rpm_limit)
|
||||||
|
|
||||||
async def _check_concurrent_available(
|
async def _check_concurrent_available(
|
||||||
self,
|
self,
|
||||||
endpoint: ProviderEndpoint,
|
|
||||||
key: ProviderAPIKey,
|
key: ProviderAPIKey,
|
||||||
is_cached_user: bool = False,
|
is_cached_user: bool = False,
|
||||||
) -> Tuple[bool, ConcurrencySnapshot]:
|
) -> Tuple[bool, ConcurrencySnapshot]:
|
||||||
"""
|
"""
|
||||||
检查并发是否可用(使用动态预留机制)
|
检查 RPM 限制是否可用(使用动态预留机制)
|
||||||
|
|
||||||
核心逻辑 - 动态缓存预留机制:
|
核心逻辑 - 动态缓存预留机制:
|
||||||
- 总槽位: 有效并发限制(固定值或学习到的值)
|
- 总槽位: 有效 RPM 限制(固定值或学习到的值)
|
||||||
- 预留比例: 由 AdaptiveReservationManager 根据置信度和负载动态计算
|
- 预留比例: 由 AdaptiveReservationManager 根据置信度和负载动态计算
|
||||||
- 缓存用户可用: 全部槽位
|
- 缓存用户可用: 全部槽位
|
||||||
- 新用户可用: 总槽位 × (1 - 动态预留比例)
|
- 新用户可用: 总槽位 × (1 - 动态预留比例)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
endpoint: ProviderEndpoint对象
|
|
||||||
key: ProviderAPIKey对象
|
key: ProviderAPIKey对象
|
||||||
is_cached_user: 是否是缓存用户
|
is_cached_user: 是否是缓存用户
|
||||||
|
|
||||||
@@ -358,7 +358,7 @@ class CacheAwareScheduler:
|
|||||||
(是否可用, 并发快照)
|
(是否可用, 并发快照)
|
||||||
"""
|
"""
|
||||||
# 获取有效的并发限制
|
# 获取有效的并发限制
|
||||||
effective_key_limit = self._get_effective_concurrent_limit(key)
|
effective_key_limit = self._get_effective_rpm_limit(key)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f" -> 并发检查: _concurrency_manager={self._concurrency_manager is not None}, "
|
f" -> 并发检查: _concurrency_manager={self._concurrency_manager is not None}, "
|
||||||
@@ -369,33 +369,23 @@ class CacheAwareScheduler:
|
|||||||
# 并发管理器不可用,直接返回True
|
# 并发管理器不可用,直接返回True
|
||||||
logger.debug(f" -> 无并发管理器,直接通过")
|
logger.debug(f" -> 无并发管理器,直接通过")
|
||||||
snapshot = ConcurrencySnapshot(
|
snapshot = ConcurrencySnapshot(
|
||||||
endpoint_current=0,
|
|
||||||
endpoint_limit=(
|
|
||||||
int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None
|
|
||||||
),
|
|
||||||
key_current=0,
|
key_current=0,
|
||||||
key_limit=effective_key_limit,
|
key_limit=effective_key_limit,
|
||||||
is_cached_user=is_cached_user,
|
is_cached_user=is_cached_user,
|
||||||
)
|
)
|
||||||
return True, snapshot
|
return True, snapshot
|
||||||
|
|
||||||
# 获取当前并发数
|
# 获取当前 RPM 计数
|
||||||
endpoint_count, key_count = await self._concurrency_manager.get_current_concurrency(
|
key_count = await self._concurrency_manager.get_key_rpm_count(
|
||||||
endpoint_id=str(endpoint.id),
|
|
||||||
key_id=str(key.id),
|
key_id=str(key.id),
|
||||||
)
|
)
|
||||||
|
|
||||||
can_use = True
|
can_use = True
|
||||||
|
|
||||||
# 检查Endpoint级别限制
|
|
||||||
if endpoint.max_concurrent is not None:
|
|
||||||
if endpoint_count >= endpoint.max_concurrent:
|
|
||||||
can_use = False
|
|
||||||
|
|
||||||
# 计算动态预留比例
|
# 计算动态预留比例
|
||||||
reservation_result = self._reservation_manager.calculate_reservation(
|
reservation_result = self._reservation_manager.calculate_reservation(
|
||||||
key=key,
|
key=key,
|
||||||
current_concurrent=key_count,
|
current_usage=key_count,
|
||||||
effective_limit=effective_key_limit,
|
effective_limit=effective_key_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -438,7 +428,8 @@ class CacheAwareScheduler:
|
|||||||
# 使用 max 确保至少有 1 个槽位可用
|
# 使用 max 确保至少有 1 个槽位可用
|
||||||
import math
|
import math
|
||||||
|
|
||||||
available_for_new = max(1, math.ceil(effective_key_limit * (1 - reservation_ratio)))
|
# 与 ConcurrencyManager 的 Lua 脚本保持一致:使用 floor 计算新用户可用槽位
|
||||||
|
available_for_new = max(1, math.floor(effective_key_limit * (1 - reservation_ratio)))
|
||||||
if key_count >= available_for_new:
|
if key_count >= available_for_new:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Key {key.id[:8]}... 新用户配额已满 "
|
f"Key {key.id[:8]}... 新用户配额已满 "
|
||||||
@@ -458,8 +449,6 @@ class CacheAwareScheduler:
|
|||||||
key_limit_for_snapshot = None
|
key_limit_for_snapshot = None
|
||||||
|
|
||||||
snapshot = ConcurrencySnapshot(
|
snapshot = ConcurrencySnapshot(
|
||||||
endpoint_current=endpoint_count,
|
|
||||||
endpoint_limit=endpoint.max_concurrent,
|
|
||||||
key_current=key_count,
|
key_current=key_count,
|
||||||
key_limit=key_limit_for_snapshot,
|
key_limit=key_limit_for_snapshot,
|
||||||
is_cached_user=is_cached_user,
|
is_cached_user=is_cached_user,
|
||||||
@@ -473,7 +462,7 @@ class CacheAwareScheduler:
|
|||||||
def _get_effective_restrictions(
|
def _get_effective_restrictions(
|
||||||
self,
|
self,
|
||||||
user_api_key: Optional[ApiKey],
|
user_api_key: Optional[ApiKey],
|
||||||
) -> Dict[str, Optional[set]]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取有效的访问限制(合并 ApiKey 和 User 的限制)
|
获取有效的访问限制(合并 ApiKey 和 User 的限制)
|
||||||
|
|
||||||
@@ -534,7 +523,10 @@ class CacheAwareScheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 合并 allowed_models
|
# 合并 allowed_models
|
||||||
result["allowed_models"] = merge_restrictions(
|
# allowed_models 支持 list/dict 两种结构,不能转成 set 否则会导致权限校验失效
|
||||||
|
from src.core.model_permissions import merge_allowed_models
|
||||||
|
|
||||||
|
result["allowed_models"] = merge_allowed_models(
|
||||||
user_api_key.allowed_models, user.allowed_models if user else None
|
user_api_key.allowed_models, user.allowed_models if user else None
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -615,22 +607,25 @@ class CacheAwareScheduler:
|
|||||||
)
|
)
|
||||||
return [], global_model_id
|
return [], global_model_id
|
||||||
|
|
||||||
# 0.2 检查模型是否被允许
|
# 0.2 检查模型是否被允许(支持简单列表和按格式字典两种模式)
|
||||||
if allowed_models is not None:
|
from src.core.model_permissions import check_model_allowed, get_allowed_models_preview
|
||||||
if (
|
|
||||||
requested_model_name not in allowed_models
|
if not check_model_allowed(
|
||||||
and resolved_model_name not in allowed_models
|
model_name=requested_model_name,
|
||||||
):
|
allowed_models=allowed_models,
|
||||||
resolved_note = (
|
api_format=target_format.value,
|
||||||
f" (解析为 {resolved_model_name})"
|
resolved_model_name=resolved_model_name,
|
||||||
if resolved_model_name != requested_model_name
|
):
|
||||||
else ""
|
resolved_note = (
|
||||||
)
|
f" (解析为 {resolved_model_name})"
|
||||||
logger.debug(
|
if resolved_model_name != requested_model_name
|
||||||
f"用户/API Key 不允许使用模型 {requested_model_name}{resolved_note}, "
|
else ""
|
||||||
f"允许的模型: {allowed_models}"
|
)
|
||||||
)
|
logger.debug(
|
||||||
return [], global_model_id
|
f"用户/API Key 不允许使用模型 {requested_model_name}{resolved_note}, "
|
||||||
|
f"允许的模型: {get_allowed_models_preview(allowed_models)}"
|
||||||
|
)
|
||||||
|
return [], global_model_id
|
||||||
|
|
||||||
# 1. 查询 Providers
|
# 1. 查询 Providers
|
||||||
providers = self._query_providers(
|
providers = self._query_providers(
|
||||||
@@ -680,8 +675,9 @@ class CacheAwareScheduler:
|
|||||||
f"(api_format={target_format.value}, model={model_name})"
|
f"(api_format={target_format.value}, model={model_name})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. 应用缓存亲和性排序(仅在缓存亲和模式下启用)
|
# 4. 根据调度模式应用不同的排序策略
|
||||||
if self.scheduling_mode == self.SCHEDULING_MODE_CACHE_AFFINITY:
|
if self.scheduling_mode == self.SCHEDULING_MODE_CACHE_AFFINITY:
|
||||||
|
# 缓存亲和模式:优先使用缓存的,同优先级内哈希分散
|
||||||
if affinity_key and candidates:
|
if affinity_key and candidates:
|
||||||
candidates = await self._apply_cache_affinity(
|
candidates = await self._apply_cache_affinity(
|
||||||
candidates=candidates,
|
candidates=candidates,
|
||||||
@@ -689,8 +685,13 @@ class CacheAwareScheduler:
|
|||||||
api_format=target_format,
|
api_format=target_format,
|
||||||
global_model_id=global_model_id,
|
global_model_id=global_model_id,
|
||||||
)
|
)
|
||||||
|
elif self.scheduling_mode == self.SCHEDULING_MODE_LOAD_BALANCE:
|
||||||
|
# 负载均衡模式:忽略缓存,同优先级内随机轮换
|
||||||
|
candidates = self._apply_load_balance(candidates)
|
||||||
|
for candidate in candidates:
|
||||||
|
candidate.is_cached = False
|
||||||
else:
|
else:
|
||||||
# 固定顺序模式:标记所有候选为非缓存
|
# 固定顺序模式:严格按优先级,忽略缓存
|
||||||
for candidate in candidates:
|
for candidate in candidates:
|
||||||
candidate.is_cached = False
|
candidate.is_cached = False
|
||||||
|
|
||||||
@@ -716,8 +717,11 @@ class CacheAwareScheduler:
|
|||||||
provider_query = (
|
provider_query = (
|
||||||
db.query(Provider)
|
db.query(Provider)
|
||||||
.options(
|
.options(
|
||||||
selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys),
|
# 预加载 Provider 级别的 api_keys
|
||||||
# 同时加载 models 和 global_model 关系,以便 get_effective_* 方法能正确继承默认值
|
selectinload(Provider.api_keys),
|
||||||
|
# 预加载 endpoints(用于按 api_format 选择请求配置)
|
||||||
|
selectinload(Provider.endpoints),
|
||||||
|
# 同时加载 models 和 global_model 关系
|
||||||
selectinload(Provider.models).selectinload(Model.global_model),
|
selectinload(Provider.models).selectinload(Model.global_model),
|
||||||
)
|
)
|
||||||
.filter(Provider.is_active == True)
|
.filter(Provider.is_active == True)
|
||||||
@@ -844,6 +848,7 @@ class CacheAwareScheduler:
|
|||||||
def _check_key_availability(
|
def _check_key_availability(
|
||||||
self,
|
self,
|
||||||
key: ProviderAPIKey,
|
key: ProviderAPIKey,
|
||||||
|
api_format: Optional[str],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
capability_requirements: Optional[Dict[str, bool]] = None,
|
capability_requirements: Optional[Dict[str, bool]] = None,
|
||||||
resolved_model_name: Optional[str] = None,
|
resolved_model_name: Optional[str] = None,
|
||||||
@@ -863,20 +868,24 @@ class CacheAwareScheduler:
|
|||||||
Returns:
|
Returns:
|
||||||
(is_available, skip_reason)
|
(is_available, skip_reason)
|
||||||
"""
|
"""
|
||||||
# 检查熔断器状态(使用详细状态方法获取更丰富的跳过原因)
|
# 检查熔断器状态(使用详细状态方法获取更丰富的跳过原因,按 API 格式)
|
||||||
is_available, circuit_reason = health_monitor.get_circuit_breaker_status(key)
|
is_available, circuit_reason = health_monitor.get_circuit_breaker_status(
|
||||||
|
key, api_format=api_format
|
||||||
|
)
|
||||||
if not is_available:
|
if not is_available:
|
||||||
return False, circuit_reason or "熔断器已打开"
|
return False, circuit_reason or "熔断器已打开"
|
||||||
|
|
||||||
# 模型权限检查:使用 allowed_models 白名单
|
# 模型权限检查:使用 allowed_models 白名单(支持简单列表和按格式字典两种模式)
|
||||||
# None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型
|
# None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型
|
||||||
if key.allowed_models is not None and (
|
from src.core.model_permissions import check_model_allowed, get_allowed_models_preview
|
||||||
model_name not in key.allowed_models
|
|
||||||
and (not resolved_model_name or resolved_model_name not in key.allowed_models)
|
if not check_model_allowed(
|
||||||
|
model_name=model_name,
|
||||||
|
allowed_models=key.allowed_models,
|
||||||
|
api_format=api_format,
|
||||||
|
resolved_model_name=resolved_model_name,
|
||||||
):
|
):
|
||||||
allowed_preview = ", ".join(key.allowed_models[:3]) if key.allowed_models else "(无)"
|
return False, f"模型权限不匹配(允许: {get_allowed_models_preview(key.allowed_models)})"
|
||||||
suffix = "..." if len(key.allowed_models) > 3 else ""
|
|
||||||
return False, f"模型权限不匹配(允许: {allowed_preview}{suffix})"
|
|
||||||
|
|
||||||
# Key 级别的能力匹配检查
|
# Key 级别的能力匹配检查
|
||||||
# 注意:模型级别的能力检查已在 _check_model_support 中完成
|
# 注意:模型级别的能力检查已在 _check_model_support 中完成
|
||||||
@@ -906,6 +915,8 @@ class CacheAwareScheduler:
|
|||||||
"""
|
"""
|
||||||
构建候选列表
|
构建候选列表
|
||||||
|
|
||||||
|
Key 直属 Provider,通过 api_formats 筛选符合目标格式的 Key。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
providers: Provider 列表
|
providers: Provider 列表
|
||||||
@@ -921,10 +932,10 @@ class CacheAwareScheduler:
|
|||||||
候选列表
|
候选列表
|
||||||
"""
|
"""
|
||||||
candidates: List[ProviderCandidate] = []
|
candidates: List[ProviderCandidate] = []
|
||||||
|
target_format_str = target_format.value
|
||||||
|
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
# 检查模型支持(同时检查流式支持和模型能力需求)
|
# 检查模型支持(同时检查流式支持和模型能力需求)
|
||||||
# 模型能力检查在 Provider 级别进行,如果模型不支持所需能力,整个 Provider 被跳过
|
|
||||||
supports_model, skip_reason, _model_caps = await self._check_model_support(
|
supports_model, skip_reason, _model_caps = await self._check_model_support(
|
||||||
db, provider, model_name, is_stream, capability_requirements
|
db, provider, model_name, is_stream, capability_requirements
|
||||||
)
|
)
|
||||||
@@ -932,49 +943,63 @@ class CacheAwareScheduler:
|
|||||||
logger.debug(f"Provider {provider.name} 不支持模型 {model_name}: {skip_reason}")
|
logger.debug(f"Provider {provider.name} 不支持模型 {model_name}: {skip_reason}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# 查找目标格式对应的 Endpoint(获取请求配置)
|
||||||
|
target_endpoint = None
|
||||||
for endpoint in provider.endpoints:
|
for endpoint in provider.endpoints:
|
||||||
# endpoint.api_format 是字符串,target_format 是枚举
|
|
||||||
endpoint_format_str = (
|
endpoint_format_str = (
|
||||||
endpoint.api_format
|
endpoint.api_format
|
||||||
if isinstance(endpoint.api_format, str)
|
if isinstance(endpoint.api_format, str)
|
||||||
else endpoint.api_format.value
|
else endpoint.api_format.value
|
||||||
)
|
)
|
||||||
if not endpoint.is_active or endpoint_format_str != target_format.value:
|
if endpoint.is_active and endpoint_format_str == target_format_str:
|
||||||
continue
|
target_endpoint = endpoint
|
||||||
|
break
|
||||||
|
|
||||||
# 获取活跃的 Key 并按 internal_priority + 负载均衡排序
|
if not target_endpoint:
|
||||||
active_keys = [key for key in endpoint.api_keys if key.is_active]
|
logger.debug(f"Provider {provider.name} 没有活跃的 {target_format_str} 端点")
|
||||||
# 检查是否所有 Key 都是 TTL=0(轮换模式)
|
continue
|
||||||
# 如果所有 Key 的 cache_ttl_minutes 都是 0 或 None,则使用随机排序
|
|
||||||
use_random = all(
|
|
||||||
(key.cache_ttl_minutes or 0) == 0 for key in active_keys
|
|
||||||
) if active_keys else False
|
|
||||||
if use_random and len(active_keys) > 1:
|
|
||||||
logger.debug(
|
|
||||||
f" Endpoint {endpoint.id[:8]}... 启用 Key 轮换模式 (TTL=0, {len(active_keys)} keys)"
|
|
||||||
)
|
|
||||||
keys = self._shuffle_keys_by_internal_priority(active_keys, affinity_key, use_random)
|
|
||||||
|
|
||||||
for key in keys:
|
# Key 直属 Provider,通过 api_formats 筛选
|
||||||
# Key 级别的能力检查(模型级别的能力检查已在上面完成)
|
active_keys = [
|
||||||
is_available, skip_reason = self._check_key_availability(
|
key for key in provider.api_keys
|
||||||
key,
|
if key.is_active and target_format_str in (key.api_formats or [])
|
||||||
model_name,
|
]
|
||||||
capability_requirements,
|
|
||||||
resolved_model_name=resolved_model_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
candidate = ProviderCandidate(
|
if not active_keys:
|
||||||
provider=provider,
|
logger.debug(f"Provider {provider.name} 没有支持 {target_format_str} 的活跃 Key")
|
||||||
endpoint=endpoint,
|
continue
|
||||||
key=key,
|
|
||||||
is_skipped=not is_available,
|
|
||||||
skip_reason=skip_reason,
|
|
||||||
)
|
|
||||||
candidates.append(candidate)
|
|
||||||
|
|
||||||
if max_candidates and len(candidates) >= max_candidates:
|
# 检查是否所有 Key 都是 TTL=0(轮换模式)
|
||||||
return candidates
|
use_random = all(
|
||||||
|
(key.cache_ttl_minutes or 0) == 0 for key in active_keys
|
||||||
|
) if active_keys else False
|
||||||
|
if use_random and len(active_keys) > 1:
|
||||||
|
logger.debug(
|
||||||
|
f" Provider {provider.name} 启用 Key 轮换模式 (TTL=0, {len(active_keys)} keys)"
|
||||||
|
)
|
||||||
|
keys = self._shuffle_keys_by_internal_priority(active_keys, affinity_key, use_random)
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
# Key 级别的能力检查
|
||||||
|
is_available, skip_reason = self._check_key_availability(
|
||||||
|
key,
|
||||||
|
target_format_str,
|
||||||
|
model_name,
|
||||||
|
capability_requirements,
|
||||||
|
resolved_model_name=resolved_model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
candidate = ProviderCandidate(
|
||||||
|
provider=provider,
|
||||||
|
endpoint=target_endpoint,
|
||||||
|
key=key,
|
||||||
|
is_skipped=not is_available,
|
||||||
|
skip_reason=skip_reason,
|
||||||
|
)
|
||||||
|
candidates.append(candidate)
|
||||||
|
|
||||||
|
if max_candidates and len(candidates) >= max_candidates:
|
||||||
|
return candidates
|
||||||
|
|
||||||
return candidates
|
return candidates
|
||||||
|
|
||||||
@@ -1163,6 +1188,56 @@ class CacheAwareScheduler:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _apply_load_balance(
|
||||||
|
self, candidates: List[ProviderCandidate]
|
||||||
|
) -> List[ProviderCandidate]:
|
||||||
|
"""
|
||||||
|
负载均衡模式:同优先级内随机轮换
|
||||||
|
|
||||||
|
排序逻辑:
|
||||||
|
1. 按优先级分组(provider_priority, internal_priority 或 global_priority)
|
||||||
|
2. 同优先级组内随机打乱
|
||||||
|
3. 不考虑缓存亲和性
|
||||||
|
"""
|
||||||
|
if not candidates:
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
priority_groups: Dict[tuple, List[ProviderCandidate]] = defaultdict(list)
|
||||||
|
|
||||||
|
# 根据优先级模式选择分组方式
|
||||||
|
if self.priority_mode == self.PRIORITY_MODE_GLOBAL_KEY:
|
||||||
|
# 全局 Key 优先模式:按 global_priority 分组
|
||||||
|
for candidate in candidates:
|
||||||
|
global_priority = (
|
||||||
|
candidate.key.global_priority
|
||||||
|
if candidate.key and candidate.key.global_priority is not None
|
||||||
|
else 999999
|
||||||
|
)
|
||||||
|
priority_groups[(global_priority,)].append(candidate)
|
||||||
|
else:
|
||||||
|
# 提供商优先模式:按 (provider_priority, internal_priority) 分组
|
||||||
|
for candidate in candidates:
|
||||||
|
key = (
|
||||||
|
candidate.provider.provider_priority or 999999,
|
||||||
|
candidate.key.internal_priority if candidate.key else 999999,
|
||||||
|
)
|
||||||
|
priority_groups[key].append(candidate)
|
||||||
|
|
||||||
|
result: List[ProviderCandidate] = []
|
||||||
|
for priority in sorted(priority_groups.keys()):
|
||||||
|
group = priority_groups[priority]
|
||||||
|
if len(group) > 1:
|
||||||
|
# 同优先级内随机打乱
|
||||||
|
shuffled = list(group)
|
||||||
|
random.shuffle(shuffled)
|
||||||
|
result.extend(shuffled)
|
||||||
|
else:
|
||||||
|
result.extend(group)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def _shuffle_keys_by_internal_priority(
|
def _shuffle_keys_by_internal_priority(
|
||||||
self,
|
self,
|
||||||
keys: List[ProviderAPIKey],
|
keys: List[ProviderAPIKey],
|
||||||
|
|||||||
185
src/services/cache/provider_cache.py
vendored
Normal file
185
src/services/cache/provider_cache.py
vendored
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
"""
|
||||||
|
Provider 缓存服务 - 减少 Provider 和 ProviderAPIKey 查询
|
||||||
|
|
||||||
|
用于缓存 Provider 的 billing_type 和 ProviderAPIKey 的 rate_multiplier,
|
||||||
|
这些数据在 UsageService.record_usage() 中被频繁查询但变化不频繁。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from src.config.constants import CacheTTL
|
||||||
|
from src.core.cache_service import CacheService
|
||||||
|
from src.core.enums import ProviderBillingType
|
||||||
|
from src.core.logger import logger
|
||||||
|
from src.models.database import Provider, ProviderAPIKey
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderCacheService:
|
||||||
|
"""Provider 缓存服务
|
||||||
|
|
||||||
|
提供 Provider 和 ProviderAPIKey 的缓存查询功能,减少数据库访问。
|
||||||
|
主要用于 UsageService 中获取费率倍数和计费类型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
CACHE_TTL = CacheTTL.PROVIDER # 5 分钟
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_provider_api_key_rate_multiplier(
|
||||||
|
db: Session, provider_api_key_id: str, api_format: Optional[str] = None
|
||||||
|
) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
获取 ProviderAPIKey 的 rate_multiplier(带缓存)
|
||||||
|
|
||||||
|
优先返回指定 API 格式的倍率,如果没有则返回默认倍率。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
provider_api_key_id: ProviderAPIKey ID
|
||||||
|
api_format: API 格式(可选),如 "CLAUDE"、"OPENAI"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
rate_multiplier 或 None(如果找不到)
|
||||||
|
"""
|
||||||
|
# 缓存键包含 api_format
|
||||||
|
format_suffix = api_format.upper() if api_format else "default"
|
||||||
|
cache_key = f"provider_api_key:rate_multiplier:{provider_api_key_id}:{format_suffix}"
|
||||||
|
|
||||||
|
# 1. 尝试从缓存获取
|
||||||
|
cached_data = await CacheService.get(cache_key)
|
||||||
|
if cached_data is not None:
|
||||||
|
logger.debug(f"ProviderAPIKey rate_multiplier 缓存命中: {provider_api_key_id[:8]}... format={format_suffix}")
|
||||||
|
# 缓存的 "NOT_FOUND" 表示数据库中不存在
|
||||||
|
if cached_data == "NOT_FOUND":
|
||||||
|
return None
|
||||||
|
return float(cached_data)
|
||||||
|
|
||||||
|
# 2. 缓存未命中,查询数据库
|
||||||
|
provider_key = (
|
||||||
|
db.query(ProviderAPIKey.rate_multiplier, ProviderAPIKey.rate_multipliers)
|
||||||
|
.filter(ProviderAPIKey.id == provider_api_key_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 计算倍率并写入缓存
|
||||||
|
if provider_key:
|
||||||
|
# 优先使用 rate_multipliers[api_format],回退到 rate_multiplier
|
||||||
|
rate_multiplier = provider_key.rate_multiplier or 1.0
|
||||||
|
if api_format and provider_key.rate_multipliers:
|
||||||
|
format_upper = api_format.upper()
|
||||||
|
if format_upper in provider_key.rate_multipliers:
|
||||||
|
rate_multiplier = provider_key.rate_multipliers[format_upper]
|
||||||
|
|
||||||
|
await CacheService.set(
|
||||||
|
cache_key, rate_multiplier, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||||
|
)
|
||||||
|
logger.debug(f"ProviderAPIKey rate_multiplier 已缓存: {provider_api_key_id[:8]}... format={format_suffix} value={rate_multiplier}")
|
||||||
|
return rate_multiplier
|
||||||
|
else:
|
||||||
|
# 缓存负结果
|
||||||
|
await CacheService.set(
|
||||||
|
cache_key, "NOT_FOUND", ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_provider_billing_type(
|
||||||
|
db: Session, provider_id: str
|
||||||
|
) -> Optional[ProviderBillingType]:
|
||||||
|
"""
|
||||||
|
获取 Provider 的 billing_type(带缓存)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
provider_id: Provider ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
billing_type 或 None(如果找不到)
|
||||||
|
"""
|
||||||
|
cache_key = f"provider:billing_type:{provider_id}"
|
||||||
|
|
||||||
|
# 1. 尝试从缓存获取
|
||||||
|
cached_data = await CacheService.get(cache_key)
|
||||||
|
if cached_data is not None:
|
||||||
|
logger.debug(f"Provider billing_type 缓存命中: {provider_id[:8]}...")
|
||||||
|
if cached_data == "NOT_FOUND":
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return ProviderBillingType(cached_data)
|
||||||
|
except ValueError:
|
||||||
|
# 缓存值无效,删除并重新查询
|
||||||
|
await CacheService.delete(cache_key)
|
||||||
|
|
||||||
|
# 2. 缓存未命中,查询数据库
|
||||||
|
provider = (
|
||||||
|
db.query(Provider.billing_type).filter(Provider.id == provider_id).first()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 写入缓存
|
||||||
|
if provider:
|
||||||
|
billing_type = provider.billing_type
|
||||||
|
await CacheService.set(
|
||||||
|
cache_key, billing_type.value, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||||
|
)
|
||||||
|
logger.debug(f"Provider billing_type 已缓存: {provider_id[:8]}...")
|
||||||
|
return billing_type
|
||||||
|
else:
|
||||||
|
# 缓存负结果
|
||||||
|
await CacheService.set(
|
||||||
|
cache_key, "NOT_FOUND", ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_rate_multiplier_and_free_tier(
|
||||||
|
db: Session,
|
||||||
|
provider_api_key_id: Optional[str],
|
||||||
|
provider_id: Optional[str],
|
||||||
|
api_format: Optional[str] = None,
|
||||||
|
) -> Tuple[float, bool]:
|
||||||
|
"""
|
||||||
|
获取费率倍数和是否免费套餐(带缓存)
|
||||||
|
|
||||||
|
这是 UsageService._get_rate_multiplier_and_free_tier 的缓存版本。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
provider_api_key_id: ProviderAPIKey ID(可选)
|
||||||
|
provider_id: Provider ID(可选)
|
||||||
|
api_format: API 格式(可选),用于获取按格式配置的倍率
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(rate_multiplier, is_free_tier) 元组
|
||||||
|
"""
|
||||||
|
actual_rate_multiplier = 1.0
|
||||||
|
is_free_tier = False
|
||||||
|
|
||||||
|
# 获取费率倍数(支持按 API 格式查询)
|
||||||
|
if provider_api_key_id:
|
||||||
|
rate_multiplier = await ProviderCacheService.get_provider_api_key_rate_multiplier(
|
||||||
|
db, provider_api_key_id, api_format
|
||||||
|
)
|
||||||
|
if rate_multiplier is not None:
|
||||||
|
actual_rate_multiplier = rate_multiplier
|
||||||
|
|
||||||
|
# 获取计费类型
|
||||||
|
if provider_id:
|
||||||
|
billing_type = await ProviderCacheService.get_provider_billing_type(db, provider_id)
|
||||||
|
if billing_type == ProviderBillingType.FREE_TIER:
|
||||||
|
is_free_tier = True
|
||||||
|
|
||||||
|
return actual_rate_multiplier, is_free_tier
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def invalidate_provider_api_key_cache(provider_api_key_id: str) -> None:
|
||||||
|
"""清除 ProviderAPIKey 缓存(包括所有 API 格式的缓存)"""
|
||||||
|
# 使用模式匹配删除所有格式的缓存
|
||||||
|
await CacheService.delete_pattern(f"provider_api_key:rate_multiplier:{provider_api_key_id}:*")
|
||||||
|
logger.debug(f"ProviderAPIKey 缓存已清除: {provider_api_key_id[:8]}...")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def invalidate_provider_cache(provider_id: str) -> None:
|
||||||
|
"""清除 Provider 缓存"""
|
||||||
|
await CacheService.delete(f"provider:billing_type:{provider_id}")
|
||||||
|
logger.debug(f"Provider 缓存已清除: {provider_id[:8]}...")
|
||||||
@@ -70,20 +70,21 @@ class EndpointHealthService:
|
|||||||
db.query(ProviderEndpoint).join(Provider).filter(Provider.is_active.is_(True)).all()
|
db.query(ProviderEndpoint).join(Provider).filter(Provider.is_active.is_(True)).all()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 收集所有 endpoint_ids
|
# 收集所有 provider_ids
|
||||||
all_endpoint_ids = [ep.id for ep in endpoints]
|
all_provider_ids = list(set(ep.provider_id for ep in endpoints))
|
||||||
|
|
||||||
# 批量查询所有密钥
|
# 批量查询所有密钥(通过 provider_id 关联)
|
||||||
all_keys = (
|
all_keys = (
|
||||||
db.query(ProviderAPIKey)
|
db.query(ProviderAPIKey)
|
||||||
.filter(ProviderAPIKey.endpoint_id.in_(all_endpoint_ids))
|
.filter(ProviderAPIKey.provider_id.in_(all_provider_ids))
|
||||||
.all()
|
.all()
|
||||||
) if all_endpoint_ids else []
|
) if all_provider_ids else []
|
||||||
|
|
||||||
# 按 endpoint_id 分组密钥
|
# 按 api_format 分组密钥(通过 api_formats 字段)
|
||||||
keys_by_endpoint: Dict[str, List[ProviderAPIKey]] = defaultdict(list)
|
keys_by_format: Dict[str, List[ProviderAPIKey]] = defaultdict(list)
|
||||||
for key in all_keys:
|
for key in all_keys:
|
||||||
keys_by_endpoint[key.endpoint_id].append(key)
|
for fmt in (key.api_formats or []):
|
||||||
|
keys_by_format[fmt].append(key)
|
||||||
|
|
||||||
# 按 API 格式聚合
|
# 按 API 格式聚合
|
||||||
format_stats = defaultdict(
|
format_stats = defaultdict(
|
||||||
@@ -106,18 +107,36 @@ class EndpointHealthService:
|
|||||||
format_stats[api_format]["endpoint_ids"].append(ep.id)
|
format_stats[api_format]["endpoint_ids"].append(ep.id)
|
||||||
format_stats[api_format]["provider_ids"].add(ep.provider_id)
|
format_stats[api_format]["provider_ids"].add(ep.provider_id)
|
||||||
|
|
||||||
# 从预加载的密钥中获取
|
# 统计每个格式的密钥(直接从 keys_by_format 获取)
|
||||||
keys = keys_by_endpoint.get(ep.id, [])
|
for api_format, keys in keys_by_format.items():
|
||||||
format_stats[api_format]["total_keys"] += len(keys)
|
if api_format not in format_stats:
|
||||||
|
# 如果有 Key 但没有对应的 Endpoint,跳过
|
||||||
|
continue
|
||||||
|
|
||||||
# 统计活跃密钥和健康度
|
# 去重(同一个 Key 可能支持多个格式)
|
||||||
if ep.is_active:
|
seen_key_ids = set()
|
||||||
for key in keys:
|
unique_keys = []
|
||||||
format_stats[api_format]["key_ids"].append(key.id)
|
for key in keys:
|
||||||
if key.is_active and not key.circuit_breaker_open:
|
if key.id not in seen_key_ids:
|
||||||
format_stats[api_format]["active_keys"] += 1
|
seen_key_ids.add(key.id)
|
||||||
health_score = key.health_score if key.health_score is not None else 1.0
|
unique_keys.append(key)
|
||||||
format_stats[api_format]["health_scores"].append(health_score)
|
|
||||||
|
format_stats[api_format]["total_keys"] = len(unique_keys)
|
||||||
|
|
||||||
|
for key in unique_keys:
|
||||||
|
format_stats[api_format]["key_ids"].append(key.id)
|
||||||
|
# 检查该格式的熔断器状态
|
||||||
|
circuit_by_format = key.circuit_breaker_by_format or {}
|
||||||
|
format_circuit = circuit_by_format.get(api_format, {})
|
||||||
|
is_circuit_open = format_circuit.get("open", False)
|
||||||
|
|
||||||
|
if key.is_active and not is_circuit_open:
|
||||||
|
format_stats[api_format]["active_keys"] += 1
|
||||||
|
# 获取该格式的健康度
|
||||||
|
health_by_format = key.health_by_format or {}
|
||||||
|
format_health = health_by_format.get(api_format, {})
|
||||||
|
health_score = float(format_health.get("health_score") or 1.0)
|
||||||
|
format_stats[api_format]["health_scores"].append(health_score)
|
||||||
|
|
||||||
# 批量生成所有格式的时间线数据
|
# 批量生成所有格式的时间线数据
|
||||||
all_key_ids = []
|
all_key_ids = []
|
||||||
@@ -372,7 +391,7 @@ class EndpointHealthService:
|
|||||||
segments: int = 100,
|
segments: int = 100,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
从真实使用记录生成时间线数据(兼容旧接口,使用批量查询优化)
|
从真实使用记录生成时间线数据(使用批量查询优化)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
@@ -391,13 +410,34 @@ class EndpointHealthService:
|
|||||||
"time_range_end": None,
|
"time_range_end": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 先查询该 API 格式下的所有密钥
|
# 基于 endpoint_ids 反推 provider_ids 与 api_format,再选出支持该格式的 keys
|
||||||
key_ids = [
|
endpoint_rows = (
|
||||||
k.id
|
db.query(ProviderEndpoint.provider_id, ProviderEndpoint.api_format)
|
||||||
for k in db.query(ProviderAPIKey.id)
|
.filter(ProviderEndpoint.id.in_(endpoint_ids))
|
||||||
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
|
|
||||||
.all()
|
.all()
|
||||||
]
|
)
|
||||||
|
|
||||||
|
if not endpoint_rows:
|
||||||
|
return {
|
||||||
|
"timeline": ["unknown"] * 100,
|
||||||
|
"time_range_start": None,
|
||||||
|
"time_range_end": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
provider_ids = {str(pid) for pid, _fmt in endpoint_rows}
|
||||||
|
# 同一调用中 endpoint_ids 来自同一 api_format(上层已按格式分组)
|
||||||
|
api_format = (
|
||||||
|
endpoint_rows[0][1].value
|
||||||
|
if hasattr(endpoint_rows[0][1], "value")
|
||||||
|
else str(endpoint_rows[0][1])
|
||||||
|
)
|
||||||
|
|
||||||
|
keys = (
|
||||||
|
db.query(ProviderAPIKey.id, ProviderAPIKey.api_formats)
|
||||||
|
.filter(ProviderAPIKey.provider_id.in_(provider_ids))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
key_ids = [str(key_id) for key_id, formats in keys if api_format in (formats or [])]
|
||||||
|
|
||||||
if not key_ids:
|
if not key_ids:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
"""
|
"""
|
||||||
健康监控器 - Endpoint 和 Key 的健康度追踪
|
健康监控器 - Endpoint 和 Key 的健康度追踪(按 API 格式区分)
|
||||||
|
|
||||||
功能:
|
功能:
|
||||||
1. 基于滑动窗口的错误率计算
|
1. 基于滑动窗口的错误率计算(按 API 格式独立)
|
||||||
2. 三态熔断器:关闭 -> 打开 -> 半开 -> 关闭
|
2. 三态熔断器:关闭 -> 打开 -> 半开 -> 关闭(按 API 格式独立)
|
||||||
3. 半开状态允许少量请求验证服务恢复
|
3. 半开状态允许少量请求验证服务恢复
|
||||||
4. 提供健康度查询和管理 API
|
4. 提供健康度查询和管理 API
|
||||||
|
|
||||||
|
数据结构:
|
||||||
|
- health_by_format: {"CLAUDE": {"health_score": 1.0, "consecutive_failures": 0, ...}, ...}
|
||||||
|
- circuit_breaker_by_format: {"CLAUDE": {"open": false, "open_at": null, ...}, ...}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -30,8 +34,30 @@ class CircuitState:
|
|||||||
HALF_OPEN = "half_open" # 半开(验证恢复)
|
HALF_OPEN = "half_open" # 半开(验证恢复)
|
||||||
|
|
||||||
|
|
||||||
|
# 默认健康度数据结构
|
||||||
|
def _default_health_data() -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"health_score": 1.0,
|
||||||
|
"consecutive_failures": 0,
|
||||||
|
"last_failure_at": None,
|
||||||
|
"request_results_window": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 默认熔断器数据结构
|
||||||
|
def _default_circuit_data() -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"open": False,
|
||||||
|
"open_at": None,
|
||||||
|
"next_probe_at": None,
|
||||||
|
"half_open_until": None,
|
||||||
|
"half_open_successes": 0,
|
||||||
|
"half_open_failures": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class HealthMonitor:
|
class HealthMonitor:
|
||||||
"""健康监控器(滑动窗口 + 半开状态模式)"""
|
"""健康监控器(滑动窗口 + 半开状态模式,按 API 格式区分)"""
|
||||||
|
|
||||||
# === 滑动窗口配置 ===
|
# === 滑动窗口配置 ===
|
||||||
WINDOW_SIZE = int(os.getenv("HEALTH_WINDOW_SIZE", str(CircuitBreakerDefaults.WINDOW_SIZE)))
|
WINDOW_SIZE = int(os.getenv("HEALTH_WINDOW_SIZE", str(CircuitBreakerDefaults.WINDOW_SIZE)))
|
||||||
@@ -96,6 +122,38 @@ class HealthMonitor:
|
|||||||
_circuit_history: List[Dict[str, Any]] = []
|
_circuit_history: List[Dict[str, Any]] = []
|
||||||
_open_circuit_keys: int = 0
|
_open_circuit_keys: int = 0
|
||||||
|
|
||||||
|
# ==================== 数据访问辅助方法 ====================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_health_data(cls, key: ProviderAPIKey, api_format: str) -> Dict[str, Any]:
|
||||||
|
"""获取指定格式的健康度数据,不存在则返回默认值"""
|
||||||
|
health_by_format = key.health_by_format or {}
|
||||||
|
if api_format not in health_by_format:
|
||||||
|
return _default_health_data()
|
||||||
|
return health_by_format[api_format]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _set_health_data(cls, key: ProviderAPIKey, api_format: str, data: Dict[str, Any]) -> None:
|
||||||
|
"""设置指定格式的健康度数据"""
|
||||||
|
health_by_format = dict(key.health_by_format or {})
|
||||||
|
health_by_format[api_format] = data
|
||||||
|
key.health_by_format = health_by_format # type: ignore[assignment]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_circuit_data(cls, key: ProviderAPIKey, api_format: str) -> Dict[str, Any]:
|
||||||
|
"""获取指定格式的熔断器数据,不存在则返回默认值"""
|
||||||
|
circuit_by_format = key.circuit_breaker_by_format or {}
|
||||||
|
if api_format not in circuit_by_format:
|
||||||
|
return _default_circuit_data()
|
||||||
|
return circuit_by_format[api_format]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _set_circuit_data(cls, key: ProviderAPIKey, api_format: str, data: Dict[str, Any]) -> None:
|
||||||
|
"""设置指定格式的熔断器数据"""
|
||||||
|
circuit_by_format = dict(key.circuit_breaker_by_format or {})
|
||||||
|
circuit_by_format[api_format] = data
|
||||||
|
key.circuit_breaker_by_format = circuit_by_format # type: ignore[assignment]
|
||||||
|
|
||||||
# ==================== 核心方法 ====================
|
# ==================== 核心方法 ====================
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -103,9 +161,21 @@ class HealthMonitor:
|
|||||||
cls,
|
cls,
|
||||||
db: Session,
|
db: Session,
|
||||||
key_id: Optional[str] = None,
|
key_id: Optional[str] = None,
|
||||||
|
api_format: Optional[str] = None,
|
||||||
response_time_ms: Optional[int] = None,
|
response_time_ms: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""记录成功请求"""
|
"""记录成功请求(按 API 格式)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
key_id: Key ID(必需)
|
||||||
|
api_format: API 格式(必需,用于区分不同格式的健康度)
|
||||||
|
response_time_ms: 响应时间(可选)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
api_format 在逻辑上是必需的,但为了向后兼容保持 Optional 签名。
|
||||||
|
如果未提供,会尝试从 Key 的 api_formats 中获取第一个格式作为 fallback。
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
if not key_id:
|
if not key_id:
|
||||||
return
|
return
|
||||||
@@ -114,39 +184,96 @@ class HealthMonitor:
|
|||||||
if not key:
|
if not key:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# api_format 兼容处理:如果未提供,尝试使用 Key 的第一个格式
|
||||||
|
effective_api_format = api_format
|
||||||
|
if not effective_api_format:
|
||||||
|
if key.api_formats and len(key.api_formats) > 0:
|
||||||
|
effective_api_format = key.api_formats[0]
|
||||||
|
logger.debug(
|
||||||
|
f"record_success: api_format 未提供,使用默认格式 {effective_api_format}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"record_success: api_format 未提供且 Key 无可用格式: key_id={key_id[:8]}..."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
now_ts = now.timestamp()
|
now_ts = now.timestamp()
|
||||||
|
|
||||||
|
# 获取当前格式的健康度数据
|
||||||
|
health_data = cls._get_health_data(key, effective_api_format)
|
||||||
|
circuit_data = cls._get_circuit_data(key, effective_api_format)
|
||||||
|
|
||||||
# 1. 更新滑动窗口
|
# 1. 更新滑动窗口
|
||||||
cls._add_to_window(key, now_ts, success=True)
|
window = health_data.get("request_results_window") or []
|
||||||
|
window.append({"ts": now_ts, "ok": True})
|
||||||
|
cutoff_ts = now_ts - cls.WINDOW_SECONDS
|
||||||
|
window = [r for r in window if r["ts"] > cutoff_ts]
|
||||||
|
if len(window) > cls.WINDOW_SIZE:
|
||||||
|
window = window[-cls.WINDOW_SIZE :]
|
||||||
|
health_data["request_results_window"] = window
|
||||||
|
|
||||||
# 2. 更新健康度(用于展示)
|
# 2. 更新健康度(用于展示)
|
||||||
new_score = min(float(key.health_score or 0) + cls.SUCCESS_INCREMENT, 1.0)
|
current_score = float(health_data.get("health_score") or 0)
|
||||||
key.health_score = new_score # type: ignore[assignment]
|
new_score = min(current_score + cls.SUCCESS_INCREMENT, 1.0)
|
||||||
|
health_data["health_score"] = new_score
|
||||||
|
|
||||||
# 3. 更新统计
|
# 3. 更新统计
|
||||||
key.consecutive_failures = 0 # type: ignore[assignment]
|
health_data["consecutive_failures"] = 0
|
||||||
key.last_failure_at = None # type: ignore[assignment]
|
health_data["last_failure_at"] = None
|
||||||
|
|
||||||
|
# 4. 处理熔断器状态
|
||||||
|
state = cls._get_circuit_state_from_data(circuit_data, now)
|
||||||
|
|
||||||
|
if state == CircuitState.HALF_OPEN:
|
||||||
|
# 半开状态:记录成功
|
||||||
|
circuit_data["half_open_successes"] = int(
|
||||||
|
circuit_data.get("half_open_successes") or 0
|
||||||
|
) + 1
|
||||||
|
|
||||||
|
if circuit_data["half_open_successes"] >= cls.HALF_OPEN_SUCCESS_THRESHOLD:
|
||||||
|
# 达到成功阈值,关闭熔断器
|
||||||
|
cls._close_circuit_data(circuit_data, health_data, reason="半开状态验证成功")
|
||||||
|
cls._push_circuit_event(
|
||||||
|
{
|
||||||
|
"event": "closed",
|
||||||
|
"key_id": key.id,
|
||||||
|
"api_format": effective_api_format,
|
||||||
|
"reason": "半开状态验证成功",
|
||||||
|
"timestamp": now.isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[CLOSED] Key 熔断器关闭: {key.id[:8]}.../{effective_api_format} | 原因: 半开状态验证成功"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif state == CircuitState.OPEN:
|
||||||
|
# 打开状态下的成功(探测成功),进入半开状态
|
||||||
|
cls._enter_half_open_data(circuit_data, now)
|
||||||
|
cls._push_circuit_event(
|
||||||
|
{
|
||||||
|
"event": "half_open",
|
||||||
|
"key_id": key.id,
|
||||||
|
"api_format": effective_api_format,
|
||||||
|
"timestamp": now.isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[HALF-OPEN] Key 进入半开状态: {key.id[:8]}.../{effective_api_format} | "
|
||||||
|
f"需要 {cls.HALF_OPEN_SUCCESS_THRESHOLD} 次成功关闭熔断器"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存数据
|
||||||
|
cls._set_health_data(key, effective_api_format, health_data)
|
||||||
|
cls._set_circuit_data(key, effective_api_format, circuit_data)
|
||||||
|
|
||||||
|
# 更新全局统计
|
||||||
key.success_count = int(key.success_count or 0) + 1 # type: ignore[assignment]
|
key.success_count = int(key.success_count or 0) + 1 # type: ignore[assignment]
|
||||||
key.request_count = int(key.request_count or 0) + 1 # type: ignore[assignment]
|
key.request_count = int(key.request_count or 0) + 1 # type: ignore[assignment]
|
||||||
if response_time_ms:
|
if response_time_ms:
|
||||||
key.total_response_time_ms = int(key.total_response_time_ms or 0) + response_time_ms # type: ignore[assignment]
|
key.total_response_time_ms = int(key.total_response_time_ms or 0) + response_time_ms # type: ignore[assignment]
|
||||||
|
|
||||||
# 4. 处理熔断器状态
|
|
||||||
state = cls._get_circuit_state(key, now)
|
|
||||||
|
|
||||||
if state == CircuitState.HALF_OPEN:
|
|
||||||
# 半开状态:记录成功
|
|
||||||
key.half_open_successes = int(key.half_open_successes or 0) + 1 # type: ignore[assignment]
|
|
||||||
|
|
||||||
if int(key.half_open_successes or 0) >= cls.HALF_OPEN_SUCCESS_THRESHOLD:
|
|
||||||
# 达到成功阈值,关闭熔断器
|
|
||||||
cls._close_circuit(key, now, reason="半开状态验证成功")
|
|
||||||
|
|
||||||
elif state == CircuitState.OPEN:
|
|
||||||
# 打开状态下的成功(探测成功),进入半开状态
|
|
||||||
cls._enter_half_open(key, now)
|
|
||||||
|
|
||||||
db.flush()
|
db.flush()
|
||||||
get_batch_committer().mark_dirty(db)
|
get_batch_committer().mark_dirty(db)
|
||||||
|
|
||||||
@@ -159,9 +286,21 @@ class HealthMonitor:
|
|||||||
cls,
|
cls,
|
||||||
db: Session,
|
db: Session,
|
||||||
key_id: Optional[str] = None,
|
key_id: Optional[str] = None,
|
||||||
|
api_format: Optional[str] = None,
|
||||||
error_type: Optional[str] = None,
|
error_type: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""记录失败请求"""
|
"""记录失败请求(按 API 格式)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
key_id: Key ID(必需)
|
||||||
|
api_format: API 格式(必需,用于区分不同格式的健康度)
|
||||||
|
error_type: 错误类型(可选)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
api_format 在逻辑上是必需的,但为了向后兼容保持 Optional 签名。
|
||||||
|
如果未提供,会尝试从 Key 的 api_formats 中获取第一个格式作为 fallback。
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
if not key_id:
|
if not key_id:
|
||||||
return
|
return
|
||||||
@@ -170,46 +309,117 @@ class HealthMonitor:
|
|||||||
if not key:
|
if not key:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# api_format 兼容处理:如果未提供,尝试使用 Key 的第一个格式
|
||||||
|
effective_api_format = api_format
|
||||||
|
if not effective_api_format:
|
||||||
|
if key.api_formats and len(key.api_formats) > 0:
|
||||||
|
effective_api_format = key.api_formats[0]
|
||||||
|
logger.debug(
|
||||||
|
f"record_failure: api_format 未提供,使用默认格式 {effective_api_format}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"record_failure: api_format 未提供且 Key 无可用格式: key_id={key_id[:8]}..."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
now_ts = now.timestamp()
|
now_ts = now.timestamp()
|
||||||
|
|
||||||
|
# 获取当前格式的健康度数据
|
||||||
|
health_data = cls._get_health_data(key, effective_api_format)
|
||||||
|
circuit_data = cls._get_circuit_data(key, effective_api_format)
|
||||||
|
|
||||||
# 1. 更新滑动窗口
|
# 1. 更新滑动窗口
|
||||||
cls._add_to_window(key, now_ts, success=False)
|
window = health_data.get("request_results_window") or []
|
||||||
|
window.append({"ts": now_ts, "ok": False})
|
||||||
|
cutoff_ts = now_ts - cls.WINDOW_SECONDS
|
||||||
|
window = [r for r in window if r["ts"] > cutoff_ts]
|
||||||
|
if len(window) > cls.WINDOW_SIZE:
|
||||||
|
window = window[-cls.WINDOW_SIZE :]
|
||||||
|
health_data["request_results_window"] = window
|
||||||
|
|
||||||
# 2. 更新健康度(用于展示)
|
# 2. 更新健康度(用于展示)
|
||||||
new_score = max(float(key.health_score or 1) - cls.FAILURE_DECREMENT, 0.0)
|
current_score = float(health_data.get("health_score") or 1)
|
||||||
key.health_score = new_score # type: ignore[assignment]
|
new_score = max(current_score - cls.FAILURE_DECREMENT, 0.0)
|
||||||
|
health_data["health_score"] = new_score
|
||||||
|
|
||||||
# 3. 更新统计
|
# 3. 更新统计
|
||||||
key.consecutive_failures = int(key.consecutive_failures or 0) + 1 # type: ignore[assignment]
|
health_data["consecutive_failures"] = (
|
||||||
key.last_failure_at = now # type: ignore[assignment]
|
int(health_data.get("consecutive_failures") or 0) + 1
|
||||||
key.error_count = int(key.error_count or 0) + 1 # type: ignore[assignment]
|
)
|
||||||
key.request_count = int(key.request_count or 0) + 1 # type: ignore[assignment]
|
health_data["last_failure_at"] = now.isoformat()
|
||||||
|
|
||||||
# 4. 处理熔断器状态
|
# 4. 处理熔断器状态
|
||||||
state = cls._get_circuit_state(key, now)
|
state = cls._get_circuit_state_from_data(circuit_data, now)
|
||||||
|
|
||||||
if state == CircuitState.HALF_OPEN:
|
if state == CircuitState.HALF_OPEN:
|
||||||
# 半开状态:记录失败
|
# 半开状态:记录失败
|
||||||
key.half_open_failures = int(key.half_open_failures or 0) + 1 # type: ignore[assignment]
|
circuit_data["half_open_failures"] = int(
|
||||||
|
circuit_data.get("half_open_failures") or 0
|
||||||
|
) + 1
|
||||||
|
|
||||||
if int(key.half_open_failures or 0) >= cls.HALF_OPEN_FAILURE_THRESHOLD:
|
if circuit_data["half_open_failures"] >= cls.HALF_OPEN_FAILURE_THRESHOLD:
|
||||||
# 达到失败阈值,重新打开熔断器
|
# 达到失败阈值,重新打开熔断器
|
||||||
cls._open_circuit(key, now, reason="半开状态验证失败")
|
# 注意:半开状态本身就是打开状态的子状态,不需要增加计数
|
||||||
|
consecutive = int(health_data.get("consecutive_failures") or 0)
|
||||||
|
recovery_seconds = cls._calculate_recovery_seconds(consecutive)
|
||||||
|
cls._open_circuit_data(
|
||||||
|
circuit_data, now, recovery_seconds, reason="半开状态验证失败"
|
||||||
|
)
|
||||||
|
cls._push_circuit_event(
|
||||||
|
{
|
||||||
|
"event": "opened",
|
||||||
|
"key_id": key.id,
|
||||||
|
"api_format": effective_api_format,
|
||||||
|
"reason": "半开状态验证失败",
|
||||||
|
"recovery_seconds": recovery_seconds,
|
||||||
|
"timestamp": now.isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"[OPEN] Key 熔断器打开: {key.id[:8]}.../{effective_api_format} | 原因: 半开状态验证失败 | "
|
||||||
|
f"{recovery_seconds}秒后进入半开状态"
|
||||||
|
)
|
||||||
|
|
||||||
elif state == CircuitState.CLOSED:
|
elif state == CircuitState.CLOSED:
|
||||||
# 关闭状态:检查是否需要打开熔断器
|
# 关闭状态:检查是否需要打开熔断器
|
||||||
error_rate = cls._calculate_error_rate(key, now_ts)
|
error_rate = cls._calculate_error_rate_from_window(window, now_ts)
|
||||||
window = key.request_results_window or []
|
|
||||||
|
|
||||||
if len(window) >= cls.MIN_REQUESTS and error_rate >= cls.ERROR_RATE_THRESHOLD:
|
if len(window) >= cls.MIN_REQUESTS and error_rate >= cls.ERROR_RATE_THRESHOLD:
|
||||||
cls._open_circuit(
|
consecutive = int(health_data.get("consecutive_failures") or 0)
|
||||||
key, now, reason=f"错误率 {error_rate:.0%} 超过阈值 {cls.ERROR_RATE_THRESHOLD:.0%}"
|
recovery_seconds = cls._calculate_recovery_seconds(consecutive)
|
||||||
|
reason = f"错误率 {error_rate:.0%} 超过阈值 {cls.ERROR_RATE_THRESHOLD:.0%}"
|
||||||
|
cls._open_circuit_data(circuit_data, now, recovery_seconds, reason=reason)
|
||||||
|
cls._open_circuit_keys += 1
|
||||||
|
health_open_circuits.set(cls._open_circuit_keys)
|
||||||
|
cls._push_circuit_event(
|
||||||
|
{
|
||||||
|
"event": "opened",
|
||||||
|
"key_id": key.id,
|
||||||
|
"api_format": effective_api_format,
|
||||||
|
"reason": reason,
|
||||||
|
"recovery_seconds": recovery_seconds,
|
||||||
|
"timestamp": now.isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
f"[OPEN] Key 熔断器打开: {key.id[:8]}.../{effective_api_format} | 原因: {reason} | "
|
||||||
|
f"{recovery_seconds}秒后进入半开状态"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 保存数据
|
||||||
|
cls._set_health_data(key, effective_api_format, health_data)
|
||||||
|
cls._set_circuit_data(key, effective_api_format, circuit_data)
|
||||||
|
|
||||||
|
# 更新全局统计
|
||||||
|
key.error_count = int(key.error_count or 0) + 1 # type: ignore[assignment]
|
||||||
|
key.request_count = int(key.request_count or 0) + 1 # type: ignore[assignment]
|
||||||
|
key.last_error_at = now # type: ignore[assignment]
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[WARN] Key 健康度下降: {key_id[:8]}... -> {new_score:.2f} "
|
f"[WARN] Key 健康度下降: {key_id[:8]}.../{effective_api_format} -> {new_score:.2f} "
|
||||||
f"(连续失败 {key.consecutive_failures} 次, error_type={error_type})"
|
f"(连续失败 {health_data['consecutive_failures']} 次, error_type={error_type})"
|
||||||
)
|
)
|
||||||
|
|
||||||
db.flush()
|
db.flush()
|
||||||
@@ -222,31 +432,13 @@ class HealthMonitor:
|
|||||||
# ==================== 滑动窗口方法 ====================
|
# ==================== 滑动窗口方法 ====================
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _add_to_window(cls, key: ProviderAPIKey, now_ts: float, success: bool) -> None:
|
def _calculate_error_rate_from_window(
|
||||||
"""添加请求结果到滑动窗口"""
|
cls, window: List[Dict[str, Any]], now_ts: float
|
||||||
window: List[Dict[str, Any]] = key.request_results_window or []
|
) -> float:
|
||||||
|
"""从窗口数据计算错误率"""
|
||||||
# 添加新记录
|
|
||||||
window.append({"ts": now_ts, "ok": success})
|
|
||||||
|
|
||||||
# 清理过期记录
|
|
||||||
cutoff_ts = now_ts - cls.WINDOW_SECONDS
|
|
||||||
window = [r for r in window if r["ts"] > cutoff_ts]
|
|
||||||
|
|
||||||
# 限制窗口大小
|
|
||||||
if len(window) > cls.WINDOW_SIZE:
|
|
||||||
window = window[-cls.WINDOW_SIZE :]
|
|
||||||
|
|
||||||
key.request_results_window = window # type: ignore[assignment]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _calculate_error_rate(cls, key: ProviderAPIKey, now_ts: float) -> float:
|
|
||||||
"""计算滑动窗口内的错误率"""
|
|
||||||
window: List[Dict[str, Any]] = key.request_results_window or []
|
|
||||||
if not window:
|
if not window:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# 过滤过期记录
|
|
||||||
cutoff_ts = now_ts - cls.WINDOW_SECONDS
|
cutoff_ts = now_ts - cls.WINDOW_SECONDS
|
||||||
valid_records = [r for r in window if r["ts"] > cutoff_ts]
|
valid_records = [r for r in window if r["ts"] > cutoff_ts]
|
||||||
|
|
||||||
@@ -256,157 +448,158 @@ class HealthMonitor:
|
|||||||
failures = sum(1 for r in valid_records if not r["ok"])
|
failures = sum(1 for r in valid_records if not r["ok"])
|
||||||
return failures / len(valid_records)
|
return failures / len(valid_records)
|
||||||
|
|
||||||
# ==================== 熔断器状态方法 ====================
|
# ==================== 熔断器状态方法(操作数据字典)====================
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_circuit_state(cls, key: ProviderAPIKey, now: datetime) -> str:
|
def _get_circuit_state_from_data(cls, circuit_data: Dict[str, Any], now: datetime) -> str:
|
||||||
"""获取当前熔断器状态"""
|
"""从数据字典获取当前熔断器状态"""
|
||||||
if not key.circuit_breaker_open:
|
if not circuit_data.get("open"):
|
||||||
return CircuitState.CLOSED
|
return CircuitState.CLOSED
|
||||||
|
|
||||||
# 检查是否在半开状态
|
# 检查是否在半开状态
|
||||||
if key.half_open_until and now < key.half_open_until:
|
half_open_until_str = circuit_data.get("half_open_until")
|
||||||
return CircuitState.HALF_OPEN
|
if half_open_until_str:
|
||||||
|
half_open_until = datetime.fromisoformat(half_open_until_str)
|
||||||
|
if now < half_open_until:
|
||||||
|
return CircuitState.HALF_OPEN
|
||||||
|
|
||||||
# 检查是否到了探测时间(进入半开)
|
# 检查是否到了探测时间(进入半开)
|
||||||
if key.next_probe_at and now >= key.next_probe_at:
|
next_probe_str = circuit_data.get("next_probe_at")
|
||||||
return CircuitState.HALF_OPEN
|
if next_probe_str:
|
||||||
|
next_probe_at = datetime.fromisoformat(next_probe_str)
|
||||||
|
if now >= next_probe_at:
|
||||||
|
return CircuitState.HALF_OPEN
|
||||||
|
|
||||||
return CircuitState.OPEN
|
return CircuitState.OPEN
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _open_circuit(cls, key: ProviderAPIKey, now: datetime, reason: str) -> None:
|
def _open_circuit_data(
|
||||||
"""打开熔断器"""
|
cls,
|
||||||
was_open = key.circuit_breaker_open
|
circuit_data: Dict[str, Any],
|
||||||
|
now: datetime,
|
||||||
key.circuit_breaker_open = True # type: ignore[assignment]
|
recovery_seconds: int,
|
||||||
key.circuit_breaker_open_at = now # type: ignore[assignment]
|
reason: str,
|
||||||
key.half_open_until = None # type: ignore[assignment]
|
) -> None:
|
||||||
key.half_open_successes = 0 # type: ignore[assignment]
|
"""打开熔断器(操作数据字典)"""
|
||||||
key.half_open_failures = 0 # type: ignore[assignment]
|
circuit_data["open"] = True
|
||||||
|
circuit_data["open_at"] = now.isoformat()
|
||||||
# 计算下次探测时间(进入半开状态的时间)
|
circuit_data["half_open_until"] = None
|
||||||
consecutive = int(key.consecutive_failures or 0)
|
circuit_data["half_open_successes"] = 0
|
||||||
recovery_seconds = cls._calculate_recovery_seconds(consecutive)
|
circuit_data["half_open_failures"] = 0
|
||||||
key.next_probe_at = now + timedelta(seconds=recovery_seconds) # type: ignore[assignment]
|
circuit_data["next_probe_at"] = (now + timedelta(seconds=recovery_seconds)).isoformat()
|
||||||
|
|
||||||
if not was_open:
|
|
||||||
cls._open_circuit_keys += 1
|
|
||||||
health_open_circuits.set(cls._open_circuit_keys)
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"[OPEN] Key 熔断器打开: {key.id[:8]}... | 原因: {reason} | "
|
|
||||||
f"{recovery_seconds}秒后进入半开状态"
|
|
||||||
)
|
|
||||||
|
|
||||||
cls._push_circuit_event(
|
|
||||||
{
|
|
||||||
"event": "opened",
|
|
||||||
"key_id": key.id,
|
|
||||||
"reason": reason,
|
|
||||||
"recovery_seconds": recovery_seconds,
|
|
||||||
"timestamp": now.isoformat(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _enter_half_open(cls, key: ProviderAPIKey, now: datetime) -> None:
|
def _enter_half_open_data(cls, circuit_data: Dict[str, Any], now: datetime) -> None:
|
||||||
"""进入半开状态"""
|
"""进入半开状态(操作数据字典)"""
|
||||||
key.half_open_until = now + timedelta(seconds=cls.HALF_OPEN_DURATION) # type: ignore[assignment]
|
circuit_data["half_open_until"] = (
|
||||||
key.half_open_successes = 0 # type: ignore[assignment]
|
now + timedelta(seconds=cls.HALF_OPEN_DURATION)
|
||||||
key.half_open_failures = 0 # type: ignore[assignment]
|
).isoformat()
|
||||||
|
circuit_data["half_open_successes"] = 0
|
||||||
logger.info(
|
circuit_data["half_open_failures"] = 0
|
||||||
f"[HALF-OPEN] Key 进入半开状态: {key.id[:8]}... | "
|
|
||||||
f"需要 {cls.HALF_OPEN_SUCCESS_THRESHOLD} 次成功关闭熔断器"
|
|
||||||
)
|
|
||||||
|
|
||||||
cls._push_circuit_event(
|
|
||||||
{
|
|
||||||
"event": "half_open",
|
|
||||||
"key_id": key.id,
|
|
||||||
"timestamp": now.isoformat(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _close_circuit(cls, key: ProviderAPIKey, now: datetime, reason: str) -> None:
|
def _close_circuit_data(
|
||||||
"""关闭熔断器"""
|
cls, circuit_data: Dict[str, Any], health_data: Dict[str, Any], reason: str
|
||||||
key.circuit_breaker_open = False # type: ignore[assignment]
|
) -> None:
|
||||||
key.circuit_breaker_open_at = None # type: ignore[assignment]
|
"""关闭熔断器(操作数据字典)"""
|
||||||
key.next_probe_at = None # type: ignore[assignment]
|
circuit_data["open"] = False
|
||||||
key.half_open_until = None # type: ignore[assignment]
|
circuit_data["open_at"] = None
|
||||||
key.half_open_successes = 0 # type: ignore[assignment]
|
circuit_data["next_probe_at"] = None
|
||||||
key.half_open_failures = 0 # type: ignore[assignment]
|
circuit_data["half_open_until"] = None
|
||||||
|
circuit_data["half_open_successes"] = 0
|
||||||
|
circuit_data["half_open_failures"] = 0
|
||||||
|
|
||||||
# 快速恢复健康度
|
# 快速恢复健康度
|
||||||
key.health_score = max(float(key.health_score or 0), cls.PROBE_RECOVERY_SCORE) # type: ignore[assignment]
|
current_score = float(health_data.get("health_score") or 0)
|
||||||
|
health_data["health_score"] = max(current_score, cls.PROBE_RECOVERY_SCORE)
|
||||||
|
|
||||||
cls._open_circuit_keys = max(0, cls._open_circuit_keys - 1)
|
cls._open_circuit_keys = max(0, cls._open_circuit_keys - 1)
|
||||||
health_open_circuits.set(cls._open_circuit_keys)
|
health_open_circuits.set(cls._open_circuit_keys)
|
||||||
|
|
||||||
logger.info(f"[CLOSED] Key 熔断器关闭: {key.id[:8]}... | 原因: {reason}")
|
|
||||||
|
|
||||||
cls._push_circuit_event(
|
|
||||||
{
|
|
||||||
"event": "closed",
|
|
||||||
"key_id": key.id,
|
|
||||||
"reason": reason,
|
|
||||||
"timestamp": now.isoformat(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _calculate_recovery_seconds(cls, consecutive_failures: int) -> int:
|
def _calculate_recovery_seconds(cls, consecutive_failures: int) -> int:
|
||||||
"""计算恢复等待时间(指数退避)"""
|
"""计算恢复等待时间(指数退避)"""
|
||||||
# 指数退避:30s -> 60s -> 120s -> 240s -> 300s(上限)
|
exponent = min(consecutive_failures // 5, 4)
|
||||||
exponent = min(consecutive_failures // 5, 4) # 每5次失败增加一级
|
|
||||||
seconds = cls.INITIAL_RECOVERY_SECONDS * (cls.RECOVERY_BACKOFF**exponent)
|
seconds = cls.INITIAL_RECOVERY_SECONDS * (cls.RECOVERY_BACKOFF**exponent)
|
||||||
return min(int(seconds), cls.MAX_RECOVERY_SECONDS)
|
return min(int(seconds), cls.MAX_RECOVERY_SECONDS)
|
||||||
|
|
||||||
# ==================== 状态查询方法 ====================
|
# ==================== 状态查询方法 ====================
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_circuit_breaker_closed(cls, resource: ProviderAPIKey) -> bool:
|
def is_circuit_breaker_closed(
|
||||||
"""检查熔断器是否允许请求通过"""
|
cls, resource: ProviderAPIKey, api_format: Optional[str] = None
|
||||||
if not resource.circuit_breaker_open:
|
) -> bool:
|
||||||
|
"""检查熔断器是否允许请求通过(按 API 格式)"""
|
||||||
|
if not api_format:
|
||||||
|
# 兼容旧调用:检查是否有任何格式的熔断器开启
|
||||||
|
circuit_by_format = resource.circuit_breaker_by_format or {}
|
||||||
|
for fmt, circuit_data in circuit_by_format.items():
|
||||||
|
if circuit_data.get("open"):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
circuit_data = cls._get_circuit_data(resource, api_format)
|
||||||
|
|
||||||
|
if not circuit_data.get("open"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
state = cls._get_circuit_state(resource, now)
|
state = cls._get_circuit_state_from_data(circuit_data, now)
|
||||||
|
|
||||||
# 半开状态允许请求通过
|
# 半开状态允许请求通过
|
||||||
if state == CircuitState.HALF_OPEN:
|
if state == CircuitState.HALF_OPEN:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 检查是否到了探测时间
|
# 检查是否到了探测时间
|
||||||
if resource.next_probe_at and now >= resource.next_probe_at:
|
next_probe_str = circuit_data.get("next_probe_at")
|
||||||
# 自动进入半开状态
|
if next_probe_str:
|
||||||
cls._enter_half_open(resource, now)
|
next_probe_at = datetime.fromisoformat(next_probe_str)
|
||||||
return True
|
if now >= next_probe_at:
|
||||||
|
# 自动进入半开状态
|
||||||
|
cls._enter_half_open_data(circuit_data, now)
|
||||||
|
cls._set_circuit_data(resource, api_format, circuit_data)
|
||||||
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_circuit_breaker_status(
|
def get_circuit_breaker_status(
|
||||||
cls, resource: ProviderAPIKey
|
cls, resource: ProviderAPIKey, api_format: Optional[str] = None
|
||||||
) -> Tuple[bool, Optional[str]]:
|
) -> Tuple[bool, Optional[str]]:
|
||||||
"""获取熔断器详细状态"""
|
"""获取熔断器详细状态(按 API 格式)"""
|
||||||
if not resource.circuit_breaker_open:
|
if not api_format:
|
||||||
|
# 兼容旧调用:返回第一个开启的熔断器状态
|
||||||
|
circuit_by_format = resource.circuit_breaker_by_format or {}
|
||||||
|
for fmt, circuit_data in circuit_by_format.items():
|
||||||
|
if circuit_data.get("open"):
|
||||||
|
return cls._get_status_from_circuit_data(circuit_data)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
circuit_data = cls._get_circuit_data(resource, api_format)
|
||||||
|
return cls._get_status_from_circuit_data(circuit_data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_status_from_circuit_data(
|
||||||
|
cls, circuit_data: Dict[str, Any]
|
||||||
|
) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""从熔断器数据获取状态描述"""
|
||||||
|
if not circuit_data.get("open"):
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
state = cls._get_circuit_state(resource, now)
|
state = cls._get_circuit_state_from_data(circuit_data, now)
|
||||||
|
|
||||||
if state == CircuitState.HALF_OPEN:
|
if state == CircuitState.HALF_OPEN:
|
||||||
successes = int(resource.half_open_successes or 0)
|
successes = int(circuit_data.get("half_open_successes") or 0)
|
||||||
return True, f"半开状态({successes}/{cls.HALF_OPEN_SUCCESS_THRESHOLD}成功)"
|
return True, f"半开状态({successes}/{cls.HALF_OPEN_SUCCESS_THRESHOLD}成功)"
|
||||||
|
|
||||||
if resource.next_probe_at:
|
next_probe_str = circuit_data.get("next_probe_at")
|
||||||
if now >= resource.next_probe_at:
|
if next_probe_str:
|
||||||
|
next_probe_at = datetime.fromisoformat(next_probe_str)
|
||||||
|
if now >= next_probe_at:
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
remaining = resource.next_probe_at - now
|
remaining = next_probe_at - now
|
||||||
remaining_seconds = int(remaining.total_seconds())
|
remaining_seconds = int(remaining.total_seconds())
|
||||||
if remaining_seconds >= 60:
|
if remaining_seconds >= 60:
|
||||||
time_str = f"{remaining_seconds // 60}min{remaining_seconds % 60}s"
|
time_str = f"{remaining_seconds // 60}min{remaining_seconds % 60}s"
|
||||||
@@ -417,8 +610,10 @@ class HealthMonitor:
|
|||||||
return False, "熔断中"
|
return False, "熔断中"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_key_health(cls, db: Session, key_id: str) -> Optional[Dict[str, Any]]:
|
def get_key_health(
|
||||||
"""获取 Key 健康状态"""
|
cls, db: Session, key_id: str, api_format: Optional[str] = None
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""获取 Key 健康状态(支持按格式查询)"""
|
||||||
try:
|
try:
|
||||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
|
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
|
||||||
if not key:
|
if not key:
|
||||||
@@ -427,24 +622,15 @@ class HealthMonitor:
|
|||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
now_ts = now.timestamp()
|
now_ts = now.timestamp()
|
||||||
|
|
||||||
# 计算当前错误率
|
|
||||||
error_rate = cls._calculate_error_rate(key, now_ts)
|
|
||||||
window = key.request_results_window or []
|
|
||||||
valid_window = [r for r in window if r["ts"] > now_ts - cls.WINDOW_SECONDS]
|
|
||||||
|
|
||||||
avg_response_time_ms = (
|
avg_response_time_ms = (
|
||||||
int(key.total_response_time_ms or 0) / int(key.success_count or 1)
|
int(key.total_response_time_ms or 0) / int(key.success_count or 1)
|
||||||
if key.success_count
|
if key.success_count
|
||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
# 全局统计
|
||||||
|
result = {
|
||||||
"key_id": key.id,
|
"key_id": key.id,
|
||||||
"health_score": float(key.health_score or 1.0),
|
|
||||||
"error_rate": error_rate,
|
|
||||||
"window_size": len(valid_window),
|
|
||||||
"consecutive_failures": int(key.consecutive_failures or 0),
|
|
||||||
"last_failure_at": key.last_failure_at.isoformat() if key.last_failure_at else None,
|
|
||||||
"is_active": key.is_active,
|
"is_active": key.is_active,
|
||||||
"statistics": {
|
"statistics": {
|
||||||
"request_count": int(key.request_count or 0),
|
"request_count": int(key.request_count or 0),
|
||||||
@@ -457,25 +643,84 @@ class HealthMonitor:
|
|||||||
),
|
),
|
||||||
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
||||||
},
|
},
|
||||||
"circuit_breaker": {
|
|
||||||
"state": cls._get_circuit_state(key, now),
|
|
||||||
"open": key.circuit_breaker_open,
|
|
||||||
"open_at": (
|
|
||||||
key.circuit_breaker_open_at.isoformat()
|
|
||||||
if key.circuit_breaker_open_at
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
"next_probe_at": (
|
|
||||||
key.next_probe_at.isoformat() if key.next_probe_at else None
|
|
||||||
),
|
|
||||||
"half_open_until": (
|
|
||||||
key.half_open_until.isoformat() if key.half_open_until else None
|
|
||||||
),
|
|
||||||
"half_open_successes": int(key.half_open_successes or 0),
|
|
||||||
"half_open_failures": int(key.half_open_failures or 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 按格式的健康度数据
|
||||||
|
health_by_format = key.health_by_format or {}
|
||||||
|
circuit_by_format = key.circuit_breaker_by_format or {}
|
||||||
|
|
||||||
|
if api_format:
|
||||||
|
# 查询单个格式
|
||||||
|
health_data = cls._get_health_data(key, api_format)
|
||||||
|
circuit_data = cls._get_circuit_data(key, api_format)
|
||||||
|
window = health_data.get("request_results_window") or []
|
||||||
|
valid_window = [r for r in window if r["ts"] > now_ts - cls.WINDOW_SECONDS]
|
||||||
|
|
||||||
|
result["api_format"] = api_format
|
||||||
|
result["health_score"] = float(health_data.get("health_score") or 1.0)
|
||||||
|
result["error_rate"] = cls._calculate_error_rate_from_window(window, now_ts)
|
||||||
|
result["window_size"] = len(valid_window)
|
||||||
|
result["consecutive_failures"] = int(
|
||||||
|
health_data.get("consecutive_failures") or 0
|
||||||
|
)
|
||||||
|
result["last_failure_at"] = health_data.get("last_failure_at")
|
||||||
|
result["circuit_breaker"] = {
|
||||||
|
"state": cls._get_circuit_state_from_data(circuit_data, now),
|
||||||
|
"open": circuit_data.get("open", False),
|
||||||
|
"open_at": circuit_data.get("open_at"),
|
||||||
|
"next_probe_at": circuit_data.get("next_probe_at"),
|
||||||
|
"half_open_until": circuit_data.get("half_open_until"),
|
||||||
|
"half_open_successes": int(circuit_data.get("half_open_successes") or 0),
|
||||||
|
"half_open_failures": int(circuit_data.get("half_open_failures") or 0),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 返回所有格式的健康度数据
|
||||||
|
formats_health = {}
|
||||||
|
for fmt in (key.api_formats or []):
|
||||||
|
health_data = health_by_format.get(fmt, _default_health_data())
|
||||||
|
circuit_data = circuit_by_format.get(fmt, _default_circuit_data())
|
||||||
|
window = health_data.get("request_results_window") or []
|
||||||
|
valid_window = [r for r in window if r["ts"] > now_ts - cls.WINDOW_SECONDS]
|
||||||
|
|
||||||
|
formats_health[fmt] = {
|
||||||
|
"health_score": float(health_data.get("health_score") or 1.0),
|
||||||
|
"error_rate": cls._calculate_error_rate_from_window(window, now_ts),
|
||||||
|
"window_size": len(valid_window),
|
||||||
|
"consecutive_failures": int(
|
||||||
|
health_data.get("consecutive_failures") or 0
|
||||||
|
),
|
||||||
|
"last_failure_at": health_data.get("last_failure_at"),
|
||||||
|
"circuit_breaker": {
|
||||||
|
"state": cls._get_circuit_state_from_data(circuit_data, now),
|
||||||
|
"open": circuit_data.get("open", False),
|
||||||
|
"open_at": circuit_data.get("open_at"),
|
||||||
|
"next_probe_at": circuit_data.get("next_probe_at"),
|
||||||
|
"half_open_until": circuit_data.get("half_open_until"),
|
||||||
|
"half_open_successes": int(
|
||||||
|
circuit_data.get("half_open_successes") or 0
|
||||||
|
),
|
||||||
|
"half_open_failures": int(
|
||||||
|
circuit_data.get("half_open_failures") or 0
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result["health_by_format"] = formats_health
|
||||||
|
|
||||||
|
# 计算整体健康度(取最低值)
|
||||||
|
if formats_health:
|
||||||
|
result["health_score"] = min(
|
||||||
|
h["health_score"] for h in formats_health.values()
|
||||||
|
)
|
||||||
|
result["any_circuit_open"] = any(
|
||||||
|
h["circuit_breaker"]["open"] for h in formats_health.values()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result["health_score"] = 1.0
|
||||||
|
result["any_circuit_open"] = False
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取 Key 健康状态失败: {e}")
|
logger.error(f"获取 Key 健康状态失败: {e}")
|
||||||
return None
|
return None
|
||||||
@@ -507,23 +752,24 @@ class HealthMonitor:
|
|||||||
# ==================== 管理方法 ====================
|
# ==================== 管理方法 ====================
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def reset_health(cls, db: Session, key_id: Optional[str] = None) -> bool:
|
def reset_health(
|
||||||
"""重置健康度"""
|
cls, db: Session, key_id: Optional[str] = None, api_format: Optional[str] = None
|
||||||
|
) -> bool:
|
||||||
|
"""重置健康度(支持按格式重置)"""
|
||||||
try:
|
try:
|
||||||
if key_id:
|
if key_id:
|
||||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
|
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
|
||||||
if key:
|
if key:
|
||||||
key.health_score = 1.0 # type: ignore[assignment]
|
if api_format:
|
||||||
key.consecutive_failures = 0 # type: ignore[assignment]
|
# 重置单个格式
|
||||||
key.last_failure_at = None # type: ignore[assignment]
|
cls._set_health_data(key, api_format, _default_health_data())
|
||||||
key.request_results_window = [] # type: ignore[assignment]
|
cls._set_circuit_data(key, api_format, _default_circuit_data())
|
||||||
key.circuit_breaker_open = False # type: ignore[assignment]
|
logger.info(f"[RESET] 重置 Key 健康度: {key_id}/{api_format}")
|
||||||
key.circuit_breaker_open_at = None # type: ignore[assignment]
|
else:
|
||||||
key.next_probe_at = None # type: ignore[assignment]
|
# 重置所有格式
|
||||||
key.half_open_until = None # type: ignore[assignment]
|
key.health_by_format = {} # type: ignore[assignment]
|
||||||
key.half_open_successes = 0 # type: ignore[assignment]
|
key.circuit_breaker_by_format = {} # type: ignore[assignment]
|
||||||
key.half_open_failures = 0 # type: ignore[assignment]
|
logger.info(f"[RESET] 重置 Key 所有格式健康度: {key_id}")
|
||||||
logger.info(f"[RESET] 重置 Key 健康度: {key_id}")
|
|
||||||
|
|
||||||
db.flush()
|
db.flush()
|
||||||
get_batch_committer().mark_dirty(db)
|
get_batch_committer().mark_dirty(db)
|
||||||
@@ -542,7 +788,9 @@ class HealthMonitor:
|
|||||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
|
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
|
||||||
if key and not key.is_active:
|
if key and not key.is_active:
|
||||||
key.is_active = True # type: ignore[assignment]
|
key.is_active = True # type: ignore[assignment]
|
||||||
key.consecutive_failures = 0 # type: ignore[assignment]
|
# 重置所有格式的健康度
|
||||||
|
key.health_by_format = {} # type: ignore[assignment]
|
||||||
|
key.circuit_breaker_by_format = {} # type: ignore[assignment]
|
||||||
logger.info(f"[OK] 手动启用 Key: {key_id}")
|
logger.info(f"[OK] 手动启用 Key: {key_id}")
|
||||||
|
|
||||||
db.flush()
|
db.flush()
|
||||||
@@ -566,14 +814,28 @@ class HealthMonitor:
|
|||||||
),
|
),
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
key_stats = db.query(
|
# 统计 Key(需要遍历 JSON 字段计算熔断状态)
|
||||||
func.count(ProviderAPIKey.id).label("total"),
|
keys = db.query(ProviderAPIKey).all()
|
||||||
func.sum(case((ProviderAPIKey.is_active == True, 1), else_=0)).label("active"),
|
total_keys = len(keys)
|
||||||
func.sum(case((ProviderAPIKey.health_score < 0.5, 1), else_=0)).label("unhealthy"),
|
active_keys = sum(1 for k in keys if k.is_active)
|
||||||
func.sum(case((ProviderAPIKey.circuit_breaker_open == True, 1), else_=0)).label(
|
unhealthy_keys = 0
|
||||||
"circuit_open"
|
circuit_open_keys = 0
|
||||||
),
|
|
||||||
).first()
|
for key in keys:
|
||||||
|
health_by_format = key.health_by_format or {}
|
||||||
|
circuit_by_format = key.circuit_breaker_by_format or {}
|
||||||
|
|
||||||
|
# 检查是否有任何格式健康度低于 0.5
|
||||||
|
for fmt, health_data in health_by_format.items():
|
||||||
|
if float(health_data.get("health_score") or 1.0) < 0.5:
|
||||||
|
unhealthy_keys += 1
|
||||||
|
break
|
||||||
|
|
||||||
|
# 检查是否有任何格式熔断器开启
|
||||||
|
for fmt, circuit_data in circuit_by_format.items():
|
||||||
|
if circuit_data.get("open"):
|
||||||
|
circuit_open_keys += 1
|
||||||
|
break
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"endpoints": {
|
"endpoints": {
|
||||||
@@ -582,10 +844,10 @@ class HealthMonitor:
|
|||||||
"unhealthy": int(endpoint_stats.unhealthy or 0) if endpoint_stats else 0,
|
"unhealthy": int(endpoint_stats.unhealthy or 0) if endpoint_stats else 0,
|
||||||
},
|
},
|
||||||
"keys": {
|
"keys": {
|
||||||
"total": key_stats.total or 0 if key_stats else 0,
|
"total": total_keys,
|
||||||
"active": int(key_stats.active or 0) if key_stats else 0,
|
"active": active_keys,
|
||||||
"unhealthy": int(key_stats.unhealthy or 0) if key_stats else 0,
|
"unhealthy": unhealthy_keys,
|
||||||
"circuit_open": int(key_stats.circuit_open or 0) if key_stats else 0,
|
"circuit_open": circuit_open_keys,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -618,8 +880,9 @@ class HealthMonitor:
|
|||||||
db: Session,
|
db: Session,
|
||||||
endpoint_id: Optional[str] = None,
|
endpoint_id: Optional[str] = None,
|
||||||
key_id: Optional[str] = None,
|
key_id: Optional[str] = None,
|
||||||
|
api_format: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""检查是否有资格进行探测(兼容旧接口)"""
|
"""检查是否有资格进行探测(按 API 格式)"""
|
||||||
if not cls.ALLOW_AUTO_RECOVER:
|
if not cls.ALLOW_AUTO_RECOVER:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -628,13 +891,53 @@ class HealthMonitor:
|
|||||||
|
|
||||||
if key_id:
|
if key_id:
|
||||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
|
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
|
||||||
if key and key.circuit_breaker_open:
|
if key:
|
||||||
now = datetime.now(timezone.utc)
|
if api_format:
|
||||||
state = cls._get_circuit_state(key, now)
|
circuit_data = cls._get_circuit_data(key, api_format)
|
||||||
return state == CircuitState.HALF_OPEN
|
if circuit_data.get("open"):
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
state = cls._get_circuit_state_from_data(circuit_data, now)
|
||||||
|
return state == CircuitState.HALF_OPEN
|
||||||
|
else:
|
||||||
|
# 兼容旧调用:检查是否有任何格式处于半开状态
|
||||||
|
circuit_by_format = key.circuit_breaker_by_format or {}
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
for fmt, circuit_data in circuit_by_format.items():
|
||||||
|
if circuit_data.get("open"):
|
||||||
|
state = cls._get_circuit_state_from_data(circuit_data, now)
|
||||||
|
if state == CircuitState.HALF_OPEN:
|
||||||
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# ==================== 便捷方法 ====================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_health_score(
|
||||||
|
cls, key: ProviderAPIKey, api_format: Optional[str] = None
|
||||||
|
) -> float:
|
||||||
|
"""获取指定格式的健康度分数"""
|
||||||
|
if not api_format:
|
||||||
|
# 返回所有格式中的最低健康度
|
||||||
|
health_by_format = key.health_by_format or {}
|
||||||
|
if not health_by_format:
|
||||||
|
return 1.0
|
||||||
|
return min(
|
||||||
|
float(h.get("health_score") or 1.0) for h in health_by_format.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
health_data = cls._get_health_data(key, api_format)
|
||||||
|
return float(health_data.get("health_score") or 1.0)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_any_circuit_open(cls, key: ProviderAPIKey) -> bool:
|
||||||
|
"""检查是否有任何格式的熔断器开启"""
|
||||||
|
circuit_by_format = key.circuit_breaker_by_format or {}
|
||||||
|
for circuit_data in circuit_by_format.values():
|
||||||
|
if circuit_data.get("open"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# 全局健康监控器实例
|
# 全局健康监控器实例
|
||||||
health_monitor = HealthMonitor()
|
health_monitor = HealthMonitor()
|
||||||
|
|||||||
@@ -216,7 +216,7 @@ class ModelService:
|
|||||||
def delete_model(db: Session, model_id: str): # UUID
|
def delete_model(db: Session, model_id: str): # UUID
|
||||||
"""删除模型
|
"""删除模型
|
||||||
|
|
||||||
新架构删除逻辑:
|
删除逻辑:
|
||||||
- Model 只是 Provider 对 GlobalModel 的实现,删除不影响 GlobalModel
|
- Model 只是 Provider 对 GlobalModel 的实现,删除不影响 GlobalModel
|
||||||
- 检查是否是该 GlobalModel 的最后一个实现(如果是,警告但允许删除)
|
- 检查是否是该 GlobalModel 的最后一个实现(如果是,警告但允许删除)
|
||||||
"""
|
"""
|
||||||
@@ -384,7 +384,7 @@ class ModelService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_to_response(model: Model) -> ModelResponse:
|
def convert_to_response(model: Model) -> ModelResponse:
|
||||||
"""转换为响应模型(新架构:从 GlobalModel 获取显示信息和默认值)"""
|
"""转换为响应模型(从 GlobalModel 获取显示信息和默认值)"""
|
||||||
return ModelResponse(
|
return ModelResponse(
|
||||||
id=model.id,
|
id=model.id,
|
||||||
provider_id=model.provider_id,
|
provider_id=model.provider_id,
|
||||||
|
|||||||
@@ -171,7 +171,8 @@ class CandidateResolver:
|
|||||||
)
|
)
|
||||||
candidate_record_map[(candidate_index, 0)] = record_id
|
candidate_record_map[(candidate_index, 0)] = record_id
|
||||||
else:
|
else:
|
||||||
max_retries_for_candidate = endpoint.max_retries if candidate.is_cached else 1
|
# max_retries 已从 Endpoint 迁移到 Provider(Endpoint 仍可能保留旧字段用于兼容)
|
||||||
|
max_retries_for_candidate = int(provider.max_retries or 2) if candidate.is_cached else 1
|
||||||
|
|
||||||
for retry_index in range(max_retries_for_candidate):
|
for retry_index in range(max_retries_for_candidate):
|
||||||
record_id = str(uuid.uuid4())
|
record_id = str(uuid.uuid4())
|
||||||
@@ -236,7 +237,7 @@ class CandidateResolver:
|
|||||||
total = 0
|
total = 0
|
||||||
for candidate in all_candidates:
|
for candidate in all_candidates:
|
||||||
if not candidate.is_skipped:
|
if not candidate.is_skipped:
|
||||||
endpoint = candidate.endpoint
|
provider = candidate.provider
|
||||||
max_retries = int(endpoint.max_retries) if candidate.is_cached else 1
|
max_retries = int(provider.max_retries or 2) if candidate.is_cached else 1
|
||||||
total += max_retries
|
total += max_retries
|
||||||
return total
|
return total
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user