mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-12 04:28:28 +08:00
Compare commits
29 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8033fb6ab | ||
|
|
e33d5b952c | ||
|
|
4345ac2ba2 | ||
|
|
a12b43ce5c | ||
|
|
6885cf1f6d | ||
|
|
00f6fafcfc | ||
|
|
42dc64246c | ||
|
|
fbe303a3cd | ||
|
|
373845450b | ||
|
|
084bbc0bef | ||
|
|
0061fc04b7 | ||
|
|
f6a6410626 | ||
|
|
835be3d329 | ||
|
|
2395093394 | ||
|
|
28209e1c2a | ||
|
|
00562dd1d4 | ||
|
|
0f78d5cbf3 | ||
|
|
431c6de8d2 | ||
|
|
142e15bbcc | ||
|
|
31acc5c607 | ||
|
|
bfa0a26d41 | ||
|
|
93ab9b6a5e | ||
|
|
35e29d46bd | ||
|
|
3e4309eba3 | ||
|
|
414f45aa71 | ||
|
|
ebdc76346f | ||
|
|
64bfa955f4 | ||
|
|
612992fa1f | ||
|
|
9bfb295238 |
@@ -39,7 +39,18 @@ COPY alembic.ini ./
|
||||
COPY alembic/ ./alembic/
|
||||
|
||||
# Nginx 配置模板
|
||||
# 智能处理 IP:有外层代理头就透传,没有就用直连 IP
|
||||
RUN printf '%s\n' \
|
||||
'map $http_x_real_ip $real_ip {' \
|
||||
' default $http_x_real_ip;' \
|
||||
' "" $remote_addr;' \
|
||||
'}' \
|
||||
'' \
|
||||
'map $http_x_forwarded_for $forwarded_for {' \
|
||||
' default $http_x_forwarded_for;' \
|
||||
' "" $remote_addr;' \
|
||||
'}' \
|
||||
'' \
|
||||
'server {' \
|
||||
' listen 80;' \
|
||||
' server_name _;' \
|
||||
@@ -47,6 +58,15 @@ RUN printf '%s\n' \
|
||||
' index index.html;' \
|
||||
' client_max_body_size 100M;' \
|
||||
'' \
|
||||
' # gzip 压缩配置(对 base64 图片等非流式响应有效)' \
|
||||
' gzip on;' \
|
||||
' gzip_min_length 256;' \
|
||||
' gzip_comp_level 5;' \
|
||||
' gzip_vary on;' \
|
||||
' gzip_proxied any;' \
|
||||
' gzip_types application/json text/plain text/css text/javascript application/javascript application/octet-stream;' \
|
||||
' gzip_disable "msie6";' \
|
||||
'' \
|
||||
' location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {' \
|
||||
' expires 1y;' \
|
||||
' add_header Cache-Control "public, no-transform";' \
|
||||
@@ -62,6 +82,15 @@ RUN printf '%s\n' \
|
||||
' try_files $uri $uri/ /index.html;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location ~ ^/(docs|redoc|openapi\\.json)$ {' \
|
||||
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
|
||||
' proxy_http_version 1.1;' \
|
||||
' proxy_set_header Host $host;' \
|
||||
' proxy_set_header X-Real-IP $real_ip;' \
|
||||
' proxy_set_header X-Forwarded-For $forwarded_for;' \
|
||||
' proxy_set_header X-Forwarded-Proto $scheme;' \
|
||||
' }' \
|
||||
'' \
|
||||
' location / {' \
|
||||
' try_files $uri $uri/ @backend;' \
|
||||
' }' \
|
||||
@@ -70,8 +99,8 @@ RUN printf '%s\n' \
|
||||
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
|
||||
' proxy_http_version 1.1;' \
|
||||
' proxy_set_header Host $host;' \
|
||||
' proxy_set_header X-Real-IP $remote_addr;' \
|
||||
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
|
||||
' proxy_set_header X-Real-IP $real_ip;' \
|
||||
' proxy_set_header X-Forwarded-For $forwarded_for;' \
|
||||
' proxy_set_header X-Forwarded-Proto $scheme;' \
|
||||
' proxy_set_header Connection "";' \
|
||||
' proxy_set_header Accept $http_accept;' \
|
||||
@@ -124,7 +153,8 @@ ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONIOENCODING=utf-8 \
|
||||
LANG=C.UTF-8 \
|
||||
LC_ALL=C.UTF-8 \
|
||||
PORT=8084
|
||||
PORT=8084 \
|
||||
GUNICORN_WORKERS=4
|
||||
|
||||
EXPOSE 80
|
||||
|
||||
|
||||
@@ -40,7 +40,18 @@ COPY alembic.ini ./
|
||||
COPY alembic/ ./alembic/
|
||||
|
||||
# Nginx 配置模板
|
||||
# 智能处理 IP:有外层代理头就透传,没有就用直连 IP
|
||||
RUN printf '%s\n' \
|
||||
'map $http_x_real_ip $real_ip {' \
|
||||
' default $http_x_real_ip;' \
|
||||
' "" $remote_addr;' \
|
||||
'}' \
|
||||
'' \
|
||||
'map $http_x_forwarded_for $forwarded_for {' \
|
||||
' default $http_x_forwarded_for;' \
|
||||
' "" $remote_addr;' \
|
||||
'}' \
|
||||
'' \
|
||||
'server {' \
|
||||
' listen 80;' \
|
||||
' server_name _;' \
|
||||
@@ -48,6 +59,15 @@ RUN printf '%s\n' \
|
||||
' index index.html;' \
|
||||
' client_max_body_size 100M;' \
|
||||
'' \
|
||||
' # gzip 压缩配置(对 base64 图片等非流式响应有效)' \
|
||||
' gzip on;' \
|
||||
' gzip_min_length 256;' \
|
||||
' gzip_comp_level 5;' \
|
||||
' gzip_vary on;' \
|
||||
' gzip_proxied any;' \
|
||||
' gzip_types application/json text/plain text/css text/javascript application/javascript application/octet-stream;' \
|
||||
' gzip_disable "msie6";' \
|
||||
'' \
|
||||
' location ~* \.(js|css|png|jpg|jpeg|gif|ico|svg|woff|woff2|ttf|eot)$ {' \
|
||||
' expires 1y;' \
|
||||
' add_header Cache-Control "public, no-transform";' \
|
||||
@@ -71,8 +91,8 @@ RUN printf '%s\n' \
|
||||
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
|
||||
' proxy_http_version 1.1;' \
|
||||
' proxy_set_header Host $host;' \
|
||||
' proxy_set_header X-Real-IP $remote_addr;' \
|
||||
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
|
||||
' proxy_set_header X-Real-IP $real_ip;' \
|
||||
' proxy_set_header X-Forwarded-For $forwarded_for;' \
|
||||
' proxy_set_header X-Forwarded-Proto $scheme;' \
|
||||
' proxy_set_header Connection "";' \
|
||||
' proxy_set_header Accept $http_accept;' \
|
||||
|
||||
@@ -51,7 +51,7 @@ Aether 是一个自托管的 AI API 网关,为团队和个人提供多租户
|
||||
```bash
|
||||
# 1. 克隆代码
|
||||
git clone https://github.com/fawney19/Aether.git
|
||||
cd aether
|
||||
cd Aether
|
||||
|
||||
# 2. 配置环境变量
|
||||
cp .env.example .env
|
||||
@@ -72,7 +72,7 @@ docker compose pull && docker compose up -d && ./migrate.sh
|
||||
```bash
|
||||
# 1. 克隆代码
|
||||
git clone https://github.com/fawney19/Aether.git
|
||||
cd aether
|
||||
cd Aether
|
||||
|
||||
# 2. 配置环境变量
|
||||
cp .env.example .env
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
"""add ldap authentication support
|
||||
|
||||
Revision ID: c3d4e5f6g7h8
|
||||
Revises: b2c3d4e5f6g7
|
||||
Create Date: 2026-01-01 14:00:00.000000+00:00
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'c3d4e5f6g7h8'
|
||||
down_revision = 'b2c3d4e5f6g7'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _type_exists(conn, type_name: str) -> bool:
|
||||
"""检查 PostgreSQL 类型是否存在"""
|
||||
result = conn.execute(
|
||||
text("SELECT 1 FROM pg_type WHERE typname = :name"),
|
||||
{"name": type_name}
|
||||
)
|
||||
return result.scalar() is not None
|
||||
|
||||
|
||||
def _column_exists(conn, table_name: str, column_name: str) -> bool:
|
||||
"""检查列是否存在"""
|
||||
result = conn.execute(
|
||||
text("""
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = :table AND column_name = :column
|
||||
"""),
|
||||
{"table": table_name, "column": column_name}
|
||||
)
|
||||
return result.scalar() is not None
|
||||
|
||||
|
||||
def _index_exists(conn, index_name: str) -> bool:
|
||||
"""检查索引是否存在"""
|
||||
result = conn.execute(
|
||||
text("SELECT 1 FROM pg_indexes WHERE indexname = :name"),
|
||||
{"name": index_name}
|
||||
)
|
||||
return result.scalar() is not None
|
||||
|
||||
|
||||
def _table_exists(conn, table_name: str) -> bool:
|
||||
"""检查表是否存在"""
|
||||
result = conn.execute(
|
||||
text("""
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_name = :name AND table_schema = 'public'
|
||||
"""),
|
||||
{"name": table_name}
|
||||
)
|
||||
return result.scalar() is not None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""添加 LDAP 认证支持
|
||||
|
||||
1. 创建 authsource 枚举类型
|
||||
2. 在 users 表添加 auth_source 字段和 LDAP 标识字段
|
||||
3. 创建 ldap_configs 表
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# 1. 创建 authsource 枚举类型(幂等)
|
||||
if not _type_exists(conn, 'authsource'):
|
||||
conn.execute(text("CREATE TYPE authsource AS ENUM ('local', 'ldap')"))
|
||||
|
||||
# 2. 在 users 表添加字段(幂等)
|
||||
if not _column_exists(conn, 'users', 'auth_source'):
|
||||
op.add_column('users', sa.Column(
|
||||
'auth_source',
|
||||
sa.Enum('local', 'ldap', name='authsource', create_type=False),
|
||||
nullable=False,
|
||||
server_default='local'
|
||||
))
|
||||
|
||||
if not _column_exists(conn, 'users', 'ldap_dn'):
|
||||
op.add_column('users', sa.Column('ldap_dn', sa.String(length=512), nullable=True))
|
||||
|
||||
if not _column_exists(conn, 'users', 'ldap_username'):
|
||||
op.add_column('users', sa.Column('ldap_username', sa.String(length=255), nullable=True))
|
||||
|
||||
# 创建索引(幂等)
|
||||
if not _index_exists(conn, 'ix_users_ldap_dn'):
|
||||
op.create_index('ix_users_ldap_dn', 'users', ['ldap_dn'])
|
||||
|
||||
if not _index_exists(conn, 'ix_users_ldap_username'):
|
||||
op.create_index('ix_users_ldap_username', 'users', ['ldap_username'])
|
||||
|
||||
# 3. 创建 ldap_configs 表(幂等)
|
||||
if not _table_exists(conn, 'ldap_configs'):
|
||||
op.create_table(
|
||||
'ldap_configs',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('server_url', sa.String(length=255), nullable=False),
|
||||
sa.Column('bind_dn', sa.String(length=255), nullable=False),
|
||||
sa.Column('bind_password_encrypted', sa.Text(), nullable=True),
|
||||
sa.Column('base_dn', sa.String(length=255), nullable=False),
|
||||
sa.Column('user_search_filter', sa.String(length=500), nullable=False, server_default='(uid={username})'),
|
||||
sa.Column('username_attr', sa.String(length=50), nullable=False, server_default='uid'),
|
||||
sa.Column('email_attr', sa.String(length=50), nullable=False, server_default='mail'),
|
||||
sa.Column('display_name_attr', sa.String(length=50), nullable=False, server_default='cn'),
|
||||
sa.Column('is_enabled', sa.Boolean(), nullable=False, server_default='false'),
|
||||
sa.Column('is_exclusive', sa.Boolean(), nullable=False, server_default='false'),
|
||||
sa.Column('use_starttls', sa.Boolean(), nullable=False, server_default='false'),
|
||||
sa.Column('connect_timeout', sa.Integer(), nullable=False, server_default='10'),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""回滚 LDAP 认证支持
|
||||
|
||||
警告:回滚前请确保:
|
||||
1. 已备份数据库
|
||||
2. 没有 LDAP 用户需要保留
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# 检查是否存在 LDAP 用户,防止数据丢失
|
||||
if _column_exists(conn, 'users', 'auth_source'):
|
||||
result = conn.execute(text("SELECT COUNT(*) FROM users WHERE auth_source = 'ldap'"))
|
||||
ldap_user_count = result.scalar()
|
||||
if ldap_user_count and ldap_user_count > 0:
|
||||
raise RuntimeError(
|
||||
f"无法回滚:存在 {ldap_user_count} 个 LDAP 用户。"
|
||||
f"请先删除或转换这些用户,或使用 --force 参数强制回滚(将丢失数据)。"
|
||||
)
|
||||
|
||||
# 1. 删除 ldap_configs 表(幂等)
|
||||
if _table_exists(conn, 'ldap_configs'):
|
||||
op.drop_table('ldap_configs')
|
||||
|
||||
# 2. 删除 users 表的 LDAP 相关字段(幂等)
|
||||
if _index_exists(conn, 'ix_users_ldap_username'):
|
||||
op.drop_index('ix_users_ldap_username', table_name='users')
|
||||
|
||||
if _index_exists(conn, 'ix_users_ldap_dn'):
|
||||
op.drop_index('ix_users_ldap_dn', table_name='users')
|
||||
|
||||
if _column_exists(conn, 'users', 'ldap_username'):
|
||||
op.drop_column('users', 'ldap_username')
|
||||
|
||||
if _column_exists(conn, 'users', 'ldap_dn'):
|
||||
op.drop_column('users', 'ldap_dn')
|
||||
|
||||
if _column_exists(conn, 'users', 'auth_source'):
|
||||
op.drop_column('users', 'auth_source')
|
||||
|
||||
# 3. 删除 authsource 枚举类型(幂等)
|
||||
# 注意:不使用 CASCADE,因为此时所有依赖应该已被删除
|
||||
if _type_exists(conn, 'authsource'):
|
||||
conn.execute(text("DROP TYPE authsource"))
|
||||
@@ -0,0 +1,131 @@
|
||||
"""add_management_tokens_table
|
||||
|
||||
Revision ID: ad55f1d008b7
|
||||
Revises: c3d4e5f6g7h8
|
||||
Create Date: 2026-01-06 15:24:10.660394+00:00
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'ad55f1d008b7'
|
||||
down_revision = 'c3d4e5f6g7h8'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def table_exists(table_name: str) -> bool:
|
||||
"""检查表是否存在"""
|
||||
conn = op.get_bind()
|
||||
inspector = inspect(conn)
|
||||
return table_name in inspector.get_table_names()
|
||||
|
||||
|
||||
def index_exists(table_name: str, index_name: str) -> bool:
|
||||
"""检查索引是否存在"""
|
||||
conn = op.get_bind()
|
||||
inspector = inspect(conn)
|
||||
try:
|
||||
indexes = inspector.get_indexes(table_name)
|
||||
return any(idx["name"] == index_name for idx in indexes)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def constraint_exists(table_name: str, constraint_name: str) -> bool:
|
||||
"""检查约束是否存在"""
|
||||
conn = op.get_bind()
|
||||
inspector = inspect(conn)
|
||||
try:
|
||||
constraints = inspector.get_unique_constraints(table_name)
|
||||
if any(c["name"] == constraint_name for c in constraints):
|
||||
return True
|
||||
# 也检查 check 约束
|
||||
check_constraints = inspector.get_check_constraints(table_name)
|
||||
if any(c["name"] == constraint_name for c in check_constraints):
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""应用迁移:创建 management_tokens 表"""
|
||||
# 幂等性检查
|
||||
if table_exists("management_tokens"):
|
||||
# 表已存在,检查是否需要添加约束
|
||||
if not constraint_exists("management_tokens", "uq_management_tokens_user_name"):
|
||||
op.create_unique_constraint(
|
||||
"uq_management_tokens_user_name",
|
||||
"management_tokens",
|
||||
["user_id", "name"],
|
||||
)
|
||||
# 添加 IP 白名单非空检查约束
|
||||
if not constraint_exists("management_tokens", "check_allowed_ips_not_empty"):
|
||||
op.create_check_constraint(
|
||||
"check_allowed_ips_not_empty",
|
||||
"management_tokens",
|
||||
"allowed_ips IS NULL OR allowed_ips::text = 'null' OR json_array_length(allowed_ips) > 0",
|
||||
)
|
||||
return
|
||||
|
||||
op.create_table('management_tokens',
|
||||
sa.Column('id', sa.String(length=36), nullable=False),
|
||||
sa.Column('user_id', sa.String(length=36), nullable=False),
|
||||
sa.Column('token_hash', sa.String(length=64), nullable=False),
|
||||
sa.Column('token_prefix', sa.String(length=12), nullable=True),
|
||||
sa.Column('name', sa.String(length=100), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('allowed_ips', sa.JSON(), nullable=True),
|
||||
sa.Column('expires_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('last_used_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('last_used_ip', sa.String(length=45), nullable=True),
|
||||
sa.Column('usage_count', sa.Integer(), server_default='0', nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), server_default='true', nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index('idx_management_tokens_is_active', 'management_tokens', ['is_active'], unique=False)
|
||||
op.create_index('idx_management_tokens_user_id', 'management_tokens', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_management_tokens_token_hash'), 'management_tokens', ['token_hash'], unique=True)
|
||||
# 添加用户名称唯一约束
|
||||
op.create_unique_constraint(
|
||||
"uq_management_tokens_user_name",
|
||||
"management_tokens",
|
||||
["user_id", "name"],
|
||||
)
|
||||
# 添加 IP 白名单非空检查约束
|
||||
# 注意:JSON 类型的 NULL 可能被序列化为 JSON 'null',需要同时处理
|
||||
op.create_check_constraint(
|
||||
"check_allowed_ips_not_empty",
|
||||
"management_tokens",
|
||||
"allowed_ips IS NULL OR allowed_ips::text = 'null' OR json_array_length(allowed_ips) > 0",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""回滚迁移:删除 management_tokens 表"""
|
||||
# 幂等性检查
|
||||
if not table_exists("management_tokens"):
|
||||
return
|
||||
|
||||
# 删除约束
|
||||
if constraint_exists("management_tokens", "check_allowed_ips_not_empty"):
|
||||
op.drop_constraint("check_allowed_ips_not_empty", "management_tokens", type_="check")
|
||||
if constraint_exists("management_tokens", "uq_management_tokens_user_name"):
|
||||
op.drop_constraint("uq_management_tokens_user_name", "management_tokens", type_="unique")
|
||||
|
||||
# 删除索引
|
||||
if index_exists("management_tokens", "ix_management_tokens_token_hash"):
|
||||
op.drop_index(op.f('ix_management_tokens_token_hash'), table_name='management_tokens')
|
||||
if index_exists("management_tokens", "idx_management_tokens_user_id"):
|
||||
op.drop_index('idx_management_tokens_user_id', table_name='management_tokens')
|
||||
if index_exists("management_tokens", "idx_management_tokens_is_active"):
|
||||
op.drop_index('idx_management_tokens_is_active', table_name='management_tokens')
|
||||
|
||||
# 删除表
|
||||
op.drop_table('management_tokens')
|
||||
@@ -0,0 +1,73 @@
|
||||
"""cleanup ambiguous database fields
|
||||
|
||||
Revision ID: 02a45b66b7c4
|
||||
Revises: ad55f1d008b7
|
||||
Create Date: 2026-01-07 11:20:12.684426+00:00
|
||||
|
||||
变更内容:
|
||||
1. users 表:重命名 allowed_endpoints 为 allowed_api_formats(修正历史命名错误)
|
||||
2. api_keys 表:删除 allowed_endpoints 字段(未使用的功能)
|
||||
3. providers 表:删除 rate_limit 字段(与 rpm_limit 功能重复,且未使用)
|
||||
4. usage 表:重命名 provider 为 provider_name(避免与 provider_id 外键混淆)
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '02a45b66b7c4'
|
||||
down_revision = 'ad55f1d008b7'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _column_exists(table_name: str, column_name: str) -> bool:
|
||||
"""检查列是否存在"""
|
||||
bind = op.get_bind()
|
||||
inspector = inspect(bind)
|
||||
columns = [col['name'] for col in inspector.get_columns(table_name)]
|
||||
return column_name in columns
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""
|
||||
1. users.allowed_endpoints -> allowed_api_formats(重命名)
|
||||
2. api_keys.allowed_endpoints 删除
|
||||
3. providers.rate_limit 删除(与 rpm_limit 重复)
|
||||
4. usage.provider -> provider_name(重命名)
|
||||
"""
|
||||
# 1. users 表:重命名 allowed_endpoints 为 allowed_api_formats
|
||||
if _column_exists('users', 'allowed_endpoints'):
|
||||
op.alter_column('users', 'allowed_endpoints', new_column_name='allowed_api_formats')
|
||||
|
||||
# 2. api_keys 表:删除 allowed_endpoints 字段
|
||||
if _column_exists('api_keys', 'allowed_endpoints'):
|
||||
op.drop_column('api_keys', 'allowed_endpoints')
|
||||
|
||||
# 3. providers 表:删除 rate_limit 字段(与 rpm_limit 功能重复)
|
||||
if _column_exists('providers', 'rate_limit'):
|
||||
op.drop_column('providers', 'rate_limit')
|
||||
|
||||
# 4. usage 表:重命名 provider 为 provider_name
|
||||
if _column_exists('usage', 'provider'):
|
||||
op.alter_column('usage', 'provider', new_column_name='provider_name')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""回滚:恢复原字段"""
|
||||
# 4. usage 表:将 provider_name 改回 provider
|
||||
if _column_exists('usage', 'provider_name'):
|
||||
op.alter_column('usage', 'provider_name', new_column_name='provider')
|
||||
|
||||
# 3. providers 表:恢复 rate_limit 字段
|
||||
if not _column_exists('providers', 'rate_limit'):
|
||||
op.add_column('providers', sa.Column('rate_limit', sa.Integer(), nullable=True))
|
||||
|
||||
# 2. api_keys 表:恢复 allowed_endpoints 字段
|
||||
if not _column_exists('api_keys', 'allowed_endpoints'):
|
||||
op.add_column('api_keys', sa.Column('allowed_endpoints', sa.JSON(), nullable=True))
|
||||
|
||||
# 1. users 表:将 allowed_api_formats 改回 allowed_endpoints
|
||||
if _column_exists('users', 'allowed_api_formats'):
|
||||
op.alter_column('users', 'allowed_api_formats', new_column_name='allowed_endpoints')
|
||||
@@ -17,7 +17,7 @@ services:
|
||||
ports:
|
||||
- "${DB_PORT:-5432}:5432"
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||
test: [ "CMD-SHELL", "pg_isready -U postgres" ]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
@@ -32,7 +32,7 @@ services:
|
||||
ports:
|
||||
- "${REDIS_PORT:-6379}:6379"
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
||||
test: [ "CMD", "redis-cli", "--raw", "incr", "ping" ]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
@@ -44,20 +44,15 @@ services:
|
||||
dockerfile: Dockerfile.app.local
|
||||
image: aether-app:latest
|
||||
container_name: aether-app
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
# 需要组合的变量
|
||||
DATABASE_URL: postgresql://postgres:${DB_PASSWORD}@postgres:5432/aether
|
||||
REDIS_URL: redis://:${REDIS_PASSWORD}@redis:6379/0
|
||||
PORT: 8084
|
||||
JWT_SECRET_KEY: ${JWT_SECRET_KEY}
|
||||
ENCRYPTION_KEY: ${ENCRYPTION_KEY}
|
||||
JWT_ALGORITHM: HS256
|
||||
JWT_EXPIRATION_DELTA: 86400
|
||||
LOG_LEVEL: ${LOG_LEVEL:-INFO}
|
||||
ADMIN_EMAIL: ${ADMIN_EMAIL}
|
||||
ADMIN_USERNAME: ${ADMIN_USERNAME}
|
||||
ADMIN_PASSWORD: ${ADMIN_PASSWORD}
|
||||
API_KEY_PREFIX: ${API_KEY_PREFIX:-sk}
|
||||
# Supervisor 需要的变量
|
||||
GUNICORN_WORKERS: ${GUNICORN_WORKERS:-4}
|
||||
# 容器级别设置
|
||||
TZ: Asia/Shanghai
|
||||
PYTHONIOENCODING: utf-8
|
||||
LANG: C.UTF-8
|
||||
|
||||
@@ -13,7 +13,7 @@ services:
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||
test: [ "CMD-SHELL", "pg_isready -U postgres" ]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
@@ -26,7 +26,7 @@ services:
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
||||
test: [ "CMD", "redis-cli", "--raw", "incr", "ping" ]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
@@ -35,20 +35,15 @@ services:
|
||||
app:
|
||||
image: ghcr.io/fawney19/aether:latest
|
||||
container_name: aether-app
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
# 需要组合的变量
|
||||
DATABASE_URL: postgresql://postgres:${DB_PASSWORD}@postgres:5432/aether
|
||||
REDIS_URL: redis://:${REDIS_PASSWORD}@redis:6379/0
|
||||
PORT: 8084
|
||||
JWT_SECRET_KEY: ${JWT_SECRET_KEY}
|
||||
ENCRYPTION_KEY: ${ENCRYPTION_KEY}
|
||||
JWT_ALGORITHM: HS256
|
||||
JWT_EXPIRATION_DELTA: 86400
|
||||
LOG_LEVEL: ${LOG_LEVEL:-INFO}
|
||||
ADMIN_EMAIL: ${ADMIN_EMAIL}
|
||||
ADMIN_USERNAME: ${ADMIN_USERNAME}
|
||||
ADMIN_PASSWORD: ${ADMIN_PASSWORD}
|
||||
API_KEY_PREFIX: ${API_KEY_PREFIX:-sk}
|
||||
# Supervisor 需要的变量
|
||||
GUNICORN_WORKERS: ${GUNICORN_WORKERS:-4}
|
||||
# 容器级别设置
|
||||
TZ: Asia/Shanghai
|
||||
PYTHONIOENCODING: utf-8
|
||||
LANG: C.UTF-8
|
||||
|
||||
@@ -13,6 +13,7 @@ export interface UsersExportData {
|
||||
version: string
|
||||
exported_at: string
|
||||
users: UserExport[]
|
||||
standalone_keys?: StandaloneKeyExport[]
|
||||
}
|
||||
|
||||
export interface UserExport {
|
||||
@@ -21,7 +22,7 @@ export interface UserExport {
|
||||
password_hash: string
|
||||
role: string
|
||||
allowed_providers?: string[] | null
|
||||
allowed_endpoints?: string[] | null
|
||||
allowed_api_formats?: string[] | null
|
||||
allowed_models?: string[] | null
|
||||
model_capability_settings?: any
|
||||
quota_usd?: number | null
|
||||
@@ -39,18 +40,21 @@ export interface UserApiKeyExport {
|
||||
balance_used_usd?: number
|
||||
current_balance_usd?: number | null
|
||||
allowed_providers?: string[] | null
|
||||
allowed_endpoints?: string[] | null
|
||||
allowed_api_formats?: string[] | null
|
||||
allowed_models?: string[] | null
|
||||
rate_limit?: number | null // null = 无限制
|
||||
concurrent_limit?: number | null
|
||||
force_capabilities?: any
|
||||
is_active: boolean
|
||||
expires_at?: string | null
|
||||
auto_delete_on_expiry?: boolean
|
||||
total_requests?: number
|
||||
total_cost_usd?: number
|
||||
}
|
||||
|
||||
// 独立余额 Key 导出结构(与 UserApiKeyExport 相同,但不包含 is_standalone)
|
||||
export type StandaloneKeyExport = Omit<UserApiKeyExport, 'is_standalone'>
|
||||
|
||||
export interface GlobalModelExport {
|
||||
name: string
|
||||
display_name: string
|
||||
@@ -155,6 +159,44 @@ export interface EmailTemplateResetResponse {
|
||||
}
|
||||
}
|
||||
|
||||
// LDAP 配置响应
|
||||
export interface LdapConfigResponse {
|
||||
server_url: string | null
|
||||
bind_dn: string | null
|
||||
base_dn: string | null
|
||||
has_bind_password: boolean
|
||||
user_search_filter: string
|
||||
username_attr: string
|
||||
email_attr: string
|
||||
display_name_attr: string
|
||||
is_enabled: boolean
|
||||
is_exclusive: boolean
|
||||
use_starttls: boolean
|
||||
connect_timeout: number
|
||||
}
|
||||
|
||||
// LDAP 配置更新请求
|
||||
export interface LdapConfigUpdateRequest {
|
||||
server_url: string
|
||||
bind_dn: string
|
||||
bind_password?: string
|
||||
base_dn: string
|
||||
user_search_filter?: string
|
||||
username_attr?: string
|
||||
email_attr?: string
|
||||
display_name_attr?: string
|
||||
is_enabled?: boolean
|
||||
is_exclusive?: boolean
|
||||
use_starttls?: boolean
|
||||
connect_timeout?: number
|
||||
}
|
||||
|
||||
// LDAP 连接测试响应
|
||||
export interface LdapTestResponse {
|
||||
success: boolean
|
||||
message: string
|
||||
}
|
||||
|
||||
// Provider 模型查询响应
|
||||
export interface ProviderModelsQueryResponse {
|
||||
success: boolean
|
||||
@@ -189,6 +231,7 @@ export interface UsersImportResponse {
|
||||
stats: {
|
||||
users: { created: number; updated: number; skipped: number }
|
||||
api_keys: { created: number; skipped: number }
|
||||
standalone_keys?: { created: number; skipped: number }
|
||||
errors: string[]
|
||||
}
|
||||
}
|
||||
@@ -473,5 +516,35 @@ export const adminApi = {
|
||||
`/api/admin/system/email/templates/${templateType}/reset`
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 获取系统版本信息
|
||||
async getSystemVersion(): Promise<{ version: string }> {
|
||||
const response = await apiClient.get<{ version: string }>(
|
||||
'/api/admin/system/version'
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
// LDAP 配置相关
|
||||
// 获取 LDAP 配置
|
||||
async getLdapConfig(): Promise<LdapConfigResponse> {
|
||||
const response = await apiClient.get<LdapConfigResponse>('/api/admin/ldap/config')
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 更新 LDAP 配置
|
||||
async updateLdapConfig(config: LdapConfigUpdateRequest): Promise<{ message: string }> {
|
||||
const response = await apiClient.put<{ message: string }>(
|
||||
'/api/admin/ldap/config',
|
||||
config
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 测试 LDAP 连接
|
||||
async testLdapConnection(config: LdapConfigUpdateRequest): Promise<LdapTestResponse> {
|
||||
const response = await apiClient.post<LdapTestResponse>('/api/admin/ldap/test', config)
|
||||
return response.data
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import { log } from '@/utils/logger'
|
||||
export interface LoginRequest {
|
||||
email: string
|
||||
password: string
|
||||
auth_type?: 'local' | 'ldap'
|
||||
}
|
||||
|
||||
export interface LoginResponse {
|
||||
@@ -81,6 +82,12 @@ export interface RegistrationSettingsResponse {
|
||||
require_email_verification: boolean
|
||||
}
|
||||
|
||||
export interface AuthSettingsResponse {
|
||||
local_enabled: boolean
|
||||
ldap_enabled: boolean
|
||||
ldap_exclusive: boolean
|
||||
}
|
||||
|
||||
export interface User {
|
||||
id: string // UUID
|
||||
username: string
|
||||
@@ -91,7 +98,7 @@ export interface User {
|
||||
used_usd?: number
|
||||
total_usd?: number
|
||||
allowed_providers?: string[] | null // 允许使用的提供商 ID 列表
|
||||
allowed_endpoints?: string[] | null // 允许使用的端点 ID 列表
|
||||
allowed_api_formats?: string[] | null // 允许使用的 API 格式列表
|
||||
allowed_models?: string[] | null // 允许使用的模型名称列表
|
||||
created_at: string
|
||||
last_login_at?: string
|
||||
@@ -173,5 +180,10 @@ export const authApi = {
|
||||
{ email }
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
async getAuthSettings(): Promise<AuthSettingsResponse> {
|
||||
const response = await apiClient.get<AuthSettingsResponse>('/api/auth/settings')
|
||||
return response.data
|
||||
}
|
||||
}
|
||||
|
||||
203
frontend/src/api/management-tokens.ts
Normal file
203
frontend/src/api/management-tokens.ts
Normal file
@@ -0,0 +1,203 @@
|
||||
/**
|
||||
* Management Token API
|
||||
*/
|
||||
|
||||
import apiClient from './client'
|
||||
|
||||
// ============== 类型定义 ==============
|
||||
|
||||
export interface ManagementToken {
|
||||
id: string
|
||||
user_id: string
|
||||
name: string
|
||||
description?: string
|
||||
token_display: string
|
||||
allowed_ips?: string[] | null
|
||||
expires_at?: string | null
|
||||
last_used_at?: string | null
|
||||
last_used_ip?: string | null
|
||||
usage_count: number
|
||||
is_active: boolean
|
||||
created_at: string
|
||||
updated_at: string
|
||||
user?: {
|
||||
id: string
|
||||
email: string
|
||||
username: string
|
||||
role: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface CreateManagementTokenRequest {
|
||||
name: string
|
||||
description?: string
|
||||
allowed_ips?: string[]
|
||||
expires_at?: string | null
|
||||
}
|
||||
|
||||
export interface CreateManagementTokenResponse {
|
||||
message: string
|
||||
token: string
|
||||
data: ManagementToken
|
||||
}
|
||||
|
||||
export interface UpdateManagementTokenRequest {
|
||||
name?: string
|
||||
description?: string | null
|
||||
allowed_ips?: string[] | null
|
||||
expires_at?: string | null
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
export interface ManagementTokenListResponse {
|
||||
items: ManagementToken[]
|
||||
total: number
|
||||
skip: number
|
||||
limit: number
|
||||
quota?: {
|
||||
used: number
|
||||
max: number
|
||||
}
|
||||
}
|
||||
|
||||
// ============== 用户自助管理 API ==============
|
||||
|
||||
export const managementTokenApi = {
|
||||
/**
|
||||
* 列出当前用户的 Management Tokens
|
||||
*/
|
||||
async listTokens(params?: {
|
||||
is_active?: boolean
|
||||
skip?: number
|
||||
limit?: number
|
||||
}): Promise<ManagementTokenListResponse> {
|
||||
const response = await apiClient.get<ManagementTokenListResponse>(
|
||||
'/api/me/management-tokens',
|
||||
{ params }
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 创建 Management Token
|
||||
*/
|
||||
async createToken(
|
||||
data: CreateManagementTokenRequest
|
||||
): Promise<CreateManagementTokenResponse> {
|
||||
const response = await apiClient.post<CreateManagementTokenResponse>(
|
||||
'/api/me/management-tokens',
|
||||
data
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 获取 Token 详情
|
||||
*/
|
||||
async getToken(tokenId: string): Promise<ManagementToken> {
|
||||
const response = await apiClient.get<ManagementToken>(
|
||||
`/api/me/management-tokens/${tokenId}`
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 更新 Token
|
||||
*/
|
||||
async updateToken(
|
||||
tokenId: string,
|
||||
data: UpdateManagementTokenRequest
|
||||
): Promise<{ message: string; data: ManagementToken }> {
|
||||
const response = await apiClient.put<{ message: string; data: ManagementToken }>(
|
||||
`/api/me/management-tokens/${tokenId}`,
|
||||
data
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 删除 Token
|
||||
*/
|
||||
async deleteToken(tokenId: string): Promise<{ message: string }> {
|
||||
const response = await apiClient.delete<{ message: string }>(
|
||||
`/api/me/management-tokens/${tokenId}`
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 切换 Token 状态
|
||||
*/
|
||||
async toggleToken(
|
||||
tokenId: string
|
||||
): Promise<{ message: string; data: ManagementToken }> {
|
||||
const response = await apiClient.patch<{ message: string; data: ManagementToken }>(
|
||||
`/api/me/management-tokens/${tokenId}/status`
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 重新生成 Token
|
||||
*/
|
||||
async regenerateToken(
|
||||
tokenId: string
|
||||
): Promise<{ token: string; data: ManagementToken }> {
|
||||
const response = await apiClient.post<{ token: string; data: ManagementToken }>(
|
||||
`/api/me/management-tokens/${tokenId}/regenerate`
|
||||
)
|
||||
return response.data
|
||||
}
|
||||
}
|
||||
|
||||
// ============== 管理员 API ==============
|
||||
|
||||
export const adminManagementTokenApi = {
|
||||
/**
|
||||
* 列出所有 Management Tokens(管理员)
|
||||
*/
|
||||
async listAllTokens(params?: {
|
||||
user_id?: string
|
||||
is_active?: boolean
|
||||
skip?: number
|
||||
limit?: number
|
||||
}): Promise<ManagementTokenListResponse> {
|
||||
const response = await apiClient.get<ManagementTokenListResponse>(
|
||||
'/api/admin/management-tokens',
|
||||
{ params }
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 获取 Token 详情(管理员)
|
||||
*/
|
||||
async getToken(tokenId: string): Promise<ManagementToken> {
|
||||
const response = await apiClient.get<ManagementToken>(
|
||||
`/api/admin/management-tokens/${tokenId}`
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 删除任意 Token(管理员)
|
||||
*/
|
||||
async deleteToken(tokenId: string): Promise<{ message: string }> {
|
||||
const response = await apiClient.delete<{ message: string }>(
|
||||
`/api/admin/management-tokens/${tokenId}`
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 切换任意 Token 状态(管理员)
|
||||
*/
|
||||
async toggleToken(
|
||||
tokenId: string
|
||||
): Promise<{ message: string; data: ManagementToken }> {
|
||||
const response = await apiClient.patch<{ message: string; data: ManagementToken }>(
|
||||
`/api/admin/management-tokens/${tokenId}/status`
|
||||
)
|
||||
return response.data
|
||||
}
|
||||
}
|
||||
@@ -62,6 +62,11 @@ export interface UsageRecordDetail {
|
||||
cache_creation_price_per_1m?: number
|
||||
cache_read_price_per_1m?: number
|
||||
price_per_request?: number // 按次计费价格
|
||||
api_key?: {
|
||||
id: string
|
||||
name: string
|
||||
display: string
|
||||
}
|
||||
}
|
||||
|
||||
// 模型统计接口
|
||||
@@ -192,6 +197,7 @@ export const meApi = {
|
||||
async getUsage(params?: {
|
||||
start_date?: string
|
||||
end_date?: string
|
||||
search?: string // 通用搜索:密钥名、模型名
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<UsageResponse> {
|
||||
|
||||
@@ -192,10 +192,17 @@ export async function getModelsDevList(officialOnly: boolean = true): Promise<Mo
|
||||
}
|
||||
}
|
||||
|
||||
// 按 provider 名称和模型名称排序
|
||||
// 按 provider 名称排序,provider 中的模型按 release_date 从近到远排序
|
||||
items.sort((a, b) => {
|
||||
const providerCompare = a.providerName.localeCompare(b.providerName)
|
||||
if (providerCompare !== 0) return providerCompare
|
||||
|
||||
// 模型按 release_date 从近到远排序(没有日期的排到最后)
|
||||
const aDate = a.releaseDate ? new Date(a.releaseDate).getTime() : 0
|
||||
const bDate = b.releaseDate ? new Date(b.releaseDate).getTime() : 0
|
||||
if (aDate !== bDate) return bDate - aDate // 降序:新的在前
|
||||
|
||||
// 日期相同或都没有日期时,按模型名称排序
|
||||
return a.modelName.localeCompare(b.modelName)
|
||||
})
|
||||
|
||||
|
||||
@@ -164,6 +164,7 @@ export const usageApi = {
|
||||
async getAllUsageRecords(params?: {
|
||||
start_date?: string
|
||||
end_date?: string
|
||||
search?: string // 通用搜索:用户名、密钥名、模型名、提供商名
|
||||
user_id?: string // UUID
|
||||
username?: string
|
||||
model?: string
|
||||
|
||||
@@ -10,7 +10,7 @@ export interface User {
|
||||
used_usd: number
|
||||
total_usd: number
|
||||
allowed_providers: string[] | null // 允许使用的提供商 ID 列表
|
||||
allowed_endpoints: string[] | null // 允许使用的端点 ID 列表
|
||||
allowed_api_formats: string[] | null // 允许使用的 API 格式列表
|
||||
allowed_models: string[] | null // 允许使用的模型名称列表
|
||||
created_at: string
|
||||
updated_at?: string
|
||||
@@ -23,7 +23,7 @@ export interface CreateUserRequest {
|
||||
role?: 'admin' | 'user'
|
||||
quota_usd?: number | null
|
||||
allowed_providers?: string[] | null
|
||||
allowed_endpoints?: string[] | null
|
||||
allowed_api_formats?: string[] | null
|
||||
allowed_models?: string[] | null
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ export interface UpdateUserRequest {
|
||||
quota_usd?: number | null
|
||||
password?: string
|
||||
allowed_providers?: string[] | null
|
||||
allowed_endpoints?: string[] | null
|
||||
allowed_api_formats?: string[] | null
|
||||
allowed_models?: string[] | null
|
||||
}
|
||||
|
||||
|
||||
13
frontend/src/components/icons/GithubIcon.vue
Normal file
13
frontend/src/components/icons/GithubIcon.vue
Normal file
@@ -0,0 +1,13 @@
|
||||
<template>
|
||||
<svg
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path d="M15 22v-4a4.8 4.8 0 0 0-1-3.5c3 0 6-2 6-5.5.08-1.25-.27-2.48-1-3.5.28-1.15.28-2.35 0-3.5 0 0-1 0-3 1.5-2.64-.5-5.36-.5-8 0C6 2 5 2 5 2c-.3 1.15-.3 2.35 0 3.5A5.403 5.403 0 0 0 4 9c0 3.5 3 5.5 6 5.5-.39.49-.68 1.05-.85 1.65-.17.6-.22 1.23-.15 1.85v4" />
|
||||
<path d="M9 18c-4.51 2-5-2-7-2" />
|
||||
</svg>
|
||||
</template>
|
||||
@@ -18,7 +18,7 @@
|
||||
v-if="isOpen"
|
||||
class="fixed inset-0 bg-black/40 backdrop-blur-sm transition-opacity pointer-events-auto"
|
||||
:style="{ zIndex: backdropZIndex }"
|
||||
@click="handleClose"
|
||||
@click="handleBackdropClick"
|
||||
/>
|
||||
</Transition>
|
||||
|
||||
@@ -106,6 +106,7 @@ const props = defineProps<{
|
||||
iconClass?: string // Custom icon color class
|
||||
zIndex?: number // Custom z-index for nested dialogs (default: 60)
|
||||
noPadding?: boolean // Disable default content padding
|
||||
persistent?: boolean // Prevent closing on backdrop click
|
||||
}>()
|
||||
|
||||
// Emits 定义
|
||||
@@ -138,6 +139,13 @@ function handleClose() {
|
||||
}
|
||||
}
|
||||
|
||||
// 处理背景点击
|
||||
function handleBackdropClick() {
|
||||
if (!props.persistent) {
|
||||
handleClose()
|
||||
}
|
||||
}
|
||||
|
||||
const maxWidthClass = computed(() => {
|
||||
const sizeValue = props.maxWidth || props.size || 'md'
|
||||
const sizes = {
|
||||
@@ -162,7 +170,7 @@ const contentZIndex = computed(() => (props.zIndex || 60) + 10)
|
||||
|
||||
// 添加 ESC 键监听
|
||||
useEscapeKey(() => {
|
||||
if (isOpen.value) {
|
||||
if (isOpen.value && !props.persistent) {
|
||||
handleClose()
|
||||
return true // 阻止其他监听器(如父级抽屉的 ESC 监听器)
|
||||
}
|
||||
|
||||
@@ -66,19 +66,61 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 认证方式切换 -->
|
||||
<div
|
||||
v-if="showAuthTypeTabs"
|
||||
class="auth-type-tabs"
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
class="auth-tab"
|
||||
:class="[authType === 'local' && 'active']"
|
||||
@click="authType = 'local'"
|
||||
>
|
||||
本地登录
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="auth-tab"
|
||||
:class="[authType === 'ldap' && 'active']"
|
||||
@click="authType = 'ldap'"
|
||||
>
|
||||
LDAP 登录
|
||||
</button>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- 登录表单 -->
|
||||
<form
|
||||
class="space-y-4"
|
||||
@submit.prevent="handleLogin"
|
||||
>
|
||||
<div class="space-y-2">
|
||||
<Label for="login-email">邮箱</Label>
|
||||
<div class="flex items-center justify-between">
|
||||
<Label for="login-email">{{ emailLabel }}</Label>
|
||||
<button
|
||||
v-if="ldapExclusive && authType === 'ldap'"
|
||||
type="button"
|
||||
class="text-xs text-muted-foreground/60 hover:text-muted-foreground transition-colors"
|
||||
@click="authType = 'local'"
|
||||
>
|
||||
管理员本地登录
|
||||
</button>
|
||||
<button
|
||||
v-if="ldapExclusive && authType === 'local'"
|
||||
type="button"
|
||||
class="text-xs text-muted-foreground/60 hover:text-muted-foreground transition-colors"
|
||||
@click="authType = 'ldap'"
|
||||
>
|
||||
返回 LDAP 登录
|
||||
</button>
|
||||
</div>
|
||||
<Input
|
||||
id="login-email"
|
||||
v-model="form.email"
|
||||
type="email"
|
||||
type="text"
|
||||
required
|
||||
placeholder="hello@example.com"
|
||||
placeholder="username 或 email"
|
||||
autocomplete="off"
|
||||
/>
|
||||
</div>
|
||||
@@ -180,6 +222,30 @@ const showRegisterDialog = ref(false)
|
||||
const requireEmailVerification = ref(false)
|
||||
const allowRegistration = ref(false) // 由系统配置控制,默认关闭
|
||||
|
||||
// LDAP authentication settings
|
||||
const PREFERRED_AUTH_TYPE_KEY = 'aether_preferred_auth_type'
|
||||
function getStoredAuthType(): 'local' | 'ldap' {
|
||||
const stored = localStorage.getItem(PREFERRED_AUTH_TYPE_KEY)
|
||||
return (stored === 'ldap' || stored === 'local') ? stored : 'local'
|
||||
}
|
||||
const authType = ref<'local' | 'ldap'>(getStoredAuthType())
|
||||
const localEnabled = ref(true)
|
||||
const ldapEnabled = ref(false)
|
||||
const ldapExclusive = ref(false)
|
||||
|
||||
// 保存用户的认证类型偏好
|
||||
watch(authType, (newType) => {
|
||||
localStorage.setItem(PREFERRED_AUTH_TYPE_KEY, newType)
|
||||
})
|
||||
|
||||
const showAuthTypeTabs = computed(() => {
|
||||
return localEnabled.value && ldapEnabled.value && !ldapExclusive.value
|
||||
})
|
||||
|
||||
const emailLabel = computed(() => {
|
||||
return '用户名/邮箱'
|
||||
})
|
||||
|
||||
watch(() => props.modelValue, (val) => {
|
||||
isOpen.value = val
|
||||
// 打开对话框时重置表单
|
||||
@@ -212,7 +278,7 @@ async function handleLogin() {
|
||||
return
|
||||
}
|
||||
|
||||
const success = await authStore.login(form.value.email, form.value.password)
|
||||
const success = await authStore.login(form.value.email, form.value.password, authType.value)
|
||||
if (success) {
|
||||
showSuccess('登录成功,正在跳转...')
|
||||
|
||||
@@ -246,16 +312,84 @@ function handleSwitchToLogin() {
|
||||
isOpen.value = true
|
||||
}
|
||||
|
||||
// Load registration settings on mount
|
||||
// Load authentication and registration settings on mount
|
||||
onMounted(async () => {
|
||||
try {
|
||||
const settings = await authApi.getRegistrationSettings()
|
||||
allowRegistration.value = !!settings.enable_registration
|
||||
requireEmailVerification.value = !!settings.require_email_verification
|
||||
// Load registration settings
|
||||
const regSettings = await authApi.getRegistrationSettings()
|
||||
allowRegistration.value = !!regSettings.enable_registration
|
||||
requireEmailVerification.value = !!regSettings.require_email_verification
|
||||
|
||||
// Load authentication settings
|
||||
const authSettings = await authApi.getAuthSettings()
|
||||
localEnabled.value = authSettings.local_enabled
|
||||
ldapEnabled.value = authSettings.ldap_enabled
|
||||
ldapExclusive.value = authSettings.ldap_exclusive
|
||||
// 若仅允许 LDAP 登录,则禁用本地注册入口
|
||||
if (ldapExclusive.value) {
|
||||
allowRegistration.value = false
|
||||
}
|
||||
|
||||
// Set default auth type based on settings
|
||||
if (authSettings.ldap_exclusive) {
|
||||
authType.value = 'ldap'
|
||||
} else if (!authSettings.local_enabled && authSettings.ldap_enabled) {
|
||||
authType.value = 'ldap'
|
||||
} else {
|
||||
authType.value = 'local'
|
||||
}
|
||||
} catch (error) {
|
||||
// If获取失败,保持默认:关闭注册 & 关闭邮箱验证
|
||||
// If获取失败,保持默认:关闭注册 & 关闭邮箱验证 & 使用本地认证
|
||||
allowRegistration.value = false
|
||||
requireEmailVerification.value = false
|
||||
localEnabled.value = true
|
||||
ldapEnabled.value = false
|
||||
ldapExclusive.value = false
|
||||
authType.value = 'local'
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.auth-type-tabs {
|
||||
display: flex;
|
||||
border-bottom: 1px solid hsl(var(--border));
|
||||
}
|
||||
|
||||
.auth-tab {
|
||||
flex: 1;
|
||||
padding: 0.625rem 1rem;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
color: hsl(var(--muted-foreground));
|
||||
background: transparent;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
transition: color 0.15s ease;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.auth-tab::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
bottom: -1px;
|
||||
left: 0;
|
||||
right: 0;
|
||||
height: 2px;
|
||||
background: transparent;
|
||||
transition: background 0.15s ease;
|
||||
}
|
||||
|
||||
.auth-tab:hover:not(.active) {
|
||||
color: hsl(var(--foreground));
|
||||
}
|
||||
|
||||
.auth-tab.active {
|
||||
color: var(--book-cloth);
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.auth-tab.active::after {
|
||||
background: var(--book-cloth);
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -32,6 +32,17 @@
|
||||
<!-- 分隔线 -->
|
||||
<div class="hidden sm:block h-4 w-px bg-border" />
|
||||
|
||||
<!-- 通用搜索 -->
|
||||
<div class="relative">
|
||||
<Search class="absolute left-2.5 top-1/2 -translate-y-1/2 h-3.5 w-3.5 text-muted-foreground z-10 pointer-events-none" />
|
||||
<Input
|
||||
id="usage-records-search"
|
||||
v-model="localSearch"
|
||||
:placeholder="isAdmin ? '搜索用户/密钥/模型/提供商' : '搜索密钥/模型'"
|
||||
class="w-32 sm:w-48 h-8 text-xs border-border/60 pl-8"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- 用户筛选(仅管理员可见) -->
|
||||
<Select
|
||||
v-if="isAdmin && availableUsers.length > 0"
|
||||
@@ -164,6 +175,12 @@
|
||||
>
|
||||
用户
|
||||
</TableHead>
|
||||
<TableHead
|
||||
v-if="!isAdmin"
|
||||
class="h-12 font-semibold w-[100px]"
|
||||
>
|
||||
密钥
|
||||
</TableHead>
|
||||
<TableHead class="h-12 font-semibold w-[140px]">
|
||||
模型
|
||||
</TableHead>
|
||||
@@ -196,7 +213,7 @@
|
||||
<TableBody>
|
||||
<TableRow v-if="records.length === 0">
|
||||
<TableCell
|
||||
:colspan="isAdmin ? 9 : 7"
|
||||
:colspan="isAdmin ? 9 : 8"
|
||||
class="text-center py-12 text-muted-foreground"
|
||||
>
|
||||
暂无请求记录
|
||||
@@ -218,7 +235,34 @@
|
||||
class="py-4 w-[100px] truncate"
|
||||
:title="record.username || record.user_email || (record.user_id ? `User ${record.user_id}` : '已删除用户')"
|
||||
>
|
||||
{{ record.username || record.user_email || (record.user_id ? `User ${record.user_id}` : '已删除用户') }}
|
||||
<div class="flex flex-col text-xs gap-0.5">
|
||||
<span class="truncate">
|
||||
{{ record.username || record.user_email || (record.user_id ? `User ${record.user_id}` : '已删除用户') }}
|
||||
</span>
|
||||
<span
|
||||
v-if="record.api_key?.name"
|
||||
class="text-muted-foreground truncate"
|
||||
:title="record.api_key.name"
|
||||
>
|
||||
{{ record.api_key.name }}
|
||||
</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<!-- 用户页面的密钥列 -->
|
||||
<TableCell
|
||||
v-if="!isAdmin"
|
||||
class="py-4 w-[100px]"
|
||||
:title="record.api_key?.name || '-'"
|
||||
>
|
||||
<div class="flex flex-col text-xs gap-0.5">
|
||||
<span class="truncate">{{ record.api_key?.name || '-' }}</span>
|
||||
<span
|
||||
v-if="record.api_key?.display"
|
||||
class="text-muted-foreground truncate"
|
||||
>
|
||||
{{ record.api_key.display }}
|
||||
</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell
|
||||
class="font-medium py-4 w-[140px]"
|
||||
@@ -438,6 +482,7 @@ import {
|
||||
TableCard,
|
||||
Badge,
|
||||
Button,
|
||||
Input,
|
||||
Select,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
@@ -451,7 +496,7 @@ import {
|
||||
TableCell,
|
||||
Pagination,
|
||||
} from '@/components/ui'
|
||||
import { RefreshCcw } from 'lucide-vue-next'
|
||||
import { RefreshCcw, Search } from 'lucide-vue-next'
|
||||
import { formatTokens, formatCurrency } from '@/utils/format'
|
||||
import { formatDateTime } from '../composables'
|
||||
import { useRowClick } from '@/composables/useRowClick'
|
||||
@@ -471,6 +516,7 @@ const props = defineProps<{
|
||||
// 时间段
|
||||
selectedPeriod: string
|
||||
// 筛选
|
||||
filterSearch: string
|
||||
filterUser: string
|
||||
filterModel: string
|
||||
filterProvider: string
|
||||
@@ -489,6 +535,7 @@ const props = defineProps<{
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:selectedPeriod': [value: string]
|
||||
'update:filterSearch': [value: string]
|
||||
'update:filterUser': [value: string]
|
||||
'update:filterModel': [value: string]
|
||||
'update:filterProvider': [value: string]
|
||||
@@ -507,6 +554,23 @@ const filterModelSelectOpen = ref(false)
|
||||
const filterProviderSelectOpen = ref(false)
|
||||
const filterStatusSelectOpen = ref(false)
|
||||
|
||||
// 通用搜索(输入防抖)
|
||||
const localSearch = ref(props.filterSearch)
|
||||
let searchDebounceTimer: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
watch(() => props.filterSearch, (value) => {
|
||||
if (value !== localSearch.value) {
|
||||
localSearch.value = value
|
||||
}
|
||||
})
|
||||
|
||||
watch(localSearch, (value) => {
|
||||
if (searchDebounceTimer) clearTimeout(searchDebounceTimer)
|
||||
searchDebounceTimer = setTimeout(() => {
|
||||
emit('update:filterSearch', value)
|
||||
}, 300)
|
||||
})
|
||||
|
||||
// 动态计时器相关
|
||||
const now = ref(Date.now())
|
||||
let timerInterval: ReturnType<typeof setInterval> | null = null
|
||||
@@ -574,6 +638,10 @@ function handleRowClick(event: MouseEvent, id: string) {
|
||||
// 组件卸载时清理
|
||||
onUnmounted(() => {
|
||||
stopTimer()
|
||||
if (searchDebounceTimer) {
|
||||
clearTimeout(searchDebounceTimer)
|
||||
searchDebounceTimer = null
|
||||
}
|
||||
})
|
||||
|
||||
// 格式化 API 格式显示名称
|
||||
|
||||
@@ -23,6 +23,7 @@ export interface PaginationParams {
|
||||
}
|
||||
|
||||
export interface FilterParams {
|
||||
search?: string
|
||||
user_id?: string
|
||||
model?: string
|
||||
provider?: string
|
||||
@@ -234,11 +235,6 @@ export function useUsageData(options: UseUsageDataOptions) {
|
||||
pagination: PaginationParams,
|
||||
filters?: FilterParams
|
||||
): Promise<void> {
|
||||
if (!isAdminPage.value) {
|
||||
// 用户页面不需要分页加载,记录已在 loadStats 中获取
|
||||
return
|
||||
}
|
||||
|
||||
isLoadingRecords.value = true
|
||||
|
||||
try {
|
||||
@@ -252,24 +248,34 @@ export function useUsageData(options: UseUsageDataOptions) {
|
||||
}
|
||||
|
||||
// 添加筛选条件
|
||||
if (filters?.user_id) {
|
||||
params.user_id = filters.user_id
|
||||
}
|
||||
if (filters?.model) {
|
||||
params.model = filters.model
|
||||
}
|
||||
if (filters?.provider) {
|
||||
params.provider = filters.provider
|
||||
}
|
||||
if (filters?.status) {
|
||||
params.status = filters.status
|
||||
if (filters?.search?.trim()) {
|
||||
params.search = filters.search.trim()
|
||||
}
|
||||
|
||||
const response = await usageApi.getAllUsageRecords(params)
|
||||
|
||||
currentRecords.value = (response.records || []) as UsageRecord[]
|
||||
totalRecords.value = response.total || 0
|
||||
if (isAdminPage.value) {
|
||||
// 管理员页面:使用管理员 API
|
||||
if (filters?.user_id) {
|
||||
params.user_id = filters.user_id
|
||||
}
|
||||
if (filters?.model) {
|
||||
params.model = filters.model
|
||||
}
|
||||
if (filters?.provider) {
|
||||
params.provider = filters.provider
|
||||
}
|
||||
if (filters?.status) {
|
||||
params.status = filters.status
|
||||
}
|
||||
|
||||
const response = await usageApi.getAllUsageRecords(params)
|
||||
currentRecords.value = (response.records || []) as UsageRecord[]
|
||||
totalRecords.value = response.total || 0
|
||||
} else {
|
||||
// 用户页面:使用用户 API
|
||||
const userData = await meApi.getUsage(params)
|
||||
currentRecords.value = (userData.records || []) as UsageRecord[]
|
||||
totalRecords.value = userData.pagination?.total || currentRecords.value.length
|
||||
}
|
||||
} catch (error) {
|
||||
log.error('加载记录失败:', error)
|
||||
currentRecords.value = []
|
||||
|
||||
@@ -61,6 +61,11 @@ export interface UsageRecord {
|
||||
user_id?: string
|
||||
username?: string
|
||||
user_email?: string
|
||||
api_key?: {
|
||||
id: string | null
|
||||
name: string | null
|
||||
display: string | null
|
||||
} | null
|
||||
provider: string
|
||||
api_key_name?: string
|
||||
rate_multiplier?: number
|
||||
|
||||
@@ -273,8 +273,8 @@
|
||||
class="w-full h-10 px-3 border rounded-lg bg-background text-left flex items-center justify-between hover:bg-muted/50 transition-colors"
|
||||
@click="endpointDropdownOpen = !endpointDropdownOpen"
|
||||
>
|
||||
<span :class="form.allowed_endpoints.length ? 'text-foreground' : 'text-muted-foreground'">
|
||||
{{ form.allowed_endpoints.length ? `已选择 ${form.allowed_endpoints.length} 个` : '全部可用' }}
|
||||
<span :class="form.allowed_api_formats.length ? 'text-foreground' : 'text-muted-foreground'">
|
||||
{{ form.allowed_api_formats.length ? `已选择 ${form.allowed_api_formats.length} 个` : '全部可用' }}
|
||||
</span>
|
||||
<ChevronDown
|
||||
class="h-4 w-4 text-muted-foreground transition-transform"
|
||||
@@ -294,14 +294,14 @@
|
||||
v-for="format in apiFormats"
|
||||
:key="format.value"
|
||||
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer"
|
||||
@click="toggleSelection('allowed_endpoints', format.value)"
|
||||
@click="toggleSelection('allowed_api_formats', format.value)"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.allowed_endpoints.includes(format.value)"
|
||||
:checked="form.allowed_api_formats.includes(format.value)"
|
||||
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
|
||||
@click.stop
|
||||
@change="toggleSelection('allowed_endpoints', format.value)"
|
||||
@change="toggleSelection('allowed_api_formats', format.value)"
|
||||
>
|
||||
<span class="text-sm">{{ format.label }}</span>
|
||||
</div>
|
||||
@@ -374,7 +374,7 @@ export interface UserFormData {
|
||||
role: 'admin' | 'user'
|
||||
is_active?: boolean
|
||||
allowed_providers?: string[] | null
|
||||
allowed_endpoints?: string[] | null
|
||||
allowed_api_formats?: string[] | null
|
||||
allowed_models?: string[] | null
|
||||
}
|
||||
|
||||
@@ -414,7 +414,7 @@ const form = ref({
|
||||
unlimited: false,
|
||||
is_active: true,
|
||||
allowed_providers: [] as string[],
|
||||
allowed_endpoints: [] as string[],
|
||||
allowed_api_formats: [] as string[],
|
||||
allowed_models: [] as string[]
|
||||
})
|
||||
|
||||
@@ -435,7 +435,7 @@ function resetForm() {
|
||||
unlimited: false,
|
||||
is_active: true,
|
||||
allowed_providers: [],
|
||||
allowed_endpoints: [],
|
||||
allowed_api_formats: [],
|
||||
allowed_models: []
|
||||
}
|
||||
}
|
||||
@@ -454,7 +454,7 @@ function loadUserData() {
|
||||
unlimited: props.user.quota_usd == null,
|
||||
is_active: props.user.is_active ?? true,
|
||||
allowed_providers: props.user.allowed_providers || [],
|
||||
allowed_endpoints: props.user.allowed_endpoints || [],
|
||||
allowed_api_formats: props.user.allowed_api_formats || [],
|
||||
allowed_models: props.user.allowed_models || []
|
||||
}
|
||||
}
|
||||
@@ -495,7 +495,7 @@ async function loadAccessControlOptions() {
|
||||
}
|
||||
|
||||
// 切换选择
|
||||
function toggleSelection(field: 'allowed_providers' | 'allowed_endpoints' | 'allowed_models', value: string) {
|
||||
function toggleSelection(field: 'allowed_providers' | 'allowed_api_formats' | 'allowed_models', value: string) {
|
||||
const arr = form.value[field]
|
||||
const index = arr.indexOf(value)
|
||||
if (index === -1) {
|
||||
@@ -520,7 +520,7 @@ async function handleSubmit() {
|
||||
quota_usd: form.value.unlimited ? null : form.value.quota,
|
||||
role: form.value.role,
|
||||
allowed_providers: form.value.allowed_providers.length > 0 ? form.value.allowed_providers : null,
|
||||
allowed_endpoints: form.value.allowed_endpoints.length > 0 ? form.value.allowed_endpoints : null,
|
||||
allowed_api_formats: form.value.allowed_api_formats.length > 0 ? form.value.allowed_api_formats : null,
|
||||
allowed_models: form.value.allowed_models.length > 0 ? form.value.allowed_models : null
|
||||
}
|
||||
|
||||
|
||||
@@ -280,6 +280,16 @@
|
||||
class="h-4 w-4"
|
||||
/>
|
||||
</button>
|
||||
<!-- GitHub Link -->
|
||||
<a
|
||||
href="https://github.com/fawney19/Aether"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
class="flex h-9 w-9 items-center justify-center rounded-lg text-muted-foreground hover:text-foreground hover:bg-muted/50 transition"
|
||||
title="GitHub 仓库"
|
||||
>
|
||||
<GithubIcon class="h-4 w-4" />
|
||||
</a>
|
||||
</div>
|
||||
</header>
|
||||
</template>
|
||||
@@ -302,6 +312,7 @@ import {
|
||||
Home,
|
||||
Users,
|
||||
Key,
|
||||
KeyRound,
|
||||
BarChart3,
|
||||
Cog,
|
||||
Settings,
|
||||
@@ -322,6 +333,7 @@ import {
|
||||
X,
|
||||
Mail,
|
||||
} from 'lucide-vue-next'
|
||||
import GithubIcon from '@/components/icons/GithubIcon.vue'
|
||||
|
||||
const router = useRouter()
|
||||
const route = useRoute()
|
||||
@@ -387,6 +399,7 @@ const navigation = computed(() => {
|
||||
items: [
|
||||
{ name: '模型目录', href: '/dashboard/models', icon: Box },
|
||||
{ name: 'API 密钥', href: '/dashboard/api-keys', icon: Key },
|
||||
{ name: '访问令牌', href: '/dashboard/management-tokens', icon: KeyRound },
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -412,6 +425,7 @@ const navigation = computed(() => {
|
||||
{ name: '提供商', href: '/admin/providers', icon: FolderTree },
|
||||
{ name: '模型管理', href: '/admin/models', icon: Layers },
|
||||
{ name: '独立密钥', href: '/admin/keys', icon: Key },
|
||||
{ name: '访问令牌', href: '/admin/management-tokens', icon: KeyRound },
|
||||
{ name: '使用记录', href: '/admin/usage', icon: BarChart3 },
|
||||
]
|
||||
},
|
||||
@@ -423,6 +437,7 @@ const navigation = computed(() => {
|
||||
{ name: 'IP 安全', href: '/admin/ip-security', icon: Shield },
|
||||
{ name: '审计日志', href: '/admin/audit-logs', icon: AlertTriangle },
|
||||
{ name: '邮件配置', href: '/admin/email', icon: Mail },
|
||||
{ name: 'LDAP 配置', href: '/admin/ldap', icon: Shield },
|
||||
{ name: '系统设置', href: '/admin/system', icon: Cog },
|
||||
]
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ export const MOCK_ADMIN_USER: User = {
|
||||
used_usd: 156.78,
|
||||
total_usd: 1234.56,
|
||||
allowed_providers: null,
|
||||
allowed_endpoints: null,
|
||||
allowed_api_formats: null,
|
||||
allowed_models: null,
|
||||
created_at: '2024-01-01T00:00:00Z',
|
||||
last_login_at: new Date().toISOString()
|
||||
@@ -38,7 +38,7 @@ export const MOCK_NORMAL_USER: User = {
|
||||
used_usd: 45.32,
|
||||
total_usd: 245.32,
|
||||
allowed_providers: null,
|
||||
allowed_endpoints: null,
|
||||
allowed_api_formats: null,
|
||||
allowed_models: null,
|
||||
created_at: '2024-06-01T00:00:00Z',
|
||||
last_login_at: new Date().toISOString()
|
||||
@@ -274,7 +274,7 @@ export const MOCK_ALL_USERS: AdminUser[] = [
|
||||
used_usd: 156.78,
|
||||
total_usd: 1234.56,
|
||||
allowed_providers: null,
|
||||
allowed_endpoints: null,
|
||||
allowed_api_formats: null,
|
||||
allowed_models: null,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
@@ -288,7 +288,7 @@ export const MOCK_ALL_USERS: AdminUser[] = [
|
||||
used_usd: 45.32,
|
||||
total_usd: 245.32,
|
||||
allowed_providers: null,
|
||||
allowed_endpoints: null,
|
||||
allowed_api_formats: null,
|
||||
allowed_models: null,
|
||||
created_at: '2024-06-01T00:00:00Z'
|
||||
},
|
||||
@@ -302,7 +302,7 @@ export const MOCK_ALL_USERS: AdminUser[] = [
|
||||
used_usd: 23.45,
|
||||
total_usd: 123.45,
|
||||
allowed_providers: null,
|
||||
allowed_endpoints: null,
|
||||
allowed_api_formats: null,
|
||||
allowed_models: null,
|
||||
created_at: '2024-03-15T00:00:00Z'
|
||||
},
|
||||
@@ -316,7 +316,7 @@ export const MOCK_ALL_USERS: AdminUser[] = [
|
||||
used_usd: 89.12,
|
||||
total_usd: 589.12,
|
||||
allowed_providers: null,
|
||||
allowed_endpoints: null,
|
||||
allowed_api_formats: null,
|
||||
allowed_models: null,
|
||||
created_at: '2024-02-20T00:00:00Z'
|
||||
},
|
||||
@@ -330,7 +330,7 @@ export const MOCK_ALL_USERS: AdminUser[] = [
|
||||
used_usd: 30.00,
|
||||
total_usd: 30.00,
|
||||
allowed_providers: null,
|
||||
allowed_endpoints: null,
|
||||
allowed_api_formats: null,
|
||||
allowed_models: null,
|
||||
created_at: '2024-04-10T00:00:00Z'
|
||||
}
|
||||
|
||||
@@ -367,6 +367,11 @@ function generateMockUsageRecords(count: number = 100) {
|
||||
user_id: user.id,
|
||||
username: user.username,
|
||||
user_email: user.email,
|
||||
api_key: {
|
||||
id: `key-${user.id}-${Math.ceil(Math.random() * 2)}`,
|
||||
name: `${user.username} Key ${Math.ceil(Math.random() * 3)}`,
|
||||
display: `sk-ae...${String(1000 + Math.floor(Math.random() * 9000))}`
|
||||
},
|
||||
provider: model.provider,
|
||||
api_key_name: `${model.provider}-key-${Math.ceil(Math.random() * 3)}`,
|
||||
rate_multiplier: 1.0,
|
||||
@@ -685,7 +690,7 @@ const mockHandlers: Record<string, (config: AxiosRequestConfig) => Promise<Axios
|
||||
used_usd: 0,
|
||||
total_usd: 0,
|
||||
allowed_providers: null,
|
||||
allowed_endpoints: null,
|
||||
allowed_api_formats: null,
|
||||
allowed_models: null,
|
||||
created_at: new Date().toISOString()
|
||||
}
|
||||
@@ -835,10 +840,26 @@ const mockHandlers: Record<string, (config: AxiosRequestConfig) => Promise<Axios
|
||||
'GET /api/admin/usage/records': async (config) => {
|
||||
await delay()
|
||||
requireAdmin()
|
||||
const records = getUsageRecords()
|
||||
let records = getUsageRecords()
|
||||
const params = config.params || {}
|
||||
const limit = parseInt(params.limit) || 20
|
||||
const offset = parseInt(params.offset) || 0
|
||||
|
||||
// 通用搜索:用户名、密钥名、模型名、提供商名
|
||||
// 支持空格分隔的组合搜索,多个关键词之间是 AND 关系
|
||||
if (typeof params.search === 'string' && params.search.trim()) {
|
||||
const keywords = params.search.trim().toLowerCase().split(/\s+/)
|
||||
records = records.filter(r => {
|
||||
// 每个关键词都要匹配至少一个字段
|
||||
return keywords.every((keyword: string) =>
|
||||
(r.username || '').toLowerCase().includes(keyword) ||
|
||||
(r.api_key?.name || '').toLowerCase().includes(keyword) ||
|
||||
(r.model || '').toLowerCase().includes(keyword) ||
|
||||
(r.provider || '').toLowerCase().includes(keyword)
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
return createMockResponse({
|
||||
records: records.slice(offset, offset + limit),
|
||||
total: records.length,
|
||||
|
||||
@@ -34,6 +34,11 @@ const routes: RouteRecordRaw[] = [
|
||||
name: 'MyApiKeys',
|
||||
component: () => importWithRetry(() => import('@/views/user/MyApiKeys.vue'))
|
||||
},
|
||||
{
|
||||
path: 'management-tokens',
|
||||
name: 'ManagementTokens',
|
||||
component: () => importWithRetry(() => import('@/views/user/ManagementTokens.vue'))
|
||||
},
|
||||
{
|
||||
path: 'announcements',
|
||||
name: 'Announcements',
|
||||
@@ -81,6 +86,11 @@ const routes: RouteRecordRaw[] = [
|
||||
name: 'ApiKeys',
|
||||
component: () => importWithRetry(() => import('@/views/admin/ApiKeys.vue'))
|
||||
},
|
||||
{
|
||||
path: 'management-tokens',
|
||||
name: 'AdminManagementTokens',
|
||||
component: () => importWithRetry(() => import('@/views/user/ManagementTokens.vue'))
|
||||
},
|
||||
{
|
||||
path: 'providers',
|
||||
name: 'ProviderManagement',
|
||||
@@ -111,6 +121,11 @@ const routes: RouteRecordRaw[] = [
|
||||
name: 'EmailSettings',
|
||||
component: () => importWithRetry(() => import('@/views/admin/EmailSettings.vue'))
|
||||
},
|
||||
{
|
||||
path: 'ldap',
|
||||
name: 'LdapSettings',
|
||||
component: () => importWithRetry(() => import('@/views/admin/LdapSettings.vue'))
|
||||
},
|
||||
{
|
||||
path: 'audit-logs',
|
||||
name: 'AuditLogs',
|
||||
|
||||
@@ -31,12 +31,12 @@ export const useAuthStore = defineStore('auth', () => {
|
||||
}
|
||||
const isAdmin = computed(() => user.value?.role === 'admin')
|
||||
|
||||
async function login(email: string, password: string) {
|
||||
async function login(email: string, password: string, authType: 'local' | 'ldap' = 'local') {
|
||||
loading.value = true
|
||||
error.value = null
|
||||
|
||||
try {
|
||||
const response = await authApi.login({ email, password })
|
||||
const response = await authApi.login({ email, password, auth_type: authType })
|
||||
token.value = response.access_token
|
||||
|
||||
// 获取用户信息
|
||||
|
||||
@@ -106,23 +106,23 @@
|
||||
type="text"
|
||||
:placeholder="smtpPasswordIsSet ? '已设置(留空保持不变)' : '请输入密码'"
|
||||
class="-webkit-text-security-disc"
|
||||
:class="smtpPasswordIsSet ? 'pr-8' : ''"
|
||||
:class="(smtpPasswordIsSet || emailConfig.smtp_password) ? 'pr-10' : ''"
|
||||
autocomplete="one-time-code"
|
||||
data-lpignore="true"
|
||||
data-1p-ignore="true"
|
||||
data-form-type="other"
|
||||
/>
|
||||
<button
|
||||
v-if="smtpPasswordIsSet"
|
||||
v-if="smtpPasswordIsSet || emailConfig.smtp_password"
|
||||
type="button"
|
||||
class="absolute right-2 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground transition-colors"
|
||||
title="清除已保存的密码"
|
||||
class="absolute right-3 top-1/2 -translate-y-1/2 p-1 rounded-full text-muted-foreground/60 hover:text-muted-foreground hover:bg-muted/50 transition-colors"
|
||||
title="清除密码"
|
||||
@click="handleClearSmtpPassword"
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="16"
|
||||
height="16"
|
||||
width="14"
|
||||
height="14"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
@@ -498,6 +498,7 @@ const smtpEncryptionSelectOpen = ref(false)
|
||||
const emailSuffixModeSelectOpen = ref(false)
|
||||
const testSmtpLoading = ref(false)
|
||||
const smtpPasswordIsSet = ref(false)
|
||||
const clearSmtpPassword = ref(false) // 标记是否要清除密码
|
||||
|
||||
// 邮件模板相关状态
|
||||
const templateLoading = ref(false)
|
||||
@@ -710,6 +711,7 @@ async function loadEmailConfig() {
|
||||
// 配置不存在时使用默认值,无需处理
|
||||
}
|
||||
}
|
||||
clearSmtpPassword.value = false
|
||||
} catch (err) {
|
||||
error('加载邮件配置失败')
|
||||
log.error('加载邮件配置失败:', err)
|
||||
@@ -720,6 +722,12 @@ async function loadEmailConfig() {
|
||||
async function saveSmtpConfig() {
|
||||
smtpSaveLoading.value = true
|
||||
try {
|
||||
const passwordAction: 'unchanged' | 'updated' | 'cleared' = emailConfig.value.smtp_password
|
||||
? 'updated'
|
||||
: clearSmtpPassword.value
|
||||
? 'cleared'
|
||||
: 'unchanged'
|
||||
|
||||
const configItems = [
|
||||
{
|
||||
key: 'smtp_host',
|
||||
@@ -737,7 +745,7 @@ async function saveSmtpConfig() {
|
||||
description: 'SMTP 用户名'
|
||||
},
|
||||
// 只有输入了新密码才提交(空值表示保持原密码)
|
||||
...(emailConfig.value.smtp_password
|
||||
...(passwordAction === 'updated'
|
||||
? [{
|
||||
key: 'smtp_password',
|
||||
value: emailConfig.value.smtp_password,
|
||||
@@ -770,8 +778,23 @@ async function saveSmtpConfig() {
|
||||
adminApi.updateSystemConfig(item.key, item.value, item.description)
|
||||
)
|
||||
|
||||
// 如果标记了清除密码,删除密码配置
|
||||
if (passwordAction === 'cleared') {
|
||||
promises.push(adminApi.deleteSystemConfig('smtp_password'))
|
||||
}
|
||||
|
||||
await Promise.all(promises)
|
||||
success('SMTP 配置已保存')
|
||||
|
||||
// 更新状态
|
||||
if (passwordAction === 'cleared') {
|
||||
clearSmtpPassword.value = false
|
||||
smtpPasswordIsSet.value = false
|
||||
} else if (passwordAction === 'updated') {
|
||||
clearSmtpPassword.value = false
|
||||
smtpPasswordIsSet.value = true
|
||||
}
|
||||
emailConfig.value.smtp_password = null
|
||||
} catch (err) {
|
||||
error('保存配置失败')
|
||||
log.error('保存 SMTP 配置失败:', err)
|
||||
@@ -812,15 +835,16 @@ async function saveEmailSuffixConfig() {
|
||||
}
|
||||
|
||||
// 清除 SMTP 密码
|
||||
async function handleClearSmtpPassword() {
|
||||
try {
|
||||
await adminApi.deleteSystemConfig('smtp_password')
|
||||
smtpPasswordIsSet.value = false
|
||||
function handleClearSmtpPassword() {
|
||||
// 如果有输入内容,先清空输入框
|
||||
if (emailConfig.value.smtp_password) {
|
||||
emailConfig.value.smtp_password = null
|
||||
success('SMTP 密码已清除')
|
||||
} catch (err) {
|
||||
error('清除密码失败')
|
||||
log.error('清除 SMTP 密码失败:', err)
|
||||
return
|
||||
}
|
||||
// 标记要清除服务端密码(保存时生效)
|
||||
if (smtpPasswordIsSet.value) {
|
||||
clearSmtpPassword.value = true
|
||||
smtpPasswordIsSet.value = false
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
426
frontend/src/views/admin/LdapSettings.vue
Normal file
426
frontend/src/views/admin/LdapSettings.vue
Normal file
@@ -0,0 +1,426 @@
|
||||
<template>
|
||||
<PageContainer>
|
||||
<PageHeader
|
||||
title="LDAP 配置"
|
||||
description="配置 LDAP 认证服务"
|
||||
/>
|
||||
|
||||
<div class="mt-6 space-y-6">
|
||||
<CardSection
|
||||
title="LDAP 服务器配置"
|
||||
description="配置 LDAP 服务器连接参数"
|
||||
>
|
||||
<template #actions>
|
||||
<div class="flex gap-2">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
:disabled="testLoading"
|
||||
@click="handleTestConnection"
|
||||
>
|
||||
{{ testLoading ? '测试中...' : '测试连接' }}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
:disabled="saveLoading"
|
||||
@click="handleSave"
|
||||
>
|
||||
{{ saveLoading ? '保存中...' : '保存' }}
|
||||
</Button>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
|
||||
<div>
|
||||
<Label
|
||||
for="server-url"
|
||||
class="block text-sm font-medium"
|
||||
>
|
||||
服务器地址
|
||||
</Label>
|
||||
<Input
|
||||
id="server-url"
|
||||
v-model="ldapConfig.server_url"
|
||||
type="text"
|
||||
placeholder="ldap://ldap.example.com:389"
|
||||
class="mt-1"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
格式: ldap://host:389 或 ldaps://host:636
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label
|
||||
for="bind-dn"
|
||||
class="block text-sm font-medium"
|
||||
>
|
||||
绑定 DN
|
||||
</Label>
|
||||
<Input
|
||||
id="bind-dn"
|
||||
v-model="ldapConfig.bind_dn"
|
||||
type="text"
|
||||
placeholder="cn=admin,dc=example,dc=com"
|
||||
class="mt-1"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
用于连接 LDAP 服务器的管理员 DN
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label
|
||||
for="bind-password"
|
||||
class="block text-sm font-medium"
|
||||
>
|
||||
绑定密码
|
||||
</Label>
|
||||
<div class="relative mt-1">
|
||||
<Input
|
||||
id="bind-password"
|
||||
v-model="ldapConfig.bind_password"
|
||||
type="password"
|
||||
:placeholder="hasPassword ? '已设置(留空保持不变)' : '请输入密码'"
|
||||
:class="(hasPassword || ldapConfig.bind_password) ? 'pr-10' : ''"
|
||||
autocomplete="new-password"
|
||||
/>
|
||||
<button
|
||||
v-if="hasPassword || ldapConfig.bind_password"
|
||||
type="button"
|
||||
class="absolute right-3 top-1/2 -translate-y-1/2 p-1 rounded-full text-muted-foreground/60 hover:text-muted-foreground hover:bg-muted/50 transition-colors"
|
||||
title="清除密码"
|
||||
@click="handleClearPassword"
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="14"
|
||||
height="14"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
><line
|
||||
x1="18"
|
||||
y1="6"
|
||||
x2="6"
|
||||
y2="18"
|
||||
/><line
|
||||
x1="6"
|
||||
y1="6"
|
||||
x2="18"
|
||||
y2="18"
|
||||
/></svg>
|
||||
</button>
|
||||
</div>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
绑定账号的密码
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label
|
||||
for="base-dn"
|
||||
class="block text-sm font-medium"
|
||||
>
|
||||
基础 DN
|
||||
</Label>
|
||||
<Input
|
||||
id="base-dn"
|
||||
v-model="ldapConfig.base_dn"
|
||||
type="text"
|
||||
placeholder="ou=users,dc=example,dc=com"
|
||||
class="mt-1"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
用户搜索的基础 DN
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label
|
||||
for="user-search-filter"
|
||||
class="block text-sm font-medium"
|
||||
>
|
||||
用户搜索过滤器
|
||||
</Label>
|
||||
<Input
|
||||
id="user-search-filter"
|
||||
v-model="ldapConfig.user_search_filter"
|
||||
type="text"
|
||||
placeholder="(uid={username})"
|
||||
class="mt-1"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{username} 会被替换为登录用户名
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label
|
||||
for="username-attr"
|
||||
class="block text-sm font-medium"
|
||||
>
|
||||
用户名属性
|
||||
</Label>
|
||||
<Input
|
||||
id="username-attr"
|
||||
v-model="ldapConfig.username_attr"
|
||||
type="text"
|
||||
placeholder="uid"
|
||||
class="mt-1"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
常用: uid (OpenLDAP), sAMAccountName (AD)
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label
|
||||
for="email-attr"
|
||||
class="block text-sm font-medium"
|
||||
>
|
||||
邮箱属性
|
||||
</Label>
|
||||
<Input
|
||||
id="email-attr"
|
||||
v-model="ldapConfig.email_attr"
|
||||
type="text"
|
||||
placeholder="mail"
|
||||
class="mt-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label
|
||||
for="display-name-attr"
|
||||
class="block text-sm font-medium"
|
||||
>
|
||||
显示名称属性
|
||||
</Label>
|
||||
<Input
|
||||
id="display-name-attr"
|
||||
v-model="ldapConfig.display_name_attr"
|
||||
type="text"
|
||||
placeholder="cn"
|
||||
class="mt-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label
|
||||
for="connect-timeout"
|
||||
class="block text-sm font-medium"
|
||||
>
|
||||
连接超时 (秒)
|
||||
</Label>
|
||||
<Input
|
||||
id="connect-timeout"
|
||||
v-model.number="ldapConfig.connect_timeout"
|
||||
type="number"
|
||||
min="1"
|
||||
max="60"
|
||||
placeholder="10"
|
||||
class="mt-1"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
单次 LDAP 操作超时时间 (1-60秒),跨国网络建议 15-30 秒
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mt-6 space-y-4">
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<Label class="text-sm font-medium">使用 STARTTLS</Label>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
在非 SSL 连接上启用 TLS 加密
|
||||
</p>
|
||||
</div>
|
||||
<Switch v-model="ldapConfig.use_starttls" />
|
||||
</div>
|
||||
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<Label class="text-sm font-medium">启用 LDAP 认证</Label>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
允许用户使用 LDAP 账号登录
|
||||
</p>
|
||||
</div>
|
||||
<Switch v-model="ldapConfig.is_enabled" />
|
||||
</div>
|
||||
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<Label class="text-sm font-medium">仅允许 LDAP 登录</Label>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
禁用本地账号登录,仅允许 LDAP 认证
|
||||
</p>
|
||||
</div>
|
||||
<Switch v-model="ldapConfig.is_exclusive" />
|
||||
</div>
|
||||
</div>
|
||||
</CardSection>
|
||||
</div>
|
||||
</PageContainer>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { PageContainer, PageHeader, CardSection } from '@/components/layout'
|
||||
import { Button, Input, Label, Switch } from '@/components/ui'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { adminApi, type LdapConfigUpdateRequest } from '@/api/admin'
|
||||
|
||||
const { success, error } = useToast()
|
||||
|
||||
const loading = ref(false)
|
||||
const saveLoading = ref(false)
|
||||
const testLoading = ref(false)
|
||||
const hasPassword = ref(false)
|
||||
const clearPassword = ref(false) // 标记是否要清除密码
|
||||
|
||||
const ldapConfig = ref({
|
||||
server_url: '',
|
||||
bind_dn: '',
|
||||
bind_password: '',
|
||||
base_dn: '',
|
||||
user_search_filter: '(uid={username})',
|
||||
username_attr: 'uid',
|
||||
email_attr: 'mail',
|
||||
display_name_attr: 'cn',
|
||||
is_enabled: false,
|
||||
is_exclusive: false,
|
||||
use_starttls: false,
|
||||
connect_timeout: 10,
|
||||
})
|
||||
|
||||
onMounted(async () => {
|
||||
await loadConfig()
|
||||
})
|
||||
|
||||
async function loadConfig() {
|
||||
loading.value = true
|
||||
try {
|
||||
const response = await adminApi.getLdapConfig()
|
||||
ldapConfig.value = {
|
||||
server_url: response.server_url || '',
|
||||
bind_dn: response.bind_dn || '',
|
||||
bind_password: '',
|
||||
base_dn: response.base_dn || '',
|
||||
user_search_filter: response.user_search_filter || '(uid={username})',
|
||||
username_attr: response.username_attr || 'uid',
|
||||
email_attr: response.email_attr || 'mail',
|
||||
display_name_attr: response.display_name_attr || 'cn',
|
||||
is_enabled: response.is_enabled || false,
|
||||
is_exclusive: response.is_exclusive || false,
|
||||
use_starttls: response.use_starttls || false,
|
||||
connect_timeout: response.connect_timeout || 10,
|
||||
}
|
||||
hasPassword.value = !!response.has_bind_password
|
||||
clearPassword.value = false
|
||||
} catch (err) {
|
||||
error('加载 LDAP 配置失败')
|
||||
console.error('加载 LDAP 配置失败:', err)
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function handleSave() {
|
||||
saveLoading.value = true
|
||||
try {
|
||||
const payload: LdapConfigUpdateRequest = {
|
||||
server_url: ldapConfig.value.server_url,
|
||||
bind_dn: ldapConfig.value.bind_dn,
|
||||
base_dn: ldapConfig.value.base_dn,
|
||||
user_search_filter: ldapConfig.value.user_search_filter,
|
||||
username_attr: ldapConfig.value.username_attr,
|
||||
email_attr: ldapConfig.value.email_attr,
|
||||
display_name_attr: ldapConfig.value.display_name_attr,
|
||||
is_enabled: ldapConfig.value.is_enabled,
|
||||
is_exclusive: ldapConfig.value.is_exclusive,
|
||||
use_starttls: ldapConfig.value.use_starttls,
|
||||
connect_timeout: ldapConfig.value.connect_timeout,
|
||||
}
|
||||
|
||||
// 优先使用输入的新密码;否则如果标记清除则发送空字符串
|
||||
let passwordAction: 'unchanged' | 'updated' | 'cleared' = 'unchanged'
|
||||
if (ldapConfig.value.bind_password) {
|
||||
payload.bind_password = ldapConfig.value.bind_password
|
||||
passwordAction = 'updated'
|
||||
} else if (clearPassword.value) {
|
||||
payload.bind_password = ''
|
||||
passwordAction = 'cleared'
|
||||
}
|
||||
|
||||
await adminApi.updateLdapConfig(payload)
|
||||
success('LDAP 配置保存成功')
|
||||
|
||||
if (passwordAction === 'cleared') {
|
||||
hasPassword.value = false
|
||||
clearPassword.value = false
|
||||
} else if (passwordAction === 'updated') {
|
||||
hasPassword.value = true
|
||||
clearPassword.value = false
|
||||
}
|
||||
ldapConfig.value.bind_password = ''
|
||||
} catch (err) {
|
||||
error('保存 LDAP 配置失败')
|
||||
console.error('保存 LDAP 配置失败:', err)
|
||||
} finally {
|
||||
saveLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function handleTestConnection() {
|
||||
if (clearPassword.value && !ldapConfig.value.bind_password) {
|
||||
error('已标记清除绑定密码,请先保存或输入新的绑定密码再测试')
|
||||
return
|
||||
}
|
||||
|
||||
testLoading.value = true
|
||||
try {
|
||||
const payload: LdapConfigUpdateRequest = {
|
||||
server_url: ldapConfig.value.server_url,
|
||||
bind_dn: ldapConfig.value.bind_dn,
|
||||
base_dn: ldapConfig.value.base_dn,
|
||||
user_search_filter: ldapConfig.value.user_search_filter,
|
||||
username_attr: ldapConfig.value.username_attr,
|
||||
email_attr: ldapConfig.value.email_attr,
|
||||
display_name_attr: ldapConfig.value.display_name_attr,
|
||||
is_enabled: ldapConfig.value.is_enabled,
|
||||
is_exclusive: ldapConfig.value.is_exclusive,
|
||||
use_starttls: ldapConfig.value.use_starttls,
|
||||
connect_timeout: ldapConfig.value.connect_timeout,
|
||||
...(ldapConfig.value.bind_password && { bind_password: ldapConfig.value.bind_password }),
|
||||
}
|
||||
const response = await adminApi.testLdapConnection(payload)
|
||||
if (response.success) {
|
||||
success('LDAP 连接测试成功')
|
||||
} else {
|
||||
error(`LDAP 连接测试失败: ${response.message}`)
|
||||
}
|
||||
} catch (err) {
|
||||
error('LDAP 连接测试失败')
|
||||
console.error('LDAP 连接测试失败:', err)
|
||||
} finally {
|
||||
testLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function handleClearPassword() {
|
||||
// 如果有输入内容,先清空输入框
|
||||
if (ldapConfig.value.bind_password) {
|
||||
ldapConfig.value.bind_password = ''
|
||||
return
|
||||
}
|
||||
// 标记要清除服务端密码(保存时生效)
|
||||
if (hasPassword.value) {
|
||||
clearPassword.value = true
|
||||
hasPassword.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
@@ -464,6 +464,30 @@
|
||||
</div>
|
||||
</div>
|
||||
</CardSection>
|
||||
|
||||
<!-- 系统版本信息 -->
|
||||
<CardSection
|
||||
title="系统信息"
|
||||
description="当前系统版本和构建信息"
|
||||
>
|
||||
<div class="flex items-center gap-4">
|
||||
<div class="flex items-center gap-2">
|
||||
<Label class="text-sm font-medium text-muted-foreground">版本:</Label>
|
||||
<span
|
||||
v-if="systemVersion"
|
||||
class="text-sm font-mono"
|
||||
>
|
||||
{{ systemVersion }}
|
||||
</span>
|
||||
<span
|
||||
v-else
|
||||
class="text-sm text-muted-foreground"
|
||||
>
|
||||
加载中...
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</CardSection>
|
||||
</div>
|
||||
|
||||
<!-- 导入配置对话框 -->
|
||||
@@ -475,7 +499,7 @@
|
||||
<div class="space-y-4">
|
||||
<div
|
||||
v-if="importPreview"
|
||||
class="p-3 bg-muted rounded-lg text-sm"
|
||||
class="text-sm"
|
||||
>
|
||||
<p class="font-medium mb-2">
|
||||
配置预览
|
||||
@@ -557,7 +581,7 @@
|
||||
class="space-y-4"
|
||||
>
|
||||
<div class="grid grid-cols-2 gap-4 text-sm">
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<div>
|
||||
<p class="font-medium">
|
||||
全局模型
|
||||
</p>
|
||||
@@ -567,7 +591,7 @@
|
||||
跳过: {{ importResult.stats.global_models.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<div>
|
||||
<p class="font-medium">
|
||||
提供商
|
||||
</p>
|
||||
@@ -577,7 +601,7 @@
|
||||
跳过: {{ importResult.stats.providers.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<div>
|
||||
<p class="font-medium">
|
||||
端点
|
||||
</p>
|
||||
@@ -587,7 +611,7 @@
|
||||
跳过: {{ importResult.stats.endpoints.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<div>
|
||||
<p class="font-medium">
|
||||
API Keys
|
||||
</p>
|
||||
@@ -596,7 +620,7 @@
|
||||
跳过: {{ importResult.stats.keys.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg col-span-2">
|
||||
<div class="col-span-2">
|
||||
<p class="font-medium">
|
||||
模型配置
|
||||
</p>
|
||||
@@ -642,7 +666,7 @@
|
||||
<div class="space-y-4">
|
||||
<div
|
||||
v-if="importUsersPreview"
|
||||
class="p-3 bg-muted rounded-lg text-sm"
|
||||
class="text-sm"
|
||||
>
|
||||
<p class="font-medium mb-2">
|
||||
数据预览
|
||||
@@ -652,6 +676,9 @@
|
||||
<li>
|
||||
API Keys: {{ importUsersPreview.users?.reduce((sum: number, u: any) => sum + (u.api_keys?.length || 0), 0) }} 个
|
||||
</li>
|
||||
<li v-if="importUsersPreview.standalone_keys?.length">
|
||||
独立余额 Keys: {{ importUsersPreview.standalone_keys.length }} 个
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
@@ -720,7 +747,7 @@
|
||||
class="space-y-4"
|
||||
>
|
||||
<div class="grid grid-cols-2 gap-4 text-sm">
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<div>
|
||||
<p class="font-medium">
|
||||
用户
|
||||
</p>
|
||||
@@ -730,7 +757,7 @@
|
||||
跳过: {{ importUsersResult.stats.users.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<div>
|
||||
<p class="font-medium">
|
||||
API Keys
|
||||
</p>
|
||||
@@ -739,6 +766,18 @@
|
||||
跳过: {{ importUsersResult.stats.api_keys.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div
|
||||
v-if="importUsersResult.stats.standalone_keys"
|
||||
class="col-span-2"
|
||||
>
|
||||
<p class="font-medium">
|
||||
独立余额 Keys
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importUsersResult.stats.standalone_keys.created }},
|
||||
跳过: {{ importUsersResult.stats.standalone_keys.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
@@ -839,6 +878,9 @@ const importUsersResult = ref<UsersImportResponse | null>(null)
|
||||
const usersMergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
|
||||
const usersMergeModeSelectOpen = ref(false)
|
||||
|
||||
// 系统版本信息
|
||||
const systemVersion = ref<string>('')
|
||||
|
||||
const systemConfig = ref<SystemConfig>({
|
||||
// 基础配置
|
||||
default_user_quota_usd: 10.0,
|
||||
@@ -890,9 +932,21 @@ const sensitiveHeadersStr = computed({
|
||||
})
|
||||
|
||||
onMounted(async () => {
|
||||
await loadSystemConfig()
|
||||
await Promise.all([
|
||||
loadSystemConfig(),
|
||||
loadSystemVersion()
|
||||
])
|
||||
})
|
||||
|
||||
async function loadSystemVersion() {
|
||||
try {
|
||||
const data = await adminApi.getSystemVersion()
|
||||
systemVersion.value = data.version
|
||||
} catch (err) {
|
||||
log.error('加载系统版本失败:', err)
|
||||
}
|
||||
}
|
||||
|
||||
async function loadSystemConfig() {
|
||||
try {
|
||||
const configs = [
|
||||
@@ -1178,12 +1232,6 @@ function handleUsersFileSelect(event: Event) {
|
||||
const content = e.target?.result as string
|
||||
const data = JSON.parse(content) as UsersExportData
|
||||
|
||||
// 验证版本
|
||||
if (data.version !== '1.0') {
|
||||
error(`不支持的配置版本: ${data.version}`)
|
||||
return
|
||||
}
|
||||
|
||||
importUsersPreview.value = data
|
||||
usersMergeMode.value = 'skip'
|
||||
importUsersDialogOpen.value = true
|
||||
|
||||
@@ -907,7 +907,7 @@ function editUser(user: any) {
|
||||
role: user.role,
|
||||
is_active: user.is_active,
|
||||
allowed_providers: user.allowed_providers || [],
|
||||
allowed_endpoints: user.allowed_endpoints || [],
|
||||
allowed_api_formats: user.allowed_api_formats || [],
|
||||
allowed_models: user.allowed_models || []
|
||||
}
|
||||
showUserFormDialog.value = true
|
||||
@@ -929,7 +929,7 @@ async function handleUserFormSubmit(data: UserFormData & { password?: string })
|
||||
quota_usd: data.quota_usd,
|
||||
role: data.role,
|
||||
allowed_providers: data.allowed_providers,
|
||||
allowed_endpoints: data.allowed_endpoints,
|
||||
allowed_api_formats: data.allowed_api_formats,
|
||||
allowed_models: data.allowed_models
|
||||
}
|
||||
if (data.password) {
|
||||
@@ -946,7 +946,7 @@ async function handleUserFormSubmit(data: UserFormData & { password?: string })
|
||||
quota_usd: data.quota_usd,
|
||||
role: data.role,
|
||||
allowed_providers: data.allowed_providers,
|
||||
allowed_endpoints: data.allowed_endpoints,
|
||||
allowed_api_formats: data.allowed_api_formats,
|
||||
allowed_models: data.allowed_models
|
||||
})
|
||||
// 如果创建时指定为禁用,则更新状态
|
||||
|
||||
@@ -20,10 +20,11 @@
|
||||
</nav>
|
||||
|
||||
<!-- Header -->
|
||||
<header class="fixed top-0 left-0 right-0 z-50 border-b border-[#cc785c]/10 dark:border-[rgba(227,224,211,0.12)] bg-[#fafaf7]/90 dark:bg-[#191714]/95 backdrop-blur-xl transition-all">
|
||||
<div class="mx-auto max-w-7xl px-6 py-4">
|
||||
<div class="flex items-center justify-between">
|
||||
<!-- Logo & Brand -->
|
||||
<header class="sticky top-0 z-50 border-b border-[#cc785c]/10 dark:border-[rgba(227,224,211,0.12)] bg-[#fafaf7]/90 dark:bg-[#191714]/95 backdrop-blur-xl transition-all">
|
||||
<div class="h-16 flex items-center">
|
||||
<!-- Centered content container (max-w-7xl) -->
|
||||
<div class="mx-auto max-w-7xl w-full px-6 flex items-center justify-between">
|
||||
<!-- Left: Logo & Brand -->
|
||||
<div
|
||||
class="flex items-center gap-3 group/logo cursor-pointer"
|
||||
@click="scrollToSection(0)"
|
||||
@@ -40,7 +41,7 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Center Navigation -->
|
||||
<!-- Center: Navigation -->
|
||||
<nav class="hidden md:flex items-center gap-2">
|
||||
<button
|
||||
v-for="(section, index) in sections"
|
||||
@@ -59,42 +60,54 @@
|
||||
</button>
|
||||
</nav>
|
||||
|
||||
<!-- Right Actions -->
|
||||
<div class="flex items-center gap-3">
|
||||
<button
|
||||
class="flex h-9 w-9 items-center justify-center rounded-lg text-muted-foreground hover:text-foreground hover:bg-muted/50 transition"
|
||||
:title="themeMode === 'system' ? '跟随系统' : themeMode === 'dark' ? '深色模式' : '浅色模式'"
|
||||
@click="toggleDarkMode"
|
||||
>
|
||||
<SunMoon
|
||||
v-if="themeMode === 'system'"
|
||||
class="h-4 w-4"
|
||||
/>
|
||||
<Sun
|
||||
v-else-if="themeMode === 'light'"
|
||||
class="h-4 w-4"
|
||||
/>
|
||||
<Moon
|
||||
v-else
|
||||
class="h-4 w-4"
|
||||
/>
|
||||
</button>
|
||||
<!-- Right: Login/Dashboard Button -->
|
||||
<RouterLink
|
||||
v-if="authStore.isAuthenticated"
|
||||
:to="dashboardPath"
|
||||
class="min-w-[72px] text-center rounded-xl bg-[#191919] dark:bg-[#cc785c] px-4 py-2 text-sm font-medium text-white shadow-sm transition hover:bg-[#262625] dark:hover:bg-[#b86d52] whitespace-nowrap"
|
||||
>
|
||||
控制台
|
||||
</RouterLink>
|
||||
<button
|
||||
v-else
|
||||
class="min-w-[72px] text-center rounded-xl bg-[#cc785c] px-4 py-2 text-sm font-medium text-white shadow-lg shadow-[#cc785c]/30 transition hover:bg-[#d4a27f] whitespace-nowrap"
|
||||
@click="showLoginDialog = true"
|
||||
>
|
||||
登录
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<RouterLink
|
||||
v-if="authStore.isAuthenticated"
|
||||
:to="dashboardPath"
|
||||
class="rounded-xl bg-[#191919] dark:bg-[#cc785c] px-4 py-2 text-sm font-medium text-white shadow-sm transition hover:bg-[#262625] dark:hover:bg-[#b86d52]"
|
||||
>
|
||||
控制台
|
||||
</RouterLink>
|
||||
<button
|
||||
<!-- Fixed right icons (px-8 to match dashboard) -->
|
||||
<div class="absolute right-8 top-1/2 -translate-y-1/2 flex items-center gap-2">
|
||||
<!-- Theme Toggle -->
|
||||
<button
|
||||
class="flex h-9 w-9 items-center justify-center rounded-lg text-muted-foreground hover:text-foreground hover:bg-muted/50 transition"
|
||||
:title="themeMode === 'system' ? '跟随系统' : themeMode === 'dark' ? '深色模式' : '浅色模式'"
|
||||
@click="toggleDarkMode"
|
||||
>
|
||||
<SunMoon
|
||||
v-if="themeMode === 'system'"
|
||||
class="h-4 w-4"
|
||||
/>
|
||||
<Sun
|
||||
v-else-if="themeMode === 'light'"
|
||||
class="h-4 w-4"
|
||||
/>
|
||||
<Moon
|
||||
v-else
|
||||
class="rounded-xl bg-[#cc785c] px-4 py-2 text-sm font-medium text-white shadow-lg shadow-[#cc785c]/30 transition hover:bg-[#d4a27f]"
|
||||
@click="showLoginDialog = true"
|
||||
>
|
||||
登录
|
||||
</button>
|
||||
</div>
|
||||
class="h-4 w-4"
|
||||
/>
|
||||
</button>
|
||||
<!-- GitHub Link -->
|
||||
<a
|
||||
href="https://github.com/fawney19/Aether"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
class="flex h-9 w-9 items-center justify-center rounded-lg text-muted-foreground hover:text-foreground hover:bg-muted/50 transition"
|
||||
title="GitHub 仓库"
|
||||
>
|
||||
<GithubIcon class="h-4 w-4" />
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
@@ -336,31 +349,6 @@
|
||||
</section>
|
||||
</main>
|
||||
|
||||
<!-- Footer -->
|
||||
<footer class="relative z-10 border-t border-[#cc785c]/10 dark:border-[rgba(227,224,211,0.12)] bg-[#fafaf7]/90 dark:bg-[#191714]/95 backdrop-blur-md py-8">
|
||||
<div class="mx-auto max-w-7xl px-6">
|
||||
<div class="flex flex-col items-center justify-between gap-4 sm:flex-row">
|
||||
<p class="text-sm text-[#91918d] dark:text-muted-foreground">
|
||||
© 2025 Aether. 团队内部使用
|
||||
</p>
|
||||
<div class="flex items-center gap-6 text-sm text-[#91918d] dark:text-muted-foreground">
|
||||
<a
|
||||
href="#"
|
||||
class="transition hover:text-[#191919] dark:hover:text-white"
|
||||
>使用条款</a>
|
||||
<a
|
||||
href="#"
|
||||
class="transition hover:text-[#191919] dark:hover:text-white"
|
||||
>隐私政策</a>
|
||||
<a
|
||||
href="#"
|
||||
class="transition hover:text-[#191919] dark:hover:text-white"
|
||||
>技术支持</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</footer>
|
||||
|
||||
<LoginDialog v-model="showLoginDialog" />
|
||||
</div>
|
||||
</template>
|
||||
@@ -378,6 +366,7 @@ import {
|
||||
SunMoon,
|
||||
Terminal
|
||||
} from 'lucide-vue-next'
|
||||
import GithubIcon from '@/components/icons/GithubIcon.vue'
|
||||
import { useAuthStore } from '@/stores/auth'
|
||||
import { useDarkMode } from '@/composables/useDarkMode'
|
||||
import { useClipboard } from '@/composables/useClipboard'
|
||||
|
||||
@@ -56,6 +56,7 @@
|
||||
:show-actual-cost="authStore.isAdmin"
|
||||
:loading="isLoadingRecords"
|
||||
:selected-period="selectedPeriod"
|
||||
:filter-search="filterSearch"
|
||||
:filter-user="filterUser"
|
||||
:filter-model="filterModel"
|
||||
:filter-provider="filterProvider"
|
||||
@@ -69,6 +70,7 @@
|
||||
:page-size-options="pageSizeOptions"
|
||||
:auto-refresh="globalAutoRefresh"
|
||||
@update:selected-period="handlePeriodChange"
|
||||
@update:filter-search="handleFilterSearchChange"
|
||||
@update:filter-user="handleFilterUserChange"
|
||||
@update:filter-model="handleFilterModelChange"
|
||||
@update:filter-provider="handleFilterProviderChange"
|
||||
@@ -133,6 +135,7 @@ const pageSize = ref(20)
|
||||
const pageSizeOptions = [10, 20, 50, 100]
|
||||
|
||||
// 筛选状态
|
||||
const filterSearch = ref('')
|
||||
const filterUser = ref('__all__')
|
||||
const filterModel = ref('__all__')
|
||||
const filterProvider = ref('__all__')
|
||||
@@ -392,14 +395,17 @@ onMounted(async () => {
|
||||
// 热力图加载失败不提示,因为 UI 已显示占位符
|
||||
}
|
||||
|
||||
// 管理员页面加载用户列表和第一页记录
|
||||
// 加载记录和用户列表
|
||||
if (isAdminPage.value) {
|
||||
// 并行加载用户列表和记录
|
||||
// 管理员页面:并行加载用户列表和记录
|
||||
const [users] = await Promise.all([
|
||||
usersApi.getAllUsers(),
|
||||
loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
|
||||
])
|
||||
availableUsers.value = users.map(u => ({ id: u.id, username: u.username, email: u.email }))
|
||||
} else {
|
||||
// 用户页面:加载记录
|
||||
await loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
})
|
||||
|
||||
@@ -410,34 +416,26 @@ async function handlePeriodChange(value: string) {
|
||||
|
||||
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
|
||||
await loadStats(dateRange)
|
||||
|
||||
if (isAdminPage.value) {
|
||||
await loadRecords({ page: 1, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
await loadRecords({ page: 1, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
|
||||
// 处理分页变化
|
||||
async function handlePageChange(page: number) {
|
||||
currentPage.value = page
|
||||
|
||||
if (isAdminPage.value) {
|
||||
await loadRecords({ page, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
await loadRecords({ page, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
|
||||
// 处理每页大小变化
|
||||
async function handlePageSizeChange(size: number) {
|
||||
pageSize.value = size
|
||||
currentPage.value = 1 // 重置到第一页
|
||||
|
||||
if (isAdminPage.value) {
|
||||
await loadRecords({ page: 1, pageSize: size }, getCurrentFilters())
|
||||
}
|
||||
await loadRecords({ page: 1, pageSize: size }, getCurrentFilters())
|
||||
}
|
||||
|
||||
// 获取当前筛选参数
|
||||
function getCurrentFilters() {
|
||||
return {
|
||||
search: filterSearch.value.trim() || undefined,
|
||||
user_id: filterUser.value !== '__all__' ? filterUser.value : undefined,
|
||||
model: filterModel.value !== '__all__' ? filterModel.value : undefined,
|
||||
provider: filterProvider.value !== '__all__' ? filterProvider.value : undefined,
|
||||
@@ -446,6 +444,13 @@ function getCurrentFilters() {
|
||||
}
|
||||
|
||||
// 处理筛选变化
|
||||
async function handleFilterSearchChange(value: string) {
|
||||
filterSearch.value = value
|
||||
currentPage.value = 1
|
||||
|
||||
await loadRecords({ page: 1, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
|
||||
async function handleFilterUserChange(value: string) {
|
||||
filterUser.value = value
|
||||
currentPage.value = 1 // 重置到第一页
|
||||
@@ -486,10 +491,7 @@ async function handleFilterStatusChange(value: string) {
|
||||
async function refreshData() {
|
||||
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
|
||||
await loadStats(dateRange)
|
||||
|
||||
if (isAdminPage.value) {
|
||||
await loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
await loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
|
||||
// 显示请求详情
|
||||
|
||||
859
frontend/src/views/user/ManagementTokens.vue
Normal file
859
frontend/src/views/user/ManagementTokens.vue
Normal file
@@ -0,0 +1,859 @@
|
||||
<template>
|
||||
<div class="space-y-6 pb-8">
|
||||
<!-- 访问令牌表格 -->
|
||||
<Card
|
||||
variant="default"
|
||||
class="overflow-hidden"
|
||||
>
|
||||
<!-- 标题和操作栏 -->
|
||||
<div class="px-4 sm:px-6 py-3 sm:py-3.5 border-b border-border/60">
|
||||
<div class="flex flex-col sm:flex-row sm:items-center sm:justify-between gap-3 sm:gap-4">
|
||||
<div>
|
||||
<h3 class="text-sm sm:text-base font-semibold">
|
||||
访问令牌
|
||||
</h3>
|
||||
<p class="text-xs text-muted-foreground mt-0.5">
|
||||
<template v-if="quota">
|
||||
已创建 {{ quota.used }}/{{ quota.max }} 个令牌
|
||||
<span
|
||||
v-if="quota.used >= quota.max"
|
||||
class="text-destructive font-medium"
|
||||
>(已达上限)</span>
|
||||
</template>
|
||||
<template v-else>
|
||||
用于程序化访问管理 API 的令牌
|
||||
</template>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 操作按钮 -->
|
||||
<div class="flex items-center gap-2">
|
||||
<!-- 新增按钮 -->
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
title="创建新令牌"
|
||||
:disabled="quota ? quota.used >= quota.max : false"
|
||||
@click="showCreateDialog = true"
|
||||
>
|
||||
<Plus class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
|
||||
<!-- 刷新按钮 -->
|
||||
<RefreshButton
|
||||
:loading="loading"
|
||||
@click="loadTokens"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 加载状态 -->
|
||||
<div
|
||||
v-if="loading"
|
||||
class="flex items-center justify-center py-12"
|
||||
>
|
||||
<LoadingState message="加载中..." />
|
||||
</div>
|
||||
|
||||
<!-- 空状态 -->
|
||||
<div
|
||||
v-else-if="tokens.length === 0"
|
||||
class="flex items-center justify-center py-12"
|
||||
>
|
||||
<EmptyState
|
||||
title="暂无访问令牌"
|
||||
description="创建你的第一个访问令牌开始使用管理 API"
|
||||
:icon="KeyRound"
|
||||
>
|
||||
<template #actions>
|
||||
<Button
|
||||
size="lg"
|
||||
class="shadow-lg shadow-primary/20"
|
||||
@click="showCreateDialog = true"
|
||||
>
|
||||
<Plus class="mr-2 h-4 w-4" />
|
||||
创建访问令牌
|
||||
</Button>
|
||||
</template>
|
||||
</EmptyState>
|
||||
</div>
|
||||
|
||||
<!-- 桌面端表格 -->
|
||||
<div
|
||||
v-else
|
||||
class="hidden md:block overflow-x-auto"
|
||||
>
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow class="border-b border-border/60 hover:bg-transparent">
|
||||
<TableHead class="min-w-[180px] h-12 font-semibold">
|
||||
名称
|
||||
</TableHead>
|
||||
<TableHead class="min-w-[160px] h-12 font-semibold">
|
||||
令牌
|
||||
</TableHead>
|
||||
<TableHead class="min-w-[80px] h-12 font-semibold text-center">
|
||||
使用次数
|
||||
</TableHead>
|
||||
<TableHead class="min-w-[70px] h-12 font-semibold text-center">
|
||||
状态
|
||||
</TableHead>
|
||||
<TableHead class="min-w-[100px] h-12 font-semibold">
|
||||
时间
|
||||
</TableHead>
|
||||
<TableHead class="min-w-[100px] h-12 font-semibold text-center">
|
||||
操作
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
<TableRow
|
||||
v-for="token in paginatedTokens"
|
||||
:key="token.id"
|
||||
class="border-b border-border/40 hover:bg-muted/30 transition-colors"
|
||||
>
|
||||
<!-- 名称 -->
|
||||
<TableCell class="py-4">
|
||||
<div class="flex-1 min-w-0">
|
||||
<div
|
||||
class="text-sm font-semibold truncate"
|
||||
:title="token.name"
|
||||
>
|
||||
{{ token.name }}
|
||||
</div>
|
||||
<div
|
||||
v-if="token.description"
|
||||
class="text-xs text-muted-foreground mt-0.5 truncate"
|
||||
:title="token.description"
|
||||
>
|
||||
{{ token.description }}
|
||||
</div>
|
||||
</div>
|
||||
</TableCell>
|
||||
|
||||
<!-- Token 显示 -->
|
||||
<TableCell class="py-4">
|
||||
<div class="flex items-center gap-1.5">
|
||||
<code class="text-xs font-mono text-muted-foreground bg-muted/30 px-2 py-1 rounded">
|
||||
{{ token.token_display }}
|
||||
</code>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-6 w-6"
|
||||
title="重新生成令牌"
|
||||
@click="confirmRegenerate(token)"
|
||||
>
|
||||
<RefreshCw class="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
|
||||
<!-- 使用次数 -->
|
||||
<TableCell class="py-4 text-center">
|
||||
<span class="text-sm font-medium">
|
||||
{{ formatNumber(token.usage_count || 0) }}
|
||||
</span>
|
||||
</TableCell>
|
||||
|
||||
<!-- 状态 -->
|
||||
<TableCell class="py-4 text-center">
|
||||
<Badge
|
||||
:variant="getStatusVariant(token)"
|
||||
class="font-medium px-3 py-1"
|
||||
>
|
||||
{{ getStatusText(token) }}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
|
||||
<!-- 时间 -->
|
||||
<TableCell class="py-4 text-sm text-muted-foreground">
|
||||
<div class="text-xs">
|
||||
创建于 {{ formatDate(token.created_at) }}
|
||||
</div>
|
||||
<div class="text-xs mt-1">
|
||||
{{ token.last_used_at ? `最后使用 ${formatRelativeTime(token.last_used_at)}` : '从未使用' }}
|
||||
</div>
|
||||
</TableCell>
|
||||
|
||||
<!-- 操作按钮 -->
|
||||
<TableCell class="py-4">
|
||||
<div class="flex justify-center gap-1">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
title="编辑"
|
||||
@click="openEditDialog(token)"
|
||||
>
|
||||
<Pencil class="h-4 w-4" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
:title="token.is_active ? '禁用' : '启用'"
|
||||
@click="toggleToken(token)"
|
||||
>
|
||||
<Power class="h-4 w-4" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
title="删除"
|
||||
@click="confirmDelete(token)"
|
||||
>
|
||||
<Trash2 class="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
|
||||
<!-- 移动端卡片列表 -->
|
||||
<div
|
||||
v-if="!loading && tokens.length > 0"
|
||||
class="md:hidden space-y-3 p-4"
|
||||
>
|
||||
<Card
|
||||
v-for="token in paginatedTokens"
|
||||
:key="token.id"
|
||||
variant="default"
|
||||
class="group hover:shadow-md hover:border-primary/30 transition-all duration-200"
|
||||
>
|
||||
<div class="p-4">
|
||||
<!-- 第一行:名称、状态、操作 -->
|
||||
<div class="flex items-center justify-between mb-2">
|
||||
<div class="flex items-center gap-2 min-w-0 flex-1">
|
||||
<h3 class="text-sm font-semibold text-foreground truncate">
|
||||
{{ token.name }}
|
||||
</h3>
|
||||
<Badge
|
||||
:variant="getStatusVariant(token)"
|
||||
class="text-xs px-1.5 py-0"
|
||||
>
|
||||
{{ getStatusText(token) }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-0.5 flex-shrink-0">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="编辑"
|
||||
@click="openEditDialog(token)"
|
||||
>
|
||||
<Pencil class="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
:title="token.is_active ? '禁用' : '启用'"
|
||||
@click="toggleToken(token)"
|
||||
>
|
||||
<Power class="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="删除"
|
||||
@click="confirmDelete(token)"
|
||||
>
|
||||
<Trash2 class="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Token 显示 -->
|
||||
<div class="flex items-center gap-2 text-xs mb-2">
|
||||
<code class="font-mono text-muted-foreground">{{ token.token_display }}</code>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-5 w-5"
|
||||
title="重新生成"
|
||||
@click="confirmRegenerate(token)"
|
||||
>
|
||||
<RefreshCw class="h-3 w-3" />
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<!-- 统计信息 -->
|
||||
<div class="flex items-center gap-3 text-xs text-muted-foreground">
|
||||
<span>{{ formatNumber(token.usage_count || 0) }} 次使用</span>
|
||||
<span>·</span>
|
||||
<span>{{ token.last_used_at ? formatRelativeTime(token.last_used_at) : '从未使用' }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
<!-- 分页 -->
|
||||
<Pagination
|
||||
v-if="totalTokens > 0"
|
||||
:current="currentPage"
|
||||
:total="totalTokens"
|
||||
:page-size="pageSize"
|
||||
@update:current="currentPage = $event"
|
||||
@update:page-size="handlePageSizeChange"
|
||||
/>
|
||||
</Card>
|
||||
|
||||
<!-- 创建/编辑 Token 对话框 -->
|
||||
<Dialog
|
||||
v-model="showCreateDialog"
|
||||
size="lg"
|
||||
>
|
||||
<template #header>
|
||||
<div class="border-b border-border px-6 py-4">
|
||||
<div class="flex items-center gap-3">
|
||||
<div class="flex h-9 w-9 items-center justify-center rounded-lg bg-primary/10 flex-shrink-0">
|
||||
<KeyRound class="h-5 w-5 text-primary" />
|
||||
</div>
|
||||
<div class="flex-1 min-w-0">
|
||||
<h3 class="text-lg font-semibold text-foreground leading-tight">
|
||||
{{ editingToken ? '编辑访问令牌' : '创建访问令牌' }}
|
||||
</h3>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
{{ editingToken ? '修改令牌配置' : '创建一个新的令牌用于访问管理 API' }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="space-y-4">
|
||||
<!-- 名称 -->
|
||||
<div class="space-y-2">
|
||||
<Label
|
||||
for="token-name"
|
||||
class="text-sm font-semibold"
|
||||
>名称 *</Label>
|
||||
<Input
|
||||
id="token-name"
|
||||
v-model="formData.name"
|
||||
placeholder="例如:CI/CD 自动化"
|
||||
class="h-11 border-border/60"
|
||||
autocomplete="off"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- 描述 -->
|
||||
<div class="space-y-2">
|
||||
<Label
|
||||
for="token-description"
|
||||
class="text-sm font-semibold"
|
||||
>描述</Label>
|
||||
<Input
|
||||
id="token-description"
|
||||
v-model="formData.description"
|
||||
placeholder="用途说明(可选)"
|
||||
class="h-11 border-border/60"
|
||||
autocomplete="off"
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- IP 白名单 -->
|
||||
<div class="space-y-2">
|
||||
<Label
|
||||
for="token-ips"
|
||||
class="text-sm font-semibold"
|
||||
>IP 白名单</Label>
|
||||
<Input
|
||||
id="token-ips"
|
||||
v-model="formData.allowedIpsText"
|
||||
placeholder="例如:192.168.1.0/24, 10.0.0.1(逗号分隔,留空不限制)"
|
||||
class="h-11 border-border/60"
|
||||
autocomplete="off"
|
||||
/>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
限制只能从指定 IP 地址使用此令牌,支持 CIDR 格式
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 过期时间 -->
|
||||
<div class="space-y-2">
|
||||
<Label
|
||||
for="token-expires"
|
||||
class="text-sm font-semibold"
|
||||
>过期时间</Label>
|
||||
<Input
|
||||
id="token-expires"
|
||||
v-model="formData.expiresAt"
|
||||
type="datetime-local"
|
||||
class="h-11 border-border/60"
|
||||
/>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
留空表示永不过期
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button
|
||||
variant="outline"
|
||||
class="h-11 px-6"
|
||||
@click="closeDialog"
|
||||
>
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
class="h-11 px-6 shadow-lg shadow-primary/20"
|
||||
:disabled="saving || !isFormValid"
|
||||
@click="saveToken"
|
||||
>
|
||||
<Loader2
|
||||
v-if="saving"
|
||||
class="animate-spin h-4 w-4 mr-2"
|
||||
/>
|
||||
{{ saving ? '保存中...' : (editingToken ? '保存' : '创建') }}
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
|
||||
<!-- 新 Token 创建成功对话框 -->
|
||||
<Dialog
|
||||
v-model="showTokenDialog"
|
||||
size="lg"
|
||||
persistent
|
||||
>
|
||||
<template #header>
|
||||
<div class="border-b border-border px-6 py-4">
|
||||
<div class="flex items-center gap-3">
|
||||
<div class="flex h-9 w-9 items-center justify-center rounded-lg bg-emerald-100 dark:bg-emerald-900/30 flex-shrink-0">
|
||||
<CheckCircle class="h-5 w-5 text-emerald-600 dark:text-emerald-400" />
|
||||
</div>
|
||||
<div class="flex-1 min-w-0">
|
||||
<h3 class="text-lg font-semibold text-foreground leading-tight">
|
||||
{{ isRegenerating ? '令牌已重新生成' : '创建成功' }}
|
||||
</h3>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
请妥善保管,此令牌只会显示一次
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="space-y-4">
|
||||
<div class="space-y-2">
|
||||
<Label class="text-sm font-medium">访问令牌</Label>
|
||||
<div class="flex items-center gap-2">
|
||||
<Input
|
||||
type="text"
|
||||
:value="newTokenValue"
|
||||
readonly
|
||||
class="flex-1 font-mono text-sm bg-muted/50 h-11"
|
||||
@click="($event.target as HTMLInputElement)?.select()"
|
||||
/>
|
||||
<Button
|
||||
class="h-11"
|
||||
@click="copyToken(newTokenValue)"
|
||||
>
|
||||
复制
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="p-3 rounded-lg bg-amber-50 dark:bg-amber-950/30 border border-amber-200 dark:border-amber-800">
|
||||
<div class="flex gap-2">
|
||||
<AlertTriangle class="h-4 w-4 text-amber-600 dark:text-amber-400 flex-shrink-0 mt-0.5" />
|
||||
<p class="text-sm text-amber-800 dark:text-amber-200">
|
||||
此令牌只会显示一次,关闭后将无法再次查看,请妥善保管。
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button
|
||||
class="h-10 px-5"
|
||||
@click="showTokenDialog = false"
|
||||
>
|
||||
我已安全保存
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
|
||||
<!-- 删除确认对话框 -->
|
||||
<AlertDialog
|
||||
v-model="showDeleteDialog"
|
||||
type="danger"
|
||||
title="确认删除"
|
||||
:description="`确定要删除令牌「${tokenToDelete?.name}」吗?此操作不可恢复。`"
|
||||
confirm-text="删除"
|
||||
:loading="deleting"
|
||||
@confirm="deleteToken"
|
||||
@cancel="showDeleteDialog = false"
|
||||
/>
|
||||
|
||||
<!-- 重新生成确认对话框 -->
|
||||
<AlertDialog
|
||||
v-model="showRegenerateDialog"
|
||||
type="warning"
|
||||
title="确认重新生成"
|
||||
:description="`重新生成后,原令牌将立即失效。确定要重新生成「${tokenToRegenerate?.name}」吗?`"
|
||||
confirm-text="重新生成"
|
||||
:loading="regenerating"
|
||||
@confirm="regenerateToken"
|
||||
@cancel="showRegenerateDialog = false"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, reactive, watch } from 'vue'
|
||||
import {
|
||||
managementTokenApi,
|
||||
type ManagementToken
|
||||
} from '@/api/management-tokens'
|
||||
import Card from '@/components/ui/card.vue'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import Input from '@/components/ui/input.vue'
|
||||
import Label from '@/components/ui/label.vue'
|
||||
import Badge from '@/components/ui/badge.vue'
|
||||
import { Dialog, Pagination } from '@/components/ui'
|
||||
import { LoadingState, AlertDialog, EmptyState } from '@/components/common'
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow
|
||||
} from '@/components/ui'
|
||||
import RefreshButton from '@/components/ui/refresh-button.vue'
|
||||
import {
|
||||
Plus,
|
||||
KeyRound,
|
||||
Trash2,
|
||||
Loader2,
|
||||
CheckCircle,
|
||||
Power,
|
||||
Pencil,
|
||||
RefreshCw,
|
||||
AlertTriangle
|
||||
} from 'lucide-vue-next'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { log } from '@/utils/logger'
|
||||
|
||||
const { success, error: showError } = useToast()
|
||||
|
||||
// 数据
|
||||
const tokens = ref<ManagementToken[]>([])
|
||||
const totalTokens = ref(0)
|
||||
const loading = ref(false)
|
||||
const saving = ref(false)
|
||||
const deleting = ref(false)
|
||||
const regenerating = ref(false)
|
||||
|
||||
// 配额信息
|
||||
const quota = ref<{ used: number; max: number } | null>(null)
|
||||
|
||||
// 分页
|
||||
const currentPage = ref(1)
|
||||
const pageSize = ref(10)
|
||||
|
||||
const paginatedTokens = computed(() => tokens.value)
|
||||
|
||||
// 监听分页变化
|
||||
watch([currentPage, pageSize], () => {
|
||||
loadTokens()
|
||||
})
|
||||
|
||||
function handlePageSizeChange(newSize: number) {
|
||||
pageSize.value = newSize
|
||||
currentPage.value = 1
|
||||
}
|
||||
|
||||
// 对话框状态
|
||||
const showCreateDialog = ref(false)
|
||||
const showTokenDialog = ref(false)
|
||||
const showDeleteDialog = ref(false)
|
||||
const showRegenerateDialog = ref(false)
|
||||
|
||||
// 表单数据
|
||||
const editingToken = ref<ManagementToken | null>(null)
|
||||
const formData = reactive({
|
||||
name: '',
|
||||
description: '',
|
||||
allowedIpsText: '',
|
||||
expiresAt: ''
|
||||
})
|
||||
|
||||
const newTokenValue = ref('')
|
||||
const isRegenerating = ref(false)
|
||||
const tokenToDelete = ref<ManagementToken | null>(null)
|
||||
const tokenToRegenerate = ref<ManagementToken | null>(null)
|
||||
|
||||
// 表单验证
|
||||
const isFormValid = computed(() => {
|
||||
return formData.name.trim().length > 0
|
||||
})
|
||||
|
||||
function getStatusVariant(token: ManagementToken): 'success' | 'secondary' | 'destructive' {
|
||||
if (token.expires_at && isExpired(token.expires_at)) {
|
||||
return 'destructive'
|
||||
}
|
||||
return token.is_active ? 'success' : 'secondary'
|
||||
}
|
||||
|
||||
function getStatusText(token: ManagementToken): string {
|
||||
if (token.expires_at && isExpired(token.expires_at)) {
|
||||
return '已过期'
|
||||
}
|
||||
return token.is_active ? '活跃' : '禁用'
|
||||
}
|
||||
|
||||
function isExpired(dateString: string): boolean {
|
||||
return new Date(dateString) < new Date()
|
||||
}
|
||||
|
||||
// 加载数据
|
||||
onMounted(() => {
|
||||
loadTokens()
|
||||
})
|
||||
|
||||
async function loadTokens() {
|
||||
loading.value = true
|
||||
try {
|
||||
const skip = (currentPage.value - 1) * pageSize.value
|
||||
const response = await managementTokenApi.listTokens({ skip, limit: pageSize.value })
|
||||
|
||||
tokens.value = response.items
|
||||
totalTokens.value = response.total
|
||||
|
||||
if (response.quota) {
|
||||
quota.value = response.quota
|
||||
}
|
||||
|
||||
// 如果当前页超出范围,重置到第一页
|
||||
if (tokens.value.length === 0 && currentPage.value > 1) {
|
||||
currentPage.value = 1
|
||||
}
|
||||
} catch (err: any) {
|
||||
log.error('加载 Management Tokens 失败:', err)
|
||||
if (!err.response) {
|
||||
showError('无法连接到服务器')
|
||||
} else {
|
||||
showError(`加载失败:${err.response?.data?.detail || err.message}`)
|
||||
}
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 打开编辑对话框
|
||||
function openEditDialog(token: ManagementToken) {
|
||||
editingToken.value = token
|
||||
formData.name = token.name
|
||||
formData.description = token.description || ''
|
||||
formData.allowedIpsText = (token.allowed_ips && token.allowed_ips.length > 0)
|
||||
? token.allowed_ips.join(', ')
|
||||
: ''
|
||||
formData.expiresAt = token.expires_at
|
||||
? toLocalDatetimeString(new Date(token.expires_at))
|
||||
: ''
|
||||
showCreateDialog.value = true
|
||||
}
|
||||
|
||||
// 关闭对话框
|
||||
function closeDialog() {
|
||||
showCreateDialog.value = false
|
||||
editingToken.value = null
|
||||
formData.name = ''
|
||||
formData.description = ''
|
||||
formData.allowedIpsText = ''
|
||||
formData.expiresAt = ''
|
||||
}
|
||||
|
||||
// 保存 Token
|
||||
async function saveToken() {
|
||||
if (!isFormValid.value) return
|
||||
|
||||
saving.value = true
|
||||
try {
|
||||
const allowedIps = formData.allowedIpsText
|
||||
.split(',')
|
||||
.map(ip => ip.trim())
|
||||
.filter(ip => ip)
|
||||
|
||||
// 将本地时间转换为 UTC ISO 字符串
|
||||
const expiresAtUtc = formData.expiresAt
|
||||
? new Date(formData.expiresAt).toISOString()
|
||||
: null
|
||||
|
||||
if (editingToken.value) {
|
||||
// 更新
|
||||
await managementTokenApi.updateToken(editingToken.value.id, {
|
||||
name: formData.name,
|
||||
description: formData.description.trim() || null,
|
||||
allowed_ips: allowedIps.length > 0 ? allowedIps : null,
|
||||
expires_at: expiresAtUtc
|
||||
})
|
||||
success('令牌更新成功')
|
||||
} else {
|
||||
// 创建
|
||||
const result = await managementTokenApi.createToken({
|
||||
name: formData.name,
|
||||
description: formData.description || undefined,
|
||||
allowed_ips: allowedIps.length > 0 ? allowedIps : undefined,
|
||||
expires_at: expiresAtUtc
|
||||
})
|
||||
newTokenValue.value = result.token
|
||||
isRegenerating.value = false
|
||||
showTokenDialog.value = true
|
||||
success('令牌创建成功')
|
||||
}
|
||||
|
||||
closeDialog()
|
||||
await loadTokens()
|
||||
} catch (err: any) {
|
||||
log.error('保存 Token 失败:', err)
|
||||
const message = err.response?.data?.error?.message
|
||||
|| err.response?.data?.detail
|
||||
|| '保存失败'
|
||||
showError(message)
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 切换状态
|
||||
async function toggleToken(token: ManagementToken) {
|
||||
try {
|
||||
const result = await managementTokenApi.toggleToken(token.id)
|
||||
|
||||
const index = tokens.value.findIndex(t => t.id === token.id)
|
||||
if (index !== -1) {
|
||||
tokens.value[index] = result.data
|
||||
}
|
||||
success(result.data.is_active ? '令牌已启用' : '令牌已禁用')
|
||||
} catch (err: any) {
|
||||
log.error('切换状态失败:', err)
|
||||
showError('操作失败')
|
||||
}
|
||||
}
|
||||
|
||||
// 删除
|
||||
function confirmDelete(token: ManagementToken) {
|
||||
tokenToDelete.value = token
|
||||
showDeleteDialog.value = true
|
||||
}
|
||||
|
||||
async function deleteToken() {
|
||||
if (!tokenToDelete.value) return
|
||||
|
||||
deleting.value = true
|
||||
try {
|
||||
await managementTokenApi.deleteToken(tokenToDelete.value.id)
|
||||
|
||||
showDeleteDialog.value = false
|
||||
success('令牌已删除')
|
||||
await loadTokens()
|
||||
} catch (err: any) {
|
||||
log.error('删除 Token 失败:', err)
|
||||
showError('删除失败')
|
||||
} finally {
|
||||
deleting.value = false
|
||||
tokenToDelete.value = null
|
||||
}
|
||||
}
|
||||
|
||||
// 重新生成
|
||||
function confirmRegenerate(token: ManagementToken) {
|
||||
tokenToRegenerate.value = token
|
||||
showRegenerateDialog.value = true
|
||||
}
|
||||
|
||||
async function regenerateToken() {
|
||||
if (!tokenToRegenerate.value) return
|
||||
|
||||
regenerating.value = true
|
||||
try {
|
||||
const result = await managementTokenApi.regenerateToken(tokenToRegenerate.value.id)
|
||||
newTokenValue.value = result.token
|
||||
isRegenerating.value = true
|
||||
showRegenerateDialog.value = false
|
||||
showTokenDialog.value = true
|
||||
await loadTokens()
|
||||
success('令牌已重新生成')
|
||||
} catch (err: any) {
|
||||
log.error('重新生成失败:', err)
|
||||
showError('重新生成失败')
|
||||
} finally {
|
||||
regenerating.value = false
|
||||
tokenToRegenerate.value = null
|
||||
}
|
||||
}
|
||||
|
||||
// 复制 Token
|
||||
async function copyToken(text: string) {
|
||||
try {
|
||||
if (navigator.clipboard && window.isSecureContext) {
|
||||
await navigator.clipboard.writeText(text)
|
||||
success('已复制到剪贴板')
|
||||
} else {
|
||||
const textArea = document.createElement('textarea')
|
||||
textArea.value = text
|
||||
textArea.style.position = 'fixed'
|
||||
textArea.style.left = '-999999px'
|
||||
document.body.appendChild(textArea)
|
||||
textArea.select()
|
||||
document.execCommand('copy')
|
||||
document.body.removeChild(textArea)
|
||||
success('已复制到剪贴板')
|
||||
}
|
||||
} catch (err) {
|
||||
log.error('复制失败:', err)
|
||||
showError('复制失败')
|
||||
}
|
||||
}
|
||||
|
||||
// 格式化
|
||||
function formatNumber(num: number): string {
|
||||
return num.toLocaleString('zh-CN')
|
||||
}
|
||||
|
||||
function toLocalDatetimeString(date: Date): string {
|
||||
const year = date.getFullYear()
|
||||
const month = String(date.getMonth() + 1).padStart(2, '0')
|
||||
const day = String(date.getDate()).padStart(2, '0')
|
||||
const hours = String(date.getHours()).padStart(2, '0')
|
||||
const minutes = String(date.getMinutes()).padStart(2, '0')
|
||||
return `${year}-${month}-${day}T${hours}:${minutes}`
|
||||
}
|
||||
|
||||
function formatDate(dateString: string): string {
|
||||
const date = new Date(dateString)
|
||||
return date.toLocaleDateString('zh-CN', {
|
||||
year: 'numeric',
|
||||
month: '2-digit',
|
||||
day: '2-digit'
|
||||
})
|
||||
}
|
||||
|
||||
function formatRelativeTime(dateString: string): string {
|
||||
const date = new Date(dateString)
|
||||
const now = new Date()
|
||||
const diffMs = now.getTime() - date.getTime()
|
||||
const diffMins = Math.floor(diffMs / (1000 * 60))
|
||||
const diffHours = Math.floor(diffMs / (1000 * 60 * 60))
|
||||
const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24))
|
||||
|
||||
if (diffMins < 1) return '刚刚'
|
||||
if (diffMins < 60) return `${diffMins}分钟前`
|
||||
if (diffHours < 24) return `${diffHours}小时前`
|
||||
if (diffDays < 7) return `${diffDays}天前`
|
||||
|
||||
return formatDate(dateString)
|
||||
}
|
||||
</script>
|
||||
@@ -47,6 +47,7 @@ dependencies = [
|
||||
"redis>=5.0.0",
|
||||
"prometheus-client>=0.20.0",
|
||||
"apscheduler>=3.10.0",
|
||||
"ldap3>=2.9.1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
||||
commit_id: COMMIT_ID
|
||||
__commit_id__: COMMIT_ID
|
||||
|
||||
__version__ = version = '0.1.1.dev0+g393d4d13f.d20251213'
|
||||
__version_tuple__ = version_tuple = (0, 1, 1, 'dev0', 'g393d4d13f.d20251213')
|
||||
__version__ = version = '0.2.5'
|
||||
__version_tuple__ = version_tuple = (0, 2, 5)
|
||||
|
||||
__commit_id__ = commit_id = None
|
||||
|
||||
@@ -5,6 +5,8 @@ from fastapi import APIRouter
|
||||
from .adaptive import router as adaptive_router
|
||||
from .api_keys import router as api_keys_router
|
||||
from .endpoints import router as endpoints_router
|
||||
from .ldap import router as ldap_router
|
||||
from .management_tokens import router as management_tokens_router
|
||||
from .models import router as models_router
|
||||
from .monitoring import router as monitoring_router
|
||||
from .provider_query import router as provider_query_router
|
||||
@@ -28,5 +30,7 @@ router.include_router(adaptive_router)
|
||||
router.include_router(models_router)
|
||||
router.include_router(security_router)
|
||||
router.include_router(provider_query_router)
|
||||
router.include_router(ldap_router)
|
||||
router.include_router(management_tokens_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
|
||||
@@ -73,7 +73,26 @@ async def list_standalone_api_keys(
|
||||
is_active: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""列出所有独立余额API Keys"""
|
||||
"""
|
||||
列出所有独立余额 API Keys
|
||||
|
||||
获取系统中所有独立余额 API Key 的列表。独立余额 Key 不关联用户配额,
|
||||
有独立的余额限制,主要用于给非注册用户使用。
|
||||
|
||||
**查询参数**:
|
||||
- `skip`: 跳过的记录数(分页偏移量),默认 0
|
||||
- `limit`: 返回的记录数(分页限制),默认 100,最大 500
|
||||
- `is_active`: 可选,根据启用状态筛选(true/false)
|
||||
|
||||
**返回字段**:
|
||||
- `api_keys`: API Key 列表,包含 id, name, key_display, is_active, current_balance_usd,
|
||||
balance_used_usd, total_requests, total_cost_usd, rate_limit, allowed_providers,
|
||||
allowed_api_formats, allowed_models, last_used_at, expires_at, created_at, updated_at,
|
||||
auto_delete_on_expiry 等字段
|
||||
- `total`: 符合条件的总记录数
|
||||
- `limit`: 当前分页限制
|
||||
- `skip`: 当前分页偏移量
|
||||
"""
|
||||
adapter = AdminListStandaloneKeysAdapter(skip=skip, limit=limit, is_active=is_active)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -84,7 +103,35 @@ async def create_standalone_api_key(
|
||||
key_data: CreateApiKeyRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""创建独立余额API Key(必须设置余额限制)"""
|
||||
"""
|
||||
创建独立余额 API Key
|
||||
|
||||
创建一个新的独立余额 API Key。独立余额 Key 必须设置初始余额限制。
|
||||
|
||||
**请求体字段**:
|
||||
- `name`: API Key 的名称
|
||||
- `initial_balance_usd`: 必需,初始余额(美元),必须大于 0
|
||||
- `allowed_providers`: 可选,允许使用的提供商列表
|
||||
- `allowed_api_formats`: 可选,允许使用的 API 格式列表
|
||||
- `allowed_models`: 可选,允许使用的模型列表
|
||||
- `rate_limit`: 可选,速率限制配置(请求数/秒)
|
||||
- `expire_days`: 可选,过期天数(兼容旧版)
|
||||
- `expires_at`: 可选,过期时间(ISO 格式或 YYYY-MM-DD 格式,优先级高于 expire_days)
|
||||
- `auto_delete_on_expiry`: 可选,过期后是否自动删除
|
||||
|
||||
**返回字段**:
|
||||
- `id`: API Key ID
|
||||
- `key`: 完整的 API Key(仅在创建时返回一次)
|
||||
- `name`: API Key 名称
|
||||
- `key_display`: 脱敏显示的 Key
|
||||
- `is_standalone`: 是否为独立余额 Key(始终为 true)
|
||||
- `current_balance_usd`: 当前余额
|
||||
- `balance_used_usd`: 已使用余额
|
||||
- `rate_limit`: 速率限制配置
|
||||
- `expires_at`: 过期时间
|
||||
- `created_at`: 创建时间
|
||||
- `message`: 提示信息
|
||||
"""
|
||||
adapter = AdminCreateStandaloneKeyAdapter(key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -93,20 +140,72 @@ async def create_standalone_api_key(
|
||||
async def update_api_key(
|
||||
key_id: str, request: Request, key_data: CreateApiKeyRequest, db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新独立余额Key(可修改名称、过期时间、余额限制等)"""
|
||||
"""
|
||||
更新独立余额 API Key
|
||||
|
||||
更新指定 ID 的独立余额 API Key 的配置信息。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: API Key ID
|
||||
|
||||
**请求体字段**:
|
||||
- `name`: 可选,API Key 的名称
|
||||
- `rate_limit`: 可选,速率限制配置(null 表示无限制)
|
||||
- `allowed_providers`: 可选,允许使用的提供商列表
|
||||
- `allowed_api_formats`: 可选,允许使用的 API 格式列表
|
||||
- `allowed_models`: 可选,允许使用的模型列表
|
||||
- `expire_days`: 可选,过期天数(兼容旧版)
|
||||
- `expires_at`: 可选,过期时间(ISO 格式或 YYYY-MM-DD 格式,优先级高于 expire_days,null 或空字符串表示永不过期)
|
||||
- `auto_delete_on_expiry`: 可选,过期后是否自动删除
|
||||
|
||||
**返回字段**:
|
||||
- `id`: API Key ID
|
||||
- `name`: API Key 名称
|
||||
- `key_display`: 脱敏显示的 Key
|
||||
- `is_active`: 是否启用
|
||||
- `current_balance_usd`: 当前余额
|
||||
- `balance_used_usd`: 已使用余额
|
||||
- `rate_limit`: 速率限制配置
|
||||
- `expires_at`: 过期时间
|
||||
- `updated_at`: 更新时间
|
||||
- `message`: 提示信息
|
||||
"""
|
||||
adapter = AdminUpdateApiKeyAdapter(key_id=key_id, key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{key_id}")
|
||||
async def toggle_api_key(key_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""Toggle API key active status (PATCH with is_active in body)"""
|
||||
"""
|
||||
切换 API Key 启用状态
|
||||
|
||||
切换指定 API Key 的启用/禁用状态。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: API Key ID
|
||||
|
||||
**返回字段**:
|
||||
- `id`: API Key ID
|
||||
- `is_active`: 新的启用状态
|
||||
- `message`: 提示信息
|
||||
"""
|
||||
adapter = AdminToggleApiKeyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{key_id}")
|
||||
async def delete_api_key(key_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
删除 API Key
|
||||
|
||||
删除指定的 API Key。此操作不可逆。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: API Key ID
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 提示信息
|
||||
"""
|
||||
adapter = AdminDeleteApiKeyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -117,7 +216,24 @@ async def add_balance_to_key(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Adjust balance for standalone API key (positive to add, negative to deduct)"""
|
||||
"""
|
||||
调整独立余额 API Key 的余额
|
||||
|
||||
为指定的独立余额 API Key 增加或扣除余额。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: API Key ID
|
||||
|
||||
**请求体字段**:
|
||||
- `amount_usd`: 调整金额(美元),正数为充值,负数为扣除
|
||||
|
||||
**返回字段**:
|
||||
- `id`: API Key ID
|
||||
- `name`: API Key 名称
|
||||
- `current_balance_usd`: 调整后的当前余额
|
||||
- `balance_used_usd`: 已使用余额
|
||||
- `message`: 提示信息
|
||||
"""
|
||||
# 从请求体获取调整金额
|
||||
body = await request.json()
|
||||
amount_usd = body.get("amount_usd")
|
||||
@@ -162,7 +278,24 @@ async def get_api_key_detail(
|
||||
include_key: bool = Query(False, description="Include full decrypted key in response"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get API key detail, optionally include full key"""
|
||||
"""
|
||||
获取 API Key 详情
|
||||
|
||||
获取指定 API Key 的详细信息。可选择是否返回完整的解密密钥。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: API Key ID
|
||||
|
||||
**查询参数**:
|
||||
- `include_key`: 是否包含完整的解密密钥,默认 false
|
||||
|
||||
**返回字段**:
|
||||
- 当 include_key=false 时,返回基本信息:id, user_id, name, key_display, is_active,
|
||||
is_standalone, current_balance_usd, balance_used_usd, total_requests, total_cost_usd,
|
||||
rate_limit, allowed_providers, allowed_api_formats, allowed_models, last_used_at,
|
||||
expires_at, created_at, updated_at
|
||||
- 当 include_key=true 时,返回完整密钥:key
|
||||
"""
|
||||
if include_key:
|
||||
adapter = AdminGetFullKeyAdapter(key_id=key_id)
|
||||
else:
|
||||
|
||||
@@ -7,7 +7,7 @@ from .health import router as health_router
|
||||
from .keys import router as keys_router
|
||||
from .routes import router as routes_router
|
||||
|
||||
router = APIRouter(prefix="/api/admin/endpoints", tags=["Endpoint Management"])
|
||||
router = APIRouter(prefix="/api/admin/endpoints", tags=["Admin - Endpoints"])
|
||||
|
||||
# Endpoint CRUD
|
||||
router.include_router(routes_router)
|
||||
|
||||
@@ -29,7 +29,19 @@ async def get_endpoint_concurrency(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ConcurrencyStatusResponse:
|
||||
"""获取 Endpoint 当前并发状态"""
|
||||
"""
|
||||
获取 Endpoint 当前并发状态
|
||||
|
||||
查询指定 Endpoint 的实时并发使用情况,包括当前并发数和最大并发限制。
|
||||
|
||||
**路径参数**:
|
||||
- `endpoint_id`: Endpoint ID
|
||||
|
||||
**返回字段**:
|
||||
- `endpoint_id`: Endpoint ID
|
||||
- `endpoint_current_concurrency`: 当前并发数
|
||||
- `endpoint_max_concurrent`: 最大并发限制
|
||||
"""
|
||||
adapter = AdminEndpointConcurrencyAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -40,7 +52,19 @@ async def get_key_concurrency(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ConcurrencyStatusResponse:
|
||||
"""获取 Key 当前并发状态"""
|
||||
"""
|
||||
获取 Key 当前并发状态
|
||||
|
||||
查询指定 API Key 的实时并发使用情况,包括当前并发数和最大并发限制。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: API Key ID
|
||||
|
||||
**返回字段**:
|
||||
- `key_id`: API Key ID
|
||||
- `key_current_concurrency`: 当前并发数
|
||||
- `key_max_concurrent`: 最大并发限制
|
||||
"""
|
||||
adapter = AdminKeyConcurrencyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -51,7 +75,19 @@ async def reset_concurrency(
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Reset concurrency counters (admin function, use with caution)"""
|
||||
"""
|
||||
重置并发计数器
|
||||
|
||||
重置指定 Endpoint 或 Key 的并发计数器,用于解决计数不准确的问题。
|
||||
管理员功能,请谨慎使用。
|
||||
|
||||
**请求体字段**:
|
||||
- `endpoint_id`: Endpoint ID(可选)
|
||||
- `key_id`: API Key ID(可选)
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果消息
|
||||
"""
|
||||
adapter = AdminResetConcurrencyAdapter(endpoint_id=request.endpoint_id, key_id=request.key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@@ -36,7 +36,20 @@ async def get_health_summary(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> HealthSummaryResponse:
|
||||
"""获取健康状态摘要"""
|
||||
"""
|
||||
获取健康状态摘要
|
||||
|
||||
获取系统整体健康状态摘要,包括所有 Provider、Endpoint 和 Key 的健康状态统计。
|
||||
|
||||
**返回字段**:
|
||||
- `total_providers`: Provider 总数
|
||||
- `active_providers`: 活跃 Provider 数量
|
||||
- `total_endpoints`: Endpoint 总数
|
||||
- `active_endpoints`: 活跃 Endpoint 数量
|
||||
- `total_keys`: Key 总数
|
||||
- `active_keys`: 活跃 Key 数量
|
||||
- `circuit_breaker_open_keys`: 熔断的 Key 数量
|
||||
"""
|
||||
adapter = AdminHealthSummaryAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -50,9 +63,21 @@ async def get_endpoint_health_status(
|
||||
"""
|
||||
获取端点健康状态(简化视图,与用户端点统一)
|
||||
|
||||
获取按 API 格式聚合的端点健康状态时间线,基于 Usage 表统计,
|
||||
返回 50 个时间段的聚合状态,适用于快速查看整体健康趋势。
|
||||
|
||||
与 /health/api-formats 的区别:
|
||||
- /health/status: 返回聚合的时间线状态(50个时间段),基于 Usage 表
|
||||
- /health/api-formats: 返回详细的事件列表,基于 RequestCandidate 表
|
||||
|
||||
**查询参数**:
|
||||
- `lookback_hours`: 回溯的小时数(1-72),默认 6
|
||||
|
||||
**返回字段**:
|
||||
- `api_format`: API 格式名称
|
||||
- `timeline`: 时间线数据(50个时间段)
|
||||
- `time_range_start`: 时间范围起始
|
||||
- `time_range_end`: 时间范围结束
|
||||
"""
|
||||
adapter = AdminEndpointHealthStatusAdapter(lookback_hours=lookback_hours)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -65,7 +90,33 @@ async def get_api_format_health_monitor(
|
||||
per_format_limit: int = Query(60, ge=10, le=200, description="每个 API 格式的事件数量"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> ApiFormatHealthMonitorResponse:
|
||||
"""获取按 API 格式聚合的健康监控时间线(详细事件列表)"""
|
||||
"""
|
||||
获取按 API 格式聚合的健康监控时间线(详细事件列表)
|
||||
|
||||
获取每个 API 格式的详细健康监控数据,包括请求事件列表、成功率统计、
|
||||
时间线数据等,基于 RequestCandidate 表查询,适用于详细分析。
|
||||
|
||||
**查询参数**:
|
||||
- `lookback_hours`: 回溯的小时数(1-72),默认 6
|
||||
- `per_format_limit`: 每个 API 格式返回的事件数量(10-200),默认 60
|
||||
|
||||
**返回字段**:
|
||||
- `generated_at`: 数据生成时间
|
||||
- `formats`: API 格式健康监控数据列表
|
||||
- `api_format`: API 格式名称
|
||||
- `total_attempts`: 总请求数
|
||||
- `success_count`: 成功请求数
|
||||
- `failed_count`: 失败请求数
|
||||
- `skipped_count`: 跳过请求数
|
||||
- `success_rate`: 成功率
|
||||
- `provider_count`: Provider 数量
|
||||
- `key_count`: Key 数量
|
||||
- `last_event_at`: 最后事件时间
|
||||
- `events`: 事件列表
|
||||
- `timeline`: 时间线数据
|
||||
- `time_range_start`: 时间范围起始
|
||||
- `time_range_end`: 时间范围结束
|
||||
"""
|
||||
adapter = AdminApiFormatHealthMonitorAdapter(
|
||||
lookback_hours=lookback_hours,
|
||||
per_format_limit=per_format_limit,
|
||||
@@ -79,7 +130,26 @@ async def get_key_health(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> HealthStatusResponse:
|
||||
"""获取 Key 健康状态"""
|
||||
"""
|
||||
获取 Key 健康状态
|
||||
|
||||
获取指定 API Key 的健康状态详情,包括健康分数、连续失败次数、
|
||||
熔断器状态等信息。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: API Key ID
|
||||
|
||||
**返回字段**:
|
||||
- `key_id`: API Key ID
|
||||
- `key_health_score`: 健康分数(0.0-1.0)
|
||||
- `key_consecutive_failures`: 连续失败次数
|
||||
- `key_last_failure_at`: 最后失败时间
|
||||
- `key_is_active`: 是否活跃
|
||||
- `key_statistics`: 统计信息
|
||||
- `circuit_breaker_open`: 熔断器是否打开
|
||||
- `circuit_breaker_open_at`: 熔断器打开时间
|
||||
- `next_probe_at`: 下次探测时间
|
||||
"""
|
||||
adapter = AdminKeyHealthAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -91,13 +161,20 @@ async def recover_key_health(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
Recover key health status
|
||||
恢复 Key 健康状态
|
||||
|
||||
Resets health_score to 1.0, closes circuit breaker,
|
||||
cancels auto-disable, and resets all failure counts.
|
||||
手动恢复指定 Key 的健康状态,将健康分数重置为 1.0,关闭熔断器,
|
||||
取消自动禁用,并重置所有失败计数。
|
||||
|
||||
Parameters:
|
||||
- key_id: Key ID (path parameter)
|
||||
**路径参数**:
|
||||
- `key_id`: API Key ID
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果消息
|
||||
- `details`: 详细信息
|
||||
- `health_score`: 健康分数
|
||||
- `circuit_breaker_open`: 熔断器状态
|
||||
- `is_active`: 是否活跃
|
||||
"""
|
||||
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -109,12 +186,21 @@ async def recover_all_keys_health(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
Batch recover all circuit-broken keys
|
||||
批量恢复所有熔断 Key 的健康状态
|
||||
|
||||
Finds all keys with circuit_breaker_open=True and:
|
||||
1. Resets health_score to 1.0
|
||||
2. Closes circuit breaker
|
||||
3. Resets failure counts
|
||||
查找所有处于熔断状态的 Key(circuit_breaker_open=True),
|
||||
并批量执行以下操作:
|
||||
1. 将健康分数重置为 1.0
|
||||
2. 关闭熔断器
|
||||
3. 重置失败计数
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果消息
|
||||
- `recovered_count`: 恢复的 Key 数量
|
||||
- `recovered_keys`: 恢复的 Key 列表
|
||||
- `key_id`: Key ID
|
||||
- `key_name`: Key 名称
|
||||
- `endpoint_id`: Endpoint ID
|
||||
"""
|
||||
adapter = AdminRecoverAllKeysHealthAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -37,7 +37,33 @@ async def list_endpoint_keys(
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[EndpointAPIKeyResponse]:
|
||||
"""获取 Endpoint 的所有 Keys"""
|
||||
"""
|
||||
获取 Endpoint 的所有 Keys
|
||||
|
||||
获取指定 Endpoint 下的所有 API Key 列表,包括 Key 的配置、统计信息等。
|
||||
结果按优先级和创建时间排序。
|
||||
|
||||
**路径参数**:
|
||||
- `endpoint_id`: Endpoint ID
|
||||
|
||||
**查询参数**:
|
||||
- `skip`: 跳过的记录数,用于分页(默认 0)
|
||||
- `limit`: 返回的最大记录数(1-1000,默认 100)
|
||||
|
||||
**返回字段**:
|
||||
- `id`: Key ID
|
||||
- `name`: Key 名称
|
||||
- `api_key_masked`: 脱敏后的 API Key
|
||||
- `internal_priority`: 内部优先级
|
||||
- `global_priority`: 全局优先级
|
||||
- `rate_multiplier`: 速率倍数
|
||||
- `max_concurrent`: 最大并发数(null 表示自适应模式)
|
||||
- `is_adaptive`: 是否为自适应并发模式
|
||||
- `effective_limit`: 有效并发限制
|
||||
- `success_rate`: 成功率
|
||||
- `avg_response_time_ms`: 平均响应时间(毫秒)
|
||||
- 其他配置和统计字段
|
||||
"""
|
||||
adapter = AdminListEndpointKeysAdapter(
|
||||
endpoint_id=endpoint_id,
|
||||
skip=skip,
|
||||
@@ -53,7 +79,32 @@ async def add_endpoint_key(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> EndpointAPIKeyResponse:
|
||||
"""为 Endpoint 添加 Key"""
|
||||
"""
|
||||
为 Endpoint 添加 Key
|
||||
|
||||
为指定 Endpoint 添加新的 API Key,支持配置并发限制、速率倍数、
|
||||
优先级、配额限制、能力限制等。
|
||||
|
||||
**路径参数**:
|
||||
- `endpoint_id`: Endpoint ID
|
||||
|
||||
**请求体字段**:
|
||||
- `endpoint_id`: Endpoint ID(必须与路径参数一致)
|
||||
- `api_key`: API Key 原文(将被加密存储)
|
||||
- `name`: Key 名称
|
||||
- `note`: 备注(可选)
|
||||
- `rate_multiplier`: 速率倍数(默认 1.0)
|
||||
- `internal_priority`: 内部优先级(默认 100)
|
||||
- `max_concurrent`: 最大并发数(null 表示自适应模式)
|
||||
- `rate_limit`: 每分钟请求限制(可选)
|
||||
- `daily_limit`: 每日请求限制(可选)
|
||||
- `monthly_limit`: 每月请求限制(可选)
|
||||
- `allowed_models`: 允许的模型列表(可选)
|
||||
- `capabilities`: 能力配置(可选)
|
||||
|
||||
**返回字段**:
|
||||
- 包含完整的 Key 信息,其中 `api_key_plain` 为原文(仅在创建时返回)
|
||||
"""
|
||||
adapter = AdminCreateEndpointKeyAdapter(endpoint_id=endpoint_id, key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -65,7 +116,32 @@ async def update_endpoint_key(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> EndpointAPIKeyResponse:
|
||||
"""更新 Endpoint Key"""
|
||||
"""
|
||||
更新 Endpoint Key
|
||||
|
||||
更新指定 Key 的配置,支持修改并发限制、速率倍数、优先级、
|
||||
配额限制、能力限制等。支持部分更新。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: Key ID
|
||||
|
||||
**请求体字段**(均为可选):
|
||||
- `api_key`: 新的 API Key 原文
|
||||
- `name`: Key 名称
|
||||
- `note`: 备注
|
||||
- `rate_multiplier`: 速率倍数
|
||||
- `internal_priority`: 内部优先级
|
||||
- `max_concurrent`: 最大并发数(设置为 null 可切换到自适应模式)
|
||||
- `rate_limit`: 每分钟请求限制
|
||||
- `daily_limit`: 每日请求限制
|
||||
- `monthly_limit`: 每月请求限制
|
||||
- `allowed_models`: 允许的模型列表
|
||||
- `capabilities`: 能力配置
|
||||
- `is_active`: 是否活跃
|
||||
|
||||
**返回字段**:
|
||||
- 包含更新后的完整 Key 信息
|
||||
"""
|
||||
adapter = AdminUpdateEndpointKeyAdapter(key_id=key_id, key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -75,7 +151,31 @@ async def get_keys_grouped_by_format(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""获取按 API 格式分组的所有 Keys(用于全局优先级管理)"""
|
||||
"""
|
||||
获取按 API 格式分组的所有 Keys
|
||||
|
||||
获取所有活跃的 Key,按 API 格式分组返回,用于全局优先级管理。
|
||||
每个 Key 包含基本信息、健康度指标、能力标签等。
|
||||
|
||||
**返回字段**:
|
||||
- 返回一个字典,键为 API 格式,值为该格式下的 Key 列表
|
||||
- 每个 Key 包含:
|
||||
- `id`: Key ID
|
||||
- `name`: Key 名称
|
||||
- `api_key_masked`: 脱敏后的 API Key
|
||||
- `internal_priority`: 内部优先级
|
||||
- `global_priority`: 全局优先级
|
||||
- `rate_multiplier`: 速率倍数
|
||||
- `is_active`: 是否活跃
|
||||
- `circuit_breaker_open`: 熔断器状态
|
||||
- `provider_name`: Provider 名称
|
||||
- `endpoint_base_url`: Endpoint 基础 URL
|
||||
- `api_format`: API 格式
|
||||
- `capabilities`: 能力简称列表
|
||||
- `success_rate`: 成功率
|
||||
- `avg_response_time_ms`: 平均响应时间
|
||||
- `request_count`: 请求总数
|
||||
"""
|
||||
adapter = AdminGetKeysGroupedByFormatAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -86,7 +186,18 @@ async def reveal_endpoint_key(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""获取完整的 API Key(用于查看和复制)"""
|
||||
"""
|
||||
获取完整的 API Key
|
||||
|
||||
解密并返回指定 Key 的完整原文,用于查看和复制。
|
||||
此操作会被记录到审计日志。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: Key ID
|
||||
|
||||
**返回字段**:
|
||||
- `api_key`: 完整的 API Key 原文
|
||||
"""
|
||||
adapter = AdminRevealEndpointKeyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -97,7 +208,17 @@ async def delete_endpoint_key(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""删除 Endpoint Key"""
|
||||
"""
|
||||
删除 Endpoint Key
|
||||
|
||||
删除指定的 API Key。此操作不可逆,请谨慎使用。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: Key ID
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果消息
|
||||
"""
|
||||
adapter = AdminDeleteEndpointKeyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -109,7 +230,24 @@ async def batch_update_key_priority(
|
||||
priority_data: BatchUpdateKeyPriorityRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""批量更新 Endpoint 下 Keys 的优先级(用于拖动排序)"""
|
||||
"""
|
||||
批量更新 Endpoint 下 Keys 的优先级
|
||||
|
||||
批量更新指定 Endpoint 下多个 Key 的内部优先级,用于拖动排序。
|
||||
所有 Key 必须属于指定的 Endpoint。
|
||||
|
||||
**路径参数**:
|
||||
- `endpoint_id`: Endpoint ID
|
||||
|
||||
**请求体字段**:
|
||||
- `priorities`: 优先级列表
|
||||
- `key_id`: Key ID
|
||||
- `internal_priority`: 新的内部优先级
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果消息
|
||||
- `updated_count`: 实际更新的 Key 数量
|
||||
"""
|
||||
adapter = AdminBatchUpdateKeyPriorityAdapter(endpoint_id=endpoint_id, priority_data=priority_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@@ -45,7 +45,36 @@ async def list_provider_endpoints(
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[ProviderEndpointResponse]:
|
||||
"""获取指定 Provider 的所有 Endpoints"""
|
||||
"""
|
||||
获取指定 Provider 的所有 Endpoints
|
||||
|
||||
获取指定 Provider 下的所有 Endpoint 列表,包括配置、统计信息等。
|
||||
结果按创建时间倒序排列。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: Provider ID
|
||||
|
||||
**查询参数**:
|
||||
- `skip`: 跳过的记录数,用于分页(默认 0)
|
||||
- `limit`: 返回的最大记录数(1-1000,默认 100)
|
||||
|
||||
**返回字段**:
|
||||
- `id`: Endpoint ID
|
||||
- `provider_id`: Provider ID
|
||||
- `provider_name`: Provider 名称
|
||||
- `api_format`: API 格式
|
||||
- `base_url`: 基础 URL
|
||||
- `custom_path`: 自定义路径
|
||||
- `timeout`: 超时时间(秒)
|
||||
- `max_retries`: 最大重试次数
|
||||
- `max_concurrent`: 最大并发数
|
||||
- `rate_limit`: 速率限制
|
||||
- `is_active`: 是否活跃
|
||||
- `total_keys`: Key 总数
|
||||
- `active_keys`: 活跃 Key 数量
|
||||
- `proxy`: 代理配置(密码已脱敏)
|
||||
- 其他配置字段
|
||||
"""
|
||||
adapter = AdminListProviderEndpointsAdapter(
|
||||
provider_id=provider_id,
|
||||
skip=skip,
|
||||
@@ -61,7 +90,31 @@ async def create_provider_endpoint(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderEndpointResponse:
|
||||
"""为 Provider 创建新的 Endpoint"""
|
||||
"""
|
||||
为 Provider 创建新的 Endpoint
|
||||
|
||||
为指定 Provider 创建新的 Endpoint,每个 Provider 的每种 API 格式
|
||||
只能创建一个 Endpoint。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: Provider ID
|
||||
|
||||
**请求体字段**:
|
||||
- `provider_id`: Provider ID(必须与路径参数一致)
|
||||
- `api_format`: API 格式(如 claude、openai、gemini 等)
|
||||
- `base_url`: 基础 URL
|
||||
- `custom_path`: 自定义路径(可选)
|
||||
- `headers`: 自定义请求头(可选)
|
||||
- `timeout`: 超时时间(秒,默认 300)
|
||||
- `max_retries`: 最大重试次数(默认 2)
|
||||
- `max_concurrent`: 最大并发数(可选)
|
||||
- `rate_limit`: 速率限制(可选)
|
||||
- `config`: 额外配置(可选)
|
||||
- `proxy`: 代理配置(可选)
|
||||
|
||||
**返回字段**:
|
||||
- 包含完整的 Endpoint 信息
|
||||
"""
|
||||
adapter = AdminCreateProviderEndpointAdapter(
|
||||
provider_id=provider_id,
|
||||
endpoint_data=endpoint_data,
|
||||
@@ -75,7 +128,31 @@ async def get_endpoint(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderEndpointResponse:
|
||||
"""获取 Endpoint 详情"""
|
||||
"""
|
||||
获取 Endpoint 详情
|
||||
|
||||
获取指定 Endpoint 的详细信息,包括配置、统计信息等。
|
||||
|
||||
**路径参数**:
|
||||
- `endpoint_id`: Endpoint ID
|
||||
|
||||
**返回字段**:
|
||||
- `id`: Endpoint ID
|
||||
- `provider_id`: Provider ID
|
||||
- `provider_name`: Provider 名称
|
||||
- `api_format`: API 格式
|
||||
- `base_url`: 基础 URL
|
||||
- `custom_path`: 自定义路径
|
||||
- `timeout`: 超时时间(秒)
|
||||
- `max_retries`: 最大重试次数
|
||||
- `max_concurrent`: 最大并发数
|
||||
- `rate_limit`: 速率限制
|
||||
- `is_active`: 是否活跃
|
||||
- `total_keys`: Key 总数
|
||||
- `active_keys`: 活跃 Key 数量
|
||||
- `proxy`: 代理配置(密码已脱敏)
|
||||
- 其他配置字段
|
||||
"""
|
||||
adapter = AdminGetProviderEndpointAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -87,7 +164,29 @@ async def update_endpoint(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderEndpointResponse:
|
||||
"""更新 Endpoint"""
|
||||
"""
|
||||
更新 Endpoint
|
||||
|
||||
更新指定 Endpoint 的配置。支持部分更新。
|
||||
|
||||
**路径参数**:
|
||||
- `endpoint_id`: Endpoint ID
|
||||
|
||||
**请求体字段**(均为可选):
|
||||
- `base_url`: 基础 URL
|
||||
- `custom_path`: 自定义路径
|
||||
- `headers`: 自定义请求头
|
||||
- `timeout`: 超时时间(秒)
|
||||
- `max_retries`: 最大重试次数
|
||||
- `max_concurrent`: 最大并发数
|
||||
- `rate_limit`: 速率限制
|
||||
- `is_active`: 是否活跃
|
||||
- `config`: 额外配置
|
||||
- `proxy`: 代理配置(设置为 null 可清除代理)
|
||||
|
||||
**返回字段**:
|
||||
- 包含更新后的完整 Endpoint 信息
|
||||
"""
|
||||
adapter = AdminUpdateProviderEndpointAdapter(
|
||||
endpoint_id=endpoint_id,
|
||||
endpoint_data=endpoint_data,
|
||||
@@ -101,7 +200,19 @@ async def delete_endpoint(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""删除 Endpoint(级联删除所有关联的 Keys)"""
|
||||
"""
|
||||
删除 Endpoint
|
||||
|
||||
删除指定的 Endpoint,同时级联删除所有关联的 API Keys。
|
||||
此操作不可逆,请谨慎使用。
|
||||
|
||||
**路径参数**:
|
||||
- `endpoint_id`: Endpoint ID
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果消息
|
||||
- `deleted_keys_count`: 同时删除的 Key 数量
|
||||
"""
|
||||
adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
501
src/api/admin/ldap.py
Normal file
501
src/api/admin/ldap.py
Normal file
@@ -0,0 +1,501 @@
|
||||
"""LDAP配置管理API端点。"""
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.enums import AuthSource
|
||||
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import AuditEventType, LDAPConfig, User, UserRole
|
||||
from src.services.system.audit import AuditService
|
||||
|
||||
router = APIRouter(prefix="/api/admin/ldap", tags=["Admin - LDAP"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
# bcrypt 哈希格式正则:$2a$, $2b$, $2y$ + 2位cost + $ + 53字符(22位salt + 31位hash)
|
||||
BCRYPT_HASH_PATTERN = re.compile(r"^\$2[aby]\$\d{2}\$.{53}$")
|
||||
|
||||
|
||||
# ========== Request/Response Models ==========
|
||||
|
||||
|
||||
class LDAPConfigResponse(BaseModel):
|
||||
"""LDAP配置响应(不返回密码)"""
|
||||
|
||||
server_url: Optional[str] = None
|
||||
bind_dn: Optional[str] = None
|
||||
base_dn: Optional[str] = None
|
||||
has_bind_password: bool = False
|
||||
user_search_filter: str
|
||||
username_attr: str
|
||||
email_attr: str
|
||||
display_name_attr: str
|
||||
is_enabled: bool
|
||||
is_exclusive: bool
|
||||
use_starttls: bool
|
||||
connect_timeout: int
|
||||
|
||||
|
||||
class LDAPConfigUpdate(BaseModel):
|
||||
"""LDAP配置更新请求"""
|
||||
|
||||
server_url: str = Field(..., min_length=1, max_length=255)
|
||||
bind_dn: str = Field(..., min_length=1, max_length=255)
|
||||
# 允许空字符串表示"清除密码";非空时自动 strip 并校验不能为空
|
||||
bind_password: Optional[str] = Field(None, max_length=1024)
|
||||
base_dn: str = Field(..., min_length=1, max_length=255)
|
||||
user_search_filter: str = Field(default="(uid={username})", max_length=500)
|
||||
username_attr: str = Field(default="uid", max_length=50)
|
||||
email_attr: str = Field(default="mail", max_length=50)
|
||||
display_name_attr: str = Field(default="cn", max_length=50)
|
||||
is_enabled: bool = False
|
||||
is_exclusive: bool = False
|
||||
use_starttls: bool = False
|
||||
connect_timeout: int = Field(default=10, ge=1, le=60) # 单次操作超时,跨国网络建议 15-30 秒
|
||||
|
||||
@field_validator("bind_password")
|
||||
@classmethod
|
||||
def validate_bind_password(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None or v == "":
|
||||
return v
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("绑定密码不能为空")
|
||||
return v
|
||||
|
||||
@field_validator("user_search_filter")
|
||||
@classmethod
|
||||
def validate_search_filter(cls, v: str) -> str:
|
||||
if "{username}" not in v:
|
||||
raise ValueError("搜索过滤器必须包含 {username} 占位符")
|
||||
# 验证括号匹配和嵌套正确性
|
||||
depth = 0
|
||||
for char in v:
|
||||
if char == "(":
|
||||
depth += 1
|
||||
elif char == ")":
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
raise ValueError("搜索过滤器括号不匹配")
|
||||
if depth != 0:
|
||||
raise ValueError("搜索过滤器括号不匹配")
|
||||
# 限制过滤器复杂度,防止构造复杂查询
|
||||
# 检查嵌套层数而非括号总数
|
||||
depth = 0
|
||||
max_depth = 0
|
||||
for char in v:
|
||||
if char == "(":
|
||||
depth += 1
|
||||
max_depth = max(max_depth, depth)
|
||||
elif char == ")":
|
||||
depth -= 1
|
||||
if max_depth > 5:
|
||||
raise ValueError("搜索过滤器嵌套层数过深(最多5层)")
|
||||
if len(v) > 200:
|
||||
raise ValueError("搜索过滤器过长(最多200字符)")
|
||||
return v
|
||||
|
||||
|
||||
class LDAPTestResponse(BaseModel):
|
||||
"""LDAP连接测试响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class LDAPConfigTest(BaseModel):
|
||||
"""LDAP配置测试请求(全部可选,用于临时覆盖)"""
|
||||
|
||||
server_url: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
bind_dn: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
bind_password: Optional[str] = Field(None, min_length=1)
|
||||
base_dn: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
user_search_filter: Optional[str] = Field(None, max_length=500)
|
||||
username_attr: Optional[str] = Field(None, max_length=50)
|
||||
email_attr: Optional[str] = Field(None, max_length=50)
|
||||
display_name_attr: Optional[str] = Field(None, max_length=50)
|
||||
is_enabled: Optional[bool] = None
|
||||
is_exclusive: Optional[bool] = None
|
||||
use_starttls: Optional[bool] = None
|
||||
connect_timeout: Optional[int] = Field(None, ge=1, le=60)
|
||||
|
||||
@field_validator("user_search_filter")
|
||||
@classmethod
|
||||
def validate_search_filter(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
if "{username}" not in v:
|
||||
raise ValueError("搜索过滤器必须包含 {username} 占位符")
|
||||
# 验证括号匹配和嵌套正确性
|
||||
depth = 0
|
||||
for char in v:
|
||||
if char == "(":
|
||||
depth += 1
|
||||
elif char == ")":
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
raise ValueError("搜索过滤器括号不匹配")
|
||||
if depth != 0:
|
||||
raise ValueError("搜索过滤器括号不匹配")
|
||||
# 限制过滤器复杂度(检查嵌套层数而非括号总数)
|
||||
depth = 0
|
||||
max_depth = 0
|
||||
for char in v:
|
||||
if char == "(":
|
||||
depth += 1
|
||||
max_depth = max(max_depth, depth)
|
||||
elif char == ")":
|
||||
depth -= 1
|
||||
if max_depth > 5:
|
||||
raise ValueError("搜索过滤器嵌套层数过深(最多5层)")
|
||||
if len(v) > 200:
|
||||
raise ValueError("搜索过滤器过长(最多200字符)")
|
||||
return v
|
||||
|
||||
|
||||
# ========== API Endpoints ==========
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_ldap_config(request: Request, db: Session = Depends(get_db)) -> Any:
|
||||
"""
|
||||
获取 LDAP 配置
|
||||
|
||||
获取系统当前的 LDAP 认证配置信息,用于管理界面显示和编辑。
|
||||
密码字段不会返回原文,仅返回是否已设置的标志。
|
||||
|
||||
**返回字段**:
|
||||
- `server_url`: LDAP 服务器地址(如:ldap://ldap.example.com:389)
|
||||
- `bind_dn`: 绑定 DN(如:cn=admin,dc=example,dc=com)
|
||||
- `base_dn`: 搜索基准 DN(如:ou=users,dc=example,dc=com)
|
||||
- `has_bind_password`: 是否已设置绑定密码(布尔值)
|
||||
- `user_search_filter`: 用户搜索过滤器(默认:(uid={username}))
|
||||
- `username_attr`: 用户名属性(默认:uid)
|
||||
- `email_attr`: 邮箱属性(默认:mail)
|
||||
- `display_name_attr`: 显示名称属性(默认:cn)
|
||||
- `is_enabled`: 是否启用 LDAP 认证
|
||||
- `is_exclusive`: 是否仅允许 LDAP 登录(独占模式)
|
||||
- `use_starttls`: 是否使用 STARTTLS 加密连接
|
||||
- `connect_timeout`: 连接超时时间(秒,1-60)
|
||||
"""
|
||||
adapter = AdminGetLDAPConfigAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/config")
|
||||
async def update_ldap_config(request: Request, db: Session = Depends(get_db)) -> Any:
|
||||
"""
|
||||
更新 LDAP 配置
|
||||
|
||||
更新系统的 LDAP 认证配置。支持完整配置更新,包括连接参数、
|
||||
搜索过滤器、属性映射等。提供多重安全校验,防止误锁定管理员。
|
||||
|
||||
**请求体字段**:
|
||||
- `server_url`: LDAP 服务器地址(必填,1-255字符)
|
||||
- `bind_dn`: 绑定 DN(必填,1-255字符)
|
||||
- `bind_password`: 绑定密码(可选,设为空字符串可清除密码)
|
||||
- `base_dn`: 搜索基准 DN(必填,1-255字符)
|
||||
- `user_search_filter`: 用户搜索过滤器(必须包含 {username} 占位符,默认:(uid={username}))
|
||||
- `username_attr`: 用户名属性(默认:uid)
|
||||
- `email_attr`: 邮箱属性(默认:mail)
|
||||
- `display_name_attr`: 显示名称属性(默认:cn)
|
||||
- `is_enabled`: 是否启用 LDAP 认证
|
||||
- `is_exclusive`: 是否仅允许 LDAP 登录(需先启用 LDAP)
|
||||
- `use_starttls`: 是否使用 STARTTLS 加密连接
|
||||
- `connect_timeout`: 连接超时时间(秒,1-60,默认 10)
|
||||
|
||||
**安全校验**:
|
||||
- 启用 LDAP 时必须设置有效的绑定密码
|
||||
- 启用独占模式前会检查是否有至少 1 个有效的本地管理员账户
|
||||
- 独占模式要求先启用 LDAP 认证
|
||||
- 搜索过滤器必须包含 {username} 占位符且括号匹配
|
||||
- 搜索过滤器嵌套层数不超过 5 层,长度不超过 200 字符
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果消息
|
||||
"""
|
||||
adapter = AdminUpdateLDAPConfigAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_ldap_connection(request: Request, db: Session = Depends(get_db)) -> Any:
|
||||
"""
|
||||
测试 LDAP 连接
|
||||
|
||||
在保存配置前测试 LDAP 服务器连接是否正常。支持使用已保存的配置,
|
||||
也支持通过请求体覆盖任意配置项进行临时测试,而不影响已保存的配置。
|
||||
|
||||
**请求体字段**(均为可选,用于临时覆盖):
|
||||
- `server_url`: LDAP 服务器地址(覆盖已保存的配置)
|
||||
- `bind_dn`: 绑定 DN(覆盖已保存的配置)
|
||||
- `bind_password`: 绑定密码(覆盖已保存的密码)
|
||||
- `base_dn`: 搜索基准 DN(覆盖已保存的配置)
|
||||
- `user_search_filter`: 用户搜索过滤器(覆盖已保存的配置)
|
||||
- `username_attr`: 用户名属性(覆盖已保存的配置)
|
||||
- `email_attr`: 邮箱属性(覆盖已保存的配置)
|
||||
- `display_name_attr`: 显示名称属性(覆盖已保存的配置)
|
||||
- `use_starttls`: 是否使用 STARTTLS(覆盖已保存的配置)
|
||||
- `connect_timeout`: 连接超时时间(覆盖已保存的配置)
|
||||
|
||||
**测试逻辑**:
|
||||
- 未提供的字段使用已保存的配置值
|
||||
- `bind_password` 优先使用请求体中的值,否则使用已保存的加密密码
|
||||
- 测试时会尝试连接 LDAP 服务器并验证绑定 DN
|
||||
|
||||
**返回字段**:
|
||||
- `success`: 测试是否成功(布尔值)
|
||||
- `message`: 测试结果消息(成功或失败原因)
|
||||
"""
|
||||
adapter = AdminTestLDAPConnectionAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ========== Adapters ==========
|
||||
|
||||
|
||||
class AdminGetLDAPConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context) -> Dict[str, Any]: # type: ignore[override]
|
||||
db = context.db
|
||||
config = db.query(LDAPConfig).first()
|
||||
|
||||
if not config:
|
||||
return LDAPConfigResponse(
|
||||
server_url=None,
|
||||
bind_dn=None,
|
||||
base_dn=None,
|
||||
has_bind_password=False,
|
||||
user_search_filter="(uid={username})",
|
||||
username_attr="uid",
|
||||
email_attr="mail",
|
||||
display_name_attr="cn",
|
||||
is_enabled=False,
|
||||
is_exclusive=False,
|
||||
use_starttls=False,
|
||||
connect_timeout=10,
|
||||
).model_dump()
|
||||
|
||||
return LDAPConfigResponse(
|
||||
server_url=config.server_url,
|
||||
bind_dn=config.bind_dn,
|
||||
base_dn=config.base_dn,
|
||||
has_bind_password=bool(config.bind_password_encrypted),
|
||||
user_search_filter=config.user_search_filter,
|
||||
username_attr=config.username_attr,
|
||||
email_attr=config.email_attr,
|
||||
display_name_attr=config.display_name_attr,
|
||||
is_enabled=config.is_enabled,
|
||||
is_exclusive=config.is_exclusive,
|
||||
use_starttls=config.use_starttls,
|
||||
connect_timeout=config.connect_timeout,
|
||||
).model_dump()
|
||||
|
||||
|
||||
class AdminUpdateLDAPConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context) -> Dict[str, str]: # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
try:
|
||||
config_update = LDAPConfigUpdate.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
# 使用行级锁防止并发修改导致的竞态条件
|
||||
config = db.query(LDAPConfig).with_for_update().first()
|
||||
is_new_config = config is None
|
||||
|
||||
if is_new_config:
|
||||
# 首次创建配置时必须提供密码
|
||||
if not config_update.bind_password:
|
||||
raise InvalidRequestException("首次配置 LDAP 时必须设置绑定密码")
|
||||
config = LDAPConfig()
|
||||
db.add(config)
|
||||
|
||||
# 需要启用 LDAP 且未提交新密码时,验证已保存密码可解密(避免开启后不可用)
|
||||
if config_update.is_enabled and config_update.bind_password is None:
|
||||
try:
|
||||
if not config.get_bind_password():
|
||||
raise InvalidRequestException("启用 LDAP 认证 需要先设置绑定密码")
|
||||
except InvalidRequestException:
|
||||
raise
|
||||
except Exception:
|
||||
raise InvalidRequestException("绑定密码解密失败,请重新设置绑定密码")
|
||||
|
||||
# 计算更新后的密码状态(用于校验是否可启用/独占)
|
||||
if config_update.bind_password is None:
|
||||
will_have_password = bool(config.bind_password_encrypted)
|
||||
elif config_update.bind_password == "":
|
||||
will_have_password = False
|
||||
else:
|
||||
will_have_password = True
|
||||
|
||||
# 独占模式必须启用 LDAP 且必须有绑定密码(防止误锁定)
|
||||
if config_update.is_exclusive and not config_update.is_enabled:
|
||||
raise InvalidRequestException("仅允许 LDAP 登录 需要先启用 LDAP 认证")
|
||||
if config_update.is_enabled and not will_have_password:
|
||||
raise InvalidRequestException("启用 LDAP 认证 需要先设置绑定密码")
|
||||
if config_update.is_exclusive and not will_have_password:
|
||||
raise InvalidRequestException("仅允许 LDAP 登录 需要先设置绑定密码")
|
||||
|
||||
config.server_url = config_update.server_url
|
||||
config.bind_dn = config_update.bind_dn
|
||||
config.base_dn = config_update.base_dn
|
||||
config.user_search_filter = config_update.user_search_filter
|
||||
config.username_attr = config_update.username_attr
|
||||
config.email_attr = config_update.email_attr
|
||||
config.display_name_attr = config_update.display_name_attr
|
||||
config.is_enabled = config_update.is_enabled
|
||||
config.is_exclusive = config_update.is_exclusive
|
||||
config.use_starttls = config_update.use_starttls
|
||||
config.connect_timeout = config_update.connect_timeout
|
||||
|
||||
# 启用独占模式前检查是否有足够的本地管理员(防止锁定)
|
||||
# 使用 with_for_update() 阻塞锁防止竞态条件(移除 nowait 确保并发安全)
|
||||
if config_update.is_enabled and config_update.is_exclusive:
|
||||
local_admins = (
|
||||
db.query(User)
|
||||
.filter(
|
||||
User.role == UserRole.ADMIN,
|
||||
User.auth_source == AuthSource.LOCAL,
|
||||
User.is_active.is_(True),
|
||||
User.is_deleted.is_(False),
|
||||
)
|
||||
.with_for_update()
|
||||
.all()
|
||||
)
|
||||
# 验证至少有一个管理员有有效的密码哈希(可以登录)
|
||||
# 使用严格的 bcrypt 格式校验:$2a$/$2b$/$2y$ + 2位cost + $ + 53字符
|
||||
valid_admin_count = sum(
|
||||
1
|
||||
for admin in local_admins
|
||||
if admin.password_hash
|
||||
and isinstance(admin.password_hash, str)
|
||||
and BCRYPT_HASH_PATTERN.match(admin.password_hash)
|
||||
)
|
||||
if valid_admin_count < 1:
|
||||
raise InvalidRequestException(
|
||||
"启用 LDAP 独占模式前,必须至少保留 1 个有效的本地管理员账户(含有效密码)作为紧急恢复通道"
|
||||
)
|
||||
|
||||
if config_update.bind_password is not None:
|
||||
if config_update.bind_password == "":
|
||||
# 显式清除密码(设置为 NULL)
|
||||
config.bind_password_encrypted = None
|
||||
password_changed = "cleared"
|
||||
else:
|
||||
config.bind_password_encrypted = crypto_service.encrypt(config_update.bind_password)
|
||||
password_changed = "updated"
|
||||
else:
|
||||
password_changed = None
|
||||
|
||||
db.commit()
|
||||
|
||||
# 记录审计日志
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type=AuditEventType.CONFIG_CHANGED,
|
||||
description=f"LDAP 配置已更新 (enabled={config_update.is_enabled}, exclusive={config_update.is_exclusive})",
|
||||
user_id=str(context.user.id) if context.user else None,
|
||||
metadata={
|
||||
"server_url": config_update.server_url,
|
||||
"is_enabled": config_update.is_enabled,
|
||||
"is_exclusive": config_update.is_exclusive,
|
||||
"password_changed": password_changed,
|
||||
"is_new_config": is_new_config,
|
||||
},
|
||||
)
|
||||
|
||||
return {"message": "LDAP配置更新成功"}
|
||||
|
||||
|
||||
class AdminTestLDAPConnectionAdapter(AdminApiAdapter):
|
||||
async def handle(self, context) -> Dict[str, Any]: # type: ignore[override]
|
||||
from src.services.auth.ldap import LDAPService
|
||||
|
||||
db = context.db
|
||||
if context.json_body is not None:
|
||||
payload = context.json_body
|
||||
elif not context.raw_body:
|
||||
payload = {}
|
||||
else:
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
saved_config = db.query(LDAPConfig).first()
|
||||
|
||||
try:
|
||||
overrides = LDAPConfigTest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
config_data: Dict[str, Any] = {}
|
||||
|
||||
if saved_config:
|
||||
config_data = {
|
||||
"server_url": saved_config.server_url,
|
||||
"bind_dn": saved_config.bind_dn,
|
||||
"base_dn": saved_config.base_dn,
|
||||
"user_search_filter": saved_config.user_search_filter,
|
||||
"username_attr": saved_config.username_attr,
|
||||
"email_attr": saved_config.email_attr,
|
||||
"display_name_attr": saved_config.display_name_attr,
|
||||
"use_starttls": saved_config.use_starttls,
|
||||
"connect_timeout": saved_config.connect_timeout,
|
||||
}
|
||||
|
||||
# 应用前端传入的覆盖值
|
||||
for field in [
|
||||
"server_url",
|
||||
"bind_dn",
|
||||
"base_dn",
|
||||
"user_search_filter",
|
||||
"username_attr",
|
||||
"email_attr",
|
||||
"display_name_attr",
|
||||
"use_starttls",
|
||||
"is_enabled",
|
||||
"is_exclusive",
|
||||
"connect_timeout",
|
||||
]:
|
||||
value = getattr(overrides, field)
|
||||
if value is not None:
|
||||
config_data[field] = value
|
||||
|
||||
# bind_password 优先使用 overrides;否则使用已保存的密码(允许保存密码无法解密时依然用 overrides 测试)
|
||||
if overrides.bind_password is not None:
|
||||
config_data["bind_password"] = overrides.bind_password
|
||||
elif saved_config and saved_config.bind_password_encrypted:
|
||||
try:
|
||||
config_data["bind_password"] = crypto_service.decrypt(
|
||||
saved_config.bind_password_encrypted
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"绑定密码解密失败: {type(e).__name__}: {e}")
|
||||
return LDAPTestResponse(
|
||||
success=False, message="绑定密码解密失败,请检查配置或重新设置密码"
|
||||
).model_dump()
|
||||
|
||||
# 必填字段检查
|
||||
required_fields = ["server_url", "bind_dn", "base_dn", "bind_password"]
|
||||
missing = [f for f in required_fields if not config_data.get(f)]
|
||||
if missing:
|
||||
return LDAPTestResponse(
|
||||
success=False, message=f"缺少必要字段: {', '.join(missing)}"
|
||||
).model_dump()
|
||||
|
||||
success, message = LDAPService.test_connection_with_config(config_data)
|
||||
return LDAPTestResponse(success=success, message=message).model_dump()
|
||||
10
src/api/admin/management_tokens/__init__.py
Normal file
10
src/api/admin/management_tokens/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Management Token 管理员路由模块"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .routes import router as management_tokens_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(management_tokens_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
300
src/api/admin/management_tokens/routes.py
Normal file
300
src/api/admin/management_tokens/routes.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""管理员 Management Token 管理端点"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.context import ApiRequestContext
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import NotFoundException
|
||||
from src.database import get_db
|
||||
from src.models.database import AuditEventType, ManagementToken, User
|
||||
from src.services.management_token import ManagementTokenService, token_to_dict
|
||||
|
||||
router = APIRouter(prefix="/api/admin/management-tokens", tags=["Admin - Management Tokens"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
# ============== 安全基类 ==============
|
||||
|
||||
|
||||
class AdminManagementTokenApiAdapter(AdminApiAdapter):
|
||||
"""管理员 Management Token 管理 API 的基类
|
||||
|
||||
安全限制:禁止使用 Management Token 调用这些接口。
|
||||
"""
|
||||
|
||||
def authorize(self, context: ApiRequestContext) -> None:
|
||||
# 先调用父类的认证和权限检查
|
||||
super().authorize(context)
|
||||
|
||||
# 禁止使用 Management Token 调用 management-tokens 相关接口
|
||||
if context.management_token is not None:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="不允许使用 Management Token 管理其他 Token,请使用 Web 界面或 JWT 认证",
|
||||
)
|
||||
|
||||
|
||||
# ============== 路由 ==============
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_all_management_tokens(
|
||||
request: Request,
|
||||
user_id: Optional[str] = Query(None, description="筛选用户 ID"),
|
||||
is_active: Optional[bool] = Query(None, description="筛选激活状态"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""列出所有 Management Tokens(管理员)
|
||||
|
||||
管理员查看所有用户的 Management Tokens,支持筛选和分页。
|
||||
|
||||
**查询参数**
|
||||
- user_id (Optional[str]): 筛选指定用户 ID 的 tokens
|
||||
- is_active (Optional[bool]): 筛选激活状态(true/false)
|
||||
- skip (int): 分页偏移量,默认 0
|
||||
- limit (int): 每页数量,范围 1-100,默认 50
|
||||
|
||||
**返回字段**
|
||||
- items (List[dict]): Token 列表
|
||||
- id (str): Token ID
|
||||
- user_id (str): 所属用户 ID
|
||||
- user (dict): 用户信息(包含 id, username, email 等)
|
||||
- name (str): Token 名称
|
||||
- description (Optional[str]): 描述
|
||||
- token_hash (str): Token 哈希值(不返回明文)
|
||||
- is_active (bool): 是否激活
|
||||
- allowed_ips (Optional[List[str]]): IP 白名单
|
||||
- expires_at (Optional[str]): 过期时间(ISO 8601 格式)
|
||||
- last_used_at (Optional[str]): 最后使用时间
|
||||
- created_at (str): 创建时间
|
||||
- updated_at (str): 更新时间
|
||||
- total (int): 总数量
|
||||
- skip (int): 当前偏移量
|
||||
- limit (int): 当前每页数量
|
||||
"""
|
||||
adapter = AdminListManagementTokensAdapter(
|
||||
user_id=user_id, is_active=is_active, skip=skip, limit=limit
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{token_id}")
|
||||
async def get_management_token(
|
||||
token_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取 Management Token 详情(管理员)
|
||||
|
||||
管理员查看任意 Management Token 的详细信息。
|
||||
|
||||
**路径参数**
|
||||
- token_id (str): Token ID
|
||||
|
||||
**返回字段**
|
||||
- id (str): Token ID
|
||||
- user_id (str): 所属用户 ID
|
||||
- user (dict): 用户信息(包含 id, username, email 等)
|
||||
- name (str): Token 名称
|
||||
- description (Optional[str]): 描述
|
||||
- token_hash (str): Token 哈希值(不返回明文)
|
||||
- is_active (bool): 是否激活
|
||||
- allowed_ips (Optional[List[str]]): IP 白名单
|
||||
- expires_at (Optional[str]): 过期时间(ISO 8601 格式)
|
||||
- last_used_at (Optional[str]): 最后使用时间
|
||||
- created_at (str): 创建时间
|
||||
- updated_at (str): 更新时间
|
||||
"""
|
||||
adapter = AdminGetManagementTokenAdapter(token_id=token_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{token_id}")
|
||||
async def delete_management_token(
|
||||
token_id: str, request: Request, db: Session = Depends(get_db)
|
||||
):
|
||||
"""删除任意 Management Token(管理员)
|
||||
|
||||
管理员可以删除任意用户的 Management Token。
|
||||
|
||||
**路径参数**
|
||||
- token_id (str): 要删除的 Token ID
|
||||
|
||||
**返回字段**
|
||||
- message (str): 操作结果消息
|
||||
"""
|
||||
adapter = AdminDeleteManagementTokenAdapter(token_id=token_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{token_id}/status")
|
||||
async def toggle_management_token(
|
||||
token_id: str, request: Request, db: Session = Depends(get_db)
|
||||
):
|
||||
"""切换任意 Management Token 状态(管理员)
|
||||
|
||||
管理员可以启用/禁用任意用户的 Management Token。
|
||||
|
||||
**路径参数**
|
||||
- token_id (str): Token ID
|
||||
|
||||
**返回字段**
|
||||
- message (str): 操作结果消息("Token 已启用" 或 "Token 已禁用")
|
||||
- data (dict): 更新后的 Token 信息
|
||||
- id (str): Token ID
|
||||
- user_id (str): 所属用户 ID
|
||||
- user (dict): 用户信息
|
||||
- name (str): Token 名称
|
||||
- description (Optional[str]): 描述
|
||||
- token_hash (str): Token 哈希值
|
||||
- is_active (bool): 是否激活(已切换后的状态)
|
||||
- allowed_ips (Optional[List[str]]): IP 白名单
|
||||
- expires_at (Optional[str]): 过期时间
|
||||
- last_used_at (Optional[str]): 最后使用时间
|
||||
- created_at (str): 创建时间
|
||||
- updated_at (str): 更新时间
|
||||
"""
|
||||
adapter = AdminToggleManagementTokenAdapter(token_id=token_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ============== 适配器 ==============
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminListManagementTokensAdapter(AdminManagementTokenApiAdapter):
|
||||
"""列出所有 Management Tokens"""
|
||||
|
||||
name: str = "admin_list_management_tokens"
|
||||
user_id: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
skip: int = 0
|
||||
limit: int = 50
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
# 构建查询
|
||||
query = context.db.query(ManagementToken)
|
||||
|
||||
if self.user_id:
|
||||
query = query.filter(ManagementToken.user_id == self.user_id)
|
||||
if self.is_active is not None:
|
||||
query = query.filter(ManagementToken.is_active == self.is_active)
|
||||
|
||||
total = query.count()
|
||||
tokens = (
|
||||
query.order_by(ManagementToken.created_at.desc())
|
||||
.offset(self.skip)
|
||||
.limit(self.limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 预加载用户信息
|
||||
user_ids = list(set(t.user_id for t in tokens))
|
||||
users = {u.id: u for u in context.db.query(User).filter(User.id.in_(user_ids)).all()}
|
||||
for token in tokens:
|
||||
token.user = users.get(token.user_id)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"items": [token_to_dict(t, include_user=True) for t in tokens],
|
||||
"total": total,
|
||||
"skip": self.skip,
|
||||
"limit": self.limit,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetManagementTokenAdapter(AdminManagementTokenApiAdapter):
|
||||
"""获取 Management Token 详情"""
|
||||
|
||||
name: str = "admin_get_management_token"
|
||||
token_id: str = ""
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
token = ManagementTokenService.get_token_by_id(
|
||||
db=context.db, token_id=self.token_id
|
||||
)
|
||||
|
||||
if not token:
|
||||
raise NotFoundException("Management Token 不存在")
|
||||
|
||||
# 加载用户信息
|
||||
token.user = context.db.query(User).filter(User.id == token.user_id).first()
|
||||
|
||||
return JSONResponse(content=token_to_dict(token, include_user=True))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteManagementTokenAdapter(AdminManagementTokenApiAdapter):
|
||||
"""删除 Management Token"""
|
||||
|
||||
name: str = "admin_delete_management_token"
|
||||
token_id: str = ""
|
||||
audit_success_event = AuditEventType.MANAGEMENT_TOKEN_DELETED
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
# 先获取 token 信息用于审计
|
||||
token = ManagementTokenService.get_token_by_id(
|
||||
db=context.db, token_id=self.token_id
|
||||
)
|
||||
|
||||
if not token:
|
||||
raise NotFoundException("Management Token 不存在")
|
||||
|
||||
context.add_audit_metadata(
|
||||
token_id=token.id,
|
||||
token_name=token.name,
|
||||
owner_user_id=token.user_id,
|
||||
)
|
||||
|
||||
success = ManagementTokenService.delete_token(
|
||||
db=context.db, token_id=self.token_id
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise NotFoundException("Management Token 不存在")
|
||||
|
||||
return JSONResponse(content={"message": "删除成功"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminToggleManagementTokenAdapter(AdminManagementTokenApiAdapter):
|
||||
"""切换 Management Token 状态"""
|
||||
|
||||
name: str = "admin_toggle_management_token"
|
||||
token_id: str = ""
|
||||
audit_success_event = AuditEventType.MANAGEMENT_TOKEN_UPDATED
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
token = ManagementTokenService.toggle_status(
|
||||
db=context.db, token_id=self.token_id
|
||||
)
|
||||
|
||||
if not token:
|
||||
raise NotFoundException("Management Token 不存在")
|
||||
|
||||
# 加载用户信息
|
||||
token.user = context.db.query(User).filter(User.id == token.user_id).first()
|
||||
|
||||
context.add_audit_metadata(
|
||||
token_id=token.id,
|
||||
token_name=token.name,
|
||||
owner_user_id=token.user_id,
|
||||
is_active=token.is_active,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"message": f"Token 已{'启用' if token.is_active else '禁用'}",
|
||||
"data": token_to_dict(token, include_user=True),
|
||||
}
|
||||
)
|
||||
@@ -8,7 +8,7 @@ from .catalog import router as catalog_router
|
||||
from .external import router as external_router
|
||||
from .global_models import router as global_models_router
|
||||
|
||||
router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"])
|
||||
router = APIRouter(prefix="/api/admin/models", tags=["Admin - Models"])
|
||||
|
||||
# 挂载子路由
|
||||
router.include_router(catalog_router)
|
||||
|
||||
@@ -31,6 +31,22 @@ async def get_model_catalog(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ModelCatalogResponse:
|
||||
"""
|
||||
获取统一模型目录
|
||||
|
||||
基于 GlobalModel 聚合所有活跃模型及其关联提供商的信息,返回完整的模型目录视图。
|
||||
|
||||
**返回字段**:
|
||||
- `models`: 模型列表,每个模型包含:
|
||||
- `global_model_name`: GlobalModel 名称
|
||||
- `display_name`: 显示名称
|
||||
- `description`: 模型描述
|
||||
- `providers`: 提供商列表,包含提供商名称、价格、能力等详细信息
|
||||
- `price_range`: 价格区间(基于 GlobalModel 第一阶梯价格)
|
||||
- `total_providers`: 关联提供商数量
|
||||
- `capabilities`: 模型能力标志(视觉、函数调用、流式输出)
|
||||
- `total`: 模型总数
|
||||
"""
|
||||
adapter = AdminGetModelCatalogAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@@ -82,9 +82,21 @@ def _mark_official_providers(data: dict[str, Any]) -> dict[str, Any]:
|
||||
@router.get("/external")
|
||||
async def get_external_models(_: User = Depends(require_admin)) -> JSONResponse:
|
||||
"""
|
||||
获取 models.dev 的模型数据(代理请求,解决跨域问题)
|
||||
数据缓存 15 分钟(使用 Redis,多 worker 共享)
|
||||
每个提供商会标记 official 字段,前端可据此过滤
|
||||
获取外部模型数据
|
||||
|
||||
从 models.dev 获取第三方模型数据,用于导入新模型或参考定价信息。
|
||||
该接口作为代理请求解决跨域问题,并提供缓存优化。
|
||||
|
||||
**功能特性**:
|
||||
- 代理 models.dev API,解决前端跨域问题
|
||||
- 使用 Redis 缓存 15 分钟,多 worker 共享缓存
|
||||
- 自动标记官方提供商(official 字段),前端可据此过滤第三方转售商
|
||||
|
||||
**返回字段**:
|
||||
- 键为提供商 ID(如 "anthropic"、"openai")
|
||||
- 值为提供商详细信息,包含:
|
||||
- `official`: 是否为官方提供商(true/false)
|
||||
- 其他 models.dev 提供的原始字段(模型列表、定价等)
|
||||
"""
|
||||
# 检查缓存
|
||||
cached = await _get_cached_data()
|
||||
@@ -130,7 +142,16 @@ async def get_external_models(_: User = Depends(require_admin)) -> JSONResponse:
|
||||
|
||||
@router.delete("/external/cache")
|
||||
async def clear_external_models_cache(_: User = Depends(require_admin)) -> dict:
|
||||
"""清除 models.dev 缓存"""
|
||||
"""
|
||||
清除外部模型数据缓存
|
||||
|
||||
手动清除 models.dev 的 Redis 缓存,强制下次请求重新获取最新数据。
|
||||
通常用于需要立即更新外部模型数据的场景。
|
||||
|
||||
**返回字段**:
|
||||
- `cleared`: 是否成功清除缓存(true/false)
|
||||
- `message`: 提示信息(仅在 Redis 未启用时返回)
|
||||
"""
|
||||
redis = await get_redis_client()
|
||||
if redis is None:
|
||||
return {"cleared": False, "message": "Redis 未启用"}
|
||||
|
||||
@@ -40,7 +40,27 @@ async def list_global_models(
|
||||
search: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
) -> GlobalModelListResponse:
|
||||
"""获取 GlobalModel 列表"""
|
||||
"""
|
||||
获取 GlobalModel 列表
|
||||
|
||||
查询系统中的全局模型列表,支持分页、过滤和搜索功能。
|
||||
|
||||
**查询参数**:
|
||||
- `skip`: 跳过记录数,用于分页(默认 0)
|
||||
- `limit`: 返回记录数,用于分页(默认 100,最大 1000)
|
||||
- `is_active`: 过滤活跃状态(true/false/null,null 表示不过滤)
|
||||
- `search`: 搜索关键词,支持按名称或显示名称模糊搜索
|
||||
|
||||
**返回字段**:
|
||||
- `models`: GlobalModel 列表,每个包含:
|
||||
- `id`: GlobalModel ID
|
||||
- `name`: 模型名称(唯一)
|
||||
- `display_name`: 显示名称
|
||||
- `is_active`: 是否活跃
|
||||
- `provider_count`: 关联提供商数量
|
||||
- 定价和能力配置等其他字段
|
||||
- `total`: 返回的模型总数
|
||||
"""
|
||||
adapter = AdminListGlobalModelsAdapter(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
@@ -56,7 +76,21 @@ async def get_global_model(
|
||||
global_model_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
) -> GlobalModelWithStats:
|
||||
"""获取单个 GlobalModel 详情(含统计信息)"""
|
||||
"""
|
||||
获取单个 GlobalModel 详情
|
||||
|
||||
查询指定 GlobalModel 的详细信息,包含关联的提供商和价格统计数据。
|
||||
|
||||
**路径参数**:
|
||||
- `global_model_id`: GlobalModel ID
|
||||
|
||||
**返回字段**:
|
||||
- 基础字段:`id`, `name`, `display_name`, `is_active` 等
|
||||
- 统计字段:
|
||||
- `total_models`: 关联的 Model 实现数量
|
||||
- `total_providers`: 关联的提供商数量
|
||||
- `price_range`: 价格区间统计(最低/最高输入输出价格)
|
||||
"""
|
||||
adapter = AdminGetGlobalModelAdapter(global_model_id=global_model_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -67,7 +101,24 @@ async def create_global_model(
|
||||
payload: GlobalModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> GlobalModelResponse:
|
||||
"""创建 GlobalModel"""
|
||||
"""
|
||||
创建 GlobalModel
|
||||
|
||||
创建一个新的全局模型定义,作为多个提供商实现的统一抽象。
|
||||
|
||||
**请求体字段**:
|
||||
- `name`: 模型名称(唯一标识,如 "claude-3-5-sonnet-20241022")
|
||||
- `display_name`: 显示名称(如 "Claude 3.5 Sonnet")
|
||||
- `is_active`: 是否活跃(默认 true)
|
||||
- `default_price_per_request`: 默认按次计费价格(可选)
|
||||
- `default_tiered_pricing`: 默认阶梯定价配置(包含多个价格阶梯)
|
||||
- `supported_capabilities`: 支持的能力标志(vision、function_calling、streaming)
|
||||
- `config`: 额外配置(JSON 格式,如 description、context_window 等)
|
||||
|
||||
**返回字段**:
|
||||
- `id`: 创建的 GlobalModel ID
|
||||
- 其他请求体中的所有字段
|
||||
"""
|
||||
adapter = AdminCreateGlobalModelAdapter(payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -79,7 +130,26 @@ async def update_global_model(
|
||||
payload: GlobalModelUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> GlobalModelResponse:
|
||||
"""更新 GlobalModel"""
|
||||
"""
|
||||
更新 GlobalModel
|
||||
|
||||
更新指定 GlobalModel 的配置信息,支持部分字段更新。
|
||||
更新后会自动失效相关缓存。
|
||||
|
||||
**路径参数**:
|
||||
- `global_model_id`: GlobalModel ID
|
||||
|
||||
**请求体字段**(均为可选):
|
||||
- `display_name`: 显示名称
|
||||
- `is_active`: 是否活跃
|
||||
- `default_price_per_request`: 默认按次计费价格
|
||||
- `default_tiered_pricing`: 默认阶梯定价配置
|
||||
- `supported_capabilities`: 支持的能力标志
|
||||
- `config`: 额外配置
|
||||
|
||||
**返回字段**:
|
||||
- 更新后的完整 GlobalModel 信息
|
||||
"""
|
||||
adapter = AdminUpdateGlobalModelAdapter(global_model_id=global_model_id, payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -90,7 +160,18 @@ async def delete_global_model(
|
||||
global_model_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除 GlobalModel(级联删除所有关联的 Provider 模型实现)"""
|
||||
"""
|
||||
删除 GlobalModel
|
||||
|
||||
删除指定的 GlobalModel,会级联删除所有关联的 Provider 模型实现。
|
||||
删除后会自动失效相关缓存。
|
||||
|
||||
**路径参数**:
|
||||
- `global_model_id`: GlobalModel ID
|
||||
|
||||
**返回**:
|
||||
- 成功删除返回 204 状态码,无响应体
|
||||
"""
|
||||
adapter = AdminDeleteGlobalModelAdapter(global_model_id=global_model_id)
|
||||
await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
return None
|
||||
@@ -105,7 +186,29 @@ async def batch_assign_to_providers(
|
||||
payload: BatchAssignToProvidersRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> BatchAssignToProvidersResponse:
|
||||
"""批量为多个 Provider 添加 GlobalModel 实现"""
|
||||
"""
|
||||
批量为提供商添加模型实现
|
||||
|
||||
为指定的 GlobalModel 批量创建多个 Provider 的模型实现(Model 记录)。
|
||||
用于快速将一个统一模型分配给多个提供商。
|
||||
|
||||
**路径参数**:
|
||||
- `global_model_id`: GlobalModel ID
|
||||
|
||||
**请求体字段**:
|
||||
- `provider_ids`: 提供商 ID 列表
|
||||
- `create_models`: Model 创建配置列表,每个包含:
|
||||
- `provider_id`: 提供商 ID
|
||||
- `provider_model_name`: 提供商侧的模型名称(如 "claude-3-5-sonnet-20241022")
|
||||
- 其他可选字段(价格覆盖、能力覆盖等)
|
||||
|
||||
**返回字段**:
|
||||
- `success`: 成功创建的 Model 列表
|
||||
- `errors`: 失败的提供商及错误信息列表
|
||||
- `total_requested`: 请求处理的总数
|
||||
- `total_success`: 成功创建的数量
|
||||
- `total_errors`: 失败的数量
|
||||
"""
|
||||
adapter = AdminBatchAssignToProvidersAdapter(global_model_id=global_model_id, payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -116,7 +219,27 @@ async def get_global_model_providers(
|
||||
global_model_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
) -> GlobalModelProvidersResponse:
|
||||
"""获取 GlobalModel 的所有关联提供商(包括非活跃的)"""
|
||||
"""
|
||||
获取 GlobalModel 的关联提供商
|
||||
|
||||
查询指定 GlobalModel 的所有关联提供商及其模型实现详情,包括非活跃的提供商。
|
||||
用于查看某个统一模型在各个提供商上的具体配置。
|
||||
|
||||
**路径参数**:
|
||||
- `global_model_id`: GlobalModel ID
|
||||
|
||||
**返回字段**:
|
||||
- `providers`: 提供商列表,每个包含:
|
||||
- `provider_id`: 提供商 ID
|
||||
- `provider_name`: 提供商名称
|
||||
- `provider_display_name`: 提供商显示名称
|
||||
- `model_id`: Model 实现 ID
|
||||
- `target_model`: 提供商侧的模型名称
|
||||
- 价格信息(input_price_per_1m、output_price_per_1m 等)
|
||||
- 能力标志(supports_vision、supports_function_calling、supports_streaming)
|
||||
- `is_active`: 是否活跃
|
||||
- `total`: 关联提供商总数
|
||||
"""
|
||||
adapter = AdminGetGlobalModelProvidersAdapter(global_model_id=global_model_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@@ -39,6 +39,34 @@ async def get_audit_logs(
|
||||
offset: int = Query(0, description="偏移量"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取审计日志
|
||||
|
||||
获取系统审计日志列表,支持按用户、事件类型、时间范围筛选。需要管理员权限。
|
||||
|
||||
**查询参数**:
|
||||
- `user_id`: 可选,用户 ID 筛选(UUID 格式)
|
||||
- `event_type`: 可选,事件类型筛选
|
||||
- `days`: 查询最近多少天的日志,默认 7 天
|
||||
- `limit`: 返回数量限制,默认 100
|
||||
- `offset`: 分页偏移量,默认 0
|
||||
|
||||
**返回字段**:
|
||||
- `items`: 审计日志列表,每条日志包含:
|
||||
- `id`: 日志 ID
|
||||
- `event_type`: 事件类型
|
||||
- `user_id`: 用户 ID
|
||||
- `user_email`: 用户邮箱
|
||||
- `user_username`: 用户名
|
||||
- `description`: 事件描述
|
||||
- `ip_address`: IP 地址
|
||||
- `status_code`: HTTP 状态码
|
||||
- `error_message`: 错误信息
|
||||
- `metadata`: 事件元数据
|
||||
- `created_at`: 创建时间
|
||||
- `meta`: 分页元数据(total, limit, offset, count)
|
||||
- `filters`: 筛选条件
|
||||
"""
|
||||
adapter = AdminGetAuditLogsAdapter(
|
||||
user_id=user_id,
|
||||
event_type=event_type,
|
||||
@@ -51,6 +79,19 @@ async def get_audit_logs(
|
||||
|
||||
@router.get("/system-status")
|
||||
async def get_system_status(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
获取系统状态
|
||||
|
||||
获取系统当前的运行状态和关键指标。需要管理员权限。
|
||||
|
||||
**返回字段**:
|
||||
- `timestamp`: 当前时间戳
|
||||
- `users`: 用户统计(total: 总用户数, active: 活跃用户数)
|
||||
- `providers`: 提供商统计(total: 总提供商数, active: 活跃提供商数)
|
||||
- `api_keys`: API Key 统计(total: 总数, active: 活跃数)
|
||||
- `today_stats`: 今日统计(requests: 请求数, tokens: token 数, cost_usd: 成本)
|
||||
- `recent_errors`: 最近 1 小时内的错误数
|
||||
"""
|
||||
adapter = AdminSystemStatusAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -61,6 +102,26 @@ async def get_suspicious_activities(
|
||||
hours: int = Query(24, description="时间范围(小时)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取可疑活动记录
|
||||
|
||||
获取系统检测到的可疑活动记录。需要管理员权限。
|
||||
|
||||
**查询参数**:
|
||||
- `hours`: 时间范围(小时),默认 24 小时
|
||||
|
||||
**返回字段**:
|
||||
- `activities`: 可疑活动列表,每条记录包含:
|
||||
- `id`: 记录 ID
|
||||
- `event_type`: 事件类型
|
||||
- `user_id`: 用户 ID
|
||||
- `description`: 事件描述
|
||||
- `ip_address`: IP 地址
|
||||
- `metadata`: 事件元数据
|
||||
- `created_at`: 创建时间
|
||||
- `count`: 活动总数
|
||||
- `time_range_hours`: 查询的时间范围(小时)
|
||||
"""
|
||||
adapter = AdminSuspiciousActivitiesAdapter(hours=hours)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -72,19 +133,56 @@ async def analyze_user_behavior(
|
||||
days: int = Query(30, description="分析天数"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
分析用户行为
|
||||
|
||||
分析指定用户的行为模式和使用情况。需要管理员权限。
|
||||
|
||||
**路径参数**:
|
||||
- `user_id`: 用户 ID
|
||||
|
||||
**查询参数**:
|
||||
- `days`: 分析最近多少天的数据,默认 30 天
|
||||
|
||||
**返回字段**:
|
||||
- 用户行为分析结果,包括活动频率、使用模式、异常行为等
|
||||
"""
|
||||
adapter = AdminUserBehaviorAdapter(user_id=user_id, days=days)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/resilience-status")
|
||||
async def get_resilience_status(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
获取韧性系统状态
|
||||
|
||||
获取系统韧性管理的当前状态,包括错误统计、熔断器状态等。需要管理员权限。
|
||||
|
||||
**返回字段**:
|
||||
- `timestamp`: 当前时间戳
|
||||
- `health_score`: 健康评分(0-100)
|
||||
- `status`: 系统状态(healthy: 健康,degraded: 降级,critical: 严重)
|
||||
- `error_statistics`: 错误统计信息
|
||||
- `recent_errors`: 最近的错误列表(最多 10 条)
|
||||
- `recommendations`: 系统建议
|
||||
"""
|
||||
adapter = AdminResilienceStatusAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/resilience/error-stats")
|
||||
async def reset_error_stats(request: Request, db: Session = Depends(get_db)):
|
||||
"""Reset resilience error statistics"""
|
||||
"""
|
||||
重置错误统计
|
||||
|
||||
重置韧性系统的错误统计数据。需要管理员权限。
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果信息
|
||||
- `previous_stats`: 重置前的统计数据
|
||||
- `reset_by`: 执行重置的管理员邮箱
|
||||
- `reset_at`: 重置时间
|
||||
"""
|
||||
adapter = AdminResetErrorStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -95,6 +193,18 @@ async def get_circuit_history(
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取熔断器历史记录
|
||||
|
||||
获取熔断器的状态变更历史记录。需要管理员权限。
|
||||
|
||||
**查询参数**:
|
||||
- `limit`: 返回数量限制,默认 50,最大 200
|
||||
|
||||
**返回字段**:
|
||||
- `items`: 熔断器历史记录列表
|
||||
- `count`: 记录总数
|
||||
"""
|
||||
adapter = AdminCircuitHistoryAdapter(limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -107,6 +217,9 @@ class AdminGetAuditLogsAdapter(AdminApiAdapter):
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
# 查看审计日志本身不应该产生审计记录,避免刷新页面时产生大量无意义的日志
|
||||
audit_log_enabled: bool = False
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(days=self.days)
|
||||
|
||||
@@ -117,12 +117,21 @@ async def get_cache_stats(
|
||||
"""
|
||||
获取缓存亲和性统计信息
|
||||
|
||||
返回:
|
||||
- 缓存命中率
|
||||
- 缓存用户数
|
||||
- Provider切换次数
|
||||
- Key切换次数
|
||||
- 缓存预留配置
|
||||
获取缓存调度器的运行统计数据,包括命中率、切换次数、调度器配置等。
|
||||
用于监控缓存亲和性功能的运行状态和性能指标。
|
||||
|
||||
**返回字段**:
|
||||
- `status`: 状态(ok)
|
||||
- `data`: 统计数据对象
|
||||
- `scheduler`: 调度器名称(cache_aware 或 random)
|
||||
- `total_affinities`: 总缓存亲和性数量
|
||||
- `cache_hit_rate`: 缓存命中率(0.0-1.0)
|
||||
- `provider_switches`: Provider 切换次数
|
||||
- `key_switches`: Key 切换次数
|
||||
- `cache_hits`: 缓存命中次数
|
||||
- `cache_misses`: 缓存未命中次数
|
||||
- `scheduler_metrics`: 调度器详细指标
|
||||
- `affinity_stats`: 亲和性统计数据
|
||||
"""
|
||||
adapter = AdminCacheStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -137,16 +146,33 @@ async def get_user_affinity(
|
||||
"""
|
||||
查询指定用户的所有缓存亲和性
|
||||
|
||||
参数:
|
||||
- user_identifier: 用户标识符,支持以下格式:
|
||||
* 用户名 (username),如: yuanhonghu
|
||||
* 邮箱 (email),如: user@example.com
|
||||
* 用户UUID (user_id),如: 550e8400-e29b-41d4-a716-446655440000
|
||||
* API Key ID,如: 660e8400-e29b-41d4-a716-446655440000
|
||||
根据用户标识符查询该用户在各个端点上的缓存亲和性记录。
|
||||
支持多种标识符格式的自动识别和解析。
|
||||
|
||||
返回:
|
||||
- 用户信息
|
||||
- 所有端点的缓存亲和性列表(每个端点一条记录)
|
||||
**路径参数**:
|
||||
- `user_identifier`: 用户标识符,支持以下格式:
|
||||
- 用户名(username),如:yuanhonghu
|
||||
- 邮箱(email),如:user@example.com
|
||||
- 用户 UUID(user_id),如:550e8400-e29b-41d4-a716-446655440000
|
||||
- API Key ID,如:660e8400-e29b-41d4-a716-446655440000
|
||||
|
||||
**返回字段**:
|
||||
- `status`: 状态(ok 或 not_found)
|
||||
- `message`: 提示消息(当无缓存时)
|
||||
- `user_info`: 用户信息
|
||||
- `user_id`: 用户 ID
|
||||
- `username`: 用户名
|
||||
- `email`: 邮箱
|
||||
- `affinities`: 缓存亲和性列表
|
||||
- `provider_id`: Provider ID
|
||||
- `endpoint_id`: Endpoint ID
|
||||
- `key_id`: Key ID
|
||||
- `api_format`: API 格式
|
||||
- `model_name`: 模型名称(global_model_id)
|
||||
- `created_at`: 创建时间
|
||||
- `expire_at`: 过期时间
|
||||
- `request_count`: 请求计数
|
||||
- `total_endpoints`: 缓存的端点数量
|
||||
"""
|
||||
adapter = AdminGetUserAffinityAdapter(user_identifier=user_identifier)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -161,10 +187,50 @@ async def list_affinities(
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
获取所有缓存亲和性列表,可选按关键词过滤
|
||||
获取所有缓存亲和性列表
|
||||
|
||||
参数:
|
||||
- keyword: 可选,支持用户名/邮箱/User ID/API Key ID 或模糊匹配
|
||||
查询系统中所有的缓存亲和性记录,支持按关键词过滤和分页。
|
||||
返回详细的用户、Provider、Endpoint、Key 信息。
|
||||
|
||||
**查询参数**:
|
||||
- `keyword`: 可选,支持以下过滤方式(可选)
|
||||
- 用户名/邮箱/User ID/API Key ID(精确匹配)
|
||||
- 任意字段的模糊匹配(affinity_key、user_id、username、email、provider_id、key_id)
|
||||
- `limit`: 返回数量限制(1-1000,默认 100)
|
||||
- `offset`: 偏移量(用于分页,默认 0)
|
||||
|
||||
**返回字段**:
|
||||
- `status`: 状态(ok)
|
||||
- `data`: 分页数据对象
|
||||
- `items`: 缓存亲和性列表
|
||||
- `affinity_key`: API Key ID(用于缓存键)
|
||||
- `user_api_key_name`: 用户 API Key 名称
|
||||
- `user_api_key_prefix`: 脱敏后的用户 API Key
|
||||
- `is_standalone`: 是否为独立 API Key
|
||||
- `user_id`: 用户 ID
|
||||
- `username`: 用户名
|
||||
- `email`: 邮箱
|
||||
- `provider_id`: Provider ID
|
||||
- `provider_name`: Provider 显示名称
|
||||
- `endpoint_id`: Endpoint ID
|
||||
- `endpoint_api_format`: Endpoint API 格式
|
||||
- `endpoint_url`: Endpoint 基础 URL
|
||||
- `key_id`: Key ID
|
||||
- `key_name`: Key 名称
|
||||
- `key_prefix`: 脱敏后的 Provider Key
|
||||
- `rate_multiplier`: 速率倍数
|
||||
- `global_model_id`: GlobalModel ID
|
||||
- `model_name`: 模型名称
|
||||
- `model_display_name`: 模型显示名称
|
||||
- `api_format`: API 格式
|
||||
- `created_at`: 创建时间
|
||||
- `expire_at`: 过期时间
|
||||
- `request_count`: 请求计数
|
||||
- `meta`: 分页元数据
|
||||
- `count`: 总数量
|
||||
- `limit`: 每页数量
|
||||
- `offset`: 当前偏移量
|
||||
- `matched_user_id`: 匹配到的用户 ID(当关键词为用户标识时)
|
||||
"""
|
||||
adapter = AdminListAffinitiesAdapter(keyword=keyword, limit=limit, offset=offset)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -177,10 +243,27 @@ async def clear_user_cache(
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Clear cache affinity for a specific user
|
||||
清除指定用户的缓存亲和性
|
||||
|
||||
Parameters:
|
||||
- user_identifier: User identifier (username, email, user_id, or API Key ID)
|
||||
清除指定用户或 API Key 的所有缓存亲和性记录。
|
||||
支持按用户维度或单个 API Key 维度清除。
|
||||
|
||||
**路径参数**:
|
||||
- `user_identifier`: 用户标识符,支持以下格式:
|
||||
- 用户名(username)
|
||||
- 邮箱(email)
|
||||
- 用户 UUID(user_id)
|
||||
- API Key ID(清除该 API Key 的缓存)
|
||||
|
||||
**返回字段**:
|
||||
- `status`: 状态(ok)
|
||||
- `message`: 操作结果消息
|
||||
- `user_info`: 用户信息
|
||||
- `user_id`: 用户 ID
|
||||
- `username`: 用户名
|
||||
- `email`: 邮箱
|
||||
- `api_key_id`: API Key ID(当清除单个 API Key 时)
|
||||
- `api_key_name`: API Key 名称(当清除单个 API Key 时)
|
||||
"""
|
||||
adapter = AdminClearUserCacheAdapter(user_identifier=user_identifier)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -196,13 +279,23 @@ async def clear_single_affinity(
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Clear a single cache affinity entry
|
||||
清除单条缓存亲和性记录
|
||||
|
||||
Parameters:
|
||||
- affinity_key: API Key ID
|
||||
- endpoint_id: Endpoint ID
|
||||
- model_id: Model ID (GlobalModel ID)
|
||||
- api_format: API format (claude/openai)
|
||||
根据精确的缓存键(affinity_key + endpoint_id + model_id + api_format)
|
||||
清除单条缓存亲和性记录。用于精确控制缓存清除。
|
||||
|
||||
**路径参数**:
|
||||
- `affinity_key`: API Key ID(用于缓存的键)
|
||||
- `endpoint_id`: Endpoint ID
|
||||
- `model_id`: GlobalModel ID
|
||||
- `api_format`: API 格式(如:claude、openai、gemini)
|
||||
|
||||
**返回字段**:
|
||||
- `status`: 状态(ok)
|
||||
- `message`: 操作结果消息
|
||||
- `affinity_key`: API Key ID
|
||||
- `endpoint_id`: Endpoint ID
|
||||
- `model_id`: GlobalModel ID
|
||||
"""
|
||||
adapter = AdminClearSingleAffinityAdapter(
|
||||
affinity_key=affinity_key, endpoint_id=endpoint_id, model_id=model_id, api_format=api_format
|
||||
@@ -216,9 +309,17 @@ async def clear_all_cache(
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Clear all cache affinities
|
||||
清除所有缓存亲和性
|
||||
|
||||
Warning: This affects all users, use with caution
|
||||
清除系统中所有用户的缓存亲和性记录。此操作会影响所有用户,
|
||||
下次请求时将重新建立缓存亲和性。请谨慎使用。
|
||||
|
||||
**警告**: 此操作影响所有用户,使用前请确认
|
||||
|
||||
**返回字段**:
|
||||
- `status`: 状态(ok)
|
||||
- `message`: 操作结果消息
|
||||
- `count`: 清除的缓存数量
|
||||
"""
|
||||
adapter = AdminClearAllCacheAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -231,10 +332,19 @@ async def clear_provider_cache(
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Clear cache affinities for a specific provider
|
||||
清除指定 Provider 的缓存亲和性
|
||||
|
||||
Parameters:
|
||||
- provider_id: Provider ID
|
||||
清除与指定 Provider 相关的所有缓存亲和性记录。
|
||||
当 Provider 配置变更或下线时使用。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: Provider ID
|
||||
|
||||
**返回字段**:
|
||||
- `status`: 状态(ok)
|
||||
- `message`: 操作结果消息
|
||||
- `provider_id`: Provider ID
|
||||
- `count`: 清除的缓存数量
|
||||
"""
|
||||
adapter = AdminClearProviderCacheAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -248,9 +358,25 @@ async def get_cache_config(
|
||||
"""
|
||||
获取缓存相关配置
|
||||
|
||||
返回:
|
||||
- 缓存TTL
|
||||
- 缓存预留比例
|
||||
获取缓存亲和性功能的配置参数,包括缓存 TTL、预留比例、
|
||||
动态预留机制配置等。
|
||||
|
||||
**返回字段**:
|
||||
- `status`: 状态(ok)
|
||||
- `data`: 配置数据
|
||||
- `cache_ttl_seconds`: 缓存亲和性有效期(秒)
|
||||
- `cache_reservation_ratio`: 静态预留比例(已被动态预留替代)
|
||||
- `dynamic_reservation`: 动态预留机制配置
|
||||
- `enabled`: 是否启用
|
||||
- `config`: 配置参数
|
||||
- `probe_phase_requests`: 探测阶段请求数阈值
|
||||
- `probe_reservation`: 探测阶段预留比例
|
||||
- `stable_min_reservation`: 稳定阶段最小预留比例
|
||||
- `stable_max_reservation`: 稳定阶段最大预留比例
|
||||
- `low_load_threshold`: 低负载阈值
|
||||
- `high_load_threshold`: 高负载阈值
|
||||
- `description`: 各参数说明
|
||||
- `description`: 配置说明
|
||||
"""
|
||||
adapter = AdminCacheConfigAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -262,7 +388,31 @@ async def get_cache_metrics(
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
以 Prometheus 文本格式暴露缓存调度指标,方便接入 Grafana。
|
||||
获取缓存调度指标(Prometheus 格式)
|
||||
|
||||
以 Prometheus 文本格式输出缓存调度器的监控指标,
|
||||
方便接入 Prometheus/Grafana 等监控系统。
|
||||
|
||||
**返回格式**: Prometheus 文本格式(Content-Type: text/plain)
|
||||
|
||||
**指标列表**:
|
||||
- `cache_scheduler_total_batches`: 总批次数
|
||||
- `cache_scheduler_last_batch_size`: 最后一批候选数
|
||||
- `cache_scheduler_total_candidates`: 总候选数
|
||||
- `cache_scheduler_last_candidate_count`: 最后一批候选计数
|
||||
- `cache_scheduler_cache_hits`: 缓存命中次数
|
||||
- `cache_scheduler_cache_misses`: 缓存未命中次数
|
||||
- `cache_scheduler_cache_hit_rate`: 缓存命中率
|
||||
- `cache_scheduler_concurrency_denied`: 并发拒绝次数
|
||||
- `cache_scheduler_avg_candidates_per_batch`: 平均每批候选数
|
||||
- `cache_affinity_total`: 总缓存亲和性数量
|
||||
- `cache_affinity_hits`: 亲和性命中次数
|
||||
- `cache_affinity_misses`: 亲和性未命中次数
|
||||
- `cache_affinity_hit_rate`: 亲和性命中率
|
||||
- `cache_affinity_invalidations`: 亲和性失效次数
|
||||
- `cache_affinity_provider_switches`: Provider 切换次数
|
||||
- `cache_affinity_key_switches`: Key 切换次数
|
||||
- `cache_scheduler_info`: 调度器信息(label: scheduler)
|
||||
"""
|
||||
adapter = AdminCacheMetricsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -998,10 +1148,39 @@ async def get_model_mapping_cache_stats(
|
||||
"""
|
||||
获取模型映射缓存统计信息
|
||||
|
||||
返回:
|
||||
- 缓存键数量
|
||||
- 缓存 TTL 配置
|
||||
- 各类型缓存数量
|
||||
获取模型解析缓存的详细统计信息,包括各类型缓存键数量、
|
||||
映射关系列表、Provider 级别的模型映射缓存等。
|
||||
|
||||
**返回字段**:
|
||||
- `status`: 状态(ok)
|
||||
- `data`: 统计数据
|
||||
- `available`: Redis 是否可用
|
||||
- `message`: 提示消息(当 Redis 未启用时)
|
||||
- `ttl_seconds`: 缓存 TTL(秒)
|
||||
- `total_keys`: 总缓存键数量
|
||||
- `breakdown`: 各类型缓存键数量分解
|
||||
- `model_by_id`: Model ID 缓存数量
|
||||
- `model_by_provider_global`: Provider-GlobalModel 缓存数量
|
||||
- `global_model_by_id`: GlobalModel ID 缓存数量
|
||||
- `global_model_by_name`: GlobalModel 名称缓存数量
|
||||
- `global_model_resolve`: GlobalModel 解析缓存数量
|
||||
- `mappings`: 模型映射列表(最多 100 条)
|
||||
- `mapping_name`: 映射名称(别名)
|
||||
- `global_model_name`: GlobalModel 名称
|
||||
- `global_model_display_name`: GlobalModel 显示名称
|
||||
- `providers`: 使用该映射的 Provider 列表
|
||||
- `ttl`: 缓存剩余 TTL(秒)
|
||||
- `provider_model_mappings`: Provider 级别的模型映射(最多 100 条)
|
||||
- `provider_id`: Provider ID
|
||||
- `provider_name`: Provider 名称
|
||||
- `global_model_id`: GlobalModel ID
|
||||
- `global_model_name`: GlobalModel 名称
|
||||
- `global_model_display_name`: GlobalModel 显示名称
|
||||
- `provider_model_name`: Provider 侧的模型名称
|
||||
- `aliases`: 别名列表
|
||||
- `ttl`: 缓存剩余 TTL(秒)
|
||||
- `hit_count`: 缓存命中次数
|
||||
- `unmapped`: 未映射或无效的缓存条目
|
||||
"""
|
||||
adapter = AdminModelMappingCacheStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -1015,7 +1194,15 @@ async def clear_all_model_mapping_cache(
|
||||
"""
|
||||
清除所有模型映射缓存
|
||||
|
||||
警告: 这会影响所有模型解析,请谨慎使用
|
||||
清除系统中所有的模型映射缓存,包括 Model、GlobalModel、
|
||||
模型解析等所有相关缓存。下次请求时将重新从数据库查询。
|
||||
|
||||
**警告**: 此操作会影响所有模型解析,请谨慎使用
|
||||
|
||||
**返回字段**:
|
||||
- `status`: 状态(ok)
|
||||
- `message`: 操作结果消息
|
||||
- `deleted_count`: 删除的缓存键数量
|
||||
"""
|
||||
adapter = AdminClearAllModelMappingCacheAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -1030,8 +1217,17 @@ async def clear_model_mapping_cache_by_name(
|
||||
"""
|
||||
清除指定模型名称的映射缓存
|
||||
|
||||
参数:
|
||||
- model_name: 模型名称(可以是 GlobalModel.name 或映射名称)
|
||||
根据模型名称清除相关的映射缓存,包括 resolve 缓存和 name 缓存。
|
||||
用于更新单个模型的配置后刷新缓存。
|
||||
|
||||
**路径参数**:
|
||||
- `model_name`: 模型名称(可以是 GlobalModel.name 或映射名称)
|
||||
|
||||
**返回字段**:
|
||||
- `status`: 状态(ok)
|
||||
- `message`: 操作结果消息
|
||||
- `model_name`: 模型名称
|
||||
- `deleted_keys`: 删除的缓存键列表
|
||||
"""
|
||||
adapter = AdminClearModelMappingCacheByNameAdapter(model_name=model_name)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -1047,9 +1243,19 @@ async def clear_provider_model_mapping_cache(
|
||||
"""
|
||||
清除指定 Provider 和 GlobalModel 的模型映射缓存
|
||||
|
||||
参数:
|
||||
- provider_id: Provider ID
|
||||
- global_model_id: GlobalModel ID
|
||||
清除特定 Provider 和 GlobalModel 组合的映射缓存及其命中次数统计。
|
||||
用于 Provider 模型配置更新后刷新缓存。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: Provider ID
|
||||
- `global_model_id`: GlobalModel ID
|
||||
|
||||
**返回字段**:
|
||||
- `status`: 状态(ok)
|
||||
- `message`: 操作结果消息
|
||||
- `provider_id`: Provider ID
|
||||
- `global_model_id`: GlobalModel ID
|
||||
- `deleted_keys`: 删除的缓存键列表
|
||||
"""
|
||||
adapter = AdminClearProviderModelMappingCacheAdapter(
|
||||
provider_id=provider_id, global_model_id=global_model_id
|
||||
|
||||
@@ -71,7 +71,47 @@ async def get_request_trace(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取特定请求的完整追踪信息"""
|
||||
"""
|
||||
获取请求的完整追踪信息
|
||||
|
||||
获取指定请求的完整链路追踪信息,包括所有候选(candidates)的执行情况。
|
||||
|
||||
**路径参数**:
|
||||
- `request_id`: 请求 ID
|
||||
|
||||
**返回字段**:
|
||||
- `request_id`: 请求 ID
|
||||
- `total_candidates`: 候选总数
|
||||
- `final_status`: 最终状态(success: 成功,failed: 失败,streaming: 流式传输中,pending: 等待中)
|
||||
- `total_latency_ms`: 总延迟(毫秒)
|
||||
- `candidates`: 候选列表,每个候选包含:
|
||||
- `id`: 候选 ID
|
||||
- `request_id`: 请求 ID
|
||||
- `candidate_index`: 候选索引
|
||||
- `retry_index`: 重试序号
|
||||
- `provider_id`: 提供商 ID
|
||||
- `provider_name`: 提供商名称
|
||||
- `provider_website`: 提供商官网
|
||||
- `endpoint_id`: 端点 ID
|
||||
- `endpoint_name`: 端点名称(API 格式)
|
||||
- `key_id`: 密钥 ID
|
||||
- `key_name`: 密钥名称
|
||||
- `key_preview`: 密钥脱敏预览
|
||||
- `key_capabilities`: 密钥支持的能力
|
||||
- `required_capabilities`: 请求需要的能力标签
|
||||
- `status`: 状态(pending, success, failed, skipped)
|
||||
- `skip_reason`: 跳过原因
|
||||
- `is_cached`: 是否缓存命中
|
||||
- `status_code`: HTTP 状态码
|
||||
- `error_type`: 错误类型
|
||||
- `error_message`: 错误信息
|
||||
- `latency_ms`: 延迟(毫秒)
|
||||
- `concurrent_requests`: 并发请求数
|
||||
- `extra_data`: 额外数据
|
||||
- `created_at`: 创建时间
|
||||
- `started_at`: 开始时间
|
||||
- `finished_at`: 完成时间
|
||||
"""
|
||||
|
||||
adapter = AdminGetRequestTraceAdapter(request_id=request_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -85,9 +125,23 @@ async def get_provider_failure_rate(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取某个 Provider 的失败率统计
|
||||
获取提供商的失败率统计
|
||||
|
||||
需要管理员权限
|
||||
获取指定提供商最近的失败率统计信息。需要管理员权限。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**查询参数**:
|
||||
- `limit`: 统计最近的尝试数量,默认 100,最大 1000
|
||||
|
||||
**返回字段**:
|
||||
- `provider_id`: 提供商 ID
|
||||
- `total_attempts`: 总尝试次数
|
||||
- `success_count`: 成功次数
|
||||
- `failed_count`: 失败次数
|
||||
- `failure_rate`: 失败率(百分比)
|
||||
- `avg_latency_ms`: 平均延迟(毫秒)
|
||||
"""
|
||||
adapter = AdminProviderFailureRateAdapter(provider_id=provider_id, limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -39,6 +39,31 @@ async def update_provider_billing(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
更新提供商计费配置
|
||||
|
||||
更新指定提供商的计费策略、配额设置和优先级配置。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**请求体字段**:
|
||||
- `billing_type`: 计费类型(pay_as_you_go、subscription、prepaid、monthly_quota)
|
||||
- `monthly_quota_usd`: 月度配额(美元),可选
|
||||
- `quota_reset_day`: 配额重置周期(天数,1-365),默认 30
|
||||
- `quota_last_reset_at`: 当前周期开始时间,可选(设置后会自动同步该周期内的历史使用量)
|
||||
- `quota_expires_at`: 配额过期时间,可选
|
||||
- `rpm_limit`: 每分钟请求数限制,可选
|
||||
- `provider_priority`: 提供商优先级(0-200),默认 100
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果信息
|
||||
- `provider`: 更新后的提供商信息
|
||||
- `id`: 提供商 ID
|
||||
- `name`: 提供商名称
|
||||
- `billing_type`: 计费类型
|
||||
- `provider_priority`: 提供商优先级
|
||||
"""
|
||||
adapter = AdminProviderBillingAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -50,6 +75,39 @@ async def get_provider_stats(
|
||||
hours: int = 24,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取提供商统计数据
|
||||
|
||||
获取指定提供商的计费信息、RPM 使用情况和使用统计数据。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**查询参数**:
|
||||
- `hours`: 统计时间范围(小时),默认 24
|
||||
|
||||
**返回字段**:
|
||||
- `provider_id`: 提供商 ID
|
||||
- `provider_name`: 提供商名称
|
||||
- `period_hours`: 统计时间范围
|
||||
- `billing_info`: 计费信息
|
||||
- `billing_type`: 计费类型
|
||||
- `monthly_quota_usd`: 月度配额
|
||||
- `monthly_used_usd`: 月度已使用
|
||||
- `quota_remaining_usd`: 剩余配额
|
||||
- `quota_expires_at`: 配额过期时间
|
||||
- `rpm_info`: RPM 信息
|
||||
- `rpm_limit`: RPM 限制
|
||||
- `rpm_used`: 已使用 RPM
|
||||
- `rpm_reset_at`: RPM 重置时间
|
||||
- `usage_stats`: 使用统计
|
||||
- `total_requests`: 总请求数
|
||||
- `successful_requests`: 成功请求数
|
||||
- `failed_requests`: 失败请求数
|
||||
- `success_rate`: 成功率
|
||||
- `avg_response_time_ms`: 平均响应时间(毫秒)
|
||||
- `total_cost_usd`: 总成本(美元)
|
||||
"""
|
||||
adapter = AdminProviderStatsAdapter(provider_id=provider_id, hours=hours)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -67,6 +125,20 @@ async def reset_provider_quota(
|
||||
|
||||
@router.get("/strategies")
|
||||
async def list_available_strategies(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
获取可用负载均衡策略列表
|
||||
|
||||
列出系统中所有已注册的负载均衡策略插件。
|
||||
|
||||
**返回字段**:
|
||||
- `strategies`: 策略列表
|
||||
- `name`: 策略名称
|
||||
- `priority`: 策略优先级
|
||||
- `version`: 策略版本
|
||||
- `description`: 策略描述
|
||||
- `author`: 策略作者
|
||||
- `total`: 策略总数
|
||||
"""
|
||||
adapter = AdminListStrategiesAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@@ -49,7 +49,36 @@ async def list_provider_models(
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[ModelResponse]:
|
||||
"""获取提供商的所有模型(管理员)"""
|
||||
"""
|
||||
获取提供商的所有模型
|
||||
|
||||
获取指定提供商的模型列表,支持分页和状态过滤。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**查询参数**:
|
||||
- `is_active`: 可选的活跃状态过滤,true 仅返回活跃模型,false 返回禁用模型,不传则返回全部
|
||||
- `skip`: 跳过的记录数,默认为 0
|
||||
- `limit`: 返回的最大记录数,默认为 100
|
||||
|
||||
**返回字段**(数组,每项包含):
|
||||
- `id`: 模型 ID
|
||||
- `provider_id`: 提供商 ID
|
||||
- `global_model_id`: 全局模型 ID
|
||||
- `provider_model_name`: 提供商模型名称
|
||||
- `is_active`: 是否启用
|
||||
- `input_price_per_1m`: 输入价格(每百万 token)
|
||||
- `output_price_per_1m`: 输出价格(每百万 token)
|
||||
- `cache_creation_price_per_1m`: 缓存创建价格(每百万 token)
|
||||
- `cache_read_price_per_1m`: 缓存读取价格(每百万 token)
|
||||
- `price_per_request`: 每次请求价格
|
||||
- `supports_vision`: 是否支持视觉
|
||||
- `supports_function_calling`: 是否支持函数调用
|
||||
- `supports_streaming`: 是否支持流式输出
|
||||
- `created_at`: 创建时间
|
||||
- `updated_at`: 更新时间
|
||||
"""
|
||||
adapter = AdminListProviderModelsAdapter(
|
||||
provider_id=provider_id,
|
||||
is_active=is_active,
|
||||
@@ -66,7 +95,29 @@ async def create_provider_model(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ModelResponse:
|
||||
"""创建模型(管理员)"""
|
||||
"""
|
||||
创建模型
|
||||
|
||||
为指定提供商创建一个新的模型配置。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**请求体字段**:
|
||||
- `provider_model_name`: 提供商模型名称(必填)
|
||||
- `global_model_id`: 全局模型 ID(可选,关联到全局模型)
|
||||
- `is_active`: 是否启用(默认 true)
|
||||
- `input_price_per_1m`: 输入价格(每百万 token)(可选)
|
||||
- `output_price_per_1m`: 输出价格(每百万 token)(可选)
|
||||
- `cache_creation_price_per_1m`: 缓存创建价格(每百万 token)(可选)
|
||||
- `cache_read_price_per_1m`: 缓存读取价格(每百万 token)(可选)
|
||||
- `price_per_request`: 每次请求价格(可选)
|
||||
- `supports_vision`: 是否支持视觉(可选)
|
||||
- `supports_function_calling`: 是否支持函数调用(可选)
|
||||
- `supports_streaming`: 是否支持流式输出(可选)
|
||||
|
||||
**返回字段**: 返回创建的模型详细信息(与 GET 单个模型接口返回格式相同)
|
||||
"""
|
||||
adapter = AdminCreateProviderModelAdapter(provider_id=provider_id, model_data=model_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -78,7 +129,32 @@ async def get_provider_model(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ModelResponse:
|
||||
"""获取模型详情(管理员)"""
|
||||
"""
|
||||
获取模型详情
|
||||
|
||||
获取指定模型的详细配置信息。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
- `model_id`: 模型 ID
|
||||
|
||||
**返回字段**:
|
||||
- `id`: 模型 ID
|
||||
- `provider_id`: 提供商 ID
|
||||
- `global_model_id`: 全局模型 ID
|
||||
- `provider_model_name`: 提供商模型名称
|
||||
- `is_active`: 是否启用
|
||||
- `input_price_per_1m`: 输入价格(每百万 token)
|
||||
- `output_price_per_1m`: 输出价格(每百万 token)
|
||||
- `cache_creation_price_per_1m`: 缓存创建价格(每百万 token)
|
||||
- `cache_read_price_per_1m`: 缓存读取价格(每百万 token)
|
||||
- `price_per_request`: 每次请求价格
|
||||
- `supports_vision`: 是否支持视觉
|
||||
- `supports_function_calling`: 是否支持函数调用
|
||||
- `supports_streaming`: 是否支持流式输出
|
||||
- `created_at`: 创建时间
|
||||
- `updated_at`: 更新时间
|
||||
"""
|
||||
adapter = AdminGetProviderModelAdapter(provider_id=provider_id, model_id=model_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -91,7 +167,30 @@ async def update_provider_model(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ModelResponse:
|
||||
"""更新模型(管理员)"""
|
||||
"""
|
||||
更新模型配置
|
||||
|
||||
更新指定模型的配置信息。只需传入需要更新的字段,未传入的字段保持不变。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
- `model_id`: 模型 ID
|
||||
|
||||
**请求体字段**(所有字段可选):
|
||||
- `provider_model_name`: 提供商模型名称
|
||||
- `global_model_id`: 全局模型 ID
|
||||
- `is_active`: 是否启用
|
||||
- `input_price_per_1m`: 输入价格(每百万 token)
|
||||
- `output_price_per_1m`: 输出价格(每百万 token)
|
||||
- `cache_creation_price_per_1m`: 缓存创建价格(每百万 token)
|
||||
- `cache_read_price_per_1m`: 缓存读取价格(每百万 token)
|
||||
- `price_per_request`: 每次请求价格
|
||||
- `supports_vision`: 是否支持视觉
|
||||
- `supports_function_calling`: 是否支持函数调用
|
||||
- `supports_streaming`: 是否支持流式输出
|
||||
|
||||
**返回字段**: 返回更新后的模型详细信息(与 GET 单个模型接口返回格式相同)
|
||||
"""
|
||||
adapter = AdminUpdateProviderModelAdapter(
|
||||
provider_id=provider_id,
|
||||
model_id=model_id,
|
||||
@@ -107,7 +206,18 @@ async def delete_provider_model(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除模型(管理员)"""
|
||||
"""
|
||||
删除模型
|
||||
|
||||
删除指定的模型配置。注意:此操作不可逆。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
- `model_id`: 模型 ID
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 删除成功提示信息
|
||||
"""
|
||||
adapter = AdminDeleteProviderModelAdapter(provider_id=provider_id, model_id=model_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -119,7 +229,29 @@ async def batch_create_provider_models(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[ModelResponse]:
|
||||
"""批量创建模型(管理员)"""
|
||||
"""
|
||||
批量创建模型
|
||||
|
||||
为指定提供商批量创建多个模型配置。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**请求体**: 模型数据数组,每项包含:
|
||||
- `provider_model_name`: 提供商模型名称(必填)
|
||||
- `global_model_id`: 全局模型 ID(可选)
|
||||
- `is_active`: 是否启用(默认 true)
|
||||
- `input_price_per_1m`: 输入价格(每百万 token)(可选)
|
||||
- `output_price_per_1m`: 输出价格(每百万 token)(可选)
|
||||
- `cache_creation_price_per_1m`: 缓存创建价格(每百万 token)(可选)
|
||||
- `cache_read_price_per_1m`: 缓存读取价格(每百万 token)(可选)
|
||||
- `price_per_request`: 每次请求价格(可选)
|
||||
- `supports_vision`: 是否支持视觉(可选)
|
||||
- `supports_function_calling`: 是否支持函数调用(可选)
|
||||
- `supports_streaming`: 是否支持流式输出(可选)
|
||||
|
||||
**返回字段**: 返回创建的模型列表(与 GET 模型列表接口返回格式相同)
|
||||
"""
|
||||
adapter = AdminBatchCreateModelsAdapter(provider_id=provider_id, models_data=models_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -134,10 +266,23 @@ async def get_provider_available_source_models(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取该 Provider 支持的所有统一模型名(source_model)
|
||||
获取提供商支持的可用源模型
|
||||
|
||||
包括:
|
||||
1. 直连模型(Model.provider_model_name 直接作为统一模型名)
|
||||
获取该提供商支持的所有统一模型名(source_model),包含价格和能力信息。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**返回字段**:
|
||||
- `models`: 可用源模型数组,每项包含:
|
||||
- `global_model_name`: 全局模型名称
|
||||
- `display_name`: 显示名称
|
||||
- `provider_model_name`: 提供商模型名称
|
||||
- `model_id`: 模型 ID
|
||||
- `price`: 价格信息(包含 input_price_per_1m, output_price_per_1m, cache_creation_price_per_1m, cache_read_price_per_1m, price_per_request)
|
||||
- `capabilities`: 能力信息(包含 supports_vision, supports_function_calling, supports_streaming)
|
||||
- `is_active`: 是否启用
|
||||
- `total`: 总数
|
||||
"""
|
||||
adapter = AdminGetProviderAvailableSourceModelsAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -153,7 +298,27 @@ async def batch_assign_global_models_to_provider(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> BatchAssignModelsToProviderResponse:
|
||||
"""批量为 Provider 关联 GlobalModels(自动继承价格和能力配置)"""
|
||||
"""
|
||||
批量关联全局模型
|
||||
|
||||
批量为提供商关联全局模型,自动继承全局模型的价格和能力配置。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**请求体字段**:
|
||||
- `global_model_ids`: 全局模型 ID 数组(必填)
|
||||
|
||||
**返回字段**:
|
||||
- `success`: 成功关联的模型数组,每项包含:
|
||||
- `global_model_id`: 全局模型 ID
|
||||
- `global_model_name`: 全局模型名称
|
||||
- `model_id`: 新创建的模型 ID
|
||||
- `errors`: 失败的模型数组,每项包含:
|
||||
- `global_model_id`: 全局模型 ID
|
||||
- `global_model_name`: 全局模型名称(如果可用)
|
||||
- `error`: 错误信息
|
||||
"""
|
||||
adapter = AdminBatchAssignModelsToProviderAdapter(
|
||||
provider_id=provider_id, payload=payload
|
||||
)
|
||||
@@ -173,10 +338,30 @@ async def import_models_from_upstream(
|
||||
"""
|
||||
从上游提供商导入模型
|
||||
|
||||
流程:
|
||||
从上游提供商导入模型列表。如果全局模型不存在,将自动创建。
|
||||
|
||||
**流程说明**:
|
||||
1. 根据 model_ids 检查全局模型是否存在(按 name 匹配)
|
||||
2. 如不存在,自动创建新的 GlobalModel(使用默认配置)
|
||||
2. 如不存在,自动创建新的 GlobalModel(使用默认免费配置)
|
||||
3. 创建 Model 关联到当前 Provider
|
||||
4. 如模型已关联,则记录到成功列表中
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**请求体字段**:
|
||||
- `model_ids`: 模型 ID 数组(必填,每个 ID 长度 1-100 字符)
|
||||
|
||||
**返回字段**:
|
||||
- `success`: 成功导入的模型数组,每项包含:
|
||||
- `model_id`: 模型 ID
|
||||
- `global_model_id`: 全局模型 ID
|
||||
- `global_model_name`: 全局模型名称
|
||||
- `provider_model_id`: 提供商模型 ID
|
||||
- `created_global_model`: 是否新创建了全局模型
|
||||
- `errors`: 失败的模型数组,每项包含:
|
||||
- `model_id`: 模型 ID
|
||||
- `error`: 错误信息
|
||||
"""
|
||||
adapter = AdminImportFromUpstreamAdapter(provider_id=provider_id, payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -27,24 +27,114 @@ async def list_providers(
|
||||
is_active: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取提供商列表
|
||||
|
||||
获取所有提供商的基本信息列表,支持分页和状态过滤。
|
||||
|
||||
**查询参数**:
|
||||
- `skip`: 跳过的记录数,用于分页,默认为 0
|
||||
- `limit`: 返回的最大记录数,范围 1-500,默认为 100
|
||||
- `is_active`: 可选的活跃状态过滤,true 仅返回活跃提供商,false 返回禁用提供商,不传则返回全部
|
||||
|
||||
**返回字段**:
|
||||
- `id`: 提供商 ID
|
||||
- `name`: 提供商名称(唯一标识)
|
||||
- `display_name`: 显示名称
|
||||
- `api_format`: API 格式(如 claude、openai、gemini 等)
|
||||
- `base_url`: API 基础 URL
|
||||
- `api_key`: API 密钥(脱敏显示)
|
||||
- `priority`: 优先级
|
||||
- `is_active`: 是否活跃
|
||||
- `created_at`: 创建时间
|
||||
- `updated_at`: 更新时间
|
||||
"""
|
||||
adapter = AdminListProvidersAdapter(skip=skip, limit=limit, is_active=is_active)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_provider(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
创建新提供商
|
||||
|
||||
创建一个新的 AI 模型提供商配置。
|
||||
|
||||
**请求体字段**:
|
||||
- `name`: 提供商名称(必填,唯一,用于系统标识)
|
||||
- `display_name`: 显示名称(必填)
|
||||
- `description`: 描述信息(可选)
|
||||
- `website`: 官网地址(可选)
|
||||
- `billing_type`: 计费类型(可选,pay_as_you_go/subscription/prepaid,默认 pay_as_you_go)
|
||||
- `monthly_quota_usd`: 月度配额(美元)(可选)
|
||||
- `quota_reset_day`: 配额重置日期(1-31)(可选)
|
||||
- `quota_last_reset_at`: 上次配额重置时间(可选)
|
||||
- `quota_expires_at`: 配额过期时间(可选)
|
||||
- `rpm_limit`: 每分钟请求数限制(可选)
|
||||
- `provider_priority`: 提供商优先级(数字越小优先级越高,默认 100)
|
||||
- `is_active`: 是否启用(默认 true)
|
||||
- `concurrent_limit`: 并发限制(可选)
|
||||
- `config`: 额外配置信息(JSON,可选)
|
||||
|
||||
**返回字段**:
|
||||
- `id`: 新创建的提供商 ID
|
||||
- `name`: 提供商名称
|
||||
- `display_name`: 显示名称
|
||||
- `message`: 成功提示信息
|
||||
"""
|
||||
adapter = AdminCreateProviderAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{provider_id}")
|
||||
async def update_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
更新提供商配置
|
||||
|
||||
更新指定提供商的配置信息。只需传入需要更新的字段,未传入的字段保持不变。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**请求体字段**(所有字段可选):
|
||||
- `name`: 提供商名称
|
||||
- `display_name`: 显示名称
|
||||
- `description`: 描述信息
|
||||
- `website`: 官网地址
|
||||
- `billing_type`: 计费类型(pay_as_you_go/subscription/prepaid)
|
||||
- `monthly_quota_usd`: 月度配额(美元)
|
||||
- `quota_reset_day`: 配额重置日期(1-31)
|
||||
- `quota_last_reset_at`: 上次配额重置时间
|
||||
- `quota_expires_at`: 配额过期时间
|
||||
- `rpm_limit`: 每分钟请求数限制
|
||||
- `provider_priority`: 提供商优先级
|
||||
- `is_active`: 是否启用
|
||||
- `concurrent_limit`: 并发限制
|
||||
- `config`: 额外配置信息(JSON)
|
||||
|
||||
**返回字段**:
|
||||
- `id`: 提供商 ID
|
||||
- `name`: 提供商名称
|
||||
- `is_active`: 是否启用
|
||||
- `message`: 成功提示信息
|
||||
"""
|
||||
adapter = AdminUpdateProviderAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{provider_id}")
|
||||
async def delete_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
删除提供商
|
||||
|
||||
删除指定的提供商。注意:此操作会级联删除关联的端点、密钥和模型配置。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 删除成功提示信息
|
||||
"""
|
||||
adapter = AdminDeleteProviderAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -136,7 +226,6 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
|
||||
rpm_limit=validated_data.rpm_limit,
|
||||
provider_priority=validated_data.provider_priority,
|
||||
is_active=validated_data.is_active,
|
||||
rate_limit=validated_data.rate_limit,
|
||||
concurrent_limit=validated_data.concurrent_limit,
|
||||
config=validated_data.config,
|
||||
)
|
||||
|
||||
@@ -40,7 +40,41 @@ async def get_providers_summary(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[ProviderWithEndpointsSummary]:
|
||||
"""获取所有 Providers 的摘要信息(包含 Endpoints 和 Keys 统计)"""
|
||||
"""
|
||||
获取所有提供商摘要信息
|
||||
|
||||
获取所有提供商的详细摘要信息,包含端点、密钥、模型统计和健康状态。
|
||||
|
||||
**返回字段**(数组,每项包含):
|
||||
- `id`: 提供商 ID
|
||||
- `name`: 提供商名称
|
||||
- `display_name`: 显示名称
|
||||
- `description`: 描述信息
|
||||
- `website`: 官网地址
|
||||
- `provider_priority`: 优先级
|
||||
- `is_active`: 是否启用
|
||||
- `billing_type`: 计费类型
|
||||
- `monthly_quota_usd`: 月度配额(美元)
|
||||
- `monthly_used_usd`: 本月已使用金额(美元)
|
||||
- `quota_reset_day`: 配额重置日期
|
||||
- `quota_last_reset_at`: 上次配额重置时间
|
||||
- `quota_expires_at`: 配额过期时间
|
||||
- `rpm_limit`: RPM 限制
|
||||
- `rpm_used`: 已使用 RPM
|
||||
- `rpm_reset_at`: RPM 重置时间
|
||||
- `total_endpoints`: 端点总数
|
||||
- `active_endpoints`: 活跃端点数
|
||||
- `total_keys`: 密钥总数
|
||||
- `active_keys`: 活跃密钥数
|
||||
- `total_models`: 模型总数
|
||||
- `active_models`: 活跃模型数
|
||||
- `avg_health_score`: 平均健康分数(0-1)
|
||||
- `unhealthy_endpoints`: 不健康端点数(健康分数 < 0.5)
|
||||
- `api_formats`: 支持的 API 格式列表
|
||||
- `endpoint_health_details`: 端点健康详情(包含 api_format, health_score, is_active, active_keys)
|
||||
- `created_at`: 创建时间
|
||||
- `updated_at`: 更新时间
|
||||
"""
|
||||
adapter = AdminProviderSummaryAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -51,7 +85,44 @@ async def get_provider_summary(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderWithEndpointsSummary:
|
||||
"""获取单个 Provider 的摘要信息(包含 Endpoints 和 Keys 统计)"""
|
||||
"""
|
||||
获取单个提供商摘要信息
|
||||
|
||||
获取指定提供商的详细摘要信息,包含端点、密钥、模型统计和健康状态。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**返回字段**:
|
||||
- `id`: 提供商 ID
|
||||
- `name`: 提供商名称
|
||||
- `display_name`: 显示名称
|
||||
- `description`: 描述信息
|
||||
- `website`: 官网地址
|
||||
- `provider_priority`: 优先级
|
||||
- `is_active`: 是否启用
|
||||
- `billing_type`: 计费类型
|
||||
- `monthly_quota_usd`: 月度配额(美元)
|
||||
- `monthly_used_usd`: 本月已使用金额(美元)
|
||||
- `quota_reset_day`: 配额重置日期
|
||||
- `quota_last_reset_at`: 上次配额重置时间
|
||||
- `quota_expires_at`: 配额过期时间
|
||||
- `rpm_limit`: RPM 限制
|
||||
- `rpm_used`: 已使用 RPM
|
||||
- `rpm_reset_at`: RPM 重置时间
|
||||
- `total_endpoints`: 端点总数
|
||||
- `active_endpoints`: 活跃端点数
|
||||
- `total_keys`: 密钥总数
|
||||
- `active_keys`: 活跃密钥数
|
||||
- `total_models`: 模型总数
|
||||
- `active_models`: 活跃模型数
|
||||
- `avg_health_score`: 平均健康分数(0-1)
|
||||
- `unhealthy_endpoints`: 不健康端点数(健康分数 < 0.5)
|
||||
- `api_formats`: 支持的 API 格式列表
|
||||
- `endpoint_health_details`: 端点健康详情(包含 api_format, health_score, is_active, active_keys)
|
||||
- `created_at`: 创建时间
|
||||
- `updated_at`: 更新时间
|
||||
"""
|
||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException(f"Provider {provider_id} not found")
|
||||
@@ -67,7 +138,34 @@ async def get_provider_health_monitor(
|
||||
per_endpoint_limit: int = Query(48, ge=10, le=200, description="每个端点的事件数量"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderEndpointHealthMonitorResponse:
|
||||
"""获取 Provider 下所有端点的健康监控时间线"""
|
||||
"""
|
||||
获取提供商健康监控数据
|
||||
|
||||
获取指定提供商下所有端点的健康监控时间线,包含请求成功率、延迟、错误信息等。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**查询参数**:
|
||||
- `lookback_hours`: 回溯的小时数,范围 1-72,默认为 6
|
||||
- `per_endpoint_limit`: 每个端点返回的事件数量,范围 10-200,默认为 48
|
||||
|
||||
**返回字段**:
|
||||
- `provider_id`: 提供商 ID
|
||||
- `provider_name`: 提供商名称
|
||||
- `generated_at`: 生成时间
|
||||
- `endpoints`: 端点健康监控数据数组,每项包含:
|
||||
- `endpoint_id`: 端点 ID
|
||||
- `api_format`: API 格式
|
||||
- `is_active`: 是否活跃
|
||||
- `total_attempts`: 总请求次数
|
||||
- `success_count`: 成功次数
|
||||
- `failed_count`: 失败次数
|
||||
- `skipped_count`: 跳过次数
|
||||
- `success_rate`: 成功率(0-1)
|
||||
- `last_event_at`: 最后事件时间
|
||||
- `events`: 事件详情数组(包含 timestamp, status, status_code, latency_ms, error_type, error_message)
|
||||
"""
|
||||
|
||||
adapter = AdminProviderHealthMonitorAdapter(
|
||||
provider_id=provider_id,
|
||||
@@ -84,7 +182,29 @@ async def update_provider_settings(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderWithEndpointsSummary:
|
||||
"""更新 Provider 基础配置(display_name, description, priority, weight 等)"""
|
||||
"""
|
||||
更新提供商基础配置
|
||||
|
||||
更新提供商的基础配置信息,如显示名称、描述、优先级等。只需传入需要更新的字段。
|
||||
|
||||
**路径参数**:
|
||||
- `provider_id`: 提供商 ID
|
||||
|
||||
**请求体字段**(所有字段可选):
|
||||
- `display_name`: 显示名称
|
||||
- `description`: 描述信息
|
||||
- `website`: 官网地址
|
||||
- `provider_priority`: 优先级
|
||||
- `is_active`: 是否启用
|
||||
- `billing_type`: 计费类型
|
||||
- `monthly_quota_usd`: 月度配额(美元)
|
||||
- `quota_reset_day`: 配额重置日期
|
||||
- `quota_last_reset_at`: 上次配额重置时间
|
||||
- `quota_expires_at`: 配额过期时间
|
||||
- `rpm_limit`: RPM 限制
|
||||
|
||||
**返回字段**: 返回更新后的提供商摘要信息(与 GET /summary 接口返回格式相同)
|
||||
"""
|
||||
|
||||
adapter = AdminUpdateProviderSettingsAdapter(provider_id=provider_id, update_data=update_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -18,7 +18,7 @@ from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.services.rate_limit.ip_limiter import IPRateLimiter
|
||||
|
||||
router = APIRouter(prefix="/api/admin/security/ip", tags=["IP Security"])
|
||||
router = APIRouter(prefix="/api/admin/security/ip", tags=["Admin - Security"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@@ -56,42 +56,110 @@ class RemoveIPFromWhitelistRequest(BaseModel):
|
||||
|
||||
@router.post("/blacklist")
|
||||
async def add_to_blacklist(request: Request, db: Session = Depends(get_db)):
|
||||
"""Add IP to blacklist"""
|
||||
"""
|
||||
添加 IP 到黑名单
|
||||
|
||||
将指定 IP 地址添加到黑名单,被加入黑名单的 IP 将无法访问系统。需要管理员权限。
|
||||
|
||||
**请求体字段**:
|
||||
- `ip_address`: IP 地址
|
||||
- `reason`: 加入黑名单的原因
|
||||
- `ttl`: 可选,过期时间(秒),不指定表示永久
|
||||
|
||||
**返回字段**:
|
||||
- `success`: 是否成功
|
||||
- `message`: 操作结果信息
|
||||
- `reason`: 加入黑名单的原因
|
||||
- `ttl`: 过期时间(秒或"永久")
|
||||
"""
|
||||
adapter = AddToBlacklistAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@router.delete("/blacklist/{ip_address}")
|
||||
async def remove_from_blacklist(ip_address: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""Remove IP from blacklist"""
|
||||
"""
|
||||
从黑名单移除 IP
|
||||
|
||||
将指定 IP 地址从黑名单中移除。需要管理员权限。
|
||||
|
||||
**路径参数**:
|
||||
- `ip_address`: IP 地址
|
||||
|
||||
**返回字段**:
|
||||
- `success`: 是否成功
|
||||
- `message`: 操作结果信息
|
||||
"""
|
||||
adapter = RemoveFromBlacklistAdapter(ip_address=ip_address)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@router.get("/blacklist/stats")
|
||||
async def get_blacklist_stats(request: Request, db: Session = Depends(get_db)):
|
||||
"""Get blacklist statistics"""
|
||||
"""
|
||||
获取黑名单统计信息
|
||||
|
||||
获取黑名单的统计信息和列表。需要管理员权限。
|
||||
|
||||
**返回字段**:
|
||||
- `total`: 黑名单总数
|
||||
- `items`: 黑名单列表,每个项包含:
|
||||
- `ip`: IP 地址
|
||||
- `reason`: 加入原因
|
||||
- `added_at`: 添加时间
|
||||
- `ttl`: 剩余有效时间(秒)
|
||||
"""
|
||||
adapter = GetBlacklistStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@router.post("/whitelist")
|
||||
async def add_to_whitelist(request: Request, db: Session = Depends(get_db)):
|
||||
"""Add IP to whitelist"""
|
||||
"""
|
||||
添加 IP 到白名单
|
||||
|
||||
将指定 IP 地址或 CIDR 网段添加到白名单,白名单中的 IP 将跳过速率限制检查。需要管理员权限。
|
||||
|
||||
**请求体字段**:
|
||||
- `ip_address`: IP 地址或 CIDR 格式(如 192.168.1.0/24)
|
||||
|
||||
**返回字段**:
|
||||
- `success`: 是否成功
|
||||
- `message`: 操作结果信息
|
||||
"""
|
||||
adapter = AddToWhitelistAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@router.delete("/whitelist/{ip_address}")
|
||||
async def remove_from_whitelist(ip_address: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""Remove IP from whitelist"""
|
||||
"""
|
||||
从白名单移除 IP
|
||||
|
||||
将指定 IP 地址从白名单中移除。需要管理员权限。
|
||||
|
||||
**路径参数**:
|
||||
- `ip_address`: IP 地址
|
||||
|
||||
**返回字段**:
|
||||
- `success`: 是否成功
|
||||
- `message`: 操作结果信息
|
||||
"""
|
||||
adapter = RemoveFromWhitelistAdapter(ip_address=ip_address)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@router.get("/whitelist")
|
||||
async def get_whitelist(request: Request, db: Session = Depends(get_db)):
|
||||
"""Get whitelist"""
|
||||
"""
|
||||
获取白名单
|
||||
|
||||
获取当前的 IP 白名单列表。需要管理员权限。
|
||||
|
||||
**返回字段**:
|
||||
- `whitelist`: 白名单 IP 地址列表
|
||||
- `total`: 白名单总数
|
||||
"""
|
||||
adapter = GetWhitelistAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""系统设置API端点。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@@ -17,12 +19,68 @@ from src.services.email.email_template import EmailTemplate
|
||||
from src.services.system.config import SystemConfigService
|
||||
|
||||
router = APIRouter(prefix="/api/admin/system", tags=["Admin - System"])
|
||||
|
||||
|
||||
def _get_version_from_git() -> str | None:
|
||||
"""从 git describe 获取版本号"""
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "describe", "--tags", "--always"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
version = result.stdout.strip()
|
||||
if version.startswith("v"):
|
||||
version = version[1:]
|
||||
return version
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/version")
|
||||
async def get_system_version():
|
||||
"""
|
||||
获取系统版本信息
|
||||
|
||||
获取当前系统的版本号。优先从 git describe 获取,回退到静态版本文件。
|
||||
|
||||
**返回字段**:
|
||||
- `version`: 版本号字符串
|
||||
"""
|
||||
# 优先从 git 获取
|
||||
version = _get_version_from_git()
|
||||
if version:
|
||||
return {"version": version}
|
||||
|
||||
# 回退到静态版本文件
|
||||
try:
|
||||
from src._version import __version__
|
||||
|
||||
return {"version": __version__}
|
||||
except ImportError:
|
||||
return {"version": "unknown"}
|
||||
|
||||
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/settings")
|
||||
async def get_system_settings(request: Request, db: Session = Depends(get_db)):
|
||||
"""获取系统设置(管理员)"""
|
||||
"""
|
||||
获取系统设置
|
||||
|
||||
获取系统的全局设置信息。需要管理员权限。
|
||||
|
||||
**返回字段**:
|
||||
- `default_provider`: 默认提供商名称
|
||||
- `default_model`: 默认模型名称
|
||||
- `enable_usage_tracking`: 是否启用使用情况追踪
|
||||
"""
|
||||
|
||||
adapter = AdminGetSystemSettingsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -30,7 +88,19 @@ async def get_system_settings(request: Request, db: Session = Depends(get_db)):
|
||||
|
||||
@router.put("/settings")
|
||||
async def update_system_settings(http_request: Request, db: Session = Depends(get_db)):
|
||||
"""更新系统设置(管理员)"""
|
||||
"""
|
||||
更新系统设置
|
||||
|
||||
更新系统的全局设置。需要管理员权限。
|
||||
|
||||
**请求体字段**:
|
||||
- `default_provider`: 可选,默认提供商名称(空字符串表示清除设置)
|
||||
- `default_model`: 可选,默认模型名称(空字符串表示清除设置)
|
||||
- `enable_usage_tracking`: 可选,是否启用使用情况追踪
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果信息
|
||||
"""
|
||||
|
||||
adapter = AdminUpdateSystemSettingsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
|
||||
@@ -38,7 +108,14 @@ async def update_system_settings(http_request: Request, db: Session = Depends(ge
|
||||
|
||||
@router.get("/configs")
|
||||
async def get_all_system_configs(request: Request, db: Session = Depends(get_db)):
|
||||
"""获取所有系统配置(管理员)"""
|
||||
"""
|
||||
获取所有系统配置
|
||||
|
||||
获取系统中所有的配置项。需要管理员权限。
|
||||
|
||||
**返回字段**:
|
||||
- 配置项的键值对字典
|
||||
"""
|
||||
|
||||
adapter = AdminGetAllConfigsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -46,7 +123,19 @@ async def get_all_system_configs(request: Request, db: Session = Depends(get_db)
|
||||
|
||||
@router.get("/configs/{key}")
|
||||
async def get_system_config(key: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""获取特定系统配置(管理员)"""
|
||||
"""
|
||||
获取特定系统配置
|
||||
|
||||
获取指定配置项的值。需要管理员权限。
|
||||
|
||||
**路径参数**:
|
||||
- `key`: 配置项键名
|
||||
|
||||
**返回字段**:
|
||||
- `key`: 配置项键名
|
||||
- `value`: 配置项的值(敏感配置项不返回实际值)
|
||||
- `is_set`: 可选,对于敏感配置项,指示是否已设置
|
||||
"""
|
||||
|
||||
adapter = AdminGetSystemConfigAdapter(key=key)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -58,7 +147,24 @@ async def set_system_config(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""设置系统配置(管理员)"""
|
||||
"""
|
||||
设置系统配置
|
||||
|
||||
设置或更新指定配置项的值。需要管理员权限。
|
||||
|
||||
**路径参数**:
|
||||
- `key`: 配置项键名
|
||||
|
||||
**请求体字段**:
|
||||
- `value`: 配置项的值
|
||||
- `description`: 可选,配置项描述
|
||||
|
||||
**返回字段**:
|
||||
- `key`: 配置项键名
|
||||
- `value`: 配置项的值(敏感配置项显示为 ********)
|
||||
- `description`: 配置项描述
|
||||
- `updated_at`: 更新时间
|
||||
"""
|
||||
|
||||
adapter = AdminSetSystemConfigAdapter(key=key)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -66,7 +172,17 @@ async def set_system_config(
|
||||
|
||||
@router.delete("/configs/{key}")
|
||||
async def delete_system_config(key: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""删除系统配置(管理员)"""
|
||||
"""
|
||||
删除系统配置
|
||||
|
||||
删除指定的配置项。需要管理员权限。
|
||||
|
||||
**路径参数**:
|
||||
- `key`: 配置项键名
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果信息
|
||||
"""
|
||||
|
||||
adapter = AdminDeleteSystemConfigAdapter(key=key)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -74,20 +190,54 @@ async def delete_system_config(key: str, request: Request, db: Session = Depends
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_system_stats(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
获取系统统计信息
|
||||
|
||||
获取系统的整体统计数据。需要管理员权限。
|
||||
|
||||
**返回字段**:
|
||||
- `users`: 用户统计(total: 总用户数, active: 活跃用户数)
|
||||
- `providers`: 提供商统计(total: 总提供商数, active: 活跃提供商数)
|
||||
- `api_keys`: API Key 总数
|
||||
- `requests`: 请求总数
|
||||
"""
|
||||
adapter = AdminSystemStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/cleanup")
|
||||
async def trigger_cleanup(request: Request, db: Session = Depends(get_db)):
|
||||
"""Manually trigger usage record cleanup task"""
|
||||
"""
|
||||
手动触发清理任务
|
||||
|
||||
手动触发使用记录清理任务,清理过期的请求/响应数据。需要管理员权限。
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果信息
|
||||
- `stats`: 清理统计信息
|
||||
- `total_records`: 总记录数统计(before, after, deleted)
|
||||
- `body_fields`: 请求/响应体字段清理统计(before, after, cleaned)
|
||||
- `header_fields`: 请求/响应头字段清理统计(before, after, cleaned)
|
||||
- `timestamp`: 清理完成时间
|
||||
"""
|
||||
adapter = AdminTriggerCleanupAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/api-formats")
|
||||
async def get_api_formats(request: Request, db: Session = Depends(get_db)):
|
||||
"""获取所有可用的API格式列表"""
|
||||
"""
|
||||
获取所有可用的 API 格式列表
|
||||
|
||||
获取系统支持的所有 API 格式及其元数据。需要管理员权限。
|
||||
|
||||
**返回字段**:
|
||||
- `formats`: API 格式列表,每个格式包含:
|
||||
- `value`: 格式值
|
||||
- `label`: 显示名称
|
||||
- `default_path`: 默认路径
|
||||
- `aliases`: 别名列表
|
||||
"""
|
||||
adapter = AdminGetApiFormatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -534,7 +684,6 @@ class AdminExportConfigAdapter(AdminApiAdapter):
|
||||
"rpm_limit": provider.rpm_limit,
|
||||
"provider_priority": provider.provider_priority,
|
||||
"is_active": provider.is_active,
|
||||
"rate_limit": provider.rate_limit,
|
||||
"concurrent_limit": provider.concurrent_limit,
|
||||
"config": provider.config,
|
||||
"endpoints": endpoints_data,
|
||||
@@ -681,7 +830,6 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
"provider_priority", 100
|
||||
)
|
||||
existing_provider.is_active = prov_data.get("is_active", True)
|
||||
existing_provider.rate_limit = prov_data.get("rate_limit")
|
||||
existing_provider.concurrent_limit = prov_data.get(
|
||||
"concurrent_limit"
|
||||
)
|
||||
@@ -706,7 +854,6 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
rpm_limit=prov_data.get("rpm_limit"),
|
||||
provider_priority=prov_data.get("provider_priority", 100),
|
||||
is_active=prov_data.get("is_active", True),
|
||||
rate_limit=prov_data.get("rate_limit"),
|
||||
concurrent_limit=prov_data.get("concurrent_limit"),
|
||||
config=prov_data.get("config"),
|
||||
)
|
||||
@@ -950,6 +1097,30 @@ class AdminExportUsersAdapter(AdminApiAdapter):
|
||||
|
||||
db = context.db
|
||||
|
||||
def _serialize_api_key(key: ApiKey, include_is_standalone: bool = False) -> dict:
|
||||
"""序列化 API Key 为导出格式"""
|
||||
data = {
|
||||
"key_hash": key.key_hash,
|
||||
"key_encrypted": key.key_encrypted,
|
||||
"name": key.name,
|
||||
"balance_used_usd": key.balance_used_usd,
|
||||
"current_balance_usd": key.current_balance_usd,
|
||||
"allowed_providers": key.allowed_providers,
|
||||
"allowed_api_formats": key.allowed_api_formats,
|
||||
"allowed_models": key.allowed_models,
|
||||
"rate_limit": key.rate_limit,
|
||||
"concurrent_limit": key.concurrent_limit,
|
||||
"force_capabilities": key.force_capabilities,
|
||||
"is_active": key.is_active,
|
||||
"expires_at": key.expires_at.isoformat() if key.expires_at else None,
|
||||
"auto_delete_on_expiry": key.auto_delete_on_expiry,
|
||||
"total_requests": key.total_requests,
|
||||
"total_cost_usd": key.total_cost_usd,
|
||||
}
|
||||
if include_is_standalone:
|
||||
data["is_standalone"] = key.is_standalone
|
||||
return data
|
||||
|
||||
# 导出 Users(排除管理员)
|
||||
users = db.query(User).filter(
|
||||
User.is_deleted.is_(False),
|
||||
@@ -957,31 +1128,12 @@ class AdminExportUsersAdapter(AdminApiAdapter):
|
||||
).all()
|
||||
users_data = []
|
||||
for user in users:
|
||||
# 导出用户的 API Keys(保留加密数据)
|
||||
api_keys = db.query(ApiKey).filter(ApiKey.user_id == user.id).all()
|
||||
api_keys_data = []
|
||||
for key in api_keys:
|
||||
api_keys_data.append(
|
||||
{
|
||||
"key_hash": key.key_hash,
|
||||
"key_encrypted": key.key_encrypted,
|
||||
"name": key.name,
|
||||
"is_standalone": key.is_standalone,
|
||||
"balance_used_usd": key.balance_used_usd,
|
||||
"current_balance_usd": key.current_balance_usd,
|
||||
"allowed_providers": key.allowed_providers,
|
||||
"allowed_endpoints": key.allowed_endpoints,
|
||||
"allowed_api_formats": key.allowed_api_formats,
|
||||
"allowed_models": key.allowed_models,
|
||||
"rate_limit": key.rate_limit,
|
||||
"concurrent_limit": key.concurrent_limit,
|
||||
"force_capabilities": key.force_capabilities,
|
||||
"is_active": key.is_active,
|
||||
"auto_delete_on_expiry": key.auto_delete_on_expiry,
|
||||
"total_requests": key.total_requests,
|
||||
"total_cost_usd": key.total_cost_usd,
|
||||
}
|
||||
)
|
||||
# 导出用户的 API Keys(排除独立余额Key,独立Key单独导出)
|
||||
api_keys = db.query(ApiKey).filter(
|
||||
ApiKey.user_id == user.id,
|
||||
ApiKey.is_standalone.is_(False)
|
||||
).all()
|
||||
api_keys_data = [_serialize_api_key(key, include_is_standalone=True) for key in api_keys]
|
||||
|
||||
users_data.append(
|
||||
{
|
||||
@@ -990,7 +1142,7 @@ class AdminExportUsersAdapter(AdminApiAdapter):
|
||||
"password_hash": user.password_hash,
|
||||
"role": user.role.value if user.role else "user",
|
||||
"allowed_providers": user.allowed_providers,
|
||||
"allowed_endpoints": user.allowed_endpoints,
|
||||
"allowed_api_formats": user.allowed_api_formats,
|
||||
"allowed_models": user.allowed_models,
|
||||
"model_capability_settings": user.model_capability_settings,
|
||||
"quota_usd": user.quota_usd,
|
||||
@@ -1001,10 +1153,15 @@ class AdminExportUsersAdapter(AdminApiAdapter):
|
||||
}
|
||||
)
|
||||
|
||||
# 导出独立余额 Keys(管理员创建的,不属于普通用户)
|
||||
standalone_keys = db.query(ApiKey).filter(ApiKey.is_standalone.is_(True)).all()
|
||||
standalone_keys_data = [_serialize_api_key(key) for key in standalone_keys]
|
||||
|
||||
return {
|
||||
"version": "1.0",
|
||||
"version": "1.1",
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"users": users_data,
|
||||
"standalone_keys": standalone_keys_data,
|
||||
}
|
||||
|
||||
|
||||
@@ -1024,21 +1181,71 @@ class AdminImportUsersAdapter(AdminApiAdapter):
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
# 验证配置版本
|
||||
version = payload.get("version")
|
||||
if version != "1.0":
|
||||
raise InvalidRequestException(f"不支持的配置版本: {version}")
|
||||
|
||||
# 获取导入选项
|
||||
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
|
||||
users_data = payload.get("users", [])
|
||||
standalone_keys_data = payload.get("standalone_keys", [])
|
||||
|
||||
stats = {
|
||||
"users": {"created": 0, "updated": 0, "skipped": 0},
|
||||
"api_keys": {"created": 0, "skipped": 0},
|
||||
"standalone_keys": {"created": 0, "skipped": 0},
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
def _create_api_key_from_data(
|
||||
key_data: dict,
|
||||
owner_id: str,
|
||||
is_standalone: bool = False,
|
||||
) -> tuple[ApiKey | None, str]:
|
||||
"""从导入数据创建 ApiKey 对象
|
||||
|
||||
Returns:
|
||||
(ApiKey, "created"): 成功创建
|
||||
(None, "skipped"): key 已存在,跳过
|
||||
(None, "invalid"): 数据无效,跳过
|
||||
"""
|
||||
key_hash = key_data.get("key_hash", "").strip()
|
||||
if not key_hash:
|
||||
return None, "invalid"
|
||||
|
||||
# 检查是否已存在
|
||||
existing = db.query(ApiKey).filter(ApiKey.key_hash == key_hash).first()
|
||||
if existing:
|
||||
return None, "skipped"
|
||||
|
||||
# 解析 expires_at
|
||||
expires_at = None
|
||||
if key_data.get("expires_at"):
|
||||
try:
|
||||
expires_at = datetime.fromisoformat(key_data["expires_at"])
|
||||
except ValueError:
|
||||
stats["errors"].append(
|
||||
f"API Key '{key_data.get('name', key_hash[:8])}' 的 expires_at 格式无效"
|
||||
)
|
||||
|
||||
return ApiKey(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=owner_id,
|
||||
key_hash=key_hash,
|
||||
key_encrypted=key_data.get("key_encrypted"),
|
||||
name=key_data.get("name"),
|
||||
is_standalone=is_standalone or key_data.get("is_standalone", False),
|
||||
balance_used_usd=key_data.get("balance_used_usd", 0.0),
|
||||
current_balance_usd=key_data.get("current_balance_usd"),
|
||||
allowed_providers=key_data.get("allowed_providers"),
|
||||
allowed_api_formats=key_data.get("allowed_api_formats"),
|
||||
allowed_models=key_data.get("allowed_models"),
|
||||
rate_limit=key_data.get("rate_limit"),
|
||||
concurrent_limit=key_data.get("concurrent_limit", 5),
|
||||
force_capabilities=key_data.get("force_capabilities"),
|
||||
is_active=key_data.get("is_active", True),
|
||||
expires_at=expires_at,
|
||||
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
|
||||
total_requests=key_data.get("total_requests", 0),
|
||||
total_cost_usd=key_data.get("total_cost_usd", 0.0),
|
||||
), "created"
|
||||
|
||||
try:
|
||||
for user_data in users_data:
|
||||
# 跳过管理员角色的导入(不区分大小写)
|
||||
@@ -1070,7 +1277,7 @@ class AdminImportUsersAdapter(AdminApiAdapter):
|
||||
if user_data.get("role"):
|
||||
existing_user.role = UserRole(user_data["role"])
|
||||
existing_user.allowed_providers = user_data.get("allowed_providers")
|
||||
existing_user.allowed_endpoints = user_data.get("allowed_endpoints")
|
||||
existing_user.allowed_api_formats = user_data.get("allowed_api_formats")
|
||||
existing_user.allowed_models = user_data.get("allowed_models")
|
||||
existing_user.model_capability_settings = user_data.get(
|
||||
"model_capability_settings"
|
||||
@@ -1094,7 +1301,7 @@ class AdminImportUsersAdapter(AdminApiAdapter):
|
||||
password_hash=user_data.get("password_hash", ""),
|
||||
role=role,
|
||||
allowed_providers=user_data.get("allowed_providers"),
|
||||
allowed_endpoints=user_data.get("allowed_endpoints"),
|
||||
allowed_api_formats=user_data.get("allowed_api_formats"),
|
||||
allowed_models=user_data.get("allowed_models"),
|
||||
model_capability_settings=user_data.get("model_capability_settings"),
|
||||
quota_usd=user_data.get("quota_usd"),
|
||||
@@ -1109,40 +1316,31 @@ class AdminImportUsersAdapter(AdminApiAdapter):
|
||||
|
||||
# 导入 API Keys
|
||||
for key_data in user_data.get("api_keys", []):
|
||||
# 检查是否已存在相同的 key_hash
|
||||
if key_data.get("key_hash"):
|
||||
existing_key = (
|
||||
db.query(ApiKey)
|
||||
.filter(ApiKey.key_hash == key_data["key_hash"])
|
||||
.first()
|
||||
)
|
||||
if existing_key:
|
||||
stats["api_keys"]["skipped"] += 1
|
||||
continue
|
||||
new_key, status = _create_api_key_from_data(key_data, user_id)
|
||||
if new_key:
|
||||
db.add(new_key)
|
||||
stats["api_keys"]["created"] += 1
|
||||
elif status == "skipped":
|
||||
stats["api_keys"]["skipped"] += 1
|
||||
# invalid 数据不计入统计
|
||||
|
||||
new_key = ApiKey(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
key_hash=key_data.get("key_hash", ""),
|
||||
key_encrypted=key_data.get("key_encrypted"),
|
||||
name=key_data.get("name"),
|
||||
is_standalone=key_data.get("is_standalone", False),
|
||||
balance_used_usd=key_data.get("balance_used_usd", 0.0),
|
||||
current_balance_usd=key_data.get("current_balance_usd"),
|
||||
allowed_providers=key_data.get("allowed_providers"),
|
||||
allowed_endpoints=key_data.get("allowed_endpoints"),
|
||||
allowed_api_formats=key_data.get("allowed_api_formats"),
|
||||
allowed_models=key_data.get("allowed_models"),
|
||||
rate_limit=key_data.get("rate_limit"), # None = 无限制
|
||||
concurrent_limit=key_data.get("concurrent_limit", 5),
|
||||
force_capabilities=key_data.get("force_capabilities"),
|
||||
is_active=key_data.get("is_active", True),
|
||||
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
|
||||
total_requests=key_data.get("total_requests", 0),
|
||||
total_cost_usd=key_data.get("total_cost_usd", 0.0),
|
||||
)
|
||||
db.add(new_key)
|
||||
stats["api_keys"]["created"] += 1
|
||||
# 导入独立余额 Keys(需要找一个管理员用户作为 owner)
|
||||
if standalone_keys_data:
|
||||
# 查找一个管理员用户作为独立Key的owner
|
||||
admin_user = db.query(User).filter(User.role == UserRole.ADMIN).first()
|
||||
if not admin_user:
|
||||
stats["errors"].append("无法导入独立余额Key: 系统中没有管理员用户")
|
||||
else:
|
||||
for key_data in standalone_keys_data:
|
||||
new_key, status = _create_api_key_from_data(
|
||||
key_data, admin_user.id, is_standalone=True
|
||||
)
|
||||
if new_key:
|
||||
db.add(new_key)
|
||||
stats["standalone_keys"]["created"] += 1
|
||||
elif status == "skipped":
|
||||
stats["standalone_keys"]["skipped"] += 1
|
||||
# invalid 数据不计入统计
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
@@ -39,12 +39,21 @@ async def get_usage_aggregation(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get usage aggregation by specified dimension.
|
||||
获取使用情况聚合统计
|
||||
|
||||
- group_by=model: Aggregate by model
|
||||
- group_by=user: Aggregate by user
|
||||
- group_by=provider: Aggregate by provider
|
||||
- group_by=api_format: Aggregate by API format
|
||||
按指定维度聚合使用情况统计数据。
|
||||
|
||||
**查询参数**:
|
||||
- `group_by`: 必需,聚合维度,可选值:model(按模型)、user(按用户)、provider(按提供商)、api_format(按 API 格式)
|
||||
- `start_date`: 可选,开始日期(ISO 格式)
|
||||
- `end_date`: 可选,结束日期(ISO 格式)
|
||||
- `limit`: 返回数量限制,默认 20,最大 100
|
||||
|
||||
**返回字段**:
|
||||
- 按模型聚合时:model, request_count, total_tokens, total_cost, actual_cost
|
||||
- 按用户聚合时:user_id, email, username, request_count, total_tokens, total_cost
|
||||
- 按提供商聚合时:provider_id, provider, request_count, total_tokens, total_cost, actual_cost, avg_response_time_ms, success_rate, error_count
|
||||
- 按 API 格式聚合时:api_format, request_count, total_tokens, total_cost, actual_cost, avg_response_time_ms
|
||||
"""
|
||||
if group_by == "model":
|
||||
adapter = AdminUsageByModelAdapter(start_date=start_date, end_date=end_date, limit=limit)
|
||||
@@ -69,6 +78,25 @@ async def get_usage_stats(
|
||||
end_date: Optional[datetime] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取使用情况总体统计
|
||||
|
||||
获取指定时间范围内的使用情况总体统计数据。
|
||||
|
||||
**查询参数**:
|
||||
- `start_date`: 可选,开始日期(ISO 格式)
|
||||
- `end_date`: 可选,结束日期(ISO 格式)
|
||||
|
||||
**返回字段**:
|
||||
- `total_requests`: 总请求数
|
||||
- `total_tokens`: 总 token 数
|
||||
- `total_cost`: 总成本(美元)
|
||||
- `total_actual_cost`: 实际总成本(美元)
|
||||
- `avg_response_time`: 平均响应时间(秒)
|
||||
- `error_count`: 错误请求数
|
||||
- `error_rate`: 错误率(百分比)
|
||||
- `cache_stats`: 缓存统计信息(cache_creation_tokens, cache_read_tokens, cache_creation_cost, cache_read_cost)
|
||||
"""
|
||||
adapter = AdminUsageStatsAdapter(start_date=start_date, end_date=end_date)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -79,9 +107,12 @@ async def get_activity_heatmap(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get activity heatmap data for the past 365 days.
|
||||
获取活动热力图数据
|
||||
|
||||
This endpoint is cached for 5 minutes to reduce database load.
|
||||
获取过去 365 天的活动热力图数据。此接口缓存 5 分钟以减少数据库负载。
|
||||
|
||||
**返回字段**:
|
||||
- 按日期聚合的请求数、token 数、成本等统计数据
|
||||
"""
|
||||
adapter = AdminActivityHeatmapAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -92,6 +123,7 @@ async def get_usage_records(
|
||||
request: Request,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
search: Optional[str] = None, # 通用搜索:用户名、密钥名、模型名、提供商名
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
@@ -101,9 +133,37 @@ async def get_usage_records(
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取使用记录列表
|
||||
|
||||
获取详细的使用记录列表,支持多种筛选条件。
|
||||
|
||||
**查询参数**:
|
||||
- `start_date`: 可选,开始日期(ISO 格式)
|
||||
- `end_date`: 可选,结束日期(ISO 格式)
|
||||
- `search`: 可选,通用搜索关键词(支持用户名、密钥名、模型名、提供商名模糊搜索,多个关键词用空格分隔)
|
||||
- `user_id`: 可选,用户 ID 筛选
|
||||
- `username`: 可选,用户名模糊搜索
|
||||
- `model`: 可选,模型名模糊搜索
|
||||
- `provider`: 可选,提供商名称搜索
|
||||
- `status`: 可选,状态筛选(stream: 流式请求,standard: 标准请求,error: 错误请求,pending: 等待中,streaming: 流式中,completed: 已完成,failed: 失败,active: 活跃请求)
|
||||
- `limit`: 返回数量限制,默认 100,最大 500
|
||||
- `offset`: 分页偏移量,默认 0
|
||||
|
||||
**返回字段**:
|
||||
- `records`: 使用记录列表,包含 id, user_id, user_email, username, api_key, provider, model, target_model,
|
||||
input_tokens, output_tokens, cache_creation_input_tokens, cache_read_input_tokens, total_tokens,
|
||||
cost, actual_cost, rate_multiplier, response_time_ms, first_byte_time_ms, created_at, is_stream,
|
||||
input_price_per_1m, output_price_per_1m, cache_creation_price_per_1m, cache_read_price_per_1m,
|
||||
status_code, error_message, status, has_fallback, api_format, api_key_name, request_metadata
|
||||
- `total`: 符合条件的总记录数
|
||||
- `limit`: 当前分页限制
|
||||
- `offset`: 当前分页偏移量
|
||||
"""
|
||||
adapter = AdminUsageRecordsAdapter(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
search=search,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
model=model,
|
||||
@@ -122,10 +182,19 @@ async def get_active_requests(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取活跃请求的状态(轻量级接口,用于前端轮询)
|
||||
获取活跃请求的状态
|
||||
|
||||
获取当前活跃(pending/streaming 状态)请求的状态信息。这是一个轻量级接口,适合前端轮询。
|
||||
|
||||
**查询参数**:
|
||||
- `ids`: 可选,逗号分隔的请求 ID 列表,用于查询特定请求的状态
|
||||
|
||||
**行为说明**:
|
||||
- 如果提供 ids 参数,只返回这些 ID 对应请求的最新状态
|
||||
- 如果不提供 ids,返回所有 pending/streaming 状态的请求
|
||||
|
||||
**返回字段**:
|
||||
- `requests`: 活跃请求列表,包含请求状态信息
|
||||
"""
|
||||
adapter = AdminActiveRequestsAdapter(ids=ids)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -140,9 +209,48 @@ async def get_usage_detail(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get detailed information of a specific usage record.
|
||||
获取使用记录详情
|
||||
|
||||
Includes request/response headers and body.
|
||||
获取指定使用记录的详细信息,包括请求/响应的头部和正文。
|
||||
|
||||
**路径参数**:
|
||||
- `usage_id`: 使用记录 ID
|
||||
|
||||
**返回字段**:
|
||||
- `id`: 记录 ID
|
||||
- `request_id`: 请求 ID
|
||||
- `user`: 用户信息(id, username, email)
|
||||
- `api_key`: API Key 信息(id, name, display)
|
||||
- `provider`: 提供商名称
|
||||
- `api_format`: API 格式
|
||||
- `model`: 请求的模型名称
|
||||
- `target_model`: 映射后的目标模型名称
|
||||
- `tokens`: Token 统计(input, output, total)
|
||||
- `cost`: 成本统计(input, output, total)
|
||||
- `cache_creation_input_tokens`: 缓存创建输入 token 数
|
||||
- `cache_read_input_tokens`: 缓存读取输入 token 数
|
||||
- `cache_creation_cost`: 缓存创建成本
|
||||
- `cache_read_cost`: 缓存读取成本
|
||||
- `request_cost`: 请求成本
|
||||
- `input_price_per_1m`: 输入价格(每百万 token)
|
||||
- `output_price_per_1m`: 输出价格(每百万 token)
|
||||
- `cache_creation_price_per_1m`: 缓存创建价格(每百万 token)
|
||||
- `cache_read_price_per_1m`: 缓存读取价格(每百万 token)
|
||||
- `price_per_request`: 每请求价格
|
||||
- `request_type`: 请求类型
|
||||
- `is_stream`: 是否为流式请求
|
||||
- `status_code`: HTTP 状态码
|
||||
- `error_message`: 错误信息
|
||||
- `response_time_ms`: 响应时间(毫秒)
|
||||
- `first_byte_time_ms`: 首字节时间(TTFB,毫秒)
|
||||
- `created_at`: 创建时间
|
||||
- `request_headers`: 请求头
|
||||
- `request_body`: 请求体
|
||||
- `provider_request_headers`: 提供商请求头
|
||||
- `response_headers`: 响应头
|
||||
- `response_body`: 响应体
|
||||
- `metadata`: 提供商响应元数据
|
||||
- `tiered_pricing`: 阶梯计费信息(如适用)
|
||||
"""
|
||||
adapter = AdminUsageDetailAdapter(usage_id=usage_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -245,8 +353,8 @@ class AdminUsageByModelAdapter(AdminApiAdapter):
|
||||
)
|
||||
# 过滤掉 pending/streaming 状态的请求(尚未完成的请求不应计入统计)
|
||||
query = query.filter(Usage.status.notin_(["pending", "streaming"]))
|
||||
# 过滤掉 unknown/pending provider(请求未到达任何提供商)
|
||||
query = query.filter(Usage.provider.notin_(["unknown", "pending"]))
|
||||
# 过滤掉 unknown/pending provider_name(请求未到达任何提供商)
|
||||
query = query.filter(Usage.provider_name.notin_(["unknown", "pending"]))
|
||||
|
||||
if self.start_date:
|
||||
query = query.filter(Usage.created_at >= self.start_date)
|
||||
@@ -457,8 +565,8 @@ class AdminUsageByApiFormatAdapter(AdminApiAdapter):
|
||||
)
|
||||
# 过滤掉 pending/streaming 状态的请求
|
||||
query = query.filter(Usage.status.notin_(["pending", "streaming"]))
|
||||
# 过滤掉 unknown/pending provider
|
||||
query = query.filter(Usage.provider.notin_(["unknown", "pending"]))
|
||||
# 过滤掉 unknown/pending provider_name
|
||||
query = query.filter(Usage.provider_name.notin_(["unknown", "pending"]))
|
||||
# 只统计有 api_format 的记录
|
||||
query = query.filter(Usage.api_format.isnot(None))
|
||||
|
||||
@@ -500,6 +608,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
self,
|
||||
start_date: Optional[datetime],
|
||||
end_date: Optional[datetime],
|
||||
search: Optional[str],
|
||||
user_id: Optional[str],
|
||||
username: Optional[str],
|
||||
model: Optional[str],
|
||||
@@ -510,6 +619,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.search = search
|
||||
self.user_id = user_id
|
||||
self.username = username
|
||||
self.model = model
|
||||
@@ -519,25 +629,54 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
self.offset = offset
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from sqlalchemy import or_
|
||||
|
||||
from src.utils.database_helpers import escape_like_pattern, safe_truncate_escaped
|
||||
|
||||
db = context.db
|
||||
query = (
|
||||
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey)
|
||||
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey, ApiKey)
|
||||
.outerjoin(User, Usage.user_id == User.id)
|
||||
.outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
|
||||
.outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
|
||||
.outerjoin(ApiKey, Usage.api_key_id == ApiKey.id)
|
||||
)
|
||||
|
||||
# 如果需要按 Provider 名称搜索/筛选,统一在这里 JOIN
|
||||
if self.search or self.provider:
|
||||
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
|
||||
|
||||
# 通用搜索:用户名、密钥名、模型名、提供商名
|
||||
# 支持空格分隔的组合搜索,多个关键词之间是 AND 关系
|
||||
# 限制:最多 10 个关键词,转义后每个关键词最长 100 字符
|
||||
if self.search:
|
||||
keywords = [kw for kw in self.search.strip().split() if kw][:10]
|
||||
for keyword in keywords:
|
||||
escaped = safe_truncate_escaped(escape_like_pattern(keyword), 100)
|
||||
search_pattern = f"%{escaped}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
User.username.ilike(search_pattern, escape="\\"),
|
||||
ApiKey.name.ilike(search_pattern, escape="\\"),
|
||||
Usage.model.ilike(search_pattern, escape="\\"),
|
||||
Provider.name.ilike(search_pattern, escape="\\"),
|
||||
)
|
||||
)
|
||||
|
||||
if self.user_id:
|
||||
query = query.filter(Usage.user_id == self.user_id)
|
||||
if self.username:
|
||||
# 支持用户名模糊搜索
|
||||
query = query.filter(User.username.ilike(f"%{self.username}%"))
|
||||
escaped = escape_like_pattern(self.username)
|
||||
query = query.filter(User.username.ilike(f"%{escaped}%", escape="\\"))
|
||||
if self.model:
|
||||
# 支持模型名模糊搜索
|
||||
query = query.filter(Usage.model.ilike(f"%{self.model}%"))
|
||||
escaped = escape_like_pattern(self.model)
|
||||
query = query.filter(Usage.model.ilike(f"%{escaped}%", escape="\\"))
|
||||
if self.provider:
|
||||
# 支持提供商名称搜索(通过 Provider 表)
|
||||
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
|
||||
query = query.filter(Provider.name.ilike(f"%{self.provider}%"))
|
||||
# 支持提供商名称搜索
|
||||
escaped = escape_like_pattern(self.provider)
|
||||
query = query.filter(Provider.name.ilike(f"%{escaped}%", escape="\\"))
|
||||
if self.status:
|
||||
# 状态筛选
|
||||
# 旧的筛选值(基于 is_stream 和 status_code):stream, standard, error
|
||||
@@ -575,7 +714,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
|
||||
)
|
||||
|
||||
request_ids = [usage.request_id for usage, _, _, _ in records if usage.request_id]
|
||||
request_ids = [usage.request_id for usage, _, _, _, _ in records if usage.request_id]
|
||||
fallback_map = {}
|
||||
if request_ids:
|
||||
# 只统计实际执行的候选(success 或 failed),不包括 skipped/pending/available
|
||||
@@ -595,6 +734,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
action="usage_records",
|
||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||
end_date=self.end_date.isoformat() if self.end_date else None,
|
||||
search=self.search,
|
||||
user_id=self.user_id,
|
||||
username=self.username,
|
||||
model=self.model,
|
||||
@@ -606,7 +746,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
)
|
||||
|
||||
# 构建 provider_id -> Provider 名称的映射,避免 N+1 查询
|
||||
provider_ids = [usage.provider_id for usage, _, _, _ in records if usage.provider_id]
|
||||
provider_ids = [usage.provider_id for usage, _, _, _, _ in records if usage.provider_id]
|
||||
provider_map = {}
|
||||
if provider_ids:
|
||||
providers_data = (
|
||||
@@ -615,7 +755,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
provider_map = {str(p.id): p.name for p in providers_data}
|
||||
|
||||
data = []
|
||||
for usage, user, endpoint, api_key in records:
|
||||
for usage, user, endpoint, provider_api_key, user_api_key in records:
|
||||
actual_cost = (
|
||||
float(usage.actual_total_cost_usd)
|
||||
if usage.actual_total_cost_usd is not None
|
||||
@@ -625,8 +765,8 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
float(usage.rate_multiplier) if usage.rate_multiplier is not None else 1.0
|
||||
)
|
||||
|
||||
# 提供商名称优先级:关联的 Provider 表 > usage.provider 字段
|
||||
provider_name = usage.provider
|
||||
# 提供商名称优先级:关联的 Provider 表 > usage.provider_name 字段
|
||||
provider_name = usage.provider_name
|
||||
if usage.provider_id and str(usage.provider_id) in provider_map:
|
||||
provider_name = provider_map[str(usage.provider_id)]
|
||||
|
||||
@@ -636,6 +776,15 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
"user_id": user.id if user else None,
|
||||
"user_email": user.email if user else "已删除用户",
|
||||
"username": user.username if user else "已删除用户",
|
||||
"api_key": (
|
||||
{
|
||||
"id": user_api_key.id,
|
||||
"name": user_api_key.name,
|
||||
"display": user_api_key.get_display_key(),
|
||||
}
|
||||
if user_api_key
|
||||
else None
|
||||
),
|
||||
"provider": provider_name,
|
||||
"model": usage.model,
|
||||
"target_model": usage.target_model, # 映射后的目标模型名
|
||||
@@ -661,7 +810,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
"has_fallback": fallback_map.get(usage.request_id, False),
|
||||
"api_format": usage.api_format
|
||||
or (endpoint.api_format if endpoint and endpoint.api_format else None),
|
||||
"api_key_name": api_key.name if api_key else None,
|
||||
"api_key_name": provider_api_key.name if provider_api_key else None,
|
||||
"request_metadata": usage.request_metadata, # Provider 响应元数据
|
||||
}
|
||||
)
|
||||
@@ -732,7 +881,7 @@ class AdminUsageDetailAdapter(AdminApiAdapter):
|
||||
"name": api_key.name if api_key else None,
|
||||
"display": api_key.get_display_key() if api_key else None,
|
||||
},
|
||||
"provider": usage_record.provider,
|
||||
"provider": usage_record.provider_name,
|
||||
"api_format": usage_record.api_format,
|
||||
"model": usage_record.model,
|
||||
"target_model": usage_record.target_model,
|
||||
@@ -785,7 +934,7 @@ class AdminUsageDetailAdapter(AdminApiAdapter):
|
||||
# 尝试获取模型的阶梯配置(带来源信息)
|
||||
cost_service = ModelCostService(db)
|
||||
pricing_result = await cost_service.get_tiered_pricing_with_source_async(
|
||||
usage_record.provider, usage_record.model
|
||||
usage_record.provider_name, usage_record.model
|
||||
)
|
||||
|
||||
if not pricing_result:
|
||||
|
||||
@@ -26,6 +26,18 @@ pipeline = ApiRequestPipeline()
|
||||
# 管理员端点
|
||||
@router.post("")
|
||||
async def create_user_endpoint(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
创建用户
|
||||
|
||||
创建新用户账号(管理员专用)。
|
||||
|
||||
**请求体**:
|
||||
- `email`: 邮箱地址
|
||||
- `username`: 用户名
|
||||
- `password`: 密码
|
||||
- `role`: 角色(user/admin)
|
||||
- `quota_usd`: 配额(USD)
|
||||
"""
|
||||
adapter = AdminCreateUserAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -33,18 +45,33 @@ async def create_user_endpoint(request: Request, db: Session = Depends(get_db)):
|
||||
@router.get("")
|
||||
async def list_users(
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
role: Optional[str] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回记录数"),
|
||||
role: Optional[str] = Query(None, description="按角色筛选(user/admin)"),
|
||||
is_active: Optional[bool] = Query(None, description="按状态筛选"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取用户列表
|
||||
|
||||
分页获取用户列表,支持按角色和状态筛选。
|
||||
|
||||
**返回字段**: id, email, username, role, quota_usd, used_usd, is_active, created_at 等
|
||||
"""
|
||||
adapter = AdminListUsersAdapter(skip=skip, limit=limit, role=role, is_active=is_active)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{user_id}")
|
||||
async def get_user(user_id: str, request: Request, db: Session = Depends(get_db)): # UUID
|
||||
async def get_user(user_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
获取用户详情
|
||||
|
||||
获取指定用户的详细信息。
|
||||
|
||||
**路径参数**:
|
||||
- `user_id`: 用户 ID (UUID)
|
||||
"""
|
||||
adapter = AdminGetUserAdapter(user_id=user_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -55,19 +82,51 @@ async def update_user(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
更新用户信息
|
||||
|
||||
更新指定用户的信息,包括角色、配额、权限等。
|
||||
|
||||
**路径参数**:
|
||||
- `user_id`: 用户 ID (UUID)
|
||||
|
||||
**请求体** (均为可选):
|
||||
- `email`: 邮箱地址
|
||||
- `username`: 用户名
|
||||
- `role`: 角色
|
||||
- `quota_usd`: 配额
|
||||
- `is_active`: 是否启用
|
||||
- `allowed_providers`: 允许的提供商列表
|
||||
- `allowed_models`: 允许的模型列表
|
||||
"""
|
||||
adapter = AdminUpdateUserAdapter(user_id=user_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{user_id}")
|
||||
async def delete_user(user_id: str, request: Request, db: Session = Depends(get_db)): # UUID
|
||||
async def delete_user(user_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
删除用户
|
||||
|
||||
永久删除指定用户。不能删除最后一个管理员账户。
|
||||
|
||||
**路径参数**:
|
||||
- `user_id`: 用户 ID (UUID)
|
||||
"""
|
||||
adapter = AdminDeleteUserAdapter(user_id=user_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{user_id}/quota")
|
||||
async def reset_user_quota(user_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""Reset user quota (set used_usd to 0)"""
|
||||
"""
|
||||
重置用户配额
|
||||
|
||||
将用户的已用配额(used_usd)重置为 0。
|
||||
|
||||
**路径参数**:
|
||||
- `user_id`: 用户 ID (UUID)
|
||||
"""
|
||||
adapter = AdminResetUserQuotaAdapter(user_id=user_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -76,10 +135,17 @@ async def reset_user_quota(user_id: str, request: Request, db: Session = Depends
|
||||
async def get_user_api_keys(
|
||||
user_id: str,
|
||||
request: Request,
|
||||
is_active: Optional[bool] = None,
|
||||
is_active: Optional[bool] = Query(None, description="按状态筛选"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取用户的所有API Keys(不包括独立Keys)"""
|
||||
"""
|
||||
获取用户的 API 密钥列表
|
||||
|
||||
获取指定用户的所有 API 密钥(不包括独立密钥)。
|
||||
|
||||
**路径参数**:
|
||||
- `user_id`: 用户 ID (UUID)
|
||||
"""
|
||||
adapter = AdminGetUserKeysAdapter(user_id=user_id, is_active=is_active)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -90,7 +156,23 @@ async def create_user_api_key(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""为用户创建API Key"""
|
||||
"""
|
||||
为用户创建 API 密钥
|
||||
|
||||
为指定用户创建新的 API 密钥。
|
||||
|
||||
**路径参数**:
|
||||
- `user_id`: 用户 ID (UUID)
|
||||
|
||||
**请求体**:
|
||||
- `name`: 密钥名称
|
||||
- `allowed_providers`: 允许的提供商(可选)
|
||||
- `allowed_models`: 允许的模型(可选)
|
||||
- `rate_limit`: 速率限制(可选)
|
||||
- `expire_days`: 过期天数(可选)
|
||||
|
||||
**返回**: 包含完整密钥值的响应(仅此一次显示)
|
||||
"""
|
||||
adapter = AdminCreateUserKeyAdapter(user_id=user_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -102,7 +184,15 @@ async def delete_user_api_key(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除用户的API Key"""
|
||||
"""
|
||||
删除用户的 API 密钥
|
||||
|
||||
删除指定用户的指定 API 密钥。
|
||||
|
||||
**路径参数**:
|
||||
- `user_id`: 用户 ID (UUID)
|
||||
- `key_id`: 密钥 ID
|
||||
"""
|
||||
adapter = AdminDeleteUserKeyAdapter(user_id=user_id, key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -156,7 +246,7 @@ class AdminCreateUserAdapter(AdminApiAdapter):
|
||||
"username": user.username,
|
||||
"role": user.role.value,
|
||||
"allowed_providers": user.allowed_providers,
|
||||
"allowed_endpoints": user.allowed_endpoints,
|
||||
"allowed_api_formats": user.allowed_api_formats,
|
||||
"allowed_models": user.allowed_models,
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
@@ -183,6 +273,9 @@ class AdminListUsersAdapter(AdminApiAdapter):
|
||||
"email": u.email,
|
||||
"username": u.username,
|
||||
"role": u.role.value,
|
||||
"allowed_providers": u.allowed_providers,
|
||||
"allowed_api_formats": u.allowed_api_formats,
|
||||
"allowed_models": u.allowed_models,
|
||||
"quota_usd": u.quota_usd,
|
||||
"used_usd": u.used_usd,
|
||||
"total_usd": getattr(u, "total_usd", 0),
|
||||
@@ -216,7 +309,7 @@ class AdminGetUserAdapter(AdminApiAdapter):
|
||||
"username": user.username,
|
||||
"role": user.role.value,
|
||||
"allowed_providers": user.allowed_providers,
|
||||
"allowed_endpoints": user.allowed_endpoints,
|
||||
"allowed_api_formats": user.allowed_api_formats,
|
||||
"allowed_models": user.allowed_models,
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
@@ -282,7 +375,7 @@ class AdminUpdateUserAdapter(AdminApiAdapter):
|
||||
"username": user.username,
|
||||
"role": user.role.value,
|
||||
"allowed_providers": user.allowed_providers,
|
||||
"allowed_endpoints": user.allowed_endpoints,
|
||||
"allowed_api_formats": user.allowed_api_formats,
|
||||
"allowed_models": user.allowed_models,
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
|
||||
@@ -35,7 +35,32 @@ async def list_announcements(
|
||||
offset: int = Query(0, description="偏移量"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取公告列表(包含已读状态)"""
|
||||
"""
|
||||
获取公告列表
|
||||
|
||||
获取公告列表,支持分页和筛选。如果用户已登录,返回包含已读状态。
|
||||
|
||||
**查询参数**:
|
||||
- `active_only`: 是否只返回有效公告,默认 true
|
||||
- `limit`: 返回数量限制,默认 50
|
||||
- `offset`: 分页偏移量,默认 0
|
||||
|
||||
**返回字段**:
|
||||
- `items`: 公告列表,每条公告包含:
|
||||
- `id`: 公告 ID
|
||||
- `title`: 标题
|
||||
- `content`: 内容
|
||||
- `type`: 类型(info/warning/error/success)
|
||||
- `priority`: 优先级
|
||||
- `is_pinned`: 是否置顶
|
||||
- `is_read`: 是否已读(仅登录用户)
|
||||
- `author`: 作者信息
|
||||
- `start_time`: 生效开始时间
|
||||
- `end_time`: 生效结束时间
|
||||
- `created_at`: 创建时间
|
||||
- `total`: 总数
|
||||
- `unread_count`: 未读数量(仅登录用户)
|
||||
"""
|
||||
adapter = ListAnnouncementsAdapter(active_only=active_only, limit=limit, offset=offset)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -45,7 +70,16 @@ async def get_active_announcements(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取当前有效的公告(首页展示)"""
|
||||
"""
|
||||
获取当前有效的公告
|
||||
|
||||
获取当前时间范围内有效的公告列表,用于首页展示。
|
||||
|
||||
**返回字段**:
|
||||
- `items`: 有效公告列表
|
||||
- `total`: 有效公告总数
|
||||
- `unread_count`: 未读数量(仅登录用户)
|
||||
"""
|
||||
adapter = GetActiveAnnouncementsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -56,7 +90,27 @@ async def get_announcement(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取单个公告详情"""
|
||||
"""
|
||||
获取单个公告详情
|
||||
|
||||
获取指定公告的详细信息。
|
||||
|
||||
**路径参数**:
|
||||
- `announcement_id`: 公告 ID(UUID)
|
||||
|
||||
**返回字段**:
|
||||
- `id`: 公告 ID
|
||||
- `title`: 标题
|
||||
- `content`: 内容
|
||||
- `type`: 类型(info/warning/error/success)
|
||||
- `priority`: 优先级
|
||||
- `is_pinned`: 是否置顶
|
||||
- `author`: 作者信息(id, username)
|
||||
- `start_time`: 生效开始时间
|
||||
- `end_time`: 生效结束时间
|
||||
- `created_at`: 创建时间
|
||||
- `updated_at`: 更新时间
|
||||
"""
|
||||
adapter = GetAnnouncementAdapter(announcement_id=announcement_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -67,7 +121,17 @@ async def mark_announcement_as_read(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Mark announcement as read"""
|
||||
"""
|
||||
标记公告为已读
|
||||
|
||||
将指定公告标记为当前用户已读。需要登录。
|
||||
|
||||
**路径参数**:
|
||||
- `announcement_id`: 公告 ID(UUID)
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果信息
|
||||
"""
|
||||
adapter = MarkAnnouncementReadAdapter(announcement_id=announcement_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -80,7 +144,25 @@ async def create_announcement(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""创建公告(管理员权限)"""
|
||||
"""
|
||||
创建公告
|
||||
|
||||
创建新的系统公告。需要管理员权限。
|
||||
|
||||
**请求体字段**:
|
||||
- `title`: 公告标题(必填)
|
||||
- `content`: 公告内容(必填)
|
||||
- `type`: 公告类型(info/warning/error/success),默认 info
|
||||
- `priority`: 优先级(0-100),默认 0
|
||||
- `is_pinned`: 是否置顶,默认 false
|
||||
- `start_time`: 生效开始时间(可选)
|
||||
- `end_time`: 生效结束时间(可选)
|
||||
|
||||
**返回字段**:
|
||||
- `id`: 新创建的公告 ID
|
||||
- `title`: 公告标题
|
||||
- `message`: 操作结果信息
|
||||
"""
|
||||
adapter = CreateAnnouncementAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -91,7 +173,27 @@ async def update_announcement(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""更新公告(管理员权限)"""
|
||||
"""
|
||||
更新公告
|
||||
|
||||
更新指定公告的信息。需要管理员权限。
|
||||
|
||||
**路径参数**:
|
||||
- `announcement_id`: 公告 ID(UUID)
|
||||
|
||||
**请求体字段(均为可选)**:
|
||||
- `title`: 公告标题
|
||||
- `content`: 公告内容
|
||||
- `type`: 公告类型(info/warning/error/success)
|
||||
- `priority`: 优先级(0-100)
|
||||
- `is_active`: 是否启用
|
||||
- `is_pinned`: 是否置顶
|
||||
- `start_time`: 生效开始时间
|
||||
- `end_time`: 生效结束时间
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果信息
|
||||
"""
|
||||
adapter = UpdateAnnouncementAdapter(announcement_id=announcement_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -102,7 +204,17 @@ async def delete_announcement(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除公告(管理员权限)"""
|
||||
"""
|
||||
删除公告
|
||||
|
||||
删除指定的公告。需要管理员权限。
|
||||
|
||||
**路径参数**:
|
||||
- `announcement_id`: 公告 ID(UUID)
|
||||
|
||||
**返回字段**:
|
||||
- `message`: 操作结果信息
|
||||
"""
|
||||
adapter = DeleteAnnouncementAdapter(announcement_id=announcement_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -115,7 +227,14 @@ async def get_my_unread_announcement_count(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取我的未读公告数量"""
|
||||
"""
|
||||
获取我的未读公告数量
|
||||
|
||||
获取当前用户的未读公告数量。需要登录。
|
||||
|
||||
**返回字段**:
|
||||
- `unread_count`: 未读公告数量
|
||||
"""
|
||||
adapter = UnreadAnnouncementCountAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from src.models.api import (
|
||||
)
|
||||
from src.models.database import AuditEventType, User, UserRole
|
||||
from src.services.auth.service import AuthService
|
||||
from src.services.auth.ldap import LDAPService
|
||||
from src.services.rate_limit.ip_limiter import IPRateLimiter
|
||||
from src.services.system.audit import AuditService
|
||||
from src.services.system.config import SystemConfigService
|
||||
@@ -94,65 +95,142 @@ pipeline = ApiRequestPipeline()
|
||||
# API端点
|
||||
@router.get("/registration-settings", response_model=RegistrationSettingsResponse)
|
||||
async def registration_settings(request: Request, db: Session = Depends(get_db)):
|
||||
"""公开获取注册相关配置"""
|
||||
"""
|
||||
获取注册相关配置
|
||||
|
||||
返回系统注册配置,包括是否开放注册、是否需要邮箱验证等。
|
||||
此接口为公开接口,无需认证。
|
||||
"""
|
||||
adapter = AuthRegistrationSettingsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/settings")
|
||||
async def auth_settings(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
获取认证设置
|
||||
|
||||
返回系统支持的认证方式,如本地认证、LDAP 认证等。
|
||||
前端据此判断显示哪些登录选项。此接口为公开接口,无需认证。
|
||||
"""
|
||||
adapter = AuthSettingsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
async def login(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
用户登录
|
||||
|
||||
使用邮箱和密码登录,成功后返回 JWT access_token 和 refresh_token。
|
||||
|
||||
- **access_token**: 用于后续 API 调用,有效期 24 小时
|
||||
- **refresh_token**: 用于刷新 access_token
|
||||
|
||||
速率限制: 5次/分钟/IP
|
||||
"""
|
||||
adapter = AuthLoginAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=RefreshTokenResponse)
|
||||
async def refresh_token(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
刷新访问令牌
|
||||
|
||||
使用 refresh_token 获取新的 access_token 和 refresh_token。
|
||||
原 refresh_token 刷新后失效。
|
||||
"""
|
||||
adapter = AuthRefreshAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/register", response_model=RegisterResponse)
|
||||
async def register(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
用户注册
|
||||
|
||||
创建新用户账号。需要系统开放注册功能。
|
||||
如果系统开启了邮箱验证,需先通过 /send-verification-code 和 /verify-email 完成邮箱验证。
|
||||
|
||||
速率限制: 3次/分钟/IP
|
||||
"""
|
||||
adapter = AuthRegisterAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def get_current_user_info(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
获取当前用户信息
|
||||
|
||||
返回当前登录用户的基本信息,包括邮箱、用户名、角色、配额等。
|
||||
需要 Bearer Token 认证。
|
||||
"""
|
||||
adapter = AuthCurrentUserAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/password")
|
||||
async def change_password(request: Request, db: Session = Depends(get_db)):
|
||||
"""Change current user's password"""
|
||||
"""
|
||||
修改密码
|
||||
|
||||
修改当前用户的登录密码,需提供旧密码验证。
|
||||
密码长度至少 6 位。
|
||||
"""
|
||||
adapter = AuthChangePasswordAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=LogoutResponse)
|
||||
async def logout(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
用户登出
|
||||
|
||||
将当前 Token 加入黑名单,使其失效。
|
||||
"""
|
||||
adapter = AuthLogoutAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/send-verification-code", response_model=SendVerificationCodeResponse)
|
||||
async def send_verification_code(request: Request, db: Session = Depends(get_db)):
|
||||
"""发送邮箱验证码"""
|
||||
"""
|
||||
发送邮箱验证码
|
||||
|
||||
向指定邮箱发送验证码,用于注册前的邮箱验证。
|
||||
验证码有效期 5 分钟,同一邮箱 60 秒内只能发送一次。
|
||||
|
||||
速率限制: 3次/分钟/IP
|
||||
"""
|
||||
adapter = AuthSendVerificationCodeAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/verify-email", response_model=VerifyEmailResponse)
|
||||
async def verify_email(request: Request, db: Session = Depends(get_db)):
|
||||
"""验证邮箱验证码"""
|
||||
"""
|
||||
验证邮箱验证码
|
||||
|
||||
验证邮箱收到的验证码是否正确。
|
||||
验证成功后,邮箱会被标记为已验证状态,可用于注册。
|
||||
|
||||
速率限制: 10次/分钟/IP
|
||||
"""
|
||||
adapter = AuthVerifyEmailAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/verification-status", response_model=VerificationStatusResponse)
|
||||
async def verification_status(request: Request, db: Session = Depends(get_db)):
|
||||
"""查询邮箱验证状态"""
|
||||
"""
|
||||
查询邮箱验证状态
|
||||
|
||||
查询指定邮箱的验证状态,包括是否有待验证的验证码、是否已验证等。
|
||||
|
||||
速率限制: 20次/分钟/IP
|
||||
"""
|
||||
adapter = AuthVerificationStatusAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -193,7 +271,9 @@ class AuthLoginAdapter(AuthPublicAdapter):
|
||||
detail=f"登录请求过于频繁,请在 {reset_after} 秒后重试",
|
||||
)
|
||||
|
||||
user = await AuthService.authenticate_user(db, login_request.email, login_request.password)
|
||||
user = await AuthService.authenticate_user(
|
||||
db, login_request.email, login_request.password, login_request.auth_type
|
||||
)
|
||||
if not user:
|
||||
AuditService.log_login_attempt(
|
||||
db=db,
|
||||
@@ -305,6 +385,21 @@ class AuthRegistrationSettingsAdapter(AuthPublicAdapter):
|
||||
).model_dump()
|
||||
|
||||
|
||||
class AuthSettingsAdapter(AuthPublicAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""公开返回认证设置"""
|
||||
db = context.db
|
||||
|
||||
ldap_enabled = LDAPService.is_ldap_enabled(db)
|
||||
ldap_exclusive = LDAPService.is_ldap_exclusive(db)
|
||||
|
||||
return {
|
||||
"local_enabled": not ldap_exclusive,
|
||||
"ldap_enabled": ldap_enabled,
|
||||
"ldap_exclusive": ldap_exclusive,
|
||||
}
|
||||
|
||||
|
||||
class AuthRegisterAdapter(AuthPublicAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.models.database import SystemConfig
|
||||
@@ -324,6 +419,12 @@ class AuthRegisterAdapter(AuthPublicAdapter):
|
||||
detail=f"注册请求过于频繁,请在 {reset_after} 秒后重试",
|
||||
)
|
||||
|
||||
# 仅允许 LDAP 登录时拒绝本地注册
|
||||
if LDAPService.is_ldap_exclusive(db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="系统已启用 LDAP 专属登录,禁止本地注册"
|
||||
)
|
||||
|
||||
allow_registration = db.query(SystemConfig).filter_by(key="enable_registration").first()
|
||||
if allow_registration and not allow_registration.value:
|
||||
AuditService.log_event(
|
||||
@@ -427,7 +528,7 @@ class AuthCurrentUserAdapter(AuthenticatedApiAdapter):
|
||||
"used_usd": user.used_usd,
|
||||
"total_usd": user.total_usd,
|
||||
"allowed_providers": user.allowed_providers,
|
||||
"allowed_endpoints": user.allowed_endpoints,
|
||||
"allowed_api_formats": user.allowed_api_formats,
|
||||
"allowed_models": user.allowed_models,
|
||||
"created_at": user.created_at.isoformat(),
|
||||
"last_login_at": user.last_login_at.isoformat() if user.last_login_at else None,
|
||||
|
||||
@@ -15,6 +15,7 @@ class ApiMode(str, Enum):
|
||||
ADMIN = "admin"
|
||||
USER = "user" # JWT 认证的普通用户(不要求管理员权限)
|
||||
PUBLIC = "public"
|
||||
MANAGEMENT = "management" # Management Token 认证
|
||||
|
||||
|
||||
class ApiAdapter(ABC):
|
||||
|
||||
@@ -10,7 +10,8 @@ from fastapi import HTTPException, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ApiKey, User
|
||||
from src.models.database import ApiKey, ManagementToken, User
|
||||
from src.utils.request_utils import get_client_ip
|
||||
|
||||
|
||||
|
||||
@@ -37,6 +38,9 @@ class ApiRequestContext:
|
||||
# URL 路径参数(如 Gemini API 的 /v1beta/models/{model}:generateContent)
|
||||
path_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Management Token(用于管理 API 认证)
|
||||
management_token: Optional[ManagementToken] = None
|
||||
|
||||
# 供适配器扩展的状态存储
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
audit_metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
@@ -86,7 +90,7 @@ class ApiRequestContext:
|
||||
setattr(request.state, "request_id", request_id)
|
||||
|
||||
start_time = time.time()
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
client_ip = get_client_ip(request)
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
context = cls(
|
||||
|
||||
@@ -143,12 +143,13 @@ class AccessRestrictions:
|
||||
allowed_api_formats = api_key.allowed_api_formats
|
||||
|
||||
# 如果 API Key 没有限制,检查 User 的限制
|
||||
# 注意: User 没有 allowed_api_formats 字段
|
||||
if user:
|
||||
if allowed_providers is None and user.allowed_providers is not None:
|
||||
allowed_providers = user.allowed_providers
|
||||
if allowed_models is None and user.allowed_models is not None:
|
||||
allowed_models = user.allowed_models
|
||||
if allowed_api_formats is None and user.allowed_api_formats is not None:
|
||||
allowed_api_formats = user.allowed_api_formats
|
||||
|
||||
return cls(
|
||||
allowed_providers=allowed_providers,
|
||||
|
||||
@@ -2,18 +2,23 @@ from __future__ import annotations
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.settings import config
|
||||
from src.core.enums import UserRole
|
||||
from src.core.exceptions import QuotaExceededException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ApiKey, AuditEventType, User, UserRole
|
||||
from src.models.database import ApiKey, AuditEventType, User
|
||||
from src.services.auth.service import AuthService
|
||||
from src.services.system.audit import AuditService
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.models.database import ManagementToken
|
||||
|
||||
from .adapter import ApiAdapter, ApiMode
|
||||
from .context import ApiRequestContext
|
||||
|
||||
@@ -46,17 +51,22 @@ class ApiRequestPipeline:
|
||||
logger.debug(f"[Pipeline] Running with mode={mode}, adapter={adapter.__class__.__name__}, "
|
||||
f"adapter.mode={adapter.mode}, path={http_request.url.path}")
|
||||
if mode == ApiMode.ADMIN:
|
||||
user = await self._authenticate_admin(http_request, db)
|
||||
user, management_token = await self._authenticate_admin(http_request, db)
|
||||
api_key = None
|
||||
elif mode == ApiMode.USER:
|
||||
user = await self._authenticate_user(http_request, db)
|
||||
user, management_token = await self._authenticate_user(http_request, db)
|
||||
api_key = None
|
||||
elif mode == ApiMode.PUBLIC:
|
||||
user = None
|
||||
api_key = None
|
||||
management_token = None
|
||||
elif mode == ApiMode.MANAGEMENT:
|
||||
user, management_token = await self._authenticate_management(http_request, db)
|
||||
api_key = None
|
||||
else:
|
||||
logger.debug("[Pipeline] 调用 _authenticate_client")
|
||||
user, api_key = self._authenticate_client(http_request, db, adapter)
|
||||
management_token = None
|
||||
logger.debug(f"[Pipeline] 认证完成 | user={user.username if user else None}")
|
||||
|
||||
raw_body = None
|
||||
@@ -64,13 +74,17 @@ class ApiRequestPipeline:
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
# 添加30秒超时防止卡死
|
||||
raw_body = await asyncio.wait_for(http_request.body(), timeout=30.0)
|
||||
# 添加超时防止卡死
|
||||
raw_body = await asyncio.wait_for(
|
||||
http_request.body(), timeout=config.request_body_timeout
|
||||
)
|
||||
logger.debug(f"[Pipeline] Raw body读取完成 | size={len(raw_body) if raw_body is not None else 0} bytes")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("读取请求体超时(30s),可能客户端未发送完整请求体")
|
||||
timeout_sec = int(config.request_body_timeout)
|
||||
logger.error(f"读取请求体超时({timeout_sec}s),可能客户端未发送完整请求体")
|
||||
raise HTTPException(
|
||||
status_code=408, detail="Request timeout: body not received within 30 seconds"
|
||||
status_code=408,
|
||||
detail=f"Request timeout: body not received within {timeout_sec} seconds",
|
||||
)
|
||||
else:
|
||||
logger.debug(f"[Pipeline] 非写请求跳过读取Body | method={http_request.method}")
|
||||
@@ -85,6 +99,9 @@ class ApiRequestPipeline:
|
||||
api_format_hint=api_format_hint,
|
||||
path_params=path_params,
|
||||
)
|
||||
# 存储 management_token 到 context(用于权限检查)
|
||||
if management_token:
|
||||
context.management_token = management_token
|
||||
logger.debug(f"[Pipeline] Context构建完成 | adapter={adapter.name} | request_id={context.request_id}")
|
||||
|
||||
if mode != ApiMode.ADMIN and user:
|
||||
@@ -172,12 +189,41 @@ class ApiRequestPipeline:
|
||||
|
||||
return user, api_key
|
||||
|
||||
async def _authenticate_admin(self, request: Request, db: Session) -> User:
|
||||
async def _authenticate_admin(
|
||||
self, request: Request, db: Session
|
||||
) -> Tuple[User, Optional["ManagementToken"]]:
|
||||
"""管理员认证,支持 JWT 和 Management Token 两种方式"""
|
||||
from src.models.database import ManagementToken
|
||||
from src.utils.request_utils import get_client_ip
|
||||
|
||||
authorization = request.headers.get("authorization")
|
||||
if not authorization or not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="缺少管理员凭证")
|
||||
|
||||
token = authorization[7:].strip()
|
||||
|
||||
# 检查是否为 Management Token(ae_ 前缀)
|
||||
if token.startswith(ManagementToken.TOKEN_PREFIX):
|
||||
client_ip = get_client_ip(request)
|
||||
result = await self.auth_service.authenticate_management_token(db, token, client_ip)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=401, detail="无效或过期的 Management Token")
|
||||
|
||||
user, management_token = result
|
||||
|
||||
# 检查管理员权限
|
||||
if user.role != UserRole.ADMIN:
|
||||
logger.warning(f"非管理员尝试通过 Management Token 访问管理端点: {user.email}")
|
||||
raise HTTPException(status_code=403, detail="需要管理员权限")
|
||||
|
||||
# 存储到 request.state
|
||||
request.state.user_id = user.id
|
||||
request.state.management_token_id = management_token.id
|
||||
|
||||
return user, management_token
|
||||
|
||||
# JWT 认证
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||
except HTTPException:
|
||||
@@ -195,16 +241,43 @@ class ApiRequestPipeline:
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||||
|
||||
request.state.user_id = user.id
|
||||
return user
|
||||
# 检查管理员权限
|
||||
if user.role != UserRole.ADMIN:
|
||||
logger.warning(f"非管理员尝试通过 JWT 访问管理端点: {user.email}")
|
||||
raise HTTPException(status_code=403, detail="需要管理员权限")
|
||||
|
||||
request.state.user_id = user.id
|
||||
return user, None
|
||||
|
||||
async def _authenticate_user(
|
||||
self, request: Request, db: Session
|
||||
) -> Tuple[User, Optional["ManagementToken"]]:
|
||||
"""用户认证,支持 JWT 和 Management Token 两种方式"""
|
||||
from src.models.database import ManagementToken
|
||||
from src.utils.request_utils import get_client_ip
|
||||
|
||||
async def _authenticate_user(self, request: Request, db: Session) -> User:
|
||||
"""JWT 认证普通用户(不要求管理员权限)"""
|
||||
authorization = request.headers.get("authorization")
|
||||
if not authorization or not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="缺少用户凭证")
|
||||
|
||||
token = authorization[7:].strip()
|
||||
|
||||
# 检查是否为 Management Token(ae_ 前缀)
|
||||
if token.startswith(ManagementToken.TOKEN_PREFIX):
|
||||
client_ip = get_client_ip(request)
|
||||
result = await self.auth_service.authenticate_management_token(db, token, client_ip)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=401, detail="无效或过期的 Management Token")
|
||||
|
||||
user, management_token = result
|
||||
|
||||
request.state.user_id = user.id
|
||||
request.state.management_token_id = management_token.id
|
||||
|
||||
return user, management_token
|
||||
|
||||
# JWT 认证
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||
except HTTPException:
|
||||
@@ -217,13 +290,47 @@ class ApiRequestPipeline:
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="无效的用户令牌")
|
||||
|
||||
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||||
|
||||
request.state.user_id = user.id
|
||||
return user
|
||||
return user, None
|
||||
|
||||
async def _authenticate_management(
|
||||
self, request: Request, db: Session
|
||||
) -> Tuple[User, "ManagementToken"]:
|
||||
"""Management Token 认证"""
|
||||
from src.models.database import ManagementToken
|
||||
from src.utils.request_utils import get_client_ip
|
||||
|
||||
authorization = request.headers.get("authorization")
|
||||
if not authorization or not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="缺少 Management Token")
|
||||
|
||||
token = authorization[7:].strip()
|
||||
|
||||
# 检查是否为 Management Token 格式
|
||||
if not token.startswith(ManagementToken.TOKEN_PREFIX):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"无效的 Token 格式,需要 Management Token ({ManagementToken.TOKEN_PREFIX}xxx)",
|
||||
)
|
||||
|
||||
client_ip = get_client_ip(request)
|
||||
|
||||
result = await self.auth_service.authenticate_management_token(db, token, client_ip)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=401, detail="无效或过期的 Management Token")
|
||||
|
||||
user, management_token = result
|
||||
|
||||
# 存储到 request.state
|
||||
request.state.user_id = user.id
|
||||
request.state.management_token_id = management_token.id
|
||||
|
||||
return user, management_token
|
||||
|
||||
def _calculate_quota_remaining(self, user: Optional[User]) -> Optional[float]:
|
||||
if not user:
|
||||
|
||||
@@ -45,6 +45,29 @@ def format_tokens(num: int) -> str:
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_dashboard_stats(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
获取仪表盘统计数据
|
||||
|
||||
根据用户角色返回不同的统计数据。管理员可以看到全局数据,普通用户只能看到自己的数据。
|
||||
|
||||
**返回字段(管理员)**:
|
||||
- `stats`: 统计卡片数组,包含总请求、总费用、总Token、总缓存等信息
|
||||
- `today`: 今日统计(requests, cost, actual_cost, tokens, cache_creation_tokens, cache_read_tokens)
|
||||
- `api_keys`: API Key 统计(total, active)
|
||||
- `tokens`: 本月 Token 统计
|
||||
- `token_breakdown`: Token 详细分类(input, output, cache_creation, cache_read)
|
||||
- `system_health`: 系统健康指标(avg_response_time, error_rate, error_requests, fallback_count, total_requests)
|
||||
- `cost_stats`: 成本统计(total_cost, total_actual_cost, cost_savings)
|
||||
- `cache_stats`: 缓存统计信息
|
||||
- `users`: 用户统计(total, active)
|
||||
|
||||
**返回字段(普通用户)**:
|
||||
- `stats`: 统计卡片数组,包含 API 密钥、本月请求、配额使用、总Token 等信息
|
||||
- `today`: 今日统计
|
||||
- `token_breakdown`: Token 详细分类
|
||||
- `cache_stats`: 缓存统计信息
|
||||
- `monthly_cost`: 本月费用
|
||||
"""
|
||||
adapter = DashboardStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -55,6 +78,23 @@ async def get_recent_requests(
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取最近请求列表
|
||||
|
||||
获取最近的 API 请求记录。管理员可以看到所有用户的请求,普通用户只能看到自己的请求。
|
||||
|
||||
**查询参数**:
|
||||
- `limit`: 返回记录数,默认 10,最大 100
|
||||
|
||||
**返回字段**:
|
||||
- `requests`: 请求列表,每条记录包含:
|
||||
- `id`: 请求 ID
|
||||
- `user`: 用户名
|
||||
- `model`: 使用的模型
|
||||
- `tokens`: Token 数量
|
||||
- `time`: 请求时间(HH:MM 格式)
|
||||
- `is_stream`: 是否为流式请求
|
||||
"""
|
||||
adapter = DashboardRecentRequestsAdapter(limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -65,6 +105,17 @@ async def get_recent_requests(
|
||||
|
||||
@router.get("/provider-status")
|
||||
async def get_provider_status(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
获取提供商状态
|
||||
|
||||
获取所有活跃提供商的状态和最近 24 小时的请求统计。
|
||||
|
||||
**返回字段**:
|
||||
- `providers`: 提供商列表,每个提供商包含:
|
||||
- `name`: 提供商名称
|
||||
- `status`: 状态(active/inactive)
|
||||
- `requests`: 最近 24 小时的请求数
|
||||
"""
|
||||
adapter = DashboardProviderStatusAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -75,6 +126,28 @@ async def get_daily_stats(
|
||||
days: int = Query(7, ge=1, le=30),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取每日统计数据
|
||||
|
||||
获取指定天数的每日使用统计数据,用于生成图表。
|
||||
|
||||
**查询参数**:
|
||||
- `days`: 统计天数,默认 7 天,最大 30 天
|
||||
|
||||
**返回字段**:
|
||||
- `daily_stats`: 每日统计数组,每天包含:
|
||||
- `date`: 日期(ISO 格式)
|
||||
- `requests`: 请求数
|
||||
- `tokens`: Token 数量
|
||||
- `cost`: 费用(USD)
|
||||
- `avg_response_time`: 平均响应时间(秒)
|
||||
- `unique_models`: 使用的模型数量(仅管理员)
|
||||
- `unique_providers`: 使用的提供商数量(仅管理员)
|
||||
- `fallback_count`: 故障转移次数(仅管理员)
|
||||
- `model_breakdown`: 按模型分解的统计(仅管理员)
|
||||
- `model_summary`: 模型使用汇总,按费用排序
|
||||
- `period`: 统计周期信息(start_date, end_date, days)
|
||||
"""
|
||||
adapter = DashboardDailyStatsAdapter(days=days)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -693,7 +766,7 @@ class DashboardProviderStatusAdapter(DashboardAdapter):
|
||||
for provider in providers:
|
||||
count = (
|
||||
db.query(func.count(Usage.id))
|
||||
.filter(and_(Usage.provider == provider.name, Usage.created_at >= since))
|
||||
.filter(and_(Usage.provider_name == provider.name, Usage.created_at >= since))
|
||||
.scalar()
|
||||
)
|
||||
entries.append(
|
||||
@@ -781,7 +854,7 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
|
||||
.scalar() or 0
|
||||
)
|
||||
today_unique_providers = (
|
||||
db.query(func.count(func.distinct(Usage.provider)))
|
||||
db.query(func.count(func.distinct(Usage.provider_name)))
|
||||
.filter(Usage.created_at >= today)
|
||||
.scalar() or 0
|
||||
)
|
||||
|
||||
@@ -40,6 +40,7 @@ from src.core.exceptions import (
|
||||
UpstreamClientException,
|
||||
)
|
||||
from src.core.logger import logger
|
||||
from src.services.billing import calculate_request_cost as _calculate_request_cost
|
||||
from src.services.request.result import RequestResult
|
||||
from src.services.usage.recorder import UsageRecorder
|
||||
|
||||
@@ -63,6 +64,9 @@ class ChatAdapterBase(ApiAdapter):
|
||||
name: str = "chat.base"
|
||||
mode = ApiMode.STANDARD
|
||||
|
||||
# 计费模板配置(子类可覆盖,如 "claude", "openai", "gemini")
|
||||
BILLING_TEMPLATE: str = "claude"
|
||||
|
||||
# 子类可以配置的特殊方法(用于check_endpoint)
|
||||
@classmethod
|
||||
def build_endpoint_url(cls, base_url: str) -> str:
|
||||
@@ -486,40 +490,6 @@ class ChatAdapterBase(ApiAdapter):
|
||||
"""
|
||||
return input_tokens + cache_read_input_tokens
|
||||
|
||||
def get_cache_read_price_for_ttl(
|
||||
self,
|
||||
tier: dict,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
根据缓存 TTL 获取缓存读取价格
|
||||
|
||||
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
|
||||
子类可覆盖此方法实现不同的 TTL 定价逻辑
|
||||
|
||||
Args:
|
||||
tier: 当前阶梯配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
缓存读取价格(每 1M tokens)
|
||||
"""
|
||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||
if ttl_pricing and cache_ttl_minutes is not None:
|
||||
matched_price = None
|
||||
for ttl_config in ttl_pricing:
|
||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||
if cache_ttl_minutes <= ttl_limit:
|
||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
||||
break
|
||||
if matched_price is not None:
|
||||
return matched_price
|
||||
# 超过所有配置的 TTL,使用最后一个
|
||||
if ttl_pricing:
|
||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||
|
||||
return tier.get("cache_read_price_per_1m")
|
||||
|
||||
def compute_cost(
|
||||
self,
|
||||
input_tokens: int,
|
||||
@@ -537,8 +507,9 @@ class ChatAdapterBase(ApiAdapter):
|
||||
"""
|
||||
计算请求成本
|
||||
|
||||
默认实现:支持固定价格和阶梯计费
|
||||
子类可覆盖此方法实现完全不同的计费逻辑
|
||||
使用 billing 模块的配置驱动计费。
|
||||
子类可通过设置 BILLING_TEMPLATE 类属性来指定计费模板,
|
||||
或覆盖此方法实现完全自定义的计费逻辑。
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
@@ -566,88 +537,26 @@ class ChatAdapterBase(ApiAdapter):
|
||||
"tier_index": Optional[int], # 命中的阶梯索引
|
||||
}
|
||||
"""
|
||||
tier_index = None
|
||||
effective_input_price = input_price_per_1m
|
||||
effective_output_price = output_price_per_1m
|
||||
effective_cache_creation_price = cache_creation_price_per_1m
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
# 计算总输入上下文(使用子类可覆盖的方法)
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
|
||||
# 检查阶梯计费
|
||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
|
||||
|
||||
if tier:
|
||||
tier_index = tiered_pricing["tiers"].index(tier)
|
||||
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
|
||||
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
|
||||
effective_cache_creation_price = tier.get(
|
||||
"cache_creation_price_per_1m", cache_creation_price_per_1m
|
||||
)
|
||||
effective_cache_read_price = self.get_cache_read_price_for_ttl(
|
||||
tier, cache_ttl_minutes
|
||||
)
|
||||
if effective_cache_read_price is None:
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 计算各项成本
|
||||
input_cost = (input_tokens / 1_000_000) * effective_input_price
|
||||
output_cost = (output_tokens / 1_000_000) * effective_output_price
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * effective_cache_creation_price
|
||||
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
|
||||
cache_read_cost = (
|
||||
cache_read_input_tokens / 1_000_000
|
||||
) * effective_cache_read_price
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
request_cost = price_per_request if price_per_request else 0.0
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
|
||||
return {
|
||||
"input_cost": input_cost,
|
||||
"output_cost": output_cost,
|
||||
"cache_creation_cost": cache_creation_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_cost": cache_cost,
|
||||
"request_cost": request_cost,
|
||||
"total_cost": total_cost,
|
||||
"tier_index": tier_index,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
|
||||
"""
|
||||
根据总输入 token 数确定价格阶梯
|
||||
|
||||
Args:
|
||||
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
|
||||
total_input_tokens: 总输入 token 数
|
||||
|
||||
Returns:
|
||||
匹配的阶梯配置
|
||||
"""
|
||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
for tier in tiers:
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_tokens <= up_to:
|
||||
return tier
|
||||
|
||||
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
|
||||
return tiers[-1] if tiers else None
|
||||
return _calculate_request_cost(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
input_price_per_1m=input_price_per_1m,
|
||||
output_price_per_1m=output_price_per_1m,
|
||||
cache_creation_price_per_1m=cache_creation_price_per_1m,
|
||||
cache_read_price_per_1m=cache_read_price_per_1m,
|
||||
price_per_request=price_per_request,
|
||||
tiered_pricing=tiered_pricing,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
total_input_context=total_input_context,
|
||||
billing_template=self.BILLING_TEMPLATE,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 模型列表查询 - 子类应覆盖此方法
|
||||
|
||||
@@ -19,6 +19,7 @@ Chat Handler Base - Chat API 格式的通用基类
|
||||
- StreamTelemetryRecorder: 统计记录(Usage、Audit、Candidate)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, Optional
|
||||
|
||||
@@ -55,7 +56,6 @@ from src.models.database import (
|
||||
from src.services.provider.transport import build_provider_url
|
||||
|
||||
|
||||
|
||||
class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
"""
|
||||
Chat Handler 基类
|
||||
@@ -89,7 +89,9 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
user_agent: str,
|
||||
start_time: float,
|
||||
allowed_api_formats: Optional[list] = None,
|
||||
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
|
||||
adapter_detector: Optional[
|
||||
Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]
|
||||
] = None,
|
||||
):
|
||||
allowed = allowed_api_formats or [self.FORMAT_ID]
|
||||
super().__init__(
|
||||
@@ -459,14 +461,19 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
f"模型={ctx.model} -> {mapped_model or '无映射'}"
|
||||
)
|
||||
|
||||
# 发送请求(使用配置中的超时设置)
|
||||
# 配置 HTTP 超时
|
||||
# 注意:read timeout 用于检测连接断开,不是整体请求超时
|
||||
# 整体请求超时由 asyncio.wait_for 控制,使用 endpoint.timeout
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=config.http_connect_timeout,
|
||||
read=float(endpoint.timeout),
|
||||
read=config.http_read_timeout, # 使用全局配置,用于检测连接断开
|
||||
write=config.http_write_timeout,
|
||||
pool=config.http_pool_timeout,
|
||||
)
|
||||
|
||||
# endpoint.timeout 作为整体请求超时(建立连接 + 获取首字节)
|
||||
request_timeout = float(endpoint.timeout or 300)
|
||||
|
||||
# 创建 HTTP 客户端(支持代理配置)
|
||||
from src.clients.http_client import HTTPClientPool
|
||||
|
||||
@@ -474,7 +481,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
proxy_config=endpoint.proxy,
|
||||
timeout=timeout_config,
|
||||
)
|
||||
try:
|
||||
|
||||
# 用于存储内部函数的结果(必须在函数定义前声明,供 nonlocal 使用)
|
||||
byte_iterator: Any = None
|
||||
prefetched_chunks: Any = None
|
||||
response_ctx: Any = None
|
||||
|
||||
async def _connect_and_prefetch() -> None:
|
||||
"""建立连接并预读首字节(受整体超时控制)"""
|
||||
nonlocal byte_iterator, prefetched_chunks, response_ctx
|
||||
response_ctx = http_client.stream(
|
||||
"POST", url, json=provider_payload, headers=provider_headers
|
||||
)
|
||||
@@ -497,6 +512,28 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
max_prefetch_lines=config.stream_prefetch_lines,
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段
|
||||
# endpoint.timeout 控制整体超时,避免上游长时间无响应
|
||||
await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 整体请求超时(建立连接 + 获取首字节)
|
||||
# 清理可能已建立的连接上下文
|
||||
if response_ctx is not None:
|
||||
try:
|
||||
await response_ctx.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
await http_client.aclose()
|
||||
logger.warning(
|
||||
f" [{self.request_id}] 请求超时: Provider={provider.name}, timeout={request_timeout}s"
|
||||
)
|
||||
raise ProviderTimeoutException(
|
||||
provider_name=str(provider.name),
|
||||
timeout=int(request_timeout),
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_text = await self._extract_error_text(e)
|
||||
logger.error(f"Provider 返回错误: {e.response.status_code}\n Response: {error_text}")
|
||||
@@ -507,7 +544,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
|
||||
except EmbeddedErrorException:
|
||||
try:
|
||||
await response_ctx.__aexit__(None, None, None)
|
||||
if response_ctx is not None:
|
||||
await response_ctx.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
await http_client.aclose()
|
||||
@@ -517,6 +555,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
await http_client.aclose()
|
||||
raise
|
||||
|
||||
# 类型断言:成功执行后这些变量不会为 None
|
||||
assert byte_iterator is not None
|
||||
assert prefetched_chunks is not None
|
||||
assert response_ctx is not None
|
||||
|
||||
# 创建流生成器(传入字节流迭代器)
|
||||
return stream_processor.create_response_stream(
|
||||
ctx,
|
||||
@@ -639,17 +682,23 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
is_stream=False,
|
||||
)
|
||||
|
||||
logger.info(f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, "
|
||||
f"模型={model} -> {mapped_model or '无映射'}")
|
||||
logger.info(
|
||||
f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, "
|
||||
f"模型={model} -> {mapped_model or '无映射'}"
|
||||
)
|
||||
logger.debug(f" [{self.request_id}] 请求URL: {url}")
|
||||
logger.debug(f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}")
|
||||
logger.debug(
|
||||
f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}"
|
||||
)
|
||||
|
||||
# 创建 HTTP 客户端(支持代理配置)
|
||||
# endpoint.timeout 作为整体请求超时
|
||||
from src.clients.http_client import HTTPClientPool
|
||||
|
||||
request_timeout = float(endpoint.timeout or 300)
|
||||
http_client = HTTPClientPool.create_client_with_proxy(
|
||||
proxy_config=endpoint.proxy,
|
||||
timeout=httpx.Timeout(float(endpoint.timeout)),
|
||||
timeout=httpx.Timeout(request_timeout),
|
||||
)
|
||||
async with http_client:
|
||||
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
|
||||
@@ -670,7 +719,9 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
error_body = ""
|
||||
try:
|
||||
error_body = resp.text[:1000]
|
||||
logger.error(f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}")
|
||||
logger.error(
|
||||
f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise ProviderNotAvailableException(
|
||||
@@ -684,7 +735,9 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
error_body = ""
|
||||
try:
|
||||
error_body = resp.text[:1000]
|
||||
logger.warning(f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}")
|
||||
logger.warning(
|
||||
f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise ProviderNotAvailableException(
|
||||
@@ -765,8 +818,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
logger.debug(f"{self.FORMAT_ID} 非流式响应完成")
|
||||
|
||||
# 简洁的请求完成摘要
|
||||
logger.info(f"[OK] {self.request_id[:8]} | {model} | {provider_name or 'unknown'} | {response_time_ms}ms | "
|
||||
f"in:{input_tokens or 0} out:{output_tokens or 0}")
|
||||
logger.info(
|
||||
f"[OK] {self.request_id[:8]} | {model} | {provider_name or 'unknown'} | {response_time_ms}ms | "
|
||||
f"in:{input_tokens or 0} out:{output_tokens or 0}"
|
||||
)
|
||||
|
||||
return JSONResponse(status_code=status_code, content=response_json)
|
||||
|
||||
@@ -807,8 +862,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
error_bytes = await e.response.aread()
|
||||
return error_bytes.decode("utf-8", errors="replace")
|
||||
else:
|
||||
return (
|
||||
e.response.text if hasattr(e.response, "_content") else "Unable to read"
|
||||
)
|
||||
return e.response.text if hasattr(e.response, "_content") else "Unable to read"
|
||||
except Exception as decode_error:
|
||||
return f"Unable to read error: {decode_error}"
|
||||
|
||||
@@ -38,6 +38,7 @@ from src.core.exceptions import (
|
||||
UpstreamClientException,
|
||||
)
|
||||
from src.core.logger import logger
|
||||
from src.services.billing import calculate_request_cost as _calculate_request_cost
|
||||
from src.services.request.result import RequestResult
|
||||
from src.services.usage.recorder import UsageRecorder
|
||||
|
||||
@@ -61,6 +62,9 @@ class CliAdapterBase(ApiAdapter):
|
||||
name: str = "cli.base"
|
||||
mode = ApiMode.PROXY
|
||||
|
||||
# 计费模板配置(子类可覆盖,如 "claude", "openai", "gemini")
|
||||
BILLING_TEMPLATE: str = "claude"
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||
|
||||
@@ -438,40 +442,6 @@ class CliAdapterBase(ApiAdapter):
|
||||
"""
|
||||
return input_tokens + cache_read_input_tokens
|
||||
|
||||
def get_cache_read_price_for_ttl(
|
||||
self,
|
||||
tier: dict,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
根据缓存 TTL 获取缓存读取价格
|
||||
|
||||
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
|
||||
子类可覆盖此方法实现不同的 TTL 定价逻辑
|
||||
|
||||
Args:
|
||||
tier: 当前阶梯配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
缓存读取价格(每 1M tokens)
|
||||
"""
|
||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||
if ttl_pricing and cache_ttl_minutes is not None:
|
||||
matched_price = None
|
||||
for ttl_config in ttl_pricing:
|
||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||
if cache_ttl_minutes <= ttl_limit:
|
||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
||||
break
|
||||
if matched_price is not None:
|
||||
return matched_price
|
||||
# 超过所有配置的 TTL,使用最后一个
|
||||
if ttl_pricing:
|
||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||
|
||||
return tier.get("cache_read_price_per_1m")
|
||||
|
||||
def compute_cost(
|
||||
self,
|
||||
input_tokens: int,
|
||||
@@ -489,8 +459,9 @@ class CliAdapterBase(ApiAdapter):
|
||||
"""
|
||||
计算请求成本
|
||||
|
||||
默认实现:支持固定价格和阶梯计费
|
||||
子类可覆盖此方法实现完全不同的计费逻辑
|
||||
使用 billing 模块的配置驱动计费。
|
||||
子类可通过设置 BILLING_TEMPLATE 类属性来指定计费模板,
|
||||
或覆盖此方法实现完全自定义的计费逻辑。
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
@@ -508,78 +479,26 @@ class CliAdapterBase(ApiAdapter):
|
||||
Returns:
|
||||
包含各项成本的字典
|
||||
"""
|
||||
tier_index = None
|
||||
effective_input_price = input_price_per_1m
|
||||
effective_output_price = output_price_per_1m
|
||||
effective_cache_creation_price = cache_creation_price_per_1m
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
# 计算总输入上下文(使用子类可覆盖的方法)
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
|
||||
# 检查阶梯计费
|
||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
|
||||
|
||||
if tier:
|
||||
tier_index = tiered_pricing["tiers"].index(tier)
|
||||
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
|
||||
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
|
||||
effective_cache_creation_price = tier.get(
|
||||
"cache_creation_price_per_1m", cache_creation_price_per_1m
|
||||
)
|
||||
effective_cache_read_price = self.get_cache_read_price_for_ttl(
|
||||
tier, cache_ttl_minutes
|
||||
)
|
||||
if effective_cache_read_price is None:
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 计算各项成本
|
||||
input_cost = (input_tokens / 1_000_000) * effective_input_price
|
||||
output_cost = (output_tokens / 1_000_000) * effective_output_price
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * effective_cache_creation_price
|
||||
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
|
||||
cache_read_cost = (
|
||||
cache_read_input_tokens / 1_000_000
|
||||
) * effective_cache_read_price
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
request_cost = price_per_request if price_per_request else 0.0
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
|
||||
return {
|
||||
"input_cost": input_cost,
|
||||
"output_cost": output_cost,
|
||||
"cache_creation_cost": cache_creation_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_cost": cache_cost,
|
||||
"request_cost": request_cost,
|
||||
"total_cost": total_cost,
|
||||
"tier_index": tier_index,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
|
||||
"""根据总输入 token 数确定价格阶梯"""
|
||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
for tier in tiers:
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_tokens <= up_to:
|
||||
return tier
|
||||
|
||||
return tiers[-1] if tiers else None
|
||||
return _calculate_request_cost(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
input_price_per_1m=input_price_per_1m,
|
||||
output_price_per_1m=output_price_per_1m,
|
||||
cache_creation_price_per_1m=cache_creation_price_per_1m,
|
||||
cache_read_price_per_1m=cache_read_price_per_1m,
|
||||
price_per_request=price_per_request,
|
||||
tiered_pricing=tiered_pricing,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
total_input_context=total_input_context,
|
||||
billing_template=self.BILLING_TEMPLATE,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 模型列表查询 - 子类应覆盖此方法
|
||||
|
||||
@@ -33,19 +33,21 @@ from src.api.handlers.base.base_handler import (
|
||||
)
|
||||
from src.api.handlers.base.parsers import get_parser_for_format
|
||||
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
from src.api.handlers.base.utils import (
|
||||
build_sse_headers,
|
||||
check_html_response,
|
||||
check_prefetched_response_error,
|
||||
)
|
||||
from src.core.error_utils import extract_error_message
|
||||
|
||||
# 直接从具体模块导入,避免循环依赖
|
||||
from src.api.handlers.base.response_parser import (
|
||||
ResponseParser,
|
||||
StreamStats,
|
||||
)
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
from src.api.handlers.base.utils import (
|
||||
build_sse_headers,
|
||||
check_html_response,
|
||||
check_prefetched_response_error,
|
||||
)
|
||||
from src.config.constants import StreamDefaults
|
||||
from src.config.settings import config
|
||||
from src.core.error_utils import extract_error_message
|
||||
from src.core.exceptions import (
|
||||
EmbeddedErrorException,
|
||||
ProviderAuthException,
|
||||
@@ -62,8 +64,6 @@ from src.models.database import (
|
||||
ProviderEndpoint,
|
||||
User,
|
||||
)
|
||||
from src.config.constants import StreamDefaults
|
||||
from src.config.settings import config
|
||||
from src.services.provider.transport import build_provider_url
|
||||
from src.utils.sse_parser import SSEEventParser
|
||||
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
|
||||
@@ -100,7 +100,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
user_agent: str,
|
||||
start_time: float,
|
||||
allowed_api_formats: Optional[list] = None,
|
||||
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
|
||||
adapter_detector: Optional[
|
||||
Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]
|
||||
] = None,
|
||||
):
|
||||
allowed = allowed_api_formats or [self.FORMAT_ID]
|
||||
super().__init__(
|
||||
@@ -158,7 +160,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
mapper = ModelMapperMiddleware(self.db)
|
||||
mapping = await mapper.get_mapping(source_model, provider_id)
|
||||
|
||||
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
|
||||
logger.debug(
|
||||
f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}"
|
||||
)
|
||||
|
||||
if mapping and mapping.model:
|
||||
# 使用 select_provider_model_name 支持模型映射功能
|
||||
@@ -168,7 +172,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
mapped_name = mapping.model.select_provider_model_name(
|
||||
affinity_key, api_format=self.FORMAT_ID
|
||||
)
|
||||
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
||||
logger.debug(
|
||||
f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)"
|
||||
)
|
||||
return mapped_name
|
||||
|
||||
logger.debug(f"[CLI] 无模型映射,使用原始名称: {source_model}")
|
||||
@@ -459,18 +465,26 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
is_stream=True, # CLI handler 处理流式请求
|
||||
)
|
||||
|
||||
# 配置超时
|
||||
# 配置 HTTP 超时
|
||||
# 注意:read timeout 用于检测连接断开,不是整体请求超时
|
||||
# 整体请求超时由 _connect_and_prefetch 内部的 asyncio.wait_for 控制
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=10.0,
|
||||
read=float(endpoint.timeout),
|
||||
write=60.0, # 写入超时增加到60秒,支持大请求体(如包含图片的长对话)
|
||||
pool=10.0,
|
||||
connect=config.http_connect_timeout,
|
||||
read=config.http_read_timeout, # 使用全局配置,用于检测连接断开
|
||||
write=config.http_write_timeout,
|
||||
pool=config.http_pool_timeout,
|
||||
)
|
||||
|
||||
logger.debug(f" └─ [{self.request_id}] 发送流式请求: "
|
||||
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||
f"Key=***{key.api_key[-4:]}, "
|
||||
f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
|
||||
# endpoint.timeout 作为整体请求超时(建立连接 + 获取首字节)
|
||||
request_timeout = float(endpoint.timeout or 300)
|
||||
|
||||
logger.debug(
|
||||
f" └─ [{self.request_id}] 发送流式请求: "
|
||||
f"Provider={provider.name}, Endpoint={endpoint.id[:8] if endpoint.id else 'N/A'}..., "
|
||||
f"Key=***{key.api_key[-4:] if key.api_key else 'N/A'}, "
|
||||
f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}, "
|
||||
f"timeout={request_timeout}s"
|
||||
)
|
||||
|
||||
# 创建 HTTP 客户端(支持代理配置)
|
||||
from src.clients.http_client import HTTPClientPool
|
||||
@@ -479,7 +493,15 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
proxy_config=endpoint.proxy,
|
||||
timeout=timeout_config,
|
||||
)
|
||||
try:
|
||||
|
||||
# 用于存储内部函数的结果(必须在函数定义前声明,供 nonlocal 使用)
|
||||
byte_iterator: Any = None
|
||||
prefetched_chunks: Any = None
|
||||
response_ctx: Any = None
|
||||
|
||||
async def _connect_and_prefetch() -> None:
|
||||
"""建立连接并预读首字节(受整体超时控制)"""
|
||||
nonlocal byte_iterator, prefetched_chunks, response_ctx
|
||||
response_ctx = http_client.stream(
|
||||
"POST", url, json=provider_payload, headers=provider_headers
|
||||
)
|
||||
@@ -500,9 +522,33 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
byte_iterator, provider, endpoint, ctx
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段
|
||||
# endpoint.timeout 控制整体超时,避免上游长时间无响应
|
||||
await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 整体请求超时(建立连接 + 获取首字节)
|
||||
# 清理可能已建立的连接上下文
|
||||
if response_ctx is not None:
|
||||
try:
|
||||
await response_ctx.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
await http_client.aclose()
|
||||
logger.warning(
|
||||
f" [{self.request_id}] 请求超时: Provider={provider.name}, timeout={request_timeout}s"
|
||||
)
|
||||
raise ProviderTimeoutException(
|
||||
provider_name=str(provider.name),
|
||||
timeout=int(request_timeout),
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_text = await self._extract_error_text(e)
|
||||
logger.error(f"Provider 返回错误状态: {e.response.status_code}\n Response: {error_text}")
|
||||
logger.error(
|
||||
f"Provider 返回错误状态: {e.response.status_code}\n Response: {error_text}"
|
||||
)
|
||||
await http_client.aclose()
|
||||
# 将上游错误信息附加到异常,以便故障转移时能够返回给客户端
|
||||
e.upstream_response = error_text # type: ignore[attr-defined]
|
||||
@@ -511,7 +557,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
except EmbeddedErrorException:
|
||||
# 嵌套错误需要触发重试,关闭连接后重新抛出
|
||||
try:
|
||||
await response_ctx.__aexit__(None, None, None)
|
||||
if response_ctx is not None:
|
||||
await response_ctx.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
await http_client.aclose()
|
||||
@@ -521,6 +568,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
await http_client.aclose()
|
||||
raise
|
||||
|
||||
# 类型断言:成功执行后这些变量不会为 None
|
||||
assert byte_iterator is not None
|
||||
assert prefetched_chunks is not None
|
||||
assert response_ctx is not None
|
||||
|
||||
# 创建流生成器(带预读数据,使用同一个迭代器)
|
||||
return self._create_response_stream_with_prefetch(
|
||||
ctx,
|
||||
@@ -593,7 +645,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
},
|
||||
}
|
||||
self._mark_first_output(ctx, output_state)
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode(
|
||||
"utf-8"
|
||||
)
|
||||
return # 结束生成器
|
||||
|
||||
# 格式转换或直接透传
|
||||
@@ -801,10 +855,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
if isinstance(data, dict) and provider_parser.is_error_response(data):
|
||||
# 提取错误信息
|
||||
parsed = provider_parser.parse_response(data, 200)
|
||||
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
|
||||
logger.warning(
|
||||
f" [{self.request_id}] 检测到嵌套错误: "
|
||||
f"Provider={provider.name}, "
|
||||
f"error_type={parsed.error_type}, "
|
||||
f"message={parsed.error_message}")
|
||||
f"message={parsed.error_message}"
|
||||
)
|
||||
raise EmbeddedErrorException(
|
||||
provider_name=str(provider.name),
|
||||
error_code=(
|
||||
@@ -849,14 +905,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
raise
|
||||
except (OSError, IOError) as e:
|
||||
# 网络 I/O 异常:记录警告,可能需要重试
|
||||
logger.warning(
|
||||
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
|
||||
)
|
||||
logger.warning(f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}")
|
||||
except Exception as e:
|
||||
# 未预期的严重异常:记录错误并重新抛出,避免掩盖问题
|
||||
logger.error(
|
||||
f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}",
|
||||
exc_info=True
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -979,7 +1033,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
},
|
||||
}
|
||||
self._mark_first_output(ctx, output_state)
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode(
|
||||
"utf-8"
|
||||
)
|
||||
return
|
||||
|
||||
# 格式转换或直接透传
|
||||
@@ -1255,8 +1311,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
)
|
||||
logger.debug(f"{self.FORMAT_ID} 流式响应中断")
|
||||
# 简洁的请求失败摘要(包含预估 token 信息)
|
||||
logger.info(f"[FAIL] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | "
|
||||
f"{ctx.status_code} | in:{actual_input_tokens} out:{ctx.output_tokens} cache:{ctx.cached_tokens}")
|
||||
logger.info(
|
||||
f"[FAIL] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | "
|
||||
f"{ctx.status_code} | in:{actual_input_tokens} out:{ctx.output_tokens} cache:{ctx.cached_tokens}"
|
||||
)
|
||||
else:
|
||||
# 在记录统计前,允许子类从 parsed_chunks 中提取额外的元数据
|
||||
self._finalize_stream_metadata(ctx)
|
||||
@@ -1289,9 +1347,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
)
|
||||
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
|
||||
# 简洁的请求完成摘要(两行格式)
|
||||
line1 = (
|
||||
f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name}"
|
||||
)
|
||||
line1 = f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name}"
|
||||
if ctx.first_byte_time_ms:
|
||||
line1 += f" | TTFB: {ctx.first_byte_time_ms}ms"
|
||||
|
||||
@@ -1314,7 +1370,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
RequestCandidateService.mark_candidate_failed(
|
||||
db=bg_db,
|
||||
candidate_id=ctx.attempt_id,
|
||||
error_type="client_disconnected" if ctx.status_code == 499 else "stream_error",
|
||||
error_type=(
|
||||
"client_disconnected" if ctx.status_code == 499 else "stream_error"
|
||||
),
|
||||
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
|
||||
status_code=ctx.status_code,
|
||||
latency_ms=response_time_ms,
|
||||
@@ -1469,17 +1527,21 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
is_stream=False, # 非流式请求
|
||||
)
|
||||
|
||||
logger.info(f" └─ [{self.request_id}] 发送非流式请求: "
|
||||
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||
f"Key=***{key.api_key[-4:]}, "
|
||||
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
|
||||
logger.info(
|
||||
f" └─ [{self.request_id}] 发送非流式请求: "
|
||||
f"Provider={provider.name}, Endpoint={endpoint.id[:8] if endpoint.id else 'N/A'}..., "
|
||||
f"Key=***{key.api_key[-4:] if key.api_key else 'N/A'}, "
|
||||
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}"
|
||||
)
|
||||
|
||||
# 创建 HTTP 客户端(支持代理配置)
|
||||
# endpoint.timeout 作为整体请求超时
|
||||
from src.clients.http_client import HTTPClientPool
|
||||
|
||||
request_timeout = float(endpoint.timeout or 300)
|
||||
http_client = HTTPClientPool.create_client_with_proxy(
|
||||
proxy_config=endpoint.proxy,
|
||||
timeout=httpx.Timeout(float(endpoint.timeout)),
|
||||
timeout=httpx.Timeout(request_timeout),
|
||||
)
|
||||
async with http_client:
|
||||
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
|
||||
@@ -1497,8 +1559,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
retry_after=int(resp.headers.get("retry-after", 0)) or None,
|
||||
)
|
||||
elif resp.status_code >= 500:
|
||||
error_text = resp.text
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}"
|
||||
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}",
|
||||
provider_name=str(provider.name),
|
||||
upstream_status=resp.status_code,
|
||||
upstream_response=error_text,
|
||||
)
|
||||
elif 300 <= resp.status_code < 400:
|
||||
redirect_url = resp.headers.get("location", "unknown")
|
||||
@@ -1508,7 +1574,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
elif resp.status_code != 200:
|
||||
error_text = resp.text
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}, 错误: {error_text[:200]}"
|
||||
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
|
||||
provider_name=str(provider.name),
|
||||
upstream_status=resp.status_code,
|
||||
upstream_response=error_text,
|
||||
)
|
||||
|
||||
# 安全解析 JSON 响应,处理可能的编码错误
|
||||
@@ -1518,9 +1587,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
# 记录原始响应信息用于调试
|
||||
content_type = resp.headers.get("content-type", "unknown")
|
||||
content_encoding = resp.headers.get("content-encoding", "none")
|
||||
logger.error(f"[{self.request_id}] 无法解析响应 JSON: {e}, "
|
||||
logger.error(
|
||||
f"[{self.request_id}] 无法解析响应 JSON: {e}, "
|
||||
f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, "
|
||||
f"响应长度: {len(resp.content)} bytes")
|
||||
f"响应长度: {len(resp.content)} bytes"
|
||||
)
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}"
|
||||
)
|
||||
|
||||
@@ -63,6 +63,7 @@ class ClaudeChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "CLAUDE"
|
||||
BILLING_TEMPLATE = "claude" # 使用 Claude 计费模板
|
||||
name = "claude.chat"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,6 +24,7 @@ class ClaudeCliAdapter(CliAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "CLAUDE_CLI"
|
||||
BILLING_TEMPLATE = "claude" # 使用 Claude 计费模板
|
||||
name = "claude.cli"
|
||||
|
||||
@property
|
||||
|
||||
@@ -27,6 +27,7 @@ class GeminiChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "GEMINI"
|
||||
BILLING_TEMPLATE = "gemini" # 使用 Gemini 计费模板
|
||||
name = "gemini.chat"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,6 +24,7 @@ class GeminiCliAdapter(CliAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "GEMINI_CLI"
|
||||
BILLING_TEMPLATE = "gemini" # 使用 Gemini 计费模板
|
||||
name = "gemini.cli"
|
||||
|
||||
@property
|
||||
|
||||
@@ -26,6 +26,7 @@ class OpenAIChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "OPENAI"
|
||||
BILLING_TEMPLATE = "openai" # 使用 OpenAI 计费模板
|
||||
name = "openai.chat"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,6 +24,7 @@ class OpenAICliAdapter(CliAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "OPENAI_CLI"
|
||||
BILLING_TEMPLATE = "openai" # 使用 OpenAI 计费模板
|
||||
name = "openai.cli"
|
||||
|
||||
@property
|
||||
|
||||
@@ -28,12 +28,48 @@ async def get_my_audit_logs(
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取我的审计日志
|
||||
|
||||
获取当前用户的审计日志记录。需要登录。
|
||||
|
||||
**查询参数**:
|
||||
- `event_type`: 可选,事件类型筛选
|
||||
- `days`: 查询最近多少天的日志,默认 30 天
|
||||
- `limit`: 返回数量限制,默认 50
|
||||
- `offset`: 分页偏移量,默认 0
|
||||
|
||||
**返回字段**:
|
||||
- `items`: 审计日志列表,每条日志包含:
|
||||
- `id`: 日志 ID
|
||||
- `event_type`: 事件类型
|
||||
- `description`: 事件描述
|
||||
- `ip_address`: IP 地址
|
||||
- `status_code`: HTTP 状态码
|
||||
- `created_at`: 创建时间
|
||||
- `meta`: 分页元数据(total, limit, offset, count)
|
||||
- `filters`: 筛选条件
|
||||
"""
|
||||
adapter = UserAuditLogsAdapter(event_type=event_type, days=days, limit=limit, offset=offset)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/rate-limit-status")
|
||||
async def get_rate_limit_status(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
获取速率限制状态
|
||||
|
||||
获取当前用户所有活跃 API Key 的速率限制状态。需要登录。
|
||||
|
||||
**返回字段**:
|
||||
- `user_id`: 用户 ID
|
||||
- `api_keys`: API Key 限流状态列表,每个包含:
|
||||
- `api_key_name`: API Key 名称
|
||||
- `limit`: 速率限制上限
|
||||
- `remaining`: 剩余可用次数
|
||||
- `reset_time`: 限制重置时间
|
||||
- `window`: 时间窗口
|
||||
"""
|
||||
adapter = UserRateLimitStatusAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@@ -13,12 +13,26 @@ from src.core.key_capabilities import (
|
||||
)
|
||||
from src.database import get_db
|
||||
|
||||
router = APIRouter(prefix="/api/capabilities", tags=["Capabilities"])
|
||||
router = APIRouter(prefix="/api/capabilities", tags=["System Catalog"])
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_capabilities():
|
||||
"""获取所有能力定义"""
|
||||
"""
|
||||
获取所有能力定义
|
||||
|
||||
返回系统中定义的所有能力(capabilities),包括用户可配置和系统内部使用的能力。
|
||||
能力用于描述模型支持的功能特性,如视觉输入、函数调用、流式输出等。
|
||||
|
||||
**返回字段**
|
||||
- capabilities: 能力列表,每个能力包含:
|
||||
- name: 能力的唯一标识符(如 vision、function_calling)
|
||||
- display_name: 能力的显示名称(如"视觉输入"、"函数调用")
|
||||
- short_name: 能力的简短名称(如"视觉"、"函数")
|
||||
- description: 能力的详细描述
|
||||
- match_mode: 匹配模式(exact 精确匹配,fuzzy 模糊匹配,prefix 前缀匹配等)
|
||||
- config_mode: 配置模式(user_configurable 用户可配置,system_only 仅系统使用)
|
||||
"""
|
||||
return {
|
||||
"capabilities": [
|
||||
{
|
||||
@@ -36,7 +50,21 @@ async def list_capabilities():
|
||||
|
||||
@router.get("/user-configurable")
|
||||
async def list_user_configurable_capabilities():
|
||||
"""获取用户可配置的能力列表(用于前端展示配置选项)"""
|
||||
"""
|
||||
获取用户可配置的能力列表
|
||||
|
||||
返回允许用户在 API Key 中配置的能力列表,用于前端展示配置选项。
|
||||
用户可以通过配置这些能力来限制或指定 API Key 可以访问的模型功能。
|
||||
|
||||
**返回字段**
|
||||
- capabilities: 用户可配置的能力列表,每个能力包含:
|
||||
- name: 能力的唯一标识符
|
||||
- display_name: 能力的显示名称
|
||||
- short_name: 能力的简短名称
|
||||
- description: 能力的详细描述
|
||||
- match_mode: 匹配模式(exact、fuzzy、prefix 等)
|
||||
- config_mode: 配置模式(此接口返回的都是 user_configurable)
|
||||
"""
|
||||
return {
|
||||
"capabilities": [
|
||||
{
|
||||
@@ -60,11 +88,24 @@ async def get_model_supported_capabilities(
|
||||
"""
|
||||
获取指定模型支持的能力列表
|
||||
|
||||
Args:
|
||||
model_name: 模型名称(如 claude-sonnet-4-20250514,必须是 GlobalModel.name)
|
||||
根据全局模型名称(GlobalModel.name)查询该模型支持的能力,
|
||||
并返回每个能力的详细定义。只查询活跃的全局模型。
|
||||
|
||||
Returns:
|
||||
模型支持的能力列表,以及每个能力的详细定义
|
||||
**路径参数**
|
||||
- model_name: 全局模型名称(如 claude-sonnet-4-20250514,必须是 GlobalModel.name)
|
||||
|
||||
**返回字段**
|
||||
- model: 查询的模型名称
|
||||
- global_model_id: 全局模型的 UUID
|
||||
- global_model_name: 全局模型的标准名称
|
||||
- supported_capabilities: 该模型支持的能力名称列表
|
||||
- capability_details: 支持的能力详细信息列表,每个能力包含:
|
||||
- name: 能力标识符
|
||||
- display_name: 能力显示名称
|
||||
- description: 能力描述
|
||||
- match_mode: 匹配模式
|
||||
- config_mode: 配置模式
|
||||
- error: 错误信息(仅在模型不存在时返回)
|
||||
"""
|
||||
from src.models.database import GlobalModel
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ from src.models.endpoint_models import (
|
||||
)
|
||||
from src.services.health.endpoint import EndpointHealthService
|
||||
|
||||
router = APIRouter(prefix="/api/public", tags=["Public Catalog"])
|
||||
router = APIRouter(prefix="/api/public", tags=["System Catalog"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@@ -49,7 +49,29 @@ async def get_public_providers(
|
||||
limit: int = Query(100, description="返回记录数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取提供商列表(用户视图)。"""
|
||||
"""
|
||||
获取提供商列表(用户视图)
|
||||
|
||||
返回系统中可用的提供商列表,包含提供商的基本信息和统计数据。
|
||||
默认只返回活跃的提供商。
|
||||
|
||||
**查询参数**
|
||||
- is_active: 可选,过滤活跃状态。None 表示只返回活跃提供商,True 返回活跃,False 返回非活跃
|
||||
- skip: 跳过的记录数,用于分页,默认 0
|
||||
- limit: 返回记录数限制,默认 100,最大 100
|
||||
|
||||
**返回字段**
|
||||
- id: 提供商唯一标识符
|
||||
- name: 提供商名称(英文标识)
|
||||
- display_name: 提供商显示名称
|
||||
- description: 提供商描述信息
|
||||
- is_active: 是否活跃
|
||||
- provider_priority: 提供商优先级
|
||||
- models_count: 该提供商下的模型总数
|
||||
- active_models_count: 该提供商下活跃的模型数
|
||||
- endpoints_count: 该提供商下的端点总数
|
||||
- active_endpoints_count: 该提供商下活跃的端点数
|
||||
"""
|
||||
|
||||
adapter = PublicProvidersAdapter(is_active=is_active, skip=skip, limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
@@ -64,6 +86,37 @@ async def get_public_models(
|
||||
limit: int = Query(100, description="返回记录数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取模型列表(用户视图)
|
||||
|
||||
返回系统中可用的模型列表,包含模型的详细信息和定价。
|
||||
默认只返回活跃提供商下的活跃模型。
|
||||
|
||||
**查询参数**
|
||||
- provider_id: 可选,按提供商 ID 过滤,只返回该提供商下的模型
|
||||
- is_active: 可选,过滤活跃状态(当前未使用,始终返回活跃模型)
|
||||
- skip: 跳过的记录数,用于分页,默认 0
|
||||
- limit: 返回记录数限制,默认 100,最大 100
|
||||
|
||||
**返回字段**
|
||||
- id: 模型唯一标识符
|
||||
- provider_id: 所属提供商 ID
|
||||
- provider_name: 提供商名称
|
||||
- provider_display_name: 提供商显示名称
|
||||
- name: 模型统一名称(优先使用 GlobalModel 名称)
|
||||
- display_name: 模型显示名称
|
||||
- description: 模型描述信息
|
||||
- tags: 模型标签(当前为 null)
|
||||
- icon_url: 模型图标 URL
|
||||
- input_price_per_1m: 输入价格(每 100 万 token)
|
||||
- output_price_per_1m: 输出价格(每 100 万 token)
|
||||
- cache_creation_price_per_1m: 缓存创建价格(每 100 万 token)
|
||||
- cache_read_price_per_1m: 缓存读取价格(每 100 万 token)
|
||||
- supports_vision: 是否支持视觉输入
|
||||
- supports_function_calling: 是否支持函数调用
|
||||
- supports_streaming: 是否支持流式输出
|
||||
- is_active: 是否活跃
|
||||
"""
|
||||
adapter = PublicModelsAdapter(
|
||||
provider_id=provider_id, is_active=is_active, skip=skip, limit=limit
|
||||
)
|
||||
@@ -72,6 +125,19 @@ async def get_public_models(
|
||||
|
||||
@router.get("/stats", response_model=ProviderStatsResponse)
|
||||
async def get_public_stats(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
获取系统统计信息
|
||||
|
||||
返回系统的整体统计数据,包括提供商数量、模型数量和支持的 API 格式。
|
||||
只统计活跃的提供商和模型。
|
||||
|
||||
**返回字段**
|
||||
- total_providers: 活跃提供商总数
|
||||
- active_providers: 活跃提供商数量(与 total_providers 相同)
|
||||
- total_models: 活跃模型总数
|
||||
- active_models: 活跃模型数量(与 total_models 相同)
|
||||
- supported_formats: 支持的 API 格式列表(如 claude、openai、gemini 等)
|
||||
"""
|
||||
adapter = PublicStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
@@ -84,6 +150,37 @@ async def search_models(
|
||||
limit: int = Query(20, description="返回记录数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
搜索模型
|
||||
|
||||
根据关键词搜索模型,支持按模型名称、显示名称等字段进行模糊匹配。
|
||||
只返回活跃提供商下的活跃模型。
|
||||
|
||||
**查询参数**
|
||||
- q: 必填,搜索关键词,支持模糊匹配模型的 provider_model_name、GlobalModel.name 或 GlobalModel.display_name
|
||||
- provider_id: 可选,按提供商 ID 过滤,只在该提供商下搜索
|
||||
- limit: 返回记录数限制,默认 20,最大值取决于系统配置
|
||||
|
||||
**返回字段**
|
||||
返回符合条件的模型列表,字段与 /api/public/models 接口相同:
|
||||
- id: 模型唯一标识符
|
||||
- provider_id: 所属提供商 ID
|
||||
- provider_name: 提供商名称
|
||||
- provider_display_name: 提供商显示名称
|
||||
- name: 模型统一名称
|
||||
- display_name: 模型显示名称
|
||||
- description: 模型描述
|
||||
- tags: 模型标签
|
||||
- icon_url: 模型图标 URL
|
||||
- input_price_per_1m: 输入价格(每 100 万 token)
|
||||
- output_price_per_1m: 输出价格(每 100 万 token)
|
||||
- cache_creation_price_per_1m: 缓存创建价格(每 100 万 token)
|
||||
- cache_read_price_per_1m: 缓存读取价格(每 100 万 token)
|
||||
- supports_vision: 是否支持视觉
|
||||
- supports_function_calling: 是否支持函数调用
|
||||
- supports_streaming: 是否支持流式输出
|
||||
- is_active: 是否活跃
|
||||
"""
|
||||
adapter = PublicSearchModelsAdapter(query=q, provider_id=provider_id, limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
@@ -95,7 +192,37 @@ async def get_public_api_format_health(
|
||||
per_format_limit: int = Query(100, ge=10, le=500, description="每个格式的事件数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取各 API 格式的健康监控数据(公开版,不含敏感信息)"""
|
||||
"""
|
||||
获取各 API 格式的健康监控数据
|
||||
|
||||
返回系统中各 API 格式(如 Claude、OpenAI、Gemini)的健康状态和历史事件。
|
||||
公开版本,不包含敏感信息(如 provider_id、key_id 等)。
|
||||
|
||||
**查询参数**
|
||||
- lookback_hours: 回溯的时间范围(小时),默认 6 小时,范围 1-168(7 天)
|
||||
- per_format_limit: 每个 API 格式返回的历史事件数量上限,默认 100,范围 10-500
|
||||
|
||||
**返回字段**
|
||||
- generated_at: 响应生成时间
|
||||
- formats: API 格式健康监控数据列表,每个格式包含:
|
||||
- api_format: API 格式名称(如 claude、openai、gemini)
|
||||
- api_path: 本站入口路径
|
||||
- total_attempts: 总请求尝试次数
|
||||
- success_count: 成功次数
|
||||
- failed_count: 失败次数
|
||||
- skipped_count: 跳过次数
|
||||
- success_rate: 成功率(success / (success + failed))
|
||||
- last_event_at: 最后事件时间
|
||||
- events: 历史事件列表,按时间倒序,每个事件包含:
|
||||
- timestamp: 事件时间
|
||||
- status: 状态(success、failed、skipped)
|
||||
- status_code: HTTP 状态码
|
||||
- latency_ms: 延迟(毫秒)
|
||||
- error_type: 错误类型(如果失败)
|
||||
- timeline: 时间线数据,用于展示请求量趋势
|
||||
- time_range_start: 时间范围起始
|
||||
- time_range_end: 时间范围结束
|
||||
"""
|
||||
adapter = PublicApiFormatHealthMonitorAdapter(
|
||||
lookback_hours=lookback_hours,
|
||||
per_format_limit=per_format_limit,
|
||||
@@ -112,7 +239,30 @@ async def get_public_global_models(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取 GlobalModel 列表(用户视图,只读)"""
|
||||
"""
|
||||
获取全局模型(GlobalModel)列表
|
||||
|
||||
返回系统定义的全局模型列表,用于统一不同提供商的模型标识。
|
||||
默认只返回活跃的全局模型。
|
||||
|
||||
**查询参数**
|
||||
- skip: 跳过的记录数,用于分页,默认 0,最小 0
|
||||
- limit: 返回记录数限制,默认 100,范围 1-1000
|
||||
- is_active: 可选,过滤活跃状态。None 表示只返回活跃模型,True 返回活跃,False 返回非活跃
|
||||
- search: 可选,搜索关键词,支持模糊匹配模型名称(name)和显示名称(display_name)
|
||||
|
||||
**返回字段**
|
||||
- models: 全局模型列表,每个模型包含:
|
||||
- id: 全局模型唯一标识符(UUID)
|
||||
- name: 模型名称(统一标识符)
|
||||
- display_name: 模型显示名称
|
||||
- is_active: 是否活跃
|
||||
- default_price_per_request: 默认的按请求计价配置
|
||||
- default_tiered_pricing: 默认的阶梯定价配置
|
||||
- supported_capabilities: 支持的能力列表(如 vision、function_calling 等)
|
||||
- config: 模型配置信息(如 description、icon_url 等)
|
||||
- total: 符合条件的模型总数
|
||||
"""
|
||||
adapter = PublicGlobalModelsAdapter(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
|
||||
@@ -29,7 +29,27 @@ async def create_message(
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""统一入口:根据 x-app 自动在标准/Claude Code 之间切换。"""
|
||||
"""
|
||||
Claude Messages API
|
||||
|
||||
兼容 Anthropic Claude Messages API 格式的代理接口。
|
||||
根据请求头 `x-app` 自动在标准 API 和 Claude Code CLI 模式之间切换。
|
||||
|
||||
**认证方式**: x-api-key 请求头
|
||||
|
||||
**请求格式**:
|
||||
```json
|
||||
{
|
||||
"model": "claude-sonnet-4-20250514",
|
||||
"max_tokens": 1024,
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}
|
||||
```
|
||||
|
||||
**必需请求头**:
|
||||
- `x-api-key`: API 密钥
|
||||
- `anthropic-version`: API 版本(如 2023-06-01)
|
||||
"""
|
||||
adapter = build_claude_adapter(http_request.headers.get("x-app", ""))
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
@@ -45,6 +65,13 @@ async def count_tokens(
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Claude Token Count API
|
||||
|
||||
计算消息的 Token 数量,用于预估请求成本。
|
||||
|
||||
**认证方式**: x-api-key 请求头
|
||||
"""
|
||||
adapter = ClaudeTokenCountAdapter()
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
|
||||
@@ -56,9 +56,23 @@ async def generate_content(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Gemini generateContent 端点
|
||||
Gemini generateContent API
|
||||
|
||||
非流式生成内容请求
|
||||
兼容 Google Gemini API 格式的代理接口(非流式)。
|
||||
|
||||
**认证方式**:
|
||||
- `x-goog-api-key` 请求头,或
|
||||
- `?key=` URL 参数
|
||||
|
||||
**请求格式**:
|
||||
```json
|
||||
{
|
||||
"contents": [{"parts": [{"text": "Hello"}]}]
|
||||
}
|
||||
```
|
||||
|
||||
**路径参数**:
|
||||
- `model`: 模型名称,如 gemini-2.0-flash
|
||||
"""
|
||||
# 根据 user-agent 或 x-app header 选择适配器
|
||||
if _is_cli_request(http_request):
|
||||
@@ -84,9 +98,16 @@ async def stream_generate_content(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Gemini streamGenerateContent 端点
|
||||
Gemini streamGenerateContent API
|
||||
|
||||
流式生成内容请求
|
||||
兼容 Google Gemini API 格式的代理接口(流式)。
|
||||
|
||||
**认证方式**:
|
||||
- `x-goog-api-key` 请求头,或
|
||||
- `?key=` URL 参数
|
||||
|
||||
**路径参数**:
|
||||
- `model`: 模型名称,如 gemini-2.0-flash
|
||||
|
||||
注意: Gemini API 通过 URL 端点区分流式/非流式,不需要在请求体中添加 stream 字段
|
||||
"""
|
||||
@@ -114,7 +135,11 @@ async def generate_content_v1(
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""v1 兼容端点"""
|
||||
"""
|
||||
Gemini generateContent API (v1 兼容)
|
||||
|
||||
v1 版本 API 端点,兼容部分使用旧版路径的 SDK。
|
||||
"""
|
||||
return await generate_content(model, http_request, db)
|
||||
|
||||
|
||||
@@ -124,5 +149,9 @@ async def stream_generate_content_v1(
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""v1 兼容端点"""
|
||||
"""
|
||||
Gemini streamGenerateContent API (v1 兼容)
|
||||
|
||||
v1 版本流式 API 端点,兼容部分使用旧版路径的 SDK。
|
||||
"""
|
||||
return await stream_generate_content(model, http_request, db)
|
||||
|
||||
@@ -27,7 +27,7 @@ from src.database import get_db
|
||||
from src.models.database import ApiKey, User
|
||||
from src.services.auth.service import AuthService
|
||||
|
||||
router = APIRouter(tags=["Models API"])
|
||||
router = APIRouter(tags=["System Catalog"])
|
||||
|
||||
# 各格式对应的 API 格式列表
|
||||
# 注意: CLI 格式是透传格式,Models API 只返回非 CLI 格式的端点支持的模型
|
||||
@@ -126,7 +126,9 @@ def _filter_formats_by_restrictions(
|
||||
"""
|
||||
if restrictions.allowed_api_formats is None:
|
||||
return formats, None
|
||||
filtered = [f for f in formats if f in restrictions.allowed_api_formats]
|
||||
# 统一转为大写比较,兼容数据库中存储的大小写
|
||||
allowed_upper = {f.upper() for f in restrictions.allowed_api_formats}
|
||||
filtered = [f for f in formats if f.upper() in allowed_upper]
|
||||
if not filtered:
|
||||
logger.info(f"[Models] API Key 不允许访问格式 {api_format}")
|
||||
return [], _build_empty_list_response(api_format)
|
||||
@@ -395,11 +397,65 @@ async def list_models(
|
||||
db: Session = Depends(get_db),
|
||||
) -> Union[dict, JSONResponse]:
|
||||
"""
|
||||
List models - 根据请求头认证方式返回对应格式
|
||||
列出可用模型(统一端点)
|
||||
|
||||
- x-api-key -> Claude 格式
|
||||
- x-goog-api-key 或 ?key= -> Gemini 格式
|
||||
- Authorization: Bearer -> OpenAI 格式
|
||||
根据请求头中的认证方式自动检测 API 格式,并返回相应格式的模型列表。
|
||||
此接口兼容 Claude、OpenAI 和 Gemini 三种 API 格式。
|
||||
|
||||
**格式检测规则**
|
||||
- x-api-key + anthropic-version → Claude 格式
|
||||
- x-goog-api-key 或 ?key= → Gemini 格式
|
||||
- Authorization: Bearer → OpenAI 格式(默认)
|
||||
|
||||
**查询参数**
|
||||
|
||||
Claude 格式:
|
||||
- before_id: 返回此 ID 之前的结果,用于向前分页
|
||||
- after_id: 返回此 ID 之后的结果,用于向后分页
|
||||
- limit: 返回数量限制,默认 20,范围 1-1000
|
||||
|
||||
Gemini 格式:
|
||||
- pageSize: 每页数量,默认 50,范围 1-1000
|
||||
- pageToken: 分页 token,用于获取下一页
|
||||
|
||||
**返回字段**
|
||||
|
||||
Claude 格式:
|
||||
- data: 模型列表,每个模型包含:
|
||||
- id: 模型标识符
|
||||
- type: "model"
|
||||
- display_name: 显示名称
|
||||
- created_at: 创建时间(ISO 8601 格式)
|
||||
- has_more: 是否有更多结果
|
||||
- first_id: 当前页第一个模型 ID
|
||||
- last_id: 当前页最后一个模型 ID
|
||||
|
||||
OpenAI 格式:
|
||||
- object: "list"
|
||||
- data: 模型列表,每个模型包含:
|
||||
- id: 模型标识符
|
||||
- object: "model"
|
||||
- created: Unix 时间戳
|
||||
- owned_by: 提供商名称
|
||||
|
||||
Gemini 格式:
|
||||
- models: 模型列表,每个模型包含:
|
||||
- name: 模型资源名称(如 models/gemini-pro)
|
||||
- baseModelId: 基础模型 ID
|
||||
- version: 版本号
|
||||
- displayName: 显示名称
|
||||
- description: 描述信息
|
||||
- inputTokenLimit: 输入 token 上限
|
||||
- outputTokenLimit: 输出 token 上限
|
||||
- supportedGenerationMethods: 支持的生成方法
|
||||
- temperature: 默认温度参数
|
||||
- maxTemperature: 最大温度参数
|
||||
- topP: Top-P 参数
|
||||
- topK: Top-K 参数
|
||||
- nextPageToken: 下一页的 token(如果有更多结果)
|
||||
|
||||
**错误响应**
|
||||
401: API Key 无效或未提供(格式根据检测到的 API 格式返回)
|
||||
"""
|
||||
api_format, api_key = _detect_api_format_and_key(request)
|
||||
logger.info(f"[Models] GET /v1/models | format={api_format}")
|
||||
@@ -440,7 +496,50 @@ async def retrieve_model(
|
||||
db: Session = Depends(get_db),
|
||||
) -> Union[dict, JSONResponse]:
|
||||
"""
|
||||
Retrieve model - 根据请求头认证方式返回对应格式
|
||||
获取单个模型详情(统一端点)
|
||||
|
||||
根据请求头中的认证方式自动检测 API 格式,并返回相应格式的模型详情。
|
||||
此接口兼容 Claude、OpenAI 和 Gemini 三种 API 格式。
|
||||
|
||||
**格式检测规则**
|
||||
- x-api-key + anthropic-version → Claude 格式
|
||||
- x-goog-api-key 或 ?key= → Gemini 格式
|
||||
- Authorization: Bearer → OpenAI 格式(默认)
|
||||
|
||||
**路径参数**
|
||||
- model_id: 模型标识符(Gemini 格式支持 models/ 前缀,会自动移除)
|
||||
|
||||
**返回字段**
|
||||
|
||||
Claude 格式:
|
||||
- id: 模型标识符
|
||||
- type: "model"
|
||||
- display_name: 显示名称
|
||||
- created_at: 创建时间(ISO 8601 格式)
|
||||
|
||||
OpenAI 格式:
|
||||
- id: 模型标识符
|
||||
- object: "model"
|
||||
- created: Unix 时间戳
|
||||
- owned_by: 提供商名称
|
||||
|
||||
Gemini 格式:
|
||||
- name: 模型资源名称(如 models/gemini-pro)
|
||||
- baseModelId: 基础模型 ID
|
||||
- version: 版本号
|
||||
- displayName: 显示名称
|
||||
- description: 描述信息
|
||||
- inputTokenLimit: 输入 token 上限
|
||||
- outputTokenLimit: 输出 token 上限
|
||||
- supportedGenerationMethods: 支持的生成方法
|
||||
- temperature: 默认温度参数
|
||||
- maxTemperature: 最大温度参数
|
||||
- topP: Top-P 参数
|
||||
- topK: Top-K 参数
|
||||
|
||||
**错误响应**
|
||||
401: API Key 无效或未提供
|
||||
404: 模型不存在或不可访问
|
||||
"""
|
||||
api_format, api_key = _detect_api_format_and_key(request)
|
||||
|
||||
@@ -486,7 +585,35 @@ async def list_models_gemini(
|
||||
page_token: Optional[str] = Query(None, alias="pageToken"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Union[dict, JSONResponse]:
|
||||
"""List models (Gemini v1beta 端点)"""
|
||||
"""
|
||||
列出可用模型(Gemini v1beta 专用端点)
|
||||
|
||||
Gemini API 的专用模型列表端点,使用 x-goog-api-key 或 ?key= 参数进行认证。
|
||||
返回 Gemini 格式的模型列表。
|
||||
|
||||
**查询参数**
|
||||
- pageSize: 每页数量,默认 50,范围 1-1000
|
||||
- pageToken: 分页 token,用于获取下一页
|
||||
|
||||
**返回字段**
|
||||
- models: 模型列表,每个模型包含:
|
||||
- name: 模型资源名称(如 models/gemini-pro)
|
||||
- baseModelId: 基础模型 ID
|
||||
- version: 版本号
|
||||
- displayName: 显示名称
|
||||
- description: 描述信息
|
||||
- inputTokenLimit: 输入 token 上限
|
||||
- outputTokenLimit: 输出 token 上限
|
||||
- supportedGenerationMethods: 支持的生成方法列表
|
||||
- temperature: 默认温度参数
|
||||
- maxTemperature: 最大温度参数
|
||||
- topP: Top-P 参数
|
||||
- topK: Top-K 参数
|
||||
- nextPageToken: 下一页的 token(如果有更多结果)
|
||||
|
||||
**错误响应**
|
||||
401: API Key 无效或未提供
|
||||
"""
|
||||
logger.info("[Models] GET /v1beta/models | format=gemini")
|
||||
|
||||
# 从 x-goog-api-key 或 ?key= 提取 API Key
|
||||
@@ -525,7 +652,33 @@ async def get_model_gemini(
|
||||
model_name: str,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Union[dict, JSONResponse]:
|
||||
"""Get model (Gemini v1beta 端点)"""
|
||||
"""
|
||||
获取单个模型详情(Gemini v1beta 专用端点)
|
||||
|
||||
Gemini API 的专用模型详情端点,使用 x-goog-api-key 或 ?key= 参数进行认证。
|
||||
返回 Gemini 格式的模型详情。
|
||||
|
||||
**路径参数**
|
||||
- model_name: 模型名称或资源路径(支持 models/ 前缀,会自动移除)
|
||||
|
||||
**返回字段**
|
||||
- name: 模型资源名称(如 models/gemini-pro)
|
||||
- baseModelId: 基础模型 ID
|
||||
- version: 版本号
|
||||
- displayName: 显示名称
|
||||
- description: 描述信息
|
||||
- inputTokenLimit: 输入 token 上限
|
||||
- outputTokenLimit: 输出 token 上限
|
||||
- supportedGenerationMethods: 支持的生成方法列表
|
||||
- temperature: 默认温度参数
|
||||
- maxTemperature: 最大温度参数
|
||||
- topP: Top-P 参数
|
||||
- topK: Top-K 参数
|
||||
|
||||
**错误响应**
|
||||
401: API Key 无效或未提供
|
||||
404: 模型不存在或不可访问
|
||||
"""
|
||||
# 移除 "models/" 前缀(如果有)
|
||||
model_id = model_name[7:] if model_name.startswith("models/") else model_name
|
||||
logger.info(f"[Models] GET /v1beta/models/{model_id} | format=gemini")
|
||||
|
||||
@@ -27,6 +27,24 @@ async def create_chat_completion(
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
OpenAI Chat Completions API
|
||||
|
||||
兼容 OpenAI Chat Completions API 格式的代理接口。
|
||||
|
||||
**认证方式**: Bearer Token(API Key 或 JWT Token)
|
||||
|
||||
**请求格式**:
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**支持的参数**: model, messages, stream, temperature, max_tokens 等标准 OpenAI 参数
|
||||
"""
|
||||
adapter = OpenAIChatAdapter()
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
@@ -42,6 +60,13 @@ async def create_responses(
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
OpenAI Responses API (CLI)
|
||||
|
||||
兼容 OpenAI Codex CLI 使用的 Responses API 格式,请求透传到上游。
|
||||
|
||||
**认证方式**: Bearer Token(API Key 或 JWT Token)
|
||||
"""
|
||||
adapter = OpenAICliAdapter()
|
||||
return await pipeline.run(
|
||||
adapter=adapter,
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .management_tokens import router as management_tokens_router
|
||||
from .routes import router as me_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(me_router)
|
||||
router.include_router(management_tokens_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
|
||||
577
src/api/user_me/management_tokens.py
Normal file
577
src/api/user_me/management_tokens.py
Normal file
@@ -0,0 +1,577 @@
|
||||
"""用户 Management Token 管理端点"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
|
||||
from src.api.base.context import ApiRequestContext
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.database import get_db
|
||||
from src.models.database import AuditEventType
|
||||
from src.services.management_token import (
|
||||
ManagementTokenService,
|
||||
parse_expires_at,
|
||||
token_to_dict,
|
||||
validate_ip_list,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/me/management-tokens", tags=["Management Tokens"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
# ============== 安全基类 ==============
|
||||
|
||||
|
||||
class ManagementTokenApiAdapter(AuthenticatedApiAdapter):
|
||||
"""Management Token 管理 API 的基类
|
||||
|
||||
安全限制:禁止使用 Management Token 调用这些接口,
|
||||
防止用户通过已有的 Token 再创建/修改/删除其他 Token。
|
||||
"""
|
||||
|
||||
def authorize(self, context: ApiRequestContext):
|
||||
# 先调用父类的认证检查
|
||||
super().authorize(context)
|
||||
|
||||
# 禁止使用 Management Token 调用 management-tokens 相关接口
|
||||
if context.management_token is not None:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="不允许使用 Management Token 管理其他 Token,请使用 Web 界面或 JWT 认证",
|
||||
)
|
||||
|
||||
|
||||
# ============== 请求/响应模型 ==============
|
||||
|
||||
|
||||
class CreateManagementTokenRequest(BaseModel):
|
||||
"""创建 Management Token 请求"""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100, description="Token 名称")
|
||||
description: Optional[str] = Field(None, max_length=500, description="描述")
|
||||
allowed_ips: Optional[list[str]] = Field(None, description="IP 白名单")
|
||||
expires_at: Optional[datetime] = Field(None, description="过期时间")
|
||||
|
||||
@field_validator("allowed_ips")
|
||||
@classmethod
|
||||
def validate_allowed_ips(cls, v: Optional[list[str]]) -> Optional[list[str]]:
|
||||
return validate_ip_list(v)
|
||||
|
||||
@field_validator("expires_at", mode="before")
|
||||
@classmethod
|
||||
def parse_expires(cls, v):
|
||||
return parse_expires_at(v)
|
||||
|
||||
|
||||
class UpdateManagementTokenRequest(BaseModel):
|
||||
"""更新 Management Token 请求
|
||||
|
||||
对于 allowed_ips 和 expires_at 字段:
|
||||
- 未提供(字段不在请求中): 不修改
|
||||
- 显式设为 null: 清空该字段
|
||||
- 提供有效值: 更新为新值
|
||||
"""
|
||||
|
||||
model_config = {"extra": "allow"} # 允许额外字段以便检测哪些字段被显式提供
|
||||
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
allowed_ips: Optional[list[str]] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
# 用于追踪哪些字段被显式提供(包括显式设为 null 的情况)
|
||||
_provided_fields: set[str] = set()
|
||||
|
||||
def __init__(self, **data):
|
||||
# 记录实际传入的字段(包括值为 None 的)
|
||||
provided = set(data.keys())
|
||||
super().__init__(**data)
|
||||
object.__setattr__(self, "_provided_fields", provided)
|
||||
|
||||
def is_field_provided(self, field_name: str) -> bool:
|
||||
"""检查字段是否被显式提供(区分未提供和显式设为 null)"""
|
||||
return field_name in self._provided_fields
|
||||
|
||||
@field_validator("allowed_ips")
|
||||
@classmethod
|
||||
def validate_allowed_ips(cls, v: Optional[list[str]]) -> Optional[list[str]]:
|
||||
# 如果是 None,表示要清空,直接返回
|
||||
if v is None:
|
||||
return None
|
||||
return validate_ip_list(v)
|
||||
|
||||
@field_validator("expires_at", mode="before")
|
||||
@classmethod
|
||||
def parse_expires(cls, v):
|
||||
# 如果是 None 或空字符串,表示要清空
|
||||
if v is None or (isinstance(v, str) and not v.strip()):
|
||||
return None
|
||||
return parse_expires_at(v)
|
||||
|
||||
|
||||
# ============== 路由 ==============
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_my_management_tokens(
|
||||
request: Request,
|
||||
is_active: Optional[bool] = Query(None, description="筛选激活状态"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""列出当前用户的 Management Tokens
|
||||
|
||||
获取当前登录用户创建的所有 Management Tokens,支持按激活状态筛选和分页。
|
||||
|
||||
**查询参数**
|
||||
- is_active (Optional[bool]): 筛选激活状态(true/false),不传则返回全部
|
||||
- skip (int): 分页偏移量,默认 0
|
||||
- limit (int): 每页数量,范围 1-100,默认 50
|
||||
|
||||
**返回字段**
|
||||
- items (List[dict]): Token 列表
|
||||
- id (str): Token ID
|
||||
- user_id (str): 所属用户 ID
|
||||
- name (str): Token 名称
|
||||
- description (Optional[str]): 描述
|
||||
- token_hash (str): Token 哈希值(不返回明文)
|
||||
- is_active (bool): 是否激活
|
||||
- allowed_ips (Optional[List[str]]): IP 白名单
|
||||
- expires_at (Optional[str]): 过期时间(ISO 8601 格式)
|
||||
- last_used_at (Optional[str]): 最后使用时间
|
||||
- created_at (str): 创建时间
|
||||
- updated_at (str): 更新时间
|
||||
- total (int): 总数量
|
||||
- skip (int): 当前偏移量
|
||||
- limit (int): 当前每页数量
|
||||
- quota (dict): 配额信息
|
||||
- used (int): 已使用数量
|
||||
- max (int): 最大允许数量
|
||||
"""
|
||||
adapter = ListMyManagementTokensAdapter(is_active=is_active, skip=skip, limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_my_management_token(request: Request, db: Session = Depends(get_db)):
|
||||
"""创建 Management Token
|
||||
|
||||
为当前用户创建一个新的 Management Token。
|
||||
|
||||
**请求体字段**
|
||||
- name (str): Token 名称,必填,长度 1-100
|
||||
- description (Optional[str]): 描述,可选,最大长度 500
|
||||
- allowed_ips (Optional[List[str]]): IP 白名单,可选,支持 IPv4/IPv6 和 CIDR 格式
|
||||
- expires_at (Optional[datetime]): 过期时间,可选,支持 ISO 8601 格式字符串或 datetime 对象
|
||||
|
||||
**返回字段**
|
||||
- message (str): 操作结果消息
|
||||
- token (str): 生成的 Token 明文(仅在创建时返回一次,请妥善保存)
|
||||
- data (dict): Token 信息
|
||||
- id (str): Token ID
|
||||
- user_id (str): 所属用户 ID
|
||||
- name (str): Token 名称
|
||||
- description (Optional[str]): 描述
|
||||
- token_hash (str): Token 哈希值
|
||||
- is_active (bool): 是否激活(新创建默认为 true)
|
||||
- allowed_ips (Optional[List[str]]): IP 白名单
|
||||
- expires_at (Optional[str]): 过期时间
|
||||
- last_used_at (Optional[str]): 最后使用时间
|
||||
- created_at (str): 创建时间
|
||||
- updated_at (str): 更新时间
|
||||
"""
|
||||
adapter = CreateMyManagementTokenAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{token_id}")
|
||||
async def get_my_management_token(
|
||||
token_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取 Management Token 详情
|
||||
|
||||
获取当前用户指定 Token 的详细信息。
|
||||
|
||||
**路径参数**
|
||||
- token_id (str): Token ID
|
||||
|
||||
**返回字段**
|
||||
- id (str): Token ID
|
||||
- user_id (str): 所属用户 ID
|
||||
- name (str): Token 名称
|
||||
- description (Optional[str]): 描述
|
||||
- token_hash (str): Token 哈希值(不返回明文)
|
||||
- is_active (bool): 是否激活
|
||||
- allowed_ips (Optional[List[str]]): IP 白名单
|
||||
- expires_at (Optional[str]): 过期时间(ISO 8601 格式)
|
||||
- last_used_at (Optional[str]): 最后使用时间
|
||||
- created_at (str): 创建时间
|
||||
- updated_at (str): 更新时间
|
||||
"""
|
||||
adapter = GetMyManagementTokenAdapter(token_id=token_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{token_id}")
|
||||
async def update_my_management_token(
|
||||
token_id: str, request: Request, db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新 Management Token
|
||||
|
||||
更新当前用户指定 Token 的信息。支持部分字段更新。
|
||||
|
||||
**路径参数**
|
||||
- token_id (str): Token ID
|
||||
|
||||
**请求体字段**(所有字段均可选)
|
||||
- name (Optional[str]): Token 名称,长度 1-100
|
||||
- description (Optional[str]): 描述,最大长度 500,传空字符串或 null 可清空
|
||||
- allowed_ips (Optional[List[str]]): IP 白名单,传 null 可清空
|
||||
- expires_at (Optional[datetime]): 过期时间,传 null 可清空
|
||||
|
||||
注意:未提供的字段不会被修改,显式传 null 表示清空该字段。
|
||||
|
||||
**返回字段**
|
||||
- message (str): 操作结果消息
|
||||
- data (dict): 更新后的 Token 信息
|
||||
- id (str): Token ID
|
||||
- user_id (str): 所属用户 ID
|
||||
- name (str): Token 名称
|
||||
- description (Optional[str]): 描述
|
||||
- token_hash (str): Token 哈希值
|
||||
- is_active (bool): 是否激活
|
||||
- allowed_ips (Optional[List[str]]): IP 白名单
|
||||
- expires_at (Optional[str]): 过期时间
|
||||
- last_used_at (Optional[str]): 最后使用时间
|
||||
- created_at (str): 创建时间
|
||||
- updated_at (str): 更新时间
|
||||
"""
|
||||
adapter = UpdateMyManagementTokenAdapter(token_id=token_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{token_id}")
|
||||
async def delete_my_management_token(
|
||||
token_id: str, request: Request, db: Session = Depends(get_db)
|
||||
):
|
||||
"""删除 Management Token
|
||||
|
||||
删除当前用户指定的 Token。
|
||||
|
||||
**路径参数**
|
||||
- token_id (str): 要删除的 Token ID
|
||||
|
||||
**返回字段**
|
||||
- message (str): 操作结果消息
|
||||
"""
|
||||
adapter = DeleteMyManagementTokenAdapter(token_id=token_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{token_id}/status")
|
||||
async def toggle_my_management_token(
|
||||
token_id: str, request: Request, db: Session = Depends(get_db)
|
||||
):
|
||||
"""切换 Management Token 状态
|
||||
|
||||
启用或禁用当前用户指定的 Token。
|
||||
|
||||
**路径参数**
|
||||
- token_id (str): Token ID
|
||||
|
||||
**返回字段**
|
||||
- message (str): 操作结果消息("Token 已启用" 或 "Token 已禁用")
|
||||
- data (dict): 更新后的 Token 信息
|
||||
- id (str): Token ID
|
||||
- user_id (str): 所属用户 ID
|
||||
- name (str): Token 名称
|
||||
- description (Optional[str]): 描述
|
||||
- token_hash (str): Token 哈希值
|
||||
- is_active (bool): 是否激活(已切换后的状态)
|
||||
- allowed_ips (Optional[List[str]]): IP 白名单
|
||||
- expires_at (Optional[str]): 过期时间
|
||||
- last_used_at (Optional[str]): 最后使用时间
|
||||
- created_at (str): 创建时间
|
||||
- updated_at (str): 更新时间
|
||||
"""
|
||||
adapter = ToggleMyManagementTokenAdapter(token_id=token_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/{token_id}/regenerate")
|
||||
async def regenerate_my_management_token(
|
||||
token_id: str, request: Request, db: Session = Depends(get_db)
|
||||
):
|
||||
"""重新生成 Management Token
|
||||
|
||||
重新生成当前用户指定 Token 的值,旧 Token 将立即失效。
|
||||
|
||||
**路径参数**
|
||||
- token_id (str): Token ID
|
||||
|
||||
**返回字段**
|
||||
- message (str): 操作结果消息
|
||||
- token (str): 新生成的 Token 明文(仅在重新生成时返回一次,请妥善保存)
|
||||
- data (dict): Token 信息
|
||||
- id (str): Token ID
|
||||
- user_id (str): 所属用户 ID
|
||||
- name (str): Token 名称
|
||||
- description (Optional[str]): 描述
|
||||
- token_hash (str): 新的 Token 哈希值
|
||||
- is_active (bool): 是否激活
|
||||
- allowed_ips (Optional[List[str]]): IP 白名单
|
||||
- expires_at (Optional[str]): 过期时间
|
||||
- last_used_at (Optional[str]): 最后使用时间(重置为 null)
|
||||
- created_at (str): 创建时间
|
||||
- updated_at (str): 更新时间
|
||||
"""
|
||||
adapter = RegenerateMyManagementTokenAdapter(token_id=token_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ============== 适配器 ==============
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListMyManagementTokensAdapter(ManagementTokenApiAdapter):
|
||||
"""列出用户的 Management Tokens"""
|
||||
|
||||
name: str = "list_my_management_tokens"
|
||||
is_active: Optional[bool] = None
|
||||
skip: int = 0
|
||||
limit: int = 50
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
from src.config.settings import config
|
||||
|
||||
tokens, total = ManagementTokenService.list_tokens(
|
||||
db=context.db,
|
||||
user_id=context.user.id,
|
||||
is_active=self.is_active,
|
||||
skip=self.skip,
|
||||
limit=self.limit,
|
||||
)
|
||||
|
||||
# 获取用户 Token 总数(用于配额显示)
|
||||
max_tokens = config.management_token_max_per_user
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"items": [token_to_dict(t) for t in tokens],
|
||||
"total": total,
|
||||
"skip": self.skip,
|
||||
"limit": self.limit,
|
||||
"quota": {
|
||||
"used": total,
|
||||
"max": max_tokens,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateMyManagementTokenAdapter(ManagementTokenApiAdapter):
|
||||
"""创建 Management Token"""
|
||||
|
||||
name: str = "create_my_management_token"
|
||||
audit_success_event = AuditEventType.MANAGEMENT_TOKEN_CREATED
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
body = context.ensure_json_body()
|
||||
|
||||
try:
|
||||
req = CreateManagementTokenRequest(**body)
|
||||
except Exception as e:
|
||||
raise InvalidRequestException(str(e))
|
||||
|
||||
try:
|
||||
token, raw_token = ManagementTokenService.create_token(
|
||||
db=context.db,
|
||||
user_id=context.user.id,
|
||||
name=req.name,
|
||||
description=req.description,
|
||||
allowed_ips=req.allowed_ips,
|
||||
expires_at=req.expires_at,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise InvalidRequestException(str(e))
|
||||
|
||||
context.add_audit_metadata(token_id=token.id, token_name=token.name)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=201,
|
||||
content={
|
||||
"message": "Management Token 创建成功",
|
||||
"token": raw_token, # 仅在创建时返回一次
|
||||
"data": token_to_dict(token),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetMyManagementTokenAdapter(ManagementTokenApiAdapter):
|
||||
"""获取 Management Token 详情"""
|
||||
|
||||
name: str = "get_my_management_token"
|
||||
token_id: str = ""
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
token = ManagementTokenService.get_token_by_id(
|
||||
db=context.db, token_id=self.token_id, user_id=context.user.id
|
||||
)
|
||||
|
||||
if not token:
|
||||
raise NotFoundException("Management Token 不存在")
|
||||
|
||||
return JSONResponse(content=token_to_dict(token))
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateMyManagementTokenAdapter(ManagementTokenApiAdapter):
|
||||
"""更新 Management Token"""
|
||||
|
||||
name: str = "update_my_management_token"
|
||||
token_id: str = ""
|
||||
audit_success_event = AuditEventType.MANAGEMENT_TOKEN_UPDATED
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
body = context.ensure_json_body()
|
||||
|
||||
try:
|
||||
req = UpdateManagementTokenRequest(**body)
|
||||
except Exception as e:
|
||||
raise InvalidRequestException(str(e))
|
||||
|
||||
# 构建更新参数,只包含显式提供的字段
|
||||
update_kwargs: dict = {
|
||||
"db": context.db,
|
||||
"token_id": self.token_id,
|
||||
"user_id": context.user.id,
|
||||
}
|
||||
|
||||
# 对于普通字段,只有提供了才更新
|
||||
if req.is_field_provided("name"):
|
||||
update_kwargs["name"] = req.name
|
||||
if req.is_field_provided("description"):
|
||||
update_kwargs["description"] = req.description
|
||||
update_kwargs["clear_description"] = req.description is None or req.description == ""
|
||||
|
||||
# 对于可清空字段,需要传递特殊标记
|
||||
if req.is_field_provided("allowed_ips"):
|
||||
update_kwargs["allowed_ips"] = req.allowed_ips
|
||||
update_kwargs["clear_allowed_ips"] = req.allowed_ips is None
|
||||
if req.is_field_provided("expires_at"):
|
||||
update_kwargs["expires_at"] = req.expires_at
|
||||
update_kwargs["clear_expires_at"] = req.expires_at is None
|
||||
|
||||
try:
|
||||
token = ManagementTokenService.update_token(**update_kwargs)
|
||||
except ValueError as e:
|
||||
raise InvalidRequestException(str(e))
|
||||
|
||||
if not token:
|
||||
raise NotFoundException("Management Token 不存在")
|
||||
|
||||
context.add_audit_metadata(token_id=token.id, token_name=token.name)
|
||||
|
||||
return JSONResponse(
|
||||
content={"message": "更新成功", "data": token_to_dict(token)}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeleteMyManagementTokenAdapter(ManagementTokenApiAdapter):
|
||||
"""删除 Management Token"""
|
||||
|
||||
name: str = "delete_my_management_token"
|
||||
token_id: str = ""
|
||||
audit_success_event = AuditEventType.MANAGEMENT_TOKEN_DELETED
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
# 先获取 token 信息用于审计
|
||||
token = ManagementTokenService.get_token_by_id(
|
||||
db=context.db, token_id=self.token_id, user_id=context.user.id
|
||||
)
|
||||
|
||||
if not token:
|
||||
raise NotFoundException("Management Token 不存在")
|
||||
|
||||
context.add_audit_metadata(token_id=token.id, token_name=token.name)
|
||||
|
||||
success = ManagementTokenService.delete_token(
|
||||
db=context.db, token_id=self.token_id, user_id=context.user.id
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise NotFoundException("Management Token 不存在")
|
||||
|
||||
return JSONResponse(content={"message": "删除成功"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToggleMyManagementTokenAdapter(ManagementTokenApiAdapter):
|
||||
"""切换 Management Token 状态"""
|
||||
|
||||
name: str = "toggle_my_management_token"
|
||||
token_id: str = ""
|
||||
audit_success_event = AuditEventType.MANAGEMENT_TOKEN_UPDATED
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
token = ManagementTokenService.toggle_status(
|
||||
db=context.db, token_id=self.token_id, user_id=context.user.id
|
||||
)
|
||||
|
||||
if not token:
|
||||
raise NotFoundException("Management Token 不存在")
|
||||
|
||||
context.add_audit_metadata(
|
||||
token_id=token.id, token_name=token.name, is_active=token.is_active
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"message": f"Token 已{'启用' if token.is_active else '禁用'}",
|
||||
"data": token_to_dict(token),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegenerateMyManagementTokenAdapter(ManagementTokenApiAdapter):
|
||||
"""重新生成 Management Token"""
|
||||
|
||||
name: str = "regenerate_my_management_token"
|
||||
token_id: str = ""
|
||||
audit_success_event = AuditEventType.MANAGEMENT_TOKEN_UPDATED
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
token, raw_token, old_token_hash = ManagementTokenService.regenerate_token(
|
||||
db=context.db, token_id=self.token_id, user_id=context.user.id
|
||||
)
|
||||
|
||||
if not token:
|
||||
raise NotFoundException("Management Token 不存在")
|
||||
|
||||
context.add_audit_metadata(
|
||||
token_id=token.id,
|
||||
token_name=token.name,
|
||||
regenerated=True,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
"message": "Token 已重新生成",
|
||||
"token": raw_token, # 仅在重新生成时返回一次
|
||||
"data": token_to_dict(token),
|
||||
}
|
||||
)
|
||||
@@ -35,20 +35,43 @@ pipeline = ApiRequestPipeline()
|
||||
|
||||
@router.get("")
|
||||
async def get_my_profile(request: Request, db: Session = Depends(get_db)):
|
||||
"""获取当前用户完整信息(包含偏好设置)"""
|
||||
"""
|
||||
获取当前用户信息
|
||||
|
||||
返回当前登录用户的完整信息,包括基本信息和偏好设置。
|
||||
|
||||
**返回字段**: id, email, username, role, is_active, quota_usd, used_usd, preferences 等
|
||||
"""
|
||||
adapter = MeProfileAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("")
|
||||
async def update_my_profile(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
更新个人信息
|
||||
|
||||
更新当前用户的邮箱或用户名。
|
||||
|
||||
**请求体**:
|
||||
- `email`: 新邮箱地址(可选)
|
||||
- `username`: 新用户名(可选)
|
||||
"""
|
||||
adapter = UpdateProfileAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/password")
|
||||
async def change_my_password(request: Request, db: Session = Depends(get_db)):
|
||||
"""Change current user's password"""
|
||||
"""
|
||||
修改密码
|
||||
|
||||
修改当前用户的登录密码。
|
||||
|
||||
**请求体**:
|
||||
- `old_password`: 当前密码
|
||||
- `new_password`: 新密码(至少 6 位)
|
||||
"""
|
||||
adapter = ChangePasswordAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -58,12 +81,30 @@ async def change_my_password(request: Request, db: Session = Depends(get_db)):
|
||||
|
||||
@router.get("/api-keys")
|
||||
async def list_my_api_keys(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
获取 API 密钥列表
|
||||
|
||||
返回当前用户的所有 API 密钥,包含使用统计信息。
|
||||
密钥值仅显示前后几位,完整密钥需通过详情接口获取。
|
||||
|
||||
**返回字段**: id, name, key_display, is_active, total_requests, total_cost_usd, last_used_at 等
|
||||
"""
|
||||
adapter = ListMyApiKeysAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/api-keys")
|
||||
async def create_my_api_key(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
创建 API 密钥
|
||||
|
||||
为当前用户创建新的 API 密钥。创建成功后会返回完整的密钥值,请妥善保存。
|
||||
|
||||
**请求体**:
|
||||
- `name`: 密钥名称
|
||||
|
||||
**返回**: 包含完整密钥值的响应(仅此一次显示完整密钥)
|
||||
"""
|
||||
adapter = CreateMyApiKeyAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -72,10 +113,20 @@ async def create_my_api_key(request: Request, db: Session = Depends(get_db)):
|
||||
async def get_my_api_key(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
include_key: bool = Query(False, description="Include full decrypted key in response"),
|
||||
include_key: bool = Query(False, description="是否返回完整密钥"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get API key detail, optionally include full key"""
|
||||
"""
|
||||
获取 API 密钥详情
|
||||
|
||||
获取指定 API 密钥的详细信息。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: 密钥 ID
|
||||
|
||||
**查询参数**:
|
||||
- `include_key`: 设为 true 时返回完整解密后的密钥值
|
||||
"""
|
||||
if include_key:
|
||||
adapter = GetMyFullKeyAdapter(key_id=key_id)
|
||||
else:
|
||||
@@ -85,13 +136,28 @@ async def get_my_api_key(
|
||||
|
||||
@router.delete("/api-keys/{key_id}")
|
||||
async def delete_my_api_key(key_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
删除 API 密钥
|
||||
|
||||
永久删除指定的 API 密钥,删除后无法恢复。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: 密钥 ID
|
||||
"""
|
||||
adapter = DeleteMyApiKeyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/api-keys/{key_id}")
|
||||
async def toggle_my_api_key(key_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""Toggle API key active status"""
|
||||
"""
|
||||
切换 API 密钥状态
|
||||
|
||||
启用或禁用指定的 API 密钥。禁用后该密钥将无法用于 API 调用。
|
||||
|
||||
**路径参数**:
|
||||
- `key_id`: 密钥 ID
|
||||
"""
|
||||
adapter = ToggleMyApiKeyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -102,23 +168,47 @@ async def toggle_my_api_key(key_id: str, request: Request, db: Session = Depends
|
||||
@router.get("/usage")
|
||||
async def get_my_usage(
|
||||
request: Request,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
start_date: Optional[datetime] = Query(None, description="开始时间(ISO 格式)"),
|
||||
end_date: Optional[datetime] = Query(None, description="结束时间(ISO 格式)"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词(密钥名、模型名)"),
|
||||
limit: int = Query(100, ge=1, le=200, description="每页记录数,默认100,最大200"),
|
||||
offset: int = Query(0, ge=0, le=2000, description="偏移量,用于分页,最大2000"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = GetUsageAdapter(start_date=start_date, end_date=end_date, limit=limit, offset=offset)
|
||||
"""
|
||||
获取使用统计
|
||||
|
||||
获取当前用户的 API 使用统计数据,包括总量汇总、按模型/提供商分组统计及详细记录。
|
||||
|
||||
**返回字段**:
|
||||
- `total_requests`: 总请求数
|
||||
- `total_tokens`: 总 Token 数
|
||||
- `total_cost`: 总成本(USD)
|
||||
- `summary_by_model`: 按模型分组统计
|
||||
- `summary_by_provider`: 按提供商分组统计
|
||||
- `records`: 详细使用记录列表
|
||||
- `pagination`: 分页信息
|
||||
"""
|
||||
adapter = GetUsageAdapter(
|
||||
start_date=start_date, end_date=end_date, search=search, limit=limit, offset=offset
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/usage/active")
|
||||
async def get_my_active_requests(
|
||||
request: Request,
|
||||
ids: Optional[str] = Query(None, description="Comma-separated request IDs to query"),
|
||||
ids: Optional[str] = Query(None, description="请求 ID 列表,逗号分隔"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取用户活跃请求状态(用于轮询更新)"""
|
||||
"""
|
||||
获取活跃请求状态
|
||||
|
||||
查询正在进行中的请求状态,用于前端轮询更新流式请求的进度。
|
||||
|
||||
**查询参数**:
|
||||
- `ids`: 要查询的请求 ID 列表,逗号分隔
|
||||
"""
|
||||
adapter = GetActiveRequestsAdapter(ids=ids)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -130,7 +220,13 @@ async def get_my_interval_timeline(
|
||||
limit: int = Query(5000, ge=100, le=20000, description="最大返回数据点数量"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取当前用户的请求间隔时间线数据,用于散点图展示"""
|
||||
"""
|
||||
获取请求间隔时间线
|
||||
|
||||
获取请求间隔时间线数据,用于散点图展示请求分布情况。
|
||||
|
||||
**返回**: 包含时间戳和间隔时间的数据点列表
|
||||
"""
|
||||
adapter = GetMyIntervalTimelineAdapter(hours=hours, limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -141,9 +237,12 @@ async def get_my_activity_heatmap(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get user's activity heatmap data for the past 365 days.
|
||||
获取活动热力图数据
|
||||
|
||||
This endpoint is cached for 5 minutes to reduce database load.
|
||||
获取过去 365 天的活动热力图数据,用于展示每日使用频率。
|
||||
此接口有 5 分钟缓存。
|
||||
|
||||
**返回**: 包含日期和请求数量的数据列表
|
||||
"""
|
||||
adapter = GetMyActivityHeatmapAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -151,13 +250,26 @@ async def get_my_activity_heatmap(
|
||||
|
||||
@router.get("/providers")
|
||||
async def list_available_providers(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
获取可用提供商列表
|
||||
|
||||
获取当前用户可用的所有提供商及其模型信息。
|
||||
|
||||
**返回字段**: id, name, display_name, endpoints, models 等
|
||||
"""
|
||||
adapter = ListAvailableProvidersAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/endpoint-status")
|
||||
async def get_endpoint_status(request: Request, db: Session = Depends(get_db)):
|
||||
"""获取端点状态(简化版,不包含敏感信息)"""
|
||||
"""
|
||||
获取端点健康状态
|
||||
|
||||
获取各 API 格式端点的健康状态(简化版,不包含敏感信息)。
|
||||
|
||||
**返回**: 按 API 格式分组的端点健康状态
|
||||
"""
|
||||
adapter = GetEndpointStatusAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -174,6 +286,17 @@ async def update_api_key_providers(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
更新 API 密钥可用提供商
|
||||
|
||||
设置指定 API 密钥可以使用哪些提供商。未设置时使用用户默认权限。
|
||||
|
||||
**路径参数**:
|
||||
- `api_key_id`: API 密钥 ID
|
||||
|
||||
**请求体**:
|
||||
- `allowed_providers`: 允许的提供商 ID 列表
|
||||
"""
|
||||
adapter = UpdateApiKeyProvidersAdapter(api_key_id=api_key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -184,7 +307,17 @@ async def update_api_key_capabilities(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""更新 API Key 的强制能力配置"""
|
||||
"""
|
||||
更新 API 密钥能力配置
|
||||
|
||||
设置指定 API 密钥的强制能力配置(如是否启用代码执行等)。
|
||||
|
||||
**路径参数**:
|
||||
- `api_key_id`: API 密钥 ID
|
||||
|
||||
**请求体**:
|
||||
- `force_capabilities`: 能力配置字典,如 `{"code_execution": true}`
|
||||
"""
|
||||
adapter = UpdateApiKeyCapabilitiesAdapter(api_key_id=api_key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -194,26 +327,59 @@ async def update_api_key_capabilities(
|
||||
|
||||
@router.get("/preferences")
|
||||
async def get_my_preferences(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
获取偏好设置
|
||||
|
||||
获取当前用户的偏好设置,包括主题、语言、通知配置等。
|
||||
|
||||
**返回字段**: avatar_url, bio, theme, language, timezone, notifications 等
|
||||
"""
|
||||
adapter = GetPreferencesAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/preferences")
|
||||
async def update_my_preferences(request: Request, db: Session = Depends(get_db)):
|
||||
"""
|
||||
更新偏好设置
|
||||
|
||||
更新当前用户的偏好设置。
|
||||
|
||||
**请求体**:
|
||||
- `theme`: 主题(light/dark)
|
||||
- `language`: 语言
|
||||
- `timezone`: 时区
|
||||
- `email_notifications`: 邮件通知开关
|
||||
- `usage_alerts`: 用量告警开关
|
||||
- 等
|
||||
"""
|
||||
adapter = UpdatePreferencesAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/model-capabilities")
|
||||
async def get_model_capability_settings(request: Request, db: Session = Depends(get_db)):
|
||||
"""获取用户的模型能力配置"""
|
||||
"""
|
||||
获取模型能力配置
|
||||
|
||||
获取用户针对各模型的能力配置(如是否启用特定功能)。
|
||||
|
||||
**返回**: model_capability_settings 字典
|
||||
"""
|
||||
adapter = GetModelCapabilitySettingsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/model-capabilities")
|
||||
async def update_model_capability_settings(request: Request, db: Session = Depends(get_db)):
|
||||
"""更新用户的模型能力配置"""
|
||||
"""
|
||||
更新模型能力配置
|
||||
|
||||
更新用户针对各模型的能力配置。
|
||||
|
||||
**请求体**:
|
||||
- `model_capability_settings`: 模型能力配置字典,格式为 `{"model_name": {"capability": true}}`
|
||||
"""
|
||||
adapter = UpdateModelCapabilitySettingsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
@@ -222,11 +388,15 @@ async def update_model_capability_settings(request: Request, db: Session = Depen
|
||||
|
||||
|
||||
class MeProfileAdapter(AuthenticatedApiAdapter):
|
||||
"""获取当前用户信息的适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
return PreferenceService.get_user_with_preferences(context.db, context.user.id)
|
||||
|
||||
|
||||
class UpdateProfileAdapter(AuthenticatedApiAdapter):
|
||||
"""更新用户个人信息的适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
user = context.user
|
||||
@@ -262,6 +432,8 @@ class UpdateProfileAdapter(AuthenticatedApiAdapter):
|
||||
|
||||
|
||||
class ChangePasswordAdapter(AuthenticatedApiAdapter):
|
||||
"""修改用户密码的适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
user = context.user
|
||||
@@ -287,6 +459,8 @@ class ChangePasswordAdapter(AuthenticatedApiAdapter):
|
||||
|
||||
|
||||
class ListMyApiKeysAdapter(AuthenticatedApiAdapter):
|
||||
"""获取用户 API 密钥列表的适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
user = context.user
|
||||
@@ -356,6 +530,8 @@ class ListMyApiKeysAdapter(AuthenticatedApiAdapter):
|
||||
|
||||
|
||||
class CreateMyApiKeyAdapter(AuthenticatedApiAdapter):
|
||||
"""创建 API 密钥的适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
@@ -385,6 +561,8 @@ class CreateMyApiKeyAdapter(AuthenticatedApiAdapter):
|
||||
|
||||
@dataclass
|
||||
class GetMyFullKeyAdapter(AuthenticatedApiAdapter):
|
||||
"""获取 API 密钥完整密钥值的适配器"""
|
||||
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
@@ -417,7 +595,8 @@ class GetMyFullKeyAdapter(AuthenticatedApiAdapter):
|
||||
|
||||
@dataclass
|
||||
class GetMyApiKeyDetailAdapter(AuthenticatedApiAdapter):
|
||||
"""Get API key detail without full key"""
|
||||
"""获取 API 密钥详情的适配器(不包含完整密钥值)"""
|
||||
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
@@ -446,6 +625,8 @@ class GetMyApiKeyDetailAdapter(AuthenticatedApiAdapter):
|
||||
|
||||
@dataclass
|
||||
class DeleteMyApiKeyAdapter(AuthenticatedApiAdapter):
|
||||
"""删除 API 密钥的适配器"""
|
||||
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
@@ -463,6 +644,8 @@ class DeleteMyApiKeyAdapter(AuthenticatedApiAdapter):
|
||||
|
||||
@dataclass
|
||||
class ToggleMyApiKeyAdapter(AuthenticatedApiAdapter):
|
||||
"""切换 API 密钥启用/禁用状态的适配器"""
|
||||
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
@@ -485,12 +668,19 @@ class ToggleMyApiKeyAdapter(AuthenticatedApiAdapter):
|
||||
|
||||
@dataclass
|
||||
class GetUsageAdapter(AuthenticatedApiAdapter):
|
||||
"""获取用户使用统计的适配器"""
|
||||
|
||||
start_date: Optional[datetime]
|
||||
end_date: Optional[datetime]
|
||||
search: Optional[str] = None
|
||||
limit: int = 100
|
||||
offset: int = 0
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from sqlalchemy import or_
|
||||
|
||||
from src.utils.database_helpers import escape_like_pattern, safe_truncate_escaped
|
||||
|
||||
db = context.db
|
||||
user = context.user
|
||||
summary_list = UsageService.get_usage_summary(
|
||||
@@ -595,12 +785,30 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
|
||||
})
|
||||
summary_by_provider = sorted(summary_by_provider, key=lambda x: x["requests"], reverse=True)
|
||||
|
||||
query = db.query(Usage).filter(Usage.user_id == user.id)
|
||||
query = (
|
||||
db.query(Usage, ApiKey)
|
||||
.outerjoin(ApiKey, Usage.api_key_id == ApiKey.id)
|
||||
.filter(Usage.user_id == user.id)
|
||||
)
|
||||
if self.start_date:
|
||||
query = query.filter(Usage.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
query = query.filter(Usage.created_at <= self.end_date)
|
||||
|
||||
# 通用搜索:密钥名、模型名
|
||||
# 支持空格分隔的组合搜索,多个关键词之间是 AND 关系
|
||||
if self.search and self.search.strip():
|
||||
keywords = [kw for kw in self.search.strip().split() if kw][:10]
|
||||
for keyword in keywords:
|
||||
escaped = safe_truncate_escaped(escape_like_pattern(keyword), 100)
|
||||
search_pattern = f"%{escaped}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
ApiKey.name.ilike(search_pattern, escape="\\"),
|
||||
Usage.model.ilike(search_pattern, escape="\\"),
|
||||
)
|
||||
)
|
||||
|
||||
# 计算总数用于分页
|
||||
total_records = query.count()
|
||||
usage_records = query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
|
||||
@@ -639,7 +847,7 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
|
||||
"records": [
|
||||
{
|
||||
"id": r.id,
|
||||
"provider": r.provider,
|
||||
"provider": r.provider_name,
|
||||
"model": r.model,
|
||||
"target_model": r.target_model, # 映射后的目标模型名
|
||||
"api_format": r.api_format,
|
||||
@@ -659,8 +867,17 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
|
||||
"output_price_per_1m": r.output_price_per_1m,
|
||||
"cache_creation_price_per_1m": r.cache_creation_price_per_1m,
|
||||
"cache_read_price_per_1m": r.cache_read_price_per_1m,
|
||||
"api_key": (
|
||||
{
|
||||
"id": str(api_key.id),
|
||||
"name": api_key.name,
|
||||
"display": api_key.get_display_key(),
|
||||
}
|
||||
if api_key
|
||||
else None
|
||||
),
|
||||
}
|
||||
for r in usage_records
|
||||
for r, api_key in usage_records
|
||||
],
|
||||
}
|
||||
|
||||
@@ -668,7 +885,7 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
|
||||
if user.role == "admin":
|
||||
response_data["total_actual_cost"] = total_actual_cost
|
||||
# 为每条记录添加真实成本和倍率信息
|
||||
for i, r in enumerate(usage_records):
|
||||
for i, (r, _) in enumerate(usage_records):
|
||||
# 确保字段有值,避免前端显示 -
|
||||
actual_cost = (
|
||||
r.actual_total_cost_usd if r.actual_total_cost_usd is not None else 0.0
|
||||
@@ -731,7 +948,7 @@ class GetMyIntervalTimelineAdapter(AuthenticatedApiAdapter):
|
||||
|
||||
|
||||
class GetMyActivityHeatmapAdapter(AuthenticatedApiAdapter):
|
||||
"""Activity heatmap adapter with Redis caching for user."""
|
||||
"""获取用户活动热力图数据的适配器(带 Redis 缓存)"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
user = context.user
|
||||
@@ -745,6 +962,8 @@ class GetMyActivityHeatmapAdapter(AuthenticatedApiAdapter):
|
||||
|
||||
|
||||
class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
||||
"""获取可用提供商列表的适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
@@ -816,6 +1035,8 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
||||
|
||||
@dataclass
|
||||
class UpdateApiKeyProvidersAdapter(AuthenticatedApiAdapter):
|
||||
"""更新 API 密钥可用提供商的适配器"""
|
||||
|
||||
api_key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
@@ -927,6 +1148,8 @@ class UpdateApiKeyCapabilitiesAdapter(AuthenticatedApiAdapter):
|
||||
|
||||
|
||||
class GetPreferencesAdapter(AuthenticatedApiAdapter):
|
||||
"""获取用户偏好设置的适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
preferences = PreferenceService.get_or_create_preferences(context.db, context.user.id)
|
||||
return {
|
||||
@@ -948,6 +1171,8 @@ class GetPreferencesAdapter(AuthenticatedApiAdapter):
|
||||
|
||||
|
||||
class UpdatePreferencesAdapter(AuthenticatedApiAdapter):
|
||||
"""更新用户偏好设置的适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
|
||||
@@ -106,13 +106,6 @@ class Config:
|
||||
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
|
||||
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
|
||||
|
||||
# 可信代理配置
|
||||
# TRUSTED_PROXY_COUNT: 信任的代理层数(默认 1,即信任最近一层代理)
|
||||
# 设置为 0 表示不信任任何代理头,直接使用连接 IP
|
||||
# 当服务部署在 Nginx/CloudFlare 等反向代理后面时,设置为对应的代理层数
|
||||
# 如果服务直接暴露公网,应设置为 0 以防止 IP 伪造
|
||||
self.trusted_proxy_count = int(os.getenv("TRUSTED_PROXY_COUNT", "1"))
|
||||
|
||||
# 异常处理配置
|
||||
# 设置为 True 时,ProxyException 会传播到路由层以便记录 provider_request_headers
|
||||
# 设置为 False 时,使用全局异常处理器统一处理
|
||||
@@ -161,6 +154,11 @@ class Config:
|
||||
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
|
||||
self.stream_first_byte_timeout = self._parse_ttfb_timeout()
|
||||
|
||||
# 请求体读取超时(秒)
|
||||
# REQUEST_BODY_TIMEOUT: 等待客户端发送完整请求体的超时时间
|
||||
# 默认 60 秒,防止客户端发送不完整请求导致连接卡死
|
||||
self.request_body_timeout = float(os.getenv("REQUEST_BODY_TIMEOUT", "60.0"))
|
||||
|
||||
# 内部请求 User-Agent 配置(用于查询上游模型列表等)
|
||||
# 可通过环境变量覆盖默认值,模拟对应 CLI 客户端
|
||||
self.internal_user_agent_claude_cli = os.getenv(
|
||||
@@ -183,6 +181,28 @@ class Config:
|
||||
os.getenv("VERIFICATION_SEND_COOLDOWN", "60")
|
||||
)
|
||||
|
||||
# Management Token 速率限制(每分钟每 IP)
|
||||
self.management_token_rate_limit = int(
|
||||
os.getenv("MANAGEMENT_TOKEN_RATE_LIMIT", "30")
|
||||
)
|
||||
|
||||
# 每个用户最多可创建的 Management Token 数量
|
||||
self.management_token_max_per_user = int(
|
||||
os.getenv("MANAGEMENT_TOKEN_MAX_PER_USER", "20")
|
||||
)
|
||||
|
||||
# API 文档配置
|
||||
# DOCS_ENABLED: 是否启用 API 文档(/docs, /redoc, /openapi.json)
|
||||
# - 未设置: 开发环境启用,生产环境禁用
|
||||
# - true: 强制启用
|
||||
# - false: 强制禁用
|
||||
docs_enabled_env = os.getenv("DOCS_ENABLED")
|
||||
if docs_enabled_env is not None:
|
||||
self.docs_enabled = docs_enabled_env.lower() == "true"
|
||||
else:
|
||||
# 默认:开发环境启用,生产环境禁用
|
||||
self.docs_enabled = self.environment == "development"
|
||||
|
||||
# 验证连接池配置
|
||||
self._validate_pool_config()
|
||||
|
||||
|
||||
@@ -30,3 +30,10 @@ class ProviderBillingType(Enum):
|
||||
MONTHLY_QUOTA = "monthly_quota" # 月卡额度
|
||||
PAY_AS_YOU_GO = "pay_as_you_go" # 按量付费
|
||||
FREE_TIER = "free_tier" # 免费额度
|
||||
|
||||
|
||||
class AuthSource(str, Enum):
|
||||
"""认证来源枚举"""
|
||||
|
||||
LOCAL = "local" # 本地认证
|
||||
LDAP = "ldap" # LDAP 认证
|
||||
|
||||
93
src/main.py
93
src/main.py
@@ -248,11 +248,94 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
from src import __version__ as app_version
|
||||
|
||||
# OpenAPI Tags 元数据定义
|
||||
openapi_tags = [
|
||||
{
|
||||
"name": "Authentication",
|
||||
"description": "用户认证相关接口,包括登录、注册、令牌刷新等",
|
||||
},
|
||||
{
|
||||
"name": "User Profile",
|
||||
"description": "用户个人信息管理,包括 API 密钥、使用统计、偏好设置等",
|
||||
},
|
||||
{
|
||||
"name": "Management Tokens",
|
||||
"description": "管理令牌,用于 CLI 工具等外部应用的认证",
|
||||
},
|
||||
{
|
||||
"name": "Dashboard",
|
||||
"description": "仪表盘统计数据,包括请求量、Token 用量、成本等概览信息",
|
||||
},
|
||||
{
|
||||
"name": "Announcements",
|
||||
"description": "系统公告管理",
|
||||
},
|
||||
{
|
||||
"name": "Monitoring",
|
||||
"description": "用户监控与审计日志查询",
|
||||
},
|
||||
{
|
||||
"name": "Admin - Users",
|
||||
"description": "用户管理(管理员)",
|
||||
},
|
||||
{
|
||||
"name": "Admin - Providers",
|
||||
"description": "提供商管理(管理员)",
|
||||
},
|
||||
{
|
||||
"name": "Admin - Endpoints",
|
||||
"description": "端点管理(管理员)",
|
||||
},
|
||||
{
|
||||
"name": "Admin - Models",
|
||||
"description": "模型管理(管理员)",
|
||||
},
|
||||
{
|
||||
"name": "Admin - API Keys",
|
||||
"description": "API 密钥管理(管理员)",
|
||||
},
|
||||
{
|
||||
"name": "Admin - Usage",
|
||||
"description": "使用统计管理(管理员)",
|
||||
},
|
||||
{
|
||||
"name": "Admin - Monitoring",
|
||||
"description": "系统监控(管理员)",
|
||||
},
|
||||
{
|
||||
"name": "Admin - Security",
|
||||
"description": "安全配置管理(管理员)",
|
||||
},
|
||||
{
|
||||
"name": "Admin - System",
|
||||
"description": "系统配置管理(管理员)",
|
||||
},
|
||||
{
|
||||
"name": "Claude API",
|
||||
"description": "Claude API 代理接口,兼容 Anthropic Claude API 格式",
|
||||
},
|
||||
{
|
||||
"name": "OpenAI API",
|
||||
"description": "OpenAI API 代理接口,兼容 OpenAI Chat Completions API 格式",
|
||||
},
|
||||
{
|
||||
"name": "Gemini API",
|
||||
"description": "Gemini API 代理接口,兼容 Google Gemini API 格式",
|
||||
},
|
||||
{
|
||||
"name": "System Catalog",
|
||||
"description": "系统目录接口,用于获取可用模型列表等",
|
||||
},
|
||||
]
|
||||
|
||||
app = FastAPI(
|
||||
title="AI Proxy with Modular Architecture",
|
||||
title="Aether API Gateway",
|
||||
version=app_version,
|
||||
description="AI代理服务,采用模块化架构,支持插件化扩展",
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs" if config.docs_enabled else None,
|
||||
redoc_url="/redoc" if config.docs_enabled else None,
|
||||
openapi_url="/openapi.json" if config.docs_enabled else None,
|
||||
openapi_tags=openapi_tags
|
||||
)
|
||||
|
||||
# 注册全局异常处理器
|
||||
@@ -272,15 +355,17 @@ app.add_middleware(PluginMiddleware)
|
||||
# 生产环境必须通过 CORS_ORIGINS 环境变量显式指定允许的域名
|
||||
# 开发环境默认允许本地前端访问
|
||||
if config.cors_origins:
|
||||
# CORS_ORIGINS=* 时自动禁用 credentials(浏览器规范要求)
|
||||
allow_credentials = config.cors_allow_credentials and "*" not in config.cors_origins
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=config.cors_origins, # 使用配置的白名单
|
||||
allow_credentials=config.cors_allow_credentials,
|
||||
allow_credentials=allow_credentials,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["*"],
|
||||
)
|
||||
logger.info(f"CORS已启用,允许的源: {config.cors_origins}")
|
||||
logger.info(f"CORS已启用,允许的源: {config.cors_origins}, credentials: {allow_credentials}")
|
||||
else:
|
||||
# 没有配置CORS源,不允许跨域
|
||||
logger.warning(
|
||||
|
||||
@@ -203,28 +203,21 @@ class PluginMiddleware:
|
||||
"""
|
||||
获取客户端 IP 地址,支持代理头
|
||||
|
||||
注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头,
|
||||
仅当服务部署在可信代理(如 Nginx、CloudFlare)后面时才安全。
|
||||
如果服务直接暴露公网,攻击者可伪造这些头绕过限流。
|
||||
优先级:X-Real-IP > X-Forwarded-For > 直连 IP
|
||||
X-Real-IP 由最外层 Nginx 设置,最可靠
|
||||
"""
|
||||
# 从配置获取可信代理层数(默认为 1,即信任最近一层代理)
|
||||
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
|
||||
|
||||
# 优先从代理头获取真实 IP
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",")]
|
||||
if len(ips) > trusted_proxy_count:
|
||||
return ips[-(trusted_proxy_count + 1)]
|
||||
elif ips:
|
||||
return ips[0]
|
||||
|
||||
# 优先检查 X-Real-IP(由最外层 Nginx 设置,最可靠)
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
|
||||
# 检查 X-Forwarded-For,取第一个 IP(原始客户端)
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||
if ips:
|
||||
return ips[0]
|
||||
|
||||
# 回退到直连 IP
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
@@ -71,7 +71,6 @@ class CreateProviderRequest(BaseModel):
|
||||
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
|
||||
provider_priority: Optional[int] = Field(100, ge=0, le=1000, description="提供商优先级(数字越小越优先)")
|
||||
is_active: Optional[bool] = Field(True, description="是否启用")
|
||||
rate_limit: Optional[int] = Field(None, ge=0, description="速率限制")
|
||||
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
||||
|
||||
@@ -174,7 +173,6 @@ class UpdateProviderRequest(BaseModel):
|
||||
rpm_limit: Optional[int] = Field(None, ge=0)
|
||||
provider_priority: Optional[int] = Field(None, ge=0, le=1000)
|
||||
is_active: Optional[bool] = None
|
||||
rate_limit: Optional[int] = Field(None, ge=0)
|
||||
concurrent_limit: Optional[int] = Field(None, ge=0)
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
|
||||
@@ -322,7 +320,7 @@ class UpdateUserRequest(BaseModel):
|
||||
is_active: Optional[bool] = None
|
||||
role: Optional[str] = None
|
||||
allowed_providers: Optional[List[str]] = Field(None, description="允许使用的提供商 ID 列表")
|
||||
allowed_endpoints: Optional[List[str]] = Field(None, description="允许使用的端点 ID 列表")
|
||||
allowed_api_formats: Optional[List[str]] = Field(None, description="允许使用的 API 格式列表")
|
||||
allowed_models: Optional[List[str]] = Field(None, description="允许使用的模型名称列表")
|
||||
|
||||
@field_validator("username")
|
||||
|
||||
@@ -4,9 +4,9 @@ API端点请求/响应模型定义
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from ..core.enums import UserRole
|
||||
|
||||
@@ -15,17 +15,9 @@ from ..core.enums import UserRole
|
||||
class LoginRequest(BaseModel):
|
||||
"""登录请求"""
|
||||
|
||||
email: str = Field(..., min_length=3, max_length=255, description="邮箱地址")
|
||||
email: str = Field(..., min_length=1, max_length=255, description="邮箱/用户名")
|
||||
password: str = Field(..., min_length=1, max_length=128, description="密码")
|
||||
|
||||
@classmethod
|
||||
@field_validator("email")
|
||||
def validate_email(cls, v):
|
||||
"""验证邮箱格式"""
|
||||
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
if not re.match(email_pattern, v):
|
||||
raise ValueError("邮箱格式无效")
|
||||
return v.lower()
|
||||
auth_type: Literal["local", "ldap"] = Field(default="local", description="认证类型")
|
||||
|
||||
@classmethod
|
||||
@field_validator("password")
|
||||
@@ -36,6 +28,24 @@ class LoginRequest(BaseModel):
|
||||
raise ValueError("密码不能为空")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_login(self):
|
||||
"""根据认证类型校验并规范化登录标识"""
|
||||
identifier = self.email.strip()
|
||||
|
||||
if not identifier:
|
||||
raise ValueError("用户名/邮箱不能为空")
|
||||
|
||||
# 本地和 LDAP 登录都支持用户名或邮箱
|
||||
# 如果是邮箱格式,转换为小写
|
||||
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
if re.match(email_pattern, identifier):
|
||||
self.email = identifier.lower()
|
||||
else:
|
||||
self.email = identifier
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""登录响应"""
|
||||
@@ -283,7 +293,7 @@ class UpdateUserRequest(BaseModel):
|
||||
password: Optional[str] = None
|
||||
role: Optional[UserRole] = None
|
||||
allowed_providers: Optional[List[str]] = None # 允许使用的提供商 ID 列表
|
||||
allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表
|
||||
allowed_api_formats: Optional[List[str]] = None # 允许使用的 API 格式列表
|
||||
allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表
|
||||
quota_usd: Optional[float] = None
|
||||
is_active: Optional[bool] = None
|
||||
@@ -306,7 +316,6 @@ class CreateApiKeyRequest(BaseModel):
|
||||
|
||||
name: Optional[str] = None
|
||||
allowed_providers: Optional[List[str]] = None # 允许使用的提供商 ID 列表
|
||||
allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表
|
||||
allowed_api_formats: Optional[List[str]] = None # 允许使用的 API 格式列表
|
||||
allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表
|
||||
rate_limit: Optional[int] = None # None = 无限制
|
||||
@@ -329,7 +338,7 @@ class UserResponse(BaseModel):
|
||||
username: str
|
||||
role: UserRole
|
||||
allowed_providers: Optional[List[str]] = None # 允许使用的提供商 ID 列表
|
||||
allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表
|
||||
allowed_api_formats: Optional[List[str]] = None # 允许使用的 API 格式列表
|
||||
allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表
|
||||
quota_usd: float
|
||||
used_usd: float
|
||||
|
||||
@@ -14,6 +14,7 @@ from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
Column,
|
||||
DateTime,
|
||||
Enum,
|
||||
@@ -30,7 +31,7 @@ from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
from ..config import config
|
||||
from ..core.enums import ProviderBillingType, UserRole
|
||||
from ..core.enums import AuthSource, ProviderBillingType, UserRole
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
@@ -54,10 +55,24 @@ class User(Base):
|
||||
default=UserRole.USER,
|
||||
nullable=False,
|
||||
)
|
||||
auth_source = Column(
|
||||
Enum(
|
||||
AuthSource,
|
||||
name="authsource",
|
||||
create_type=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=AuthSource.LOCAL,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# LDAP 标识(仅 auth_source=ldap 时使用,用于在邮箱变更/用户名冲突时稳定关联本地账户)
|
||||
ldap_dn = Column(String(512), nullable=True, index=True)
|
||||
ldap_username = Column(String(255), nullable=True, index=True)
|
||||
|
||||
# 访问限制(NULL 表示不限制,允许访问所有资源)
|
||||
allowed_providers = Column(JSON, nullable=True) # 允许使用的提供商 ID 列表
|
||||
allowed_endpoints = Column(JSON, nullable=True) # 允许使用的端点 ID 列表
|
||||
allowed_api_formats = Column(JSON, nullable=True) # 允许使用的 API 格式列表
|
||||
allowed_models = Column(JSON, nullable=True) # 允许使用的模型名称列表
|
||||
|
||||
# Key 能力配置
|
||||
@@ -87,6 +102,9 @@ class User(Base):
|
||||
|
||||
# 关系 - CASCADE delete: 让数据库处理级联删除
|
||||
api_keys = relationship("ApiKey", back_populates="user", cascade="all, delete-orphan")
|
||||
management_tokens = relationship(
|
||||
"ManagementToken", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
preferences = relationship(
|
||||
"UserPreference", back_populates="user", cascade="all, delete-orphan", passive_deletes=True
|
||||
)
|
||||
@@ -147,7 +165,6 @@ class ApiKey(Base):
|
||||
|
||||
# 访问限制(NULL 表示不限制,允许访问所有资源)
|
||||
allowed_providers = Column(JSON, nullable=True) # 允许使用的提供商 ID 列表
|
||||
allowed_endpoints = Column(JSON, nullable=True) # 允许使用的端点 ID 列表
|
||||
allowed_api_formats = Column(JSON, nullable=True) # 允许使用的 API 格式列表
|
||||
allowed_models = Column(JSON, nullable=True) # 允许使用的模型名称列表
|
||||
rate_limit = Column(Integer, default=None, nullable=True) # 每分钟请求限制,None = 无限制
|
||||
@@ -254,7 +271,7 @@ class Usage(Base):
|
||||
|
||||
# 请求信息
|
||||
request_id = Column(String(100), unique=True, index=True, nullable=False)
|
||||
provider = Column(String(100), nullable=False)
|
||||
provider_name = Column(String(100), nullable=False) # Provider 名称(非外键)
|
||||
model = Column(String(100), nullable=False)
|
||||
target_model = Column(String(100), nullable=True, comment="映射后的目标模型名(若无映射则为空)")
|
||||
|
||||
@@ -428,6 +445,68 @@ class SystemConfig(Base):
|
||||
)
|
||||
|
||||
|
||||
class LDAPConfig(Base):
|
||||
"""LDAP认证配置表 - 单行配置"""
|
||||
|
||||
__tablename__ = "ldap_configs"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
server_url = Column(String(255), nullable=False) # ldap://host:389 或 ldaps://host:636
|
||||
bind_dn = Column(String(255), nullable=False) # 绑定账号 DN
|
||||
bind_password_encrypted = Column(Text, nullable=True) # 加密的绑定密码(允许 NULL 表示已清除)
|
||||
base_dn = Column(String(255), nullable=False) # 用户搜索基础 DN
|
||||
user_search_filter = Column(
|
||||
String(500), default="(uid={username})", nullable=False
|
||||
) # 用户搜索过滤器
|
||||
username_attr = Column(String(50), default="uid", nullable=False) # 用户名属性 (uid/sAMAccountName)
|
||||
email_attr = Column(String(50), default="mail", nullable=False) # 邮箱属性
|
||||
display_name_attr = Column(String(50), default="cn", nullable=False) # 显示名称属性
|
||||
is_enabled = Column(Boolean, default=False, nullable=False) # 是否启用 LDAP 认证
|
||||
is_exclusive = Column(
|
||||
Boolean, default=False, nullable=False
|
||||
) # 是否仅允许 LDAP 登录(禁用本地认证)
|
||||
use_starttls = Column(Boolean, default=False, nullable=False) # 是否使用 STARTTLS
|
||||
connect_timeout = Column(Integer, default=10, nullable=False) # 连接超时时间(秒)
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def set_bind_password(self, password: str) -> None:
|
||||
"""
|
||||
设置并加密绑定密码
|
||||
|
||||
Args:
|
||||
password: 明文密码
|
||||
"""
|
||||
from src.core.crypto import crypto_service
|
||||
|
||||
self.bind_password_encrypted = crypto_service.encrypt(password)
|
||||
|
||||
def get_bind_password(self) -> str:
|
||||
"""
|
||||
获取解密后的绑定密码
|
||||
|
||||
Returns:
|
||||
str: 解密后的明文密码
|
||||
|
||||
Raises:
|
||||
DecryptionException: 解密失败时抛出异常
|
||||
"""
|
||||
from src.core.crypto import crypto_service
|
||||
|
||||
if not self.bind_password_encrypted:
|
||||
return ""
|
||||
return crypto_service.decrypt(self.bind_password_encrypted)
|
||||
|
||||
|
||||
class Provider(Base):
|
||||
"""提供商配置表"""
|
||||
|
||||
@@ -474,7 +553,6 @@ class Provider(Base):
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
# 限制
|
||||
rate_limit = Column(Integer, nullable=True) # 每分钟请求限制
|
||||
concurrent_limit = Column(Integer, nullable=True) # 并发请求限制
|
||||
|
||||
# 配置
|
||||
@@ -519,7 +597,7 @@ class ProviderEndpoint(Base):
|
||||
# 请求配置
|
||||
headers = Column(JSON, nullable=True) # 额外请求头
|
||||
timeout = Column(Integer, default=300) # 超时(秒)
|
||||
max_retries = Column(Integer, default=3) # 最大重试次数
|
||||
max_retries = Column(Integer, default=2) # 最大重试次数
|
||||
|
||||
# 限制
|
||||
max_concurrent = Column(
|
||||
@@ -1153,6 +1231,192 @@ class AuditEventType(PyEnum):
|
||||
DATA_EXPORT = "data_export"
|
||||
CONFIG_CHANGED = "config_changed"
|
||||
|
||||
# Management Token 相关
|
||||
MANAGEMENT_TOKEN_CREATED = "management_token_created"
|
||||
MANAGEMENT_TOKEN_UPDATED = "management_token_updated"
|
||||
MANAGEMENT_TOKEN_DELETED = "management_token_deleted"
|
||||
MANAGEMENT_TOKEN_USED = "management_token_used"
|
||||
MANAGEMENT_TOKEN_EXPIRED = "management_token_expired"
|
||||
MANAGEMENT_TOKEN_IP_BLOCKED = "management_token_ip_blocked"
|
||||
|
||||
|
||||
class ManagementToken(Base):
|
||||
"""Management Token 模型 - 用于程序化管理 API 调用"""
|
||||
|
||||
__tablename__ = "management_tokens"
|
||||
|
||||
# Token 格式常量
|
||||
TOKEN_PREFIX = "ae_"
|
||||
TOKEN_RANDOM_LENGTH = 40
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
||||
user_id = Column(String(36), ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
|
||||
# Token 信息
|
||||
token_hash = Column(String(64), unique=True, index=True, nullable=False) # SHA256 哈希
|
||||
token_prefix = Column(String(12), nullable=True) # Token 前缀用于显示(如 ae_xxxxxxxx)
|
||||
name = Column(String(100), nullable=False) # Token 名称
|
||||
description = Column(Text, nullable=True) # 描述
|
||||
|
||||
# IP 白名单(可选)
|
||||
allowed_ips = Column(JSON, nullable=True) # 允许的 IP 列表,NULL = 不限制
|
||||
# 格式: ["192.168.1.1", "10.0.0.0/24"]
|
||||
|
||||
# 有效期
|
||||
expires_at = Column(DateTime(timezone=True), nullable=True) # NULL = 永不过期
|
||||
|
||||
# 使用统计
|
||||
last_used_at = Column(DateTime(timezone=True), nullable=True)
|
||||
last_used_ip = Column(String(45), nullable=True)
|
||||
usage_count = Column(Integer, default=0) # 使用次数
|
||||
|
||||
# 状态
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# 关系
|
||||
user = relationship("User", back_populates="management_tokens")
|
||||
|
||||
# 索引和约束
|
||||
__table_args__ = (
|
||||
Index("idx_management_tokens_user_id", "user_id"),
|
||||
Index("idx_management_tokens_is_active", "is_active"),
|
||||
UniqueConstraint("user_id", "name", name="uq_management_tokens_user_name"),
|
||||
# IP 白名单必须为 NULL(不限制)或非空数组,禁止空数组
|
||||
# 注意:JSON 类型的 NULL 可能被序列化为 JSON 'null',需要同时处理
|
||||
CheckConstraint(
|
||||
"allowed_ips IS NULL OR allowed_ips::text = 'null' OR json_array_length(allowed_ips) > 0",
|
||||
name="check_allowed_ips_not_empty",
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_token() -> str:
|
||||
"""生成 Management Token(使用加密安全的随机数)"""
|
||||
import string
|
||||
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
random_part = "".join(
|
||||
secrets.choice(alphabet) for _ in range(ManagementToken.TOKEN_RANDOM_LENGTH)
|
||||
)
|
||||
return f"{ManagementToken.TOKEN_PREFIX}{random_part}"
|
||||
|
||||
@staticmethod
|
||||
def hash_token(token: str) -> str:
|
||||
"""对 Token 进行 SHA256 哈希
|
||||
|
||||
安全性说明(当前方案是安全的):
|
||||
- Token 熵为 62^40(约 2^238),暴力破解在计算上不可行
|
||||
- 结合速率限制(默认 30 次/分钟/IP),在线攻击不可行
|
||||
- 不需要盐值:盐值用于防止彩虹表攻击,但 Token 是高熵随机值,
|
||||
不存在可预计算的"常见值",因此彩虹表攻击不适用
|
||||
"""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
def set_token(self, token: str) -> None:
|
||||
"""设置 Token(只存储哈希和前缀用于显示)"""
|
||||
self.token_hash = self.hash_token(token)
|
||||
# 存储前缀用于显示(ae_ + 4 个字符,共 7 个字符)
|
||||
self.token_prefix = token[:7] if len(token) > 7 else token
|
||||
|
||||
def get_display_token(self) -> str:
|
||||
"""获取用于显示的脱敏 Token(显示前缀 + 掩码)"""
|
||||
if self.token_prefix:
|
||||
return f"{self.token_prefix}...****"
|
||||
return "ae_****"
|
||||
|
||||
def is_ip_allowed(self, client_ip: str) -> bool:
|
||||
"""检查 IP 是否在白名单中
|
||||
|
||||
安全策略:
|
||||
- None 或不设置表示不限制(允许所有 IP)
|
||||
- 非空列表表示只允许列表中的 IP
|
||||
- 无效的白名单条目会被记录并跳过
|
||||
- 无效的客户端 IP 直接拒绝
|
||||
- 支持 IPv4 映射的 IPv6 地址规范化
|
||||
"""
|
||||
if self.allowed_ips is None:
|
||||
return True # 未设置白名单,不限制
|
||||
|
||||
import ipaddress
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
# 防御性检查:空列表应该在数据库层被拒绝,但这里再检查一次
|
||||
if not self.allowed_ips:
|
||||
logger.critical(f"Management Token {self.id} - allowed_ips 为空列表(违反数据库约束)")
|
||||
return False # fail-safe
|
||||
|
||||
def normalize_ip(ip_str: str) -> ipaddress.IPv4Address | ipaddress.IPv6Address | None:
|
||||
"""规范化 IP 地址,将 IPv4 映射的 IPv6 转换为 IPv4"""
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped:
|
||||
return ip.ipv4_mapped
|
||||
return ip
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# 规范化客户端 IP
|
||||
client = normalize_ip(client_ip)
|
||||
if client is None:
|
||||
logger.error(f"Management Token {self.id} - 拒绝无效的客户端 IP: {client_ip}")
|
||||
return False
|
||||
|
||||
valid_entries = 0
|
||||
for allowed in self.allowed_ips:
|
||||
try:
|
||||
if "/" in allowed:
|
||||
# CIDR 格式
|
||||
network = ipaddress.ip_network(allowed, strict=False)
|
||||
valid_entries += 1
|
||||
if client in network:
|
||||
return True
|
||||
else:
|
||||
# 精确 IP
|
||||
allowed_ip = normalize_ip(allowed)
|
||||
if allowed_ip is None:
|
||||
logger.error(f"Management Token {self.id} - 白名单包含无效条目: {allowed}")
|
||||
continue
|
||||
valid_entries += 1
|
||||
if client == allowed_ip:
|
||||
return True
|
||||
except ValueError:
|
||||
logger.error(f"Management Token {self.id} - 白名单包含无效条目: {allowed}")
|
||||
continue
|
||||
|
||||
# 如果白名单全部无效,记录严重错误并拒绝
|
||||
if valid_entries == 0:
|
||||
logger.critical(f"Management Token {self.id} - 白名单全部无效,拒绝所有访问")
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""检查 Token 是否已过期(时区安全)"""
|
||||
if not self.expires_at:
|
||||
return False
|
||||
|
||||
expires = self.expires_at
|
||||
if expires.tzinfo is None:
|
||||
# 数据库中的时间应该有时区信息,如果没有则表示数据完整性问题
|
||||
from src.core.logger import logger
|
||||
|
||||
logger.error(f"Management Token {self.id} expires_at 缺少时区信息(数据完整性问题)")
|
||||
expires = expires.replace(tzinfo=timezone.utc)
|
||||
|
||||
return expires < datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
"""审计日志模型"""
|
||||
|
||||
@@ -24,7 +24,7 @@ class ProviderEndpointCreate(BaseModel):
|
||||
# 请求配置
|
||||
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
|
||||
timeout: int = Field(default=300, ge=10, le=600, description="超时时间(秒)")
|
||||
max_retries: int = Field(default=3, ge=0, le=10, description="最大重试次数")
|
||||
max_retries: int = Field(default=2, ge=0, le=10, description="最大重试次数")
|
||||
|
||||
# 限制
|
||||
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
Google Gemini API 请求/响应模型
|
||||
|
||||
支持 Gemini 3 Pro 及之前版本的 API 格式
|
||||
参考文档: https://ai.google.dev/gemini-api/docs/gemini-3
|
||||
支持 Gemini API 的请求/响应格式
|
||||
作为 API 网关,采用宽松类型定义以支持 API 新特性透传
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
@@ -17,282 +17,23 @@ class BaseModelWithExtras(BaseModel):
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 内容块定义
|
||||
# 内容定义 - 使用宽松类型以支持透传
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GeminiTextPart(BaseModelWithExtras):
|
||||
"""文本内容块"""
|
||||
|
||||
text: str
|
||||
thought_signature: Optional[str] = Field(
|
||||
default=None,
|
||||
alias="thoughtSignature",
|
||||
description="Gemini 3 思维签名,用于维护多轮对话中的推理上下文",
|
||||
)
|
||||
|
||||
|
||||
class GeminiInlineData(BaseModelWithExtras):
|
||||
"""内联数据(图片等)"""
|
||||
|
||||
mime_type: str = Field(alias="mimeType")
|
||||
data: str # base64 encoded
|
||||
|
||||
|
||||
class GeminiMediaResolution(BaseModelWithExtras):
|
||||
"""
|
||||
媒体分辨率配置 (Gemini 3 新增)
|
||||
|
||||
控制图片/视频的处理分辨率:
|
||||
- media_resolution_low: 图片 280 tokens, 视频 70 tokens/帧
|
||||
- media_resolution_medium: 图片 560 tokens, 视频 70 tokens/帧
|
||||
- media_resolution_high: 图片 1120 tokens, 视频 280 tokens/帧
|
||||
"""
|
||||
|
||||
level: Literal["media_resolution_low", "media_resolution_medium", "media_resolution_high"]
|
||||
|
||||
|
||||
class GeminiFileData(BaseModelWithExtras):
|
||||
"""文件引用"""
|
||||
|
||||
mime_type: Optional[str] = Field(default=None, alias="mimeType")
|
||||
file_uri: str = Field(alias="fileUri")
|
||||
|
||||
|
||||
class GeminiFunctionCall(BaseModelWithExtras):
|
||||
"""函数调用"""
|
||||
|
||||
name: str
|
||||
args: Dict[str, Any]
|
||||
|
||||
|
||||
class GeminiFunctionResponse(BaseModelWithExtras):
|
||||
"""函数响应"""
|
||||
|
||||
name: str
|
||||
response: Dict[str, Any]
|
||||
|
||||
|
||||
class GeminiPart(BaseModelWithExtras):
|
||||
"""
|
||||
Gemini 内容部分 - 支持多种类型
|
||||
|
||||
可以是以下类型之一:
|
||||
- text: 文本内容
|
||||
- inline_data: 内联数据(图片等)
|
||||
- file_data: 文件引用
|
||||
- function_call: 函数调用
|
||||
- function_response: 函数响应
|
||||
|
||||
Gemini 3 新增:
|
||||
- thought_signature: 思维签名,用于维护推理上下文
|
||||
- media_resolution: 媒体分辨率配置
|
||||
"""
|
||||
|
||||
text: Optional[str] = None
|
||||
inline_data: Optional[GeminiInlineData] = Field(default=None, alias="inlineData")
|
||||
file_data: Optional[GeminiFileData] = Field(default=None, alias="fileData")
|
||||
function_call: Optional[GeminiFunctionCall] = Field(default=None, alias="functionCall")
|
||||
function_response: Optional[GeminiFunctionResponse] = Field(
|
||||
default=None, alias="functionResponse"
|
||||
)
|
||||
# Gemini 3 新增
|
||||
thought_signature: Optional[str] = Field(
|
||||
default=None,
|
||||
alias="thoughtSignature",
|
||||
description="思维签名,用于函数调用和图片生成的上下文保持",
|
||||
)
|
||||
media_resolution: Optional[GeminiMediaResolution] = Field(
|
||||
default=None, alias="mediaResolution", description="媒体分辨率配置"
|
||||
)
|
||||
|
||||
|
||||
class GeminiContent(BaseModelWithExtras):
|
||||
"""
|
||||
Gemini 消息内容
|
||||
|
||||
对应 Gemini API 的 Content 对象
|
||||
使用宽松类型定义,parts 接受任意字典列表以支持 API 新特性
|
||||
"""
|
||||
|
||||
role: Optional[Literal["user", "model"]] = None
|
||||
parts: List[Union[GeminiPart, Dict[str, Any]]]
|
||||
role: Optional[str] = None
|
||||
parts: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 配置定义
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GeminiImageConfig(BaseModelWithExtras):
|
||||
"""
|
||||
图片生成配置 (Gemini 3 Pro Image)
|
||||
|
||||
用于 gemini-3-pro-image-preview 模型
|
||||
"""
|
||||
|
||||
aspect_ratio: Optional[str] = Field(
|
||||
default=None, alias="aspectRatio", description="图片宽高比,如 '16:9', '1:1', '4:3'"
|
||||
)
|
||||
image_size: Optional[Literal["2K", "4K"]] = Field(
|
||||
default=None, alias="imageSize", description="图片尺寸: 2K 或 4K"
|
||||
)
|
||||
|
||||
|
||||
class GeminiGenerationConfig(BaseModelWithExtras):
|
||||
"""
|
||||
生成配置
|
||||
|
||||
Gemini 3 新增:
|
||||
- thinking_level: 思考深度 (low/medium/high)
|
||||
- response_json_schema: 结构化输出的 JSON Schema
|
||||
- image_config: 图片生成配置
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = Field(
|
||||
default=None, description="采样温度,Gemini 3 建议保持默认值 1.0"
|
||||
)
|
||||
top_p: Optional[float] = Field(default=None, alias="topP")
|
||||
top_k: Optional[int] = Field(default=None, alias="topK")
|
||||
max_output_tokens: Optional[int] = Field(default=None, alias="maxOutputTokens")
|
||||
stop_sequences: Optional[List[str]] = Field(default=None, alias="stopSequences")
|
||||
candidate_count: Optional[int] = Field(default=None, alias="candidateCount")
|
||||
response_mime_type: Optional[str] = Field(default=None, alias="responseMimeType")
|
||||
response_schema: Optional[Dict[str, Any]] = Field(default=None, alias="responseSchema")
|
||||
# Gemini 3 新增
|
||||
response_json_schema: Optional[Dict[str, Any]] = Field(
|
||||
default=None, alias="responseJsonSchema", description="结构化输出的 JSON Schema"
|
||||
)
|
||||
thinking_level: Optional[Literal["low", "medium", "high"]] = Field(
|
||||
default=None,
|
||||
alias="thinkingLevel",
|
||||
description="Gemini 3 思考深度: low(快速), medium(平衡), high(深度推理,默认)",
|
||||
)
|
||||
image_config: Optional[GeminiImageConfig] = Field(
|
||||
default=None, alias="imageConfig", description="图片生成配置"
|
||||
)
|
||||
|
||||
|
||||
class GeminiSafetySettings(BaseModelWithExtras):
|
||||
"""安全设置"""
|
||||
|
||||
category: str
|
||||
threshold: str
|
||||
|
||||
|
||||
class GeminiFunctionDeclaration(BaseModelWithExtras):
|
||||
"""函数声明"""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class GeminiGoogleSearchTool(BaseModelWithExtras):
|
||||
"""Google Search 工具 (Gemini 3)"""
|
||||
|
||||
pass # 空对象表示启用
|
||||
|
||||
|
||||
class GeminiUrlContextTool(BaseModelWithExtras):
|
||||
"""URL Context 工具 (Gemini 3)"""
|
||||
|
||||
pass # 空对象表示启用
|
||||
|
||||
|
||||
class GeminiCodeExecutionTool(BaseModelWithExtras):
|
||||
"""代码执行工具"""
|
||||
|
||||
pass # 空对象表示启用
|
||||
|
||||
|
||||
class GeminiTool(BaseModelWithExtras):
|
||||
"""
|
||||
工具定义
|
||||
|
||||
支持的工具类型:
|
||||
- function_declarations: 自定义函数
|
||||
- code_execution: 代码执行
|
||||
- google_search: Google 搜索 (Gemini 3)
|
||||
- url_context: URL 上下文 (Gemini 3)
|
||||
"""
|
||||
|
||||
function_declarations: Optional[List[GeminiFunctionDeclaration]] = Field(
|
||||
default=None, alias="functionDeclarations"
|
||||
)
|
||||
code_execution: Optional[Dict[str, Any]] = Field(default=None, alias="codeExecution")
|
||||
# Gemini 3 内置工具
|
||||
google_search: Optional[Dict[str, Any]] = Field(
|
||||
default=None, alias="googleSearch", description="启用 Google 搜索工具"
|
||||
)
|
||||
url_context: Optional[Dict[str, Any]] = Field(
|
||||
default=None, alias="urlContext", description="启用 URL 上下文工具"
|
||||
)
|
||||
|
||||
|
||||
class GeminiToolConfig(BaseModelWithExtras):
|
||||
"""工具配置"""
|
||||
|
||||
function_calling_config: Optional[Dict[str, Any]] = Field(
|
||||
default=None, alias="functionCallingConfig"
|
||||
)
|
||||
|
||||
|
||||
class GeminiSystemInstruction(BaseModelWithExtras):
|
||||
"""系统指令"""
|
||||
|
||||
parts: List[Union[GeminiPart, Dict[str, Any]]]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 请求模型
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GeminiGenerateContentRequest(BaseModelWithExtras):
|
||||
"""
|
||||
Gemini generateContent 请求模型
|
||||
|
||||
对应 POST /v1beta/models/{model}:generateContent 端点
|
||||
"""
|
||||
|
||||
contents: List[GeminiContent]
|
||||
system_instruction: Optional[GeminiSystemInstruction] = Field(
|
||||
default=None, alias="systemInstruction"
|
||||
)
|
||||
tools: Optional[List[GeminiTool]] = None
|
||||
tool_config: Optional[GeminiToolConfig] = Field(default=None, alias="toolConfig")
|
||||
safety_settings: Optional[List[GeminiSafetySettings]] = Field(
|
||||
default=None, alias="safetySettings"
|
||||
)
|
||||
generation_config: Optional[GeminiGenerationConfig] = Field(
|
||||
default=None, alias="generationConfig"
|
||||
)
|
||||
|
||||
|
||||
class GeminiStreamGenerateContentRequest(BaseModelWithExtras):
|
||||
"""
|
||||
Gemini streamGenerateContent 请求模型
|
||||
|
||||
对应 POST /v1beta/models/{model}:streamGenerateContent 端点
|
||||
与 generateContent 相同,但返回流式响应
|
||||
"""
|
||||
|
||||
contents: List[GeminiContent]
|
||||
system_instruction: Optional[GeminiSystemInstruction] = Field(
|
||||
default=None, alias="systemInstruction"
|
||||
)
|
||||
tools: Optional[List[GeminiTool]] = None
|
||||
tool_config: Optional[GeminiToolConfig] = Field(default=None, alias="toolConfig")
|
||||
safety_settings: Optional[List[GeminiSafetySettings]] = Field(
|
||||
default=None, alias="safetySettings"
|
||||
)
|
||||
generation_config: Optional[GeminiGenerationConfig] = Field(
|
||||
default=None, alias="generationConfig"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 统一请求模型(用于内部处理)
|
||||
# 请求模型 - 只定义网关需要的字段,其余透传
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -306,158 +47,31 @@ class GeminiRequest(BaseModelWithExtras):
|
||||
- generateContent - 非流式
|
||||
- streamGenerateContent - 流式
|
||||
请求体中不应包含 stream 字段
|
||||
|
||||
采用宽松类型定义,除必要字段外全部透传
|
||||
"""
|
||||
|
||||
model: Optional[str] = Field(default=None, description="模型名称,从 URL 路径提取(内部使用)")
|
||||
contents: List[GeminiContent]
|
||||
system_instruction: Optional[GeminiSystemInstruction] = Field(
|
||||
default=None, alias="systemInstruction"
|
||||
)
|
||||
tools: Optional[List[GeminiTool]] = None
|
||||
tool_config: Optional[GeminiToolConfig] = Field(default=None, alias="toolConfig")
|
||||
safety_settings: Optional[List[GeminiSafetySettings]] = Field(
|
||||
default=None, alias="safetySettings"
|
||||
)
|
||||
generation_config: Optional[GeminiGenerationConfig] = Field(
|
||||
default=None, alias="generationConfig"
|
||||
)
|
||||
# 以下字段全部使用 Dict[str, Any] 透传,不做结构验证
|
||||
system_instruction: Optional[Dict[str, Any]] = Field(default=None, alias="systemInstruction")
|
||||
tools: Optional[List[Dict[str, Any]]] = None
|
||||
tool_config: Optional[Dict[str, Any]] = Field(default=None, alias="toolConfig")
|
||||
safety_settings: Optional[List[Dict[str, Any]]] = Field(default=None, alias="safetySettings")
|
||||
generation_config: Optional[Dict[str, Any]] = Field(default=None, alias="generationConfig")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 响应模型
|
||||
# 响应模型 - 用于解析上游响应提取必要信息(如 usage)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GeminiUsageMetadata(BaseModelWithExtras):
|
||||
"""Token 使用量"""
|
||||
"""Token 使用量 - 用于计费统计"""
|
||||
|
||||
prompt_token_count: int = Field(default=0, alias="promptTokenCount")
|
||||
candidates_token_count: int = Field(default=0, alias="candidatesTokenCount")
|
||||
total_token_count: int = Field(default=0, alias="totalTokenCount")
|
||||
cached_content_token_count: Optional[int] = Field(default=None, alias="cachedContentTokenCount")
|
||||
|
||||
|
||||
class GeminiSafetyRating(BaseModelWithExtras):
|
||||
"""安全评级"""
|
||||
|
||||
category: str
|
||||
probability: str
|
||||
blocked: Optional[bool] = None
|
||||
|
||||
|
||||
class GeminiCitationSource(BaseModelWithExtras):
|
||||
"""引用来源"""
|
||||
|
||||
start_index: Optional[int] = Field(default=None, alias="startIndex")
|
||||
end_index: Optional[int] = Field(default=None, alias="endIndex")
|
||||
uri: Optional[str] = None
|
||||
license: Optional[str] = None
|
||||
|
||||
|
||||
class GeminiCitationMetadata(BaseModelWithExtras):
|
||||
"""引用元数据"""
|
||||
|
||||
citation_sources: Optional[List[GeminiCitationSource]] = Field(
|
||||
default=None, alias="citationSources"
|
||||
)
|
||||
|
||||
|
||||
class GeminiGroundingMetadata(BaseModelWithExtras):
|
||||
"""
|
||||
Grounding 元数据 (Gemini 3)
|
||||
|
||||
当使用 Google Search 工具时返回
|
||||
"""
|
||||
|
||||
search_entry_point: Optional[Dict[str, Any]] = Field(default=None, alias="searchEntryPoint")
|
||||
grounding_chunks: Optional[List[Dict[str, Any]]] = Field(default=None, alias="groundingChunks")
|
||||
grounding_supports: Optional[List[Dict[str, Any]]] = Field(
|
||||
default=None, alias="groundingSupports"
|
||||
)
|
||||
web_search_queries: Optional[List[str]] = Field(default=None, alias="webSearchQueries")
|
||||
|
||||
|
||||
class GeminiCandidate(BaseModelWithExtras):
|
||||
"""候选响应"""
|
||||
|
||||
content: Optional[GeminiContent] = None
|
||||
finish_reason: Optional[str] = Field(default=None, alias="finishReason")
|
||||
safety_ratings: Optional[List[GeminiSafetyRating]] = Field(default=None, alias="safetyRatings")
|
||||
citation_metadata: Optional[GeminiCitationMetadata] = Field(
|
||||
default=None, alias="citationMetadata"
|
||||
)
|
||||
grounding_metadata: Optional[GeminiGroundingMetadata] = Field(
|
||||
default=None, alias="groundingMetadata"
|
||||
)
|
||||
token_count: Optional[int] = Field(default=None, alias="tokenCount")
|
||||
index: Optional[int] = None
|
||||
|
||||
|
||||
class GeminiPromptFeedback(BaseModelWithExtras):
|
||||
"""提示反馈"""
|
||||
|
||||
block_reason: Optional[str] = Field(default=None, alias="blockReason")
|
||||
safety_ratings: Optional[List[GeminiSafetyRating]] = Field(default=None, alias="safetyRatings")
|
||||
|
||||
|
||||
class GeminiGenerateContentResponse(BaseModelWithExtras):
|
||||
"""
|
||||
Gemini generateContent 响应模型
|
||||
|
||||
对应 generateContent 端点的响应体
|
||||
"""
|
||||
|
||||
candidates: Optional[List[GeminiCandidate]] = None
|
||||
prompt_feedback: Optional[GeminiPromptFeedback] = Field(default=None, alias="promptFeedback")
|
||||
usage_metadata: Optional[GeminiUsageMetadata] = Field(default=None, alias="usageMetadata")
|
||||
model_version: Optional[str] = Field(default=None, alias="modelVersion")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 流式响应模型
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GeminiStreamChunk(BaseModelWithExtras):
|
||||
"""
|
||||
Gemini 流式响应块
|
||||
|
||||
流式响应中的单个数据块,结构与完整响应相同
|
||||
"""
|
||||
|
||||
candidates: Optional[List[GeminiCandidate]] = None
|
||||
prompt_feedback: Optional[GeminiPromptFeedback] = Field(default=None, alias="promptFeedback")
|
||||
usage_metadata: Optional[GeminiUsageMetadata] = Field(default=None, alias="usageMetadata")
|
||||
model_version: Optional[str] = Field(default=None, alias="modelVersion")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 错误响应
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GeminiErrorDetail(BaseModelWithExtras):
|
||||
"""错误详情"""
|
||||
|
||||
type: Optional[str] = Field(default=None, alias="@type")
|
||||
reason: Optional[str] = None
|
||||
domain: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class GeminiError(BaseModelWithExtras):
|
||||
"""错误信息"""
|
||||
|
||||
code: int
|
||||
message: str
|
||||
status: str
|
||||
details: Optional[List[GeminiErrorDetail]] = None
|
||||
|
||||
|
||||
class GeminiErrorResponse(BaseModelWithExtras):
|
||||
"""错误响应"""
|
||||
|
||||
error: GeminiError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
363
src/services/auth/ldap.py
Normal file
363
src/services/auth/ldap.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""LDAP 认证服务"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import LDAPConfig
|
||||
|
||||
# LDAP 连接默认超时时间(秒)
|
||||
DEFAULT_LDAP_CONNECT_TIMEOUT = 10
|
||||
|
||||
|
||||
def parse_ldap_server_url(server_url: str) -> tuple[str, int, bool]:
|
||||
"""
|
||||
解析 LDAP 服务器地址,支持:
|
||||
- ldap://host:389
|
||||
- ldaps://host:636
|
||||
- host:389(无 scheme 时默认 ldap)
|
||||
|
||||
Returns:
|
||||
(host, port, use_ssl)
|
||||
"""
|
||||
raw = (server_url or "").strip()
|
||||
if not raw:
|
||||
raise ValueError("LDAP server_url is required")
|
||||
|
||||
parsed = urlparse(raw)
|
||||
if parsed.scheme in {"ldap", "ldaps"}:
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise ValueError("Invalid LDAP server_url")
|
||||
use_ssl = parsed.scheme == "ldaps"
|
||||
port = parsed.port or (636 if use_ssl else 389)
|
||||
return host, port, use_ssl
|
||||
|
||||
# 兼容无 scheme:按 ldap:// 解析
|
||||
parsed = urlparse(f"ldap://{raw}")
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise ValueError("Invalid LDAP server_url")
|
||||
port = parsed.port or 389
|
||||
return host, port, False
|
||||
|
||||
|
||||
def escape_ldap_filter(value: str, max_length: int = 128) -> str:
|
||||
"""
|
||||
转义 LDAP 过滤器中的特殊字符,防止 LDAP 注入攻击(RFC 4515)
|
||||
|
||||
Args:
|
||||
value: 需要转义的字符串
|
||||
max_length: 最大允许长度,默认 128 字符(覆盖大多数企业邮箱用户名)
|
||||
|
||||
Returns:
|
||||
转义后的安全字符串
|
||||
|
||||
Raises:
|
||||
ValueError: 输入值过长
|
||||
"""
|
||||
import unicodedata
|
||||
|
||||
# 先检查原始长度,防止 DoS 攻击
|
||||
# 128 字符足够覆盖大多数企业用户名和邮箱地址
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f"LDAP filter value too long (max {max_length} characters)")
|
||||
|
||||
# Unicode 规范化(使用 NFC 而非 NFKC,避免兼容性字符转换导致安全问题)
|
||||
value = unicodedata.normalize("NFC", value)
|
||||
|
||||
# 再次检查规范化后的长度(防止规范化后长度突增)
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f"LDAP filter value too long after normalization (max {max_length})")
|
||||
|
||||
# LDAP 过滤器特殊字符(RFC 4515 + 扩展)
|
||||
# 使用显式顺序处理,确保反斜杠首先转义
|
||||
value = value.replace("\\", r"\5c") # 反斜杠必须首先转义
|
||||
value = value.replace("*", r"\2a")
|
||||
value = value.replace("(", r"\28")
|
||||
value = value.replace(")", r"\29")
|
||||
value = value.replace("\x00", r"\00") # NUL
|
||||
value = value.replace("&", r"\26")
|
||||
value = value.replace("|", r"\7c")
|
||||
value = value.replace("=", r"\3d")
|
||||
value = value.replace(">", r"\3e")
|
||||
value = value.replace("<", r"\3c")
|
||||
value = value.replace("~", r"\7e")
|
||||
value = value.replace("!", r"\21")
|
||||
return value
|
||||
|
||||
|
||||
def _get_attr_value(entry: Any, attr_name: str, default: str = "") -> str:
|
||||
"""
|
||||
提取 LDAP 条目属性的首个值,避免返回字符串化的列表表示。
|
||||
"""
|
||||
attr = getattr(entry, attr_name, None)
|
||||
if not attr:
|
||||
return default
|
||||
# ldap3 的 EntryAttribute.value 已经是单值或列表,根据类型取首个
|
||||
val = getattr(attr, "value", None)
|
||||
if isinstance(val, list):
|
||||
val = val[0] if val else default
|
||||
if val is None:
|
||||
return default
|
||||
return str(val)
|
||||
|
||||
|
||||
class LDAPService:
|
||||
"""LDAP 认证服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_config(db: Session) -> Optional[LDAPConfig]:
|
||||
"""获取 LDAP 配置"""
|
||||
return db.query(LDAPConfig).first()
|
||||
|
||||
@staticmethod
|
||||
def is_ldap_enabled(db: Session) -> bool:
|
||||
"""检查 LDAP 是否可用(已启用且绑定密码可解密)"""
|
||||
return LDAPService.get_config_data(db) is not None
|
||||
|
||||
@staticmethod
|
||||
def is_ldap_exclusive(db: Session) -> bool:
|
||||
"""检查是否仅允许 LDAP 登录(仅在 LDAP 可用时生效,避免误锁定)"""
|
||||
config = LDAPService.get_config(db)
|
||||
if not config or config.is_exclusive is not True:
|
||||
return False
|
||||
return LDAPService.get_config_data(db) is not None
|
||||
|
||||
@staticmethod
|
||||
def get_config_data(db: Session) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
提前获取并解密配置,供线程池使用,避免跨线程共享 Session。
|
||||
"""
|
||||
config = LDAPService.get_config(db)
|
||||
if not config or config.is_enabled is not True:
|
||||
return None
|
||||
|
||||
try:
|
||||
bind_password = config.get_bind_password()
|
||||
except Exception as e:
|
||||
logger.error(f"LDAP 绑定密码解密失败: {e}")
|
||||
return None
|
||||
|
||||
# 绑定密码为空时无法进行 LDAP 认证
|
||||
if not bind_password:
|
||||
logger.warning("LDAP 绑定密码未配置,无法进行 LDAP 认证")
|
||||
return None
|
||||
|
||||
return {
|
||||
"server_url": config.server_url,
|
||||
"bind_dn": config.bind_dn,
|
||||
"bind_password": bind_password,
|
||||
"base_dn": config.base_dn,
|
||||
"user_search_filter": config.user_search_filter,
|
||||
"username_attr": config.username_attr,
|
||||
"email_attr": config.email_attr,
|
||||
"display_name_attr": config.display_name_attr,
|
||||
"use_starttls": config.use_starttls,
|
||||
"connect_timeout": config.connect_timeout or DEFAULT_LDAP_CONNECT_TIMEOUT,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def authenticate_with_config(config: Dict[str, Any], username: str, password: str) -> Optional[dict]:
|
||||
"""
|
||||
LDAP bind 验证
|
||||
|
||||
Args:
|
||||
config: 已解密的 LDAP 配置
|
||||
username: 用户名
|
||||
password: 密码
|
||||
|
||||
Returns:
|
||||
用户属性 dict {username, email, display_name} 或 None
|
||||
"""
|
||||
try:
|
||||
import ldap3
|
||||
from ldap3 import Server, Connection, SUBTREE
|
||||
from ldap3.core.exceptions import LDAPBindError, LDAPSocketOpenError
|
||||
except ImportError:
|
||||
logger.error("ldap3 库未安装")
|
||||
return None
|
||||
|
||||
if not config:
|
||||
logger.warning("LDAP 未配置或未启用")
|
||||
return None
|
||||
|
||||
admin_conn = None
|
||||
user_conn = None
|
||||
|
||||
try:
|
||||
# 创建服务器连接
|
||||
server_url = config["server_url"]
|
||||
server_host, server_port, use_ssl = parse_ldap_server_url(server_url)
|
||||
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
|
||||
server = Server(
|
||||
server_host,
|
||||
port=server_port,
|
||||
use_ssl=use_ssl,
|
||||
get_info=ldap3.ALL,
|
||||
connect_timeout=timeout,
|
||||
)
|
||||
|
||||
# 使用管理员账号连接
|
||||
bind_password = config["bind_password"]
|
||||
admin_conn = Connection(
|
||||
server,
|
||||
user=config["bind_dn"],
|
||||
password=bind_password,
|
||||
receive_timeout=timeout, # 添加读取超时,避免服务器响应缓慢时阻塞
|
||||
)
|
||||
|
||||
if config.get("use_starttls") and not use_ssl:
|
||||
admin_conn.start_tls()
|
||||
|
||||
if not admin_conn.bind():
|
||||
logger.error(f"LDAP 管理员绑定失败: {admin_conn.result}")
|
||||
return None
|
||||
|
||||
# 搜索用户(转义用户名防止 LDAP 注入)
|
||||
safe_username = escape_ldap_filter(username)
|
||||
search_filter = config["user_search_filter"].replace("{username}", safe_username)
|
||||
admin_conn.search(
|
||||
search_base=config["base_dn"],
|
||||
search_filter=search_filter,
|
||||
search_scope=SUBTREE,
|
||||
size_limit=2, # 防止过滤器误配导致匹配多用户
|
||||
time_limit=timeout, # 添加搜索超时,防止大型目录搜索阻塞
|
||||
attributes=[
|
||||
config["username_attr"],
|
||||
config["email_attr"],
|
||||
config["display_name_attr"],
|
||||
],
|
||||
)
|
||||
|
||||
if len(admin_conn.entries) != 1:
|
||||
# 统一错误信息,避免泄露用户是否存在;日志仅记录结果数量,不泄露敏感信息
|
||||
logger.warning(
|
||||
f"LDAP 认证失败(用户查找阶段): 搜索返回 {len(admin_conn.entries)} 条结果"
|
||||
)
|
||||
return None
|
||||
|
||||
user_entry = admin_conn.entries[0]
|
||||
user_dn = user_entry.entry_dn
|
||||
|
||||
# 用户密码验证
|
||||
user_conn = Connection(
|
||||
server,
|
||||
user=user_dn,
|
||||
password=password,
|
||||
receive_timeout=timeout, # 添加读取超时
|
||||
)
|
||||
if config.get("use_starttls") and not use_ssl:
|
||||
user_conn.start_tls()
|
||||
|
||||
if not user_conn.bind():
|
||||
# 统一错误信息,避免泄露密码是否正确;日志仅记录错误码,不泄露用户 DN
|
||||
bind_result = user_conn.result.get("description", "unknown")
|
||||
logger.warning(f"LDAP 认证失败(密码验证阶段): {bind_result}")
|
||||
return None
|
||||
|
||||
# 提取用户属性(优先用 LDAP 提供的值,不合法则回退默认)
|
||||
ldap_username = _get_attr_value(user_entry, config["username_attr"], username)
|
||||
email = _get_attr_value(
|
||||
user_entry, config["email_attr"], f"{username}@ldap.local"
|
||||
)
|
||||
display_name = _get_attr_value(user_entry, config["display_name_attr"], username)
|
||||
|
||||
logger.info(f"LDAP 认证成功: {username}")
|
||||
return {
|
||||
"username": ldap_username,
|
||||
"ldap_username": ldap_username,
|
||||
"ldap_dn": user_dn,
|
||||
"email": email,
|
||||
"display_name": display_name,
|
||||
}
|
||||
|
||||
except LDAPSocketOpenError as e:
|
||||
logger.error(f"LDAP 服务器连接失败: {e}")
|
||||
return None
|
||||
except LDAPBindError as e:
|
||||
logger.error(f"LDAP 绑定失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"LDAP 认证异常: {e}")
|
||||
return None
|
||||
finally:
|
||||
# 确保连接关闭,避免失败路径泄漏
|
||||
# 使用循环确保即使第一个 unbind 失败,后续连接仍会尝试关闭
|
||||
for conn, name in [(admin_conn, "admin"), (user_conn, "user")]:
|
||||
if conn:
|
||||
try:
|
||||
conn.unbind()
|
||||
except Exception as e:
|
||||
logger.warning(f"LDAP {name} 连接关闭失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def test_connection_with_config(config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""
|
||||
测试 LDAP 连接
|
||||
|
||||
Returns:
|
||||
(success, message)
|
||||
"""
|
||||
try:
|
||||
import ldap3
|
||||
from ldap3 import Server, Connection
|
||||
except ImportError:
|
||||
return False, "ldap3 库未安装"
|
||||
|
||||
if not config:
|
||||
return False, "LDAP 配置不存在"
|
||||
|
||||
conn = None
|
||||
try:
|
||||
server_url = config["server_url"]
|
||||
server_host, server_port, use_ssl = parse_ldap_server_url(server_url)
|
||||
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
|
||||
server = Server(
|
||||
server_host,
|
||||
port=server_port,
|
||||
use_ssl=use_ssl,
|
||||
get_info=ldap3.ALL,
|
||||
connect_timeout=timeout,
|
||||
)
|
||||
bind_password = config["bind_password"]
|
||||
conn = Connection(
|
||||
server,
|
||||
user=config["bind_dn"],
|
||||
password=bind_password,
|
||||
receive_timeout=timeout, # 添加读取超时
|
||||
)
|
||||
|
||||
if config.get("use_starttls") and not use_ssl:
|
||||
conn.start_tls()
|
||||
|
||||
if not conn.bind():
|
||||
return False, f"绑定失败: {conn.result}"
|
||||
|
||||
return True, "连接成功"
|
||||
|
||||
except Exception as e:
|
||||
# 记录详细错误到日志,但只返回通用信息给前端,避免泄露敏感信息
|
||||
logger.error(f"LDAP 测试连接失败: {type(e).__name__}: {e}")
|
||||
return False, "连接失败,请检查服务器地址、端口和凭据"
|
||||
finally:
|
||||
if conn:
|
||||
try:
|
||||
conn.unbind()
|
||||
except Exception as e:
|
||||
logger.warning(f"LDAP 测试连接关闭失败: {e}")
|
||||
|
||||
# 兼容旧接口:如果其他代码直接调用
|
||||
@staticmethod
|
||||
def authenticate(db: Session, username: str, password: str) -> Optional[dict]:
|
||||
config = LDAPService.get_config_data(db)
|
||||
return LDAPService.authenticate_with_config(config, username, password) if config else None
|
||||
|
||||
@staticmethod
|
||||
def test_connection(db: Session) -> Tuple[bool, str]:
|
||||
config = LDAPService.get_config_data(db)
|
||||
if not config:
|
||||
return False, "LDAP 配置不存在或未启用"
|
||||
return LDAPService.test_connection_with_config(config)
|
||||
@@ -2,21 +2,30 @@
|
||||
认证服务
|
||||
"""
|
||||
|
||||
import os
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.config import config
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.logger import logger
|
||||
from src.core.enums import AuthSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.models.database import ManagementToken
|
||||
from src.models.database import ApiKey, User, UserRole
|
||||
from src.services.auth.jwt_blacklist import JWTBlacklistService
|
||||
from src.services.auth.ldap import LDAPService
|
||||
from src.services.cache.user_cache import UserCacheService
|
||||
from src.services.user.apikey import ApiKeyService
|
||||
|
||||
@@ -92,15 +101,86 @@ class AuthService:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的Token")
|
||||
|
||||
@staticmethod
|
||||
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
"""用户登录认证"""
|
||||
async def authenticate_user(
|
||||
db: Session, email: str, password: str, auth_type: str = "local"
|
||||
) -> Optional[User]:
|
||||
"""用户登录认证
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
email: 邮箱/用户名
|
||||
password: 密码
|
||||
auth_type: 认证类型 ("local" 或 "ldap")
|
||||
"""
|
||||
if auth_type == "ldap":
|
||||
# LDAP 认证
|
||||
# 预取配置,避免将 Session 传递到线程池
|
||||
config_data = LDAPService.get_config_data(db)
|
||||
if not config_data:
|
||||
logger.warning("登录失败 - LDAP 未启用或配置无效")
|
||||
return None
|
||||
|
||||
# 计算总体超时:LDAP 认证包含多次网络操作(连接、管理员绑定、搜索、用户绑定)
|
||||
# 超时策略:
|
||||
# - 单次操作超时(connect_timeout):控制每次网络操作的最大等待时间
|
||||
# - 总体超时:防止异常场景(如服务器响应缓慢但未超时)导致请求堆积
|
||||
# - 公式:单次超时 × 4(覆盖 4 次主要网络操作)+ 10% 缓冲
|
||||
# - 最小 20 秒(保证基本操作),最大 60 秒(避免用户等待过长)
|
||||
single_timeout = config_data.get("connect_timeout", 10)
|
||||
total_timeout = max(20, min(int(single_timeout * 4 * 1.1), 60))
|
||||
|
||||
# 在线程池中执行阻塞的 LDAP 网络请求,避免阻塞事件循环
|
||||
# 添加总体超时保护,防止异常场景下请求堆积
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
ldap_user = await asyncio.wait_for(
|
||||
run_in_threadpool(
|
||||
LDAPService.authenticate_with_config, config_data, email, password
|
||||
),
|
||||
timeout=total_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"LDAP 认证总体超时({total_timeout}秒): {email}")
|
||||
return None
|
||||
|
||||
if not ldap_user:
|
||||
return None
|
||||
|
||||
# 获取或创建本地用户
|
||||
user = await AuthService._get_or_create_ldap_user(db, ldap_user)
|
||||
if not user:
|
||||
# 已有本地账号但来源不匹配等情况
|
||||
return None
|
||||
if not user.is_active:
|
||||
logger.warning(f"登录失败 - 用户已禁用: {email}")
|
||||
return None
|
||||
return user
|
||||
|
||||
# 本地认证
|
||||
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
# 支持邮箱或用户名登录
|
||||
from sqlalchemy import or_
|
||||
user = db.query(User).filter(
|
||||
or_(User.email == email, User.username == email)
|
||||
).first()
|
||||
|
||||
if not user:
|
||||
logger.warning(f"登录失败 - 用户不存在: {email}")
|
||||
return None
|
||||
|
||||
# 检查 LDAP exclusive 模式:仅允许本地管理员登录(紧急恢复通道)
|
||||
if LDAPService.is_ldap_exclusive(db):
|
||||
if user.role != UserRole.ADMIN or user.auth_source != AuthSource.LOCAL:
|
||||
logger.warning(f"登录失败 - 仅允许 LDAP 登录(管理员除外): {email}")
|
||||
return None
|
||||
logger.warning(f"[LDAP-EXCLUSIVE] 紧急恢复通道:本地管理员登录: {email}")
|
||||
|
||||
# 检查用户认证来源
|
||||
if user.auth_source == AuthSource.LDAP:
|
||||
logger.warning(f"登录失败 - 该用户使用 LDAP 认证: {email}")
|
||||
return None
|
||||
|
||||
if not user.verify_password(password):
|
||||
logger.warning(f"登录失败 - 密码错误: {email}")
|
||||
return None
|
||||
@@ -118,6 +198,127 @@ class AuthService:
|
||||
logger.info(f"用户登录成功: {email} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def _get_or_create_ldap_user(db: Session, ldap_user: dict) -> Optional[User]:
|
||||
"""获取或创建 LDAP 用户
|
||||
|
||||
Args:
|
||||
ldap_user: LDAP 用户信息 {username, email, display_name, ldap_dn, ldap_username}
|
||||
|
||||
注意:使用 with_for_update() 防止并发首次登录创建重复用户
|
||||
"""
|
||||
ldap_dn = (ldap_user.get("ldap_dn") or "").strip() or None
|
||||
ldap_username = (ldap_user.get("ldap_username") or ldap_user.get("username") or "").strip() or None
|
||||
email = ldap_user["email"]
|
||||
|
||||
# 优先用稳定标识查找,避免邮箱变更/用户名冲突导致重复建号
|
||||
# 使用 with_for_update() 锁定行,防止并发创建
|
||||
user: Optional[User] = None
|
||||
if ldap_dn:
|
||||
user = (
|
||||
db.query(User)
|
||||
.filter(User.auth_source == AuthSource.LDAP, User.ldap_dn == ldap_dn)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not user and ldap_username:
|
||||
user = (
|
||||
db.query(User)
|
||||
.filter(User.auth_source == AuthSource.LDAP, User.ldap_username == ldap_username)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
# 最后回退按 email 查找:如果存在同邮箱的本地账号,需要拒绝以避免接管
|
||||
user = db.query(User).filter(User.email == email).with_for_update().first()
|
||||
|
||||
if user:
|
||||
if user.auth_source != AuthSource.LDAP:
|
||||
# 避免覆盖已有本地账户(不同来源时拒绝登录)
|
||||
logger.warning(
|
||||
f"LDAP 登录拒绝 - 账户来源不匹配(现有:{user.auth_source}, 请求:LDAP): {email}"
|
||||
)
|
||||
return None
|
||||
|
||||
# 同步邮箱(LDAP 侧邮箱变更时更新;若新邮箱已被占用则拒绝)
|
||||
if user.email != email:
|
||||
email_taken = (
|
||||
db.query(User)
|
||||
.filter(User.email == email, User.id != user.id)
|
||||
.first()
|
||||
)
|
||||
if email_taken:
|
||||
logger.warning(f"LDAP 登录拒绝 - 新邮箱已被占用: {email}")
|
||||
return None
|
||||
user.email = email
|
||||
|
||||
# 同步 LDAP 标识(首次填充或 LDAP 侧发生变化)
|
||||
if ldap_dn and user.ldap_dn != ldap_dn:
|
||||
user.ldap_dn = ldap_dn
|
||||
if ldap_username and user.ldap_username != ldap_username:
|
||||
user.ldap_username = ldap_username
|
||||
|
||||
user.last_login_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
await UserCacheService.invalidate_user_cache(user.id, user.email)
|
||||
logger.info(f"LDAP 用户登录成功: {ldap_user['email']} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
# 检查 username 是否已被占用,使用时间戳+随机数确保唯一性
|
||||
base_username = ldap_username or ldap_user["username"]
|
||||
username = base_username
|
||||
max_retries = 3
|
||||
|
||||
for attempt in range(max_retries):
|
||||
# 检查用户名是否已存在
|
||||
existing_user_with_username = db.query(User).filter(User.username == username).first()
|
||||
if existing_user_with_username:
|
||||
# 如果 username 已存在,使用时间戳+随机数确保唯一性
|
||||
username = f"{base_username}_ldap_{int(time.time())}{uuid.uuid4().hex[:4]}"
|
||||
logger.info(f"LDAP 用户名冲突,使用新用户名: {ldap_user['username']} -> {username}")
|
||||
|
||||
# 创建新用户
|
||||
user = User(
|
||||
email=email,
|
||||
username=username,
|
||||
password_hash="", # LDAP 用户无本地密码
|
||||
auth_source=AuthSource.LDAP,
|
||||
ldap_dn=ldap_dn,
|
||||
ldap_username=ldap_username,
|
||||
role=UserRole.USER,
|
||||
is_active=True,
|
||||
last_login_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
try:
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
logger.info(f"LDAP 用户创建成功: {ldap_user['email']} (ID: {user.id})")
|
||||
return user
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_str = str(e.orig).lower() if e.orig else str(e).lower()
|
||||
|
||||
# 解析具体冲突类型
|
||||
if "email" in error_str or "ix_users_email" in error_str:
|
||||
# 邮箱冲突不应重试(前面已检查过,说明是并发创建)
|
||||
logger.error(f"LDAP 用户创建失败 - 邮箱并发冲突: {email}")
|
||||
return None
|
||||
elif "username" in error_str or "ix_users_username" in error_str:
|
||||
# 用户名冲突,重试时会生成新用户名
|
||||
if attempt == max_retries - 1:
|
||||
logger.error(f"LDAP 用户创建失败(用户名冲突重试耗尽): {username}")
|
||||
return None
|
||||
username = f"{base_username}_ldap_{int(time.time())}{uuid.uuid4().hex[:4]}"
|
||||
logger.warning(f"LDAP 用户创建用户名冲突,重试 ({attempt + 1}/{max_retries}): {username}")
|
||||
else:
|
||||
# 其他约束冲突,不重试
|
||||
logger.error(f"LDAP 用户创建失败 - 未知数据库约束冲突: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def authenticate_api_key(db: Session, api_key: str) -> Optional[tuple[User, ApiKey]]:
|
||||
"""API密钥认证"""
|
||||
@@ -282,3 +483,137 @@ class AuthService:
|
||||
except Exception as e:
|
||||
logger.error(f"撤销 Token 失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def authenticate_management_token(
|
||||
db: Session, raw_token: str, client_ip: str
|
||||
) -> Optional[tuple[User, "ManagementToken"]]:
|
||||
"""Management Token 认证
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
raw_token: Management Token 字符串
|
||||
client_ip: 客户端 IP
|
||||
|
||||
Returns:
|
||||
(User, ManagementToken) 元组,认证失败返回 None
|
||||
|
||||
Raises:
|
||||
RateLimitException: 超过速率限制时抛出(用于返回 429)
|
||||
"""
|
||||
from src.core.exceptions import RateLimitException
|
||||
from src.models.database import AuditEventType, ManagementToken
|
||||
from src.services.rate_limit.ip_limiter import IPRateLimiter
|
||||
from src.services.system.audit import AuditService
|
||||
|
||||
# 速率限制检查(防止暴力破解)
|
||||
allowed, remaining, ttl = await IPRateLimiter.check_limit(
|
||||
client_ip,
|
||||
endpoint_type="management_token",
|
||||
limit=config.management_token_rate_limit,
|
||||
)
|
||||
if not allowed:
|
||||
logger.warning(f"Management Token 认证 - IP {client_ip} 超过速率限制")
|
||||
raise RateLimitException(limit=config.management_token_rate_limit, window="分钟")
|
||||
|
||||
# 检查 Token 格式
|
||||
if not raw_token.startswith(ManagementToken.TOKEN_PREFIX):
|
||||
logger.warning("Management Token 认证失败 - 格式错误")
|
||||
return None
|
||||
|
||||
# 哈希查找
|
||||
token_hash = ManagementToken.hash_token(raw_token)
|
||||
token_record = (
|
||||
db.query(ManagementToken)
|
||||
.options(joinedload(ManagementToken.user))
|
||||
.filter(ManagementToken.token_hash == token_hash)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not token_record:
|
||||
logger.warning("Management Token 认证失败 - Token 不存在")
|
||||
return None
|
||||
|
||||
# 注意:数据库查询已通过 token_hash 索引匹配,此处不再需要额外的常量时间比较
|
||||
# Token 的 62^40 熵(约 238 位)加上速率限制已足够防止暴力破解
|
||||
|
||||
# 检查状态
|
||||
if not token_record.is_active:
|
||||
logger.warning(f"Management Token 认证失败 - Token 已禁用: {token_record.id}")
|
||||
return None
|
||||
|
||||
# 检查过期(使用属性方法,确保时区安全)
|
||||
if token_record.is_expired:
|
||||
logger.warning(f"Management Token 认证失败 - Token 已过期: {token_record.id}")
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type=AuditEventType.MANAGEMENT_TOKEN_EXPIRED,
|
||||
description=f"Management Token 已过期: {token_record.name}",
|
||||
user_id=token_record.user_id,
|
||||
ip_address=client_ip,
|
||||
metadata={
|
||||
"token_id": token_record.id,
|
||||
"token_name": token_record.name,
|
||||
"expired_at": (
|
||||
token_record.expires_at.isoformat() if token_record.expires_at else None
|
||||
),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# 检查 IP 白名单
|
||||
if not token_record.is_ip_allowed(client_ip):
|
||||
logger.warning(
|
||||
f"Management Token IP 限制 - Token: {token_record.id}, IP: {client_ip}"
|
||||
)
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type=AuditEventType.MANAGEMENT_TOKEN_IP_BLOCKED,
|
||||
description=f"Management Token IP 被拒绝: {token_record.name}",
|
||||
user_id=token_record.user_id,
|
||||
ip_address=client_ip,
|
||||
metadata={
|
||||
"token_id": token_record.id,
|
||||
"token_name": token_record.name,
|
||||
"blocked_ip": client_ip,
|
||||
# 不记录 allowed_ips 以防信息泄露
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# 获取用户
|
||||
user = token_record.user
|
||||
if not user or not user.is_active:
|
||||
logger.warning("Management Token 认证失败 - 用户不存在或已禁用")
|
||||
return None
|
||||
|
||||
# 使用 SQL 原子操作更新使用统计
|
||||
from sqlalchemy import func
|
||||
|
||||
db.query(ManagementToken).filter(ManagementToken.id == token_record.id).update(
|
||||
{
|
||||
ManagementToken.last_used_at: func.now(), # 使用数据库时间确保一致性
|
||||
ManagementToken.last_used_ip: client_ip,
|
||||
ManagementToken.usage_count: ManagementToken.usage_count + 1,
|
||||
ManagementToken.updated_at: func.now(), # 显式更新,因为原子 SQL 绕过 ORM
|
||||
},
|
||||
synchronize_session=False,
|
||||
)
|
||||
|
||||
# 记录 Token 使用审计日志
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type=AuditEventType.MANAGEMENT_TOKEN_USED,
|
||||
description=f"Management Token 认证成功: {token_record.name}",
|
||||
user_id=user.id,
|
||||
ip_address=client_ip,
|
||||
metadata={
|
||||
"token_id": token_record.id,
|
||||
"token_name": token_record.name,
|
||||
},
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.debug(f"Management Token 认证成功: user={user.email}, token={token_record.id}")
|
||||
return user, token_record
|
||||
|
||||
51
src/services/billing/__init__.py
Normal file
51
src/services/billing/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
计费模块
|
||||
|
||||
提供配置驱动的计费计算,支持不同厂商的差异化计费模式:
|
||||
- Claude: input + output + cache_creation + cache_read
|
||||
- OpenAI: input + output + cache_read (无缓存创建费用)
|
||||
- 豆包: input + output + cache_read + cache_storage (缓存按时计费)
|
||||
- 按次计费: per_request
|
||||
|
||||
使用方式:
|
||||
from src.services.billing import BillingCalculator, UsageMapper, StandardizedUsage
|
||||
|
||||
# 1. 将原始 usage 映射为标准格式
|
||||
usage = UsageMapper.map(raw_usage, api_format="OPENAI")
|
||||
|
||||
# 2. 使用计费计算器计算费用
|
||||
calculator = BillingCalculator(template="openai")
|
||||
result = calculator.calculate(usage, prices)
|
||||
|
||||
# 3. 获取费用明细
|
||||
print(result.total_cost)
|
||||
print(result.costs) # {"input": 0.01, "output": 0.02, ...}
|
||||
"""
|
||||
|
||||
from src.services.billing.calculator import BillingCalculator, calculate_request_cost
|
||||
from src.services.billing.models import (
|
||||
BillingDimension,
|
||||
BillingUnit,
|
||||
CostBreakdown,
|
||||
StandardizedUsage,
|
||||
)
|
||||
from src.services.billing.templates import BILLING_TEMPLATE_REGISTRY, BillingTemplates
|
||||
from src.services.billing.usage_mapper import UsageMapper, map_usage, map_usage_from_response
|
||||
|
||||
__all__ = [
|
||||
# 数据模型
|
||||
"BillingDimension",
|
||||
"BillingUnit",
|
||||
"CostBreakdown",
|
||||
"StandardizedUsage",
|
||||
# 模板
|
||||
"BillingTemplates",
|
||||
"BILLING_TEMPLATE_REGISTRY",
|
||||
# 计算器
|
||||
"BillingCalculator",
|
||||
"calculate_request_cost",
|
||||
# 映射器
|
||||
"UsageMapper",
|
||||
"map_usage",
|
||||
"map_usage_from_response",
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user