35 Commits

Author SHA1 Message Date
fawney19
2395093394 refactor: 简化 IP 获取逻辑并将请求体超时配置化
- 移除 TRUSTED_PROXY_COUNT 配置,改为优先使用 X-Real-IP 头
- 添加 REQUEST_BODY_TIMEOUT 环境变量,默认 60 秒
- 统一 get_client_ip 逻辑,优先级:X-Real-IP > X-Forwarded-For > 直连 IP
2026-01-06 16:29:03 +08:00
fawney19
28209e1c2a Merge pull request #72 from fawney19/test-ldap-pr
feat: 添加 LDAP 认证支持
2026-01-06 14:45:31 +08:00
fawney19
00562dd1d4 feat: 添加 LDAP 认证支持
- 新增 LDAP 服务和 API 接口
- 添加 LDAP 配置管理页面
- 登录页面支持 LDAP/本地认证切换
- 数据库迁移支持 LDAP 相关字段
2026-01-06 14:38:42 +08:00
fawney19
0f78d5cbf3 fix: 增强 CLI 处理器的错误信息,包含上游响应详情 2026-01-05 19:44:38 +08:00
fawney19
431c6de8d2 feat: 用户用量页面支持分页、搜索和密钥信息展示
- 用户用量API增加search参数支持密钥名、模型名搜索
- 用户用量API返回api_key信息(id、name、display)
- 用户页面记录表格增加密钥列显示
- 前端统一管理员和用户页面的分页/搜索逻辑
- 后端LIKE查询增加特殊字符转义防止SQL注入
- 添加escape_like_pattern和safe_truncate_escaped工具函数
2026-01-05 19:35:14 +08:00
fawney19
142e15bbcc Merge pull request #69 from AoaoMH/feature/Record-optimization
feat: add usage statistics and records feature with new API routes, f…
2026-01-05 19:31:59 +08:00
AAEE86
31acc5c607 feat(models): sort models by release date within each provider
Models are now sorted by release date in descending order (newest first)
within each provider group. Models without release dates are placed at the
end. When release dates are identical or missing, models fall back to
alphabetical sorting by name.
2026-01-05 18:23:04 +08:00
fawney19
bfa0a26d41 feat: 用户导出支持独立余额Key,新增系统版本接口
- 用户导出/导入支持独立余额 Key (standalone_keys)
- API Key 导出增加 expires_at 字段
- 新增 /api/admin/system/version 接口获取版本信息
- 前端系统设置页面显示当前版本
- 移除导入对话框中多余的 bg-muted 背景样式
2026-01-05 18:18:45 +08:00
AoaoMH
93ab9b6a5e feat: add usage statistics and records feature with new API routes, frontend types, services, and UI components 2026-01-05 17:03:05 +08:00
fawney19
35e29d46bd refactor: 抽取统一计费模块,支持配置驱动的多厂商计费
- 新增 src/services/billing/ 模块,包含计费计算器、模板和使用量映射
- 将 ChatAdapterBase 和 CliAdapterBase 中的计费逻辑重构为调用 billing 模块
- 为每个 adapter 添加 BILLING_TEMPLATE 类属性,指定计费模板
- 支持 Claude/OpenAI/Gemini 三种计费模板,支持阶梯计费和缓存 TTL 定价
- 新增 tests/services/billing/ 单元测试
2026-01-05 16:48:59 +08:00
fawney19
465da6f818 feat: OpenAI 流式响应解析器支持提取 usage 信息
部分 OpenAI 兼容 API(如豆包)会在最后一个 chunk 中发送 usage 信息,
现在可以正确提取 prompt_tokens 和 completion_tokens。
2026-01-05 12:50:05 +08:00
fawney19
e5f12fddd9 feat: 流式预读增强与自适应并发算法优化
流式预读增强:
- 新增预读字节上限(64KB),防止无换行响应导致内存增长
- 预读结束后检测非 SSE 格式的错误响应(HTML 页面、纯 JSON 错误)
- 抽取 check_html_response 和 check_prefetched_response_error 到 utils.py

自适应并发算法优化(边界记忆 + 渐进探测):
- 缩容策略:从乘性减少改为边界 -1,一次 429 即可收敛到真实限制附近
- 扩容策略:普通扩容不超过已知边界,探测性扩容可谨慎突破(每次 +1)
- 仅在并发限制 429 时记录边界,避免 RPM/UNKNOWN 类型覆盖
2026-01-05 12:17:45 +08:00
fawney19
4fa9a1303a feat: 优化首字时间和 streaming 状态的记录时序
改进 streaming 状态更新机制:
- 统一在首次输出时记录 TTFB 并更新 streaming 状态
- 重构 CliMessageHandlerBase 中的状态更新逻辑,消除重复
- 确保 provider/key 信息在 streaming 状态更新时已可用

前端改进:
- 添加 first_byte_time_ms 字段支持
- 管理员接口支持返回 provider/api_key_name 字段
- 优化活跃请求轮询逻辑,更准确地判断是否需要刷新完整数据

数据库与 API:
- UsageService.get_active_requests_status 添加 include_admin_fields 参数
- 管理员接口调用时启用该参数以获取额外信息
2026-01-05 10:31:34 +08:00
fawney19
43f349d415 fix: 确保 CLI handler 的 streaming 状态更新时 provider 信息已设置
在 execute_with_fallback 返回后,显式设置 ctx 的 provider 信息,
与 chat_handler_base.py 的行为保持一致,避免 streaming 状态更新
时 provider 为空的问题。
2026-01-05 09:36:35 +08:00
fawney19
02069954de fix: streaming 状态更新时传递 first_byte_time_ms 2026-01-05 09:29:38 +08:00
fawney19
2e15875fed feat: 端点 API 支持 custom_path 字段
- ProviderEndpointCreate 添加 custom_path 参数
- ProviderEndpointUpdate 添加 custom_path 参数
- ProviderEndpointResponse 返回 custom_path 字段
- 创建端点时传递 custom_path 到数据库模型
2026-01-05 09:22:20 +08:00
fawney19
b34cfb676d fix: streaming 状态更新时传递 provider 相关 ID 信息
在 update_usage_status 方法中增加 provider_id、provider_endpoint_id
和 provider_api_key_id 参数,确保流式请求进入 streaming 状态时
能正确记录这些字段。
2026-01-05 09:12:03 +08:00
fawney19
3064497636 refactor: 改进上游错误消息的提取和传递
- 新增 extract_error_message 工具函数,统一错误消息提取逻辑
- 在 HTTPStatusError 异常上附加 upstream_response 属性,保留原始错误
- 优先使用上游响应内容作为错误消息,而非异常字符串表示
- 移除错误消息的长度限制(500/1000 字符)
- 修复边界条件检查,使用 startswith 匹配 "Unable to read" 前缀
- 简化 result.py 中的条件判断逻辑
2026-01-05 03:18:55 +08:00
fawney19
dec681fea0 fix: 统一时区处理,确保所有 datetime 带时区信息
- token_bucket.py: get_reset_time 和 Redis 后端使用 timezone.utc
- sliding_window.py: get_reset_time 和 retry_after 计算使用 timezone.utc
- provider_strategy.py: dateutil.parser 解析后确保有时区信息
2026-01-05 02:23:24 +08:00
fawney19
523e27ba9a fix: API Key 过期时间使用应用时区而非 UTC
- 后端:parse_expiry_date 使用 APP_TIMEZONE(默认 Asia/Shanghai)
- 前端:移除提示文案中的 "UTC"
2026-01-05 02:18:16 +08:00
fawney19
e7db76e581 refactor: API Key 过期时间改用日期选择器,rate_limit 支持无限制
- 前端:将过期时间设置从"天数输入"改为"日期选择器",更直观
- 后端:新增 expires_at 字段(ISO 日期格式),兼容旧版 expire_days
- rate_limit 字段现在支持 null 表示无限制,移除默认值 100
- 解析逻辑:过期时间设为当天 UTC 23:59:59.999999
2026-01-05 02:16:16 +08:00
fawney19
689339117a refactor: 提取 ModelMultiSelect 组件并支持失效模型检测
- 新增 ModelMultiSelect 组件,支持显示和移除已失效的模型
- 新增 useInvalidModels composable 检测 allowed_models 中的无效引用
- 重构 StandaloneKeyFormDialog 和 UserFormDialog 使用新组件
- 补充 GlobalModel 删除逻辑的设计说明注释
2026-01-05 01:20:58 +08:00
fawney19
b202765be4 perf: 优化流式响应 TTFB,将数据库状态更新移至 yield 后执行
- StreamUsageTracker: 先 yield 首个 chunk 再更新 streaming 状态
- EnhancedStreamUsageTracker: 同步添加 TTFB 记录和状态更新逻辑
- 确保客户端首字节响应不受数据库操作延迟影响
2026-01-05 00:13:23 +08:00
fawney19
3bbf3073df feat: 所有 Provider 失败时透传上游错误信息
- FallbackOrchestrator 在所有候选组合失败后保留最后的错误信息
- 从 httpx.HTTPStatusError 提取上游状态码和响应内容
- ProviderNotAvailableException 携带上游错误信息
- ErrorResponse 在返回错误时透传上游状态码和响应
2026-01-04 23:50:15 +08:00
fawney19
f46aaa2182 debug: 添加 streaming 状态更新时 provider 为空的调试日志
- base_handler: 更新 streaming 状态时检测并记录 provider 为空的情况
- cli_handler_base: 修复预读数据为空时未更新 streaming 状态的问题
- usage service: 检测状态变为 streaming 但 provider 仍为 pending 的异常
2026-01-04 23:16:01 +08:00
fawney19
a2f33a6c35 perf: 拆分热力图为独立接口并添加 Redis 缓存
- 新增独立热力图 API 端点 (/api/admin/usage/heatmap, /api/users/me/usage/heatmap)
- 添加 Redis 缓存层 (5分钟 TTL),减少数据库查询
- 用户角色变更时清除热力图缓存
- 前端并行加载统计数据和热力图,添加加载/错误状态显示
- 修复 cache_decorator 缺少 JSON 解析错误处理的问题
- 更新 docker-compose 启动命令提示
2026-01-04 22:42:58 +08:00
fawney19
b6bd6357ed perf: 优化 GlobalModel 列表查询的 N+1 问题 2026-01-04 20:05:23 +08:00
fawney19
c3a5878b1b feat: 优化用量查询分页和热力图性能
- 用量查询接口添加 limit/offset 分页参数支持
- 热力图统计从实时查询 Usage 表改为读取预计算的 StatsDaily/StatsUserDaily 表
- 修复 avg_response_time_ms 为 0 时被错误跳过的问题
2026-01-04 18:02:47 +08:00
RWDai
3e4309eba3 Enhance LDAP auth config handling 2026-01-04 16:27:02 +08:00
RWDai
414f45aa71 Fix LDAP authentication stability 2026-01-04 13:09:55 +08:00
RWDai
ebdc76346f revert: 回滚 _version.py 版本号变更 2026-01-04 11:25:58 +08:00
RWDai
64bfa955f4 feat(ldap): 完善 LDAP 认证功能和安全性
- 添加 LDAP 配置类型定义,移除 any 类型
- 首次配置 LDAP 时强制要求设置绑定密码
- 根据认证类型区分登录标识验证(本地需邮箱,LDAP 允许用户名)
- 添加 LDAP 过滤器转义函数防止注入攻击
- 增加 LDAP 连接超时设置
- 添加账户来源冲突检查,防止 LDAP 覆盖本地账户
- 添加用户名冲突自动重命名机制
2026-01-04 11:18:28 +08:00
RWDai
612992fa1f Merge remote-tracking branch 'origin/master' into feature/ldap-authentication 2026-01-04 10:48:36 +08:00
fawney19
c02ac56da8 chore: 更新 docker-compose 命令为 docker compose
统一使用 Docker Compose V2 的现代写法
2026-01-03 01:39:45 +08:00
RWDai
9bfb295238 feat: add ldap login 2026-01-02 16:17:24 +08:00
92 changed files with 7598 additions and 2960 deletions

View File

@@ -51,20 +51,20 @@ Aether 是一个自托管的 AI API 网关,为团队和个人提供多租户
```bash
# 1. 克隆代码
git clone https://github.com/fawney19/Aether.git
cd aether
cd Aether
# 2. 配置环境变量
cp .env.example .env
python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
# 3. 部署
docker-compose up -d
docker compose up -d
# 4. 首次部署时, 初始化数据库
./migrate.sh
# 5. 更新
docker-compose pull && docker-compose up -d && ./migrate.sh
docker compose pull && docker compose up -d && ./migrate.sh
```
### Docker Compose本地构建镜像
@@ -72,7 +72,7 @@ docker-compose pull && docker-compose up -d && ./migrate.sh
```bash
# 1. 克隆代码
git clone https://github.com/fawney19/Aether.git
cd aether
cd Aether
# 2. 配置环境变量
cp .env.example .env
@@ -86,7 +86,7 @@ python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
```bash
# 启动依赖
docker-compose -f docker-compose.build.yml up -d postgres redis
docker compose -f docker-compose.build.yml up -d postgres redis
# 后端
uv sync

View File

@@ -30,7 +30,7 @@ from src.models.database import Base
config = context.config
# 从环境变量获取数据库 URL
# 优先使用 DATABASE_URL否则从 DB_PASSWORD 自动构建(与 docker-compose 保持一致)
# 优先使用 DATABASE_URL否则从 DB_PASSWORD 自动构建(与 docker compose 保持一致)
database_url = os.getenv("DATABASE_URL")
if not database_url:
db_password = os.getenv("DB_PASSWORD", "")

View File

