11 Commits

Author SHA1 Message Date
fawney19
c42ebdd0ee test(handler): add comprehensive stream processor unit tests 2025-12-16 02:40:26 +08:00
fawney19
f1e3c2ab11 feat(frontend-usage): enhance usage UI with first byte latency metrics
- Update usage records table to display first_byte_time_ms metrics
- Improve request timeline visualization for latency tracking
- Extend usage types for new timing information
2025-12-16 02:39:54 +08:00
fawney19
4e2ba0e57f feat(usage): add first_byte_time_ms tracking to usage statistics
- Enhance usage service to capture and store first byte latency metrics
- Update usage API routes to include new timing information
2025-12-16 02:39:36 +08:00
fawney19
a3df41d63d refactor(cli-handler): improve stream handling and response processing
- Refactor CLI handler base for better stream context management
- Optimize request/response handling for Claude, OpenAI, and Gemini CLI adapters
- Enhance telemetry tracking across CLI handlers
2025-12-16 02:39:20 +08:00
fawney19
ad1c8c394c refactor(handler): optimize stream processing and telemetry pipeline
- Enhance stream context for better token and latency tracking
- Refactor stream processor for improved performance metrics
- Improve telemetry integration with first_byte_time_ms support
- Add comprehensive stream context unit tests
2025-12-16 02:39:03 +08:00
fawney19
9b496abb73 feat(db): add first_byte_time_ms column to usage table 2025-12-16 02:38:43 +08:00
fawney19
f3a69a6160 refactor(handler): implement defensive token update strategy and extract cache creation token utility
- Add extract_cache_creation_tokens utility to handle new/old cache creation token formats
- Implement defensive update strategy in StreamContext to prevent zero values overwriting valid data
- Simplify cache creation token parsing in Claude handler using new utility
- Add comprehensive test suite for cache creation token extraction
- Improve type hints in handler classes
2025-12-16 00:02:49 +08:00
fawney19
adcdb73d29 feat(frontend): enhance cache monitoring UI and API integration 2025-12-15 23:12:58 +08:00
fawney19
cf67160821 feat(cache): enhance cache monitoring endpoints and handler integrations 2025-12-15 23:12:48 +08:00
fawney19
718f56ba75 refactor(cache): optimize cache service architecture and provider transport 2025-12-15 23:12:34 +08:00
fawney19
d87de10f62 refactor(docker): remove exposed ports for postgres and redis services 2025-12-15 23:12:23 +08:00
31 changed files with 1492 additions and 550 deletions

View File

