Files
Aether/src/services/health/endpoint.py
2025-12-25 22:44:17 +08:00

460 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
端点健康状态服务
提供统一的端点健康监控功能,支持:
1. 按 API 格式聚合的健康状态
2. 基于时间窗口的状态追踪
3. 管理员和普通用户的差异化视图
4. Redis 缓存优化
"""
import json
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional
from sqlalchemy import case, func
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, RequestCandidate
# 缓存配置
CACHE_TTL_SECONDS = 30 # 缓存 30 秒
CACHE_KEY_PREFIX = "health:endpoint:"
def _get_redis_client():
"""获取 Redis 客户端,失败返回 None"""
try:
from src.clients.redis_client import redis_client
return redis_client
except Exception:
return None
class EndpointHealthService:
"""端点健康状态服务"""
@staticmethod
def get_endpoint_health_by_format(
db: Session,
lookback_hours: int = 6,
include_admin_fields: bool = False,
use_cache: bool = True,
) -> List[Dict[str, Any]]:
"""
获取按 API 格式聚合的端点健康状态
Args:
db: 数据库会话
lookback_hours: 回溯小时数
include_admin_fields: 是否包含管理员字段provider_count, key_count等
use_cache: 是否使用缓存(仅对普通用户视图有效)
Returns:
按 API 格式聚合的健康状态列表
"""
# 尝试从缓存获取
cache_key = f"{CACHE_KEY_PREFIX}format:{lookback_hours}:{include_admin_fields}"
if use_cache:
cached = EndpointHealthService._get_from_cache(cache_key)
if cached is not None:
return cached
now = datetime.now(timezone.utc)
# 查询所有活跃的端点(一次性获取所有需要的数据)
endpoints = (
db.query(ProviderEndpoint).join(Provider).filter(Provider.is_active.is_(True)).all()
)
# 收集所有 endpoint_ids
all_endpoint_ids = [ep.id for ep in endpoints]
# 批量查询所有密钥
all_keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id.in_(all_endpoint_ids))
.all()
) if all_endpoint_ids else []
# 按 endpoint_id 分组密钥
keys_by_endpoint: Dict[str, List[ProviderAPIKey]] = defaultdict(list)
for key in all_keys:
keys_by_endpoint[key.endpoint_id].append(key)
# 按 API 格式聚合
format_stats = defaultdict(
lambda: {
"total_endpoints": 0,
"total_keys": 0,
"active_keys": 0,
"health_scores": [],
"endpoint_ids": [],
"provider_ids": set(),
"key_ids": [],
}
)
for ep in endpoints:
api_format = ep.api_format if ep.api_format else "UNKNOWN"
# 统计端点数
format_stats[api_format]["total_endpoints"] += 1
format_stats[api_format]["endpoint_ids"].append(ep.id)
format_stats[api_format]["provider_ids"].add(ep.provider_id)
# 从预加载的密钥中获取
keys = keys_by_endpoint.get(ep.id, [])
format_stats[api_format]["total_keys"] += len(keys)
# 统计活跃密钥和健康度
if ep.is_active:
for key in keys:
format_stats[api_format]["key_ids"].append(key.id)
if key.is_active and not key.circuit_breaker_open:
format_stats[api_format]["active_keys"] += 1
health_score = key.health_score if key.health_score is not None else 1.0
format_stats[api_format]["health_scores"].append(health_score)
# 批量生成所有格式的时间线数据
all_key_ids = []
format_key_mapping: Dict[str, List[str]] = {}
for api_format, stats in format_stats.items():
key_ids = stats["key_ids"]
format_key_mapping[api_format] = key_ids
all_key_ids.extend(key_ids)
# 一次性查询所有时间线数据
timeline_data_map = EndpointHealthService._generate_timeline_batch(
db, format_key_mapping, now, lookback_hours
)
# 生成结果
result = []
for api_format, stats in format_stats.items():
timeline_data = timeline_data_map.get(api_format, {
"timeline": ["unknown"] * 100,
"time_range_start": None,
"time_range_end": None,
})
timeline = timeline_data["timeline"]
time_range_start = timeline_data.get("time_range_start")
time_range_end = timeline_data.get("time_range_end")
# 基于时间线计算实际健康度
if timeline:
healthy_count = sum(1 for status in timeline if status == "healthy")
warning_count = sum(1 for status in timeline if status == "warning")
unhealthy_count = sum(1 for status in timeline if status == "unhealthy")
known_count = healthy_count + warning_count + unhealthy_count
if known_count > 0:
avg_health = (healthy_count * 1.0 + warning_count * 0.8) / known_count
else:
if stats["health_scores"]:
avg_health = sum(stats["health_scores"]) / len(stats["health_scores"])
elif stats["total_keys"] == 0:
avg_health = 0.0
else:
avg_health = 0.1
else:
avg_health = 0.0
item = {
"api_format": api_format,
"display_name": EndpointHealthService._format_display_name(api_format),
"health_score": avg_health,
"timeline": timeline,
"time_range_start": time_range_start.isoformat() if time_range_start else None,
"time_range_end": time_range_end.isoformat() if time_range_end else None,
}
if include_admin_fields:
item.update(
{
"total_endpoints": stats["total_endpoints"],
"total_keys": stats["total_keys"],
"active_keys": stats["active_keys"],
"provider_count": len(stats["provider_ids"]),
}
)
result.append(item)
result.sort(key=lambda x: x["health_score"], reverse=True)
# 写入缓存
if use_cache:
EndpointHealthService._set_to_cache(cache_key, result)
return result
@staticmethod
def _generate_timeline_batch(
db: Session,
format_key_mapping: Dict[str, List[str]],
now: datetime,
lookback_hours: int,
segments: int = 100,
) -> Dict[str, Dict[str, Any]]:
"""
批量生成多个 API 格式的时间线数据(基于 RequestCandidate 表)
使用 RequestCandidate 表可以:
1. 记录所有尝试(包括 fallback 中失败的尝试)
2. 准确反映每个 Provider/Key 的真实健康状态
3. 失败的请求会显示为红色节点
Args:
db: 数据库会话
format_key_mapping: API格式 -> key_ids 的映射
now: 当前时间
lookback_hours: 回溯小时数
segments: 时间段数量
Returns:
API格式 -> 时间线数据的映射
"""
# 收集所有 key_ids
all_key_ids = []
for key_ids in format_key_mapping.values():
all_key_ids.extend(key_ids)
if not all_key_ids:
return {
api_format: {
"timeline": ["unknown"] * 100,
"time_range_start": None,
"time_range_end": None,
}
for api_format in format_key_mapping.keys()
}
# 参数校验API 层已通过 Query(ge=1) 保证,这里做防御性检查)
if lookback_hours <= 0 or segments <= 0:
raise ValueError(
f"lookback_hours and segments must be positive, "
f"got lookback_hours={lookback_hours}, segments={segments}"
)
# 计算时间范围
segment_seconds = (lookback_hours * 3600) / segments
start_time = now - timedelta(hours=lookback_hours)
# 使用 RequestCandidate 表查询所有尝试记录
# 只统计最终状态success, failed, skipped
final_statuses = ["success", "failed", "skipped"]
segment_expr = func.floor(
func.extract('epoch', RequestCandidate.created_at - start_time) / segment_seconds
).label('segment_idx')
candidate_stats = (
db.query(
RequestCandidate.key_id,
segment_expr,
func.count(RequestCandidate.id).label('total_count'),
func.sum(
case(
(RequestCandidate.status == "success", 1),
else_=0
)
).label('success_count'),
func.sum(
case(
(RequestCandidate.status == "failed", 1),
else_=0
)
).label('failed_count'),
func.min(RequestCandidate.created_at).label('min_time'),
func.max(RequestCandidate.created_at).label('max_time'),
)
.filter(
RequestCandidate.key_id.in_(all_key_ids),
RequestCandidate.created_at >= start_time,
RequestCandidate.created_at <= now,
RequestCandidate.status.in_(final_statuses),
)
.group_by(RequestCandidate.key_id, segment_expr)
.all()
)
# 构建 key_id -> api_format 的反向映射
key_to_format: Dict[str, str] = {}
for api_format, key_ids in format_key_mapping.items():
for key_id in key_ids:
key_to_format[key_id] = api_format
# 按 api_format 和 segment 聚合数据
format_segment_data: Dict[str, Dict[int, Dict]] = defaultdict(lambda: defaultdict(lambda: {
"total": 0,
"success": 0,
"failed": 0,
"min_time": None,
"max_time": None,
}))
for row in candidate_stats:
key_id = row.key_id
segment_idx = int(row.segment_idx) if row.segment_idx is not None else 0
api_format = key_to_format.get(key_id)
if api_format and 0 <= segment_idx < segments:
seg_data = format_segment_data[api_format][segment_idx]
seg_data["total"] += row.total_count or 0
seg_data["success"] += row.success_count or 0
seg_data["failed"] += row.failed_count or 0
if row.min_time:
if seg_data["min_time"] is None or row.min_time < seg_data["min_time"]:
seg_data["min_time"] = row.min_time
if row.max_time:
if seg_data["max_time"] is None or row.max_time > seg_data["max_time"]:
seg_data["max_time"] = row.max_time
# 生成各格式的时间线
result: Dict[str, Dict[str, Any]] = {}
for api_format in format_key_mapping.keys():
timeline = []
earliest_time = None
latest_time = None
segment_data = format_segment_data.get(api_format, {})
for i in range(segments):
seg = segment_data.get(i)
if not seg or seg["total"] == 0:
timeline.append("unknown")
else:
# 更新时间范围
if seg["min_time"]:
if earliest_time is None or seg["min_time"] < earliest_time:
earliest_time = seg["min_time"]
if seg["max_time"]:
if latest_time is None or seg["max_time"] > latest_time:
latest_time = seg["max_time"]
# 计算成功率 = success / (success + failed)
# skipped 不算失败,不影响成功率
actual_completed = seg["success"] + seg["failed"]
if actual_completed > 0:
success_rate = seg["success"] / actual_completed
else:
# 只有 skipped视为健康
success_rate = 1.0
if success_rate >= 0.95:
timeline.append("healthy")
elif success_rate >= 0.7:
timeline.append("warning")
else:
timeline.append("unhealthy")
result[api_format] = {
"timeline": timeline,
"time_range_start": earliest_time,
"time_range_end": latest_time if latest_time else now,
}
return result
@staticmethod
def _generate_timeline_from_usage(
db: Session,
endpoint_ids: List[str],
now: datetime,
lookback_hours: int,
segments: int = 100,
) -> Dict[str, Any]:
"""
从真实使用记录生成时间线数据(兼容旧接口,使用批量查询优化)
Args:
db: 数据库会话
endpoint_ids: 端点ID列表
now: 当前时间
lookback_hours: 回溯小时数
segments: 时间段数量
Returns:
包含时间线和时间范围的字典
"""
if not endpoint_ids:
return {
"timeline": ["unknown"] * 100,
"time_range_start": None,
"time_range_end": None,
}
# 先查询该 API 格式下的所有密钥
key_ids = [
k.id
for k in db.query(ProviderAPIKey.id)
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
.all()
]
if not key_ids:
return {
"timeline": ["unknown"] * 100,
"time_range_start": None,
"time_range_end": None,
}
# 使用批量查询
format_key_mapping = {"_single": key_ids}
result = EndpointHealthService._generate_timeline_batch(
db, format_key_mapping, now, lookback_hours, segments
)
return result.get("_single", {
"timeline": ["unknown"] * 100,
"time_range_start": None,
"time_range_end": None,
})
@staticmethod
def _format_display_name(api_format: str) -> str:
"""格式化 API 格式的显示名称"""
format_names = {
"CLAUDE": "Claude API",
"CLAUDE_CLI": "Claude CLI",
"CLAUDE_COMPATIBLE": "Claude 兼容",
"OPENAI": "OpenAI API",
"OPENAI_CLI": "OpenAI CLI",
"OPENAI_COMPATIBLE": "OpenAI 兼容",
}
return format_names.get(api_format, api_format)
@staticmethod
def _get_from_cache(key: str) -> Optional[List[Dict[str, Any]]]:
"""从 Redis 缓存获取数据"""
redis_client = _get_redis_client()
if not redis_client:
return None
try:
data = redis_client.get(key)
if data:
return json.loads(data)
except Exception as e:
logger.warning(f"Failed to get from cache: {e}")
return None
@staticmethod
def _set_to_cache(key: str, data: List[Dict[str, Any]]) -> None:
"""写入 Redis 缓存"""
redis_client = _get_redis_client()
if not redis_client:
return
try:
redis_client.setex(key, CACHE_TTL_SECONDS, json.dumps(data, default=str))
except Exception as e:
logger.warning(f"Failed to set cache: {e}")