@@ -0,0 +1,161 @@
"""add ldap authentication support
Revision ID: c3d4e5f6g7h8
Revises: b2c3d4e5f6g7
Create Date: 2026-01-01 14:00:00.000000+00:00
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision = 'c3d4e5f6g7h8'
down_revision = 'b2c3d4e5f6g7'
branch_labels = None
depends_on = None
def _type_exists(conn, type_name: str) -> bool:
"""检查 PostgreSQL 类型是否存在"""
result = conn.execute(
text("SELECT 1 FROM pg_type WHERE typname = :name"),
{"name": type_name}
)
return result.scalar() is not None
def _column_exists(conn, table_name: str, column_name: str) -> bool:
"""检查列是否存在"""
result = conn.execute(
text("""
SELECT 1 FROM information_schema.columns
WHERE table_name = :table AND column_name = :column
"""),
{"table": table_name, "column": column_name}
)
return result.scalar() is not None
def _index_exists(conn, index_name: str) -> bool:
"""检查索引是否存在"""
result = conn.execute(
text("SELECT 1 FROM pg_indexes WHERE indexname = :name"),
{"name": index_name}
)
return result.scalar() is not None
def _table_exists(conn, table_name: str) -> bool:
"""检查表是否存在"""
result = conn.execute(
text("""
SELECT 1 FROM information_schema.tables
WHERE table_name = :name AND table_schema = 'public'
"""),
{"name": table_name}
)
return result.scalar() is not None
def upgrade() -> None:
"""添加 LDAP 认证支持
1. 创建 authsource 枚举类型
2. 在 users 表添加 auth_source 字段和 LDAP 标识字段
3. 创建 ldap_configs 表
"""
conn = op.get_bind()
# 1. 创建 authsource 枚举类型(幂等)
if not _type_exists(conn, 'authsource'):
conn.execute(text("CREATE TYPE authsource AS ENUM ('local', 'ldap')"))
# 2. 在 users 表添加字段(幂等)
if not _column_exists(conn, 'users', 'auth_source'):
op.add_column('users', sa.Column(
'auth_source',
sa.Enum('local', 'ldap', name='authsource', create_type=False),
nullable=False,
server_default='local'
))
if not _column_exists(conn, 'users', 'ldap_dn'):
op.add_column('users', sa.Column('ldap_dn', sa.String(length=512), nullable=True))
if not _column_exists(conn, 'users', 'ldap_username'):
op.add_column('users', sa.Column('ldap_username', sa.String(length=255), nullable=True))
# 创建索引(幂等)
if not _index_exists(conn, 'ix_users_ldap_dn'):
op.create_index('ix_users_ldap_dn', 'users', ['ldap_dn'])
if not _index_exists(conn, 'ix_users_ldap_username'):
op.create_index('ix_users_ldap_username', 'users', ['ldap_username'])
# 3. 创建 ldap_configs 表(幂等)
if not _table_exists(conn, 'ldap_configs'):
op.create_table(
'ldap_configs',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('server_url', sa.String(length=255), nullable=False),
sa.Column('bind_dn', sa.String(length=255), nullable=False),
sa.Column('bind_password_encrypted', sa.Text(), nullable=True),
sa.Column('base_dn', sa.String(length=255), nullable=False),
sa.Column('user_search_filter', sa.String(length=500), nullable=False, server_default='(uid={username})'),
sa.Column('username_attr', sa.String(length=50), nullable=False, server_default='uid'),
sa.Column('email_attr', sa.String(length=50), nullable=False, server_default='mail'),
sa.Column('display_name_attr', sa.String(length=50), nullable=False, server_default='cn'),
sa.Column('is_enabled', sa.Boolean(), nullable=False, server_default='false'),
sa.Column('is_exclusive', sa.Boolean(), nullable=False, server_default='false'),
sa.Column('use_starttls', sa.Boolean(), nullable=False, server_default='false'),
sa.Column('connect_timeout', sa.Integer(), nullable=False, server_default='10'),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.PrimaryKeyConstraint('id')
)
def downgrade() -> None:
"""回滚 LDAP 认证支持
警告:回滚前请确保:
1. 已备份数据库
2. 没有 LDAP 用户需要保留
"""
conn = op.get_bind()
# 检查是否存在 LDAP 用户,防止数据丢失
if _column_exists(conn, 'users', 'auth_source'):
result = conn.execute(text("SELECT COUNT(*) FROM users WHERE auth_source = 'ldap'"))
ldap_user_count = result.scalar()
if ldap_user_count and ldap_user_count > 0:
raise RuntimeError(
f"无法回滚:存在 {ldap_user_count} 个 LDAP 用户。"
f"请先删除或转换这些用户,或使用 --force 参数强制回滚(将丢失数据)。"
)
# 1. 删除 ldap_configs 表(幂等)
if _table_exists(conn, 'ldap_configs'):
op.drop_table('ldap_configs')
# 2. 删除 users 表的 LDAP 相关字段(幂等)
if _index_exists(conn, 'ix_users_ldap_username'):
op.drop_index('ix_users_ldap_username', table_name='users')
if _index_exists(conn, 'ix_users_ldap_dn'):
op.drop_index('ix_users_ldap_dn', table_name='users')
if _column_exists(conn, 'users', 'ldap_username'):
op.drop_column('users', 'ldap_username')
if _column_exists(conn, 'users', 'ldap_dn'):
op.drop_column('users', 'ldap_dn')
if _column_exists(conn, 'users', 'auth_source'):
op.drop_column('users', 'auth_source')
# 3. 删除 authsource 枚举类型(幂等)
# 注意:不使用 CASCADE因为此时所有依赖应该已被删除
if _type_exists(conn, 'authsource'):
conn.execute(text("DROP TYPE authsource"))

View File

@@ -1,7 +1,7 @@
# Aether 部署配置 - 本地构建
# 使用方法:
# 首次构建 base: docker build -f Dockerfile.base -t aether-base:latest .
# 启动服务: docker-compose -f docker-compose.build.yml up -d --build
# 启动服务: docker compose -f docker-compose.build.yml up -d --build
services:
postgres:

View File

@@ -1,5 +1,5 @@
# Aether 部署配置 - 使用预构建镜像
# 使用方法: docker-compose up -d
# 使用方法: docker compose up -d
services:
postgres:

View File

@@ -13,6 +13,7 @@ export interface UsersExportData {
version: string
exported_at: string
users: UserExport[]
standalone_keys?: StandaloneKeyExport[]
}
export interface UserExport {
@@ -42,15 +43,19 @@ export interface UserApiKeyExport {
allowed_endpoints?: string[] | null
allowed_api_formats?: string[] | null
allowed_models?: string[] | null
rate_limit?: number
rate_limit?: number | null // null = 无限制
concurrent_limit?: number | null
force_capabilities?: any
is_active: boolean
expires_at?: string | null
auto_delete_on_expiry?: boolean
total_requests?: number
total_cost_usd?: number
}
// 独立余额 Key 导出结构(与 UserApiKeyExport 相同,但不包含 is_standalone
export type StandaloneKeyExport = Omit<UserApiKeyExport, 'is_standalone'>
export interface GlobalModelExport {
name: string
display_name: string
@@ -155,6 +160,44 @@ export interface EmailTemplateResetResponse {
}
}
// LDAP 配置响应
export interface LdapConfigResponse {
server_url: string | null
bind_dn: string | null
base_dn: string | null
has_bind_password: boolean
user_search_filter: string
username_attr: string
email_attr: string
display_name_attr: string
is_enabled: boolean
is_exclusive: boolean
use_starttls: boolean
connect_timeout: number
}
// LDAP 配置更新请求
export interface LdapConfigUpdateRequest {
server_url: string
bind_dn: string
bind_password?: string
base_dn: string
user_search_filter?: string
username_attr?: string
email_attr?: string
display_name_attr?: string
is_enabled?: boolean
is_exclusive?: boolean
use_starttls?: boolean
connect_timeout?: number
}
// LDAP 连接测试响应
export interface LdapTestResponse {
success: boolean
message: string
}
// Provider 模型查询响应
export interface ProviderModelsQueryResponse {
success: boolean
@@ -189,6 +232,7 @@ export interface UsersImportResponse {
stats: {
users: { created: number; updated: number; skipped: number }
api_keys: { created: number; skipped: number }
standalone_keys?: { created: number; skipped: number }
errors: string[]
}
}
@@ -220,7 +264,7 @@ export interface AdminApiKey {
total_requests?: number
total_tokens?: number
total_cost_usd?: number
rate_limit?: number
rate_limit?: number | null // null = 无限制
allowed_providers?: string[] | null // 允许的提供商列表
allowed_api_formats?: string[] | null // 允许的 API 格式列表
allowed_models?: string[] | null // 允许的模型列表
@@ -236,8 +280,8 @@ export interface CreateStandaloneApiKeyRequest {
allowed_providers?: string[] | null
allowed_api_formats?: string[] | null
allowed_models?: string[] | null
rate_limit?: number
expire_days?: number | null // null = 永不过期
rate_limit?: number | null // null = 无限制
expires_at?: string | null // ISO 日期字符串,如 "2025-12-31"null = 永不过期
initial_balance_usd: number // 初始余额,必须设置
auto_delete_on_expiry?: boolean // 过期后是否自动删除
}
@@ -473,5 +517,35 @@ export const adminApi = {
`/api/admin/system/email/templates/${templateType}/reset`
)
return response.data
},
// 获取系统版本信息
async getSystemVersion(): Promise<{ version: string }> {
const response = await apiClient.get<{ version: string }>(
'/api/admin/system/version'
)
return response.data
},
// LDAP 配置相关
// 获取 LDAP 配置
async getLdapConfig(): Promise<LdapConfigResponse> {
const response = await apiClient.get<LdapConfigResponse>('/api/admin/ldap/config')
return response.data
},
// 更新 LDAP 配置
async updateLdapConfig(config: LdapConfigUpdateRequest): Promise<{ message: string }> {
const response = await apiClient.put<{ message: string }>(
'/api/admin/ldap/config',
config
)
return response.data
},
// 测试 LDAP 连接
async testLdapConnection(config: LdapConfigUpdateRequest): Promise<LdapTestResponse> {
const response = await apiClient.post<LdapTestResponse>('/api/admin/ldap/test', config)
return response.data
}
}

View File

@@ -4,6 +4,7 @@ import { log } from '@/utils/logger'
export interface LoginRequest {
email: string
password: string
auth_type?: 'local' | 'ldap'
}
export interface LoginResponse {
@@ -81,6 +82,12 @@ export interface RegistrationSettingsResponse {
require_email_verification: boolean
}
export interface AuthSettingsResponse {
local_enabled: boolean
ldap_enabled: boolean
ldap_exclusive: boolean
}
export interface User {
id: string // UUID
username: string
@@ -173,5 +180,10 @@ export const authApi = {
{ email }
)
return response.data
},
async getAuthSettings(): Promise<AuthSettingsResponse> {
const response = await apiClient.get<AuthSettingsResponse>('/api/auth/settings')
return response.data
}
}

View File

@@ -62,6 +62,11 @@ export interface UsageRecordDetail {
cache_creation_price_per_1m?: number
cache_read_price_per_1m?: number
price_per_request?: number // 按次计费价格
api_key?: {
id: string
name: string
display: string
}
}
// 模型统计接口
@@ -75,6 +80,16 @@ export interface ModelSummary {
actual_total_cost_usd?: number // 倍率消耗(仅管理员可见)
}
// 提供商统计接口
export interface ProviderSummary {
provider: string
requests: number
total_tokens: number
total_cost_usd: number
success_rate: number | null
avg_response_time_ms: number | null
}
// 使用统计响应接口
export interface UsageResponse {
total_requests: number
@@ -87,6 +102,13 @@ export interface UsageResponse {
quota_usd: number | null
used_usd: number
summary_by_model: ModelSummary[]
summary_by_provider?: ProviderSummary[]
pagination?: {
total: number
limit: number
offset: number
has_more: boolean
}
records: UsageRecordDetail[]
activity_heatmap?: ActivityHeatmap | null
}
@@ -175,6 +197,9 @@ export const meApi = {
async getUsage(params?: {
start_date?: string
end_date?: string
search?: string // 通用搜索:密钥名、模型名
limit?: number
offset?: number
}): Promise<UsageResponse> {
const response = await apiClient.get<UsageResponse>('/api/users/me/usage', { params })
return response.data
@@ -184,11 +209,12 @@ export const meApi = {
async getActiveRequests(ids?: string): Promise<{
requests: Array<{
id: string
status: string
status: 'pending' | 'streaming' | 'completed' | 'failed'
input_tokens: number
output_tokens: number
cost: number
response_time_ms: number | null
first_byte_time_ms: number | null
}>
}> {
const params = ids ? { ids } : {}
@@ -267,5 +293,14 @@ export const meApi = {
}> {
const response = await apiClient.get('/api/users/me/usage/interval-timeline', { params })
return response.data
},
/**
* 获取活跃度热力图数据(用户)
* 后端已缓存5分钟
*/
async getActivityHeatmap(): Promise<ActivityHeatmap> {
const response = await apiClient.get<ActivityHeatmap>('/api/users/me/usage/heatmap')
return response.data
}
}

View File

@@ -192,10 +192,17 @@ export async function getModelsDevList(officialOnly: boolean = true): Promise<Mo
}
}
// 按 provider 名称和模型名称排序
// 按 provider 名称排序provider 中的模型按 release_date 从近到远排序
items.sort((a, b) => {
const providerCompare = a.providerName.localeCompare(b.providerName)
if (providerCompare !== 0) return providerCompare
// 模型按 release_date 从近到远排序(没有日期的排到最后)
const aDate = a.releaseDate ? new Date(a.releaseDate).getTime() : 0
const bDate = b.releaseDate ? new Date(b.releaseDate).getTime() : 0
if (aDate !== bDate) return bDate - aDate // 降序:新的在前
// 日期相同或都没有日期时,按模型名称排序
return a.modelName.localeCompare(b.modelName)
})

View File

@@ -164,6 +164,7 @@ export const usageApi = {
async getAllUsageRecords(params?: {
start_date?: string
end_date?: string
search?: string // 通用搜索:用户名、密钥名、模型名、提供商名
user_id?: string // UUID
username?: string
model?: string
@@ -193,10 +194,22 @@ export const usageApi = {
output_tokens: number
cost: number
response_time_ms: number | null
first_byte_time_ms: number | null
provider?: string | null
api_key_name?: string | null
}>
}> {
const params = ids?.length ? { ids: ids.join(',') } : {}
const response = await apiClient.get('/api/admin/usage/active', { params })
return response.data
},
/**
* 获取活跃度热力图数据(管理员)
* 后端已缓存5分钟
*/
async getActivityHeatmap(): Promise<ActivityHeatmap> {
const response = await apiClient.get<ActivityHeatmap>('/api/admin/usage/heatmap')
return response.data
}
}

View File

@@ -0,0 +1,117 @@
<template>
<div class="space-y-2">
<Label class="text-sm font-medium">允许的模型</Label>
<div class="relative">
<button
type="button"
class="w-full h-10 px-3 border rounded-lg bg-background text-left flex items-center justify-between hover:bg-muted/50 transition-colors"
@click="isOpen = !isOpen"
>
<span :class="modelValue.length ? 'text-foreground' : 'text-muted-foreground'">
{{ modelValue.length ? `已选择 ${modelValue.length}` : '全部可用' }}
<span
v-if="invalidModels.length"
class="text-destructive"
>({{ invalidModels.length }} 个已失效)</span>
</span>
<ChevronDown
class="h-4 w-4 text-muted-foreground transition-transform"
:class="isOpen ? 'rotate-180' : ''"
/>
</button>
<div
v-if="isOpen"
class="fixed inset-0 z-[80]"
@click.stop="isOpen = false"
/>
<div
v-if="isOpen"
class="absolute z-[90] w-full mt-1 bg-popover border rounded-lg shadow-lg max-h-48 overflow-y-auto"
>
<!-- 失效模型置顶显示只能取消选择 -->
<div
v-for="modelName in invalidModels"
:key="modelName"
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer bg-destructive/5"
@click="removeModel(modelName)"
>
<input
type="checkbox"
:checked="true"
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
@click.stop
@change="removeModel(modelName)"
>
<span class="text-sm text-destructive">{{ modelName }}</span>
<span class="text-xs text-destructive/70">(已失效)</span>
</div>
<!-- 有效模型 -->
<div
v-for="model in models"
:key="model.name"
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer"
@click="toggleModel(model.name)"
>
<input
type="checkbox"
:checked="modelValue.includes(model.name)"
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
@click.stop
@change="toggleModel(model.name)"
>
<span class="text-sm">{{ model.name }}</span>
</div>
<div
v-if="models.length === 0 && invalidModels.length === 0"
class="px-3 py-2 text-sm text-muted-foreground"
>
暂无可用模型
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, computed } from 'vue'
import { Label } from '@/components/ui'
import { ChevronDown } from 'lucide-vue-next'
import { useInvalidModels } from '@/composables/useInvalidModels'
export interface ModelWithName {
name: string
}
const props = defineProps<{
modelValue: string[]
models: ModelWithName[]
}>()
const emit = defineEmits<{
'update:modelValue': [value: string[]]
}>()
const isOpen = ref(false)
// 检测失效模型
const { invalidModels } = useInvalidModels(
computed(() => props.modelValue),
computed(() => props.models)
)
function toggleModel(name: string) {
const newValue = [...props.modelValue]
const index = newValue.indexOf(name)
if (index === -1) {
newValue.push(name)
} else {
newValue.splice(index, 1)
}
emit('update:modelValue', newValue)
}
function removeModel(name: string) {
const newValue = props.modelValue.filter(m => m !== name)
emit('update:modelValue', newValue)
}
</script>

View File

@@ -7,3 +7,6 @@
export { default as EmptyState } from './EmptyState.vue'
export { default as AlertDialog } from './AlertDialog.vue'
export { default as LoadingState } from './LoadingState.vue'
// 表单组件
export { default as ModelMultiSelect } from './ModelMultiSelect.vue'

View File

@@ -0,0 +1,34 @@
import { computed, type Ref, type ComputedRef } from 'vue'
/**
* 检测失效模型的 composable
*
* 用于检测 allowed_models 中已不存在于 globalModels 的模型名称,
* 这些模型可能已被删除但引用未清理。
*
* @example
* ```typescript
* const { invalidModels } = useInvalidModels(
* computed(() => form.value.allowed_models),
* globalModels
* )
* ```
*/
export interface ModelWithName {
name: string
}
export function useInvalidModels<T extends ModelWithName>(
allowedModels: Ref<string[]> | ComputedRef<string[]>,
globalModels: Ref<T[]>
): { invalidModels: ComputedRef<string[]> } {
const validModelNames = computed(() =>
new Set(globalModels.value.map(m => m.name))
)
const invalidModels = computed(() =>
allowedModels.value.filter(name => !validModelNames.value.has(name))
)
return { invalidModels }
}

View File

@@ -79,45 +79,45 @@
<div class="space-y-2">
<Label
for="form-expire-days"
for="form-expires-at"
class="text-sm font-medium"
>有效期设置</Label>
<div class="flex items-center gap-2">
<div class="relative flex-1">
<Input
id="form-expire-days"
:model-value="form.expire_days ?? ''"
type="number"
min="1"
max="3650"
placeholder="天数"
:class="form.never_expire ? 'flex-1 h-9 opacity-50' : 'flex-1 h-9'"
:disabled="form.never_expire"
@update:model-value="(v) => form.expire_days = parseNumberInput(v, { min: 1, max: 3650 })"
id="form-expires-at"
:model-value="form.expires_at || ''"
type="date"
:min="minExpiryDate"
class="h-9 pr-8"
:placeholder="form.expires_at ? '' : '永不过期'"
@update:model-value="(v) => form.expires_at = v || undefined"
/>
<label class="flex items-center gap-1.5 border rounded-md px-2 py-1.5 bg-muted/50 cursor-pointer text-xs whitespace-nowrap">
<input
v-model="form.never_expire"
type="checkbox"
class="h-3.5 w-3.5 rounded border-gray-300 cursor-pointer"
@change="onNeverExpireChange"
<button
v-if="form.expires_at"
type="button"
class="absolute right-2 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground"
title="清空永不过期"
@click="clearExpiryDate"
>
永不过期
</label>
<X class="h-4 w-4" />
</button>
</div>
<label
class="flex items-center gap-1.5 border rounded-md px-2 py-1.5 bg-muted/50 cursor-pointer text-xs whitespace-nowrap"
:class="form.never_expire ? 'opacity-50' : ''"
:class="!form.expires_at ? 'opacity-50 cursor-not-allowed' : ''"
>
<input
v-model="form.auto_delete_on_expiry"
type="checkbox"
class="h-3.5 w-3.5 rounded border-gray-300 cursor-pointer"
:disabled="form.never_expire"
:disabled="!form.expires_at"
>
到期删除
</label>
</div>
<p class="text-xs text-muted-foreground">
不勾选"到期删除"则仅禁用
{{ form.expires_at ? '到期后' + (form.auto_delete_on_expiry ? '自动删除' : '仅禁用') + '(当天 23:59 失效)' : '留空表示永不过期' }}
</p>
</div>
@@ -244,55 +244,10 @@
</div>
<!-- 模型多选下拉框 -->
<div class="space-y-2">
<Label class="text-sm font-medium">允许的模型</Label>
<div class="relative">
<button
type="button"
class="w-full h-10 px-3 border rounded-lg bg-background text-left flex items-center justify-between hover:bg-muted/50 transition-colors"
@click="modelDropdownOpen = !modelDropdownOpen"
>
<span :class="form.allowed_models.length ? 'text-foreground' : 'text-muted-foreground'">
{{ form.allowed_models.length ? `已选择 ${form.allowed_models.length} 个` : '全部可用' }}
</span>
<ChevronDown
class="h-4 w-4 text-muted-foreground transition-transform"
:class="modelDropdownOpen ? 'rotate-180' : ''"
<ModelMultiSelect
v-model="form.allowed_models"
:models="globalModels"
/>
</button>
<div
v-if="modelDropdownOpen"
class="fixed inset-0 z-[80]"
@click.stop="modelDropdownOpen = false"
/>
<div
v-if="modelDropdownOpen"
class="absolute z-[90] w-full mt-1 bg-popover border rounded-lg shadow-lg max-h-48 overflow-y-auto"
>
<div
v-for="model in globalModels"
:key="model.name"
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer"
@click="toggleSelection('allowed_models', model.name)"
>
<input
type="checkbox"
:checked="form.allowed_models.includes(model.name)"
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
@click.stop
@change="toggleSelection('allowed_models', model.name)"
>
<span class="text-sm">{{ model.name }}</span>
</div>
<div
v-if="globalModels.length === 0"
class="px-3 py-2 text-sm text-muted-foreground"
>
暂无可用模型
</div>
</div>
</div>
</div>
</div>
</div>
</form>
@@ -325,8 +280,9 @@ import {
Input,
Label,
} from '@/components/ui'
import { Plus, SquarePen, Key, Shield, ChevronDown } from 'lucide-vue-next'
import { Plus, SquarePen, Key, Shield, ChevronDown, X } from 'lucide-vue-next'
import { useFormDialog } from '@/composables/useFormDialog'
import { ModelMultiSelect } from '@/components/common'
import { getProvidersSummary } from '@/api/endpoints/providers'
import { getGlobalModels } from '@/api/global-models'
import { adminApi } from '@/api/admin'
@@ -338,8 +294,7 @@ export interface StandaloneKeyFormData {
id?: string
name: string
initial_balance_usd?: number
expire_days?: number
never_expire: boolean
expires_at?: string // ISO 日期字符串,如 "2025-12-31"undefined = 永不过期
rate_limit?: number
auto_delete_on_expiry: boolean
allowed_providers: string[]
@@ -363,7 +318,6 @@ const saving = ref(false)
// 下拉框状态
const providerDropdownOpen = ref(false)
const apiFormatDropdownOpen = ref(false)
const modelDropdownOpen = ref(false)
// 选项数据
const providers = ref<ProviderWithEndpointsSummary[]>([])
@@ -374,8 +328,7 @@ const allApiFormats = ref<string[]>([])
const form = ref<StandaloneKeyFormData>({
name: '',
initial_balance_usd: 10,
expire_days: undefined,
never_expire: true,
expires_at: undefined,
rate_limit: undefined,
auto_delete_on_expiry: false,
allowed_providers: [],
@@ -383,12 +336,18 @@ const form = ref<StandaloneKeyFormData>({
allowed_models: []
})
// 计算最小可选日期(明天)
const minExpiryDate = computed(() => {
const tomorrow = new Date()
tomorrow.setDate(tomorrow.getDate() + 1)
return tomorrow.toISOString().split('T')[0]
})
function resetForm() {
form.value = {
name: '',
initial_balance_usd: 10,
expire_days: undefined,
never_expire: true,
expires_at: undefined,
rate_limit: undefined,
auto_delete_on_expiry: false,
allowed_providers: [],
@@ -397,7 +356,6 @@ function resetForm() {
}
providerDropdownOpen.value = false
apiFormatDropdownOpen.value = false
modelDropdownOpen.value = false
}
function loadKeyData() {
@@ -406,8 +364,7 @@ function loadKeyData() {
id: props.apiKey.id,
name: props.apiKey.name || '',
initial_balance_usd: props.apiKey.initial_balance_usd,
expire_days: props.apiKey.expire_days,
never_expire: props.apiKey.never_expire,
expires_at: props.apiKey.expires_at,
rate_limit: props.apiKey.rate_limit,
auto_delete_on_expiry: props.apiKey.auto_delete_on_expiry,
allowed_providers: props.apiKey.allowed_providers || [],
@@ -452,13 +409,11 @@ function toggleSelection(field: 'allowed_providers' | 'allowed_api_formats' | 'a
}
}
// 永不过期切换
function onNeverExpireChange() {
if (form.value.never_expire) {
form.value.expire_days = undefined
// 清空过期日期(同时清空到期删除选项)
function clearExpiryDate() {
form.value.expires_at = undefined
form.value.auto_delete_on_expiry = false
}
}
// 提交表单
function handleSubmit() {

View File

@@ -66,19 +66,59 @@
</div>
</div>
<!-- 认证方式切换 -->
<div
v-if="showAuthTypeTabs"
class="auth-type-tabs"
>
<button
type="button"
:class="['auth-tab', authType === 'local' && 'active']"
@click="authType = 'local'"
>
本地登录
</button>
<button
type="button"
:class="['auth-tab', authType === 'ldap' && 'active']"
@click="authType = 'ldap'"
>
LDAP 登录
</button>
</div>
<!-- 登录表单 -->
<form
class="space-y-4"
@submit.prevent="handleLogin"
>
<div class="space-y-2">
<Label for="login-email">邮箱</Label>
<div class="flex items-center justify-between">
<Label for="login-email">{{ emailLabel }}</Label>
<button
v-if="ldapExclusive && authType === 'ldap'"
type="button"
class="text-xs text-muted-foreground/60 hover:text-muted-foreground transition-colors"
@click="authType = 'local'"
>
管理员本地登录
</button>
<button
v-if="ldapExclusive && authType === 'local'"
type="button"
class="text-xs text-muted-foreground/60 hover:text-muted-foreground transition-colors"
@click="authType = 'ldap'"
>
返回 LDAP 登录
</button>
</div>
<Input
id="login-email"
v-model="form.email"
type="email"
type="text"
required
placeholder="hello@example.com"
placeholder="username 或 email"
autocomplete="off"
/>
</div>
@@ -180,6 +220,30 @@ const showRegisterDialog = ref(false)
const requireEmailVerification = ref(false)
const allowRegistration = ref(false) // 由系统配置控制,默认关闭
// LDAP authentication settings
const PREFERRED_AUTH_TYPE_KEY = 'aether_preferred_auth_type'
function getStoredAuthType(): 'local' | 'ldap' {
const stored = localStorage.getItem(PREFERRED_AUTH_TYPE_KEY)
return (stored === 'ldap' || stored === 'local') ? stored : 'local'
}
const authType = ref<'local' | 'ldap'>(getStoredAuthType())
const localEnabled = ref(true)
const ldapEnabled = ref(false)
const ldapExclusive = ref(false)
// 保存用户的认证类型偏好
watch(authType, (newType) => {
localStorage.setItem(PREFERRED_AUTH_TYPE_KEY, newType)
})
const showAuthTypeTabs = computed(() => {
return localEnabled.value && ldapEnabled.value && !ldapExclusive.value
})
const emailLabel = computed(() => {
return '用户名/邮箱'
})
watch(() => props.modelValue, (val) => {
isOpen.value = val
// 打开对话框时重置表单
@@ -212,7 +276,7 @@ async function handleLogin() {
return
}
const success = await authStore.login(form.value.email, form.value.password)
const success = await authStore.login(form.value.email, form.value.password, authType.value)
if (success) {
showSuccess('登录成功,正在跳转...')
@@ -246,16 +310,84 @@ function handleSwitchToLogin() {
isOpen.value = true
}
// Load registration settings on mount
// Load authentication and registration settings on mount
onMounted(async () => {
try {
const settings = await authApi.getRegistrationSettings()
allowRegistration.value = !!settings.enable_registration
requireEmailVerification.value = !!settings.require_email_verification
// Load registration settings
const regSettings = await authApi.getRegistrationSettings()
allowRegistration.value = !!regSettings.enable_registration
requireEmailVerification.value = !!regSettings.require_email_verification
// Load authentication settings
const authSettings = await authApi.getAuthSettings()
localEnabled.value = authSettings.local_enabled
ldapEnabled.value = authSettings.ldap_enabled
ldapExclusive.value = authSettings.ldap_exclusive
// 若仅允许 LDAP 登录,则禁用本地注册入口
if (ldapExclusive.value) {
allowRegistration.value = false
}
// Set default auth type based on settings
if (authSettings.ldap_exclusive) {
authType.value = 'ldap'
} else if (!authSettings.local_enabled && authSettings.ldap_enabled) {
authType.value = 'ldap'
} else {
authType.value = 'local'
}
} catch (error) {
// If获取失败保持默认关闭注册 & 关闭邮箱验证
// If获取失败保持默认关闭注册 & 关闭邮箱验证 & 使用本地认证
allowRegistration.value = false
requireEmailVerification.value = false
localEnabled.value = true
ldapEnabled.value = false
ldapExclusive.value = false
authType.value = 'local'
}
})
</script>
<style scoped>
.auth-type-tabs {
display: flex;
border-bottom: 1px solid hsl(var(--border));
}
.auth-tab {
flex: 1;
padding: 0.625rem 1rem;
font-size: 0.875rem;
font-weight: 500;
color: hsl(var(--muted-foreground));
background: transparent;
border: none;
cursor: pointer;
transition: color 0.15s ease;
position: relative;
}
.auth-tab::after {
content: '';
position: absolute;
bottom: -1px;
left: 0;
right: 0;
height: 2px;
background: transparent;
transition: background 0.15s ease;
}
.auth-tab:hover:not(.active) {
color: hsl(var(--foreground));
}
.auth-tab.active {
color: var(--book-cloth);
font-weight: 600;
}
.auth-tab.active::after {
background: var(--book-cloth);
}
</style>

View File

@@ -18,8 +18,22 @@
<span class="flex-shrink-0"></span>
</div>
</div>
<div
v-if="isLoading"
class="h-full min-h-[160px] flex items-center justify-center text-sm text-muted-foreground"
>
<Loader2 class="h-5 w-5 animate-spin mr-2" />
加载中...
</div>
<div
v-else-if="hasError"
class="h-full min-h-[160px] flex items-center justify-center text-sm text-destructive"
>
<AlertCircle class="h-4 w-4 mr-1.5" />
加载失败
</div>
<ActivityHeatmap
v-if="hasData"
v-else-if="hasData"
:data="data"
:show-header="false"
/>
@@ -34,6 +48,7 @@
<script setup lang="ts">
import { computed } from 'vue'
import { Loader2, AlertCircle } from 'lucide-vue-next'
import Card from '@/components/ui/card.vue'
import ActivityHeatmap from '@/components/stats/ActivityHeatmap.vue'
import type { ActivityHeatmap as ActivityHeatmapData } from '@/types/activity'
@@ -41,6 +56,8 @@ import type { ActivityHeatmap as ActivityHeatmapData } from '@/types/activity'
const props = defineProps<{
data: ActivityHeatmapData | null
title: string
isLoading?: boolean
hasError?: boolean
}>()
const legendLevels = [0.08, 0.25, 0.45, 0.65, 0.85]

View File

@@ -32,6 +32,17 @@
<!-- 分隔线 -->
<div class="hidden sm:block h-4 w-px bg-border" />
<!-- 通用搜索 -->
<div class="relative">
<Search class="absolute left-2.5 top-1/2 -translate-y-1/2 h-3.5 w-3.5 text-muted-foreground z-10 pointer-events-none" />
<Input
id="usage-records-search"
v-model="localSearch"
:placeholder="isAdmin ? '搜索用户/密钥/模型/提供商' : '搜索密钥/模型'"
class="w-32 sm:w-48 h-8 text-xs border-border/60 pl-8"
/>
</div>
<!-- 用户筛选仅管理员可见 -->
<Select
v-if="isAdmin && availableUsers.length > 0"
@@ -164,6 +175,12 @@
>
用户
</TableHead>
<TableHead
v-if="!isAdmin"
class="h-12 font-semibold w-[100px]"
>
密钥
</TableHead>
<TableHead class="h-12 font-semibold w-[140px]">
模型
</TableHead>
@@ -196,7 +213,7 @@
<TableBody>
<TableRow v-if="records.length === 0">
<TableCell
:colspan="isAdmin ? 9 : 7"
:colspan="isAdmin ? 9 : 8"
class="text-center py-12 text-muted-foreground"
>
暂无请求记录
@@ -218,7 +235,34 @@
class="py-4 w-[100px] truncate"
:title="record.username || record.user_email || (record.user_id ? `User ${record.user_id}` : '已删除用户')"
>
<div class="flex flex-col text-xs gap-0.5">
<span class="truncate">
{{ record.username || record.user_email || (record.user_id ? `User ${record.user_id}` : '已删除用户') }}
</span>
<span
v-if="record.api_key?.name"
class="text-muted-foreground truncate"
:title="record.api_key.name"
>
{{ record.api_key.name }}
</span>
</div>
</TableCell>
<!-- 用户页面的密钥列 -->
<TableCell
v-if="!isAdmin"
class="py-4 w-[100px]"
:title="record.api_key?.name || '-'"
>
<div class="flex flex-col text-xs gap-0.5">
<span class="truncate">{{ record.api_key?.name || '-' }}</span>
<span
v-if="record.api_key?.display"
class="text-muted-foreground truncate"
>
{{ record.api_key.display }}
</span>
</div>
</TableCell>
<TableCell
class="font-medium py-4 w-[140px]"
@@ -438,6 +482,7 @@ import {
TableCard,
Badge,
Button,
Input,
Select,
SelectTrigger,
SelectValue,
@@ -451,7 +496,7 @@ import {
TableCell,
Pagination,
} from '@/components/ui'
import { RefreshCcw } from 'lucide-vue-next'
import { RefreshCcw, Search } from 'lucide-vue-next'
import { formatTokens, formatCurrency } from '@/utils/format'
import { formatDateTime } from '../composables'
import { useRowClick } from '@/composables/useRowClick'
@@ -471,6 +516,7 @@ const props = defineProps<{
// 时间段
selectedPeriod: string
// 筛选
filterSearch: string
filterUser: string
filterModel: string
filterProvider: string
@@ -489,6 +535,7 @@ const props = defineProps<{
const emit = defineEmits<{
'update:selectedPeriod': [value: string]
'update:filterSearch': [value: string]
'update:filterUser': [value: string]
'update:filterModel': [value: string]
'update:filterProvider': [value: string]
@@ -507,6 +554,23 @@ const filterModelSelectOpen = ref(false)
const filterProviderSelectOpen = ref(false)
const filterStatusSelectOpen = ref(false)
// 通用搜索(输入防抖)
const localSearch = ref(props.filterSearch)
let searchDebounceTimer: ReturnType<typeof setTimeout> | null = null
watch(() => props.filterSearch, (value) => {
if (value !== localSearch.value) {
localSearch.value = value
}
})
watch(localSearch, (value) => {
if (searchDebounceTimer) clearTimeout(searchDebounceTimer)
searchDebounceTimer = setTimeout(() => {
emit('update:filterSearch', value)
}, 300)
})
// 动态计时器相关
const now = ref(Date.now())
let timerInterval: ReturnType<typeof setInterval> | null = null
@@ -574,6 +638,10 @@ function handleRowClick(event: MouseEvent, id: string) {
// 组件卸载时清理
onUnmounted(() => {
stopTimer()
if (searchDebounceTimer) {
clearTimeout(searchDebounceTimer)
searchDebounceTimer = null
}
})
// 格式化 API 格式显示名称

View File

@@ -23,6 +23,7 @@ export interface PaginationParams {
}
export interface FilterParams {
search?: string
user_id?: string
model?: string
provider?: string
@@ -64,9 +65,6 @@ export function useUsageData(options: UseUsageDataOptions) {
}))
})
// 活跃度热图数据
const activityHeatmapData = computed(() => stats.value.activity_heatmap)
// 加载统计数据(不加载记录)
async function loadStats(dateRange?: DateRangeParams) {
isLoadingStats.value = true
@@ -93,7 +91,7 @@ export function useUsageData(options: UseUsageDataOptions) {
cache_stats: (statsData as any).cache_stats,
period_start: '',
period_end: '',
activity_heatmap: statsData.activity_heatmap || null
activity_heatmap: null
}
modelStats.value = modelData.map(item => ({
@@ -143,7 +141,7 @@ export function useUsageData(options: UseUsageDataOptions) {
avg_response_time: userData.avg_response_time || 0,
period_start: '',
period_end: '',
activity_heatmap: userData.activity_heatmap || null
activity_heatmap: null
}
modelStats.value = (userData.summary_by_model || []).map((item: any) => ({
@@ -237,11 +235,6 @@ export function useUsageData(options: UseUsageDataOptions) {
pagination: PaginationParams,
filters?: FilterParams
): Promise<void> {
if (!isAdminPage.value) {
// 用户页面不需要分页加载,记录已在 loadStats 中获取
return
}
isLoadingRecords.value = true
try {
@@ -255,6 +248,12 @@ export function useUsageData(options: UseUsageDataOptions) {
}
// 添加筛选条件
if (filters?.search?.trim()) {
params.search = filters.search.trim()
}
if (isAdminPage.value) {
// 管理员页面:使用管理员 API
if (filters?.user_id) {
params.user_id = filters.user_id
}
@@ -269,10 +268,14 @@ export function useUsageData(options: UseUsageDataOptions) {
}
const response = await usageApi.getAllUsageRecords(params)
currentRecords.value = (response.records || []) as UsageRecord[]
totalRecords.value = response.total || 0
} else {
// 用户页面:使用用户 API
const userData = await meApi.getUsage(params)
currentRecords.value = (userData.records || []) as UsageRecord[]
totalRecords.value = userData.pagination?.total || currentRecords.value.length
}
} catch (error) {
log.error('加载记录失败:', error)
currentRecords.value = []
@@ -305,7 +308,6 @@ export function useUsageData(options: UseUsageDataOptions) {
// 计算属性
enhancedModelStats,
activityHeatmapData,
// 方法
loadStats,

View File

@@ -1,5 +1,3 @@
import type { ActivityHeatmap } from '@/types/activity'
// 统计数据状态
export interface UsageStatsState {
total_requests: number
@@ -17,7 +15,6 @@ export interface UsageStatsState {
}
period_start: string
period_end: string
activity_heatmap: ActivityHeatmap | null
}
// 模型统计
@@ -64,6 +61,11 @@ export interface UsageRecord {
user_id?: string
username?: string
user_email?: string
api_key?: {
id: string | null
name: string | null
display: string | null
} | null
provider: string
api_key_name?: string
rate_multiplier?: number
@@ -115,7 +117,6 @@ export function createDefaultStats(): UsageStatsState {
error_rate: undefined,
cache_stats: undefined,
period_start: '',
period_end: '',
activity_heatmap: null
period_end: ''
}
}

View File

@@ -316,55 +316,10 @@
</div>
<!-- 模型多选下拉框 -->
<div class="space-y-2">
<Label class="text-sm font-medium">允许的模型</Label>
<div class="relative">
<button
type="button"
class="w-full h-10 px-3 border rounded-lg bg-background text-left flex items-center justify-between hover:bg-muted/50 transition-colors"
@click="modelDropdownOpen = !modelDropdownOpen"
>
<span :class="form.allowed_models.length ? 'text-foreground' : 'text-muted-foreground'">
{{ form.allowed_models.length ? `已选择 ${form.allowed_models.length}` : '全部可用' }}
</span>
<ChevronDown
class="h-4 w-4 text-muted-foreground transition-transform"
:class="modelDropdownOpen ? 'rotate-180' : ''"
<ModelMultiSelect
v-model="form.allowed_models"
:models="globalModels"
/>
</button>
<div
v-if="modelDropdownOpen"
class="fixed inset-0 z-[80]"
@click.stop="modelDropdownOpen = false"
/>
<div
v-if="modelDropdownOpen"
class="absolute z-[90] w-full mt-1 bg-popover border rounded-lg shadow-lg max-h-48 overflow-y-auto"
>
<div
v-for="model in globalModels"
:key="model.name"
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer"
@click="toggleSelection('allowed_models', model.name)"
>
<input
type="checkbox"
:checked="form.allowed_models.includes(model.name)"
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
@click.stop
@change="toggleSelection('allowed_models', model.name)"
>
<span class="text-sm">{{ model.name }}</span>
</div>
<div
v-if="globalModels.length === 0"
class="px-3 py-2 text-sm text-muted-foreground"
>
暂无可用模型
</div>
</div>
</div>
</div>
</div>
</div>
</form>
@@ -404,10 +359,12 @@ import {
} from '@/components/ui'
import { UserPlus, SquarePen, ChevronDown } from 'lucide-vue-next'
import { useFormDialog } from '@/composables/useFormDialog'
import { ModelMultiSelect } from '@/components/common'
import { getProvidersSummary } from '@/api/endpoints/providers'
import { getGlobalModels } from '@/api/global-models'
import { adminApi } from '@/api/admin'
import { log } from '@/utils/logger'
import type { ProviderWithEndpointsSummary, GlobalModelResponse } from '@/api/endpoints/types'
export interface UserFormData {
id?: string
@@ -440,11 +397,10 @@ const roleSelectOpen = ref(false)
// 下拉框状态
const providerDropdownOpen = ref(false)
const endpointDropdownOpen = ref(false)
const modelDropdownOpen = ref(false)
// 选项数据
const providers = ref<any[]>([])
const globalModels = ref<any[]>([])
const providers = ref<ProviderWithEndpointsSummary[]>([])
const globalModels = ref<GlobalModelResponse[]>([])
const apiFormats = ref<Array<{ value: string; label: string }>>([])
// 表单数据

View File

@@ -423,6 +423,7 @@ const navigation = computed(() => {
{ name: 'IP 安全', href: '/admin/ip-security', icon: Shield },
{ name: '审计日志', href: '/admin/audit-logs', icon: AlertTriangle },
{ name: '邮件配置', href: '/admin/email', icon: Mail },
{ name: 'LDAP 配置', href: '/admin/ldap', icon: Shield },
{ name: '系统设置', href: '/admin/system', icon: Cog },
]
}

View File

@@ -367,6 +367,11 @@ function generateMockUsageRecords(count: number = 100) {
user_id: user.id,
username: user.username,
user_email: user.email,
api_key: {
id: `key-${user.id}-${Math.ceil(Math.random() * 2)}`,
name: `${user.username} Key ${Math.ceil(Math.random() * 3)}`,
display: `sk-ae...${String(1000 + Math.floor(Math.random() * 9000))}`
},
provider: model.provider,
api_key_name: `${model.provider}-key-${Math.ceil(Math.random() * 3)}`,
rate_multiplier: 1.0,
@@ -835,10 +840,26 @@ const mockHandlers: Record<string, (config: AxiosRequestConfig) => Promise<Axios
'GET /api/admin/usage/records': async (config) => {
await delay()
requireAdmin()
const records = getUsageRecords()
let records = getUsageRecords()
const params = config.params || {}
const limit = parseInt(params.limit) || 20
const offset = parseInt(params.offset) || 0
// 通用搜索:用户名、密钥名、模型名、提供商名
// 支持空格分隔的组合搜索,多个关键词之间是 AND 关系
if (typeof params.search === 'string' && params.search.trim()) {
const keywords = params.search.trim().toLowerCase().split(/\s+/)
records = records.filter(r => {
// 每个关键词都要匹配至少一个字段
return keywords.every((keyword: string) =>
(r.username || '').toLowerCase().includes(keyword) ||
(r.api_key?.name || '').toLowerCase().includes(keyword) ||
(r.model || '').toLowerCase().includes(keyword) ||
(r.provider || '').toLowerCase().includes(keyword)
)
})
}
return createMockResponse({
records: records.slice(offset, offset + limit),
total: records.length,

View File

@@ -111,6 +111,11 @@ const routes: RouteRecordRaw[] = [
name: 'EmailSettings',
component: () => importWithRetry(() => import('@/views/admin/EmailSettings.vue'))
},
{
path: 'ldap',
name: 'LdapSettings',
component: () => importWithRetry(() => import('@/views/admin/LdapSettings.vue'))
},
{
path: 'audit-logs',
name: 'AuditLogs',

View File

@@ -31,12 +31,12 @@ export const useAuthStore = defineStore('auth', () => {
}
const isAdmin = computed(() => user.value?.role === 'admin')
async function login(email: string, password: string) {
async function login(email: string, password: string, authType: 'local' | 'ldap' = 'local') {
loading.value = true
error.value = null
try {
const response = await authApi.login({ email, password })
const response = await authApi.login({ email, password, auth_type: authType })
token.value = response.access_token
// 获取用户信息

View File

@@ -850,28 +850,20 @@ async function deleteApiKey(apiKey: AdminApiKey) {
}
function editApiKey(apiKey: AdminApiKey) {
// 计算过期天数
let expireDays: number | undefined = undefined
let neverExpire = true
// 解析过期日期为 YYYY-MM-DD 格式
// 保留原始日期,不做时间过滤(避免编辑当天过期的 Key 时意外清空)
let expiresAt: string | undefined = undefined
if (apiKey.expires_at) {
const expiresDate = new Date(apiKey.expires_at)
const now = new Date()
const diffMs = expiresDate.getTime() - now.getTime()
const diffDays = Math.ceil(diffMs / (1000 * 60 * 60 * 24))
if (diffDays > 0) {
expireDays = diffDays
neverExpire = false
}
expiresAt = expiresDate.toISOString().split('T')[0]
}
editingKeyData.value = {
id: apiKey.id,
name: apiKey.name || '',
expire_days: expireDays,
never_expire: neverExpire,
rate_limit: apiKey.rate_limit || 100,
expires_at: expiresAt,
rate_limit: apiKey.rate_limit ?? undefined,
auto_delete_on_expiry: apiKey.auto_delete_on_expiry || false,
allowed_providers: apiKey.allowed_providers || [],
allowed_api_formats: apiKey.allowed_api_formats || [],
@@ -1033,14 +1025,25 @@ function closeKeyFormDialog() {
// 统一处理表单提交
async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
// 验证过期日期(如果设置了,必须晚于今天)
if (data.expires_at) {
const selectedDate = new Date(data.expires_at)
const today = new Date()
today.setHours(0, 0, 0, 0)
if (selectedDate <= today) {
error('过期日期必须晚于今天')
return
}
}
keyFormDialogRef.value?.setSaving(true)
try {
if (data.id) {
// 更新
const updateData: Partial<CreateStandaloneApiKeyRequest> = {
name: data.name || undefined,
rate_limit: data.rate_limit,
expire_days: data.never_expire ? null : (data.expire_days || null),
rate_limit: data.rate_limit ?? null, // undefined = 无限制,显式传 null
expires_at: data.expires_at || null, // undefined/空 = 永不过期
auto_delete_on_expiry: data.auto_delete_on_expiry,
// 空数组表示清除限制(允许全部),后端会将空数组存为 NULL
allowed_providers: data.allowed_providers,
@@ -1058,8 +1061,8 @@ async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
const createData: CreateStandaloneApiKeyRequest = {
name: data.name || undefined,
initial_balance_usd: data.initial_balance_usd,
rate_limit: data.rate_limit,
expire_days: data.never_expire ? null : (data.expire_days || null),
rate_limit: data.rate_limit ?? null, // undefined = 无限制,显式传 null
expires_at: data.expires_at || null, // undefined/空 = 永不过期
auto_delete_on_expiry: data.auto_delete_on_expiry,
// 空数组表示不设置限制(允许全部),后端会将空数组存为 NULL
allowed_providers: data.allowed_providers,

View File

@@ -106,23 +106,23 @@
type="text"
:placeholder="smtpPasswordIsSet ? '已设置(留空保持不变)' : '请输入密码'"
class="-webkit-text-security-disc"
:class="smtpPasswordIsSet ? 'pr-8' : ''"
:class="(smtpPasswordIsSet || emailConfig.smtp_password) ? 'pr-10' : ''"
autocomplete="one-time-code"
data-lpignore="true"
data-1p-ignore="true"
data-form-type="other"
/>
<button
v-if="smtpPasswordIsSet"
v-if="smtpPasswordIsSet || emailConfig.smtp_password"
type="button"
class="absolute right-2 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground transition-colors"
title="清除已保存的密码"
class="absolute right-3 top-1/2 -translate-y-1/2 p-1 rounded-full text-muted-foreground/60 hover:text-muted-foreground hover:bg-muted/50 transition-colors"
title="清除密码"
@click="handleClearSmtpPassword"
>
<svg
xmlns="http://www.w3.org/2000/svg"
width="16"
height="16"
width="14"
height="14"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
@@ -498,6 +498,7 @@ const smtpEncryptionSelectOpen = ref(false)
const emailSuffixModeSelectOpen = ref(false)
const testSmtpLoading = ref(false)
const smtpPasswordIsSet = ref(false)
const clearSmtpPassword = ref(false) // 标记是否要清除密码
// 邮件模板相关状态
const templateLoading = ref(false)
@@ -710,6 +711,7 @@ async function loadEmailConfig() {
// 配置不存在时使用默认值,无需处理
}
}
clearSmtpPassword.value = false
} catch (err) {
error('加载邮件配置失败')
log.error('加载邮件配置失败:', err)
@@ -720,6 +722,12 @@ async function loadEmailConfig() {
async function saveSmtpConfig() {
smtpSaveLoading.value = true
try {
const passwordAction: 'unchanged' | 'updated' | 'cleared' = emailConfig.value.smtp_password
? 'updated'
: clearSmtpPassword.value
? 'cleared'
: 'unchanged'
const configItems = [
{
key: 'smtp_host',
@@ -737,7 +745,7 @@ async function saveSmtpConfig() {
description: 'SMTP 用户名'
},
// 只有输入了新密码才提交(空值表示保持原密码)
...(emailConfig.value.smtp_password
...(passwordAction === 'updated'
? [{
key: 'smtp_password',
value: emailConfig.value.smtp_password,
@@ -770,8 +778,23 @@ async function saveSmtpConfig() {
adminApi.updateSystemConfig(item.key, item.value, item.description)
)
// 如果标记了清除密码,删除密码配置
if (passwordAction === 'cleared') {
promises.push(adminApi.deleteSystemConfig('smtp_password'))
}
await Promise.all(promises)
success('SMTP 配置已保存')
// 更新状态
if (passwordAction === 'cleared') {
clearSmtpPassword.value = false
smtpPasswordIsSet.value = false
} else if (passwordAction === 'updated') {
clearSmtpPassword.value = false
smtpPasswordIsSet.value = true
}
emailConfig.value.smtp_password = null
} catch (err) {
error('保存配置失败')
log.error('保存 SMTP 配置失败:', err)
@@ -812,15 +835,16 @@ async function saveEmailSuffixConfig() {
}
// 清除 SMTP 密码
async function handleClearSmtpPassword() {
try {
await adminApi.deleteSystemConfig('smtp_password')
smtpPasswordIsSet.value = false
function handleClearSmtpPassword() {
// 如果有输入内容,先清空输入框
if (emailConfig.value.smtp_password) {
emailConfig.value.smtp_password = null
success('SMTP 密码已清除')
} catch (err) {
error('清除密码失败')
log.error('清除 SMTP 密码失败:', err)
return
}
// 标记要清除服务端密码(保存时生效)
if (smtpPasswordIsSet.value) {
clearSmtpPassword.value = true
smtpPasswordIsSet.value = false
}
}

View File

@@ -0,0 +1,379 @@
<template>
<PageContainer>
<PageHeader
title="LDAP 配置"
description="配置 LDAP 认证服务"
/>
<div class="mt-6 space-y-6">
<CardSection
title="LDAP 服务器配置"
description="配置 LDAP 服务器连接参数"
>
<template #actions>
<div class="flex gap-2">
<Button
size="sm"
variant="outline"
:disabled="testLoading"
@click="handleTestConnection"
>
{{ testLoading ? '测试中...' : '测试连接' }}
</Button>
<Button
size="sm"
:disabled="saveLoading"
@click="handleSave"
>
{{ saveLoading ? '保存中...' : '保存' }}
</Button>
</div>
</template>
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
<div>
<Label for="server-url" class="block text-sm font-medium">
服务器地址
</Label>
<Input
id="server-url"
v-model="ldapConfig.server_url"
type="text"
placeholder="ldap://ldap.example.com:389"
class="mt-1"
/>
<p class="mt-1 text-xs text-muted-foreground">
格式: ldap://host:389 或 ldaps://host:636
</p>
</div>
<div>
<Label for="bind-dn" class="block text-sm font-medium">
绑定 DN
</Label>
<Input
id="bind-dn"
v-model="ldapConfig.bind_dn"
type="text"
placeholder="cn=admin,dc=example,dc=com"
class="mt-1"
/>
<p class="mt-1 text-xs text-muted-foreground">
用于连接 LDAP 服务器的管理员 DN
</p>
</div>
<div>
<Label for="bind-password" class="block text-sm font-medium">
绑定密码
</Label>
<div class="relative mt-1">
<Input
id="bind-password"
v-model="ldapConfig.bind_password"
type="password"
:placeholder="hasPassword ? '已设置(留空保持不变)' : '请输入密码'"
:class="(hasPassword || ldapConfig.bind_password) ? 'pr-10' : ''"
autocomplete="new-password"
/>
<button
v-if="hasPassword || ldapConfig.bind_password"
type="button"
class="absolute right-3 top-1/2 -translate-y-1/2 p-1 rounded-full text-muted-foreground/60 hover:text-muted-foreground hover:bg-muted/50 transition-colors"
@click="handleClearPassword"
title="清除密码"
>
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><line x1="18" y1="6" x2="6" y2="18"/><line x1="6" y1="6" x2="18" y2="18"/></svg>
</button>
</div>
<p class="mt-1 text-xs text-muted-foreground">
绑定账号的密码
</p>
</div>
<div>
<Label for="base-dn" class="block text-sm font-medium">
基础 DN
</Label>
<Input
id="base-dn"
v-model="ldapConfig.base_dn"
type="text"
placeholder="ou=users,dc=example,dc=com"
class="mt-1"
/>
<p class="mt-1 text-xs text-muted-foreground">
用户搜索的基础 DN
</p>
</div>
<div>
<Label for="user-search-filter" class="block text-sm font-medium">
用户搜索过滤器
</Label>
<Input
id="user-search-filter"
v-model="ldapConfig.user_search_filter"
type="text"
placeholder="(uid={username})"
class="mt-1"
/>
<p class="mt-1 text-xs text-muted-foreground">
{username} 会被替换为登录用户名
</p>
</div>
<div>
<Label for="username-attr" class="block text-sm font-medium">
用户名属性
</Label>
<Input
id="username-attr"
v-model="ldapConfig.username_attr"
type="text"
placeholder="uid"
class="mt-1"
/>
<p class="mt-1 text-xs text-muted-foreground">
常用: uid (OpenLDAP), sAMAccountName (AD)
</p>
</div>
<div>
<Label for="email-attr" class="block text-sm font-medium">
邮箱属性
</Label>
<Input
id="email-attr"
v-model="ldapConfig.email_attr"
type="text"
placeholder="mail"
class="mt-1"
/>
</div>
<div>
<Label for="display-name-attr" class="block text-sm font-medium">
显示名称属性
</Label>
<Input
id="display-name-attr"
v-model="ldapConfig.display_name_attr"
type="text"
placeholder="cn"
class="mt-1"
/>
</div>
<div>
<Label for="connect-timeout" class="block text-sm font-medium">
连接超时 ()
</Label>
<Input
id="connect-timeout"
v-model.number="ldapConfig.connect_timeout"
type="number"
min="1"
max="60"
placeholder="10"
class="mt-1"
/>
<p class="mt-1 text-xs text-muted-foreground">
单次 LDAP 操作超时时间 (1-60)跨国网络建议 15-30
</p>
</div>
</div>
<div class="mt-6 space-y-4">
<div class="flex items-center justify-between">
<div>
<Label class="text-sm font-medium">使用 STARTTLS</Label>
<p class="text-xs text-muted-foreground">
在非 SSL 连接上启用 TLS 加密
</p>
</div>
<Switch v-model="ldapConfig.use_starttls" />
</div>
<div class="flex items-center justify-between">
<div>
<Label class="text-sm font-medium">启用 LDAP 认证</Label>
<p class="text-xs text-muted-foreground">
允许用户使用 LDAP 账号登录
</p>
</div>
<Switch v-model="ldapConfig.is_enabled" />
</div>
<div class="flex items-center justify-between">
<div>
<Label class="text-sm font-medium">仅允许 LDAP 登录</Label>
<p class="text-xs text-muted-foreground">
禁用本地账号登录仅允许 LDAP 认证
</p>
</div>
<Switch v-model="ldapConfig.is_exclusive" />
</div>
</div>
</CardSection>
</div>
</PageContainer>
</template>
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import { PageContainer, PageHeader, CardSection } from '@/components/layout'
import { Button, Input, Label, Switch } from '@/components/ui'
import { useToast } from '@/composables/useToast'
import { adminApi, type LdapConfigUpdateRequest } from '@/api/admin'
const { success, error } = useToast()
const loading = ref(false)
const saveLoading = ref(false)
const testLoading = ref(false)
const hasPassword = ref(false)
const clearPassword = ref(false) // 标记是否要清除密码
const ldapConfig = ref({
server_url: '',
bind_dn: '',
bind_password: '',
base_dn: '',
user_search_filter: '(uid={username})',
username_attr: 'uid',
email_attr: 'mail',
display_name_attr: 'cn',
is_enabled: false,
is_exclusive: false,
use_starttls: false,
connect_timeout: 10,
})
onMounted(async () => {
await loadConfig()
})
async function loadConfig() {
loading.value = true
try {
const response = await adminApi.getLdapConfig()
ldapConfig.value = {
server_url: response.server_url || '',
bind_dn: response.bind_dn || '',
bind_password: '',
base_dn: response.base_dn || '',
user_search_filter: response.user_search_filter || '(uid={username})',
username_attr: response.username_attr || 'uid',
email_attr: response.email_attr || 'mail',
display_name_attr: response.display_name_attr || 'cn',
is_enabled: response.is_enabled || false,
is_exclusive: response.is_exclusive || false,
use_starttls: response.use_starttls || false,
connect_timeout: response.connect_timeout || 10,
}
hasPassword.value = !!response.has_bind_password
clearPassword.value = false
} catch (err) {
error('加载 LDAP 配置失败')
console.error('加载 LDAP 配置失败:', err)
} finally {
loading.value = false
}
}
async function handleSave() {
saveLoading.value = true
try {
const payload: LdapConfigUpdateRequest = {
server_url: ldapConfig.value.server_url,
bind_dn: ldapConfig.value.bind_dn,
base_dn: ldapConfig.value.base_dn,
user_search_filter: ldapConfig.value.user_search_filter,
username_attr: ldapConfig.value.username_attr,
email_attr: ldapConfig.value.email_attr,
display_name_attr: ldapConfig.value.display_name_attr,
is_enabled: ldapConfig.value.is_enabled,
is_exclusive: ldapConfig.value.is_exclusive,
use_starttls: ldapConfig.value.use_starttls,
connect_timeout: ldapConfig.value.connect_timeout,
}
// 优先使用输入的新密码;否则如果标记清除则发送空字符串
let passwordAction: 'unchanged' | 'updated' | 'cleared' = 'unchanged'
if (ldapConfig.value.bind_password) {
payload.bind_password = ldapConfig.value.bind_password
passwordAction = 'updated'
} else if (clearPassword.value) {
payload.bind_password = ''
passwordAction = 'cleared'
}
await adminApi.updateLdapConfig(payload)
success('LDAP 配置保存成功')
if (passwordAction === 'cleared') {
hasPassword.value = false
clearPassword.value = false
} else if (passwordAction === 'updated') {
hasPassword.value = true
clearPassword.value = false
}
ldapConfig.value.bind_password = ''
} catch (err) {
error('保存 LDAP 配置失败')
console.error('保存 LDAP 配置失败:', err)
} finally {
saveLoading.value = false
}
}
async function handleTestConnection() {
if (clearPassword.value && !ldapConfig.value.bind_password) {
error('已标记清除绑定密码,请先保存或输入新的绑定密码再测试')
return
}
testLoading.value = true
try {
const payload: LdapConfigUpdateRequest = {
server_url: ldapConfig.value.server_url,
bind_dn: ldapConfig.value.bind_dn,
base_dn: ldapConfig.value.base_dn,
user_search_filter: ldapConfig.value.user_search_filter,
username_attr: ldapConfig.value.username_attr,
email_attr: ldapConfig.value.email_attr,
display_name_attr: ldapConfig.value.display_name_attr,
is_enabled: ldapConfig.value.is_enabled,
is_exclusive: ldapConfig.value.is_exclusive,
use_starttls: ldapConfig.value.use_starttls,
connect_timeout: ldapConfig.value.connect_timeout,
...(ldapConfig.value.bind_password && { bind_password: ldapConfig.value.bind_password }),
}
const response = await adminApi.testLdapConnection(payload)
if (response.success) {
success('LDAP 连接测试成功')
} else {
error(`LDAP 连接测试失败: ${response.message}`)
}
} catch (err) {
error('LDAP 连接测试失败')
console.error('LDAP 连接测试失败:', err)
} finally {
testLoading.value = false
}
}
function handleClearPassword() {
// 如果有输入内容,先清空输入框
if (ldapConfig.value.bind_password) {
ldapConfig.value.bind_password = ''
return
}
// 标记要清除服务端密码(保存时生效)
if (hasPassword.value) {
clearPassword.value = true
hasPassword.value = false
}
}
</script>

View File

@@ -464,6 +464,30 @@
</div>
</div>
</CardSection>
<!-- 系统版本信息 -->
<CardSection
title="系统信息"
description="当前系统版本和构建信息"
>
<div class="flex items-center gap-4">
<div class="flex items-center gap-2">
<Label class="text-sm font-medium text-muted-foreground">版本:</Label>
<span
v-if="systemVersion"
class="text-sm font-mono"
>
{{ systemVersion }}
</span>
<span
v-else
class="text-sm text-muted-foreground"
>
加载中...
</span>
</div>
</div>
</CardSection>
</div>
<!-- 导入配置对话框 -->
@@ -475,7 +499,7 @@
<div class="space-y-4">
<div
v-if="importPreview"
class="p-3 bg-muted rounded-lg text-sm"
class="text-sm"
>
<p class="font-medium mb-2">
配置预览
@@ -557,7 +581,7 @@
class="space-y-4"
>
<div class="grid grid-cols-2 gap-4 text-sm">
<div class="p-3 bg-muted rounded-lg">
<div>
<p class="font-medium">
全局模型
</p>
@@ -567,7 +591,7 @@
跳过: {{ importResult.stats.global_models.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<div>
<p class="font-medium">
提供商
</p>
@@ -577,7 +601,7 @@
跳过: {{ importResult.stats.providers.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<div>
<p class="font-medium">
端点
</p>
@@ -587,7 +611,7 @@
跳过: {{ importResult.stats.endpoints.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<div>
<p class="font-medium">
API Keys
</p>
@@ -596,7 +620,7 @@
跳过: {{ importResult.stats.keys.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg col-span-2">
<div class="col-span-2">
<p class="font-medium">
模型配置
</p>
@@ -642,7 +666,7 @@
<div class="space-y-4">
<div
v-if="importUsersPreview"
class="p-3 bg-muted rounded-lg text-sm"
class="text-sm"
>
<p class="font-medium mb-2">
数据预览
@@ -652,6 +676,9 @@
<li>
API Keys: {{ importUsersPreview.users?.reduce((sum: number, u: any) => sum + (u.api_keys?.length || 0), 0) }}
</li>
<li v-if="importUsersPreview.standalone_keys?.length">
独立余额 Keys: {{ importUsersPreview.standalone_keys.length }}
</li>
</ul>
</div>
@@ -720,7 +747,7 @@
class="space-y-4"
>
<div class="grid grid-cols-2 gap-4 text-sm">
<div class="p-3 bg-muted rounded-lg">
<div>
<p class="font-medium">
用户
</p>
@@ -730,7 +757,7 @@
跳过: {{ importUsersResult.stats.users.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<div>
<p class="font-medium">
API Keys
</p>
@@ -739,6 +766,18 @@
跳过: {{ importUsersResult.stats.api_keys.skipped }}
</p>
</div>
<div
v-if="importUsersResult.stats.standalone_keys"
class="col-span-2"
>
<p class="font-medium">
独立余额 Keys
</p>
<p class="text-muted-foreground">
创建: {{ importUsersResult.stats.standalone_keys.created }},
跳过: {{ importUsersResult.stats.standalone_keys.skipped }}
</p>
</div>
</div>
<div
@@ -839,6 +878,9 @@ const importUsersResult = ref<UsersImportResponse | null>(null)
const usersMergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
const usersMergeModeSelectOpen = ref(false)
// 系统版本信息
const systemVersion = ref<string>('')
const systemConfig = ref<SystemConfig>({
// 基础配置
default_user_quota_usd: 10.0,
@@ -890,9 +932,21 @@ const sensitiveHeadersStr = computed({
})
onMounted(async () => {
await loadSystemConfig()
await Promise.all([
loadSystemConfig(),
loadSystemVersion()
])
})
async function loadSystemVersion() {
try {
const data = await adminApi.getSystemVersion()
systemVersion.value = data.version
} catch (err) {
log.error('加载系统版本失败:', err)
}
}
async function loadSystemConfig() {
try {
const configs = [
@@ -1178,12 +1232,6 @@ function handleUsersFileSelect(event: Event) {
const content = e.target?.result as string
const data = JSON.parse(content) as UsersExportData
// 验证版本
if (data.version !== '1.0') {
error(`不支持的配置版本: ${data.version}`)
return
}
importUsersPreview.value = data
usersMergeMode.value = 'skip'
importUsersDialogOpen.value = true

View File

@@ -5,6 +5,8 @@
<ActivityHeatmapCard
:data="activityHeatmapData"
:title="isAdminPage ? '总体活跃天数' : '我的活跃天数'"
:is-loading="isLoadingHeatmap"
:has-error="heatmapError"
/>
<IntervalTimelineCard
:title="isAdminPage ? '请求间隔时间线' : '我的请求间隔'"
@@ -54,6 +56,7 @@
:show-actual-cost="authStore.isAdmin"
:loading="isLoadingRecords"
:selected-period="selectedPeriod"
:filter-search="filterSearch"
:filter-user="filterUser"
:filter-model="filterModel"
:filter-provider="filterProvider"
@@ -67,6 +70,7 @@
:page-size-options="pageSizeOptions"
:auto-refresh="globalAutoRefresh"
@update:selected-period="handlePeriodChange"
@update:filter-search="handleFilterSearchChange"
@update:filter-user="handleFilterUserChange"
@update:filter-model="handleFilterModelChange"
@update:filter-provider="handleFilterProviderChange"
@@ -112,8 +116,11 @@ import {
import type { PeriodValue, FilterStatusValue } from '@/features/usage/types'
import type { UserOption } from '@/features/usage/components/UsageRecordsTable.vue'
import { log } from '@/utils/logger'
import type { ActivityHeatmap } from '@/types/activity'
import { useToast } from '@/composables/useToast'
const route = useRoute()
const { warning } = useToast()
const authStore = useAuthStore()
// 判断是否是管理员页面
@@ -128,6 +135,7 @@ const pageSize = ref(20)
const pageSizeOptions = [10, 20, 50, 100]
// 筛选状态
const filterSearch = ref('')
const filterUser = ref('__all__')
const filterModel = ref('__all__')
const filterProvider = ref('__all__')
@@ -144,13 +152,35 @@ const {
currentRecords,
totalRecords,
enhancedModelStats,
activityHeatmapData,
availableModels,
availableProviders,
loadStats,
loadRecords
} = useUsageData({ isAdminPage })
// 热力图状态
const activityHeatmapData = ref<ActivityHeatmap | null>(null)
const isLoadingHeatmap = ref(false)
const heatmapError = ref(false)
// 加载热力图数据
async function loadHeatmapData() {
isLoadingHeatmap.value = true
heatmapError.value = false
try {
if (isAdminPage.value) {
activityHeatmapData.value = await usageApi.getActivityHeatmap()
} else {
activityHeatmapData.value = await meApi.getActivityHeatmap()
}
} catch (error) {
log.error('加载热力图数据失败:', error)
heatmapError.value = true
} finally {
isLoadingHeatmap.value = false
}
}
// 用户页面需要前端筛选
const filteredRecords = computed(() => {
if (!isAdminPage.value) {
@@ -232,27 +262,40 @@ async function pollActiveRequests() {
? await usageApi.getActiveRequests(activeRequestIds.value)
: await meApi.getActiveRequests(idsParam)
// 检查是否有状态变化
let hasChanges = false
let shouldRefresh = false
for (const update of requests) {
const record = currentRecords.value.find(r => r.id === update.id)
if (record && record.status !== update.status) {
hasChanges = true
// 如果状态变为 completed 或 failed需要刷新获取完整数据
if (update.status === 'completed' || update.status === 'failed') {
break
if (!record) {
// 后端返回了未知的活跃请求,触发刷新以获取完整数据
shouldRefresh = true
continue
}
// 否则只更新状态和 token 信息
// 状态变化completed/failed 需要刷新获取完整数据
if (record.status !== update.status) {
record.status = update.status
}
if (update.status === 'completed' || update.status === 'failed') {
shouldRefresh = true
}
// 进行中状态也需要持续更新provider/key/TTFB 可能在 streaming 后才落库)
record.input_tokens = update.input_tokens
record.output_tokens = update.output_tokens
record.cost = update.cost
record.response_time_ms = update.response_time_ms ?? undefined
record.first_byte_time_ms = update.first_byte_time_ms ?? undefined
// 管理员接口返回额外字段
if ('provider' in update && typeof update.provider === 'string') {
record.provider = update.provider
}
if ('api_key_name' in update) {
record.api_key_name = typeof update.api_key_name === 'string' ? update.api_key_name : undefined
}
}
// 如果有请求完成或失败,刷新整个列表获取完整数据
if (hasChanges && requests.some(r => r.status === 'completed' || r.status === 'failed')) {
if (shouldRefresh) {
await refreshData()
}
} catch (error) {
@@ -335,16 +378,34 @@ const selectedRequestId = ref<string | null>(null)
// 初始化加载
onMounted(async () => {
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
await loadStats(dateRange)
// 管理员页面加载用户列表和第一页记录
// 并行加载统计数据和热力图(使用 allSettled 避免其中一个失败影响另一个)
const [statsResult, heatmapResult] = await Promise.allSettled([
loadStats(dateRange),
loadHeatmapData()
])
// 检查加载结果并通知用户
if (statsResult.status === 'rejected') {
log.error('加载统计数据失败:', statsResult.reason)
warning('统计数据加载失败,请刷新重试')
}
if (heatmapResult.status === 'rejected') {
log.error('加载热力图数据失败:', heatmapResult.reason)
// 热力图加载失败不提示,因为 UI 已显示占位符
}
// 加载记录和用户列表
if (isAdminPage.value) {
// 并行加载用户列表和记录
// 管理员页面:并行加载用户列表和记录
const [users] = await Promise.all([
usersApi.getAllUsers(),
loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
])
availableUsers.value = users.map(u => ({ id: u.id, username: u.username, email: u.email }))
} else {
// 用户页面:加载记录
await loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
}
})
@@ -355,34 +416,26 @@ async function handlePeriodChange(value: string) {
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
await loadStats(dateRange)
if (isAdminPage.value) {
await loadRecords({ page: 1, pageSize: pageSize.value }, getCurrentFilters())
}
}
// 处理分页变化
async function handlePageChange(page: number) {
currentPage.value = page
if (isAdminPage.value) {
await loadRecords({ page, pageSize: pageSize.value }, getCurrentFilters())
}
}
// 处理每页大小变化
async function handlePageSizeChange(size: number) {
pageSize.value = size
currentPage.value = 1 // 重置到第一页
if (isAdminPage.value) {
await loadRecords({ page: 1, pageSize: size }, getCurrentFilters())
}
}
// 获取当前筛选参数
function getCurrentFilters() {
return {
search: filterSearch.value.trim() || undefined,
user_id: filterUser.value !== '__all__' ? filterUser.value : undefined,
model: filterModel.value !== '__all__' ? filterModel.value : undefined,
provider: filterProvider.value !== '__all__' ? filterProvider.value : undefined,
@@ -391,6 +444,13 @@ function getCurrentFilters() {
}
// 处理筛选变化
async function handleFilterSearchChange(value: string) {
filterSearch.value = value
currentPage.value = 1
await loadRecords({ page: 1, pageSize: pageSize.value }, getCurrentFilters())
}
async function handleFilterUserChange(value: string) {
filterUser.value = value
currentPage.value = 1 // 重置到第一页
@@ -431,11 +491,8 @@ async function handleFilterStatusChange(value: string) {
async function refreshData() {
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
await loadStats(dateRange)
if (isAdminPage.value) {
await loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
}
}
// 显示请求详情
function showRequestDetail(id: string) {

View File

@@ -47,6 +47,7 @@ dependencies = [
"redis>=5.0.0",
"prometheus-client>=0.20.0",
"apscheduler>=3.10.0",
"ldap3>=2.9.1",
]
[project.optional-dependencies]

View File

@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID
__commit_id__: COMMIT_ID
__version__ = version = '0.1.1.dev0+g393d4d13f.d20251213'
__version_tuple__ = version_tuple = (0, 1, 1, 'dev0', 'g393d4d13f.d20251213')
__version__ = version = '0.2.3.dev0+g0f78d5cbf.d20260105'
__version_tuple__ = version_tuple = (0, 2, 3, 'dev0', 'g0f78d5cbf.d20260105')
__commit_id__ = commit_id = None

View File

@@ -5,6 +5,7 @@ from fastapi import APIRouter
from .adaptive import router as adaptive_router
from .api_keys import router as api_keys_router
from .endpoints import router as endpoints_router
from .ldap import router as ldap_router
from .models import router as models_router
from .monitoring import router as monitoring_router
from .provider_query import router as provider_query_router
@@ -28,5 +29,6 @@ router.include_router(adaptive_router)
router.include_router(models_router)
router.include_router(security_router)
router.include_router(provider_query_router)
router.include_router(ldap_router)
__all__ = ["router"]

View File

@@ -3,22 +3,64 @@
独立余额Key不关联用户配额有独立余额限制用于给非注册用户使用。
"""
from datetime import datetime, timezone
import os
from datetime import datetime, timedelta, timezone
from typing import Optional
from zoneinfo import ZoneInfo
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import NotFoundException
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.api import CreateApiKeyRequest
from src.models.database import ApiKey, User
from src.models.database import ApiKey
from src.services.user.apikey import ApiKeyService
# 应用时区配置,默认为 Asia/Shanghai
APP_TIMEZONE = ZoneInfo(os.getenv("APP_TIMEZONE", "Asia/Shanghai"))
def parse_expiry_date(date_str: Optional[str]) -> Optional[datetime]:
"""解析过期日期字符串为 datetime 对象。
Args:
date_str: 日期字符串,支持 "YYYY-MM-DD" 或 ISO 格式
Returns:
datetime 对象(当天 23:59:59.999999,应用时区),或 None 如果输入为空
Raises:
BadRequestException: 日期格式无效
"""
if not date_str or not date_str.strip():
return None
date_str = date_str.strip()
# 尝试 YYYY-MM-DD 格式
try:
parsed_date = datetime.strptime(date_str, "%Y-%m-%d")
# 设置为当天结束时间 (23:59:59.999999,应用时区)
return parsed_date.replace(
hour=23, minute=59, second=59, microsecond=999999, tzinfo=APP_TIMEZONE
)
except ValueError:
pass
# 尝试完整 ISO 格式
try:
return datetime.fromisoformat(date_str.replace("Z", "+00:00"))
except ValueError:
pass
raise InvalidRequestException(f"无效的日期格式: {date_str},请使用 YYYY-MM-DD 格式")
router = APIRouter(prefix="/api/admin/api-keys", tags=["Admin - API Keys (Standalone)"])
pipeline = ApiRequestPipeline()
@@ -215,6 +257,9 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
# 独立Key需要关联到管理员用户从context获取
admin_user_id = context.user.id
# 解析过期时间(优先使用 expires_at其次使用 expire_days
expires_at_dt = parse_expiry_date(self.key_data.expires_at)
# 创建独立Key
api_key, plain_key = ApiKeyService.create_api_key(
db=db,
@@ -224,7 +269,8 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
allowed_api_formats=self.key_data.allowed_api_formats,
allowed_models=self.key_data.allowed_models,
rate_limit=self.key_data.rate_limit, # None 表示不限制
expire_days=self.key_data.expire_days,
expire_days=self.key_data.expire_days, # 兼容旧版
expires_at=expires_at_dt, # 优先使用
initial_balance_usd=self.key_data.initial_balance_usd,
is_standalone=True, # 标记为独立Key
auto_delete_on_expiry=self.key_data.auto_delete_on_expiry,
@@ -270,7 +316,8 @@ class AdminUpdateApiKeyAdapter(AdminApiAdapter):
update_data = {}
if self.key_data.name is not None:
update_data["name"] = self.key_data.name
if self.key_data.rate_limit is not None:
# rate_limit: 显式传递时更新(包括 null 表示无限制)
if "rate_limit" in self.key_data.model_fields_set:
update_data["rate_limit"] = self.key_data.rate_limit
if (
hasattr(self.key_data, "auto_delete_on_expiry")
@@ -287,18 +334,20 @@ class AdminUpdateApiKeyAdapter(AdminApiAdapter):
update_data["allowed_models"] = self.key_data.allowed_models
# 处理过期时间
if self.key_data.expire_days is not None:
if self.key_data.expire_days > 0:
from datetime import timedelta
# 优先使用 expires_at如果显式传递且有值
if self.key_data.expires_at and self.key_data.expires_at.strip():
update_data["expires_at"] = parse_expiry_date(self.key_data.expires_at)
elif "expires_at" in self.key_data.model_fields_set:
# expires_at 明确传递为 null 或空字符串,设为永不过期
update_data["expires_at"] = None
# 兼容旧版 expire_days
elif "expire_days" in self.key_data.model_fields_set:
if self.key_data.expire_days is not None and self.key_data.expire_days > 0:
update_data["expires_at"] = datetime.now(timezone.utc) + timedelta(
days=self.key_data.expire_days
)
else:
# expire_days = 0 或负数表示永不过期
update_data["expires_at"] = None
elif hasattr(self.key_data, "expire_days") and self.key_data.expire_days is None:
# 明确传递 None设为永不过期
# expire_days = None/0/负数 表示永不过期
update_data["expires_at"] = None
# 使用 ApiKeyService 更新

View File

@@ -206,6 +206,7 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
provider_id=self.provider_id,
api_format=self.endpoint_data.api_format,
base_url=self.endpoint_data.base_url,
custom_path=self.endpoint_data.custom_path,
headers=self.endpoint_data.headers,
timeout=self.endpoint_data.timeout,
max_retries=self.endpoint_data.max_retries,

427
src/api/admin/ldap.py Normal file
View File

@@ -0,0 +1,427 @@
"""LDAP配置管理API端点。"""
import re
from typing import Any, Dict, Optional
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel, Field, ValidationError, field_validator
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.crypto import crypto_service
from src.core.enums import AuthSource
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
from src.core.logger import logger
from src.database import get_db
from src.models.database import AuditEventType, LDAPConfig, User, UserRole
from src.services.system.audit import AuditService
router = APIRouter(prefix="/api/admin/ldap", tags=["Admin - LDAP"])
pipeline = ApiRequestPipeline()
# bcrypt 哈希格式正则:$2a$, $2b$, $2y$ + 2位cost + $ + 53字符(22位salt + 31位hash)
BCRYPT_HASH_PATTERN = re.compile(r"^\$2[aby]\$\d{2}\$.{53}$")
# ========== Request/Response Models ==========
class LDAPConfigResponse(BaseModel):
"""LDAP配置响应不返回密码"""
server_url: Optional[str] = None
bind_dn: Optional[str] = None
base_dn: Optional[str] = None
has_bind_password: bool = False
user_search_filter: str
username_attr: str
email_attr: str
display_name_attr: str
is_enabled: bool
is_exclusive: bool
use_starttls: bool
connect_timeout: int
class LDAPConfigUpdate(BaseModel):
"""LDAP配置更新请求"""
server_url: str = Field(..., min_length=1, max_length=255)
bind_dn: str = Field(..., min_length=1, max_length=255)
# 允许空字符串表示"清除密码";非空时自动 strip 并校验不能为空
bind_password: Optional[str] = Field(None, max_length=1024)
base_dn: str = Field(..., min_length=1, max_length=255)
user_search_filter: str = Field(default="(uid={username})", max_length=500)
username_attr: str = Field(default="uid", max_length=50)
email_attr: str = Field(default="mail", max_length=50)
display_name_attr: str = Field(default="cn", max_length=50)
is_enabled: bool = False
is_exclusive: bool = False
use_starttls: bool = False
connect_timeout: int = Field(default=10, ge=1, le=60) # 单次操作超时,跨国网络建议 15-30 秒
@field_validator("bind_password")
@classmethod
def validate_bind_password(cls, v: Optional[str]) -> Optional[str]:
if v is None or v == "":
return v
v = v.strip()
if not v:
raise ValueError("绑定密码不能为空")
return v
@field_validator("user_search_filter")
@classmethod
def validate_search_filter(cls, v: str) -> str:
if "{username}" not in v:
raise ValueError("搜索过滤器必须包含 {username} 占位符")
# 验证括号匹配和嵌套正确性
depth = 0
for char in v:
if char == "(":
depth += 1
elif char == ")":
depth -= 1
if depth < 0:
raise ValueError("搜索过滤器括号不匹配")
if depth != 0:
raise ValueError("搜索过滤器括号不匹配")
# 限制过滤器复杂度,防止构造复杂查询
# 检查嵌套层数而非括号总数
depth = 0
max_depth = 0
for char in v:
if char == "(":
depth += 1
max_depth = max(max_depth, depth)
elif char == ")":
depth -= 1
if max_depth > 5:
raise ValueError("搜索过滤器嵌套层数过深最多5层")
if len(v) > 200:
raise ValueError("搜索过滤器过长最多200字符")
return v
class LDAPTestResponse(BaseModel):
"""LDAP连接测试响应"""
success: bool
message: str
class LDAPConfigTest(BaseModel):
"""LDAP配置测试请求全部可选用于临时覆盖"""
server_url: Optional[str] = Field(None, min_length=1, max_length=255)
bind_dn: Optional[str] = Field(None, min_length=1, max_length=255)
bind_password: Optional[str] = Field(None, min_length=1)
base_dn: Optional[str] = Field(None, min_length=1, max_length=255)
user_search_filter: Optional[str] = Field(None, max_length=500)
username_attr: Optional[str] = Field(None, max_length=50)
email_attr: Optional[str] = Field(None, max_length=50)
display_name_attr: Optional[str] = Field(None, max_length=50)
is_enabled: Optional[bool] = None
is_exclusive: Optional[bool] = None
use_starttls: Optional[bool] = None
connect_timeout: Optional[int] = Field(None, ge=1, le=60)
@field_validator("user_search_filter")
@classmethod
def validate_search_filter(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
if "{username}" not in v:
raise ValueError("搜索过滤器必须包含 {username} 占位符")
# 验证括号匹配和嵌套正确性
depth = 0
for char in v:
if char == "(":
depth += 1
elif char == ")":
depth -= 1
if depth < 0:
raise ValueError("搜索过滤器括号不匹配")
if depth != 0:
raise ValueError("搜索过滤器括号不匹配")
# 限制过滤器复杂度(检查嵌套层数而非括号总数)
depth = 0
max_depth = 0
for char in v:
if char == "(":
depth += 1
max_depth = max(max_depth, depth)
elif char == ")":
depth -= 1
if max_depth > 5:
raise ValueError("搜索过滤器嵌套层数过深最多5层")
if len(v) > 200:
raise ValueError("搜索过滤器过长最多200字符")
return v
# ========== API Endpoints ==========
@router.get("/config")
async def get_ldap_config(request: Request, db: Session = Depends(get_db)) -> Any:
"""获取LDAP配置管理员"""
adapter = AdminGetLDAPConfigAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/config")
async def update_ldap_config(request: Request, db: Session = Depends(get_db)) -> Any:
"""更新LDAP配置管理员"""
adapter = AdminUpdateLDAPConfigAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/test")
async def test_ldap_connection(request: Request, db: Session = Depends(get_db)) -> Any:
"""测试LDAP连接管理员"""
adapter = AdminTestLDAPConnectionAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# ========== Adapters ==========
class AdminGetLDAPConfigAdapter(AdminApiAdapter):
async def handle(self, context) -> Dict[str, Any]: # type: ignore[override]
db = context.db
config = db.query(LDAPConfig).first()
if not config:
return LDAPConfigResponse(
server_url=None,
bind_dn=None,
base_dn=None,
has_bind_password=False,
user_search_filter="(uid={username})",
username_attr="uid",
email_attr="mail",
display_name_attr="cn",
is_enabled=False,
is_exclusive=False,
use_starttls=False,
connect_timeout=10,
).model_dump()
return LDAPConfigResponse(
server_url=config.server_url,
bind_dn=config.bind_dn,
base_dn=config.base_dn,
has_bind_password=bool(config.bind_password_encrypted),
user_search_filter=config.user_search_filter,
username_attr=config.username_attr,
email_attr=config.email_attr,
display_name_attr=config.display_name_attr,
is_enabled=config.is_enabled,
is_exclusive=config.is_exclusive,
use_starttls=config.use_starttls,
connect_timeout=config.connect_timeout,
).model_dump()
class AdminUpdateLDAPConfigAdapter(AdminApiAdapter):
async def handle(self, context) -> Dict[str, str]: # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
try:
config_update = LDAPConfigUpdate.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
# 使用行级锁防止并发修改导致的竞态条件
config = db.query(LDAPConfig).with_for_update().first()
is_new_config = config is None
if is_new_config:
# 首次创建配置时必须提供密码
if not config_update.bind_password:
raise InvalidRequestException("首次配置 LDAP 时必须设置绑定密码")
config = LDAPConfig()
db.add(config)
# 需要启用 LDAP 且未提交新密码时,验证已保存密码可解密(避免开启后不可用)
if config_update.is_enabled and config_update.bind_password is None:
try:
if not config.get_bind_password():
raise InvalidRequestException("启用 LDAP 认证 需要先设置绑定密码")
except InvalidRequestException:
raise
except Exception:
raise InvalidRequestException("绑定密码解密失败,请重新设置绑定密码")
# 计算更新后的密码状态(用于校验是否可启用/独占)
if config_update.bind_password is None:
will_have_password = bool(config.bind_password_encrypted)
elif config_update.bind_password == "":
will_have_password = False
else:
will_have_password = True
# 独占模式必须启用 LDAP 且必须有绑定密码(防止误锁定)
if config_update.is_exclusive and not config_update.is_enabled:
raise InvalidRequestException("仅允许 LDAP 登录 需要先启用 LDAP 认证")
if config_update.is_enabled and not will_have_password:
raise InvalidRequestException("启用 LDAP 认证 需要先设置绑定密码")
if config_update.is_exclusive and not will_have_password:
raise InvalidRequestException("仅允许 LDAP 登录 需要先设置绑定密码")
config.server_url = config_update.server_url
config.bind_dn = config_update.bind_dn
config.base_dn = config_update.base_dn
config.user_search_filter = config_update.user_search_filter
config.username_attr = config_update.username_attr
config.email_attr = config_update.email_attr
config.display_name_attr = config_update.display_name_attr
config.is_enabled = config_update.is_enabled
config.is_exclusive = config_update.is_exclusive
config.use_starttls = config_update.use_starttls
config.connect_timeout = config_update.connect_timeout
# 启用独占模式前检查是否有足够的本地管理员(防止锁定)
# 使用 with_for_update() 阻塞锁防止竞态条件(移除 nowait 确保并发安全)
if config_update.is_enabled and config_update.is_exclusive:
local_admins = (
db.query(User)
.filter(
User.role == UserRole.ADMIN,
User.auth_source == AuthSource.LOCAL,
User.is_active.is_(True),
User.is_deleted.is_(False),
)
.with_for_update()
.all()
)
# 验证至少有一个管理员有有效的密码哈希(可以登录)
# 使用严格的 bcrypt 格式校验:$2a$/$2b$/$2y$ + 2位cost + $ + 53字符
valid_admin_count = sum(
1
for admin in local_admins
if admin.password_hash
and isinstance(admin.password_hash, str)
and BCRYPT_HASH_PATTERN.match(admin.password_hash)
)
if valid_admin_count < 1:
raise InvalidRequestException(
"启用 LDAP 独占模式前,必须至少保留 1 个有效的本地管理员账户(含有效密码)作为紧急恢复通道"
)
if config_update.bind_password is not None:
if config_update.bind_password == "":
# 显式清除密码(设置为 NULL
config.bind_password_encrypted = None
password_changed = "cleared"
else:
config.bind_password_encrypted = crypto_service.encrypt(config_update.bind_password)
password_changed = "updated"
else:
password_changed = None
db.commit()
# 记录审计日志
AuditService.log_event(
db=db,
event_type=AuditEventType.CONFIG_CHANGED,
description=f"LDAP 配置已更新 (enabled={config_update.is_enabled}, exclusive={config_update.is_exclusive})",
user_id=str(context.user.id) if context.user else None,
metadata={
"server_url": config_update.server_url,
"is_enabled": config_update.is_enabled,
"is_exclusive": config_update.is_exclusive,
"password_changed": password_changed,
"is_new_config": is_new_config,
},
)
return {"message": "LDAP配置更新成功"}
class AdminTestLDAPConnectionAdapter(AdminApiAdapter):
async def handle(self, context) -> Dict[str, Any]: # type: ignore[override]
from src.services.auth.ldap import LDAPService
db = context.db
if context.json_body is not None:
payload = context.json_body
elif not context.raw_body:
payload = {}
else:
payload = context.ensure_json_body()
saved_config = db.query(LDAPConfig).first()
try:
overrides = LDAPConfigTest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
config_data: Dict[str, Any] = {}
if saved_config:
config_data = {
"server_url": saved_config.server_url,
"bind_dn": saved_config.bind_dn,
"base_dn": saved_config.base_dn,
"user_search_filter": saved_config.user_search_filter,
"username_attr": saved_config.username_attr,
"email_attr": saved_config.email_attr,
"display_name_attr": saved_config.display_name_attr,
"use_starttls": saved_config.use_starttls,
"connect_timeout": saved_config.connect_timeout,
}
# 应用前端传入的覆盖值
for field in [
"server_url",
"bind_dn",
"base_dn",
"user_search_filter",
"username_attr",
"email_attr",
"display_name_attr",
"use_starttls",
"is_enabled",
"is_exclusive",
"connect_timeout",
]:
value = getattr(overrides, field)
if value is not None:
config_data[field] = value
# bind_password 优先使用 overrides否则使用已保存的密码允许保存密码无法解密时依然用 overrides 测试)
if overrides.bind_password is not None:
config_data["bind_password"] = overrides.bind_password
elif saved_config and saved_config.bind_password_encrypted:
try:
config_data["bind_password"] = crypto_service.decrypt(
saved_config.bind_password_encrypted
)
except Exception as e:
logger.error(f"绑定密码解密失败: {type(e).__name__}: {e}")
return LDAPTestResponse(
success=False, message="绑定密码解密失败,请检查配置或重新设置密码"
).model_dump()
# 必填字段检查
required_fields = ["server_url", "bind_dn", "base_dn", "bind_password"]
missing = [f for f in required_fields if not config_data.get(f)]
if missing:
return LDAPTestResponse(
success=False, message=f"缺少必要字段: {', '.join(missing)}"
).model_dump()
success, message = LDAPService.test_connection_with_config(config_data)
return LDAPTestResponse(success=success, message=message).model_dump()

View File

@@ -146,20 +146,25 @@ class AdminListGlobalModelsAdapter(AdminApiAdapter):
search=self.search,
)
# 为每个 GlobalModel 添加统计数据
# 一次性查询所有 GlobalModel 的 provider_count优化 N+1 问题)
model_ids = [gm.id for gm in models]
provider_counts = {}
if model_ids:
count_results = (
context.db.query(
Model.global_model_id, func.count(func.distinct(Model.provider_id))
)
.filter(Model.global_model_id.in_(model_ids))
.group_by(Model.global_model_id)
.all()
)
provider_counts = {gm_id: count for gm_id, count in count_results}
# 构建响应
model_responses = []
for gm in models:
# 统计关联的 Model 数量(去重 Provider
provider_count = (
context.db.query(func.count(func.distinct(Model.provider_id)))
.filter(Model.global_model_id == gm.id)
.scalar()
or 0
)
response = GlobalModelResponse.model_validate(gm)
response.provider_count = provider_count
# usage_count 直接从 GlobalModel 表读取,已在 model_validate 中自动映射
response.provider_count = provider_counts.get(gm.id, 0)
model_responses.append(response)
return GlobalModelListResponse(

View File

@@ -2,7 +2,7 @@
提供商策略管理 API 端点
"""
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request
@@ -103,6 +103,9 @@ class AdminProviderBillingAdapter(AdminApiAdapter):
if config.quota_last_reset_at:
new_reset_at = parser.parse(config.quota_last_reset_at)
# 确保有时区信息,如果没有则假设为 UTC
if new_reset_at.tzinfo is None:
new_reset_at = new_reset_at.replace(tzinfo=timezone.utc)
provider.quota_last_reset_at = new_reset_at
# 自动同步该周期内的历史使用量
@@ -118,7 +121,11 @@ class AdminProviderBillingAdapter(AdminApiAdapter):
logger.info(f"Synced usage for provider {provider.name}: ${period_usage:.4f} since {new_reset_at}")
if config.quota_expires_at:
provider.quota_expires_at = parser.parse(config.quota_expires_at)
expires_at = parser.parse(config.quota_expires_at)
# 确保有时区信息,如果没有则假设为 UTC
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
provider.quota_expires_at = expires_at
db.commit()
db.refresh(provider)
@@ -149,7 +156,7 @@ class AdminProviderStatsAdapter(AdminApiAdapter):
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
since = datetime.now() - timedelta(hours=self.hours)
since = datetime.now(timezone.utc) - timedelta(hours=self.hours)
stats = (
db.query(ProviderUsageTracking)
.filter(

View File

@@ -1,5 +1,7 @@
"""系统设置API端点。"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
@@ -17,6 +19,46 @@ from src.services.email.email_template import EmailTemplate
from src.services.system.config import SystemConfigService
router = APIRouter(prefix="/api/admin/system", tags=["Admin - System"])
def _get_version_from_git() -> str | None:
"""从 git describe 获取版本号"""
import subprocess
try:
result = subprocess.run(
["git", "describe", "--tags", "--always"],
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0:
version = result.stdout.strip()
if version.startswith("v"):
version = version[1:]
return version
except Exception:
pass
return None
@router.get("/version")
async def get_system_version():
"""获取系统版本信息"""
# 优先从 git 获取
version = _get_version_from_git()
if version:
return {"version": version}
# 回退到静态版本文件
try:
from src._version import __version__
return {"version": __version__}
except ImportError:
return {"version": "unknown"}
pipeline = ApiRequestPipeline()
@@ -950,23 +992,12 @@ class AdminExportUsersAdapter(AdminApiAdapter):
db = context.db
# 导出 Users排除管理员
users = db.query(User).filter(
User.is_deleted.is_(False),
User.role != UserRole.ADMIN
).all()
users_data = []
for user in users:
# 导出用户的 API Keys保留加密数据
api_keys = db.query(ApiKey).filter(ApiKey.user_id == user.id).all()
api_keys_data = []
for key in api_keys:
api_keys_data.append(
{
def _serialize_api_key(key: ApiKey, include_is_standalone: bool = False) -> dict:
"""序列化 API Key 为导出格式"""
data = {
"key_hash": key.key_hash,
"key_encrypted": key.key_encrypted,
"name": key.name,
"is_standalone": key.is_standalone,
"balance_used_usd": key.balance_used_usd,
"current_balance_usd": key.current_balance_usd,
"allowed_providers": key.allowed_providers,
@@ -977,11 +1008,28 @@ class AdminExportUsersAdapter(AdminApiAdapter):
"concurrent_limit": key.concurrent_limit,
"force_capabilities": key.force_capabilities,
"is_active": key.is_active,
"expires_at": key.expires_at.isoformat() if key.expires_at else None,
"auto_delete_on_expiry": key.auto_delete_on_expiry,
"total_requests": key.total_requests,
"total_cost_usd": key.total_cost_usd,
}
)
if include_is_standalone:
data["is_standalone"] = key.is_standalone
return data
# 导出 Users排除管理员
users = db.query(User).filter(
User.is_deleted.is_(False),
User.role != UserRole.ADMIN
).all()
users_data = []
for user in users:
# 导出用户的 API Keys排除独立余额Key独立Key单独导出
api_keys = db.query(ApiKey).filter(
ApiKey.user_id == user.id,
ApiKey.is_standalone.is_(False)
).all()
api_keys_data = [_serialize_api_key(key, include_is_standalone=True) for key in api_keys]
users_data.append(
{
@@ -1001,10 +1049,15 @@ class AdminExportUsersAdapter(AdminApiAdapter):
}
)
# 导出独立余额 Keys管理员创建的不属于普通用户
standalone_keys = db.query(ApiKey).filter(ApiKey.is_standalone.is_(True)).all()
standalone_keys_data = [_serialize_api_key(key) for key in standalone_keys]
return {
"version": "1.0",
"version": "1.1",
"exported_at": datetime.now(timezone.utc).isoformat(),
"users": users_data,
"standalone_keys": standalone_keys_data,
}
@@ -1024,21 +1077,72 @@ class AdminImportUsersAdapter(AdminApiAdapter):
db = context.db
payload = context.ensure_json_body()
# 验证配置版本
version = payload.get("version")
if version != "1.0":
raise InvalidRequestException(f"不支持的配置版本: {version}")
# 获取导入选项
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
users_data = payload.get("users", [])
standalone_keys_data = payload.get("standalone_keys", [])
stats = {
"users": {"created": 0, "updated": 0, "skipped": 0},
"api_keys": {"created": 0, "skipped": 0},
"standalone_keys": {"created": 0, "skipped": 0},
"errors": [],
}
def _create_api_key_from_data(
key_data: dict,
owner_id: str,
is_standalone: bool = False,
) -> tuple[ApiKey | None, str]:
"""从导入数据创建 ApiKey 对象
Returns:
(ApiKey, "created"): 成功创建
(None, "skipped"): key 已存在,跳过
(None, "invalid"): 数据无效,跳过
"""
key_hash = key_data.get("key_hash", "").strip()
if not key_hash:
return None, "invalid"
# 检查是否已存在
existing = db.query(ApiKey).filter(ApiKey.key_hash == key_hash).first()
if existing:
return None, "skipped"
# 解析 expires_at
expires_at = None
if key_data.get("expires_at"):
try:
expires_at = datetime.fromisoformat(key_data["expires_at"])
except ValueError:
stats["errors"].append(
f"API Key '{key_data.get('name', key_hash[:8])}' 的 expires_at 格式无效"
)
return ApiKey(
id=str(uuid.uuid4()),
user_id=owner_id,
key_hash=key_hash,
key_encrypted=key_data.get("key_encrypted"),
name=key_data.get("name"),
is_standalone=is_standalone or key_data.get("is_standalone", False),
balance_used_usd=key_data.get("balance_used_usd", 0.0),
current_balance_usd=key_data.get("current_balance_usd"),
allowed_providers=key_data.get("allowed_providers"),
allowed_endpoints=key_data.get("allowed_endpoints"),
allowed_api_formats=key_data.get("allowed_api_formats"),
allowed_models=key_data.get("allowed_models"),
rate_limit=key_data.get("rate_limit"),
concurrent_limit=key_data.get("concurrent_limit", 5),
force_capabilities=key_data.get("force_capabilities"),
is_active=key_data.get("is_active", True),
expires_at=expires_at,
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
total_requests=key_data.get("total_requests", 0),
total_cost_usd=key_data.get("total_cost_usd", 0.0),
), "created"
try:
for user_data in users_data:
# 跳过管理员角色的导入(不区分大小写)
@@ -1109,40 +1213,31 @@ class AdminImportUsersAdapter(AdminApiAdapter):
# 导入 API Keys
for key_data in user_data.get("api_keys", []):
# 检查是否已存在相同的 key_hash
if key_data.get("key_hash"):
existing_key = (
db.query(ApiKey)
.filter(ApiKey.key_hash == key_data["key_hash"])
.first()
)
if existing_key:
stats["api_keys"]["skipped"] += 1
continue
new_key = ApiKey(
id=str(uuid.uuid4()),
user_id=user_id,
key_hash=key_data.get("key_hash", ""),
key_encrypted=key_data.get("key_encrypted"),
name=key_data.get("name"),
is_standalone=key_data.get("is_standalone", False),
balance_used_usd=key_data.get("balance_used_usd", 0.0),
current_balance_usd=key_data.get("current_balance_usd"),
allowed_providers=key_data.get("allowed_providers"),
allowed_endpoints=key_data.get("allowed_endpoints"),
allowed_api_formats=key_data.get("allowed_api_formats"),
allowed_models=key_data.get("allowed_models"),
rate_limit=key_data.get("rate_limit", 100),
concurrent_limit=key_data.get("concurrent_limit", 5),
force_capabilities=key_data.get("force_capabilities"),
is_active=key_data.get("is_active", True),
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
total_requests=key_data.get("total_requests", 0),
total_cost_usd=key_data.get("total_cost_usd", 0.0),
)
new_key, status = _create_api_key_from_data(key_data, user_id)
if new_key:
db.add(new_key)
stats["api_keys"]["created"] += 1
elif status == "skipped":
stats["api_keys"]["skipped"] += 1
# invalid 数据不计入统计
# 导入独立余额 Keys需要找一个管理员用户作为 owner
if standalone_keys_data:
# 查找一个管理员用户作为独立Key的owner
admin_user = db.query(User).filter(User.role == UserRole.ADMIN).first()
if not admin_user:
stats["errors"].append("无法导入独立余额Key: 系统中没有管理员用户")
else:
for key_data in standalone_keys_data:
new_key, status = _create_api_key_from_data(
key_data, admin_user.id, is_standalone=True
)
if new_key:
db.add(new_key)
stats["standalone_keys"]["created"] += 1
elif status == "skipped":
stats["standalone_keys"]["skipped"] += 1
# invalid 数据不计入统计
db.commit()

View File

@@ -73,11 +73,26 @@ async def get_usage_stats(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/heatmap")
async def get_activity_heatmap(
request: Request,
db: Session = Depends(get_db),
):
"""
Get activity heatmap data for the past 365 days.
This endpoint is cached for 5 minutes to reduce database load.
"""
adapter = AdminActivityHeatmapAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/records")
async def get_usage_records(
request: Request,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
search: Optional[str] = None, # 通用搜索:用户名、密钥名、模型名、提供商名
user_id: Optional[str] = None,
username: Optional[str] = None,
model: Optional[str] = None,
@@ -90,6 +105,7 @@ async def get_usage_records(
adapter = AdminUsageRecordsAdapter(
start_date=start_date,
end_date=end_date,
search=search,
user_id=user_id,
username=username,
model=model,
@@ -168,12 +184,6 @@ class AdminUsageStatsAdapter(AdminApiAdapter):
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
).count()
activity_heatmap = UsageService.get_daily_activity(
db=db,
window_days=365,
include_actual_cost=True,
)
context.add_audit_metadata(
action="usage_stats",
start_date=self.start_date.isoformat() if self.start_date else None,
@@ -204,10 +214,22 @@ class AdminUsageStatsAdapter(AdminApiAdapter):
),
"cache_read_cost": float(cache_stats.cache_read_cost or 0) if cache_stats else 0,
},
"activity_heatmap": activity_heatmap,
}
class AdminActivityHeatmapAdapter(AdminApiAdapter):
"""Activity heatmap adapter with Redis caching."""
async def handle(self, context): # type: ignore[override]
result = await UsageService.get_cached_heatmap(
db=context.db,
user_id=None,
include_actual_cost=True,
)
context.add_audit_metadata(action="activity_heatmap")
return result
class AdminUsageByModelAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
@@ -480,6 +502,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
self,
start_date: Optional[datetime],
end_date: Optional[datetime],
search: Optional[str],
user_id: Optional[str],
username: Optional[str],
model: Optional[str],
@@ -490,6 +513,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
):
self.start_date = start_date
self.end_date = end_date
self.search = search
self.user_id = user_id
self.username = username
self.model = model
@@ -499,25 +523,54 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
self.offset = offset
async def handle(self, context): # type: ignore[override]
from sqlalchemy import or_
from src.utils.database_helpers import escape_like_pattern, safe_truncate_escaped
db = context.db
query = (
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey)
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey, ApiKey)
.outerjoin(User, Usage.user_id == User.id)
.outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
.outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
.outerjoin(ApiKey, Usage.api_key_id == ApiKey.id)
)
# 如果需要按 Provider 名称搜索/筛选,统一在这里 JOIN
if self.search or self.provider:
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
# 通用搜索:用户名、密钥名、模型名、提供商名
# 支持空格分隔的组合搜索,多个关键词之间是 AND 关系
# 限制:最多 10 个关键词,转义后每个关键词最长 100 字符
if self.search:
keywords = [kw for kw in self.search.strip().split() if kw][:10]
for keyword in keywords:
escaped = safe_truncate_escaped(escape_like_pattern(keyword), 100)
search_pattern = f"%{escaped}%"
query = query.filter(
or_(
User.username.ilike(search_pattern, escape="\\"),
ApiKey.name.ilike(search_pattern, escape="\\"),
Usage.model.ilike(search_pattern, escape="\\"),
Provider.name.ilike(search_pattern, escape="\\"),
)
)
if self.user_id:
query = query.filter(Usage.user_id == self.user_id)
if self.username:
# 支持用户名模糊搜索
query = query.filter(User.username.ilike(f"%{self.username}%"))
escaped = escape_like_pattern(self.username)
query = query.filter(User.username.ilike(f"%{escaped}%", escape="\\"))
if self.model:
# 支持模型名模糊搜索
query = query.filter(Usage.model.ilike(f"%{self.model}%"))
escaped = escape_like_pattern(self.model)
query = query.filter(Usage.model.ilike(f"%{escaped}%", escape="\\"))
if self.provider:
# 支持提供商名称搜索(通过 Provider 表)
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
query = query.filter(Provider.name.ilike(f"%{self.provider}%"))
# 支持提供商名称搜索
escaped = escape_like_pattern(self.provider)
query = query.filter(Provider.name.ilike(f"%{escaped}%", escape="\\"))
if self.status:
# 状态筛选
# 旧的筛选值(基于 is_stream 和 status_codestream, standard, error
@@ -555,7 +608,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
)
request_ids = [usage.request_id for usage, _, _, _ in records if usage.request_id]
request_ids = [usage.request_id for usage, _, _, _, _ in records if usage.request_id]
fallback_map = {}
if request_ids:
# 只统计实际执行的候选success 或 failed不包括 skipped/pending/available
@@ -575,6 +628,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
action="usage_records",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
search=self.search,
user_id=self.user_id,
username=self.username,
model=self.model,
@@ -586,7 +640,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
)
# 构建 provider_id -> Provider 名称的映射,避免 N+1 查询
provider_ids = [usage.provider_id for usage, _, _, _ in records if usage.provider_id]
provider_ids = [usage.provider_id for usage, _, _, _, _ in records if usage.provider_id]
provider_map = {}
if provider_ids:
providers_data = (
@@ -595,7 +649,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
provider_map = {str(p.id): p.name for p in providers_data}
data = []
for usage, user, endpoint, api_key in records:
for usage, user, endpoint, provider_api_key, user_api_key in records:
actual_cost = (
float(usage.actual_total_cost_usd)
if usage.actual_total_cost_usd is not None
@@ -616,6 +670,15 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
"user_id": user.id if user else None,
"user_email": user.email if user else "已删除用户",
"username": user.username if user else "已删除用户",
"api_key": (
{
"id": user_api_key.id,
"name": user_api_key.name,
"display": user_api_key.get_display_key(),
}
if user_api_key
else None
),
"provider": provider_name,
"model": usage.model,
"target_model": usage.target_model, # 映射后的目标模型名
@@ -641,7 +704,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
"has_fallback": fallback_map.get(usage.request_id, False),
"api_format": usage.api_format
or (endpoint.api_format if endpoint and endpoint.api_format else None),
"api_key_name": api_key.name if api_key else None,
"api_key_name": provider_api_key.name if provider_api_key else None,
"request_metadata": usage.request_metadata, # Provider 响应元数据
}
)
@@ -670,7 +733,9 @@ class AdminActiveRequestsAdapter(AdminApiAdapter):
if not id_list:
return {"requests": []}
requests = UsageService.get_active_requests_status(db=db, ids=id_list)
requests = UsageService.get_active_requests_status(
db=db, ids=id_list, include_admin_fields=True
)
return {"requests": requests}

View File

@@ -248,6 +248,7 @@ class AdminUpdateUserAdapter(AdminApiAdapter):
raise InvalidRequestException("请求数据验证失败")
update_data = request.model_dump(exclude_unset=True)
old_role = existing_user.role
if "role" in update_data and update_data["role"]:
if hasattr(update_data["role"], "value"):
update_data["role"] = update_data["role"]
@@ -258,6 +259,12 @@ class AdminUpdateUserAdapter(AdminApiAdapter):
if not user:
raise NotFoundException("用户不存在", "user")
# 角色变更时清除热力图缓存(影响 include_actual_cost 权限)
if "role" in update_data and update_data["role"] != old_role:
from src.services.usage.service import UsageService
await UsageService.clear_user_heatmap_cache(self.user_id)
changed_fields = list(update_data.keys())
context.add_audit_metadata(
action="update_user",
@@ -424,7 +431,7 @@ class AdminCreateUserKeyAdapter(AdminApiAdapter):
name=key_data.name,
allowed_providers=key_data.allowed_providers,
allowed_models=key_data.allowed_models,
rate_limit=key_data.rate_limit or 100,
rate_limit=key_data.rate_limit, # None = 无限制
expire_days=key_data.expire_days,
initial_balance_usd=None, # 普通Key不设置余额限制
is_standalone=False, # 不是独立Key

View File

@@ -33,6 +33,7 @@ from src.models.api import (
)
from src.models.database import AuditEventType, User, UserRole
from src.services.auth.service import AuthService
from src.services.auth.ldap import LDAPService
from src.services.rate_limit.ip_limiter import IPRateLimiter
from src.services.system.audit import AuditService
from src.services.system.config import SystemConfigService
@@ -99,6 +100,13 @@ async def registration_settings(request: Request, db: Session = Depends(get_db))
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/settings")
async def auth_settings(request: Request, db: Session = Depends(get_db)):
"""公开获取认证设置(用于前端判断显示哪些登录选项)"""
adapter = AuthSettingsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/login", response_model=LoginResponse)
async def login(request: Request, db: Session = Depends(get_db)):
adapter = AuthLoginAdapter()
@@ -193,7 +201,9 @@ class AuthLoginAdapter(AuthPublicAdapter):
detail=f"登录请求过于频繁,请在 {reset_after} 秒后重试",
)
user = await AuthService.authenticate_user(db, login_request.email, login_request.password)
user = await AuthService.authenticate_user(
db, login_request.email, login_request.password, login_request.auth_type
)
if not user:
AuditService.log_login_attempt(
db=db,
@@ -305,6 +315,21 @@ class AuthRegistrationSettingsAdapter(AuthPublicAdapter):
).model_dump()
class AuthSettingsAdapter(AuthPublicAdapter):
async def handle(self, context): # type: ignore[override]
"""公开返回认证设置"""
db = context.db
ldap_enabled = LDAPService.is_ldap_enabled(db)
ldap_exclusive = LDAPService.is_ldap_exclusive(db)
return {
"local_enabled": not ldap_exclusive,
"ldap_enabled": ldap_enabled,
"ldap_exclusive": ldap_exclusive,
}
class AuthRegisterAdapter(AuthPublicAdapter):
async def handle(self, context): # type: ignore[override]
from src.models.database import SystemConfig
@@ -324,6 +349,12 @@ class AuthRegisterAdapter(AuthPublicAdapter):
detail=f"注册请求过于频繁,请在 {reset_after} 秒后重试",
)
# 仅允许 LDAP 登录时拒绝本地注册
if LDAPService.is_ldap_exclusive(db):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="系统已启用 LDAP 专属登录,禁止本地注册"
)
allow_registration = db.query(SystemConfig).filter_by(key="enable_registration").first()
if allow_registration and not allow_registration.value:
AuditService.log_event(

View File

@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from src.core.logger import logger
from src.models.database import ApiKey, User
from src.utils.request_utils import get_client_ip
@@ -86,7 +87,7 @@ class ApiRequestContext:
setattr(request.state, "request_id", request_id)
start_time = time.time()
client_ip = request.client.host if request.client else "unknown"
client_ip = get_client_ip(request)
user_agent = request.headers.get("user-agent", "unknown")
context = cls(

View File

@@ -7,6 +7,7 @@ from typing import Any, Optional, Tuple
from fastapi import HTTPException, Request
from sqlalchemy.orm import Session
from src.config.settings import config
from src.core.exceptions import QuotaExceededException
from src.core.logger import logger
from src.models.database import ApiKey, AuditEventType, User, UserRole
@@ -64,13 +65,17 @@ class ApiRequestPipeline:
try:
import asyncio
# 添加30秒超时防止卡死
raw_body = await asyncio.wait_for(http_request.body(), timeout=30.0)
# 添加超时防止卡死
raw_body = await asyncio.wait_for(
http_request.body(), timeout=config.request_body_timeout
)
logger.debug(f"[Pipeline] Raw body读取完成 | size={len(raw_body) if raw_body is not None else 0} bytes")
except asyncio.TimeoutError:
logger.error("读取请求体超时(30s),可能客户端未发送完整请求体")
timeout_sec = int(config.request_body_timeout)
logger.error(f"读取请求体超时({timeout_sec}s),可能客户端未发送完整请求体")
raise HTTPException(
status_code=408, detail="Request timeout: body not received within 30 seconds"
status_code=408,
detail=f"Request timeout: body not received within {timeout_sec} seconds",
)
else:
logger.debug(f"[Pipeline] 非写请求跳过读取Body | method={http_request.method}")

View File

@@ -47,7 +47,6 @@ if TYPE_CHECKING:
from src.api.handlers.base.stream_context import StreamContext
class MessageTelemetry:
"""
负责记录 Usage/Audit避免处理器里重复代码。
@@ -406,7 +405,7 @@ class BaseMessageHandler:
asyncio.create_task(_do_update())
def _update_usage_to_streaming_with_ctx(self, ctx: "StreamContext") -> None:
"""更新 Usage 状态为 streaming同时更新 provider 和 target_model
"""更新 Usage 状态为 streaming同时更新 provider 相关信息
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
@@ -414,7 +413,7 @@ class BaseMessageHandler:
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
Args:
ctx: 流式上下文,包含 provider_name 和 mapped_model
ctx: 流式上下文,包含 provider 相关信息
"""
import asyncio
from src.database.database import get_db
@@ -422,6 +421,17 @@ class BaseMessageHandler:
target_request_id = self.request_id
provider = ctx.provider_name
target_model = ctx.mapped_model
provider_id = ctx.provider_id
endpoint_id = ctx.endpoint_id
key_id = ctx.key_id
first_byte_time_ms = ctx.first_byte_time_ms
# 如果 provider 为空,记录警告(不应该发生,但用于调试)
if not provider:
logger.warning(
f"[{target_request_id}] 更新 streaming 状态时 provider 为空: "
f"ctx.provider_name={ctx.provider_name}, ctx.provider_id={ctx.provider_id}"
)
async def _do_update() -> None:
try:
@@ -434,6 +444,10 @@ class BaseMessageHandler:
status="streaming",
provider=provider,
target_model=target_model,
provider_id=provider_id,
provider_endpoint_id=endpoint_id,
provider_api_key_id=key_id,
first_byte_time_ms=first_byte_time_ms,
)
finally:
db.close()

View File

@@ -40,6 +40,7 @@ from src.core.exceptions import (
UpstreamClientException,
)
from src.core.logger import logger
from src.services.billing import calculate_request_cost as _calculate_request_cost
from src.services.request.result import RequestResult
from src.services.usage.recorder import UsageRecorder
@@ -63,6 +64,9 @@ class ChatAdapterBase(ApiAdapter):
name: str = "chat.base"
mode = ApiMode.STANDARD
# 计费模板配置(子类可覆盖,如 "claude", "openai", "gemini"
BILLING_TEMPLATE: str = "claude"
# 子类可以配置的特殊方法用于check_endpoint
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
@@ -486,40 +490,6 @@ class ChatAdapterBase(ApiAdapter):
"""
return input_tokens + cache_read_input_tokens
def get_cache_read_price_for_ttl(
self,
tier: dict,
cache_ttl_minutes: Optional[int] = None,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
子类可覆盖此方法实现不同的 TTL 定价逻辑
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格(每 1M tokens
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if ttl_pricing and cache_ttl_minutes is not None:
matched_price = None
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
matched_price = ttl_config.get("cache_read_price_per_1m")
break
if matched_price is not None:
return matched_price
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
return ttl_pricing[-1].get("cache_read_price_per_1m")
return tier.get("cache_read_price_per_1m")
def compute_cost(
self,
input_tokens: int,
@@ -537,8 +507,9 @@ class ChatAdapterBase(ApiAdapter):
"""
计算请求成本
默认实现:支持固定价格和阶梯计费
子类可覆盖此方法实现完全不同的计费逻辑
使用 billing 模块的配置驱动计费
子类可通过设置 BILLING_TEMPLATE 类属性来指定计费模板,
或覆盖此方法实现完全自定义的计费逻辑。
Args:
input_tokens: 输入 token 数
@@ -566,88 +537,26 @@ class ChatAdapterBase(ApiAdapter):
"tier_index": Optional[int], # 命中的阶梯索引
}
"""
tier_index = None
effective_input_price = input_price_per_1m
effective_output_price = output_price_per_1m
effective_cache_creation_price = cache_creation_price_per_1m
effective_cache_read_price = cache_read_price_per_1m
# 检查阶梯计费
if tiered_pricing and tiered_pricing.get("tiers"):
# 计算总输入上下文(使用子类可覆盖的方法)
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
if tier:
tier_index = tiered_pricing["tiers"].index(tier)
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
effective_cache_creation_price = tier.get(
"cache_creation_price_per_1m", cache_creation_price_per_1m
return _calculate_request_cost(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
input_price_per_1m=input_price_per_1m,
output_price_per_1m=output_price_per_1m,
cache_creation_price_per_1m=cache_creation_price_per_1m,
cache_read_price_per_1m=cache_read_price_per_1m,
price_per_request=price_per_request,
tiered_pricing=tiered_pricing,
cache_ttl_minutes=cache_ttl_minutes,
total_input_context=total_input_context,
billing_template=self.BILLING_TEMPLATE,
)
effective_cache_read_price = self.get_cache_read_price_for_ttl(
tier, cache_ttl_minutes
)
if effective_cache_read_price is None:
effective_cache_read_price = cache_read_price_per_1m
# 计算各项成本
input_cost = (input_tokens / 1_000_000) * effective_input_price
output_cost = (output_tokens / 1_000_000) * effective_output_price
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * effective_cache_creation_price
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
cache_read_cost = (
cache_read_input_tokens / 1_000_000
) * effective_cache_read_price
cache_cost = cache_creation_cost + cache_read_cost
request_cost = price_per_request if price_per_request else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return {
"input_cost": input_cost,
"output_cost": output_cost,
"cache_creation_cost": cache_creation_cost,
"cache_read_cost": cache_read_cost,
"cache_cost": cache_cost,
"request_cost": request_cost,
"total_cost": total_cost,
"tier_index": tier_index,
}
@staticmethod
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
"""
根据总输入 token 数确定价格阶梯
Args:
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
total_input_tokens: 总输入 token 数
Returns:
匹配的阶梯配置
"""
if not tiered_pricing or "tiers" not in tiered_pricing:
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
for tier in tiers:
up_to = tier.get("up_to")
if up_to is None or total_input_tokens <= up_to:
return tier
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
return tiers[-1] if tiers else None
# =========================================================================
# 模型列表查询 - 子类应覆盖此方法

View File

@@ -36,6 +36,7 @@ from src.api.handlers.base.stream_processor import StreamProcessor
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
from src.api.handlers.base.utils import build_sse_headers
from src.config.settings import config
from src.core.error_utils import extract_error_message
from src.core.exceptions import (
EmbeddedErrorException,
ProviderAuthException,
@@ -500,6 +501,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
error_text = await self._extract_error_text(e)
logger.error(f"Provider 返回错误: {e.response.status_code}\n Response: {error_text}")
await http_client.aclose()
# 将上游错误信息附加到异常,以便故障转移时能够返回给客户端
e.upstream_response = error_text # type: ignore[attr-defined]
raise
except EmbeddedErrorException:
@@ -549,7 +552,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
model=ctx.model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=str(error),
error_message=extract_error_message(error),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
@@ -785,7 +788,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
model=model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=str(e),
error_message=extract_error_message(e),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=False,
@@ -802,10 +805,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
try:
if hasattr(e.response, "is_stream_consumed") and not e.response.is_stream_consumed:
error_bytes = await e.response.aread()
return error_bytes.decode("utf-8", errors="replace")[:500]
return error_bytes.decode("utf-8", errors="replace")
else:
return (
e.response.text[:500] if hasattr(e.response, "_content") else "Unable to read"
e.response.text if hasattr(e.response, "_content") else "Unable to read"
)
except Exception as decode_error:
return f"Unable to read error: {decode_error}"

View File

@@ -38,6 +38,7 @@ from src.core.exceptions import (
UpstreamClientException,
)
from src.core.logger import logger
from src.services.billing import calculate_request_cost as _calculate_request_cost
from src.services.request.result import RequestResult
from src.services.usage.recorder import UsageRecorder
@@ -61,6 +62,9 @@ class CliAdapterBase(ApiAdapter):
name: str = "cli.base"
mode = ApiMode.PROXY
# 计费模板配置(子类可覆盖,如 "claude", "openai", "gemini"
BILLING_TEMPLATE: str = "claude"
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
@@ -438,40 +442,6 @@ class CliAdapterBase(ApiAdapter):
"""
return input_tokens + cache_read_input_tokens
def get_cache_read_price_for_ttl(
self,
tier: dict,
cache_ttl_minutes: Optional[int] = None,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
子类可覆盖此方法实现不同的 TTL 定价逻辑
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格(每 1M tokens
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if ttl_pricing and cache_ttl_minutes is not None:
matched_price = None
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
matched_price = ttl_config.get("cache_read_price_per_1m")
break
if matched_price is not None:
return matched_price
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
return ttl_pricing[-1].get("cache_read_price_per_1m")
return tier.get("cache_read_price_per_1m")
def compute_cost(
self,
input_tokens: int,
@@ -489,8 +459,9 @@ class CliAdapterBase(ApiAdapter):
"""
计算请求成本
默认实现:支持固定价格和阶梯计费
子类可覆盖此方法实现完全不同的计费逻辑
使用 billing 模块的配置驱动计费
子类可通过设置 BILLING_TEMPLATE 类属性来指定计费模板,
或覆盖此方法实现完全自定义的计费逻辑。
Args:
input_tokens: 输入 token 数
@@ -508,78 +479,26 @@ class CliAdapterBase(ApiAdapter):
Returns:
包含各项成本的字典
"""
tier_index = None
effective_input_price = input_price_per_1m
effective_output_price = output_price_per_1m
effective_cache_creation_price = cache_creation_price_per_1m
effective_cache_read_price = cache_read_price_per_1m
# 检查阶梯计费
if tiered_pricing and tiered_pricing.get("tiers"):
# 计算总输入上下文(使用子类可覆盖的方法)
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
if tier:
tier_index = tiered_pricing["tiers"].index(tier)
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
effective_cache_creation_price = tier.get(
"cache_creation_price_per_1m", cache_creation_price_per_1m
return _calculate_request_cost(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
input_price_per_1m=input_price_per_1m,
output_price_per_1m=output_price_per_1m,
cache_creation_price_per_1m=cache_creation_price_per_1m,
cache_read_price_per_1m=cache_read_price_per_1m,
price_per_request=price_per_request,
tiered_pricing=tiered_pricing,
cache_ttl_minutes=cache_ttl_minutes,
total_input_context=total_input_context,
billing_template=self.BILLING_TEMPLATE,
)
effective_cache_read_price = self.get_cache_read_price_for_ttl(
tier, cache_ttl_minutes
)
if effective_cache_read_price is None:
effective_cache_read_price = cache_read_price_per_1m
# 计算各项成本
input_cost = (input_tokens / 1_000_000) * effective_input_price
output_cost = (output_tokens / 1_000_000) * effective_output_price
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * effective_cache_creation_price
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
cache_read_cost = (
cache_read_input_tokens / 1_000_000
) * effective_cache_read_price
cache_cost = cache_creation_cost + cache_read_cost
request_cost = price_per_request if price_per_request else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return {
"input_cost": input_cost,
"output_cost": output_cost,
"cache_creation_cost": cache_creation_cost,
"cache_read_cost": cache_read_cost,
"cache_cost": cache_cost,
"request_cost": request_cost,
"total_cost": total_cost,
"tier_index": tier_index,
}
@staticmethod
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
"""根据总输入 token 数确定价格阶梯"""
if not tiered_pricing or "tiers" not in tiered_pricing:
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
for tier in tiers:
up_to = tier.get("up_to")
if up_to is None or total_input_tokens <= up_to:
return tier
return tiers[-1] if tiers else None
# =========================================================================
# 模型列表查询 - 子类应覆盖此方法

View File

@@ -34,7 +34,12 @@ from src.api.handlers.base.base_handler import (
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.utils import build_sse_headers
from src.api.handlers.base.utils import (
build_sse_headers,
check_html_response,
check_prefetched_response_error,
)
from src.core.error_utils import extract_error_message
# 直接从具体模块导入,避免循环依赖
from src.api.handlers.base.response_parser import (
@@ -57,6 +62,7 @@ from src.models.database import (
ProviderEndpoint,
User,
)
from src.config.constants import StreamDefaults
from src.config.settings import config
from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser
@@ -328,9 +334,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
stream_generator,
provider_name,
attempt_id,
_provider_id,
_endpoint_id,
_key_id,
provider_id,
endpoint_id,
key_id,
) = await self.orchestrator.execute_with_fallback(
api_format=ctx.api_format,
model_name=ctx.model,
@@ -340,7 +346,17 @@ class CliMessageHandlerBase(BaseMessageHandler):
is_stream=True,
capability_requirements=capability_requirements or None,
)
# 更新上下文(确保 provider 信息已设置,用于 streaming 状态更新)
ctx.attempt_id = attempt_id
if not ctx.provider_name:
ctx.provider_name = provider_name
if not ctx.provider_id:
ctx.provider_id = provider_id
if not ctx.endpoint_id:
ctx.endpoint_id = endpoint_id
if not ctx.key_id:
ctx.key_id = key_id
# 创建后台任务记录统计
background_tasks = BackgroundTasks()
@@ -488,6 +504,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
error_text = await self._extract_error_text(e)
logger.error(f"Provider 返回错误状态: {e.response.status_code}\n Response: {error_text}")
await http_client.aclose()
# 将上游错误信息附加到异常,以便故障转移时能够返回给客户端
e.upstream_response = error_text # type: ignore[attr-defined]
raise
except EmbeddedErrorException:
@@ -523,8 +541,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
try:
sse_parser = SSEEventParser()
last_data_time = time.time()
streaming_status_updated = False
buffer = b""
output_state = {"first_yield": True, "streaming_updated": False}
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
@@ -532,11 +550,6 @@ class CliMessageHandlerBase(BaseMessageHandler):
needs_conversion = self._needs_format_conversion(ctx)
async for chunk in stream_response.aiter_bytes():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming_with_ctx(ctx)
streaming_status_updated = True
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
@@ -561,6 +574,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
event.get("event"),
event.get("data") or "",
)
self._mark_first_output(ctx, output_state)
yield b"\n"
continue
@@ -578,6 +592,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
self._mark_first_output(ctx, output_state)
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return # 结束生成器
@@ -585,8 +600,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
self._mark_first_output(ctx, output_state)
yield (converted_line + "\n").encode("utf-8")
else:
self._mark_first_output(ctx, output_state)
yield (line + "\n").encode("utf-8")
for event in events:
@@ -637,7 +654,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
except httpx.RemoteProtocolError as e:
except httpx.RemoteProtocolError:
if ctx.data_count > 0:
error_event = {
"type": "error",
@@ -691,7 +708,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
ProviderTimeoutException: 如果首字节超时TTFB timeout
"""
prefetched_chunks: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误
max_prefetch_lines = config.stream_prefetch_lines # 最多预读行来检测错误
max_prefetch_bytes = StreamDefaults.MAX_PREFETCH_BYTES # 避免无换行响应导致 buffer 增长
total_prefetched_bytes = 0
buffer = b""
line_count = 0
should_stop = False
@@ -718,14 +737,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
provider_name=str(provider.name),
)
prefetched_chunks.append(first_chunk)
total_prefetched_bytes += len(first_chunk)
buffer += first_chunk
# 继续读取剩余的预读数据
async for chunk in aiter:
prefetched_chunks.append(chunk)
total_prefetched_bytes += len(chunk)
buffer += chunk
# 尝试按行解析缓冲区
# 尝试按行解析缓冲区SSE 格式)
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
@@ -742,15 +763,15 @@ class CliMessageHandlerBase(BaseMessageHandler):
normalized_line = line.rstrip("\r")
# 检测 HTML 响应base_url 配置错误的常见症状)
lower_line = normalized_line.lower()
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
if check_html_response(normalized_line):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,"
f"请检查 endpoint 的 base_url 配置是否正确"
)
if not normalized_line or normalized_line.startswith(":"):
@@ -799,9 +820,30 @@ class CliMessageHandlerBase(BaseMessageHandler):
should_stop = True
break
# 达到预读字节上限,停止继续预读(避免无换行响应导致内存增长)
if not should_stop and total_prefetched_bytes >= max_prefetch_bytes:
logger.debug(
f" [{self.request_id}] 预读达到字节上限,停止继续预读: "
f"Provider={provider.name}, bytes={total_prefetched_bytes}, "
f"max_bytes={max_prefetch_bytes}"
)
break
if should_stop or line_count >= max_prefetch_lines:
break
# 预读结束后,检查是否为非 SSE 格式的 HTML/JSON 响应
# 处理某些代理返回的纯 JSON 错误(可能无换行/多行 JSON以及 HTML 页面base_url 配置错误)
if not should_stop and prefetched_chunks:
check_prefetched_response_error(
prefetched_chunks=prefetched_chunks,
parser=provider_parser,
request_id=self.request_id,
provider_name=str(provider.name),
endpoint_id=endpoint.id,
base_url=endpoint.base_url,
)
except (EmbeddedErrorException, ProviderTimeoutException, ProviderNotAvailableException):
# 重新抛出可重试的 Provider 异常,触发故障转移
raise
@@ -833,17 +875,13 @@ class CliMessageHandlerBase(BaseMessageHandler):
sse_parser = SSEEventParser()
last_data_time = time.time()
buffer = b""
first_yield = True # 标记是否是第一次 yield
output_state = {"first_yield": True, "streaming_updated": False}
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx)
# 在第一次输出数据前更新状态为 streaming
if prefetched_chunks:
self._update_usage_to_streaming_with_ctx(ctx)
# 先处理预读的字节块
for chunk in prefetched_chunks:
buffer += chunk
@@ -870,10 +908,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield b"\n"
continue
@@ -883,16 +918,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield (line + "\n").encode("utf-8")
for event in events:
@@ -931,10 +960,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield b"\n"
continue
@@ -952,6 +978,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
self._mark_first_output(ctx, output_state)
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return
@@ -959,16 +986,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield (line + "\n").encode("utf-8")
for event in events:
@@ -1352,7 +1373,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
model=ctx.model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=str(error),
error_message=extract_error_message(error),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
@@ -1476,8 +1497,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
retry_after=int(resp.headers.get("retry-after", 0)) or None,
)
elif resp.status_code >= 500:
error_text = resp.text
raise ProviderNotAvailableException(
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}"
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_text,
)
elif 300 <= resp.status_code < 400:
redirect_url = resp.headers.get("location", "unknown")
@@ -1487,7 +1512,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
elif resp.status_code != 200:
error_text = resp.text
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}, 错误: {error_text[:200]}"
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_text,
)
# 安全解析 JSON 响应,处理可能的编码错误
@@ -1620,7 +1648,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
model=model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=str(e),
error_message=extract_error_message(e),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=False,
@@ -1640,14 +1668,14 @@ class CliMessageHandlerBase(BaseMessageHandler):
for encoding in ["utf-8", "gbk", "latin1"]:
try:
return error_bytes.decode(encoding)[:500]
return error_bytes.decode(encoding)
except (UnicodeDecodeError, LookupError):
continue
return error_bytes.decode("utf-8", errors="replace")[:500]
return error_bytes.decode("utf-8", errors="replace")
else:
return (
e.response.text[:500]
e.response.text
if hasattr(e.response, "_content")
else "Unable to read response"
)
@@ -1665,6 +1693,25 @@ class CliMessageHandlerBase(BaseMessageHandler):
return False
return ctx.provider_api_format.upper() != ctx.client_api_format.upper()
def _mark_first_output(self, ctx: StreamContext, state: Dict[str, bool]) -> None:
"""
标记首次输出:记录 TTFB 并更新 streaming 状态
在第一次 yield 数据前调用,确保:
1. 首字时间 (TTFB) 已记录到 ctx
2. Usage 状态已更新为 streaming包含 provider/key/TTFB 信息)
Args:
ctx: 流上下文
state: 包含 first_yield 和 streaming_updated 的状态字典
"""
if state["first_yield"]:
ctx.record_first_byte_time(self.start_time)
state["first_yield"] = False
if not state["streaming_updated"]:
self._update_usage_to_streaming_with_ctx(ctx)
state["streaming_updated"] = True
def _convert_sse_line(
self,
ctx: StreamContext,

View File

@@ -98,6 +98,17 @@ class OpenAIResponseParser(ResponseParser):
chunk.is_done = True
stats.has_completion = True
# 提取 usage 信息(某些 OpenAI 兼容 API 如豆包会在最后一个 chunk 中发送 usage
# 这个 chunk 通常 choices 为空数组,但包含完整的 usage 信息
usage = parsed.get("usage")
if usage and isinstance(usage, dict):
chunk.input_tokens = usage.get("prompt_tokens", 0)
chunk.output_tokens = usage.get("completion_tokens", 0)
# 更新 stats
stats.input_tokens = chunk.input_tokens
stats.output_tokens = chunk.output_tokens
stats.chunk_count += 1
stats.data_count += 1

View File

@@ -25,8 +25,17 @@ from src.api.handlers.base.content_extractors import (
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.utils import (
check_html_response,
check_prefetched_response_error,
)
from src.config.constants import StreamDefaults
from src.config.settings import config
from src.core.exceptions import EmbeddedErrorException, ProviderTimeoutException
from src.core.exceptions import (
EmbeddedErrorException,
ProviderNotAvailableException,
ProviderTimeoutException,
)
from src.core.logger import logger
from src.models.database import Provider, ProviderEndpoint
from src.utils.sse_parser import SSEEventParser
@@ -165,6 +174,7 @@ class StreamProcessor:
endpoint: ProviderEndpoint,
ctx: StreamContext,
max_prefetch_lines: int = 5,
max_prefetch_bytes: int = StreamDefaults.MAX_PREFETCH_BYTES,
) -> list:
"""
预读流的前几行,检测嵌套错误
@@ -180,12 +190,14 @@ class StreamProcessor:
endpoint: Endpoint 对象
ctx: 流式上下文
max_prefetch_lines: 最多预读行数
max_prefetch_bytes: 最多预读字节数(避免无换行响应导致 buffer 增长)
Returns:
预读的字节块列表
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
ProviderTimeoutException: 如果首字节超时TTFB timeout
"""
prefetched_chunks: list = []
@@ -193,6 +205,7 @@ class StreamProcessor:
buffer = b""
line_count = 0
should_stop = False
total_prefetched_bytes = 0
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
@@ -206,11 +219,13 @@ class StreamProcessor:
provider_name=str(provider.name),
)
prefetched_chunks.append(first_chunk)
total_prefetched_bytes += len(first_chunk)
buffer += first_chunk
# 继续读取剩余的预读数据
async for chunk in aiter:
prefetched_chunks.append(chunk)
total_prefetched_bytes += len(chunk)
buffer += chunk
# 尝试按行解析缓冲区
@@ -228,10 +243,21 @@ class StreamProcessor:
line_count += 1
# 检测 HTML 响应base_url 配置错误的常见症状)
if check_html_response(line):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,"
f"请检查 endpoint 的 base_url 配置是否正确"
)
# 跳过空行和注释行
if not line or line.startswith(":"):
if line_count >= max_prefetch_lines:
should_stop = True
break
continue
@@ -248,7 +274,6 @@ class StreamProcessor:
data = json.loads(data_str)
except json.JSONDecodeError:
if line_count >= max_prefetch_lines:
should_stop = True
break
continue
@@ -276,14 +301,34 @@ class StreamProcessor:
should_stop = True
break
# 达到预读字节上限,停止继续预读(避免无换行响应导致内存增长)
if not should_stop and total_prefetched_bytes >= max_prefetch_bytes:
logger.debug(
f" [{self.request_id}] 预读达到字节上限,停止继续预读: "
f"Provider={provider.name}, bytes={total_prefetched_bytes}, "
f"max_bytes={max_prefetch_bytes}"
)
break
if should_stop or line_count >= max_prefetch_lines:
break
except (EmbeddedErrorException, ProviderTimeoutException):
# 预读结束后,检查是否为非 SSE 格式的 HTML/JSON 响应
if not should_stop and prefetched_chunks:
check_prefetched_response_error(
prefetched_chunks=prefetched_chunks,
parser=parser,
request_id=self.request_id,
provider_name=str(provider.name),
endpoint_id=endpoint.id,
base_url=endpoint.base_url,
)
except (EmbeddedErrorException, ProviderNotAvailableException, ProviderTimeoutException):
# 重新抛出可重试的 Provider 异常,触发故障转移
raise
except (OSError, IOError) as e:
# 网络 I/O <EFBFBD><EFBFBD><EFBFBD>常:记录警告,可能需要重试
# 网络 I/O 常:记录警告,可能需要重试
logger.warning(
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
)
@@ -332,15 +377,15 @@ class StreamProcessor:
# 处理预读数据
if prefetched_chunks:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
for chunk in prefetched_chunks:
# 记录首字时间 (TTFB) - 在 yield 之前记录
if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 首次输出前触发 streaming 回调(确保 TTFB 已写入 ctx
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
# 把原始数据转发给客户端
yield chunk
@@ -363,14 +408,14 @@ class StreamProcessor:
# 处理剩余的流数据
async for chunk in byte_iterator:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
# 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 首次输出前触发 streaming 回调(确保 TTFB 已写入 ctx
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
# 原始数据透传
yield chunk

View File

@@ -2,8 +2,10 @@
Handler 基础工具函数
"""
import json
from typing import Any, Dict, Optional
from src.core.exceptions import EmbeddedErrorException, ProviderNotAvailableException
from src.core.logger import logger
@@ -107,3 +109,95 @@ def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[st
if extra_headers:
headers.update(extra_headers)
return headers
def check_html_response(line: str) -> bool:
"""
检查行是否为 HTML 响应base_url 配置错误的常见症状)
Args:
line: 要检查的行内容
Returns:
True 如果检测到 HTML 响应
"""
lower_line = line.lstrip().lower()
return lower_line.startswith("<!doctype") or lower_line.startswith("<html")
def check_prefetched_response_error(
prefetched_chunks: list,
parser: Any,
request_id: str,
provider_name: str,
endpoint_id: Optional[str],
base_url: Optional[str],
) -> None:
"""
检查预读的响应是否为非 SSE 格式的错误响应HTML 或纯 JSON 错误)
某些代理可能返回:
1. HTML 页面base_url 配置错误)
2. 纯 JSON 错误(无换行或多行 JSON
Args:
prefetched_chunks: 预读的字节块列表
parser: 响应解析器(需要有 is_error_response 和 parse_response 方法)
request_id: 请求 ID用于日志
provider_name: Provider 名称
endpoint_id: Endpoint ID
base_url: Endpoint 的 base_url
Raises:
ProviderNotAvailableException: 如果检测到 HTML 响应
EmbeddedErrorException: 如果检测到 JSON 错误响应
"""
if not prefetched_chunks:
return
try:
prefetched_bytes = b"".join(prefetched_chunks)
stripped = prefetched_bytes.lstrip()
# 去除 BOM
if stripped.startswith(b"\xef\xbb\xbf"):
stripped = stripped[3:]
# HTML 响应(通常是 base_url 配置错误导致返回网页)
lower_prefix = stripped[:32].lower()
if lower_prefix.startswith(b"<!doctype") or lower_prefix.startswith(b"<html"):
endpoint_short = endpoint_id[:8] + "..." if endpoint_id else "N/A"
logger.error(
f" [{request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider_name}, Endpoint={endpoint_short}, "
f"base_url={base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider_name}' 返回了 HTML 页面而非 API 响应,"
f"请检查 endpoint 的 base_url 配置是否正确"
)
# 纯 JSON可能无换行/多行 JSON
if stripped.startswith(b"{") or stripped.startswith(b"["):
payload_str = stripped.decode("utf-8", errors="replace").strip()
data = json.loads(payload_str)
if isinstance(data, dict) and parser.is_error_response(data):
parsed = parser.parse_response(data, 200)
logger.warning(
f" [{request_id}] 检测到 JSON 错误响应: "
f"Provider={provider_name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}"
)
raise EmbeddedErrorException(
provider_name=provider_name,
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
except json.JSONDecodeError:
pass

View File

@@ -63,6 +63,7 @@ class ClaudeChatAdapter(ChatAdapterBase):
"""
FORMAT_ID = "CLAUDE"
BILLING_TEMPLATE = "claude" # 使用 Claude 计费模板
name = "claude.chat"
@property

View File

@@ -24,6 +24,7 @@ class ClaudeCliAdapter(CliAdapterBase):
"""
FORMAT_ID = "CLAUDE_CLI"
BILLING_TEMPLATE = "claude" # 使用 Claude 计费模板
name = "claude.cli"
@property

View File

@@ -27,6 +27,7 @@ class GeminiChatAdapter(ChatAdapterBase):
"""
FORMAT_ID = "GEMINI"
BILLING_TEMPLATE = "gemini" # 使用 Gemini 计费模板
name = "gemini.chat"
@property

View File

@@ -24,6 +24,7 @@ class GeminiCliAdapter(CliAdapterBase):
"""
FORMAT_ID = "GEMINI_CLI"
BILLING_TEMPLATE = "gemini" # 使用 Gemini 计费模板
name = "gemini.cli"
@property

View File

@@ -26,6 +26,7 @@ class OpenAIChatAdapter(ChatAdapterBase):
"""
FORMAT_ID = "OPENAI"
BILLING_TEMPLATE = "openai" # 使用 OpenAI 计费模板
name = "openai.chat"
@property

View File

@@ -24,6 +24,7 @@ class OpenAICliAdapter(CliAdapterBase):
"""
FORMAT_ID = "OPENAI_CLI"
BILLING_TEMPLATE = "openai" # 使用 OpenAI 计费模板
name = "openai.cli"
@property

View File

@@ -104,9 +104,14 @@ async def get_my_usage(
request: Request,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
search: Optional[str] = None, # 通用搜索:密钥名、模型名
limit: int = Query(100, ge=1, le=200, description="每页记录数默认100最大200"),
offset: int = Query(0, ge=0, le=2000, description="偏移量用于分页最大2000"),
db: Session = Depends(get_db),
):
adapter = GetUsageAdapter(start_date=start_date, end_date=end_date)
adapter = GetUsageAdapter(
start_date=start_date, end_date=end_date, search=search, limit=limit, offset=offset
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -133,6 +138,20 @@ async def get_my_interval_timeline(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/usage/heatmap")
async def get_my_activity_heatmap(
request: Request,
db: Session = Depends(get_db),
):
"""
Get user's activity heatmap data for the past 365 days.
This endpoint is cached for 5 minutes to reduce database load.
"""
adapter = GetMyActivityHeatmapAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/providers")
async def list_available_providers(request: Request, db: Session = Depends(get_db)):
adapter = ListAvailableProvidersAdapter()
@@ -471,8 +490,15 @@ class ToggleMyApiKeyAdapter(AuthenticatedApiAdapter):
class GetUsageAdapter(AuthenticatedApiAdapter):
start_date: Optional[datetime]
end_date: Optional[datetime]
search: Optional[str] = None
limit: int = 100
offset: int = 0
async def handle(self, context): # type: ignore[override]
from sqlalchemy import or_
from src.utils.database_helpers import escape_like_pattern, safe_truncate_escaped
db = context.db
user = context.user
summary_list = UsageService.get_usage_summary(
@@ -553,7 +579,7 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
stats["total_cost_usd"] += item["total_cost_usd"]
# 假设 summary 中的都是成功的请求
stats["success_count"] += item["requests"]
if item.get("avg_response_time_ms"):
if item.get("avg_response_time_ms") is not None:
stats["total_response_time_ms"] += item["avg_response_time_ms"] * item["requests"]
stats["response_time_count"] += item["requests"]
@@ -577,12 +603,33 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
})
summary_by_provider = sorted(summary_by_provider, key=lambda x: x["requests"], reverse=True)
query = db.query(Usage).filter(Usage.user_id == user.id)
query = (
db.query(Usage, ApiKey)
.outerjoin(ApiKey, Usage.api_key_id == ApiKey.id)
.filter(Usage.user_id == user.id)
)
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
usage_records = query.order_by(Usage.created_at.desc()).limit(100).all()
# 通用搜索:密钥名、模型名
# 支持空格分隔的组合搜索,多个关键词之间是 AND 关系
if self.search and self.search.strip():
keywords = [kw for kw in self.search.strip().split() if kw][:10]
for keyword in keywords:
escaped = safe_truncate_escaped(escape_like_pattern(keyword), 100)
search_pattern = f"%{escaped}%"
query = query.filter(
or_(
ApiKey.name.ilike(search_pattern, escape="\\"),
Usage.model.ilike(search_pattern, escape="\\"),
)
)
# 计算总数用于分页
total_records = query.count()
usage_records = query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
avg_resp_query = db.query(func.avg(Usage.response_time_ms)).filter(
Usage.user_id == user.id,
@@ -608,6 +655,13 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
"used_usd": user.used_usd,
"summary_by_model": summary_by_model,
"summary_by_provider": summary_by_provider,
# 分页信息
"pagination": {
"total": total_records,
"limit": self.limit,
"offset": self.offset,
"has_more": self.offset + self.limit < total_records,
},
"records": [
{
"id": r.id,
@@ -631,23 +685,25 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
"output_price_per_1m": r.output_price_per_1m,
"cache_creation_price_per_1m": r.cache_creation_price_per_1m,
"cache_read_price_per_1m": r.cache_read_price_per_1m,
"api_key": (
{
"id": str(api_key.id),
"name": api_key.name,
"display": api_key.get_display_key(),
}
for r in usage_records
if api_key
else None
),
}
for r, api_key in usage_records
],
}
response_data["activity_heatmap"] = UsageService.get_daily_activity(
db=db,
user_id=user.id,
window_days=365,
include_actual_cost=user.role == "admin",
)
# 管理员可以看到真实成本
if user.role == "admin":
response_data["total_actual_cost"] = total_actual_cost
# 为每条记录添加真实成本和倍率信息
for i, r in enumerate(usage_records):
for i, (r, _) in enumerate(usage_records):
# 确保字段有值,避免前端显示 -
actual_cost = (
r.actual_total_cost_usd if r.actual_total_cost_usd is not None else 0.0
@@ -709,6 +765,20 @@ class GetMyIntervalTimelineAdapter(AuthenticatedApiAdapter):
return result
class GetMyActivityHeatmapAdapter(AuthenticatedApiAdapter):
"""Activity heatmap adapter with Redis caching for user."""
async def handle(self, context): # type: ignore[override]
user = context.user
result = await UsageService.get_cached_heatmap(
db=context.db,
user_id=user.id,
include_actual_cost=user.role == "admin",
)
context.add_audit_metadata(action="activity_heatmap")
return result
class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
async def handle(self, context): # type: ignore[override]
from sqlalchemy.orm import selectinload

View File

@@ -213,7 +213,7 @@ class RedisClientManager:
f"Redis连接失败: {error_msg}\n"
"缓存亲和性功能需要Redis支持请确保Redis服务正常运行。\n"
"检查事项:\n"
"1. Redis服务是否已启动docker-compose up -d redis\n"
"1. Redis服务是否已启动docker compose up -d redis\n"
"2. 环境变量 REDIS_URL 或 REDIS_PASSWORD 是否配置正确\n"
"3. Redis端口默认6379是否可访问"
) from e

View File

@@ -21,6 +21,9 @@ class CacheTTL:
# L1 本地缓存(用于减少 Redis 访问)
L1_LOCAL = 3 # 3秒
# 活跃度热力图缓存 - 历史数据变化不频繁
ACTIVITY_HEATMAP = 300 # 5分钟
# 并发锁 TTL - 防止死锁
CONCURRENCY_LOCK = 600 # 10分钟
@@ -38,8 +41,25 @@ class CacheSize:
# ==============================================================================
class StreamDefaults:
"""流式处理默认值"""
# 预读字节上限(避免无换行响应导致内存增长)
# 64KB 基于:
# 1. SSE 单条消息通常远小于此值
# 2. 足够检测 HTML 和 JSON 错误响应
# 3. 不会占用过多内存
MAX_PREFETCH_BYTES = 64 * 1024 # 64KB
class ConcurrencyDefaults:
"""并发控制默认值"""
"""并发控制默认值
算法说明:边界记忆 + 渐进探测
- 触发 429 时记录边界last_concurrent_peak新限制 = 边界 - 1
- 扩容时不超过边界,除非是探测性扩容(长时间无 429
- 这样可以快速收敛到真实限制附近,避免过度保守
"""
# 自适应并发初始限制(宽松起步,遇到 429 再降低)
INITIAL_LIMIT = 50
@@ -69,10 +89,6 @@ class ConcurrencyDefaults:
# 扩容步长 - 每次扩容增加的并发数
INCREASE_STEP = 2
# 缩容乘数 - 遇到 429 时基于当前并发数的缩容比例
# 0.85 表示降到触发 429 时并发数的 85%
DECREASE_MULTIPLIER = 0.85
# 最大并发限制上限
MAX_CONCURRENT_LIMIT = 200
@@ -84,6 +100,7 @@ class ConcurrencyDefaults:
# === 探测性扩容参数 ===
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容
# 探测性扩容可以突破已知边界,尝试更高的并发
PROBE_INCREASE_INTERVAL_MINUTES = 30
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求

View File

@@ -106,13 +106,6 @@ class Config:
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
# 可信代理配置
# TRUSTED_PROXY_COUNT: 信任的代理层数(默认 1即信任最近一层代理
# 设置为 0 表示不信任任何代理头,直接使用连接 IP
# 当服务部署在 Nginx/CloudFlare 等反向代理后面时,设置为对应的代理层数
# 如果服务直接暴露公网,应设置为 0 以防止 IP 伪造
self.trusted_proxy_count = int(os.getenv("TRUSTED_PROXY_COUNT", "1"))
# 异常处理配置
# 设置为 True 时ProxyException 会传播到路由层以便记录 provider_request_headers
# 设置为 False 时,使用全局异常处理器统一处理
@@ -161,6 +154,11 @@ class Config:
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
self.stream_first_byte_timeout = self._parse_ttfb_timeout()
# 请求体读取超时(秒)
# REQUEST_BODY_TIMEOUT: 等待客户端发送完整请求体的超时时间
# 默认 60 秒,防止客户端发送不完整请求导致连接卡死
self.request_body_timeout = float(os.getenv("REQUEST_BODY_TIMEOUT", "60.0"))
# 内部请求 User-Agent 配置(用于查询上游模型列表等)
# 可通过环境变量覆盖默认值,模拟对应 CLI 客户端
self.internal_user_agent_claude_cli = os.getenv(

View File

@@ -30,3 +30,10 @@ class ProviderBillingType(Enum):
MONTHLY_QUOTA = "monthly_quota" # 月卡额度
PAY_AS_YOU_GO = "pay_as_you_go" # 按量付费
FREE_TIER = "free_tier" # 免费额度
class AuthSource(str, Enum):
"""认证来源枚举"""
LOCAL = "local" # 本地认证
LDAP = "ldap" # LDAP 认证

28
src/core/error_utils.py Normal file
View File

@@ -0,0 +1,28 @@
"""
错误消息处理工具函数
"""
from typing import Optional
def extract_error_message(error: Exception, status_code: Optional[int] = None) -> str:
"""
从异常中提取错误消息,优先使用上游响应内容
Args:
error: 异常对象
status_code: 可选的 HTTP 状态码,用于构建更详细的错误消息
Returns:
错误消息字符串
"""
# 优先使用 upstream_response 属性(包含上游 Provider 的原始错误)
upstream_response = getattr(error, "upstream_response", None)
if upstream_response and isinstance(upstream_response, str) and upstream_response.strip():
return str(upstream_response)
# 回退到异常的字符串表示str 可能为空,如 httpx 超时异常)
error_str = str(error) or repr(error)
if status_code is not None:
return f"HTTP {status_code}: {error_str}"
return error_str

View File

@@ -547,11 +547,19 @@ class ErrorResponse:
- 所有错误都记录到日志,通过错误 ID 关联
"""
if isinstance(e, ProxyException):
details = e.details.copy() if e.details else {}
status_code = e.status_code
message = e.message
# 如果是 ProviderNotAvailableException 且有上游错误,直接透传上游信息
if isinstance(e, ProviderNotAvailableException) and e.upstream_response:
if e.upstream_status:
status_code = e.upstream_status
message = e.upstream_response
return ErrorResponse.create(
error_type=e.error_type,
message=e.message,
status_code=e.status_code,
details=e.details,
message=message,
status_code=status_code,
details=details if details else None,
)
elif isinstance(e, HTTPException):
return ErrorResponse.create(

View File

@@ -411,7 +411,7 @@ def init_db():
print(" 3. 数据库用户名和密码是否正确", file=sys.stderr)
print("", file=sys.stderr)
print("如果使用 Docker请先运行:", file=sys.stderr)
print(" docker-compose up -d postgres redis", file=sys.stderr)
print(" docker compose -f docker-compose.build.yml up -d postgres redis", file=sys.stderr)
print("", file=sys.stderr)
print("=" * 60, file=sys.stderr)
# 使用 os._exit 直接退出,避免 uvicorn 捕获并打印堆栈

View File

@@ -203,28 +203,21 @@ class PluginMiddleware:
"""
获取客户端 IP 地址,支持代理头
注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头,
仅当服务部署在可信代理(如 Nginx、CloudFlare后面时才安全。
如果服务直接暴露公网,攻击者可伪造这些头绕过限流。
优先级X-Real-IP > X-Forwarded-For > 直连 IP
X-Real-IP 由最外层 Nginx 设置,最可靠
"""
# 从配置获取可信代理层数(默认为 1即信任最近一层代理
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
# 优先从代理头获取真实 IP
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
# X-Forwarded-For 格式: "client, proxy1, proxy2"
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
ips = [ip.strip() for ip in forwarded_for.split(",")]
if len(ips) > trusted_proxy_count:
return ips[-(trusted_proxy_count + 1)]
elif ips:
return ips[0]
# 优先检查 X-Real-IP由最外层 Nginx 设置,最可靠
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip.strip()
# 检查 X-Forwarded-For取第一个 IP原始客户端
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
if ips:
return ips[0]
# 回退到直连 IP
if request.client:
return request.client.host

View File

@@ -4,9 +4,9 @@ API端点请求/响应模型定义
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from ..core.enums import UserRole
@@ -15,17 +15,9 @@ from ..core.enums import UserRole
class LoginRequest(BaseModel):
"""登录请求"""
email: str = Field(..., min_length=3, max_length=255, description="邮箱地址")
email: str = Field(..., min_length=1, max_length=255, description="邮箱/用户名")
password: str = Field(..., min_length=1, max_length=128, description="密码")
@classmethod
@field_validator("email")
def validate_email(cls, v):
"""验证邮箱格式"""
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
if not re.match(email_pattern, v):
raise ValueError("邮箱格式无效")
return v.lower()
auth_type: Literal["local", "ldap"] = Field(default="local", description="认证类型")
@classmethod
@field_validator("password")
@@ -36,6 +28,24 @@ class LoginRequest(BaseModel):
raise ValueError("密码不能为空")
return v
@model_validator(mode="after")
def validate_login(self):
"""根据认证类型校验并规范化登录标识"""
identifier = self.email.strip()
if not identifier:
raise ValueError("用户名/邮箱不能为空")
# 本地和 LDAP 登录都支持用户名或邮箱
# 如果是邮箱格式,转换为小写
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
if re.match(email_pattern, identifier):
self.email = identifier.lower()
else:
self.email = identifier
return self
class LoginResponse(BaseModel):
"""登录响应"""
@@ -309,8 +319,9 @@ class CreateApiKeyRequest(BaseModel):
allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表
allowed_api_formats: Optional[List[str]] = None # 允许使用的 API 格式列表
allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表
rate_limit: Optional[int] = 100
expire_days: Optional[int] = None # None = 永不过期,数字 = 多少天后过期
rate_limit: Optional[int] = None # None = 无限制
expire_days: Optional[int] = None # None = 永不过期,数字 = 多少天后过期(兼容旧版)
expires_at: Optional[str] = None # ISO 日期字符串,如 "2025-12-31",优先于 expire_days
initial_balance_usd: Optional[float] = Field(
None, description="初始余额USD仅用于独立KeyNone = 无限制"
)

View File

@@ -30,7 +30,7 @@ from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import declarative_base, relationship
from ..config import config
from ..core.enums import ProviderBillingType, UserRole
from ..core.enums import AuthSource, ProviderBillingType, UserRole
Base = declarative_base()
@@ -54,6 +54,20 @@ class User(Base):
default=UserRole.USER,
nullable=False,
)
auth_source = Column(
Enum(
AuthSource,
name="authsource",
create_type=False,
values_callable=lambda x: [e.value for e in x],
),
default=AuthSource.LOCAL,
nullable=False,
)
# LDAP 标识(仅 auth_source=ldap 时使用,用于在邮箱变更/用户名冲突时稳定关联本地账户)
ldap_dn = Column(String(512), nullable=True, index=True)
ldap_username = Column(String(255), nullable=True, index=True)
# 访问限制NULL 表示不限制,允许访问所有资源)
allowed_providers = Column(JSON, nullable=True) # 允许使用的提供商 ID 列表
@@ -150,7 +164,7 @@ class ApiKey(Base):
allowed_endpoints = Column(JSON, nullable=True) # 允许使用的端点 ID 列表
allowed_api_formats = Column(JSON, nullable=True) # 允许使用的 API 格式列表
allowed_models = Column(JSON, nullable=True) # 允许使用的模型名称列表
rate_limit = Column(Integer, default=100) # 每分钟请求限制
rate_limit = Column(Integer, default=None, nullable=True) # 每分钟请求限制None = 无限制
concurrent_limit = Column(Integer, default=5, nullable=True) # 并发请求限制
# Key 能力配置
@@ -428,6 +442,68 @@ class SystemConfig(Base):
)
class LDAPConfig(Base):
"""LDAP认证配置表 - 单行配置"""
__tablename__ = "ldap_configs"
id = Column(Integer, primary_key=True, autoincrement=True)
server_url = Column(String(255), nullable=False) # ldap://host:389 或 ldaps://host:636
bind_dn = Column(String(255), nullable=False) # 绑定账号 DN
bind_password_encrypted = Column(Text, nullable=True) # 加密的绑定密码(允许 NULL 表示已清除)
base_dn = Column(String(255), nullable=False) # 用户搜索基础 DN
user_search_filter = Column(
String(500), default="(uid={username})", nullable=False
) # 用户搜索过滤器
username_attr = Column(String(50), default="uid", nullable=False) # 用户名属性 (uid/sAMAccountName)
email_attr = Column(String(50), default="mail", nullable=False) # 邮箱属性
display_name_attr = Column(String(50), default="cn", nullable=False) # 显示名称属性
is_enabled = Column(Boolean, default=False, nullable=False) # 是否启用 LDAP 认证
is_exclusive = Column(
Boolean, default=False, nullable=False
) # 是否仅允许 LDAP 登录(禁用本地认证)
use_starttls = Column(Boolean, default=False, nullable=False) # 是否使用 STARTTLS
connect_timeout = Column(Integer, default=10, nullable=False) # 连接超时时间(秒)
# 时间戳
created_at = Column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
)
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
nullable=False,
)
def set_bind_password(self, password: str) -> None:
"""
设置并加密绑定密码
Args:
password: 明文密码
"""
from src.core.crypto import crypto_service
self.bind_password_encrypted = crypto_service.encrypt(password)
def get_bind_password(self) -> str:
"""
获取解密后的绑定密码
Returns:
str: 解密后的明文密码
Raises:
DecryptionException: 解密失败时抛出异常
"""
from src.core.crypto import crypto_service
if not self.bind_password_encrypted:
return ""
return crypto_service.decrypt(self.bind_password_encrypted)
class Provider(Base):
"""提供商配置表"""

View File

@@ -19,6 +19,7 @@ class ProviderEndpointCreate(BaseModel):
provider_id: str = Field(..., description="Provider ID")
api_format: str = Field(..., description="API 格式 (CLAUDE, OPENAI, CLAUDE_CLI, OPENAI_CLI)")
base_url: str = Field(..., min_length=1, max_length=500, description="API 基础 URL")
custom_path: Optional[str] = Field(default=None, max_length=200, description="自定义请求路径")
# 请求配置
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
@@ -62,6 +63,7 @@ class ProviderEndpointUpdate(BaseModel):
base_url: Optional[str] = Field(
default=None, min_length=1, max_length=500, description="API 基础 URL"
)
custom_path: Optional[str] = Field(default=None, max_length=200, description="自定义请求路径")
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
timeout: Optional[int] = Field(default=None, ge=10, le=600, description="超时时间(秒)")
max_retries: Optional[int] = Field(default=None, ge=0, le=10, description="最大重试次数")
@@ -94,6 +96,7 @@ class ProviderEndpointResponse(BaseModel):
# API 配置
api_format: str
base_url: str
custom_path: Optional[str] = None
# 请求配置
headers: Optional[Dict[str, str]] = None

View File

@@ -21,7 +21,7 @@ WARNING: 多进程环境注意事项
import asyncio
import time
from collections import deque
from datetime import datetime
from datetime import datetime, timezone
from typing import Any, Deque, Dict
from src.core.logger import logger
@@ -95,12 +95,12 @@ class SlidingWindow:
"""获取最早的重置时间"""
self._cleanup()
if not self.requests:
return datetime.now()
return datetime.now(timezone.utc)
# 最早的请求将在window_size秒后过期
oldest_request = self.requests[0]
reset_time = oldest_request + self.window_size
return datetime.fromtimestamp(reset_time)
return datetime.fromtimestamp(reset_time, tz=timezone.utc)
class SlidingWindowStrategy(RateLimitStrategy):
@@ -250,7 +250,7 @@ class SlidingWindowStrategy(RateLimitStrategy):
retry_after = None
if not allowed:
# 计算需要等待的时间(最早请求过期的时间)
retry_after = int((reset_at - datetime.now()).total_seconds()) + 1
retry_after = int((reset_at - datetime.now(timezone.utc)).total_seconds()) + 1
return RateLimitResult(
allowed=allowed,

View File

@@ -3,7 +3,7 @@
import asyncio
import os
import time
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Tuple
from ...clients.redis_client import get_redis_client_sync
@@ -63,11 +63,11 @@ class TokenBucket:
def get_reset_time(self) -> datetime:
"""获取下次完全恢复的时间"""
if self.tokens >= self.capacity:
return datetime.now()
return datetime.now(timezone.utc)
tokens_needed = self.capacity - self.tokens
seconds_to_full = tokens_needed / self.refill_rate
return datetime.now() + timedelta(seconds=seconds_to_full)
return datetime.now(timezone.utc) + timedelta(seconds=seconds_to_full)
class TokenBucketStrategy(RateLimitStrategy):
@@ -370,7 +370,7 @@ class RedisTokenBucketBackend:
if tokens is None or last_refill is None:
remaining = capacity
reset_at = datetime.now() + timedelta(seconds=capacity / refill_rate)
reset_at = datetime.now(timezone.utc) + timedelta(seconds=capacity / refill_rate)
else:
tokens_value = float(tokens)
last_refill_value = float(last_refill)
@@ -378,7 +378,7 @@ class RedisTokenBucketBackend:
tokens_value = min(capacity, tokens_value + delta * refill_rate)
remaining = int(tokens_value)
reset_after = 0 if tokens_value >= capacity else (capacity - tokens_value) / refill_rate
reset_at = datetime.now() + timedelta(seconds=reset_after)
reset_at = datetime.now(timezone.utc) + timedelta(seconds=reset_after)
allowed = remaining >= amount
retry_after = None

363
src/services/auth/ldap.py Normal file
View File

@@ -0,0 +1,363 @@
"""LDAP 认证服务"""
from typing import Any, Dict, Optional, Tuple
from urllib.parse import urlparse
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.models.database import LDAPConfig
# LDAP 连接默认超时时间(秒)
DEFAULT_LDAP_CONNECT_TIMEOUT = 10
def parse_ldap_server_url(server_url: str) -> tuple[str, int, bool]:
"""
解析 LDAP 服务器地址,支持:
- ldap://host:389
- ldaps://host:636
- host:389无 scheme 时默认 ldap
Returns:
(host, port, use_ssl)
"""
raw = (server_url or "").strip()
if not raw:
raise ValueError("LDAP server_url is required")
parsed = urlparse(raw)
if parsed.scheme in {"ldap", "ldaps"}:
host = parsed.hostname
if not host:
raise ValueError("Invalid LDAP server_url")
use_ssl = parsed.scheme == "ldaps"
port = parsed.port or (636 if use_ssl else 389)
return host, port, use_ssl
# 兼容无 scheme按 ldap:// 解析
parsed = urlparse(f"ldap://{raw}")
host = parsed.hostname
if not host:
raise ValueError("Invalid LDAP server_url")
port = parsed.port or 389
return host, port, False
def escape_ldap_filter(value: str, max_length: int = 128) -> str:
"""
转义 LDAP 过滤器中的特殊字符,防止 LDAP 注入攻击RFC 4515
Args:
value: 需要转义的字符串
max_length: 最大允许长度,默认 128 字符(覆盖大多数企业邮箱用户名)
Returns:
转义后的安全字符串
Raises:
ValueError: 输入值过长
"""
import unicodedata
# 先检查原始长度,防止 DoS 攻击
# 128 字符足够覆盖大多数企业用户名和邮箱地址
if len(value) > max_length:
raise ValueError(f"LDAP filter value too long (max {max_length} characters)")
# Unicode 规范化(使用 NFC 而非 NFKC避免兼容性字符转换导致安全问题
value = unicodedata.normalize("NFC", value)
# 再次检查规范化后的长度(防止规范化后长度突增)
if len(value) > max_length:
raise ValueError(f"LDAP filter value too long after normalization (max {max_length})")
# LDAP 过滤器特殊字符RFC 4515 + 扩展)
# 使用显式顺序处理,确保反斜杠首先转义
value = value.replace("\\", r"\5c") # 反斜杠必须首先转义
value = value.replace("*", r"\2a")
value = value.replace("(", r"\28")
value = value.replace(")", r"\29")
value = value.replace("\x00", r"\00") # NUL
value = value.replace("&", r"\26")
value = value.replace("|", r"\7c")
value = value.replace("=", r"\3d")
value = value.replace(">", r"\3e")
value = value.replace("<", r"\3c")
value = value.replace("~", r"\7e")
value = value.replace("!", r"\21")
return value
def _get_attr_value(entry: Any, attr_name: str, default: str = "") -> str:
"""
提取 LDAP 条目属性的首个值,避免返回字符串化的列表表示。
"""
attr = getattr(entry, attr_name, None)
if not attr:
return default
# ldap3 的 EntryAttribute.value 已经是单值或列表,根据类型取首个
val = getattr(attr, "value", None)
if isinstance(val, list):
val = val[0] if val else default
if val is None:
return default
return str(val)
class LDAPService:
"""LDAP 认证服务"""
@staticmethod
def get_config(db: Session) -> Optional[LDAPConfig]:
"""获取 LDAP 配置"""
return db.query(LDAPConfig).first()
@staticmethod
def is_ldap_enabled(db: Session) -> bool:
"""检查 LDAP 是否可用(已启用且绑定密码可解密)"""
return LDAPService.get_config_data(db) is not None
@staticmethod
def is_ldap_exclusive(db: Session) -> bool:
"""检查是否仅允许 LDAP 登录(仅在 LDAP 可用时生效,避免误锁定)"""
config = LDAPService.get_config(db)
if not config or config.is_exclusive is not True:
return False
return LDAPService.get_config_data(db) is not None
@staticmethod
def get_config_data(db: Session) -> Optional[Dict[str, Any]]:
"""
提前获取并解密配置,供线程池使用,避免跨线程共享 Session。
"""
config = LDAPService.get_config(db)
if not config or config.is_enabled is not True:
return None
try:
bind_password = config.get_bind_password()
except Exception as e:
logger.error(f"LDAP 绑定密码解密失败: {e}")
return None
# 绑定密码为空时无法进行 LDAP 认证
if not bind_password:
logger.warning("LDAP 绑定密码未配置,无法进行 LDAP 认证")
return None
return {
"server_url": config.server_url,
"bind_dn": config.bind_dn,
"bind_password": bind_password,
"base_dn": config.base_dn,
"user_search_filter": config.user_search_filter,
"username_attr": config.username_attr,
"email_attr": config.email_attr,
"display_name_attr": config.display_name_attr,
"use_starttls": config.use_starttls,
"connect_timeout": config.connect_timeout or DEFAULT_LDAP_CONNECT_TIMEOUT,
}
@staticmethod
def authenticate_with_config(config: Dict[str, Any], username: str, password: str) -> Optional[dict]:
"""
LDAP bind 验证
Args:
config: 已解密的 LDAP 配置
username: 用户名
password: 密码
Returns:
用户属性 dict {username, email, display_name} 或 None
"""
try:
import ldap3
from ldap3 import Server, Connection, SUBTREE
from ldap3.core.exceptions import LDAPBindError, LDAPSocketOpenError
except ImportError:
logger.error("ldap3 库未安装")
return None
if not config:
logger.warning("LDAP 未配置或未启用")
return None
admin_conn = None
user_conn = None
try:
# 创建服务器连接
server_url = config["server_url"]
server_host, server_port, use_ssl = parse_ldap_server_url(server_url)
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
server = Server(
server_host,
port=server_port,
use_ssl=use_ssl,
get_info=ldap3.ALL,
connect_timeout=timeout,
)
# 使用管理员账号连接
bind_password = config["bind_password"]
admin_conn = Connection(
server,
user=config["bind_dn"],
password=bind_password,
receive_timeout=timeout, # 添加读取超时,避免服务器响应缓慢时阻塞
)
if config.get("use_starttls") and not use_ssl:
admin_conn.start_tls()
if not admin_conn.bind():
logger.error(f"LDAP 管理员绑定失败: {admin_conn.result}")
return None
# 搜索用户(转义用户名防止 LDAP 注入)
safe_username = escape_ldap_filter(username)
search_filter = config["user_search_filter"].replace("{username}", safe_username)
admin_conn.search(
search_base=config["base_dn"],
search_filter=search_filter,
search_scope=SUBTREE,
size_limit=2, # 防止过滤器误配导致匹配多用户
time_limit=timeout, # 添加搜索超时,防止大型目录搜索阻塞
attributes=[
config["username_attr"],
config["email_attr"],
config["display_name_attr"],
],
)
if len(admin_conn.entries) != 1:
# 统一错误信息,避免泄露用户是否存在;日志仅记录结果数量,不泄露敏感信息
logger.warning(
f"LDAP 认证失败(用户查找阶段): 搜索返回 {len(admin_conn.entries)} 条结果"
)
return None
user_entry = admin_conn.entries[0]
user_dn = user_entry.entry_dn
# 用户密码验证
user_conn = Connection(
server,
user=user_dn,
password=password,
receive_timeout=timeout, # 添加读取超时
)
if config.get("use_starttls") and not use_ssl:
user_conn.start_tls()
if not user_conn.bind():
# 统一错误信息,避免泄露密码是否正确;日志仅记录错误码,不泄露用户 DN
bind_result = user_conn.result.get("description", "unknown")
logger.warning(f"LDAP 认证失败(密码验证阶段): {bind_result}")
return None
# 提取用户属性(优先用 LDAP 提供的值,不合法则回退默认)
ldap_username = _get_attr_value(user_entry, config["username_attr"], username)
email = _get_attr_value(
user_entry, config["email_attr"], f"{username}@ldap.local"
)
display_name = _get_attr_value(user_entry, config["display_name_attr"], username)
logger.info(f"LDAP 认证成功: {username}")
return {
"username": ldap_username,
"ldap_username": ldap_username,
"ldap_dn": user_dn,
"email": email,
"display_name": display_name,
}
except LDAPSocketOpenError as e:
logger.error(f"LDAP 服务器连接失败: {e}")
return None
except LDAPBindError as e:
logger.error(f"LDAP 绑定失败: {e}")
return None
except Exception as e:
logger.error(f"LDAP 认证异常: {e}")
return None
finally:
# 确保连接关闭,避免失败路径泄漏
# 使用循环确保即使第一个 unbind 失败,后续连接仍会尝试关闭
for conn, name in [(admin_conn, "admin"), (user_conn, "user")]:
if conn:
try:
conn.unbind()
except Exception as e:
logger.warning(f"LDAP {name} 连接关闭失败: {e}")
@staticmethod
def test_connection_with_config(config: Dict[str, Any]) -> Tuple[bool, str]:
"""
测试 LDAP 连接
Returns:
(success, message)
"""
try:
import ldap3
from ldap3 import Server, Connection
except ImportError:
return False, "ldap3 库未安装"
if not config:
return False, "LDAP 配置不存在"
conn = None
try:
server_url = config["server_url"]
server_host, server_port, use_ssl = parse_ldap_server_url(server_url)
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
server = Server(
server_host,
port=server_port,
use_ssl=use_ssl,
get_info=ldap3.ALL,
connect_timeout=timeout,
)
bind_password = config["bind_password"]
conn = Connection(
server,
user=config["bind_dn"],
password=bind_password,
receive_timeout=timeout, # 添加读取超时
)
if config.get("use_starttls") and not use_ssl:
conn.start_tls()
if not conn.bind():
return False, f"绑定失败: {conn.result}"
return True, "连接成功"
except Exception as e:
# 记录详细错误到日志,但只返回通用信息给前端,避免泄露敏感信息
logger.error(f"LDAP 测试连接失败: {type(e).__name__}: {e}")
return False, "连接失败,请检查服务器地址、端口和凭据"
finally:
if conn:
try:
conn.unbind()
except Exception as e:
logger.warning(f"LDAP 测试连接关闭失败: {e}")
# 兼容旧接口:如果其他代码直接调用
@staticmethod
def authenticate(db: Session, username: str, password: str) -> Optional[dict]:
config = LDAPService.get_config_data(db)
return LDAPService.authenticate_with_config(config, username, password) if config else None
@staticmethod
def test_connection(db: Session) -> Tuple[bool, str]:
config = LDAPService.get_config_data(db)
if not config:
return False, "LDAP 配置不存在或未启用"
return LDAPService.test_connection_with_config(config)

View File

@@ -2,21 +2,25 @@
认证服务
"""
import os
import hashlib
import secrets
import time
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional
import jwt
from fastapi import HTTPException, status
from fastapi.concurrency import run_in_threadpool
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, joinedload
from src.config import config
from src.core.crypto import crypto_service
from src.core.logger import logger
from src.core.enums import AuthSource
from src.models.database import ApiKey, User, UserRole
from src.services.auth.jwt_blacklist import JWTBlacklistService
from src.services.auth.ldap import LDAPService
from src.services.cache.user_cache import UserCacheService
from src.services.user.apikey import ApiKeyService
@@ -92,15 +96,86 @@ class AuthService:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的Token")
@staticmethod
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
"""用户登录认证"""
async def authenticate_user(
db: Session, email: str, password: str, auth_type: str = "local"
) -> Optional[User]:
"""用户登录认证
Args:
db: 数据库会话
email: 邮箱/用户名
password: 密码
auth_type: 认证类型 ("local""ldap")
"""
if auth_type == "ldap":
# LDAP 认证
# 预取配置,避免将 Session 传递到线程池
config_data = LDAPService.get_config_data(db)
if not config_data:
logger.warning("登录失败 - LDAP 未启用或配置无效")
return None
# 计算总体超时LDAP 认证包含多次网络操作(连接、管理员绑定、搜索、用户绑定)
# 超时策略:
# - 单次操作超时(connect_timeout):控制每次网络操作的最大等待时间
# - 总体超时:防止异常场景(如服务器响应缓慢但未超时)导致请求堆积
# - 公式:单次超时 × 4覆盖 4 次主要网络操作)+ 10% 缓冲
# - 最小 20 秒(保证基本操作),最大 60 秒(避免用户等待过长)
single_timeout = config_data.get("connect_timeout", 10)
total_timeout = max(20, min(int(single_timeout * 4 * 1.1), 60))
# 在线程池中执行阻塞的 LDAP 网络请求,避免阻塞事件循环
# 添加总体超时保护,防止异常场景下请求堆积
import asyncio
try:
ldap_user = await asyncio.wait_for(
run_in_threadpool(
LDAPService.authenticate_with_config, config_data, email, password
),
timeout=total_timeout,
)
except asyncio.TimeoutError:
logger.error(f"LDAP 认证总体超时({total_timeout}秒): {email}")
return None
if not ldap_user:
return None
# 获取或创建本地用户
user = await AuthService._get_or_create_ldap_user(db, ldap_user)
if not user:
# 已有本地账号但来源不匹配等情况
return None
if not user.is_active:
logger.warning(f"登录失败 - 用户已禁用: {email}")
return None
return user
# 本地认证
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
user = db.query(User).filter(User.email == email).first()
# 支持邮箱或用户名登录
from sqlalchemy import or_
user = db.query(User).filter(
or_(User.email == email, User.username == email)
).first()
if not user:
logger.warning(f"登录失败 - 用户不存在: {email}")
return None
# 检查 LDAP exclusive 模式:仅允许本地管理员登录(紧急恢复通道)
if LDAPService.is_ldap_exclusive(db):
if user.role != UserRole.ADMIN or user.auth_source != AuthSource.LOCAL:
logger.warning(f"登录失败 - 仅允许 LDAP 登录(管理员除外): {email}")
return None
logger.warning(f"[LDAP-EXCLUSIVE] 紧急恢复通道:本地管理员登录: {email}")
# 检查用户认证来源
if user.auth_source == AuthSource.LDAP:
logger.warning(f"登录失败 - 该用户使用 LDAP 认证: {email}")
return None
if not user.verify_password(password):
logger.warning(f"登录失败 - 密码错误: {email}")
return None
@@ -118,6 +193,127 @@ class AuthService:
logger.info(f"用户登录成功: {email} (ID: {user.id})")
return user
@staticmethod
async def _get_or_create_ldap_user(db: Session, ldap_user: dict) -> Optional[User]:
"""获取或创建 LDAP 用户
Args:
ldap_user: LDAP 用户信息 {username, email, display_name, ldap_dn, ldap_username}
注意:使用 with_for_update() 防止并发首次登录创建重复用户
"""
ldap_dn = (ldap_user.get("ldap_dn") or "").strip() or None
ldap_username = (ldap_user.get("ldap_username") or ldap_user.get("username") or "").strip() or None
email = ldap_user["email"]
# 优先用稳定标识查找,避免邮箱变更/用户名冲突导致重复建号
# 使用 with_for_update() 锁定行,防止并发创建
user: Optional[User] = None
if ldap_dn:
user = (
db.query(User)
.filter(User.auth_source == AuthSource.LDAP, User.ldap_dn == ldap_dn)
.with_for_update()
.first()
)
if not user and ldap_username:
user = (
db.query(User)
.filter(User.auth_source == AuthSource.LDAP, User.ldap_username == ldap_username)
.with_for_update()
.first()
)
if not user:
# 最后回退按 email 查找:如果存在同邮箱的本地账号,需要拒绝以避免接管
user = db.query(User).filter(User.email == email).with_for_update().first()
if user:
if user.auth_source != AuthSource.LDAP:
# 避免覆盖已有本地账户(不同来源时拒绝登录)
logger.warning(
f"LDAP 登录拒绝 - 账户来源不匹配(现有:{user.auth_source}, 请求:LDAP): {email}"
)
return None
# 同步邮箱LDAP 侧邮箱变更时更新;若新邮箱已被占用则拒绝)
if user.email != email:
email_taken = (
db.query(User)
.filter(User.email == email, User.id != user.id)
.first()
)
if email_taken:
logger.warning(f"LDAP 登录拒绝 - 新邮箱已被占用: {email}")
return None
user.email = email
# 同步 LDAP 标识(首次填充或 LDAP 侧发生变化)
if ldap_dn and user.ldap_dn != ldap_dn:
user.ldap_dn = ldap_dn
if ldap_username and user.ldap_username != ldap_username:
user.ldap_username = ldap_username
user.last_login_at = datetime.now(timezone.utc)
db.commit()
await UserCacheService.invalidate_user_cache(user.id, user.email)
logger.info(f"LDAP 用户登录成功: {ldap_user['email']} (ID: {user.id})")
return user
# 检查 username 是否已被占用,使用时间戳+随机数确保唯一性
base_username = ldap_username or ldap_user["username"]
username = base_username
max_retries = 3
for attempt in range(max_retries):
# 检查用户名是否已存在
existing_user_with_username = db.query(User).filter(User.username == username).first()
if existing_user_with_username:
# 如果 username 已存在,使用时间戳+随机数确保唯一性
username = f"{base_username}_ldap_{int(time.time())}{uuid.uuid4().hex[:4]}"
logger.info(f"LDAP 用户名冲突,使用新用户名: {ldap_user['username']} -> {username}")
# 创建新用户
user = User(
email=email,
username=username,
password_hash="", # LDAP 用户无本地密码
auth_source=AuthSource.LDAP,
ldap_dn=ldap_dn,
ldap_username=ldap_username,
role=UserRole.USER,
is_active=True,
last_login_at=datetime.now(timezone.utc),
)
try:
db.add(user)
db.commit()
db.refresh(user)
logger.info(f"LDAP 用户创建成功: {ldap_user['email']} (ID: {user.id})")
return user
except IntegrityError as e:
db.rollback()
error_str = str(e.orig).lower() if e.orig else str(e).lower()
# 解析具体冲突类型
if "email" in error_str or "ix_users_email" in error_str:
# 邮箱冲突不应重试(前面已检查过,说明是并发创建)
logger.error(f"LDAP 用户创建失败 - 邮箱并发冲突: {email}")
return None
elif "username" in error_str or "ix_users_username" in error_str:
# 用户名冲突,重试时会生成新用户名
if attempt == max_retries - 1:
logger.error(f"LDAP 用户创建失败(用户名冲突重试耗尽): {username}")
return None
username = f"{base_username}_ldap_{int(time.time())}{uuid.uuid4().hex[:4]}"
logger.warning(f"LDAP 用户创建用户名冲突,重试 ({attempt + 1}/{max_retries}): {username}")
else:
# 其他约束冲突,不重试
logger.error(f"LDAP 用户创建失败 - 未知数据库约束冲突: {e}")
return None
return None
@staticmethod
def authenticate_api_key(db: Session, api_key: str) -> Optional[tuple[User, ApiKey]]:
"""API密钥认证"""

View File

@@ -0,0 +1,51 @@
"""
计费模块
提供配置驱动的计费计算,支持不同厂商的差异化计费模式:
- Claude: input + output + cache_creation + cache_read
- OpenAI: input + output + cache_read (无缓存创建费用)
- 豆包: input + output + cache_read + cache_storage (缓存按时计费)
- 按次计费: per_request
使用方式:
from src.services.billing import BillingCalculator, UsageMapper, StandardizedUsage
# 1. 将原始 usage 映射为标准格式
usage = UsageMapper.map(raw_usage, api_format="OPENAI")
# 2. 使用计费计算器计算费用
calculator = BillingCalculator(template="openai")
result = calculator.calculate(usage, prices)
# 3. 获取费用明细
print(result.total_cost)
print(result.costs) # {"input": 0.01, "output": 0.02, ...}
"""
from src.services.billing.calculator import BillingCalculator, calculate_request_cost
from src.services.billing.models import (
BillingDimension,
BillingUnit,
CostBreakdown,
StandardizedUsage,
)
from src.services.billing.templates import BILLING_TEMPLATE_REGISTRY, BillingTemplates
from src.services.billing.usage_mapper import UsageMapper, map_usage, map_usage_from_response
__all__ = [
# 数据模型
"BillingDimension",
"BillingUnit",
"CostBreakdown",
"StandardizedUsage",
# 模板
"BillingTemplates",
"BILLING_TEMPLATE_REGISTRY",
# 计算器
"BillingCalculator",
"calculate_request_cost",
# 映射器
"UsageMapper",
"map_usage",
"map_usage_from_response",
]

View File

@@ -0,0 +1,339 @@
"""
计费计算器
配置驱动的计费计算,支持:
- 固定价格计费
- 阶梯计费
- 多种计费模板
- 自定义计费维度
"""
from typing import Any, Dict, List, Optional, Tuple
from src.services.billing.models import (
BillingDimension,
BillingUnit,
CostBreakdown,
StandardizedUsage,
)
from src.services.billing.templates import (
BILLING_TEMPLATE_REGISTRY,
BillingTemplates,
get_template,
)
class BillingCalculator:
"""
配置驱动的计费计算器
支持多种计费模式:
- 使用预定义模板claude, openai, doubao 等)
- 自定义计费维度
- 阶梯计费
示例:
# 使用模板
calculator = BillingCalculator(template="openai")
# 自定义维度
calculator = BillingCalculator(dimensions=[
BillingDimension(name="input", usage_field="input_tokens", price_field="input_price_per_1m"),
BillingDimension(name="output", usage_field="output_tokens", price_field="output_price_per_1m"),
])
# 计算费用
usage = StandardizedUsage(input_tokens=1000, output_tokens=500)
prices = {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0}
result = calculator.calculate(usage, prices)
"""
def __init__(
self,
dimensions: Optional[List[BillingDimension]] = None,
template: Optional[str] = None,
):
"""
初始化计费计算器
Args:
dimensions: 自定义计费维度列表(优先级高于模板)
template: 使用预定义模板名称 ("claude", "openai", "doubao", "per_request" 等)
"""
if dimensions:
self.dimensions = dimensions
elif template:
self.dimensions = get_template(template)
else:
# 默认使用 Claude 模板(向后兼容)
self.dimensions = BillingTemplates.CLAUDE_STANDARD
self.template_name = template
def calculate(
self,
usage: StandardizedUsage,
prices: Dict[str, float],
tiered_pricing: Optional[Dict[str, Any]] = None,
cache_ttl_minutes: Optional[int] = None,
total_input_context: Optional[int] = None,
) -> CostBreakdown:
"""
计算费用
Args:
usage: 标准化的 usage 数据
prices: 价格配置 {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0, ...}
tiered_pricing: 阶梯计费配置(可选)
cache_ttl_minutes: 缓存 TTL 分钟数(用于 TTL 差异化定价)
total_input_context: 总输入上下文(用于阶梯判定,可选)
如果提供,将使用该值进行阶梯判定;否则使用默认计算逻辑
Returns:
费用明细 (CostBreakdown)
"""
result = CostBreakdown()
# 处理阶梯计费
effective_prices = prices.copy()
if tiered_pricing and tiered_pricing.get("tiers"):
tier, tier_index = self._get_tier(usage, tiered_pricing, total_input_context)
if tier:
result.tier_index = tier_index
# 阶梯价格覆盖默认价格
for key, value in tier.items():
if key not in ("up_to", "cache_ttl_pricing") and value is not None:
effective_prices[key] = value
# 处理 TTL 差异化定价
if cache_ttl_minutes is not None:
ttl_price = self._get_cache_read_price_for_ttl(tier, cache_ttl_minutes)
if ttl_price is not None:
effective_prices["cache_read_price_per_1m"] = ttl_price
# 记录使用的价格
result.effective_prices = effective_prices.copy()
# 计算各维度费用
total = 0.0
for dim in self.dimensions:
usage_value = usage.get(dim.usage_field, 0)
price = effective_prices.get(dim.price_field, dim.default_price)
if usage_value and price:
cost = dim.calculate(usage_value, price)
result.costs[dim.name] = cost
total += cost
result.total_cost = total
return result
def _get_tier(
self,
usage: StandardizedUsage,
tiered_pricing: Dict[str, Any],
total_input_context: Optional[int] = None,
) -> Tuple[Optional[Dict[str, Any]], Optional[int]]:
"""
确定价格阶梯
Args:
usage: usage 数据
tiered_pricing: 阶梯配置 {"tiers": [...]}
total_input_context: 预计算的总输入上下文(可选)
Returns:
(匹配的阶梯配置, 阶梯索引)
"""
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None, None
# 使用传入的 total_input_context或者默认计算
if total_input_context is None:
total_input_context = self._compute_total_input_context(usage)
for i, tier in enumerate(tiers):
up_to = tier.get("up_to")
# up_to 为 None 表示无上限(最后一个阶梯)
if up_to is None or total_input_context <= up_to:
return tier, i
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
return tiers[-1], len(tiers) - 1
def _compute_total_input_context(self, usage: StandardizedUsage) -> int:
"""
计算总输入上下文(用于阶梯计费判定)
默认: input_tokens + cache_read_tokens
Args:
usage: usage 数据
Returns:
总输入 token 数
"""
return usage.input_tokens + usage.cache_read_tokens
def _get_cache_read_price_for_ttl(
self,
tier: Dict[str, Any],
cache_ttl_minutes: int,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
某些厂商(如 Claude对不同 TTL 的缓存有不同定价。
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格,如果没有 TTL 差异化配置返回 None
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if not ttl_pricing:
return None
# 找到匹配或最接近的 TTL 价格
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
price = ttl_config.get("cache_read_price_per_1m")
return float(price) if price is not None else None
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
price = ttl_pricing[-1].get("cache_read_price_per_1m")
return float(price) if price is not None else None
return None
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BillingCalculator":
"""
从配置创建计费计算器
Config 格式:
{
"template": "claude", # 或 "openai", "doubao", "per_request"
# 或者自定义维度:
"dimensions": [
{"name": "input", "usage_field": "input_tokens", "price_field": "input_price_per_1m"},
...
]
}
Args:
config: 配置字典
Returns:
BillingCalculator 实例
"""
if "dimensions" in config:
dimensions = [BillingDimension.from_dict(d) for d in config["dimensions"]]
return cls(dimensions=dimensions)
return cls(template=config.get("template", "claude"))
def get_dimension_names(self) -> List[str]:
"""获取所有计费维度名称"""
return [dim.name for dim in self.dimensions]
def get_required_price_fields(self) -> List[str]:
"""获取所需的价格字段名称"""
return [dim.price_field for dim in self.dimensions]
def get_required_usage_fields(self) -> List[str]:
"""获取所需的 usage 字段名称"""
return [dim.usage_field for dim in self.dimensions]
def calculate_request_cost(
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int,
cache_read_input_tokens: int,
input_price_per_1m: float,
output_price_per_1m: float,
cache_creation_price_per_1m: Optional[float],
cache_read_price_per_1m: Optional[float],
price_per_request: Optional[float],
tiered_pricing: Optional[Dict[str, Any]] = None,
cache_ttl_minutes: Optional[int] = None,
total_input_context: Optional[int] = None,
billing_template: str = "claude",
) -> Dict[str, Any]:
"""
计算请求成本的便捷函数
封装了 BillingCalculator 的调用逻辑,返回兼容旧格式的字典。
Args:
input_tokens: 输入 token 数
output_tokens: 输出 token 数
cache_creation_input_tokens: 缓存创建 token 数
cache_read_input_tokens: 缓存读取 token 数
input_price_per_1m: 输入价格(每 1M tokens
output_price_per_1m: 输出价格(每 1M tokens
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens
cache_read_price_per_1m: 缓存读取价格(每 1M tokens
price_per_request: 按次计费价格
tiered_pricing: 阶梯计费配置
cache_ttl_minutes: 缓存时长(分钟)
total_input_context: 总输入上下文(用于阶梯判定)
billing_template: 计费模板名称
Returns:
包含各项成本的字典:
{
"input_cost": float,
"output_cost": float,
"cache_creation_cost": float,
"cache_read_cost": float,
"cache_cost": float,
"request_cost": float,
"total_cost": float,
"tier_index": Optional[int],
}
"""
# 构建标准化 usage
usage = StandardizedUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_tokens=cache_creation_input_tokens,
cache_read_tokens=cache_read_input_tokens,
request_count=1,
)
# 构建价格配置
prices: Dict[str, float] = {
"input_price_per_1m": input_price_per_1m,
"output_price_per_1m": output_price_per_1m,
}
if cache_creation_price_per_1m is not None:
prices["cache_creation_price_per_1m"] = cache_creation_price_per_1m
if cache_read_price_per_1m is not None:
prices["cache_read_price_per_1m"] = cache_read_price_per_1m
if price_per_request is not None:
prices["price_per_request"] = price_per_request
# 使用 BillingCalculator 计算
calculator = BillingCalculator(template=billing_template)
result = calculator.calculate(
usage, prices, tiered_pricing, cache_ttl_minutes, total_input_context
)
# 返回兼容旧格式的字典
return {
"input_cost": result.input_cost,
"output_cost": result.output_cost,
"cache_creation_cost": result.cache_creation_cost,
"cache_read_cost": result.cache_read_cost,
"cache_cost": result.cache_cost,
"request_cost": result.request_cost,
"total_cost": result.total_cost,
"tier_index": result.tier_index,
}

View File

@@ -0,0 +1,281 @@
"""
计费模块数据模型
定义计费相关的核心数据结构:
- BillingUnit: 计费单位枚举
- BillingDimension: 计费维度定义
- StandardizedUsage: 标准化的 usage 数据
- CostBreakdown: 计费明细结果
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Optional
class BillingUnit(str, Enum):
"""计费单位"""
PER_1M_TOKENS = "per_1m_tokens" # 每百万 token
PER_1M_TOKENS_HOUR = "per_1m_tokens_hour" # 每百万 token 每小时(豆包缓存存储)
PER_REQUEST = "per_request" # 每次请求
FIXED = "fixed" # 固定费用
@dataclass
class BillingDimension:
"""
计费维度定义
每个维度描述一种计费方式,例如:
- 输入 token 计费
- 输出 token 计费
- 缓存读取计费
- 按次计费
"""
name: str # 维度名称,如 "input", "output", "cache_read"
usage_field: str # 从 usage 中取值的字段名
price_field: str # 价格配置中的字段名
unit: BillingUnit = BillingUnit.PER_1M_TOKENS # 计费单位
default_price: float = 0.0 # 默认价格(当价格配置中没有时使用)
def calculate(self, usage_value: float, price: float) -> float:
"""
计算该维度的费用
Args:
usage_value: 使用量数值
price: 单价
Returns:
计算后的费用
"""
if usage_value <= 0 or price <= 0:
return 0.0
if self.unit == BillingUnit.PER_1M_TOKENS:
return (usage_value / 1_000_000) * price
elif self.unit == BillingUnit.PER_1M_TOKENS_HOUR:
# 缓存存储按 token 数 * 小时数计费
return (usage_value / 1_000_000) * price
elif self.unit == BillingUnit.PER_REQUEST:
return usage_value * price
elif self.unit == BillingUnit.FIXED:
return price
return 0.0
def to_dict(self) -> Dict[str, Any]:
"""转换为字典(用于序列化)"""
return {
"name": self.name,
"usage_field": self.usage_field,
"price_field": self.price_field,
"unit": self.unit.value,
"default_price": self.default_price,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BillingDimension":
"""从字典创建实例"""
return cls(
name=data["name"],
usage_field=data["usage_field"],
price_field=data["price_field"],
unit=BillingUnit(data.get("unit", "per_1m_tokens")),
default_price=data.get("default_price", 0.0),
)
@dataclass
class StandardizedUsage:
"""
标准化的 Usage 数据
将不同 API 格式的 usage 统一为标准格式,便于计费计算。
"""
# 基础 token 计数
input_tokens: int = 0
output_tokens: int = 0
# 缓存相关
cache_creation_tokens: int = 0 # Claude: 缓存创建
cache_read_tokens: int = 0 # Claude/OpenAI/豆包: 缓存读取/命中
# 特殊 token 类型
reasoning_tokens: int = 0 # o1/豆包: 推理 token通常包含在 output 中,单独记录用于分析)
# 时间相关(用于按时计费)
cache_storage_token_hours: float = 0.0 # 豆包: 缓存存储 token*小时
# 请求计数(用于按次计费)
request_count: int = 1
# 扩展字段(未来可能需要的额外维度)
extra: Dict[str, Any] = field(default_factory=dict)
def get(self, field_name: str, default: Any = 0) -> Any:
"""
通用字段获取
支持获取标准字段和扩展字段。
Args:
field_name: 字段名
default: 默认值
Returns:
字段值
"""
if hasattr(self, field_name):
value = getattr(self, field_name)
# 对于 extra 字段,不直接返回
if field_name != "extra":
return value
return self.extra.get(field_name, default)
def set(self, field_name: str, value: Any) -> None:
"""
通用字段设置
Args:
field_name: 字段名
value: 字段值
"""
if hasattr(self, field_name) and field_name != "extra":
setattr(self, field_name, value)
else:
self.extra[field_name] = value
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
result: Dict[str, Any] = {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"cache_creation_tokens": self.cache_creation_tokens,
"cache_read_tokens": self.cache_read_tokens,
"reasoning_tokens": self.reasoning_tokens,
"cache_storage_token_hours": self.cache_storage_token_hours,
"request_count": self.request_count,
}
if self.extra:
result["extra"] = self.extra
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "StandardizedUsage":
"""从字典创建实例"""
extra = data.pop("extra", {}) if "extra" in data else {}
# 只取已知字段
known_fields = {
"input_tokens",
"output_tokens",
"cache_creation_tokens",
"cache_read_tokens",
"reasoning_tokens",
"cache_storage_token_hours",
"request_count",
}
filtered = {k: v for k, v in data.items() if k in known_fields}
return cls(**filtered, extra=extra)
@dataclass
class CostBreakdown:
"""
计费明细结果
包含各维度的费用和总费用。
"""
# 各维度费用 {"input": 0.01, "output": 0.02, "cache_read": 0.001, ...}
costs: Dict[str, float] = field(default_factory=dict)
# 总费用
total_cost: float = 0.0
# 命中的阶梯索引(如果使用阶梯计费)
tier_index: Optional[int] = None
# 货币单位
currency: str = "USD"
# 使用的价格(用于记录和审计)
effective_prices: Dict[str, float] = field(default_factory=dict)
# =========================================================================
# 兼容旧接口的属性(便于渐进式迁移)
# =========================================================================
@property
def input_cost(self) -> float:
"""输入费用"""
return self.costs.get("input", 0.0)
@property
def output_cost(self) -> float:
"""输出费用"""
return self.costs.get("output", 0.0)
@property
def cache_creation_cost(self) -> float:
"""缓存创建费用"""
return self.costs.get("cache_creation", 0.0)
@property
def cache_read_cost(self) -> float:
"""缓存读取费用"""
return self.costs.get("cache_read", 0.0)
@property
def cache_cost(self) -> float:
"""总缓存费用(创建 + 读取)"""
return self.cache_creation_cost + self.cache_read_cost
@property
def request_cost(self) -> float:
"""按次计费费用"""
return self.costs.get("request", 0.0)
@property
def cache_storage_cost(self) -> float:
"""缓存存储费用(豆包等)"""
return self.costs.get("cache_storage", 0.0)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"costs": self.costs,
"total_cost": self.total_cost,
"tier_index": self.tier_index,
"currency": self.currency,
"effective_prices": self.effective_prices,
# 兼容字段
"input_cost": self.input_cost,
"output_cost": self.output_cost,
"cache_creation_cost": self.cache_creation_cost,
"cache_read_cost": self.cache_read_cost,
"cache_cost": self.cache_cost,
"request_cost": self.request_cost,
}
def to_legacy_tuple(self) -> tuple:
"""
转换为旧接口的元组格式
Returns:
(input_cost, output_cost, cache_creation_cost, cache_read_cost,
cache_cost, request_cost, total_cost, tier_index)
"""
return (
self.input_cost,
self.output_cost,
self.cache_creation_cost,
self.cache_read_cost,
self.cache_cost,
self.request_cost,
self.total_cost,
self.tier_index,
)

View File

@@ -0,0 +1,213 @@
"""
预定义计费模板
提供常见厂商的计费配置模板,避免重复配置:
- CLAUDE_STANDARD: Claude/Anthropic 标准计费
- OPENAI_STANDARD: OpenAI 标准计费
- DOUBAO_STANDARD: 豆包计费(含缓存存储)
- GEMINI_STANDARD: Gemini 标准计费
- PER_REQUEST: 按次计费
"""
from typing import Dict, List, Optional
from src.services.billing.models import BillingDimension, BillingUnit
class BillingTemplates:
"""预定义的计费模板"""
# =========================================================================
# Claude/Anthropic 标准计费
# - 输入 token
# - 输出 token
# - 缓存创建(创建时收费,约 1.25x 输入价格)
# - 缓存读取(约 0.1x 输入价格)
# =========================================================================
CLAUDE_STANDARD: List[BillingDimension] = [
BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
),
BillingDimension(
name="output",
usage_field="output_tokens",
price_field="output_price_per_1m",
),
BillingDimension(
name="cache_creation",
usage_field="cache_creation_tokens",
price_field="cache_creation_price_per_1m",
),
BillingDimension(
name="cache_read",
usage_field="cache_read_tokens",
price_field="cache_read_price_per_1m",
),
]
# =========================================================================
# OpenAI 标准计费
# - 输入 token
# - 输出 token
# - 缓存读取(部分模型支持,无缓存创建费用)
# =========================================================================
OPENAI_STANDARD: List[BillingDimension] = [
BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
),
BillingDimension(
name="output",
usage_field="output_tokens",
price_field="output_price_per_1m",
),
BillingDimension(
name="cache_read",
usage_field="cache_read_tokens",
price_field="cache_read_price_per_1m",
),
]
# =========================================================================
# 豆包计费
# - 推理输入 (input_tokens)
# - 推理输出 (output_tokens)
# - 缓存命中 (cache_read_tokens) - 类似 Claude 的缓存读取
# - 缓存存储 (cache_storage_token_hours) - 按 token 数 * 存储时长计费
#
# 注意:豆包的缓存创建是免费的,但存储需要按时付费
# =========================================================================
DOUBAO_STANDARD: List[BillingDimension] = [
BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
),
BillingDimension(
name="output",
usage_field="output_tokens",
price_field="output_price_per_1m",
),
BillingDimension(
name="cache_read",
usage_field="cache_read_tokens",
price_field="cache_read_price_per_1m",
),
BillingDimension(
name="cache_storage",
usage_field="cache_storage_token_hours",
price_field="cache_storage_price_per_1m_hour",
unit=BillingUnit.PER_1M_TOKENS_HOUR,
),
]
# =========================================================================
# Gemini 标准计费
# - 输入 token
# - 输出 token
# - 缓存读取
# =========================================================================
GEMINI_STANDARD: List[BillingDimension] = [
BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
),
BillingDimension(
name="output",
usage_field="output_tokens",
price_field="output_price_per_1m",
),
BillingDimension(
name="cache_read",
usage_field="cache_read_tokens",
price_field="cache_read_price_per_1m",
),
]
# =========================================================================
# 按次计费
# - 适用于某些图片生成模型、特殊 API 等
# - 仅按请求次数计费,不按 token 计费
# =========================================================================
PER_REQUEST: List[BillingDimension] = [
BillingDimension(
name="request",
usage_field="request_count",
price_field="price_per_request",
unit=BillingUnit.PER_REQUEST,
),
]
# =========================================================================
# 混合计费(按次 + 按 token
# - 某些模型既有固定费用又有 token 费用
# =========================================================================
HYBRID_STANDARD: List[BillingDimension] = [
BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
),
BillingDimension(
name="output",
usage_field="output_tokens",
price_field="output_price_per_1m",
),
BillingDimension(
name="request",
usage_field="request_count",
price_field="price_per_request",
unit=BillingUnit.PER_REQUEST,
),
]
# =========================================================================
# 模板注册表
# =========================================================================
BILLING_TEMPLATE_REGISTRY: Dict[str, List[BillingDimension]] = {
# 按厂商名称
"claude": BillingTemplates.CLAUDE_STANDARD,
"anthropic": BillingTemplates.CLAUDE_STANDARD,
"openai": BillingTemplates.OPENAI_STANDARD,
"doubao": BillingTemplates.DOUBAO_STANDARD,
"bytedance": BillingTemplates.DOUBAO_STANDARD,
"gemini": BillingTemplates.GEMINI_STANDARD,
"google": BillingTemplates.GEMINI_STANDARD,
# 按计费模式
"per_request": BillingTemplates.PER_REQUEST,
"hybrid": BillingTemplates.HYBRID_STANDARD,
# 默认
"default": BillingTemplates.CLAUDE_STANDARD,
}
def get_template(name: Optional[str]) -> List[BillingDimension]:
"""
获取计费模板
Args:
name: 模板名称(不区分大小写)
Returns:
计费维度列表
"""
if not name:
return BILLING_TEMPLATE_REGISTRY["default"]
template = BILLING_TEMPLATE_REGISTRY.get(name.lower())
if template is None:
available = ", ".join(sorted(BILLING_TEMPLATE_REGISTRY.keys()))
raise ValueError(f"Unknown billing template: {name!r}. Available: {available}")
return template
def list_templates() -> List[str]:
"""列出所有可用的模板名称"""
return list(BILLING_TEMPLATE_REGISTRY.keys())

View File

@@ -0,0 +1,267 @@
"""
Usage 字段映射器
将不同 API 格式的原始 usage 数据映射为标准化格式。
支持的格式:
- OPENAI / OPENAI_CLI: OpenAI Chat Completions API
- CLAUDE / CLAUDE_CLI: Anthropic Messages API
- GEMINI / GEMINI_CLI: Google Gemini API
"""
from typing import Any, Dict, Optional
from src.services.billing.models import StandardizedUsage
class UsageMapper:
"""
Usage 字段映射器
将不同 API 格式的 usage 统一映射为 StandardizedUsage。
示例:
# OpenAI 格式
raw_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
"prompt_tokens_details": {"cached_tokens": 20},
"completion_tokens_details": {"reasoning_tokens": 10}
}
usage = UsageMapper.map(raw_usage, "OPENAI")
# Claude 格式
raw_usage = {
"input_tokens": 100,
"output_tokens": 50,
"cache_creation_input_tokens": 30,
"cache_read_input_tokens": 20
}
usage = UsageMapper.map(raw_usage, "CLAUDE")
"""
# =========================================================================
# 字段映射配置
# 格式: "source_path" -> "target_field"
# source_path 支持点号分隔的嵌套路径
# =========================================================================
# OpenAI 格式字段映射
OPENAI_MAPPING: Dict[str, str] = {
"prompt_tokens": "input_tokens",
"completion_tokens": "output_tokens",
"prompt_tokens_details.cached_tokens": "cache_read_tokens",
"completion_tokens_details.reasoning_tokens": "reasoning_tokens",
}
# Claude 格式字段映射
CLAUDE_MAPPING: Dict[str, str] = {
"input_tokens": "input_tokens",
"output_tokens": "output_tokens",
"cache_creation_input_tokens": "cache_creation_tokens",
"cache_read_input_tokens": "cache_read_tokens",
}
# Gemini 格式字段映射
GEMINI_MAPPING: Dict[str, str] = {
"promptTokenCount": "input_tokens",
"candidatesTokenCount": "output_tokens",
"cachedContentTokenCount": "cache_read_tokens",
# Gemini 的 usageMetadata 格式
"usageMetadata.promptTokenCount": "input_tokens",
"usageMetadata.candidatesTokenCount": "output_tokens",
"usageMetadata.cachedContentTokenCount": "cache_read_tokens",
}
# 格式名称到映射的对应关系
FORMAT_MAPPINGS: Dict[str, Dict[str, str]] = {
"OPENAI": OPENAI_MAPPING,
"OPENAI_CLI": OPENAI_MAPPING,
"CLAUDE": CLAUDE_MAPPING,
"CLAUDE_CLI": CLAUDE_MAPPING,
"GEMINI": GEMINI_MAPPING,
"GEMINI_CLI": GEMINI_MAPPING,
}
@classmethod
def map(
cls,
raw_usage: Dict[str, Any],
api_format: str,
extra_mapping: Optional[Dict[str, str]] = None,
) -> StandardizedUsage:
"""
将原始 usage 映射为标准化格式
Args:
raw_usage: 原始 usage 字典
api_format: API 格式 ("OPENAI", "CLAUDE", "GEMINI" 等)
extra_mapping: 额外的字段映射(用于自定义扩展)
Returns:
标准化的 usage 对象
"""
if not raw_usage:
return StandardizedUsage()
# 获取对应格式的字段映射
mapping = cls._get_mapping(api_format)
# 合并额外映射
if extra_mapping:
mapping = {**mapping, **extra_mapping}
result = StandardizedUsage()
# 执行映射
for source_path, target_field in mapping.items():
value = cls._get_nested_value(raw_usage, source_path)
if value is not None:
result.set(target_field, value)
return result
@classmethod
def map_from_response(
cls,
response: Dict[str, Any],
api_format: str,
) -> StandardizedUsage:
"""
从完整响应中提取并映射 usage
不同 API 格式的 usage 位置可能不同:
- OpenAI: response["usage"]
- Claude: response["usage"] 或 message_delta 中
- Gemini: response["usageMetadata"]
Args:
response: 完整的 API 响应
api_format: API 格式
Returns:
标准化的 usage 对象
"""
format_upper = api_format.upper() if api_format else ""
# 提取 usage 部分
usage_data: Dict[str, Any] = {}
if format_upper.startswith("GEMINI"):
# Gemini: usageMetadata
usage_data = response.get("usageMetadata", {})
if not usage_data:
# 尝试从 candidates 中获取
candidates = response.get("candidates", [])
if candidates:
usage_data = candidates[0].get("usageMetadata", {})
else:
# OpenAI/Claude: usage
usage_data = response.get("usage", {})
return cls.map(usage_data, api_format)
@classmethod
def _get_mapping(cls, api_format: str) -> Dict[str, str]:
"""获取对应格式的字段映射"""
if not api_format:
return cls.CLAUDE_MAPPING
format_upper = api_format.upper()
# 精确匹配
if format_upper in cls.FORMAT_MAPPINGS:
return cls.FORMAT_MAPPINGS[format_upper]
# 前缀匹配
for key, mapping in cls.FORMAT_MAPPINGS.items():
if format_upper.startswith(key.split("_")[0]):
return mapping
# 默认使用 Claude 映射
return cls.CLAUDE_MAPPING
@classmethod
def _get_nested_value(cls, data: Dict[str, Any], path: str) -> Any:
"""
获取嵌套字段值
支持点号分隔的路径,如 "prompt_tokens_details.cached_tokens"
Args:
data: 数据字典
path: 字段路径
Returns:
字段值,不存在则返回 None
"""
if not data or not path:
return None
keys = path.split(".")
value: Any = data
for key in keys:
if isinstance(value, dict):
value = value.get(key)
if value is None:
return None
else:
return None
return value
@classmethod
def register_format(cls, format_name: str, mapping: Dict[str, str]) -> None:
"""
注册新的格式映射
Args:
format_name: 格式名称(会自动转为大写)
mapping: 字段映射
"""
cls.FORMAT_MAPPINGS[format_name.upper()] = mapping
@classmethod
def get_supported_formats(cls) -> list:
"""获取所有支持的格式"""
return list(cls.FORMAT_MAPPINGS.keys())
# =========================================================================
# 便捷函数
# =========================================================================
def map_usage(
raw_usage: Dict[str, Any],
api_format: str,
) -> StandardizedUsage:
"""
便捷函数:将原始 usage 映射为标准化格式
Args:
raw_usage: 原始 usage 字典
api_format: API 格式
Returns:
StandardizedUsage 对象
"""
return UsageMapper.map(raw_usage, api_format)
def map_usage_from_response(
response: Dict[str, Any],
api_format: str,
) -> StandardizedUsage:
"""
便捷函数:从响应中提取并映射 usage
Args:
response: API 响应
api_format: API 格式
Returns:
StandardizedUsage 对象
"""
return UsageMapper.map_from_response(response, api_format)

View File

@@ -148,6 +148,8 @@ class GlobalModelService:
删除 GlobalModel
默认行为: 级联删除所有关联的 Provider 模型实现
注意: 不清理 API Key 和 User 的 allowed_models 引用,
保留无效引用可让用户在前端看到"已失效"的模型,便于手动清理或等待重建同名模型
"""
global_model = GlobalModelService.get_global_model(db, global_model_id)

View File

@@ -237,7 +237,7 @@ class ErrorClassifier:
result["reason"] = str(data.get("reason", data.get("code", "")))
except (json.JSONDecodeError, TypeError, KeyError):
result["message"] = error_text[:500] if len(error_text) > 500 else error_text
result["message"] = error_text
return result
@@ -323,8 +323,8 @@ class ErrorClassifier:
if parts:
return ": ".join(parts) if len(parts) > 1 else parts[0]
# 无法解析,返回原始文本(截断)
return parsed["raw"][:500] if len(parsed["raw"]) > 500 else parsed["raw"]
# 无法解析,返回原始文本
return parsed["raw"]
def classify(
self,
@@ -484,11 +484,15 @@ class ErrorClassifier:
return ProviderNotAvailableException(
message=detailed_message,
provider_name=provider_name,
upstream_status=status,
upstream_response=error_response_text,
)
return ProviderNotAvailableException(
message=detailed_message,
provider_name=provider_name,
upstream_status=status,
upstream_response=error_response_text,
)
async def handle_http_error(
@@ -532,10 +536,12 @@ class ErrorClassifier:
provider_name = str(provider.name)
# 尝试读取错误响应内容
error_response_text = None
# 优先使用 handler 附加的 upstream_response 属性(流式请求中 response.text 可能为空)
error_response_text = getattr(http_error, "upstream_response", None)
if not error_response_text:
try:
if http_error.response and hasattr(http_error.response, "text"):
error_response_text = http_error.response.text[:1000] # 限制长度
error_response_text = http_error.response.text
except Exception:
pass

View File

@@ -30,6 +30,7 @@ from redis import Redis
from sqlalchemy.orm import Session
from src.core.enums import APIFormat
from src.core.error_utils import extract_error_message
from src.core.exceptions import (
ConcurrencyLimitError,
ProviderNotAvailableException,
@@ -401,7 +402,7 @@ class FallbackOrchestrator:
db=self.db,
candidate_id=candidate_record_id,
error_type="HTTPStatusError",
error_message=f"HTTP {status_code}: {str(cause)}",
error_message=extract_error_message(cause, status_code),
status_code=status_code,
latency_ms=elapsed_ms,
concurrent_requests=captured_key_concurrent,
@@ -425,31 +426,22 @@ class FallbackOrchestrator:
attempt=attempt,
max_attempts=max_attempts,
)
# str(cause) 可能为空(如 httpx 超时异常),使用 repr() 作为备用
error_msg = str(cause) or repr(cause)
# 如果是 ProviderNotAvailableException附加上游响应
if hasattr(cause, "upstream_response") and cause.upstream_response:
error_msg = f"{error_msg} | 上游响应: {cause.upstream_response[:500]}"
RequestCandidateService.mark_candidate_failed(
db=self.db,
candidate_id=candidate_record_id,
error_type=type(cause).__name__,
error_message=error_msg,
error_message=extract_error_message(cause),
latency_ms=elapsed_ms,
concurrent_requests=captured_key_concurrent,
)
return "continue" if has_retry_left else "break"
# 未知错误:记录失败并抛出
error_msg = str(cause) or repr(cause)
# 如果是 ProviderNotAvailableException附加上游响应
if hasattr(cause, "upstream_response") and cause.upstream_response:
error_msg = f"{error_msg} | 上游响应: {cause.upstream_response[:500]}"
RequestCandidateService.mark_candidate_failed(
db=self.db,
candidate_id=candidate_record_id,
error_type=type(cause).__name__,
error_message=error_msg,
error_message=extract_error_message(cause),
latency_ms=elapsed_ms,
concurrent_requests=captured_key_concurrent,
)
@@ -543,7 +535,9 @@ class FallbackOrchestrator:
raise last_error
# 所有组合都已尝试完毕,全部失败
self._raise_all_failed_exception(request_id, max_attempts, last_candidate, model_name, api_format_enum)
self._raise_all_failed_exception(
request_id, max_attempts, last_candidate, model_name, api_format_enum, last_error
)
async def _try_candidate_with_retries(
self,
@@ -565,6 +559,7 @@ class FallbackOrchestrator:
provider = candidate.provider
endpoint = candidate.endpoint
max_retries_for_candidate = int(endpoint.max_retries) if candidate.is_cached else 1
last_error: Optional[Exception] = None
for retry_index in range(max_retries_for_candidate):
attempt_counter += 1
@@ -599,6 +594,7 @@ class FallbackOrchestrator:
return {"success": True, "response": response}
except ExecutionError as exec_err:
last_error = exec_err.cause
action = await self._handle_candidate_error(
exec_err=exec_err,
candidate=candidate,
@@ -630,6 +626,7 @@ class FallbackOrchestrator:
"success": False,
"attempt_counter": attempt_counter,
"max_attempts": max_attempts,
"error": last_error,
}
def _attach_metadata_to_error(
@@ -678,6 +675,7 @@ class FallbackOrchestrator:
last_candidate: Optional[ProviderCandidate],
model_name: str,
api_format_enum: APIFormat,
last_error: Optional[Exception] = None,
) -> NoReturn:
"""所有组合都失败时抛出异常"""
logger.error(f" [{request_id}] 所有 {max_attempts} 个组合均失败")
@@ -693,9 +691,38 @@ class FallbackOrchestrator:
"api_format": api_format_enum.value,
}
# 提取上游错误响应
upstream_status: Optional[int] = None
upstream_response: Optional[str] = None
if last_error:
# 从 httpx.HTTPStatusError 提取
if isinstance(last_error, httpx.HTTPStatusError):
upstream_status = last_error.response.status_code
# 优先使用我们附加的 upstream_response 属性(流已读取时 response.text 可能为空)
upstream_response = getattr(last_error, "upstream_response", None)
if not upstream_response:
try:
upstream_response = last_error.response.text
except Exception:
pass
# 从其他异常属性提取(如 ProviderNotAvailableException
else:
upstream_status = getattr(last_error, "upstream_status", None)
upstream_response = getattr(last_error, "upstream_response", None)
# 如果响应为空或无效,使用异常的字符串表示
if (
not upstream_response
or not upstream_response.strip()
or upstream_response.startswith("Unable to read")
):
upstream_response = str(last_error)
raise ProviderNotAvailableException(
f"所有Provider均不可用已尝试{max_attempts}个组合",
request_metadata=request_metadata,
upstream_status=upstream_status,
upstream_response=upstream_response,
)
async def execute_with_fallback(

View File

@@ -1,14 +1,16 @@
"""
自适应并发调整器 - 基于滑动窗口利用率的并发限制调整
自适应并发调整器 - 基于边界记忆的并发限制调整
核心改进(相对于旧版基于"持续高利用率"的方案):
- 使用滑动窗口采样,容忍并发波动
- 基于窗口内高利用率采样比例决策,而非要求连续高利用率
- 增加探测性扩容机制,长时间稳定时主动尝试扩容
核心算法:边界记忆 + 渐进探测
- 触发 429 时记录边界last_concurrent_peak这就是真实上限
- 缩容策略:新限制 = 边界 - 1而非乘性减少
- 扩容策略:不超过已知边界,除非是探测性扩容
- 探测性扩容:长时间无 429 时尝试突破边界
AIMD 参数说明
- 扩容:加性增加 (+INCREASE_STEP)
- 缩容:乘性减少 (*DECREASE_MULTIPLIER默认 0.85)
设计原则
1. 快速收敛:一次 429 就能找到接近真实的限制
2. 避免过度保守:不会因为多次 429 而无限下降
3. 安全探测:允许在稳定后尝试更高并发
"""
from datetime import datetime, timezone
@@ -35,21 +37,21 @@ class AdaptiveConcurrencyManager:
"""
自适应并发管理器
核心算法:基于滑动窗口利用率的 AIMD
- 滑动窗口记录最近 N 次请求的利用率
- 当窗口内高利用率采样比例 >= 60% 时触发扩容
- 遇到 429 错误时乘性减少 (*0.85)
- 长时间无 429 且有流量时触发探测性扩容
核心算法:边界记忆 + 渐进探测
- 触发 429 时记录边界last_concurrent_peak = 触发时的并发数)
- 缩容:新限制 = 边界 - 1快速收敛到真实限制附近
- 扩容:不超过边界(即 last_concurrent_peak允许回到边界值尝试
- 探测性扩容长时间30分钟无 429 时,可以尝试 +1 突破边界
扩容条件(满足任一即可):
1. 滑动窗口扩容:窗口内 >= 60% 的采样利用率 >= 70%,且不在冷却期
2. 探测性扩容:距上次 429 超过 30 分钟,且期间有足够请求量
1. 利用率扩容:窗口内利用率比例 >= 60%,且当前限制 < 边界
2. 探测性扩容:距上次 429 超过 30 分钟,可以尝试突破边界
关键特性:
1. 滑动窗口容忍并发波动,不会因单次低利用率重置
2. 区分并发限制和 RPM 限制
3. 探测性扩容避免长期卡在低限制
4. 记录调整历史
1. 快速收敛:一次 429 就能学到接近真实的限制值
2. 边界保护:普通扩容不会超过已知边界
3. 安全探测:长时间稳定后允许尝试更高并发
4. 区分并发限制和 RPM 限制
"""
# 默认配置 - 使用统一常量
@@ -59,7 +61,6 @@ class AdaptiveConcurrencyManager:
# AIMD 参数
INCREASE_STEP = ConcurrencyDefaults.INCREASE_STEP
DECREASE_MULTIPLIER = ConcurrencyDefaults.DECREASE_MULTIPLIER
# 滑动窗口参数
UTILIZATION_WINDOW_SIZE = ConcurrencyDefaults.UTILIZATION_WINDOW_SIZE
@@ -115,6 +116,12 @@ class AdaptiveConcurrencyManager:
# 更新429统计
key.last_429_at = datetime.now(timezone.utc) # type: ignore[assignment]
key.last_429_type = rate_limit_info.limit_type # type: ignore[assignment]
# 仅在并发限制且拿到并发数时记录边界RPM/UNKNOWN 不应覆盖并发边界记忆)
if (
rate_limit_info.limit_type == RateLimitType.CONCURRENT
and current_concurrent is not None
and current_concurrent > 0
):
key.last_concurrent_peak = current_concurrent # type: ignore[assignment]
# 遇到 429 错误,清空利用率采样窗口(重新开始收集)
@@ -207,6 +214,9 @@ class AdaptiveConcurrencyManager:
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
# 获取已知边界(上次触发 429 时的并发数)
known_boundary = key.last_concurrent_peak
# 计算当前利用率
utilization = float(current_concurrent / current_limit) if current_limit > 0 else 0.0
@@ -217,22 +227,29 @@ class AdaptiveConcurrencyManager:
samples = self._update_utilization_window(key, now_ts, utilization)
# 检查是否满足扩容条件
increase_reason = self._check_increase_conditions(key, samples, now)
increase_reason = self._check_increase_conditions(key, samples, now, known_boundary)
if increase_reason and current_limit < self.MAX_CONCURRENT_LIMIT:
old_limit = current_limit
new_limit = self._increase_limit(current_limit)
is_probe = increase_reason == "probe_increase"
new_limit = self._increase_limit(current_limit, known_boundary, is_probe)
# 如果没有实际增长(已达边界),跳过
if new_limit <= old_limit:
return None
# 计算窗口统计用于日志
avg_util = sum(s["util"] for s in samples) / len(samples) if samples else 0
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
high_util_ratio = high_util_count / len(samples) if samples else 0
boundary_info = f"边界: {known_boundary}" if known_boundary else "无边界"
logger.info(
f"[INCREASE] {increase_reason}: Key {key.id[:8]}... | "
f"窗口采样: {len(samples)} | "
f"平均利用率: {avg_util:.1%} | "
f"高利用率比例: {high_util_ratio:.1%} | "
f"{boundary_info} | "
f"调整: {old_limit} -> {new_limit}"
)
@@ -246,13 +263,14 @@ class AdaptiveConcurrencyManager:
high_util_ratio=round(high_util_ratio, 2),
sample_count=len(samples),
current_concurrent=current_concurrent,
known_boundary=known_boundary,
)
# 更新限制
key.learned_max_concurrent = new_limit # type: ignore[assignment]
# 如果是探测性扩容,更新探测时间
if increase_reason == "probe_increase":
if is_probe:
key.last_probe_increase_at = now # type: ignore[assignment]
# 扩容后清空采样窗口,重新开始收集
@@ -303,7 +321,11 @@ class AdaptiveConcurrencyManager:
return samples
def _check_increase_conditions(
self, key: ProviderAPIKey, samples: List[Dict[str, Any]], now: datetime
self,
key: ProviderAPIKey,
samples: List[Dict[str, Any]],
now: datetime,
known_boundary: Optional[int] = None,
) -> Optional[str]:
"""
检查是否满足扩容条件
@@ -312,6 +334,7 @@ class AdaptiveConcurrencyManager:
key: API Key对象
samples: 利用率采样列表
now: 当前时间
known_boundary: 已知边界(触发 429 时的并发数)
Returns:
扩容原因(如果满足条件),否则返回 None
@@ -320,15 +343,25 @@ class AdaptiveConcurrencyManager:
if self._is_in_cooldown(key):
return None
# 条件1滑动窗口扩容
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
# 条件1滑动窗口扩容不超过边界
if len(samples) >= self.MIN_SAMPLES_FOR_DECISION:
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
high_util_ratio = high_util_count / len(samples)
if high_util_ratio >= self.HIGH_UTILIZATION_RATIO:
# 检查是否还有扩容空间(边界保护)
if known_boundary:
# 允许扩容到边界值(而非 boundary - 1因为缩容时已经 -1 了
if current_limit < known_boundary:
return "high_utilization"
# 已达边界,不触发普通扩容
else:
# 无边界信息,允许扩容
return "high_utilization"
# 条件2探测性扩容长时间无 429 且有流量)
# 条件2探测性扩容长时间无 429 且有流量,可以突破边界
if self._should_probe_increase(key, samples, now):
return "probe_increase"
@@ -406,32 +439,65 @@ class AdaptiveConcurrencyManager:
current_concurrent: Optional[int] = None,
) -> int:
"""
减少并发限制
减少并发限制(基于边界记忆策略)
策略:
- 如果知道当前并发数设置为当前并发的70%
- 否则,使用乘性减少
- 如果知道触发 429 时的并发数,新限制 = 并发数 - 1
- 这样可以快速收敛到真实限制附近,而不会过度保守
- 例如:真实限制 8触发时并发 8 -> 新限制 7而非 8*0.85=6
"""
if current_concurrent:
# 基于当前并发数减少
new_limit = max(
int(current_concurrent * self.DECREASE_MULTIPLIER), self.MIN_CONCURRENT_LIMIT
)
if current_concurrent is not None and current_concurrent > 0:
# 边界记忆策略:新限制 = 触发边界 - 1
candidate = current_concurrent - 1
else:
# 乘性减少
new_limit = max(
int(current_limit * self.DECREASE_MULTIPLIER), self.MIN_CONCURRENT_LIMIT
)
# 没有并发信息时,保守减少 1
candidate = current_limit - 1
# 保证不会“缩容变扩容”(例如 current_concurrent > current_limit 的异常场景)
candidate = min(candidate, current_limit - 1)
new_limit = max(candidate, self.MIN_CONCURRENT_LIMIT)
return new_limit
def _increase_limit(self, current_limit: int) -> int:
def _increase_limit(
self,
current_limit: int,
known_boundary: Optional[int] = None,
is_probe: bool = False,
) -> int:
"""
增加并发限制
增加并发限制(考虑边界保护)
策略:加性增加 (+1)
策略:
- 普通扩容:每次 +INCREASE_STEP但不超过 known_boundary
(因为缩容时已经 -1 了,这里允许回到边界值尝试)
- 探测性扩容:每次只 +1可以突破边界但要谨慎
Args:
current_limit: 当前限制
known_boundary: 已知边界last_concurrent_peak即触发 429 时的并发数
is_probe: 是否是探测性扩容(可以突破边界)
"""
new_limit = min(current_limit + self.INCREASE_STEP, self.MAX_CONCURRENT_LIMIT)
if is_probe:
# 探测模式:每次只 +1谨慎突破边界
new_limit = current_limit + 1
else:
# 普通模式:每次 +INCREASE_STEP
new_limit = current_limit + self.INCREASE_STEP
# 边界保护:普通扩容不超过 known_boundary允许回到边界值尝试
if known_boundary:
if new_limit > known_boundary:
new_limit = known_boundary
# 全局上限保护
new_limit = min(new_limit, self.MAX_CONCURRENT_LIMIT)
# 确保有增长(否则返回原值表示不扩容)
if new_limit <= current_limit:
return current_limit
return new_limit
def _record_adjustment(
@@ -503,11 +569,16 @@ class AdaptiveConcurrencyManager:
if key.last_probe_increase_at:
last_probe_at_str = cast(datetime, key.last_probe_increase_at).isoformat()
# 边界信息
known_boundary = key.last_concurrent_peak
return {
"adaptive_mode": is_adaptive,
"max_concurrent": key.max_concurrent, # NULL=自适应,数字=固定限制
"effective_limit": effective_limit, # 当前有效限制
"learned_limit": key.learned_max_concurrent, # 学习到的限制
# 边界记忆相关
"known_boundary": known_boundary, # 触发 429 时的并发数(已知上限)
"concurrent_429_count": int(key.concurrent_429_count or 0),
"rpm_429_count": int(key.rpm_429_count or 0),
"last_429_at": last_429_at_str,

View File

@@ -289,11 +289,11 @@ class RequestResult:
status_code = 500
error_type = "internal_error"
# 构建错误消息,包含上游响应信息
# 构建错误消息:优先使用上游响应作为主要错误信息
if isinstance(exception, ProviderNotAvailableException) and exception.upstream_response:
error_message = exception.upstream_response
else:
error_message = str(exception)
if isinstance(exception, ProviderNotAvailableException):
if exception.upstream_response:
error_message = f"{error_message} | 上游响应: {exception.upstream_response[:500]}"
return cls(
status=RequestStatus.FAILED,

View File

@@ -86,6 +86,118 @@ class UsageRecordParams:
class UsageService:
"""用量统计服务"""
# ==================== 缓存键常量 ====================
# 热力图缓存键前缀(依赖 TTL 自动过期,用户角色变更时主动清除)
HEATMAP_CACHE_KEY_PREFIX = "activity_heatmap"
# ==================== 热力图缓存 ====================
@classmethod
def _get_heatmap_cache_key(cls, user_id: Optional[str], include_actual_cost: bool) -> str:
"""生成热力图缓存键"""
cost_suffix = "with_cost" if include_actual_cost else "no_cost"
if user_id:
return f"{cls.HEATMAP_CACHE_KEY_PREFIX}:user:{user_id}:{cost_suffix}"
else:
return f"{cls.HEATMAP_CACHE_KEY_PREFIX}:admin:all:{cost_suffix}"
@classmethod
async def clear_user_heatmap_cache(cls, user_id: str) -> None:
"""
清除用户的热力图缓存(用户角色变更时调用)
Args:
user_id: 用户ID
"""
from src.clients.redis_client import get_redis_client
redis_client = await get_redis_client(require_redis=False)
if not redis_client:
return
# 清除该用户的所有热力图缓存with_cost 和 no_cost
keys_to_delete = [
cls._get_heatmap_cache_key(user_id, include_actual_cost=True),
cls._get_heatmap_cache_key(user_id, include_actual_cost=False),
]
for key in keys_to_delete:
try:
await redis_client.delete(key)
logger.debug(f"已清除热力图缓存: {key}")
except Exception as e:
logger.warning(f"清除热力图缓存失败: {key}, error={e}")
@classmethod
async def get_cached_heatmap(
cls,
db: Session,
user_id: Optional[str] = None,
include_actual_cost: bool = False,
) -> Dict[str, Any]:
"""
获取带缓存的热力图数据
缓存策略:
- TTL: 5分钟CacheTTL.ACTIVITY_HEATMAP
- 仅依赖 TTL 自动过期,新使用记录最多延迟 5 分钟出现
- 用户角色变更时通过 clear_user_heatmap_cache() 主动清除
Args:
db: 数据库会话
user_id: 用户IDNone 表示获取全局热力图(管理员)
include_actual_cost: 是否包含实际成本
Returns:
热力图数据字典
"""
from src.clients.redis_client import get_redis_client
from src.config.constants import CacheTTL
import json
cache_key = cls._get_heatmap_cache_key(user_id, include_actual_cost)
cache_ttl = CacheTTL.ACTIVITY_HEATMAP
redis_client = await get_redis_client(require_redis=False)
# 尝试从缓存获取
if redis_client:
try:
cached = await redis_client.get(cache_key)
if cached:
try:
return json.loads(cached) # type: ignore[no-any-return]
except json.JSONDecodeError as e:
logger.warning(f"热力图缓存解析失败,删除损坏缓存: {cache_key}, error={e}")
try:
await redis_client.delete(cache_key)
except Exception:
pass
except Exception as e:
logger.error(f"读取热力图缓存出错: {cache_key}, error={e}")
# 从数据库查询
result = cls.get_daily_activity(
db=db,
user_id=user_id,
window_days=365,
include_actual_cost=include_actual_cost,
)
# 保存到缓存(失败不影响返回结果)
if redis_client:
try:
await redis_client.setex(
cache_key,
cache_ttl,
json.dumps(result, ensure_ascii=False, default=str),
)
except Exception as e:
logger.warning(f"保存热力图缓存失败: {cache_key}, error={e}")
return result
# ==================== 内部数据类 ====================
@staticmethod
@@ -1027,7 +1139,12 @@ class UsageService:
window_days: int = 365,
include_actual_cost: bool = False,
) -> Dict[str, Any]:
"""按天统计请求活跃度,用于渲染热力图。"""
"""按天统计请求活跃度,用于渲染热力图。
优化策略:
- 历史数据从预计算的 StatsDaily/StatsUserDaily 表读取
- 只有"今天"的数据才实时查询 Usage 表
"""
def ensure_timezone(value: datetime) -> datetime:
if value.tzinfo is None:
@@ -1041,54 +1158,109 @@ class UsageService:
ensure_timezone(start_date) if start_date else end_dt - timedelta(days=window_days - 1)
)
# 对齐到自然日的开始/结束,避免遗漏边界数据
start_dt = start_dt.replace(hour=0, minute=0, second=0, microsecond=0)
end_dt = end_dt.replace(hour=23, minute=59, second=59, microsecond=999999)
# 对齐到自然日的开始/结束
start_dt = datetime.combine(start_dt.date(), datetime.min.time(), tzinfo=timezone.utc)
end_dt = datetime.combine(end_dt.date(), datetime.max.time(), tzinfo=timezone.utc)
from src.utils.database_helpers import date_trunc_portable
today = now.date()
today_start_dt = datetime.combine(today, datetime.min.time(), tzinfo=timezone.utc)
aggregated: Dict[str, Dict[str, Any]] = {}
bind = db.get_bind()
dialect = bind.dialect.name if bind is not None else "sqlite"
day_bucket = date_trunc_portable(dialect, "day", Usage.created_at).label("day")
# 1. 从预计算表读取历史数据(不包括今天)
if user_id:
from src.models.database import StatsUserDaily
columns = [
day_bucket,
hist_query = db.query(StatsUserDaily).filter(
StatsUserDaily.user_id == user_id,
StatsUserDaily.date >= start_dt,
StatsUserDaily.date < today_start_dt,
)
for row in hist_query.all():
key = (
row.date.date().isoformat()
if isinstance(row.date, datetime)
else str(row.date)[:10]
)
aggregated[key] = {
"requests": row.total_requests or 0,
"total_tokens": (
(row.input_tokens or 0)
+ (row.output_tokens or 0)
+ (row.cache_creation_tokens or 0)
+ (row.cache_read_tokens or 0)
),
"total_cost_usd": float(row.total_cost or 0.0),
}
# StatsUserDaily 没有 actual_total_cost 字段,用户视图不需要倍率成本
else:
from src.models.database import StatsDaily
hist_query = db.query(StatsDaily).filter(
StatsDaily.date >= start_dt,
StatsDaily.date < today_start_dt,
)
for row in hist_query.all():
key = (
row.date.date().isoformat()
if isinstance(row.date, datetime)
else str(row.date)[:10]
)
aggregated[key] = {
"requests": row.total_requests or 0,
"total_tokens": (
(row.input_tokens or 0)
+ (row.output_tokens or 0)
+ (row.cache_creation_tokens or 0)
+ (row.cache_read_tokens or 0)
),
"total_cost_usd": float(row.total_cost or 0.0),
}
if include_actual_cost:
aggregated[key]["actual_total_cost_usd"] = float(
row.actual_total_cost or 0.0 # type: ignore[attr-defined]
)
# 2. 实时查询今天的数据(如果在查询范围内)
if today >= start_dt.date() and today <= end_dt.date():
today_start = datetime.combine(today, datetime.min.time(), tzinfo=timezone.utc)
today_end = datetime.combine(today, datetime.max.time(), tzinfo=timezone.utc)
if include_actual_cost:
today_query = db.query(
func.count(Usage.id).label("requests"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
]
if include_actual_cost:
columns.append(func.sum(Usage.actual_total_cost_usd).label("actual_total_cost_usd"))
query = db.query(*columns).filter(Usage.created_at >= start_dt, Usage.created_at <= end_dt)
func.sum(Usage.actual_total_cost_usd).label("actual_total_cost_usd"),
).filter(
Usage.created_at >= today_start,
Usage.created_at <= today_end,
)
else:
today_query = db.query(
func.count(Usage.id).label("requests"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
).filter(
Usage.created_at >= today_start,
Usage.created_at <= today_end,
)
if user_id:
query = query.filter(Usage.user_id == user_id)
today_query = today_query.filter(Usage.user_id == user_id)
query = query.group_by(day_bucket).order_by(day_bucket)
rows = query.all()
def normalize_period(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value[:10]
if isinstance(value, datetime):
return value.date().isoformat()
return str(value)
aggregated: Dict[str, Dict[str, Any]] = {}
for row in rows:
key = normalize_period(row.day)
aggregated[key] = {
"requests": int(row.requests or 0),
"total_tokens": int(row.total_tokens or 0),
"total_cost_usd": float(row.total_cost_usd or 0.0),
today_row = today_query.first()
if today_row and today_row.requests:
aggregated[today.isoformat()] = {
"requests": int(today_row.requests or 0),
"total_tokens": int(today_row.total_tokens or 0),
"total_cost_usd": float(today_row.total_cost_usd or 0.0),
}
if include_actual_cost:
aggregated[key]["actual_total_cost_usd"] = float(row.actual_total_cost_usd or 0.0)
aggregated[today.isoformat()]["actual_total_cost_usd"] = float(
today_row.actual_total_cost_usd or 0.0
)
# 3. 构建返回结果
days: List[Dict[str, Any]] = []
cursor = start_dt.date()
end_date_only = end_dt.date()
@@ -1304,6 +1476,9 @@ class UsageService:
provider: Optional[str] = None,
target_model: Optional[str] = None,
first_byte_time_ms: Optional[int] = None,
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
provider_api_key_id: Optional[str] = None,
) -> Optional[Usage]:
"""
快速更新使用记录状态
@@ -1316,6 +1491,9 @@ class UsageService:
provider: 提供商名称可选streaming 状态时更新)
target_model: 映射后的目标模型名(可选)
first_byte_time_ms: 首字时间/TTFB可选streaming 状态时更新)
provider_id: Provider ID可选streaming 状态时更新)
provider_endpoint_id: Endpoint ID可选streaming 状态时更新)
provider_api_key_id: Provider API Key ID可选streaming 状态时更新)
Returns:
更新后的 Usage 记录,如果未找到则返回 None
@@ -1331,10 +1509,22 @@ class UsageService:
usage.error_message = error_message
if provider:
usage.provider = provider
elif status == "streaming" and usage.provider == "pending":
# 状态变为 streaming 但 provider 仍为 pending记录警告
logger.warning(
f"状态更新为 streaming 但 provider 为空: request_id={request_id}, "
f"当前 provider={usage.provider}"
)
if target_model:
usage.target_model = target_model
if first_byte_time_ms is not None:
usage.first_byte_time_ms = first_byte_time_ms
if provider_id is not None:
usage.provider_id = provider_id
if provider_endpoint_id is not None:
usage.provider_endpoint_id = provider_endpoint_id
if provider_api_key_id is not None:
usage.provider_api_key_id = provider_api_key_id
db.commit()
@@ -1446,6 +1636,8 @@ class UsageService:
ids: Optional[List[str]] = None,
user_id: Optional[str] = None,
default_timeout_seconds: int = 300,
*,
include_admin_fields: bool = False,
) -> List[Dict[str, Any]]:
"""
获取活跃请求状态(用于前端轮询),并自动清理超时的 pending/streaming 请求
@@ -1482,6 +1674,15 @@ class UsageService:
ProviderEndpoint.timeout.label("endpoint_timeout"),
).outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
# 管理员轮询:可附带 provider 与上游 key 名称(注意:不要在普通用户接口暴露上游 key 信息)
if include_admin_fields:
from src.models.database import ProviderAPIKey
query = query.add_columns(
Usage.provider,
ProviderAPIKey.name.label("api_key_name"),
).outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
if ids:
query = query.filter(Usage.id.in_(ids))
if user_id:
@@ -1518,8 +1719,9 @@ class UsageService:
)
db.commit()
return [
{
result: List[Dict[str, Any]] = []
for r in records:
item: Dict[str, Any] = {
"id": r.id,
"status": "failed" if r.id in timeout_ids else r.status,
"input_tokens": r.input_tokens,
@@ -1528,8 +1730,12 @@ class UsageService:
"response_time_ms": r.response_time_ms,
"first_byte_time_ms": r.first_byte_time_ms, # 首字时间 (TTFB)
}
for r in records
]
if include_admin_fields:
item["provider"] = r.provider
item["api_key_name"] = r.api_key_name
result.append(item)
return result
# ========== 缓存亲和性分析方法 ==========

View File

@@ -459,34 +459,38 @@ class StreamUsageTracker:
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
chunk_count = 0
first_chunk_received = False
first_byte_time_ms = None # 预先记录 TTFB避免 yield 后计算不准确
try:
async for chunk in stream:
chunk_count += 1
# 保存原始字节流(用于错误诊断)
self.raw_chunks.append(chunk)
# 第一个 chunk 收到时,更新状态为 streaming 并记录 TTFB
if not first_chunk_received:
first_chunk_received = True
if self.request_id:
try:
# 第一个 chunk 收到时,记录 TTFB 时间点(但先不更新数据库,避免阻塞)
if chunk_count == 1:
# 计算 TTFB使用请求原始开始时间或 track_stream 开始时间)
base_time = self.request_start_time or self.start_time
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
# 先返回原始块给客户端,确保 TTFB 不受数据库操作影响
yield chunk
# yield 后再更新数据库状态(仅第一个 chunk 时执行)
if chunk_count == 1 and self.request_id:
try:
UsageService.update_usage_status(
db=self.db,
request_id=self.request_id,
status="streaming",
provider=self.provider,
first_byte_time_ms=first_byte_time_ms,
provider_id=self.provider_id,
provider_endpoint_id=self.provider_endpoint_id,
provider_api_key_id=self.provider_api_key_id,
)
except Exception as e:
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
# 返回原始块给客户端
yield chunk
# 解析块以提取内容和使用信息chunk是原始字节
content, usage = self.parse_stream_chunk(chunk)
@@ -916,15 +920,38 @@ class EnhancedStreamUsageTracker(StreamUsageTracker):
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应(Enhanced) | 估算输入tokens:{self.input_tokens}")
chunk_count = 0
first_byte_time_ms = None # 预先记录 TTFB避免 yield 后计算不准确
try:
async for chunk in stream:
chunk_count += 1
# 保存原始字节流(用于错误诊断)
self.raw_chunks.append(chunk)
# 返回原始块给客户端
# 第一个 chunk 收到时,记录 TTFB 时间点(但先不更新数据库,避免阻塞)
if chunk_count == 1:
# 计算 TTFB使用请求原始开始时间或 track_stream 开始时间)
base_time = self.request_start_time or self.start_time
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
# 先返回原始块给客户端,确保 TTFB 不受数据库操作影响
yield chunk
# yield 后再更新数据库状态(仅第一个 chunk 时执行)
if chunk_count == 1 and self.request_id:
try:
UsageService.update_usage_status(
db=self.db,
request_id=self.request_id,
status="streaming",
provider=self.provider,
first_byte_time_ms=first_byte_time_ms,
provider_id=self.provider_id,
provider_endpoint_id=self.provider_endpoint_id,
provider_api_key_id=self.provider_api_key_id,
)
except Exception as e:
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
# 解析块以提取内容和使用信息chunk是原始字节
content, usage = self.parse_stream_chunk(chunk)

View File

@@ -25,9 +25,10 @@ class ApiKeyService:
allowed_providers: Optional[List[str]] = None,
allowed_api_formats: Optional[List[str]] = None,
allowed_models: Optional[List[str]] = None,
rate_limit: int = 100,
rate_limit: Optional[int] = None,
concurrent_limit: int = 5,
expire_days: Optional[int] = None,
expires_at: Optional[datetime] = None, # 直接传入过期时间,优先于 expire_days
initial_balance_usd: Optional[float] = None,
is_standalone: bool = False,
auto_delete_on_expiry: bool = False,
@@ -44,6 +45,7 @@ class ApiKeyService:
rate_limit: 速率限制
concurrent_limit: 并发限制
expire_days: 过期天数None = 永不过期
expires_at: 直接指定过期时间,优先于 expire_days
initial_balance_usd: 初始余额USD仅用于独立KeyNone = 无限制
is_standalone: 是否为独立余额Key仅管理员可创建
auto_delete_on_expiry: 过期后是否自动删除True=物理删除False=仅禁用)
@@ -54,10 +56,10 @@ class ApiKeyService:
key_hash = ApiKey.hash_key(key)
key_encrypted = crypto_service.encrypt(key) # 加密存储密钥
# 计算过期时间
expires_at = None
if expire_days:
expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
# 计算过期时间:优先使用 expires_at其次使用 expire_days
final_expires_at = expires_at
if final_expires_at is None and expire_days:
final_expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
# 空数组转为 None表示不限制
api_key = ApiKey(
@@ -70,7 +72,7 @@ class ApiKeyService:
allowed_models=allowed_models or None,
rate_limit=rate_limit,
concurrent_limit=concurrent_limit,
expires_at=expires_at,
expires_at=final_expires_at,
balance_used_usd=0.0,
current_balance_usd=initial_balance_usd, # 直接使用初始余额None = 无限制
is_standalone=is_standalone,
@@ -145,6 +147,9 @@ class ApiKeyService:
# 允许显式设置为空数组/None 的字段(空数组会转为 None表示"全部"
nullable_list_fields = {"allowed_providers", "allowed_api_formats", "allowed_models"}
# 允许显式设置为 None 的字段(如 expires_at=None 表示永不过期rate_limit=None 表示无限制)
nullable_fields = {"expires_at", "rate_limit"}
for field, value in kwargs.items():
if field not in updatable_fields:
continue
@@ -153,6 +158,9 @@ class ApiKeyService:
if value is not None:
# 空数组转为 None表示允许全部
setattr(api_key, field, value if value else None)
elif field in nullable_fields:
# 这些字段允许显式设置为 None
setattr(api_key, field, value)
elif value is not None:
setattr(api_key, field, value)

View File

@@ -49,8 +49,16 @@ def cache_result(key_prefix: str, ttl: int = 60, user_specific: bool = True) ->
# 尝试从缓存获取
cached = await redis_client.get(cache_key)
if cached:
try:
result = json.loads(cached)
logger.debug(f"缓存命中: {cache_key}")
return json.loads(cached)
return result
except json.JSONDecodeError as e:
logger.warning(f"缓存解析失败,删除损坏缓存: {cache_key}, 错误: {e}")
try:
await redis_client.delete(cache_key)
except Exception:
pass
# 执行原函数
result = await func(*args, **kwargs)

View File

@@ -7,6 +7,59 @@ from typing import Any
from sqlalchemy import func
def escape_like_pattern(pattern: str) -> str:
"""
转义 SQL LIKE 语句中的特殊字符(%、_、\\
Args:
pattern: 原始搜索模式
Returns:
转义后的模式,可安全用于 LIKE 查询(需配合 escape="\\\\"
Examples:
>>> escape_like_pattern("hello_world%test")
'hello\\\\_world\\\\%test'
"""
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
def safe_truncate_escaped(escaped: str, max_len: int) -> str:
"""
安全截断已转义的字符串,避免截断在转义序列中间
转义后的字符串中,反斜杠总是成对出现(\\\\)或作为转义符(\\%, \\_
如果在某个位置截断导致末尾有奇数个反斜杠,说明截断发生在转义序列中间,
需要去掉最后一个反斜杠以保持转义完整性。
Args:
escaped: 已经过 escape_like_pattern 处理的字符串
max_len: 最大长度
Returns:
截断后的字符串,保证不会破坏转义序列
"""
if len(escaped) <= max_len:
return escaped
truncated = escaped[:max_len]
# 统计末尾连续的反斜杠数量
trailing_backslashes = 0
for i in range(len(truncated) - 1, -1, -1):
if truncated[i] == "\\":
trailing_backslashes += 1
else:
break
# 如果末尾反斜杠数量为奇数,说明截断在转义序列中间
# 需要去掉最后一个反斜杠
if trailing_backslashes % 2 == 1:
truncated = truncated[:-1]
return truncated
def date_trunc_portable(dialect_name: str, interval: str, column: Any) -> Any:
"""
跨数据库的日期截断函数

View File

@@ -7,22 +7,20 @@ from typing import Optional
from fastapi import Request
from src.config import config
def get_client_ip(request: Request) -> str:
"""
获取客户端真实IP地址
按优先级检查:
1. X-Forwarded-For 头(支持代理链,根据可信代理数量提取
2. X-Real-IP 头Nginx 代理
1. X-Real-IP 头(最可靠,由最外层可信 Nginx 直接设置
2. X-Forwarded-For 头的第一个 IP原始客户端
3. 直接客户端IP
安全说明:
- 此函数根据 TRUSTED_PROXY_COUNT 配置来决定信任的代理层数
- 当 TRUSTED_PROXY_COUNT=0 时,不信任任何代理头,直接使用连接 IP
- 当服务直接暴露公网时,应设置 TRUSTED_PROXY_COUNT=0 以防止 IP 伪造
- X-Real-IP 优先级最高,因为它通常由最外层 Nginx 设置为 $remote_addr
Nginx 会直接覆盖这个头,不会传递客户端伪造的值
- 只要最外层 Nginx 配置了 proxy_set_header X-Real-IP $remote_addr; 即可正确获取真实 IP
Args:
request: FastAPI Request 对象
@@ -30,30 +28,19 @@ def get_client_ip(request: Request) -> str:
Returns:
str: 客户端IP地址如果无法获取则返回 "unknown"
"""
trusted_proxy_count = config.trusted_proxy_count
# 如果不信任任何代理,直接返回连接 IP
if trusted_proxy_count == 0:
if request.client and request.client.host:
return request.client.host
return "unknown"
# 优先检查 X-Forwarded-For 头(可能包含代理链)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For 格式: "client, proxy1, proxy2"
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
if len(ips) > trusted_proxy_count:
return ips[-(trusted_proxy_count + 1)]
elif ips:
return ips[0]
# 检查 X-Real-IP 头(通常由 Nginx 设置)
# 优先检查 X-Real-IP 头(由最外层 Nginx 设置,最可靠)
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip.strip()
# 检查 X-Forwarded-For 头,取第一个 IP原始客户端
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For 格式: "client, proxy1, proxy2"
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
if ips:
return ips[0]
# 回退到直接客户端IP
if request.client and request.client.host:
return request.client.host
@@ -109,36 +96,26 @@ def get_request_metadata(request: Request) -> dict:
}
def extract_ip_from_headers(headers: dict, trusted_proxy_count: Optional[int] = None) -> str:
def extract_ip_from_headers(headers: dict) -> str:
"""
从HTTP头字典中提取IP地址用于中间件等场景
Args:
headers: HTTP头字典
trusted_proxy_count: 可信代理层数None 时使用配置值
Returns:
str: 客户端IP地址
"""
if trusted_proxy_count is None:
trusted_proxy_count = config.trusted_proxy_count
# 如果不信任任何代理,返回 unknown调用方需要用其他方式获取连接 IP
if trusted_proxy_count == 0:
return "unknown"
# 检查 X-Forwarded-For
forwarded_for = headers.get("x-forwarded-for", "")
if forwarded_for:
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
if len(ips) > trusted_proxy_count:
return ips[-(trusted_proxy_count + 1)]
elif ips:
return ips[0]
# 检查 X-Real-IP
# 优先检查 X-Real-IP由最外层 Nginx 设置,最可靠)
real_ip = headers.get("x-real-ip", "")
if real_ip:
return real_ip.strip()
# 检查 X-Forwarded-For取第一个 IP
forwarded_for = headers.get("x-forwarded-for", "")
if forwarded_for:
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
if ips:
return ips[0]
return "unknown"

View File

View File

@@ -0,0 +1,440 @@
"""
Billing 模块测试
测试计费模块的核心功能:
- BillingCalculator 计费计算
- 计费模板
- 阶梯计费
- calculate_request_cost 便捷函数
"""
import pytest
from src.services.billing import (
BillingCalculator,
BillingDimension,
BillingTemplates,
BillingUnit,
CostBreakdown,
StandardizedUsage,
calculate_request_cost,
)
from src.services.billing.templates import get_template, list_templates
class TestBillingDimension:
"""测试计费维度"""
def test_calculate_per_1m_tokens(self) -> None:
"""测试 per_1m_tokens 计费"""
dim = BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
)
# 1000 tokens * $3 / 1M = $0.003
cost = dim.calculate(1000, 3.0)
assert abs(cost - 0.003) < 0.0001
def test_calculate_per_request(self) -> None:
"""测试按次计费"""
dim = BillingDimension(
name="request",
usage_field="request_count",
price_field="price_per_request",
unit=BillingUnit.PER_REQUEST,
)
# 按次计费cost = request_count * price
cost = dim.calculate(1, 0.05)
assert cost == 0.05
# 多次请求应按次数计费
cost = dim.calculate(3, 0.05)
assert abs(cost - 0.15) < 0.0001
def test_calculate_zero_usage(self) -> None:
"""测试零用量"""
dim = BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
)
cost = dim.calculate(0, 3.0)
assert cost == 0.0
def test_calculate_zero_price(self) -> None:
"""测试零价格"""
dim = BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
)
cost = dim.calculate(1000, 0.0)
assert cost == 0.0
def test_to_dict_and_from_dict(self) -> None:
"""测试序列化和反序列化"""
dim = BillingDimension(
name="cache_read",
usage_field="cache_read_tokens",
price_field="cache_read_price_per_1m",
unit=BillingUnit.PER_1M_TOKENS,
default_price=0.3,
)
d = dim.to_dict()
restored = BillingDimension.from_dict(d)
assert restored.name == dim.name
assert restored.usage_field == dim.usage_field
assert restored.price_field == dim.price_field
assert restored.unit == dim.unit
assert restored.default_price == dim.default_price
class TestStandardizedUsage:
"""测试标准化 Usage"""
def test_basic_usage(self) -> None:
"""测试基础 usage"""
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
)
assert usage.input_tokens == 1000
assert usage.output_tokens == 500
assert usage.cache_creation_tokens == 0
assert usage.cache_read_tokens == 0
def test_get_field(self) -> None:
"""测试字段获取"""
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
)
assert usage.get("input_tokens") == 1000
assert usage.get("nonexistent", 0) == 0
def test_extra_fields(self) -> None:
"""测试扩展字段"""
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
extra={"custom_field": 123},
)
assert usage.get("custom_field") == 123
def test_to_dict(self) -> None:
"""测试转换为字典"""
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
cache_creation_tokens=100,
)
d = usage.to_dict()
assert d["input_tokens"] == 1000
assert d["output_tokens"] == 500
assert d["cache_creation_tokens"] == 100
class TestCostBreakdown:
"""测试费用明细"""
def test_basic_breakdown(self) -> None:
"""测试基础费用明细"""
breakdown = CostBreakdown(
costs={"input": 0.003, "output": 0.0075},
total_cost=0.0105,
)
assert breakdown.input_cost == 0.003
assert breakdown.output_cost == 0.0075
assert breakdown.total_cost == 0.0105
def test_cache_cost_calculation(self) -> None:
"""测试缓存费用汇总"""
breakdown = CostBreakdown(
costs={
"input": 0.003,
"output": 0.0075,
"cache_creation": 0.001,
"cache_read": 0.0005,
},
total_cost=0.012,
)
# cache_cost = cache_creation + cache_read
assert abs(breakdown.cache_cost - 0.0015) < 0.0001
def test_to_dict(self) -> None:
"""测试转换为字典"""
breakdown = CostBreakdown(
costs={"input": 0.003, "output": 0.0075},
total_cost=0.0105,
tier_index=1,
)
d = breakdown.to_dict()
assert d["total_cost"] == 0.0105
assert d["tier_index"] == 1
assert d["input_cost"] == 0.003
class TestBillingTemplates:
"""测试计费模板"""
def test_claude_template(self) -> None:
"""测试 Claude 模板"""
template = BillingTemplates.CLAUDE_STANDARD
dim_names = [d.name for d in template]
assert "input" in dim_names
assert "output" in dim_names
assert "cache_creation" in dim_names
assert "cache_read" in dim_names
def test_openai_template(self) -> None:
"""测试 OpenAI 模板"""
template = BillingTemplates.OPENAI_STANDARD
dim_names = [d.name for d in template]
assert "input" in dim_names
assert "output" in dim_names
assert "cache_read" in dim_names
# OpenAI 没有缓存创建费用
assert "cache_creation" not in dim_names
def test_gemini_template(self) -> None:
"""测试 Gemini 模板"""
template = BillingTemplates.GEMINI_STANDARD
dim_names = [d.name for d in template]
assert "input" in dim_names
assert "output" in dim_names
assert "cache_read" in dim_names
def test_per_request_template(self) -> None:
"""测试按次计费模板"""
template = BillingTemplates.PER_REQUEST
assert len(template) == 1
assert template[0].name == "request"
assert template[0].unit == BillingUnit.PER_REQUEST
def test_get_template(self) -> None:
"""测试获取模板"""
template = get_template("claude")
assert template == BillingTemplates.CLAUDE_STANDARD
template = get_template("openai")
assert template == BillingTemplates.OPENAI_STANDARD
# 不区分大小写
template = get_template("CLAUDE")
assert template == BillingTemplates.CLAUDE_STANDARD
with pytest.raises(ValueError, match="Unknown billing template"):
get_template("unknown_template")
def test_list_templates(self) -> None:
"""测试列出模板"""
templates = list_templates()
assert "claude" in templates
assert "openai" in templates
assert "gemini" in templates
assert "per_request" in templates
class TestBillingCalculator:
"""测试计费计算器"""
def test_basic_calculation(self) -> None:
"""测试基础计费计算"""
calculator = BillingCalculator(template="claude")
usage = StandardizedUsage(input_tokens=1000, output_tokens=500)
prices = {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0}
result = calculator.calculate(usage, prices)
# 1000 * 3 / 1M = 0.003
assert abs(result.input_cost - 0.003) < 0.0001
# 500 * 15 / 1M = 0.0075
assert abs(result.output_cost - 0.0075) < 0.0001
# Total = 0.0105
assert abs(result.total_cost - 0.0105) < 0.0001
def test_calculation_with_cache(self) -> None:
"""测试带缓存的计费计算"""
calculator = BillingCalculator(template="claude")
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
cache_creation_tokens=200,
cache_read_tokens=300,
)
prices = {
"input_price_per_1m": 3.0,
"output_price_per_1m": 15.0,
"cache_creation_price_per_1m": 3.75,
"cache_read_price_per_1m": 0.3,
}
result = calculator.calculate(usage, prices)
# cache_creation: 200 * 3.75 / 1M = 0.00075
assert abs(result.cache_creation_cost - 0.00075) < 0.0001
# cache_read: 300 * 0.3 / 1M = 0.00009
assert abs(result.cache_read_cost - 0.00009) < 0.0001
def test_tiered_pricing(self) -> None:
"""测试阶梯计费"""
calculator = BillingCalculator(template="claude")
usage = StandardizedUsage(input_tokens=250000, output_tokens=10000)
# 大于 200k 进入第二阶梯
tiered_pricing = {
"tiers": [
{"up_to": 200000, "input_price_per_1m": 3.0, "output_price_per_1m": 15.0},
{"up_to": None, "input_price_per_1m": 1.5, "output_price_per_1m": 7.5},
]
}
prices = {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0}
result = calculator.calculate(usage, prices, tiered_pricing)
# 应该使用第二阶梯价格
assert result.tier_index == 1
# 250000 * 1.5 / 1M = 0.375
assert abs(result.input_cost - 0.375) < 0.0001
def test_openai_no_cache_creation(self) -> None:
"""测试 OpenAI 模板没有缓存创建费用"""
calculator = BillingCalculator(template="openai")
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
cache_creation_tokens=200, # 这个不应该计费
cache_read_tokens=300,
)
prices = {
"input_price_per_1m": 3.0,
"output_price_per_1m": 15.0,
"cache_creation_price_per_1m": 3.75,
"cache_read_price_per_1m": 0.3,
}
result = calculator.calculate(usage, prices)
# OpenAI 模板不包含 cache_creation 维度
assert result.cache_creation_cost == 0.0
# 但 cache_read 应该计费
assert result.cache_read_cost > 0
def test_from_config(self) -> None:
"""测试从配置创建计算器"""
config = {"template": "openai"}
calculator = BillingCalculator.from_config(config)
assert calculator.template_name == "openai"
class TestCalculateRequestCost:
"""测试便捷函数"""
def test_basic_usage(self) -> None:
"""测试基础用法"""
result = calculate_request_cost(
input_tokens=1000,
output_tokens=500,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
cache_creation_price_per_1m=None,
cache_read_price_per_1m=None,
price_per_request=None,
billing_template="claude",
)
assert "input_cost" in result
assert "output_cost" in result
assert "total_cost" in result
assert abs(result["input_cost"] - 0.003) < 0.0001
assert abs(result["output_cost"] - 0.0075) < 0.0001
def test_with_cache(self) -> None:
"""测试带缓存"""
result = calculate_request_cost(
input_tokens=1000,
output_tokens=500,
cache_creation_input_tokens=200,
cache_read_input_tokens=300,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
cache_creation_price_per_1m=3.75,
cache_read_price_per_1m=0.3,
price_per_request=None,
billing_template="claude",
)
assert result["cache_creation_cost"] > 0
assert result["cache_read_cost"] > 0
assert result["cache_cost"] == result["cache_creation_cost"] + result["cache_read_cost"]
def test_different_templates(self) -> None:
"""测试不同模板"""
prices = {
"input_tokens": 1000,
"output_tokens": 500,
"cache_creation_input_tokens": 200,
"cache_read_input_tokens": 300,
"input_price_per_1m": 3.0,
"output_price_per_1m": 15.0,
"cache_creation_price_per_1m": 3.75,
"cache_read_price_per_1m": 0.3,
"price_per_request": None,
}
# Claude 模板有 cache_creation
result_claude = calculate_request_cost(**prices, billing_template="claude")
assert result_claude["cache_creation_cost"] > 0
# OpenAI 模板没有 cache_creation
result_openai = calculate_request_cost(**prices, billing_template="openai")
assert result_openai["cache_creation_cost"] == 0
def test_tiered_pricing_with_total_context(self) -> None:
"""测试使用自定义 total_input_context 的阶梯计费"""
tiered_pricing = {
"tiers": [
{"up_to": 200000, "input_price_per_1m": 3.0, "output_price_per_1m": 15.0},
{"up_to": None, "input_price_per_1m": 1.5, "output_price_per_1m": 7.5},
]
}
# 传入预计算的 total_input_context
result = calculate_request_cost(
input_tokens=1000,
output_tokens=500,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
cache_creation_price_per_1m=None,
cache_read_price_per_1m=None,
price_per_request=None,
tiered_pricing=tiered_pricing,
total_input_context=250000, # 预计算的值,超过 200k
billing_template="claude",
)
# 应该使用第二阶梯价格
assert result["tier_index"] == 1

3999
uv.lock generated

File diff suppressed because it is too large Load Diff