mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-12 12:38:31 +08:00
Compare commits
37 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f6a6410626 | ||
|
|
835be3d329 | ||
|
|
2395093394 | ||
|
|
28209e1c2a | ||
|
|
00562dd1d4 | ||
|
|
0f78d5cbf3 | ||
|
|
431c6de8d2 | ||
|
|
142e15bbcc | ||
|
|
31acc5c607 | ||
|
|
bfa0a26d41 | ||
|
|
93ab9b6a5e | ||
|
|
35e29d46bd | ||
|
|
465da6f818 | ||
|
|
e5f12fddd9 | ||
|
|
4fa9a1303a | ||
|
|
43f349d415 | ||
|
|
02069954de | ||
|
|
2e15875fed | ||
|
|
b34cfb676d | ||
|
|
3064497636 | ||
|
|
dec681fea0 | ||
|
|
523e27ba9a | ||
|
|
e7db76e581 | ||
|
|
689339117a | ||
|
|
b202765be4 | ||
|
|
3bbf3073df | ||
|
|
f46aaa2182 | ||
|
|
a2f33a6c35 | ||
|
|
b6bd6357ed | ||
|
|
c3a5878b1b | ||
|
|
3e4309eba3 | ||
|
|
414f45aa71 | ||
|
|
ebdc76346f | ||
|
|
64bfa955f4 | ||
|
|
612992fa1f | ||
|
|
c02ac56da8 | ||
|
|
9bfb295238 |
@@ -39,7 +39,18 @@ COPY alembic.ini ./
|
||||
COPY alembic/ ./alembic/
|
||||
|
||||
# Nginx 配置模板
|
||||
# 智能处理 IP:有外层代理头就透传,没有就用直连 IP
|
||||
RUN printf '%s\n' \
|
||||
'map $http_x_real_ip $real_ip {' \
|
||||
' default $http_x_real_ip;' \
|
||||
' "" $remote_addr;' \
|
||||
'}' \
|
||||
'' \
|
||||
'map $http_x_forwarded_for $forwarded_for {' \
|
||||
' default $http_x_forwarded_for;' \
|
||||
' "" $remote_addr;' \
|
||||
'}' \
|
||||
'' \
|
||||
'server {' \
|
||||
' listen 80;' \
|
||||
' server_name _;' \
|
||||
@@ -70,8 +81,8 @@ RUN printf '%s\n' \
|
||||
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
|
||||
' proxy_http_version 1.1;' \
|
||||
' proxy_set_header Host $host;' \
|
||||
' proxy_set_header X-Real-IP $remote_addr;' \
|
||||
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
|
||||
' proxy_set_header X-Real-IP $real_ip;' \
|
||||
' proxy_set_header X-Forwarded-For $forwarded_for;' \
|
||||
' proxy_set_header X-Forwarded-Proto $scheme;' \
|
||||
' proxy_set_header Connection "";' \
|
||||
' proxy_set_header Accept $http_accept;' \
|
||||
|
||||
@@ -40,7 +40,18 @@ COPY alembic.ini ./
|
||||
COPY alembic/ ./alembic/
|
||||
|
||||
# Nginx 配置模板
|
||||
# 智能处理 IP:有外层代理头就透传,没有就用直连 IP
|
||||
RUN printf '%s\n' \
|
||||
'map $http_x_real_ip $real_ip {' \
|
||||
' default $http_x_real_ip;' \
|
||||
' "" $remote_addr;' \
|
||||
'}' \
|
||||
'' \
|
||||
'map $http_x_forwarded_for $forwarded_for {' \
|
||||
' default $http_x_forwarded_for;' \
|
||||
' "" $remote_addr;' \
|
||||
'}' \
|
||||
'' \
|
||||
'server {' \
|
||||
' listen 80;' \
|
||||
' server_name _;' \
|
||||
@@ -71,8 +82,8 @@ RUN printf '%s\n' \
|
||||
' proxy_pass http://127.0.0.1:PORT_PLACEHOLDER;' \
|
||||
' proxy_http_version 1.1;' \
|
||||
' proxy_set_header Host $host;' \
|
||||
' proxy_set_header X-Real-IP $remote_addr;' \
|
||||
' proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;' \
|
||||
' proxy_set_header X-Real-IP $real_ip;' \
|
||||
' proxy_set_header X-Forwarded-For $forwarded_for;' \
|
||||
' proxy_set_header X-Forwarded-Proto $scheme;' \
|
||||
' proxy_set_header Connection "";' \
|
||||
' proxy_set_header Accept $http_accept;' \
|
||||
|
||||
10
README.md
10
README.md
@@ -51,20 +51,20 @@ Aether 是一个自托管的 AI API 网关,为团队和个人提供多租户
|
||||
```bash
|
||||
# 1. 克隆代码
|
||||
git clone https://github.com/fawney19/Aether.git
|
||||
cd aether
|
||||
cd Aether
|
||||
|
||||
# 2. 配置环境变量
|
||||
cp .env.example .env
|
||||
python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
|
||||
|
||||
# 3. 部署
|
||||
docker-compose up -d
|
||||
docker compose up -d
|
||||
|
||||
# 4. 首次部署时, 初始化数据库
|
||||
./migrate.sh
|
||||
|
||||
# 5. 更新
|
||||
docker-compose pull && docker-compose up -d && ./migrate.sh
|
||||
docker compose pull && docker compose up -d && ./migrate.sh
|
||||
```
|
||||
|
||||
### Docker Compose(本地构建镜像)
|
||||
@@ -72,7 +72,7 @@ docker-compose pull && docker-compose up -d && ./migrate.sh
|
||||
```bash
|
||||
# 1. 克隆代码
|
||||
git clone https://github.com/fawney19/Aether.git
|
||||
cd aether
|
||||
cd Aether
|
||||
|
||||
# 2. 配置环境变量
|
||||
cp .env.example .env
|
||||
@@ -86,7 +86,7 @@ python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
|
||||
|
||||
```bash
|
||||
# 启动依赖
|
||||
docker-compose -f docker-compose.build.yml up -d postgres redis
|
||||
docker compose -f docker-compose.build.yml up -d postgres redis
|
||||
|
||||
# 后端
|
||||
uv sync
|
||||
|
||||
@@ -30,7 +30,7 @@ from src.models.database import Base
|
||||
config = context.config
|
||||
|
||||
# 从环境变量获取数据库 URL
|
||||
# 优先使用 DATABASE_URL,否则从 DB_PASSWORD 自动构建(与 docker-compose 保持一致)
|
||||
# 优先使用 DATABASE_URL,否则从 DB_PASSWORD 自动构建(与 docker compose 保持一致)
|
||||
database_url = os.getenv("DATABASE_URL")
|
||||
if not database_url:
|
||||
db_password = os.getenv("DB_PASSWORD", "")
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
"""add ldap authentication support
|
||||
|
||||
Revision ID: c3d4e5f6g7h8
|
||||
Revises: b2c3d4e5f6g7
|
||||
Create Date: 2026-01-01 14:00:00.000000+00:00
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import text
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'c3d4e5f6g7h8'
|
||||
down_revision = 'b2c3d4e5f6g7'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def _type_exists(conn, type_name: str) -> bool:
|
||||
"""检查 PostgreSQL 类型是否存在"""
|
||||
result = conn.execute(
|
||||
text("SELECT 1 FROM pg_type WHERE typname = :name"),
|
||||
{"name": type_name}
|
||||
)
|
||||
return result.scalar() is not None
|
||||
|
||||
|
||||
def _column_exists(conn, table_name: str, column_name: str) -> bool:
|
||||
"""检查列是否存在"""
|
||||
result = conn.execute(
|
||||
text("""
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = :table AND column_name = :column
|
||||
"""),
|
||||
{"table": table_name, "column": column_name}
|
||||
)
|
||||
return result.scalar() is not None
|
||||
|
||||
|
||||
def _index_exists(conn, index_name: str) -> bool:
|
||||
"""检查索引是否存在"""
|
||||
result = conn.execute(
|
||||
text("SELECT 1 FROM pg_indexes WHERE indexname = :name"),
|
||||
{"name": index_name}
|
||||
)
|
||||
return result.scalar() is not None
|
||||
|
||||
|
||||
def _table_exists(conn, table_name: str) -> bool:
|
||||
"""检查表是否存在"""
|
||||
result = conn.execute(
|
||||
text("""
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_name = :name AND table_schema = 'public'
|
||||
"""),
|
||||
{"name": table_name}
|
||||
)
|
||||
return result.scalar() is not None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""添加 LDAP 认证支持
|
||||
|
||||
1. 创建 authsource 枚举类型
|
||||
2. 在 users 表添加 auth_source 字段和 LDAP 标识字段
|
||||
3. 创建 ldap_configs 表
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# 1. 创建 authsource 枚举类型(幂等)
|
||||
if not _type_exists(conn, 'authsource'):
|
||||
conn.execute(text("CREATE TYPE authsource AS ENUM ('local', 'ldap')"))
|
||||
|
||||
# 2. 在 users 表添加字段(幂等)
|
||||
if not _column_exists(conn, 'users', 'auth_source'):
|
||||
op.add_column('users', sa.Column(
|
||||
'auth_source',
|
||||
sa.Enum('local', 'ldap', name='authsource', create_type=False),
|
||||
nullable=False,
|
||||
server_default='local'
|
||||
))
|
||||
|
||||
if not _column_exists(conn, 'users', 'ldap_dn'):
|
||||
op.add_column('users', sa.Column('ldap_dn', sa.String(length=512), nullable=True))
|
||||
|
||||
if not _column_exists(conn, 'users', 'ldap_username'):
|
||||
op.add_column('users', sa.Column('ldap_username', sa.String(length=255), nullable=True))
|
||||
|
||||
# 创建索引(幂等)
|
||||
if not _index_exists(conn, 'ix_users_ldap_dn'):
|
||||
op.create_index('ix_users_ldap_dn', 'users', ['ldap_dn'])
|
||||
|
||||
if not _index_exists(conn, 'ix_users_ldap_username'):
|
||||
op.create_index('ix_users_ldap_username', 'users', ['ldap_username'])
|
||||
|
||||
# 3. 创建 ldap_configs 表(幂等)
|
||||
if not _table_exists(conn, 'ldap_configs'):
|
||||
op.create_table(
|
||||
'ldap_configs',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('server_url', sa.String(length=255), nullable=False),
|
||||
sa.Column('bind_dn', sa.String(length=255), nullable=False),
|
||||
sa.Column('bind_password_encrypted', sa.Text(), nullable=True),
|
||||
sa.Column('base_dn', sa.String(length=255), nullable=False),
|
||||
sa.Column('user_search_filter', sa.String(length=500), nullable=False, server_default='(uid={username})'),
|
||||
sa.Column('username_attr', sa.String(length=50), nullable=False, server_default='uid'),
|
||||
sa.Column('email_attr', sa.String(length=50), nullable=False, server_default='mail'),
|
||||
sa.Column('display_name_attr', sa.String(length=50), nullable=False, server_default='cn'),
|
||||
sa.Column('is_enabled', sa.Boolean(), nullable=False, server_default='false'),
|
||||
sa.Column('is_exclusive', sa.Boolean(), nullable=False, server_default='false'),
|
||||
sa.Column('use_starttls', sa.Boolean(), nullable=False, server_default='false'),
|
||||
sa.Column('connect_timeout', sa.Integer(), nullable=False, server_default='10'),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""回滚 LDAP 认证支持
|
||||
|
||||
警告:回滚前请确保:
|
||||
1. 已备份数据库
|
||||
2. 没有 LDAP 用户需要保留
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# 检查是否存在 LDAP 用户,防止数据丢失
|
||||
if _column_exists(conn, 'users', 'auth_source'):
|
||||
result = conn.execute(text("SELECT COUNT(*) FROM users WHERE auth_source = 'ldap'"))
|
||||
ldap_user_count = result.scalar()
|
||||
if ldap_user_count and ldap_user_count > 0:
|
||||
raise RuntimeError(
|
||||
f"无法回滚:存在 {ldap_user_count} 个 LDAP 用户。"
|
||||
f"请先删除或转换这些用户,或使用 --force 参数强制回滚(将丢失数据)。"
|
||||
)
|
||||
|
||||
# 1. 删除 ldap_configs 表(幂等)
|
||||
if _table_exists(conn, 'ldap_configs'):
|
||||
op.drop_table('ldap_configs')
|
||||
|
||||
# 2. 删除 users 表的 LDAP 相关字段(幂等)
|
||||
if _index_exists(conn, 'ix_users_ldap_username'):
|
||||
op.drop_index('ix_users_ldap_username', table_name='users')
|
||||
|
||||
if _index_exists(conn, 'ix_users_ldap_dn'):
|
||||
op.drop_index('ix_users_ldap_dn', table_name='users')
|
||||
|
||||
if _column_exists(conn, 'users', 'ldap_username'):
|
||||
op.drop_column('users', 'ldap_username')
|
||||
|
||||
if _column_exists(conn, 'users', 'ldap_dn'):
|
||||
op.drop_column('users', 'ldap_dn')
|
||||
|
||||
if _column_exists(conn, 'users', 'auth_source'):
|
||||
op.drop_column('users', 'auth_source')
|
||||
|
||||
# 3. 删除 authsource 枚举类型(幂等)
|
||||
# 注意:不使用 CASCADE,因为此时所有依赖应该已被删除
|
||||
if _type_exists(conn, 'authsource'):
|
||||
conn.execute(text("DROP TYPE authsource"))
|
||||
@@ -1,7 +1,7 @@
|
||||
# Aether 部署配置 - 本地构建
|
||||
# 使用方法:
|
||||
# 首次构建 base: docker build -f Dockerfile.base -t aether-base:latest .
|
||||
# 启动服务: docker-compose -f docker-compose.build.yml up -d --build
|
||||
# 启动服务: docker compose -f docker-compose.build.yml up -d --build
|
||||
|
||||
services:
|
||||
postgres:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Aether 部署配置 - 使用预构建镜像
|
||||
# 使用方法: docker-compose up -d
|
||||
# 使用方法: docker compose up -d
|
||||
|
||||
services:
|
||||
postgres:
|
||||
|
||||
@@ -13,6 +13,7 @@ export interface UsersExportData {
|
||||
version: string
|
||||
exported_at: string
|
||||
users: UserExport[]
|
||||
standalone_keys?: StandaloneKeyExport[]
|
||||
}
|
||||
|
||||
export interface UserExport {
|
||||
@@ -42,15 +43,19 @@ export interface UserApiKeyExport {
|
||||
allowed_endpoints?: string[] | null
|
||||
allowed_api_formats?: string[] | null
|
||||
allowed_models?: string[] | null
|
||||
rate_limit?: number
|
||||
rate_limit?: number | null // null = 无限制
|
||||
concurrent_limit?: number | null
|
||||
force_capabilities?: any
|
||||
is_active: boolean
|
||||
expires_at?: string | null
|
||||
auto_delete_on_expiry?: boolean
|
||||
total_requests?: number
|
||||
total_cost_usd?: number
|
||||
}
|
||||
|
||||
// 独立余额 Key 导出结构(与 UserApiKeyExport 相同,但不包含 is_standalone)
|
||||
export type StandaloneKeyExport = Omit<UserApiKeyExport, 'is_standalone'>
|
||||
|
||||
export interface GlobalModelExport {
|
||||
name: string
|
||||
display_name: string
|
||||
@@ -155,6 +160,44 @@ export interface EmailTemplateResetResponse {
|
||||
}
|
||||
}
|
||||
|
||||
// LDAP 配置响应
|
||||
export interface LdapConfigResponse {
|
||||
server_url: string | null
|
||||
bind_dn: string | null
|
||||
base_dn: string | null
|
||||
has_bind_password: boolean
|
||||
user_search_filter: string
|
||||
username_attr: string
|
||||
email_attr: string
|
||||
display_name_attr: string
|
||||
is_enabled: boolean
|
||||
is_exclusive: boolean
|
||||
use_starttls: boolean
|
||||
connect_timeout: number
|
||||
}
|
||||
|
||||
// LDAP 配置更新请求
|
||||
export interface LdapConfigUpdateRequest {
|
||||
server_url: string
|
||||
bind_dn: string
|
||||
bind_password?: string
|
||||
base_dn: string
|
||||
user_search_filter?: string
|
||||
username_attr?: string
|
||||
email_attr?: string
|
||||
display_name_attr?: string
|
||||
is_enabled?: boolean
|
||||
is_exclusive?: boolean
|
||||
use_starttls?: boolean
|
||||
connect_timeout?: number
|
||||
}
|
||||
|
||||
// LDAP 连接测试响应
|
||||
export interface LdapTestResponse {
|
||||
success: boolean
|
||||
message: string
|
||||
}
|
||||
|
||||
// Provider 模型查询响应
|
||||
export interface ProviderModelsQueryResponse {
|
||||
success: boolean
|
||||
@@ -189,6 +232,7 @@ export interface UsersImportResponse {
|
||||
stats: {
|
||||
users: { created: number; updated: number; skipped: number }
|
||||
api_keys: { created: number; skipped: number }
|
||||
standalone_keys?: { created: number; skipped: number }
|
||||
errors: string[]
|
||||
}
|
||||
}
|
||||
@@ -220,7 +264,7 @@ export interface AdminApiKey {
|
||||
total_requests?: number
|
||||
total_tokens?: number
|
||||
total_cost_usd?: number
|
||||
rate_limit?: number
|
||||
rate_limit?: number | null // null = 无限制
|
||||
allowed_providers?: string[] | null // 允许的提供商列表
|
||||
allowed_api_formats?: string[] | null // 允许的 API 格式列表
|
||||
allowed_models?: string[] | null // 允许的模型列表
|
||||
@@ -236,8 +280,8 @@ export interface CreateStandaloneApiKeyRequest {
|
||||
allowed_providers?: string[] | null
|
||||
allowed_api_formats?: string[] | null
|
||||
allowed_models?: string[] | null
|
||||
rate_limit?: number
|
||||
expire_days?: number | null // null = 永不过期
|
||||
rate_limit?: number | null // null = 无限制
|
||||
expires_at?: string | null // ISO 日期字符串,如 "2025-12-31",null = 永不过期
|
||||
initial_balance_usd: number // 初始余额,必须设置
|
||||
auto_delete_on_expiry?: boolean // 过期后是否自动删除
|
||||
}
|
||||
@@ -473,5 +517,35 @@ export const adminApi = {
|
||||
`/api/admin/system/email/templates/${templateType}/reset`
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 获取系统版本信息
|
||||
async getSystemVersion(): Promise<{ version: string }> {
|
||||
const response = await apiClient.get<{ version: string }>(
|
||||
'/api/admin/system/version'
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
// LDAP 配置相关
|
||||
// 获取 LDAP 配置
|
||||
async getLdapConfig(): Promise<LdapConfigResponse> {
|
||||
const response = await apiClient.get<LdapConfigResponse>('/api/admin/ldap/config')
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 更新 LDAP 配置
|
||||
async updateLdapConfig(config: LdapConfigUpdateRequest): Promise<{ message: string }> {
|
||||
const response = await apiClient.put<{ message: string }>(
|
||||
'/api/admin/ldap/config',
|
||||
config
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 测试 LDAP 连接
|
||||
async testLdapConnection(config: LdapConfigUpdateRequest): Promise<LdapTestResponse> {
|
||||
const response = await apiClient.post<LdapTestResponse>('/api/admin/ldap/test', config)
|
||||
return response.data
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import { log } from '@/utils/logger'
|
||||
export interface LoginRequest {
|
||||
email: string
|
||||
password: string
|
||||
auth_type?: 'local' | 'ldap'
|
||||
}
|
||||
|
||||
export interface LoginResponse {
|
||||
@@ -81,6 +82,12 @@ export interface RegistrationSettingsResponse {
|
||||
require_email_verification: boolean
|
||||
}
|
||||
|
||||
export interface AuthSettingsResponse {
|
||||
local_enabled: boolean
|
||||
ldap_enabled: boolean
|
||||
ldap_exclusive: boolean
|
||||
}
|
||||
|
||||
export interface User {
|
||||
id: string // UUID
|
||||
username: string
|
||||
@@ -173,5 +180,10 @@ export const authApi = {
|
||||
{ email }
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
async getAuthSettings(): Promise<AuthSettingsResponse> {
|
||||
const response = await apiClient.get<AuthSettingsResponse>('/api/auth/settings')
|
||||
return response.data
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,6 +62,11 @@ export interface UsageRecordDetail {
|
||||
cache_creation_price_per_1m?: number
|
||||
cache_read_price_per_1m?: number
|
||||
price_per_request?: number // 按次计费价格
|
||||
api_key?: {
|
||||
id: string
|
||||
name: string
|
||||
display: string
|
||||
}
|
||||
}
|
||||
|
||||
// 模型统计接口
|
||||
@@ -75,6 +80,16 @@ export interface ModelSummary {
|
||||
actual_total_cost_usd?: number // 倍率消耗(仅管理员可见)
|
||||
}
|
||||
|
||||
// 提供商统计接口
|
||||
export interface ProviderSummary {
|
||||
provider: string
|
||||
requests: number
|
||||
total_tokens: number
|
||||
total_cost_usd: number
|
||||
success_rate: number | null
|
||||
avg_response_time_ms: number | null
|
||||
}
|
||||
|
||||
// 使用统计响应接口
|
||||
export interface UsageResponse {
|
||||
total_requests: number
|
||||
@@ -87,6 +102,13 @@ export interface UsageResponse {
|
||||
quota_usd: number | null
|
||||
used_usd: number
|
||||
summary_by_model: ModelSummary[]
|
||||
summary_by_provider?: ProviderSummary[]
|
||||
pagination?: {
|
||||
total: number
|
||||
limit: number
|
||||
offset: number
|
||||
has_more: boolean
|
||||
}
|
||||
records: UsageRecordDetail[]
|
||||
activity_heatmap?: ActivityHeatmap | null
|
||||
}
|
||||
@@ -175,6 +197,9 @@ export const meApi = {
|
||||
async getUsage(params?: {
|
||||
start_date?: string
|
||||
end_date?: string
|
||||
search?: string // 通用搜索:密钥名、模型名
|
||||
limit?: number
|
||||
offset?: number
|
||||
}): Promise<UsageResponse> {
|
||||
const response = await apiClient.get<UsageResponse>('/api/users/me/usage', { params })
|
||||
return response.data
|
||||
@@ -184,11 +209,12 @@ export const meApi = {
|
||||
async getActiveRequests(ids?: string): Promise<{
|
||||
requests: Array<{
|
||||
id: string
|
||||
status: string
|
||||
status: 'pending' | 'streaming' | 'completed' | 'failed'
|
||||
input_tokens: number
|
||||
output_tokens: number
|
||||
cost: number
|
||||
response_time_ms: number | null
|
||||
first_byte_time_ms: number | null
|
||||
}>
|
||||
}> {
|
||||
const params = ids ? { ids } : {}
|
||||
@@ -267,5 +293,14 @@ export const meApi = {
|
||||
}> {
|
||||
const response = await apiClient.get('/api/users/me/usage/interval-timeline', { params })
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 获取活跃度热力图数据(用户)
|
||||
* 后端已缓存5分钟
|
||||
*/
|
||||
async getActivityHeatmap(): Promise<ActivityHeatmap> {
|
||||
const response = await apiClient.get<ActivityHeatmap>('/api/users/me/usage/heatmap')
|
||||
return response.data
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,10 +192,17 @@ export async function getModelsDevList(officialOnly: boolean = true): Promise<Mo
|
||||
}
|
||||
}
|
||||
|
||||
// 按 provider 名称和模型名称排序
|
||||
// 按 provider 名称排序,provider 中的模型按 release_date 从近到远排序
|
||||
items.sort((a, b) => {
|
||||
const providerCompare = a.providerName.localeCompare(b.providerName)
|
||||
if (providerCompare !== 0) return providerCompare
|
||||
|
||||
// 模型按 release_date 从近到远排序(没有日期的排到最后)
|
||||
const aDate = a.releaseDate ? new Date(a.releaseDate).getTime() : 0
|
||||
const bDate = b.releaseDate ? new Date(b.releaseDate).getTime() : 0
|
||||
if (aDate !== bDate) return bDate - aDate // 降序:新的在前
|
||||
|
||||
// 日期相同或都没有日期时,按模型名称排序
|
||||
return a.modelName.localeCompare(b.modelName)
|
||||
})
|
||||
|
||||
|
||||
@@ -164,6 +164,7 @@ export const usageApi = {
|
||||
async getAllUsageRecords(params?: {
|
||||
start_date?: string
|
||||
end_date?: string
|
||||
search?: string // 通用搜索:用户名、密钥名、模型名、提供商名
|
||||
user_id?: string // UUID
|
||||
username?: string
|
||||
model?: string
|
||||
@@ -193,10 +194,22 @@ export const usageApi = {
|
||||
output_tokens: number
|
||||
cost: number
|
||||
response_time_ms: number | null
|
||||
first_byte_time_ms: number | null
|
||||
provider?: string | null
|
||||
api_key_name?: string | null
|
||||
}>
|
||||
}> {
|
||||
const params = ids?.length ? { ids: ids.join(',') } : {}
|
||||
const response = await apiClient.get('/api/admin/usage/active', { params })
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 获取活跃度热力图数据(管理员)
|
||||
* 后端已缓存5分钟
|
||||
*/
|
||||
async getActivityHeatmap(): Promise<ActivityHeatmap> {
|
||||
const response = await apiClient.get<ActivityHeatmap>('/api/admin/usage/heatmap')
|
||||
return response.data
|
||||
}
|
||||
}
|
||||
|
||||
117
frontend/src/components/common/ModelMultiSelect.vue
Normal file
117
frontend/src/components/common/ModelMultiSelect.vue
Normal file
@@ -0,0 +1,117 @@
|
||||
<template>
|
||||
<div class="space-y-2">
|
||||
<Label class="text-sm font-medium">允许的模型</Label>
|
||||
<div class="relative">
|
||||
<button
|
||||
type="button"
|
||||
class="w-full h-10 px-3 border rounded-lg bg-background text-left flex items-center justify-between hover:bg-muted/50 transition-colors"
|
||||
@click="isOpen = !isOpen"
|
||||
>
|
||||
<span :class="modelValue.length ? 'text-foreground' : 'text-muted-foreground'">
|
||||
{{ modelValue.length ? `已选择 ${modelValue.length} 个` : '全部可用' }}
|
||||
<span
|
||||
v-if="invalidModels.length"
|
||||
class="text-destructive"
|
||||
>({{ invalidModels.length }} 个已失效)</span>
|
||||
</span>
|
||||
<ChevronDown
|
||||
class="h-4 w-4 text-muted-foreground transition-transform"
|
||||
:class="isOpen ? 'rotate-180' : ''"
|
||||
/>
|
||||
</button>
|
||||
<div
|
||||
v-if="isOpen"
|
||||
class="fixed inset-0 z-[80]"
|
||||
@click.stop="isOpen = false"
|
||||
/>
|
||||
<div
|
||||
v-if="isOpen"
|
||||
class="absolute z-[90] w-full mt-1 bg-popover border rounded-lg shadow-lg max-h-48 overflow-y-auto"
|
||||
>
|
||||
<!-- 失效模型(置顶显示,只能取消选择) -->
|
||||
<div
|
||||
v-for="modelName in invalidModels"
|
||||
:key="modelName"
|
||||
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer bg-destructive/5"
|
||||
@click="removeModel(modelName)"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="true"
|
||||
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
|
||||
@click.stop
|
||||
@change="removeModel(modelName)"
|
||||
>
|
||||
<span class="text-sm text-destructive">{{ modelName }}</span>
|
||||
<span class="text-xs text-destructive/70">(已失效)</span>
|
||||
</div>
|
||||
<!-- 有效模型 -->
|
||||
<div
|
||||
v-for="model in models"
|
||||
:key="model.name"
|
||||
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer"
|
||||
@click="toggleModel(model.name)"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="modelValue.includes(model.name)"
|
||||
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
|
||||
@click.stop
|
||||
@change="toggleModel(model.name)"
|
||||
>
|
||||
<span class="text-sm">{{ model.name }}</span>
|
||||
</div>
|
||||
<div
|
||||
v-if="models.length === 0 && invalidModels.length === 0"
|
||||
class="px-3 py-2 text-sm text-muted-foreground"
|
||||
>
|
||||
暂无可用模型
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue'
|
||||
import { Label } from '@/components/ui'
|
||||
import { ChevronDown } from 'lucide-vue-next'
|
||||
import { useInvalidModels } from '@/composables/useInvalidModels'
|
||||
|
||||
export interface ModelWithName {
|
||||
name: string
|
||||
}
|
||||
|
||||
const props = defineProps<{
|
||||
modelValue: string[]
|
||||
models: ModelWithName[]
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:modelValue': [value: string[]]
|
||||
}>()
|
||||
|
||||
const isOpen = ref(false)
|
||||
|
||||
// 检测失效模型
|
||||
const { invalidModels } = useInvalidModels(
|
||||
computed(() => props.modelValue),
|
||||
computed(() => props.models)
|
||||
)
|
||||
|
||||
function toggleModel(name: string) {
|
||||
const newValue = [...props.modelValue]
|
||||
const index = newValue.indexOf(name)
|
||||
if (index === -1) {
|
||||
newValue.push(name)
|
||||
} else {
|
||||
newValue.splice(index, 1)
|
||||
}
|
||||
emit('update:modelValue', newValue)
|
||||
}
|
||||
|
||||
function removeModel(name: string) {
|
||||
const newValue = props.modelValue.filter(m => m !== name)
|
||||
emit('update:modelValue', newValue)
|
||||
}
|
||||
</script>
|
||||
@@ -7,3 +7,6 @@
|
||||
export { default as EmptyState } from './EmptyState.vue'
|
||||
export { default as AlertDialog } from './AlertDialog.vue'
|
||||
export { default as LoadingState } from './LoadingState.vue'
|
||||
|
||||
// 表单组件
|
||||
export { default as ModelMultiSelect } from './ModelMultiSelect.vue'
|
||||
|
||||
13
frontend/src/components/icons/GithubIcon.vue
Normal file
13
frontend/src/components/icons/GithubIcon.vue
Normal file
@@ -0,0 +1,13 @@
|
||||
<template>
|
||||
<svg
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path d="M15 22v-4a4.8 4.8 0 0 0-1-3.5c3 0 6-2 6-5.5.08-1.25-.27-2.48-1-3.5.28-1.15.28-2.35 0-3.5 0 0-1 0-3 1.5-2.64-.5-5.36-.5-8 0C6 2 5 2 5 2c-.3 1.15-.3 2.35 0 3.5A5.403 5.403 0 0 0 4 9c0 3.5 3 5.5 6 5.5-.39.49-.68 1.05-.85 1.65-.17.6-.22 1.23-.15 1.85v4" />
|
||||
<path d="M9 18c-4.51 2-5-2-7-2" />
|
||||
</svg>
|
||||
</template>
|
||||
34
frontend/src/composables/useInvalidModels.ts
Normal file
34
frontend/src/composables/useInvalidModels.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import { computed, type Ref, type ComputedRef } from 'vue'
|
||||
|
||||
/**
|
||||
* 检测失效模型的 composable
|
||||
*
|
||||
* 用于检测 allowed_models 中已不存在于 globalModels 的模型名称,
|
||||
* 这些模型可能已被删除但引用未清理。
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const { invalidModels } = useInvalidModels(
|
||||
* computed(() => form.value.allowed_models),
|
||||
* globalModels
|
||||
* )
|
||||
* ```
|
||||
*/
|
||||
export interface ModelWithName {
|
||||
name: string
|
||||
}
|
||||
|
||||
export function useInvalidModels<T extends ModelWithName>(
|
||||
allowedModels: Ref<string[]> | ComputedRef<string[]>,
|
||||
globalModels: Ref<T[]>
|
||||
): { invalidModels: ComputedRef<string[]> } {
|
||||
const validModelNames = computed(() =>
|
||||
new Set(globalModels.value.map(m => m.name))
|
||||
)
|
||||
|
||||
const invalidModels = computed(() =>
|
||||
allowedModels.value.filter(name => !validModelNames.value.has(name))
|
||||
)
|
||||
|
||||
return { invalidModels }
|
||||
}
|
||||
@@ -79,45 +79,45 @@
|
||||
|
||||
<div class="space-y-2">
|
||||
<Label
|
||||
for="form-expire-days"
|
||||
for="form-expires-at"
|
||||
class="text-sm font-medium"
|
||||
>有效期设置</Label>
|
||||
<div class="flex items-center gap-2">
|
||||
<Input
|
||||
id="form-expire-days"
|
||||
:model-value="form.expire_days ?? ''"
|
||||
type="number"
|
||||
min="1"
|
||||
max="3650"
|
||||
placeholder="天数"
|
||||
:class="form.never_expire ? 'flex-1 h-9 opacity-50' : 'flex-1 h-9'"
|
||||
:disabled="form.never_expire"
|
||||
@update:model-value="(v) => form.expire_days = parseNumberInput(v, { min: 1, max: 3650 })"
|
||||
/>
|
||||
<label class="flex items-center gap-1.5 border rounded-md px-2 py-1.5 bg-muted/50 cursor-pointer text-xs whitespace-nowrap">
|
||||
<input
|
||||
v-model="form.never_expire"
|
||||
type="checkbox"
|
||||
class="h-3.5 w-3.5 rounded border-gray-300 cursor-pointer"
|
||||
@change="onNeverExpireChange"
|
||||
<div class="relative flex-1">
|
||||
<Input
|
||||
id="form-expires-at"
|
||||
:model-value="form.expires_at || ''"
|
||||
type="date"
|
||||
:min="minExpiryDate"
|
||||
class="h-9 pr-8"
|
||||
:placeholder="form.expires_at ? '' : '永不过期'"
|
||||
@update:model-value="(v) => form.expires_at = v || undefined"
|
||||
/>
|
||||
<button
|
||||
v-if="form.expires_at"
|
||||
type="button"
|
||||
class="absolute right-2 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground"
|
||||
title="清空(永不过期)"
|
||||
@click="clearExpiryDate"
|
||||
>
|
||||
永不过期
|
||||
</label>
|
||||
<X class="h-4 w-4" />
|
||||
</button>
|
||||
</div>
|
||||
<label
|
||||
class="flex items-center gap-1.5 border rounded-md px-2 py-1.5 bg-muted/50 cursor-pointer text-xs whitespace-nowrap"
|
||||
:class="form.never_expire ? 'opacity-50' : ''"
|
||||
:class="!form.expires_at ? 'opacity-50 cursor-not-allowed' : ''"
|
||||
>
|
||||
<input
|
||||
v-model="form.auto_delete_on_expiry"
|
||||
type="checkbox"
|
||||
class="h-3.5 w-3.5 rounded border-gray-300 cursor-pointer"
|
||||
:disabled="form.never_expire"
|
||||
:disabled="!form.expires_at"
|
||||
>
|
||||
到期删除
|
||||
</label>
|
||||
</div>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
不勾选"到期删除"则仅禁用
|
||||
{{ form.expires_at ? '到期后' + (form.auto_delete_on_expiry ? '自动删除' : '仅禁用') + '(当天 23:59 失效)' : '留空表示永不过期' }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
@@ -244,55 +244,10 @@
|
||||
</div>
|
||||
|
||||
<!-- 模型多选下拉框 -->
|
||||
<div class="space-y-2">
|
||||
<Label class="text-sm font-medium">允许的模型</Label>
|
||||
<div class="relative">
|
||||
<button
|
||||
type="button"
|
||||
class="w-full h-10 px-3 border rounded-lg bg-background text-left flex items-center justify-between hover:bg-muted/50 transition-colors"
|
||||
@click="modelDropdownOpen = !modelDropdownOpen"
|
||||
>
|
||||
<span :class="form.allowed_models.length ? 'text-foreground' : 'text-muted-foreground'">
|
||||
{{ form.allowed_models.length ? `已选择 ${form.allowed_models.length} 个` : '全部可用' }}
|
||||
</span>
|
||||
<ChevronDown
|
||||
class="h-4 w-4 text-muted-foreground transition-transform"
|
||||
:class="modelDropdownOpen ? 'rotate-180' : ''"
|
||||
/>
|
||||
</button>
|
||||
<div
|
||||
v-if="modelDropdownOpen"
|
||||
class="fixed inset-0 z-[80]"
|
||||
@click.stop="modelDropdownOpen = false"
|
||||
/>
|
||||
<div
|
||||
v-if="modelDropdownOpen"
|
||||
class="absolute z-[90] w-full mt-1 bg-popover border rounded-lg shadow-lg max-h-48 overflow-y-auto"
|
||||
>
|
||||
<div
|
||||
v-for="model in globalModels"
|
||||
:key="model.name"
|
||||
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer"
|
||||
@click="toggleSelection('allowed_models', model.name)"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.allowed_models.includes(model.name)"
|
||||
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
|
||||
@click.stop
|
||||
@change="toggleSelection('allowed_models', model.name)"
|
||||
>
|
||||
<span class="text-sm">{{ model.name }}</span>
|
||||
</div>
|
||||
<div
|
||||
v-if="globalModels.length === 0"
|
||||
class="px-3 py-2 text-sm text-muted-foreground"
|
||||
>
|
||||
暂无可用模型
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<ModelMultiSelect
|
||||
v-model="form.allowed_models"
|
||||
:models="globalModels"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
@@ -325,8 +280,9 @@ import {
|
||||
Input,
|
||||
Label,
|
||||
} from '@/components/ui'
|
||||
import { Plus, SquarePen, Key, Shield, ChevronDown } from 'lucide-vue-next'
|
||||
import { Plus, SquarePen, Key, Shield, ChevronDown, X } from 'lucide-vue-next'
|
||||
import { useFormDialog } from '@/composables/useFormDialog'
|
||||
import { ModelMultiSelect } from '@/components/common'
|
||||
import { getProvidersSummary } from '@/api/endpoints/providers'
|
||||
import { getGlobalModels } from '@/api/global-models'
|
||||
import { adminApi } from '@/api/admin'
|
||||
@@ -338,8 +294,7 @@ export interface StandaloneKeyFormData {
|
||||
id?: string
|
||||
name: string
|
||||
initial_balance_usd?: number
|
||||
expire_days?: number
|
||||
never_expire: boolean
|
||||
expires_at?: string // ISO 日期字符串,如 "2025-12-31",undefined = 永不过期
|
||||
rate_limit?: number
|
||||
auto_delete_on_expiry: boolean
|
||||
allowed_providers: string[]
|
||||
@@ -363,7 +318,6 @@ const saving = ref(false)
|
||||
// 下拉框状态
|
||||
const providerDropdownOpen = ref(false)
|
||||
const apiFormatDropdownOpen = ref(false)
|
||||
const modelDropdownOpen = ref(false)
|
||||
|
||||
// 选项数据
|
||||
const providers = ref<ProviderWithEndpointsSummary[]>([])
|
||||
@@ -374,8 +328,7 @@ const allApiFormats = ref<string[]>([])
|
||||
const form = ref<StandaloneKeyFormData>({
|
||||
name: '',
|
||||
initial_balance_usd: 10,
|
||||
expire_days: undefined,
|
||||
never_expire: true,
|
||||
expires_at: undefined,
|
||||
rate_limit: undefined,
|
||||
auto_delete_on_expiry: false,
|
||||
allowed_providers: [],
|
||||
@@ -383,12 +336,18 @@ const form = ref<StandaloneKeyFormData>({
|
||||
allowed_models: []
|
||||
})
|
||||
|
||||
// 计算最小可选日期(明天)
|
||||
const minExpiryDate = computed(() => {
|
||||
const tomorrow = new Date()
|
||||
tomorrow.setDate(tomorrow.getDate() + 1)
|
||||
return tomorrow.toISOString().split('T')[0]
|
||||
})
|
||||
|
||||
function resetForm() {
|
||||
form.value = {
|
||||
name: '',
|
||||
initial_balance_usd: 10,
|
||||
expire_days: undefined,
|
||||
never_expire: true,
|
||||
expires_at: undefined,
|
||||
rate_limit: undefined,
|
||||
auto_delete_on_expiry: false,
|
||||
allowed_providers: [],
|
||||
@@ -397,7 +356,6 @@ function resetForm() {
|
||||
}
|
||||
providerDropdownOpen.value = false
|
||||
apiFormatDropdownOpen.value = false
|
||||
modelDropdownOpen.value = false
|
||||
}
|
||||
|
||||
function loadKeyData() {
|
||||
@@ -406,8 +364,7 @@ function loadKeyData() {
|
||||
id: props.apiKey.id,
|
||||
name: props.apiKey.name || '',
|
||||
initial_balance_usd: props.apiKey.initial_balance_usd,
|
||||
expire_days: props.apiKey.expire_days,
|
||||
never_expire: props.apiKey.never_expire,
|
||||
expires_at: props.apiKey.expires_at,
|
||||
rate_limit: props.apiKey.rate_limit,
|
||||
auto_delete_on_expiry: props.apiKey.auto_delete_on_expiry,
|
||||
allowed_providers: props.apiKey.allowed_providers || [],
|
||||
@@ -452,12 +409,10 @@ function toggleSelection(field: 'allowed_providers' | 'allowed_api_formats' | 'a
|
||||
}
|
||||
}
|
||||
|
||||
// 永不过期切换
|
||||
function onNeverExpireChange() {
|
||||
if (form.value.never_expire) {
|
||||
form.value.expire_days = undefined
|
||||
form.value.auto_delete_on_expiry = false
|
||||
}
|
||||
// 清空过期日期(同时清空到期删除选项)
|
||||
function clearExpiryDate() {
|
||||
form.value.expires_at = undefined
|
||||
form.value.auto_delete_on_expiry = false
|
||||
}
|
||||
|
||||
// 提交表单
|
||||
|
||||
@@ -66,19 +66,59 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 认证方式切换 -->
|
||||
<div
|
||||
v-if="showAuthTypeTabs"
|
||||
class="auth-type-tabs"
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
:class="['auth-tab', authType === 'local' && 'active']"
|
||||
@click="authType = 'local'"
|
||||
>
|
||||
本地登录
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
:class="['auth-tab', authType === 'ldap' && 'active']"
|
||||
@click="authType = 'ldap'"
|
||||
>
|
||||
LDAP 登录
|
||||
</button>
|
||||
</div>
|
||||
|
||||
|
||||
<!-- 登录表单 -->
|
||||
<form
|
||||
class="space-y-4"
|
||||
@submit.prevent="handleLogin"
|
||||
>
|
||||
<div class="space-y-2">
|
||||
<Label for="login-email">邮箱</Label>
|
||||
<div class="flex items-center justify-between">
|
||||
<Label for="login-email">{{ emailLabel }}</Label>
|
||||
<button
|
||||
v-if="ldapExclusive && authType === 'ldap'"
|
||||
type="button"
|
||||
class="text-xs text-muted-foreground/60 hover:text-muted-foreground transition-colors"
|
||||
@click="authType = 'local'"
|
||||
>
|
||||
管理员本地登录
|
||||
</button>
|
||||
<button
|
||||
v-if="ldapExclusive && authType === 'local'"
|
||||
type="button"
|
||||
class="text-xs text-muted-foreground/60 hover:text-muted-foreground transition-colors"
|
||||
@click="authType = 'ldap'"
|
||||
>
|
||||
返回 LDAP 登录
|
||||
</button>
|
||||
</div>
|
||||
<Input
|
||||
id="login-email"
|
||||
v-model="form.email"
|
||||
type="email"
|
||||
type="text"
|
||||
required
|
||||
placeholder="hello@example.com"
|
||||
placeholder="username 或 email"
|
||||
autocomplete="off"
|
||||
/>
|
||||
</div>
|
||||
@@ -180,6 +220,30 @@ const showRegisterDialog = ref(false)
|
||||
const requireEmailVerification = ref(false)
|
||||
const allowRegistration = ref(false) // 由系统配置控制,默认关闭
|
||||
|
||||
// LDAP authentication settings
|
||||
const PREFERRED_AUTH_TYPE_KEY = 'aether_preferred_auth_type'
|
||||
function getStoredAuthType(): 'local' | 'ldap' {
|
||||
const stored = localStorage.getItem(PREFERRED_AUTH_TYPE_KEY)
|
||||
return (stored === 'ldap' || stored === 'local') ? stored : 'local'
|
||||
}
|
||||
const authType = ref<'local' | 'ldap'>(getStoredAuthType())
|
||||
const localEnabled = ref(true)
|
||||
const ldapEnabled = ref(false)
|
||||
const ldapExclusive = ref(false)
|
||||
|
||||
// 保存用户的认证类型偏好
|
||||
watch(authType, (newType) => {
|
||||
localStorage.setItem(PREFERRED_AUTH_TYPE_KEY, newType)
|
||||
})
|
||||
|
||||
const showAuthTypeTabs = computed(() => {
|
||||
return localEnabled.value && ldapEnabled.value && !ldapExclusive.value
|
||||
})
|
||||
|
||||
const emailLabel = computed(() => {
|
||||
return '用户名/邮箱'
|
||||
})
|
||||
|
||||
watch(() => props.modelValue, (val) => {
|
||||
isOpen.value = val
|
||||
// 打开对话框时重置表单
|
||||
@@ -212,7 +276,7 @@ async function handleLogin() {
|
||||
return
|
||||
}
|
||||
|
||||
const success = await authStore.login(form.value.email, form.value.password)
|
||||
const success = await authStore.login(form.value.email, form.value.password, authType.value)
|
||||
if (success) {
|
||||
showSuccess('登录成功,正在跳转...')
|
||||
|
||||
@@ -246,16 +310,84 @@ function handleSwitchToLogin() {
|
||||
isOpen.value = true
|
||||
}
|
||||
|
||||
// Load registration settings on mount
|
||||
// Load authentication and registration settings on mount
|
||||
onMounted(async () => {
|
||||
try {
|
||||
const settings = await authApi.getRegistrationSettings()
|
||||
allowRegistration.value = !!settings.enable_registration
|
||||
requireEmailVerification.value = !!settings.require_email_verification
|
||||
// Load registration settings
|
||||
const regSettings = await authApi.getRegistrationSettings()
|
||||
allowRegistration.value = !!regSettings.enable_registration
|
||||
requireEmailVerification.value = !!regSettings.require_email_verification
|
||||
|
||||
// Load authentication settings
|
||||
const authSettings = await authApi.getAuthSettings()
|
||||
localEnabled.value = authSettings.local_enabled
|
||||
ldapEnabled.value = authSettings.ldap_enabled
|
||||
ldapExclusive.value = authSettings.ldap_exclusive
|
||||
// 若仅允许 LDAP 登录,则禁用本地注册入口
|
||||
if (ldapExclusive.value) {
|
||||
allowRegistration.value = false
|
||||
}
|
||||
|
||||
// Set default auth type based on settings
|
||||
if (authSettings.ldap_exclusive) {
|
||||
authType.value = 'ldap'
|
||||
} else if (!authSettings.local_enabled && authSettings.ldap_enabled) {
|
||||
authType.value = 'ldap'
|
||||
} else {
|
||||
authType.value = 'local'
|
||||
}
|
||||
} catch (error) {
|
||||
// If获取失败,保持默认:关闭注册 & 关闭邮箱验证
|
||||
// If获取失败,保持默认:关闭注册 & 关闭邮箱验证 & 使用本地认证
|
||||
allowRegistration.value = false
|
||||
requireEmailVerification.value = false
|
||||
localEnabled.value = true
|
||||
ldapEnabled.value = false
|
||||
ldapExclusive.value = false
|
||||
authType.value = 'local'
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.auth-type-tabs {
|
||||
display: flex;
|
||||
border-bottom: 1px solid hsl(var(--border));
|
||||
}
|
||||
|
||||
.auth-tab {
|
||||
flex: 1;
|
||||
padding: 0.625rem 1rem;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
color: hsl(var(--muted-foreground));
|
||||
background: transparent;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
transition: color 0.15s ease;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.auth-tab::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
bottom: -1px;
|
||||
left: 0;
|
||||
right: 0;
|
||||
height: 2px;
|
||||
background: transparent;
|
||||
transition: background 0.15s ease;
|
||||
}
|
||||
|
||||
.auth-tab:hover:not(.active) {
|
||||
color: hsl(var(--foreground));
|
||||
}
|
||||
|
||||
.auth-tab.active {
|
||||
color: var(--book-cloth);
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.auth-tab.active::after {
|
||||
background: var(--book-cloth);
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -18,8 +18,22 @@
|
||||
<span class="flex-shrink-0">多</span>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
v-if="isLoading"
|
||||
class="h-full min-h-[160px] flex items-center justify-center text-sm text-muted-foreground"
|
||||
>
|
||||
<Loader2 class="h-5 w-5 animate-spin mr-2" />
|
||||
加载中...
|
||||
</div>
|
||||
<div
|
||||
v-else-if="hasError"
|
||||
class="h-full min-h-[160px] flex items-center justify-center text-sm text-destructive"
|
||||
>
|
||||
<AlertCircle class="h-4 w-4 mr-1.5" />
|
||||
加载失败
|
||||
</div>
|
||||
<ActivityHeatmap
|
||||
v-if="hasData"
|
||||
v-else-if="hasData"
|
||||
:data="data"
|
||||
:show-header="false"
|
||||
/>
|
||||
@@ -34,6 +48,7 @@
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import { Loader2, AlertCircle } from 'lucide-vue-next'
|
||||
import Card from '@/components/ui/card.vue'
|
||||
import ActivityHeatmap from '@/components/stats/ActivityHeatmap.vue'
|
||||
import type { ActivityHeatmap as ActivityHeatmapData } from '@/types/activity'
|
||||
@@ -41,6 +56,8 @@ import type { ActivityHeatmap as ActivityHeatmapData } from '@/types/activity'
|
||||
const props = defineProps<{
|
||||
data: ActivityHeatmapData | null
|
||||
title: string
|
||||
isLoading?: boolean
|
||||
hasError?: boolean
|
||||
}>()
|
||||
|
||||
const legendLevels = [0.08, 0.25, 0.45, 0.65, 0.85]
|
||||
|
||||
@@ -32,6 +32,17 @@
|
||||
<!-- 分隔线 -->
|
||||
<div class="hidden sm:block h-4 w-px bg-border" />
|
||||
|
||||
<!-- 通用搜索 -->
|
||||
<div class="relative">
|
||||
<Search class="absolute left-2.5 top-1/2 -translate-y-1/2 h-3.5 w-3.5 text-muted-foreground z-10 pointer-events-none" />
|
||||
<Input
|
||||
id="usage-records-search"
|
||||
v-model="localSearch"
|
||||
:placeholder="isAdmin ? '搜索用户/密钥/模型/提供商' : '搜索密钥/模型'"
|
||||
class="w-32 sm:w-48 h-8 text-xs border-border/60 pl-8"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- 用户筛选(仅管理员可见) -->
|
||||
<Select
|
||||
v-if="isAdmin && availableUsers.length > 0"
|
||||
@@ -164,6 +175,12 @@
|
||||
>
|
||||
用户
|
||||
</TableHead>
|
||||
<TableHead
|
||||
v-if="!isAdmin"
|
||||
class="h-12 font-semibold w-[100px]"
|
||||
>
|
||||
密钥
|
||||
</TableHead>
|
||||
<TableHead class="h-12 font-semibold w-[140px]">
|
||||
模型
|
||||
</TableHead>
|
||||
@@ -196,7 +213,7 @@
|
||||
<TableBody>
|
||||
<TableRow v-if="records.length === 0">
|
||||
<TableCell
|
||||
:colspan="isAdmin ? 9 : 7"
|
||||
:colspan="isAdmin ? 9 : 8"
|
||||
class="text-center py-12 text-muted-foreground"
|
||||
>
|
||||
暂无请求记录
|
||||
@@ -218,7 +235,34 @@
|
||||
class="py-4 w-[100px] truncate"
|
||||
:title="record.username || record.user_email || (record.user_id ? `User ${record.user_id}` : '已删除用户')"
|
||||
>
|
||||
{{ record.username || record.user_email || (record.user_id ? `User ${record.user_id}` : '已删除用户') }}
|
||||
<div class="flex flex-col text-xs gap-0.5">
|
||||
<span class="truncate">
|
||||
{{ record.username || record.user_email || (record.user_id ? `User ${record.user_id}` : '已删除用户') }}
|
||||
</span>
|
||||
<span
|
||||
v-if="record.api_key?.name"
|
||||
class="text-muted-foreground truncate"
|
||||
:title="record.api_key.name"
|
||||
>
|
||||
{{ record.api_key.name }}
|
||||
</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<!-- 用户页面的密钥列 -->
|
||||
<TableCell
|
||||
v-if="!isAdmin"
|
||||
class="py-4 w-[100px]"
|
||||
:title="record.api_key?.name || '-'"
|
||||
>
|
||||
<div class="flex flex-col text-xs gap-0.5">
|
||||
<span class="truncate">{{ record.api_key?.name || '-' }}</span>
|
||||
<span
|
||||
v-if="record.api_key?.display"
|
||||
class="text-muted-foreground truncate"
|
||||
>
|
||||
{{ record.api_key.display }}
|
||||
</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell
|
||||
class="font-medium py-4 w-[140px]"
|
||||
@@ -438,6 +482,7 @@ import {
|
||||
TableCard,
|
||||
Badge,
|
||||
Button,
|
||||
Input,
|
||||
Select,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
@@ -451,7 +496,7 @@ import {
|
||||
TableCell,
|
||||
Pagination,
|
||||
} from '@/components/ui'
|
||||
import { RefreshCcw } from 'lucide-vue-next'
|
||||
import { RefreshCcw, Search } from 'lucide-vue-next'
|
||||
import { formatTokens, formatCurrency } from '@/utils/format'
|
||||
import { formatDateTime } from '../composables'
|
||||
import { useRowClick } from '@/composables/useRowClick'
|
||||
@@ -471,6 +516,7 @@ const props = defineProps<{
|
||||
// 时间段
|
||||
selectedPeriod: string
|
||||
// 筛选
|
||||
filterSearch: string
|
||||
filterUser: string
|
||||
filterModel: string
|
||||
filterProvider: string
|
||||
@@ -489,6 +535,7 @@ const props = defineProps<{
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:selectedPeriod': [value: string]
|
||||
'update:filterSearch': [value: string]
|
||||
'update:filterUser': [value: string]
|
||||
'update:filterModel': [value: string]
|
||||
'update:filterProvider': [value: string]
|
||||
@@ -507,6 +554,23 @@ const filterModelSelectOpen = ref(false)
|
||||
const filterProviderSelectOpen = ref(false)
|
||||
const filterStatusSelectOpen = ref(false)
|
||||
|
||||
// 通用搜索(输入防抖)
|
||||
const localSearch = ref(props.filterSearch)
|
||||
let searchDebounceTimer: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
watch(() => props.filterSearch, (value) => {
|
||||
if (value !== localSearch.value) {
|
||||
localSearch.value = value
|
||||
}
|
||||
})
|
||||
|
||||
watch(localSearch, (value) => {
|
||||
if (searchDebounceTimer) clearTimeout(searchDebounceTimer)
|
||||
searchDebounceTimer = setTimeout(() => {
|
||||
emit('update:filterSearch', value)
|
||||
}, 300)
|
||||
})
|
||||
|
||||
// 动态计时器相关
|
||||
const now = ref(Date.now())
|
||||
let timerInterval: ReturnType<typeof setInterval> | null = null
|
||||
@@ -574,6 +638,10 @@ function handleRowClick(event: MouseEvent, id: string) {
|
||||
// 组件卸载时清理
|
||||
onUnmounted(() => {
|
||||
stopTimer()
|
||||
if (searchDebounceTimer) {
|
||||
clearTimeout(searchDebounceTimer)
|
||||
searchDebounceTimer = null
|
||||
}
|
||||
})
|
||||
|
||||
// 格式化 API 格式显示名称
|
||||
|
||||
@@ -23,6 +23,7 @@ export interface PaginationParams {
|
||||
}
|
||||
|
||||
export interface FilterParams {
|
||||
search?: string
|
||||
user_id?: string
|
||||
model?: string
|
||||
provider?: string
|
||||
@@ -64,9 +65,6 @@ export function useUsageData(options: UseUsageDataOptions) {
|
||||
}))
|
||||
})
|
||||
|
||||
// 活跃度热图数据
|
||||
const activityHeatmapData = computed(() => stats.value.activity_heatmap)
|
||||
|
||||
// 加载统计数据(不加载记录)
|
||||
async function loadStats(dateRange?: DateRangeParams) {
|
||||
isLoadingStats.value = true
|
||||
@@ -93,7 +91,7 @@ export function useUsageData(options: UseUsageDataOptions) {
|
||||
cache_stats: (statsData as any).cache_stats,
|
||||
period_start: '',
|
||||
period_end: '',
|
||||
activity_heatmap: statsData.activity_heatmap || null
|
||||
activity_heatmap: null
|
||||
}
|
||||
|
||||
modelStats.value = modelData.map(item => ({
|
||||
@@ -143,7 +141,7 @@ export function useUsageData(options: UseUsageDataOptions) {
|
||||
avg_response_time: userData.avg_response_time || 0,
|
||||
period_start: '',
|
||||
period_end: '',
|
||||
activity_heatmap: userData.activity_heatmap || null
|
||||
activity_heatmap: null
|
||||
}
|
||||
|
||||
modelStats.value = (userData.summary_by_model || []).map((item: any) => ({
|
||||
@@ -237,11 +235,6 @@ export function useUsageData(options: UseUsageDataOptions) {
|
||||
pagination: PaginationParams,
|
||||
filters?: FilterParams
|
||||
): Promise<void> {
|
||||
if (!isAdminPage.value) {
|
||||
// 用户页面不需要分页加载,记录已在 loadStats 中获取
|
||||
return
|
||||
}
|
||||
|
||||
isLoadingRecords.value = true
|
||||
|
||||
try {
|
||||
@@ -255,24 +248,34 @@ export function useUsageData(options: UseUsageDataOptions) {
|
||||
}
|
||||
|
||||
// 添加筛选条件
|
||||
if (filters?.user_id) {
|
||||
params.user_id = filters.user_id
|
||||
}
|
||||
if (filters?.model) {
|
||||
params.model = filters.model
|
||||
}
|
||||
if (filters?.provider) {
|
||||
params.provider = filters.provider
|
||||
}
|
||||
if (filters?.status) {
|
||||
params.status = filters.status
|
||||
if (filters?.search?.trim()) {
|
||||
params.search = filters.search.trim()
|
||||
}
|
||||
|
||||
const response = await usageApi.getAllUsageRecords(params)
|
||||
|
||||
currentRecords.value = (response.records || []) as UsageRecord[]
|
||||
totalRecords.value = response.total || 0
|
||||
if (isAdminPage.value) {
|
||||
// 管理员页面:使用管理员 API
|
||||
if (filters?.user_id) {
|
||||
params.user_id = filters.user_id
|
||||
}
|
||||
if (filters?.model) {
|
||||
params.model = filters.model
|
||||
}
|
||||
if (filters?.provider) {
|
||||
params.provider = filters.provider
|
||||
}
|
||||
if (filters?.status) {
|
||||
params.status = filters.status
|
||||
}
|
||||
|
||||
const response = await usageApi.getAllUsageRecords(params)
|
||||
currentRecords.value = (response.records || []) as UsageRecord[]
|
||||
totalRecords.value = response.total || 0
|
||||
} else {
|
||||
// 用户页面:使用用户 API
|
||||
const userData = await meApi.getUsage(params)
|
||||
currentRecords.value = (userData.records || []) as UsageRecord[]
|
||||
totalRecords.value = userData.pagination?.total || currentRecords.value.length
|
||||
}
|
||||
} catch (error) {
|
||||
log.error('加载记录失败:', error)
|
||||
currentRecords.value = []
|
||||
@@ -305,7 +308,6 @@ export function useUsageData(options: UseUsageDataOptions) {
|
||||
|
||||
// 计算属性
|
||||
enhancedModelStats,
|
||||
activityHeatmapData,
|
||||
|
||||
// 方法
|
||||
loadStats,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import type { ActivityHeatmap } from '@/types/activity'
|
||||
|
||||
// 统计数据状态
|
||||
export interface UsageStatsState {
|
||||
total_requests: number
|
||||
@@ -17,7 +15,6 @@ export interface UsageStatsState {
|
||||
}
|
||||
period_start: string
|
||||
period_end: string
|
||||
activity_heatmap: ActivityHeatmap | null
|
||||
}
|
||||
|
||||
// 模型统计
|
||||
@@ -64,6 +61,11 @@ export interface UsageRecord {
|
||||
user_id?: string
|
||||
username?: string
|
||||
user_email?: string
|
||||
api_key?: {
|
||||
id: string | null
|
||||
name: string | null
|
||||
display: string | null
|
||||
} | null
|
||||
provider: string
|
||||
api_key_name?: string
|
||||
rate_multiplier?: number
|
||||
@@ -115,7 +117,6 @@ export function createDefaultStats(): UsageStatsState {
|
||||
error_rate: undefined,
|
||||
cache_stats: undefined,
|
||||
period_start: '',
|
||||
period_end: '',
|
||||
activity_heatmap: null
|
||||
period_end: ''
|
||||
}
|
||||
}
|
||||
|
||||
@@ -316,55 +316,10 @@
|
||||
</div>
|
||||
|
||||
<!-- 模型多选下拉框 -->
|
||||
<div class="space-y-2">
|
||||
<Label class="text-sm font-medium">允许的模型</Label>
|
||||
<div class="relative">
|
||||
<button
|
||||
type="button"
|
||||
class="w-full h-10 px-3 border rounded-lg bg-background text-left flex items-center justify-between hover:bg-muted/50 transition-colors"
|
||||
@click="modelDropdownOpen = !modelDropdownOpen"
|
||||
>
|
||||
<span :class="form.allowed_models.length ? 'text-foreground' : 'text-muted-foreground'">
|
||||
{{ form.allowed_models.length ? `已选择 ${form.allowed_models.length} 个` : '全部可用' }}
|
||||
</span>
|
||||
<ChevronDown
|
||||
class="h-4 w-4 text-muted-foreground transition-transform"
|
||||
:class="modelDropdownOpen ? 'rotate-180' : ''"
|
||||
/>
|
||||
</button>
|
||||
<div
|
||||
v-if="modelDropdownOpen"
|
||||
class="fixed inset-0 z-[80]"
|
||||
@click.stop="modelDropdownOpen = false"
|
||||
/>
|
||||
<div
|
||||
v-if="modelDropdownOpen"
|
||||
class="absolute z-[90] w-full mt-1 bg-popover border rounded-lg shadow-lg max-h-48 overflow-y-auto"
|
||||
>
|
||||
<div
|
||||
v-for="model in globalModels"
|
||||
:key="model.name"
|
||||
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer"
|
||||
@click="toggleSelection('allowed_models', model.name)"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.allowed_models.includes(model.name)"
|
||||
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
|
||||
@click.stop
|
||||
@change="toggleSelection('allowed_models', model.name)"
|
||||
>
|
||||
<span class="text-sm">{{ model.name }}</span>
|
||||
</div>
|
||||
<div
|
||||
v-if="globalModels.length === 0"
|
||||
class="px-3 py-2 text-sm text-muted-foreground"
|
||||
>
|
||||
暂无可用模型
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<ModelMultiSelect
|
||||
v-model="form.allowed_models"
|
||||
:models="globalModels"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
@@ -404,10 +359,12 @@ import {
|
||||
} from '@/components/ui'
|
||||
import { UserPlus, SquarePen, ChevronDown } from 'lucide-vue-next'
|
||||
import { useFormDialog } from '@/composables/useFormDialog'
|
||||
import { ModelMultiSelect } from '@/components/common'
|
||||
import { getProvidersSummary } from '@/api/endpoints/providers'
|
||||
import { getGlobalModels } from '@/api/global-models'
|
||||
import { adminApi } from '@/api/admin'
|
||||
import { log } from '@/utils/logger'
|
||||
import type { ProviderWithEndpointsSummary, GlobalModelResponse } from '@/api/endpoints/types'
|
||||
|
||||
export interface UserFormData {
|
||||
id?: string
|
||||
@@ -440,11 +397,10 @@ const roleSelectOpen = ref(false)
|
||||
// 下拉框状态
|
||||
const providerDropdownOpen = ref(false)
|
||||
const endpointDropdownOpen = ref(false)
|
||||
const modelDropdownOpen = ref(false)
|
||||
|
||||
// 选项数据
|
||||
const providers = ref<any[]>([])
|
||||
const globalModels = ref<any[]>([])
|
||||
const providers = ref<ProviderWithEndpointsSummary[]>([])
|
||||
const globalModels = ref<GlobalModelResponse[]>([])
|
||||
const apiFormats = ref<Array<{ value: string; label: string }>>([])
|
||||
|
||||
// 表单数据
|
||||
|
||||
@@ -280,6 +280,16 @@
|
||||
class="h-4 w-4"
|
||||
/>
|
||||
</button>
|
||||
<!-- GitHub Link -->
|
||||
<a
|
||||
href="https://github.com/fawney19/Aether"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
class="flex h-9 w-9 items-center justify-center rounded-lg text-muted-foreground hover:text-foreground hover:bg-muted/50 transition"
|
||||
title="GitHub 仓库"
|
||||
>
|
||||
<GithubIcon class="h-4 w-4" />
|
||||
</a>
|
||||
</div>
|
||||
</header>
|
||||
</template>
|
||||
@@ -322,6 +332,7 @@ import {
|
||||
X,
|
||||
Mail,
|
||||
} from 'lucide-vue-next'
|
||||
import GithubIcon from '@/components/icons/GithubIcon.vue'
|
||||
|
||||
const router = useRouter()
|
||||
const route = useRoute()
|
||||
@@ -423,6 +434,7 @@ const navigation = computed(() => {
|
||||
{ name: 'IP 安全', href: '/admin/ip-security', icon: Shield },
|
||||
{ name: '审计日志', href: '/admin/audit-logs', icon: AlertTriangle },
|
||||
{ name: '邮件配置', href: '/admin/email', icon: Mail },
|
||||
{ name: 'LDAP 配置', href: '/admin/ldap', icon: Shield },
|
||||
{ name: '系统设置', href: '/admin/system', icon: Cog },
|
||||
]
|
||||
}
|
||||
|
||||
@@ -367,6 +367,11 @@ function generateMockUsageRecords(count: number = 100) {
|
||||
user_id: user.id,
|
||||
username: user.username,
|
||||
user_email: user.email,
|
||||
api_key: {
|
||||
id: `key-${user.id}-${Math.ceil(Math.random() * 2)}`,
|
||||
name: `${user.username} Key ${Math.ceil(Math.random() * 3)}`,
|
||||
display: `sk-ae...${String(1000 + Math.floor(Math.random() * 9000))}`
|
||||
},
|
||||
provider: model.provider,
|
||||
api_key_name: `${model.provider}-key-${Math.ceil(Math.random() * 3)}`,
|
||||
rate_multiplier: 1.0,
|
||||
@@ -835,10 +840,26 @@ const mockHandlers: Record<string, (config: AxiosRequestConfig) => Promise<Axios
|
||||
'GET /api/admin/usage/records': async (config) => {
|
||||
await delay()
|
||||
requireAdmin()
|
||||
const records = getUsageRecords()
|
||||
let records = getUsageRecords()
|
||||
const params = config.params || {}
|
||||
const limit = parseInt(params.limit) || 20
|
||||
const offset = parseInt(params.offset) || 0
|
||||
|
||||
// 通用搜索:用户名、密钥名、模型名、提供商名
|
||||
// 支持空格分隔的组合搜索,多个关键词之间是 AND 关系
|
||||
if (typeof params.search === 'string' && params.search.trim()) {
|
||||
const keywords = params.search.trim().toLowerCase().split(/\s+/)
|
||||
records = records.filter(r => {
|
||||
// 每个关键词都要匹配至少一个字段
|
||||
return keywords.every((keyword: string) =>
|
||||
(r.username || '').toLowerCase().includes(keyword) ||
|
||||
(r.api_key?.name || '').toLowerCase().includes(keyword) ||
|
||||
(r.model || '').toLowerCase().includes(keyword) ||
|
||||
(r.provider || '').toLowerCase().includes(keyword)
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
return createMockResponse({
|
||||
records: records.slice(offset, offset + limit),
|
||||
total: records.length,
|
||||
|
||||
@@ -111,6 +111,11 @@ const routes: RouteRecordRaw[] = [
|
||||
name: 'EmailSettings',
|
||||
component: () => importWithRetry(() => import('@/views/admin/EmailSettings.vue'))
|
||||
},
|
||||
{
|
||||
path: 'ldap',
|
||||
name: 'LdapSettings',
|
||||
component: () => importWithRetry(() => import('@/views/admin/LdapSettings.vue'))
|
||||
},
|
||||
{
|
||||
path: 'audit-logs',
|
||||
name: 'AuditLogs',
|
||||
|
||||
@@ -31,12 +31,12 @@ export const useAuthStore = defineStore('auth', () => {
|
||||
}
|
||||
const isAdmin = computed(() => user.value?.role === 'admin')
|
||||
|
||||
async function login(email: string, password: string) {
|
||||
async function login(email: string, password: string, authType: 'local' | 'ldap' = 'local') {
|
||||
loading.value = true
|
||||
error.value = null
|
||||
|
||||
try {
|
||||
const response = await authApi.login({ email, password })
|
||||
const response = await authApi.login({ email, password, auth_type: authType })
|
||||
token.value = response.access_token
|
||||
|
||||
// 获取用户信息
|
||||
|
||||
@@ -850,28 +850,20 @@ async function deleteApiKey(apiKey: AdminApiKey) {
|
||||
}
|
||||
|
||||
function editApiKey(apiKey: AdminApiKey) {
|
||||
// 计算过期天数
|
||||
let expireDays: number | undefined = undefined
|
||||
let neverExpire = true
|
||||
// 解析过期日期为 YYYY-MM-DD 格式
|
||||
// 保留原始日期,不做时间过滤(避免编辑当天过期的 Key 时意外清空)
|
||||
let expiresAt: string | undefined = undefined
|
||||
|
||||
if (apiKey.expires_at) {
|
||||
const expiresDate = new Date(apiKey.expires_at)
|
||||
const now = new Date()
|
||||
const diffMs = expiresDate.getTime() - now.getTime()
|
||||
const diffDays = Math.ceil(diffMs / (1000 * 60 * 60 * 24))
|
||||
|
||||
if (diffDays > 0) {
|
||||
expireDays = diffDays
|
||||
neverExpire = false
|
||||
}
|
||||
expiresAt = expiresDate.toISOString().split('T')[0]
|
||||
}
|
||||
|
||||
editingKeyData.value = {
|
||||
id: apiKey.id,
|
||||
name: apiKey.name || '',
|
||||
expire_days: expireDays,
|
||||
never_expire: neverExpire,
|
||||
rate_limit: apiKey.rate_limit || 100,
|
||||
expires_at: expiresAt,
|
||||
rate_limit: apiKey.rate_limit ?? undefined,
|
||||
auto_delete_on_expiry: apiKey.auto_delete_on_expiry || false,
|
||||
allowed_providers: apiKey.allowed_providers || [],
|
||||
allowed_api_formats: apiKey.allowed_api_formats || [],
|
||||
@@ -1033,14 +1025,25 @@ function closeKeyFormDialog() {
|
||||
|
||||
// 统一处理表单提交
|
||||
async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
|
||||
// 验证过期日期(如果设置了,必须晚于今天)
|
||||
if (data.expires_at) {
|
||||
const selectedDate = new Date(data.expires_at)
|
||||
const today = new Date()
|
||||
today.setHours(0, 0, 0, 0)
|
||||
if (selectedDate <= today) {
|
||||
error('过期日期必须晚于今天')
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
keyFormDialogRef.value?.setSaving(true)
|
||||
try {
|
||||
if (data.id) {
|
||||
// 更新
|
||||
const updateData: Partial<CreateStandaloneApiKeyRequest> = {
|
||||
name: data.name || undefined,
|
||||
rate_limit: data.rate_limit,
|
||||
expire_days: data.never_expire ? null : (data.expire_days || null),
|
||||
rate_limit: data.rate_limit ?? null, // undefined = 无限制,显式传 null
|
||||
expires_at: data.expires_at || null, // undefined/空 = 永不过期
|
||||
auto_delete_on_expiry: data.auto_delete_on_expiry,
|
||||
// 空数组表示清除限制(允许全部),后端会将空数组存为 NULL
|
||||
allowed_providers: data.allowed_providers,
|
||||
@@ -1058,8 +1061,8 @@ async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
|
||||
const createData: CreateStandaloneApiKeyRequest = {
|
||||
name: data.name || undefined,
|
||||
initial_balance_usd: data.initial_balance_usd,
|
||||
rate_limit: data.rate_limit,
|
||||
expire_days: data.never_expire ? null : (data.expire_days || null),
|
||||
rate_limit: data.rate_limit ?? null, // undefined = 无限制,显式传 null
|
||||
expires_at: data.expires_at || null, // undefined/空 = 永不过期
|
||||
auto_delete_on_expiry: data.auto_delete_on_expiry,
|
||||
// 空数组表示不设置限制(允许全部),后端会将空数组存为 NULL
|
||||
allowed_providers: data.allowed_providers,
|
||||
|
||||
@@ -106,23 +106,23 @@
|
||||
type="text"
|
||||
:placeholder="smtpPasswordIsSet ? '已设置(留空保持不变)' : '请输入密码'"
|
||||
class="-webkit-text-security-disc"
|
||||
:class="smtpPasswordIsSet ? 'pr-8' : ''"
|
||||
:class="(smtpPasswordIsSet || emailConfig.smtp_password) ? 'pr-10' : ''"
|
||||
autocomplete="one-time-code"
|
||||
data-lpignore="true"
|
||||
data-1p-ignore="true"
|
||||
data-form-type="other"
|
||||
/>
|
||||
<button
|
||||
v-if="smtpPasswordIsSet"
|
||||
v-if="smtpPasswordIsSet || emailConfig.smtp_password"
|
||||
type="button"
|
||||
class="absolute right-2 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground transition-colors"
|
||||
title="清除已保存的密码"
|
||||
class="absolute right-3 top-1/2 -translate-y-1/2 p-1 rounded-full text-muted-foreground/60 hover:text-muted-foreground hover:bg-muted/50 transition-colors"
|
||||
title="清除密码"
|
||||
@click="handleClearSmtpPassword"
|
||||
>
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="16"
|
||||
height="16"
|
||||
width="14"
|
||||
height="14"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
@@ -498,6 +498,7 @@ const smtpEncryptionSelectOpen = ref(false)
|
||||
const emailSuffixModeSelectOpen = ref(false)
|
||||
const testSmtpLoading = ref(false)
|
||||
const smtpPasswordIsSet = ref(false)
|
||||
const clearSmtpPassword = ref(false) // 标记是否要清除密码
|
||||
|
||||
// 邮件模板相关状态
|
||||
const templateLoading = ref(false)
|
||||
@@ -710,6 +711,7 @@ async function loadEmailConfig() {
|
||||
// 配置不存在时使用默认值,无需处理
|
||||
}
|
||||
}
|
||||
clearSmtpPassword.value = false
|
||||
} catch (err) {
|
||||
error('加载邮件配置失败')
|
||||
log.error('加载邮件配置失败:', err)
|
||||
@@ -720,6 +722,12 @@ async function loadEmailConfig() {
|
||||
async function saveSmtpConfig() {
|
||||
smtpSaveLoading.value = true
|
||||
try {
|
||||
const passwordAction: 'unchanged' | 'updated' | 'cleared' = emailConfig.value.smtp_password
|
||||
? 'updated'
|
||||
: clearSmtpPassword.value
|
||||
? 'cleared'
|
||||
: 'unchanged'
|
||||
|
||||
const configItems = [
|
||||
{
|
||||
key: 'smtp_host',
|
||||
@@ -737,7 +745,7 @@ async function saveSmtpConfig() {
|
||||
description: 'SMTP 用户名'
|
||||
},
|
||||
// 只有输入了新密码才提交(空值表示保持原密码)
|
||||
...(emailConfig.value.smtp_password
|
||||
...(passwordAction === 'updated'
|
||||
? [{
|
||||
key: 'smtp_password',
|
||||
value: emailConfig.value.smtp_password,
|
||||
@@ -770,8 +778,23 @@ async function saveSmtpConfig() {
|
||||
adminApi.updateSystemConfig(item.key, item.value, item.description)
|
||||
)
|
||||
|
||||
// 如果标记了清除密码,删除密码配置
|
||||
if (passwordAction === 'cleared') {
|
||||
promises.push(adminApi.deleteSystemConfig('smtp_password'))
|
||||
}
|
||||
|
||||
await Promise.all(promises)
|
||||
success('SMTP 配置已保存')
|
||||
|
||||
// 更新状态
|
||||
if (passwordAction === 'cleared') {
|
||||
clearSmtpPassword.value = false
|
||||
smtpPasswordIsSet.value = false
|
||||
} else if (passwordAction === 'updated') {
|
||||
clearSmtpPassword.value = false
|
||||
smtpPasswordIsSet.value = true
|
||||
}
|
||||
emailConfig.value.smtp_password = null
|
||||
} catch (err) {
|
||||
error('保存配置失败')
|
||||
log.error('保存 SMTP 配置失败:', err)
|
||||
@@ -812,15 +835,16 @@ async function saveEmailSuffixConfig() {
|
||||
}
|
||||
|
||||
// 清除 SMTP 密码
|
||||
async function handleClearSmtpPassword() {
|
||||
try {
|
||||
await adminApi.deleteSystemConfig('smtp_password')
|
||||
smtpPasswordIsSet.value = false
|
||||
function handleClearSmtpPassword() {
|
||||
// 如果有输入内容,先清空输入框
|
||||
if (emailConfig.value.smtp_password) {
|
||||
emailConfig.value.smtp_password = null
|
||||
success('SMTP 密码已清除')
|
||||
} catch (err) {
|
||||
error('清除密码失败')
|
||||
log.error('清除 SMTP 密码失败:', err)
|
||||
return
|
||||
}
|
||||
// 标记要清除服务端密码(保存时生效)
|
||||
if (smtpPasswordIsSet.value) {
|
||||
clearSmtpPassword.value = true
|
||||
smtpPasswordIsSet.value = false
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
379
frontend/src/views/admin/LdapSettings.vue
Normal file
379
frontend/src/views/admin/LdapSettings.vue
Normal file
@@ -0,0 +1,379 @@
|
||||
<template>
|
||||
<PageContainer>
|
||||
<PageHeader
|
||||
title="LDAP 配置"
|
||||
description="配置 LDAP 认证服务"
|
||||
/>
|
||||
|
||||
<div class="mt-6 space-y-6">
|
||||
<CardSection
|
||||
title="LDAP 服务器配置"
|
||||
description="配置 LDAP 服务器连接参数"
|
||||
>
|
||||
<template #actions>
|
||||
<div class="flex gap-2">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
:disabled="testLoading"
|
||||
@click="handleTestConnection"
|
||||
>
|
||||
{{ testLoading ? '测试中...' : '测试连接' }}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
:disabled="saveLoading"
|
||||
@click="handleSave"
|
||||
>
|
||||
{{ saveLoading ? '保存中...' : '保存' }}
|
||||
</Button>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
|
||||
<div>
|
||||
<Label for="server-url" class="block text-sm font-medium">
|
||||
服务器地址
|
||||
</Label>
|
||||
<Input
|
||||
id="server-url"
|
||||
v-model="ldapConfig.server_url"
|
||||
type="text"
|
||||
placeholder="ldap://ldap.example.com:389"
|
||||
class="mt-1"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
格式: ldap://host:389 或 ldaps://host:636
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label for="bind-dn" class="block text-sm font-medium">
|
||||
绑定 DN
|
||||
</Label>
|
||||
<Input
|
||||
id="bind-dn"
|
||||
v-model="ldapConfig.bind_dn"
|
||||
type="text"
|
||||
placeholder="cn=admin,dc=example,dc=com"
|
||||
class="mt-1"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
用于连接 LDAP 服务器的管理员 DN
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label for="bind-password" class="block text-sm font-medium">
|
||||
绑定密码
|
||||
</Label>
|
||||
<div class="relative mt-1">
|
||||
<Input
|
||||
id="bind-password"
|
||||
v-model="ldapConfig.bind_password"
|
||||
type="password"
|
||||
:placeholder="hasPassword ? '已设置(留空保持不变)' : '请输入密码'"
|
||||
:class="(hasPassword || ldapConfig.bind_password) ? 'pr-10' : ''"
|
||||
autocomplete="new-password"
|
||||
/>
|
||||
<button
|
||||
v-if="hasPassword || ldapConfig.bind_password"
|
||||
type="button"
|
||||
class="absolute right-3 top-1/2 -translate-y-1/2 p-1 rounded-full text-muted-foreground/60 hover:text-muted-foreground hover:bg-muted/50 transition-colors"
|
||||
@click="handleClearPassword"
|
||||
title="清除密码"
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><line x1="18" y1="6" x2="6" y2="18"/><line x1="6" y1="6" x2="18" y2="18"/></svg>
|
||||
</button>
|
||||
</div>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
绑定账号的密码
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label for="base-dn" class="block text-sm font-medium">
|
||||
基础 DN
|
||||
</Label>
|
||||
<Input
|
||||
id="base-dn"
|
||||
v-model="ldapConfig.base_dn"
|
||||
type="text"
|
||||
placeholder="ou=users,dc=example,dc=com"
|
||||
class="mt-1"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
用户搜索的基础 DN
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label for="user-search-filter" class="block text-sm font-medium">
|
||||
用户搜索过滤器
|
||||
</Label>
|
||||
<Input
|
||||
id="user-search-filter"
|
||||
v-model="ldapConfig.user_search_filter"
|
||||
type="text"
|
||||
placeholder="(uid={username})"
|
||||
class="mt-1"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{username} 会被替换为登录用户名
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label for="username-attr" class="block text-sm font-medium">
|
||||
用户名属性
|
||||
</Label>
|
||||
<Input
|
||||
id="username-attr"
|
||||
v-model="ldapConfig.username_attr"
|
||||
type="text"
|
||||
placeholder="uid"
|
||||
class="mt-1"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
常用: uid (OpenLDAP), sAMAccountName (AD)
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label for="email-attr" class="block text-sm font-medium">
|
||||
邮箱属性
|
||||
</Label>
|
||||
<Input
|
||||
id="email-attr"
|
||||
v-model="ldapConfig.email_attr"
|
||||
type="text"
|
||||
placeholder="mail"
|
||||
class="mt-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label for="display-name-attr" class="block text-sm font-medium">
|
||||
显示名称属性
|
||||
</Label>
|
||||
<Input
|
||||
id="display-name-attr"
|
||||
v-model="ldapConfig.display_name_attr"
|
||||
type="text"
|
||||
placeholder="cn"
|
||||
class="mt-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label for="connect-timeout" class="block text-sm font-medium">
|
||||
连接超时 (秒)
|
||||
</Label>
|
||||
<Input
|
||||
id="connect-timeout"
|
||||
v-model.number="ldapConfig.connect_timeout"
|
||||
type="number"
|
||||
min="1"
|
||||
max="60"
|
||||
placeholder="10"
|
||||
class="mt-1"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
单次 LDAP 操作超时时间 (1-60秒),跨国网络建议 15-30 秒
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mt-6 space-y-4">
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<Label class="text-sm font-medium">使用 STARTTLS</Label>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
在非 SSL 连接上启用 TLS 加密
|
||||
</p>
|
||||
</div>
|
||||
<Switch v-model="ldapConfig.use_starttls" />
|
||||
</div>
|
||||
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<Label class="text-sm font-medium">启用 LDAP 认证</Label>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
允许用户使用 LDAP 账号登录
|
||||
</p>
|
||||
</div>
|
||||
<Switch v-model="ldapConfig.is_enabled" />
|
||||
</div>
|
||||
|
||||
<div class="flex items-center justify-between">
|
||||
<div>
|
||||
<Label class="text-sm font-medium">仅允许 LDAP 登录</Label>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
禁用本地账号登录,仅允许 LDAP 认证
|
||||
</p>
|
||||
</div>
|
||||
<Switch v-model="ldapConfig.is_exclusive" />
|
||||
</div>
|
||||
</div>
|
||||
</CardSection>
|
||||
</div>
|
||||
</PageContainer>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { PageContainer, PageHeader, CardSection } from '@/components/layout'
|
||||
import { Button, Input, Label, Switch } from '@/components/ui'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { adminApi, type LdapConfigUpdateRequest } from '@/api/admin'
|
||||
|
||||
const { success, error } = useToast()
|
||||
|
||||
const loading = ref(false)
|
||||
const saveLoading = ref(false)
|
||||
const testLoading = ref(false)
|
||||
const hasPassword = ref(false)
|
||||
const clearPassword = ref(false) // 标记是否要清除密码
|
||||
|
||||
const ldapConfig = ref({
|
||||
server_url: '',
|
||||
bind_dn: '',
|
||||
bind_password: '',
|
||||
base_dn: '',
|
||||
user_search_filter: '(uid={username})',
|
||||
username_attr: 'uid',
|
||||
email_attr: 'mail',
|
||||
display_name_attr: 'cn',
|
||||
is_enabled: false,
|
||||
is_exclusive: false,
|
||||
use_starttls: false,
|
||||
connect_timeout: 10,
|
||||
})
|
||||
|
||||
onMounted(async () => {
|
||||
await loadConfig()
|
||||
})
|
||||
|
||||
async function loadConfig() {
|
||||
loading.value = true
|
||||
try {
|
||||
const response = await adminApi.getLdapConfig()
|
||||
ldapConfig.value = {
|
||||
server_url: response.server_url || '',
|
||||
bind_dn: response.bind_dn || '',
|
||||
bind_password: '',
|
||||
base_dn: response.base_dn || '',
|
||||
user_search_filter: response.user_search_filter || '(uid={username})',
|
||||
username_attr: response.username_attr || 'uid',
|
||||
email_attr: response.email_attr || 'mail',
|
||||
display_name_attr: response.display_name_attr || 'cn',
|
||||
is_enabled: response.is_enabled || false,
|
||||
is_exclusive: response.is_exclusive || false,
|
||||
use_starttls: response.use_starttls || false,
|
||||
connect_timeout: response.connect_timeout || 10,
|
||||
}
|
||||
hasPassword.value = !!response.has_bind_password
|
||||
clearPassword.value = false
|
||||
} catch (err) {
|
||||
error('加载 LDAP 配置失败')
|
||||
console.error('加载 LDAP 配置失败:', err)
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function handleSave() {
|
||||
saveLoading.value = true
|
||||
try {
|
||||
const payload: LdapConfigUpdateRequest = {
|
||||
server_url: ldapConfig.value.server_url,
|
||||
bind_dn: ldapConfig.value.bind_dn,
|
||||
base_dn: ldapConfig.value.base_dn,
|
||||
user_search_filter: ldapConfig.value.user_search_filter,
|
||||
username_attr: ldapConfig.value.username_attr,
|
||||
email_attr: ldapConfig.value.email_attr,
|
||||
display_name_attr: ldapConfig.value.display_name_attr,
|
||||
is_enabled: ldapConfig.value.is_enabled,
|
||||
is_exclusive: ldapConfig.value.is_exclusive,
|
||||
use_starttls: ldapConfig.value.use_starttls,
|
||||
connect_timeout: ldapConfig.value.connect_timeout,
|
||||
}
|
||||
|
||||
// 优先使用输入的新密码;否则如果标记清除则发送空字符串
|
||||
let passwordAction: 'unchanged' | 'updated' | 'cleared' = 'unchanged'
|
||||
if (ldapConfig.value.bind_password) {
|
||||
payload.bind_password = ldapConfig.value.bind_password
|
||||
passwordAction = 'updated'
|
||||
} else if (clearPassword.value) {
|
||||
payload.bind_password = ''
|
||||
passwordAction = 'cleared'
|
||||
}
|
||||
|
||||
await adminApi.updateLdapConfig(payload)
|
||||
success('LDAP 配置保存成功')
|
||||
|
||||
if (passwordAction === 'cleared') {
|
||||
hasPassword.value = false
|
||||
clearPassword.value = false
|
||||
} else if (passwordAction === 'updated') {
|
||||
hasPassword.value = true
|
||||
clearPassword.value = false
|
||||
}
|
||||
ldapConfig.value.bind_password = ''
|
||||
} catch (err) {
|
||||
error('保存 LDAP 配置失败')
|
||||
console.error('保存 LDAP 配置失败:', err)
|
||||
} finally {
|
||||
saveLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function handleTestConnection() {
|
||||
if (clearPassword.value && !ldapConfig.value.bind_password) {
|
||||
error('已标记清除绑定密码,请先保存或输入新的绑定密码再测试')
|
||||
return
|
||||
}
|
||||
|
||||
testLoading.value = true
|
||||
try {
|
||||
const payload: LdapConfigUpdateRequest = {
|
||||
server_url: ldapConfig.value.server_url,
|
||||
bind_dn: ldapConfig.value.bind_dn,
|
||||
base_dn: ldapConfig.value.base_dn,
|
||||
user_search_filter: ldapConfig.value.user_search_filter,
|
||||
username_attr: ldapConfig.value.username_attr,
|
||||
email_attr: ldapConfig.value.email_attr,
|
||||
display_name_attr: ldapConfig.value.display_name_attr,
|
||||
is_enabled: ldapConfig.value.is_enabled,
|
||||
is_exclusive: ldapConfig.value.is_exclusive,
|
||||
use_starttls: ldapConfig.value.use_starttls,
|
||||
connect_timeout: ldapConfig.value.connect_timeout,
|
||||
...(ldapConfig.value.bind_password && { bind_password: ldapConfig.value.bind_password }),
|
||||
}
|
||||
const response = await adminApi.testLdapConnection(payload)
|
||||
if (response.success) {
|
||||
success('LDAP 连接测试成功')
|
||||
} else {
|
||||
error(`LDAP 连接测试失败: ${response.message}`)
|
||||
}
|
||||
} catch (err) {
|
||||
error('LDAP 连接测试失败')
|
||||
console.error('LDAP 连接测试失败:', err)
|
||||
} finally {
|
||||
testLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function handleClearPassword() {
|
||||
// 如果有输入内容,先清空输入框
|
||||
if (ldapConfig.value.bind_password) {
|
||||
ldapConfig.value.bind_password = ''
|
||||
return
|
||||
}
|
||||
// 标记要清除服务端密码(保存时生效)
|
||||
if (hasPassword.value) {
|
||||
clearPassword.value = true
|
||||
hasPassword.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
@@ -464,6 +464,30 @@
|
||||
</div>
|
||||
</div>
|
||||
</CardSection>
|
||||
|
||||
<!-- 系统版本信息 -->
|
||||
<CardSection
|
||||
title="系统信息"
|
||||
description="当前系统版本和构建信息"
|
||||
>
|
||||
<div class="flex items-center gap-4">
|
||||
<div class="flex items-center gap-2">
|
||||
<Label class="text-sm font-medium text-muted-foreground">版本:</Label>
|
||||
<span
|
||||
v-if="systemVersion"
|
||||
class="text-sm font-mono"
|
||||
>
|
||||
{{ systemVersion }}
|
||||
</span>
|
||||
<span
|
||||
v-else
|
||||
class="text-sm text-muted-foreground"
|
||||
>
|
||||
加载中...
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</CardSection>
|
||||
</div>
|
||||
|
||||
<!-- 导入配置对话框 -->
|
||||
@@ -475,7 +499,7 @@
|
||||
<div class="space-y-4">
|
||||
<div
|
||||
v-if="importPreview"
|
||||
class="p-3 bg-muted rounded-lg text-sm"
|
||||
class="text-sm"
|
||||
>
|
||||
<p class="font-medium mb-2">
|
||||
配置预览
|
||||
@@ -557,7 +581,7 @@
|
||||
class="space-y-4"
|
||||
>
|
||||
<div class="grid grid-cols-2 gap-4 text-sm">
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<div>
|
||||
<p class="font-medium">
|
||||
全局模型
|
||||
</p>
|
||||
@@ -567,7 +591,7 @@
|
||||
跳过: {{ importResult.stats.global_models.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<div>
|
||||
<p class="font-medium">
|
||||
提供商
|
||||
</p>
|
||||
@@ -577,7 +601,7 @@
|
||||
跳过: {{ importResult.stats.providers.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<div>
|
||||
<p class="font-medium">
|
||||
端点
|
||||
</p>
|
||||
@@ -587,7 +611,7 @@
|
||||
跳过: {{ importResult.stats.endpoints.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<div>
|
||||
<p class="font-medium">
|
||||
API Keys
|
||||
</p>
|
||||
@@ -596,7 +620,7 @@
|
||||
跳过: {{ importResult.stats.keys.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg col-span-2">
|
||||
<div class="col-span-2">
|
||||
<p class="font-medium">
|
||||
模型配置
|
||||
</p>
|
||||
@@ -642,7 +666,7 @@
|
||||
<div class="space-y-4">
|
||||
<div
|
||||
v-if="importUsersPreview"
|
||||
class="p-3 bg-muted rounded-lg text-sm"
|
||||
class="text-sm"
|
||||
>
|
||||
<p class="font-medium mb-2">
|
||||
数据预览
|
||||
@@ -652,6 +676,9 @@
|
||||
<li>
|
||||
API Keys: {{ importUsersPreview.users?.reduce((sum: number, u: any) => sum + (u.api_keys?.length || 0), 0) }} 个
|
||||
</li>
|
||||
<li v-if="importUsersPreview.standalone_keys?.length">
|
||||
独立余额 Keys: {{ importUsersPreview.standalone_keys.length }} 个
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
@@ -720,7 +747,7 @@
|
||||
class="space-y-4"
|
||||
>
|
||||
<div class="grid grid-cols-2 gap-4 text-sm">
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<div>
|
||||
<p class="font-medium">
|
||||
用户
|
||||
</p>
|
||||
@@ -730,7 +757,7 @@
|
||||
跳过: {{ importUsersResult.stats.users.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<div>
|
||||
<p class="font-medium">
|
||||
API Keys
|
||||
</p>
|
||||
@@ -739,6 +766,18 @@
|
||||
跳过: {{ importUsersResult.stats.api_keys.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div
|
||||
v-if="importUsersResult.stats.standalone_keys"
|
||||
class="col-span-2"
|
||||
>
|
||||
<p class="font-medium">
|
||||
独立余额 Keys
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importUsersResult.stats.standalone_keys.created }},
|
||||
跳过: {{ importUsersResult.stats.standalone_keys.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
@@ -839,6 +878,9 @@ const importUsersResult = ref<UsersImportResponse | null>(null)
|
||||
const usersMergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
|
||||
const usersMergeModeSelectOpen = ref(false)
|
||||
|
||||
// 系统版本信息
|
||||
const systemVersion = ref<string>('')
|
||||
|
||||
const systemConfig = ref<SystemConfig>({
|
||||
// 基础配置
|
||||
default_user_quota_usd: 10.0,
|
||||
@@ -890,9 +932,21 @@ const sensitiveHeadersStr = computed({
|
||||
})
|
||||
|
||||
onMounted(async () => {
|
||||
await loadSystemConfig()
|
||||
await Promise.all([
|
||||
loadSystemConfig(),
|
||||
loadSystemVersion()
|
||||
])
|
||||
})
|
||||
|
||||
async function loadSystemVersion() {
|
||||
try {
|
||||
const data = await adminApi.getSystemVersion()
|
||||
systemVersion.value = data.version
|
||||
} catch (err) {
|
||||
log.error('加载系统版本失败:', err)
|
||||
}
|
||||
}
|
||||
|
||||
async function loadSystemConfig() {
|
||||
try {
|
||||
const configs = [
|
||||
@@ -1178,12 +1232,6 @@ function handleUsersFileSelect(event: Event) {
|
||||
const content = e.target?.result as string
|
||||
const data = JSON.parse(content) as UsersExportData
|
||||
|
||||
// 验证版本
|
||||
if (data.version !== '1.0') {
|
||||
error(`不支持的配置版本: ${data.version}`)
|
||||
return
|
||||
}
|
||||
|
||||
importUsersPreview.value = data
|
||||
usersMergeMode.value = 'skip'
|
||||
importUsersDialogOpen.value = true
|
||||
|
||||
@@ -20,10 +20,11 @@
|
||||
</nav>
|
||||
|
||||
<!-- Header -->
|
||||
<header class="fixed top-0 left-0 right-0 z-50 border-b border-[#cc785c]/10 dark:border-[rgba(227,224,211,0.12)] bg-[#fafaf7]/90 dark:bg-[#191714]/95 backdrop-blur-xl transition-all">
|
||||
<div class="mx-auto max-w-7xl px-6 py-4">
|
||||
<div class="flex items-center justify-between">
|
||||
<!-- Logo & Brand -->
|
||||
<header class="sticky top-0 z-50 border-b border-[#cc785c]/10 dark:border-[rgba(227,224,211,0.12)] bg-[#fafaf7]/90 dark:bg-[#191714]/95 backdrop-blur-xl transition-all">
|
||||
<div class="h-16 flex items-center">
|
||||
<!-- Centered content container (max-w-7xl) -->
|
||||
<div class="mx-auto max-w-7xl w-full px-6 flex items-center justify-between">
|
||||
<!-- Left: Logo & Brand -->
|
||||
<div
|
||||
class="flex items-center gap-3 group/logo cursor-pointer"
|
||||
@click="scrollToSection(0)"
|
||||
@@ -40,7 +41,7 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Center Navigation -->
|
||||
<!-- Center: Navigation -->
|
||||
<nav class="hidden md:flex items-center gap-2">
|
||||
<button
|
||||
v-for="(section, index) in sections"
|
||||
@@ -59,42 +60,54 @@
|
||||
</button>
|
||||
</nav>
|
||||
|
||||
<!-- Right Actions -->
|
||||
<div class="flex items-center gap-3">
|
||||
<button
|
||||
class="flex h-9 w-9 items-center justify-center rounded-lg text-muted-foreground hover:text-foreground hover:bg-muted/50 transition"
|
||||
:title="themeMode === 'system' ? '跟随系统' : themeMode === 'dark' ? '深色模式' : '浅色模式'"
|
||||
@click="toggleDarkMode"
|
||||
>
|
||||
<SunMoon
|
||||
v-if="themeMode === 'system'"
|
||||
class="h-4 w-4"
|
||||
/>
|
||||
<Sun
|
||||
v-else-if="themeMode === 'light'"
|
||||
class="h-4 w-4"
|
||||
/>
|
||||
<Moon
|
||||
v-else
|
||||
class="h-4 w-4"
|
||||
/>
|
||||
</button>
|
||||
<!-- Right: Login/Dashboard Button -->
|
||||
<RouterLink
|
||||
v-if="authStore.isAuthenticated"
|
||||
:to="dashboardPath"
|
||||
class="min-w-[72px] text-center rounded-xl bg-[#191919] dark:bg-[#cc785c] px-4 py-2 text-sm font-medium text-white shadow-sm transition hover:bg-[#262625] dark:hover:bg-[#b86d52] whitespace-nowrap"
|
||||
>
|
||||
控制台
|
||||
</RouterLink>
|
||||
<button
|
||||
v-else
|
||||
class="min-w-[72px] text-center rounded-xl bg-[#cc785c] px-4 py-2 text-sm font-medium text-white shadow-lg shadow-[#cc785c]/30 transition hover:bg-[#d4a27f] whitespace-nowrap"
|
||||
@click="showLoginDialog = true"
|
||||
>
|
||||
登录
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<RouterLink
|
||||
v-if="authStore.isAuthenticated"
|
||||
:to="dashboardPath"
|
||||
class="rounded-xl bg-[#191919] dark:bg-[#cc785c] px-4 py-2 text-sm font-medium text-white shadow-sm transition hover:bg-[#262625] dark:hover:bg-[#b86d52]"
|
||||
>
|
||||
控制台
|
||||
</RouterLink>
|
||||
<button
|
||||
<!-- Fixed right icons (px-8 to match dashboard) -->
|
||||
<div class="absolute right-8 top-1/2 -translate-y-1/2 flex items-center gap-2">
|
||||
<!-- Theme Toggle -->
|
||||
<button
|
||||
class="flex h-9 w-9 items-center justify-center rounded-lg text-muted-foreground hover:text-foreground hover:bg-muted/50 transition"
|
||||
:title="themeMode === 'system' ? '跟随系统' : themeMode === 'dark' ? '深色模式' : '浅色模式'"
|
||||
@click="toggleDarkMode"
|
||||
>
|
||||
<SunMoon
|
||||
v-if="themeMode === 'system'"
|
||||
class="h-4 w-4"
|
||||
/>
|
||||
<Sun
|
||||
v-else-if="themeMode === 'light'"
|
||||
class="h-4 w-4"
|
||||
/>
|
||||
<Moon
|
||||
v-else
|
||||
class="rounded-xl bg-[#cc785c] px-4 py-2 text-sm font-medium text-white shadow-lg shadow-[#cc785c]/30 transition hover:bg-[#d4a27f]"
|
||||
@click="showLoginDialog = true"
|
||||
>
|
||||
登录
|
||||
</button>
|
||||
</div>
|
||||
class="h-4 w-4"
|
||||
/>
|
||||
</button>
|
||||
<!-- GitHub Link -->
|
||||
<a
|
||||
href="https://github.com/fawney19/Aether"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
class="flex h-9 w-9 items-center justify-center rounded-lg text-muted-foreground hover:text-foreground hover:bg-muted/50 transition"
|
||||
title="GitHub 仓库"
|
||||
>
|
||||
<GithubIcon class="h-4 w-4" />
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
@@ -336,31 +349,6 @@
|
||||
</section>
|
||||
</main>
|
||||
|
||||
<!-- Footer -->
|
||||
<footer class="relative z-10 border-t border-[#cc785c]/10 dark:border-[rgba(227,224,211,0.12)] bg-[#fafaf7]/90 dark:bg-[#191714]/95 backdrop-blur-md py-8">
|
||||
<div class="mx-auto max-w-7xl px-6">
|
||||
<div class="flex flex-col items-center justify-between gap-4 sm:flex-row">
|
||||
<p class="text-sm text-[#91918d] dark:text-muted-foreground">
|
||||
© 2025 Aether. 团队内部使用
|
||||
</p>
|
||||
<div class="flex items-center gap-6 text-sm text-[#91918d] dark:text-muted-foreground">
|
||||
<a
|
||||
href="#"
|
||||
class="transition hover:text-[#191919] dark:hover:text-white"
|
||||
>使用条款</a>
|
||||
<a
|
||||
href="#"
|
||||
class="transition hover:text-[#191919] dark:hover:text-white"
|
||||
>隐私政策</a>
|
||||
<a
|
||||
href="#"
|
||||
class="transition hover:text-[#191919] dark:hover:text-white"
|
||||
>技术支持</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</footer>
|
||||
|
||||
<LoginDialog v-model="showLoginDialog" />
|
||||
</div>
|
||||
</template>
|
||||
@@ -378,6 +366,7 @@ import {
|
||||
SunMoon,
|
||||
Terminal
|
||||
} from 'lucide-vue-next'
|
||||
import GithubIcon from '@/components/icons/GithubIcon.vue'
|
||||
import { useAuthStore } from '@/stores/auth'
|
||||
import { useDarkMode } from '@/composables/useDarkMode'
|
||||
import { useClipboard } from '@/composables/useClipboard'
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
<ActivityHeatmapCard
|
||||
:data="activityHeatmapData"
|
||||
:title="isAdminPage ? '总体活跃天数' : '我的活跃天数'"
|
||||
:is-loading="isLoadingHeatmap"
|
||||
:has-error="heatmapError"
|
||||
/>
|
||||
<IntervalTimelineCard
|
||||
:title="isAdminPage ? '请求间隔时间线' : '我的请求间隔'"
|
||||
@@ -54,6 +56,7 @@
|
||||
:show-actual-cost="authStore.isAdmin"
|
||||
:loading="isLoadingRecords"
|
||||
:selected-period="selectedPeriod"
|
||||
:filter-search="filterSearch"
|
||||
:filter-user="filterUser"
|
||||
:filter-model="filterModel"
|
||||
:filter-provider="filterProvider"
|
||||
@@ -67,6 +70,7 @@
|
||||
:page-size-options="pageSizeOptions"
|
||||
:auto-refresh="globalAutoRefresh"
|
||||
@update:selected-period="handlePeriodChange"
|
||||
@update:filter-search="handleFilterSearchChange"
|
||||
@update:filter-user="handleFilterUserChange"
|
||||
@update:filter-model="handleFilterModelChange"
|
||||
@update:filter-provider="handleFilterProviderChange"
|
||||
@@ -112,8 +116,11 @@ import {
|
||||
import type { PeriodValue, FilterStatusValue } from '@/features/usage/types'
|
||||
import type { UserOption } from '@/features/usage/components/UsageRecordsTable.vue'
|
||||
import { log } from '@/utils/logger'
|
||||
import type { ActivityHeatmap } from '@/types/activity'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
|
||||
const route = useRoute()
|
||||
const { warning } = useToast()
|
||||
const authStore = useAuthStore()
|
||||
|
||||
// 判断是否是管理员页面
|
||||
@@ -128,6 +135,7 @@ const pageSize = ref(20)
|
||||
const pageSizeOptions = [10, 20, 50, 100]
|
||||
|
||||
// 筛选状态
|
||||
const filterSearch = ref('')
|
||||
const filterUser = ref('__all__')
|
||||
const filterModel = ref('__all__')
|
||||
const filterProvider = ref('__all__')
|
||||
@@ -144,13 +152,35 @@ const {
|
||||
currentRecords,
|
||||
totalRecords,
|
||||
enhancedModelStats,
|
||||
activityHeatmapData,
|
||||
availableModels,
|
||||
availableProviders,
|
||||
loadStats,
|
||||
loadRecords
|
||||
} = useUsageData({ isAdminPage })
|
||||
|
||||
// 热力图状态
|
||||
const activityHeatmapData = ref<ActivityHeatmap | null>(null)
|
||||
const isLoadingHeatmap = ref(false)
|
||||
const heatmapError = ref(false)
|
||||
|
||||
// 加载热力图数据
|
||||
async function loadHeatmapData() {
|
||||
isLoadingHeatmap.value = true
|
||||
heatmapError.value = false
|
||||
try {
|
||||
if (isAdminPage.value) {
|
||||
activityHeatmapData.value = await usageApi.getActivityHeatmap()
|
||||
} else {
|
||||
activityHeatmapData.value = await meApi.getActivityHeatmap()
|
||||
}
|
||||
} catch (error) {
|
||||
log.error('加载热力图数据失败:', error)
|
||||
heatmapError.value = true
|
||||
} finally {
|
||||
isLoadingHeatmap.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 用户页面需要前端筛选
|
||||
const filteredRecords = computed(() => {
|
||||
if (!isAdminPage.value) {
|
||||
@@ -232,27 +262,40 @@ async function pollActiveRequests() {
|
||||
? await usageApi.getActiveRequests(activeRequestIds.value)
|
||||
: await meApi.getActiveRequests(idsParam)
|
||||
|
||||
// 检查是否有状态变化
|
||||
let hasChanges = false
|
||||
let shouldRefresh = false
|
||||
|
||||
for (const update of requests) {
|
||||
const record = currentRecords.value.find(r => r.id === update.id)
|
||||
if (record && record.status !== update.status) {
|
||||
hasChanges = true
|
||||
// 如果状态变为 completed 或 failed,需要刷新获取完整数据
|
||||
if (update.status === 'completed' || update.status === 'failed') {
|
||||
break
|
||||
}
|
||||
// 否则只更新状态和 token 信息
|
||||
if (!record) {
|
||||
// 后端返回了未知的活跃请求,触发刷新以获取完整数据
|
||||
shouldRefresh = true
|
||||
continue
|
||||
}
|
||||
|
||||
// 状态变化:completed/failed 需要刷新获取完整数据
|
||||
if (record.status !== update.status) {
|
||||
record.status = update.status
|
||||
record.input_tokens = update.input_tokens
|
||||
record.output_tokens = update.output_tokens
|
||||
record.cost = update.cost
|
||||
record.response_time_ms = update.response_time_ms ?? undefined
|
||||
}
|
||||
if (update.status === 'completed' || update.status === 'failed') {
|
||||
shouldRefresh = true
|
||||
}
|
||||
|
||||
// 进行中状态也需要持续更新(provider/key/TTFB 可能在 streaming 后才落库)
|
||||
record.input_tokens = update.input_tokens
|
||||
record.output_tokens = update.output_tokens
|
||||
record.cost = update.cost
|
||||
record.response_time_ms = update.response_time_ms ?? undefined
|
||||
record.first_byte_time_ms = update.first_byte_time_ms ?? undefined
|
||||
// 管理员接口返回额外字段
|
||||
if ('provider' in update && typeof update.provider === 'string') {
|
||||
record.provider = update.provider
|
||||
}
|
||||
if ('api_key_name' in update) {
|
||||
record.api_key_name = typeof update.api_key_name === 'string' ? update.api_key_name : undefined
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有请求完成或失败,刷新整个列表获取完整数据
|
||||
if (hasChanges && requests.some(r => r.status === 'completed' || r.status === 'failed')) {
|
||||
if (shouldRefresh) {
|
||||
await refreshData()
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -335,16 +378,34 @@ const selectedRequestId = ref<string | null>(null)
|
||||
// 初始化加载
|
||||
onMounted(async () => {
|
||||
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
|
||||
await loadStats(dateRange)
|
||||
|
||||
// 管理员页面加载用户列表和第一页记录
|
||||
// 并行加载统计数据和热力图(使用 allSettled 避免其中一个失败影响另一个)
|
||||
const [statsResult, heatmapResult] = await Promise.allSettled([
|
||||
loadStats(dateRange),
|
||||
loadHeatmapData()
|
||||
])
|
||||
|
||||
// 检查加载结果并通知用户
|
||||
if (statsResult.status === 'rejected') {
|
||||
log.error('加载统计数据失败:', statsResult.reason)
|
||||
warning('统计数据加载失败,请刷新重试')
|
||||
}
|
||||
if (heatmapResult.status === 'rejected') {
|
||||
log.error('加载热力图数据失败:', heatmapResult.reason)
|
||||
// 热力图加载失败不提示,因为 UI 已显示占位符
|
||||
}
|
||||
|
||||
// 加载记录和用户列表
|
||||
if (isAdminPage.value) {
|
||||
// 并行加载用户列表和记录
|
||||
// 管理员页面:并行加载用户列表和记录
|
||||
const [users] = await Promise.all([
|
||||
usersApi.getAllUsers(),
|
||||
loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
|
||||
])
|
||||
availableUsers.value = users.map(u => ({ id: u.id, username: u.username, email: u.email }))
|
||||
} else {
|
||||
// 用户页面:加载记录
|
||||
await loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
})
|
||||
|
||||
@@ -355,34 +416,26 @@ async function handlePeriodChange(value: string) {
|
||||
|
||||
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
|
||||
await loadStats(dateRange)
|
||||
|
||||
if (isAdminPage.value) {
|
||||
await loadRecords({ page: 1, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
await loadRecords({ page: 1, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
|
||||
// 处理分页变化
|
||||
async function handlePageChange(page: number) {
|
||||
currentPage.value = page
|
||||
|
||||
if (isAdminPage.value) {
|
||||
await loadRecords({ page, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
await loadRecords({ page, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
|
||||
// 处理每页大小变化
|
||||
async function handlePageSizeChange(size: number) {
|
||||
pageSize.value = size
|
||||
currentPage.value = 1 // 重置到第一页
|
||||
|
||||
if (isAdminPage.value) {
|
||||
await loadRecords({ page: 1, pageSize: size }, getCurrentFilters())
|
||||
}
|
||||
await loadRecords({ page: 1, pageSize: size }, getCurrentFilters())
|
||||
}
|
||||
|
||||
// 获取当前筛选参数
|
||||
function getCurrentFilters() {
|
||||
return {
|
||||
search: filterSearch.value.trim() || undefined,
|
||||
user_id: filterUser.value !== '__all__' ? filterUser.value : undefined,
|
||||
model: filterModel.value !== '__all__' ? filterModel.value : undefined,
|
||||
provider: filterProvider.value !== '__all__' ? filterProvider.value : undefined,
|
||||
@@ -391,6 +444,13 @@ function getCurrentFilters() {
|
||||
}
|
||||
|
||||
// 处理筛选变化
|
||||
async function handleFilterSearchChange(value: string) {
|
||||
filterSearch.value = value
|
||||
currentPage.value = 1
|
||||
|
||||
await loadRecords({ page: 1, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
|
||||
async function handleFilterUserChange(value: string) {
|
||||
filterUser.value = value
|
||||
currentPage.value = 1 // 重置到第一页
|
||||
@@ -431,10 +491,7 @@ async function handleFilterStatusChange(value: string) {
|
||||
async function refreshData() {
|
||||
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
|
||||
await loadStats(dateRange)
|
||||
|
||||
if (isAdminPage.value) {
|
||||
await loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
await loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
|
||||
}
|
||||
|
||||
// 显示请求详情
|
||||
|
||||
@@ -47,6 +47,7 @@ dependencies = [
|
||||
"redis>=5.0.0",
|
||||
"prometheus-client>=0.20.0",
|
||||
"apscheduler>=3.10.0",
|
||||
"ldap3>=2.9.1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
||||
commit_id: COMMIT_ID
|
||||
__commit_id__: COMMIT_ID
|
||||
|
||||
__version__ = version = '0.1.1.dev0+g393d4d13f.d20251213'
|
||||
__version_tuple__ = version_tuple = (0, 1, 1, 'dev0', 'g393d4d13f.d20251213')
|
||||
__version__ = version = '0.2.3.dev0+g0f78d5cbf.d20260105'
|
||||
__version_tuple__ = version_tuple = (0, 2, 3, 'dev0', 'g0f78d5cbf.d20260105')
|
||||
|
||||
__commit_id__ = commit_id = None
|
||||
|
||||
@@ -5,6 +5,7 @@ from fastapi import APIRouter
|
||||
from .adaptive import router as adaptive_router
|
||||
from .api_keys import router as api_keys_router
|
||||
from .endpoints import router as endpoints_router
|
||||
from .ldap import router as ldap_router
|
||||
from .models import router as models_router
|
||||
from .monitoring import router as monitoring_router
|
||||
from .provider_query import router as provider_query_router
|
||||
@@ -28,5 +29,6 @@ router.include_router(adaptive_router)
|
||||
router.include_router(models_router)
|
||||
router.include_router(security_router)
|
||||
router.include_router(provider_query_router)
|
||||
router.include_router(ldap_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
|
||||
@@ -3,22 +3,64 @@
|
||||
独立余额Key:不关联用户配额,有独立余额限制,用于给非注册用户使用。
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import NotFoundException
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.api import CreateApiKeyRequest
|
||||
from src.models.database import ApiKey, User
|
||||
from src.models.database import ApiKey
|
||||
from src.services.user.apikey import ApiKeyService
|
||||
|
||||
|
||||
# 应用时区配置,默认为 Asia/Shanghai
|
||||
APP_TIMEZONE = ZoneInfo(os.getenv("APP_TIMEZONE", "Asia/Shanghai"))
|
||||
|
||||
|
||||
def parse_expiry_date(date_str: Optional[str]) -> Optional[datetime]:
|
||||
"""解析过期日期字符串为 datetime 对象。
|
||||
|
||||
Args:
|
||||
date_str: 日期字符串,支持 "YYYY-MM-DD" 或 ISO 格式
|
||||
|
||||
Returns:
|
||||
datetime 对象(当天 23:59:59.999999,应用时区),或 None 如果输入为空
|
||||
|
||||
Raises:
|
||||
BadRequestException: 日期格式无效
|
||||
"""
|
||||
if not date_str or not date_str.strip():
|
||||
return None
|
||||
|
||||
date_str = date_str.strip()
|
||||
|
||||
# 尝试 YYYY-MM-DD 格式
|
||||
try:
|
||||
parsed_date = datetime.strptime(date_str, "%Y-%m-%d")
|
||||
# 设置为当天结束时间 (23:59:59.999999,应用时区)
|
||||
return parsed_date.replace(
|
||||
hour=23, minute=59, second=59, microsecond=999999, tzinfo=APP_TIMEZONE
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 尝试完整 ISO 格式
|
||||
try:
|
||||
return datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
raise InvalidRequestException(f"无效的日期格式: {date_str},请使用 YYYY-MM-DD 格式")
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/admin/api-keys", tags=["Admin - API Keys (Standalone)"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
@@ -215,6 +257,9 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
|
||||
# 独立Key需要关联到管理员用户(从context获取)
|
||||
admin_user_id = context.user.id
|
||||
|
||||
# 解析过期时间(优先使用 expires_at,其次使用 expire_days)
|
||||
expires_at_dt = parse_expiry_date(self.key_data.expires_at)
|
||||
|
||||
# 创建独立Key
|
||||
api_key, plain_key = ApiKeyService.create_api_key(
|
||||
db=db,
|
||||
@@ -224,7 +269,8 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
|
||||
allowed_api_formats=self.key_data.allowed_api_formats,
|
||||
allowed_models=self.key_data.allowed_models,
|
||||
rate_limit=self.key_data.rate_limit, # None 表示不限制
|
||||
expire_days=self.key_data.expire_days,
|
||||
expire_days=self.key_data.expire_days, # 兼容旧版
|
||||
expires_at=expires_at_dt, # 优先使用
|
||||
initial_balance_usd=self.key_data.initial_balance_usd,
|
||||
is_standalone=True, # 标记为独立Key
|
||||
auto_delete_on_expiry=self.key_data.auto_delete_on_expiry,
|
||||
@@ -270,7 +316,8 @@ class AdminUpdateApiKeyAdapter(AdminApiAdapter):
|
||||
update_data = {}
|
||||
if self.key_data.name is not None:
|
||||
update_data["name"] = self.key_data.name
|
||||
if self.key_data.rate_limit is not None:
|
||||
# rate_limit: 显式传递时更新(包括 null 表示无限制)
|
||||
if "rate_limit" in self.key_data.model_fields_set:
|
||||
update_data["rate_limit"] = self.key_data.rate_limit
|
||||
if (
|
||||
hasattr(self.key_data, "auto_delete_on_expiry")
|
||||
@@ -287,19 +334,21 @@ class AdminUpdateApiKeyAdapter(AdminApiAdapter):
|
||||
update_data["allowed_models"] = self.key_data.allowed_models
|
||||
|
||||
# 处理过期时间
|
||||
if self.key_data.expire_days is not None:
|
||||
if self.key_data.expire_days > 0:
|
||||
from datetime import timedelta
|
||||
|
||||
# 优先使用 expires_at(如果显式传递且有值)
|
||||
if self.key_data.expires_at and self.key_data.expires_at.strip():
|
||||
update_data["expires_at"] = parse_expiry_date(self.key_data.expires_at)
|
||||
elif "expires_at" in self.key_data.model_fields_set:
|
||||
# expires_at 明确传递为 null 或空字符串,设为永不过期
|
||||
update_data["expires_at"] = None
|
||||
# 兼容旧版 expire_days
|
||||
elif "expire_days" in self.key_data.model_fields_set:
|
||||
if self.key_data.expire_days is not None and self.key_data.expire_days > 0:
|
||||
update_data["expires_at"] = datetime.now(timezone.utc) + timedelta(
|
||||
days=self.key_data.expire_days
|
||||
)
|
||||
else:
|
||||
# expire_days = 0 或负数表示永不过期
|
||||
# expire_days = None/0/负数 表示永不过期
|
||||
update_data["expires_at"] = None
|
||||
elif hasattr(self.key_data, "expire_days") and self.key_data.expire_days is None:
|
||||
# 明确传递 None,设为永不过期
|
||||
update_data["expires_at"] = None
|
||||
|
||||
# 使用 ApiKeyService 更新
|
||||
updated_key = ApiKeyService.update_api_key(db, self.key_id, **update_data)
|
||||
|
||||
@@ -206,6 +206,7 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
|
||||
provider_id=self.provider_id,
|
||||
api_format=self.endpoint_data.api_format,
|
||||
base_url=self.endpoint_data.base_url,
|
||||
custom_path=self.endpoint_data.custom_path,
|
||||
headers=self.endpoint_data.headers,
|
||||
timeout=self.endpoint_data.timeout,
|
||||
max_retries=self.endpoint_data.max_retries,
|
||||
|
||||
427
src/api/admin/ldap.py
Normal file
427
src/api/admin/ldap.py
Normal file
@@ -0,0 +1,427 @@
|
||||
"""LDAP配置管理API端点。"""
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.enums import AuthSource
|
||||
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import AuditEventType, LDAPConfig, User, UserRole
|
||||
from src.services.system.audit import AuditService
|
||||
|
||||
router = APIRouter(prefix="/api/admin/ldap", tags=["Admin - LDAP"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
# bcrypt 哈希格式正则:$2a$, $2b$, $2y$ + 2位cost + $ + 53字符(22位salt + 31位hash)
|
||||
BCRYPT_HASH_PATTERN = re.compile(r"^\$2[aby]\$\d{2}\$.{53}$")
|
||||
|
||||
|
||||
# ========== Request/Response Models ==========
|
||||
|
||||
|
||||
class LDAPConfigResponse(BaseModel):
|
||||
"""LDAP配置响应(不返回密码)"""
|
||||
|
||||
server_url: Optional[str] = None
|
||||
bind_dn: Optional[str] = None
|
||||
base_dn: Optional[str] = None
|
||||
has_bind_password: bool = False
|
||||
user_search_filter: str
|
||||
username_attr: str
|
||||
email_attr: str
|
||||
display_name_attr: str
|
||||
is_enabled: bool
|
||||
is_exclusive: bool
|
||||
use_starttls: bool
|
||||
connect_timeout: int
|
||||
|
||||
|
||||
class LDAPConfigUpdate(BaseModel):
|
||||
"""LDAP配置更新请求"""
|
||||
|
||||
server_url: str = Field(..., min_length=1, max_length=255)
|
||||
bind_dn: str = Field(..., min_length=1, max_length=255)
|
||||
# 允许空字符串表示"清除密码";非空时自动 strip 并校验不能为空
|
||||
bind_password: Optional[str] = Field(None, max_length=1024)
|
||||
base_dn: str = Field(..., min_length=1, max_length=255)
|
||||
user_search_filter: str = Field(default="(uid={username})", max_length=500)
|
||||
username_attr: str = Field(default="uid", max_length=50)
|
||||
email_attr: str = Field(default="mail", max_length=50)
|
||||
display_name_attr: str = Field(default="cn", max_length=50)
|
||||
is_enabled: bool = False
|
||||
is_exclusive: bool = False
|
||||
use_starttls: bool = False
|
||||
connect_timeout: int = Field(default=10, ge=1, le=60) # 单次操作超时,跨国网络建议 15-30 秒
|
||||
|
||||
@field_validator("bind_password")
|
||||
@classmethod
|
||||
def validate_bind_password(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None or v == "":
|
||||
return v
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("绑定密码不能为空")
|
||||
return v
|
||||
|
||||
@field_validator("user_search_filter")
|
||||
@classmethod
|
||||
def validate_search_filter(cls, v: str) -> str:
|
||||
if "{username}" not in v:
|
||||
raise ValueError("搜索过滤器必须包含 {username} 占位符")
|
||||
# 验证括号匹配和嵌套正确性
|
||||
depth = 0
|
||||
for char in v:
|
||||
if char == "(":
|
||||
depth += 1
|
||||
elif char == ")":
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
raise ValueError("搜索过滤器括号不匹配")
|
||||
if depth != 0:
|
||||
raise ValueError("搜索过滤器括号不匹配")
|
||||
# 限制过滤器复杂度,防止构造复杂查询
|
||||
# 检查嵌套层数而非括号总数
|
||||
depth = 0
|
||||
max_depth = 0
|
||||
for char in v:
|
||||
if char == "(":
|
||||
depth += 1
|
||||
max_depth = max(max_depth, depth)
|
||||
elif char == ")":
|
||||
depth -= 1
|
||||
if max_depth > 5:
|
||||
raise ValueError("搜索过滤器嵌套层数过深(最多5层)")
|
||||
if len(v) > 200:
|
||||
raise ValueError("搜索过滤器过长(最多200字符)")
|
||||
return v
|
||||
|
||||
|
||||
class LDAPTestResponse(BaseModel):
|
||||
"""LDAP连接测试响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class LDAPConfigTest(BaseModel):
|
||||
"""LDAP配置测试请求(全部可选,用于临时覆盖)"""
|
||||
|
||||
server_url: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
bind_dn: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
bind_password: Optional[str] = Field(None, min_length=1)
|
||||
base_dn: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
user_search_filter: Optional[str] = Field(None, max_length=500)
|
||||
username_attr: Optional[str] = Field(None, max_length=50)
|
||||
email_attr: Optional[str] = Field(None, max_length=50)
|
||||
display_name_attr: Optional[str] = Field(None, max_length=50)
|
||||
is_enabled: Optional[bool] = None
|
||||
is_exclusive: Optional[bool] = None
|
||||
use_starttls: Optional[bool] = None
|
||||
connect_timeout: Optional[int] = Field(None, ge=1, le=60)
|
||||
|
||||
@field_validator("user_search_filter")
|
||||
@classmethod
|
||||
def validate_search_filter(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
if "{username}" not in v:
|
||||
raise ValueError("搜索过滤器必须包含 {username} 占位符")
|
||||
# 验证括号匹配和嵌套正确性
|
||||
depth = 0
|
||||
for char in v:
|
||||
if char == "(":
|
||||
depth += 1
|
||||
elif char == ")":
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
raise ValueError("搜索过滤器括号不匹配")
|
||||
if depth != 0:
|
||||
raise ValueError("搜索过滤器括号不匹配")
|
||||
# 限制过滤器复杂度(检查嵌套层数而非括号总数)
|
||||
depth = 0
|
||||
max_depth = 0
|
||||
for char in v:
|
||||
if char == "(":
|
||||
depth += 1
|
||||
max_depth = max(max_depth, depth)
|
||||
elif char == ")":
|
||||
depth -= 1
|
||||
if max_depth > 5:
|
||||
raise ValueError("搜索过滤器嵌套层数过深(最多5层)")
|
||||
if len(v) > 200:
|
||||
raise ValueError("搜索过滤器过长(最多200字符)")
|
||||
return v
|
||||
|
||||
|
||||
# ========== API Endpoints ==========
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_ldap_config(request: Request, db: Session = Depends(get_db)) -> Any:
|
||||
"""获取LDAP配置(管理员)"""
|
||||
adapter = AdminGetLDAPConfigAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/config")
|
||||
async def update_ldap_config(request: Request, db: Session = Depends(get_db)) -> Any:
|
||||
"""更新LDAP配置(管理员)"""
|
||||
adapter = AdminUpdateLDAPConfigAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_ldap_connection(request: Request, db: Session = Depends(get_db)) -> Any:
|
||||
"""测试LDAP连接(管理员)"""
|
||||
adapter = AdminTestLDAPConnectionAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ========== Adapters ==========
|
||||
|
||||
|
||||
class AdminGetLDAPConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context) -> Dict[str, Any]: # type: ignore[override]
|
||||
db = context.db
|
||||
config = db.query(LDAPConfig).first()
|
||||
|
||||
if not config:
|
||||
return LDAPConfigResponse(
|
||||
server_url=None,
|
||||
bind_dn=None,
|
||||
base_dn=None,
|
||||
has_bind_password=False,
|
||||
user_search_filter="(uid={username})",
|
||||
username_attr="uid",
|
||||
email_attr="mail",
|
||||
display_name_attr="cn",
|
||||
is_enabled=False,
|
||||
is_exclusive=False,
|
||||
use_starttls=False,
|
||||
connect_timeout=10,
|
||||
).model_dump()
|
||||
|
||||
return LDAPConfigResponse(
|
||||
server_url=config.server_url,
|
||||
bind_dn=config.bind_dn,
|
||||
base_dn=config.base_dn,
|
||||
has_bind_password=bool(config.bind_password_encrypted),
|
||||
user_search_filter=config.user_search_filter,
|
||||
username_attr=config.username_attr,
|
||||
email_attr=config.email_attr,
|
||||
display_name_attr=config.display_name_attr,
|
||||
is_enabled=config.is_enabled,
|
||||
is_exclusive=config.is_exclusive,
|
||||
use_starttls=config.use_starttls,
|
||||
connect_timeout=config.connect_timeout,
|
||||
).model_dump()
|
||||
|
||||
|
||||
class AdminUpdateLDAPConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context) -> Dict[str, str]: # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
try:
|
||||
config_update = LDAPConfigUpdate.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
# 使用行级锁防止并发修改导致的竞态条件
|
||||
config = db.query(LDAPConfig).with_for_update().first()
|
||||
is_new_config = config is None
|
||||
|
||||
if is_new_config:
|
||||
# 首次创建配置时必须提供密码
|
||||
if not config_update.bind_password:
|
||||
raise InvalidRequestException("首次配置 LDAP 时必须设置绑定密码")
|
||||
config = LDAPConfig()
|
||||
db.add(config)
|
||||
|
||||
# 需要启用 LDAP 且未提交新密码时,验证已保存密码可解密(避免开启后不可用)
|
||||
if config_update.is_enabled and config_update.bind_password is None:
|
||||
try:
|
||||
if not config.get_bind_password():
|
||||
raise InvalidRequestException("启用 LDAP 认证 需要先设置绑定密码")
|
||||
except InvalidRequestException:
|
||||
raise
|
||||
except Exception:
|
||||
raise InvalidRequestException("绑定密码解密失败,请重新设置绑定密码")
|
||||
|
||||
# 计算更新后的密码状态(用于校验是否可启用/独占)
|
||||
if config_update.bind_password is None:
|
||||
will_have_password = bool(config.bind_password_encrypted)
|
||||
elif config_update.bind_password == "":
|
||||
will_have_password = False
|
||||
else:
|
||||
will_have_password = True
|
||||
|
||||
# 独占模式必须启用 LDAP 且必须有绑定密码(防止误锁定)
|
||||
if config_update.is_exclusive and not config_update.is_enabled:
|
||||
raise InvalidRequestException("仅允许 LDAP 登录 需要先启用 LDAP 认证")
|
||||
if config_update.is_enabled and not will_have_password:
|
||||
raise InvalidRequestException("启用 LDAP 认证 需要先设置绑定密码")
|
||||
if config_update.is_exclusive and not will_have_password:
|
||||
raise InvalidRequestException("仅允许 LDAP 登录 需要先设置绑定密码")
|
||||
|
||||
config.server_url = config_update.server_url
|
||||
config.bind_dn = config_update.bind_dn
|
||||
config.base_dn = config_update.base_dn
|
||||
config.user_search_filter = config_update.user_search_filter
|
||||
config.username_attr = config_update.username_attr
|
||||
config.email_attr = config_update.email_attr
|
||||
config.display_name_attr = config_update.display_name_attr
|
||||
config.is_enabled = config_update.is_enabled
|
||||
config.is_exclusive = config_update.is_exclusive
|
||||
config.use_starttls = config_update.use_starttls
|
||||
config.connect_timeout = config_update.connect_timeout
|
||||
|
||||
# 启用独占模式前检查是否有足够的本地管理员(防止锁定)
|
||||
# 使用 with_for_update() 阻塞锁防止竞态条件(移除 nowait 确保并发安全)
|
||||
if config_update.is_enabled and config_update.is_exclusive:
|
||||
local_admins = (
|
||||
db.query(User)
|
||||
.filter(
|
||||
User.role == UserRole.ADMIN,
|
||||
User.auth_source == AuthSource.LOCAL,
|
||||
User.is_active.is_(True),
|
||||
User.is_deleted.is_(False),
|
||||
)
|
||||
.with_for_update()
|
||||
.all()
|
||||
)
|
||||
# 验证至少有一个管理员有有效的密码哈希(可以登录)
|
||||
# 使用严格的 bcrypt 格式校验:$2a$/$2b$/$2y$ + 2位cost + $ + 53字符
|
||||
valid_admin_count = sum(
|
||||
1
|
||||
for admin in local_admins
|
||||
if admin.password_hash
|
||||
and isinstance(admin.password_hash, str)
|
||||
and BCRYPT_HASH_PATTERN.match(admin.password_hash)
|
||||
)
|
||||
if valid_admin_count < 1:
|
||||
raise InvalidRequestException(
|
||||
"启用 LDAP 独占模式前,必须至少保留 1 个有效的本地管理员账户(含有效密码)作为紧急恢复通道"
|
||||
)
|
||||
|
||||
if config_update.bind_password is not None:
|
||||
if config_update.bind_password == "":
|
||||
# 显式清除密码(设置为 NULL)
|
||||
config.bind_password_encrypted = None
|
||||
password_changed = "cleared"
|
||||
else:
|
||||
config.bind_password_encrypted = crypto_service.encrypt(config_update.bind_password)
|
||||
password_changed = "updated"
|
||||
else:
|
||||
password_changed = None
|
||||
|
||||
db.commit()
|
||||
|
||||
# 记录审计日志
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type=AuditEventType.CONFIG_CHANGED,
|
||||
description=f"LDAP 配置已更新 (enabled={config_update.is_enabled}, exclusive={config_update.is_exclusive})",
|
||||
user_id=str(context.user.id) if context.user else None,
|
||||
metadata={
|
||||
"server_url": config_update.server_url,
|
||||
"is_enabled": config_update.is_enabled,
|
||||
"is_exclusive": config_update.is_exclusive,
|
||||
"password_changed": password_changed,
|
||||
"is_new_config": is_new_config,
|
||||
},
|
||||
)
|
||||
|
||||
return {"message": "LDAP配置更新成功"}
|
||||
|
||||
|
||||
class AdminTestLDAPConnectionAdapter(AdminApiAdapter):
|
||||
async def handle(self, context) -> Dict[str, Any]: # type: ignore[override]
|
||||
from src.services.auth.ldap import LDAPService
|
||||
|
||||
db = context.db
|
||||
if context.json_body is not None:
|
||||
payload = context.json_body
|
||||
elif not context.raw_body:
|
||||
payload = {}
|
||||
else:
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
saved_config = db.query(LDAPConfig).first()
|
||||
|
||||
try:
|
||||
overrides = LDAPConfigTest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
config_data: Dict[str, Any] = {}
|
||||
|
||||
if saved_config:
|
||||
config_data = {
|
||||
"server_url": saved_config.server_url,
|
||||
"bind_dn": saved_config.bind_dn,
|
||||
"base_dn": saved_config.base_dn,
|
||||
"user_search_filter": saved_config.user_search_filter,
|
||||
"username_attr": saved_config.username_attr,
|
||||
"email_attr": saved_config.email_attr,
|
||||
"display_name_attr": saved_config.display_name_attr,
|
||||
"use_starttls": saved_config.use_starttls,
|
||||
"connect_timeout": saved_config.connect_timeout,
|
||||
}
|
||||
|
||||
# 应用前端传入的覆盖值
|
||||
for field in [
|
||||
"server_url",
|
||||
"bind_dn",
|
||||
"base_dn",
|
||||
"user_search_filter",
|
||||
"username_attr",
|
||||
"email_attr",
|
||||
"display_name_attr",
|
||||
"use_starttls",
|
||||
"is_enabled",
|
||||
"is_exclusive",
|
||||
"connect_timeout",
|
||||
]:
|
||||
value = getattr(overrides, field)
|
||||
if value is not None:
|
||||
config_data[field] = value
|
||||
|
||||
# bind_password 优先使用 overrides;否则使用已保存的密码(允许保存密码无法解密时依然用 overrides 测试)
|
||||
if overrides.bind_password is not None:
|
||||
config_data["bind_password"] = overrides.bind_password
|
||||
elif saved_config and saved_config.bind_password_encrypted:
|
||||
try:
|
||||
config_data["bind_password"] = crypto_service.decrypt(
|
||||
saved_config.bind_password_encrypted
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"绑定密码解密失败: {type(e).__name__}: {e}")
|
||||
return LDAPTestResponse(
|
||||
success=False, message="绑定密码解密失败,请检查配置或重新设置密码"
|
||||
).model_dump()
|
||||
|
||||
# 必填字段检查
|
||||
required_fields = ["server_url", "bind_dn", "base_dn", "bind_password"]
|
||||
missing = [f for f in required_fields if not config_data.get(f)]
|
||||
if missing:
|
||||
return LDAPTestResponse(
|
||||
success=False, message=f"缺少必要字段: {', '.join(missing)}"
|
||||
).model_dump()
|
||||
|
||||
success, message = LDAPService.test_connection_with_config(config_data)
|
||||
return LDAPTestResponse(success=success, message=message).model_dump()
|
||||
@@ -146,20 +146,25 @@ class AdminListGlobalModelsAdapter(AdminApiAdapter):
|
||||
search=self.search,
|
||||
)
|
||||
|
||||
# 为每个 GlobalModel 添加统计数据
|
||||
# 一次性查询所有 GlobalModel 的 provider_count(优化 N+1 问题)
|
||||
model_ids = [gm.id for gm in models]
|
||||
provider_counts = {}
|
||||
if model_ids:
|
||||
count_results = (
|
||||
context.db.query(
|
||||
Model.global_model_id, func.count(func.distinct(Model.provider_id))
|
||||
)
|
||||
.filter(Model.global_model_id.in_(model_ids))
|
||||
.group_by(Model.global_model_id)
|
||||
.all()
|
||||
)
|
||||
provider_counts = {gm_id: count for gm_id, count in count_results}
|
||||
|
||||
# 构建响应
|
||||
model_responses = []
|
||||
for gm in models:
|
||||
# 统计关联的 Model 数量(去重 Provider)
|
||||
provider_count = (
|
||||
context.db.query(func.count(func.distinct(Model.provider_id)))
|
||||
.filter(Model.global_model_id == gm.id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
response = GlobalModelResponse.model_validate(gm)
|
||||
response.provider_count = provider_count
|
||||
# usage_count 直接从 GlobalModel 表读取,已在 model_validate 中自动映射
|
||||
response.provider_count = provider_counts.get(gm.id, 0)
|
||||
model_responses.append(response)
|
||||
|
||||
return GlobalModelListResponse(
|
||||
|
||||
@@ -107,6 +107,9 @@ class AdminGetAuditLogsAdapter(AdminApiAdapter):
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
# 查看审计日志本身不应该产生审计记录,避免刷新页面时产生大量无意义的日志
|
||||
audit_log_enabled: bool = False
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(days=self.days)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
提供商策略管理 API 端点
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
@@ -103,6 +103,9 @@ class AdminProviderBillingAdapter(AdminApiAdapter):
|
||||
|
||||
if config.quota_last_reset_at:
|
||||
new_reset_at = parser.parse(config.quota_last_reset_at)
|
||||
# 确保有时区信息,如果没有则假设为 UTC
|
||||
if new_reset_at.tzinfo is None:
|
||||
new_reset_at = new_reset_at.replace(tzinfo=timezone.utc)
|
||||
provider.quota_last_reset_at = new_reset_at
|
||||
|
||||
# 自动同步该周期内的历史使用量
|
||||
@@ -118,7 +121,11 @@ class AdminProviderBillingAdapter(AdminApiAdapter):
|
||||
logger.info(f"Synced usage for provider {provider.name}: ${period_usage:.4f} since {new_reset_at}")
|
||||
|
||||
if config.quota_expires_at:
|
||||
provider.quota_expires_at = parser.parse(config.quota_expires_at)
|
||||
expires_at = parser.parse(config.quota_expires_at)
|
||||
# 确保有时区信息,如果没有则假设为 UTC
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
provider.quota_expires_at = expires_at
|
||||
|
||||
db.commit()
|
||||
db.refresh(provider)
|
||||
@@ -149,7 +156,7 @@ class AdminProviderStatsAdapter(AdminApiAdapter):
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
since = datetime.now() - timedelta(hours=self.hours)
|
||||
since = datetime.now(timezone.utc) - timedelta(hours=self.hours)
|
||||
stats = (
|
||||
db.query(ProviderUsageTracking)
|
||||
.filter(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""系统设置API端点。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@@ -17,6 +19,46 @@ from src.services.email.email_template import EmailTemplate
|
||||
from src.services.system.config import SystemConfigService
|
||||
|
||||
router = APIRouter(prefix="/api/admin/system", tags=["Admin - System"])
|
||||
|
||||
|
||||
def _get_version_from_git() -> str | None:
|
||||
"""从 git describe 获取版本号"""
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "describe", "--tags", "--always"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
version = result.stdout.strip()
|
||||
if version.startswith("v"):
|
||||
version = version[1:]
|
||||
return version
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/version")
|
||||
async def get_system_version():
|
||||
"""获取系统版本信息"""
|
||||
# 优先从 git 获取
|
||||
version = _get_version_from_git()
|
||||
if version:
|
||||
return {"version": version}
|
||||
|
||||
# 回退到静态版本文件
|
||||
try:
|
||||
from src._version import __version__
|
||||
|
||||
return {"version": __version__}
|
||||
except ImportError:
|
||||
return {"version": "unknown"}
|
||||
|
||||
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@@ -950,6 +992,31 @@ class AdminExportUsersAdapter(AdminApiAdapter):
|
||||
|
||||
db = context.db
|
||||
|
||||
def _serialize_api_key(key: ApiKey, include_is_standalone: bool = False) -> dict:
|
||||
"""序列化 API Key 为导出格式"""
|
||||
data = {
|
||||
"key_hash": key.key_hash,
|
||||
"key_encrypted": key.key_encrypted,
|
||||
"name": key.name,
|
||||
"balance_used_usd": key.balance_used_usd,
|
||||
"current_balance_usd": key.current_balance_usd,
|
||||
"allowed_providers": key.allowed_providers,
|
||||
"allowed_endpoints": key.allowed_endpoints,
|
||||
"allowed_api_formats": key.allowed_api_formats,
|
||||
"allowed_models": key.allowed_models,
|
||||
"rate_limit": key.rate_limit,
|
||||
"concurrent_limit": key.concurrent_limit,
|
||||
"force_capabilities": key.force_capabilities,
|
||||
"is_active": key.is_active,
|
||||
"expires_at": key.expires_at.isoformat() if key.expires_at else None,
|
||||
"auto_delete_on_expiry": key.auto_delete_on_expiry,
|
||||
"total_requests": key.total_requests,
|
||||
"total_cost_usd": key.total_cost_usd,
|
||||
}
|
||||
if include_is_standalone:
|
||||
data["is_standalone"] = key.is_standalone
|
||||
return data
|
||||
|
||||
# 导出 Users(排除管理员)
|
||||
users = db.query(User).filter(
|
||||
User.is_deleted.is_(False),
|
||||
@@ -957,31 +1024,12 @@ class AdminExportUsersAdapter(AdminApiAdapter):
|
||||
).all()
|
||||
users_data = []
|
||||
for user in users:
|
||||
# 导出用户的 API Keys(保留加密数据)
|
||||
api_keys = db.query(ApiKey).filter(ApiKey.user_id == user.id).all()
|
||||
api_keys_data = []
|
||||
for key in api_keys:
|
||||
api_keys_data.append(
|
||||
{
|
||||
"key_hash": key.key_hash,
|
||||
"key_encrypted": key.key_encrypted,
|
||||
"name": key.name,
|
||||
"is_standalone": key.is_standalone,
|
||||
"balance_used_usd": key.balance_used_usd,
|
||||
"current_balance_usd": key.current_balance_usd,
|
||||
"allowed_providers": key.allowed_providers,
|
||||
"allowed_endpoints": key.allowed_endpoints,
|
||||
"allowed_api_formats": key.allowed_api_formats,
|
||||
"allowed_models": key.allowed_models,
|
||||
"rate_limit": key.rate_limit,
|
||||
"concurrent_limit": key.concurrent_limit,
|
||||
"force_capabilities": key.force_capabilities,
|
||||
"is_active": key.is_active,
|
||||
"auto_delete_on_expiry": key.auto_delete_on_expiry,
|
||||
"total_requests": key.total_requests,
|
||||
"total_cost_usd": key.total_cost_usd,
|
||||
}
|
||||
)
|
||||
# 导出用户的 API Keys(排除独立余额Key,独立Key单独导出)
|
||||
api_keys = db.query(ApiKey).filter(
|
||||
ApiKey.user_id == user.id,
|
||||
ApiKey.is_standalone.is_(False)
|
||||
).all()
|
||||
api_keys_data = [_serialize_api_key(key, include_is_standalone=True) for key in api_keys]
|
||||
|
||||
users_data.append(
|
||||
{
|
||||
@@ -1001,10 +1049,15 @@ class AdminExportUsersAdapter(AdminApiAdapter):
|
||||
}
|
||||
)
|
||||
|
||||
# 导出独立余额 Keys(管理员创建的,不属于普通用户)
|
||||
standalone_keys = db.query(ApiKey).filter(ApiKey.is_standalone.is_(True)).all()
|
||||
standalone_keys_data = [_serialize_api_key(key) for key in standalone_keys]
|
||||
|
||||
return {
|
||||
"version": "1.0",
|
||||
"version": "1.1",
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"users": users_data,
|
||||
"standalone_keys": standalone_keys_data,
|
||||
}
|
||||
|
||||
|
||||
@@ -1024,21 +1077,72 @@ class AdminImportUsersAdapter(AdminApiAdapter):
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
# 验证配置版本
|
||||
version = payload.get("version")
|
||||
if version != "1.0":
|
||||
raise InvalidRequestException(f"不支持的配置版本: {version}")
|
||||
|
||||
# 获取导入选项
|
||||
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
|
||||
users_data = payload.get("users", [])
|
||||
standalone_keys_data = payload.get("standalone_keys", [])
|
||||
|
||||
stats = {
|
||||
"users": {"created": 0, "updated": 0, "skipped": 0},
|
||||
"api_keys": {"created": 0, "skipped": 0},
|
||||
"standalone_keys": {"created": 0, "skipped": 0},
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
def _create_api_key_from_data(
|
||||
key_data: dict,
|
||||
owner_id: str,
|
||||
is_standalone: bool = False,
|
||||
) -> tuple[ApiKey | None, str]:
|
||||
"""从导入数据创建 ApiKey 对象
|
||||
|
||||
Returns:
|
||||
(ApiKey, "created"): 成功创建
|
||||
(None, "skipped"): key 已存在,跳过
|
||||
(None, "invalid"): 数据无效,跳过
|
||||
"""
|
||||
key_hash = key_data.get("key_hash", "").strip()
|
||||
if not key_hash:
|
||||
return None, "invalid"
|
||||
|
||||
# 检查是否已存在
|
||||
existing = db.query(ApiKey).filter(ApiKey.key_hash == key_hash).first()
|
||||
if existing:
|
||||
return None, "skipped"
|
||||
|
||||
# 解析 expires_at
|
||||
expires_at = None
|
||||
if key_data.get("expires_at"):
|
||||
try:
|
||||
expires_at = datetime.fromisoformat(key_data["expires_at"])
|
||||
except ValueError:
|
||||
stats["errors"].append(
|
||||
f"API Key '{key_data.get('name', key_hash[:8])}' 的 expires_at 格式无效"
|
||||
)
|
||||
|
||||
return ApiKey(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=owner_id,
|
||||
key_hash=key_hash,
|
||||
key_encrypted=key_data.get("key_encrypted"),
|
||||
name=key_data.get("name"),
|
||||
is_standalone=is_standalone or key_data.get("is_standalone", False),
|
||||
balance_used_usd=key_data.get("balance_used_usd", 0.0),
|
||||
current_balance_usd=key_data.get("current_balance_usd"),
|
||||
allowed_providers=key_data.get("allowed_providers"),
|
||||
allowed_endpoints=key_data.get("allowed_endpoints"),
|
||||
allowed_api_formats=key_data.get("allowed_api_formats"),
|
||||
allowed_models=key_data.get("allowed_models"),
|
||||
rate_limit=key_data.get("rate_limit"),
|
||||
concurrent_limit=key_data.get("concurrent_limit", 5),
|
||||
force_capabilities=key_data.get("force_capabilities"),
|
||||
is_active=key_data.get("is_active", True),
|
||||
expires_at=expires_at,
|
||||
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
|
||||
total_requests=key_data.get("total_requests", 0),
|
||||
total_cost_usd=key_data.get("total_cost_usd", 0.0),
|
||||
), "created"
|
||||
|
||||
try:
|
||||
for user_data in users_data:
|
||||
# 跳过管理员角色的导入(不区分大小写)
|
||||
@@ -1109,40 +1213,31 @@ class AdminImportUsersAdapter(AdminApiAdapter):
|
||||
|
||||
# 导入 API Keys
|
||||
for key_data in user_data.get("api_keys", []):
|
||||
# 检查是否已存在相同的 key_hash
|
||||
if key_data.get("key_hash"):
|
||||
existing_key = (
|
||||
db.query(ApiKey)
|
||||
.filter(ApiKey.key_hash == key_data["key_hash"])
|
||||
.first()
|
||||
)
|
||||
if existing_key:
|
||||
stats["api_keys"]["skipped"] += 1
|
||||
continue
|
||||
new_key, status = _create_api_key_from_data(key_data, user_id)
|
||||
if new_key:
|
||||
db.add(new_key)
|
||||
stats["api_keys"]["created"] += 1
|
||||
elif status == "skipped":
|
||||
stats["api_keys"]["skipped"] += 1
|
||||
# invalid 数据不计入统计
|
||||
|
||||
new_key = ApiKey(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
key_hash=key_data.get("key_hash", ""),
|
||||
key_encrypted=key_data.get("key_encrypted"),
|
||||
name=key_data.get("name"),
|
||||
is_standalone=key_data.get("is_standalone", False),
|
||||
balance_used_usd=key_data.get("balance_used_usd", 0.0),
|
||||
current_balance_usd=key_data.get("current_balance_usd"),
|
||||
allowed_providers=key_data.get("allowed_providers"),
|
||||
allowed_endpoints=key_data.get("allowed_endpoints"),
|
||||
allowed_api_formats=key_data.get("allowed_api_formats"),
|
||||
allowed_models=key_data.get("allowed_models"),
|
||||
rate_limit=key_data.get("rate_limit", 100),
|
||||
concurrent_limit=key_data.get("concurrent_limit", 5),
|
||||
force_capabilities=key_data.get("force_capabilities"),
|
||||
is_active=key_data.get("is_active", True),
|
||||
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
|
||||
total_requests=key_data.get("total_requests", 0),
|
||||
total_cost_usd=key_data.get("total_cost_usd", 0.0),
|
||||
)
|
||||
db.add(new_key)
|
||||
stats["api_keys"]["created"] += 1
|
||||
# 导入独立余额 Keys(需要找一个管理员用户作为 owner)
|
||||
if standalone_keys_data:
|
||||
# 查找一个管理员用户作为独立Key的owner
|
||||
admin_user = db.query(User).filter(User.role == UserRole.ADMIN).first()
|
||||
if not admin_user:
|
||||
stats["errors"].append("无法导入独立余额Key: 系统中没有管理员用户")
|
||||
else:
|
||||
for key_data in standalone_keys_data:
|
||||
new_key, status = _create_api_key_from_data(
|
||||
key_data, admin_user.id, is_standalone=True
|
||||
)
|
||||
if new_key:
|
||||
db.add(new_key)
|
||||
stats["standalone_keys"]["created"] += 1
|
||||
elif status == "skipped":
|
||||
stats["standalone_keys"]["skipped"] += 1
|
||||
# invalid 数据不计入统计
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
@@ -73,11 +73,26 @@ async def get_usage_stats(
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/heatmap")
|
||||
async def get_activity_heatmap(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get activity heatmap data for the past 365 days.
|
||||
|
||||
This endpoint is cached for 5 minutes to reduce database load.
|
||||
"""
|
||||
adapter = AdminActivityHeatmapAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/records")
|
||||
async def get_usage_records(
|
||||
request: Request,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
search: Optional[str] = None, # 通用搜索:用户名、密钥名、模型名、提供商名
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
@@ -90,6 +105,7 @@ async def get_usage_records(
|
||||
adapter = AdminUsageRecordsAdapter(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
search=search,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
model=model,
|
||||
@@ -168,12 +184,6 @@ class AdminUsageStatsAdapter(AdminApiAdapter):
|
||||
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
||||
).count()
|
||||
|
||||
activity_heatmap = UsageService.get_daily_activity(
|
||||
db=db,
|
||||
window_days=365,
|
||||
include_actual_cost=True,
|
||||
)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="usage_stats",
|
||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||
@@ -204,10 +214,22 @@ class AdminUsageStatsAdapter(AdminApiAdapter):
|
||||
),
|
||||
"cache_read_cost": float(cache_stats.cache_read_cost or 0) if cache_stats else 0,
|
||||
},
|
||||
"activity_heatmap": activity_heatmap,
|
||||
}
|
||||
|
||||
|
||||
class AdminActivityHeatmapAdapter(AdminApiAdapter):
|
||||
"""Activity heatmap adapter with Redis caching."""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
result = await UsageService.get_cached_heatmap(
|
||||
db=context.db,
|
||||
user_id=None,
|
||||
include_actual_cost=True,
|
||||
)
|
||||
context.add_audit_metadata(action="activity_heatmap")
|
||||
return result
|
||||
|
||||
|
||||
class AdminUsageByModelAdapter(AdminApiAdapter):
|
||||
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
|
||||
self.start_date = start_date
|
||||
@@ -480,6 +502,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
self,
|
||||
start_date: Optional[datetime],
|
||||
end_date: Optional[datetime],
|
||||
search: Optional[str],
|
||||
user_id: Optional[str],
|
||||
username: Optional[str],
|
||||
model: Optional[str],
|
||||
@@ -490,6 +513,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.search = search
|
||||
self.user_id = user_id
|
||||
self.username = username
|
||||
self.model = model
|
||||
@@ -499,25 +523,54 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
self.offset = offset
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from sqlalchemy import or_
|
||||
|
||||
from src.utils.database_helpers import escape_like_pattern, safe_truncate_escaped
|
||||
|
||||
db = context.db
|
||||
query = (
|
||||
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey)
|
||||
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey, ApiKey)
|
||||
.outerjoin(User, Usage.user_id == User.id)
|
||||
.outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
|
||||
.outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
|
||||
.outerjoin(ApiKey, Usage.api_key_id == ApiKey.id)
|
||||
)
|
||||
|
||||
# 如果需要按 Provider 名称搜索/筛选,统一在这里 JOIN
|
||||
if self.search or self.provider:
|
||||
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
|
||||
|
||||
# 通用搜索:用户名、密钥名、模型名、提供商名
|
||||
# 支持空格分隔的组合搜索,多个关键词之间是 AND 关系
|
||||
# 限制:最多 10 个关键词,转义后每个关键词最长 100 字符
|
||||
if self.search:
|
||||
keywords = [kw for kw in self.search.strip().split() if kw][:10]
|
||||
for keyword in keywords:
|
||||
escaped = safe_truncate_escaped(escape_like_pattern(keyword), 100)
|
||||
search_pattern = f"%{escaped}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
User.username.ilike(search_pattern, escape="\\"),
|
||||
ApiKey.name.ilike(search_pattern, escape="\\"),
|
||||
Usage.model.ilike(search_pattern, escape="\\"),
|
||||
Provider.name.ilike(search_pattern, escape="\\"),
|
||||
)
|
||||
)
|
||||
|
||||
if self.user_id:
|
||||
query = query.filter(Usage.user_id == self.user_id)
|
||||
if self.username:
|
||||
# 支持用户名模糊搜索
|
||||
query = query.filter(User.username.ilike(f"%{self.username}%"))
|
||||
escaped = escape_like_pattern(self.username)
|
||||
query = query.filter(User.username.ilike(f"%{escaped}%", escape="\\"))
|
||||
if self.model:
|
||||
# 支持模型名模糊搜索
|
||||
query = query.filter(Usage.model.ilike(f"%{self.model}%"))
|
||||
escaped = escape_like_pattern(self.model)
|
||||
query = query.filter(Usage.model.ilike(f"%{escaped}%", escape="\\"))
|
||||
if self.provider:
|
||||
# 支持提供商名称搜索(通过 Provider 表)
|
||||
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
|
||||
query = query.filter(Provider.name.ilike(f"%{self.provider}%"))
|
||||
# 支持提供商名称搜索
|
||||
escaped = escape_like_pattern(self.provider)
|
||||
query = query.filter(Provider.name.ilike(f"%{escaped}%", escape="\\"))
|
||||
if self.status:
|
||||
# 状态筛选
|
||||
# 旧的筛选值(基于 is_stream 和 status_code):stream, standard, error
|
||||
@@ -555,7 +608,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
|
||||
)
|
||||
|
||||
request_ids = [usage.request_id for usage, _, _, _ in records if usage.request_id]
|
||||
request_ids = [usage.request_id for usage, _, _, _, _ in records if usage.request_id]
|
||||
fallback_map = {}
|
||||
if request_ids:
|
||||
# 只统计实际执行的候选(success 或 failed),不包括 skipped/pending/available
|
||||
@@ -575,6 +628,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
action="usage_records",
|
||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||
end_date=self.end_date.isoformat() if self.end_date else None,
|
||||
search=self.search,
|
||||
user_id=self.user_id,
|
||||
username=self.username,
|
||||
model=self.model,
|
||||
@@ -586,7 +640,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
)
|
||||
|
||||
# 构建 provider_id -> Provider 名称的映射,避免 N+1 查询
|
||||
provider_ids = [usage.provider_id for usage, _, _, _ in records if usage.provider_id]
|
||||
provider_ids = [usage.provider_id for usage, _, _, _, _ in records if usage.provider_id]
|
||||
provider_map = {}
|
||||
if provider_ids:
|
||||
providers_data = (
|
||||
@@ -595,7 +649,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
provider_map = {str(p.id): p.name for p in providers_data}
|
||||
|
||||
data = []
|
||||
for usage, user, endpoint, api_key in records:
|
||||
for usage, user, endpoint, provider_api_key, user_api_key in records:
|
||||
actual_cost = (
|
||||
float(usage.actual_total_cost_usd)
|
||||
if usage.actual_total_cost_usd is not None
|
||||
@@ -616,6 +670,15 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
"user_id": user.id if user else None,
|
||||
"user_email": user.email if user else "已删除用户",
|
||||
"username": user.username if user else "已删除用户",
|
||||
"api_key": (
|
||||
{
|
||||
"id": user_api_key.id,
|
||||
"name": user_api_key.name,
|
||||
"display": user_api_key.get_display_key(),
|
||||
}
|
||||
if user_api_key
|
||||
else None
|
||||
),
|
||||
"provider": provider_name,
|
||||
"model": usage.model,
|
||||
"target_model": usage.target_model, # 映射后的目标模型名
|
||||
@@ -641,7 +704,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
"has_fallback": fallback_map.get(usage.request_id, False),
|
||||
"api_format": usage.api_format
|
||||
or (endpoint.api_format if endpoint and endpoint.api_format else None),
|
||||
"api_key_name": api_key.name if api_key else None,
|
||||
"api_key_name": provider_api_key.name if provider_api_key else None,
|
||||
"request_metadata": usage.request_metadata, # Provider 响应元数据
|
||||
}
|
||||
)
|
||||
@@ -670,7 +733,9 @@ class AdminActiveRequestsAdapter(AdminApiAdapter):
|
||||
if not id_list:
|
||||
return {"requests": []}
|
||||
|
||||
requests = UsageService.get_active_requests_status(db=db, ids=id_list)
|
||||
requests = UsageService.get_active_requests_status(
|
||||
db=db, ids=id_list, include_admin_fields=True
|
||||
)
|
||||
return {"requests": requests}
|
||||
|
||||
|
||||
|
||||
@@ -248,6 +248,7 @@ class AdminUpdateUserAdapter(AdminApiAdapter):
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
old_role = existing_user.role
|
||||
if "role" in update_data and update_data["role"]:
|
||||
if hasattr(update_data["role"], "value"):
|
||||
update_data["role"] = update_data["role"]
|
||||
@@ -258,6 +259,12 @@ class AdminUpdateUserAdapter(AdminApiAdapter):
|
||||
if not user:
|
||||
raise NotFoundException("用户不存在", "user")
|
||||
|
||||
# 角色变更时清除热力图缓存(影响 include_actual_cost 权限)
|
||||
if "role" in update_data and update_data["role"] != old_role:
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
await UsageService.clear_user_heatmap_cache(self.user_id)
|
||||
|
||||
changed_fields = list(update_data.keys())
|
||||
context.add_audit_metadata(
|
||||
action="update_user",
|
||||
@@ -424,7 +431,7 @@ class AdminCreateUserKeyAdapter(AdminApiAdapter):
|
||||
name=key_data.name,
|
||||
allowed_providers=key_data.allowed_providers,
|
||||
allowed_models=key_data.allowed_models,
|
||||
rate_limit=key_data.rate_limit or 100,
|
||||
rate_limit=key_data.rate_limit, # None = 无限制
|
||||
expire_days=key_data.expire_days,
|
||||
initial_balance_usd=None, # 普通Key不设置余额限制
|
||||
is_standalone=False, # 不是独立Key
|
||||
|
||||
@@ -33,6 +33,7 @@ from src.models.api import (
|
||||
)
|
||||
from src.models.database import AuditEventType, User, UserRole
|
||||
from src.services.auth.service import AuthService
|
||||
from src.services.auth.ldap import LDAPService
|
||||
from src.services.rate_limit.ip_limiter import IPRateLimiter
|
||||
from src.services.system.audit import AuditService
|
||||
from src.services.system.config import SystemConfigService
|
||||
@@ -99,6 +100,13 @@ async def registration_settings(request: Request, db: Session = Depends(get_db))
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/settings")
|
||||
async def auth_settings(request: Request, db: Session = Depends(get_db)):
|
||||
"""公开获取认证设置(用于前端判断显示哪些登录选项)"""
|
||||
adapter = AuthSettingsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
async def login(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AuthLoginAdapter()
|
||||
@@ -193,7 +201,9 @@ class AuthLoginAdapter(AuthPublicAdapter):
|
||||
detail=f"登录请求过于频繁,请在 {reset_after} 秒后重试",
|
||||
)
|
||||
|
||||
user = await AuthService.authenticate_user(db, login_request.email, login_request.password)
|
||||
user = await AuthService.authenticate_user(
|
||||
db, login_request.email, login_request.password, login_request.auth_type
|
||||
)
|
||||
if not user:
|
||||
AuditService.log_login_attempt(
|
||||
db=db,
|
||||
@@ -305,6 +315,21 @@ class AuthRegistrationSettingsAdapter(AuthPublicAdapter):
|
||||
).model_dump()
|
||||
|
||||
|
||||
class AuthSettingsAdapter(AuthPublicAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""公开返回认证设置"""
|
||||
db = context.db
|
||||
|
||||
ldap_enabled = LDAPService.is_ldap_enabled(db)
|
||||
ldap_exclusive = LDAPService.is_ldap_exclusive(db)
|
||||
|
||||
return {
|
||||
"local_enabled": not ldap_exclusive,
|
||||
"ldap_enabled": ldap_enabled,
|
||||
"ldap_exclusive": ldap_exclusive,
|
||||
}
|
||||
|
||||
|
||||
class AuthRegisterAdapter(AuthPublicAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.models.database import SystemConfig
|
||||
@@ -324,6 +349,12 @@ class AuthRegisterAdapter(AuthPublicAdapter):
|
||||
detail=f"注册请求过于频繁,请在 {reset_after} 秒后重试",
|
||||
)
|
||||
|
||||
# 仅允许 LDAP 登录时拒绝本地注册
|
||||
if LDAPService.is_ldap_exclusive(db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="系统已启用 LDAP 专属登录,禁止本地注册"
|
||||
)
|
||||
|
||||
allow_registration = db.query(SystemConfig).filter_by(key="enable_registration").first()
|
||||
if allow_registration and not allow_registration.value:
|
||||
AuditService.log_event(
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ApiKey, User
|
||||
from src.utils.request_utils import get_client_ip
|
||||
|
||||
|
||||
|
||||
@@ -86,7 +87,7 @@ class ApiRequestContext:
|
||||
setattr(request.state, "request_id", request_id)
|
||||
|
||||
start_time = time.time()
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
client_ip = get_client_ip(request)
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
context = cls(
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, Optional, Tuple
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.settings import config
|
||||
from src.core.exceptions import QuotaExceededException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ApiKey, AuditEventType, User, UserRole
|
||||
@@ -64,13 +65,17 @@ class ApiRequestPipeline:
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
# 添加30秒超时防止卡死
|
||||
raw_body = await asyncio.wait_for(http_request.body(), timeout=30.0)
|
||||
# 添加超时防止卡死
|
||||
raw_body = await asyncio.wait_for(
|
||||
http_request.body(), timeout=config.request_body_timeout
|
||||
)
|
||||
logger.debug(f"[Pipeline] Raw body读取完成 | size={len(raw_body) if raw_body is not None else 0} bytes")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("读取请求体超时(30s),可能客户端未发送完整请求体")
|
||||
timeout_sec = int(config.request_body_timeout)
|
||||
logger.error(f"读取请求体超时({timeout_sec}s),可能客户端未发送完整请求体")
|
||||
raise HTTPException(
|
||||
status_code=408, detail="Request timeout: body not received within 30 seconds"
|
||||
status_code=408,
|
||||
detail=f"Request timeout: body not received within {timeout_sec} seconds",
|
||||
)
|
||||
else:
|
||||
logger.debug(f"[Pipeline] 非写请求跳过读取Body | method={http_request.method}")
|
||||
|
||||
@@ -47,7 +47,6 @@ if TYPE_CHECKING:
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
|
||||
|
||||
|
||||
class MessageTelemetry:
|
||||
"""
|
||||
负责记录 Usage/Audit,避免处理器里重复代码。
|
||||
@@ -406,7 +405,7 @@ class BaseMessageHandler:
|
||||
asyncio.create_task(_do_update())
|
||||
|
||||
def _update_usage_to_streaming_with_ctx(self, ctx: "StreamContext") -> None:
|
||||
"""更新 Usage 状态为 streaming,同时更新 provider 和 target_model
|
||||
"""更新 Usage 状态为 streaming,同时更新 provider 相关信息
|
||||
|
||||
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
||||
|
||||
@@ -414,7 +413,7 @@ class BaseMessageHandler:
|
||||
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
|
||||
|
||||
Args:
|
||||
ctx: 流式上下文,包含 provider_name 和 mapped_model
|
||||
ctx: 流式上下文,包含 provider 相关信息
|
||||
"""
|
||||
import asyncio
|
||||
from src.database.database import get_db
|
||||
@@ -422,6 +421,17 @@ class BaseMessageHandler:
|
||||
target_request_id = self.request_id
|
||||
provider = ctx.provider_name
|
||||
target_model = ctx.mapped_model
|
||||
provider_id = ctx.provider_id
|
||||
endpoint_id = ctx.endpoint_id
|
||||
key_id = ctx.key_id
|
||||
first_byte_time_ms = ctx.first_byte_time_ms
|
||||
|
||||
# 如果 provider 为空,记录警告(不应该发生,但用于调试)
|
||||
if not provider:
|
||||
logger.warning(
|
||||
f"[{target_request_id}] 更新 streaming 状态时 provider 为空: "
|
||||
f"ctx.provider_name={ctx.provider_name}, ctx.provider_id={ctx.provider_id}"
|
||||
)
|
||||
|
||||
async def _do_update() -> None:
|
||||
try:
|
||||
@@ -434,6 +444,10 @@ class BaseMessageHandler:
|
||||
status="streaming",
|
||||
provider=provider,
|
||||
target_model=target_model,
|
||||
provider_id=provider_id,
|
||||
provider_endpoint_id=endpoint_id,
|
||||
provider_api_key_id=key_id,
|
||||
first_byte_time_ms=first_byte_time_ms,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -40,6 +40,7 @@ from src.core.exceptions import (
|
||||
UpstreamClientException,
|
||||
)
|
||||
from src.core.logger import logger
|
||||
from src.services.billing import calculate_request_cost as _calculate_request_cost
|
||||
from src.services.request.result import RequestResult
|
||||
from src.services.usage.recorder import UsageRecorder
|
||||
|
||||
@@ -63,6 +64,9 @@ class ChatAdapterBase(ApiAdapter):
|
||||
name: str = "chat.base"
|
||||
mode = ApiMode.STANDARD
|
||||
|
||||
# 计费模板配置(子类可覆盖,如 "claude", "openai", "gemini")
|
||||
BILLING_TEMPLATE: str = "claude"
|
||||
|
||||
# 子类可以配置的特殊方法(用于check_endpoint)
|
||||
@classmethod
|
||||
def build_endpoint_url(cls, base_url: str) -> str:
|
||||
@@ -486,40 +490,6 @@ class ChatAdapterBase(ApiAdapter):
|
||||
"""
|
||||
return input_tokens + cache_read_input_tokens
|
||||
|
||||
def get_cache_read_price_for_ttl(
|
||||
self,
|
||||
tier: dict,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
根据缓存 TTL 获取缓存读取价格
|
||||
|
||||
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
|
||||
子类可覆盖此方法实现不同的 TTL 定价逻辑
|
||||
|
||||
Args:
|
||||
tier: 当前阶梯配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
缓存读取价格(每 1M tokens)
|
||||
"""
|
||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||
if ttl_pricing and cache_ttl_minutes is not None:
|
||||
matched_price = None
|
||||
for ttl_config in ttl_pricing:
|
||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||
if cache_ttl_minutes <= ttl_limit:
|
||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
||||
break
|
||||
if matched_price is not None:
|
||||
return matched_price
|
||||
# 超过所有配置的 TTL,使用最后一个
|
||||
if ttl_pricing:
|
||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||
|
||||
return tier.get("cache_read_price_per_1m")
|
||||
|
||||
def compute_cost(
|
||||
self,
|
||||
input_tokens: int,
|
||||
@@ -537,8 +507,9 @@ class ChatAdapterBase(ApiAdapter):
|
||||
"""
|
||||
计算请求成本
|
||||
|
||||
默认实现:支持固定价格和阶梯计费
|
||||
子类可覆盖此方法实现完全不同的计费逻辑
|
||||
使用 billing 模块的配置驱动计费。
|
||||
子类可通过设置 BILLING_TEMPLATE 类属性来指定计费模板,
|
||||
或覆盖此方法实现完全自定义的计费逻辑。
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
@@ -566,88 +537,26 @@ class ChatAdapterBase(ApiAdapter):
|
||||
"tier_index": Optional[int], # 命中的阶梯索引
|
||||
}
|
||||
"""
|
||||
tier_index = None
|
||||
effective_input_price = input_price_per_1m
|
||||
effective_output_price = output_price_per_1m
|
||||
effective_cache_creation_price = cache_creation_price_per_1m
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
# 计算总输入上下文(使用子类可覆盖的方法)
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
|
||||
# 检查阶梯计费
|
||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
|
||||
|
||||
if tier:
|
||||
tier_index = tiered_pricing["tiers"].index(tier)
|
||||
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
|
||||
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
|
||||
effective_cache_creation_price = tier.get(
|
||||
"cache_creation_price_per_1m", cache_creation_price_per_1m
|
||||
)
|
||||
effective_cache_read_price = self.get_cache_read_price_for_ttl(
|
||||
tier, cache_ttl_minutes
|
||||
)
|
||||
if effective_cache_read_price is None:
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 计算各项成本
|
||||
input_cost = (input_tokens / 1_000_000) * effective_input_price
|
||||
output_cost = (output_tokens / 1_000_000) * effective_output_price
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * effective_cache_creation_price
|
||||
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
|
||||
cache_read_cost = (
|
||||
cache_read_input_tokens / 1_000_000
|
||||
) * effective_cache_read_price
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
request_cost = price_per_request if price_per_request else 0.0
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
|
||||
return {
|
||||
"input_cost": input_cost,
|
||||
"output_cost": output_cost,
|
||||
"cache_creation_cost": cache_creation_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_cost": cache_cost,
|
||||
"request_cost": request_cost,
|
||||
"total_cost": total_cost,
|
||||
"tier_index": tier_index,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
|
||||
"""
|
||||
根据总输入 token 数确定价格阶梯
|
||||
|
||||
Args:
|
||||
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
|
||||
total_input_tokens: 总输入 token 数
|
||||
|
||||
Returns:
|
||||
匹配的阶梯配置
|
||||
"""
|
||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
for tier in tiers:
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_tokens <= up_to:
|
||||
return tier
|
||||
|
||||
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
|
||||
return tiers[-1] if tiers else None
|
||||
return _calculate_request_cost(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
input_price_per_1m=input_price_per_1m,
|
||||
output_price_per_1m=output_price_per_1m,
|
||||
cache_creation_price_per_1m=cache_creation_price_per_1m,
|
||||
cache_read_price_per_1m=cache_read_price_per_1m,
|
||||
price_per_request=price_per_request,
|
||||
tiered_pricing=tiered_pricing,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
total_input_context=total_input_context,
|
||||
billing_template=self.BILLING_TEMPLATE,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 模型列表查询 - 子类应覆盖此方法
|
||||
|
||||
@@ -36,6 +36,7 @@ from src.api.handlers.base.stream_processor import StreamProcessor
|
||||
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
|
||||
from src.api.handlers.base.utils import build_sse_headers
|
||||
from src.config.settings import config
|
||||
from src.core.error_utils import extract_error_message
|
||||
from src.core.exceptions import (
|
||||
EmbeddedErrorException,
|
||||
ProviderAuthException,
|
||||
@@ -500,6 +501,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
error_text = await self._extract_error_text(e)
|
||||
logger.error(f"Provider 返回错误: {e.response.status_code}\n Response: {error_text}")
|
||||
await http_client.aclose()
|
||||
# 将上游错误信息附加到异常,以便故障转移时能够返回给客户端
|
||||
e.upstream_response = error_text # type: ignore[attr-defined]
|
||||
raise
|
||||
|
||||
except EmbeddedErrorException:
|
||||
@@ -549,7 +552,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
model=ctx.model,
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=status_code,
|
||||
error_message=str(error),
|
||||
error_message=extract_error_message(error),
|
||||
request_headers=original_headers,
|
||||
request_body=actual_request_body,
|
||||
is_stream=True,
|
||||
@@ -785,7 +788,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
model=model,
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=status_code,
|
||||
error_message=str(e),
|
||||
error_message=extract_error_message(e),
|
||||
request_headers=original_headers,
|
||||
request_body=actual_request_body,
|
||||
is_stream=False,
|
||||
@@ -802,10 +805,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
try:
|
||||
if hasattr(e.response, "is_stream_consumed") and not e.response.is_stream_consumed:
|
||||
error_bytes = await e.response.aread()
|
||||
return error_bytes.decode("utf-8", errors="replace")[:500]
|
||||
return error_bytes.decode("utf-8", errors="replace")
|
||||
else:
|
||||
return (
|
||||
e.response.text[:500] if hasattr(e.response, "_content") else "Unable to read"
|
||||
e.response.text if hasattr(e.response, "_content") else "Unable to read"
|
||||
)
|
||||
except Exception as decode_error:
|
||||
return f"Unable to read error: {decode_error}"
|
||||
|
||||
@@ -38,6 +38,7 @@ from src.core.exceptions import (
|
||||
UpstreamClientException,
|
||||
)
|
||||
from src.core.logger import logger
|
||||
from src.services.billing import calculate_request_cost as _calculate_request_cost
|
||||
from src.services.request.result import RequestResult
|
||||
from src.services.usage.recorder import UsageRecorder
|
||||
|
||||
@@ -61,6 +62,9 @@ class CliAdapterBase(ApiAdapter):
|
||||
name: str = "cli.base"
|
||||
mode = ApiMode.PROXY
|
||||
|
||||
# 计费模板配置(子类可覆盖,如 "claude", "openai", "gemini")
|
||||
BILLING_TEMPLATE: str = "claude"
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||
|
||||
@@ -438,40 +442,6 @@ class CliAdapterBase(ApiAdapter):
|
||||
"""
|
||||
return input_tokens + cache_read_input_tokens
|
||||
|
||||
def get_cache_read_price_for_ttl(
|
||||
self,
|
||||
tier: dict,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
根据缓存 TTL 获取缓存读取价格
|
||||
|
||||
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
|
||||
子类可覆盖此方法实现不同的 TTL 定价逻辑
|
||||
|
||||
Args:
|
||||
tier: 当前阶梯配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
缓存读取价格(每 1M tokens)
|
||||
"""
|
||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||
if ttl_pricing and cache_ttl_minutes is not None:
|
||||
matched_price = None
|
||||
for ttl_config in ttl_pricing:
|
||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||
if cache_ttl_minutes <= ttl_limit:
|
||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
||||
break
|
||||
if matched_price is not None:
|
||||
return matched_price
|
||||
# 超过所有配置的 TTL,使用最后一个
|
||||
if ttl_pricing:
|
||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||
|
||||
return tier.get("cache_read_price_per_1m")
|
||||
|
||||
def compute_cost(
|
||||
self,
|
||||
input_tokens: int,
|
||||
@@ -489,8 +459,9 @@ class CliAdapterBase(ApiAdapter):
|
||||
"""
|
||||
计算请求成本
|
||||
|
||||
默认实现:支持固定价格和阶梯计费
|
||||
子类可覆盖此方法实现完全不同的计费逻辑
|
||||
使用 billing 模块的配置驱动计费。
|
||||
子类可通过设置 BILLING_TEMPLATE 类属性来指定计费模板,
|
||||
或覆盖此方法实现完全自定义的计费逻辑。
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
@@ -508,78 +479,26 @@ class CliAdapterBase(ApiAdapter):
|
||||
Returns:
|
||||
包含各项成本的字典
|
||||
"""
|
||||
tier_index = None
|
||||
effective_input_price = input_price_per_1m
|
||||
effective_output_price = output_price_per_1m
|
||||
effective_cache_creation_price = cache_creation_price_per_1m
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
# 计算总输入上下文(使用子类可覆盖的方法)
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
|
||||
# 检查阶梯计费
|
||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
|
||||
|
||||
if tier:
|
||||
tier_index = tiered_pricing["tiers"].index(tier)
|
||||
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
|
||||
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
|
||||
effective_cache_creation_price = tier.get(
|
||||
"cache_creation_price_per_1m", cache_creation_price_per_1m
|
||||
)
|
||||
effective_cache_read_price = self.get_cache_read_price_for_ttl(
|
||||
tier, cache_ttl_minutes
|
||||
)
|
||||
if effective_cache_read_price is None:
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 计算各项成本
|
||||
input_cost = (input_tokens / 1_000_000) * effective_input_price
|
||||
output_cost = (output_tokens / 1_000_000) * effective_output_price
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * effective_cache_creation_price
|
||||
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
|
||||
cache_read_cost = (
|
||||
cache_read_input_tokens / 1_000_000
|
||||
) * effective_cache_read_price
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
request_cost = price_per_request if price_per_request else 0.0
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
|
||||
return {
|
||||
"input_cost": input_cost,
|
||||
"output_cost": output_cost,
|
||||
"cache_creation_cost": cache_creation_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_cost": cache_cost,
|
||||
"request_cost": request_cost,
|
||||
"total_cost": total_cost,
|
||||
"tier_index": tier_index,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
|
||||
"""根据总输入 token 数确定价格阶梯"""
|
||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
for tier in tiers:
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_tokens <= up_to:
|
||||
return tier
|
||||
|
||||
return tiers[-1] if tiers else None
|
||||
return _calculate_request_cost(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
input_price_per_1m=input_price_per_1m,
|
||||
output_price_per_1m=output_price_per_1m,
|
||||
cache_creation_price_per_1m=cache_creation_price_per_1m,
|
||||
cache_read_price_per_1m=cache_read_price_per_1m,
|
||||
price_per_request=price_per_request,
|
||||
tiered_pricing=tiered_pricing,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
total_input_context=total_input_context,
|
||||
billing_template=self.BILLING_TEMPLATE,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 模型列表查询 - 子类应覆盖此方法
|
||||
|
||||
@@ -34,7 +34,12 @@ from src.api.handlers.base.base_handler import (
|
||||
from src.api.handlers.base.parsers import get_parser_for_format
|
||||
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
from src.api.handlers.base.utils import build_sse_headers
|
||||
from src.api.handlers.base.utils import (
|
||||
build_sse_headers,
|
||||
check_html_response,
|
||||
check_prefetched_response_error,
|
||||
)
|
||||
from src.core.error_utils import extract_error_message
|
||||
|
||||
# 直接从具体模块导入,避免循环依赖
|
||||
from src.api.handlers.base.response_parser import (
|
||||
@@ -57,6 +62,7 @@ from src.models.database import (
|
||||
ProviderEndpoint,
|
||||
User,
|
||||
)
|
||||
from src.config.constants import StreamDefaults
|
||||
from src.config.settings import config
|
||||
from src.services.provider.transport import build_provider_url
|
||||
from src.utils.sse_parser import SSEEventParser
|
||||
@@ -328,9 +334,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
stream_generator,
|
||||
provider_name,
|
||||
attempt_id,
|
||||
_provider_id,
|
||||
_endpoint_id,
|
||||
_key_id,
|
||||
provider_id,
|
||||
endpoint_id,
|
||||
key_id,
|
||||
) = await self.orchestrator.execute_with_fallback(
|
||||
api_format=ctx.api_format,
|
||||
model_name=ctx.model,
|
||||
@@ -340,7 +346,17 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
is_stream=True,
|
||||
capability_requirements=capability_requirements or None,
|
||||
)
|
||||
|
||||
# 更新上下文(确保 provider 信息已设置,用于 streaming 状态更新)
|
||||
ctx.attempt_id = attempt_id
|
||||
if not ctx.provider_name:
|
||||
ctx.provider_name = provider_name
|
||||
if not ctx.provider_id:
|
||||
ctx.provider_id = provider_id
|
||||
if not ctx.endpoint_id:
|
||||
ctx.endpoint_id = endpoint_id
|
||||
if not ctx.key_id:
|
||||
ctx.key_id = key_id
|
||||
|
||||
# 创建后台任务记录统计
|
||||
background_tasks = BackgroundTasks()
|
||||
@@ -488,6 +504,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
error_text = await self._extract_error_text(e)
|
||||
logger.error(f"Provider 返回错误状态: {e.response.status_code}\n Response: {error_text}")
|
||||
await http_client.aclose()
|
||||
# 将上游错误信息附加到异常,以便故障转移时能够返回给客户端
|
||||
e.upstream_response = error_text # type: ignore[attr-defined]
|
||||
raise
|
||||
|
||||
except EmbeddedErrorException:
|
||||
@@ -523,8 +541,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
try:
|
||||
sse_parser = SSEEventParser()
|
||||
last_data_time = time.time()
|
||||
streaming_status_updated = False
|
||||
buffer = b""
|
||||
output_state = {"first_yield": True, "streaming_updated": False}
|
||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
|
||||
@@ -532,11 +550,6 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
needs_conversion = self._needs_format_conversion(ctx)
|
||||
|
||||
async for chunk in stream_response.aiter_bytes():
|
||||
# 在第一次输出数据前更新状态为 streaming
|
||||
if not streaming_status_updated:
|
||||
self._update_usage_to_streaming_with_ctx(ctx)
|
||||
streaming_status_updated = True
|
||||
|
||||
buffer += chunk
|
||||
# 处理缓冲区中的完整行
|
||||
while b"\n" in buffer:
|
||||
@@ -561,6 +574,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
self._mark_first_output(ctx, output_state)
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
@@ -578,6 +592,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||||
},
|
||||
}
|
||||
self._mark_first_output(ctx, output_state)
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||
return # 结束生成器
|
||||
|
||||
@@ -585,8 +600,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
if needs_conversion:
|
||||
converted_line = self._convert_sse_line(ctx, line, events)
|
||||
if converted_line:
|
||||
self._mark_first_output(ctx, output_state)
|
||||
yield (converted_line + "\n").encode("utf-8")
|
||||
else:
|
||||
self._mark_first_output(ctx, output_state)
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
for event in events:
|
||||
@@ -637,7 +654,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
},
|
||||
}
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||
except httpx.RemoteProtocolError as e:
|
||||
except httpx.RemoteProtocolError:
|
||||
if ctx.data_count > 0:
|
||||
error_event = {
|
||||
"type": "error",
|
||||
@@ -691,7 +708,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
ProviderTimeoutException: 如果首字节超时(TTFB timeout)
|
||||
"""
|
||||
prefetched_chunks: list = []
|
||||
max_prefetch_lines = 5 # 最多预读5行来检测错误
|
||||
max_prefetch_lines = config.stream_prefetch_lines # 最多预读行数来检测错误
|
||||
max_prefetch_bytes = StreamDefaults.MAX_PREFETCH_BYTES # 避免无换行响应导致 buffer 增长
|
||||
total_prefetched_bytes = 0
|
||||
buffer = b""
|
||||
line_count = 0
|
||||
should_stop = False
|
||||
@@ -718,14 +737,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
provider_name=str(provider.name),
|
||||
)
|
||||
prefetched_chunks.append(first_chunk)
|
||||
total_prefetched_bytes += len(first_chunk)
|
||||
buffer += first_chunk
|
||||
|
||||
# 继续读取剩余的预读数据
|
||||
async for chunk in aiter:
|
||||
prefetched_chunks.append(chunk)
|
||||
total_prefetched_bytes += len(chunk)
|
||||
buffer += chunk
|
||||
|
||||
# 尝试按行解析缓冲区
|
||||
# 尝试按行解析缓冲区(SSE 格式)
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
try:
|
||||
@@ -742,15 +763,15 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
normalized_line = line.rstrip("\r")
|
||||
|
||||
# 检测 HTML 响应(base_url 配置错误的常见症状)
|
||||
lower_line = normalized_line.lower()
|
||||
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
|
||||
if check_html_response(normalized_line):
|
||||
logger.error(
|
||||
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
|
||||
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||
f"base_url={endpoint.base_url}"
|
||||
)
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
|
||||
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,"
|
||||
f"请检查 endpoint 的 base_url 配置是否正确"
|
||||
)
|
||||
|
||||
if not normalized_line or normalized_line.startswith(":"):
|
||||
@@ -799,9 +820,30 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
should_stop = True
|
||||
break
|
||||
|
||||
# 达到预读字节上限,停止继续预读(避免无换行响应导致内存增长)
|
||||
if not should_stop and total_prefetched_bytes >= max_prefetch_bytes:
|
||||
logger.debug(
|
||||
f" [{self.request_id}] 预读达到字节上限,停止继续预读: "
|
||||
f"Provider={provider.name}, bytes={total_prefetched_bytes}, "
|
||||
f"max_bytes={max_prefetch_bytes}"
|
||||
)
|
||||
break
|
||||
|
||||
if should_stop or line_count >= max_prefetch_lines:
|
||||
break
|
||||
|
||||
# 预读结束后,检查是否为非 SSE 格式的 HTML/JSON 响应
|
||||
# 处理某些代理返回的纯 JSON 错误(可能无换行/多行 JSON)以及 HTML 页面(base_url 配置错误)
|
||||
if not should_stop and prefetched_chunks:
|
||||
check_prefetched_response_error(
|
||||
prefetched_chunks=prefetched_chunks,
|
||||
parser=provider_parser,
|
||||
request_id=self.request_id,
|
||||
provider_name=str(provider.name),
|
||||
endpoint_id=endpoint.id,
|
||||
base_url=endpoint.base_url,
|
||||
)
|
||||
|
||||
except (EmbeddedErrorException, ProviderTimeoutException, ProviderNotAvailableException):
|
||||
# 重新抛出可重试的 Provider 异常,触发故障转移
|
||||
raise
|
||||
@@ -833,17 +875,13 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
sse_parser = SSEEventParser()
|
||||
last_data_time = time.time()
|
||||
buffer = b""
|
||||
first_yield = True # 标记是否是第一次 yield
|
||||
output_state = {"first_yield": True, "streaming_updated": False}
|
||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
|
||||
# 检查是否需要格式转换
|
||||
needs_conversion = self._needs_format_conversion(ctx)
|
||||
|
||||
# 在第一次输出数据前更新状态为 streaming
|
||||
if prefetched_chunks:
|
||||
self._update_usage_to_streaming_with_ctx(ctx)
|
||||
|
||||
# 先处理预读的字节块
|
||||
for chunk in prefetched_chunks:
|
||||
buffer += chunk
|
||||
@@ -870,10 +908,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
# 记录首字时间 (第一次 yield)
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
self._mark_first_output(ctx, output_state)
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
@@ -883,16 +918,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
if needs_conversion:
|
||||
converted_line = self._convert_sse_line(ctx, line, events)
|
||||
if converted_line:
|
||||
# 记录首字时间 (第一次 yield)
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
self._mark_first_output(ctx, output_state)
|
||||
yield (converted_line + "\n").encode("utf-8")
|
||||
else:
|
||||
# 记录首字时间 (第一次 yield)
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
self._mark_first_output(ctx, output_state)
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
for event in events:
|
||||
@@ -931,10 +960,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
# 记录首字时间 (第一次 yield) - 如果预读数据为空
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
self._mark_first_output(ctx, output_state)
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
@@ -952,6 +978,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||||
},
|
||||
}
|
||||
self._mark_first_output(ctx, output_state)
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||
return
|
||||
|
||||
@@ -959,16 +986,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
if needs_conversion:
|
||||
converted_line = self._convert_sse_line(ctx, line, events)
|
||||
if converted_line:
|
||||
# 记录首字时间 (第一次 yield) - 如果预读数据为空
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
self._mark_first_output(ctx, output_state)
|
||||
yield (converted_line + "\n").encode("utf-8")
|
||||
else:
|
||||
# 记录首字时间 (第一次 yield) - 如果预读数据为空
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
self._mark_first_output(ctx, output_state)
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
for event in events:
|
||||
@@ -1352,7 +1373,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
model=ctx.model,
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=status_code,
|
||||
error_message=str(error),
|
||||
error_message=extract_error_message(error),
|
||||
request_headers=original_headers,
|
||||
request_body=actual_request_body,
|
||||
is_stream=True,
|
||||
@@ -1476,8 +1497,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
retry_after=int(resp.headers.get("retry-after", 0)) or None,
|
||||
)
|
||||
elif resp.status_code >= 500:
|
||||
error_text = resp.text
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}"
|
||||
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}",
|
||||
provider_name=str(provider.name),
|
||||
upstream_status=resp.status_code,
|
||||
upstream_response=error_text,
|
||||
)
|
||||
elif 300 <= resp.status_code < 400:
|
||||
redirect_url = resp.headers.get("location", "unknown")
|
||||
@@ -1487,7 +1512,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
elif resp.status_code != 200:
|
||||
error_text = resp.text
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}, 错误: {error_text[:200]}"
|
||||
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
|
||||
provider_name=str(provider.name),
|
||||
upstream_status=resp.status_code,
|
||||
upstream_response=error_text,
|
||||
)
|
||||
|
||||
# 安全解析 JSON 响应,处理可能的编码错误
|
||||
@@ -1620,7 +1648,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
model=model,
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=status_code,
|
||||
error_message=str(e),
|
||||
error_message=extract_error_message(e),
|
||||
request_headers=original_headers,
|
||||
request_body=actual_request_body,
|
||||
is_stream=False,
|
||||
@@ -1640,14 +1668,14 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
|
||||
for encoding in ["utf-8", "gbk", "latin1"]:
|
||||
try:
|
||||
return error_bytes.decode(encoding)[:500]
|
||||
return error_bytes.decode(encoding)
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
continue
|
||||
|
||||
return error_bytes.decode("utf-8", errors="replace")[:500]
|
||||
return error_bytes.decode("utf-8", errors="replace")
|
||||
else:
|
||||
return (
|
||||
e.response.text[:500]
|
||||
e.response.text
|
||||
if hasattr(e.response, "_content")
|
||||
else "Unable to read response"
|
||||
)
|
||||
@@ -1665,6 +1693,25 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
return False
|
||||
return ctx.provider_api_format.upper() != ctx.client_api_format.upper()
|
||||
|
||||
def _mark_first_output(self, ctx: StreamContext, state: Dict[str, bool]) -> None:
|
||||
"""
|
||||
标记首次输出:记录 TTFB 并更新 streaming 状态
|
||||
|
||||
在第一次 yield 数据前调用,确保:
|
||||
1. 首字时间 (TTFB) 已记录到 ctx
|
||||
2. Usage 状态已更新为 streaming(包含 provider/key/TTFB 信息)
|
||||
|
||||
Args:
|
||||
ctx: 流上下文
|
||||
state: 包含 first_yield 和 streaming_updated 的状态字典
|
||||
"""
|
||||
if state["first_yield"]:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
state["first_yield"] = False
|
||||
if not state["streaming_updated"]:
|
||||
self._update_usage_to_streaming_with_ctx(ctx)
|
||||
state["streaming_updated"] = True
|
||||
|
||||
def _convert_sse_line(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
|
||||
@@ -98,6 +98,17 @@ class OpenAIResponseParser(ResponseParser):
|
||||
chunk.is_done = True
|
||||
stats.has_completion = True
|
||||
|
||||
# 提取 usage 信息(某些 OpenAI 兼容 API 如豆包会在最后一个 chunk 中发送 usage)
|
||||
# 这个 chunk 通常 choices 为空数组,但包含完整的 usage 信息
|
||||
usage = parsed.get("usage")
|
||||
if usage and isinstance(usage, dict):
|
||||
chunk.input_tokens = usage.get("prompt_tokens", 0)
|
||||
chunk.output_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
# 更新 stats
|
||||
stats.input_tokens = chunk.input_tokens
|
||||
stats.output_tokens = chunk.output_tokens
|
||||
|
||||
stats.chunk_count += 1
|
||||
stats.data_count += 1
|
||||
|
||||
|
||||
@@ -25,8 +25,17 @@ from src.api.handlers.base.content_extractors import (
|
||||
from src.api.handlers.base.parsers import get_parser_for_format
|
||||
from src.api.handlers.base.response_parser import ResponseParser
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
from src.api.handlers.base.utils import (
|
||||
check_html_response,
|
||||
check_prefetched_response_error,
|
||||
)
|
||||
from src.config.constants import StreamDefaults
|
||||
from src.config.settings import config
|
||||
from src.core.exceptions import EmbeddedErrorException, ProviderTimeoutException
|
||||
from src.core.exceptions import (
|
||||
EmbeddedErrorException,
|
||||
ProviderNotAvailableException,
|
||||
ProviderTimeoutException,
|
||||
)
|
||||
from src.core.logger import logger
|
||||
from src.models.database import Provider, ProviderEndpoint
|
||||
from src.utils.sse_parser import SSEEventParser
|
||||
@@ -165,6 +174,7 @@ class StreamProcessor:
|
||||
endpoint: ProviderEndpoint,
|
||||
ctx: StreamContext,
|
||||
max_prefetch_lines: int = 5,
|
||||
max_prefetch_bytes: int = StreamDefaults.MAX_PREFETCH_BYTES,
|
||||
) -> list:
|
||||
"""
|
||||
预读流的前几行,检测嵌套错误
|
||||
@@ -180,12 +190,14 @@ class StreamProcessor:
|
||||
endpoint: Endpoint 对象
|
||||
ctx: 流式上下文
|
||||
max_prefetch_lines: 最多预读行数
|
||||
max_prefetch_bytes: 最多预读字节数(避免无换行响应导致 buffer 增长)
|
||||
|
||||
Returns:
|
||||
预读的字节块列表
|
||||
|
||||
Raises:
|
||||
EmbeddedErrorException: 如果检测到嵌套错误
|
||||
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
|
||||
ProviderTimeoutException: 如果首字节超时(TTFB timeout)
|
||||
"""
|
||||
prefetched_chunks: list = []
|
||||
@@ -193,6 +205,7 @@ class StreamProcessor:
|
||||
buffer = b""
|
||||
line_count = 0
|
||||
should_stop = False
|
||||
total_prefetched_bytes = 0
|
||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
|
||||
@@ -206,11 +219,13 @@ class StreamProcessor:
|
||||
provider_name=str(provider.name),
|
||||
)
|
||||
prefetched_chunks.append(first_chunk)
|
||||
total_prefetched_bytes += len(first_chunk)
|
||||
buffer += first_chunk
|
||||
|
||||
# 继续读取剩余的预读数据
|
||||
async for chunk in aiter:
|
||||
prefetched_chunks.append(chunk)
|
||||
total_prefetched_bytes += len(chunk)
|
||||
buffer += chunk
|
||||
|
||||
# 尝试按行解析缓冲区
|
||||
@@ -228,10 +243,21 @@ class StreamProcessor:
|
||||
|
||||
line_count += 1
|
||||
|
||||
# 检测 HTML 响应(base_url 配置错误的常见症状)
|
||||
if check_html_response(line):
|
||||
logger.error(
|
||||
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
|
||||
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||
f"base_url={endpoint.base_url}"
|
||||
)
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,"
|
||||
f"请检查 endpoint 的 base_url 配置是否正确"
|
||||
)
|
||||
|
||||
# 跳过空行和注释行
|
||||
if not line or line.startswith(":"):
|
||||
if line_count >= max_prefetch_lines:
|
||||
should_stop = True
|
||||
break
|
||||
continue
|
||||
|
||||
@@ -248,7 +274,6 @@ class StreamProcessor:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
if line_count >= max_prefetch_lines:
|
||||
should_stop = True
|
||||
break
|
||||
continue
|
||||
|
||||
@@ -276,14 +301,34 @@ class StreamProcessor:
|
||||
should_stop = True
|
||||
break
|
||||
|
||||
# 达到预读字节上限,停止继续预读(避免无换行响应导致内存增长)
|
||||
if not should_stop and total_prefetched_bytes >= max_prefetch_bytes:
|
||||
logger.debug(
|
||||
f" [{self.request_id}] 预读达到字节上限,停止继续预读: "
|
||||
f"Provider={provider.name}, bytes={total_prefetched_bytes}, "
|
||||
f"max_bytes={max_prefetch_bytes}"
|
||||
)
|
||||
break
|
||||
|
||||
if should_stop or line_count >= max_prefetch_lines:
|
||||
break
|
||||
|
||||
except (EmbeddedErrorException, ProviderTimeoutException):
|
||||
# 预读结束后,检查是否为非 SSE 格式的 HTML/JSON 响应
|
||||
if not should_stop and prefetched_chunks:
|
||||
check_prefetched_response_error(
|
||||
prefetched_chunks=prefetched_chunks,
|
||||
parser=parser,
|
||||
request_id=self.request_id,
|
||||
provider_name=str(provider.name),
|
||||
endpoint_id=endpoint.id,
|
||||
base_url=endpoint.base_url,
|
||||
)
|
||||
|
||||
except (EmbeddedErrorException, ProviderNotAvailableException, ProviderTimeoutException):
|
||||
# 重新抛出可重试的 Provider 异常,触发故障转移
|
||||
raise
|
||||
except (OSError, IOError) as e:
|
||||
# 网络 I/O <EFBFBD><EFBFBD><EFBFBD>常:记录警告,可能需要重试
|
||||
# 网络 I/O 异常:记录警告,可能需要重试
|
||||
logger.warning(
|
||||
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
|
||||
)
|
||||
@@ -332,15 +377,15 @@ class StreamProcessor:
|
||||
|
||||
# 处理预读数据
|
||||
if prefetched_chunks:
|
||||
if not streaming_started and self.on_streaming_start:
|
||||
self.on_streaming_start()
|
||||
streaming_started = True
|
||||
|
||||
for chunk in prefetched_chunks:
|
||||
# 记录首字时间 (TTFB) - 在 yield 之前记录
|
||||
if start_time is not None:
|
||||
ctx.record_first_byte_time(start_time)
|
||||
start_time = None # 只记录一次
|
||||
# 首次输出前触发 streaming 回调(确保 TTFB 已写入 ctx)
|
||||
if not streaming_started and self.on_streaming_start:
|
||||
self.on_streaming_start()
|
||||
streaming_started = True
|
||||
|
||||
# 把原始数据转发给客户端
|
||||
yield chunk
|
||||
@@ -363,14 +408,14 @@ class StreamProcessor:
|
||||
|
||||
# 处理剩余的流数据
|
||||
async for chunk in byte_iterator:
|
||||
if not streaming_started and self.on_streaming_start:
|
||||
self.on_streaming_start()
|
||||
streaming_started = True
|
||||
|
||||
# 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
|
||||
if start_time is not None:
|
||||
ctx.record_first_byte_time(start_time)
|
||||
start_time = None # 只记录一次
|
||||
# 首次输出前触发 streaming 回调(确保 TTFB 已写入 ctx)
|
||||
if not streaming_started and self.on_streaming_start:
|
||||
self.on_streaming_start()
|
||||
streaming_started = True
|
||||
|
||||
# 原始数据透传
|
||||
yield chunk
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
Handler 基础工具函数
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.core.exceptions import EmbeddedErrorException, ProviderNotAvailableException
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
@@ -107,3 +109,95 @@ def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[st
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
return headers
|
||||
|
||||
|
||||
def check_html_response(line: str) -> bool:
|
||||
"""
|
||||
检查行是否为 HTML 响应(base_url 配置错误的常见症状)
|
||||
|
||||
Args:
|
||||
line: 要检查的行内容
|
||||
|
||||
Returns:
|
||||
True 如果检测到 HTML 响应
|
||||
"""
|
||||
lower_line = line.lstrip().lower()
|
||||
return lower_line.startswith("<!doctype") or lower_line.startswith("<html")
|
||||
|
||||
|
||||
def check_prefetched_response_error(
|
||||
prefetched_chunks: list,
|
||||
parser: Any,
|
||||
request_id: str,
|
||||
provider_name: str,
|
||||
endpoint_id: Optional[str],
|
||||
base_url: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
检查预读的响应是否为非 SSE 格式的错误响应(HTML 或纯 JSON 错误)
|
||||
|
||||
某些代理可能返回:
|
||||
1. HTML 页面(base_url 配置错误)
|
||||
2. 纯 JSON 错误(无换行或多行 JSON)
|
||||
|
||||
Args:
|
||||
prefetched_chunks: 预读的字节块列表
|
||||
parser: 响应解析器(需要有 is_error_response 和 parse_response 方法)
|
||||
request_id: 请求 ID(用于日志)
|
||||
provider_name: Provider 名称
|
||||
endpoint_id: Endpoint ID
|
||||
base_url: Endpoint 的 base_url
|
||||
|
||||
Raises:
|
||||
ProviderNotAvailableException: 如果检测到 HTML 响应
|
||||
EmbeddedErrorException: 如果检测到 JSON 错误响应
|
||||
"""
|
||||
if not prefetched_chunks:
|
||||
return
|
||||
|
||||
try:
|
||||
prefetched_bytes = b"".join(prefetched_chunks)
|
||||
stripped = prefetched_bytes.lstrip()
|
||||
|
||||
# 去除 BOM
|
||||
if stripped.startswith(b"\xef\xbb\xbf"):
|
||||
stripped = stripped[3:]
|
||||
|
||||
# HTML 响应(通常是 base_url 配置错误导致返回网页)
|
||||
lower_prefix = stripped[:32].lower()
|
||||
if lower_prefix.startswith(b"<!doctype") or lower_prefix.startswith(b"<html"):
|
||||
endpoint_short = endpoint_id[:8] + "..." if endpoint_id else "N/A"
|
||||
logger.error(
|
||||
f" [{request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
|
||||
f"Provider={provider_name}, Endpoint={endpoint_short}, "
|
||||
f"base_url={base_url}"
|
||||
)
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商 '{provider_name}' 返回了 HTML 页面而非 API 响应,"
|
||||
f"请检查 endpoint 的 base_url 配置是否正确"
|
||||
)
|
||||
|
||||
# 纯 JSON(可能无换行/多行 JSON)
|
||||
if stripped.startswith(b"{") or stripped.startswith(b"["):
|
||||
payload_str = stripped.decode("utf-8", errors="replace").strip()
|
||||
data = json.loads(payload_str)
|
||||
if isinstance(data, dict) and parser.is_error_response(data):
|
||||
parsed = parser.parse_response(data, 200)
|
||||
logger.warning(
|
||||
f" [{request_id}] 检测到 JSON 错误响应: "
|
||||
f"Provider={provider_name}, "
|
||||
f"error_type={parsed.error_type}, "
|
||||
f"message={parsed.error_message}"
|
||||
)
|
||||
raise EmbeddedErrorException(
|
||||
provider_name=provider_name,
|
||||
error_code=(
|
||||
int(parsed.error_type)
|
||||
if parsed.error_type and parsed.error_type.isdigit()
|
||||
else None
|
||||
),
|
||||
error_message=parsed.error_message,
|
||||
error_status=parsed.error_type,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
@@ -63,6 +63,7 @@ class ClaudeChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "CLAUDE"
|
||||
BILLING_TEMPLATE = "claude" # 使用 Claude 计费模板
|
||||
name = "claude.chat"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,6 +24,7 @@ class ClaudeCliAdapter(CliAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "CLAUDE_CLI"
|
||||
BILLING_TEMPLATE = "claude" # 使用 Claude 计费模板
|
||||
name = "claude.cli"
|
||||
|
||||
@property
|
||||
|
||||
@@ -27,6 +27,7 @@ class GeminiChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "GEMINI"
|
||||
BILLING_TEMPLATE = "gemini" # 使用 Gemini 计费模板
|
||||
name = "gemini.chat"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,6 +24,7 @@ class GeminiCliAdapter(CliAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "GEMINI_CLI"
|
||||
BILLING_TEMPLATE = "gemini" # 使用 Gemini 计费模板
|
||||
name = "gemini.cli"
|
||||
|
||||
@property
|
||||
|
||||
@@ -26,6 +26,7 @@ class OpenAIChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "OPENAI"
|
||||
BILLING_TEMPLATE = "openai" # 使用 OpenAI 计费模板
|
||||
name = "openai.chat"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,6 +24,7 @@ class OpenAICliAdapter(CliAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "OPENAI_CLI"
|
||||
BILLING_TEMPLATE = "openai" # 使用 OpenAI 计费模板
|
||||
name = "openai.cli"
|
||||
|
||||
@property
|
||||
|
||||
@@ -104,9 +104,14 @@ async def get_my_usage(
|
||||
request: Request,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
search: Optional[str] = None, # 通用搜索:密钥名、模型名
|
||||
limit: int = Query(100, ge=1, le=200, description="每页记录数,默认100,最大200"),
|
||||
offset: int = Query(0, ge=0, le=2000, description="偏移量,用于分页,最大2000"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = GetUsageAdapter(start_date=start_date, end_date=end_date)
|
||||
adapter = GetUsageAdapter(
|
||||
start_date=start_date, end_date=end_date, search=search, limit=limit, offset=offset
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@@ -133,6 +138,20 @@ async def get_my_interval_timeline(
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/usage/heatmap")
|
||||
async def get_my_activity_heatmap(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get user's activity heatmap data for the past 365 days.
|
||||
|
||||
This endpoint is cached for 5 minutes to reduce database load.
|
||||
"""
|
||||
adapter = GetMyActivityHeatmapAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/providers")
|
||||
async def list_available_providers(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = ListAvailableProvidersAdapter()
|
||||
@@ -471,8 +490,15 @@ class ToggleMyApiKeyAdapter(AuthenticatedApiAdapter):
|
||||
class GetUsageAdapter(AuthenticatedApiAdapter):
|
||||
start_date: Optional[datetime]
|
||||
end_date: Optional[datetime]
|
||||
search: Optional[str] = None
|
||||
limit: int = 100
|
||||
offset: int = 0
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from sqlalchemy import or_
|
||||
|
||||
from src.utils.database_helpers import escape_like_pattern, safe_truncate_escaped
|
||||
|
||||
db = context.db
|
||||
user = context.user
|
||||
summary_list = UsageService.get_usage_summary(
|
||||
@@ -553,7 +579,7 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
|
||||
stats["total_cost_usd"] += item["total_cost_usd"]
|
||||
# 假设 summary 中的都是成功的请求
|
||||
stats["success_count"] += item["requests"]
|
||||
if item.get("avg_response_time_ms"):
|
||||
if item.get("avg_response_time_ms") is not None:
|
||||
stats["total_response_time_ms"] += item["avg_response_time_ms"] * item["requests"]
|
||||
stats["response_time_count"] += item["requests"]
|
||||
|
||||
@@ -577,12 +603,33 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
|
||||
})
|
||||
summary_by_provider = sorted(summary_by_provider, key=lambda x: x["requests"], reverse=True)
|
||||
|
||||
query = db.query(Usage).filter(Usage.user_id == user.id)
|
||||
query = (
|
||||
db.query(Usage, ApiKey)
|
||||
.outerjoin(ApiKey, Usage.api_key_id == ApiKey.id)
|
||||
.filter(Usage.user_id == user.id)
|
||||
)
|
||||
if self.start_date:
|
||||
query = query.filter(Usage.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
query = query.filter(Usage.created_at <= self.end_date)
|
||||
usage_records = query.order_by(Usage.created_at.desc()).limit(100).all()
|
||||
|
||||
# 通用搜索:密钥名、模型名
|
||||
# 支持空格分隔的组合搜索,多个关键词之间是 AND 关系
|
||||
if self.search and self.search.strip():
|
||||
keywords = [kw for kw in self.search.strip().split() if kw][:10]
|
||||
for keyword in keywords:
|
||||
escaped = safe_truncate_escaped(escape_like_pattern(keyword), 100)
|
||||
search_pattern = f"%{escaped}%"
|
||||
query = query.filter(
|
||||
or_(
|
||||
ApiKey.name.ilike(search_pattern, escape="\\"),
|
||||
Usage.model.ilike(search_pattern, escape="\\"),
|
||||
)
|
||||
)
|
||||
|
||||
# 计算总数用于分页
|
||||
total_records = query.count()
|
||||
usage_records = query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
|
||||
|
||||
avg_resp_query = db.query(func.avg(Usage.response_time_ms)).filter(
|
||||
Usage.user_id == user.id,
|
||||
@@ -608,6 +655,13 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
|
||||
"used_usd": user.used_usd,
|
||||
"summary_by_model": summary_by_model,
|
||||
"summary_by_provider": summary_by_provider,
|
||||
# 分页信息
|
||||
"pagination": {
|
||||
"total": total_records,
|
||||
"limit": self.limit,
|
||||
"offset": self.offset,
|
||||
"has_more": self.offset + self.limit < total_records,
|
||||
},
|
||||
"records": [
|
||||
{
|
||||
"id": r.id,
|
||||
@@ -631,23 +685,25 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
|
||||
"output_price_per_1m": r.output_price_per_1m,
|
||||
"cache_creation_price_per_1m": r.cache_creation_price_per_1m,
|
||||
"cache_read_price_per_1m": r.cache_read_price_per_1m,
|
||||
"api_key": (
|
||||
{
|
||||
"id": str(api_key.id),
|
||||
"name": api_key.name,
|
||||
"display": api_key.get_display_key(),
|
||||
}
|
||||
if api_key
|
||||
else None
|
||||
),
|
||||
}
|
||||
for r in usage_records
|
||||
for r, api_key in usage_records
|
||||
],
|
||||
}
|
||||
|
||||
response_data["activity_heatmap"] = UsageService.get_daily_activity(
|
||||
db=db,
|
||||
user_id=user.id,
|
||||
window_days=365,
|
||||
include_actual_cost=user.role == "admin",
|
||||
)
|
||||
|
||||
# 管理员可以看到真实成本
|
||||
if user.role == "admin":
|
||||
response_data["total_actual_cost"] = total_actual_cost
|
||||
# 为每条记录添加真实成本和倍率信息
|
||||
for i, r in enumerate(usage_records):
|
||||
for i, (r, _) in enumerate(usage_records):
|
||||
# 确保字段有值,避免前端显示 -
|
||||
actual_cost = (
|
||||
r.actual_total_cost_usd if r.actual_total_cost_usd is not None else 0.0
|
||||
@@ -709,6 +765,20 @@ class GetMyIntervalTimelineAdapter(AuthenticatedApiAdapter):
|
||||
return result
|
||||
|
||||
|
||||
class GetMyActivityHeatmapAdapter(AuthenticatedApiAdapter):
|
||||
"""Activity heatmap adapter with Redis caching for user."""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
user = context.user
|
||||
result = await UsageService.get_cached_heatmap(
|
||||
db=context.db,
|
||||
user_id=user.id,
|
||||
include_actual_cost=user.role == "admin",
|
||||
)
|
||||
context.add_audit_metadata(action="activity_heatmap")
|
||||
return result
|
||||
|
||||
|
||||
class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
@@ -213,7 +213,7 @@ class RedisClientManager:
|
||||
f"Redis连接失败: {error_msg}\n"
|
||||
"缓存亲和性功能需要Redis支持,请确保Redis服务正常运行。\n"
|
||||
"检查事项:\n"
|
||||
"1. Redis服务是否已启动(docker-compose up -d redis)\n"
|
||||
"1. Redis服务是否已启动(docker compose up -d redis)\n"
|
||||
"2. 环境变量 REDIS_URL 或 REDIS_PASSWORD 是否配置正确\n"
|
||||
"3. Redis端口(默认6379)是否可访问"
|
||||
) from e
|
||||
|
||||
@@ -21,6 +21,9 @@ class CacheTTL:
|
||||
# L1 本地缓存(用于减少 Redis 访问)
|
||||
L1_LOCAL = 3 # 3秒
|
||||
|
||||
# 活跃度热力图缓存 - 历史数据变化不频繁
|
||||
ACTIVITY_HEATMAP = 300 # 5分钟
|
||||
|
||||
# 并发锁 TTL - 防止死锁
|
||||
CONCURRENCY_LOCK = 600 # 10分钟
|
||||
|
||||
@@ -38,8 +41,25 @@ class CacheSize:
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class StreamDefaults:
|
||||
"""流式处理默认值"""
|
||||
|
||||
# 预读字节上限(避免无换行响应导致内存增长)
|
||||
# 64KB 基于:
|
||||
# 1. SSE 单条消息通常远小于此值
|
||||
# 2. 足够检测 HTML 和 JSON 错误响应
|
||||
# 3. 不会占用过多内存
|
||||
MAX_PREFETCH_BYTES = 64 * 1024 # 64KB
|
||||
|
||||
|
||||
class ConcurrencyDefaults:
|
||||
"""并发控制默认值"""
|
||||
"""并发控制默认值
|
||||
|
||||
算法说明:边界记忆 + 渐进探测
|
||||
- 触发 429 时记录边界(last_concurrent_peak),新限制 = 边界 - 1
|
||||
- 扩容时不超过边界,除非是探测性扩容(长时间无 429)
|
||||
- 这样可以快速收敛到真实限制附近,避免过度保守
|
||||
"""
|
||||
|
||||
# 自适应并发初始限制(宽松起步,遇到 429 再降低)
|
||||
INITIAL_LIMIT = 50
|
||||
@@ -69,10 +89,6 @@ class ConcurrencyDefaults:
|
||||
# 扩容步长 - 每次扩容增加的并发数
|
||||
INCREASE_STEP = 2
|
||||
|
||||
# 缩容乘数 - 遇到 429 时基于当前并发数的缩容比例
|
||||
# 0.85 表示降到触发 429 时并发数的 85%
|
||||
DECREASE_MULTIPLIER = 0.85
|
||||
|
||||
# 最大并发限制上限
|
||||
MAX_CONCURRENT_LIMIT = 200
|
||||
|
||||
@@ -84,6 +100,7 @@ class ConcurrencyDefaults:
|
||||
|
||||
# === 探测性扩容参数 ===
|
||||
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容
|
||||
# 探测性扩容可以突破已知边界,尝试更高的并发
|
||||
PROBE_INCREASE_INTERVAL_MINUTES = 30
|
||||
|
||||
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
|
||||
|
||||
@@ -106,13 +106,6 @@ class Config:
|
||||
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
|
||||
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
|
||||
|
||||
# 可信代理配置
|
||||
# TRUSTED_PROXY_COUNT: 信任的代理层数(默认 1,即信任最近一层代理)
|
||||
# 设置为 0 表示不信任任何代理头,直接使用连接 IP
|
||||
# 当服务部署在 Nginx/CloudFlare 等反向代理后面时,设置为对应的代理层数
|
||||
# 如果服务直接暴露公网,应设置为 0 以防止 IP 伪造
|
||||
self.trusted_proxy_count = int(os.getenv("TRUSTED_PROXY_COUNT", "1"))
|
||||
|
||||
# 异常处理配置
|
||||
# 设置为 True 时,ProxyException 会传播到路由层以便记录 provider_request_headers
|
||||
# 设置为 False 时,使用全局异常处理器统一处理
|
||||
@@ -161,6 +154,11 @@ class Config:
|
||||
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
|
||||
self.stream_first_byte_timeout = self._parse_ttfb_timeout()
|
||||
|
||||
# 请求体读取超时(秒)
|
||||
# REQUEST_BODY_TIMEOUT: 等待客户端发送完整请求体的超时时间
|
||||
# 默认 60 秒,防止客户端发送不完整请求导致连接卡死
|
||||
self.request_body_timeout = float(os.getenv("REQUEST_BODY_TIMEOUT", "60.0"))
|
||||
|
||||
# 内部请求 User-Agent 配置(用于查询上游模型列表等)
|
||||
# 可通过环境变量覆盖默认值,模拟对应 CLI 客户端
|
||||
self.internal_user_agent_claude_cli = os.getenv(
|
||||
|
||||
@@ -30,3 +30,10 @@ class ProviderBillingType(Enum):
|
||||
MONTHLY_QUOTA = "monthly_quota" # 月卡额度
|
||||
PAY_AS_YOU_GO = "pay_as_you_go" # 按量付费
|
||||
FREE_TIER = "free_tier" # 免费额度
|
||||
|
||||
|
||||
class AuthSource(str, Enum):
|
||||
"""认证来源枚举"""
|
||||
|
||||
LOCAL = "local" # 本地认证
|
||||
LDAP = "ldap" # LDAP 认证
|
||||
|
||||
28
src/core/error_utils.py
Normal file
28
src/core/error_utils.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
错误消息处理工具函数
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def extract_error_message(error: Exception, status_code: Optional[int] = None) -> str:
|
||||
"""
|
||||
从异常中提取错误消息,优先使用上游响应内容
|
||||
|
||||
Args:
|
||||
error: 异常对象
|
||||
status_code: 可选的 HTTP 状态码,用于构建更详细的错误消息
|
||||
|
||||
Returns:
|
||||
错误消息字符串
|
||||
"""
|
||||
# 优先使用 upstream_response 属性(包含上游 Provider 的原始错误)
|
||||
upstream_response = getattr(error, "upstream_response", None)
|
||||
if upstream_response and isinstance(upstream_response, str) and upstream_response.strip():
|
||||
return str(upstream_response)
|
||||
|
||||
# 回退到异常的字符串表示(str 可能为空,如 httpx 超时异常)
|
||||
error_str = str(error) or repr(error)
|
||||
if status_code is not None:
|
||||
return f"HTTP {status_code}: {error_str}"
|
||||
return error_str
|
||||
@@ -547,11 +547,19 @@ class ErrorResponse:
|
||||
- 所有错误都记录到日志,通过错误 ID 关联
|
||||
"""
|
||||
if isinstance(e, ProxyException):
|
||||
details = e.details.copy() if e.details else {}
|
||||
status_code = e.status_code
|
||||
message = e.message
|
||||
# 如果是 ProviderNotAvailableException 且有上游错误,直接透传上游信息
|
||||
if isinstance(e, ProviderNotAvailableException) and e.upstream_response:
|
||||
if e.upstream_status:
|
||||
status_code = e.upstream_status
|
||||
message = e.upstream_response
|
||||
return ErrorResponse.create(
|
||||
error_type=e.error_type,
|
||||
message=e.message,
|
||||
status_code=e.status_code,
|
||||
details=e.details,
|
||||
message=message,
|
||||
status_code=status_code,
|
||||
details=details if details else None,
|
||||
)
|
||||
elif isinstance(e, HTTPException):
|
||||
return ErrorResponse.create(
|
||||
|
||||
@@ -411,7 +411,7 @@ def init_db():
|
||||
print(" 3. 数据库用户名和密码是否正确", file=sys.stderr)
|
||||
print("", file=sys.stderr)
|
||||
print("如果使用 Docker,请先运行:", file=sys.stderr)
|
||||
print(" docker-compose up -d postgres redis", file=sys.stderr)
|
||||
print(" docker compose -f docker-compose.build.yml up -d postgres redis", file=sys.stderr)
|
||||
print("", file=sys.stderr)
|
||||
print("=" * 60, file=sys.stderr)
|
||||
# 使用 os._exit 直接退出,避免 uvicorn 捕获并打印堆栈
|
||||
|
||||
@@ -203,28 +203,21 @@ class PluginMiddleware:
|
||||
"""
|
||||
获取客户端 IP 地址,支持代理头
|
||||
|
||||
注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头,
|
||||
仅当服务部署在可信代理(如 Nginx、CloudFlare)后面时才安全。
|
||||
如果服务直接暴露公网,攻击者可伪造这些头绕过限流。
|
||||
优先级:X-Real-IP > X-Forwarded-For > 直连 IP
|
||||
X-Real-IP 由最外层 Nginx 设置,最可靠
|
||||
"""
|
||||
# 从配置获取可信代理层数(默认为 1,即信任最近一层代理)
|
||||
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
|
||||
|
||||
# 优先从代理头获取真实 IP
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",")]
|
||||
if len(ips) > trusted_proxy_count:
|
||||
return ips[-(trusted_proxy_count + 1)]
|
||||
elif ips:
|
||||
return ips[0]
|
||||
|
||||
# 优先检查 X-Real-IP(由最外层 Nginx 设置,最可靠)
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
|
||||
# 检查 X-Forwarded-For,取第一个 IP(原始客户端)
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||
if ips:
|
||||
return ips[0]
|
||||
|
||||
# 回退到直连 IP
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
@@ -4,9 +4,9 @@ API端点请求/响应模型定义
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from ..core.enums import UserRole
|
||||
|
||||
@@ -15,17 +15,9 @@ from ..core.enums import UserRole
|
||||
class LoginRequest(BaseModel):
|
||||
"""登录请求"""
|
||||
|
||||
email: str = Field(..., min_length=3, max_length=255, description="邮箱地址")
|
||||
email: str = Field(..., min_length=1, max_length=255, description="邮箱/用户名")
|
||||
password: str = Field(..., min_length=1, max_length=128, description="密码")
|
||||
|
||||
@classmethod
|
||||
@field_validator("email")
|
||||
def validate_email(cls, v):
|
||||
"""验证邮箱格式"""
|
||||
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
if not re.match(email_pattern, v):
|
||||
raise ValueError("邮箱格式无效")
|
||||
return v.lower()
|
||||
auth_type: Literal["local", "ldap"] = Field(default="local", description="认证类型")
|
||||
|
||||
@classmethod
|
||||
@field_validator("password")
|
||||
@@ -36,6 +28,24 @@ class LoginRequest(BaseModel):
|
||||
raise ValueError("密码不能为空")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_login(self):
|
||||
"""根据认证类型校验并规范化登录标识"""
|
||||
identifier = self.email.strip()
|
||||
|
||||
if not identifier:
|
||||
raise ValueError("用户名/邮箱不能为空")
|
||||
|
||||
# 本地和 LDAP 登录都支持用户名或邮箱
|
||||
# 如果是邮箱格式,转换为小写
|
||||
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
if re.match(email_pattern, identifier):
|
||||
self.email = identifier.lower()
|
||||
else:
|
||||
self.email = identifier
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""登录响应"""
|
||||
@@ -309,8 +319,9 @@ class CreateApiKeyRequest(BaseModel):
|
||||
allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表
|
||||
allowed_api_formats: Optional[List[str]] = None # 允许使用的 API 格式列表
|
||||
allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表
|
||||
rate_limit: Optional[int] = 100
|
||||
expire_days: Optional[int] = None # None = 永不过期,数字 = 多少天后过期
|
||||
rate_limit: Optional[int] = None # None = 无限制
|
||||
expire_days: Optional[int] = None # None = 永不过期,数字 = 多少天后过期(兼容旧版)
|
||||
expires_at: Optional[str] = None # ISO 日期字符串,如 "2025-12-31",优先于 expire_days
|
||||
initial_balance_usd: Optional[float] = Field(
|
||||
None, description="初始余额(USD),仅用于独立Key,None = 无限制"
|
||||
)
|
||||
|
||||
@@ -30,7 +30,7 @@ from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
from ..config import config
|
||||
from ..core.enums import ProviderBillingType, UserRole
|
||||
from ..core.enums import AuthSource, ProviderBillingType, UserRole
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
@@ -54,6 +54,20 @@ class User(Base):
|
||||
default=UserRole.USER,
|
||||
nullable=False,
|
||||
)
|
||||
auth_source = Column(
|
||||
Enum(
|
||||
AuthSource,
|
||||
name="authsource",
|
||||
create_type=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=AuthSource.LOCAL,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# LDAP 标识(仅 auth_source=ldap 时使用,用于在邮箱变更/用户名冲突时稳定关联本地账户)
|
||||
ldap_dn = Column(String(512), nullable=True, index=True)
|
||||
ldap_username = Column(String(255), nullable=True, index=True)
|
||||
|
||||
# 访问限制(NULL 表示不限制,允许访问所有资源)
|
||||
allowed_providers = Column(JSON, nullable=True) # 允许使用的提供商 ID 列表
|
||||
@@ -150,7 +164,7 @@ class ApiKey(Base):
|
||||
allowed_endpoints = Column(JSON, nullable=True) # 允许使用的端点 ID 列表
|
||||
allowed_api_formats = Column(JSON, nullable=True) # 允许使用的 API 格式列表
|
||||
allowed_models = Column(JSON, nullable=True) # 允许使用的模型名称列表
|
||||
rate_limit = Column(Integer, default=100) # 每分钟请求限制
|
||||
rate_limit = Column(Integer, default=None, nullable=True) # 每分钟请求限制,None = 无限制
|
||||
concurrent_limit = Column(Integer, default=5, nullable=True) # 并发请求限制
|
||||
|
||||
# Key 能力配置
|
||||
@@ -428,6 +442,68 @@ class SystemConfig(Base):
|
||||
)
|
||||
|
||||
|
||||
class LDAPConfig(Base):
|
||||
"""LDAP认证配置表 - 单行配置"""
|
||||
|
||||
__tablename__ = "ldap_configs"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
server_url = Column(String(255), nullable=False) # ldap://host:389 或 ldaps://host:636
|
||||
bind_dn = Column(String(255), nullable=False) # 绑定账号 DN
|
||||
bind_password_encrypted = Column(Text, nullable=True) # 加密的绑定密码(允许 NULL 表示已清除)
|
||||
base_dn = Column(String(255), nullable=False) # 用户搜索基础 DN
|
||||
user_search_filter = Column(
|
||||
String(500), default="(uid={username})", nullable=False
|
||||
) # 用户搜索过滤器
|
||||
username_attr = Column(String(50), default="uid", nullable=False) # 用户名属性 (uid/sAMAccountName)
|
||||
email_attr = Column(String(50), default="mail", nullable=False) # 邮箱属性
|
||||
display_name_attr = Column(String(50), default="cn", nullable=False) # 显示名称属性
|
||||
is_enabled = Column(Boolean, default=False, nullable=False) # 是否启用 LDAP 认证
|
||||
is_exclusive = Column(
|
||||
Boolean, default=False, nullable=False
|
||||
) # 是否仅允许 LDAP 登录(禁用本地认证)
|
||||
use_starttls = Column(Boolean, default=False, nullable=False) # 是否使用 STARTTLS
|
||||
connect_timeout = Column(Integer, default=10, nullable=False) # 连接超时时间(秒)
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def set_bind_password(self, password: str) -> None:
|
||||
"""
|
||||
设置并加密绑定密码
|
||||
|
||||
Args:
|
||||
password: 明文密码
|
||||
"""
|
||||
from src.core.crypto import crypto_service
|
||||
|
||||
self.bind_password_encrypted = crypto_service.encrypt(password)
|
||||
|
||||
def get_bind_password(self) -> str:
|
||||
"""
|
||||
获取解密后的绑定密码
|
||||
|
||||
Returns:
|
||||
str: 解密后的明文密码
|
||||
|
||||
Raises:
|
||||
DecryptionException: 解密失败时抛出异常
|
||||
"""
|
||||
from src.core.crypto import crypto_service
|
||||
|
||||
if not self.bind_password_encrypted:
|
||||
return ""
|
||||
return crypto_service.decrypt(self.bind_password_encrypted)
|
||||
|
||||
|
||||
class Provider(Base):
|
||||
"""提供商配置表"""
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ class ProviderEndpointCreate(BaseModel):
|
||||
provider_id: str = Field(..., description="Provider ID")
|
||||
api_format: str = Field(..., description="API 格式 (CLAUDE, OPENAI, CLAUDE_CLI, OPENAI_CLI)")
|
||||
base_url: str = Field(..., min_length=1, max_length=500, description="API 基础 URL")
|
||||
custom_path: Optional[str] = Field(default=None, max_length=200, description="自定义请求路径")
|
||||
|
||||
# 请求配置
|
||||
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
|
||||
@@ -62,6 +63,7 @@ class ProviderEndpointUpdate(BaseModel):
|
||||
base_url: Optional[str] = Field(
|
||||
default=None, min_length=1, max_length=500, description="API 基础 URL"
|
||||
)
|
||||
custom_path: Optional[str] = Field(default=None, max_length=200, description="自定义请求路径")
|
||||
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
|
||||
timeout: Optional[int] = Field(default=None, ge=10, le=600, description="超时时间(秒)")
|
||||
max_retries: Optional[int] = Field(default=None, ge=0, le=10, description="最大重试次数")
|
||||
@@ -94,6 +96,7 @@ class ProviderEndpointResponse(BaseModel):
|
||||
# API 配置
|
||||
api_format: str
|
||||
base_url: str
|
||||
custom_path: Optional[str] = None
|
||||
|
||||
# 请求配置
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
|
||||
@@ -21,7 +21,7 @@ WARNING: 多进程环境注意事项
|
||||
import asyncio
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Deque, Dict
|
||||
|
||||
from src.core.logger import logger
|
||||
@@ -95,12 +95,12 @@ class SlidingWindow:
|
||||
"""获取最早的重置时间"""
|
||||
self._cleanup()
|
||||
if not self.requests:
|
||||
return datetime.now()
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
# 最早的请求将在window_size秒后过期
|
||||
oldest_request = self.requests[0]
|
||||
reset_time = oldest_request + self.window_size
|
||||
return datetime.fromtimestamp(reset_time)
|
||||
return datetime.fromtimestamp(reset_time, tz=timezone.utc)
|
||||
|
||||
|
||||
class SlidingWindowStrategy(RateLimitStrategy):
|
||||
@@ -250,7 +250,7 @@ class SlidingWindowStrategy(RateLimitStrategy):
|
||||
retry_after = None
|
||||
if not allowed:
|
||||
# 计算需要等待的时间(最早请求过期的时间)
|
||||
retry_after = int((reset_at - datetime.now()).total_seconds()) + 1
|
||||
retry_after = int((reset_at - datetime.now(timezone.utc)).total_seconds()) + 1
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=allowed,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from ...clients.redis_client import get_redis_client_sync
|
||||
@@ -63,11 +63,11 @@ class TokenBucket:
|
||||
def get_reset_time(self) -> datetime:
|
||||
"""获取下次完全恢复的时间"""
|
||||
if self.tokens >= self.capacity:
|
||||
return datetime.now()
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
tokens_needed = self.capacity - self.tokens
|
||||
seconds_to_full = tokens_needed / self.refill_rate
|
||||
return datetime.now() + timedelta(seconds=seconds_to_full)
|
||||
return datetime.now(timezone.utc) + timedelta(seconds=seconds_to_full)
|
||||
|
||||
|
||||
class TokenBucketStrategy(RateLimitStrategy):
|
||||
@@ -370,7 +370,7 @@ class RedisTokenBucketBackend:
|
||||
|
||||
if tokens is None or last_refill is None:
|
||||
remaining = capacity
|
||||
reset_at = datetime.now() + timedelta(seconds=capacity / refill_rate)
|
||||
reset_at = datetime.now(timezone.utc) + timedelta(seconds=capacity / refill_rate)
|
||||
else:
|
||||
tokens_value = float(tokens)
|
||||
last_refill_value = float(last_refill)
|
||||
@@ -378,7 +378,7 @@ class RedisTokenBucketBackend:
|
||||
tokens_value = min(capacity, tokens_value + delta * refill_rate)
|
||||
remaining = int(tokens_value)
|
||||
reset_after = 0 if tokens_value >= capacity else (capacity - tokens_value) / refill_rate
|
||||
reset_at = datetime.now() + timedelta(seconds=reset_after)
|
||||
reset_at = datetime.now(timezone.utc) + timedelta(seconds=reset_after)
|
||||
|
||||
allowed = remaining >= amount
|
||||
retry_after = None
|
||||
|
||||
363
src/services/auth/ldap.py
Normal file
363
src/services/auth/ldap.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""LDAP 认证服务"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import LDAPConfig
|
||||
|
||||
# LDAP 连接默认超时时间(秒)
|
||||
DEFAULT_LDAP_CONNECT_TIMEOUT = 10
|
||||
|
||||
|
||||
def parse_ldap_server_url(server_url: str) -> tuple[str, int, bool]:
|
||||
"""
|
||||
解析 LDAP 服务器地址,支持:
|
||||
- ldap://host:389
|
||||
- ldaps://host:636
|
||||
- host:389(无 scheme 时默认 ldap)
|
||||
|
||||
Returns:
|
||||
(host, port, use_ssl)
|
||||
"""
|
||||
raw = (server_url or "").strip()
|
||||
if not raw:
|
||||
raise ValueError("LDAP server_url is required")
|
||||
|
||||
parsed = urlparse(raw)
|
||||
if parsed.scheme in {"ldap", "ldaps"}:
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise ValueError("Invalid LDAP server_url")
|
||||
use_ssl = parsed.scheme == "ldaps"
|
||||
port = parsed.port or (636 if use_ssl else 389)
|
||||
return host, port, use_ssl
|
||||
|
||||
# 兼容无 scheme:按 ldap:// 解析
|
||||
parsed = urlparse(f"ldap://{raw}")
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise ValueError("Invalid LDAP server_url")
|
||||
port = parsed.port or 389
|
||||
return host, port, False
|
||||
|
||||
|
||||
def escape_ldap_filter(value: str, max_length: int = 128) -> str:
|
||||
"""
|
||||
转义 LDAP 过滤器中的特殊字符,防止 LDAP 注入攻击(RFC 4515)
|
||||
|
||||
Args:
|
||||
value: 需要转义的字符串
|
||||
max_length: 最大允许长度,默认 128 字符(覆盖大多数企业邮箱用户名)
|
||||
|
||||
Returns:
|
||||
转义后的安全字符串
|
||||
|
||||
Raises:
|
||||
ValueError: 输入值过长
|
||||
"""
|
||||
import unicodedata
|
||||
|
||||
# 先检查原始长度,防止 DoS 攻击
|
||||
# 128 字符足够覆盖大多数企业用户名和邮箱地址
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f"LDAP filter value too long (max {max_length} characters)")
|
||||
|
||||
# Unicode 规范化(使用 NFC 而非 NFKC,避免兼容性字符转换导致安全问题)
|
||||
value = unicodedata.normalize("NFC", value)
|
||||
|
||||
# 再次检查规范化后的长度(防止规范化后长度突增)
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f"LDAP filter value too long after normalization (max {max_length})")
|
||||
|
||||
# LDAP 过滤器特殊字符(RFC 4515 + 扩展)
|
||||
# 使用显式顺序处理,确保反斜杠首先转义
|
||||
value = value.replace("\\", r"\5c") # 反斜杠必须首先转义
|
||||
value = value.replace("*", r"\2a")
|
||||
value = value.replace("(", r"\28")
|
||||
value = value.replace(")", r"\29")
|
||||
value = value.replace("\x00", r"\00") # NUL
|
||||
value = value.replace("&", r"\26")
|
||||
value = value.replace("|", r"\7c")
|
||||
value = value.replace("=", r"\3d")
|
||||
value = value.replace(">", r"\3e")
|
||||
value = value.replace("<", r"\3c")
|
||||
value = value.replace("~", r"\7e")
|
||||
value = value.replace("!", r"\21")
|
||||
return value
|
||||
|
||||
|
||||
def _get_attr_value(entry: Any, attr_name: str, default: str = "") -> str:
|
||||
"""
|
||||
提取 LDAP 条目属性的首个值,避免返回字符串化的列表表示。
|
||||
"""
|
||||
attr = getattr(entry, attr_name, None)
|
||||
if not attr:
|
||||
return default
|
||||
# ldap3 的 EntryAttribute.value 已经是单值或列表,根据类型取首个
|
||||
val = getattr(attr, "value", None)
|
||||
if isinstance(val, list):
|
||||
val = val[0] if val else default
|
||||
if val is None:
|
||||
return default
|
||||
return str(val)
|
||||
|
||||
|
||||
class LDAPService:
|
||||
"""LDAP 认证服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_config(db: Session) -> Optional[LDAPConfig]:
|
||||
"""获取 LDAP 配置"""
|
||||
return db.query(LDAPConfig).first()
|
||||
|
||||
@staticmethod
|
||||
def is_ldap_enabled(db: Session) -> bool:
|
||||
"""检查 LDAP 是否可用(已启用且绑定密码可解密)"""
|
||||
return LDAPService.get_config_data(db) is not None
|
||||
|
||||
@staticmethod
|
||||
def is_ldap_exclusive(db: Session) -> bool:
|
||||
"""检查是否仅允许 LDAP 登录(仅在 LDAP 可用时生效,避免误锁定)"""
|
||||
config = LDAPService.get_config(db)
|
||||
if not config or config.is_exclusive is not True:
|
||||
return False
|
||||
return LDAPService.get_config_data(db) is not None
|
||||
|
||||
@staticmethod
|
||||
def get_config_data(db: Session) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
提前获取并解密配置,供线程池使用,避免跨线程共享 Session。
|
||||
"""
|
||||
config = LDAPService.get_config(db)
|
||||
if not config or config.is_enabled is not True:
|
||||
return None
|
||||
|
||||
try:
|
||||
bind_password = config.get_bind_password()
|
||||
except Exception as e:
|
||||
logger.error(f"LDAP 绑定密码解密失败: {e}")
|
||||
return None
|
||||
|
||||
# 绑定密码为空时无法进行 LDAP 认证
|
||||
if not bind_password:
|
||||
logger.warning("LDAP 绑定密码未配置,无法进行 LDAP 认证")
|
||||
return None
|
||||
|
||||
return {
|
||||
"server_url": config.server_url,
|
||||
"bind_dn": config.bind_dn,
|
||||
"bind_password": bind_password,
|
||||
"base_dn": config.base_dn,
|
||||
"user_search_filter": config.user_search_filter,
|
||||
"username_attr": config.username_attr,
|
||||
"email_attr": config.email_attr,
|
||||
"display_name_attr": config.display_name_attr,
|
||||
"use_starttls": config.use_starttls,
|
||||
"connect_timeout": config.connect_timeout or DEFAULT_LDAP_CONNECT_TIMEOUT,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def authenticate_with_config(config: Dict[str, Any], username: str, password: str) -> Optional[dict]:
|
||||
"""
|
||||
LDAP bind 验证
|
||||
|
||||
Args:
|
||||
config: 已解密的 LDAP 配置
|
||||
username: 用户名
|
||||
password: 密码
|
||||
|
||||
Returns:
|
||||
用户属性 dict {username, email, display_name} 或 None
|
||||
"""
|
||||
try:
|
||||
import ldap3
|
||||
from ldap3 import Server, Connection, SUBTREE
|
||||
from ldap3.core.exceptions import LDAPBindError, LDAPSocketOpenError
|
||||
except ImportError:
|
||||
logger.error("ldap3 库未安装")
|
||||
return None
|
||||
|
||||
if not config:
|
||||
logger.warning("LDAP 未配置或未启用")
|
||||
return None
|
||||
|
||||
admin_conn = None
|
||||
user_conn = None
|
||||
|
||||
try:
|
||||
# 创建服务器连接
|
||||
server_url = config["server_url"]
|
||||
server_host, server_port, use_ssl = parse_ldap_server_url(server_url)
|
||||
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
|
||||
server = Server(
|
||||
server_host,
|
||||
port=server_port,
|
||||
use_ssl=use_ssl,
|
||||
get_info=ldap3.ALL,
|
||||
connect_timeout=timeout,
|
||||
)
|
||||
|
||||
# 使用管理员账号连接
|
||||
bind_password = config["bind_password"]
|
||||
admin_conn = Connection(
|
||||
server,
|
||||
user=config["bind_dn"],
|
||||
password=bind_password,
|
||||
receive_timeout=timeout, # 添加读取超时,避免服务器响应缓慢时阻塞
|
||||
)
|
||||
|
||||
if config.get("use_starttls") and not use_ssl:
|
||||
admin_conn.start_tls()
|
||||
|
||||
if not admin_conn.bind():
|
||||
logger.error(f"LDAP 管理员绑定失败: {admin_conn.result}")
|
||||
return None
|
||||
|
||||
# 搜索用户(转义用户名防止 LDAP 注入)
|
||||
safe_username = escape_ldap_filter(username)
|
||||
search_filter = config["user_search_filter"].replace("{username}", safe_username)
|
||||
admin_conn.search(
|
||||
search_base=config["base_dn"],
|
||||
search_filter=search_filter,
|
||||
search_scope=SUBTREE,
|
||||
size_limit=2, # 防止过滤器误配导致匹配多用户
|
||||
time_limit=timeout, # 添加搜索超时,防止大型目录搜索阻塞
|
||||
attributes=[
|
||||
config["username_attr"],
|
||||
config["email_attr"],
|
||||
config["display_name_attr"],
|
||||
],
|
||||
)
|
||||
|
||||
if len(admin_conn.entries) != 1:
|
||||
# 统一错误信息,避免泄露用户是否存在;日志仅记录结果数量,不泄露敏感信息
|
||||
logger.warning(
|
||||
f"LDAP 认证失败(用户查找阶段): 搜索返回 {len(admin_conn.entries)} 条结果"
|
||||
)
|
||||
return None
|
||||
|
||||
user_entry = admin_conn.entries[0]
|
||||
user_dn = user_entry.entry_dn
|
||||
|
||||
# 用户密码验证
|
||||
user_conn = Connection(
|
||||
server,
|
||||
user=user_dn,
|
||||
password=password,
|
||||
receive_timeout=timeout, # 添加读取超时
|
||||
)
|
||||
if config.get("use_starttls") and not use_ssl:
|
||||
user_conn.start_tls()
|
||||
|
||||
if not user_conn.bind():
|
||||
# 统一错误信息,避免泄露密码是否正确;日志仅记录错误码,不泄露用户 DN
|
||||
bind_result = user_conn.result.get("description", "unknown")
|
||||
logger.warning(f"LDAP 认证失败(密码验证阶段): {bind_result}")
|
||||
return None
|
||||
|
||||
# 提取用户属性(优先用 LDAP 提供的值,不合法则回退默认)
|
||||
ldap_username = _get_attr_value(user_entry, config["username_attr"], username)
|
||||
email = _get_attr_value(
|
||||
user_entry, config["email_attr"], f"{username}@ldap.local"
|
||||
)
|
||||
display_name = _get_attr_value(user_entry, config["display_name_attr"], username)
|
||||
|
||||
logger.info(f"LDAP 认证成功: {username}")
|
||||
return {
|
||||
"username": ldap_username,
|
||||
"ldap_username": ldap_username,
|
||||
"ldap_dn": user_dn,
|
||||
"email": email,
|
||||
"display_name": display_name,
|
||||
}
|
||||
|
||||
except LDAPSocketOpenError as e:
|
||||
logger.error(f"LDAP 服务器连接失败: {e}")
|
||||
return None
|
||||
except LDAPBindError as e:
|
||||
logger.error(f"LDAP 绑定失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"LDAP 认证异常: {e}")
|
||||
return None
|
||||
finally:
|
||||
# 确保连接关闭,避免失败路径泄漏
|
||||
# 使用循环确保即使第一个 unbind 失败,后续连接仍会尝试关闭
|
||||
for conn, name in [(admin_conn, "admin"), (user_conn, "user")]:
|
||||
if conn:
|
||||
try:
|
||||
conn.unbind()
|
||||
except Exception as e:
|
||||
logger.warning(f"LDAP {name} 连接关闭失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def test_connection_with_config(config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""
|
||||
测试 LDAP 连接
|
||||
|
||||
Returns:
|
||||
(success, message)
|
||||
"""
|
||||
try:
|
||||
import ldap3
|
||||
from ldap3 import Server, Connection
|
||||
except ImportError:
|
||||
return False, "ldap3 库未安装"
|
||||
|
||||
if not config:
|
||||
return False, "LDAP 配置不存在"
|
||||
|
||||
conn = None
|
||||
try:
|
||||
server_url = config["server_url"]
|
||||
server_host, server_port, use_ssl = parse_ldap_server_url(server_url)
|
||||
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
|
||||
server = Server(
|
||||
server_host,
|
||||
port=server_port,
|
||||
use_ssl=use_ssl,
|
||||
get_info=ldap3.ALL,
|
||||
connect_timeout=timeout,
|
||||
)
|
||||
bind_password = config["bind_password"]
|
||||
conn = Connection(
|
||||
server,
|
||||
user=config["bind_dn"],
|
||||
password=bind_password,
|
||||
receive_timeout=timeout, # 添加读取超时
|
||||
)
|
||||
|
||||
if config.get("use_starttls") and not use_ssl:
|
||||
conn.start_tls()
|
||||
|
||||
if not conn.bind():
|
||||
return False, f"绑定失败: {conn.result}"
|
||||
|
||||
return True, "连接成功"
|
||||
|
||||
except Exception as e:
|
||||
# 记录详细错误到日志,但只返回通用信息给前端,避免泄露敏感信息
|
||||
logger.error(f"LDAP 测试连接失败: {type(e).__name__}: {e}")
|
||||
return False, "连接失败,请检查服务器地址、端口和凭据"
|
||||
finally:
|
||||
if conn:
|
||||
try:
|
||||
conn.unbind()
|
||||
except Exception as e:
|
||||
logger.warning(f"LDAP 测试连接关闭失败: {e}")
|
||||
|
||||
# 兼容旧接口:如果其他代码直接调用
|
||||
@staticmethod
|
||||
def authenticate(db: Session, username: str, password: str) -> Optional[dict]:
|
||||
config = LDAPService.get_config_data(db)
|
||||
return LDAPService.authenticate_with_config(config, username, password) if config else None
|
||||
|
||||
@staticmethod
|
||||
def test_connection(db: Session) -> Tuple[bool, str]:
|
||||
config = LDAPService.get_config_data(db)
|
||||
if not config:
|
||||
return False, "LDAP 配置不存在或未启用"
|
||||
return LDAPService.test_connection_with_config(config)
|
||||
@@ -2,21 +2,25 @@
|
||||
认证服务
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import secrets
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.config import config
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.logger import logger
|
||||
from src.core.enums import AuthSource
|
||||
from src.models.database import ApiKey, User, UserRole
|
||||
from src.services.auth.jwt_blacklist import JWTBlacklistService
|
||||
from src.services.auth.ldap import LDAPService
|
||||
from src.services.cache.user_cache import UserCacheService
|
||||
from src.services.user.apikey import ApiKeyService
|
||||
|
||||
@@ -92,15 +96,86 @@ class AuthService:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的Token")
|
||||
|
||||
@staticmethod
|
||||
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
"""用户登录认证"""
|
||||
async def authenticate_user(
|
||||
db: Session, email: str, password: str, auth_type: str = "local"
|
||||
) -> Optional[User]:
|
||||
"""用户登录认证
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
email: 邮箱/用户名
|
||||
password: 密码
|
||||
auth_type: 认证类型 ("local" 或 "ldap")
|
||||
"""
|
||||
if auth_type == "ldap":
|
||||
# LDAP 认证
|
||||
# 预取配置,避免将 Session 传递到线程池
|
||||
config_data = LDAPService.get_config_data(db)
|
||||
if not config_data:
|
||||
logger.warning("登录失败 - LDAP 未启用或配置无效")
|
||||
return None
|
||||
|
||||
# 计算总体超时:LDAP 认证包含多次网络操作(连接、管理员绑定、搜索、用户绑定)
|
||||
# 超时策略:
|
||||
# - 单次操作超时(connect_timeout):控制每次网络操作的最大等待时间
|
||||
# - 总体超时:防止异常场景(如服务器响应缓慢但未超时)导致请求堆积
|
||||
# - 公式:单次超时 × 4(覆盖 4 次主要网络操作)+ 10% 缓冲
|
||||
# - 最小 20 秒(保证基本操作),最大 60 秒(避免用户等待过长)
|
||||
single_timeout = config_data.get("connect_timeout", 10)
|
||||
total_timeout = max(20, min(int(single_timeout * 4 * 1.1), 60))
|
||||
|
||||
# 在线程池中执行阻塞的 LDAP 网络请求,避免阻塞事件循环
|
||||
# 添加总体超时保护,防止异常场景下请求堆积
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
ldap_user = await asyncio.wait_for(
|
||||
run_in_threadpool(
|
||||
LDAPService.authenticate_with_config, config_data, email, password
|
||||
),
|
||||
timeout=total_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"LDAP 认证总体超时({total_timeout}秒): {email}")
|
||||
return None
|
||||
|
||||
if not ldap_user:
|
||||
return None
|
||||
|
||||
# 获取或创建本地用户
|
||||
user = await AuthService._get_or_create_ldap_user(db, ldap_user)
|
||||
if not user:
|
||||
# 已有本地账号但来源不匹配等情况
|
||||
return None
|
||||
if not user.is_active:
|
||||
logger.warning(f"登录失败 - 用户已禁用: {email}")
|
||||
return None
|
||||
return user
|
||||
|
||||
# 本地认证
|
||||
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
# 支持邮箱或用户名登录
|
||||
from sqlalchemy import or_
|
||||
user = db.query(User).filter(
|
||||
or_(User.email == email, User.username == email)
|
||||
).first()
|
||||
|
||||
if not user:
|
||||
logger.warning(f"登录失败 - 用户不存在: {email}")
|
||||
return None
|
||||
|
||||
# 检查 LDAP exclusive 模式:仅允许本地管理员登录(紧急恢复通道)
|
||||
if LDAPService.is_ldap_exclusive(db):
|
||||
if user.role != UserRole.ADMIN or user.auth_source != AuthSource.LOCAL:
|
||||
logger.warning(f"登录失败 - 仅允许 LDAP 登录(管理员除外): {email}")
|
||||
return None
|
||||
logger.warning(f"[LDAP-EXCLUSIVE] 紧急恢复通道:本地管理员登录: {email}")
|
||||
|
||||
# 检查用户认证来源
|
||||
if user.auth_source == AuthSource.LDAP:
|
||||
logger.warning(f"登录失败 - 该用户使用 LDAP 认证: {email}")
|
||||
return None
|
||||
|
||||
if not user.verify_password(password):
|
||||
logger.warning(f"登录失败 - 密码错误: {email}")
|
||||
return None
|
||||
@@ -118,6 +193,127 @@ class AuthService:
|
||||
logger.info(f"用户登录成功: {email} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def _get_or_create_ldap_user(db: Session, ldap_user: dict) -> Optional[User]:
|
||||
"""获取或创建 LDAP 用户
|
||||
|
||||
Args:
|
||||
ldap_user: LDAP 用户信息 {username, email, display_name, ldap_dn, ldap_username}
|
||||
|
||||
注意:使用 with_for_update() 防止并发首次登录创建重复用户
|
||||
"""
|
||||
ldap_dn = (ldap_user.get("ldap_dn") or "").strip() or None
|
||||
ldap_username = (ldap_user.get("ldap_username") or ldap_user.get("username") or "").strip() or None
|
||||
email = ldap_user["email"]
|
||||
|
||||
# 优先用稳定标识查找,避免邮箱变更/用户名冲突导致重复建号
|
||||
# 使用 with_for_update() 锁定行,防止并发创建
|
||||
user: Optional[User] = None
|
||||
if ldap_dn:
|
||||
user = (
|
||||
db.query(User)
|
||||
.filter(User.auth_source == AuthSource.LDAP, User.ldap_dn == ldap_dn)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not user and ldap_username:
|
||||
user = (
|
||||
db.query(User)
|
||||
.filter(User.auth_source == AuthSource.LDAP, User.ldap_username == ldap_username)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
# 最后回退按 email 查找:如果存在同邮箱的本地账号,需要拒绝以避免接管
|
||||
user = db.query(User).filter(User.email == email).with_for_update().first()
|
||||
|
||||
if user:
|
||||
if user.auth_source != AuthSource.LDAP:
|
||||
# 避免覆盖已有本地账户(不同来源时拒绝登录)
|
||||
logger.warning(
|
||||
f"LDAP 登录拒绝 - 账户来源不匹配(现有:{user.auth_source}, 请求:LDAP): {email}"
|
||||
)
|
||||
return None
|
||||
|
||||
# 同步邮箱(LDAP 侧邮箱变更时更新;若新邮箱已被占用则拒绝)
|
||||
if user.email != email:
|
||||
email_taken = (
|
||||
db.query(User)
|
||||
.filter(User.email == email, User.id != user.id)
|
||||
.first()
|
||||
)
|
||||
if email_taken:
|
||||
logger.warning(f"LDAP 登录拒绝 - 新邮箱已被占用: {email}")
|
||||
return None
|
||||
user.email = email
|
||||
|
||||
# 同步 LDAP 标识(首次填充或 LDAP 侧发生变化)
|
||||
if ldap_dn and user.ldap_dn != ldap_dn:
|
||||
user.ldap_dn = ldap_dn
|
||||
if ldap_username and user.ldap_username != ldap_username:
|
||||
user.ldap_username = ldap_username
|
||||
|
||||
user.last_login_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
await UserCacheService.invalidate_user_cache(user.id, user.email)
|
||||
logger.info(f"LDAP 用户登录成功: {ldap_user['email']} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
# 检查 username 是否已被占用,使用时间戳+随机数确保唯一性
|
||||
base_username = ldap_username or ldap_user["username"]
|
||||
username = base_username
|
||||
max_retries = 3
|
||||
|
||||
for attempt in range(max_retries):
|
||||
# 检查用户名是否已存在
|
||||
existing_user_with_username = db.query(User).filter(User.username == username).first()
|
||||
if existing_user_with_username:
|
||||
# 如果 username 已存在,使用时间戳+随机数确保唯一性
|
||||
username = f"{base_username}_ldap_{int(time.time())}{uuid.uuid4().hex[:4]}"
|
||||
logger.info(f"LDAP 用户名冲突,使用新用户名: {ldap_user['username']} -> {username}")
|
||||
|
||||
# 创建新用户
|
||||
user = User(
|
||||
email=email,
|
||||
username=username,
|
||||
password_hash="", # LDAP 用户无本地密码
|
||||
auth_source=AuthSource.LDAP,
|
||||
ldap_dn=ldap_dn,
|
||||
ldap_username=ldap_username,
|
||||
role=UserRole.USER,
|
||||
is_active=True,
|
||||
last_login_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
try:
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
logger.info(f"LDAP 用户创建成功: {ldap_user['email']} (ID: {user.id})")
|
||||
return user
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_str = str(e.orig).lower() if e.orig else str(e).lower()
|
||||
|
||||
# 解析具体冲突类型
|
||||
if "email" in error_str or "ix_users_email" in error_str:
|
||||
# 邮箱冲突不应重试(前面已检查过,说明是并发创建)
|
||||
logger.error(f"LDAP 用户创建失败 - 邮箱并发冲突: {email}")
|
||||
return None
|
||||
elif "username" in error_str or "ix_users_username" in error_str:
|
||||
# 用户名冲突,重试时会生成新用户名
|
||||
if attempt == max_retries - 1:
|
||||
logger.error(f"LDAP 用户创建失败(用户名冲突重试耗尽): {username}")
|
||||
return None
|
||||
username = f"{base_username}_ldap_{int(time.time())}{uuid.uuid4().hex[:4]}"
|
||||
logger.warning(f"LDAP 用户创建用户名冲突,重试 ({attempt + 1}/{max_retries}): {username}")
|
||||
else:
|
||||
# 其他约束冲突,不重试
|
||||
logger.error(f"LDAP 用户创建失败 - 未知数据库约束冲突: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def authenticate_api_key(db: Session, api_key: str) -> Optional[tuple[User, ApiKey]]:
|
||||
"""API密钥认证"""
|
||||
|
||||
51
src/services/billing/__init__.py
Normal file
51
src/services/billing/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
计费模块
|
||||
|
||||
提供配置驱动的计费计算,支持不同厂商的差异化计费模式:
|
||||
- Claude: input + output + cache_creation + cache_read
|
||||
- OpenAI: input + output + cache_read (无缓存创建费用)
|
||||
- 豆包: input + output + cache_read + cache_storage (缓存按时计费)
|
||||
- 按次计费: per_request
|
||||
|
||||
使用方式:
|
||||
from src.services.billing import BillingCalculator, UsageMapper, StandardizedUsage
|
||||
|
||||
# 1. 将原始 usage 映射为标准格式
|
||||
usage = UsageMapper.map(raw_usage, api_format="OPENAI")
|
||||
|
||||
# 2. 使用计费计算器计算费用
|
||||
calculator = BillingCalculator(template="openai")
|
||||
result = calculator.calculate(usage, prices)
|
||||
|
||||
# 3. 获取费用明细
|
||||
print(result.total_cost)
|
||||
print(result.costs) # {"input": 0.01, "output": 0.02, ...}
|
||||
"""
|
||||
|
||||
from src.services.billing.calculator import BillingCalculator, calculate_request_cost
|
||||
from src.services.billing.models import (
|
||||
BillingDimension,
|
||||
BillingUnit,
|
||||
CostBreakdown,
|
||||
StandardizedUsage,
|
||||
)
|
||||
from src.services.billing.templates import BILLING_TEMPLATE_REGISTRY, BillingTemplates
|
||||
from src.services.billing.usage_mapper import UsageMapper, map_usage, map_usage_from_response
|
||||
|
||||
__all__ = [
|
||||
# 数据模型
|
||||
"BillingDimension",
|
||||
"BillingUnit",
|
||||
"CostBreakdown",
|
||||
"StandardizedUsage",
|
||||
# 模板
|
||||
"BillingTemplates",
|
||||
"BILLING_TEMPLATE_REGISTRY",
|
||||
# 计算器
|
||||
"BillingCalculator",
|
||||
"calculate_request_cost",
|
||||
# 映射器
|
||||
"UsageMapper",
|
||||
"map_usage",
|
||||
"map_usage_from_response",
|
||||
]
|
||||
339
src/services/billing/calculator.py
Normal file
339
src/services/billing/calculator.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
计费计算器
|
||||
|
||||
配置驱动的计费计算,支持:
|
||||
- 固定价格计费
|
||||
- 阶梯计费
|
||||
- 多种计费模板
|
||||
- 自定义计费维度
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.services.billing.models import (
|
||||
BillingDimension,
|
||||
BillingUnit,
|
||||
CostBreakdown,
|
||||
StandardizedUsage,
|
||||
)
|
||||
from src.services.billing.templates import (
|
||||
BILLING_TEMPLATE_REGISTRY,
|
||||
BillingTemplates,
|
||||
get_template,
|
||||
)
|
||||
|
||||
|
||||
class BillingCalculator:
|
||||
"""
|
||||
配置驱动的计费计算器
|
||||
|
||||
支持多种计费模式:
|
||||
- 使用预定义模板(claude, openai, doubao 等)
|
||||
- 自定义计费维度
|
||||
- 阶梯计费
|
||||
|
||||
示例:
|
||||
# 使用模板
|
||||
calculator = BillingCalculator(template="openai")
|
||||
|
||||
# 自定义维度
|
||||
calculator = BillingCalculator(dimensions=[
|
||||
BillingDimension(name="input", usage_field="input_tokens", price_field="input_price_per_1m"),
|
||||
BillingDimension(name="output", usage_field="output_tokens", price_field="output_price_per_1m"),
|
||||
])
|
||||
|
||||
# 计算费用
|
||||
usage = StandardizedUsage(input_tokens=1000, output_tokens=500)
|
||||
prices = {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0}
|
||||
result = calculator.calculate(usage, prices)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dimensions: Optional[List[BillingDimension]] = None,
|
||||
template: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
初始化计费计算器
|
||||
|
||||
Args:
|
||||
dimensions: 自定义计费维度列表(优先级高于模板)
|
||||
template: 使用预定义模板名称 ("claude", "openai", "doubao", "per_request" 等)
|
||||
"""
|
||||
if dimensions:
|
||||
self.dimensions = dimensions
|
||||
elif template:
|
||||
self.dimensions = get_template(template)
|
||||
else:
|
||||
# 默认使用 Claude 模板(向后兼容)
|
||||
self.dimensions = BillingTemplates.CLAUDE_STANDARD
|
||||
|
||||
self.template_name = template
|
||||
|
||||
def calculate(
|
||||
self,
|
||||
usage: StandardizedUsage,
|
||||
prices: Dict[str, float],
|
||||
tiered_pricing: Optional[Dict[str, Any]] = None,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
total_input_context: Optional[int] = None,
|
||||
) -> CostBreakdown:
|
||||
"""
|
||||
计算费用
|
||||
|
||||
Args:
|
||||
usage: 标准化的 usage 数据
|
||||
prices: 价格配置 {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0, ...}
|
||||
tiered_pricing: 阶梯计费配置(可选)
|
||||
cache_ttl_minutes: 缓存 TTL 分钟数(用于 TTL 差异化定价)
|
||||
total_input_context: 总输入上下文(用于阶梯判定,可选)
|
||||
如果提供,将使用该值进行阶梯判定;否则使用默认计算逻辑
|
||||
|
||||
Returns:
|
||||
费用明细 (CostBreakdown)
|
||||
"""
|
||||
result = CostBreakdown()
|
||||
|
||||
# 处理阶梯计费
|
||||
effective_prices = prices.copy()
|
||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||
tier, tier_index = self._get_tier(usage, tiered_pricing, total_input_context)
|
||||
if tier:
|
||||
result.tier_index = tier_index
|
||||
# 阶梯价格覆盖默认价格
|
||||
for key, value in tier.items():
|
||||
if key not in ("up_to", "cache_ttl_pricing") and value is not None:
|
||||
effective_prices[key] = value
|
||||
|
||||
# 处理 TTL 差异化定价
|
||||
if cache_ttl_minutes is not None:
|
||||
ttl_price = self._get_cache_read_price_for_ttl(tier, cache_ttl_minutes)
|
||||
if ttl_price is not None:
|
||||
effective_prices["cache_read_price_per_1m"] = ttl_price
|
||||
|
||||
# 记录使用的价格
|
||||
result.effective_prices = effective_prices.copy()
|
||||
|
||||
# 计算各维度费用
|
||||
total = 0.0
|
||||
for dim in self.dimensions:
|
||||
usage_value = usage.get(dim.usage_field, 0)
|
||||
price = effective_prices.get(dim.price_field, dim.default_price)
|
||||
|
||||
if usage_value and price:
|
||||
cost = dim.calculate(usage_value, price)
|
||||
result.costs[dim.name] = cost
|
||||
total += cost
|
||||
|
||||
result.total_cost = total
|
||||
return result
|
||||
|
||||
def _get_tier(
|
||||
self,
|
||||
usage: StandardizedUsage,
|
||||
tiered_pricing: Dict[str, Any],
|
||||
total_input_context: Optional[int] = None,
|
||||
) -> Tuple[Optional[Dict[str, Any]], Optional[int]]:
|
||||
"""
|
||||
确定价格阶梯
|
||||
|
||||
Args:
|
||||
usage: usage 数据
|
||||
tiered_pricing: 阶梯配置 {"tiers": [...]}
|
||||
total_input_context: 预计算的总输入上下文(可选)
|
||||
|
||||
Returns:
|
||||
(匹配的阶梯配置, 阶梯索引)
|
||||
"""
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None, None
|
||||
|
||||
# 使用传入的 total_input_context,或者默认计算
|
||||
if total_input_context is None:
|
||||
total_input_context = self._compute_total_input_context(usage)
|
||||
|
||||
for i, tier in enumerate(tiers):
|
||||
up_to = tier.get("up_to")
|
||||
# up_to 为 None 表示无上限(最后一个阶梯)
|
||||
if up_to is None or total_input_context <= up_to:
|
||||
return tier, i
|
||||
|
||||
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
|
||||
return tiers[-1], len(tiers) - 1
|
||||
|
||||
def _compute_total_input_context(self, usage: StandardizedUsage) -> int:
|
||||
"""
|
||||
计算总输入上下文(用于阶梯计费判定)
|
||||
|
||||
默认: input_tokens + cache_read_tokens
|
||||
|
||||
Args:
|
||||
usage: usage 数据
|
||||
|
||||
Returns:
|
||||
总输入 token 数
|
||||
"""
|
||||
return usage.input_tokens + usage.cache_read_tokens
|
||||
|
||||
def _get_cache_read_price_for_ttl(
|
||||
self,
|
||||
tier: Dict[str, Any],
|
||||
cache_ttl_minutes: int,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
根据缓存 TTL 获取缓存读取价格
|
||||
|
||||
某些厂商(如 Claude)对不同 TTL 的缓存有不同定价。
|
||||
|
||||
Args:
|
||||
tier: 当前阶梯配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
缓存读取价格,如果没有 TTL 差异化配置返回 None
|
||||
"""
|
||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||
if not ttl_pricing:
|
||||
return None
|
||||
|
||||
# 找到匹配或最接近的 TTL 价格
|
||||
for ttl_config in ttl_pricing:
|
||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||
if cache_ttl_minutes <= ttl_limit:
|
||||
price = ttl_config.get("cache_read_price_per_1m")
|
||||
return float(price) if price is not None else None
|
||||
|
||||
# 超过所有配置的 TTL,使用最后一个
|
||||
if ttl_pricing:
|
||||
price = ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||
return float(price) if price is not None else None
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "BillingCalculator":
|
||||
"""
|
||||
从配置创建计费计算器
|
||||
|
||||
Config 格式:
|
||||
{
|
||||
"template": "claude", # 或 "openai", "doubao", "per_request"
|
||||
# 或者自定义维度:
|
||||
"dimensions": [
|
||||
{"name": "input", "usage_field": "input_tokens", "price_field": "input_price_per_1m"},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
|
||||
Returns:
|
||||
BillingCalculator 实例
|
||||
"""
|
||||
if "dimensions" in config:
|
||||
dimensions = [BillingDimension.from_dict(d) for d in config["dimensions"]]
|
||||
return cls(dimensions=dimensions)
|
||||
|
||||
return cls(template=config.get("template", "claude"))
|
||||
|
||||
def get_dimension_names(self) -> List[str]:
|
||||
"""获取所有计费维度名称"""
|
||||
return [dim.name for dim in self.dimensions]
|
||||
|
||||
def get_required_price_fields(self) -> List[str]:
|
||||
"""获取所需的价格字段名称"""
|
||||
return [dim.price_field for dim in self.dimensions]
|
||||
|
||||
def get_required_usage_fields(self) -> List[str]:
|
||||
"""获取所需的 usage 字段名称"""
|
||||
return [dim.usage_field for dim in self.dimensions]
|
||||
|
||||
|
||||
def calculate_request_cost(
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
input_price_per_1m: float,
|
||||
output_price_per_1m: float,
|
||||
cache_creation_price_per_1m: Optional[float],
|
||||
cache_read_price_per_1m: Optional[float],
|
||||
price_per_request: Optional[float],
|
||||
tiered_pricing: Optional[Dict[str, Any]] = None,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
total_input_context: Optional[int] = None,
|
||||
billing_template: str = "claude",
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
计算请求成本的便捷函数
|
||||
|
||||
封装了 BillingCalculator 的调用逻辑,返回兼容旧格式的字典。
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
output_tokens: 输出 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
input_price_per_1m: 输入价格(每 1M tokens)
|
||||
output_price_per_1m: 输出价格(每 1M tokens)
|
||||
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens)
|
||||
cache_read_price_per_1m: 缓存读取价格(每 1M tokens)
|
||||
price_per_request: 按次计费价格
|
||||
tiered_pricing: 阶梯计费配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
total_input_context: 总输入上下文(用于阶梯判定)
|
||||
billing_template: 计费模板名称
|
||||
|
||||
Returns:
|
||||
包含各项成本的字典:
|
||||
{
|
||||
"input_cost": float,
|
||||
"output_cost": float,
|
||||
"cache_creation_cost": float,
|
||||
"cache_read_cost": float,
|
||||
"cache_cost": float,
|
||||
"request_cost": float,
|
||||
"total_cost": float,
|
||||
"tier_index": Optional[int],
|
||||
}
|
||||
"""
|
||||
# 构建标准化 usage
|
||||
usage = StandardizedUsage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_tokens=cache_creation_input_tokens,
|
||||
cache_read_tokens=cache_read_input_tokens,
|
||||
request_count=1,
|
||||
)
|
||||
|
||||
# 构建价格配置
|
||||
prices: Dict[str, float] = {
|
||||
"input_price_per_1m": input_price_per_1m,
|
||||
"output_price_per_1m": output_price_per_1m,
|
||||
}
|
||||
if cache_creation_price_per_1m is not None:
|
||||
prices["cache_creation_price_per_1m"] = cache_creation_price_per_1m
|
||||
if cache_read_price_per_1m is not None:
|
||||
prices["cache_read_price_per_1m"] = cache_read_price_per_1m
|
||||
if price_per_request is not None:
|
||||
prices["price_per_request"] = price_per_request
|
||||
|
||||
# 使用 BillingCalculator 计算
|
||||
calculator = BillingCalculator(template=billing_template)
|
||||
result = calculator.calculate(
|
||||
usage, prices, tiered_pricing, cache_ttl_minutes, total_input_context
|
||||
)
|
||||
|
||||
# 返回兼容旧格式的字典
|
||||
return {
|
||||
"input_cost": result.input_cost,
|
||||
"output_cost": result.output_cost,
|
||||
"cache_creation_cost": result.cache_creation_cost,
|
||||
"cache_read_cost": result.cache_read_cost,
|
||||
"cache_cost": result.cache_cost,
|
||||
"request_cost": result.request_cost,
|
||||
"total_cost": result.total_cost,
|
||||
"tier_index": result.tier_index,
|
||||
}
|
||||
281
src/services/billing/models.py
Normal file
281
src/services/billing/models.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
计费模块数据模型
|
||||
|
||||
定义计费相关的核心数据结构:
|
||||
- BillingUnit: 计费单位枚举
|
||||
- BillingDimension: 计费维度定义
|
||||
- StandardizedUsage: 标准化的 usage 数据
|
||||
- CostBreakdown: 计费明细结果
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class BillingUnit(str, Enum):
|
||||
"""计费单位"""
|
||||
|
||||
PER_1M_TOKENS = "per_1m_tokens" # 每百万 token
|
||||
PER_1M_TOKENS_HOUR = "per_1m_tokens_hour" # 每百万 token 每小时(豆包缓存存储)
|
||||
PER_REQUEST = "per_request" # 每次请求
|
||||
FIXED = "fixed" # 固定费用
|
||||
|
||||
|
||||
@dataclass
|
||||
class BillingDimension:
|
||||
"""
|
||||
计费维度定义
|
||||
|
||||
每个维度描述一种计费方式,例如:
|
||||
- 输入 token 计费
|
||||
- 输出 token 计费
|
||||
- 缓存读取计费
|
||||
- 按次计费
|
||||
"""
|
||||
|
||||
name: str # 维度名称,如 "input", "output", "cache_read"
|
||||
usage_field: str # 从 usage 中取值的字段名
|
||||
price_field: str # 价格配置中的字段名
|
||||
unit: BillingUnit = BillingUnit.PER_1M_TOKENS # 计费单位
|
||||
default_price: float = 0.0 # 默认价格(当价格配置中没有时使用)
|
||||
|
||||
def calculate(self, usage_value: float, price: float) -> float:
|
||||
"""
|
||||
计算该维度的费用
|
||||
|
||||
Args:
|
||||
usage_value: 使用量数值
|
||||
price: 单价
|
||||
|
||||
Returns:
|
||||
计算后的费用
|
||||
"""
|
||||
if usage_value <= 0 or price <= 0:
|
||||
return 0.0
|
||||
|
||||
if self.unit == BillingUnit.PER_1M_TOKENS:
|
||||
return (usage_value / 1_000_000) * price
|
||||
elif self.unit == BillingUnit.PER_1M_TOKENS_HOUR:
|
||||
# 缓存存储按 token 数 * 小时数计费
|
||||
return (usage_value / 1_000_000) * price
|
||||
elif self.unit == BillingUnit.PER_REQUEST:
|
||||
return usage_value * price
|
||||
elif self.unit == BillingUnit.FIXED:
|
||||
return price
|
||||
|
||||
return 0.0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典(用于序列化)"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"usage_field": self.usage_field,
|
||||
"price_field": self.price_field,
|
||||
"unit": self.unit.value,
|
||||
"default_price": self.default_price,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BillingDimension":
|
||||
"""从字典创建实例"""
|
||||
return cls(
|
||||
name=data["name"],
|
||||
usage_field=data["usage_field"],
|
||||
price_field=data["price_field"],
|
||||
unit=BillingUnit(data.get("unit", "per_1m_tokens")),
|
||||
default_price=data.get("default_price", 0.0),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StandardizedUsage:
|
||||
"""
|
||||
标准化的 Usage 数据
|
||||
|
||||
将不同 API 格式的 usage 统一为标准格式,便于计费计算。
|
||||
"""
|
||||
|
||||
# 基础 token 计数
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
|
||||
# 缓存相关
|
||||
cache_creation_tokens: int = 0 # Claude: 缓存创建
|
||||
cache_read_tokens: int = 0 # Claude/OpenAI/豆包: 缓存读取/命中
|
||||
|
||||
# 特殊 token 类型
|
||||
reasoning_tokens: int = 0 # o1/豆包: 推理 token(通常包含在 output 中,单独记录用于分析)
|
||||
|
||||
# 时间相关(用于按时计费)
|
||||
cache_storage_token_hours: float = 0.0 # 豆包: 缓存存储 token*小时
|
||||
|
||||
# 请求计数(用于按次计费)
|
||||
request_count: int = 1
|
||||
|
||||
# 扩展字段(未来可能需要的额外维度)
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def get(self, field_name: str, default: Any = 0) -> Any:
|
||||
"""
|
||||
通用字段获取
|
||||
|
||||
支持获取标准字段和扩展字段。
|
||||
|
||||
Args:
|
||||
field_name: 字段名
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
字段值
|
||||
"""
|
||||
if hasattr(self, field_name):
|
||||
value = getattr(self, field_name)
|
||||
# 对于 extra 字段,不直接返回
|
||||
if field_name != "extra":
|
||||
return value
|
||||
return self.extra.get(field_name, default)
|
||||
|
||||
def set(self, field_name: str, value: Any) -> None:
|
||||
"""
|
||||
通用字段设置
|
||||
|
||||
Args:
|
||||
field_name: 字段名
|
||||
value: 字段值
|
||||
"""
|
||||
if hasattr(self, field_name) and field_name != "extra":
|
||||
setattr(self, field_name, value)
|
||||
else:
|
||||
self.extra[field_name] = value
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
result: Dict[str, Any] = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"cache_creation_tokens": self.cache_creation_tokens,
|
||||
"cache_read_tokens": self.cache_read_tokens,
|
||||
"reasoning_tokens": self.reasoning_tokens,
|
||||
"cache_storage_token_hours": self.cache_storage_token_hours,
|
||||
"request_count": self.request_count,
|
||||
}
|
||||
if self.extra:
|
||||
result["extra"] = self.extra
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "StandardizedUsage":
|
||||
"""从字典创建实例"""
|
||||
extra = data.pop("extra", {}) if "extra" in data else {}
|
||||
# 只取已知字段
|
||||
known_fields = {
|
||||
"input_tokens",
|
||||
"output_tokens",
|
||||
"cache_creation_tokens",
|
||||
"cache_read_tokens",
|
||||
"reasoning_tokens",
|
||||
"cache_storage_token_hours",
|
||||
"request_count",
|
||||
}
|
||||
filtered = {k: v for k, v in data.items() if k in known_fields}
|
||||
return cls(**filtered, extra=extra)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CostBreakdown:
|
||||
"""
|
||||
计费明细结果
|
||||
|
||||
包含各维度的费用和总费用。
|
||||
"""
|
||||
|
||||
# 各维度费用 {"input": 0.01, "output": 0.02, "cache_read": 0.001, ...}
|
||||
costs: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
# 总费用
|
||||
total_cost: float = 0.0
|
||||
|
||||
# 命中的阶梯索引(如果使用阶梯计费)
|
||||
tier_index: Optional[int] = None
|
||||
|
||||
# 货币单位
|
||||
currency: str = "USD"
|
||||
|
||||
# 使用的价格(用于记录和审计)
|
||||
effective_prices: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
# =========================================================================
|
||||
# 兼容旧接口的属性(便于渐进式迁移)
|
||||
# =========================================================================
|
||||
|
||||
@property
|
||||
def input_cost(self) -> float:
|
||||
"""输入费用"""
|
||||
return self.costs.get("input", 0.0)
|
||||
|
||||
@property
|
||||
def output_cost(self) -> float:
|
||||
"""输出费用"""
|
||||
return self.costs.get("output", 0.0)
|
||||
|
||||
@property
|
||||
def cache_creation_cost(self) -> float:
|
||||
"""缓存创建费用"""
|
||||
return self.costs.get("cache_creation", 0.0)
|
||||
|
||||
@property
|
||||
def cache_read_cost(self) -> float:
|
||||
"""缓存读取费用"""
|
||||
return self.costs.get("cache_read", 0.0)
|
||||
|
||||
@property
|
||||
def cache_cost(self) -> float:
|
||||
"""总缓存费用(创建 + 读取)"""
|
||||
return self.cache_creation_cost + self.cache_read_cost
|
||||
|
||||
@property
|
||||
def request_cost(self) -> float:
|
||||
"""按次计费费用"""
|
||||
return self.costs.get("request", 0.0)
|
||||
|
||||
@property
|
||||
def cache_storage_cost(self) -> float:
|
||||
"""缓存存储费用(豆包等)"""
|
||||
return self.costs.get("cache_storage", 0.0)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"costs": self.costs,
|
||||
"total_cost": self.total_cost,
|
||||
"tier_index": self.tier_index,
|
||||
"currency": self.currency,
|
||||
"effective_prices": self.effective_prices,
|
||||
# 兼容字段
|
||||
"input_cost": self.input_cost,
|
||||
"output_cost": self.output_cost,
|
||||
"cache_creation_cost": self.cache_creation_cost,
|
||||
"cache_read_cost": self.cache_read_cost,
|
||||
"cache_cost": self.cache_cost,
|
||||
"request_cost": self.request_cost,
|
||||
}
|
||||
|
||||
def to_legacy_tuple(self) -> tuple:
|
||||
"""
|
||||
转换为旧接口的元组格式
|
||||
|
||||
Returns:
|
||||
(input_cost, output_cost, cache_creation_cost, cache_read_cost,
|
||||
cache_cost, request_cost, total_cost, tier_index)
|
||||
"""
|
||||
return (
|
||||
self.input_cost,
|
||||
self.output_cost,
|
||||
self.cache_creation_cost,
|
||||
self.cache_read_cost,
|
||||
self.cache_cost,
|
||||
self.request_cost,
|
||||
self.total_cost,
|
||||
self.tier_index,
|
||||
)
|
||||
213
src/services/billing/templates.py
Normal file
213
src/services/billing/templates.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
预定义计费模板
|
||||
|
||||
提供常见厂商的计费配置模板,避免重复配置:
|
||||
- CLAUDE_STANDARD: Claude/Anthropic 标准计费
|
||||
- OPENAI_STANDARD: OpenAI 标准计费
|
||||
- DOUBAO_STANDARD: 豆包计费(含缓存存储)
|
||||
- GEMINI_STANDARD: Gemini 标准计费
|
||||
- PER_REQUEST: 按次计费
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from src.services.billing.models import BillingDimension, BillingUnit
|
||||
|
||||
|
||||
class BillingTemplates:
|
||||
"""预定义的计费模板"""
|
||||
|
||||
# =========================================================================
|
||||
# Claude/Anthropic 标准计费
|
||||
# - 输入 token
|
||||
# - 输出 token
|
||||
# - 缓存创建(创建时收费,约 1.25x 输入价格)
|
||||
# - 缓存读取(约 0.1x 输入价格)
|
||||
# =========================================================================
|
||||
CLAUDE_STANDARD: List[BillingDimension] = [
|
||||
BillingDimension(
|
||||
name="input",
|
||||
usage_field="input_tokens",
|
||||
price_field="input_price_per_1m",
|
||||
),
|
||||
BillingDimension(
|
||||
name="output",
|
||||
usage_field="output_tokens",
|
||||
price_field="output_price_per_1m",
|
||||
),
|
||||
BillingDimension(
|
||||
name="cache_creation",
|
||||
usage_field="cache_creation_tokens",
|
||||
price_field="cache_creation_price_per_1m",
|
||||
),
|
||||
BillingDimension(
|
||||
name="cache_read",
|
||||
usage_field="cache_read_tokens",
|
||||
price_field="cache_read_price_per_1m",
|
||||
),
|
||||
]
|
||||
|
||||
# =========================================================================
|
||||
# OpenAI 标准计费
|
||||
# - 输入 token
|
||||
# - 输出 token
|
||||
# - 缓存读取(部分模型支持,无缓存创建费用)
|
||||
# =========================================================================
|
||||
OPENAI_STANDARD: List[BillingDimension] = [
|
||||
BillingDimension(
|
||||
name="input",
|
||||
usage_field="input_tokens",
|
||||
price_field="input_price_per_1m",
|
||||
),
|
||||
BillingDimension(
|
||||
name="output",
|
||||
usage_field="output_tokens",
|
||||
price_field="output_price_per_1m",
|
||||
),
|
||||
BillingDimension(
|
||||
name="cache_read",
|
||||
usage_field="cache_read_tokens",
|
||||
price_field="cache_read_price_per_1m",
|
||||
),
|
||||
]
|
||||
|
||||
# =========================================================================
|
||||
# 豆包计费
|
||||
# - 推理输入 (input_tokens)
|
||||
# - 推理输出 (output_tokens)
|
||||
# - 缓存命中 (cache_read_tokens) - 类似 Claude 的缓存读取
|
||||
# - 缓存存储 (cache_storage_token_hours) - 按 token 数 * 存储时长计费
|
||||
#
|
||||
# 注意:豆包的缓存创建是免费的,但存储需要按时付费
|
||||
# =========================================================================
|
||||
DOUBAO_STANDARD: List[BillingDimension] = [
|
||||
BillingDimension(
|
||||
name="input",
|
||||
usage_field="input_tokens",
|
||||
price_field="input_price_per_1m",
|
||||
),
|
||||
BillingDimension(
|
||||
name="output",
|
||||
usage_field="output_tokens",
|
||||
price_field="output_price_per_1m",
|
||||
),
|
||||
BillingDimension(
|
||||
name="cache_read",
|
||||
usage_field="cache_read_tokens",
|
||||
price_field="cache_read_price_per_1m",
|
||||
),
|
||||
BillingDimension(
|
||||
name="cache_storage",
|
||||
usage_field="cache_storage_token_hours",
|
||||
price_field="cache_storage_price_per_1m_hour",
|
||||
unit=BillingUnit.PER_1M_TOKENS_HOUR,
|
||||
),
|
||||
]
|
||||
|
||||
# =========================================================================
|
||||
# Gemini 标准计费
|
||||
# - 输入 token
|
||||
# - 输出 token
|
||||
# - 缓存读取
|
||||
# =========================================================================
|
||||
GEMINI_STANDARD: List[BillingDimension] = [
|
||||
BillingDimension(
|
||||
name="input",
|
||||
usage_field="input_tokens",
|
||||
price_field="input_price_per_1m",
|
||||
),
|
||||
BillingDimension(
|
||||
name="output",
|
||||
usage_field="output_tokens",
|
||||
price_field="output_price_per_1m",
|
||||
),
|
||||
BillingDimension(
|
||||
name="cache_read",
|
||||
usage_field="cache_read_tokens",
|
||||
price_field="cache_read_price_per_1m",
|
||||
),
|
||||
]
|
||||
|
||||
# =========================================================================
|
||||
# 按次计费
|
||||
# - 适用于某些图片生成模型、特殊 API 等
|
||||
# - 仅按请求次数计费,不按 token 计费
|
||||
# =========================================================================
|
||||
PER_REQUEST: List[BillingDimension] = [
|
||||
BillingDimension(
|
||||
name="request",
|
||||
usage_field="request_count",
|
||||
price_field="price_per_request",
|
||||
unit=BillingUnit.PER_REQUEST,
|
||||
),
|
||||
]
|
||||
|
||||
# =========================================================================
|
||||
# 混合计费(按次 + 按 token)
|
||||
# - 某些模型既有固定费用又有 token 费用
|
||||
# =========================================================================
|
||||
HYBRID_STANDARD: List[BillingDimension] = [
|
||||
BillingDimension(
|
||||
name="input",
|
||||
usage_field="input_tokens",
|
||||
price_field="input_price_per_1m",
|
||||
),
|
||||
BillingDimension(
|
||||
name="output",
|
||||
usage_field="output_tokens",
|
||||
price_field="output_price_per_1m",
|
||||
),
|
||||
BillingDimension(
|
||||
name="request",
|
||||
usage_field="request_count",
|
||||
price_field="price_per_request",
|
||||
unit=BillingUnit.PER_REQUEST,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# 模板注册表
|
||||
# =========================================================================
|
||||
|
||||
BILLING_TEMPLATE_REGISTRY: Dict[str, List[BillingDimension]] = {
|
||||
# 按厂商名称
|
||||
"claude": BillingTemplates.CLAUDE_STANDARD,
|
||||
"anthropic": BillingTemplates.CLAUDE_STANDARD,
|
||||
"openai": BillingTemplates.OPENAI_STANDARD,
|
||||
"doubao": BillingTemplates.DOUBAO_STANDARD,
|
||||
"bytedance": BillingTemplates.DOUBAO_STANDARD,
|
||||
"gemini": BillingTemplates.GEMINI_STANDARD,
|
||||
"google": BillingTemplates.GEMINI_STANDARD,
|
||||
# 按计费模式
|
||||
"per_request": BillingTemplates.PER_REQUEST,
|
||||
"hybrid": BillingTemplates.HYBRID_STANDARD,
|
||||
# 默认
|
||||
"default": BillingTemplates.CLAUDE_STANDARD,
|
||||
}
|
||||
|
||||
|
||||
def get_template(name: Optional[str]) -> List[BillingDimension]:
|
||||
"""
|
||||
获取计费模板
|
||||
|
||||
Args:
|
||||
name: 模板名称(不区分大小写)
|
||||
|
||||
Returns:
|
||||
计费维度列表
|
||||
"""
|
||||
if not name:
|
||||
return BILLING_TEMPLATE_REGISTRY["default"]
|
||||
|
||||
template = BILLING_TEMPLATE_REGISTRY.get(name.lower())
|
||||
if template is None:
|
||||
available = ", ".join(sorted(BILLING_TEMPLATE_REGISTRY.keys()))
|
||||
raise ValueError(f"Unknown billing template: {name!r}. Available: {available}")
|
||||
|
||||
return template
|
||||
|
||||
|
||||
def list_templates() -> List[str]:
|
||||
"""列出所有可用的模板名称"""
|
||||
return list(BILLING_TEMPLATE_REGISTRY.keys())
|
||||
267
src/services/billing/usage_mapper.py
Normal file
267
src/services/billing/usage_mapper.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Usage 字段映射器
|
||||
|
||||
将不同 API 格式的原始 usage 数据映射为标准化格式。
|
||||
|
||||
支持的格式:
|
||||
- OPENAI / OPENAI_CLI: OpenAI Chat Completions API
|
||||
- CLAUDE / CLAUDE_CLI: Anthropic Messages API
|
||||
- GEMINI / GEMINI_CLI: Google Gemini API
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.services.billing.models import StandardizedUsage
|
||||
|
||||
|
||||
class UsageMapper:
|
||||
"""
|
||||
Usage 字段映射器
|
||||
|
||||
将不同 API 格式的 usage 统一映射为 StandardizedUsage。
|
||||
|
||||
示例:
|
||||
# OpenAI 格式
|
||||
raw_usage = {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"prompt_tokens_details": {"cached_tokens": 20},
|
||||
"completion_tokens_details": {"reasoning_tokens": 10}
|
||||
}
|
||||
usage = UsageMapper.map(raw_usage, "OPENAI")
|
||||
|
||||
# Claude 格式
|
||||
raw_usage = {
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"cache_creation_input_tokens": 30,
|
||||
"cache_read_input_tokens": 20
|
||||
}
|
||||
usage = UsageMapper.map(raw_usage, "CLAUDE")
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# 字段映射配置
|
||||
# 格式: "source_path" -> "target_field"
|
||||
# source_path 支持点号分隔的嵌套路径
|
||||
# =========================================================================
|
||||
|
||||
# OpenAI 格式字段映射
|
||||
OPENAI_MAPPING: Dict[str, str] = {
|
||||
"prompt_tokens": "input_tokens",
|
||||
"completion_tokens": "output_tokens",
|
||||
"prompt_tokens_details.cached_tokens": "cache_read_tokens",
|
||||
"completion_tokens_details.reasoning_tokens": "reasoning_tokens",
|
||||
}
|
||||
|
||||
# Claude 格式字段映射
|
||||
CLAUDE_MAPPING: Dict[str, str] = {
|
||||
"input_tokens": "input_tokens",
|
||||
"output_tokens": "output_tokens",
|
||||
"cache_creation_input_tokens": "cache_creation_tokens",
|
||||
"cache_read_input_tokens": "cache_read_tokens",
|
||||
}
|
||||
|
||||
# Gemini 格式字段映射
|
||||
GEMINI_MAPPING: Dict[str, str] = {
|
||||
"promptTokenCount": "input_tokens",
|
||||
"candidatesTokenCount": "output_tokens",
|
||||
"cachedContentTokenCount": "cache_read_tokens",
|
||||
# Gemini 的 usageMetadata 格式
|
||||
"usageMetadata.promptTokenCount": "input_tokens",
|
||||
"usageMetadata.candidatesTokenCount": "output_tokens",
|
||||
"usageMetadata.cachedContentTokenCount": "cache_read_tokens",
|
||||
}
|
||||
|
||||
# 格式名称到映射的对应关系
|
||||
FORMAT_MAPPINGS: Dict[str, Dict[str, str]] = {
|
||||
"OPENAI": OPENAI_MAPPING,
|
||||
"OPENAI_CLI": OPENAI_MAPPING,
|
||||
"CLAUDE": CLAUDE_MAPPING,
|
||||
"CLAUDE_CLI": CLAUDE_MAPPING,
|
||||
"GEMINI": GEMINI_MAPPING,
|
||||
"GEMINI_CLI": GEMINI_MAPPING,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def map(
|
||||
cls,
|
||||
raw_usage: Dict[str, Any],
|
||||
api_format: str,
|
||||
extra_mapping: Optional[Dict[str, str]] = None,
|
||||
) -> StandardizedUsage:
|
||||
"""
|
||||
将原始 usage 映射为标准化格式
|
||||
|
||||
Args:
|
||||
raw_usage: 原始 usage 字典
|
||||
api_format: API 格式 ("OPENAI", "CLAUDE", "GEMINI" 等)
|
||||
extra_mapping: 额外的字段映射(用于自定义扩展)
|
||||
|
||||
Returns:
|
||||
标准化的 usage 对象
|
||||
"""
|
||||
if not raw_usage:
|
||||
return StandardizedUsage()
|
||||
|
||||
# 获取对应格式的字段映射
|
||||
mapping = cls._get_mapping(api_format)
|
||||
|
||||
# 合并额外映射
|
||||
if extra_mapping:
|
||||
mapping = {**mapping, **extra_mapping}
|
||||
|
||||
result = StandardizedUsage()
|
||||
|
||||
# 执行映射
|
||||
for source_path, target_field in mapping.items():
|
||||
value = cls._get_nested_value(raw_usage, source_path)
|
||||
if value is not None:
|
||||
result.set(target_field, value)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def map_from_response(
|
||||
cls,
|
||||
response: Dict[str, Any],
|
||||
api_format: str,
|
||||
) -> StandardizedUsage:
|
||||
"""
|
||||
从完整响应中提取并映射 usage
|
||||
|
||||
不同 API 格式的 usage 位置可能不同:
|
||||
- OpenAI: response["usage"]
|
||||
- Claude: response["usage"] 或 message_delta 中
|
||||
- Gemini: response["usageMetadata"]
|
||||
|
||||
Args:
|
||||
response: 完整的 API 响应
|
||||
api_format: API 格式
|
||||
|
||||
Returns:
|
||||
标准化的 usage 对象
|
||||
"""
|
||||
format_upper = api_format.upper() if api_format else ""
|
||||
|
||||
# 提取 usage 部分
|
||||
usage_data: Dict[str, Any] = {}
|
||||
|
||||
if format_upper.startswith("GEMINI"):
|
||||
# Gemini: usageMetadata
|
||||
usage_data = response.get("usageMetadata", {})
|
||||
if not usage_data:
|
||||
# 尝试从 candidates 中获取
|
||||
candidates = response.get("candidates", [])
|
||||
if candidates:
|
||||
usage_data = candidates[0].get("usageMetadata", {})
|
||||
else:
|
||||
# OpenAI/Claude: usage
|
||||
usage_data = response.get("usage", {})
|
||||
|
||||
return cls.map(usage_data, api_format)
|
||||
|
||||
@classmethod
|
||||
def _get_mapping(cls, api_format: str) -> Dict[str, str]:
|
||||
"""获取对应格式的字段映射"""
|
||||
if not api_format:
|
||||
return cls.CLAUDE_MAPPING
|
||||
|
||||
format_upper = api_format.upper()
|
||||
|
||||
# 精确匹配
|
||||
if format_upper in cls.FORMAT_MAPPINGS:
|
||||
return cls.FORMAT_MAPPINGS[format_upper]
|
||||
|
||||
# 前缀匹配
|
||||
for key, mapping in cls.FORMAT_MAPPINGS.items():
|
||||
if format_upper.startswith(key.split("_")[0]):
|
||||
return mapping
|
||||
|
||||
# 默认使用 Claude 映射
|
||||
return cls.CLAUDE_MAPPING
|
||||
|
||||
@classmethod
|
||||
def _get_nested_value(cls, data: Dict[str, Any], path: str) -> Any:
|
||||
"""
|
||||
获取嵌套字段值
|
||||
|
||||
支持点号分隔的路径,如 "prompt_tokens_details.cached_tokens"
|
||||
|
||||
Args:
|
||||
data: 数据字典
|
||||
path: 字段路径
|
||||
|
||||
Returns:
|
||||
字段值,不存在则返回 None
|
||||
"""
|
||||
if not data or not path:
|
||||
return None
|
||||
|
||||
keys = path.split(".")
|
||||
value: Any = data
|
||||
|
||||
for key in keys:
|
||||
if isinstance(value, dict):
|
||||
value = value.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def register_format(cls, format_name: str, mapping: Dict[str, str]) -> None:
|
||||
"""
|
||||
注册新的格式映射
|
||||
|
||||
Args:
|
||||
format_name: 格式名称(会自动转为大写)
|
||||
mapping: 字段映射
|
||||
"""
|
||||
cls.FORMAT_MAPPINGS[format_name.upper()] = mapping
|
||||
|
||||
@classmethod
|
||||
def get_supported_formats(cls) -> list:
|
||||
"""获取所有支持的格式"""
|
||||
return list(cls.FORMAT_MAPPINGS.keys())
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# 便捷函数
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def map_usage(
|
||||
raw_usage: Dict[str, Any],
|
||||
api_format: str,
|
||||
) -> StandardizedUsage:
|
||||
"""
|
||||
便捷函数:将原始 usage 映射为标准化格式
|
||||
|
||||
Args:
|
||||
raw_usage: 原始 usage 字典
|
||||
api_format: API 格式
|
||||
|
||||
Returns:
|
||||
StandardizedUsage 对象
|
||||
"""
|
||||
return UsageMapper.map(raw_usage, api_format)
|
||||
|
||||
|
||||
def map_usage_from_response(
|
||||
response: Dict[str, Any],
|
||||
api_format: str,
|
||||
) -> StandardizedUsage:
|
||||
"""
|
||||
便捷函数:从响应中提取并映射 usage
|
||||
|
||||
Args:
|
||||
response: API 响应
|
||||
api_format: API 格式
|
||||
|
||||
Returns:
|
||||
StandardizedUsage 对象
|
||||
"""
|
||||
return UsageMapper.map_from_response(response, api_format)
|
||||
@@ -148,6 +148,8 @@ class GlobalModelService:
|
||||
删除 GlobalModel
|
||||
|
||||
默认行为: 级联删除所有关联的 Provider 模型实现
|
||||
注意: 不清理 API Key 和 User 的 allowed_models 引用,
|
||||
保留无效引用可让用户在前端看到"已失效"的模型,便于手动清理或等待重建同名模型
|
||||
"""
|
||||
global_model = GlobalModelService.get_global_model(db, global_model_id)
|
||||
|
||||
|
||||
@@ -237,7 +237,7 @@ class ErrorClassifier:
|
||||
result["reason"] = str(data.get("reason", data.get("code", "")))
|
||||
|
||||
except (json.JSONDecodeError, TypeError, KeyError):
|
||||
result["message"] = error_text[:500] if len(error_text) > 500 else error_text
|
||||
result["message"] = error_text
|
||||
|
||||
return result
|
||||
|
||||
@@ -323,8 +323,8 @@ class ErrorClassifier:
|
||||
if parts:
|
||||
return ": ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
# 无法解析,返回原始文本(截断)
|
||||
return parsed["raw"][:500] if len(parsed["raw"]) > 500 else parsed["raw"]
|
||||
# 无法解析,返回原始文本
|
||||
return parsed["raw"]
|
||||
|
||||
def classify(
|
||||
self,
|
||||
@@ -484,11 +484,15 @@ class ErrorClassifier:
|
||||
return ProviderNotAvailableException(
|
||||
message=detailed_message,
|
||||
provider_name=provider_name,
|
||||
upstream_status=status,
|
||||
upstream_response=error_response_text,
|
||||
)
|
||||
|
||||
return ProviderNotAvailableException(
|
||||
message=detailed_message,
|
||||
provider_name=provider_name,
|
||||
upstream_status=status,
|
||||
upstream_response=error_response_text,
|
||||
)
|
||||
|
||||
async def handle_http_error(
|
||||
@@ -532,12 +536,14 @@ class ErrorClassifier:
|
||||
provider_name = str(provider.name)
|
||||
|
||||
# 尝试读取错误响应内容
|
||||
error_response_text = None
|
||||
try:
|
||||
if http_error.response and hasattr(http_error.response, "text"):
|
||||
error_response_text = http_error.response.text[:1000] # 限制长度
|
||||
except Exception:
|
||||
pass
|
||||
# 优先使用 handler 附加的 upstream_response 属性(流式请求中 response.text 可能为空)
|
||||
error_response_text = getattr(http_error, "upstream_response", None)
|
||||
if not error_response_text:
|
||||
try:
|
||||
if http_error.response and hasattr(http_error.response, "text"):
|
||||
error_response_text = http_error.response.text
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.warning(f" [{request_id}] HTTP错误 (attempt={attempt}/{max_attempts}): "
|
||||
f"{http_error.response.status_code if http_error.response else 'unknown'}")
|
||||
|
||||
@@ -30,6 +30,7 @@ from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.enums import APIFormat
|
||||
from src.core.error_utils import extract_error_message
|
||||
from src.core.exceptions import (
|
||||
ConcurrencyLimitError,
|
||||
ProviderNotAvailableException,
|
||||
@@ -401,7 +402,7 @@ class FallbackOrchestrator:
|
||||
db=self.db,
|
||||
candidate_id=candidate_record_id,
|
||||
error_type="HTTPStatusError",
|
||||
error_message=f"HTTP {status_code}: {str(cause)}",
|
||||
error_message=extract_error_message(cause, status_code),
|
||||
status_code=status_code,
|
||||
latency_ms=elapsed_ms,
|
||||
concurrent_requests=captured_key_concurrent,
|
||||
@@ -425,31 +426,22 @@ class FallbackOrchestrator:
|
||||
attempt=attempt,
|
||||
max_attempts=max_attempts,
|
||||
)
|
||||
# str(cause) 可能为空(如 httpx 超时异常),使用 repr() 作为备用
|
||||
error_msg = str(cause) or repr(cause)
|
||||
# 如果是 ProviderNotAvailableException,附加上游响应
|
||||
if hasattr(cause, "upstream_response") and cause.upstream_response:
|
||||
error_msg = f"{error_msg} | 上游响应: {cause.upstream_response[:500]}"
|
||||
RequestCandidateService.mark_candidate_failed(
|
||||
db=self.db,
|
||||
candidate_id=candidate_record_id,
|
||||
error_type=type(cause).__name__,
|
||||
error_message=error_msg,
|
||||
error_message=extract_error_message(cause),
|
||||
latency_ms=elapsed_ms,
|
||||
concurrent_requests=captured_key_concurrent,
|
||||
)
|
||||
return "continue" if has_retry_left else "break"
|
||||
|
||||
# 未知错误:记录失败并抛出
|
||||
error_msg = str(cause) or repr(cause)
|
||||
# 如果是 ProviderNotAvailableException,附加上游响应
|
||||
if hasattr(cause, "upstream_response") and cause.upstream_response:
|
||||
error_msg = f"{error_msg} | 上游响应: {cause.upstream_response[:500]}"
|
||||
RequestCandidateService.mark_candidate_failed(
|
||||
db=self.db,
|
||||
candidate_id=candidate_record_id,
|
||||
error_type=type(cause).__name__,
|
||||
error_message=error_msg,
|
||||
error_message=extract_error_message(cause),
|
||||
latency_ms=elapsed_ms,
|
||||
concurrent_requests=captured_key_concurrent,
|
||||
)
|
||||
@@ -543,7 +535,9 @@ class FallbackOrchestrator:
|
||||
raise last_error
|
||||
|
||||
# 所有组合都已尝试完毕,全部失败
|
||||
self._raise_all_failed_exception(request_id, max_attempts, last_candidate, model_name, api_format_enum)
|
||||
self._raise_all_failed_exception(
|
||||
request_id, max_attempts, last_candidate, model_name, api_format_enum, last_error
|
||||
)
|
||||
|
||||
async def _try_candidate_with_retries(
|
||||
self,
|
||||
@@ -565,6 +559,7 @@ class FallbackOrchestrator:
|
||||
provider = candidate.provider
|
||||
endpoint = candidate.endpoint
|
||||
max_retries_for_candidate = int(endpoint.max_retries) if candidate.is_cached else 1
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for retry_index in range(max_retries_for_candidate):
|
||||
attempt_counter += 1
|
||||
@@ -599,6 +594,7 @@ class FallbackOrchestrator:
|
||||
return {"success": True, "response": response}
|
||||
|
||||
except ExecutionError as exec_err:
|
||||
last_error = exec_err.cause
|
||||
action = await self._handle_candidate_error(
|
||||
exec_err=exec_err,
|
||||
candidate=candidate,
|
||||
@@ -630,6 +626,7 @@ class FallbackOrchestrator:
|
||||
"success": False,
|
||||
"attempt_counter": attempt_counter,
|
||||
"max_attempts": max_attempts,
|
||||
"error": last_error,
|
||||
}
|
||||
|
||||
def _attach_metadata_to_error(
|
||||
@@ -678,6 +675,7 @@ class FallbackOrchestrator:
|
||||
last_candidate: Optional[ProviderCandidate],
|
||||
model_name: str,
|
||||
api_format_enum: APIFormat,
|
||||
last_error: Optional[Exception] = None,
|
||||
) -> NoReturn:
|
||||
"""所有组合都失败时抛出异常"""
|
||||
logger.error(f" [{request_id}] 所有 {max_attempts} 个组合均失败")
|
||||
@@ -693,9 +691,38 @@ class FallbackOrchestrator:
|
||||
"api_format": api_format_enum.value,
|
||||
}
|
||||
|
||||
# 提取上游错误响应
|
||||
upstream_status: Optional[int] = None
|
||||
upstream_response: Optional[str] = None
|
||||
if last_error:
|
||||
# 从 httpx.HTTPStatusError 提取
|
||||
if isinstance(last_error, httpx.HTTPStatusError):
|
||||
upstream_status = last_error.response.status_code
|
||||
# 优先使用我们附加的 upstream_response 属性(流已读取时 response.text 可能为空)
|
||||
upstream_response = getattr(last_error, "upstream_response", None)
|
||||
if not upstream_response:
|
||||
try:
|
||||
upstream_response = last_error.response.text
|
||||
except Exception:
|
||||
pass
|
||||
# 从其他异常属性提取(如 ProviderNotAvailableException)
|
||||
else:
|
||||
upstream_status = getattr(last_error, "upstream_status", None)
|
||||
upstream_response = getattr(last_error, "upstream_response", None)
|
||||
|
||||
# 如果响应为空或无效,使用异常的字符串表示
|
||||
if (
|
||||
not upstream_response
|
||||
or not upstream_response.strip()
|
||||
or upstream_response.startswith("Unable to read")
|
||||
):
|
||||
upstream_response = str(last_error)
|
||||
|
||||
raise ProviderNotAvailableException(
|
||||
f"所有Provider均不可用,已尝试{max_attempts}个组合",
|
||||
request_metadata=request_metadata,
|
||||
upstream_status=upstream_status,
|
||||
upstream_response=upstream_response,
|
||||
)
|
||||
|
||||
async def execute_with_fallback(
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
"""
|
||||
自适应并发调整器 - 基于滑动窗口利用率的并发限制调整
|
||||
自适应并发调整器 - 基于边界记忆的并发限制调整
|
||||
|
||||
核心改进(相对于旧版基于"持续高利用率"的方案):
|
||||
- 使用滑动窗口采样,容忍并发波动
|
||||
- 基于窗口内高利用率采样比例决策,而非要求连续高利用率
|
||||
- 增加探测性扩容机制,长时间稳定时主动尝试扩容
|
||||
核心算法:边界记忆 + 渐进探测
|
||||
- 触发 429 时记录边界(last_concurrent_peak),这就是真实上限
|
||||
- 缩容策略:新限制 = 边界 - 1,而非乘性减少
|
||||
- 扩容策略:不超过已知边界,除非是探测性扩容
|
||||
- 探测性扩容:长时间无 429 时尝试突破边界
|
||||
|
||||
AIMD 参数说明:
|
||||
- 扩容:加性增加 (+INCREASE_STEP)
|
||||
- 缩容:乘性减少 (*DECREASE_MULTIPLIER,默认 0.85)
|
||||
设计原则:
|
||||
1. 快速收敛:一次 429 就能找到接近真实的限制
|
||||
2. 避免过度保守:不会因为多次 429 而无限下降
|
||||
3. 安全探测:允许在稳定后尝试更高并发
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
@@ -35,21 +37,21 @@ class AdaptiveConcurrencyManager:
|
||||
"""
|
||||
自适应并发管理器
|
||||
|
||||
核心算法:基于滑动窗口利用率的 AIMD
|
||||
- 滑动窗口记录最近 N 次请求的利用率
|
||||
- 当窗口内高利用率采样比例 >= 60% 时触发扩容
|
||||
- 遇到 429 错误时乘性减少 (*0.85)
|
||||
- 长时间无 429 且有流量时触发探测性扩容
|
||||
核心算法:边界记忆 + 渐进探测
|
||||
- 触发 429 时记录边界(last_concurrent_peak = 触发时的并发数)
|
||||
- 缩容:新限制 = 边界 - 1(快速收敛到真实限制附近)
|
||||
- 扩容:不超过边界(即 last_concurrent_peak),允许回到边界值尝试
|
||||
- 探测性扩容:长时间(30分钟)无 429 时,可以尝试 +1 突破边界
|
||||
|
||||
扩容条件(满足任一即可):
|
||||
1. 滑动窗口扩容:窗口内 >= 60% 的采样利用率 >= 70%,且不在冷却期
|
||||
2. 探测性扩容:距上次 429 超过 30 分钟,且期间有足够请求量
|
||||
1. 利用率扩容:窗口内高利用率比例 >= 60%,且当前限制 < 边界
|
||||
2. 探测性扩容:距上次 429 超过 30 分钟,可以尝试突破边界
|
||||
|
||||
关键特性:
|
||||
1. 滑动窗口容忍并发波动,不会因单次低利用率重置
|
||||
2. 区分并发限制和 RPM 限制
|
||||
3. 探测性扩容避免长期卡在低限制
|
||||
4. 记录调整历史
|
||||
1. 快速收敛:一次 429 就能学到接近真实的限制值
|
||||
2. 边界保护:普通扩容不会超过已知边界
|
||||
3. 安全探测:长时间稳定后允许尝试更高并发
|
||||
4. 区分并发限制和 RPM 限制
|
||||
"""
|
||||
|
||||
# 默认配置 - 使用统一常量
|
||||
@@ -59,7 +61,6 @@ class AdaptiveConcurrencyManager:
|
||||
|
||||
# AIMD 参数
|
||||
INCREASE_STEP = ConcurrencyDefaults.INCREASE_STEP
|
||||
DECREASE_MULTIPLIER = ConcurrencyDefaults.DECREASE_MULTIPLIER
|
||||
|
||||
# 滑动窗口参数
|
||||
UTILIZATION_WINDOW_SIZE = ConcurrencyDefaults.UTILIZATION_WINDOW_SIZE
|
||||
@@ -115,7 +116,13 @@ class AdaptiveConcurrencyManager:
|
||||
# 更新429统计
|
||||
key.last_429_at = datetime.now(timezone.utc) # type: ignore[assignment]
|
||||
key.last_429_type = rate_limit_info.limit_type # type: ignore[assignment]
|
||||
key.last_concurrent_peak = current_concurrent # type: ignore[assignment]
|
||||
# 仅在并发限制且拿到并发数时记录边界(RPM/UNKNOWN 不应覆盖并发边界记忆)
|
||||
if (
|
||||
rate_limit_info.limit_type == RateLimitType.CONCURRENT
|
||||
and current_concurrent is not None
|
||||
and current_concurrent > 0
|
||||
):
|
||||
key.last_concurrent_peak = current_concurrent # type: ignore[assignment]
|
||||
|
||||
# 遇到 429 错误,清空利用率采样窗口(重新开始收集)
|
||||
key.utilization_samples = [] # type: ignore[assignment]
|
||||
@@ -207,6 +214,9 @@ class AdaptiveConcurrencyManager:
|
||||
|
||||
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
|
||||
|
||||
# 获取已知边界(上次触发 429 时的并发数)
|
||||
known_boundary = key.last_concurrent_peak
|
||||
|
||||
# 计算当前利用率
|
||||
utilization = float(current_concurrent / current_limit) if current_limit > 0 else 0.0
|
||||
|
||||
@@ -217,22 +227,29 @@ class AdaptiveConcurrencyManager:
|
||||
samples = self._update_utilization_window(key, now_ts, utilization)
|
||||
|
||||
# 检查是否满足扩容条件
|
||||
increase_reason = self._check_increase_conditions(key, samples, now)
|
||||
increase_reason = self._check_increase_conditions(key, samples, now, known_boundary)
|
||||
|
||||
if increase_reason and current_limit < self.MAX_CONCURRENT_LIMIT:
|
||||
old_limit = current_limit
|
||||
new_limit = self._increase_limit(current_limit)
|
||||
is_probe = increase_reason == "probe_increase"
|
||||
new_limit = self._increase_limit(current_limit, known_boundary, is_probe)
|
||||
|
||||
# 如果没有实际增长(已达边界),跳过
|
||||
if new_limit <= old_limit:
|
||||
return None
|
||||
|
||||
# 计算窗口统计用于日志
|
||||
avg_util = sum(s["util"] for s in samples) / len(samples) if samples else 0
|
||||
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
|
||||
high_util_ratio = high_util_count / len(samples) if samples else 0
|
||||
|
||||
boundary_info = f"边界: {known_boundary}" if known_boundary else "无边界"
|
||||
logger.info(
|
||||
f"[INCREASE] {increase_reason}: Key {key.id[:8]}... | "
|
||||
f"窗口采样: {len(samples)} | "
|
||||
f"平均利用率: {avg_util:.1%} | "
|
||||
f"高利用率比例: {high_util_ratio:.1%} | "
|
||||
f"{boundary_info} | "
|
||||
f"调整: {old_limit} -> {new_limit}"
|
||||
)
|
||||
|
||||
@@ -246,13 +263,14 @@ class AdaptiveConcurrencyManager:
|
||||
high_util_ratio=round(high_util_ratio, 2),
|
||||
sample_count=len(samples),
|
||||
current_concurrent=current_concurrent,
|
||||
known_boundary=known_boundary,
|
||||
)
|
||||
|
||||
# 更新限制
|
||||
key.learned_max_concurrent = new_limit # type: ignore[assignment]
|
||||
|
||||
# 如果是探测性扩容,更新探测时间
|
||||
if increase_reason == "probe_increase":
|
||||
if is_probe:
|
||||
key.last_probe_increase_at = now # type: ignore[assignment]
|
||||
|
||||
# 扩容后清空采样窗口,重新开始收集
|
||||
@@ -303,7 +321,11 @@ class AdaptiveConcurrencyManager:
|
||||
return samples
|
||||
|
||||
def _check_increase_conditions(
|
||||
self, key: ProviderAPIKey, samples: List[Dict[str, Any]], now: datetime
|
||||
self,
|
||||
key: ProviderAPIKey,
|
||||
samples: List[Dict[str, Any]],
|
||||
now: datetime,
|
||||
known_boundary: Optional[int] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
检查是否满足扩容条件
|
||||
@@ -312,6 +334,7 @@ class AdaptiveConcurrencyManager:
|
||||
key: API Key对象
|
||||
samples: 利用率采样列表
|
||||
now: 当前时间
|
||||
known_boundary: 已知边界(触发 429 时的并发数)
|
||||
|
||||
Returns:
|
||||
扩容原因(如果满足条件),否则返回 None
|
||||
@@ -320,15 +343,25 @@ class AdaptiveConcurrencyManager:
|
||||
if self._is_in_cooldown(key):
|
||||
return None
|
||||
|
||||
# 条件1:滑动窗口扩容
|
||||
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
|
||||
|
||||
# 条件1:滑动窗口扩容(不超过边界)
|
||||
if len(samples) >= self.MIN_SAMPLES_FOR_DECISION:
|
||||
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
|
||||
high_util_ratio = high_util_count / len(samples)
|
||||
|
||||
if high_util_ratio >= self.HIGH_UTILIZATION_RATIO:
|
||||
return "high_utilization"
|
||||
# 检查是否还有扩容空间(边界保护)
|
||||
if known_boundary:
|
||||
# 允许扩容到边界值(而非 boundary - 1),因为缩容时已经 -1 了
|
||||
if current_limit < known_boundary:
|
||||
return "high_utilization"
|
||||
# 已达边界,不触发普通扩容
|
||||
else:
|
||||
# 无边界信息,允许扩容
|
||||
return "high_utilization"
|
||||
|
||||
# 条件2:探测性扩容(长时间无 429 且有流量)
|
||||
# 条件2:探测性扩容(长时间无 429 且有流量,可以突破边界)
|
||||
if self._should_probe_increase(key, samples, now):
|
||||
return "probe_increase"
|
||||
|
||||
@@ -406,32 +439,65 @@ class AdaptiveConcurrencyManager:
|
||||
current_concurrent: Optional[int] = None,
|
||||
) -> int:
|
||||
"""
|
||||
减少并发限制
|
||||
减少并发限制(基于边界记忆策略)
|
||||
|
||||
策略:
|
||||
- 如果知道当前并发数,设置为当前并发的70%
|
||||
- 否则,使用乘性减少
|
||||
- 如果知道触发 429 时的并发数,新限制 = 并发数 - 1
|
||||
- 这样可以快速收敛到真实限制附近,而不会过度保守
|
||||
- 例如:真实限制 8,触发时并发 8 -> 新限制 7(而非 8*0.85=6)
|
||||
"""
|
||||
if current_concurrent:
|
||||
# 基于当前并发数减少
|
||||
new_limit = max(
|
||||
int(current_concurrent * self.DECREASE_MULTIPLIER), self.MIN_CONCURRENT_LIMIT
|
||||
)
|
||||
if current_concurrent is not None and current_concurrent > 0:
|
||||
# 边界记忆策略:新限制 = 触发边界 - 1
|
||||
candidate = current_concurrent - 1
|
||||
else:
|
||||
# 乘性减少
|
||||
new_limit = max(
|
||||
int(current_limit * self.DECREASE_MULTIPLIER), self.MIN_CONCURRENT_LIMIT
|
||||
)
|
||||
# 没有并发信息时,保守减少 1
|
||||
candidate = current_limit - 1
|
||||
|
||||
# 保证不会“缩容变扩容”(例如 current_concurrent > current_limit 的异常场景)
|
||||
candidate = min(candidate, current_limit - 1)
|
||||
|
||||
new_limit = max(candidate, self.MIN_CONCURRENT_LIMIT)
|
||||
|
||||
return new_limit
|
||||
|
||||
def _increase_limit(self, current_limit: int) -> int:
|
||||
def _increase_limit(
|
||||
self,
|
||||
current_limit: int,
|
||||
known_boundary: Optional[int] = None,
|
||||
is_probe: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
增加并发限制
|
||||
增加并发限制(考虑边界保护)
|
||||
|
||||
策略:加性增加 (+1)
|
||||
策略:
|
||||
- 普通扩容:每次 +INCREASE_STEP,但不超过 known_boundary
|
||||
(因为缩容时已经 -1 了,这里允许回到边界值尝试)
|
||||
- 探测性扩容:每次只 +1,可以突破边界,但要谨慎
|
||||
|
||||
Args:
|
||||
current_limit: 当前限制
|
||||
known_boundary: 已知边界(last_concurrent_peak),即触发 429 时的并发数
|
||||
is_probe: 是否是探测性扩容(可以突破边界)
|
||||
"""
|
||||
new_limit = min(current_limit + self.INCREASE_STEP, self.MAX_CONCURRENT_LIMIT)
|
||||
if is_probe:
|
||||
# 探测模式:每次只 +1,谨慎突破边界
|
||||
new_limit = current_limit + 1
|
||||
else:
|
||||
# 普通模式:每次 +INCREASE_STEP
|
||||
new_limit = current_limit + self.INCREASE_STEP
|
||||
|
||||
# 边界保护:普通扩容不超过 known_boundary(允许回到边界值尝试)
|
||||
if known_boundary:
|
||||
if new_limit > known_boundary:
|
||||
new_limit = known_boundary
|
||||
|
||||
# 全局上限保护
|
||||
new_limit = min(new_limit, self.MAX_CONCURRENT_LIMIT)
|
||||
|
||||
# 确保有增长(否则返回原值表示不扩容)
|
||||
if new_limit <= current_limit:
|
||||
return current_limit
|
||||
|
||||
return new_limit
|
||||
|
||||
def _record_adjustment(
|
||||
@@ -503,11 +569,16 @@ class AdaptiveConcurrencyManager:
|
||||
if key.last_probe_increase_at:
|
||||
last_probe_at_str = cast(datetime, key.last_probe_increase_at).isoformat()
|
||||
|
||||
# 边界信息
|
||||
known_boundary = key.last_concurrent_peak
|
||||
|
||||
return {
|
||||
"adaptive_mode": is_adaptive,
|
||||
"max_concurrent": key.max_concurrent, # NULL=自适应,数字=固定限制
|
||||
"effective_limit": effective_limit, # 当前有效限制
|
||||
"learned_limit": key.learned_max_concurrent, # 学习到的限制
|
||||
# 边界记忆相关
|
||||
"known_boundary": known_boundary, # 触发 429 时的并发数(已知上限)
|
||||
"concurrent_429_count": int(key.concurrent_429_count or 0),
|
||||
"rpm_429_count": int(key.rpm_429_count or 0),
|
||||
"last_429_at": last_429_at_str,
|
||||
|
||||
@@ -289,11 +289,11 @@ class RequestResult:
|
||||
status_code = 500
|
||||
error_type = "internal_error"
|
||||
|
||||
# 构建错误消息,包含上游响应信息
|
||||
error_message = str(exception)
|
||||
if isinstance(exception, ProviderNotAvailableException):
|
||||
if exception.upstream_response:
|
||||
error_message = f"{error_message} | 上游响应: {exception.upstream_response[:500]}"
|
||||
# 构建错误消息:优先使用上游响应作为主要错误信息
|
||||
if isinstance(exception, ProviderNotAvailableException) and exception.upstream_response:
|
||||
error_message = exception.upstream_response
|
||||
else:
|
||||
error_message = str(exception)
|
||||
|
||||
return cls(
|
||||
status=RequestStatus.FAILED,
|
||||
|
||||
@@ -86,6 +86,118 @@ class UsageRecordParams:
|
||||
class UsageService:
|
||||
"""用量统计服务"""
|
||||
|
||||
# ==================== 缓存键常量 ====================
|
||||
|
||||
# 热力图缓存键前缀(依赖 TTL 自动过期,用户角色变更时主动清除)
|
||||
HEATMAP_CACHE_KEY_PREFIX = "activity_heatmap"
|
||||
|
||||
# ==================== 热力图缓存 ====================
|
||||
|
||||
@classmethod
|
||||
def _get_heatmap_cache_key(cls, user_id: Optional[str], include_actual_cost: bool) -> str:
|
||||
"""生成热力图缓存键"""
|
||||
cost_suffix = "with_cost" if include_actual_cost else "no_cost"
|
||||
if user_id:
|
||||
return f"{cls.HEATMAP_CACHE_KEY_PREFIX}:user:{user_id}:{cost_suffix}"
|
||||
else:
|
||||
return f"{cls.HEATMAP_CACHE_KEY_PREFIX}:admin:all:{cost_suffix}"
|
||||
|
||||
@classmethod
|
||||
async def clear_user_heatmap_cache(cls, user_id: str) -> None:
|
||||
"""
|
||||
清除用户的热力图缓存(用户角色变更时调用)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
"""
|
||||
from src.clients.redis_client import get_redis_client
|
||||
|
||||
redis_client = await get_redis_client(require_redis=False)
|
||||
if not redis_client:
|
||||
return
|
||||
|
||||
# 清除该用户的所有热力图缓存(with_cost 和 no_cost)
|
||||
keys_to_delete = [
|
||||
cls._get_heatmap_cache_key(user_id, include_actual_cost=True),
|
||||
cls._get_heatmap_cache_key(user_id, include_actual_cost=False),
|
||||
]
|
||||
|
||||
for key in keys_to_delete:
|
||||
try:
|
||||
await redis_client.delete(key)
|
||||
logger.debug(f"已清除热力图缓存: {key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清除热力图缓存失败: {key}, error={e}")
|
||||
|
||||
@classmethod
|
||||
async def get_cached_heatmap(
|
||||
cls,
|
||||
db: Session,
|
||||
user_id: Optional[str] = None,
|
||||
include_actual_cost: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取带缓存的热力图数据
|
||||
|
||||
缓存策略:
|
||||
- TTL: 5分钟(CacheTTL.ACTIVITY_HEATMAP)
|
||||
- 仅依赖 TTL 自动过期,新使用记录最多延迟 5 分钟出现
|
||||
- 用户角色变更时通过 clear_user_heatmap_cache() 主动清除
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID,None 表示获取全局热力图(管理员)
|
||||
include_actual_cost: 是否包含实际成本
|
||||
|
||||
Returns:
|
||||
热力图数据字典
|
||||
"""
|
||||
from src.clients.redis_client import get_redis_client
|
||||
from src.config.constants import CacheTTL
|
||||
import json
|
||||
|
||||
cache_key = cls._get_heatmap_cache_key(user_id, include_actual_cost)
|
||||
|
||||
cache_ttl = CacheTTL.ACTIVITY_HEATMAP
|
||||
redis_client = await get_redis_client(require_redis=False)
|
||||
|
||||
# 尝试从缓存获取
|
||||
if redis_client:
|
||||
try:
|
||||
cached = await redis_client.get(cache_key)
|
||||
if cached:
|
||||
try:
|
||||
return json.loads(cached) # type: ignore[no-any-return]
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"热力图缓存解析失败,删除损坏缓存: {cache_key}, error={e}")
|
||||
try:
|
||||
await redis_client.delete(cache_key)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"读取热力图缓存出错: {cache_key}, error={e}")
|
||||
|
||||
# 从数据库查询
|
||||
result = cls.get_daily_activity(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
window_days=365,
|
||||
include_actual_cost=include_actual_cost,
|
||||
)
|
||||
|
||||
# 保存到缓存(失败不影响返回结果)
|
||||
if redis_client:
|
||||
try:
|
||||
await redis_client.setex(
|
||||
cache_key,
|
||||
cache_ttl,
|
||||
json.dumps(result, ensure_ascii=False, default=str),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"保存热力图缓存失败: {cache_key}, error={e}")
|
||||
|
||||
return result
|
||||
|
||||
# ==================== 内部数据类 ====================
|
||||
|
||||
@staticmethod
|
||||
@@ -1027,7 +1139,12 @@ class UsageService:
|
||||
window_days: int = 365,
|
||||
include_actual_cost: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""按天统计请求活跃度,用于渲染热力图。"""
|
||||
"""按天统计请求活跃度,用于渲染热力图。
|
||||
|
||||
优化策略:
|
||||
- 历史数据从预计算的 StatsDaily/StatsUserDaily 表读取
|
||||
- 只有"今天"的数据才实时查询 Usage 表
|
||||
"""
|
||||
|
||||
def ensure_timezone(value: datetime) -> datetime:
|
||||
if value.tzinfo is None:
|
||||
@@ -1041,54 +1158,109 @@ class UsageService:
|
||||
ensure_timezone(start_date) if start_date else end_dt - timedelta(days=window_days - 1)
|
||||
)
|
||||
|
||||
# 对齐到自然日的开始/结束,避免遗漏边界数据
|
||||
start_dt = start_dt.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_dt = end_dt.replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
|
||||
from src.utils.database_helpers import date_trunc_portable
|
||||
|
||||
bind = db.get_bind()
|
||||
dialect = bind.dialect.name if bind is not None else "sqlite"
|
||||
day_bucket = date_trunc_portable(dialect, "day", Usage.created_at).label("day")
|
||||
|
||||
columns = [
|
||||
day_bucket,
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
|
||||
]
|
||||
|
||||
if include_actual_cost:
|
||||
columns.append(func.sum(Usage.actual_total_cost_usd).label("actual_total_cost_usd"))
|
||||
|
||||
query = db.query(*columns).filter(Usage.created_at >= start_dt, Usage.created_at <= end_dt)
|
||||
|
||||
if user_id:
|
||||
query = query.filter(Usage.user_id == user_id)
|
||||
|
||||
query = query.group_by(day_bucket).order_by(day_bucket)
|
||||
rows = query.all()
|
||||
|
||||
def normalize_period(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value[:10]
|
||||
if isinstance(value, datetime):
|
||||
return value.date().isoformat()
|
||||
return str(value)
|
||||
# 对齐到自然日的开始/结束
|
||||
start_dt = datetime.combine(start_dt.date(), datetime.min.time(), tzinfo=timezone.utc)
|
||||
end_dt = datetime.combine(end_dt.date(), datetime.max.time(), tzinfo=timezone.utc)
|
||||
|
||||
today = now.date()
|
||||
today_start_dt = datetime.combine(today, datetime.min.time(), tzinfo=timezone.utc)
|
||||
aggregated: Dict[str, Dict[str, Any]] = {}
|
||||
for row in rows:
|
||||
key = normalize_period(row.day)
|
||||
aggregated[key] = {
|
||||
"requests": int(row.requests or 0),
|
||||
"total_tokens": int(row.total_tokens or 0),
|
||||
"total_cost_usd": float(row.total_cost_usd or 0.0),
|
||||
}
|
||||
if include_actual_cost:
|
||||
aggregated[key]["actual_total_cost_usd"] = float(row.actual_total_cost_usd or 0.0)
|
||||
|
||||
# 1. 从预计算表读取历史数据(不包括今天)
|
||||
if user_id:
|
||||
from src.models.database import StatsUserDaily
|
||||
|
||||
hist_query = db.query(StatsUserDaily).filter(
|
||||
StatsUserDaily.user_id == user_id,
|
||||
StatsUserDaily.date >= start_dt,
|
||||
StatsUserDaily.date < today_start_dt,
|
||||
)
|
||||
for row in hist_query.all():
|
||||
key = (
|
||||
row.date.date().isoformat()
|
||||
if isinstance(row.date, datetime)
|
||||
else str(row.date)[:10]
|
||||
)
|
||||
aggregated[key] = {
|
||||
"requests": row.total_requests or 0,
|
||||
"total_tokens": (
|
||||
(row.input_tokens or 0)
|
||||
+ (row.output_tokens or 0)
|
||||
+ (row.cache_creation_tokens or 0)
|
||||
+ (row.cache_read_tokens or 0)
|
||||
),
|
||||
"total_cost_usd": float(row.total_cost or 0.0),
|
||||
}
|
||||
# StatsUserDaily 没有 actual_total_cost 字段,用户视图不需要倍率成本
|
||||
else:
|
||||
from src.models.database import StatsDaily
|
||||
|
||||
hist_query = db.query(StatsDaily).filter(
|
||||
StatsDaily.date >= start_dt,
|
||||
StatsDaily.date < today_start_dt,
|
||||
)
|
||||
for row in hist_query.all():
|
||||
key = (
|
||||
row.date.date().isoformat()
|
||||
if isinstance(row.date, datetime)
|
||||
else str(row.date)[:10]
|
||||
)
|
||||
aggregated[key] = {
|
||||
"requests": row.total_requests or 0,
|
||||
"total_tokens": (
|
||||
(row.input_tokens or 0)
|
||||
+ (row.output_tokens or 0)
|
||||
+ (row.cache_creation_tokens or 0)
|
||||
+ (row.cache_read_tokens or 0)
|
||||
),
|
||||
"total_cost_usd": float(row.total_cost or 0.0),
|
||||
}
|
||||
if include_actual_cost:
|
||||
aggregated[key]["actual_total_cost_usd"] = float(
|
||||
row.actual_total_cost or 0.0 # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
# 2. 实时查询今天的数据(如果在查询范围内)
|
||||
if today >= start_dt.date() and today <= end_dt.date():
|
||||
today_start = datetime.combine(today, datetime.min.time(), tzinfo=timezone.utc)
|
||||
today_end = datetime.combine(today, datetime.max.time(), tzinfo=timezone.utc)
|
||||
|
||||
if include_actual_cost:
|
||||
today_query = db.query(
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
|
||||
func.sum(Usage.actual_total_cost_usd).label("actual_total_cost_usd"),
|
||||
).filter(
|
||||
Usage.created_at >= today_start,
|
||||
Usage.created_at <= today_end,
|
||||
)
|
||||
else:
|
||||
today_query = db.query(
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
|
||||
).filter(
|
||||
Usage.created_at >= today_start,
|
||||
Usage.created_at <= today_end,
|
||||
)
|
||||
|
||||
if user_id:
|
||||
today_query = today_query.filter(Usage.user_id == user_id)
|
||||
|
||||
today_row = today_query.first()
|
||||
if today_row and today_row.requests:
|
||||
aggregated[today.isoformat()] = {
|
||||
"requests": int(today_row.requests or 0),
|
||||
"total_tokens": int(today_row.total_tokens or 0),
|
||||
"total_cost_usd": float(today_row.total_cost_usd or 0.0),
|
||||
}
|
||||
if include_actual_cost:
|
||||
aggregated[today.isoformat()]["actual_total_cost_usd"] = float(
|
||||
today_row.actual_total_cost_usd or 0.0
|
||||
)
|
||||
|
||||
# 3. 构建返回结果
|
||||
days: List[Dict[str, Any]] = []
|
||||
cursor = start_dt.date()
|
||||
end_date_only = end_dt.date()
|
||||
@@ -1304,6 +1476,9 @@ class UsageService:
|
||||
provider: Optional[str] = None,
|
||||
target_model: Optional[str] = None,
|
||||
first_byte_time_ms: Optional[int] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
provider_endpoint_id: Optional[str] = None,
|
||||
provider_api_key_id: Optional[str] = None,
|
||||
) -> Optional[Usage]:
|
||||
"""
|
||||
快速更新使用记录状态
|
||||
@@ -1316,6 +1491,9 @@ class UsageService:
|
||||
provider: 提供商名称(可选,streaming 状态时更新)
|
||||
target_model: 映射后的目标模型名(可选)
|
||||
first_byte_time_ms: 首字时间/TTFB(可选,streaming 状态时更新)
|
||||
provider_id: Provider ID(可选,streaming 状态时更新)
|
||||
provider_endpoint_id: Endpoint ID(可选,streaming 状态时更新)
|
||||
provider_api_key_id: Provider API Key ID(可选,streaming 状态时更新)
|
||||
|
||||
Returns:
|
||||
更新后的 Usage 记录,如果未找到则返回 None
|
||||
@@ -1331,10 +1509,22 @@ class UsageService:
|
||||
usage.error_message = error_message
|
||||
if provider:
|
||||
usage.provider = provider
|
||||
elif status == "streaming" and usage.provider == "pending":
|
||||
# 状态变为 streaming 但 provider 仍为 pending,记录警告
|
||||
logger.warning(
|
||||
f"状态更新为 streaming 但 provider 为空: request_id={request_id}, "
|
||||
f"当前 provider={usage.provider}"
|
||||
)
|
||||
if target_model:
|
||||
usage.target_model = target_model
|
||||
if first_byte_time_ms is not None:
|
||||
usage.first_byte_time_ms = first_byte_time_ms
|
||||
if provider_id is not None:
|
||||
usage.provider_id = provider_id
|
||||
if provider_endpoint_id is not None:
|
||||
usage.provider_endpoint_id = provider_endpoint_id
|
||||
if provider_api_key_id is not None:
|
||||
usage.provider_api_key_id = provider_api_key_id
|
||||
|
||||
db.commit()
|
||||
|
||||
@@ -1446,6 +1636,8 @@ class UsageService:
|
||||
ids: Optional[List[str]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
default_timeout_seconds: int = 300,
|
||||
*,
|
||||
include_admin_fields: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取活跃请求状态(用于前端轮询),并自动清理超时的 pending/streaming 请求
|
||||
@@ -1482,6 +1674,15 @@ class UsageService:
|
||||
ProviderEndpoint.timeout.label("endpoint_timeout"),
|
||||
).outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
|
||||
|
||||
# 管理员轮询:可附带 provider 与上游 key 名称(注意:不要在普通用户接口暴露上游 key 信息)
|
||||
if include_admin_fields:
|
||||
from src.models.database import ProviderAPIKey
|
||||
|
||||
query = query.add_columns(
|
||||
Usage.provider,
|
||||
ProviderAPIKey.name.label("api_key_name"),
|
||||
).outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
|
||||
|
||||
if ids:
|
||||
query = query.filter(Usage.id.in_(ids))
|
||||
if user_id:
|
||||
@@ -1518,8 +1719,9 @@ class UsageService:
|
||||
)
|
||||
db.commit()
|
||||
|
||||
return [
|
||||
{
|
||||
result: List[Dict[str, Any]] = []
|
||||
for r in records:
|
||||
item: Dict[str, Any] = {
|
||||
"id": r.id,
|
||||
"status": "failed" if r.id in timeout_ids else r.status,
|
||||
"input_tokens": r.input_tokens,
|
||||
@@ -1528,8 +1730,12 @@ class UsageService:
|
||||
"response_time_ms": r.response_time_ms,
|
||||
"first_byte_time_ms": r.first_byte_time_ms, # 首字时间 (TTFB)
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
if include_admin_fields:
|
||||
item["provider"] = r.provider
|
||||
item["api_key_name"] = r.api_key_name
|
||||
result.append(item)
|
||||
|
||||
return result
|
||||
|
||||
# ========== 缓存亲和性分析方法 ==========
|
||||
|
||||
|
||||
@@ -459,34 +459,38 @@ class StreamUsageTracker:
|
||||
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
||||
|
||||
chunk_count = 0
|
||||
first_chunk_received = False
|
||||
first_byte_time_ms = None # 预先记录 TTFB,避免 yield 后计算不准确
|
||||
try:
|
||||
async for chunk in stream:
|
||||
chunk_count += 1
|
||||
# 保存原始字节流(用于错误诊断)
|
||||
self.raw_chunks.append(chunk)
|
||||
|
||||
# 第一个 chunk 收到时,更新状态为 streaming 并记录 TTFB
|
||||
if not first_chunk_received:
|
||||
first_chunk_received = True
|
||||
if self.request_id:
|
||||
try:
|
||||
# 计算 TTFB(使用请求原始开始时间或 track_stream 开始时间)
|
||||
base_time = self.request_start_time or self.start_time
|
||||
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
|
||||
UsageService.update_usage_status(
|
||||
db=self.db,
|
||||
request_id=self.request_id,
|
||||
status="streaming",
|
||||
provider=self.provider,
|
||||
first_byte_time_ms=first_byte_time_ms,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
||||
# 第一个 chunk 收到时,记录 TTFB 时间点(但先不更新数据库,避免阻塞)
|
||||
if chunk_count == 1:
|
||||
# 计算 TTFB(使用请求原始开始时间或 track_stream 开始时间)
|
||||
base_time = self.request_start_time or self.start_time
|
||||
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
|
||||
|
||||
# 返回原始块给客户端
|
||||
# 先返回原始块给客户端,确保 TTFB 不受数据库操作影响
|
||||
yield chunk
|
||||
|
||||
# yield 后再更新数据库状态(仅第一个 chunk 时执行)
|
||||
if chunk_count == 1 and self.request_id:
|
||||
try:
|
||||
UsageService.update_usage_status(
|
||||
db=self.db,
|
||||
request_id=self.request_id,
|
||||
status="streaming",
|
||||
provider=self.provider,
|
||||
first_byte_time_ms=first_byte_time_ms,
|
||||
provider_id=self.provider_id,
|
||||
provider_endpoint_id=self.provider_endpoint_id,
|
||||
provider_api_key_id=self.provider_api_key_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
||||
|
||||
# 解析块以提取内容和使用信息(chunk是原始字节)
|
||||
content, usage = self.parse_stream_chunk(chunk)
|
||||
|
||||
@@ -916,15 +920,38 @@ class EnhancedStreamUsageTracker(StreamUsageTracker):
|
||||
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应(Enhanced) | 估算输入tokens:{self.input_tokens}")
|
||||
|
||||
chunk_count = 0
|
||||
first_byte_time_ms = None # 预先记录 TTFB,避免 yield 后计算不准确
|
||||
try:
|
||||
async for chunk in stream:
|
||||
chunk_count += 1
|
||||
# 保存原始字节流(用于错误诊断)
|
||||
self.raw_chunks.append(chunk)
|
||||
|
||||
# 返回原始块给客户端
|
||||
# 第一个 chunk 收到时,记录 TTFB 时间点(但先不更新数据库,避免阻塞)
|
||||
if chunk_count == 1:
|
||||
# 计算 TTFB(使用请求原始开始时间或 track_stream 开始时间)
|
||||
base_time = self.request_start_time or self.start_time
|
||||
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
|
||||
|
||||
# 先返回原始块给客户端,确保 TTFB 不受数据库操作影响
|
||||
yield chunk
|
||||
|
||||
# yield 后再更新数据库状态(仅第一个 chunk 时执行)
|
||||
if chunk_count == 1 and self.request_id:
|
||||
try:
|
||||
UsageService.update_usage_status(
|
||||
db=self.db,
|
||||
request_id=self.request_id,
|
||||
status="streaming",
|
||||
provider=self.provider,
|
||||
first_byte_time_ms=first_byte_time_ms,
|
||||
provider_id=self.provider_id,
|
||||
provider_endpoint_id=self.provider_endpoint_id,
|
||||
provider_api_key_id=self.provider_api_key_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
||||
|
||||
# 解析块以提取内容和使用信息(chunk是原始字节)
|
||||
content, usage = self.parse_stream_chunk(chunk)
|
||||
|
||||
|
||||
@@ -25,9 +25,10 @@ class ApiKeyService:
|
||||
allowed_providers: Optional[List[str]] = None,
|
||||
allowed_api_formats: Optional[List[str]] = None,
|
||||
allowed_models: Optional[List[str]] = None,
|
||||
rate_limit: int = 100,
|
||||
rate_limit: Optional[int] = None,
|
||||
concurrent_limit: int = 5,
|
||||
expire_days: Optional[int] = None,
|
||||
expires_at: Optional[datetime] = None, # 直接传入过期时间,优先于 expire_days
|
||||
initial_balance_usd: Optional[float] = None,
|
||||
is_standalone: bool = False,
|
||||
auto_delete_on_expiry: bool = False,
|
||||
@@ -44,6 +45,7 @@ class ApiKeyService:
|
||||
rate_limit: 速率限制
|
||||
concurrent_limit: 并发限制
|
||||
expire_days: 过期天数,None = 永不过期
|
||||
expires_at: 直接指定过期时间,优先于 expire_days
|
||||
initial_balance_usd: 初始余额(USD),仅用于独立Key,None = 无限制
|
||||
is_standalone: 是否为独立余额Key(仅管理员可创建)
|
||||
auto_delete_on_expiry: 过期后是否自动删除(True=物理删除,False=仅禁用)
|
||||
@@ -54,10 +56,10 @@ class ApiKeyService:
|
||||
key_hash = ApiKey.hash_key(key)
|
||||
key_encrypted = crypto_service.encrypt(key) # 加密存储密钥
|
||||
|
||||
# 计算过期时间
|
||||
expires_at = None
|
||||
if expire_days:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
|
||||
# 计算过期时间:优先使用 expires_at,其次使用 expire_days
|
||||
final_expires_at = expires_at
|
||||
if final_expires_at is None and expire_days:
|
||||
final_expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
|
||||
|
||||
# 空数组转为 None(表示不限制)
|
||||
api_key = ApiKey(
|
||||
@@ -70,7 +72,7 @@ class ApiKeyService:
|
||||
allowed_models=allowed_models or None,
|
||||
rate_limit=rate_limit,
|
||||
concurrent_limit=concurrent_limit,
|
||||
expires_at=expires_at,
|
||||
expires_at=final_expires_at,
|
||||
balance_used_usd=0.0,
|
||||
current_balance_usd=initial_balance_usd, # 直接使用初始余额,None = 无限制
|
||||
is_standalone=is_standalone,
|
||||
@@ -145,6 +147,9 @@ class ApiKeyService:
|
||||
# 允许显式设置为空数组/None 的字段(空数组会转为 None,表示"全部")
|
||||
nullable_list_fields = {"allowed_providers", "allowed_api_formats", "allowed_models"}
|
||||
|
||||
# 允许显式设置为 None 的字段(如 expires_at=None 表示永不过期,rate_limit=None 表示无限制)
|
||||
nullable_fields = {"expires_at", "rate_limit"}
|
||||
|
||||
for field, value in kwargs.items():
|
||||
if field not in updatable_fields:
|
||||
continue
|
||||
@@ -153,6 +158,9 @@ class ApiKeyService:
|
||||
if value is not None:
|
||||
# 空数组转为 None(表示允许全部)
|
||||
setattr(api_key, field, value if value else None)
|
||||
elif field in nullable_fields:
|
||||
# 这些字段允许显式设置为 None
|
||||
setattr(api_key, field, value)
|
||||
elif value is not None:
|
||||
setattr(api_key, field, value)
|
||||
|
||||
|
||||
@@ -49,8 +49,16 @@ def cache_result(key_prefix: str, ttl: int = 60, user_specific: bool = True) ->
|
||||
# 尝试从缓存获取
|
||||
cached = await redis_client.get(cache_key)
|
||||
if cached:
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
return json.loads(cached)
|
||||
try:
|
||||
result = json.loads(cached)
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"缓存解析失败,删除损坏缓存: {cache_key}, 错误: {e}")
|
||||
try:
|
||||
await redis_client.delete(cache_key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 执行原函数
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
@@ -7,6 +7,59 @@ from typing import Any
|
||||
from sqlalchemy import func
|
||||
|
||||
|
||||
def escape_like_pattern(pattern: str) -> str:
|
||||
"""
|
||||
转义 SQL LIKE 语句中的特殊字符(%、_、\\)
|
||||
|
||||
Args:
|
||||
pattern: 原始搜索模式
|
||||
|
||||
Returns:
|
||||
转义后的模式,可安全用于 LIKE 查询(需配合 escape="\\\\")
|
||||
|
||||
Examples:
|
||||
>>> escape_like_pattern("hello_world%test")
|
||||
'hello\\\\_world\\\\%test'
|
||||
"""
|
||||
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
|
||||
def safe_truncate_escaped(escaped: str, max_len: int) -> str:
|
||||
"""
|
||||
安全截断已转义的字符串,避免截断在转义序列中间
|
||||
|
||||
转义后的字符串中,反斜杠总是成对出现(\\\\)或作为转义符(\\%, \\_)。
|
||||
如果在某个位置截断导致末尾有奇数个反斜杠,说明截断发生在转义序列中间,
|
||||
需要去掉最后一个反斜杠以保持转义完整性。
|
||||
|
||||
Args:
|
||||
escaped: 已经过 escape_like_pattern 处理的字符串
|
||||
max_len: 最大长度
|
||||
|
||||
Returns:
|
||||
截断后的字符串,保证不会破坏转义序列
|
||||
"""
|
||||
if len(escaped) <= max_len:
|
||||
return escaped
|
||||
|
||||
truncated = escaped[:max_len]
|
||||
|
||||
# 统计末尾连续的反斜杠数量
|
||||
trailing_backslashes = 0
|
||||
for i in range(len(truncated) - 1, -1, -1):
|
||||
if truncated[i] == "\\":
|
||||
trailing_backslashes += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# 如果末尾反斜杠数量为奇数,说明截断在转义序列中间
|
||||
# 需要去掉最后一个反斜杠
|
||||
if trailing_backslashes % 2 == 1:
|
||||
truncated = truncated[:-1]
|
||||
|
||||
return truncated
|
||||
|
||||
|
||||
def date_trunc_portable(dialect_name: str, interval: str, column: Any) -> Any:
|
||||
"""
|
||||
跨数据库的日期截断函数
|
||||
|
||||
@@ -7,22 +7,20 @@ from typing import Optional
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from src.config import config
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> str:
|
||||
"""
|
||||
获取客户端真实IP地址
|
||||
|
||||
按优先级检查:
|
||||
1. X-Forwarded-For 头(支持代理链,根据可信代理数量提取)
|
||||
2. X-Real-IP 头(Nginx 代理)
|
||||
1. X-Real-IP 头(最可靠,由最外层可信 Nginx 直接设置)
|
||||
2. X-Forwarded-For 头的第一个 IP(原始客户端)
|
||||
3. 直接客户端IP
|
||||
|
||||
安全说明:
|
||||
- 此函数根据 TRUSTED_PROXY_COUNT 配置来决定信任的代理层数
|
||||
- 当 TRUSTED_PROXY_COUNT=0 时,不信任任何代理头,直接使用连接 IP
|
||||
- 当服务直接暴露公网时,应设置 TRUSTED_PROXY_COUNT=0 以防止 IP 伪造
|
||||
- X-Real-IP 优先级最高,因为它通常由最外层 Nginx 设置为 $remote_addr,
|
||||
Nginx 会直接覆盖这个头,不会传递客户端伪造的值
|
||||
- 只要最外层 Nginx 配置了 proxy_set_header X-Real-IP $remote_addr; 即可正确获取真实 IP
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
@@ -30,30 +28,19 @@ def get_client_ip(request: Request) -> str:
|
||||
Returns:
|
||||
str: 客户端IP地址,如果无法获取则返回 "unknown"
|
||||
"""
|
||||
trusted_proxy_count = config.trusted_proxy_count
|
||||
|
||||
# 如果不信任任何代理,直接返回连接 IP
|
||||
if trusted_proxy_count == 0:
|
||||
if request.client and request.client.host:
|
||||
return request.client.host
|
||||
return "unknown"
|
||||
|
||||
# 优先检查 X-Forwarded-For 头(可能包含代理链)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||
if len(ips) > trusted_proxy_count:
|
||||
return ips[-(trusted_proxy_count + 1)]
|
||||
elif ips:
|
||||
return ips[0]
|
||||
|
||||
# 检查 X-Real-IP 头(通常由 Nginx 设置)
|
||||
# 优先检查 X-Real-IP 头(由最外层 Nginx 设置,最可靠)
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
|
||||
# 检查 X-Forwarded-For 头,取第一个 IP(原始客户端)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||
if ips:
|
||||
return ips[0]
|
||||
|
||||
# 回退到直接客户端IP
|
||||
if request.client and request.client.host:
|
||||
return request.client.host
|
||||
@@ -109,36 +96,26 @@ def get_request_metadata(request: Request) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def extract_ip_from_headers(headers: dict, trusted_proxy_count: Optional[int] = None) -> str:
|
||||
def extract_ip_from_headers(headers: dict) -> str:
|
||||
"""
|
||||
从HTTP头字典中提取IP地址(用于中间件等场景)
|
||||
|
||||
Args:
|
||||
headers: HTTP头字典
|
||||
trusted_proxy_count: 可信代理层数,None 时使用配置值
|
||||
|
||||
Returns:
|
||||
str: 客户端IP地址
|
||||
"""
|
||||
if trusted_proxy_count is None:
|
||||
trusted_proxy_count = config.trusted_proxy_count
|
||||
|
||||
# 如果不信任任何代理,返回 unknown(调用方需要用其他方式获取连接 IP)
|
||||
if trusted_proxy_count == 0:
|
||||
return "unknown"
|
||||
|
||||
# 检查 X-Forwarded-For
|
||||
forwarded_for = headers.get("x-forwarded-for", "")
|
||||
if forwarded_for:
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||
if len(ips) > trusted_proxy_count:
|
||||
return ips[-(trusted_proxy_count + 1)]
|
||||
elif ips:
|
||||
return ips[0]
|
||||
|
||||
# 检查 X-Real-IP
|
||||
# 优先检查 X-Real-IP(由最外层 Nginx 设置,最可靠)
|
||||
real_ip = headers.get("x-real-ip", "")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
|
||||
# 检查 X-Forwarded-For,取第一个 IP
|
||||
forwarded_for = headers.get("x-forwarded-for", "")
|
||||
if forwarded_for:
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||
if ips:
|
||||
return ips[0]
|
||||
|
||||
return "unknown"
|
||||
|
||||
0
tests/services/billing/__init__.py
Normal file
0
tests/services/billing/__init__.py
Normal file
440
tests/services/billing/test_billing.py
Normal file
440
tests/services/billing/test_billing.py
Normal file
@@ -0,0 +1,440 @@
|
||||
"""
|
||||
Billing 模块测试
|
||||
|
||||
测试计费模块的核心功能:
|
||||
- BillingCalculator 计费计算
|
||||
- 计费模板
|
||||
- 阶梯计费
|
||||
- calculate_request_cost 便捷函数
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.services.billing import (
|
||||
BillingCalculator,
|
||||
BillingDimension,
|
||||
BillingTemplates,
|
||||
BillingUnit,
|
||||
CostBreakdown,
|
||||
StandardizedUsage,
|
||||
calculate_request_cost,
|
||||
)
|
||||
from src.services.billing.templates import get_template, list_templates
|
||||
|
||||
|
||||
class TestBillingDimension:
|
||||
"""测试计费维度"""
|
||||
|
||||
def test_calculate_per_1m_tokens(self) -> None:
|
||||
"""测试 per_1m_tokens 计费"""
|
||||
dim = BillingDimension(
|
||||
name="input",
|
||||
usage_field="input_tokens",
|
||||
price_field="input_price_per_1m",
|
||||
)
|
||||
|
||||
# 1000 tokens * $3 / 1M = $0.003
|
||||
cost = dim.calculate(1000, 3.0)
|
||||
assert abs(cost - 0.003) < 0.0001
|
||||
|
||||
def test_calculate_per_request(self) -> None:
|
||||
"""测试按次计费"""
|
||||
dim = BillingDimension(
|
||||
name="request",
|
||||
usage_field="request_count",
|
||||
price_field="price_per_request",
|
||||
unit=BillingUnit.PER_REQUEST,
|
||||
)
|
||||
|
||||
# 按次计费:cost = request_count * price
|
||||
cost = dim.calculate(1, 0.05)
|
||||
assert cost == 0.05
|
||||
|
||||
# 多次请求应按次数计费
|
||||
cost = dim.calculate(3, 0.05)
|
||||
assert abs(cost - 0.15) < 0.0001
|
||||
|
||||
def test_calculate_zero_usage(self) -> None:
|
||||
"""测试零用量"""
|
||||
dim = BillingDimension(
|
||||
name="input",
|
||||
usage_field="input_tokens",
|
||||
price_field="input_price_per_1m",
|
||||
)
|
||||
|
||||
cost = dim.calculate(0, 3.0)
|
||||
assert cost == 0.0
|
||||
|
||||
def test_calculate_zero_price(self) -> None:
|
||||
"""测试零价格"""
|
||||
dim = BillingDimension(
|
||||
name="input",
|
||||
usage_field="input_tokens",
|
||||
price_field="input_price_per_1m",
|
||||
)
|
||||
|
||||
cost = dim.calculate(1000, 0.0)
|
||||
assert cost == 0.0
|
||||
|
||||
def test_to_dict_and_from_dict(self) -> None:
|
||||
"""测试序列化和反序列化"""
|
||||
dim = BillingDimension(
|
||||
name="cache_read",
|
||||
usage_field="cache_read_tokens",
|
||||
price_field="cache_read_price_per_1m",
|
||||
unit=BillingUnit.PER_1M_TOKENS,
|
||||
default_price=0.3,
|
||||
)
|
||||
|
||||
d = dim.to_dict()
|
||||
restored = BillingDimension.from_dict(d)
|
||||
|
||||
assert restored.name == dim.name
|
||||
assert restored.usage_field == dim.usage_field
|
||||
assert restored.price_field == dim.price_field
|
||||
assert restored.unit == dim.unit
|
||||
assert restored.default_price == dim.default_price
|
||||
|
||||
|
||||
class TestStandardizedUsage:
|
||||
"""测试标准化 Usage"""
|
||||
|
||||
def test_basic_usage(self) -> None:
|
||||
"""测试基础 usage"""
|
||||
usage = StandardizedUsage(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
)
|
||||
|
||||
assert usage.input_tokens == 1000
|
||||
assert usage.output_tokens == 500
|
||||
assert usage.cache_creation_tokens == 0
|
||||
assert usage.cache_read_tokens == 0
|
||||
|
||||
def test_get_field(self) -> None:
|
||||
"""测试字段获取"""
|
||||
usage = StandardizedUsage(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
)
|
||||
|
||||
assert usage.get("input_tokens") == 1000
|
||||
assert usage.get("nonexistent", 0) == 0
|
||||
|
||||
def test_extra_fields(self) -> None:
|
||||
"""测试扩展字段"""
|
||||
usage = StandardizedUsage(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
extra={"custom_field": 123},
|
||||
)
|
||||
|
||||
assert usage.get("custom_field") == 123
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""测试转换为字典"""
|
||||
usage = StandardizedUsage(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
cache_creation_tokens=100,
|
||||
)
|
||||
|
||||
d = usage.to_dict()
|
||||
assert d["input_tokens"] == 1000
|
||||
assert d["output_tokens"] == 500
|
||||
assert d["cache_creation_tokens"] == 100
|
||||
|
||||
|
||||
class TestCostBreakdown:
|
||||
"""测试费用明细"""
|
||||
|
||||
def test_basic_breakdown(self) -> None:
|
||||
"""测试基础费用明细"""
|
||||
breakdown = CostBreakdown(
|
||||
costs={"input": 0.003, "output": 0.0075},
|
||||
total_cost=0.0105,
|
||||
)
|
||||
|
||||
assert breakdown.input_cost == 0.003
|
||||
assert breakdown.output_cost == 0.0075
|
||||
assert breakdown.total_cost == 0.0105
|
||||
|
||||
def test_cache_cost_calculation(self) -> None:
|
||||
"""测试缓存费用汇总"""
|
||||
breakdown = CostBreakdown(
|
||||
costs={
|
||||
"input": 0.003,
|
||||
"output": 0.0075,
|
||||
"cache_creation": 0.001,
|
||||
"cache_read": 0.0005,
|
||||
},
|
||||
total_cost=0.012,
|
||||
)
|
||||
|
||||
# cache_cost = cache_creation + cache_read
|
||||
assert abs(breakdown.cache_cost - 0.0015) < 0.0001
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""测试转换为字典"""
|
||||
breakdown = CostBreakdown(
|
||||
costs={"input": 0.003, "output": 0.0075},
|
||||
total_cost=0.0105,
|
||||
tier_index=1,
|
||||
)
|
||||
|
||||
d = breakdown.to_dict()
|
||||
assert d["total_cost"] == 0.0105
|
||||
assert d["tier_index"] == 1
|
||||
assert d["input_cost"] == 0.003
|
||||
|
||||
|
||||
class TestBillingTemplates:
|
||||
"""测试计费模板"""
|
||||
|
||||
def test_claude_template(self) -> None:
|
||||
"""测试 Claude 模板"""
|
||||
template = BillingTemplates.CLAUDE_STANDARD
|
||||
dim_names = [d.name for d in template]
|
||||
|
||||
assert "input" in dim_names
|
||||
assert "output" in dim_names
|
||||
assert "cache_creation" in dim_names
|
||||
assert "cache_read" in dim_names
|
||||
|
||||
def test_openai_template(self) -> None:
|
||||
"""测试 OpenAI 模板"""
|
||||
template = BillingTemplates.OPENAI_STANDARD
|
||||
dim_names = [d.name for d in template]
|
||||
|
||||
assert "input" in dim_names
|
||||
assert "output" in dim_names
|
||||
assert "cache_read" in dim_names
|
||||
# OpenAI 没有缓存创建费用
|
||||
assert "cache_creation" not in dim_names
|
||||
|
||||
def test_gemini_template(self) -> None:
|
||||
"""测试 Gemini 模板"""
|
||||
template = BillingTemplates.GEMINI_STANDARD
|
||||
dim_names = [d.name for d in template]
|
||||
|
||||
assert "input" in dim_names
|
||||
assert "output" in dim_names
|
||||
assert "cache_read" in dim_names
|
||||
|
||||
def test_per_request_template(self) -> None:
|
||||
"""测试按次计费模板"""
|
||||
template = BillingTemplates.PER_REQUEST
|
||||
assert len(template) == 1
|
||||
assert template[0].name == "request"
|
||||
assert template[0].unit == BillingUnit.PER_REQUEST
|
||||
|
||||
def test_get_template(self) -> None:
|
||||
"""测试获取模板"""
|
||||
template = get_template("claude")
|
||||
assert template == BillingTemplates.CLAUDE_STANDARD
|
||||
|
||||
template = get_template("openai")
|
||||
assert template == BillingTemplates.OPENAI_STANDARD
|
||||
|
||||
# 不区分大小写
|
||||
template = get_template("CLAUDE")
|
||||
assert template == BillingTemplates.CLAUDE_STANDARD
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown billing template"):
|
||||
get_template("unknown_template")
|
||||
|
||||
def test_list_templates(self) -> None:
|
||||
"""测试列出模板"""
|
||||
templates = list_templates()
|
||||
|
||||
assert "claude" in templates
|
||||
assert "openai" in templates
|
||||
assert "gemini" in templates
|
||||
assert "per_request" in templates
|
||||
|
||||
|
||||
class TestBillingCalculator:
|
||||
"""测试计费计算器"""
|
||||
|
||||
def test_basic_calculation(self) -> None:
|
||||
"""测试基础计费计算"""
|
||||
calculator = BillingCalculator(template="claude")
|
||||
usage = StandardizedUsage(input_tokens=1000, output_tokens=500)
|
||||
prices = {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0}
|
||||
|
||||
result = calculator.calculate(usage, prices)
|
||||
|
||||
# 1000 * 3 / 1M = 0.003
|
||||
assert abs(result.input_cost - 0.003) < 0.0001
|
||||
# 500 * 15 / 1M = 0.0075
|
||||
assert abs(result.output_cost - 0.0075) < 0.0001
|
||||
# Total = 0.0105
|
||||
assert abs(result.total_cost - 0.0105) < 0.0001
|
||||
|
||||
def test_calculation_with_cache(self) -> None:
|
||||
"""测试带缓存的计费计算"""
|
||||
calculator = BillingCalculator(template="claude")
|
||||
usage = StandardizedUsage(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
cache_creation_tokens=200,
|
||||
cache_read_tokens=300,
|
||||
)
|
||||
prices = {
|
||||
"input_price_per_1m": 3.0,
|
||||
"output_price_per_1m": 15.0,
|
||||
"cache_creation_price_per_1m": 3.75,
|
||||
"cache_read_price_per_1m": 0.3,
|
||||
}
|
||||
|
||||
result = calculator.calculate(usage, prices)
|
||||
|
||||
# cache_creation: 200 * 3.75 / 1M = 0.00075
|
||||
assert abs(result.cache_creation_cost - 0.00075) < 0.0001
|
||||
# cache_read: 300 * 0.3 / 1M = 0.00009
|
||||
assert abs(result.cache_read_cost - 0.00009) < 0.0001
|
||||
|
||||
def test_tiered_pricing(self) -> None:
|
||||
"""测试阶梯计费"""
|
||||
calculator = BillingCalculator(template="claude")
|
||||
usage = StandardizedUsage(input_tokens=250000, output_tokens=10000)
|
||||
|
||||
# 大于 200k 进入第二阶梯
|
||||
tiered_pricing = {
|
||||
"tiers": [
|
||||
{"up_to": 200000, "input_price_per_1m": 3.0, "output_price_per_1m": 15.0},
|
||||
{"up_to": None, "input_price_per_1m": 1.5, "output_price_per_1m": 7.5},
|
||||
]
|
||||
}
|
||||
prices = {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0}
|
||||
|
||||
result = calculator.calculate(usage, prices, tiered_pricing)
|
||||
|
||||
# 应该使用第二阶梯价格
|
||||
assert result.tier_index == 1
|
||||
# 250000 * 1.5 / 1M = 0.375
|
||||
assert abs(result.input_cost - 0.375) < 0.0001
|
||||
|
||||
def test_openai_no_cache_creation(self) -> None:
|
||||
"""测试 OpenAI 模板没有缓存创建费用"""
|
||||
calculator = BillingCalculator(template="openai")
|
||||
usage = StandardizedUsage(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
cache_creation_tokens=200, # 这个不应该计费
|
||||
cache_read_tokens=300,
|
||||
)
|
||||
prices = {
|
||||
"input_price_per_1m": 3.0,
|
||||
"output_price_per_1m": 15.0,
|
||||
"cache_creation_price_per_1m": 3.75,
|
||||
"cache_read_price_per_1m": 0.3,
|
||||
}
|
||||
|
||||
result = calculator.calculate(usage, prices)
|
||||
|
||||
# OpenAI 模板不包含 cache_creation 维度
|
||||
assert result.cache_creation_cost == 0.0
|
||||
# 但 cache_read 应该计费
|
||||
assert result.cache_read_cost > 0
|
||||
|
||||
def test_from_config(self) -> None:
|
||||
"""测试从配置创建计算器"""
|
||||
config = {"template": "openai"}
|
||||
calculator = BillingCalculator.from_config(config)
|
||||
|
||||
assert calculator.template_name == "openai"
|
||||
|
||||
|
||||
class TestCalculateRequestCost:
|
||||
"""测试便捷函数"""
|
||||
|
||||
def test_basic_usage(self) -> None:
|
||||
"""测试基础用法"""
|
||||
result = calculate_request_cost(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
cache_creation_input_tokens=0,
|
||||
cache_read_input_tokens=0,
|
||||
input_price_per_1m=3.0,
|
||||
output_price_per_1m=15.0,
|
||||
cache_creation_price_per_1m=None,
|
||||
cache_read_price_per_1m=None,
|
||||
price_per_request=None,
|
||||
billing_template="claude",
|
||||
)
|
||||
|
||||
assert "input_cost" in result
|
||||
assert "output_cost" in result
|
||||
assert "total_cost" in result
|
||||
assert abs(result["input_cost"] - 0.003) < 0.0001
|
||||
assert abs(result["output_cost"] - 0.0075) < 0.0001
|
||||
|
||||
def test_with_cache(self) -> None:
|
||||
"""测试带缓存"""
|
||||
result = calculate_request_cost(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
cache_creation_input_tokens=200,
|
||||
cache_read_input_tokens=300,
|
||||
input_price_per_1m=3.0,
|
||||
output_price_per_1m=15.0,
|
||||
cache_creation_price_per_1m=3.75,
|
||||
cache_read_price_per_1m=0.3,
|
||||
price_per_request=None,
|
||||
billing_template="claude",
|
||||
)
|
||||
|
||||
assert result["cache_creation_cost"] > 0
|
||||
assert result["cache_read_cost"] > 0
|
||||
assert result["cache_cost"] == result["cache_creation_cost"] + result["cache_read_cost"]
|
||||
|
||||
def test_different_templates(self) -> None:
|
||||
"""测试不同模板"""
|
||||
prices = {
|
||||
"input_tokens": 1000,
|
||||
"output_tokens": 500,
|
||||
"cache_creation_input_tokens": 200,
|
||||
"cache_read_input_tokens": 300,
|
||||
"input_price_per_1m": 3.0,
|
||||
"output_price_per_1m": 15.0,
|
||||
"cache_creation_price_per_1m": 3.75,
|
||||
"cache_read_price_per_1m": 0.3,
|
||||
"price_per_request": None,
|
||||
}
|
||||
|
||||
# Claude 模板有 cache_creation
|
||||
result_claude = calculate_request_cost(**prices, billing_template="claude")
|
||||
assert result_claude["cache_creation_cost"] > 0
|
||||
|
||||
# OpenAI 模板没有 cache_creation
|
||||
result_openai = calculate_request_cost(**prices, billing_template="openai")
|
||||
assert result_openai["cache_creation_cost"] == 0
|
||||
|
||||
def test_tiered_pricing_with_total_context(self) -> None:
|
||||
"""测试使用自定义 total_input_context 的阶梯计费"""
|
||||
tiered_pricing = {
|
||||
"tiers": [
|
||||
{"up_to": 200000, "input_price_per_1m": 3.0, "output_price_per_1m": 15.0},
|
||||
{"up_to": None, "input_price_per_1m": 1.5, "output_price_per_1m": 7.5},
|
||||
]
|
||||
}
|
||||
|
||||
# 传入预计算的 total_input_context
|
||||
result = calculate_request_cost(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
cache_creation_input_tokens=0,
|
||||
cache_read_input_tokens=0,
|
||||
input_price_per_1m=3.0,
|
||||
output_price_per_1m=15.0,
|
||||
cache_creation_price_per_1m=None,
|
||||
cache_read_price_per_1m=None,
|
||||
price_per_request=None,
|
||||
tiered_pricing=tiered_pricing,
|
||||
total_input_context=250000, # 预计算的值,超过 200k
|
||||
billing_template="claude",
|
||||
)
|
||||
|
||||
# 应该使用第二阶梯价格
|
||||
assert result["tier_index"] == 1
|
||||
Reference in New Issue
Block a user