11 Commits

Author SHA1 Message Date
fawney19
c42ebdd0ee test(handler): add comprehensive stream processor unit tests 2025-12-16 02:40:26 +08:00
fawney19
f1e3c2ab11 feat(frontend-usage): enhance usage UI with first byte latency metrics
- Update usage records table to display first_byte_time_ms metrics
- Improve request timeline visualization for latency tracking
- Extend usage types for new timing information
2025-12-16 02:39:54 +08:00
fawney19
4e2ba0e57f feat(usage): add first_byte_time_ms tracking to usage statistics
- Enhance usage service to capture and store first byte latency metrics
- Update usage API routes to include new timing information
2025-12-16 02:39:36 +08:00
fawney19
a3df41d63d refactor(cli-handler): improve stream handling and response processing
- Refactor CLI handler base for better stream context management
- Optimize request/response handling for Claude, OpenAI, and Gemini CLI adapters
- Enhance telemetry tracking across CLI handlers
2025-12-16 02:39:20 +08:00
fawney19
ad1c8c394c refactor(handler): optimize stream processing and telemetry pipeline
- Enhance stream context for better token and latency tracking
- Refactor stream processor for improved performance metrics
- Improve telemetry integration with first_byte_time_ms support
- Add comprehensive stream context unit tests
2025-12-16 02:39:03 +08:00
fawney19
9b496abb73 feat(db): add first_byte_time_ms column to usage table 2025-12-16 02:38:43 +08:00
fawney19
f3a69a6160 refactor(handler): implement defensive token update strategy and extract cache creation token utility
- Add extract_cache_creation_tokens utility to handle new/old cache creation token formats
- Implement defensive update strategy in StreamContext to prevent zero values overwriting valid data
- Simplify cache creation token parsing in Claude handler using new utility
- Add comprehensive test suite for cache creation token extraction
- Improve type hints in handler classes
2025-12-16 00:02:49 +08:00
fawney19
adcdb73d29 feat(frontend): enhance cache monitoring UI and API integration 2025-12-15 23:12:58 +08:00
fawney19
cf67160821 feat(cache): enhance cache monitoring endpoints and handler integrations 2025-12-15 23:12:48 +08:00
fawney19
718f56ba75 refactor(cache): optimize cache service architecture and provider transport 2025-12-15 23:12:34 +08:00
fawney19
d87de10f62 refactor(docker): remove exposed ports for postgres and redis services 2025-12-15 23:12:23 +08:00
31 changed files with 1492 additions and 550 deletions

View File

