mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
460 lines
16 KiB
Python
460 lines
16 KiB
Python
"""
|
||
端点健康状态服务
|
||
|
||
提供统一的端点健康监控功能,支持:
|
||
1. 按 API 格式聚合的健康状态
|
||
2. 基于时间窗口的状态追踪
|
||
3. 管理员和普通用户的差异化视图
|
||
4. Redis 缓存优化
|
||
"""
|
||
|
||
import json
|
||
from collections import defaultdict
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from sqlalchemy import case, func
|
||
from sqlalchemy.orm import Session
|
||
|
||
from src.core.logger import logger
|
||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, RequestCandidate
|
||
|
||
|
||
# 缓存配置
|
||
CACHE_TTL_SECONDS = 30 # 缓存 30 秒
|
||
CACHE_KEY_PREFIX = "health:endpoint:"
|
||
|
||
|
||
def _get_redis_client():
|
||
"""获取 Redis 客户端,失败返回 None"""
|
||
try:
|
||
from src.clients.redis_client import redis_client
|
||
return redis_client
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
class EndpointHealthService:
|
||
"""端点健康状态服务"""
|
||
|
||
@staticmethod
|
||
def get_endpoint_health_by_format(
|
||
db: Session,
|
||
lookback_hours: int = 6,
|
||
include_admin_fields: bool = False,
|
||
use_cache: bool = True,
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取按 API 格式聚合的端点健康状态
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
lookback_hours: 回溯小时数
|
||
include_admin_fields: 是否包含管理员字段(provider_count, key_count等)
|
||
use_cache: 是否使用缓存(仅对普通用户视图有效)
|
||
|
||
Returns:
|
||
按 API 格式聚合的健康状态列表
|
||
"""
|
||
# 尝试从缓存获取
|
||
cache_key = f"{CACHE_KEY_PREFIX}format:{lookback_hours}:{include_admin_fields}"
|
||
if use_cache:
|
||
cached = EndpointHealthService._get_from_cache(cache_key)
|
||
if cached is not None:
|
||
return cached
|
||
|
||
now = datetime.now(timezone.utc)
|
||
|
||
# 查询所有活跃的端点(一次性获取所有需要的数据)
|
||
endpoints = (
|
||
db.query(ProviderEndpoint).join(Provider).filter(Provider.is_active.is_(True)).all()
|
||
)
|
||
|
||
# 收集所有 endpoint_ids
|
||
all_endpoint_ids = [ep.id for ep in endpoints]
|
||
|
||
# 批量查询所有密钥
|
||
all_keys = (
|
||
db.query(ProviderAPIKey)
|
||
.filter(ProviderAPIKey.endpoint_id.in_(all_endpoint_ids))
|
||
.all()
|
||
) if all_endpoint_ids else []
|
||
|
||
# 按 endpoint_id 分组密钥
|
||
keys_by_endpoint: Dict[str, List[ProviderAPIKey]] = defaultdict(list)
|
||
for key in all_keys:
|
||
keys_by_endpoint[key.endpoint_id].append(key)
|
||
|
||
# 按 API 格式聚合
|
||
format_stats = defaultdict(
|
||
lambda: {
|
||
"total_endpoints": 0,
|
||
"total_keys": 0,
|
||
"active_keys": 0,
|
||
"health_scores": [],
|
||
"endpoint_ids": [],
|
||
"provider_ids": set(),
|
||
"key_ids": [],
|
||
}
|
||
)
|
||
|
||
for ep in endpoints:
|
||
api_format = ep.api_format if ep.api_format else "UNKNOWN"
|
||
|
||
# 统计端点数
|
||
format_stats[api_format]["total_endpoints"] += 1
|
||
format_stats[api_format]["endpoint_ids"].append(ep.id)
|
||
format_stats[api_format]["provider_ids"].add(ep.provider_id)
|
||
|
||
# 从预加载的密钥中获取
|
||
keys = keys_by_endpoint.get(ep.id, [])
|
||
format_stats[api_format]["total_keys"] += len(keys)
|
||
|
||
# 统计活跃密钥和健康度
|
||
if ep.is_active:
|
||
for key in keys:
|
||
format_stats[api_format]["key_ids"].append(key.id)
|
||
if key.is_active and not key.circuit_breaker_open:
|
||
format_stats[api_format]["active_keys"] += 1
|
||
health_score = key.health_score if key.health_score is not None else 1.0
|
||
format_stats[api_format]["health_scores"].append(health_score)
|
||
|
||
# 批量生成所有格式的时间线数据
|
||
all_key_ids = []
|
||
format_key_mapping: Dict[str, List[str]] = {}
|
||
for api_format, stats in format_stats.items():
|
||
key_ids = stats["key_ids"]
|
||
format_key_mapping[api_format] = key_ids
|
||
all_key_ids.extend(key_ids)
|
||
|
||
# 一次性查询所有时间线数据
|
||
timeline_data_map = EndpointHealthService._generate_timeline_batch(
|
||
db, format_key_mapping, now, lookback_hours
|
||
)
|
||
|
||
# 生成结果
|
||
result = []
|
||
|
||
for api_format, stats in format_stats.items():
|
||
timeline_data = timeline_data_map.get(api_format, {
|
||
"timeline": ["unknown"] * 100,
|
||
"time_range_start": None,
|
||
"time_range_end": None,
|
||
})
|
||
timeline = timeline_data["timeline"]
|
||
time_range_start = timeline_data.get("time_range_start")
|
||
time_range_end = timeline_data.get("time_range_end")
|
||
|
||
# 基于时间线计算实际健康度
|
||
if timeline:
|
||
healthy_count = sum(1 for status in timeline if status == "healthy")
|
||
warning_count = sum(1 for status in timeline if status == "warning")
|
||
unhealthy_count = sum(1 for status in timeline if status == "unhealthy")
|
||
known_count = healthy_count + warning_count + unhealthy_count
|
||
|
||
if known_count > 0:
|
||
avg_health = (healthy_count * 1.0 + warning_count * 0.8) / known_count
|
||
else:
|
||
if stats["health_scores"]:
|
||
avg_health = sum(stats["health_scores"]) / len(stats["health_scores"])
|
||
elif stats["total_keys"] == 0:
|
||
avg_health = 0.0
|
||
else:
|
||
avg_health = 0.1
|
||
else:
|
||
avg_health = 0.0
|
||
|
||
item = {
|
||
"api_format": api_format,
|
||
"display_name": EndpointHealthService._format_display_name(api_format),
|
||
"health_score": avg_health,
|
||
"timeline": timeline,
|
||
"time_range_start": time_range_start.isoformat() if time_range_start else None,
|
||
"time_range_end": time_range_end.isoformat() if time_range_end else None,
|
||
}
|
||
|
||
if include_admin_fields:
|
||
item.update(
|
||
{
|
||
"total_endpoints": stats["total_endpoints"],
|
||
"total_keys": stats["total_keys"],
|
||
"active_keys": stats["active_keys"],
|
||
"provider_count": len(stats["provider_ids"]),
|
||
}
|
||
)
|
||
|
||
result.append(item)
|
||
|
||
result.sort(key=lambda x: x["health_score"], reverse=True)
|
||
|
||
# 写入缓存
|
||
if use_cache:
|
||
EndpointHealthService._set_to_cache(cache_key, result)
|
||
|
||
return result
|
||
|
||
@staticmethod
|
||
def _generate_timeline_batch(
|
||
db: Session,
|
||
format_key_mapping: Dict[str, List[str]],
|
||
now: datetime,
|
||
lookback_hours: int,
|
||
segments: int = 100,
|
||
) -> Dict[str, Dict[str, Any]]:
|
||
"""
|
||
批量生成多个 API 格式的时间线数据(基于 RequestCandidate 表)
|
||
|
||
使用 RequestCandidate 表可以:
|
||
1. 记录所有尝试(包括 fallback 中失败的尝试)
|
||
2. 准确反映每个 Provider/Key 的真实健康状态
|
||
3. 失败的请求会显示为红色节点
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
format_key_mapping: API格式 -> key_ids 的映射
|
||
now: 当前时间
|
||
lookback_hours: 回溯小时数
|
||
segments: 时间段数量
|
||
|
||
Returns:
|
||
API格式 -> 时间线数据的映射
|
||
"""
|
||
# 收集所有 key_ids
|
||
all_key_ids = []
|
||
for key_ids in format_key_mapping.values():
|
||
all_key_ids.extend(key_ids)
|
||
|
||
if not all_key_ids:
|
||
return {
|
||
api_format: {
|
||
"timeline": ["unknown"] * 100,
|
||
"time_range_start": None,
|
||
"time_range_end": None,
|
||
}
|
||
for api_format in format_key_mapping.keys()
|
||
}
|
||
|
||
# 参数校验(API 层已通过 Query(ge=1) 保证,这里做防御性检查)
|
||
if lookback_hours <= 0 or segments <= 0:
|
||
raise ValueError(
|
||
f"lookback_hours and segments must be positive, "
|
||
f"got lookback_hours={lookback_hours}, segments={segments}"
|
||
)
|
||
|
||
# 计算时间范围
|
||
segment_seconds = (lookback_hours * 3600) / segments
|
||
start_time = now - timedelta(hours=lookback_hours)
|
||
|
||
# 使用 RequestCandidate 表查询所有尝试记录
|
||
# 只统计最终状态:success, failed, skipped
|
||
final_statuses = ["success", "failed", "skipped"]
|
||
|
||
segment_expr = func.floor(
|
||
func.extract('epoch', RequestCandidate.created_at - start_time) / segment_seconds
|
||
).label('segment_idx')
|
||
|
||
candidate_stats = (
|
||
db.query(
|
||
RequestCandidate.key_id,
|
||
segment_expr,
|
||
func.count(RequestCandidate.id).label('total_count'),
|
||
func.sum(
|
||
case(
|
||
(RequestCandidate.status == "success", 1),
|
||
else_=0
|
||
)
|
||
).label('success_count'),
|
||
func.sum(
|
||
case(
|
||
(RequestCandidate.status == "failed", 1),
|
||
else_=0
|
||
)
|
||
).label('failed_count'),
|
||
func.min(RequestCandidate.created_at).label('min_time'),
|
||
func.max(RequestCandidate.created_at).label('max_time'),
|
||
)
|
||
.filter(
|
||
RequestCandidate.key_id.in_(all_key_ids),
|
||
RequestCandidate.created_at >= start_time,
|
||
RequestCandidate.created_at <= now,
|
||
RequestCandidate.status.in_(final_statuses),
|
||
)
|
||
.group_by(RequestCandidate.key_id, segment_expr)
|
||
.all()
|
||
)
|
||
|
||
# 构建 key_id -> api_format 的反向映射
|
||
key_to_format: Dict[str, str] = {}
|
||
for api_format, key_ids in format_key_mapping.items():
|
||
for key_id in key_ids:
|
||
key_to_format[key_id] = api_format
|
||
|
||
# 按 api_format 和 segment 聚合数据
|
||
format_segment_data: Dict[str, Dict[int, Dict]] = defaultdict(lambda: defaultdict(lambda: {
|
||
"total": 0,
|
||
"success": 0,
|
||
"failed": 0,
|
||
"min_time": None,
|
||
"max_time": None,
|
||
}))
|
||
|
||
for row in candidate_stats:
|
||
key_id = row.key_id
|
||
segment_idx = int(row.segment_idx) if row.segment_idx is not None else 0
|
||
api_format = key_to_format.get(key_id)
|
||
|
||
if api_format and 0 <= segment_idx < segments:
|
||
seg_data = format_segment_data[api_format][segment_idx]
|
||
seg_data["total"] += row.total_count or 0
|
||
seg_data["success"] += row.success_count or 0
|
||
seg_data["failed"] += row.failed_count or 0
|
||
|
||
if row.min_time:
|
||
if seg_data["min_time"] is None or row.min_time < seg_data["min_time"]:
|
||
seg_data["min_time"] = row.min_time
|
||
if row.max_time:
|
||
if seg_data["max_time"] is None or row.max_time > seg_data["max_time"]:
|
||
seg_data["max_time"] = row.max_time
|
||
|
||
# 生成各格式的时间线
|
||
result: Dict[str, Dict[str, Any]] = {}
|
||
|
||
for api_format in format_key_mapping.keys():
|
||
timeline = []
|
||
earliest_time = None
|
||
latest_time = None
|
||
|
||
segment_data = format_segment_data.get(api_format, {})
|
||
|
||
for i in range(segments):
|
||
seg = segment_data.get(i)
|
||
if not seg or seg["total"] == 0:
|
||
timeline.append("unknown")
|
||
else:
|
||
# 更新时间范围
|
||
if seg["min_time"]:
|
||
if earliest_time is None or seg["min_time"] < earliest_time:
|
||
earliest_time = seg["min_time"]
|
||
if seg["max_time"]:
|
||
if latest_time is None or seg["max_time"] > latest_time:
|
||
latest_time = seg["max_time"]
|
||
|
||
# 计算成功率 = success / (success + failed)
|
||
# skipped 不算失败,不影响成功率
|
||
actual_completed = seg["success"] + seg["failed"]
|
||
if actual_completed > 0:
|
||
success_rate = seg["success"] / actual_completed
|
||
else:
|
||
# 只有 skipped,视为健康
|
||
success_rate = 1.0
|
||
|
||
if success_rate >= 0.95:
|
||
timeline.append("healthy")
|
||
elif success_rate >= 0.7:
|
||
timeline.append("warning")
|
||
else:
|
||
timeline.append("unhealthy")
|
||
|
||
result[api_format] = {
|
||
"timeline": timeline,
|
||
"time_range_start": earliest_time,
|
||
"time_range_end": latest_time if latest_time else now,
|
||
}
|
||
|
||
return result
|
||
|
||
@staticmethod
|
||
def _generate_timeline_from_usage(
|
||
db: Session,
|
||
endpoint_ids: List[str],
|
||
now: datetime,
|
||
lookback_hours: int,
|
||
segments: int = 100,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
从真实使用记录生成时间线数据(兼容旧接口,使用批量查询优化)
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
endpoint_ids: 端点ID列表
|
||
now: 当前时间
|
||
lookback_hours: 回溯小时数
|
||
segments: 时间段数量
|
||
|
||
Returns:
|
||
包含时间线和时间范围的字典
|
||
"""
|
||
if not endpoint_ids:
|
||
return {
|
||
"timeline": ["unknown"] * 100,
|
||
"time_range_start": None,
|
||
"time_range_end": None,
|
||
}
|
||
|
||
# 先查询该 API 格式下的所有密钥
|
||
key_ids = [
|
||
k.id
|
||
for k in db.query(ProviderAPIKey.id)
|
||
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
|
||
.all()
|
||
]
|
||
|
||
if not key_ids:
|
||
return {
|
||
"timeline": ["unknown"] * 100,
|
||
"time_range_start": None,
|
||
"time_range_end": None,
|
||
}
|
||
|
||
# 使用批量查询
|
||
format_key_mapping = {"_single": key_ids}
|
||
result = EndpointHealthService._generate_timeline_batch(
|
||
db, format_key_mapping, now, lookback_hours, segments
|
||
)
|
||
|
||
return result.get("_single", {
|
||
"timeline": ["unknown"] * 100,
|
||
"time_range_start": None,
|
||
"time_range_end": None,
|
||
})
|
||
|
||
@staticmethod
|
||
def _format_display_name(api_format: str) -> str:
|
||
"""格式化 API 格式的显示名称"""
|
||
format_names = {
|
||
"CLAUDE": "Claude API",
|
||
"CLAUDE_CLI": "Claude CLI",
|
||
"CLAUDE_COMPATIBLE": "Claude 兼容",
|
||
"OPENAI": "OpenAI API",
|
||
"OPENAI_CLI": "OpenAI CLI",
|
||
"OPENAI_COMPATIBLE": "OpenAI 兼容",
|
||
}
|
||
return format_names.get(api_format, api_format)
|
||
|
||
@staticmethod
|
||
def _get_from_cache(key: str) -> Optional[List[Dict[str, Any]]]:
|
||
"""从 Redis 缓存获取数据"""
|
||
redis_client = _get_redis_client()
|
||
if not redis_client:
|
||
return None
|
||
|
||
try:
|
||
data = redis_client.get(key)
|
||
if data:
|
||
return json.loads(data)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to get from cache: {e}")
|
||
return None
|
||
|
||
@staticmethod
|
||
def _set_to_cache(key: str, data: List[Dict[str, Any]]) -> None:
|
||
"""写入 Redis 缓存"""
|
||
redis_client = _get_redis_client()
|
||
if not redis_client:
|
||
return
|
||
|
||
try:
|
||
redis_client.setex(key, CACHE_TTL_SECONDS, json.dumps(data, default=str))
|
||
except Exception as e:
|
||
logger.warning(f"Failed to set cache: {e}")
|