mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
284 lines
11 KiB
Python
284 lines
11 KiB
Python
"""
|
||
请求链路追踪 API 端点
|
||
"""
|
||
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
from typing import List, Optional
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||
from pydantic import BaseModel, ConfigDict
|
||
from sqlalchemy.orm import Session
|
||
|
||
from src.api.base.admin_adapter import AdminApiAdapter
|
||
from src.api.base.pipeline import ApiRequestPipeline
|
||
from src.database import get_db
|
||
from src.models.database import Provider, ProviderEndpoint, ProviderAPIKey
|
||
from src.core.crypto import crypto_service
|
||
from src.services.request.candidate import RequestCandidateService
|
||
|
||
router = APIRouter(prefix="/api/admin/monitoring/trace", tags=["Admin - Monitoring: Trace"])
|
||
pipeline = ApiRequestPipeline()
|
||
|
||
|
||
class CandidateResponse(BaseModel):
|
||
"""候选记录响应"""
|
||
|
||
id: str
|
||
request_id: str
|
||
candidate_index: int
|
||
retry_index: int = 0 # 重试序号(从0开始)
|
||
provider_id: Optional[str] = None
|
||
provider_name: Optional[str] = None
|
||
provider_website: Optional[str] = None # Provider 官网
|
||
endpoint_id: Optional[str] = None
|
||
endpoint_name: Optional[str] = None # 端点显示名称(api_format)
|
||
key_id: Optional[str] = None
|
||
key_name: Optional[str] = None # 密钥名称
|
||
key_preview: Optional[str] = None # 密钥脱敏预览(如 sk-***abc)
|
||
key_capabilities: Optional[dict] = None # Key 支持的能力
|
||
required_capabilities: Optional[dict] = None # 请求实际需要的能力标签
|
||
status: str # 'pending', 'success', 'failed', 'skipped'
|
||
skip_reason: Optional[str] = None
|
||
is_cached: bool = False
|
||
# 执行结果字段
|
||
status_code: Optional[int] = None
|
||
error_type: Optional[str] = None
|
||
error_message: Optional[str] = None
|
||
latency_ms: Optional[int] = None
|
||
concurrent_requests: Optional[int] = None
|
||
extra_data: Optional[dict] = None
|
||
created_at: datetime
|
||
started_at: Optional[datetime] = None
|
||
finished_at: Optional[datetime] = None
|
||
|
||
model_config = ConfigDict(from_attributes=True)
|
||
|
||
|
||
class RequestTraceResponse(BaseModel):
|
||
"""请求追踪完整响应"""
|
||
|
||
request_id: str
|
||
total_candidates: int
|
||
final_status: str # 'success', 'failed', 'streaming', 'pending'
|
||
total_latency_ms: int
|
||
candidates: List[CandidateResponse]
|
||
|
||
|
||
@router.get("/{request_id}", response_model=RequestTraceResponse)
|
||
async def get_request_trace(
|
||
request_id: str,
|
||
request: Request,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""获取特定请求的完整追踪信息"""
|
||
|
||
adapter = AdminGetRequestTraceAdapter(request_id=request_id)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.get("/stats/provider/{provider_id}")
|
||
async def get_provider_failure_rate(
|
||
provider_id: str,
|
||
request: Request,
|
||
limit: int = Query(100, ge=1, le=1000, description="统计最近的尝试数量"),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""
|
||
获取某个 Provider 的失败率统计
|
||
|
||
需要管理员权限
|
||
"""
|
||
adapter = AdminProviderFailureRateAdapter(provider_id=provider_id, limit=limit)
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
# -------- 请求追踪适配器 --------
|
||
|
||
|
||
@dataclass
|
||
class AdminGetRequestTraceAdapter(AdminApiAdapter):
|
||
request_id: str
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
db = context.db
|
||
|
||
# 只查询 candidates
|
||
candidates = RequestCandidateService.get_candidates_by_request_id(db, self.request_id)
|
||
|
||
# 如果没有数据,返回 404
|
||
if not candidates:
|
||
raise HTTPException(status_code=404, detail="Request not found")
|
||
|
||
# 计算总延迟(只统计已完成的候选:success 或 failed)
|
||
# 使用显式的 is not None 检查,避免过滤掉 0ms 的快速响应
|
||
total_latency = sum(
|
||
c.latency_ms
|
||
for c in candidates
|
||
if c.status in ("success", "failed") and c.latency_ms is not None
|
||
)
|
||
|
||
# 判断最终状态:
|
||
# 1. status="success" 即视为成功(无论 status_code 是什么)
|
||
# - 流式请求即使客户端断开(499),只要 Provider 成功返回数据,也算成功
|
||
# 2. 同时检查 status_code 在 200-299 范围,作为额外的成功判断条件
|
||
# - 用于兼容非流式请求或未正确设置 status 的旧数据
|
||
# 3. status="streaming" 表示流式请求正在进行中
|
||
# 4. status="pending" 表示请求尚未开始执行
|
||
has_success = any(
|
||
c.status == "success"
|
||
or (c.status_code is not None and 200 <= c.status_code < 300)
|
||
for c in candidates
|
||
)
|
||
has_streaming = any(c.status == "streaming" for c in candidates)
|
||
has_pending = any(c.status == "pending" for c in candidates)
|
||
|
||
if has_success:
|
||
final_status = "success"
|
||
elif has_streaming:
|
||
# 有候选正在流式传输中
|
||
final_status = "streaming"
|
||
elif has_pending:
|
||
# 有候选正在等待执行
|
||
final_status = "pending"
|
||
else:
|
||
final_status = "failed"
|
||
|
||
# 批量加载 provider 信息,避免 N+1 查询
|
||
provider_ids = {c.provider_id for c in candidates if c.provider_id}
|
||
provider_map = {}
|
||
provider_website_map = {}
|
||
if provider_ids:
|
||
providers = db.query(Provider).filter(Provider.id.in_(provider_ids)).all()
|
||
for p in providers:
|
||
provider_map[p.id] = p.name
|
||
provider_website_map[p.id] = p.website
|
||
|
||
# 批量加载 endpoint 信息
|
||
endpoint_ids = {c.endpoint_id for c in candidates if c.endpoint_id}
|
||
endpoint_map = {}
|
||
if endpoint_ids:
|
||
endpoints = db.query(ProviderEndpoint).filter(ProviderEndpoint.id.in_(endpoint_ids)).all()
|
||
endpoint_map = {e.id: e.api_format for e in endpoints}
|
||
|
||
# 批量加载 key 信息
|
||
key_ids = {c.key_id for c in candidates if c.key_id}
|
||
key_map = {}
|
||
key_preview_map = {}
|
||
key_capabilities_map = {}
|
||
if key_ids:
|
||
keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.id.in_(key_ids)).all()
|
||
for k in keys:
|
||
key_map[k.id] = k.name
|
||
key_capabilities_map[k.id] = k.capabilities
|
||
# 生成脱敏预览:先解密再脱敏
|
||
try:
|
||
decrypted_key = crypto_service.decrypt(k.api_key)
|
||
if len(decrypted_key) > 8:
|
||
# 检测常见前缀模式
|
||
prefix_end = 0
|
||
for prefix in ["sk-", "key-", "api-", "ak-"]:
|
||
if decrypted_key.lower().startswith(prefix):
|
||
prefix_end = len(prefix)
|
||
break
|
||
if prefix_end > 0:
|
||
key_preview_map[k.id] = f"{decrypted_key[:prefix_end]}***{decrypted_key[-4:]}"
|
||
else:
|
||
key_preview_map[k.id] = f"{decrypted_key[:4]}***{decrypted_key[-4:]}"
|
||
elif len(decrypted_key) > 4:
|
||
key_preview_map[k.id] = f"***{decrypted_key[-4:]}"
|
||
else:
|
||
key_preview_map[k.id] = "***"
|
||
except Exception:
|
||
key_preview_map[k.id] = "***"
|
||
|
||
# 构建 candidate 响应列表
|
||
candidate_responses: List[CandidateResponse] = []
|
||
for candidate in candidates:
|
||
provider_name = (
|
||
provider_map.get(candidate.provider_id) if candidate.provider_id else None
|
||
)
|
||
provider_website = (
|
||
provider_website_map.get(candidate.provider_id) if candidate.provider_id else None
|
||
)
|
||
endpoint_name = (
|
||
endpoint_map.get(candidate.endpoint_id) if candidate.endpoint_id else None
|
||
)
|
||
key_name = (
|
||
key_map.get(candidate.key_id) if candidate.key_id else None
|
||
)
|
||
key_preview = (
|
||
key_preview_map.get(candidate.key_id) if candidate.key_id else None
|
||
)
|
||
key_capabilities = (
|
||
key_capabilities_map.get(candidate.key_id) if candidate.key_id else None
|
||
)
|
||
|
||
candidate_responses.append(
|
||
CandidateResponse(
|
||
id=candidate.id,
|
||
request_id=candidate.request_id,
|
||
candidate_index=candidate.candidate_index,
|
||
retry_index=candidate.retry_index,
|
||
provider_id=candidate.provider_id,
|
||
provider_name=provider_name,
|
||
provider_website=provider_website,
|
||
endpoint_id=candidate.endpoint_id,
|
||
endpoint_name=endpoint_name,
|
||
key_id=candidate.key_id,
|
||
key_name=key_name,
|
||
key_preview=key_preview,
|
||
key_capabilities=key_capabilities,
|
||
required_capabilities=candidate.required_capabilities,
|
||
status=candidate.status,
|
||
skip_reason=candidate.skip_reason,
|
||
is_cached=candidate.is_cached,
|
||
status_code=candidate.status_code,
|
||
error_type=candidate.error_type,
|
||
error_message=candidate.error_message,
|
||
latency_ms=candidate.latency_ms,
|
||
concurrent_requests=candidate.concurrent_requests,
|
||
extra_data=candidate.extra_data,
|
||
created_at=candidate.created_at,
|
||
started_at=candidate.started_at,
|
||
finished_at=candidate.finished_at,
|
||
)
|
||
)
|
||
|
||
response = RequestTraceResponse(
|
||
request_id=self.request_id,
|
||
total_candidates=len(candidates),
|
||
final_status=final_status,
|
||
total_latency_ms=total_latency,
|
||
candidates=candidate_responses,
|
||
)
|
||
context.add_audit_metadata(
|
||
action="trace_request_detail",
|
||
request_id=self.request_id,
|
||
total_candidates=len(candidates),
|
||
final_status=final_status,
|
||
total_latency_ms=total_latency,
|
||
)
|
||
return response
|
||
|
||
|
||
@dataclass
|
||
class AdminProviderFailureRateAdapter(AdminApiAdapter):
|
||
provider_id: str
|
||
limit: int
|
||
|
||
async def handle(self, context): # type: ignore[override]
|
||
result = RequestCandidateService.get_candidate_stats_by_provider(
|
||
db=context.db,
|
||
provider_id=self.provider_id,
|
||
limit=self.limit,
|
||
)
|
||
context.add_audit_metadata(
|
||
action="trace_provider_failure_rate",
|
||
provider_id=self.provider_id,
|
||
limit=self.limit,
|
||
total_attempts=result.get("total_attempts"),
|
||
failure_rate=result.get("failure_rate"),
|
||
)
|
||
return result
|