@@ -0,0 +1,28 @@
"""add first_byte_time_ms to usage table
Revision ID: 180e63a9c83a
Revises: e9b3d63f0cbf
Create Date: 2025-12-15 17:07:44.631032+00:00
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '180e63a9c83a'
down_revision = 'e9b3d63f0cbf'
branch_labels = None
depends_on = None
def upgrade() -> None:
"""应用迁移:升级到新版本"""
# 添加首字时间字段到 usage 表
op.add_column('usage', sa.Column('first_byte_time_ms', sa.Integer(), nullable=True))
def downgrade() -> None:
"""回滚迁移:降级到旧版本"""
# 删除首字时间字段
op.drop_column('usage', 'first_byte_time_ms')

View File

@@ -12,8 +12,6 @@ services:
TZ: Asia/Shanghai
volumes:
- postgres_data:/var/lib/postgresql/data
ports:
- "${DB_PORT:-5432}:5432"
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s
@@ -27,8 +25,6 @@ services:
command: redis-server --appendonly yes --requirepass ${REDIS_PASSWORD}
volumes:
- redis_data:/data
ports:
- "${REDIS_PORT:-6379}:6379"
healthcheck:
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
interval: 5s

View File

@@ -290,6 +290,19 @@ export interface UnmappedEntry {
ttl: number | null
}
// Provider 模型映射缓存Redis 缓存)
export interface ProviderModelMapping {
provider_id: string
provider_name: string
global_model_id: string
global_model_name: string
global_model_display_name: string | null
provider_model_name: string
aliases: string[] | null
ttl: number | null
hit_count: number
}
export interface ModelMappingCacheStats {
available: boolean
message?: string
@@ -303,6 +316,7 @@ export interface ModelMappingCacheStats {
global_model_resolve: number
}
mappings?: ModelMappingItem[]
provider_model_mappings?: ProviderModelMapping[] | null
unmapped?: UnmappedEntry[] | null
}
@@ -337,5 +351,13 @@ export const modelMappingCacheApi = {
async clearByName(modelName: string): Promise<ClearModelMappingCacheResponse> {
const response = await api.delete(`/api/admin/monitoring/cache/model-mapping/${encodeURIComponent(modelName)}`)
return response.data
},
/**
* 清除指定 Provider 和 GlobalModel 的映射缓存
*/
async clearProviderModel(providerId: string, globalModelId: string): Promise<ClearModelMappingCacheResponse> {
const response = await api.delete(`/api/admin/monitoring/cache/model-mapping/provider/${providerId}/${globalModelId}`)
return response.data
}
}

View File

@@ -479,10 +479,25 @@ const groupedTimeline = computed<NodeGroup[]>(() => {
return groups
})
// 计算链路总耗时(从第一个节点开始到最后一个节点结束
// 计算链路总耗时(使用成功候选的 latency_ms 字段
// 优先使用 latency_ms因为它与 Usage.response_time_ms 使用相同的时间基准
// 避免 finished_at - started_at 带来的额外延迟(数据库操作时间)
const totalTraceLatency = computed(() => {
if (!timeline.value || timeline.value.length === 0) return 0
// 查找成功的候选,使用其 latency_ms
const successCandidate = timeline.value.find(c => c.status === 'success')
if (successCandidate?.latency_ms != null) {
return successCandidate.latency_ms
}
// 如果没有成功的候选,查找失败但有 latency_ms 的候选
const failedWithLatency = timeline.value.find(c => c.status === 'failed' && c.latency_ms != null)
if (failedWithLatency?.latency_ms != null) {
return failedWithLatency.latency_ms
}
// 回退:使用 finished_at - started_at 计算
let earliestStart: number | null = null
let latestEnd: number | null = null

View File

@@ -177,8 +177,9 @@
费用
</TableHead>
<TableHead class="h-12 font-semibold w-[70px] text-right">
<div class="inline-block max-w-[2rem] leading-tight">
响应时间
<div class="flex flex-col items-end text-xs gap-0.5">
<span>首字</span>
<span class="text-muted-foreground font-normal">总耗时</span>
</div>
</TableHead>
</TableRow>
@@ -356,15 +357,28 @@
</div>
</TableCell>
<TableCell class="text-right py-4 w-[70px]">
<span
<div
v-if="record.status === 'pending' || record.status === 'streaming'"
class="text-primary tabular-nums"
class="flex flex-col items-end text-xs gap-0.5"
>
{{ getElapsedTime(record) }}
</span>
<span v-else-if="record.response_time_ms">
{{ (record.response_time_ms / 1000).toFixed(2) }}s
</span>
<span class="text-primary tabular-nums">
{{ getElapsedTime(record) }}
</span>
</div>
<div
v-else-if="record.response_time_ms != null"
class="flex flex-col items-end text-xs gap-0.5"
>
<span
v-if="record.first_byte_time_ms != null"
class="tabular-nums"
>{{ (record.first_byte_time_ms / 1000).toFixed(2) }}s</span>
<span
v-else
class="text-muted-foreground"
>-</span>
<span class="text-muted-foreground tabular-nums">{{ (record.response_time_ms / 1000).toFixed(2) }}s</span>
</div>
<span
v-else
class="text-muted-foreground"

View File

@@ -78,6 +78,7 @@ export interface UsageRecord {
cost: number
actual_cost?: number
response_time_ms?: number
first_byte_time_ms?: number // 首字时间 (TTFB)
is_stream: boolean
status_code?: number
error_message?: string

View File

@@ -299,6 +299,26 @@ async function clearModelMappingByName(modelName: string) {
}
}
async function clearProviderModelMapping(providerId: string, globalModelId: string, displayName?: string) {
const confirmed = await showConfirm({
title: '确认清除',
message: `确定要清除 ${displayName || 'Provider 模型映射'} 的缓存吗?`,
confirmText: '确认清除',
variant: 'destructive'
})
if (!confirmed) return
try {
await modelMappingCacheApi.clearProviderModel(providerId, globalModelId)
showSuccess('已清除 Provider 模型映射缓存')
await fetchModelMappingStats()
} catch (error) {
showError('清除缓存失败')
log.error('清除 Provider 模型映射缓存失败', error)
}
}
function formatTTL(ttl: number | null): string {
if (ttl === null || ttl < 0) return '-'
if (ttl < 60) return `${ttl}s`
@@ -872,9 +892,125 @@ onBeforeUnmount(() => {
</div>
</div>
<!-- Provider 模型映射缓存 -->
<div
v-if="modelMappingStats?.available && modelMappingStats.provider_model_mappings && modelMappingStats.provider_model_mappings.length > 0"
class="border-t border-border/40"
>
<div class="px-6 py-3 text-xs text-muted-foreground border-b border-border/30 bg-muted/20">
Provider 模型映射缓存
</div>
<!-- 桌面端表格 -->
<Table class="hidden md:table">
<TableHeader>
<TableRow>
<TableHead class="w-[15%]">
提供商
</TableHead>
<TableHead class="w-[25%]">
请求名称
</TableHead>
<TableHead class="w-8 text-center" />
<TableHead class="w-[25%]">
映射模型
</TableHead>
<TableHead class="w-[10%] text-center">
剩余
</TableHead>
<TableHead class="w-[10%] text-center">
次数
</TableHead>
<TableHead class="w-[7%] text-right">
操作
</TableHead>
</TableRow>
</TableHeader>
<TableBody>
<template
v-for="(mapping, index) in modelMappingStats.provider_model_mappings"
:key="index"
>
<TableRow
v-for="(alias, aliasIndex) in (mapping.aliases || [])"
:key="`${index}-${aliasIndex}`"
>
<TableCell>
<Badge variant="outline" class="text-xs">
{{ mapping.provider_name }}
</Badge>
</TableCell>
<TableCell>
<span class="text-sm font-mono">{{ alias }}</span>
</TableCell>
<TableCell class="text-center">
<ArrowRight class="h-4 w-4 text-muted-foreground" />
</TableCell>
<TableCell>
<span class="text-sm font-mono font-medium">{{ mapping.provider_model_name }}</span>
</TableCell>
<TableCell class="text-center">
<span class="text-xs text-muted-foreground">{{ formatTTL(mapping.ttl) }}</span>
</TableCell>
<TableCell class="text-center">
<span class="text-sm">{{ mapping.hit_count || 0 }}</span>
</TableCell>
<TableCell class="text-right">
<Button
size="icon"
variant="ghost"
class="h-7 w-7 text-muted-foreground/70 hover:text-destructive"
title="清除缓存"
@click="clearProviderModelMapping(mapping.provider_id, mapping.global_model_id, `${mapping.provider_name} - ${alias}`)"
>
<Trash2 class="h-3.5 w-3.5" />
</Button>
</TableCell>
</TableRow>
</template>
</TableBody>
</Table>
<!-- 移动端卡片 -->
<div class="md:hidden divide-y divide-border/40">
<template
v-for="(mapping, index) in modelMappingStats.provider_model_mappings"
:key="`m-pm-${index}`"
>
<div
v-for="(alias, aliasIndex) in (mapping.aliases || [])"
:key="`m-pm-${index}-${aliasIndex}`"
class="p-4 space-y-2"
>
<div class="flex items-center justify-between">
<Badge variant="outline" class="text-xs">
{{ mapping.provider_name }}
</Badge>
<div class="flex items-center gap-2">
<span class="text-xs text-muted-foreground">{{ formatTTL(mapping.ttl) }}</span>
<span class="text-xs">{{ mapping.hit_count || 0 }}</span>
<Button
size="icon"
variant="ghost"
class="h-6 w-6 text-muted-foreground/70 hover:text-destructive"
title="清除缓存"
@click="clearProviderModelMapping(mapping.provider_id, mapping.global_model_id, `${mapping.provider_name} - ${alias}`)"
>
<Trash2 class="h-3 w-3" />
</Button>
</div>
</div>
<div class="flex items-center gap-2 text-sm">
<span class="font-mono">{{ alias }}</span>
<ArrowRight class="h-3.5 w-3.5 shrink-0 text-muted-foreground/60" />
<span class="font-mono font-medium">{{ mapping.provider_model_name }}</span>
</div>
</div>
</template>
</div>
</div>
<!-- 无缓存状态 -->
<div
v-else-if="modelMappingStats?.available && (!modelMappingStats.mappings || modelMappingStats.mappings.length === 0) && (!modelMappingStats.unmapped || modelMappingStats.unmapped.length === 0)"
v-else-if="modelMappingStats?.available && (!modelMappingStats.mappings || modelMappingStats.mappings.length === 0) && (!modelMappingStats.unmapped || modelMappingStats.unmapped.length === 0) && (!modelMappingStats.provider_model_mappings || modelMappingStats.provider_model_mappings.length === 0)"
class="px-6 py-8 text-center text-sm text-muted-foreground"
>
暂无模型解析缓存

View File

@@ -12,6 +12,7 @@ from fastapi.responses import PlainTextResponse
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.context import ApiRequestContext
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_sequence
from src.api.base.pipeline import ApiRequestPipeline
from src.clients.redis_client import get_redis_client_sync
@@ -87,19 +88,19 @@ def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
# 2. 尝试作为 Username 查询
user = db.query(User).filter(User.username == identifier).first()
if user:
logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...")
logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...") # type: ignore[index]
return user.id
# 3. 尝试作为 Email 查询
user = db.query(User).filter(User.email == identifier).first()
if user:
logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...")
logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...") # type: ignore[index]
return user.id
# 4. 尝试作为 API Key ID 查询
api_key = db.query(ApiKey).filter(ApiKey.id == identifier).first()
if api_key:
logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...")
logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...") # type: ignore[index]
return api_key.user_id
# 无法识别
@@ -111,7 +112,7 @@ def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
async def get_cache_stats(
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
获取缓存亲和性统计信息
@@ -131,7 +132,7 @@ async def get_user_affinity(
user_identifier: str,
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
查询指定用户的所有缓存亲和性
@@ -157,7 +158,7 @@ async def list_affinities(
limit: int = Query(100, ge=1, le=1000, description="返回数量限制"),
offset: int = Query(0, ge=0, description="偏移量"),
db: Session = Depends(get_db),
):
) -> Any:
"""
获取所有缓存亲和性列表,可选按关键词过滤
@@ -173,7 +174,7 @@ async def clear_user_cache(
user_identifier: str,
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
Clear cache affinity for a specific user
@@ -188,7 +189,7 @@ async def clear_user_cache(
async def clear_all_cache(
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
Clear all cache affinities
@@ -203,7 +204,7 @@ async def clear_provider_cache(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
Clear cache affinities for a specific provider
@@ -218,7 +219,7 @@ async def clear_provider_cache(
async def get_cache_config(
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
获取缓存相关配置
@@ -234,7 +235,7 @@ async def get_cache_config(
async def get_cache_metrics(
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
以 Prometheus 文本格式暴露缓存调度指标,方便接入 Grafana。
"""
@@ -246,7 +247,7 @@ async def get_cache_metrics(
class AdminCacheStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try:
redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client)
@@ -266,7 +267,7 @@ class AdminCacheStatsAdapter(AdminApiAdapter):
class AdminCacheMetricsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> PlainTextResponse:
try:
redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client)
@@ -391,7 +392,7 @@ class AdminCacheMetricsAdapter(AdminApiAdapter):
class AdminGetUserAffinityAdapter(AdminApiAdapter):
user_identifier: str
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
db = context.db
try:
user_id = resolve_user_identifier(db, self.user_identifier)
@@ -472,7 +473,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
limit: int
offset: int
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
db = context.db
redis_client = get_redis_client_sync()
if not redis_client:
@@ -682,7 +683,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
class AdminClearUserCacheAdapter(AdminApiAdapter):
user_identifier: str
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
db = context.db
try:
redis_client = get_redis_client_sync()
@@ -786,7 +787,7 @@ class AdminClearUserCacheAdapter(AdminApiAdapter):
class AdminClearAllCacheAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try:
redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client)
@@ -806,7 +807,7 @@ class AdminClearAllCacheAdapter(AdminApiAdapter):
class AdminClearProviderCacheAdapter(AdminApiAdapter):
provider_id: str
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try:
redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client)
@@ -829,7 +830,7 @@ class AdminClearProviderCacheAdapter(AdminApiAdapter):
class AdminCacheConfigAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.services.cache.affinity_manager import CacheAffinityManager
from src.services.cache.aware_scheduler import CacheAwareScheduler
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
@@ -878,7 +879,7 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
async def get_model_mapping_cache_stats(
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
获取模型映射缓存统计信息
@@ -895,7 +896,7 @@ async def get_model_mapping_cache_stats(
async def clear_all_model_mapping_cache(
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
清除所有模型映射缓存
@@ -910,7 +911,7 @@ async def clear_model_mapping_cache_by_name(
model_name: str,
request: Request,
db: Session = Depends(get_db),
):
) -> Any:
"""
清除指定模型名称的映射缓存
@@ -921,8 +922,28 @@ async def clear_model_mapping_cache_by_name(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/model-mapping/provider/{provider_id}/{global_model_id}")
async def clear_provider_model_mapping_cache(
provider_id: str,
global_model_id: str,
request: Request,
db: Session = Depends(get_db),
) -> Any:
"""
清除指定 Provider 和 GlobalModel 的模型映射缓存
参数:
- provider_id: Provider ID
- global_model_id: GlobalModel ID
"""
adapter = AdminClearProviderModelMappingCacheAdapter(
provider_id=provider_id, global_model_id=global_model_id
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
import json
from src.clients.redis_client import get_redis_client
@@ -955,7 +976,9 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
if key_str.startswith("model:id:"):
model_id_keys.append(key_str)
elif key_str.startswith("model:provider_global:"):
provider_global_keys.append(key_str)
# 过滤掉 hits 统计键,只保留实际的缓存键
if not key_str.startswith("model:provider_global:hits:"):
provider_global_keys.append(key_str)
async for key in redis.scan_iter(match="global_model:*", count=100):
key_str = key.decode() if isinstance(key, bytes) else key
@@ -1067,6 +1090,85 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
# 按 mapping_name 排序
mappings.sort(key=lambda x: x["mapping_name"])
# 3. 解析 provider_global 缓存Provider 级别的模型解析缓存)
provider_model_mappings = []
# 预加载 Provider 和 GlobalModel 数据
provider_map = {str(p.id): p for p in db.query(Provider).filter(Provider.is_active.is_(True)).all()}
global_model_map = {str(gm.id): gm for gm in db.query(GlobalModel).filter(GlobalModel.is_active.is_(True)).all()}
for key in provider_global_keys[:100]: # 最多处理 100 个
# key 格式: model:provider_global:{provider_id}:{global_model_id}
try:
parts = key.replace("model:provider_global:", "").split(":")
if len(parts) != 2:
continue
provider_id, global_model_id = parts
cached_value = await redis.get(key)
ttl = await redis.ttl(key)
# 获取命中次数
hit_count_key = f"model:provider_global:hits:{provider_id}:{global_model_id}"
hit_count_raw = await redis.get(hit_count_key)
hit_count = int(hit_count_raw) if hit_count_raw else 0
if cached_value:
cached_str = (
cached_value.decode()
if isinstance(cached_value, bytes)
else cached_value
)
try:
cached_data = json.loads(cached_str)
provider_model_name = cached_data.get("provider_model_name")
provider_model_aliases = cached_data.get("provider_model_aliases", [])
# 获取 Provider 和 GlobalModel 信息
provider = provider_map.get(provider_id)
global_model = global_model_map.get(global_model_id)
if provider and global_model:
# 提取别名名称
alias_names = []
if provider_model_aliases:
for alias_entry in provider_model_aliases:
if isinstance(alias_entry, dict) and alias_entry.get("name"):
alias_names.append(alias_entry["name"])
# provider_model_name 为空时跳过
if not provider_model_name:
continue
# 只显示有实际映射的条目:
# 1. 全局模型名 != Provider 模型名(模型名称映射)
# 2. 或者有别名配置
has_name_mapping = global_model.name != provider_model_name
has_aliases = len(alias_names) > 0
if has_name_mapping or has_aliases:
# 构建用于展示的别名列表
# 如果只有名称映射没有别名,则用 global_model_name 作为"请求名称"
display_aliases = alias_names if alias_names else [global_model.name]
provider_model_mappings.append({
"provider_id": provider_id,
"provider_name": provider.display_name or provider.name,
"global_model_id": global_model_id,
"global_model_name": global_model.name,
"global_model_display_name": global_model.display_name,
"provider_model_name": provider_model_name,
"aliases": display_aliases,
"ttl": ttl if ttl > 0 else None,
"hit_count": hit_count,
})
except json.JSONDecodeError:
pass
except Exception as e:
logger.warning(f"解析 provider_global 缓存键 {key} 失败: {e}")
# 按 provider_name + global_model_name 排序
provider_model_mappings.sort(key=lambda x: (x["provider_name"], x["global_model_name"]))
response_data = {
"available": True,
"ttl_seconds": CacheTTL.MODEL,
@@ -1079,6 +1181,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
"global_model_resolve": len(global_model_resolve_keys),
},
"mappings": mappings,
"provider_model_mappings": provider_model_mappings if provider_model_mappings else None,
"unmapped": unmapped_entries if unmapped_entries else None,
}
@@ -1094,7 +1197,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.clients.redis_client import get_redis_client
try:
@@ -1136,7 +1239,7 @@ class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter):
class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter):
model_name: str
async def handle(self, context): # type: ignore[override]
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.clients.redis_client import get_redis_client
try:
@@ -1176,3 +1279,55 @@ class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter):
except Exception as exc:
logger.exception(f"清除模型映射缓存失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
@dataclass
class AdminClearProviderModelMappingCacheAdapter(AdminApiAdapter):
provider_id: str
global_model_id: str
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.clients.redis_client import get_redis_client
try:
redis = await get_redis_client(require_redis=False)
if not redis:
raise HTTPException(status_code=503, detail="Redis 未启用")
deleted_keys = []
# 清除 provider_global 缓存
provider_global_key = f"model:provider_global:{self.provider_id}:{self.global_model_id}"
if await redis.exists(provider_global_key):
await redis.delete(provider_global_key)
deleted_keys.append(provider_global_key)
# 清除对应的 hit_count 缓存
hit_count_key = f"model:provider_global:hits:{self.provider_id}:{self.global_model_id}"
if await redis.exists(hit_count_key):
await redis.delete(hit_count_key)
deleted_keys.append(hit_count_key)
logger.info(
f"已清除 Provider 模型映射缓存: provider_id={self.provider_id[:8]}..., "
f"global_model_id={self.global_model_id[:8]}..., 删除键={deleted_keys}"
)
context.add_audit_metadata(
action="provider_model_mapping_cache_clear",
provider_id=self.provider_id,
global_model_id=self.global_model_id,
deleted_keys=deleted_keys,
)
return {
"status": "ok",
"message": "已清除 Provider 模型映射缓存",
"provider_id": self.provider_id,
"global_model_id": self.global_model_id,
"deleted_keys": deleted_keys,
}
except HTTPException:
raise
except Exception as exc:
logger.exception(f"清除 Provider 模型映射缓存失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")

View File

@@ -628,6 +628,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
"actual_cost": actual_cost,
"rate_multiplier": rate_multiplier,
"response_time_ms": usage.response_time_ms,
"first_byte_time_ms": usage.first_byte_time_ms, # 首字时间 (TTFB)
"created_at": usage.created_at.isoformat(),
"is_stream": usage.is_stream,
"input_price_per_1m": usage.input_price_per_1m,
@@ -738,6 +739,7 @@ class AdminUsageDetailAdapter(AdminApiAdapter):
"status_code": usage_record.status_code,
"error_message": usage_record.error_message,
"response_time_ms": usage_record.response_time_ms,
"first_byte_time_ms": usage_record.first_byte_time_ms, # 首字时间 (TTFB)
"created_at": usage_record.created_at.isoformat() if usage_record.created_at else None,
"request_headers": usage_record.request_headers,
"request_body": usage_record.get_request_body(),

View File

@@ -100,6 +100,8 @@ class MessageTelemetry:
cache_read_tokens: int = 0,
is_stream: bool = False,
provider_request_headers: Optional[Dict[str, Any]] = None,
# 时间指标
first_byte_time_ms: Optional[int] = None, # 首字时间/TTFB
# Provider 侧追踪信息(用于记录真实成本)
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
@@ -133,6 +135,7 @@ class MessageTelemetry:
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms, # 传递首字时间
status_code=status_code,
request_headers=request_headers,
request_body=request_body,
@@ -395,3 +398,24 @@ class BaseMessageHandler:
# 创建后台任务,不阻塞当前流
asyncio.create_task(_do_update())
def _log_request_error(self, message: str, error: Exception) -> None:
"""记录请求错误日志,对业务异常不打印堆栈
Args:
message: 错误消息前缀
error: 异常对象
"""
from src.core.exceptions import (
ProviderException,
QuotaExceededException,
RateLimitException,
ModelNotSupportedException,
)
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
# 业务异常:简洁日志,不打印堆栈
logger.error(f"{message}: [{type(error).__name__}] {error}")
else:
# 未知异常:完整堆栈
logger.exception(f"{message}: {error}")

View File

@@ -34,6 +34,7 @@ from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.stream_processor import StreamProcessor
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
from src.api.handlers.base.utils import build_sse_headers
from src.config.settings import config
from src.core.exceptions import (
EmbeddedErrorException,
@@ -365,7 +366,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
ctx,
original_headers,
original_request_body,
self.elapsed_ms(),
self.start_time, # 传入开始时间,让 telemetry 在流结束后计算响应时间
)
# 创建监控流
@@ -378,11 +379,12 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
return StreamingResponse(
monitored_stream,
media_type="text/event-stream",
headers=build_sse_headers(),
background=background_tasks,
)
except Exception as e:
logger.exception(f"流式请求失败: {e}")
self._log_request_error("流式请求失败", e)
await self._record_stream_failure(ctx, e, original_headers, original_request_body)
raise
@@ -473,12 +475,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
stream_response.raise_for_status()
# 创建行迭代器
line_iterator = stream_response.aiter_lines()
# 使用字节流迭代器(避免 aiter_lines 的性能问题)
# aiter_raw() 返回原始数据块,无缓冲,实现真正的流式传输
byte_iterator = stream_response.aiter_raw()
# 预读检测嵌套错误
prefetched_lines = await stream_processor.prefetch_and_check_error(
line_iterator,
prefetched_chunks = await stream_processor.prefetch_and_check_error(
byte_iterator,
provider,
endpoint,
ctx,
@@ -503,13 +506,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
await http_client.aclose()
raise
# 创建流生成器
# 创建流生成器(传入字节流迭代器)
return stream_processor.create_response_stream(
ctx,
line_iterator,
byte_iterator,
response_ctx,
http_client,
prefetched_lines,
prefetched_chunks,
start_time=self.start_time,
)
async def _record_stream_failure(

View File

@@ -11,17 +11,15 @@ CLI Message Handler 通用基类
"""
import asyncio
import codecs
import json
import time
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Optional,
Tuple,
)
import httpx
@@ -35,6 +33,8 @@ from src.api.handlers.base.base_handler import (
)
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.utils import build_sse_headers
# 直接从具体模块导入,避免循环依赖
from src.api.handlers.base.response_parser import (
@@ -61,63 +61,6 @@ from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser
@dataclass
class StreamContext:
"""流式请求的上下文信息"""
# 请求信息
model: str = "unknown" # 用户请求的原始模型名
mapped_model: Optional[str] = None # 映射后的目标模型名(如果发生了映射)
api_format: str = ""
request_id: str = ""
# 用户信息(提前提取避免 Session detached
user_id: int = 0
api_key_id: int = 0
# 统计信息
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0 # cache_read_input_tokens
cache_creation_tokens: int = 0 # cache_creation_input_tokens
collected_text: str = ""
response_id: Optional[str] = None
final_usage: Optional[Dict[str, Any]] = None
final_response: Optional[Dict[str, Any]] = None
parsed_chunks: list = field(default_factory=list)
# 流状态
start_time: float = field(default_factory=time.time)
chunk_count: int = 0
data_count: int = 0
has_completion: bool = False
# 响应信息
status_code: int = 200
response_headers: Dict[str, str] = field(default_factory=dict)
# 请求信息(发送给 Provider 的)
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None # 实际发送的请求体
# Provider 信息
provider_name: Optional[str] = None
provider_id: Optional[str] = None # Provider ID用于记录真实成本
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
attempt_id: Optional[str] = None
attempt_synced: bool = False
error_message: Optional[str] = None
# 格式转换信息
provider_api_format: str = "" # Provider 的 API 格式(用于响应转换)
client_api_format: str = "" # 客户端请求的 API 格式
# Provider 响应元数据(存储 provider 返回的额外信息,如 Gemini 的 modelVersion
response_metadata: Dict[str, Any] = field(default_factory=dict)
class CliMessageHandlerBase(BaseMessageHandler):
"""
CLI 格式消息处理器基类
@@ -409,24 +352,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
return StreamingResponse(
monitored_stream,
media_type="text/event-stream",
headers=build_sse_headers(),
background=background_tasks,
)
except Exception as e:
# 对于已知的业务异常,只记录简洁的错误信息,不输出完整堆栈
from src.core.exceptions import (
ProviderException,
QuotaExceededException,
RateLimitException,
ModelNotSupportedException,
)
if isinstance(e, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
# 业务异常:简洁日志
logger.error(f"流式请求失败: [{type(e).__name__}] {e}")
else:
# 未知异常:完整堆栈
logger.exception(f"流式请求失败: {e}")
self._log_request_error("流式请求失败", e)
await self._record_stream_failure(ctx, e, original_headers, original_request_body)
raise
@@ -446,7 +377,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
ctx.chunk_count = 0
ctx.data_count = 0
ctx.has_completion = False
ctx.collected_text = ""
ctx._collected_text_parts = [] # 重置文本收集
ctx.input_tokens = 0
ctx.output_tokens = 0
ctx.cached_tokens = 0
@@ -534,12 +465,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
stream_response.raise_for_status()
# 创建行迭代器(只创建一次,后续会继续使用
line_iterator = stream_response.aiter_lines()
# 使用字节流迭代器(避免 aiter_lines 的性能问题
byte_iterator = stream_response.aiter_raw()
# 预读第一个数据块检测嵌套错误HTTP 200 但响应体包含错误)
prefetched_lines = await self._prefetch_and_check_embedded_error(
line_iterator, provider, endpoint, ctx
prefetched_chunks = await self._prefetch_and_check_embedded_error(
byte_iterator, provider, endpoint, ctx
)
except httpx.HTTPStatusError as e:
@@ -564,10 +495,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 创建流生成器(带预读数据,使用同一个迭代器)
return self._create_response_stream_with_prefetch(
ctx,
line_iterator,
byte_iterator,
response_ctx,
http_client,
prefetched_lines,
prefetched_chunks,
)
async def _create_response_stream(
@@ -577,58 +508,75 @@ class CliMessageHandlerBase(BaseMessageHandler):
response_ctx: Any,
http_client: httpx.AsyncClient,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器"""
"""创建响应流生成器(使用字节流)"""
try:
sse_parser = SSEEventParser()
last_data_time = time.time()
streaming_status_updated = False
buffer = b""
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx)
async for line in stream_response.aiter_lines():
async for chunk in stream_response.aiter_raw():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming(ctx.request_id)
streaming_status_updated = True
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
yield b"\n"
continue
continue
ctx.chunk_count += 1
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return # 结束生成器
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
yield b"\n"
continue
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
ctx.chunk_count += 1
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return # 结束生成器
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
@@ -702,7 +650,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
async def _prefetch_and_check_embedded_error(
self,
line_iterator: Any,
byte_iterator: Any,
provider: Provider,
endpoint: ProviderEndpoint,
ctx: StreamContext,
@@ -716,20 +664,25 @@ class CliMessageHandlerBase(BaseMessageHandler):
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
Args:
line_iterator: 行迭代器aiter_lines() 返回的迭代器
byte_iterator: 字节流迭代器
provider: Provider 对象
endpoint: Endpoint 对象
ctx: 流上下文
Returns:
预读的列表(需要在后续流中先输出)
预读的字节块列表(需要在后续流中先输出)
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
"""
prefetched_lines: list = []
prefetched_chunks: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误
buffer = b""
line_count = 0
should_stop = False
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
try:
# 获取对应格式的解析器
@@ -742,69 +695,86 @@ class CliMessageHandlerBase(BaseMessageHandler):
else:
provider_parser = self.parser
line_count = 0
async for line in line_iterator:
prefetched_lines.append(line)
line_count += 1
async for chunk in byte_iterator:
prefetched_chunks.append(chunk)
buffer += chunk
# 解析数据
normalized_line = line.rstrip("\r")
# 尝试按行解析缓冲区
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
# 检测 HTML 响应base_url 配置错误的常见症状)
lower_line = normalized_line.lower()
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
)
line_count += 1
normalized_line = line.rstrip("\r")
if not normalized_line or normalized_line.startswith(":"):
# 空行或注释行,继续预读
if line_count >= max_prefetch_lines:
# 检测 HTML 响应base_url 配置错误的常见症状)
lower_line = normalized_line.lower()
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
)
if not normalized_line or normalized_line.startswith(":"):
# 空行或注释行,继续预读
if line_count >= max_prefetch_lines:
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
if data_str == "[DONE]":
should_stop = True
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
try:
data = json.loads(data_str)
except json.JSONDecodeError:
# 不是有效 JSON可能是部分数据继续
if line_count >= max_prefetch_lines:
break
continue
if data_str == "[DONE]":
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
should_stop = True
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
# 不是有效 JSON可能是部分数据继续
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
break
if should_stop or line_count >= max_prefetch_lines:
break
except EmbeddedErrorException:
# 重新抛出嵌套错误
@@ -813,112 +783,168 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_lines
return prefetched_chunks
async def _create_response_stream_with_prefetch(
self,
ctx: StreamContext,
line_iterator: Any,
byte_iterator: Any,
response_ctx: Any,
http_client: httpx.AsyncClient,
prefetched_lines: list,
prefetched_chunks: list,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器(带预读数据)"""
"""创建响应流生成器(带预读数据,使用字节流"""
try:
sse_parser = SSEEventParser()
last_data_time = time.time()
buffer = b""
first_yield = True # 标记是否是第一次 yield
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx)
# 在第一次输出数据前更新状态为 streaming
if prefetched_lines:
if prefetched_chunks:
self._update_usage_to_streaming(ctx.request_id)
# 先处理预读的数据
for line in prefetched_lines:
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
# 先处理预读的字节块
for chunk in prefetched_chunks:
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield b"\n"
continue
ctx.chunk_count += 1
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (line + "\n").encode("utf-8")
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
yield b"\n"
continue
ctx.chunk_count += 1
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
if ctx.data_count > 0:
last_data_time = time.time()
if ctx.data_count > 0:
last_data_time = time.time()
# 继续处理剩余的流数据(使用同一个迭代器)
async for line in line_iterator:
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
async for chunk in byte_iterator:
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield b"\n"
continue
ctx.chunk_count += 1
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (line + "\n").encode("utf-8")
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
yield b"\n"
continue
ctx.chunk_count += 1
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
if ctx.data_count > 0:
last_data_time = time.time()
if ctx.data_count > 0:
last_data_time = time.time()
# 处理剩余事件
flushed_events = sse_parser.flush()
@@ -1047,7 +1073,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 提取文本内容
text = self.parser.extract_text_content(data)
if text:
ctx.collected_text += text
ctx.append_text(text)
# 检查完成事件
if event_type in ("response.completed", "message_stop"):
@@ -1099,9 +1125,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
) -> None:
"""在流完成后记录统计信息"""
try:
await asyncio.sleep(0.1)
# 使用 self.start_time 作为时间基准,与首字时间保持一致
# 注意:不要把统计延迟算进响应时间里
response_time_ms = int((time.time() - self.start_time) * 1000)
response_time_ms = int((time.time() - ctx.start_time) * 1000)
await asyncio.sleep(0.1)
if not ctx.provider_name:
logger.warning(f"[{ctx.request_id}] 流式请求失败,未选中提供商")
@@ -1181,6 +1209,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
input_tokens=actual_input_tokens,
output_tokens=ctx.output_tokens,
response_time_ms=response_time_ms,
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
status_code=ctx.status_code,
request_headers=original_headers,
request_body=actual_request_body,
@@ -1201,9 +1230,18 @@ class CliMessageHandlerBase(BaseMessageHandler):
response_metadata=ctx.response_metadata if ctx.response_metadata else None,
)
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
# 简洁的请求完成摘要
logger.info(f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | "
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}")
# 简洁的请求完成摘要(两行格式)
line1 = (
f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name}"
)
if ctx.first_byte_time_ms:
line1 += f" | TTFB: {ctx.first_byte_time_ms}ms"
line2 = (
f" Total: {response_time_ms}ms | "
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}"
)
logger.info(f"{line1}\n{line2}")
# 更新候选记录的最终状态和延迟时间
# 注意RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
@@ -1255,7 +1293,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
original_request_body: Dict[str, Any],
) -> None:
"""记录流式请求失败"""
response_time_ms = int((time.time() - ctx.start_time) * 1000)
# 使用 self.start_time 作为时间基准,与首字时间保持一致
response_time_ms = int((time.time() - self.start_time) * 1000)
status_code = 503
if isinstance(error, ProviderAuthException):

View File

@@ -13,6 +13,7 @@ from src.api.handlers.base.response_parser import (
ResponseParser,
StreamStats,
)
from src.api.handlers.base.utils import extract_cache_creation_tokens
def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[str, Any]]]:
@@ -252,7 +253,7 @@ class ClaudeResponseParser(ResponseParser):
usage = response.get("usage", {})
result.input_tokens = usage.get("input_tokens", 0)
result.output_tokens = usage.get("output_tokens", 0)
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0)
result.cache_creation_tokens = extract_cache_creation_tokens(usage)
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
# 检查错误(支持嵌套错误格式)
@@ -265,11 +266,16 @@ class ClaudeResponseParser(ResponseParser):
return result
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
# 对于 message_start 事件usage 在 message.usage 路径下
# 对于其他响应usage 在顶层
usage = response.get("usage", {})
if not usage and "message" in response:
usage = response.get("message", {}).get("usage", {})
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
"cache_creation_tokens": extract_cache_creation_tokens(usage),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
}

View File

@@ -8,6 +8,7 @@
- 请求/响应数据
"""
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
@@ -25,12 +26,18 @@ class StreamContext:
model: str
api_format: str
# 请求标识信息CLI handler 需要)
request_id: str = ""
user_id: int = 0
api_key_id: int = 0
# Provider 信息(在请求执行时填充)
provider_name: Optional[str] = None
provider_id: Optional[str] = None
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
attempt_id: Optional[str] = None
attempt_synced: bool = False
provider_api_format: Optional[str] = None # Provider 的响应格式
# 模型映射
@@ -43,7 +50,14 @@ class StreamContext:
cache_creation_tokens: int = 0
# 响应内容
collected_text: str = ""
_collected_text_parts: List[str] = field(default_factory=list, repr=False)
response_id: Optional[str] = None
final_usage: Optional[Dict[str, Any]] = None
final_response: Optional[Dict[str, Any]] = None
# 时间指标
first_byte_time_ms: Optional[int] = None # 首字时间 (TTFB - Time To First Byte)
start_time: float = field(default_factory=time.time)
# 响应状态
status_code: int = 200
@@ -55,6 +69,12 @@ class StreamContext:
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None
# 格式转换信息CLI handler 需要)
client_api_format: str = ""
# Provider 响应元数据CLI handler 需要)
response_metadata: Dict[str, Any] = field(default_factory=dict)
# 流式处理统计
data_count: int = 0
chunk_count: int = 0
@@ -71,16 +91,30 @@ class StreamContext:
self.chunk_count = 0
self.data_count = 0
self.has_completion = False
self.collected_text = ""
self._collected_text_parts = []
self.input_tokens = 0
self.output_tokens = 0
self.cached_tokens = 0
self.cache_creation_tokens = 0
self.error_message = None
self.status_code = 200
self.first_byte_time_ms = None
self.response_headers = {}
self.provider_request_headers = {}
self.provider_request_body = None
self.response_id = None
self.final_usage = None
self.final_response = None
@property
def collected_text(self) -> str:
"""已收集的文本内容(按需拼接,避免在流式过程中频繁做字符串拷贝)"""
return "".join(self._collected_text_parts)
def append_text(self, text: str) -> None:
"""追加文本内容(仅在需要收集文本时调用)"""
if text:
self._collected_text_parts.append(text)
def update_provider_info(
self,
@@ -104,14 +138,40 @@ class StreamContext:
cached_tokens: Optional[int] = None,
cache_creation_tokens: Optional[int] = None,
) -> None:
"""更新 Token 使用统计"""
if input_tokens is not None:
"""
更新 Token 使用统计
采用防御性更新策略:只有当新值 > 0 或当前值为 0 时才更新,避免用 0 覆盖已有的正确值。
设计原理:
- 在流式响应中,某些事件可能不包含完整的 usage 信息(字段为 0 或不存在)
- 后续事件可能会提供完整的统计数据
- 通过这种策略,确保一旦获得非零值就保留它,不会被后续的 0 值覆盖
示例场景:
- message_start 事件input_tokens=100, output_tokens=0
- message_delta 事件input_tokens=0, output_tokens=50
- 最终结果input_tokens=100, output_tokens=50
注意事项:
- 此策略假设初始值为 0 是正确的默认状态
- 如果需要将已有值重置为 0请直接修改实例属性不使用此方法
Args:
input_tokens: 输入 tokens 数量
output_tokens: 输出 tokens 数量
cached_tokens: 缓存命中 tokens 数量
cache_creation_tokens: 缓存创建 tokens 数量
"""
if input_tokens is not None and (input_tokens > 0 or self.input_tokens == 0):
self.input_tokens = input_tokens
if output_tokens is not None:
if output_tokens is not None and (output_tokens > 0 or self.output_tokens == 0):
self.output_tokens = output_tokens
if cached_tokens is not None:
if cached_tokens is not None and (cached_tokens > 0 or self.cached_tokens == 0):
self.cached_tokens = cached_tokens
if cache_creation_tokens is not None:
if cache_creation_tokens is not None and (
cache_creation_tokens > 0 or self.cache_creation_tokens == 0
):
self.cache_creation_tokens = cache_creation_tokens
def mark_failed(self, status_code: int, error_message: str) -> None:
@@ -119,6 +179,19 @@ class StreamContext:
self.status_code = status_code
self.error_message = error_message
def record_first_byte_time(self, start_time: float) -> None:
"""
记录首字时间 (TTFB - Time To First Byte)
应在第一次向客户端发送数据时调用。
如果已记录过,则不会覆盖(避免重试时重复记录)。
Args:
start_time: 请求开始时间 (time.time())
"""
if self.first_byte_time_ms is None:
self.first_byte_time_ms = int((time.time() - start_time) * 1000)
def is_success(self) -> bool:
"""检查请求是否成功"""
return self.status_code < 400
@@ -145,10 +218,22 @@ class StreamContext:
获取日志摘要
用于请求完成/失败时的日志输出。
包含首字时间 (TTFB) 和总响应时间,分两行显示。
"""
status = "OK" if self.is_success() else "FAIL"
return (
# 第一行:基本信息 + 首字时间
line1 = (
f"[{status}] {request_id[:8]} | {self.model} | "
f"{self.provider_name or 'unknown'} | {response_time_ms}ms | "
f"{self.provider_name or 'unknown'}"
)
if self.first_byte_time_ms is not None:
line1 += f" | TTFB: {self.first_byte_time_ms}ms"
# 第二行:总响应时间 + tokens
line2 = (
f" Total: {response_time_ms}ms | "
f"in:{self.input_tokens} out:{self.output_tokens}"
)
return f"{line1}\n{line2}"

View File

@@ -9,7 +9,9 @@
"""
import asyncio
import codecs
import json
import time
from typing import Any, AsyncGenerator, Callable, Optional
import httpx
@@ -36,6 +38,8 @@ class StreamProcessor:
request_id: str,
default_parser: ResponseParser,
on_streaming_start: Optional[Callable[[], None]] = None,
*,
collect_text: bool = False,
):
"""
初始化流处理器
@@ -48,6 +52,7 @@ class StreamProcessor:
self.request_id = request_id
self.default_parser = default_parser
self.on_streaming_start = on_streaming_start
self.collect_text = collect_text
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
"""
@@ -112,9 +117,10 @@ class StreamProcessor:
)
# 提取文本
text = parser.extract_text_content(data)
if text:
ctx.collected_text += text
if self.collect_text:
text = parser.extract_text_content(data)
if text:
ctx.append_text(text)
# 检查完成
event_type = event_name or data.get("type", "")
@@ -123,7 +129,7 @@ class StreamProcessor:
async def prefetch_and_check_error(
self,
line_iterator: Any,
byte_iterator: Any,
provider: Provider,
endpoint: ProviderEndpoint,
ctx: StreamContext,
@@ -136,97 +142,126 @@ class StreamProcessor:
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
Args:
line_iterator: 迭代器
byte_iterator: 字节流迭代器
provider: Provider 对象
endpoint: Endpoint 对象
ctx: 流式上下文
max_prefetch_lines: 最多预读行数
Returns:
预读的列表
预读的字节块列表
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
"""
prefetched_lines: list = []
prefetched_chunks: list = []
parser = self.get_parser_for_provider(ctx)
buffer = b""
line_count = 0
should_stop = False
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
try:
line_count = 0
async for line in line_iterator:
prefetched_lines.append(line)
line_count += 1
async for chunk in byte_iterator:
prefetched_chunks.append(chunk)
buffer += chunk
normalized_line = line.rstrip("\r")
if not normalized_line or normalized_line.startswith(":"):
if line_count >= max_prefetch_lines:
# 尝试按行解析缓冲区
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\r\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
line_count += 1
# 跳过空行和注释行
if not line or line.startswith(":"):
if line_count >= max_prefetch_lines:
should_stop = True
break
continue
# 尝试解析 SSE 数据
data_str = line
if line.startswith("data: "):
data_str = line[6:]
if data_str == "[DONE]":
should_stop = True
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
try:
data = json.loads(data_str)
except json.JSONDecodeError:
if line_count >= max_prefetch_lines:
should_stop = True
break
continue
if data_str == "[DONE]":
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and parser.is_error_response(data):
parsed = parser.parse_response(data, 200)
logger.warning(
f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}"
)
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
should_stop = True
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and parser.is_error_response(data):
parsed = parser.parse_response(data, 200)
logger.warning(
f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}"
)
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
break
if should_stop or line_count >= max_prefetch_lines:
break
except EmbeddedErrorException:
raise
except Exception as e:
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_lines
return prefetched_chunks
async def create_response_stream(
self,
ctx: StreamContext,
line_iterator: Any,
byte_iterator: Any,
response_ctx: Any,
http_client: httpx.AsyncClient,
prefetched_lines: Optional[list] = None,
prefetched_chunks: Optional[list] = None,
*,
start_time: Optional[float] = None,
) -> AsyncGenerator[bytes, None]:
"""
创建响应流生成器
统一的流生成器,支持预读数据和不带预读数据两种情况
从字节流中解析 SSE 数据并转发,支持预读数据。
Args:
ctx: 流式上下文
line_iterator: 迭代器
byte_iterator: 字节流迭代器
response_ctx: HTTP 响应上下文管理器
http_client: HTTP 客户端
prefetched_lines: 预读的列表(可选)
prefetched_chunks: 预读的字节块列表(可选)
start_time: 请求开始时间,用于计算 TTFB可选
Yields:
编码后的响应数据块
@@ -234,25 +269,82 @@ class StreamProcessor:
try:
sse_parser = SSEEventParser()
streaming_started = False
buffer = b""
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 处理预读数据
if prefetched_lines:
if prefetched_chunks:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
for line in prefetched_lines:
for chunk in self._process_line(ctx, sse_parser, line):
yield chunk
for chunk in prefetched_chunks:
# 记录首字时间 (TTFB) - 在 yield 之前记录
if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 把原始数据转发给客户端
yield chunk
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False)
self._process_line(ctx, sse_parser, line)
except Exception as e:
# 解码失败,记录警告但继续处理
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
# 处理剩余的流数据
async for line in line_iterator:
async for chunk in byte_iterator:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
for chunk in self._process_line(ctx, sse_parser, line):
yield chunk
# 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 原始数据透传
yield chunk
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False)
self._process_line(ctx, sse_parser, line)
except Exception as e:
# 解码失败,记录警告但继续处理
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
# 处理剩余的缓冲区数据(如果有未完成的行)
if buffer:
try:
# 使用 final=True 处理最后的不完整字符
line = decoder.decode(buffer, True)
self._process_line(ctx, sse_parser, line)
except Exception as e:
logger.warning(
f"[{self.request_id}] 处理剩余缓冲区失败: {e}, "
f"bytes={buffer[:50]!r}"
)
# 处理剩余事件
for event in sse_parser.flush():
@@ -268,7 +360,7 @@ class StreamProcessor:
ctx: StreamContext,
sse_parser: SSEEventParser,
line: str,
) -> list[bytes]:
) -> None:
"""
处理单行数据
@@ -276,26 +368,17 @@ class StreamProcessor:
ctx: 流式上下文
sse_parser: SSE 解析器
line: 原始行数据
Returns:
要发送的数据块列表
"""
result: list[bytes] = []
normalized_line = line.rstrip("\r")
# SSEEventParser 以“去掉换行符”的单行文本作为输入;这里统一剔除 CR/LF
# 避免把空行误判成 "\n" 并导致事件边界解析错误。
normalized_line = line.rstrip("\r\n")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
result.append(b"\n")
else:
if normalized_line != "":
ctx.chunk_count += 1
result.append((line + "\n").encode("utf-8"))
for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
return result
for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
async def create_monitored_stream(
self,
@@ -317,16 +400,26 @@ class StreamProcessor:
响应数据块
"""
try:
# 断连检查频率:每次 await 都会引入调度开销,过于频繁会让流式"发一段停一段"
# 这里按时间间隔节流,兼顾及时停止上游读取与吞吐平滑性。
next_disconnect_check_at = 0.0
disconnect_check_interval_s = 0.25
async for chunk in stream_generator:
if await is_disconnected():
logger.warning(f"ID:{self.request_id} | Client disconnected")
ctx.status_code = 499 # Client Closed Request
ctx.error_message = "client_disconnected"
break
now = time.monotonic()
if now >= next_disconnect_check_at:
next_disconnect_check_at = now + disconnect_check_interval_s
if await is_disconnected():
logger.warning(f"ID:{self.request_id} | Client disconnected")
ctx.status_code = 499 # Client Closed Request
ctx.error_message = "client_disconnected"
break
yield chunk
except asyncio.CancelledError:
ctx.status_code = 499
ctx.error_message = "client_disconnected"
raise
except Exception as e:
ctx.status_code = 500

View File

@@ -8,6 +8,7 @@
"""
import asyncio
import time
from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
@@ -57,7 +58,7 @@ class StreamTelemetryRecorder:
ctx: StreamContext,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
response_time_ms: int,
start_time: float,
) -> None:
"""
记录流式统计信息
@@ -66,11 +67,15 @@ class StreamTelemetryRecorder:
ctx: 流式上下文
original_headers: 原始请求头
original_request_body: 原始请求体
response_time_ms: 响应时间(毫秒)
start_time: 请求开始时间 (time.time())
"""
bg_db = None
try:
# 在流结束后计算响应时间,与首字时间使用相同的时间基准
# 注意不要把统计延迟stream_stats_delay算进响应时间里
response_time_ms = int((time.time() - start_time) * 1000)
await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭
if not ctx.provider_name:
@@ -155,6 +160,7 @@ class StreamTelemetryRecorder:
input_tokens=ctx.input_tokens,
output_tokens=ctx.output_tokens,
response_time_ms=response_time_ms,
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
status_code=ctx.status_code,
request_headers=original_headers,
request_body=actual_request_body,

View File

@@ -0,0 +1,55 @@
"""
Handler 基础工具函数
"""
from typing import Any, Dict, Optional
def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
"""
提取缓存创建 tokens兼容新旧格式
Claude API 在不同版本中使用了不同的字段名来表示缓存创建 tokens
- 新格式2024年后使用 claude_cache_creation_5_m_tokens 和
claude_cache_creation_1_h_tokens 分别表示 5 分钟和 1 小时缓存
- 旧格式:使用 cache_creation_input_tokens 表示总的缓存创建 tokens
此函数自动检测并适配两种格式,优先使用新格式。
Args:
usage: API 响应中的 usage 字典
Returns:
缓存创建 tokens 总数
"""
# 检查新格式字段是否存在(而非值是否为 0
# 如果字段存在,即使值为 0 也是合法的,不应 fallback 到旧格式
has_new_format = (
"claude_cache_creation_5_m_tokens" in usage
or "claude_cache_creation_1_h_tokens" in usage
)
if has_new_format:
cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0)
cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0)
return int(cache_5m) + int(cache_1h)
# 回退到旧格式
return int(usage.get("cache_creation_input_tokens", 0))
def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
"""
构建 SSEtext/event-stream推荐响应头用于减少代理缓冲带来的卡顿/成段输出。
说明:
- Cache-Control: no-transform 可避免部分代理对流做压缩/改写导致缓冲
- X-Accel-Buffering: no 可显式提示 Nginx 关闭缓冲(即使全局已关闭也无害)
"""
headers: Dict[str, str] = {
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
}
if extra_headers:
headers.update(extra_headers)
return headers

View File

@@ -8,6 +8,7 @@ Claude Chat Handler - 基于通用 Chat Handler 基类的简化实现
from typing import Any, Dict, Optional
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.api.handlers.base.utils import extract_cache_creation_tokens
class ClaudeChatHandler(ChatHandlerBase):
@@ -63,7 +64,7 @@ class ClaudeChatHandler(ChatHandlerBase):
result["model"] = mapped_model
return result
async def _convert_request(self, request):
async def _convert_request(self, request: Any) -> Any:
"""
将请求转换为 Claude 格式
@@ -109,30 +110,18 @@ class ClaudeChatHandler(ChatHandlerBase):
Claude 格式使用:
- input_tokens / output_tokens
- cache_creation_input_tokens / cache_read_input_tokens
- 新格式claude_cache_creation_5_m_tokens / claude_cache_creation_1_h_tokens
"""
usage = response.get("usage", {})
input_tokens = usage.get("input_tokens", 0)
output_tokens = usage.get("output_tokens", 0)
cache_creation_input_tokens = usage.get("cache_creation_input_tokens", 0)
cache_read_input_tokens = usage.get("cache_read_input_tokens", 0)
# 处理新的 cache_creation 格式
if "cache_creation" in usage:
cache_creation_data = usage.get("cache_creation", {})
if not cache_creation_input_tokens:
cache_creation_input_tokens = cache_creation_data.get(
"ephemeral_5m_input_tokens", 0
) + cache_creation_data.get("ephemeral_1h_input_tokens", 0)
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"cache_creation_input_tokens": cache_creation_input_tokens,
"cache_read_input_tokens": cache_read_input_tokens,
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_input_tokens": extract_cache_creation_tokens(usage),
"cache_read_input_tokens": usage.get("cache_read_input_tokens", 0),
}
def _normalize_response(self, response: Dict) -> Dict:
def _normalize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""
规范化 Claude 响应
@@ -143,8 +132,9 @@ class ClaudeChatHandler(ChatHandlerBase):
规范化后的响应
"""
if self.response_normalizer and self.response_normalizer.should_normalize(response):
return self.response_normalizer.normalize_claude_response(
result: Dict[str, Any] = self.response_normalizer.normalize_claude_response(
response_data=response,
request_id=self.request_id,
)
return result
return response

View File

@@ -9,6 +9,8 @@ from __future__ import annotations
import json
from typing import Any, Dict, List, Optional
from src.api.handlers.base.utils import extract_cache_creation_tokens
class ClaudeStreamParser:
"""
@@ -193,7 +195,7 @@ class ClaudeStreamParser:
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
"cache_creation_tokens": extract_cache_creation_tokens(usage),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
}
@@ -204,7 +206,7 @@ class ClaudeStreamParser:
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
"cache_creation_tokens": extract_cache_creation_tokens(usage),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
}

View File

@@ -11,6 +11,7 @@ from src.api.handlers.base.cli_handler_base import (
CliMessageHandlerBase,
StreamContext,
)
from src.api.handlers.base.utils import extract_cache_creation_tokens
class ClaudeCliMessageHandler(CliMessageHandlerBase):
@@ -95,11 +96,12 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
usage = message.get("usage", {})
if usage:
ctx.input_tokens = usage.get("input_tokens", 0)
# Claude 的缓存 tokens 使用不同的字段名
cache_read = usage.get("cache_read_input_tokens", 0)
if cache_read:
ctx.cached_tokens = cache_read
cache_creation = usage.get("cache_creation_input_tokens", 0)
cache_creation = extract_cache_creation_tokens(usage)
if cache_creation:
ctx.cache_creation_tokens = cache_creation
@@ -109,7 +111,7 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
if delta.get("type") == "text_delta":
text = delta.get("text", "")
if text:
ctx.collected_text += text
ctx.append_text(text)
# 处理消息增量(包含最终 usage
elif event_type == "message_delta":
@@ -119,11 +121,15 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
ctx.input_tokens = usage["input_tokens"]
if "output_tokens" in usage:
ctx.output_tokens = usage["output_tokens"]
# 更新缓存 tokens
# 更新缓存读取 tokens
if "cache_read_input_tokens" in usage:
ctx.cached_tokens = usage["cache_read_input_tokens"]
if "cache_creation_input_tokens" in usage:
ctx.cache_creation_tokens = usage["cache_creation_input_tokens"]
# 更新缓存创建 tokens
cache_creation = extract_cache_creation_tokens(usage)
if cache_creation > 0:
ctx.cache_creation_tokens = cache_creation
# 检查是否结束
delta = data.get("delta", {})

View File

@@ -160,7 +160,7 @@ class GeminiCliMessageHandler(CliMessageHandlerBase):
parts = content.get("parts", [])
for part in parts:
if "text" in part:
ctx.collected_text += part["text"]
ctx.append_text(part["text"])
# 检查结束原因
finish_reason = candidate.get("finishReason")

View File

@@ -94,9 +94,9 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
if event_type in ["response.output_text.delta", "response.outtext.delta"]:
delta = data.get("delta")
if isinstance(delta, str):
ctx.collected_text += delta
ctx.append_text(delta)
elif isinstance(delta, dict) and "text" in delta:
ctx.collected_text += delta["text"]
ctx.append_text(delta["text"])
# 处理完成事件
elif event_type == "response.completed":
@@ -124,7 +124,7 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
if content_item.get("type") == "output_text":
text = content_item.get("text", "")
if text:
ctx.collected_text += text
ctx.append_text(text)
# 备用:从顶层 usage 提取
usage_obj = data.get("usage")

View File

@@ -120,6 +120,33 @@ class CacheService:
logger.warning(f"缓存检查失败: {key} - {e}")
return False
@staticmethod
async def incr(key: str, ttl_seconds: Optional[int] = None) -> int:
"""
递增缓存值
Args:
key: 缓存键
ttl_seconds: 可选,如果提供则刷新 TTL
Returns:
递增后的值,如果失败返回 0
"""
try:
redis = await get_redis_client(require_redis=False)
if not redis:
return 0
result = await redis.incr(key)
# 如果提供了 TTL刷新过期时间
if ttl_seconds is not None:
await redis.expire(key, ttl_seconds)
return result
except Exception as e:
logger.warning(f"缓存递增失败: {key} - {e}")
return 0
# 缓存键前缀
class CacheKeys:

View File

@@ -307,7 +307,8 @@ class Usage(Base):
is_stream = Column(Boolean, default=False) # 是否为流式请求
status_code = Column(Integer)
error_message = Column(Text, nullable=True)
response_time_ms = Column(Integer) # 响应时间(毫秒)
response_time_ms = Column(Integer) # 响应时间(毫秒)
first_byte_time_ms = Column(Integer, nullable=True) # 首字时间/TTFB毫秒
# 请求状态追踪
# pending: 请求开始处理中

View File

@@ -2,11 +2,9 @@
Model 映射缓存服务 - 减少模型查询
"""
import json
import time
from typing import Optional
from typing import List, Optional
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
@@ -106,6 +104,7 @@ class ModelCacheService:
Model 对象或 None
"""
cache_key = f"model:provider_global:{provider_id}:{global_model_id}"
hit_count_key = f"model:provider_global:hits:{provider_id}:{global_model_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
@@ -113,6 +112,8 @@ class ModelCacheService:
logger.debug(
f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
)
# 递增命中计数,同时刷新 TTL
await CacheService.incr(hit_count_key, ttl_seconds=ModelCacheService.CACHE_TTL)
return ModelCacheService._dict_to_model(cached_data)
# 2. 缓存未命中,查询数据库
@@ -130,6 +131,8 @@ class ModelCacheService:
if model:
model_dict = ModelCacheService._model_to_dict(model)
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
# 重置命中计数新缓存从1开始
await CacheService.set(hit_count_key, 1, ttl_seconds=ModelCacheService.CACHE_TTL)
logger.debug(
f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
)
@@ -189,9 +192,10 @@ class ModelCacheService:
# 清除 model:id 缓存
await CacheService.delete(f"model:id:{model_id}")
# 清除 provider_global 缓存(如果提供了必要参数)
# 清除 provider_global 缓存及其命中计数(如果提供了必要参数)
if provider_id and global_model_id:
await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}")
await CacheService.delete(f"model:provider_global:hits:{provider_id}:{global_model_id}")
logger.debug(
f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}..."
)
@@ -230,16 +234,20 @@ class ModelCacheService:
db: Session, model_name: str
) -> Optional[GlobalModel]:
"""
通过名称或映射解析 GlobalModel带缓存,支持映射匹配
通过名称解析 GlobalModel带缓存
查找顺序:
1. 检查缓存
2. 通过映射匹配(查询 Model 表的 provider_model_name 和 provider_model_aliases
2. 通过 provider_model_name 匹配(查询 Model 表
3. 直接匹配 GlobalModel.name兜底
注意:此方法不使用 provider_model_aliases 进行全局解析。
provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效,
由 resolve_provider_model() 处理。
Args:
db: 数据库会话
model_name: 模型名称(可以是 GlobalModel.name 或映射名称
model_name: 模型名称(可以是 GlobalModel.name 或 provider_model_name
Returns:
GlobalModel 对象或 None
@@ -273,116 +281,53 @@ class ModelCacheService:
logger.debug(f"GlobalModel 缓存命中(映射解析): {normalized_name}")
return ModelCacheService._dict_to_global_model(cached_data)
# 2. 优先通过 provider_model_name 和映射名称匹配Provider 配置优先级最高
from sqlalchemy import or_
# 2. 通过 provider_model_name 匹配(不考虑 provider_model_aliases
# 重要provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效
# 全局解析不应该受到某个 Provider 别名配置的影响
# 例如Provider A 把 "haiku" 映射到 "sonnet",不应该影响 Provider B 的 "haiku" 解析
from src.models.database import Provider
# 构建精确的映射匹配条件
# 注意provider_model_aliases 是 JSONB 数组,需要使用 PostgreSQL 的 JSONB 操作符
# 对于 SQLite会在 Python 层面进行过滤
try:
# 尝试使用 PostgreSQL 的 JSONB 查询(更高效)
# 使用 json.dumps 确保正确转义特殊字符,避免 SQL 注入
jsonb_pattern = json.dumps([{"name": normalized_name}])
models_with_global = (
db.query(Model, GlobalModel)
.join(Provider, Model.provider_id == Provider.id)
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(
Provider.is_active == True,
Model.is_active == True,
GlobalModel.is_active == True,
or_(
Model.provider_model_name == normalized_name,
# PostgreSQL JSONB 查询:检查数组中是否有包含 {"name": "xxx"} 的元素
Model.provider_model_aliases.op("@>")(jsonb_pattern),
),
)
.all()
)
except (OperationalError, ProgrammingError) as e:
# JSONB 操作符不支持(如 SQLite回退到加载匹配 provider_model_name 的 Model
# 并在 Python 层过滤 aliases
logger.debug(
f"JSONB 查询失败,回退到 Python 过滤: {e}",
)
# 优化:先用 provider_model_name 缩小范围,再加载其他可能匹配的记录
models_with_global = (
db.query(Model, GlobalModel)
.join(Provider, Model.provider_id == Provider.id)
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(
Provider.is_active == True,
Model.is_active == True,
GlobalModel.is_active == True,
)
.all()
models_with_global = (
db.query(Model, GlobalModel)
.join(Provider, Model.provider_id == Provider.id)
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(
Provider.is_active == True,
Model.is_active == True,
GlobalModel.is_active == True,
Model.provider_model_name == normalized_name,
)
.all()
)
# 用于存储匹配结果:{(model_id, global_model_id): (GlobalModel, match_type, priority)}
# 使用字典去重,同一个 Model 只保留优先级最高的匹配
matched_models_dict = {}
# 遍历查询结果进行匹配
# 收集匹配的 GlobalModel(只通过 provider_model_name 匹配)
matched_global_models: List[GlobalModel] = []
seen_global_model_ids: set[str] = set()
for model, gm in models_with_global:
key = (model.id, gm.id)
# 检查 provider_model_aliases 是否匹配(优先级更高)
if model.provider_model_aliases:
for alias_entry in model.provider_model_aliases:
if isinstance(alias_entry, dict):
alias_name = alias_entry.get("name", "").strip()
if alias_name == normalized_name:
# alias 优先级为 0最高覆盖任何已存在的匹配
matched_models_dict[key] = (gm, "alias", 0)
logger.debug(
f"模型名称 '{normalized_name}' 通过映射名称匹配到 "
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
)
break
# 如果还没有匹配(或只有 provider_model_name 匹配),检查 provider_model_name
if key not in matched_models_dict or matched_models_dict[key][1] != "alias":
if model.provider_model_name == normalized_name:
# provider_model_name 优先级为 1兜底只在没有 alias 匹配时使用
if key not in matched_models_dict:
matched_models_dict[key] = (gm, "provider_model_name", 1)
logger.debug(
f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
)
# 如果通过 provider_model_name/alias 找到了,直接返回
if matched_models_dict:
# 转换为列表并排序:按 priorityalias=0 优先)、然后按 GlobalModel.name
matched_global_models = [
(gm, match_type) for gm, match_type, priority in matched_models_dict.values()
]
matched_global_models.sort(
key=lambda item: (
0 if item[1] == "alias" else 1, # alias 优先
item[0].name # 同优先级按名称排序(确定性)
if gm.id not in seen_global_model_ids:
seen_global_model_ids.add(gm.id)
matched_global_models.append(gm)
logger.debug(
f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
)
)
# 记录解析方式
resolution_method = matched_global_models[0][1]
# 如果通过 provider_model_name 找到了,返回
if matched_global_models:
resolution_method = "provider_model_name"
if len(matched_global_models) > 1:
# 检测到冲突
unique_models = {gm.id: gm for gm, _ in matched_global_models}
if len(unique_models) > 1:
model_names = [gm.name for gm in unique_models.values()]
logger.warning(
f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: "
f"{', '.join(model_names)},使用第一个匹配结果"
)
# 记录冲突指标
model_mapping_conflict_total.inc()
# 检测到冲突(多个不同的 GlobalModel 有相同的 provider_model_name
model_names = [gm.name for gm in matched_global_models if gm.name]
logger.warning(
f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: "
f"{', '.join(model_names)},使用第一个匹配结果"
)
# 记录冲突指标
model_mapping_conflict_total.inc()
# 返回第一个匹配的 GlobalModel
result_global_model: GlobalModel = matched_global_models[0][0]
result_global_model = matched_global_models[0]
global_model_dict = ModelCacheService._global_model_to_dict(result_global_model)
await CacheService.set(
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL

View File

@@ -6,7 +6,7 @@
- 根据 API 格式或端点配置生成请求 URL
"""
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional
from urllib.parse import urlencode
from src.core.api_format_metadata import get_auth_config, get_default_path, resolve_api_format
@@ -14,11 +14,14 @@ from src.core.crypto import crypto_service
from src.core.enums import APIFormat
from src.core.logger import logger
if TYPE_CHECKING:
from src.models.database import ProviderAPIKey, ProviderEndpoint
def build_provider_headers(
endpoint,
key,
endpoint: "ProviderEndpoint",
key: "ProviderAPIKey",
original_headers: Optional[Dict[str, str]] = None,
*,
extra_headers: Optional[Dict[str, str]] = None,
@@ -28,7 +31,8 @@ def build_provider_headers(
"""
headers: Dict[str, str] = {}
decrypted_key = crypto_service.decrypt(key.api_key)
# api_key 在数据库中是 NOT NULL类型标注为 Optional 是 SQLAlchemy 限制
decrypted_key = crypto_service.decrypt(key.api_key) # type: ignore[arg-type]
# 根据 API 格式自动选择认证头
api_format = getattr(endpoint, "api_format", None)
@@ -68,8 +72,32 @@ def build_provider_headers(
return headers
def _normalize_base_url(base_url: str, path: str) -> str:
"""
规范化 base_url去除末尾的斜杠和可能与 path 重复的版本前缀。
只有当 path 以版本前缀开头时,才从 base_url 中移除该前缀,
避免拼接出 /v1/v1/messages 这样的重复路径。
兼容用户填写的各种格式:
- https://api.example.com
- https://api.example.com/
- https://api.example.com/v1
- https://api.example.com/v1/
"""
base = base_url.rstrip("/")
# 只在 path 以版本前缀开头时才去除 base_url 中的该前缀
# 例如base="/v1", path="/v1/messages" -> 去除 /v1
# 例如base="/v1", path="/chat/completions" -> 不去除(用户可能期望保留)
for suffix in ("/v1beta", "/v1", "/v2", "/v3"):
if base.endswith(suffix) and path.startswith(suffix):
base = base[: -len(suffix)]
break
return base
def build_provider_url(
endpoint,
endpoint: "ProviderEndpoint",
*,
query_params: Optional[Dict[str, Any]] = None,
path_params: Optional[Dict[str, Any]] = None,
@@ -88,8 +116,6 @@ def build_provider_url(
path_params: 路径模板参数 (如 {model})
is_stream: 是否为流式请求,用于 Gemini API 选择正确的操作方法
"""
base = endpoint.base_url.rstrip("/")
# 准备路径参数,添加 Gemini API 所需的 action 参数
effective_path_params = dict(path_params) if path_params else {}
@@ -123,6 +149,9 @@ def build_provider_url(
if not path.startswith("/"):
path = f"/{path}"
# 先确定 path再根据 path 规范化 base_url
# base_url 在数据库中是 NOT NULL类型标注为 Optional 是 SQLAlchemy 限制
base = _normalize_base_url(endpoint.base_url, path) # type: ignore[arg-type]
url = f"{base}{path}"
# 添加查询参数
@@ -134,7 +163,7 @@ def build_provider_url(
return url
def _resolve_default_path(api_format) -> str:
def _resolve_default_path(api_format: Optional[str]) -> str:
"""
根据 API 格式返回默认路径
"""

View File

@@ -157,6 +157,7 @@ class UsageService:
api_format: Optional[str] = None,
is_stream: bool = False,
response_time_ms: Optional[int] = None,
first_byte_time_ms: Optional[int] = None, # 首字时间 (TTFB)
status_code: int = 200,
error_message: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
@@ -368,6 +369,7 @@ class UsageService:
status_code=status_code,
error_message=error_message,
response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms, # 首字时间 (TTFB)
status=status, # 请求状态追踪
request_metadata=metadata,
request_headers=processed_request_headers,
@@ -419,6 +421,7 @@ class UsageService:
api_format: Optional[str] = None,
is_stream: bool = False,
response_time_ms: Optional[int] = None,
first_byte_time_ms: Optional[int] = None, # 首字时间 (TTFB)
status_code: int = 200,
error_message: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
@@ -629,6 +632,7 @@ class UsageService:
status_code=status_code,
error_message=error_message,
response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms, # 首字时间 (TTFB)
status=status, # 请求状态追踪
request_metadata=metadata,
request_headers=processed_request_headers,
@@ -649,6 +653,7 @@ class UsageService:
existing_usage.status_code = status_code
existing_usage.error_message = error_message
existing_usage.response_time_ms = response_time_ms
existing_usage.first_byte_time_ms = first_byte_time_ms # 更新首字时间
# 更新请求头和请求体(如果有新值)
if processed_request_headers is not None:
existing_usage.request_headers = processed_request_headers
@@ -1315,11 +1320,11 @@ class UsageService:
default_timeout_seconds: int = 300,
) -> List[Dict[str, Any]]:
"""
获取活跃请求状态(用于前端轮询),并自动清理超时的 pending 请求
获取活跃请求状态(用于前端轮询),并自动清理超时的 pending/streaming 请求
与 get_active_requests 不同,此方法:
1. 返回轻量级的状态字典而非完整 Usage 对象
2. 自动检测并清理超时的 pending 请求
2. 自动检测并清理超时的 pending/streaming 请求
3. 支持按 ID 列表查询特定请求
Args:
@@ -1343,6 +1348,7 @@ class UsageService:
Usage.output_tokens,
Usage.total_cost_usd,
Usage.response_time_ms,
Usage.first_byte_time_ms, # 首字时间 (TTFB)
Usage.created_at,
Usage.provider_endpoint_id,
ProviderEndpoint.timeout.label("endpoint_timeout"),
@@ -1361,10 +1367,10 @@ class UsageService:
records = query.all()
# 检查超时的 pending 请求
# 检查超时的 pending/streaming 请求
timeout_ids = []
for r in records:
if r.status == "pending" and r.created_at:
if r.status in ("pending", "streaming") and r.created_at:
# 使用端点配置的超时时间,若无则使用默认值
timeout_seconds = r.endpoint_timeout or default_timeout_seconds
@@ -1392,6 +1398,7 @@ class UsageService:
"output_tokens": r.output_tokens,
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
"response_time_ms": r.response_time_ms,
"first_byte_time_ms": r.first_byte_time_ms, # 首字时间 (TTFB)
}
for r in records
]

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""测试模块"""

View File

@@ -0,0 +1,117 @@
from src.api.handlers.base import stream_context
from src.api.handlers.base.stream_context import StreamContext
def test_collected_text_append_and_property() -> None:
ctx = StreamContext(model="test-model", api_format="OPENAI")
assert ctx.collected_text == ""
ctx.append_text("hello")
ctx.append_text(" ")
ctx.append_text("world")
assert ctx.collected_text == "hello world"
def test_reset_for_retry_clears_state() -> None:
ctx = StreamContext(model="test-model", api_format="OPENAI")
ctx.append_text("x")
ctx.update_usage(input_tokens=10, output_tokens=5)
ctx.parsed_chunks.append({"type": "chunk"})
ctx.chunk_count = 3
ctx.data_count = 2
ctx.has_completion = True
ctx.status_code = 418
ctx.error_message = "boom"
ctx.reset_for_retry()
assert ctx.collected_text == ""
assert ctx.input_tokens == 0
assert ctx.output_tokens == 0
assert ctx.parsed_chunks == []
assert ctx.chunk_count == 0
assert ctx.data_count == 0
assert ctx.has_completion is False
assert ctx.status_code == 200
assert ctx.error_message is None
def test_record_first_byte_time(monkeypatch) -> None:
"""测试记录首字时间"""
ctx = StreamContext(model="claude-3", api_format="claude_messages")
start_time = 100.0
monkeypatch.setattr(stream_context.time, "time", lambda: 100.0123) # 12.3ms
# 记录首字时间
ctx.record_first_byte_time(start_time)
# 验证首字时间已记录
assert ctx.first_byte_time_ms == 12
def test_record_first_byte_time_idempotent(monkeypatch) -> None:
"""测试首字时间只记录一次"""
ctx = StreamContext(model="claude-3", api_format="claude_messages")
start_time = 100.0
# 第一次记录
monkeypatch.setattr(stream_context.time, "time", lambda: 100.010)
ctx.record_first_byte_time(start_time)
first_value = ctx.first_byte_time_ms
# 第二次记录(应该被忽略)
monkeypatch.setattr(stream_context.time, "time", lambda: 100.020)
ctx.record_first_byte_time(start_time)
second_value = ctx.first_byte_time_ms
# 验证值没有改变
assert first_value == second_value
def test_reset_for_retry_clears_first_byte_time(monkeypatch) -> None:
"""测试重试时清除首字时间"""
ctx = StreamContext(model="claude-3", api_format="claude_messages")
start_time = 100.0
# 记录首字时间
monkeypatch.setattr(stream_context.time, "time", lambda: 100.010)
ctx.record_first_byte_time(start_time)
assert ctx.first_byte_time_ms is not None
# 重置
ctx.reset_for_retry()
# 验证首字时间已清除
assert ctx.first_byte_time_ms is None
def test_get_log_summary_with_first_byte_time() -> None:
"""测试日志摘要包含首字时间"""
ctx = StreamContext(model="claude-3", api_format="claude_messages")
ctx.provider_name = "anthropic"
ctx.input_tokens = 100
ctx.output_tokens = 50
ctx.first_byte_time_ms = 123
summary = ctx.get_log_summary("request-id-123", 456)
# 验证包含首字时间和总时间(大写格式)
assert "TTFB: 123ms" in summary
assert "Total: 456ms" in summary
assert "in:100 out:50" in summary
def test_get_log_summary_without_first_byte_time() -> None:
"""测试日志摘要在没有首字时间时的格式"""
ctx = StreamContext(model="claude-3", api_format="claude_messages")
ctx.provider_name = "anthropic"
ctx.input_tokens = 100
ctx.output_tokens = 50
# first_byte_time_ms 保持为 None
summary = ctx.get_log_summary("request-id-123", 456)
# 验证不包含首字时间标记,但有总时间(使用大写 TTFB 和 Total
assert "TTFB:" not in summary
assert "Total: 456ms" in summary
assert "in:100 out:50" in summary

View File

@@ -0,0 +1,32 @@
from typing import Any, Dict, Optional
from src.api.handlers.base.response_parser import ParsedChunk, ParsedResponse, ResponseParser, StreamStats
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.stream_processor import StreamProcessor
from src.utils.sse_parser import SSEEventParser
class DummyParser(ResponseParser):
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
return None
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
return ParsedResponse(raw_response=response, status_code=status_code)
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
return {}
def extract_text_content(self, response: Dict[str, Any]) -> str:
return ""
def test_process_line_strips_newlines_and_finalizes_event() -> None:
ctx = StreamContext(model="test-model", api_format="OPENAI")
processor = StreamProcessor(request_id="test-request", default_parser=DummyParser())
sse_parser = SSEEventParser()
processor._process_line(ctx, sse_parser, 'data: {"type":"response.completed"}\n')
processor._process_line(ctx, sse_parser, "\n")
assert ctx.has_completion is True

View File

@@ -0,0 +1,104 @@
"""测试 handler 基础工具函数"""
import pytest
from src.api.handlers.base.utils import build_sse_headers, extract_cache_creation_tokens
class TestExtractCacheCreationTokens:
"""测试 extract_cache_creation_tokens 函数"""
def test_new_format_only(self) -> None:
"""测试只有新格式字段"""
usage = {
"claude_cache_creation_5_m_tokens": 100,
"claude_cache_creation_1_h_tokens": 200,
}
assert extract_cache_creation_tokens(usage) == 300
def test_new_format_5m_only(self) -> None:
"""测试只有 5 分钟缓存"""
usage = {
"claude_cache_creation_5_m_tokens": 150,
"claude_cache_creation_1_h_tokens": 0,
}
assert extract_cache_creation_tokens(usage) == 150
def test_new_format_1h_only(self) -> None:
"""测试只有 1 小时缓存"""
usage = {
"claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 250,
}
assert extract_cache_creation_tokens(usage) == 250
def test_old_format_only(self) -> None:
"""测试只有旧格式字段"""
usage = {
"cache_creation_input_tokens": 500,
}
assert extract_cache_creation_tokens(usage) == 500
def test_both_formats_prefers_new(self) -> None:
"""测试同时存在时优先使用新格式"""
usage = {
"claude_cache_creation_5_m_tokens": 100,
"claude_cache_creation_1_h_tokens": 200,
"cache_creation_input_tokens": 999, # 应该被忽略
}
assert extract_cache_creation_tokens(usage) == 300
def test_empty_usage(self) -> None:
"""测试空字典"""
usage = {}
assert extract_cache_creation_tokens(usage) == 0
def test_all_zeros(self) -> None:
"""测试所有字段都为 0"""
usage = {
"claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 0,
"cache_creation_input_tokens": 0,
}
assert extract_cache_creation_tokens(usage) == 0
def test_partial_new_format_with_old_format_fallback(self) -> None:
"""测试新格式字段不存在时回退到旧格式"""
usage = {
"cache_creation_input_tokens": 123,
}
assert extract_cache_creation_tokens(usage) == 123
def test_new_format_zero_should_not_fallback(self) -> None:
"""测试新格式字段存在但为 0 时,不应 fallback 到旧格式"""
usage = {
"claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 0,
"cache_creation_input_tokens": 456,
}
# 新格式字段存在,即使值为 0 也应该使用新格式(返回 0
# 而不是 fallback 到旧格式(返回 456
assert extract_cache_creation_tokens(usage) == 0
def test_unrelated_fields_ignored(self) -> None:
"""测试忽略无关字段"""
usage = {
"input_tokens": 1000,
"output_tokens": 2000,
"cache_read_input_tokens": 300,
"claude_cache_creation_5_m_tokens": 50,
"claude_cache_creation_1_h_tokens": 75,
}
assert extract_cache_creation_tokens(usage) == 125
class TestBuildSSEHeaders:
def test_default_headers(self) -> None:
headers = build_sse_headers()
assert headers["Cache-Control"] == "no-cache, no-transform"
assert headers["X-Accel-Buffering"] == "no"
def test_merge_extra_headers(self) -> None:
headers = build_sse_headers({"X-Test": "1", "Cache-Control": "custom"})
assert headers["X-Test"] == "1"
assert headers["Cache-Control"] == "custom"