Files
Aether/src/api/admin/monitoring/trace.py

284 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
请求链路追踪 API 端点
"""
from dataclasses import dataclass
from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel, ConfigDict
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.database import get_db
from src.models.database import Provider, ProviderEndpoint, ProviderAPIKey
from src.core.crypto import crypto_service
from src.services.request.candidate import RequestCandidateService
router = APIRouter(prefix="/api/admin/monitoring/trace", tags=["Admin - Monitoring: Trace"])
pipeline = ApiRequestPipeline()
class CandidateResponse(BaseModel):
"""候选记录响应"""
id: str
request_id: str
candidate_index: int
retry_index: int = 0 # 重试序号从0开始
provider_id: Optional[str] = None
provider_name: Optional[str] = None
provider_website: Optional[str] = None # Provider 官网
endpoint_id: Optional[str] = None
endpoint_name: Optional[str] = None # 端点显示名称api_format
key_id: Optional[str] = None
key_name: Optional[str] = None # 密钥名称
key_preview: Optional[str] = None # 密钥脱敏预览(如 sk-***abc
key_capabilities: Optional[dict] = None # Key 支持的能力
required_capabilities: Optional[dict] = None # 请求实际需要的能力标签
status: str # 'pending', 'success', 'failed', 'skipped'
skip_reason: Optional[str] = None
is_cached: bool = False
# 执行结果字段
status_code: Optional[int] = None
error_type: Optional[str] = None
error_message: Optional[str] = None
latency_ms: Optional[int] = None
concurrent_requests: Optional[int] = None
extra_data: Optional[dict] = None
created_at: datetime
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class RequestTraceResponse(BaseModel):
"""请求追踪完整响应"""
request_id: str
total_candidates: int
final_status: str # 'success', 'failed', 'streaming', 'pending'
total_latency_ms: int
candidates: List[CandidateResponse]
@router.get("/{request_id}", response_model=RequestTraceResponse)
async def get_request_trace(
request_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""获取特定请求的完整追踪信息"""
adapter = AdminGetRequestTraceAdapter(request_id=request_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/stats/provider/{provider_id}")
async def get_provider_failure_rate(
provider_id: str,
request: Request,
limit: int = Query(100, ge=1, le=1000, description="统计最近的尝试数量"),
db: Session = Depends(get_db),
):
"""
获取某个 Provider 的失败率统计
需要管理员权限
"""
adapter = AdminProviderFailureRateAdapter(provider_id=provider_id, limit=limit)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- 请求追踪适配器 --------
@dataclass
class AdminGetRequestTraceAdapter(AdminApiAdapter):
request_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
# 只查询 candidates
candidates = RequestCandidateService.get_candidates_by_request_id(db, self.request_id)
# 如果没有数据,返回 404
if not candidates:
raise HTTPException(status_code=404, detail="Request not found")
# 计算总延迟只统计已完成的候选success 或 failed
# 使用显式的 is not None 检查,避免过滤掉 0ms 的快速响应
total_latency = sum(
c.latency_ms
for c in candidates
if c.status in ("success", "failed") and c.latency_ms is not None
)
# 判断最终状态:
# 1. status="success" 即视为成功(无论 status_code 是什么)
# - 流式请求即使客户端断开499只要 Provider 成功返回数据,也算成功
# 2. 同时检查 status_code 在 200-299 范围,作为额外的成功判断条件
# - 用于兼容非流式请求或未正确设置 status 的旧数据
# 3. status="streaming" 表示流式请求正在进行中
# 4. status="pending" 表示请求尚未开始执行
has_success = any(
c.status == "success"
or (c.status_code is not None and 200 <= c.status_code < 300)
for c in candidates
)
has_streaming = any(c.status == "streaming" for c in candidates)
has_pending = any(c.status == "pending" for c in candidates)
if has_success:
final_status = "success"
elif has_streaming:
# 有候选正在流式传输中
final_status = "streaming"
elif has_pending:
# 有候选正在等待执行
final_status = "pending"
else:
final_status = "failed"
# 批量加载 provider 信息,避免 N+1 查询
provider_ids = {c.provider_id for c in candidates if c.provider_id}
provider_map = {}
provider_website_map = {}
if provider_ids:
providers = db.query(Provider).filter(Provider.id.in_(provider_ids)).all()
for p in providers:
provider_map[p.id] = p.name
provider_website_map[p.id] = p.website
# 批量加载 endpoint 信息
endpoint_ids = {c.endpoint_id for c in candidates if c.endpoint_id}
endpoint_map = {}
if endpoint_ids:
endpoints = db.query(ProviderEndpoint).filter(ProviderEndpoint.id.in_(endpoint_ids)).all()
endpoint_map = {e.id: e.api_format for e in endpoints}
# 批量加载 key 信息
key_ids = {c.key_id for c in candidates if c.key_id}
key_map = {}
key_preview_map = {}
key_capabilities_map = {}
if key_ids:
keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.id.in_(key_ids)).all()
for k in keys:
key_map[k.id] = k.name
key_capabilities_map[k.id] = k.capabilities
# 生成脱敏预览:先解密再脱敏
try:
decrypted_key = crypto_service.decrypt(k.api_key)
if len(decrypted_key) > 8:
# 检测常见前缀模式
prefix_end = 0
for prefix in ["sk-", "key-", "api-", "ak-"]:
if decrypted_key.lower().startswith(prefix):
prefix_end = len(prefix)
break
if prefix_end > 0:
key_preview_map[k.id] = f"{decrypted_key[:prefix_end]}***{decrypted_key[-4:]}"
else:
key_preview_map[k.id] = f"{decrypted_key[:4]}***{decrypted_key[-4:]}"
elif len(decrypted_key) > 4:
key_preview_map[k.id] = f"***{decrypted_key[-4:]}"
else:
key_preview_map[k.id] = "***"
except Exception:
key_preview_map[k.id] = "***"
# 构建 candidate 响应列表
candidate_responses: List[CandidateResponse] = []
for candidate in candidates:
provider_name = (
provider_map.get(candidate.provider_id) if candidate.provider_id else None
)
provider_website = (
provider_website_map.get(candidate.provider_id) if candidate.provider_id else None
)
endpoint_name = (
endpoint_map.get(candidate.endpoint_id) if candidate.endpoint_id else None
)
key_name = (
key_map.get(candidate.key_id) if candidate.key_id else None
)
key_preview = (
key_preview_map.get(candidate.key_id) if candidate.key_id else None
)
key_capabilities = (
key_capabilities_map.get(candidate.key_id) if candidate.key_id else None
)
candidate_responses.append(
CandidateResponse(
id=candidate.id,
request_id=candidate.request_id,
candidate_index=candidate.candidate_index,
retry_index=candidate.retry_index,
provider_id=candidate.provider_id,
provider_name=provider_name,
provider_website=provider_website,
endpoint_id=candidate.endpoint_id,
endpoint_name=endpoint_name,
key_id=candidate.key_id,
key_name=key_name,
key_preview=key_preview,
key_capabilities=key_capabilities,
required_capabilities=candidate.required_capabilities,
status=candidate.status,
skip_reason=candidate.skip_reason,
is_cached=candidate.is_cached,
status_code=candidate.status_code,
error_type=candidate.error_type,
error_message=candidate.error_message,
latency_ms=candidate.latency_ms,
concurrent_requests=candidate.concurrent_requests,
extra_data=candidate.extra_data,
created_at=candidate.created_at,
started_at=candidate.started_at,
finished_at=candidate.finished_at,
)
)
response = RequestTraceResponse(
request_id=self.request_id,
total_candidates=len(candidates),
final_status=final_status,
total_latency_ms=total_latency,
candidates=candidate_responses,
)
context.add_audit_metadata(
action="trace_request_detail",
request_id=self.request_id,
total_candidates=len(candidates),
final_status=final_status,
total_latency_ms=total_latency,
)
return response
@dataclass
class AdminProviderFailureRateAdapter(AdminApiAdapter):
provider_id: str
limit: int
async def handle(self, context): # type: ignore[override]
result = RequestCandidateService.get_candidate_stats_by_provider(
db=context.db,
provider_id=self.provider_id,
limit=self.limit,
)
context.add_audit_metadata(
action="trace_provider_failure_rate",
provider_id=self.provider_id,
limit=self.limit,
total_attempts=result.get("total_attempts"),
failure_rate=result.get("failure_rate"),
)
return result