""" 请求链路追踪 API 端点 """ from dataclasses import dataclass from datetime import datetime from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Query, Request from pydantic import BaseModel, ConfigDict from sqlalchemy.orm import Session from src.api.base.admin_adapter import AdminApiAdapter from src.api.base.pipeline import ApiRequestPipeline from src.database import get_db from src.models.database import Provider, ProviderEndpoint, ProviderAPIKey from src.core.crypto import crypto_service from src.services.request.candidate import RequestCandidateService router = APIRouter(prefix="/api/admin/monitoring/trace", tags=["Admin - Monitoring: Trace"]) pipeline = ApiRequestPipeline() class CandidateResponse(BaseModel): """候选记录响应""" id: str request_id: str candidate_index: int retry_index: int = 0 # 重试序号(从0开始) provider_id: Optional[str] = None provider_name: Optional[str] = None provider_website: Optional[str] = None # Provider 官网 endpoint_id: Optional[str] = None endpoint_name: Optional[str] = None # 端点显示名称(api_format) key_id: Optional[str] = None key_name: Optional[str] = None # 密钥名称 key_preview: Optional[str] = None # 密钥脱敏预览(如 sk-***abc) key_capabilities: Optional[dict] = None # Key 支持的能力 required_capabilities: Optional[dict] = None # 请求实际需要的能力标签 status: str # 'pending', 'success', 'failed', 'skipped' skip_reason: Optional[str] = None is_cached: bool = False # 执行结果字段 status_code: Optional[int] = None error_type: Optional[str] = None error_message: Optional[str] = None latency_ms: Optional[int] = None concurrent_requests: Optional[int] = None extra_data: Optional[dict] = None created_at: datetime started_at: Optional[datetime] = None finished_at: Optional[datetime] = None model_config = ConfigDict(from_attributes=True) class RequestTraceResponse(BaseModel): """请求追踪完整响应""" request_id: str total_candidates: int final_status: str # 'success', 'failed', 'streaming', 'pending' total_latency_ms: int candidates: List[CandidateResponse] @router.get("/{request_id}", response_model=RequestTraceResponse) async def get_request_trace( request_id: str, request: Request, db: Session = Depends(get_db), ): """ 获取请求的完整追踪信息 获取指定请求的完整链路追踪信息,包括所有候选(candidates)的执行情况。 **路径参数**: - `request_id`: 请求 ID **返回字段**: - `request_id`: 请求 ID - `total_candidates`: 候选总数 - `final_status`: 最终状态(success: 成功,failed: 失败,streaming: 流式传输中,pending: 等待中) - `total_latency_ms`: 总延迟(毫秒) - `candidates`: 候选列表,每个候选包含: - `id`: 候选 ID - `request_id`: 请求 ID - `candidate_index`: 候选索引 - `retry_index`: 重试序号 - `provider_id`: 提供商 ID - `provider_name`: 提供商名称 - `provider_website`: 提供商官网 - `endpoint_id`: 端点 ID - `endpoint_name`: 端点名称(API 格式) - `key_id`: 密钥 ID - `key_name`: 密钥名称 - `key_preview`: 密钥脱敏预览 - `key_capabilities`: 密钥支持的能力 - `required_capabilities`: 请求需要的能力标签 - `status`: 状态(pending, success, failed, skipped) - `skip_reason`: 跳过原因 - `is_cached`: 是否缓存命中 - `status_code`: HTTP 状态码 - `error_type`: 错误类型 - `error_message`: 错误信息 - `latency_ms`: 延迟(毫秒) - `concurrent_requests`: 并发请求数 - `extra_data`: 额外数据 - `created_at`: 创建时间 - `started_at`: 开始时间 - `finished_at`: 完成时间 """ adapter = AdminGetRequestTraceAdapter(request_id=request_id) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.get("/stats/provider/{provider_id}") async def get_provider_failure_rate( provider_id: str, request: Request, limit: int = Query(100, ge=1, le=1000, description="统计最近的尝试数量"), db: Session = Depends(get_db), ): """ 获取提供商的失败率统计 获取指定提供商最近的失败率统计信息。需要管理员权限。 **路径参数**: - `provider_id`: 提供商 ID **查询参数**: - `limit`: 统计最近的尝试数量,默认 100,最大 1000 **返回字段**: - `provider_id`: 提供商 ID - `total_attempts`: 总尝试次数 - `success_count`: 成功次数 - `failed_count`: 失败次数 - `failure_rate`: 失败率(百分比) - `avg_latency_ms`: 平均延迟(毫秒) """ adapter = AdminProviderFailureRateAdapter(provider_id=provider_id, limit=limit) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) # -------- 请求追踪适配器 -------- @dataclass class AdminGetRequestTraceAdapter(AdminApiAdapter): request_id: str async def handle(self, context): # type: ignore[override] db = context.db # 只查询 candidates candidates = RequestCandidateService.get_candidates_by_request_id(db, self.request_id) # 如果没有数据,返回 404 if not candidates: raise HTTPException(status_code=404, detail="Request not found") # 计算总延迟(只统计已完成的候选:success 或 failed) # 使用显式的 is not None 检查,避免过滤掉 0ms 的快速响应 total_latency = sum( c.latency_ms for c in candidates if c.status in ("success", "failed") and c.latency_ms is not None ) # 判断最终状态: # 1. status="success" 即视为成功(无论 status_code 是什么) # - 流式请求即使客户端断开(499),只要 Provider 成功返回数据,也算成功 # 2. 同时检查 status_code 在 200-299 范围,作为额外的成功判断条件 # - 用于兼容非流式请求或未正确设置 status 的旧数据 # 3. status="streaming" 表示流式请求正在进行中 # 4. status="pending" 表示请求尚未开始执行 has_success = any( c.status == "success" or (c.status_code is not None and 200 <= c.status_code < 300) for c in candidates ) has_streaming = any(c.status == "streaming" for c in candidates) has_pending = any(c.status == "pending" for c in candidates) if has_success: final_status = "success" elif has_streaming: # 有候选正在流式传输中 final_status = "streaming" elif has_pending: # 有候选正在等待执行 final_status = "pending" else: final_status = "failed" # 批量加载 provider 信息,避免 N+1 查询 provider_ids = {c.provider_id for c in candidates if c.provider_id} provider_map = {} provider_website_map = {} if provider_ids: providers = db.query(Provider).filter(Provider.id.in_(provider_ids)).all() for p in providers: provider_map[p.id] = p.name provider_website_map[p.id] = p.website # 批量加载 endpoint 信息 endpoint_ids = {c.endpoint_id for c in candidates if c.endpoint_id} endpoint_map = {} if endpoint_ids: endpoints = db.query(ProviderEndpoint).filter(ProviderEndpoint.id.in_(endpoint_ids)).all() endpoint_map = {e.id: e.api_format for e in endpoints} # 批量加载 key 信息 key_ids = {c.key_id for c in candidates if c.key_id} key_map = {} key_preview_map = {} key_capabilities_map = {} if key_ids: keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.id.in_(key_ids)).all() for k in keys: key_map[k.id] = k.name key_capabilities_map[k.id] = k.capabilities # 生成脱敏预览:先解密再脱敏 try: decrypted_key = crypto_service.decrypt(k.api_key) if len(decrypted_key) > 8: # 检测常见前缀模式 prefix_end = 0 for prefix in ["sk-", "key-", "api-", "ak-"]: if decrypted_key.lower().startswith(prefix): prefix_end = len(prefix) break if prefix_end > 0: key_preview_map[k.id] = f"{decrypted_key[:prefix_end]}***{decrypted_key[-4:]}" else: key_preview_map[k.id] = f"{decrypted_key[:4]}***{decrypted_key[-4:]}" elif len(decrypted_key) > 4: key_preview_map[k.id] = f"***{decrypted_key[-4:]}" else: key_preview_map[k.id] = "***" except Exception: key_preview_map[k.id] = "***" # 构建 candidate 响应列表 candidate_responses: List[CandidateResponse] = [] for candidate in candidates: provider_name = ( provider_map.get(candidate.provider_id) if candidate.provider_id else None ) provider_website = ( provider_website_map.get(candidate.provider_id) if candidate.provider_id else None ) endpoint_name = ( endpoint_map.get(candidate.endpoint_id) if candidate.endpoint_id else None ) key_name = ( key_map.get(candidate.key_id) if candidate.key_id else None ) key_preview = ( key_preview_map.get(candidate.key_id) if candidate.key_id else None ) key_capabilities = ( key_capabilities_map.get(candidate.key_id) if candidate.key_id else None ) candidate_responses.append( CandidateResponse( id=candidate.id, request_id=candidate.request_id, candidate_index=candidate.candidate_index, retry_index=candidate.retry_index, provider_id=candidate.provider_id, provider_name=provider_name, provider_website=provider_website, endpoint_id=candidate.endpoint_id, endpoint_name=endpoint_name, key_id=candidate.key_id, key_name=key_name, key_preview=key_preview, key_capabilities=key_capabilities, required_capabilities=candidate.required_capabilities, status=candidate.status, skip_reason=candidate.skip_reason, is_cached=candidate.is_cached, status_code=candidate.status_code, error_type=candidate.error_type, error_message=candidate.error_message, latency_ms=candidate.latency_ms, concurrent_requests=candidate.concurrent_requests, extra_data=candidate.extra_data, created_at=candidate.created_at, started_at=candidate.started_at, finished_at=candidate.finished_at, ) ) response = RequestTraceResponse( request_id=self.request_id, total_candidates=len(candidates), final_status=final_status, total_latency_ms=total_latency, candidates=candidate_responses, ) context.add_audit_metadata( action="trace_request_detail", request_id=self.request_id, total_candidates=len(candidates), final_status=final_status, total_latency_ms=total_latency, ) return response @dataclass class AdminProviderFailureRateAdapter(AdminApiAdapter): provider_id: str limit: int async def handle(self, context): # type: ignore[override] result = RequestCandidateService.get_candidate_stats_by_provider( db=context.db, provider_id=self.provider_id, limit=self.limit, ) context.add_audit_metadata( action="trace_provider_failure_rate", provider_id=self.provider_id, limit=self.limit, total_attempts=result.get("total_attempts"), failure_rate=result.get("failure_rate"), ) return result