12 Commits

Author SHA1 Message Date
fawney19
d696c575e6 refactor(migrations): add idempotency checks to migration scripts 2025-12-16 17:46:38 +08:00
fawney19
46ff5a1a50 refactor(models): enhance model management with official provider marking and extended metadata
- Add OFFICIAL_PROVIDERS set to mark first-party vendors in models.dev
- Implement official provider marking function with cache compatibility
- Extend model metadata with family, context_limit, output_limit fields
- Improve frontend model selection UI with wider panel and better search
- Add dark mode support for provider logos
- Optimize scrollbar styling for model lists
- Update deployment documentation with clearer migration steps
2025-12-16 17:28:40 +08:00
fawney19
edce43d45f fix(auth): make get_current_user and get_current_user_from_header async functions
将 get_current_user 和 get_current_user_from_header 函数声明为 async,
并更新 AuthService.verify_token 的调用为 await,以正确处理异步 Token 验证。
2025-12-16 13:42:26 +08:00
fawney19
33265b4b13 refactor(global-model): migrate model metadata to flexible config structure
将模型配置从多个固定字段(description, official_url, icon_url, default_supports_* 等)
统一为灵活的 config JSON 字段,提高扩展性。同时优化前端模型创建表单,支持从 models-dev
列表直接选择模型快速填充。

主要变更:
- 后端:模型表迁移,支持 config JSON 存储模型能力和元信息
- 前端:GlobalModelFormDialog 支持两种创建方式(列表选择/手动填写)
- API 类型更新,对齐新的数据结构
2025-12-16 12:21:21 +08:00
fawney19
a94aeca2d3 docs(deploy): add database migration step to deployment guide and create migration script 2025-12-16 09:21:24 +08:00
fawney19
c42ebdd0ee test(handler): add comprehensive stream processor unit tests 2025-12-16 02:40:26 +08:00
fawney19
f1e3c2ab11 feat(frontend-usage): enhance usage UI with first byte latency metrics
- Update usage records table to display first_byte_time_ms metrics
- Improve request timeline visualization for latency tracking
- Extend usage types for new timing information
2025-12-16 02:39:54 +08:00
fawney19
4e2ba0e57f feat(usage): add first_byte_time_ms tracking to usage statistics
- Enhance usage service to capture and store first byte latency metrics
- Update usage API routes to include new timing information
2025-12-16 02:39:36 +08:00
fawney19
a3df41d63d refactor(cli-handler): improve stream handling and response processing
- Refactor CLI handler base for better stream context management
- Optimize request/response handling for Claude, OpenAI, and Gemini CLI adapters
- Enhance telemetry tracking across CLI handlers
2025-12-16 02:39:20 +08:00
fawney19
ad1c8c394c refactor(handler): optimize stream processing and telemetry pipeline
- Enhance stream context for better token and latency tracking
- Refactor stream processor for improved performance metrics
- Improve telemetry integration with first_byte_time_ms support
- Add comprehensive stream context unit tests
2025-12-16 02:39:03 +08:00
fawney19
9b496abb73 feat(db): add first_byte_time_ms column to usage table 2025-12-16 02:38:43 +08:00
fawney19
f3a69a6160 refactor(handler): implement defensive token update strategy and extract cache creation token utility
- Add extract_cache_creation_tokens utility to handle new/old cache creation token formats
- Implement defensive update strategy in StreamContext to prevent zero values overwriting valid data
- Simplify cache creation token parsing in Claude handler using new utility
- Add comprehensive test suite for cache creation token extraction
- Improve type hints in handler classes
2025-12-16 00:02:49 +08:00
56 changed files with 2688 additions and 1139 deletions

View File

@@ -60,8 +60,11 @@ python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
# 3. 部署
docker-compose up -d
# 4. 更新
docker-compose pull && docker-compose up -d
# 4. 首次部署时, 初始化数据库
./migrate.sh
# 5. 更新
docker-compose pull && docker-compose up -d && ./migrate.sh
```
### Docker Compose本地构建镜像

View File

@@ -26,16 +26,66 @@ branch_labels = None
depends_on = None
def column_exists(bind, table_name: str, column_name: str) -> bool:
"""检查列是否存在"""
result = bind.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = :table_name AND column_name = :column_name
)
"""
),
{"table_name": table_name, "column_name": column_name},
)
return result.scalar()
def table_exists(bind, table_name: str) -> bool:
"""检查表是否存在"""
result = bind.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.tables
WHERE table_name = :table_name
)
"""
),
{"table_name": table_name},
)
return result.scalar()
def index_exists(bind, index_name: str) -> bool:
"""检查索引是否存在"""
result = bind.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1 FROM pg_indexes
WHERE indexname = :index_name
)
"""
),
{"index_name": index_name},
)
return result.scalar()
def upgrade() -> None:
"""添加 provider_model_aliases 字段,迁移数据,删除 model_mappings 表"""
# 1. 添加 provider_model_aliases 字段
op.add_column(
'models',
sa.Column('provider_model_aliases', sa.JSON(), nullable=True)
)
# 2. 迁移 model_mappings 数据
bind = op.get_bind()
# 1. 添加 provider_model_aliases 字段(如果不存在)
if not column_exists(bind, "models", "provider_model_aliases"):
op.add_column(
'models',
sa.Column('provider_model_aliases', sa.JSON(), nullable=True)
)
# 2. 迁移 model_mappings 数据(如果表存在)
session = Session(bind=bind)
model_mappings_table = sa.table(
@@ -96,104 +146,118 @@ def upgrade() -> None:
# 查询所有活跃的 provider 级别 alias只迁移 is_active=True 且 mapping_type='alias' 的)
# 全局别名/映射不迁移(新架构不再支持 source_model -> GlobalModel.name 的解析)
mappings = session.execute(
sa.select(
model_mappings_table.c.source_model,
model_mappings_table.c.target_global_model_id,
model_mappings_table.c.provider_id,
)
.where(
model_mappings_table.c.is_active.is_(True),
model_mappings_table.c.provider_id.isnot(None),
model_mappings_table.c.mapping_type == "alias",
)
.order_by(model_mappings_table.c.provider_id, model_mappings_table.c.source_model)
).all()
# 按 (provider_id, target_global_model_id) 分组,收集别名
alias_groups: dict = {}
for source_model, target_global_model_id, provider_id in mappings:
if not isinstance(source_model, str):
continue
source_model = source_model.strip()
if not source_model:
continue
if not isinstance(provider_id, str) or not provider_id:
continue
if not isinstance(target_global_model_id, str) or not target_global_model_id:
continue
key = (provider_id, target_global_model_id)
if key not in alias_groups:
alias_groups[key] = []
priority = len(alias_groups[key]) + 1
alias_groups[key].append({"name": source_model, "priority": priority})
# 更新对应的 models 记录
for (provider_id, global_model_id), aliases in alias_groups.items():
model_row = session.execute(
sa.select(models_table.c.id, models_table.c.provider_model_aliases)
# 仅当 model_mappings 表存在时执行迁移
if table_exists(bind, "model_mappings"):
mappings = session.execute(
sa.select(
model_mappings_table.c.source_model,
model_mappings_table.c.target_global_model_id,
model_mappings_table.c.provider_id,
)
.where(
models_table.c.provider_id == provider_id,
models_table.c.global_model_id == global_model_id,
model_mappings_table.c.is_active.is_(True),
model_mappings_table.c.provider_id.isnot(None),
model_mappings_table.c.mapping_type == "alias",
)
.limit(1)
).first()
.order_by(model_mappings_table.c.provider_id, model_mappings_table.c.source_model)
).all()
if model_row:
model_id = model_row[0]
existing_aliases = normalize_alias_list(model_row[1])
# 按 (provider_id, target_global_model_id) 分组,收集别名
alias_groups: dict = {}
for source_model, target_global_model_id, provider_id in mappings:
if not isinstance(source_model, str):
continue
source_model = source_model.strip()
if not source_model:
continue
if not isinstance(provider_id, str) or not provider_id:
continue
if not isinstance(target_global_model_id, str) or not target_global_model_id:
continue
existing_names = {a["name"] for a in existing_aliases}
merged_aliases = list(existing_aliases)
for alias in aliases:
name = alias.get("name")
if not isinstance(name, str):
continue
name = name.strip()
if not name or name in existing_names:
continue
key = (provider_id, target_global_model_id)
if key not in alias_groups:
alias_groups[key] = []
priority = len(alias_groups[key]) + 1
alias_groups[key].append({"name": source_model, "priority": priority})
merged_aliases.append(
{
"name": name,
"priority": len(merged_aliases) + 1,
}
# 更新对应的 models 记录
for (provider_id, global_model_id), aliases in alias_groups.items():
model_row = session.execute(
sa.select(models_table.c.id, models_table.c.provider_model_aliases)
.where(
models_table.c.provider_id == provider_id,
models_table.c.global_model_id == global_model_id,
)
existing_names.add(name)
.limit(1)
).first()
session.execute(
models_table.update()
.where(models_table.c.id == model_id)
.values(
provider_model_aliases=merged_aliases if merged_aliases else None,
updated_at=datetime.now(timezone.utc),
if model_row:
model_id = model_row[0]
existing_aliases = normalize_alias_list(model_row[1])
existing_names = {a["name"] for a in existing_aliases}
merged_aliases = list(existing_aliases)
for alias in aliases:
name = alias.get("name")
if not isinstance(name, str):
continue
name = name.strip()
if not name or name in existing_names:
continue
merged_aliases.append(
{
"name": name,
"priority": len(merged_aliases) + 1,
}
)
existing_names.add(name)
session.execute(
models_table.update()
.where(models_table.c.id == model_id)
.values(
provider_model_aliases=merged_aliases if merged_aliases else None,
updated_at=datetime.now(timezone.utc),
)
)
)
session.commit()
session.commit()
# 3. 删除 model_mappings 表
op.drop_table('model_mappings')
# 3. 删除 model_mappings 表
op.drop_table('model_mappings')
# 4. 添加索引优化别名解析性能
# provider_model_name 索引(支持精确匹配)
op.create_index(
"idx_model_provider_model_name",
"models",
["provider_model_name"],
unique=False,
postgresql_where=sa.text("is_active = true"),
)
# provider_model_name 索引(支持精确匹配,如果不存在
if not index_exists(bind, "idx_model_provider_model_name"):
op.create_index(
"idx_model_provider_model_name",
"models",
["provider_model_name"],
unique=False,
postgresql_where=sa.text("is_active = true"),
)
# provider_model_aliases GIN 索引(支持 JSONB 查询,仅 PostgreSQL
if bind.dialect.name == "postgresql":
# 将 json 列转为 jsonbjsonb 性能更好且支持 GIN 索引)
# 使用 IF NOT EXISTS 风格的检查来避免重复转换
op.execute(
"""
ALTER TABLE models
ALTER COLUMN provider_model_aliases TYPE jsonb
USING provider_model_aliases::jsonb
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = 'models'
AND column_name = 'provider_model_aliases'
AND data_type = 'json'
) THEN
ALTER TABLE models
ALTER COLUMN provider_model_aliases TYPE jsonb
USING provider_model_aliases::jsonb;
END IF;
END $$;
"""
)
# 创建 GIN 索引

View File

@@ -0,0 +1,47 @@
"""add first_byte_time_ms to usage table
Revision ID: 180e63a9c83a
Revises: e9b3d63f0cbf
Create Date: 2025-12-15 17:07:44.631032+00:00
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = '180e63a9c83a'
down_revision = 'e9b3d63f0cbf'
branch_labels = None
depends_on = None
def column_exists(bind, table_name: str, column_name: str) -> bool:
"""检查列是否存在"""
result = bind.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = :table_name AND column_name = :column_name
)
"""
),
{"table_name": table_name, "column_name": column_name},
)
return result.scalar()
def upgrade() -> None:
"""应用迁移:升级到新版本"""
bind = op.get_bind()
# 添加首字时间字段到 usage 表(如果不存在)
if not column_exists(bind, "usage", "first_byte_time_ms"):
op.add_column('usage', sa.Column('first_byte_time_ms', sa.Integer(), nullable=True))
def downgrade() -> None:
"""回滚迁移:降级到旧版本"""
# 删除首字时间字段
op.drop_column('usage', 'first_byte_time_ms')

View File

@@ -0,0 +1,110 @@
"""refactor global_model to use config json field
Revision ID: 1cc6942cf06f
Revises: 180e63a9c83a
Create Date: 2025-12-16 03:11:32.480976+00:00
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '1cc6942cf06f'
down_revision = '180e63a9c83a'
branch_labels = None
depends_on = None
def column_exists(bind, table_name: str, column_name: str) -> bool:
"""检查列是否存在"""
result = bind.execute(
sa.text(
"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = :table_name AND column_name = :column_name
)
"""
),
{"table_name": table_name, "column_name": column_name},
)
return result.scalar()
def upgrade() -> None:
"""应用迁移:升级到新版本
1. 添加 config 列
2. 把旧数据迁移到 config
3. 删除旧列
"""
bind = op.get_bind()
# 检查是否已经迁移过config 列存在且旧列不存在)
has_config = column_exists(bind, "global_models", "config")
has_old_columns = column_exists(bind, "global_models", "default_supports_streaming")
if has_config and not has_old_columns:
# 已完成迁移,跳过
return
# 1. 添加 config 列(使用 JSONB 类型,支持索引和更高效的查询)
if not has_config:
op.add_column('global_models', sa.Column('config', postgresql.JSONB(), nullable=True))
# 2. 迁移数据:把旧字段合并到 config JSON仅当旧列存在时
if has_old_columns:
op.execute("""
UPDATE global_models
SET config = jsonb_strip_nulls(jsonb_build_object(
'streaming', COALESCE(default_supports_streaming, true),
'vision', CASE WHEN COALESCE(default_supports_vision, false) THEN true ELSE NULL END,
'function_calling', CASE WHEN COALESCE(default_supports_function_calling, false) THEN true ELSE NULL END,
'extended_thinking', CASE WHEN COALESCE(default_supports_extended_thinking, false) THEN true ELSE NULL END,
'image_generation', CASE WHEN COALESCE(default_supports_image_generation, false) THEN true ELSE NULL END,
'description', description,
'icon_url', icon_url,
'official_url', official_url
))
""")
# 3. 删除旧列
op.drop_column('global_models', 'default_supports_streaming')
op.drop_column('global_models', 'default_supports_vision')
op.drop_column('global_models', 'default_supports_function_calling')
op.drop_column('global_models', 'default_supports_extended_thinking')
op.drop_column('global_models', 'default_supports_image_generation')
op.drop_column('global_models', 'description')
op.drop_column('global_models', 'icon_url')
op.drop_column('global_models', 'official_url')
def downgrade() -> None:
"""回滚迁移:降级到旧版本"""
# 1. 添加旧列
op.add_column('global_models', sa.Column('icon_url', sa.VARCHAR(length=500), nullable=True))
op.add_column('global_models', sa.Column('official_url', sa.VARCHAR(length=500), nullable=True))
op.add_column('global_models', sa.Column('description', sa.TEXT(), nullable=True))
op.add_column('global_models', sa.Column('default_supports_streaming', sa.BOOLEAN(), nullable=True))
op.add_column('global_models', sa.Column('default_supports_vision', sa.BOOLEAN(), nullable=True))
op.add_column('global_models', sa.Column('default_supports_function_calling', sa.BOOLEAN(), nullable=True))
op.add_column('global_models', sa.Column('default_supports_extended_thinking', sa.BOOLEAN(), nullable=True))
op.add_column('global_models', sa.Column('default_supports_image_generation', sa.BOOLEAN(), nullable=True))
# 2. 从 config 恢复数据
op.execute("""
UPDATE global_models
SET
default_supports_streaming = COALESCE((config->>'streaming')::boolean, true),
default_supports_vision = COALESCE((config->>'vision')::boolean, false),
default_supports_function_calling = COALESCE((config->>'function_calling')::boolean, false),
default_supports_extended_thinking = COALESCE((config->>'extended_thinking')::boolean, false),
default_supports_image_generation = COALESCE((config->>'image_generation')::boolean, false),
description = config->>'description',
icon_url = config->>'icon_url',
official_url = config->>'official_url'
""")
# 3. 删除 config 列
op.drop_column('global_models', 'config')

View File

@@ -407,67 +407,45 @@ export interface TieredPricingConfig {
export interface GlobalModelCreate {
name: string
display_name: string
description?: string
official_url?: string
icon_url?: string
// 按次计费配置(可选,与阶梯计费叠加)
default_price_per_request?: number
// 阶梯计费配置(必填,固定价格用单阶梯表示)
default_tiered_pricing: TieredPricingConfig
// 默认能力配置
default_supports_vision?: boolean
default_supports_function_calling?: boolean
default_supports_streaming?: boolean
default_supports_extended_thinking?: boolean
default_supports_image_generation?: boolean
// Key 能力配置 - 模型支持的能力列表
supported_capabilities?: string[]
// 模型配置JSON格式- 包含能力、规格、元信息等
config?: Record<string, any>
is_active?: boolean
}
export interface GlobalModelUpdate {
display_name?: string
description?: string
official_url?: string
icon_url?: string
is_active?: boolean
// 按次计费配置
default_price_per_request?: number | null // null 表示清空
// 阶梯计费配置
default_tiered_pricing?: TieredPricingConfig
// 默认能力配置
default_supports_vision?: boolean
default_supports_function_calling?: boolean
default_supports_streaming?: boolean
default_supports_extended_thinking?: boolean
default_supports_image_generation?: boolean
// Key 能力配置 - 模型支持的能力列表
supported_capabilities?: string[] | null
// 模型配置JSON格式- 包含能力、规格、元信息等
config?: Record<string, any> | null
}
export interface GlobalModelResponse {
id: string
name: string
display_name: string
description?: string
official_url?: string
icon_url?: string
is_active: boolean
// 按次计费配置
default_price_per_request?: number
// 阶梯计费配置(必填)
default_tiered_pricing: TieredPricingConfig
// 默认能力配置
default_supports_vision?: boolean
default_supports_function_calling?: boolean
default_supports_streaming?: boolean
default_supports_extended_thinking?: boolean
default_supports_image_generation?: boolean
// Key 能力配置 - 模型支持的能力列表
supported_capabilities?: string[] | null
// 模型配置JSON格式
config?: Record<string, any> | null
// 统计数据
provider_count?: number
alias_count?: number
usage_count?: number
created_at: string
updated_at?: string

View File

@@ -0,0 +1,288 @@
/**
* Models.dev API 服务
* 通过后端代理获取 models.dev 数据(解决跨域问题)
*/
import api from './client'
// 缓存配置
const CACHE_KEY = 'models_dev_cache'
const CACHE_DURATION = 15 * 60 * 1000 // 15 分钟
// Models.dev API 数据结构
export interface ModelsDevCost {
input?: number
output?: number
reasoning?: number
cache_read?: number
}
export interface ModelsDevLimit {
context?: number
output?: number
}
export interface ModelsDevModel {
id: string
name: string
family?: string
reasoning?: boolean
tool_call?: boolean
structured_output?: boolean
temperature?: boolean
attachment?: boolean
knowledge?: string
release_date?: string
last_updated?: string
input?: string[] // 输入模态: text, image, audio, video, pdf
output?: string[] // 输出模态: text, image, audio
open_weights?: boolean
cost?: ModelsDevCost
limit?: ModelsDevLimit
deprecated?: boolean
}
export interface ModelsDevProvider {
id: string
env?: string[]
npm?: string
api?: string
name: string
doc?: string
models: Record<string, ModelsDevModel>
official?: boolean // 是否为官方提供商
}
export type ModelsDevData = Record<string, ModelsDevProvider>
// 扁平化的模型列表项(用于搜索和选择)
export interface ModelsDevModelItem {
providerId: string
providerName: string
modelId: string
modelName: string
family?: string
inputPrice?: number
outputPrice?: number
contextLimit?: number
outputLimit?: number
supportsVision?: boolean
supportsToolCall?: boolean
supportsReasoning?: boolean
supportsStructuredOutput?: boolean
supportsTemperature?: boolean
supportsAttachment?: boolean
openWeights?: boolean
deprecated?: boolean
official?: boolean // 是否来自官方提供商
// 用于 display_metadata 的额外字段
knowledgeCutoff?: string
releaseDate?: string
inputModalities?: string[]
outputModalities?: string[]
}
interface CacheData {
timestamp: number
data: ModelsDevData
}
// 内存缓存
let memoryCache: CacheData | null = null
function hasOfficialFlag(data: ModelsDevData): boolean {
return Object.values(data).some(provider => typeof provider?.official === 'boolean')
}
/**
* 获取 models.dev 数据(带缓存)
*/
export async function getModelsDevData(): Promise<ModelsDevData> {
// 1. 检查内存缓存
if (memoryCache && Date.now() - memoryCache.timestamp < CACHE_DURATION) {
// 兼容旧缓存:没有 official 字段时丢弃,强制刷新一次
if (hasOfficialFlag(memoryCache.data)) {
return memoryCache.data
}
memoryCache = null
}
// 2. 检查 localStorage 缓存
try {
const cached = localStorage.getItem(CACHE_KEY)
if (cached) {
const cacheData: CacheData = JSON.parse(cached)
if (Date.now() - cacheData.timestamp < CACHE_DURATION) {
// 兼容旧缓存:没有 official 字段时丢弃,强制刷新一次
if (hasOfficialFlag(cacheData.data)) {
memoryCache = cacheData
return cacheData.data
}
localStorage.removeItem(CACHE_KEY)
}
}
} catch {
// 缓存解析失败,忽略
}
// 3. 从后端代理获取新数据
const response = await api.get<ModelsDevData>('/api/admin/models/external')
const data = response.data
// 4. 更新缓存
const cacheData: CacheData = {
timestamp: Date.now(),
data,
}
memoryCache = cacheData
try {
localStorage.setItem(CACHE_KEY, JSON.stringify(cacheData))
} catch {
// localStorage 写入失败,忽略
}
return data
}
// 模型列表缓存(避免重复转换)
let modelsListCache: ModelsDevModelItem[] | null = null
let modelsListCacheTimestamp: number | null = null
/**
* 获取扁平化的模型列表
* 数据只加载一次,通过参数过滤官方/全部
*/
export async function getModelsDevList(officialOnly: boolean = true): Promise<ModelsDevModelItem[]> {
const data = await getModelsDevData()
const currentTimestamp = memoryCache?.timestamp ?? 0
// 如果缓存为空或数据已刷新,构建一次
if (!modelsListCache || modelsListCacheTimestamp !== currentTimestamp) {
const items: ModelsDevModelItem[] = []
for (const [providerId, provider] of Object.entries(data)) {
if (!provider.models) continue
for (const [modelId, model] of Object.entries(provider.models)) {
items.push({
providerId,
providerName: provider.name,
modelId,
modelName: model.name || modelId,
family: model.family,
inputPrice: model.cost?.input,
outputPrice: model.cost?.output,
contextLimit: model.limit?.context,
outputLimit: model.limit?.output,
supportsVision: model.input?.includes('image'),
supportsToolCall: model.tool_call,
supportsReasoning: model.reasoning,
supportsStructuredOutput: model.structured_output,
supportsTemperature: model.temperature,
supportsAttachment: model.attachment,
openWeights: model.open_weights,
deprecated: model.deprecated,
official: provider.official,
// display_metadata 相关字段
knowledgeCutoff: model.knowledge,
releaseDate: model.release_date,
inputModalities: model.input,
outputModalities: model.output,
})
}
}
// 按 provider 名称和模型名称排序
items.sort((a, b) => {
const providerCompare = a.providerName.localeCompare(b.providerName)
if (providerCompare !== 0) return providerCompare
return a.modelName.localeCompare(b.modelName)
})
modelsListCache = items
modelsListCacheTimestamp = currentTimestamp
}
// 根据参数过滤
if (officialOnly) {
return modelsListCache.filter(m => m.official)
}
return modelsListCache
}
/**
* 搜索模型
* 搜索时包含所有提供商(包括第三方)
*/
export async function searchModelsDevModels(
query: string,
options?: {
limit?: number
excludeDeprecated?: boolean
}
): Promise<ModelsDevModelItem[]> {
// 搜索时包含全部提供商
const allModels = await getModelsDevList(false)
const { limit = 50, excludeDeprecated = true } = options || {}
const queryLower = query.toLowerCase()
const filtered = allModels.filter((model) => {
if (excludeDeprecated && model.deprecated) return false
// 搜索模型 ID、名称、provider 名称、family
return (
model.modelId.toLowerCase().includes(queryLower) ||
model.modelName.toLowerCase().includes(queryLower) ||
model.providerName.toLowerCase().includes(queryLower) ||
model.family?.toLowerCase().includes(queryLower)
)
})
// 排序:精确匹配优先
filtered.sort((a, b) => {
const aExact =
a.modelId.toLowerCase() === queryLower ||
a.modelName.toLowerCase() === queryLower
const bExact =
b.modelId.toLowerCase() === queryLower ||
b.modelName.toLowerCase() === queryLower
if (aExact && !bExact) return -1
if (!aExact && bExact) return 1
return 0
})
return filtered.slice(0, limit)
}
/**
* 获取特定模型详情
*/
export async function getModelsDevModel(
providerId: string,
modelId: string
): Promise<ModelsDevModel | null> {
const data = await getModelsDevData()
return data[providerId]?.models?.[modelId] || null
}
/**
* 获取 provider logo URL
*/
export function getProviderLogoUrl(providerId: string): string {
return `https://models.dev/logos/${providerId}.svg`
}
/**
* 清除缓存
*/
export function clearModelsDevCache(): void {
memoryCache = null
modelsListCache = null
modelsListCacheTimestamp = null
try {
localStorage.removeItem(CACHE_KEY)
} catch {
// 忽略错误
}
}