@@ -0,0 +1,28 @@
"""add first_byte_time_ms to usage table
Revision ID: 180e63a9c83a
Revises: e9b3d63f0cbf
Create Date: 2025-12-15 17:07:44.631032+00:00
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '180e63a9c83a'
down_revision = 'e9b3d63f0cbf'
branch_labels = None
depends_on = None
def upgrade() -> None:
"""应用迁移:升级到新版本"""
# 添加首字时间字段到 usage 表
op.add_column('usage', sa.Column('first_byte_time_ms', sa.Integer(), nullable=True))
def downgrade() -> None:
"""回滚迁移:降级到旧版本"""
# 删除首字时间字段
op.drop_column('usage', 'first_byte_time_ms')

View File

@@ -12,8 +12,6 @@ services:
TZ: Asia/Shanghai TZ: Asia/Shanghai
volumes: volumes:
- postgres_data:/var/lib/postgresql/data - postgres_data:/var/lib/postgresql/data
ports:
- "${DB_PORT:-5432}:5432"
healthcheck: healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"] test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s interval: 5s
@@ -27,8 +25,6 @@ services:
command: redis-server --appendonly yes --requirepass ${REDIS_PASSWORD} command: redis-server --appendonly yes --requirepass ${REDIS_PASSWORD}
volumes: volumes:
- redis_data:/data - redis_data:/data
ports:
- "${REDIS_PORT:-6379}:6379"
healthcheck: healthcheck:
test: ["CMD", "redis-cli", "--raw", "incr", "ping"] test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
interval: 5s interval: 5s

View File

@@ -290,6 +290,19 @@ export interface UnmappedEntry {
ttl: number | null ttl: number | null
} }
// Provider 模型映射缓存Redis 缓存)
export interface ProviderModelMapping {
provider_id: string
provider_name: string
global_model_id: string
global_model_name: string
global_model_display_name: string | null
provider_model_name: string
aliases: string[] | null
ttl: number | null
hit_count: number
}
export interface ModelMappingCacheStats { export interface ModelMappingCacheStats {
available: boolean available: boolean
message?: string message?: string
@@ -303,6 +316,7 @@ export interface ModelMappingCacheStats {
global_model_resolve: number global_model_resolve: number
} }
mappings?: ModelMappingItem[] mappings?: ModelMappingItem[]
provider_model_mappings?: ProviderModelMapping[] | null
unmapped?: UnmappedEntry[] | null unmapped?: UnmappedEntry[] | null
} }
@@ -337,5 +351,13 @@ export const modelMappingCacheApi = {
async clearByName(modelName: string): Promise<ClearModelMappingCacheResponse> { async clearByName(modelName: string): Promise<ClearModelMappingCacheResponse> {
const response = await api.delete(`/api/admin/monitoring/cache/model-mapping/${encodeURIComponent(modelName)}`) const response = await api.delete(`/api/admin/monitoring/cache/model-mapping/${encodeURIComponent(modelName)}`)
return response.data return response.data
},
/**
* 清除指定 Provider 和 GlobalModel 的映射缓存
*/
async clearProviderModel(providerId: string, globalModelId: string): Promise<ClearModelMappingCacheResponse> {
const response = await api.delete(`/api/admin/monitoring/cache/model-mapping/provider/${providerId}/${globalModelId}`)
return response.data
} }
} }

View File

@@ -479,10 +479,25 @@ const groupedTimeline = computed<NodeGroup[]>(() => {
return groups return groups
}) })
// 计算链路总耗时(从第一个节点开始到最后一个节点结束 // 计算链路总耗时(使用成功候选的 latency_ms 字段
// 优先使用 latency_ms因为它与 Usage.response_time_ms 使用相同的时间基准
// 避免 finished_at - started_at 带来的额外延迟(数据库操作时间)
const totalTraceLatency = computed(() => { const totalTraceLatency = computed(() => {
if (!timeline.value || timeline.value.length === 0) return 0 if (!timeline.value || timeline.value.length === 0) return 0
// 查找成功的候选,使用其 latency_ms
const successCandidate = timeline.value.find(c => c.status === 'success')
if (successCandidate?.latency_ms != null) {
return successCandidate.latency_ms
}
// 如果没有成功的候选,查找失败但有 latency_ms 的候选
const failedWithLatency = timeline.value.find(c => c.status === 'failed' && c.latency_ms != null)
if (failedWithLatency?.latency_ms != null) {
return failedWithLatency.latency_ms
}
// 回退:使用 finished_at - started_at 计算
let earliestStart: number | null = null let earliestStart: number | null = null
let latestEnd: number | null = null let latestEnd: number | null = null

View File

@@ -177,8 +177,9 @@
费用 费用
</TableHead> </TableHead>
<TableHead class="h-12 font-semibold w-[70px] text-right"> <TableHead class="h-12 font-semibold w-[70px] text-right">
<div class="inline-block max-w-[2rem] leading-tight"> <div class="flex flex-col items-end text-xs gap-0.5">
响应时间 <span>首字</span>
<span class="text-muted-foreground font-normal">总耗时</span>
</div> </div>
</TableHead> </TableHead>
</TableRow> </TableRow>
@@ -356,15 +357,28 @@
</div> </div>
</TableCell> </TableCell>
<TableCell class="text-right py-4 w-[70px]"> <TableCell class="text-right py-4 w-[70px]">
<span <div
v-if="record.status === 'pending' || record.status === 'streaming'" v-if="record.status === 'pending' || record.status === 'streaming'"
class="text-primary tabular-nums" class="flex flex-col items-end text-xs gap-0.5"
> >
{{ getElapsedTime(record) }} <span class="text-primary tabular-nums">
</span> {{ getElapsedTime(record) }}
<span v-else-if="record.response_time_ms"> </span>
{{ (record.response_time_ms / 1000).toFixed(2) }}s </div>
</span> <div
v-else-if="record.response_time_ms != null"
class="flex flex-col items-end text-xs gap-0.5"
>
<span
v-if="record.first_byte_time_ms != null"
class="tabular-nums"
>{{ (record.first_byte_time_ms / 1000).toFixed(2) }}s</span>
<span
v-else
class="text-muted-foreground"
>-</span>
<span class="text-muted-foreground tabular-nums">{{ (record.response_time_ms / 1000).toFixed(2) }}s</span>
</div>
<span <span
v-else v-else
class="text-muted-foreground" class="text-muted-foreground"

View File

@@ -78,6 +78,7 @@ export interface UsageRecord {
cost: number cost: number
actual_cost?: number actual_cost?: number
response_time_ms?: number response_time_ms?: number
first_byte_time_ms?: number // 首字时间 (TTFB)
is_stream: boolean is_stream: boolean
status_code?: number status_code?: number
error_message?: string error_message?: string

View File

@@ -299,6 +299,26 @@ async function clearModelMappingByName(modelName: string) {
} }
} }
async function clearProviderModelMapping(providerId: string, globalModelId: string, displayName?: string) {
const confirmed = await showConfirm({
title: '确认清除',
message: `确定要清除 ${displayName || 'Provider 模型映射'} 的缓存吗?`,
confirmText: '确认清除',
variant: 'destructive'
})
if (!confirmed) return
try {
await modelMappingCacheApi.clearProviderModel(providerId, globalModelId)
showSuccess('已清除 Provider 模型映射缓存')
await fetchModelMappingStats()
} catch (error) {
showError('清除缓存失败')
log.error('清除 Provider 模型映射缓存失败', error)
}
}
function formatTTL(ttl: number | null): string { function formatTTL(ttl: number | null): string {
if (ttl === null || ttl < 0) return '-' if (ttl === null || ttl < 0) return '-'
if (ttl < 60) return `${ttl}s` if (ttl < 60) return `${ttl}s`
@@ -872,9 +892,125 @@ onBeforeUnmount(() => {
</div> </div>
</div> </div>
<!-- Provider 模型映射缓存 -->
<div
v-if="modelMappingStats?.available && modelMappingStats.provider_model_mappings && modelMappingStats.provider_model_mappings.length > 0"
class="border-t border-border/40"
>
<div class="px-6 py-3 text-xs text-muted-foreground border-b border-border/30 bg-muted/20">
Provider 模型映射缓存
</div>
<!-- 桌面端表格 -->
<Table class="hidden md:table">
<TableHeader>
<TableRow>
<TableHead class="w-[15%]">
提供商
</TableHead>
<TableHead class="w-[25%]">
请求名称
</TableHead>
<TableHead class="w-8 text-center" />
<TableHead class="w-[25%]">
映射模型
</TableHead>
<TableHead class="w-[10%] text-center">
剩余
</TableHead>
<TableHead class="w-[10%] text-center">
次数
</TableHead>
<TableHead class="w-[7%] text-right">
操作
</TableHead>
</TableRow>
</TableHeader>
<TableBody>
<template
v-for="(mapping, index) in modelMappingStats.provider_model_mappings"
:key="index"
>
<TableRow
v-for="(alias, aliasIndex) in (mapping.aliases || [])"
:key="`${index}-${aliasIndex}`"
>
<TableCell>
<Badge variant="outline" class="text-xs">
{{ mapping.provider_name }}
</Badge>
</TableCell>
<TableCell>
<span class="text-sm font-mono">{{ alias }}</span>
</TableCell>
<TableCell class="text-center">
<ArrowRight class="h-4 w-4 text-muted-foreground" />
</TableCell>
<TableCell>
<span class="text-sm font-mono font-medium">{{ mapping.provider_model_name }}</span>
</TableCell>
<TableCell class="text-center">
<span class="text-xs text-muted-foreground">{{ formatTTL(mapping.ttl) }}</span>
</TableCell>
<TableCell class="text-center">
<span class="text-sm">{{ mapping.hit_count || 0 }}</span>
</TableCell>
<TableCell class="text-right">
<Button
size="icon"
variant="ghost"
class="h-7 w-7 text-muted-foreground/70 hover:text-destructive"
title="清除缓存"
@click="clearProviderModelMapping(mapping.provider_id, mapping.global_model_id, `${mapping.provider_name} - ${alias}`)"
>
<Trash2 class="h-3.5 w-3.5" />
</Button>
</TableCell>
</TableRow>
</template>
</TableBody>
</Table>
<!-- 移动端卡片 -->
<div class="md:hidden divide-y divide-border/40">
<template
v-for="(mapping, index) in modelMappingStats.provider_model_mappings"
:key="`m-pm-${index}`"
>
<div
v-for="(alias, aliasIndex) in (mapping.aliases || [])"
:key="`m-pm-${index}-${aliasIndex}`"
class="p-4 space-y-2"
>
<div class="flex items-center justify-between">
<Badge variant="outline" class="text-xs">
{{ mapping.provider_name }}
</Badge>
<div class="flex items-center gap-2">
<span class="text-xs text-muted-foreground">{{ formatTTL(mapping.ttl) }}</span>
<span class="text-xs">{{ mapping.hit_count || 0 }}</span>
<Button
size="icon"
variant="ghost"
class="h-6 w-6 text-muted-foreground/70 hover:text-destructive"
title="清除缓存"
@click="clearProviderModelMapping(mapping.provider_id, mapping.global_model_id, `${mapping.provider_name} - ${alias}`)"
>
<Trash2 class="h-3 w-3" />
</Button>
</div>
</div>
<div class="flex items-center gap-2 text-sm">
<span class="font-mono">{{ alias }}</span>
<ArrowRight class="h-3.5 w-3.5 shrink-0 text-muted-foreground/60" />
<span class="font-mono font-medium">{{ mapping.provider_model_name }}</span>
</div>
</div>
</template>
</div>
</div>
<!-- 无缓存状态 --> <!-- 无缓存状态 -->
<div <div
v-else-if="modelMappingStats?.available && (!modelMappingStats.mappings || modelMappingStats.mappings.length === 0) && (!modelMappingStats.unmapped || modelMappingStats.unmapped.length === 0)" v-else-if="modelMappingStats?.available && (!modelMappingStats.mappings || modelMappingStats.mappings.length === 0) && (!modelMappingStats.unmapped || modelMappingStats.unmapped.length === 0) && (!modelMappingStats.provider_model_mappings || modelMappingStats.provider_model_mappings.length === 0)"
class="px-6 py-8 text-center text-sm text-muted-foreground" class="px-6 py-8 text-center text-sm text-muted-foreground"
> >
暂无模型解析缓存 暂无模型解析缓存

View File

@@ -12,6 +12,7 @@ from fastapi.responses import PlainTextResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.context import ApiRequestContext
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_sequence from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_sequence
from src.api.base.pipeline import ApiRequestPipeline from src.api.base.pipeline import ApiRequestPipeline
from src.clients.redis_client import get_redis_client_sync from src.clients.redis_client import get_redis_client_sync
@@ -87,19 +88,19 @@ def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
# 2. 尝试作为 Username 查询 # 2. 尝试作为 Username 查询
user = db.query(User).filter(User.username == identifier).first() user = db.query(User).filter(User.username == identifier).first()
if user: if user:
logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...") logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...") # type: ignore[index]
return user.id return user.id
# 3. 尝试作为 Email 查询 # 3. 尝试作为 Email 查询
user = db.query(User).filter(User.email == identifier).first() user = db.query(User).filter(User.email == identifier).first()
if user: if user:
logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...") logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...") # type: ignore[index]
return user.id return user.id
# 4. 尝试作为 API Key ID 查询 # 4. 尝试作为 API Key ID 查询
api_key = db.query(ApiKey).filter(ApiKey.id == identifier).first() api_key = db.query(ApiKey).filter(ApiKey.id == identifier).first()
if api_key: if api_key:
logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...") logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...") # type: ignore[index]
return api_key.user_id return api_key.user_id
# 无法识别 # 无法识别
@@ -111,7 +112,7 @@ def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
async def get_cache_stats( async def get_cache_stats(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
获取缓存亲和性统计信息 获取缓存亲和性统计信息
@@ -131,7 +132,7 @@ async def get_user_affinity(
user_identifier: str, user_identifier: str,
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
查询指定用户的所有缓存亲和性 查询指定用户的所有缓存亲和性
@@ -157,7 +158,7 @@ async def list_affinities(
limit: int = Query(100, ge=1, le=1000, description="返回数量限制"), limit: int = Query(100, ge=1, le=1000, description="返回数量限制"),
offset: int = Query(0, ge=0, description="偏移量"), offset: int = Query(0, ge=0, description="偏移量"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
获取所有缓存亲和性列表,可选按关键词过滤 获取所有缓存亲和性列表,可选按关键词过滤
@@ -173,7 +174,7 @@ async def clear_user_cache(
user_identifier: str, user_identifier: str,
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
Clear cache affinity for a specific user Clear cache affinity for a specific user
@@ -188,7 +189,7 @@ async def clear_user_cache(
async def clear_all_cache( async def clear_all_cache(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
Clear all cache affinities Clear all cache affinities
@@ -203,7 +204,7 @@ async def clear_provider_cache(
provider_id: str, provider_id: str,
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
Clear cache affinities for a specific provider Clear cache affinities for a specific provider
@@ -218,7 +219,7 @@ async def clear_provider_cache(
async def get_cache_config( async def get_cache_config(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
获取缓存相关配置 获取缓存相关配置
@@ -234,7 +235,7 @@ async def get_cache_config(
async def get_cache_metrics( async def get_cache_metrics(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
以 Prometheus 文本格式暴露缓存调度指标,方便接入 Grafana。 以 Prometheus 文本格式暴露缓存调度指标,方便接入 Grafana。
""" """
@@ -246,7 +247,7 @@ async def get_cache_metrics(
class AdminCacheStatsAdapter(AdminApiAdapter): class AdminCacheStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try: try:
redis_client = get_redis_client_sync() redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client) scheduler = await get_cache_aware_scheduler(redis_client)
@@ -266,7 +267,7 @@ class AdminCacheStatsAdapter(AdminApiAdapter):
class AdminCacheMetricsAdapter(AdminApiAdapter): class AdminCacheMetricsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> PlainTextResponse:
try: try:
redis_client = get_redis_client_sync() redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client) scheduler = await get_cache_aware_scheduler(redis_client)
@@ -391,7 +392,7 @@ class AdminCacheMetricsAdapter(AdminApiAdapter):
class AdminGetUserAffinityAdapter(AdminApiAdapter): class AdminGetUserAffinityAdapter(AdminApiAdapter):
user_identifier: str user_identifier: str
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
db = context.db db = context.db
try: try:
user_id = resolve_user_identifier(db, self.user_identifier) user_id = resolve_user_identifier(db, self.user_identifier)
@@ -472,7 +473,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
limit: int limit: int
offset: int offset: int
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
db = context.db db = context.db
redis_client = get_redis_client_sync() redis_client = get_redis_client_sync()
if not redis_client: if not redis_client:
@@ -682,7 +683,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
class AdminClearUserCacheAdapter(AdminApiAdapter): class AdminClearUserCacheAdapter(AdminApiAdapter):
user_identifier: str user_identifier: str
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
db = context.db db = context.db
try: try:
redis_client = get_redis_client_sync() redis_client = get_redis_client_sync()
@@ -786,7 +787,7 @@ class AdminClearUserCacheAdapter(AdminApiAdapter):
class AdminClearAllCacheAdapter(AdminApiAdapter): class AdminClearAllCacheAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try: try:
redis_client = get_redis_client_sync() redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client) affinity_mgr = await get_affinity_manager(redis_client)
@@ -806,7 +807,7 @@ class AdminClearAllCacheAdapter(AdminApiAdapter):
class AdminClearProviderCacheAdapter(AdminApiAdapter): class AdminClearProviderCacheAdapter(AdminApiAdapter):
provider_id: str provider_id: str
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
try: try:
redis_client = get_redis_client_sync() redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client) affinity_mgr = await get_affinity_manager(redis_client)
@@ -829,7 +830,7 @@ class AdminClearProviderCacheAdapter(AdminApiAdapter):
class AdminCacheConfigAdapter(AdminApiAdapter): class AdminCacheConfigAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.services.cache.affinity_manager import CacheAffinityManager from src.services.cache.affinity_manager import CacheAffinityManager
from src.services.cache.aware_scheduler import CacheAwareScheduler from src.services.cache.aware_scheduler import CacheAwareScheduler
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
@@ -878,7 +879,7 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
async def get_model_mapping_cache_stats( async def get_model_mapping_cache_stats(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
获取模型映射缓存统计信息 获取模型映射缓存统计信息
@@ -895,7 +896,7 @@ async def get_model_mapping_cache_stats(
async def clear_all_model_mapping_cache( async def clear_all_model_mapping_cache(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
清除所有模型映射缓存 清除所有模型映射缓存
@@ -910,7 +911,7 @@ async def clear_model_mapping_cache_by_name(
model_name: str, model_name: str,
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ) -> Any:
""" """
清除指定模型名称的映射缓存 清除指定模型名称的映射缓存
@@ -921,8 +922,28 @@ async def clear_model_mapping_cache_by_name(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/model-mapping/provider/{provider_id}/{global_model_id}")
async def clear_provider_model_mapping_cache(
provider_id: str,
global_model_id: str,
request: Request,
db: Session = Depends(get_db),
) -> Any:
"""
清除指定 Provider 和 GlobalModel 的模型映射缓存
参数:
- provider_id: Provider ID
- global_model_id: GlobalModel ID
"""
adapter = AdminClearProviderModelMappingCacheAdapter(
provider_id=provider_id, global_model_id=global_model_id
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminModelMappingCacheStatsAdapter(AdminApiAdapter): class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
import json import json
from src.clients.redis_client import get_redis_client from src.clients.redis_client import get_redis_client
@@ -955,7 +976,9 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
if key_str.startswith("model:id:"): if key_str.startswith("model:id:"):
model_id_keys.append(key_str) model_id_keys.append(key_str)
elif key_str.startswith("model:provider_global:"): elif key_str.startswith("model:provider_global:"):
provider_global_keys.append(key_str) # 过滤掉 hits 统计键,只保留实际的缓存键
if not key_str.startswith("model:provider_global:hits:"):
provider_global_keys.append(key_str)
async for key in redis.scan_iter(match="global_model:*", count=100): async for key in redis.scan_iter(match="global_model:*", count=100):
key_str = key.decode() if isinstance(key, bytes) else key key_str = key.decode() if isinstance(key, bytes) else key
@@ -1067,6 +1090,85 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
# 按 mapping_name 排序 # 按 mapping_name 排序
mappings.sort(key=lambda x: x["mapping_name"]) mappings.sort(key=lambda x: x["mapping_name"])
# 3. 解析 provider_global 缓存Provider 级别的模型解析缓存)
provider_model_mappings = []
# 预加载 Provider 和 GlobalModel 数据
provider_map = {str(p.id): p for p in db.query(Provider).filter(Provider.is_active.is_(True)).all()}
global_model_map = {str(gm.id): gm for gm in db.query(GlobalModel).filter(GlobalModel.is_active.is_(True)).all()}
for key in provider_global_keys[:100]: # 最多处理 100 个
# key 格式: model:provider_global:{provider_id}:{global_model_id}
try:
parts = key.replace("model:provider_global:", "").split(":")
if len(parts) != 2:
continue
provider_id, global_model_id = parts
cached_value = await redis.get(key)
ttl = await redis.ttl(key)
# 获取命中次数
hit_count_key = f"model:provider_global:hits:{provider_id}:{global_model_id}"
hit_count_raw = await redis.get(hit_count_key)
hit_count = int(hit_count_raw) if hit_count_raw else 0
if cached_value:
cached_str = (
cached_value.decode()
if isinstance(cached_value, bytes)
else cached_value
)
try:
cached_data = json.loads(cached_str)
provider_model_name = cached_data.get("provider_model_name")
provider_model_aliases = cached_data.get("provider_model_aliases", [])
# 获取 Provider 和 GlobalModel 信息
provider = provider_map.get(provider_id)
global_model = global_model_map.get(global_model_id)
if provider and global_model:
# 提取别名名称
alias_names = []
if provider_model_aliases:
for alias_entry in provider_model_aliases:
if isinstance(alias_entry, dict) and alias_entry.get("name"):
alias_names.append(alias_entry["name"])
# provider_model_name 为空时跳过
if not provider_model_name:
continue
# 只显示有实际映射的条目:
# 1. 全局模型名 != Provider 模型名(模型名称映射)
# 2. 或者有别名配置
has_name_mapping = global_model.name != provider_model_name
has_aliases = len(alias_names) > 0
if has_name_mapping or has_aliases:
# 构建用于展示的别名列表
# 如果只有名称映射没有别名,则用 global_model_name 作为"请求名称"
display_aliases = alias_names if alias_names else [global_model.name]
provider_model_mappings.append({
"provider_id": provider_id,
"provider_name": provider.display_name or provider.name,
"global_model_id": global_model_id,
"global_model_name": global_model.name,
"global_model_display_name": global_model.display_name,
"provider_model_name": provider_model_name,
"aliases": display_aliases,
"ttl": ttl if ttl > 0 else None,
"hit_count": hit_count,
})
except json.JSONDecodeError:
pass
except Exception as e:
logger.warning(f"解析 provider_global 缓存键 {key} 失败: {e}")
# 按 provider_name + global_model_name 排序
provider_model_mappings.sort(key=lambda x: (x["provider_name"], x["global_model_name"]))
response_data = { response_data = {
"available": True, "available": True,
"ttl_seconds": CacheTTL.MODEL, "ttl_seconds": CacheTTL.MODEL,
@@ -1079,6 +1181,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
"global_model_resolve": len(global_model_resolve_keys), "global_model_resolve": len(global_model_resolve_keys),
}, },
"mappings": mappings, "mappings": mappings,
"provider_model_mappings": provider_model_mappings if provider_model_mappings else None,
"unmapped": unmapped_entries if unmapped_entries else None, "unmapped": unmapped_entries if unmapped_entries else None,
} }
@@ -1094,7 +1197,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter): class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.clients.redis_client import get_redis_client from src.clients.redis_client import get_redis_client
try: try:
@@ -1136,7 +1239,7 @@ class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter):
class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter): class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter):
model_name: str model_name: str
async def handle(self, context): # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.clients.redis_client import get_redis_client from src.clients.redis_client import get_redis_client
try: try:
@@ -1176,3 +1279,55 @@ class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter):
except Exception as exc: except Exception as exc:
logger.exception(f"清除模型映射缓存失败: {exc}") logger.exception(f"清除模型映射缓存失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}") raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
@dataclass
class AdminClearProviderModelMappingCacheAdapter(AdminApiAdapter):
provider_id: str
global_model_id: str
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.clients.redis_client import get_redis_client
try:
redis = await get_redis_client(require_redis=False)
if not redis:
raise HTTPException(status_code=503, detail="Redis 未启用")
deleted_keys = []
# 清除 provider_global 缓存
provider_global_key = f"model:provider_global:{self.provider_id}:{self.global_model_id}"
if await redis.exists(provider_global_key):
await redis.delete(provider_global_key)
deleted_keys.append(provider_global_key)
# 清除对应的 hit_count 缓存
hit_count_key = f"model:provider_global:hits:{self.provider_id}:{self.global_model_id}"
if await redis.exists(hit_count_key):
await redis.delete(hit_count_key)
deleted_keys.append(hit_count_key)
logger.info(
f"已清除 Provider 模型映射缓存: provider_id={self.provider_id[:8]}..., "
f"global_model_id={self.global_model_id[:8]}..., 删除键={deleted_keys}"
)
context.add_audit_metadata(
action="provider_model_mapping_cache_clear",
provider_id=self.provider_id,
global_model_id=self.global_model_id,
deleted_keys=deleted_keys,
)
return {
"status": "ok",
"message": "已清除 Provider 模型映射缓存",
"provider_id": self.provider_id,
"global_model_id": self.global_model_id,
"deleted_keys": deleted_keys,
}
except HTTPException:
raise
except Exception as exc:
logger.exception(f"清除 Provider 模型映射缓存失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")

View File

@@ -628,6 +628,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
"actual_cost": actual_cost, "actual_cost": actual_cost,
"rate_multiplier": rate_multiplier, "rate_multiplier": rate_multiplier,
"response_time_ms": usage.response_time_ms, "response_time_ms": usage.response_time_ms,
"first_byte_time_ms": usage.first_byte_time_ms, # 首字时间 (TTFB)
"created_at": usage.created_at.isoformat(), "created_at": usage.created_at.isoformat(),
"is_stream": usage.is_stream, "is_stream": usage.is_stream,
"input_price_per_1m": usage.input_price_per_1m, "input_price_per_1m": usage.input_price_per_1m,
@@ -738,6 +739,7 @@ class AdminUsageDetailAdapter(AdminApiAdapter):
"status_code": usage_record.status_code, "status_code": usage_record.status_code,
"error_message": usage_record.error_message, "error_message": usage_record.error_message,
"response_time_ms": usage_record.response_time_ms, "response_time_ms": usage_record.response_time_ms,
"first_byte_time_ms": usage_record.first_byte_time_ms, # 首字时间 (TTFB)
"created_at": usage_record.created_at.isoformat() if usage_record.created_at else None, "created_at": usage_record.created_at.isoformat() if usage_record.created_at else None,
"request_headers": usage_record.request_headers, "request_headers": usage_record.request_headers,
"request_body": usage_record.get_request_body(), "request_body": usage_record.get_request_body(),

View File

@@ -100,6 +100,8 @@ class MessageTelemetry:
cache_read_tokens: int = 0, cache_read_tokens: int = 0,
is_stream: bool = False, is_stream: bool = False,
provider_request_headers: Optional[Dict[str, Any]] = None, provider_request_headers: Optional[Dict[str, Any]] = None,
# 时间指标
first_byte_time_ms: Optional[int] = None, # 首字时间/TTFB
# Provider 侧追踪信息(用于记录真实成本) # Provider 侧追踪信息(用于记录真实成本)
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None, provider_endpoint_id: Optional[str] = None,
@@ -133,6 +135,7 @@ class MessageTelemetry:
api_format=api_format, api_format=api_format,
is_stream=is_stream, is_stream=is_stream,
response_time_ms=response_time_ms, response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms, # 传递首字时间
status_code=status_code, status_code=status_code,
request_headers=request_headers, request_headers=request_headers,
request_body=request_body, request_body=request_body,
@@ -395,3 +398,24 @@ class BaseMessageHandler:
# 创建后台任务,不阻塞当前流 # 创建后台任务,不阻塞当前流
asyncio.create_task(_do_update()) asyncio.create_task(_do_update())
def _log_request_error(self, message: str, error: Exception) -> None:
"""记录请求错误日志,对业务异常不打印堆栈
Args:
message: 错误消息前缀
error: 异常对象
"""
from src.core.exceptions import (
ProviderException,
QuotaExceededException,
RateLimitException,
ModelNotSupportedException,
)
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
# 业务异常:简洁日志,不打印堆栈
logger.error(f"{message}: [{type(error).__name__}] {error}")
else:
# 未知异常:完整堆栈
logger.exception(f"{message}: {error}")

View File

@@ -34,6 +34,7 @@ from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.stream_processor import StreamProcessor from src.api.handlers.base.stream_processor import StreamProcessor
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
from src.api.handlers.base.utils import build_sse_headers
from src.config.settings import config from src.config.settings import config
from src.core.exceptions import ( from src.core.exceptions import (
EmbeddedErrorException, EmbeddedErrorException,
@@ -365,7 +366,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
ctx, ctx,
original_headers, original_headers,
original_request_body, original_request_body,
self.elapsed_ms(), self.start_time, # 传入开始时间,让 telemetry 在流结束后计算响应时间
) )
# 创建监控流 # 创建监控流
@@ -378,11 +379,12 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
return StreamingResponse( return StreamingResponse(
monitored_stream, monitored_stream,
media_type="text/event-stream", media_type="text/event-stream",
headers=build_sse_headers(),
background=background_tasks, background=background_tasks,
) )
except Exception as e: except Exception as e:
logger.exception(f"流式请求失败: {e}") self._log_request_error("流式请求失败", e)
await self._record_stream_failure(ctx, e, original_headers, original_request_body) await self._record_stream_failure(ctx, e, original_headers, original_request_body)
raise raise
@@ -473,12 +475,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
stream_response.raise_for_status() stream_response.raise_for_status()
# 创建行迭代器 # 使用字节流迭代器(避免 aiter_lines 的性能问题)
line_iterator = stream_response.aiter_lines() # aiter_raw() 返回原始数据块,无缓冲,实现真正的流式传输
byte_iterator = stream_response.aiter_raw()
# 预读检测嵌套错误 # 预读检测嵌套错误
prefetched_lines = await stream_processor.prefetch_and_check_error( prefetched_chunks = await stream_processor.prefetch_and_check_error(
line_iterator, byte_iterator,
provider, provider,
endpoint, endpoint,
ctx, ctx,
@@ -503,13 +506,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
await http_client.aclose() await http_client.aclose()
raise raise
# 创建流生成器 # 创建流生成器(传入字节流迭代器)
return stream_processor.create_response_stream( return stream_processor.create_response_stream(
ctx, ctx,
line_iterator, byte_iterator,
response_ctx, response_ctx,
http_client, http_client,
prefetched_lines, prefetched_chunks,
start_time=self.start_time,
) )
async def _record_stream_failure( async def _record_stream_failure(

View File

@@ -11,17 +11,15 @@ CLI Message Handler 通用基类
""" """
import asyncio import asyncio
import codecs
import json import json
import time import time
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import ( from typing import (
Any, Any,
AsyncGenerator, AsyncGenerator,
Callable, Callable,
Dict, Dict,
Optional, Optional,
Tuple,
) )
import httpx import httpx
@@ -35,6 +33,8 @@ from src.api.handlers.base.base_handler import (
) )
from src.api.handlers.base.parsers import get_parser_for_format from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.utils import build_sse_headers
# 直接从具体模块导入,避免循环依赖 # 直接从具体模块导入,避免循环依赖
from src.api.handlers.base.response_parser import ( from src.api.handlers.base.response_parser import (
@@ -61,63 +61,6 @@ from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser from src.utils.sse_parser import SSEEventParser
@dataclass
class StreamContext:
"""流式请求的上下文信息"""
# 请求信息
model: str = "unknown" # 用户请求的原始模型名
mapped_model: Optional[str] = None # 映射后的目标模型名(如果发生了映射)
api_format: str = ""
request_id: str = ""
# 用户信息(提前提取避免 Session detached
user_id: int = 0
api_key_id: int = 0
# 统计信息
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0 # cache_read_input_tokens
cache_creation_tokens: int = 0 # cache_creation_input_tokens
collected_text: str = ""
response_id: Optional[str] = None
final_usage: Optional[Dict[str, Any]] = None
final_response: Optional[Dict[str, Any]] = None
parsed_chunks: list = field(default_factory=list)
# 流状态
start_time: float = field(default_factory=time.time)
chunk_count: int = 0
data_count: int = 0
has_completion: bool = False
# 响应信息
status_code: int = 200
response_headers: Dict[str, str] = field(default_factory=dict)
# 请求信息(发送给 Provider 的)
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None # 实际发送的请求体
# Provider 信息
provider_name: Optional[str] = None
provider_id: Optional[str] = None # Provider ID用于记录真实成本
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
attempt_id: Optional[str] = None
attempt_synced: bool = False
error_message: Optional[str] = None
# 格式转换信息
provider_api_format: str = "" # Provider 的 API 格式(用于响应转换)
client_api_format: str = "" # 客户端请求的 API 格式
# Provider 响应元数据(存储 provider 返回的额外信息,如 Gemini 的 modelVersion
response_metadata: Dict[str, Any] = field(default_factory=dict)
class CliMessageHandlerBase(BaseMessageHandler): class CliMessageHandlerBase(BaseMessageHandler):
""" """
CLI 格式消息处理器基类 CLI 格式消息处理器基类
@@ -409,24 +352,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
return StreamingResponse( return StreamingResponse(
monitored_stream, monitored_stream,
media_type="text/event-stream", media_type="text/event-stream",
headers=build_sse_headers(),
background=background_tasks, background=background_tasks,
) )
except Exception as e: except Exception as e:
# 对于已知的业务异常,只记录简洁的错误信息,不输出完整堆栈 self._log_request_error("流式请求失败", e)
from src.core.exceptions import (
ProviderException,
QuotaExceededException,
RateLimitException,
ModelNotSupportedException,
)
if isinstance(e, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
# 业务异常:简洁日志
logger.error(f"流式请求失败: [{type(e).__name__}] {e}")
else:
# 未知异常:完整堆栈
logger.exception(f"流式请求失败: {e}")
await self._record_stream_failure(ctx, e, original_headers, original_request_body) await self._record_stream_failure(ctx, e, original_headers, original_request_body)
raise raise
@@ -446,7 +377,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
ctx.chunk_count = 0 ctx.chunk_count = 0
ctx.data_count = 0 ctx.data_count = 0
ctx.has_completion = False ctx.has_completion = False
ctx.collected_text = "" ctx._collected_text_parts = [] # 重置文本收集
ctx.input_tokens = 0 ctx.input_tokens = 0
ctx.output_tokens = 0 ctx.output_tokens = 0
ctx.cached_tokens = 0 ctx.cached_tokens = 0
@@ -534,12 +465,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
stream_response.raise_for_status() stream_response.raise_for_status()
# 创建行迭代器(只创建一次,后续会继续使用 # 使用字节流迭代器(避免 aiter_lines 的性能问题
line_iterator = stream_response.aiter_lines() byte_iterator = stream_response.aiter_raw()
# 预读第一个数据块检测嵌套错误HTTP 200 但响应体包含错误) # 预读第一个数据块检测嵌套错误HTTP 200 但响应体包含错误)
prefetched_lines = await self._prefetch_and_check_embedded_error( prefetched_chunks = await self._prefetch_and_check_embedded_error(
line_iterator, provider, endpoint, ctx byte_iterator, provider, endpoint, ctx
) )
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
@@ -564,10 +495,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 创建流生成器(带预读数据,使用同一个迭代器) # 创建流生成器(带预读数据,使用同一个迭代器)
return self._create_response_stream_with_prefetch( return self._create_response_stream_with_prefetch(
ctx, ctx,
line_iterator, byte_iterator,
response_ctx, response_ctx,
http_client, http_client,
prefetched_lines, prefetched_chunks,
) )
async def _create_response_stream( async def _create_response_stream(
@@ -577,58 +508,75 @@ class CliMessageHandlerBase(BaseMessageHandler):
response_ctx: Any, response_ctx: Any,
http_client: httpx.AsyncClient, http_client: httpx.AsyncClient,
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器""" """创建响应流生成器(使用字节流)"""
try: try:
sse_parser = SSEEventParser() sse_parser = SSEEventParser()
last_data_time = time.time() last_data_time = time.time()
streaming_status_updated = False streaming_status_updated = False
buffer = b""
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 检查是否需要格式转换 # 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx) needs_conversion = self._needs_format_conversion(ctx)
async for line in stream_response.aiter_lines(): async for chunk in stream_response.aiter_raw():
# 在第一次输出数据前更新状态为 streaming # 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated: if not streaming_status_updated:
self._update_usage_to_streaming(ctx.request_id) self._update_usage_to_streaming(ctx.request_id)
streaming_status_updated = True streaming_status_updated = True
normalized_line = line.rstrip("\r") buffer += chunk
events = sse_parser.feed_line(normalized_line) # 处理缓冲区中的完整行
while b"\n" in buffer:
if normalized_line == "": line_bytes, buffer = buffer.split(b"\n", 1)
for event in events: try:
self._handle_sse_event( # 使用增量解码器,可以正确处理跨 chunk 的多字节字符
ctx, line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
event.get("event"), except Exception as e:
event.get("data") or "", logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
) )
yield b"\n" continue
continue
ctx.chunk_count += 1 normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
# 空流检测:超过阈值且无数据,发送错误事件并结束 if normalized_line == "":
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0: for event in events:
elapsed = time.time() - last_data_time self._handle_sse_event(
if elapsed > self.DATA_TIMEOUT: ctx,
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据") event.get("event"),
error_event = { event.get("data") or "",
"type": "error", )
"error": { yield b"\n"
"type": "empty_stream_timeout", continue
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return # 结束生成器
# 格式转换或直接透传 ctx.chunk_count += 1
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events) # 空流检测:超过阈值且无数据,发送错误事件并结束
if converted_line: if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
yield (converted_line + "\n").encode("utf-8") elapsed = time.time() - last_data_time
else: if elapsed > self.DATA_TIMEOUT:
yield (line + "\n").encode("utf-8") logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return # 结束生成器
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events: for event in events:
self._handle_sse_event( self._handle_sse_event(
@@ -702,7 +650,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
async def _prefetch_and_check_embedded_error( async def _prefetch_and_check_embedded_error(
self, self,
line_iterator: Any, byte_iterator: Any,
provider: Provider, provider: Provider,
endpoint: ProviderEndpoint, endpoint: ProviderEndpoint,
ctx: StreamContext, ctx: StreamContext,
@@ -716,20 +664,25 @@ class CliMessageHandlerBase(BaseMessageHandler):
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。 同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
Args: Args:
line_iterator: 行迭代器aiter_lines() 返回的迭代器 byte_iterator: 字节流迭代器
provider: Provider 对象 provider: Provider 对象
endpoint: Endpoint 对象 endpoint: Endpoint 对象
ctx: 流上下文 ctx: 流上下文
Returns: Returns:
预读的列表(需要在后续流中先输出) 预读的字节块列表(需要在后续流中先输出)
Raises: Raises:
EmbeddedErrorException: 如果检测到嵌套错误 EmbeddedErrorException: 如果检测到嵌套错误
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误) ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
""" """
prefetched_lines: list = [] prefetched_chunks: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误 max_prefetch_lines = 5 # 最多预读5行来检测错误
buffer = b""
line_count = 0
should_stop = False
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
try: try:
# 获取对应格式的解析器 # 获取对应格式的解析器
@@ -742,69 +695,86 @@ class CliMessageHandlerBase(BaseMessageHandler):
else: else:
provider_parser = self.parser provider_parser = self.parser
line_count = 0 async for chunk in byte_iterator:
async for line in line_iterator: prefetched_chunks.append(chunk)
prefetched_lines.append(line) buffer += chunk
line_count += 1
# 解析数据 # 尝试按行解析缓冲区
normalized_line = line.rstrip("\r") while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
# 检测 HTML 响应base_url 配置错误的常见症状) line_count += 1
lower_line = normalized_line.lower() normalized_line = line.rstrip("\r")
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
)
if not normalized_line or normalized_line.startswith(":"): # 检测 HTML 响应base_url 配置错误的常见症状)
# 空行或注释行,继续预读 lower_line = normalized_line.lower()
if line_count >= max_prefetch_lines: if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
)
if not normalized_line or normalized_line.startswith(":"):
# 空行或注释行,继续预读
if line_count >= max_prefetch_lines:
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
if data_str == "[DONE]":
should_stop = True
break break
continue
# 尝试解析 SSE 数据 try:
data_str = normalized_line data = json.loads(data_str)
if normalized_line.startswith("data: "): except json.JSONDecodeError:
data_str = normalized_line[6:] # 不是有效 JSON可能是部分数据继续
if line_count >= max_prefetch_lines:
break
continue
if data_str == "[DONE]": # 使用解析器检查是否为错误响应
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
should_stop = True
break break
try: if should_stop or line_count >= max_prefetch_lines:
data = json.loads(data_str) break
except json.JSONDecodeError:
# 不是有效 JSON可能是部分数据继续
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
break
except EmbeddedErrorException: except EmbeddedErrorException:
# 重新抛出嵌套错误 # 重新抛出嵌套错误
@@ -813,112 +783,168 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断 # 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}") logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_lines return prefetched_chunks
async def _create_response_stream_with_prefetch( async def _create_response_stream_with_prefetch(
self, self,
ctx: StreamContext, ctx: StreamContext,
line_iterator: Any, byte_iterator: Any,
response_ctx: Any, response_ctx: Any,
http_client: httpx.AsyncClient, http_client: httpx.AsyncClient,
prefetched_lines: list, prefetched_chunks: list,
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器(带预读数据)""" """创建响应流生成器(带预读数据,使用字节流"""
try: try:
sse_parser = SSEEventParser() sse_parser = SSEEventParser()
last_data_time = time.time() last_data_time = time.time()
buffer = b""
first_yield = True # 标记是否是第一次 yield
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 检查是否需要格式转换 # 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx) needs_conversion = self._needs_format_conversion(ctx)
# 在第一次输出数据前更新状态为 streaming # 在第一次输出数据前更新状态为 streaming
if prefetched_lines: if prefetched_chunks:
self._update_usage_to_streaming(ctx.request_id) self._update_usage_to_streaming(ctx.request_id)
# 先处理预读的数据 # 先处理预读的字节块
for line in prefetched_lines: for chunk in prefetched_chunks:
normalized_line = line.rstrip("\r") buffer += chunk
events = sse_parser.feed_line(normalized_line) # 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield b"\n"
continue
ctx.chunk_count += 1
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (line + "\n").encode("utf-8")
if normalized_line == "":
for event in events: for event in events:
self._handle_sse_event( self._handle_sse_event(
ctx, ctx,
event.get("event"), event.get("event"),
event.get("data") or "", event.get("data") or "",
) )
yield b"\n"
continue
ctx.chunk_count += 1 if ctx.data_count > 0:
last_data_time = time.time()
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
if ctx.data_count > 0:
last_data_time = time.time()
# 继续处理剩余的流数据(使用同一个迭代器) # 继续处理剩余的流数据(使用同一个迭代器)
async for line in line_iterator: async for chunk in byte_iterator:
normalized_line = line.rstrip("\r") buffer += chunk
events = sse_parser.feed_line(normalized_line) # 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield b"\n"
continue
ctx.chunk_count += 1
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (line + "\n").encode("utf-8")
if normalized_line == "":
for event in events: for event in events:
self._handle_sse_event( self._handle_sse_event(
ctx, ctx,
event.get("event"), event.get("event"),
event.get("data") or "", event.get("data") or "",
) )
yield b"\n"
continue
ctx.chunk_count += 1 if ctx.data_count > 0:
last_data_time = time.time()
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
if ctx.data_count > 0:
last_data_time = time.time()
# 处理剩余事件 # 处理剩余事件
flushed_events = sse_parser.flush() flushed_events = sse_parser.flush()
@@ -1047,7 +1073,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 提取文本内容 # 提取文本内容
text = self.parser.extract_text_content(data) text = self.parser.extract_text_content(data)
if text: if text:
ctx.collected_text += text ctx.append_text(text)
# 检查完成事件 # 检查完成事件
if event_type in ("response.completed", "message_stop"): if event_type in ("response.completed", "message_stop"):
@@ -1099,9 +1125,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
) -> None: ) -> None:
"""在流完成后记录统计信息""" """在流完成后记录统计信息"""
try: try:
await asyncio.sleep(0.1) # 使用 self.start_time 作为时间基准,与首字时间保持一致
# 注意:不要把统计延迟算进响应时间里
response_time_ms = int((time.time() - self.start_time) * 1000)
response_time_ms = int((time.time() - ctx.start_time) * 1000) await asyncio.sleep(0.1)
if not ctx.provider_name: if not ctx.provider_name:
logger.warning(f"[{ctx.request_id}] 流式请求失败,未选中提供商") logger.warning(f"[{ctx.request_id}] 流式请求失败,未选中提供商")
@@ -1181,6 +1209,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
input_tokens=actual_input_tokens, input_tokens=actual_input_tokens,
output_tokens=ctx.output_tokens, output_tokens=ctx.output_tokens,
response_time_ms=response_time_ms, response_time_ms=response_time_ms,
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
status_code=ctx.status_code, status_code=ctx.status_code,
request_headers=original_headers, request_headers=original_headers,
request_body=actual_request_body, request_body=actual_request_body,
@@ -1201,9 +1230,18 @@ class CliMessageHandlerBase(BaseMessageHandler):
response_metadata=ctx.response_metadata if ctx.response_metadata else None, response_metadata=ctx.response_metadata if ctx.response_metadata else None,
) )
logger.debug(f"{self.FORMAT_ID} 流式响应完成") logger.debug(f"{self.FORMAT_ID} 流式响应完成")
# 简洁的请求完成摘要 # 简洁的请求完成摘要(两行格式)
logger.info(f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | " line1 = (
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}") f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name}"
)
if ctx.first_byte_time_ms:
line1 += f" | TTFB: {ctx.first_byte_time_ms}ms"
line2 = (
f" Total: {response_time_ms}ms | "
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}"
)
logger.info(f"{line1}\n{line2}")
# 更新候选记录的最终状态和延迟时间 # 更新候选记录的最终状态和延迟时间
# 注意RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间) # 注意RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
@@ -1255,7 +1293,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
original_request_body: Dict[str, Any], original_request_body: Dict[str, Any],
) -> None: ) -> None:
"""记录流式请求失败""" """记录流式请求失败"""
response_time_ms = int((time.time() - ctx.start_time) * 1000) # 使用 self.start_time 作为时间基准,与首字时间保持一致
response_time_ms = int((time.time() - self.start_time) * 1000)
status_code = 503 status_code = 503
if isinstance(error, ProviderAuthException): if isinstance(error, ProviderAuthException):

View File

@@ -13,6 +13,7 @@ from src.api.handlers.base.response_parser import (
ResponseParser, ResponseParser,
StreamStats, StreamStats,
) )
from src.api.handlers.base.utils import extract_cache_creation_tokens
def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[str, Any]]]: def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[str, Any]]]:
@@ -252,7 +253,7 @@ class ClaudeResponseParser(ResponseParser):
usage = response.get("usage", {}) usage = response.get("usage", {})
result.input_tokens = usage.get("input_tokens", 0) result.input_tokens = usage.get("input_tokens", 0)
result.output_tokens = usage.get("output_tokens", 0) result.output_tokens = usage.get("output_tokens", 0)
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0) result.cache_creation_tokens = extract_cache_creation_tokens(usage)
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0) result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
# 检查错误(支持嵌套错误格式) # 检查错误(支持嵌套错误格式)
@@ -265,11 +266,16 @@ class ClaudeResponseParser(ResponseParser):
return result return result
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]: def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
# 对于 message_start 事件usage 在 message.usage 路径下
# 对于其他响应usage 在顶层
usage = response.get("usage", {}) usage = response.get("usage", {})
if not usage and "message" in response:
usage = response.get("message", {}).get("usage", {})
return { return {
"input_tokens": usage.get("input_tokens", 0), "input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0), "output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0), "cache_creation_tokens": extract_cache_creation_tokens(usage),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0), "cache_read_tokens": usage.get("cache_read_input_tokens", 0),
} }

View File

@@ -8,6 +8,7 @@
- 请求/响应数据 - 请求/响应数据
""" """
import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@@ -25,12 +26,18 @@ class StreamContext:
model: str model: str
api_format: str api_format: str
# 请求标识信息CLI handler 需要)
request_id: str = ""
user_id: int = 0
api_key_id: int = 0
# Provider 信息(在请求执行时填充) # Provider 信息(在请求执行时填充)
provider_name: Optional[str] = None provider_name: Optional[str] = None
provider_id: Optional[str] = None provider_id: Optional[str] = None
endpoint_id: Optional[str] = None endpoint_id: Optional[str] = None
key_id: Optional[str] = None key_id: Optional[str] = None
attempt_id: Optional[str] = None attempt_id: Optional[str] = None
attempt_synced: bool = False
provider_api_format: Optional[str] = None # Provider 的响应格式 provider_api_format: Optional[str] = None # Provider 的响应格式
# 模型映射 # 模型映射
@@ -43,7 +50,14 @@ class StreamContext:
cache_creation_tokens: int = 0 cache_creation_tokens: int = 0
# 响应内容 # 响应内容
collected_text: str = "" _collected_text_parts: List[str] = field(default_factory=list, repr=False)
response_id: Optional[str] = None
final_usage: Optional[Dict[str, Any]] = None
final_response: Optional[Dict[str, Any]] = None
# 时间指标
first_byte_time_ms: Optional[int] = None # 首字时间 (TTFB - Time To First Byte)
start_time: float = field(default_factory=time.time)
# 响应状态 # 响应状态
status_code: int = 200 status_code: int = 200
@@ -55,6 +69,12 @@ class StreamContext:
provider_request_headers: Dict[str, str] = field(default_factory=dict) provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None provider_request_body: Optional[Dict[str, Any]] = None
# 格式转换信息CLI handler 需要)
client_api_format: str = ""
# Provider 响应元数据CLI handler 需要)
response_metadata: Dict[str, Any] = field(default_factory=dict)
# 流式处理统计 # 流式处理统计
data_count: int = 0 data_count: int = 0
chunk_count: int = 0 chunk_count: int = 0
@@ -71,16 +91,30 @@ class StreamContext:
self.chunk_count = 0 self.chunk_count = 0
self.data_count = 0 self.data_count = 0
self.has_completion = False self.has_completion = False
self.collected_text = "" self._collected_text_parts = []
self.input_tokens = 0 self.input_tokens = 0
self.output_tokens = 0 self.output_tokens = 0
self.cached_tokens = 0 self.cached_tokens = 0
self.cache_creation_tokens = 0 self.cache_creation_tokens = 0
self.error_message = None self.error_message = None
self.status_code = 200 self.status_code = 200
self.first_byte_time_ms = None
self.response_headers = {} self.response_headers = {}
self.provider_request_headers = {} self.provider_request_headers = {}
self.provider_request_body = None self.provider_request_body = None
self.response_id = None
self.final_usage = None
self.final_response = None
@property
def collected_text(self) -> str:
"""已收集的文本内容(按需拼接,避免在流式过程中频繁做字符串拷贝)"""
return "".join(self._collected_text_parts)
def append_text(self, text: str) -> None:
"""追加文本内容(仅在需要收集文本时调用)"""
if text:
self._collected_text_parts.append(text)
def update_provider_info( def update_provider_info(
self, self,
@@ -104,14 +138,40 @@ class StreamContext:
cached_tokens: Optional[int] = None, cached_tokens: Optional[int] = None,
cache_creation_tokens: Optional[int] = None, cache_creation_tokens: Optional[int] = None,
) -> None: ) -> None:
"""更新 Token 使用统计""" """
if input_tokens is not None: 更新 Token 使用统计
采用防御性更新策略:只有当新值 > 0 或当前值为 0 时才更新,避免用 0 覆盖已有的正确值。
设计原理:
- 在流式响应中,某些事件可能不包含完整的 usage 信息(字段为 0 或不存在)
- 后续事件可能会提供完整的统计数据
- 通过这种策略,确保一旦获得非零值就保留它,不会被后续的 0 值覆盖
示例场景:
- message_start 事件input_tokens=100, output_tokens=0
- message_delta 事件input_tokens=0, output_tokens=50
- 最终结果input_tokens=100, output_tokens=50
注意事项:
- 此策略假设初始值为 0 是正确的默认状态
- 如果需要将已有值重置为 0请直接修改实例属性不使用此方法
Args:
input_tokens: 输入 tokens 数量
output_tokens: 输出 tokens 数量
cached_tokens: 缓存命中 tokens 数量
cache_creation_tokens: 缓存创建 tokens 数量
"""
if input_tokens is not None and (input_tokens > 0 or self.input_tokens == 0):
self.input_tokens = input_tokens self.input_tokens = input_tokens
if output_tokens is not None: if output_tokens is not None and (output_tokens > 0 or self.output_tokens == 0):
self.output_tokens = output_tokens self.output_tokens = output_tokens
if cached_tokens is not None: if cached_tokens is not None and (cached_tokens > 0 or self.cached_tokens == 0):
self.cached_tokens = cached_tokens self.cached_tokens = cached_tokens
if cache_creation_tokens is not None: if cache_creation_tokens is not None and (
cache_creation_tokens > 0 or self.cache_creation_tokens == 0
):
self.cache_creation_tokens = cache_creation_tokens self.cache_creation_tokens = cache_creation_tokens
def mark_failed(self, status_code: int, error_message: str) -> None: def mark_failed(self, status_code: int, error_message: str) -> None:
@@ -119,6 +179,19 @@ class StreamContext:
self.status_code = status_code self.status_code = status_code
self.error_message = error_message self.error_message = error_message
def record_first_byte_time(self, start_time: float) -> None:
"""
记录首字时间 (TTFB - Time To First Byte)
应在第一次向客户端发送数据时调用。
如果已记录过,则不会覆盖(避免重试时重复记录)。
Args:
start_time: 请求开始时间 (time.time())
"""
if self.first_byte_time_ms is None:
self.first_byte_time_ms = int((time.time() - start_time) * 1000)
def is_success(self) -> bool: def is_success(self) -> bool:
"""检查请求是否成功""" """检查请求是否成功"""
return self.status_code < 400 return self.status_code < 400
@@ -145,10 +218,22 @@ class StreamContext:
获取日志摘要 获取日志摘要
用于请求完成/失败时的日志输出。 用于请求完成/失败时的日志输出。
包含首字时间 (TTFB) 和总响应时间,分两行显示。
""" """
status = "OK" if self.is_success() else "FAIL" status = "OK" if self.is_success() else "FAIL"
return (
# 第一行:基本信息 + 首字时间
line1 = (
f"[{status}] {request_id[:8]} | {self.model} | " f"[{status}] {request_id[:8]} | {self.model} | "
f"{self.provider_name or 'unknown'} | {response_time_ms}ms | " f"{self.provider_name or 'unknown'}"
)
if self.first_byte_time_ms is not None:
line1 += f" | TTFB: {self.first_byte_time_ms}ms"
# 第二行:总响应时间 + tokens
line2 = (
f" Total: {response_time_ms}ms | "
f"in:{self.input_tokens} out:{self.output_tokens}" f"in:{self.input_tokens} out:{self.output_tokens}"
) )
return f"{line1}\n{line2}"

View File

@@ -9,7 +9,9 @@
""" """
import asyncio import asyncio
import codecs
import json import json
import time
from typing import Any, AsyncGenerator, Callable, Optional from typing import Any, AsyncGenerator, Callable, Optional
import httpx import httpx
@@ -36,6 +38,8 @@ class StreamProcessor:
request_id: str, request_id: str,
default_parser: ResponseParser, default_parser: ResponseParser,
on_streaming_start: Optional[Callable[[], None]] = None, on_streaming_start: Optional[Callable[[], None]] = None,
*,
collect_text: bool = False,
): ):
""" """
初始化流处理器 初始化流处理器
@@ -48,6 +52,7 @@ class StreamProcessor:
self.request_id = request_id self.request_id = request_id
self.default_parser = default_parser self.default_parser = default_parser
self.on_streaming_start = on_streaming_start self.on_streaming_start = on_streaming_start
self.collect_text = collect_text
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser: def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
""" """
@@ -112,9 +117,10 @@ class StreamProcessor:
) )
# 提取文本 # 提取文本
text = parser.extract_text_content(data) if self.collect_text:
if text: text = parser.extract_text_content(data)
ctx.collected_text += text if text:
ctx.append_text(text)
# 检查完成 # 检查完成
event_type = event_name or data.get("type", "") event_type = event_name or data.get("type", "")
@@ -123,7 +129,7 @@ class StreamProcessor:
async def prefetch_and_check_error( async def prefetch_and_check_error(
self, self,
line_iterator: Any, byte_iterator: Any,
provider: Provider, provider: Provider,
endpoint: ProviderEndpoint, endpoint: ProviderEndpoint,
ctx: StreamContext, ctx: StreamContext,
@@ -136,97 +142,126 @@ class StreamProcessor:
这种情况需要在流开始输出之前检测,以便触发重试逻辑。 这种情况需要在流开始输出之前检测,以便触发重试逻辑。
Args: Args:
line_iterator: 迭代器 byte_iterator: 字节流迭代器
provider: Provider 对象 provider: Provider 对象
endpoint: Endpoint 对象 endpoint: Endpoint 对象
ctx: 流式上下文 ctx: 流式上下文
max_prefetch_lines: 最多预读行数 max_prefetch_lines: 最多预读行数
Returns: Returns:
预读的列表 预读的字节块列表
Raises: Raises:
EmbeddedErrorException: 如果检测到嵌套错误 EmbeddedErrorException: 如果检测到嵌套错误
""" """
prefetched_lines: list = [] prefetched_chunks: list = []
parser = self.get_parser_for_provider(ctx) parser = self.get_parser_for_provider(ctx)
buffer = b""
line_count = 0
should_stop = False
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
try: try:
line_count = 0 async for chunk in byte_iterator:
async for line in line_iterator: prefetched_chunks.append(chunk)
prefetched_lines.append(line) buffer += chunk
line_count += 1
normalized_line = line.rstrip("\r") # 尝试按行解析缓冲区
if not normalized_line or normalized_line.startswith(":"): while b"\n" in buffer:
if line_count >= max_prefetch_lines: line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\r\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
line_count += 1
# 跳过空行和注释行
if not line or line.startswith(":"):
if line_count >= max_prefetch_lines:
should_stop = True
break
continue
# 尝试解析 SSE 数据
data_str = line
if line.startswith("data: "):
data_str = line[6:]
if data_str == "[DONE]":
should_stop = True
break break
continue
# 尝试解析 SSE 数据 try:
data_str = normalized_line data = json.loads(data_str)
if normalized_line.startswith("data: "): except json.JSONDecodeError:
data_str = normalized_line[6:] if line_count >= max_prefetch_lines:
should_stop = True
break
continue
if data_str == "[DONE]": # 使用解析器检查是否为错误响应
if isinstance(data, dict) and parser.is_error_response(data):
parsed = parser.parse_response(data, 200)
logger.warning(
f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}"
)
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
should_stop = True
break break
try: if should_stop or line_count >= max_prefetch_lines:
data = json.loads(data_str) break
except json.JSONDecodeError:
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and parser.is_error_response(data):
parsed = parser.parse_response(data, 200)
logger.warning(
f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}"
)
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
break
except EmbeddedErrorException: except EmbeddedErrorException:
raise raise
except Exception as e: except Exception as e:
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}") logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_lines return prefetched_chunks
async def create_response_stream( async def create_response_stream(
self, self,
ctx: StreamContext, ctx: StreamContext,
line_iterator: Any, byte_iterator: Any,
response_ctx: Any, response_ctx: Any,
http_client: httpx.AsyncClient, http_client: httpx.AsyncClient,
prefetched_lines: Optional[list] = None, prefetched_chunks: Optional[list] = None,
*,
start_time: Optional[float] = None,
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
""" """
创建响应流生成器 创建响应流生成器
统一的流生成器,支持预读数据和不带预读数据两种情况 从字节流中解析 SSE 数据并转发,支持预读数据。
Args: Args:
ctx: 流式上下文 ctx: 流式上下文
line_iterator: 迭代器 byte_iterator: 字节流迭代器
response_ctx: HTTP 响应上下文管理器 response_ctx: HTTP 响应上下文管理器
http_client: HTTP 客户端 http_client: HTTP 客户端
prefetched_lines: 预读的列表(可选) prefetched_chunks: 预读的字节块列表(可选)
start_time: 请求开始时间,用于计算 TTFB可选
Yields: Yields:
编码后的响应数据块 编码后的响应数据块
@@ -234,25 +269,82 @@ class StreamProcessor:
try: try:
sse_parser = SSEEventParser() sse_parser = SSEEventParser()
streaming_started = False streaming_started = False
buffer = b""
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 处理预读数据 # 处理预读数据
if prefetched_lines: if prefetched_chunks:
if not streaming_started and self.on_streaming_start: if not streaming_started and self.on_streaming_start:
self.on_streaming_start() self.on_streaming_start()
streaming_started = True streaming_started = True
for line in prefetched_lines: for chunk in prefetched_chunks:
for chunk in self._process_line(ctx, sse_parser, line): # 记录首字时间 (TTFB) - 在 yield 之前记录
yield chunk if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 把原始数据转发给客户端
yield chunk
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False)
self._process_line(ctx, sse_parser, line)
except Exception as e:
# 解码失败,记录警告但继续处理
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
# 处理剩余的流数据 # 处理剩余的流数据
async for line in line_iterator: async for chunk in byte_iterator:
if not streaming_started and self.on_streaming_start: if not streaming_started and self.on_streaming_start:
self.on_streaming_start() self.on_streaming_start()
streaming_started = True streaming_started = True
for chunk in self._process_line(ctx, sse_parser, line): # 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
yield chunk if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 原始数据透传
yield chunk
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False)
self._process_line(ctx, sse_parser, line)
except Exception as e:
# 解码失败,记录警告但继续处理
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
# 处理剩余的缓冲区数据(如果有未完成的行)
if buffer:
try:
# 使用 final=True 处理最后的不完整字符
line = decoder.decode(buffer, True)
self._process_line(ctx, sse_parser, line)
except Exception as e:
logger.warning(
f"[{self.request_id}] 处理剩余缓冲区失败: {e}, "
f"bytes={buffer[:50]!r}"
)
# 处理剩余事件 # 处理剩余事件
for event in sse_parser.flush(): for event in sse_parser.flush():
@@ -268,7 +360,7 @@ class StreamProcessor:
ctx: StreamContext, ctx: StreamContext,
sse_parser: SSEEventParser, sse_parser: SSEEventParser,
line: str, line: str,
) -> list[bytes]: ) -> None:
""" """
处理单行数据 处理单行数据
@@ -276,26 +368,17 @@ class StreamProcessor:
ctx: 流式上下文 ctx: 流式上下文
sse_parser: SSE 解析器 sse_parser: SSE 解析器
line: 原始行数据 line: 原始行数据
Returns:
要发送的数据块列表
""" """
result: list[bytes] = [] # SSEEventParser 以“去掉换行符”的单行文本作为输入;这里统一剔除 CR/LF
normalized_line = line.rstrip("\r") # 避免把空行误判成 "\n" 并导致事件边界解析错误。
normalized_line = line.rstrip("\r\n")
events = sse_parser.feed_line(normalized_line) events = sse_parser.feed_line(normalized_line)
if normalized_line == "": if normalized_line != "":
for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
result.append(b"\n")
else:
ctx.chunk_count += 1 ctx.chunk_count += 1
result.append((line + "\n").encode("utf-8"))
for event in events: for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "") self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
return result
async def create_monitored_stream( async def create_monitored_stream(
self, self,
@@ -317,16 +400,26 @@ class StreamProcessor:
响应数据块 响应数据块
""" """
try: try:
# 断连检查频率:每次 await 都会引入调度开销,过于频繁会让流式"发一段停一段"
# 这里按时间间隔节流,兼顾及时停止上游读取与吞吐平滑性。
next_disconnect_check_at = 0.0
disconnect_check_interval_s = 0.25
async for chunk in stream_generator: async for chunk in stream_generator:
if await is_disconnected(): now = time.monotonic()
logger.warning(f"ID:{self.request_id} | Client disconnected") if now >= next_disconnect_check_at:
ctx.status_code = 499 # Client Closed Request next_disconnect_check_at = now + disconnect_check_interval_s
ctx.error_message = "client_disconnected" if await is_disconnected():
break logger.warning(f"ID:{self.request_id} | Client disconnected")
ctx.status_code = 499 # Client Closed Request
ctx.error_message = "client_disconnected"
break
yield chunk yield chunk
except asyncio.CancelledError: except asyncio.CancelledError:
ctx.status_code = 499 ctx.status_code = 499
ctx.error_message = "client_disconnected" ctx.error_message = "client_disconnected"
raise raise
except Exception as e: except Exception as e:
ctx.status_code = 500 ctx.status_code = 500

View File

@@ -8,6 +8,7 @@
""" """
import asyncio import asyncio
import time
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -57,7 +58,7 @@ class StreamTelemetryRecorder:
ctx: StreamContext, ctx: StreamContext,
original_headers: Dict[str, str], original_headers: Dict[str, str],
original_request_body: Dict[str, Any], original_request_body: Dict[str, Any],
response_time_ms: int, start_time: float,
) -> None: ) -> None:
""" """
记录流式统计信息 记录流式统计信息
@@ -66,11 +67,15 @@ class StreamTelemetryRecorder:
ctx: 流式上下文 ctx: 流式上下文
original_headers: 原始请求头 original_headers: 原始请求头
original_request_body: 原始请求体 original_request_body: 原始请求体
response_time_ms: 响应时间(毫秒) start_time: 请求开始时间 (time.time())
""" """
bg_db = None bg_db = None
try: try:
# 在流结束后计算响应时间,与首字时间使用相同的时间基准
# 注意不要把统计延迟stream_stats_delay算进响应时间里
response_time_ms = int((time.time() - start_time) * 1000)
await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭 await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭
if not ctx.provider_name: if not ctx.provider_name:
@@ -155,6 +160,7 @@ class StreamTelemetryRecorder:
input_tokens=ctx.input_tokens, input_tokens=ctx.input_tokens,
output_tokens=ctx.output_tokens, output_tokens=ctx.output_tokens,
response_time_ms=response_time_ms, response_time_ms=response_time_ms,
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
status_code=ctx.status_code, status_code=ctx.status_code,
request_headers=original_headers, request_headers=original_headers,
request_body=actual_request_body, request_body=actual_request_body,

View File

@@ -0,0 +1,55 @@
"""
Handler 基础工具函数
"""
from typing import Any, Dict, Optional
def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
"""
提取缓存创建 tokens兼容新旧格式
Claude API 在不同版本中使用了不同的字段名来表示缓存创建 tokens
- 新格式2024年后使用 claude_cache_creation_5_m_tokens 和
claude_cache_creation_1_h_tokens 分别表示 5 分钟和 1 小时缓存
- 旧格式:使用 cache_creation_input_tokens 表示总的缓存创建 tokens
此函数自动检测并适配两种格式,优先使用新格式。
Args:
usage: API 响应中的 usage 字典
Returns:
缓存创建 tokens 总数
"""
# 检查新格式字段是否存在(而非值是否为 0
# 如果字段存在,即使值为 0 也是合法的,不应 fallback 到旧格式
has_new_format = (
"claude_cache_creation_5_m_tokens" in usage
or "claude_cache_creation_1_h_tokens" in usage
)
if has_new_format:
cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0)
cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0)
return int(cache_5m) + int(cache_1h)
# 回退到旧格式
return int(usage.get("cache_creation_input_tokens", 0))
def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
"""
构建 SSEtext/event-stream推荐响应头用于减少代理缓冲带来的卡顿/成段输出。
说明:
- Cache-Control: no-transform 可避免部分代理对流做压缩/改写导致缓冲
- X-Accel-Buffering: no 可显式提示 Nginx 关闭缓冲(即使全局已关闭也无害)
"""
headers: Dict[str, str] = {
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
}
if extra_headers:
headers.update(extra_headers)
return headers

View File

@@ -8,6 +8,7 @@ Claude Chat Handler - 基于通用 Chat Handler 基类的简化实现
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from src.api.handlers.base.chat_handler_base import ChatHandlerBase from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.api.handlers.base.utils import extract_cache_creation_tokens
class ClaudeChatHandler(ChatHandlerBase): class ClaudeChatHandler(ChatHandlerBase):
@@ -63,7 +64,7 @@ class ClaudeChatHandler(ChatHandlerBase):
result["model"] = mapped_model result["model"] = mapped_model
return result return result
async def _convert_request(self, request): async def _convert_request(self, request: Any) -> Any:
""" """
将请求转换为 Claude 格式 将请求转换为 Claude 格式
@@ -109,30 +110,18 @@ class ClaudeChatHandler(ChatHandlerBase):
Claude 格式使用: Claude 格式使用:
- input_tokens / output_tokens - input_tokens / output_tokens
- cache_creation_input_tokens / cache_read_input_tokens - cache_creation_input_tokens / cache_read_input_tokens
- 新格式claude_cache_creation_5_m_tokens / claude_cache_creation_1_h_tokens
""" """
usage = response.get("usage", {}) usage = response.get("usage", {})
input_tokens = usage.get("input_tokens", 0)
output_tokens = usage.get("output_tokens", 0)
cache_creation_input_tokens = usage.get("cache_creation_input_tokens", 0)
cache_read_input_tokens = usage.get("cache_read_input_tokens", 0)
# 处理新的 cache_creation 格式
if "cache_creation" in usage:
cache_creation_data = usage.get("cache_creation", {})
if not cache_creation_input_tokens:
cache_creation_input_tokens = cache_creation_data.get(
"ephemeral_5m_input_tokens", 0
) + cache_creation_data.get("ephemeral_1h_input_tokens", 0)
return { return {
"input_tokens": input_tokens, "input_tokens": usage.get("input_tokens", 0),
"output_tokens": output_tokens, "output_tokens": usage.get("output_tokens", 0),
"cache_creation_input_tokens": cache_creation_input_tokens, "cache_creation_input_tokens": extract_cache_creation_tokens(usage),
"cache_read_input_tokens": cache_read_input_tokens, "cache_read_input_tokens": usage.get("cache_read_input_tokens", 0),
} }
def _normalize_response(self, response: Dict) -> Dict: def _normalize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
""" """
规范化 Claude 响应 规范化 Claude 响应
@@ -143,8 +132,9 @@ class ClaudeChatHandler(ChatHandlerBase):
规范化后的响应 规范化后的响应
""" """
if self.response_normalizer and self.response_normalizer.should_normalize(response): if self.response_normalizer and self.response_normalizer.should_normalize(response):
return self.response_normalizer.normalize_claude_response( result: Dict[str, Any] = self.response_normalizer.normalize_claude_response(
response_data=response, response_data=response,
request_id=self.request_id, request_id=self.request_id,
) )
return result
return response return response

View File

@@ -9,6 +9,8 @@ from __future__ import annotations
import json import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from src.api.handlers.base.utils import extract_cache_creation_tokens
class ClaudeStreamParser: class ClaudeStreamParser:
""" """
@@ -193,7 +195,7 @@ class ClaudeStreamParser:
return { return {
"input_tokens": usage.get("input_tokens", 0), "input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0), "output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0), "cache_creation_tokens": extract_cache_creation_tokens(usage),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0), "cache_read_tokens": usage.get("cache_read_input_tokens", 0),
} }
@@ -204,7 +206,7 @@ class ClaudeStreamParser:
return { return {
"input_tokens": usage.get("input_tokens", 0), "input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0), "output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0), "cache_creation_tokens": extract_cache_creation_tokens(usage),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0), "cache_read_tokens": usage.get("cache_read_input_tokens", 0),
} }

View File

@@ -11,6 +11,7 @@ from src.api.handlers.base.cli_handler_base import (
CliMessageHandlerBase, CliMessageHandlerBase,
StreamContext, StreamContext,
) )
from src.api.handlers.base.utils import extract_cache_creation_tokens
class ClaudeCliMessageHandler(CliMessageHandlerBase): class ClaudeCliMessageHandler(CliMessageHandlerBase):
@@ -95,11 +96,12 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
usage = message.get("usage", {}) usage = message.get("usage", {})
if usage: if usage:
ctx.input_tokens = usage.get("input_tokens", 0) ctx.input_tokens = usage.get("input_tokens", 0)
# Claude 的缓存 tokens 使用不同的字段名
cache_read = usage.get("cache_read_input_tokens", 0) cache_read = usage.get("cache_read_input_tokens", 0)
if cache_read: if cache_read:
ctx.cached_tokens = cache_read ctx.cached_tokens = cache_read
cache_creation = usage.get("cache_creation_input_tokens", 0)
cache_creation = extract_cache_creation_tokens(usage)
if cache_creation: if cache_creation:
ctx.cache_creation_tokens = cache_creation ctx.cache_creation_tokens = cache_creation
@@ -109,7 +111,7 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
if delta.get("type") == "text_delta": if delta.get("type") == "text_delta":
text = delta.get("text", "") text = delta.get("text", "")
if text: if text:
ctx.collected_text += text ctx.append_text(text)
# 处理消息增量(包含最终 usage # 处理消息增量(包含最终 usage
elif event_type == "message_delta": elif event_type == "message_delta":
@@ -119,11 +121,15 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
ctx.input_tokens = usage["input_tokens"] ctx.input_tokens = usage["input_tokens"]
if "output_tokens" in usage: if "output_tokens" in usage:
ctx.output_tokens = usage["output_tokens"] ctx.output_tokens = usage["output_tokens"]
# 更新缓存 tokens
# 更新缓存读取 tokens
if "cache_read_input_tokens" in usage: if "cache_read_input_tokens" in usage:
ctx.cached_tokens = usage["cache_read_input_tokens"] ctx.cached_tokens = usage["cache_read_input_tokens"]
if "cache_creation_input_tokens" in usage:
ctx.cache_creation_tokens = usage["cache_creation_input_tokens"] # 更新缓存创建 tokens
cache_creation = extract_cache_creation_tokens(usage)
if cache_creation > 0:
ctx.cache_creation_tokens = cache_creation
# 检查是否结束 # 检查是否结束
delta = data.get("delta", {}) delta = data.get("delta", {})

View File

@@ -160,7 +160,7 @@ class GeminiCliMessageHandler(CliMessageHandlerBase):
parts = content.get("parts", []) parts = content.get("parts", [])
for part in parts: for part in parts:
if "text" in part: if "text" in part:
ctx.collected_text += part["text"] ctx.append_text(part["text"])
# 检查结束原因 # 检查结束原因
finish_reason = candidate.get("finishReason") finish_reason = candidate.get("finishReason")

View File

@@ -94,9 +94,9 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
if event_type in ["response.output_text.delta", "response.outtext.delta"]: if event_type in ["response.output_text.delta", "response.outtext.delta"]:
delta = data.get("delta") delta = data.get("delta")
if isinstance(delta, str): if isinstance(delta, str):
ctx.collected_text += delta ctx.append_text(delta)
elif isinstance(delta, dict) and "text" in delta: elif isinstance(delta, dict) and "text" in delta:
ctx.collected_text += delta["text"] ctx.append_text(delta["text"])
# 处理完成事件 # 处理完成事件
elif event_type == "response.completed": elif event_type == "response.completed":
@@ -124,7 +124,7 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
if content_item.get("type") == "output_text": if content_item.get("type") == "output_text":
text = content_item.get("text", "") text = content_item.get("text", "")
if text: if text:
ctx.collected_text += text ctx.append_text(text)
# 备用:从顶层 usage 提取 # 备用:从顶层 usage 提取
usage_obj = data.get("usage") usage_obj = data.get("usage")

View File

@@ -120,6 +120,33 @@ class CacheService:
logger.warning(f"缓存检查失败: {key} - {e}") logger.warning(f"缓存检查失败: {key} - {e}")
return False return False
@staticmethod
async def incr(key: str, ttl_seconds: Optional[int] = None) -> int:
"""
递增缓存值
Args:
key: 缓存键
ttl_seconds: 可选,如果提供则刷新 TTL
Returns:
递增后的值,如果失败返回 0
"""
try:
redis = await get_redis_client(require_redis=False)
if not redis:
return 0
result = await redis.incr(key)
# 如果提供了 TTL刷新过期时间
if ttl_seconds is not None:
await redis.expire(key, ttl_seconds)
return result
except Exception as e:
logger.warning(f"缓存递增失败: {key} - {e}")
return 0
# 缓存键前缀 # 缓存键前缀
class CacheKeys: class CacheKeys:

View File

@@ -307,7 +307,8 @@ class Usage(Base):
is_stream = Column(Boolean, default=False) # 是否为流式请求 is_stream = Column(Boolean, default=False) # 是否为流式请求
status_code = Column(Integer) status_code = Column(Integer)
error_message = Column(Text, nullable=True) error_message = Column(Text, nullable=True)
response_time_ms = Column(Integer) # 响应时间(毫秒) response_time_ms = Column(Integer) # 响应时间(毫秒)
first_byte_time_ms = Column(Integer, nullable=True) # 首字时间/TTFB毫秒
# 请求状态追踪 # 请求状态追踪
# pending: 请求开始处理中 # pending: 请求开始处理中

View File

@@ -2,11 +2,9 @@
Model 映射缓存服务 - 减少模型查询 Model 映射缓存服务 - 减少模型查询
""" """
import json
import time import time
from typing import Optional from typing import List, Optional
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.config.constants import CacheTTL from src.config.constants import CacheTTL
@@ -106,6 +104,7 @@ class ModelCacheService:
Model 对象或 None Model 对象或 None
""" """
cache_key = f"model:provider_global:{provider_id}:{global_model_id}" cache_key = f"model:provider_global:{provider_id}:{global_model_id}"
hit_count_key = f"model:provider_global:hits:{provider_id}:{global_model_id}"
# 1. 尝试从缓存获取 # 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key) cached_data = await CacheService.get(cache_key)
@@ -113,6 +112,8 @@ class ModelCacheService:
logger.debug( logger.debug(
f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..." f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
) )
# 递增命中计数,同时刷新 TTL
await CacheService.incr(hit_count_key, ttl_seconds=ModelCacheService.CACHE_TTL)
return ModelCacheService._dict_to_model(cached_data) return ModelCacheService._dict_to_model(cached_data)
# 2. 缓存未命中,查询数据库 # 2. 缓存未命中,查询数据库
@@ -130,6 +131,8 @@ class ModelCacheService:
if model: if model:
model_dict = ModelCacheService._model_to_dict(model) model_dict = ModelCacheService._model_to_dict(model)
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL) await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
# 重置命中计数新缓存从1开始
await CacheService.set(hit_count_key, 1, ttl_seconds=ModelCacheService.CACHE_TTL)
logger.debug( logger.debug(
f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..." f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
) )
@@ -189,9 +192,10 @@ class ModelCacheService:
# 清除 model:id 缓存 # 清除 model:id 缓存
await CacheService.delete(f"model:id:{model_id}") await CacheService.delete(f"model:id:{model_id}")
# 清除 provider_global 缓存(如果提供了必要参数) # 清除 provider_global 缓存及其命中计数(如果提供了必要参数)
if provider_id and global_model_id: if provider_id and global_model_id:
await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}") await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}")
await CacheService.delete(f"model:provider_global:hits:{provider_id}:{global_model_id}")
logger.debug( logger.debug(
f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}..." f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}..."
) )
@@ -230,16 +234,20 @@ class ModelCacheService:
db: Session, model_name: str db: Session, model_name: str
) -> Optional[GlobalModel]: ) -> Optional[GlobalModel]:
""" """
通过名称或映射解析 GlobalModel带缓存,支持映射匹配 通过名称解析 GlobalModel带缓存
查找顺序: 查找顺序:
1. 检查缓存 1. 检查缓存
2. 通过映射匹配(查询 Model 表的 provider_model_name 和 provider_model_aliases 2. 通过 provider_model_name 匹配(查询 Model 表
3. 直接匹配 GlobalModel.name兜底 3. 直接匹配 GlobalModel.name兜底
注意:此方法不使用 provider_model_aliases 进行全局解析。
provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效,
由 resolve_provider_model() 处理。
Args: Args:
db: 数据库会话 db: 数据库会话
model_name: 模型名称(可以是 GlobalModel.name 或映射名称 model_name: 模型名称(可以是 GlobalModel.name 或 provider_model_name
Returns: Returns:
GlobalModel 对象或 None GlobalModel 对象或 None
@@ -273,116 +281,53 @@ class ModelCacheService:
logger.debug(f"GlobalModel 缓存命中(映射解析): {normalized_name}") logger.debug(f"GlobalModel 缓存命中(映射解析): {normalized_name}")
return ModelCacheService._dict_to_global_model(cached_data) return ModelCacheService._dict_to_global_model(cached_data)
# 2. 优先通过 provider_model_name 和映射名称匹配Provider 配置优先级最高 # 2. 通过 provider_model_name 匹配(不考虑 provider_model_aliases
from sqlalchemy import or_ # 重要provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效
# 全局解析不应该受到某个 Provider 别名配置的影响
# 例如Provider A 把 "haiku" 映射到 "sonnet",不应该影响 Provider B 的 "haiku" 解析
from src.models.database import Provider from src.models.database import Provider
# 构建精确的映射匹配条件 models_with_global = (
# 注意provider_model_aliases 是 JSONB 数组,需要使用 PostgreSQL 的 JSONB 操作符 db.query(Model, GlobalModel)
# 对于 SQLite会在 Python 层面进行过滤 .join(Provider, Model.provider_id == Provider.id)
try: .join(GlobalModel, Model.global_model_id == GlobalModel.id)
# 尝试使用 PostgreSQL 的 JSONB 查询(更高效) .filter(
# 使用 json.dumps 确保正确转义特殊字符,避免 SQL 注入 Provider.is_active == True,
jsonb_pattern = json.dumps([{"name": normalized_name}]) Model.is_active == True,
models_with_global = ( GlobalModel.is_active == True,
db.query(Model, GlobalModel) Model.provider_model_name == normalized_name,
.join(Provider, Model.provider_id == Provider.id)
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(
Provider.is_active == True,
Model.is_active == True,
GlobalModel.is_active == True,
or_(
Model.provider_model_name == normalized_name,
# PostgreSQL JSONB 查询:检查数组中是否有包含 {"name": "xxx"} 的元素
Model.provider_model_aliases.op("@>")(jsonb_pattern),
),
)
.all()
)
except (OperationalError, ProgrammingError) as e:
# JSONB 操作符不支持(如 SQLite回退到加载匹配 provider_model_name 的 Model
# 并在 Python 层过滤 aliases
logger.debug(
f"JSONB 查询失败,回退到 Python 过滤: {e}",
)
# 优化:先用 provider_model_name 缩小范围,再加载其他可能匹配的记录
models_with_global = (
db.query(Model, GlobalModel)
.join(Provider, Model.provider_id == Provider.id)
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(
Provider.is_active == True,
Model.is_active == True,
GlobalModel.is_active == True,
)
.all()
) )
.all()
)
# 用于存储匹配结果:{(model_id, global_model_id): (GlobalModel, match_type, priority)} # 收集匹配的 GlobalModel(只通过 provider_model_name 匹配)
# 使用字典去重,同一个 Model 只保留优先级最高的匹配 matched_global_models: List[GlobalModel] = []
matched_models_dict = {} seen_global_model_ids: set[str] = set()
# 遍历查询结果进行匹配
for model, gm in models_with_global: for model, gm in models_with_global:
key = (model.id, gm.id) if gm.id not in seen_global_model_ids:
seen_global_model_ids.add(gm.id)
# 检查 provider_model_aliases 是否匹配(优先级更高) matched_global_models.append(gm)
if model.provider_model_aliases: logger.debug(
for alias_entry in model.provider_model_aliases: f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
if isinstance(alias_entry, dict): f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
alias_name = alias_entry.get("name", "").strip()
if alias_name == normalized_name:
# alias 优先级为 0最高覆盖任何已存在的匹配
matched_models_dict[key] = (gm, "alias", 0)
logger.debug(
f"模型名称 '{normalized_name}' 通过映射名称匹配到 "
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
)
break
# 如果还没有匹配(或只有 provider_model_name 匹配),检查 provider_model_name
if key not in matched_models_dict or matched_models_dict[key][1] != "alias":
if model.provider_model_name == normalized_name:
# provider_model_name 优先级为 1兜底只在没有 alias 匹配时使用
if key not in matched_models_dict:
matched_models_dict[key] = (gm, "provider_model_name", 1)
logger.debug(
f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
)
# 如果通过 provider_model_name/alias 找到了,直接返回
if matched_models_dict:
# 转换为列表并排序:按 priorityalias=0 优先)、然后按 GlobalModel.name
matched_global_models = [
(gm, match_type) for gm, match_type, priority in matched_models_dict.values()
]
matched_global_models.sort(
key=lambda item: (
0 if item[1] == "alias" else 1, # alias 优先
item[0].name # 同优先级按名称排序(确定性)
) )
)
# 记录解析方式 # 如果通过 provider_model_name 找到了,返回
resolution_method = matched_global_models[0][1] if matched_global_models:
resolution_method = "provider_model_name"
if len(matched_global_models) > 1: if len(matched_global_models) > 1:
# 检测到冲突 # 检测到冲突(多个不同的 GlobalModel 有相同的 provider_model_name
unique_models = {gm.id: gm for gm, _ in matched_global_models} model_names = [gm.name for gm in matched_global_models if gm.name]
if len(unique_models) > 1: logger.warning(
model_names = [gm.name for gm in unique_models.values()] f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: "
logger.warning( f"{', '.join(model_names)},使用第一个匹配结果"
f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: " )
f"{', '.join(model_names)},使用第一个匹配结果" # 记录冲突指标
) model_mapping_conflict_total.inc()
# 记录冲突指标
model_mapping_conflict_total.inc()
# 返回第一个匹配的 GlobalModel # 返回第一个匹配的 GlobalModel
result_global_model: GlobalModel = matched_global_models[0][0] result_global_model = matched_global_models[0]
global_model_dict = ModelCacheService._global_model_to_dict(result_global_model) global_model_dict = ModelCacheService._global_model_to_dict(result_global_model)
await CacheService.set( await CacheService.set(
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL

View File

@@ -6,7 +6,7 @@
- 根据 API 格式或端点配置生成请求 URL - 根据 API 格式或端点配置生成请求 URL
""" """
from typing import Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
from urllib.parse import urlencode from urllib.parse import urlencode
from src.core.api_format_metadata import get_auth_config, get_default_path, resolve_api_format from src.core.api_format_metadata import get_auth_config, get_default_path, resolve_api_format
@@ -14,11 +14,14 @@ from src.core.crypto import crypto_service
from src.core.enums import APIFormat from src.core.enums import APIFormat
from src.core.logger import logger from src.core.logger import logger
if TYPE_CHECKING:
from src.models.database import ProviderAPIKey, ProviderEndpoint
def build_provider_headers( def build_provider_headers(
endpoint, endpoint: "ProviderEndpoint",
key, key: "ProviderAPIKey",
original_headers: Optional[Dict[str, str]] = None, original_headers: Optional[Dict[str, str]] = None,
*, *,
extra_headers: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None,
@@ -28,7 +31,8 @@ def build_provider_headers(
""" """
headers: Dict[str, str] = {} headers: Dict[str, str] = {}
decrypted_key = crypto_service.decrypt(key.api_key) # api_key 在数据库中是 NOT NULL类型标注为 Optional 是 SQLAlchemy 限制
decrypted_key = crypto_service.decrypt(key.api_key) # type: ignore[arg-type]
# 根据 API 格式自动选择认证头 # 根据 API 格式自动选择认证头
api_format = getattr(endpoint, "api_format", None) api_format = getattr(endpoint, "api_format", None)
@@ -68,8 +72,32 @@ def build_provider_headers(
return headers return headers
def _normalize_base_url(base_url: str, path: str) -> str:
"""
规范化 base_url去除末尾的斜杠和可能与 path 重复的版本前缀。
只有当 path 以版本前缀开头时,才从 base_url 中移除该前缀,
避免拼接出 /v1/v1/messages 这样的重复路径。
兼容用户填写的各种格式:
- https://api.example.com
- https://api.example.com/
- https://api.example.com/v1
- https://api.example.com/v1/
"""
base = base_url.rstrip("/")
# 只在 path 以版本前缀开头时才去除 base_url 中的该前缀
# 例如base="/v1", path="/v1/messages" -> 去除 /v1
# 例如base="/v1", path="/chat/completions" -> 不去除(用户可能期望保留)
for suffix in ("/v1beta", "/v1", "/v2", "/v3"):
if base.endswith(suffix) and path.startswith(suffix):
base = base[: -len(suffix)]
break
return base
def build_provider_url( def build_provider_url(
endpoint, endpoint: "ProviderEndpoint",
*, *,
query_params: Optional[Dict[str, Any]] = None, query_params: Optional[Dict[str, Any]] = None,
path_params: Optional[Dict[str, Any]] = None, path_params: Optional[Dict[str, Any]] = None,
@@ -88,8 +116,6 @@ def build_provider_url(
path_params: 路径模板参数 (如 {model}) path_params: 路径模板参数 (如 {model})
is_stream: 是否为流式请求,用于 Gemini API 选择正确的操作方法 is_stream: 是否为流式请求,用于 Gemini API 选择正确的操作方法
""" """
base = endpoint.base_url.rstrip("/")
# 准备路径参数,添加 Gemini API 所需的 action 参数 # 准备路径参数,添加 Gemini API 所需的 action 参数
effective_path_params = dict(path_params) if path_params else {} effective_path_params = dict(path_params) if path_params else {}
@@ -123,6 +149,9 @@ def build_provider_url(
if not path.startswith("/"): if not path.startswith("/"):
path = f"/{path}" path = f"/{path}"
# 先确定 path再根据 path 规范化 base_url
# base_url 在数据库中是 NOT NULL类型标注为 Optional 是 SQLAlchemy 限制
base = _normalize_base_url(endpoint.base_url, path) # type: ignore[arg-type]
url = f"{base}{path}" url = f"{base}{path}"
# 添加查询参数 # 添加查询参数
@@ -134,7 +163,7 @@ def build_provider_url(
return url return url
def _resolve_default_path(api_format) -> str: def _resolve_default_path(api_format: Optional[str]) -> str:
""" """
根据 API 格式返回默认路径 根据 API 格式返回默认路径
""" """

View File

@@ -157,6 +157,7 @@ class UsageService:
api_format: Optional[str] = None, api_format: Optional[str] = None,
is_stream: bool = False, is_stream: bool = False,
response_time_ms: Optional[int] = None, response_time_ms: Optional[int] = None,
first_byte_time_ms: Optional[int] = None, # 首字时间 (TTFB)
status_code: int = 200, status_code: int = 200,
error_message: Optional[str] = None, error_message: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
@@ -368,6 +369,7 @@ class UsageService:
status_code=status_code, status_code=status_code,
error_message=error_message, error_message=error_message,
response_time_ms=response_time_ms, response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms, # 首字时间 (TTFB)
status=status, # 请求状态追踪 status=status, # 请求状态追踪
request_metadata=metadata, request_metadata=metadata,
request_headers=processed_request_headers, request_headers=processed_request_headers,
@@ -419,6 +421,7 @@ class UsageService:
api_format: Optional[str] = None, api_format: Optional[str] = None,
is_stream: bool = False, is_stream: bool = False,
response_time_ms: Optional[int] = None, response_time_ms: Optional[int] = None,
first_byte_time_ms: Optional[int] = None, # 首字时间 (TTFB)
status_code: int = 200, status_code: int = 200,
error_message: Optional[str] = None, error_message: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
@@ -629,6 +632,7 @@ class UsageService:
status_code=status_code, status_code=status_code,
error_message=error_message, error_message=error_message,
response_time_ms=response_time_ms, response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms, # 首字时间 (TTFB)
status=status, # 请求状态追踪 status=status, # 请求状态追踪
request_metadata=metadata, request_metadata=metadata,
request_headers=processed_request_headers, request_headers=processed_request_headers,
@@ -649,6 +653,7 @@ class UsageService:
existing_usage.status_code = status_code existing_usage.status_code = status_code
existing_usage.error_message = error_message existing_usage.error_message = error_message
existing_usage.response_time_ms = response_time_ms existing_usage.response_time_ms = response_time_ms
existing_usage.first_byte_time_ms = first_byte_time_ms # 更新首字时间
# 更新请求头和请求体(如果有新值) # 更新请求头和请求体(如果有新值)
if processed_request_headers is not None: if processed_request_headers is not None:
existing_usage.request_headers = processed_request_headers existing_usage.request_headers = processed_request_headers
@@ -1315,11 +1320,11 @@ class UsageService:
default_timeout_seconds: int = 300, default_timeout_seconds: int = 300,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
获取活跃请求状态(用于前端轮询),并自动清理超时的 pending 请求 获取活跃请求状态(用于前端轮询),并自动清理超时的 pending/streaming 请求
与 get_active_requests 不同,此方法: 与 get_active_requests 不同,此方法:
1. 返回轻量级的状态字典而非完整 Usage 对象 1. 返回轻量级的状态字典而非完整 Usage 对象
2. 自动检测并清理超时的 pending 请求 2. 自动检测并清理超时的 pending/streaming 请求
3. 支持按 ID 列表查询特定请求 3. 支持按 ID 列表查询特定请求
Args: Args:
@@ -1343,6 +1348,7 @@ class UsageService:
Usage.output_tokens, Usage.output_tokens,
Usage.total_cost_usd, Usage.total_cost_usd,
Usage.response_time_ms, Usage.response_time_ms,
Usage.first_byte_time_ms, # 首字时间 (TTFB)
Usage.created_at, Usage.created_at,
Usage.provider_endpoint_id, Usage.provider_endpoint_id,
ProviderEndpoint.timeout.label("endpoint_timeout"), ProviderEndpoint.timeout.label("endpoint_timeout"),
@@ -1361,10 +1367,10 @@ class UsageService:
records = query.all() records = query.all()
# 检查超时的 pending 请求 # 检查超时的 pending/streaming 请求
timeout_ids = [] timeout_ids = []
for r in records: for r in records:
if r.status == "pending" and r.created_at: if r.status in ("pending", "streaming") and r.created_at:
# 使用端点配置的超时时间,若无则使用默认值 # 使用端点配置的超时时间,若无则使用默认值
timeout_seconds = r.endpoint_timeout or default_timeout_seconds timeout_seconds = r.endpoint_timeout or default_timeout_seconds
@@ -1392,6 +1398,7 @@ class UsageService:
"output_tokens": r.output_tokens, "output_tokens": r.output_tokens,
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0, "cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
"response_time_ms": r.response_time_ms, "response_time_ms": r.response_time_ms,
"first_byte_time_ms": r.first_byte_time_ms, # 首字时间 (TTFB)
} }
for r in records for r in records
] ]

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""测试模块"""

View File

@@ -0,0 +1,117 @@
from src.api.handlers.base import stream_context
from src.api.handlers.base.stream_context import StreamContext
def test_collected_text_append_and_property() -> None:
ctx = StreamContext(model="test-model", api_format="OPENAI")
assert ctx.collected_text == ""
ctx.append_text("hello")
ctx.append_text(" ")
ctx.append_text("world")
assert ctx.collected_text == "hello world"
def test_reset_for_retry_clears_state() -> None:
ctx = StreamContext(model="test-model", api_format="OPENAI")
ctx.append_text("x")
ctx.update_usage(input_tokens=10, output_tokens=5)
ctx.parsed_chunks.append({"type": "chunk"})
ctx.chunk_count = 3
ctx.data_count = 2
ctx.has_completion = True
ctx.status_code = 418
ctx.error_message = "boom"
ctx.reset_for_retry()
assert ctx.collected_text == ""
assert ctx.input_tokens == 0
assert ctx.output_tokens == 0
assert ctx.parsed_chunks == []
assert ctx.chunk_count == 0
assert ctx.data_count == 0
assert ctx.has_completion is False
assert ctx.status_code == 200
assert ctx.error_message is None
def test_record_first_byte_time(monkeypatch) -> None:
"""测试记录首字时间"""
ctx = StreamContext(model="claude-3", api_format="claude_messages")
start_time = 100.0
monkeypatch.setattr(stream_context.time, "time", lambda: 100.0123) # 12.3ms
# 记录首字时间
ctx.record_first_byte_time(start_time)
# 验证首字时间已记录
assert ctx.first_byte_time_ms == 12
def test_record_first_byte_time_idempotent(monkeypatch) -> None:
"""测试首字时间只记录一次"""
ctx = StreamContext(model="claude-3", api_format="claude_messages")
start_time = 100.0
# 第一次记录
monkeypatch.setattr(stream_context.time, "time", lambda: 100.010)
ctx.record_first_byte_time(start_time)
first_value = ctx.first_byte_time_ms
# 第二次记录(应该被忽略)
monkeypatch.setattr(stream_context.time, "time", lambda: 100.020)
ctx.record_first_byte_time(start_time)
second_value = ctx.first_byte_time_ms
# 验证值没有改变
assert first_value == second_value
def test_reset_for_retry_clears_first_byte_time(monkeypatch) -> None:
"""测试重试时清除首字时间"""
ctx = StreamContext(model="claude-3", api_format="claude_messages")
start_time = 100.0
# 记录首字时间
monkeypatch.setattr(stream_context.time, "time", lambda: 100.010)
ctx.record_first_byte_time(start_time)
assert ctx.first_byte_time_ms is not None
# 重置
ctx.reset_for_retry()
# 验证首字时间已清除
assert ctx.first_byte_time_ms is None
def test_get_log_summary_with_first_byte_time() -> None:
"""测试日志摘要包含首字时间"""
ctx = StreamContext(model="claude-3", api_format="claude_messages")
ctx.provider_name = "anthropic"
ctx.input_tokens = 100
ctx.output_tokens = 50
ctx.first_byte_time_ms = 123
summary = ctx.get_log_summary("request-id-123", 456)
# 验证包含首字时间和总时间(大写格式)
assert "TTFB: 123ms" in summary
assert "Total: 456ms" in summary
assert "in:100 out:50" in summary
def test_get_log_summary_without_first_byte_time() -> None:
"""测试日志摘要在没有首字时间时的格式"""
ctx = StreamContext(model="claude-3", api_format="claude_messages")
ctx.provider_name = "anthropic"
ctx.input_tokens = 100
ctx.output_tokens = 50
# first_byte_time_ms 保持为 None
summary = ctx.get_log_summary("request-id-123", 456)
# 验证不包含首字时间标记,但有总时间(使用大写 TTFB 和 Total
assert "TTFB:" not in summary
assert "Total: 456ms" in summary
assert "in:100 out:50" in summary

View File

@@ -0,0 +1,32 @@
from typing import Any, Dict, Optional
from src.api.handlers.base.response_parser import ParsedChunk, ParsedResponse, ResponseParser, StreamStats
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.stream_processor import StreamProcessor
from src.utils.sse_parser import SSEEventParser
class DummyParser(ResponseParser):
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
return None
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
return ParsedResponse(raw_response=response, status_code=status_code)
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
return {}
def extract_text_content(self, response: Dict[str, Any]) -> str:
return ""
def test_process_line_strips_newlines_and_finalizes_event() -> None:
ctx = StreamContext(model="test-model", api_format="OPENAI")
processor = StreamProcessor(request_id="test-request", default_parser=DummyParser())
sse_parser = SSEEventParser()
processor._process_line(ctx, sse_parser, 'data: {"type":"response.completed"}\n')
processor._process_line(ctx, sse_parser, "\n")
assert ctx.has_completion is True

View File

@@ -0,0 +1,104 @@
"""测试 handler 基础工具函数"""
import pytest
from src.api.handlers.base.utils import build_sse_headers, extract_cache_creation_tokens
class TestExtractCacheCreationTokens:
"""测试 extract_cache_creation_tokens 函数"""
def test_new_format_only(self) -> None:
"""测试只有新格式字段"""
usage = {
"claude_cache_creation_5_m_tokens": 100,
"claude_cache_creation_1_h_tokens": 200,
}
assert extract_cache_creation_tokens(usage) == 300
def test_new_format_5m_only(self) -> None:
"""测试只有 5 分钟缓存"""
usage = {
"claude_cache_creation_5_m_tokens": 150,
"claude_cache_creation_1_h_tokens": 0,
}
assert extract_cache_creation_tokens(usage) == 150
def test_new_format_1h_only(self) -> None:
"""测试只有 1 小时缓存"""
usage = {
"claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 250,
}
assert extract_cache_creation_tokens(usage) == 250
def test_old_format_only(self) -> None:
"""测试只有旧格式字段"""
usage = {
"cache_creation_input_tokens": 500,
}
assert extract_cache_creation_tokens(usage) == 500
def test_both_formats_prefers_new(self) -> None:
"""测试同时存在时优先使用新格式"""
usage = {
"claude_cache_creation_5_m_tokens": 100,
"claude_cache_creation_1_h_tokens": 200,
"cache_creation_input_tokens": 999, # 应该被忽略
}
assert extract_cache_creation_tokens(usage) == 300
def test_empty_usage(self) -> None:
"""测试空字典"""
usage = {}
assert extract_cache_creation_tokens(usage) == 0
def test_all_zeros(self) -> None:
"""测试所有字段都为 0"""
usage = {
"claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 0,
"cache_creation_input_tokens": 0,
}
assert extract_cache_creation_tokens(usage) == 0
def test_partial_new_format_with_old_format_fallback(self) -> None:
"""测试新格式字段不存在时回退到旧格式"""
usage = {
"cache_creation_input_tokens": 123,
}
assert extract_cache_creation_tokens(usage) == 123
def test_new_format_zero_should_not_fallback(self) -> None:
"""测试新格式字段存在但为 0 时,不应 fallback 到旧格式"""
usage = {
"claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 0,
"cache_creation_input_tokens": 456,
}
# 新格式字段存在,即使值为 0 也应该使用新格式(返回 0
# 而不是 fallback 到旧格式(返回 456
assert extract_cache_creation_tokens(usage) == 0
def test_unrelated_fields_ignored(self) -> None:
"""测试忽略无关字段"""
usage = {
"input_tokens": 1000,
"output_tokens": 2000,
"cache_read_input_tokens": 300,
"claude_cache_creation_5_m_tokens": 50,
"claude_cache_creation_1_h_tokens": 75,
}
assert extract_cache_creation_tokens(usage) == 125
class TestBuildSSEHeaders:
def test_default_headers(self) -> None:
headers = build_sse_headers()
assert headers["Cache-Control"] == "no-cache, no-transform"
assert headers["X-Accel-Buffering"] == "no"
def test_merge_extra_headers(self) -> None:
headers = build_sse_headers({"X-Test": "1", "Cache-Control": "custom"})
assert headers["X-Test"] == "1"
assert headers["Cache-Control"] == "custom"