View File

@@ -9,20 +9,14 @@ export interface PublicGlobalModel {
id: string
name: string
display_name: string | null
description: string | null
icon_url: string | null
is_active: boolean
// 阶梯计费配置
default_tiered_pricing: TieredPricingConfig
default_price_per_request: number | null // 按次计费价格
// 能力
default_supports_vision: boolean
default_supports_function_calling: boolean
default_supports_streaming: boolean
default_supports_extended_thinking: boolean
default_supports_image_generation: boolean
// Key 能力支持
supported_capabilities: string[] | null
// 模型配置JSON
config: Record<string, any> | null
}
export interface PublicGlobalModelListResponse {

View File

@@ -299,7 +299,7 @@ function formatDuration(ms: number): string {
const hours = Math.floor(ms / (1000 * 60 * 60))
const minutes = Math.floor((ms % (1000 * 60 * 60)) / (1000 * 60))
if (hours > 0) {
return `${hours}h${minutes > 0 ? minutes + 'm' : ''}`
return `${hours}h${minutes > 0 ? `${minutes}m` : ''}`
}
return `${minutes}m`
}

View File

@@ -2,174 +2,298 @@
<Dialog
:model-value="open"
:title="isEditMode ? '编辑模型' : '创建统一模型'"
:description="isEditMode ? '修改模型配置和价格信息' : '添加一个新的全局模型定义'"
:description="isEditMode ? '修改模型配置和价格信息' : ''"
:icon="isEditMode ? SquarePen : Layers"
size="xl"
size="3xl"
@update:model-value="handleDialogUpdate"
>
<form
class="space-y-5 max-h-[70vh] overflow-y-auto pr-1"
@submit.prevent="handleSubmit"
<div
class="flex gap-4"
:class="isEditMode ? '' : 'h-[500px]'"
>
<!-- 基本信息 -->
<section class="space-y-3">
<h4 class="font-medium text-sm">
基本信息
</h4>
<div class="grid grid-cols-2 gap-3">
<div class="space-y-1.5">
<Label
for="model-name"
class="text-xs"
>模型名称 *</Label>
<Input
id="model-name"
v-model="form.name"
placeholder="claude-3-5-sonnet-20241022"
:disabled="isEditMode"
required
/>
<p
v-if="!isEditMode"
class="text-xs text-muted-foreground"
>
创建后不可修改
</p>
</div>
<div class="space-y-1.5">
<Label
for="model-display-name"
class="text-xs"
>显示名称 *</Label>
<Input
id="model-display-name"
v-model="form.display_name"
placeholder="Claude 3.5 Sonnet"
required
/>
</div>
</div>
<div class="space-y-1.5">
<Label
for="model-description"
class="text-xs"
>描述</Label>
<Input
id="model-description"
v-model="form.description"
placeholder="简短描述此模型的特点"
/>
</div>
</section>
<!-- 能力配置 -->
<section class="space-y-2">
<h4 class="font-medium text-sm">
默认能力
</h4>
<div class="flex flex-wrap gap-2">
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
<input
v-model="form.default_supports_streaming"
type="checkbox"
class="rounded"
>
<Zap class="w-3.5 h-3.5 text-muted-foreground" />
<span>流式输出</span>
</label>
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
<input
v-model="form.default_supports_vision"
type="checkbox"
class="rounded"
>
<Eye class="w-3.5 h-3.5 text-muted-foreground" />
<span>视觉理解</span>
</label>
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
<input
v-model="form.default_supports_function_calling"
type="checkbox"
class="rounded"
>
<Wrench class="w-3.5 h-3.5 text-muted-foreground" />
<span>工具调用</span>
</label>
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
<input
v-model="form.default_supports_extended_thinking"
type="checkbox"
class="rounded"
>
<Brain class="w-3.5 h-3.5 text-muted-foreground" />
<span>深度思考</span>
</label>
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
<input
v-model="form.default_supports_image_generation"
type="checkbox"
class="rounded"
>
<Image class="w-3.5 h-3.5 text-muted-foreground" />
<span>图像生成</span>
</label>
</div>
</section>
<!-- Key 能力配置 -->
<section
v-if="availableCapabilities.length > 0"
class="space-y-2"
<!-- 左侧模型选择仅创建模式 -->
<div
v-if="!isEditMode"
class="w-[260px] shrink-0 flex flex-col h-full"
>
<h4 class="font-medium text-sm">
Key 能力支持
</h4>
<div class="flex flex-wrap gap-2">
<label
v-for="cap in availableCapabilities"
:key="cap.name"
class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm"
>
<input
type="checkbox"
:checked="form.supported_capabilities?.includes(cap.name)"
class="rounded"
@change="toggleCapability(cap.name)"
>
<span>{{ cap.display_name }}</span>
</label>
</div>
</section>
<!-- 价格配置 -->
<section class="space-y-3">
<h4 class="font-medium text-sm">
价格配置
</h4>
<TieredPricingEditor
ref="tieredPricingEditorRef"
v-model="tieredPricing"
:show-cache1h="form.supported_capabilities?.includes('cache_1h')"
/>
<!-- 按次计费 -->
<div class="flex items-center gap-3 pt-2 border-t">
<Label class="text-xs whitespace-nowrap">按次计费 ($/)</Label>
<!-- 搜索框 -->
<div class="relative mb-3">
<Search class="absolute left-2.5 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground" />
<Input
:model-value="form.default_price_per_request ?? ''"
type="number"
step="0.001"
min="0"
class="w-32"
placeholder="留空不启用"
@update:model-value="(v) => form.default_price_per_request = parseNumberInput(v, { allowFloat: true })"
v-model="searchQuery"
type="text"
placeholder="搜索模型、提供商..."
class="pl-8 h-8 text-sm"
/>
<span class="text-xs text-muted-foreground">每次请求固定费用,可与 Token 计费叠加</span>
</div>
</section>
</form>
<!-- 模型列表两级结构 -->
<div class="flex-1 overflow-y-auto border rounded-lg min-h-0 scrollbar-thin">
<div
v-if="loading"
class="flex items-center justify-center h-32"
>
<Loader2 class="w-5 h-5 animate-spin text-muted-foreground" />
</div>
<template v-else>
<!-- 提供商分组 -->
<div
v-for="group in groupedModels"
:key="group.providerId"
class="border-b last:border-b-0"
>
<!-- 提供商标题行 -->
<div
class="flex items-center gap-2 px-2.5 py-2 cursor-pointer hover:bg-muted text-sm"
@click="toggleProvider(group.providerId)"
>
<ChevronRight
class="w-3.5 h-3.5 text-muted-foreground transition-transform shrink-0"
:class="expandedProvider === group.providerId ? 'rotate-90' : ''"
/>
<img
:src="getProviderLogoUrl(group.providerId)"
:alt="group.providerName"
class="w-4 h-4 rounded shrink-0 dark:invert dark:brightness-90"
@error="handleLogoError"
>
<span class="truncate font-medium text-xs flex-1">{{ group.providerName }}</span>
<span class="text-[10px] text-muted-foreground shrink-0">{{ group.models.length }}</span>
</div>
<!-- 模型列表 -->
<div
v-if="expandedProvider === group.providerId"
class="bg-muted/30"
>
<div
v-for="model in group.models"
:key="model.modelId"
class="flex items-center gap-2 pl-7 pr-2.5 py-1.5 cursor-pointer text-xs border-t"
:class="selectedModel?.modelId === model.modelId && selectedModel?.providerId === model.providerId
? 'bg-primary text-primary-foreground'
: 'hover:bg-muted'"
@click="selectModel(model)"
>
<span class="truncate">{{ model.modelName }}</span>
</div>
</div>
</div>
<div
v-if="groupedModels.length === 0"
class="text-center py-8 text-sm text-muted-foreground"
>
{{ searchQuery ? '未找到模型' : '加载中...' }}
</div>
</template>
</div>
</div>
<!-- 右侧表单 -->
<div
class="flex-1 overflow-y-auto h-full scrollbar-thin"
:class="isEditMode ? 'max-h-[70vh]' : ''"
>
<form
class="space-y-5"
@submit.prevent="handleSubmit"
>
<!-- 基本信息 -->
<section class="space-y-3">
<h4 class="font-medium text-sm">
基本信息
</h4>
<div class="grid grid-cols-2 gap-3">
<div class="space-y-1.5">
<Label
for="model-name"
class="text-xs"
>模型名称 *</Label>
<Input
id="model-name"
v-model="form.name"
placeholder="claude-3-5-sonnet-20241022"
:disabled="isEditMode"
required
/>
</div>
<div class="space-y-1.5">
<Label
for="model-display-name"
class="text-xs"
>显示名称 *</Label>
<Input
id="model-display-name"
v-model="form.display_name"
placeholder="Claude 3.5 Sonnet"
required
/>
</div>
</div>
<div class="space-y-1.5">
<Label
for="model-description"
class="text-xs"
>描述</Label>
<Input
id="model-description"
:model-value="form.config?.description || ''"
placeholder="简短描述此模型的特点"
@update:model-value="(v) => setConfigField('description', v || undefined)"
/>
</div>
<div class="grid grid-cols-3 gap-3">
<div class="space-y-1.5">
<Label
for="model-family"
class="text-xs"
>模型系列</Label>
<Input
id="model-family"
:model-value="form.config?.family || ''"
placeholder=" GPT-4Claude 3"
@update:model-value="(v) => setConfigField('family', v || undefined)"
/>
</div>
<div class="space-y-1.5">
<Label
for="model-context-limit"
class="text-xs"
>上下文限制</Label>
<Input
id="model-context-limit"
type="number"
:model-value="form.config?.context_limit ?? ''"
placeholder=" 128000"
@update:model-value="(v) => setConfigField('context_limit', v ? Number(v) : undefined)"
/>
</div>
<div class="space-y-1.5">
<Label
for="model-output-limit"
class="text-xs"
>输出限制</Label>
<Input
id="model-output-limit"
type="number"
:model-value="form.config?.output_limit ?? ''"
placeholder=" 8192"
@update:model-value="(v) => setConfigField('output_limit', v ? Number(v) : undefined)"
/>
</div>
</div>
</section>
<!-- 能力配置 -->
<section class="space-y-2">
<h4 class="font-medium text-sm">
默认能力
</h4>
<div class="flex flex-wrap gap-2">
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
<input
type="checkbox"
:checked="form.config?.streaming !== false"
class="rounded"
@change="setConfigField('streaming', ($event.target as HTMLInputElement).checked)"
>
<Zap class="w-3.5 h-3.5 text-muted-foreground" />
<span>流式</span>
</label>
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
<input
type="checkbox"
:checked="form.config?.vision === true"
class="rounded"
@change="setConfigField('vision', ($event.target as HTMLInputElement).checked)"
>
<Eye class="w-3.5 h-3.5 text-muted-foreground" />
<span>视觉</span>
</label>
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
<input
type="checkbox"
:checked="form.config?.function_calling === true"
class="rounded"
@change="setConfigField('function_calling', ($event.target as HTMLInputElement).checked)"
>
<Wrench class="w-3.5 h-3.5 text-muted-foreground" />
<span>工具</span>
</label>
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
<input
type="checkbox"
:checked="form.config?.extended_thinking === true"
class="rounded"
@change="setConfigField('extended_thinking', ($event.target as HTMLInputElement).checked)"
>
<Brain class="w-3.5 h-3.5 text-muted-foreground" />
<span>思考</span>
</label>
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
<input
type="checkbox"
:checked="form.config?.image_generation === true"
class="rounded"
@change="setConfigField('image_generation', ($event.target as HTMLInputElement).checked)"
>
<Image class="w-3.5 h-3.5 text-muted-foreground" />
<span>生图</span>
</label>
</div>
</section>
<!-- Key 能力配置 -->
<section
v-if="availableCapabilities.length > 0"
class="space-y-2"
>
<h4 class="font-medium text-sm">
Key 能力支持
</h4>
<div class="flex flex-wrap gap-2">
<label
v-for="cap in availableCapabilities"
:key="cap.name"
class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm"
>
<input
type="checkbox"
:checked="form.supported_capabilities?.includes(cap.name)"
class="rounded"
@change="toggleCapability(cap.name)"
>
<span>{{ cap.display_name }}</span>
</label>
</div>
</section>
<!-- 价格配置 -->
<section class="space-y-3">
<h4 class="font-medium text-sm">
价格配置
</h4>
<TieredPricingEditor
ref="tieredPricingEditorRef"
v-model="tieredPricing"
:show-cache1h="form.supported_capabilities?.includes('cache_1h')"
/>
<div class="flex items-center gap-3 pt-2 border-t">
<Label class="text-xs whitespace-nowrap">按次计费</Label>
<Input
:model-value="form.default_price_per_request ?? ''"
type="number"
step="0.001"
min="0"
class="w-24"
placeholder="$/"
@update:model-value="(v) => form.default_price_per_request = parseNumberInput(v, { allowFloat: true })"
/>
<span class="text-xs text-muted-foreground">可与 Token 计费叠加</span>
</div>
</section>
</form>
</div>
</div>
<template #footer>
<Button
@@ -180,7 +304,7 @@
取消
</Button>
<Button
:disabled="submitting"
:disabled="submitting || !form.name || !form.display_name"
@click="handleSubmit"
>
<Loader2
@@ -189,19 +313,35 @@
/>
{{ isEditMode ? '保存' : '创建' }}
</Button>
<Button
v-if="selectedModel && !isEditMode"
type="button"
variant="ghost"
@click="clearSelection"
>
清空
</Button>
</template>
</Dialog>
</template>
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import { Eye, Wrench, Brain, Zap, Image, Loader2, Layers, SquarePen } from 'lucide-vue-next'
import { ref, computed, onMounted, watch } from 'vue'
import {
Eye, Wrench, Brain, Zap, Image, Loader2, Layers, SquarePen,
Search, ChevronRight
} from 'lucide-vue-next'
import { Dialog, Button, Input, Label } from '@/components/ui'
import { useToast } from '@/composables/useToast'
import { useFormDialog } from '@/composables/useFormDialog'
import { parseNumberInput } from '@/utils/form'
import { log } from '@/utils/logger'
import TieredPricingEditor from './TieredPricingEditor.vue'
import {
getModelsDevList,
getProviderLogoUrl,
type ModelsDevModelItem,
} from '@/api/models-dev'
import {
createGlobalModel,
updateGlobalModel,
@@ -226,42 +366,147 @@ const { success, error: showError } = useToast()
const submitting = ref(false)
const tieredPricingEditorRef = ref<InstanceType<typeof TieredPricingEditor> | null>(null)
// 阶梯计费配置(统一使用,固定价格就是单阶梯)
// 模型列表相关
const loading = ref(false)
const searchQuery = ref('')
const allModelsCache = ref<ModelsDevModelItem[]>([]) // 全部模型(缓存)
const selectedModel = ref<ModelsDevModelItem | null>(null)
const expandedProvider = ref<string | null>(null)
// 当前显示的模型列表:有搜索词时用全部,否则只用官方
const allModels = computed(() => {
if (searchQuery.value) {
return allModelsCache.value
}
return allModelsCache.value.filter(m => m.official)
})
// 按提供商分组的模型
interface ProviderGroup {
providerId: string
providerName: string
models: ModelsDevModelItem[]
}
const groupedModels = computed(() => {
let models = allModels.value.filter(m => !m.deprecated)
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
models = models.filter(model =>
model.providerId.toLowerCase().includes(query) ||
model.providerName.toLowerCase().includes(query) ||
model.modelId.toLowerCase().includes(query) ||
model.modelName.toLowerCase().includes(query) ||
model.family?.toLowerCase().includes(query)
)
}
// 按提供商分组
const groups = new Map<string, ProviderGroup>()
for (const model of models) {
if (!groups.has(model.providerId)) {
groups.set(model.providerId, {
providerId: model.providerId,
providerName: model.providerName,
models: []
})
}
groups.get(model.providerId)!.models.push(model)
}
// 转换为数组并排序
let result = Array.from(groups.values())
// 如果有搜索词,把提供商名称/ID匹配的排在前面
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
result.sort((a, b) => {
const aProviderMatch = a.providerId.toLowerCase().includes(query) || a.providerName.toLowerCase().includes(query)
const bProviderMatch = b.providerId.toLowerCase().includes(query) || b.providerName.toLowerCase().includes(query)
if (aProviderMatch && !bProviderMatch) return -1
if (!aProviderMatch && bProviderMatch) return 1
return a.providerName.localeCompare(b.providerName)
})
} else {
result.sort((a, b) => a.providerName.localeCompare(b.providerName))
}
return result
})
// 搜索时如果只有一个提供商,自动展开
watch(groupedModels, (groups) => {
if (searchQuery.value && groups.length === 1) {
expandedProvider.value = groups[0].providerId
}
})
// 切换提供商展开状态
function toggleProvider(providerId: string) {
expandedProvider.value = expandedProvider.value === providerId ? null : providerId
}
// 阶梯计费配置
const tieredPricing = ref<TieredPricingConfig | null>(null)
interface FormData {
name: string
display_name: string
description?: string
default_price_per_request?: number
default_supports_streaming?: boolean
default_supports_image_generation?: boolean
default_supports_vision?: boolean
default_supports_function_calling?: boolean
default_supports_extended_thinking?: boolean
supported_capabilities?: string[]
config?: Record<string, any>
is_active?: boolean
}
const defaultForm = (): FormData => ({
name: '',
display_name: '',
description: '',
default_price_per_request: undefined,
default_supports_streaming: true,
default_supports_image_generation: false,
default_supports_vision: false,
default_supports_function_calling: false,
default_supports_extended_thinking: false,
supported_capabilities: [],
config: { streaming: true },
is_active: true,
})
const form = ref<FormData>(defaultForm())
const KEEP_FALSE_CONFIG_KEYS = new Set(['streaming'])
// 设置 config 字段
function setConfigField(key: string, value: any) {
if (!form.value.config) {
form.value.config = {}
}
if (value === undefined || value === '' || (value === false && !KEEP_FALSE_CONFIG_KEYS.has(key))) {
delete form.value.config[key]
} else {
form.value.config[key] = value
}
}
// Key 能力选项
const availableCapabilities = ref<CapabilityDefinition[]>([])
// 加载模型列表
async function loadModels() {
if (allModelsCache.value.length > 0) return
loading.value = true
try {
// 只加载一次全部模型,过滤在 computed 中完成
allModelsCache.value = await getModelsDevList(false)
} catch (err) {
log.error('Failed to load models:', err)
} finally {
loading.value = false
}
}
// 打开对话框时加载数据
watch(() => props.open, (isOpen) => {
if (isOpen && !props.model) {
loadModels()
}
})
// 加载可用能力列表
async function loadCapabilities() {
try {
@@ -284,15 +529,70 @@ function toggleCapability(capName: string) {
}
}
// 组件挂载时加载能力列表
onMounted(() => {
loadCapabilities()
})
// 选择模型并填充表单
function selectModel(model: ModelsDevModelItem) {
selectedModel.value = model
expandedProvider.value = model.providerId
form.value.name = model.modelId
form.value.display_name = model.modelName
// 构建 config
const config: Record<string, any> = {
streaming: true,
}
if (model.supportsVision) config.vision = true
if (model.supportsToolCall) config.function_calling = true
if (model.supportsReasoning) config.extended_thinking = true
if (model.supportsStructuredOutput) config.structured_output = true
if (model.supportsTemperature !== false) config.temperature = model.supportsTemperature
if (model.supportsAttachment) config.attachment = true
if (model.openWeights) config.open_weights = true
if (model.contextLimit) config.context_limit = model.contextLimit
if (model.outputLimit) config.output_limit = model.outputLimit
if (model.knowledgeCutoff) config.knowledge_cutoff = model.knowledgeCutoff
if (model.family) config.family = model.family
if (model.releaseDate) config.release_date = model.releaseDate
if (model.inputModalities?.length) config.input_modalities = model.inputModalities
if (model.outputModalities?.length) config.output_modalities = model.outputModalities
form.value.config = config
if (model.inputPrice !== undefined || model.outputPrice !== undefined) {
tieredPricing.value = {
tiers: [{
up_to: null,
input_price_per_1m: model.inputPrice || 0,
output_price_per_1m: model.outputPrice || 0,
}]
}
} else {
tieredPricing.value = null
}
}
// 清除选择(手动填写)
function clearSelection() {
selectedModel.value = null
form.value = defaultForm()
tieredPricing.value = null
}
// Logo 加载失败处理
function handleLogoError(event: Event) {
const img = event.target as HTMLImageElement
img.style.display = 'none'
}
// 重置表单
function resetForm() {
form.value = defaultForm()
tieredPricing.value = null
searchQuery.value = ''
selectedModel.value = null
expandedProvider.value = null
}
// 加载模型数据(编辑模式)
@@ -301,18 +601,11 @@ function loadModelData() {
form.value = {
name: props.model.name,
display_name: props.model.display_name,
description: props.model.description,
default_price_per_request: props.model.default_price_per_request,
default_supports_streaming: props.model.default_supports_streaming,
default_supports_image_generation: props.model.default_supports_image_generation,
default_supports_vision: props.model.default_supports_vision,
default_supports_function_calling: props.model.default_supports_function_calling,
default_supports_extended_thinking: props.model.default_supports_extended_thinking,
supported_capabilities: [...(props.model.supported_capabilities || [])],
config: props.model.config ? { ...props.model.config } : { streaming: true },
is_active: props.model.is_active,
}
// 加载阶梯计费配置(深拷贝)
if (props.model.default_tiered_pricing) {
tieredPricing.value = JSON.parse(JSON.stringify(props.model.default_tiered_pricing))
}
@@ -339,24 +632,22 @@ async function handleSubmit() {
return
}
// 获取包含自动计算缓存价格的最终数据
const finalTiers = tieredPricingEditorRef.value?.getFinalTiers()
const finalTieredPricing = finalTiers ? { tiers: finalTiers } : tieredPricing.value
// 清理空的 config
const cleanConfig = form.value.config && Object.keys(form.value.config).length > 0
? form.value.config
: undefined
submitting.value = true
try {
if (isEditMode.value && props.model) {
const updateData: GlobalModelUpdate = {
display_name: form.value.display_name,
description: form.value.description,
// 使用 null 而不是 undefined 来显式清空字段
config: cleanConfig || null,
default_price_per_request: form.value.default_price_per_request ?? null,
default_tiered_pricing: finalTieredPricing,
default_supports_streaming: form.value.default_supports_streaming,
default_supports_image_generation: form.value.default_supports_image_generation,
default_supports_vision: form.value.default_supports_vision,
default_supports_function_calling: form.value.default_supports_function_calling,
default_supports_extended_thinking: form.value.default_supports_extended_thinking,
supported_capabilities: form.value.supported_capabilities?.length ? form.value.supported_capabilities : null,
is_active: form.value.is_active,
}
@@ -366,14 +657,9 @@ async function handleSubmit() {
const createData: GlobalModelCreate = {
name: form.value.name!,
display_name: form.value.display_name!,
description: form.value.description,
default_price_per_request: form.value.default_price_per_request || undefined,
config: cleanConfig,
default_price_per_request: form.value.default_price_per_request ?? undefined,
default_tiered_pricing: finalTieredPricing,
default_supports_streaming: form.value.default_supports_streaming,
default_supports_image_generation: form.value.default_supports_image_generation,
default_supports_vision: form.value.default_supports_vision,
default_supports_function_calling: form.value.default_supports_function_calling,
default_supports_extended_thinking: form.value.default_supports_extended_thinking,
supported_capabilities: form.value.supported_capabilities?.length ? form.value.supported_capabilities : undefined,
is_active: form.value.is_active,
}

View File

@@ -38,12 +38,12 @@
>
<Copy class="w-3 h-3" />
</button>
<template v-if="model.description">
<template v-if="model.config?.description">
<span class="shrink-0">·</span>
<span
class="text-xs truncate"
:title="model.description"
>{{ model.description }}</span>
:title="model.config?.description"
>{{ model.config?.description }}</span>
</template>
</div>
</div>
@@ -143,10 +143,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_streaming ?? false ? 'default' : 'secondary'"
:variant="model.config?.streaming !== false ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_streaming ?? false ? '支持' : '不支持' }}
{{ model.config?.streaming !== false ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -160,10 +160,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_image_generation ?? false ? 'default' : 'secondary'"
:variant="model.config?.image_generation === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_image_generation ?? false ? '支持' : '不支持' }}
{{ model.config?.image_generation === true ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -177,10 +177,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_vision ?? false ? 'default' : 'secondary'"
:variant="model.config?.vision === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_vision ?? false ? '支持' : '不支持' }}
{{ model.config?.vision === true ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -194,10 +194,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_function_calling ?? false ? 'default' : 'secondary'"
:variant="model.config?.function_calling === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_function_calling ?? false ? '支持' : '不支持' }}
{{ model.config?.function_calling === true ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -211,10 +211,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_extended_thinking ?? false ? 'default' : 'secondary'"
:variant="model.config?.extended_thinking === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_extended_thinking ?? false ? '支持' : '不支持' }}
{{ model.config?.extended_thinking === true ? '支持' : '不支持' }}
</Badge>
</div>
</div>
@@ -396,11 +396,11 @@
</div>
<div class="p-3 rounded-lg border bg-muted/20">
<div class="flex items-center justify-between">
<Label class="text-xs text-muted-foreground">别名数量</Label>
<Tag class="w-4 h-4 text-muted-foreground" />
<Label class="text-xs text-muted-foreground">调用次数</Label>
<BarChart3 class="w-4 h-4 text-muted-foreground" />
</div>
<p class="text-2xl font-bold mt-1">
{{ model.alias_count || 0 }}
{{ model.usage_count || 0 }}
</p>
</div>
</div>
@@ -455,105 +455,153 @@
<template v-else-if="providers.length > 0">
<!-- 桌面端表格 -->
<Table class="hidden sm:table">
<TableHeader>
<TableRow class="border-b border-border/60 hover:bg-transparent">
<TableHead class="h-10 font-semibold">
Provider
</TableHead>
<TableHead class="w-[120px] h-10 font-semibold">
能力
</TableHead>
<TableHead class="w-[180px] h-10 font-semibold">
价格 ($/M)
</TableHead>
<TableHead class="w-[80px] h-10 font-semibold text-center">
操作
</TableHead>
</TableRow>
</TableHeader>
<TableBody>
<TableRow
<TableHeader>
<TableRow class="border-b border-border/60 hover:bg-transparent">
<TableHead class="h-10 font-semibold">
Provider
</TableHead>
<TableHead class="w-[120px] h-10 font-semibold">
能力
</TableHead>
<TableHead class="w-[180px] h-10 font-semibold">
价格 ($/M)
</TableHead>
<TableHead class="w-[80px] h-10 font-semibold text-center">
操作
</TableHead>
</TableRow>
</TableHeader>
<TableBody>
<TableRow
v-for="provider in providers"
:key="provider.id"
class="border-b border-border/40 hover:bg-muted/30 transition-colors"
>
<TableCell class="py-3">
<div class="flex items-center gap-2">
<span
class="w-2 h-2 rounded-full shrink-0"
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
:title="provider.is_active ? '活跃' : '停用'"
/>
<span class="font-medium truncate">{{ provider.display_name }}</span>
</div>
</TableCell>
<TableCell class="py-3">
<div class="flex gap-0.5">
<Zap
v-if="provider.supports_streaming"
class="w-3.5 h-3.5 text-muted-foreground"
title="流式输出"
/>
<Eye
v-if="provider.supports_vision"
class="w-3.5 h-3.5 text-muted-foreground"
title="视觉理解"
/>
<Wrench
v-if="provider.supports_function_calling"
class="w-3.5 h-3.5 text-muted-foreground"
title="工具调用"
/>
</div>
</TableCell>
<TableCell class="py-3">
<div class="text-xs font-mono space-y-0.5">
<!-- Token 计费输入/输出 -->
<div v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0">
<span class="text-muted-foreground">输入/输出:</span>
<span class="ml-1">${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}</span>
<!-- 阶梯标记 -->
<span
v-if="(provider.tier_count || 1) > 1"
class="ml-1 text-muted-foreground"
title="阶梯计费"
>[阶梯]</span>
</div>
<!-- 缓存价格 -->
<div
v-if="(provider.cache_creation_price_per_1m || 0) > 0 || (provider.cache_read_price_per_1m || 0) > 0"
class="text-muted-foreground"
>
<span>缓存:</span>
<span class="ml-1">${{ (provider.cache_creation_price_per_1m || 0).toFixed(2) }}/${{ (provider.cache_read_price_per_1m || 0).toFixed(2) }}</span>
</div>
<!-- 1h 缓存价格 -->
<div
v-if="(provider.cache_1h_creation_price_per_1m || 0) > 0"
class="text-muted-foreground"
>
<span>1h 缓存:</span>
<span class="ml-1">${{ (provider.cache_1h_creation_price_per_1m || 0).toFixed(2) }}</span>
</div>
<!-- 按次计费 -->
<div v-if="(provider.price_per_request || 0) > 0">
<span class="text-muted-foreground">按次:</span>
<span class="ml-1">${{ (provider.price_per_request || 0).toFixed(3) }}/</span>
</div>
<!-- 无定价 -->
<span
v-if="!(provider.input_price_per_1m || 0) && !(provider.output_price_per_1m || 0) && !(provider.price_per_request || 0)"
class="text-muted-foreground"
>-</span>
</div>
</TableCell>
<TableCell class="py-3 text-center">
<div class="flex items-center justify-center gap-1">
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
title="编辑此关联"
@click="$emit('editProvider', provider)"
>
<Edit class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
:title="provider.is_active ? '停用此关联' : '启用此关联'"
@click="$emit('toggleProviderStatus', provider)"
>
<Power class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
title="删除此关联"
@click="$emit('deleteProvider', provider)"
>
<Trash2 class="w-3.5 h-3.5" />
</Button>
</div>
</TableCell>
</TableRow>
</TableBody>
</Table>
<!-- 移动端卡片列表 -->
<div class="sm:hidden divide-y divide-border/40">
<div
v-for="provider in providers"
:key="provider.id"
class="border-b border-border/40 hover:bg-muted/30 transition-colors"
class="p-4 space-y-3"
>
<TableCell class="py-3">
<div class="flex items-center gap-2">
<div class="flex items-start justify-between gap-3">
<div class="flex items-center gap-2 min-w-0">
<span
class="w-2 h-2 rounded-full shrink-0"
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
:title="provider.is_active ? '活跃' : '停用'"
/>
<span class="font-medium truncate">{{ provider.display_name }}</span>
</div>
</TableCell>
<TableCell class="py-3">
<div class="flex gap-0.5">
<Zap
v-if="provider.supports_streaming"
class="w-3.5 h-3.5 text-muted-foreground"
title="流式输出"
/>
<Eye
v-if="provider.supports_vision"
class="w-3.5 h-3.5 text-muted-foreground"
title="视觉理解"
/>
<Wrench
v-if="provider.supports_function_calling"
class="w-3.5 h-3.5 text-muted-foreground"
title="工具调用"
/>
</div>
</TableCell>
<TableCell class="py-3">
<div class="text-xs font-mono space-y-0.5">
<!-- Token 计费输入/输出 -->
<div v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0">
<span class="text-muted-foreground">输入/输出:</span>
<span class="ml-1">${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}</span>
<!-- 阶梯标记 -->
<span
v-if="(provider.tier_count || 1) > 1"
class="ml-1 text-muted-foreground"
title="阶梯计费"
>[阶梯]</span>
</div>
<!-- 缓存价格 -->
<div
v-if="(provider.cache_creation_price_per_1m || 0) > 0 || (provider.cache_read_price_per_1m || 0) > 0"
class="text-muted-foreground"
>
<span>缓存:</span>
<span class="ml-1">${{ (provider.cache_creation_price_per_1m || 0).toFixed(2) }}/${{ (provider.cache_read_price_per_1m || 0).toFixed(2) }}</span>
</div>
<!-- 1h 缓存价格 -->
<div
v-if="(provider.cache_1h_creation_price_per_1m || 0) > 0"
class="text-muted-foreground"
>
<span>1h 缓存:</span>
<span class="ml-1">${{ (provider.cache_1h_creation_price_per_1m || 0).toFixed(2) }}</span>
</div>
<!-- 按次计费 -->
<div v-if="(provider.price_per_request || 0) > 0">
<span class="text-muted-foreground">按次:</span>
<span class="ml-1">${{ (provider.price_per_request || 0).toFixed(3) }}/</span>
</div>
<!-- 无定价 -->
<span
v-if="!(provider.input_price_per_1m || 0) && !(provider.output_price_per_1m || 0) && !(provider.price_per_request || 0)"
class="text-muted-foreground"
>-</span>
</div>
</TableCell>
<TableCell class="py-3 text-center">
<div class="flex items-center justify-center gap-1">
<div class="flex items-center gap-1 shrink-0">
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
title="编辑此关联"
@click="$emit('editProvider', provider)"
>
<Edit class="w-3.5 h-3.5" />
@@ -562,7 +610,6 @@
variant="ghost"
size="icon"
class="h-7 w-7"
:title="provider.is_active ? '停用此关联' : '启用此关联'"
@click="$emit('toggleProviderStatus', provider)"
>
<Power class="w-3.5 h-3.5" />
@@ -571,82 +618,35 @@
variant="ghost"
size="icon"
class="h-7 w-7"
title="删除此关联"
@click="$emit('deleteProvider', provider)"
>
<Trash2 class="w-3.5 h-3.5" />
</Button>
</div>
</TableCell>
</TableRow>
</TableBody>
</Table>
<!-- 移动端卡片列表 -->
<div class="sm:hidden divide-y divide-border/40">
<div
v-for="provider in providers"
:key="provider.id"
class="p-4 space-y-3"
>
<div class="flex items-start justify-between gap-3">
<div class="flex items-center gap-2 min-w-0">
<span
class="w-2 h-2 rounded-full shrink-0"
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
/>
<span class="font-medium truncate">{{ provider.display_name }}</span>
</div>
<div class="flex items-center gap-1 shrink-0">
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
@click="$emit('editProvider', provider)"
<div class="flex items-center gap-3 text-xs">
<div class="flex gap-1">
<Zap
v-if="provider.supports_streaming"
class="w-3.5 h-3.5 text-muted-foreground"
/>
<Eye
v-if="provider.supports_vision"
class="w-3.5 h-3.5 text-muted-foreground"
/>
<Wrench
v-if="provider.supports_function_calling"
class="w-3.5 h-3.5 text-muted-foreground"
/>
</div>
<div
v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0"
class="text-muted-foreground font-mono"
>
<Edit class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
@click="$emit('toggleProviderStatus', provider)"
>
<Power class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
@click="$emit('deleteProvider', provider)"
>
<Trash2 class="w-3.5 h-3.5" />
</Button>
${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}
</div>
</div>
</div>
<div class="flex items-center gap-3 text-xs">
<div class="flex gap-1">
<Zap
v-if="provider.supports_streaming"
class="w-3.5 h-3.5 text-muted-foreground"
/>
<Eye
v-if="provider.supports_vision"
class="w-3.5 h-3.5 text-muted-foreground"
/>
<Wrench
v-if="provider.supports_function_calling"
class="w-3.5 h-3.5 text-muted-foreground"
/>
</div>
<div
v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0"
class="text-muted-foreground font-mono"
>
${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}
</div>
</div>
</div>
</div>
</template>
@@ -695,7 +695,8 @@ import {
Loader2,
RefreshCw,
Copy,
Layers
Layers,
BarChart3
} from 'lucide-vue-next'
import { useToast } from '@/composables/useToast'
import Card from '@/components/ui/card.vue'

View File

@@ -117,8 +117,12 @@
class="text-center py-6 text-muted-foreground border rounded-lg border-dashed"
>
<Tag class="w-8 h-8 mx-auto mb-2 opacity-50" />
<p class="text-sm">未配置映射</p>
<p class="text-xs mt-1">将只使用主模型名称</p>
<p class="text-sm">
未配置映射
</p>
<p class="text-xs mt-1">
将只使用主模型名称
</p>
</div>
</div>
</div>

View File

@@ -479,10 +479,25 @@ const groupedTimeline = computed<NodeGroup[]>(() => {
return groups
})
// 计算链路总耗时(从第一个节点开始到最后一个节点结束
// 计算链路总耗时(使用成功候选的 latency_ms 字段
// 优先使用 latency_ms因为它与 Usage.response_time_ms 使用相同的时间基准
// 避免 finished_at - started_at 带来的额外延迟(数据库操作时间)
const totalTraceLatency = computed(() => {
if (!timeline.value || timeline.value.length === 0) return 0
// 查找成功的候选,使用其 latency_ms
const successCandidate = timeline.value.find(c => c.status === 'success')
if (successCandidate?.latency_ms != null) {
return successCandidate.latency_ms
}
// 如果没有成功的候选,查找失败但有 latency_ms 的候选
const failedWithLatency = timeline.value.find(c => c.status === 'failed' && c.latency_ms != null)
if (failedWithLatency?.latency_ms != null) {
return failedWithLatency.latency_ms
}
// 回退:使用 finished_at - started_at 计算
let earliestStart: number | null = null
let latestEnd: number | null = null

View File

@@ -177,8 +177,9 @@
费用
</TableHead>
<TableHead class="h-12 font-semibold w-[70px] text-right">
<div class="inline-block max-w-[2rem] leading-tight">
响应时间
<div class="flex flex-col items-end text-xs gap-0.5">
<span>首字</span>
<span class="text-muted-foreground font-normal">总耗时</span>
</div>
</TableHead>
</TableRow>
@@ -356,15 +357,28 @@
</div>
</TableCell>
<TableCell class="text-right py-4 w-[70px]">
<span
<div
v-if="record.status === 'pending' || record.status === 'streaming'"
class="text-primary tabular-nums"
class="flex flex-col items-end text-xs gap-0.5"
>
{{ getElapsedTime(record) }}
</span>
<span v-else-if="record.response_time_ms">
{{ (record.response_time_ms / 1000).toFixed(2) }}s
</span>
<span class="text-primary tabular-nums">
{{ getElapsedTime(record) }}
</span>
</div>
<div
v-else-if="record.response_time_ms != null"
class="flex flex-col items-end text-xs gap-0.5"
>
<span
v-if="record.first_byte_time_ms != null"
class="tabular-nums"
>{{ (record.first_byte_time_ms / 1000).toFixed(2) }}s</span>
<span
v-else
class="text-muted-foreground"
>-</span>
<span class="text-muted-foreground tabular-nums">{{ (record.response_time_ms / 1000).toFixed(2) }}s</span>
</div>
<span
v-else
class="text-muted-foreground"

View File

@@ -78,6 +78,7 @@ export interface UsageRecord {
cost: number
actual_cost?: number
response_time_ms?: number
first_byte_time_ms?: number // 首字时间 (TTFB)
is_stream: boolean
status_code?: number
error_message?: string

View File

@@ -611,41 +611,42 @@ export const MOCK_GLOBAL_MODELS: GlobalModelResponse[] = [
id: 'gm-001',
name: 'claude-haiku-4-5-20251001',
display_name: 'claude-haiku-4-5',
description: 'Anthropic 最快速的 Claude 4 系列模型',
is_active: true,
default_tiered_pricing: {
tiers: [{ up_to: null, input_price_per_1m: 1.00, output_price_per_1m: 5.00, cache_creation_price_per_1m: 1.25, cache_read_price_per_1m: 0.1 }]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'Anthropic 最快速的 Claude 4 系列模型'
},
provider_count: 3,
alias_count: 2,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-002',
name: 'claude-opus-4-5-20251101',
display_name: 'claude-opus-4-5',
description: 'Anthropic 最强大的模型',
is_active: true,
default_tiered_pricing: {
tiers: [{ up_to: null, input_price_per_1m: 5.00, output_price_per_1m: 25.00, cache_creation_price_per_1m: 6.25, cache_read_price_per_1m: 0.5 }]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'Anthropic 最强大的模型'
},
provider_count: 2,
alias_count: 1,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-003',
name: 'claude-sonnet-4-5-20250929',
display_name: 'claude-sonnet-4-5',
description: 'Anthropic 平衡型模型,支持 1h 缓存和 CLI 1M 上下文',
is_active: true,
default_tiered_pricing: {
tiers: [
@@ -677,116 +678,124 @@ export const MOCK_GLOBAL_MODELS: GlobalModelResponse[] = [
}
]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'Anthropic 平衡型模型,支持 1h 缓存和 CLI 1M 上下文'
},
supported_capabilities: ['cache_1h', 'cli_1m'],
provider_count: 3,
alias_count: 2,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-004',
name: 'gemini-3-pro-image-preview',
display_name: 'gemini-3-pro-image-preview',
description: 'Google Gemini 3 Pro 图像生成预览版',
is_active: true,
default_price_per_request: 0.300,
default_tiered_pricing: {
tiers: []
},
default_supports_vision: true,
default_supports_function_calling: false,
default_supports_streaming: true,
default_supports_image_generation: true,
config: {
streaming: true,
vision: true,
function_calling: false,
image_generation: true,
description: 'Google Gemini 3 Pro 图像生成预览版'
},
provider_count: 1,
alias_count: 0,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-005',
name: 'gemini-3-pro-preview',
display_name: 'gemini-3-pro-preview',
description: 'Google Gemini 3 Pro 预览版',
is_active: true,
default_tiered_pricing: {
tiers: [{ up_to: null, input_price_per_1m: 2.00, output_price_per_1m: 12.00 }]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'Google Gemini 3 Pro 预览版'
},
provider_count: 1,
alias_count: 0,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-006',
name: 'gpt-5.1',
display_name: 'gpt-5.1',
description: 'OpenAI GPT-5.1 模型',
is_active: true,
default_tiered_pricing: {
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'OpenAI GPT-5.1 模型'
},
provider_count: 2,
alias_count: 1,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-007',
name: 'gpt-5.1-codex',
display_name: 'gpt-5.1-codex',
description: 'OpenAI GPT-5.1 Codex 代码专用模型',
is_active: true,
default_tiered_pricing: {
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'OpenAI GPT-5.1 Codex 代码专用模型'
},
provider_count: 2,
alias_count: 0,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-008',
name: 'gpt-5.1-codex-max',
display_name: 'gpt-5.1-codex-max',
description: 'OpenAI GPT-5.1 Codex Max 代码专用增强版',
is_active: true,
default_tiered_pricing: {
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'OpenAI GPT-5.1 Codex Max 代码专用增强版'
},
provider_count: 2,
alias_count: 0,
created_at: '2024-01-01T00:00:00Z'
},
{
id: 'gm-009',
name: 'gpt-5.1-codex-mini',
display_name: 'gpt-5.1-codex-mini',
description: 'OpenAI GPT-5.1 Codex Mini 轻量代码模型',
is_active: true,
default_tiered_pricing: {
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
},
default_supports_vision: true,
default_supports_function_calling: true,
default_supports_streaming: true,
default_supports_extended_thinking: true,
config: {
streaming: true,
vision: true,
function_calling: true,
extended_thinking: true,
description: 'OpenAI GPT-5.1 Codex Mini 轻量代码模型'
},
provider_count: 2,
alias_count: 0,
created_at: '2024-01-01T00:00:00Z'
}
]

View File

@@ -1000,17 +1000,11 @@ const mockHandlers: Record<string, (config: AxiosRequestConfig) => Promise<Axios
id: m.id,
name: m.name,
display_name: m.display_name,
description: m.description,
icon_url: null,
is_active: m.is_active,
default_tiered_pricing: m.default_tiered_pricing,
default_price_per_request: null,
default_supports_vision: m.default_supports_vision,
default_supports_function_calling: m.default_supports_function_calling,
default_supports_streaming: m.default_supports_streaming,
default_supports_extended_thinking: m.default_supports_extended_thinking || false,
default_supports_image_generation: false,
supported_capabilities: null
default_price_per_request: m.default_price_per_request,
supported_capabilities: m.supported_capabilities,
config: m.config
})),
total: MOCK_GLOBAL_MODELS.length
})

View File

@@ -1169,4 +1169,26 @@ body[theme-mode='dark'] .literary-annotation {
.scrollbar-hide::-webkit-scrollbar {
display: none;
}
}
.scrollbar-thin {
scrollbar-width: thin;
scrollbar-color: hsl(var(--border)) transparent;
}
.scrollbar-thin::-webkit-scrollbar {
width: 6px;
}
.scrollbar-thin::-webkit-scrollbar-track {
background: transparent;
}
.scrollbar-thin::-webkit-scrollbar-thumb {
background-color: hsl(var(--border));
border-radius: 3px;
}
.scrollbar-thin::-webkit-scrollbar-thumb:hover {
background-color: hsl(var(--muted-foreground) / 0.5);
}
}

View File

@@ -935,7 +935,10 @@ onBeforeUnmount(() => {
:key="`${index}-${aliasIndex}`"
>
<TableCell>
<Badge variant="outline" class="text-xs">
<Badge
variant="outline"
class="text-xs"
>
{{ mapping.provider_name }}
</Badge>
</TableCell>
@@ -981,7 +984,10 @@ onBeforeUnmount(() => {
class="p-4 space-y-2"
>
<div class="flex items-center justify-between">
<Badge variant="outline" class="text-xs">
<Badge
variant="outline"
class="text-xs"
>
{{ mapping.provider_name }}
</Badge>
<div class="flex items-center gap-2">

View File

@@ -111,9 +111,6 @@
<TableHead class="w-[80px] text-center">
提供商
</TableHead>
<TableHead class="w-[70px] text-center">
别名/映射
</TableHead>
<TableHead class="w-[80px] text-center">
调用次数
</TableHead>
@@ -128,7 +125,7 @@
<TableBody>
<TableRow v-if="loading">
<TableCell
colspan="8"
colspan="7"
class="text-center py-8"
>
<Loader2 class="w-6 h-6 animate-spin mx-auto" />
@@ -136,7 +133,7 @@
</TableRow>
<TableRow v-else-if="filteredGlobalModels.length === 0">
<TableCell
colspan="8"
colspan="7"
class="text-center py-8 text-muted-foreground"
>
没有找到匹配的模型
@@ -171,27 +168,27 @@
<div class="space-y-1 w-fit">
<div class="flex flex-wrap gap-1">
<Zap
v-if="model.default_supports_streaming"
v-if="model.config?.streaming !== false"
class="w-4 h-4 text-muted-foreground"
title="流式输出"
/>
<Image
v-if="model.default_supports_image_generation"
v-if="model.config?.image_generation === true"
class="w-4 h-4 text-muted-foreground"
title="图像生成"
/>
<Eye
v-if="model.default_supports_vision"
v-if="model.config?.vision === true"
class="w-4 h-4 text-muted-foreground"
title="视觉理解"
/>
<Wrench
v-if="model.default_supports_function_calling"
v-if="model.config?.function_calling === true"
class="w-4 h-4 text-muted-foreground"
title="工具调用"
/>
<Brain
v-if="model.default_supports_extended_thinking"
v-if="model.config?.extended_thinking === true"
class="w-4 h-4 text-muted-foreground"
title="深度思考"
/>
@@ -244,11 +241,6 @@
{{ model.provider_count || 0 }}
</Badge>
</TableCell>
<TableCell class="text-center">
<Badge variant="secondary">
{{ model.alias_count || 0 }}
</Badge>
</TableCell>
<TableCell class="text-center">
<span class="text-sm font-mono">{{ formatUsageCount(model.usage_count || 0) }}</span>
</TableCell>
@@ -369,23 +361,23 @@
<!-- 第二行能力图标 -->
<div class="flex flex-wrap gap-1.5">
<Zap
v-if="model.default_supports_streaming"
v-if="model.config?.streaming !== false"
class="w-4 h-4 text-muted-foreground"
/>
<Image
v-if="model.default_supports_image_generation"
v-if="model.config?.image_generation === true"
class="w-4 h-4 text-muted-foreground"
/>
<Eye
v-if="model.default_supports_vision"
v-if="model.config?.vision === true"
class="w-4 h-4 text-muted-foreground"
/>
<Wrench
v-if="model.default_supports_function_calling"
v-if="model.config?.function_calling === true"
class="w-4 h-4 text-muted-foreground"
/>
<Brain
v-if="model.default_supports_extended_thinking"
v-if="model.config?.extended_thinking === true"
class="w-4 h-4 text-muted-foreground"
/>
</div>
@@ -393,7 +385,6 @@
<!-- 第三行统计信息 -->
<div class="flex flex-wrap items-center gap-3 text-xs text-muted-foreground">
<span>提供商 {{ model.provider_count || 0 }}</span>
<span>别名 {{ model.alias_count || 0 }}</span>
<span>调用 {{ formatUsageCount(model.usage_count || 0) }}</span>
<span
v-if="getFirstTierPrice(model, 'input') || getFirstTierPrice(model, 'output')"
@@ -1022,19 +1013,19 @@ const filteredGlobalModels = computed(() => {
// 能力筛选
if (capabilityFilters.value.streaming) {
result = result.filter(m => m.default_supports_streaming)
result = result.filter(m => m.config?.streaming !== false)
}
if (capabilityFilters.value.imageGeneration) {
result = result.filter(m => m.default_supports_image_generation)
result = result.filter(m => m.config?.image_generation === true)
}
if (capabilityFilters.value.vision) {
result = result.filter(m => m.default_supports_vision)
result = result.filter(m => m.config?.vision === true)
}
if (capabilityFilters.value.toolUse) {
result = result.filter(m => m.default_supports_function_calling)
result = result.filter(m => m.config?.function_calling === true)
}
if (capabilityFilters.value.extendedThinking) {
result = result.filter(m => m.default_supports_extended_thinking)
result = result.filter(m => m.config?.extended_thinking === true)
}
return result

View File

@@ -226,8 +226,8 @@
<div
v-for="announcement in announcements"
:key="announcement.id"
class="p-4 space-y-2 cursor-pointer transition-colors"
:class="[
'p-4 space-y-2 cursor-pointer transition-colors',
announcement.is_read ? 'hover:bg-muted/30' : 'bg-primary/5 hover:bg-primary/10'
]"
@click="viewAnnouncementDetail(announcement)"

View File

@@ -165,17 +165,17 @@
<TableCell class="py-4">
<div class="flex gap-1.5">
<Eye
v-if="model.default_supports_vision"
v-if="model.config?.vision === true"
class="w-4 h-4 text-muted-foreground"
title="Vision"
/>
<Wrench
v-if="model.default_supports_function_calling"
v-if="model.config?.function_calling === true"
class="w-4 h-4 text-muted-foreground"
title="Tool Use"
/>
<Brain
v-if="model.default_supports_extended_thinking"
v-if="model.config?.extended_thinking === true"
class="w-4 h-4 text-muted-foreground"
title="Extended Thinking"
/>
@@ -253,15 +253,15 @@
<!-- 第二行能力图标 -->
<div class="flex gap-1.5">
<Eye
v-if="model.default_supports_vision"
v-if="model.config?.vision === true"
class="w-4 h-4 text-muted-foreground"
/>
<Wrench
v-if="model.default_supports_function_calling"
v-if="model.config?.function_calling === true"
class="w-4 h-4 text-muted-foreground"
/>
<Brain
v-if="model.default_supports_extended_thinking"
v-if="model.config?.extended_thinking === true"
class="w-4 h-4 text-muted-foreground"
/>
</div>
@@ -485,13 +485,13 @@ const filteredModels = computed(() => {
// 能力筛选
if (capabilityFilters.value.vision) {
result = result.filter(m => m.default_supports_vision)
result = result.filter(m => m.config?.vision === true)
}
if (capabilityFilters.value.toolUse) {
result = result.filter(m => m.default_supports_function_calling)
result = result.filter(m => m.config?.function_calling === true)
}
if (capabilityFilters.value.extendedThinking) {
result = result.filter(m => m.default_supports_extended_thinking)
result = result.filter(m => m.config?.extended_thinking === true)
}
return result

View File

@@ -38,10 +38,10 @@
</button>
</div>
<p
v-if="model.description"
v-if="model.config?.description"
class="text-xs text-muted-foreground"
>
{{ model.description }}
{{ model.config?.description }}
</p>
</div>
<Button
@@ -73,10 +73,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_streaming ?? false ? 'default' : 'secondary'"
:variant="model.config?.streaming !== false ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_streaming ?? false ? '支持' : '不支持' }}
{{ model.config?.streaming !== false ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -90,10 +90,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_image_generation ?? false ? 'default' : 'secondary'"
:variant="model.config?.image_generation === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_image_generation ?? false ? '支持' : '不支持' }}
{{ model.config?.image_generation === true ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -107,10 +107,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_vision ?? false ? 'default' : 'secondary'"
:variant="model.config?.vision === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_vision ?? false ? '支持' : '不支持' }}
{{ model.config?.vision === true ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -124,10 +124,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_function_calling ?? false ? 'default' : 'secondary'"
:variant="model.config?.function_calling === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_function_calling ?? false ? '支持' : '不支持' }}
{{ model.config?.function_calling === true ? '支持' : '不支持' }}
</Badge>
</div>
<div class="flex items-center gap-2 p-3 rounded-lg border">
@@ -141,10 +141,10 @@
</p>
</div>
<Badge
:variant="model.default_supports_extended_thinking ?? false ? 'default' : 'secondary'"
:variant="model.config?.extended_thinking === true ? 'default' : 'secondary'"
class="text-xs"
>
{{ model.default_supports_extended_thinking ?? false ? '支持' : '不支持' }}
{{ model.config?.extended_thinking === true ? '支持' : '不支持' }}
</Badge>
</div>
</div>

12
migrate.sh Executable file
View File

@@ -0,0 +1,12 @@
#!/bin/bash
# 数据库迁移脚本 - 在 Docker 容器内执行 Alembic 迁移
set -e
CONTAINER_NAME="aether-app"
echo "Running database migrations in container: $CONTAINER_NAME"
docker exec $CONTAINER_NAME alembic upgrade head
echo "Database migration completed successfully"

View File

@@ -5,6 +5,7 @@
from fastapi import APIRouter
from .catalog import router as catalog_router
from .external import router as external_router
from .global_models import router as global_models_router
router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"])
@@ -12,3 +13,4 @@ router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"]
# 挂载子路由
router.include_router(catalog_router)
router.include_router(global_models_router)
router.include_router(external_router)

View File

@@ -72,10 +72,12 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
for gm in global_models:
gm_id = gm.id
provider_entries: List[ModelCatalogProviderDetail] = []
# 从 config JSON 读取能力标志
gm_config = gm.config or {}
capability_flags = {
"supports_vision": gm.default_supports_vision or False,
"supports_function_calling": gm.default_supports_function_calling or False,
"supports_streaming": gm.default_supports_streaming or False,
"supports_vision": gm_config.get("vision", False),
"supports_function_calling": gm_config.get("function_calling", False),
"supports_streaming": gm_config.get("streaming", True),
}
# 遍历该 GlobalModel 的所有关联提供商
@@ -140,7 +142,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
ModelCatalogItem(
global_model_name=gm.name,
display_name=gm.display_name,
description=gm.description,
description=gm_config.get("description"),
providers=provider_entries,
price_range=price_range,
total_providers=len(provider_entries),

View File

@@ -0,0 +1,141 @@
"""
models.dev 外部模型数据代理
"""
import json
from typing import Any, Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from src.clients import get_redis_client
from src.core.logger import logger
from src.models.database import User
from src.utils.auth_utils import require_admin
router = APIRouter()
CACHE_KEY = "aether:external:models_dev"
CACHE_TTL = 15 * 60 # 15 分钟
# 标记官方/一手提供商,前端可据此过滤第三方转售商
OFFICIAL_PROVIDERS = {
"anthropic", # Claude 官方
"openai", # OpenAI 官方
"google", # Gemini 官方
"google-vertex", # Google Vertex AI
"azure", # Azure OpenAI
"amazon-bedrock", # AWS Bedrock
"xai", # Grok 官方
"meta", # Llama 官方
"deepseek", # DeepSeek 官方
"mistral", # Mistral 官方
"cohere", # Cohere 官方
"zhipuai", # 智谱 AI 官方
"alibaba", # 阿里云(通义千问)
"minimax", # MiniMax 官方
"moonshot", # 月之暗面Kimi
"baichuan", # 百川智能
"ai21", # AI21 Labs
}
async def _get_cached_data() -> Optional[dict[str, Any]]:
"""从 Redis 获取缓存数据"""
redis = await get_redis_client()
if redis is None:
return None
try:
cached = await redis.get(CACHE_KEY)
if cached:
result: dict[str, Any] = json.loads(cached)
return result
except Exception as e:
logger.warning(f"读取 models.dev 缓存失败: {e}")
return None
async def _set_cached_data(data: dict) -> None:
"""将数据写入 Redis 缓存"""
redis = await get_redis_client()
if redis is None:
return
try:
await redis.setex(CACHE_KEY, CACHE_TTL, json.dumps(data, ensure_ascii=False))
except Exception as e:
logger.warning(f"写入 models.dev 缓存失败: {e}")
def _mark_official_providers(data: dict[str, Any]) -> dict[str, Any]:
"""为每个提供商标记是否为官方"""
result = {}
for provider_id, provider_data in data.items():
result[provider_id] = {
**provider_data,
"official": provider_id in OFFICIAL_PROVIDERS,
}
return result
@router.get("/external")
async def get_external_models(_: User = Depends(require_admin)) -> JSONResponse:
"""
获取 models.dev 的模型数据(代理请求,解决跨域问题)
数据缓存 15 分钟(使用 Redis多 worker 共享)
每个提供商会标记 official 字段,前端可据此过滤
"""
# 检查缓存
cached = await _get_cached_data()
if cached is not None:
# 兼容旧缓存:如果没有 official 字段则补全并回写
try:
needs_mark = False
for provider_data in cached.values():
if not isinstance(provider_data, dict) or "official" not in provider_data:
needs_mark = True
break
if needs_mark:
marked_cached = _mark_official_providers(cached)
await _set_cached_data(marked_cached)
return JSONResponse(content=marked_cached)
except Exception as e:
logger.warning(f"处理 models.dev 缓存数据失败,将直接返回原缓存: {e}")
return JSONResponse(content=cached)
# 从 models.dev 获取数据
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get("https://models.dev/api.json")
response.raise_for_status()
data = response.json()
# 标记官方提供商
marked_data = _mark_official_providers(data)
# 写入缓存
await _set_cached_data(marked_data)
return JSONResponse(content=marked_data)
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="请求 models.dev 超时")
except httpx.HTTPStatusError as e:
raise HTTPException(
status_code=502, detail=f"models.dev 返回错误: {e.response.status_code}"
)
except Exception as e:
raise HTTPException(status_code=502, detail=f"获取外部模型数据失败: {str(e)}")
@router.delete("/external/cache")
async def clear_external_models_cache(_: User = Depends(require_admin)) -> dict:
"""清除 models.dev 缓存"""
redis = await get_redis_client()
if redis is None:
return {"cleared": False, "message": "Redis 未启用"}
try:
await redis.delete(CACHE_KEY)
return {"cleared": True}
except Exception as e:
raise HTTPException(status_code=500, detail=f"清除缓存失败: {str(e)}")

View File

@@ -187,21 +187,15 @@ class AdminCreateGlobalModelAdapter(AdminApiAdapter):
db=context.db,
name=self.payload.name,
display_name=self.payload.display_name,
description=self.payload.description,
official_url=self.payload.official_url,
icon_url=self.payload.icon_url,
is_active=self.payload.is_active,
# 按次计费配置
default_price_per_request=self.payload.default_price_per_request,
# 阶梯计费配置
default_tiered_pricing=tiered_pricing_dict,
# 默认能力配置
default_supports_vision=self.payload.default_supports_vision,
default_supports_function_calling=self.payload.default_supports_function_calling,
default_supports_streaming=self.payload.default_supports_streaming,
default_supports_extended_thinking=self.payload.default_supports_extended_thinking,
# Key 能力配置
supported_capabilities=self.payload.supported_capabilities,
# 模型配置JSON
config=self.payload.config,
)
logger.info(f"GlobalModel 已创建: id={global_model.id} name={global_model.name}")

View File

@@ -628,6 +628,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
"actual_cost": actual_cost,
"rate_multiplier": rate_multiplier,
"response_time_ms": usage.response_time_ms,
"first_byte_time_ms": usage.first_byte_time_ms, # 首字时间 (TTFB)
"created_at": usage.created_at.isoformat(),
"is_stream": usage.is_stream,
"input_price_per_1m": usage.input_price_per_1m,
@@ -738,6 +739,7 @@ class AdminUsageDetailAdapter(AdminApiAdapter):
"status_code": usage_record.status_code,
"error_message": usage_record.error_message,
"response_time_ms": usage_record.response_time_ms,
"first_byte_time_ms": usage_record.first_byte_time_ms, # 首字时间 (TTFB)
"created_at": usage_record.created_at.isoformat() if usage_record.created_at else None,
"request_headers": usage_record.request_headers,
"request_body": usage_record.get_request_body(),

View File

@@ -65,6 +65,21 @@ class ModelInfo:
created_at: Optional[str] # ISO 格式
created_timestamp: int # Unix 时间戳
provider_name: str
# 能力配置
streaming: bool = True
vision: bool = False
function_calling: bool = False
extended_thinking: bool = False
image_generation: bool = False
structured_output: bool = False
# 规格参数
context_limit: Optional[int] = None
output_limit: Optional[int] = None
# 元信息
family: Optional[str] = None
knowledge_cutoff: Optional[str] = None
input_modalities: Optional[list[str]] = None
output_modalities: Optional[list[str]] = None
def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]:
@@ -181,13 +196,19 @@ def _extract_model_info(model: Any) -> ModelInfo:
global_model = model.global_model
model_id: str = global_model.name if global_model else model.provider_model_name
display_name: str = global_model.display_name if global_model else model.provider_model_name
description: Optional[str] = global_model.description if global_model else None
created_at: Optional[str] = (
model.created_at.strftime("%Y-%m-%dT%H:%M:%SZ") if model.created_at else None
)
created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0
provider_name: str = model.provider.name if model.provider else "unknown"
# 从 GlobalModel.config 提取配置信息
config: dict = {}
description: Optional[str] = None
if global_model:
config = global_model.config or {}
description = config.get("description")
return ModelInfo(
id=model_id,
display_name=display_name,
@@ -195,6 +216,21 @@ def _extract_model_info(model: Any) -> ModelInfo:
created_at=created_at,
created_timestamp=created_timestamp,
provider_name=provider_name,
# 能力配置
streaming=config.get("streaming", True),
vision=config.get("vision", False),
function_calling=config.get("function_calling", False),
extended_thinking=config.get("extended_thinking", False),
image_generation=config.get("image_generation", False),
structured_output=config.get("structured_output", False),
# 规格参数
context_limit=config.get("context_limit"),
output_limit=config.get("output_limit"),
# 元信息
family=config.get("family"),
knowledge_cutoff=config.get("knowledge_cutoff"),
input_modalities=config.get("input_modalities"),
output_modalities=config.get("output_modalities"),
)

View File

@@ -100,6 +100,8 @@ class MessageTelemetry:
cache_read_tokens: int = 0,
is_stream: bool = False,
provider_request_headers: Optional[Dict[str, Any]] = None,
# 时间指标
first_byte_time_ms: Optional[int] = None, # 首字时间/TTFB
# Provider 侧追踪信息(用于记录真实成本)
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
@@ -133,6 +135,7 @@ class MessageTelemetry:
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms, # 传递首字时间
status_code=status_code,
request_headers=request_headers,
request_body=request_body,

View File

@@ -34,6 +34,7 @@ from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.stream_processor import StreamProcessor
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
from src.api.handlers.base.utils import build_sse_headers
from src.config.settings import config
from src.core.exceptions import (
EmbeddedErrorException,
@@ -365,7 +366,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
ctx,
original_headers,
original_request_body,
self.elapsed_ms(),
self.start_time, # 传入开始时间,让 telemetry 在流结束后计算响应时间
)
# 创建监控流
@@ -378,6 +379,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
return StreamingResponse(
monitored_stream,
media_type="text/event-stream",
headers=build_sse_headers(),
background=background_tasks,
)
@@ -473,12 +475,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
stream_response.raise_for_status()
# 创建行迭代器
line_iterator = stream_response.aiter_lines()
# 使用字节流迭代器(避免 aiter_lines 的性能问题)
# aiter_raw() 返回原始数据块,无缓冲,实现真正的流式传输
byte_iterator = stream_response.aiter_raw()
# 预读检测嵌套错误
prefetched_lines = await stream_processor.prefetch_and_check_error(
line_iterator,
prefetched_chunks = await stream_processor.prefetch_and_check_error(
byte_iterator,
provider,
endpoint,
ctx,
@@ -503,13 +506,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
await http_client.aclose()
raise
# 创建流生成器
# 创建流生成器(传入字节流迭代器)
return stream_processor.create_response_stream(
ctx,
line_iterator,
byte_iterator,
response_ctx,
http_client,
prefetched_lines,
prefetched_chunks,
start_time=self.start_time,
)
async def _record_stream_failure(

View File

@@ -11,17 +11,15 @@ CLI Message Handler 通用基类
"""
import asyncio
import codecs
import json
import time
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Optional,
Tuple,
)
import httpx
@@ -35,6 +33,8 @@ from src.api.handlers.base.base_handler import (
)
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.utils import build_sse_headers
# 直接从具体模块导入,避免循环依赖
from src.api.handlers.base.response_parser import (
@@ -61,63 +61,6 @@ from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser
@dataclass
class StreamContext:
"""流式请求的上下文信息"""
# 请求信息
model: str = "unknown" # 用户请求的原始模型名
mapped_model: Optional[str] = None # 映射后的目标模型名(如果发生了映射)
api_format: str = ""
request_id: str = ""
# 用户信息(提前提取避免 Session detached
user_id: int = 0
api_key_id: int = 0
# 统计信息
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0 # cache_read_input_tokens
cache_creation_tokens: int = 0 # cache_creation_input_tokens
collected_text: str = ""
response_id: Optional[str] = None
final_usage: Optional[Dict[str, Any]] = None
final_response: Optional[Dict[str, Any]] = None
parsed_chunks: list = field(default_factory=list)
# 流状态
start_time: float = field(default_factory=time.time)
chunk_count: int = 0
data_count: int = 0
has_completion: bool = False
# 响应信息
status_code: int = 200
response_headers: Dict[str, str] = field(default_factory=dict)
# 请求信息(发送给 Provider 的)
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None # 实际发送的请求体
# Provider 信息
provider_name: Optional[str] = None
provider_id: Optional[str] = None # Provider ID用于记录真实成本
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
attempt_id: Optional[str] = None
attempt_synced: bool = False
error_message: Optional[str] = None
# 格式转换信息
provider_api_format: str = "" # Provider 的 API 格式(用于响应转换)
client_api_format: str = "" # 客户端请求的 API 格式
# Provider 响应元数据(存储 provider 返回的额外信息,如 Gemini 的 modelVersion
response_metadata: Dict[str, Any] = field(default_factory=dict)
class CliMessageHandlerBase(BaseMessageHandler):
"""
CLI 格式消息处理器基类
@@ -409,6 +352,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
return StreamingResponse(
monitored_stream,
media_type="text/event-stream",
headers=build_sse_headers(),
background=background_tasks,
)
@@ -433,7 +377,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
ctx.chunk_count = 0
ctx.data_count = 0
ctx.has_completion = False
ctx.collected_text = ""
ctx._collected_text_parts = [] # 重置文本收集
ctx.input_tokens = 0
ctx.output_tokens = 0
ctx.cached_tokens = 0
@@ -521,12 +465,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
stream_response.raise_for_status()
# 创建行迭代器(只创建一次,后续会继续使用
line_iterator = stream_response.aiter_lines()
# 使用字节流迭代器(避免 aiter_lines 的性能问题
byte_iterator = stream_response.aiter_raw()
# 预读第一个数据块检测嵌套错误HTTP 200 但响应体包含错误)
prefetched_lines = await self._prefetch_and_check_embedded_error(
line_iterator, provider, endpoint, ctx
prefetched_chunks = await self._prefetch_and_check_embedded_error(
byte_iterator, provider, endpoint, ctx
)
except httpx.HTTPStatusError as e:
@@ -551,10 +495,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 创建流生成器(带预读数据,使用同一个迭代器)
return self._create_response_stream_with_prefetch(
ctx,
line_iterator,
byte_iterator,
response_ctx,
http_client,
prefetched_lines,
prefetched_chunks,
)
async def _create_response_stream(
@@ -564,58 +508,75 @@ class CliMessageHandlerBase(BaseMessageHandler):
response_ctx: Any,
http_client: httpx.AsyncClient,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器"""
"""创建响应流生成器(使用字节流)"""
try:
sse_parser = SSEEventParser()
last_data_time = time.time()
streaming_status_updated = False
buffer = b""
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx)
async for line in stream_response.aiter_lines():
async for chunk in stream_response.aiter_raw():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming(ctx.request_id)
streaming_status_updated = True
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
yield b"\n"
continue
continue
ctx.chunk_count += 1
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return # 结束生成器
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
yield b"\n"
continue
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
ctx.chunk_count += 1
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return # 结束生成器
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
@@ -689,7 +650,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
async def _prefetch_and_check_embedded_error(
self,
line_iterator: Any,
byte_iterator: Any,
provider: Provider,
endpoint: ProviderEndpoint,
ctx: StreamContext,
@@ -703,20 +664,25 @@ class CliMessageHandlerBase(BaseMessageHandler):
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
Args:
line_iterator: 行迭代器aiter_lines() 返回的迭代器
byte_iterator: 字节流迭代器
provider: Provider 对象
endpoint: Endpoint 对象
ctx: 流上下文
Returns:
预读的列表(需要在后续流中先输出)
预读的字节块列表(需要在后续流中先输出)
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
"""
prefetched_lines: list = []
prefetched_chunks: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误
buffer = b""
line_count = 0
should_stop = False
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
try:
# 获取对应格式的解析器
@@ -729,69 +695,86 @@ class CliMessageHandlerBase(BaseMessageHandler):
else:
provider_parser = self.parser
line_count = 0
async for line in line_iterator:
prefetched_lines.append(line)
line_count += 1
async for chunk in byte_iterator:
prefetched_chunks.append(chunk)
buffer += chunk
# 解析数据
normalized_line = line.rstrip("\r")
# 尝试按行解析缓冲区
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
# 检测 HTML 响应base_url 配置错误的常见症状)
lower_line = normalized_line.lower()
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
)
line_count += 1
normalized_line = line.rstrip("\r")
if not normalized_line or normalized_line.startswith(":"):
# 空行或注释行,继续预读
if line_count >= max_prefetch_lines:
# 检测 HTML 响应base_url 配置错误的常见症状)
lower_line = normalized_line.lower()
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
)
if not normalized_line or normalized_line.startswith(":"):
# 空行或注释行,继续预读
if line_count >= max_prefetch_lines:
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
if data_str == "[DONE]":
should_stop = True
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
try:
data = json.loads(data_str)
except json.JSONDecodeError:
# 不是有效 JSON可能是部分数据继续
if line_count >= max_prefetch_lines:
break
continue
if data_str == "[DONE]":
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
should_stop = True
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
# 不是有效 JSON可能是部分数据继续
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
break
if should_stop or line_count >= max_prefetch_lines:
break
except EmbeddedErrorException:
# 重新抛出嵌套错误
@@ -800,112 +783,168 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_lines
return prefetched_chunks
async def _create_response_stream_with_prefetch(
self,
ctx: StreamContext,
line_iterator: Any,
byte_iterator: Any,
response_ctx: Any,
http_client: httpx.AsyncClient,
prefetched_lines: list,
prefetched_chunks: list,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器(带预读数据)"""
"""创建响应流生成器(带预读数据,使用字节流"""
try:
sse_parser = SSEEventParser()
last_data_time = time.time()
buffer = b""
first_yield = True # 标记是否是第一次 yield
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx)
# 在第一次输出数据前更新状态为 streaming
if prefetched_lines:
if prefetched_chunks:
self._update_usage_to_streaming(ctx.request_id)
# 先处理预读的数据
for line in prefetched_lines:
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
# 先处理预读的字节块
for chunk in prefetched_chunks:
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield b"\n"
continue
ctx.chunk_count += 1
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (line + "\n").encode("utf-8")
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
yield b"\n"
continue
ctx.chunk_count += 1
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
if ctx.data_count > 0:
last_data_time = time.time()
if ctx.data_count > 0:
last_data_time = time.time()
# 继续处理剩余的流数据(使用同一个迭代器)
async for line in line_iterator:
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
async for chunk in byte_iterator:
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield b"\n"
continue
ctx.chunk_count += 1
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (line + "\n").encode("utf-8")
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
yield b"\n"
continue
ctx.chunk_count += 1
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
if ctx.data_count > 0:
last_data_time = time.time()
if ctx.data_count > 0:
last_data_time = time.time()
# 处理剩余事件
flushed_events = sse_parser.flush()
@@ -1034,7 +1073,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 提取文本内容
text = self.parser.extract_text_content(data)
if text:
ctx.collected_text += text
ctx.append_text(text)
# 检查完成事件
if event_type in ("response.completed", "message_stop"):
@@ -1086,9 +1125,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
) -> None:
"""在流完成后记录统计信息"""
try:
await asyncio.sleep(0.1)
# 使用 self.start_time 作为时间基准,与首字时间保持一致
# 注意:不要把统计延迟算进响应时间里
response_time_ms = int((time.time() - self.start_time) * 1000)
response_time_ms = int((time.time() - ctx.start_time) * 1000)
await asyncio.sleep(0.1)
if not ctx.provider_name:
logger.warning(f"[{ctx.request_id}] 流式请求失败,未选中提供商")
@@ -1168,6 +1209,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
input_tokens=actual_input_tokens,
output_tokens=ctx.output_tokens,
response_time_ms=response_time_ms,
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
status_code=ctx.status_code,
request_headers=original_headers,
request_body=actual_request_body,
@@ -1188,9 +1230,18 @@ class CliMessageHandlerBase(BaseMessageHandler):
response_metadata=ctx.response_metadata if ctx.response_metadata else None,
)
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
# 简洁的请求完成摘要
logger.info(f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | "
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}")
# 简洁的请求完成摘要(两行格式)
line1 = (
f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name}"
)
if ctx.first_byte_time_ms:
line1 += f" | TTFB: {ctx.first_byte_time_ms}ms"
line2 = (
f" Total: {response_time_ms}ms | "
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}"
)
logger.info(f"{line1}\n{line2}")
# 更新候选记录的最终状态和延迟时间
# 注意RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
@@ -1242,7 +1293,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
original_request_body: Dict[str, Any],
) -> None:
"""记录流式请求失败"""
response_time_ms = int((time.time() - ctx.start_time) * 1000)
# 使用 self.start_time 作为时间基准,与首字时间保持一致
response_time_ms = int((time.time() - self.start_time) * 1000)
status_code = 503
if isinstance(error, ProviderAuthException):

View File

@@ -13,6 +13,7 @@ from src.api.handlers.base.response_parser import (
ResponseParser,
StreamStats,
)
from src.api.handlers.base.utils import extract_cache_creation_tokens
def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[str, Any]]]:
@@ -252,7 +253,7 @@ class ClaudeResponseParser(ResponseParser):
usage = response.get("usage", {})
result.input_tokens = usage.get("input_tokens", 0)
result.output_tokens = usage.get("output_tokens", 0)
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0)
result.cache_creation_tokens = extract_cache_creation_tokens(usage)
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
# 检查错误(支持嵌套错误格式)
@@ -265,11 +266,16 @@ class ClaudeResponseParser(ResponseParser):
return result
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
# 对于 message_start 事件usage 在 message.usage 路径下
# 对于其他响应usage 在顶层
usage = response.get("usage", {})
if not usage and "message" in response:
usage = response.get("message", {}).get("usage", {})
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
"cache_creation_tokens": extract_cache_creation_tokens(usage),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
}

View File

@@ -8,6 +8,7 @@
- 请求/响应数据
"""
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
@@ -25,12 +26,18 @@ class StreamContext:
model: str
api_format: str
# 请求标识信息CLI handler 需要)
request_id: str = ""
user_id: int = 0
api_key_id: int = 0
# Provider 信息(在请求执行时填充)
provider_name: Optional[str] = None
provider_id: Optional[str] = None
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
attempt_id: Optional[str] = None
attempt_synced: bool = False
provider_api_format: Optional[str] = None # Provider 的响应格式
# 模型映射
@@ -43,7 +50,14 @@ class StreamContext:
cache_creation_tokens: int = 0
# 响应内容
collected_text: str = ""
_collected_text_parts: List[str] = field(default_factory=list, repr=False)
response_id: Optional[str] = None
final_usage: Optional[Dict[str, Any]] = None
final_response: Optional[Dict[str, Any]] = None
# 时间指标
first_byte_time_ms: Optional[int] = None # 首字时间 (TTFB - Time To First Byte)
start_time: float = field(default_factory=time.time)
# 响应状态
status_code: int = 200
@@ -55,6 +69,12 @@ class StreamContext:
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None
# 格式转换信息CLI handler 需要)
client_api_format: str = ""
# Provider 响应元数据CLI handler 需要)
response_metadata: Dict[str, Any] = field(default_factory=dict)
# 流式处理统计
data_count: int = 0
chunk_count: int = 0
@@ -71,16 +91,30 @@ class StreamContext:
self.chunk_count = 0
self.data_count = 0
self.has_completion = False
self.collected_text = ""
self._collected_text_parts = []
self.input_tokens = 0
self.output_tokens = 0
self.cached_tokens = 0
self.cache_creation_tokens = 0
self.error_message = None
self.status_code = 200
self.first_byte_time_ms = None
self.response_headers = {}
self.provider_request_headers = {}
self.provider_request_body = None
self.response_id = None
self.final_usage = None
self.final_response = None
@property
def collected_text(self) -> str:
"""已收集的文本内容(按需拼接,避免在流式过程中频繁做字符串拷贝)"""
return "".join(self._collected_text_parts)
def append_text(self, text: str) -> None:
"""追加文本内容(仅在需要收集文本时调用)"""
if text:
self._collected_text_parts.append(text)
def update_provider_info(
self,
@@ -104,14 +138,40 @@ class StreamContext:
cached_tokens: Optional[int] = None,
cache_creation_tokens: Optional[int] = None,
) -> None:
"""更新 Token 使用统计"""
if input_tokens is not None:
"""
更新 Token 使用统计
采用防御性更新策略:只有当新值 > 0 或当前值为 0 时才更新,避免用 0 覆盖已有的正确值。
设计原理:
- 在流式响应中,某些事件可能不包含完整的 usage 信息(字段为 0 或不存在)
- 后续事件可能会提供完整的统计数据
- 通过这种策略,确保一旦获得非零值就保留它,不会被后续的 0 值覆盖
示例场景:
- message_start 事件input_tokens=100, output_tokens=0
- message_delta 事件input_tokens=0, output_tokens=50
- 最终结果input_tokens=100, output_tokens=50
注意事项:
- 此策略假设初始值为 0 是正确的默认状态
- 如果需要将已有值重置为 0请直接修改实例属性不使用此方法
Args:
input_tokens: 输入 tokens 数量
output_tokens: 输出 tokens 数量
cached_tokens: 缓存命中 tokens 数量
cache_creation_tokens: 缓存创建 tokens 数量
"""
if input_tokens is not None and (input_tokens > 0 or self.input_tokens == 0):
self.input_tokens = input_tokens
if output_tokens is not None:
if output_tokens is not None and (output_tokens > 0 or self.output_tokens == 0):
self.output_tokens = output_tokens
if cached_tokens is not None:
if cached_tokens is not None and (cached_tokens > 0 or self.cached_tokens == 0):
self.cached_tokens = cached_tokens
if cache_creation_tokens is not None:
if cache_creation_tokens is not None and (
cache_creation_tokens > 0 or self.cache_creation_tokens == 0
):
self.cache_creation_tokens = cache_creation_tokens
def mark_failed(self, status_code: int, error_message: str) -> None:
@@ -119,6 +179,19 @@ class StreamContext:
self.status_code = status_code
self.error_message = error_message
def record_first_byte_time(self, start_time: float) -> None:
"""
记录首字时间 (TTFB - Time To First Byte)
应在第一次向客户端发送数据时调用。
如果已记录过,则不会覆盖(避免重试时重复记录)。
Args:
start_time: 请求开始时间 (time.time())
"""
if self.first_byte_time_ms is None:
self.first_byte_time_ms = int((time.time() - start_time) * 1000)
def is_success(self) -> bool:
"""检查请求是否成功"""
return self.status_code < 400
@@ -145,10 +218,22 @@ class StreamContext:
获取日志摘要
用于请求完成/失败时的日志输出。
包含首字时间 (TTFB) 和总响应时间,分两行显示。
"""
status = "OK" if self.is_success() else "FAIL"
return (
# 第一行:基本信息 + 首字时间
line1 = (
f"[{status}] {request_id[:8]} | {self.model} | "
f"{self.provider_name or 'unknown'} | {response_time_ms}ms | "
f"{self.provider_name or 'unknown'}"
)
if self.first_byte_time_ms is not None:
line1 += f" | TTFB: {self.first_byte_time_ms}ms"
# 第二行:总响应时间 + tokens
line2 = (
f" Total: {response_time_ms}ms | "
f"in:{self.input_tokens} out:{self.output_tokens}"
)
return f"{line1}\n{line2}"

View File

@@ -9,7 +9,9 @@
"""
import asyncio
import codecs
import json
import time
from typing import Any, AsyncGenerator, Callable, Optional
import httpx
@@ -36,6 +38,8 @@ class StreamProcessor:
request_id: str,
default_parser: ResponseParser,
on_streaming_start: Optional[Callable[[], None]] = None,
*,
collect_text: bool = False,
):
"""
初始化流处理器
@@ -48,6 +52,7 @@ class StreamProcessor:
self.request_id = request_id
self.default_parser = default_parser
self.on_streaming_start = on_streaming_start
self.collect_text = collect_text
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
"""
@@ -112,9 +117,10 @@ class StreamProcessor:
)
# 提取文本
text = parser.extract_text_content(data)
if text:
ctx.collected_text += text
if self.collect_text:
text = parser.extract_text_content(data)
if text:
ctx.append_text(text)
# 检查完成
event_type = event_name or data.get("type", "")
@@ -123,7 +129,7 @@ class StreamProcessor:
async def prefetch_and_check_error(
self,
line_iterator: Any,
byte_iterator: Any,
provider: Provider,
endpoint: ProviderEndpoint,
ctx: StreamContext,
@@ -136,97 +142,126 @@ class StreamProcessor:
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
Args:
line_iterator: 迭代器
byte_iterator: 字节流迭代器
provider: Provider 对象
endpoint: Endpoint 对象
ctx: 流式上下文
max_prefetch_lines: 最多预读行数
Returns:
预读的列表
预读的字节块列表
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
"""
prefetched_lines: list = []
prefetched_chunks: list = []
parser = self.get_parser_for_provider(ctx)
buffer = b""
line_count = 0
should_stop = False
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
try:
line_count = 0
async for line in line_iterator:
prefetched_lines.append(line)
line_count += 1
async for chunk in byte_iterator:
prefetched_chunks.append(chunk)
buffer += chunk
normalized_line = line.rstrip("\r")
if not normalized_line or normalized_line.startswith(":"):
if line_count >= max_prefetch_lines:
# 尝试按行解析缓冲区
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\r\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
line_count += 1
# 跳过空行和注释行
if not line or line.startswith(":"):
if line_count >= max_prefetch_lines:
should_stop = True
break
continue
# 尝试解析 SSE 数据
data_str = line
if line.startswith("data: "):
data_str = line[6:]
if data_str == "[DONE]":
should_stop = True
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
try:
data = json.loads(data_str)
except json.JSONDecodeError:
if line_count >= max_prefetch_lines:
should_stop = True
break
continue
if data_str == "[DONE]":
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and parser.is_error_response(data):
parsed = parser.parse_response(data, 200)
logger.warning(
f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}"
)
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
should_stop = True
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and parser.is_error_response(data):
parsed = parser.parse_response(data, 200)
logger.warning(
f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}"
)
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
break
if should_stop or line_count >= max_prefetch_lines:
break
except EmbeddedErrorException:
raise
except Exception as e:
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_lines
return prefetched_chunks
async def create_response_stream(
self,
ctx: StreamContext,
line_iterator: Any,
byte_iterator: Any,
response_ctx: Any,
http_client: httpx.AsyncClient,
prefetched_lines: Optional[list] = None,
prefetched_chunks: Optional[list] = None,
*,
start_time: Optional[float] = None,
) -> AsyncGenerator[bytes, None]:
"""
创建响应流生成器
统一的流生成器,支持预读数据和不带预读数据两种情况
从字节流中解析 SSE 数据并转发,支持预读数据。
Args:
ctx: 流式上下文
line_iterator: 迭代器
byte_iterator: 字节流迭代器
response_ctx: HTTP 响应上下文管理器
http_client: HTTP 客户端
prefetched_lines: 预读的列表(可选)
prefetched_chunks: 预读的字节块列表(可选)
start_time: 请求开始时间,用于计算 TTFB可选
Yields:
编码后的响应数据块
@@ -234,25 +269,82 @@ class StreamProcessor:
try:
sse_parser = SSEEventParser()
streaming_started = False
buffer = b""
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 处理预读数据
if prefetched_lines:
if prefetched_chunks:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
for line in prefetched_lines:
for chunk in self._process_line(ctx, sse_parser, line):
yield chunk
for chunk in prefetched_chunks:
# 记录首字时间 (TTFB) - 在 yield 之前记录
if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 把原始数据转发给客户端
yield chunk
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False)
self._process_line(ctx, sse_parser, line)
except Exception as e:
# 解码失败,记录警告但继续处理
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
# 处理剩余的流数据
async for line in line_iterator:
async for chunk in byte_iterator:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
for chunk in self._process_line(ctx, sse_parser, line):
yield chunk
# 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 原始数据透传
yield chunk
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False)
self._process_line(ctx, sse_parser, line)
except Exception as e:
# 解码失败,记录警告但继续处理
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
# 处理剩余的缓冲区数据(如果有未完成的行)
if buffer:
try:
# 使用 final=True 处理最后的不完整字符
line = decoder.decode(buffer, True)
self._process_line(ctx, sse_parser, line)
except Exception as e:
logger.warning(
f"[{self.request_id}] 处理剩余缓冲区失败: {e}, "
f"bytes={buffer[:50]!r}"
)
# 处理剩余事件
for event in sse_parser.flush():
@@ -268,7 +360,7 @@ class StreamProcessor:
ctx: StreamContext,
sse_parser: SSEEventParser,
line: str,
) -> list[bytes]:
) -> None:
"""
处理单行数据
@@ -276,26 +368,17 @@ class StreamProcessor:
ctx: 流式上下文
sse_parser: SSE 解析器
line: 原始行数据
Returns:
要发送的数据块列表
"""
result: list[bytes] = []
normalized_line = line.rstrip("\r")
# SSEEventParser 以“去掉换行符”的单行文本作为输入;这里统一剔除 CR/LF
# 避免把空行误判成 "\n" 并导致事件边界解析错误。
normalized_line = line.rstrip("\r\n")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
result.append(b"\n")
else:
if normalized_line != "":
ctx.chunk_count += 1
result.append((line + "\n").encode("utf-8"))
for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
return result
for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
async def create_monitored_stream(
self,
@@ -317,16 +400,26 @@ class StreamProcessor:
响应数据块
"""
try:
# 断连检查频率:每次 await 都会引入调度开销,过于频繁会让流式"发一段停一段"
# 这里按时间间隔节流,兼顾及时停止上游读取与吞吐平滑性。
next_disconnect_check_at = 0.0
disconnect_check_interval_s = 0.25
async for chunk in stream_generator:
if await is_disconnected():
logger.warning(f"ID:{self.request_id} | Client disconnected")
ctx.status_code = 499 # Client Closed Request
ctx.error_message = "client_disconnected"
break
now = time.monotonic()
if now >= next_disconnect_check_at:
next_disconnect_check_at = now + disconnect_check_interval_s
if await is_disconnected():
logger.warning(f"ID:{self.request_id} | Client disconnected")
ctx.status_code = 499 # Client Closed Request
ctx.error_message = "client_disconnected"
break
yield chunk
except asyncio.CancelledError:
ctx.status_code = 499
ctx.error_message = "client_disconnected"
raise
except Exception as e:
ctx.status_code = 500

View File

@@ -8,6 +8,7 @@
"""
import asyncio
import time
from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
@@ -57,7 +58,7 @@ class StreamTelemetryRecorder:
ctx: StreamContext,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
response_time_ms: int,
start_time: float,
) -> None:
"""
记录流式统计信息
@@ -66,11 +67,15 @@ class StreamTelemetryRecorder:
ctx: 流式上下文
original_headers: 原始请求头
original_request_body: 原始请求体
response_time_ms: 响应时间(毫秒)
start_time: 请求开始时间 (time.time())
"""
bg_db = None
try:
# 在流结束后计算响应时间,与首字时间使用相同的时间基准
# 注意不要把统计延迟stream_stats_delay算进响应时间里
response_time_ms = int((time.time() - start_time) * 1000)
await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭
if not ctx.provider_name:
@@ -155,6 +160,7 @@ class StreamTelemetryRecorder:
input_tokens=ctx.input_tokens,
output_tokens=ctx.output_tokens,
response_time_ms=response_time_ms,
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
status_code=ctx.status_code,
request_headers=original_headers,
request_body=actual_request_body,

View File

@@ -0,0 +1,55 @@
"""
Handler 基础工具函数
"""
from typing import Any, Dict, Optional
def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
"""
提取缓存创建 tokens兼容新旧格式
Claude API 在不同版本中使用了不同的字段名来表示缓存创建 tokens
- 新格式2024年后使用 claude_cache_creation_5_m_tokens 和
claude_cache_creation_1_h_tokens 分别表示 5 分钟和 1 小时缓存
- 旧格式:使用 cache_creation_input_tokens 表示总的缓存创建 tokens
此函数自动检测并适配两种格式,优先使用新格式。
Args:
usage: API 响应中的 usage 字典
Returns:
缓存创建 tokens 总数
"""
# 检查新格式字段是否存在(而非值是否为 0
# 如果字段存在,即使值为 0 也是合法的,不应 fallback 到旧格式
has_new_format = (
"claude_cache_creation_5_m_tokens" in usage
or "claude_cache_creation_1_h_tokens" in usage
)
if has_new_format:
cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0)
cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0)
return int(cache_5m) + int(cache_1h)
# 回退到旧格式
return int(usage.get("cache_creation_input_tokens", 0))
def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
"""
构建 SSEtext/event-stream推荐响应头用于减少代理缓冲带来的卡顿/成段输出。
说明:
- Cache-Control: no-transform 可避免部分代理对流做压缩/改写导致缓冲
- X-Accel-Buffering: no 可显式提示 Nginx 关闭缓冲(即使全局已关闭也无害)
"""
headers: Dict[str, str] = {
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
}
if extra_headers:
headers.update(extra_headers)
return headers

View File

@@ -8,6 +8,7 @@ Claude Chat Handler - 基于通用 Chat Handler 基类的简化实现
from typing import Any, Dict, Optional
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.api.handlers.base.utils import extract_cache_creation_tokens
class ClaudeChatHandler(ChatHandlerBase):
@@ -63,7 +64,7 @@ class ClaudeChatHandler(ChatHandlerBase):
result["model"] = mapped_model
return result
async def _convert_request(self, request):
async def _convert_request(self, request: Any) -> Any:
"""
将请求转换为 Claude 格式
@@ -109,30 +110,18 @@ class ClaudeChatHandler(ChatHandlerBase):
Claude 格式使用:
- input_tokens / output_tokens
- cache_creation_input_tokens / cache_read_input_tokens
- 新格式claude_cache_creation_5_m_tokens / claude_cache_creation_1_h_tokens
"""
usage = response.get("usage", {})
input_tokens = usage.get("input_tokens", 0)
output_tokens = usage.get("output_tokens", 0)
cache_creation_input_tokens = usage.get("cache_creation_input_tokens", 0)
cache_read_input_tokens = usage.get("cache_read_input_tokens", 0)
# 处理新的 cache_creation 格式
if "cache_creation" in usage:
cache_creation_data = usage.get("cache_creation", {})
if not cache_creation_input_tokens:
cache_creation_input_tokens = cache_creation_data.get(
"ephemeral_5m_input_tokens", 0
) + cache_creation_data.get("ephemeral_1h_input_tokens", 0)
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"cache_creation_input_tokens": cache_creation_input_tokens,
"cache_read_input_tokens": cache_read_input_tokens,
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_input_tokens": extract_cache_creation_tokens(usage),
"cache_read_input_tokens": usage.get("cache_read_input_tokens", 0),
}
def _normalize_response(self, response: Dict) -> Dict:
def _normalize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""
规范化 Claude 响应
@@ -143,8 +132,9 @@ class ClaudeChatHandler(ChatHandlerBase):
规范化后的响应
"""
if self.response_normalizer and self.response_normalizer.should_normalize(response):
return self.response_normalizer.normalize_claude_response(
result: Dict[str, Any] = self.response_normalizer.normalize_claude_response(
response_data=response,
request_id=self.request_id,
)
return result
return response

View File

@@ -9,6 +9,8 @@ from __future__ import annotations
import json
from typing import Any, Dict, List, Optional
from src.api.handlers.base.utils import extract_cache_creation_tokens
class ClaudeStreamParser:
"""
@@ -193,7 +195,7 @@ class ClaudeStreamParser:
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
"cache_creation_tokens": extract_cache_creation_tokens(usage),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
}
@@ -204,7 +206,7 @@ class ClaudeStreamParser:
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
"cache_creation_tokens": extract_cache_creation_tokens(usage),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
}

View File

@@ -11,6 +11,7 @@ from src.api.handlers.base.cli_handler_base import (
CliMessageHandlerBase,
StreamContext,
)
from src.api.handlers.base.utils import extract_cache_creation_tokens
class ClaudeCliMessageHandler(CliMessageHandlerBase):
@@ -95,11 +96,12 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
usage = message.get("usage", {})
if usage:
ctx.input_tokens = usage.get("input_tokens", 0)
# Claude 的缓存 tokens 使用不同的字段名
cache_read = usage.get("cache_read_input_tokens", 0)
if cache_read:
ctx.cached_tokens = cache_read
cache_creation = usage.get("cache_creation_input_tokens", 0)
cache_creation = extract_cache_creation_tokens(usage)
if cache_creation:
ctx.cache_creation_tokens = cache_creation
@@ -109,7 +111,7 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
if delta.get("type") == "text_delta":
text = delta.get("text", "")
if text:
ctx.collected_text += text
ctx.append_text(text)
# 处理消息增量(包含最终 usage
elif event_type == "message_delta":
@@ -119,11 +121,15 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
ctx.input_tokens = usage["input_tokens"]
if "output_tokens" in usage:
ctx.output_tokens = usage["output_tokens"]
# 更新缓存 tokens
# 更新缓存读取 tokens
if "cache_read_input_tokens" in usage:
ctx.cached_tokens = usage["cache_read_input_tokens"]
if "cache_creation_input_tokens" in usage:
ctx.cache_creation_tokens = usage["cache_creation_input_tokens"]
# 更新缓存创建 tokens
cache_creation = extract_cache_creation_tokens(usage)
if cache_creation > 0:
ctx.cache_creation_tokens = cache_creation
# 检查是否结束
delta = data.get("delta", {})

View File

@@ -160,7 +160,7 @@ class GeminiCliMessageHandler(CliMessageHandlerBase):
parts = content.get("parts", [])
for part in parts:
if "text" in part:
ctx.collected_text += part["text"]
ctx.append_text(part["text"])
# 检查结束原因
finish_reason = candidate.get("finishReason")

View File

@@ -94,9 +94,9 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
if event_type in ["response.output_text.delta", "response.outtext.delta"]:
delta = data.get("delta")
if isinstance(delta, str):
ctx.collected_text += delta
ctx.append_text(delta)
elif isinstance(delta, dict) and "text" in delta:
ctx.collected_text += delta["text"]
ctx.append_text(delta["text"])
# 处理完成事件
elif event_type == "response.completed":
@@ -124,7 +124,7 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
if content_item.get("type") == "output_text":
text = content_item.get("text", "")
if text:
ctx.collected_text += text
ctx.append_text(text)
# 备用:从顶层 usage 提取
usage_obj = data.get("usage")

View File

@@ -210,9 +210,9 @@ class PublicModelsAdapter(PublicApiAdapter):
provider_display_name=provider.display_name,
name=unified_name,
display_name=display_name,
description=global_model.description if global_model else None,
description=global_model.config.get("description") if global_model and global_model.config else None,
tags=None,
icon_url=global_model.icon_url if global_model else None,
icon_url=global_model.config.get("icon_url") if global_model and global_model.config else None,
input_price_per_1m=model.get_effective_input_price(),
output_price_per_1m=model.get_effective_output_price(),
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
@@ -274,7 +274,6 @@ class PublicSearchModelsAdapter(PublicApiAdapter):
Model.provider_model_name.ilike(f"%{self.query}%")
| GlobalModel.name.ilike(f"%{self.query}%")
| GlobalModel.display_name.ilike(f"%{self.query}%")
| GlobalModel.description.ilike(f"%{self.query}%")
)
query_stmt = query_stmt.filter(search_filter)
if self.provider_id is not None:
@@ -293,9 +292,9 @@ class PublicSearchModelsAdapter(PublicApiAdapter):
provider_display_name=provider.display_name,
name=unified_name,
display_name=display_name,
description=global_model.description if global_model else None,
description=global_model.config.get("description") if global_model and global_model.config else None,
tags=None,
icon_url=global_model.icon_url if global_model else None,
icon_url=global_model.config.get("icon_url") if global_model and global_model.config else None,
input_price_per_1m=model.get_effective_input_price(),
output_price_per_1m=model.get_effective_output_price(),
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
@@ -499,7 +498,6 @@ class PublicGlobalModelsAdapter(PublicApiAdapter):
or_(
GlobalModel.name.ilike(search_term),
GlobalModel.display_name.ilike(search_term),
GlobalModel.description.ilike(search_term),
)
)
@@ -517,21 +515,11 @@ class PublicGlobalModelsAdapter(PublicApiAdapter):
id=gm.id,
name=gm.name,
display_name=gm.display_name,
description=gm.description,
icon_url=gm.icon_url,
is_active=gm.is_active,
default_price_per_request=gm.default_price_per_request,
default_tiered_pricing=gm.default_tiered_pricing,
default_supports_vision=gm.default_supports_vision or False,
default_supports_function_calling=gm.default_supports_function_calling or False,
default_supports_streaming=(
gm.default_supports_streaming
if gm.default_supports_streaming is not None
else True
),
default_supports_extended_thinking=gm.default_supports_extended_thinking
or False,
supported_capabilities=gm.supported_capabilities,
config=gm.config,
)
)

View File

@@ -251,8 +251,8 @@ def _build_gemini_list_response(
"version": "001",
"displayName": m.display_name,
"description": m.description or f"Model {m.id}",
"inputTokenLimit": 128000,
"outputTokenLimit": 8192,
"inputTokenLimit": m.context_limit if m.context_limit is not None else 128000,
"outputTokenLimit": m.output_limit if m.output_limit is not None else 8192,
"supportedGenerationMethods": ["generateContent", "countTokens"],
"temperature": 1.0,
"maxTemperature": 2.0,
@@ -297,8 +297,8 @@ def _build_gemini_model_response(model_info: ModelInfo) -> dict:
"version": "001",
"displayName": model_info.display_name,
"description": model_info.description or f"Model {model_info.id}",
"inputTokenLimit": 128000,
"outputTokenLimit": 8192,
"inputTokenLimit": model_info.context_limit if model_info.context_limit is not None else 128000,
"outputTokenLimit": model_info.output_limit if model_info.output_limit is not None else 8192,
"supportedGenerationMethods": ["generateContent", "countTokens"],
"temperature": 1.0,
"maxTemperature": 2.0,

View File

@@ -273,16 +273,17 @@ def get_db_url() -> str:
def init_db():
"""初始化数据库"""
"""初始化数据库
注意:数据库表结构由 Alembic 管理,部署时请运行 ./migrate.sh
"""
logger.info("初始化数据库...")
# 确保引擎已创建
engine = _ensure_engine()
_ensure_engine()
# 创建所有表
Base.metadata.create_all(bind=engine)
# 数据库表已通过SQLAlchemy自动创建
# 数据库表结构由 Alembic 迁移管理
# 首次部署或更新后请运行: ./migrate.sh
db = _SessionLocal()
try:

View File

@@ -562,20 +562,15 @@ class PublicGlobalModelResponse(BaseModel):
id: str
name: str
display_name: Optional[str] = None
description: Optional[str] = None
icon_url: Optional[str] = None
is_active: bool = True
# 按次计费配置
default_price_per_request: Optional[float] = None
# 阶梯计费配置
default_tiered_pricing: Optional[dict] = None
# 默认能力
default_supports_vision: bool = False
default_supports_function_calling: bool = False
default_supports_streaming: bool = True
default_supports_extended_thinking: bool = False
# Key 能力配置
supported_capabilities: Optional[List[str]] = None
# 模型配置JSON
config: Optional[dict] = None
class PublicGlobalModelListResponse(BaseModel):

View File

@@ -26,6 +26,7 @@ from sqlalchemy import (
Text,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
@@ -307,7 +308,8 @@ class Usage(Base):
is_stream = Column(Boolean, default=False) # 是否为流式请求
status_code = Column(Integer)
error_message = Column(Text, nullable=True)
response_time_ms = Column(Integer) # 响应时间(毫秒)
response_time_ms = Column(Integer) # 响应时间(毫秒)
first_byte_time_ms = Column(Integer, nullable=True) # 首字时间/TTFB毫秒
# 请求状态追踪
# pending: 请求开始处理中
@@ -575,11 +577,6 @@ class GlobalModel(Base):
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
name = Column(String(100), unique=True, nullable=False, index=True) # 统一模型名(唯一)
display_name = Column(String(100), nullable=False)
description = Column(Text, nullable=True)
# 模型元数据
icon_url = Column(String(500), nullable=True)
official_url = Column(String(500), nullable=True) # 官方文档链接
# 按次计费配置(每次请求的固定费用,美元)- 可选,与按 token 计费叠加
default_price_per_request = Column(Float, nullable=True, default=None) # 每次请求固定费用
@@ -605,17 +602,34 @@ class GlobalModel(Base):
# }
default_tiered_pricing = Column(JSON, nullable=False)
# 默认能力配置 - Provider 可覆盖
default_supports_vision = Column(Boolean, default=False, nullable=True)
default_supports_function_calling = Column(Boolean, default=False, nullable=True)
default_supports_streaming = Column(Boolean, default=True, nullable=True)
default_supports_extended_thinking = Column(Boolean, default=False, nullable=True)
default_supports_image_generation = Column(Boolean, default=False, nullable=True)
# Key 能力配置 - 模型支持的能力列表(如 ["cache_1h", "context_1m"]
# Key 只能启用模型支持的能力
supported_capabilities = Column(JSON, nullable=True, default=list)
# 模型配置JSON格式- 包含能力、规格、元信息等
# 结构示例:
# {
# # 能力配置
# "streaming": true,
# "vision": true,
# "function_calling": true,
# "extended_thinking": false,
# "image_generation": false,
# # 规格参数
# "context_limit": 200000,
# "output_limit": 8192,
# # 元信息
# "description": "...",
# "icon_url": "...",
# "official_url": "...",
# "knowledge_cutoff": "2024-04",
# "family": "claude-3.5",
# "release_date": "2024-10-22",
# "input_modalities": ["text", "image"],
# "output_modalities": ["text"],
# }
config = Column(JSONB, nullable=True, default=dict)
# 状态
is_active = Column(Boolean, default=True, nullable=False)
@@ -766,11 +780,22 @@ class Model(Base):
"""获取有效的能力配置(通用辅助方法)"""
local_value = getattr(self, attr_name, None)
if local_value is not None:
return local_value
return bool(local_value)
if self.global_model:
global_value = getattr(self.global_model, f"default_{attr_name}", None)
if global_value is not None:
return global_value
config_key_map = {
"supports_vision": "vision",
"supports_function_calling": "function_calling",
"supports_streaming": "streaming",
"supports_extended_thinking": "extended_thinking",
"supports_image_generation": "image_generation",
}
config_key = config_key_map.get(attr_name)
if config_key:
global_config = getattr(self.global_model, "config", None)
if isinstance(global_config, dict):
global_value = global_config.get(config_key)
if global_value is not None:
return bool(global_value)
return default
def get_effective_supports_vision(self) -> bool:

View File

@@ -187,9 +187,6 @@ class GlobalModelCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=100, description="统一模型名(唯一)")
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
description: Optional[str] = Field(None, description="模型描述")
official_url: Optional[str] = Field(None, max_length=500, description="官方文档链接")
icon_url: Optional[str] = Field(None, max_length=500, description="图标 URL")
# 按次计费配置(可选,与阶梯计费叠加)
default_price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
# 统一阶梯计费配置(必填)
@@ -197,22 +194,15 @@ class GlobalModelCreate(BaseModel):
default_tiered_pricing: TieredPricingConfig = Field(
..., description="阶梯计费配置(固定价格用单阶梯表示)"
)
# 默认能力配置
default_supports_vision: Optional[bool] = Field(False, description="默认是否支持视觉")
default_supports_function_calling: Optional[bool] = Field(
False, description="默认是否支持函数调用"
)
default_supports_streaming: Optional[bool] = Field(True, description="默认是否支持流式输出")
default_supports_extended_thinking: Optional[bool] = Field(
False, description="默认是否支持扩展思考"
)
default_supports_image_generation: Optional[bool] = Field(
False, description="默认是否支持图像生成"
)
# Key 能力配置 - 模型支持的能力列表(如 ["cache_1h", "context_1m"]
supported_capabilities: Optional[List[str]] = Field(
None, description="支持的 Key 能力列表"
)
# 模型配置JSON格式- 包含能力、规格、元信息等
config: Optional[Dict[str, Any]] = Field(
None,
description="模型配置streaming, vision, context_limit, description 等)"
)
is_active: Optional[bool] = Field(True, description="是否激活")
@@ -220,9 +210,6 @@ class GlobalModelUpdate(BaseModel):
"""更新 GlobalModel 请求"""
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = None
official_url: Optional[str] = Field(None, max_length=500)
icon_url: Optional[str] = Field(None, max_length=500)
is_active: Optional[bool] = None
# 按次计费配置
default_price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
@@ -230,16 +217,15 @@ class GlobalModelUpdate(BaseModel):
default_tiered_pricing: Optional[TieredPricingConfig] = Field(
None, description="阶梯计费配置"
)
# 默认能力配置
default_supports_vision: Optional[bool] = None
default_supports_function_calling: Optional[bool] = None
default_supports_streaming: Optional[bool] = None
default_supports_extended_thinking: Optional[bool] = None
default_supports_image_generation: Optional[bool] = None
# Key 能力配置 - 模型支持的能力列表(如 ["cache_1h", "context_1m"]
supported_capabilities: Optional[List[str]] = Field(
None, description="支持的 Key 能力列表"
)
# 模型配置JSON格式- 包含能力、规格、元信息等
config: Optional[Dict[str, Any]] = Field(
None,
description="模型配置streaming, vision, context_limit, description 等)"
)
class GlobalModelResponse(BaseModel):
@@ -248,9 +234,6 @@ class GlobalModelResponse(BaseModel):
id: str
name: str
display_name: str
description: Optional[str]
official_url: Optional[str]
icon_url: Optional[str]
is_active: bool
# 按次计费配置
default_price_per_request: Optional[float] = Field(None, description="每次请求固定费用")
@@ -258,16 +241,15 @@ class GlobalModelResponse(BaseModel):
default_tiered_pricing: TieredPricingConfig = Field(
..., description="阶梯计费配置"
)
# 默认能力配置
default_supports_vision: Optional[bool]
default_supports_function_calling: Optional[bool]
default_supports_streaming: Optional[bool]
default_supports_extended_thinking: Optional[bool]
default_supports_image_generation: Optional[bool]
# Key 能力配置 - 模型支持的能力列表
supported_capabilities: Optional[List[str]] = Field(
default=None, description="支持的 Key 能力列表"
)
# 模型配置JSON格式
config: Optional[Dict[str, Any]] = Field(
default=None,
description="模型配置streaming, vision, context_limit, description 等)"
)
# 统计数据(可选)
provider_count: Optional[int] = Field(default=0, description="支持的 Provider 数量")
usage_count: Optional[int] = Field(default=0, description="调用次数")

View File

@@ -385,7 +385,7 @@ class ModelCacheService:
"is_active": model.is_active,
"is_available": model.is_available if hasattr(model, "is_available") else True,
"price_per_request": (
float(model.price_per_request) if model.price_per_request else None
float(model.price_per_request) if model.price_per_request is not None else None
),
"tiered_pricing": model.tiered_pricing,
"supports_vision": model.supports_vision,
@@ -425,14 +425,15 @@ class ModelCacheService:
"id": global_model.id,
"name": global_model.name,
"display_name": global_model.display_name,
"default_supports_vision": global_model.default_supports_vision,
"default_supports_function_calling": global_model.default_supports_function_calling,
"default_supports_streaming": global_model.default_supports_streaming,
"default_supports_extended_thinking": global_model.default_supports_extended_thinking,
"default_supports_image_generation": global_model.default_supports_image_generation,
"supported_capabilities": global_model.supported_capabilities,
"config": global_model.config,
"default_tiered_pricing": global_model.default_tiered_pricing,
"default_price_per_request": (
float(global_model.default_price_per_request)
if global_model.default_price_per_request is not None
else None
),
"is_active": global_model.is_active,
"description": global_model.description,
}
@staticmethod
@@ -442,19 +443,10 @@ class ModelCacheService:
id=global_model_dict["id"],
name=global_model_dict["name"],
display_name=global_model_dict.get("display_name"),
default_supports_vision=global_model_dict.get("default_supports_vision", False),
default_supports_function_calling=global_model_dict.get(
"default_supports_function_calling", False
),
default_supports_streaming=global_model_dict.get("default_supports_streaming", True),
default_supports_extended_thinking=global_model_dict.get(
"default_supports_extended_thinking", False
),
default_supports_image_generation=global_model_dict.get(
"default_supports_image_generation", False
),
supported_capabilities=global_model_dict.get("supported_capabilities") or [],
config=global_model_dict.get("config"),
default_tiered_pricing=global_model_dict.get("default_tiered_pricing"),
default_price_per_request=global_model_dict.get("default_price_per_request"),
is_active=global_model_dict.get("is_active", True),
description=global_model_dict.get("description"),
)
return global_model

View File

@@ -62,7 +62,6 @@ class GlobalModelService:
query = query.filter(
(GlobalModel.name.ilike(search_pattern))
| (GlobalModel.display_name.ilike(search_pattern))
| (GlobalModel.description.ilike(search_pattern))
)
# 按名称排序
@@ -75,21 +74,15 @@ class GlobalModelService:
db: Session,
name: str,
display_name: str,
description: Optional[str] = None,
official_url: Optional[str] = None,
icon_url: Optional[str] = None,
is_active: Optional[bool] = True,
# 按次计费配置
default_price_per_request: Optional[float] = None,
# 阶梯计费配置(必填)
default_tiered_pricing: dict = None,
# 默认能力配置
default_supports_vision: Optional[bool] = None,
default_supports_function_calling: Optional[bool] = None,
default_supports_streaming: Optional[bool] = None,
default_supports_extended_thinking: Optional[bool] = None,
# Key 能力配置
supported_capabilities: Optional[List[str]] = None,
# 模型配置JSON
config: Optional[dict] = None,
) -> GlobalModel:
"""创建 GlobalModel"""
# 检查名称是否已存在
@@ -100,21 +93,15 @@ class GlobalModelService:
global_model = GlobalModel(
name=name,
display_name=display_name,
description=description,
official_url=official_url,
icon_url=icon_url,
is_active=is_active,
# 按次计费配置
default_price_per_request=default_price_per_request,
# 阶梯计费配置
default_tiered_pricing=default_tiered_pricing,
# 默认能力配置
default_supports_vision=default_supports_vision,
default_supports_function_calling=default_supports_function_calling,
default_supports_streaming=default_supports_streaming,
default_supports_extended_thinking=default_supports_extended_thinking,
# Key 能力配置
supported_capabilities=supported_capabilities,
# 模型配置JSON
config=config,
)
db.add(global_model)

View File

@@ -157,6 +157,7 @@ class UsageService:
api_format: Optional[str] = None,
is_stream: bool = False,
response_time_ms: Optional[int] = None,
first_byte_time_ms: Optional[int] = None, # 首字时间 (TTFB)
status_code: int = 200,
error_message: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
@@ -368,6 +369,7 @@ class UsageService:
status_code=status_code,
error_message=error_message,
response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms, # 首字时间 (TTFB)
status=status, # 请求状态追踪
request_metadata=metadata,
request_headers=processed_request_headers,
@@ -419,6 +421,7 @@ class UsageService:
api_format: Optional[str] = None,
is_stream: bool = False,
response_time_ms: Optional[int] = None,
first_byte_time_ms: Optional[int] = None, # 首字时间 (TTFB)
status_code: int = 200,
error_message: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
@@ -629,6 +632,7 @@ class UsageService:
status_code=status_code,
error_message=error_message,
response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms, # 首字时间 (TTFB)
status=status, # 请求状态追踪
request_metadata=metadata,
request_headers=processed_request_headers,
@@ -649,6 +653,7 @@ class UsageService:
existing_usage.status_code = status_code
existing_usage.error_message = error_message
existing_usage.response_time_ms = response_time_ms
existing_usage.first_byte_time_ms = first_byte_time_ms # 更新首字时间
# 更新请求头和请求体(如果有新值)
if processed_request_headers is not None:
existing_usage.request_headers = processed_request_headers
@@ -1315,11 +1320,11 @@ class UsageService:
default_timeout_seconds: int = 300,
) -> List[Dict[str, Any]]:
"""
获取活跃请求状态(用于前端轮询),并自动清理超时的 pending 请求
获取活跃请求状态(用于前端轮询),并自动清理超时的 pending/streaming 请求
与 get_active_requests 不同,此方法:
1. 返回轻量级的状态字典而非完整 Usage 对象
2. 自动检测并清理超时的 pending 请求
2. 自动检测并清理超时的 pending/streaming 请求
3. 支持按 ID 列表查询特定请求
Args:
@@ -1343,6 +1348,7 @@ class UsageService:
Usage.output_tokens,
Usage.total_cost_usd,
Usage.response_time_ms,
Usage.first_byte_time_ms, # 首字时间 (TTFB)
Usage.created_at,
Usage.provider_endpoint_id,
ProviderEndpoint.timeout.label("endpoint_timeout"),
@@ -1361,10 +1367,10 @@ class UsageService:
records = query.all()
# 检查超时的 pending 请求
# 检查超时的 pending/streaming 请求
timeout_ids = []
for r in records:
if r.status == "pending" and r.created_at:
if r.status in ("pending", "streaming") and r.created_at:
# 使用端点配置的超时时间,若无则使用默认值
timeout_seconds = r.endpoint_timeout or default_timeout_seconds
@@ -1392,6 +1398,7 @@ class UsageService:
"output_tokens": r.output_tokens,
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
"response_time_ms": r.response_time_ms,
"first_byte_time_ms": r.first_byte_time_ms, # 首字时间 (TTFB)
}
for r in records
]

View File

@@ -19,7 +19,7 @@ from ..models.database import User, UserRole
security = HTTPBearer()
def get_current_user(
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security), db: Session = Depends(get_db)
) -> User:
"""
@@ -41,7 +41,7 @@ def get_current_user(
try:
# 验证Token格式和签名
try:
payload = AuthService.verify_token(token)
payload = await AuthService.verify_token(token)
except HTTPException as token_error:
# 保持原始的HTTP状态码如401 Unauthorized不要转换为403
logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...")
@@ -122,7 +122,7 @@ def get_current_user(
)
def get_current_user_from_header(
async def get_current_user_from_header(
authorization: Optional[str] = Header(None), db: Session = Depends(get_db)
) -> User:
"""
@@ -144,7 +144,7 @@ def get_current_user_from_header(
token = authorization.replace("Bearer ", "")
try:
payload = AuthService.verify_token(token)
payload = await AuthService.verify_token(token)
user_id = payload.get("user_id")
if not user_id:

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""测试模块"""

View File

@@ -0,0 +1,117 @@
from src.api.handlers.base import stream_context
from src.api.handlers.base.stream_context import StreamContext
def test_collected_text_append_and_property() -> None:
ctx = StreamContext(model="test-model", api_format="OPENAI")
assert ctx.collected_text == ""
ctx.append_text("hello")
ctx.append_text(" ")
ctx.append_text("world")
assert ctx.collected_text == "hello world"
def test_reset_for_retry_clears_state() -> None:
ctx = StreamContext(model="test-model", api_format="OPENAI")
ctx.append_text("x")
ctx.update_usage(input_tokens=10, output_tokens=5)
ctx.parsed_chunks.append({"type": "chunk"})
ctx.chunk_count = 3
ctx.data_count = 2
ctx.has_completion = True
ctx.status_code = 418
ctx.error_message = "boom"
ctx.reset_for_retry()
assert ctx.collected_text == ""
assert ctx.input_tokens == 0
assert ctx.output_tokens == 0
assert ctx.parsed_chunks == []
assert ctx.chunk_count == 0
assert ctx.data_count == 0
assert ctx.has_completion is False
assert ctx.status_code == 200
assert ctx.error_message is None
def test_record_first_byte_time(monkeypatch) -> None:
"""测试记录首字时间"""
ctx = StreamContext(model="claude-3", api_format="claude_messages")
start_time = 100.0
monkeypatch.setattr(stream_context.time, "time", lambda: 100.0123) # 12.3ms
# 记录首字时间
ctx.record_first_byte_time(start_time)
# 验证首字时间已记录
assert ctx.first_byte_time_ms == 12
def test_record_first_byte_time_idempotent(monkeypatch) -> None:
"""测试首字时间只记录一次"""
ctx = StreamContext(model="claude-3", api_format="claude_messages")
start_time = 100.0
# 第一次记录
monkeypatch.setattr(stream_context.time, "time", lambda: 100.010)
ctx.record_first_byte_time(start_time)
first_value = ctx.first_byte_time_ms
# 第二次记录(应该被忽略)
monkeypatch.setattr(stream_context.time, "time", lambda: 100.020)
ctx.record_first_byte_time(start_time)
second_value = ctx.first_byte_time_ms
# 验证值没有改变
assert first_value == second_value
def test_reset_for_retry_clears_first_byte_time(monkeypatch) -> None:
"""测试重试时清除首字时间"""
ctx = StreamContext(model="claude-3", api_format="claude_messages")
start_time = 100.0
# 记录首字时间
monkeypatch.setattr(stream_context.time, "time", lambda: 100.010)
ctx.record_first_byte_time(start_time)
assert ctx.first_byte_time_ms is not None
# 重置
ctx.reset_for_retry()
# 验证首字时间已清除
assert ctx.first_byte_time_ms is None
def test_get_log_summary_with_first_byte_time() -> None:
"""测试日志摘要包含首字时间"""
ctx = StreamContext(model="claude-3", api_format="claude_messages")
ctx.provider_name = "anthropic"
ctx.input_tokens = 100
ctx.output_tokens = 50
ctx.first_byte_time_ms = 123
summary = ctx.get_log_summary("request-id-123", 456)
# 验证包含首字时间和总时间(大写格式)
assert "TTFB: 123ms" in summary
assert "Total: 456ms" in summary
assert "in:100 out:50" in summary
def test_get_log_summary_without_first_byte_time() -> None:
"""测试日志摘要在没有首字时间时的格式"""
ctx = StreamContext(model="claude-3", api_format="claude_messages")
ctx.provider_name = "anthropic"
ctx.input_tokens = 100
ctx.output_tokens = 50
# first_byte_time_ms 保持为 None
summary = ctx.get_log_summary("request-id-123", 456)
# 验证不包含首字时间标记,但有总时间(使用大写 TTFB 和 Total
assert "TTFB:" not in summary
assert "Total: 456ms" in summary
assert "in:100 out:50" in summary

View File

@@ -0,0 +1,32 @@
from typing import Any, Dict, Optional
from src.api.handlers.base.response_parser import ParsedChunk, ParsedResponse, ResponseParser, StreamStats
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.stream_processor import StreamProcessor
from src.utils.sse_parser import SSEEventParser
class DummyParser(ResponseParser):
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
return None
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
return ParsedResponse(raw_response=response, status_code=status_code)
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
return {}
def extract_text_content(self, response: Dict[str, Any]) -> str:
return ""
def test_process_line_strips_newlines_and_finalizes_event() -> None:
ctx = StreamContext(model="test-model", api_format="OPENAI")
processor = StreamProcessor(request_id="test-request", default_parser=DummyParser())
sse_parser = SSEEventParser()
processor._process_line(ctx, sse_parser, 'data: {"type":"response.completed"}\n')
processor._process_line(ctx, sse_parser, "\n")
assert ctx.has_completion is True

View File

@@ -0,0 +1,104 @@
"""测试 handler 基础工具函数"""
import pytest
from src.api.handlers.base.utils import build_sse_headers, extract_cache_creation_tokens
class TestExtractCacheCreationTokens:
"""测试 extract_cache_creation_tokens 函数"""
def test_new_format_only(self) -> None:
"""测试只有新格式字段"""
usage = {
"claude_cache_creation_5_m_tokens": 100,
"claude_cache_creation_1_h_tokens": 200,
}
assert extract_cache_creation_tokens(usage) == 300
def test_new_format_5m_only(self) -> None:
"""测试只有 5 分钟缓存"""
usage = {
"claude_cache_creation_5_m_tokens": 150,
"claude_cache_creation_1_h_tokens": 0,
}
assert extract_cache_creation_tokens(usage) == 150
def test_new_format_1h_only(self) -> None:
"""测试只有 1 小时缓存"""
usage = {
"claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 250,
}
assert extract_cache_creation_tokens(usage) == 250
def test_old_format_only(self) -> None:
"""测试只有旧格式字段"""
usage = {
"cache_creation_input_tokens": 500,
}
assert extract_cache_creation_tokens(usage) == 500
def test_both_formats_prefers_new(self) -> None:
"""测试同时存在时优先使用新格式"""
usage = {
"claude_cache_creation_5_m_tokens": 100,
"claude_cache_creation_1_h_tokens": 200,
"cache_creation_input_tokens": 999, # 应该被忽略
}
assert extract_cache_creation_tokens(usage) == 300
def test_empty_usage(self) -> None:
"""测试空字典"""
usage = {}
assert extract_cache_creation_tokens(usage) == 0
def test_all_zeros(self) -> None:
"""测试所有字段都为 0"""
usage = {
"claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 0,
"cache_creation_input_tokens": 0,
}
assert extract_cache_creation_tokens(usage) == 0
def test_partial_new_format_with_old_format_fallback(self) -> None:
"""测试新格式字段不存在时回退到旧格式"""
usage = {
"cache_creation_input_tokens": 123,
}
assert extract_cache_creation_tokens(usage) == 123
def test_new_format_zero_should_not_fallback(self) -> None:
"""测试新格式字段存在但为 0 时,不应 fallback 到旧格式"""
usage = {
"claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 0,
"cache_creation_input_tokens": 456,
}
# 新格式字段存在,即使值为 0 也应该使用新格式(返回 0
# 而不是 fallback 到旧格式(返回 456
assert extract_cache_creation_tokens(usage) == 0
def test_unrelated_fields_ignored(self) -> None:
"""测试忽略无关字段"""
usage = {
"input_tokens": 1000,
"output_tokens": 2000,
"cache_read_input_tokens": 300,
"claude_cache_creation_5_m_tokens": 50,
"claude_cache_creation_1_h_tokens": 75,
}
assert extract_cache_creation_tokens(usage) == 125
class TestBuildSSEHeaders:
def test_default_headers(self) -> None:
headers = build_sse_headers()
assert headers["Cache-Control"] == "no-cache, no-transform"
assert headers["X-Accel-Buffering"] == "no"
def test_merge_extra_headers(self) -> None:
headers = build_sse_headers({"X-Test": "1", "Cache-Control": "custom"})
assert headers["X-Test"] == "1"
assert headers["Cache-Control"] == "custom"