Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

View File

@@ -0,0 +1,14 @@
"""
健康监控服务模块
包含健康监控相关功能:
- health_monitor: 健康度监控单例
- HealthMonitor: 健康监控类
"""
from .monitor import HealthMonitor, health_monitor
__all__ = [
"health_monitor",
"HealthMonitor",
]

View File

@@ -0,0 +1,452 @@
"""
端点健康状态服务
提供统一的端点健康监控功能,支持:
1. 按 API 格式聚合的健康状态
2. 基于时间窗口的状态追踪
3. 管理员和普通用户的差异化视图
4. Redis 缓存优化
"""
import json
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional
from sqlalchemy import case, func
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, RequestCandidate
# 缓存配置
CACHE_TTL_SECONDS = 30 # 缓存 30 秒
CACHE_KEY_PREFIX = "health:endpoint:"
def _get_redis_client():
"""获取 Redis 客户端,失败返回 None"""
try:
from src.clients.redis_client import redis_client
return redis_client
except Exception:
return None
class EndpointHealthService:
"""端点健康状态服务"""
@staticmethod
def get_endpoint_health_by_format(
db: Session,
lookback_hours: int = 6,
include_admin_fields: bool = False,
use_cache: bool = True,
) -> List[Dict[str, Any]]:
"""
获取按 API 格式聚合的端点健康状态
Args:
db: 数据库会话
lookback_hours: 回溯小时数
include_admin_fields: 是否包含管理员字段provider_count, key_count等
use_cache: 是否使用缓存(仅对普通用户视图有效)
Returns:
按 API 格式聚合的健康状态列表
"""
# 尝试从缓存获取
cache_key = f"{CACHE_KEY_PREFIX}format:{lookback_hours}:{include_admin_fields}"
if use_cache:
cached = EndpointHealthService._get_from_cache(cache_key)
if cached is not None:
return cached
now = datetime.now(timezone.utc)
# 查询所有活跃的端点(一次性获取所有需要的数据)
endpoints = (
db.query(ProviderEndpoint).join(Provider).filter(Provider.is_active.is_(True)).all()
)
# 收集所有 endpoint_ids
all_endpoint_ids = [ep.id for ep in endpoints]
# 批量查询所有密钥
all_keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id.in_(all_endpoint_ids))
.all()
) if all_endpoint_ids else []
# 按 endpoint_id 分组密钥
keys_by_endpoint: Dict[str, List[ProviderAPIKey]] = defaultdict(list)
for key in all_keys:
keys_by_endpoint[key.endpoint_id].append(key)
# 按 API 格式聚合
format_stats = defaultdict(
lambda: {
"total_endpoints": 0,
"total_keys": 0,
"active_keys": 0,
"health_scores": [],
"endpoint_ids": [],
"provider_ids": set(),
"key_ids": [],
}
)
for ep in endpoints:
api_format = ep.api_format if ep.api_format else "UNKNOWN"
# 统计端点数
format_stats[api_format]["total_endpoints"] += 1
format_stats[api_format]["endpoint_ids"].append(ep.id)
format_stats[api_format]["provider_ids"].add(ep.provider_id)
# 从预加载的密钥中获取
keys = keys_by_endpoint.get(ep.id, [])
format_stats[api_format]["total_keys"] += len(keys)
# 统计活跃密钥和健康度
if ep.is_active:
for key in keys:
format_stats[api_format]["key_ids"].append(key.id)
if key.is_active and not key.circuit_breaker_open:
format_stats[api_format]["active_keys"] += 1
health_score = key.health_score if key.health_score is not None else 1.0
format_stats[api_format]["health_scores"].append(health_score)
# 批量生成所有格式的时间线数据
all_key_ids = []
format_key_mapping: Dict[str, List[str]] = {}
for api_format, stats in format_stats.items():
key_ids = stats["key_ids"]
format_key_mapping[api_format] = key_ids
all_key_ids.extend(key_ids)
# 一次性查询所有时间线数据
timeline_data_map = EndpointHealthService._generate_timeline_batch(
db, format_key_mapping, now, lookback_hours
)
# 生成结果
result = []
for api_format, stats in format_stats.items():
timeline_data = timeline_data_map.get(api_format, {
"timeline": ["unknown"] * 100,
"time_range_start": None,
"time_range_end": None,
})
timeline = timeline_data["timeline"]
time_range_start = timeline_data.get("time_range_start")
time_range_end = timeline_data.get("time_range_end")
# 基于时间线计算实际健康度
if timeline:
healthy_count = sum(1 for status in timeline if status == "healthy")
warning_count = sum(1 for status in timeline if status == "warning")
unhealthy_count = sum(1 for status in timeline if status == "unhealthy")
known_count = healthy_count + warning_count + unhealthy_count
if known_count > 0:
avg_health = (healthy_count * 1.0 + warning_count * 0.8) / known_count
else:
if stats["health_scores"]:
avg_health = sum(stats["health_scores"]) / len(stats["health_scores"])
elif stats["total_keys"] == 0:
avg_health = 0.0
else:
avg_health = 0.1
else:
avg_health = 0.0
item = {
"api_format": api_format,
"display_name": EndpointHealthService._format_display_name(api_format),
"health_score": avg_health,
"timeline": timeline,
"time_range_start": time_range_start.isoformat() if time_range_start else None,
"time_range_end": time_range_end.isoformat() if time_range_end else None,
}
if include_admin_fields:
item.update(
{
"total_endpoints": stats["total_endpoints"],
"total_keys": stats["total_keys"],
"active_keys": stats["active_keys"],
"provider_count": len(stats["provider_ids"]),
}
)
result.append(item)
result.sort(key=lambda x: x["health_score"], reverse=True)
# 写入缓存
if use_cache:
EndpointHealthService._set_to_cache(cache_key, result)
return result
@staticmethod
def _generate_timeline_batch(
db: Session,
format_key_mapping: Dict[str, List[str]],
now: datetime,
lookback_hours: int,
segments: int = 100,
) -> Dict[str, Dict[str, Any]]:
"""
批量生成多个 API 格式的时间线数据(基于 RequestCandidate 表)
使用 RequestCandidate 表可以:
1. 记录所有尝试(包括 fallback 中失败的尝试)
2. 准确反映每个 Provider/Key 的真实健康状态
3. 失败的请求会显示为红色节点
Args:
db: 数据库会话
format_key_mapping: API格式 -> key_ids 的映射
now: 当前时间
lookback_hours: 回溯小时数
segments: 时间段数量
Returns:
API格式 -> 时间线数据的映射
"""
# 收集所有 key_ids
all_key_ids = []
for key_ids in format_key_mapping.values():
all_key_ids.extend(key_ids)
if not all_key_ids:
return {
api_format: {
"timeline": ["unknown"] * 100,
"time_range_start": None,
"time_range_end": None,
}
for api_format in format_key_mapping.keys()
}
# 计算时间范围
interval_minutes = (lookback_hours * 60) // segments
start_time = now - timedelta(hours=lookback_hours)
# 使用 RequestCandidate 表查询所有尝试记录
# 只统计最终状态success, failed, skipped
final_statuses = ["success", "failed", "skipped"]
segment_expr = func.floor(
func.extract('epoch', RequestCandidate.created_at - start_time) / (interval_minutes * 60)
).label('segment_idx')
candidate_stats = (
db.query(
RequestCandidate.key_id,
segment_expr,
func.count(RequestCandidate.id).label('total_count'),
func.sum(
case(
(RequestCandidate.status == "success", 1),
else_=0
)
).label('success_count'),
func.sum(
case(
(RequestCandidate.status == "failed", 1),
else_=0
)
).label('failed_count'),
func.min(RequestCandidate.created_at).label('min_time'),
func.max(RequestCandidate.created_at).label('max_time'),
)
.filter(
RequestCandidate.key_id.in_(all_key_ids),
RequestCandidate.created_at >= start_time,
RequestCandidate.created_at <= now,
RequestCandidate.status.in_(final_statuses),
)
.group_by(RequestCandidate.key_id, segment_expr)
.all()
)
# 构建 key_id -> api_format 的反向映射
key_to_format: Dict[str, str] = {}
for api_format, key_ids in format_key_mapping.items():
for key_id in key_ids:
key_to_format[key_id] = api_format
# 按 api_format 和 segment 聚合数据
format_segment_data: Dict[str, Dict[int, Dict]] = defaultdict(lambda: defaultdict(lambda: {
"total": 0,
"success": 0,
"failed": 0,
"min_time": None,
"max_time": None,
}))
for row in candidate_stats:
key_id = row.key_id
segment_idx = int(row.segment_idx) if row.segment_idx is not None else 0
api_format = key_to_format.get(key_id)
if api_format and 0 <= segment_idx < segments:
seg_data = format_segment_data[api_format][segment_idx]
seg_data["total"] += row.total_count or 0
seg_data["success"] += row.success_count or 0
seg_data["failed"] += row.failed_count or 0
if row.min_time:
if seg_data["min_time"] is None or row.min_time < seg_data["min_time"]:
seg_data["min_time"] = row.min_time
if row.max_time:
if seg_data["max_time"] is None or row.max_time > seg_data["max_time"]:
seg_data["max_time"] = row.max_time
# 生成各格式的时间线
result: Dict[str, Dict[str, Any]] = {}
for api_format in format_key_mapping.keys():
timeline = []
earliest_time = None
latest_time = None
segment_data = format_segment_data.get(api_format, {})
for i in range(segments):
seg = segment_data.get(i)
if not seg or seg["total"] == 0:
timeline.append("unknown")
else:
# 更新时间范围
if seg["min_time"]:
if earliest_time is None or seg["min_time"] < earliest_time:
earliest_time = seg["min_time"]
if seg["max_time"]:
if latest_time is None or seg["max_time"] > latest_time:
latest_time = seg["max_time"]
# 计算成功率 = success / (success + failed)
# skipped 不算失败,不影响成功率
actual_completed = seg["success"] + seg["failed"]
if actual_completed > 0:
success_rate = seg["success"] / actual_completed
else:
# 只有 skipped视为健康
success_rate = 1.0
if success_rate >= 0.95:
timeline.append("healthy")
elif success_rate >= 0.7:
timeline.append("warning")
else:
timeline.append("unhealthy")
result[api_format] = {
"timeline": timeline,
"time_range_start": earliest_time,
"time_range_end": latest_time if latest_time else now,
}
return result
@staticmethod
def _generate_timeline_from_usage(
db: Session,
endpoint_ids: List[str],
now: datetime,
lookback_hours: int,
segments: int = 100,
) -> Dict[str, Any]:
"""
从真实使用记录生成时间线数据(兼容旧接口,使用批量查询优化)
Args:
db: 数据库会话
endpoint_ids: 端点ID列表
now: 当前时间
lookback_hours: 回溯小时数
segments: 时间段数量
Returns:
包含时间线和时间范围的字典
"""
if not endpoint_ids:
return {
"timeline": ["unknown"] * 100,
"time_range_start": None,
"time_range_end": None,
}
# 先查询该 API 格式下的所有密钥
key_ids = [
k.id
for k in db.query(ProviderAPIKey.id)
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
.all()
]
if not key_ids:
return {
"timeline": ["unknown"] * 100,
"time_range_start": None,
"time_range_end": None,
}
# 使用批量查询
format_key_mapping = {"_single": key_ids}
result = EndpointHealthService._generate_timeline_batch(
db, format_key_mapping, now, lookback_hours, segments
)
return result.get("_single", {
"timeline": ["unknown"] * 100,
"time_range_start": None,
"time_range_end": None,
})
@staticmethod
def _format_display_name(api_format: str) -> str:
"""格式化 API 格式的显示名称"""
format_names = {
"CLAUDE": "Claude API",
"CLAUDE_CLI": "Claude CLI",
"CLAUDE_COMPATIBLE": "Claude 兼容",
"OPENAI": "OpenAI API",
"OPENAI_CLI": "OpenAI CLI",
"OPENAI_COMPATIBLE": "OpenAI 兼容",
}
return format_names.get(api_format, api_format)
@staticmethod
def _get_from_cache(key: str) -> Optional[List[Dict[str, Any]]]:
"""从 Redis 缓存获取数据"""
redis_client = _get_redis_client()
if not redis_client:
return None
try:
data = redis_client.get(key)
if data:
return json.loads(data)
except Exception as e:
logger.warning(f"Failed to get from cache: {e}")
return None
@staticmethod
def _set_to_cache(key: str, data: List[Dict[str, Any]]) -> None:
"""写入 Redis 缓存"""
redis_client = _get_redis_client()
if not redis_client:
return
try:
redis_client.setex(key, CACHE_TTL_SECONDS, json.dumps(data, default=str))
except Exception as e:
logger.warning(f"Failed to set cache: {e}")

View File

@@ -0,0 +1,641 @@
"""
健康监控器 - Endpoint 和 Key 的健康度追踪
功能:
1. 基于滑动窗口的错误率计算
2. 三态熔断器:关闭 -> 打开 -> 半开 -> 关闭
3. 半开状态允许少量请求验证服务恢复
4. 提供健康度查询和管理 API
"""
import os
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import case, func
from sqlalchemy.orm import Session
from src.config.constants import CircuitBreakerDefaults
from src.core.batch_committer import get_batch_committer
from src.core.logger import logger
from src.core.metrics import health_open_circuits
from src.models.database import ProviderAPIKey, ProviderEndpoint
class CircuitState:
"""熔断器状态"""
CLOSED = "closed" # 关闭(正常)
OPEN = "open" # 打开(熔断)
HALF_OPEN = "half_open" # 半开(验证恢复)
class HealthMonitor:
"""健康监控器(滑动窗口 + 半开状态模式)"""
# === 滑动窗口配置 ===
WINDOW_SIZE = int(os.getenv("HEALTH_WINDOW_SIZE", str(CircuitBreakerDefaults.WINDOW_SIZE)))
WINDOW_SECONDS = int(
os.getenv("HEALTH_WINDOW_SECONDS", str(CircuitBreakerDefaults.WINDOW_SECONDS))
)
MIN_REQUESTS = int(
os.getenv("HEALTH_MIN_REQUESTS", str(CircuitBreakerDefaults.MIN_REQUESTS_FOR_DECISION))
)
ERROR_RATE_THRESHOLD = float(
os.getenv("HEALTH_ERROR_RATE_THRESHOLD", str(CircuitBreakerDefaults.ERROR_RATE_THRESHOLD))
)
# === 半开状态配置 ===
HALF_OPEN_DURATION = int(
os.getenv(
"HEALTH_HALF_OPEN_DURATION", str(CircuitBreakerDefaults.HALF_OPEN_DURATION_SECONDS)
)
)
HALF_OPEN_SUCCESS_THRESHOLD = int(
os.getenv(
"HEALTH_HALF_OPEN_SUCCESS", str(CircuitBreakerDefaults.HALF_OPEN_SUCCESS_THRESHOLD)
)
)
HALF_OPEN_FAILURE_THRESHOLD = int(
os.getenv(
"HEALTH_HALF_OPEN_FAILURE", str(CircuitBreakerDefaults.HALF_OPEN_FAILURE_THRESHOLD)
)
)
# === 恢复配置 ===
INITIAL_RECOVERY_SECONDS = int(
os.getenv(
"HEALTH_INITIAL_RECOVERY_SECONDS", str(CircuitBreakerDefaults.INITIAL_RECOVERY_SECONDS)
)
)
RECOVERY_BACKOFF = int(
os.getenv(
"HEALTH_RECOVERY_BACKOFF", str(CircuitBreakerDefaults.RECOVERY_BACKOFF_MULTIPLIER)
)
)
MAX_RECOVERY_SECONDS = int(
os.getenv("HEALTH_MAX_RECOVERY_SECONDS", str(CircuitBreakerDefaults.MAX_RECOVERY_SECONDS))
)
# === 兼容旧参数(用于健康度展示)===
SUCCESS_INCREMENT = float(
os.getenv("HEALTH_SUCCESS_INCREMENT", str(CircuitBreakerDefaults.SUCCESS_INCREMENT))
)
FAILURE_DECREMENT = float(
os.getenv("HEALTH_FAILURE_DECREMENT", str(CircuitBreakerDefaults.FAILURE_DECREMENT))
)
PROBE_RECOVERY_SCORE = float(
os.getenv("HEALTH_PROBE_RECOVERY_SCORE", str(CircuitBreakerDefaults.PROBE_RECOVERY_SCORE))
)
# === 其他配置 ===
ALLOW_AUTO_RECOVER = os.getenv("HEALTH_AUTO_RECOVER_ENABLED", "true").lower() == "true"
CIRCUIT_HISTORY_LIMIT = int(os.getenv("HEALTH_CIRCUIT_HISTORY_LIMIT", "200"))
# 进程级别状态缓存
_circuit_history: List[Dict[str, Any]] = []
_open_circuit_keys: int = 0
# ==================== 核心方法 ====================
@classmethod
def record_success(
cls,
db: Session,
key_id: Optional[str] = None,
response_time_ms: Optional[int] = None,
) -> None:
"""记录成功请求"""
try:
if not key_id:
return
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
if not key:
return
now = datetime.now(timezone.utc)
now_ts = now.timestamp()
# 1. 更新滑动窗口
cls._add_to_window(key, now_ts, success=True)
# 2. 更新健康度(用于展示)
new_score = min(float(key.health_score or 0) + cls.SUCCESS_INCREMENT, 1.0)
key.health_score = new_score # type: ignore[assignment]
# 3. 更新统计
key.consecutive_failures = 0 # type: ignore[assignment]
key.last_failure_at = None # type: ignore[assignment]
key.success_count = int(key.success_count or 0) + 1 # type: ignore[assignment]
key.request_count = int(key.request_count or 0) + 1 # type: ignore[assignment]
if response_time_ms:
key.total_response_time_ms = int(key.total_response_time_ms or 0) + response_time_ms # type: ignore[assignment]
# 4. 处理熔断器状态
state = cls._get_circuit_state(key, now)
if state == CircuitState.HALF_OPEN:
# 半开状态:记录成功
key.half_open_successes = int(key.half_open_successes or 0) + 1 # type: ignore[assignment]
if int(key.half_open_successes or 0) >= cls.HALF_OPEN_SUCCESS_THRESHOLD:
# 达到成功阈值,关闭熔断器
cls._close_circuit(key, now, reason="半开状态验证成功")
elif state == CircuitState.OPEN:
# 打开状态下的成功(探测成功),进入半开状态
cls._enter_half_open(key, now)
db.flush()
get_batch_committer().mark_dirty(db)
except Exception as e:
logger.error(f"记录成功请求失败: {e}")
db.rollback()
@classmethod
def record_failure(
cls,
db: Session,
key_id: Optional[str] = None,
error_type: Optional[str] = None,
) -> None:
"""记录失败请求"""
try:
if not key_id:
return
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
if not key:
return
now = datetime.now(timezone.utc)
now_ts = now.timestamp()
# 1. 更新滑动窗口
cls._add_to_window(key, now_ts, success=False)
# 2. 更新健康度(用于展示)
new_score = max(float(key.health_score or 1) - cls.FAILURE_DECREMENT, 0.0)
key.health_score = new_score # type: ignore[assignment]
# 3. 更新统计
key.consecutive_failures = int(key.consecutive_failures or 0) + 1 # type: ignore[assignment]
key.last_failure_at = now # type: ignore[assignment]
key.error_count = int(key.error_count or 0) + 1 # type: ignore[assignment]
key.request_count = int(key.request_count or 0) + 1 # type: ignore[assignment]
# 4. 处理熔断器状态
state = cls._get_circuit_state(key, now)
if state == CircuitState.HALF_OPEN:
# 半开状态:记录失败
key.half_open_failures = int(key.half_open_failures or 0) + 1 # type: ignore[assignment]
if int(key.half_open_failures or 0) >= cls.HALF_OPEN_FAILURE_THRESHOLD:
# 达到失败阈值,重新打开熔断器
cls._open_circuit(key, now, reason="半开状态验证失败")
elif state == CircuitState.CLOSED:
# 关闭状态:检查是否需要打开熔断器
error_rate = cls._calculate_error_rate(key, now_ts)
window = key.request_results_window or []
if len(window) >= cls.MIN_REQUESTS and error_rate >= cls.ERROR_RATE_THRESHOLD:
cls._open_circuit(
key, now, reason=f"错误率 {error_rate:.0%} 超过阈值 {cls.ERROR_RATE_THRESHOLD:.0%}"
)
logger.debug(
f"[WARN] Key 健康度下降: {key_id[:8]}... -> {new_score:.2f} "
f"(连续失败 {key.consecutive_failures} 次, error_type={error_type})"
)
db.flush()
get_batch_committer().mark_dirty(db)
except Exception as e:
logger.error(f"记录失败请求失败: {e}")
db.rollback()
# ==================== 滑动窗口方法 ====================
@classmethod
def _add_to_window(cls, key: ProviderAPIKey, now_ts: float, success: bool) -> None:
"""添加请求结果到滑动窗口"""
window: List[Dict[str, Any]] = key.request_results_window or []
# 添加新记录
window.append({"ts": now_ts, "ok": success})
# 清理过期记录
cutoff_ts = now_ts - cls.WINDOW_SECONDS
window = [r for r in window if r["ts"] > cutoff_ts]
# 限制窗口大小
if len(window) > cls.WINDOW_SIZE:
window = window[-cls.WINDOW_SIZE :]
key.request_results_window = window # type: ignore[assignment]
@classmethod
def _calculate_error_rate(cls, key: ProviderAPIKey, now_ts: float) -> float:
"""计算滑动窗口内的错误率"""
window: List[Dict[str, Any]] = key.request_results_window or []
if not window:
return 0.0
# 过滤过期记录
cutoff_ts = now_ts - cls.WINDOW_SECONDS
valid_records = [r for r in window if r["ts"] > cutoff_ts]
if not valid_records:
return 0.0
failures = sum(1 for r in valid_records if not r["ok"])
return failures / len(valid_records)
# ==================== 熔断器状态方法 ====================
@classmethod
def _get_circuit_state(cls, key: ProviderAPIKey, now: datetime) -> str:
"""获取当前熔断器状态"""
if not key.circuit_breaker_open:
return CircuitState.CLOSED
# 检查是否在半开状态
if key.half_open_until and now < key.half_open_until:
return CircuitState.HALF_OPEN
# 检查是否到了探测时间(进入半开)
if key.next_probe_at and now >= key.next_probe_at:
return CircuitState.HALF_OPEN
return CircuitState.OPEN
@classmethod
def _open_circuit(cls, key: ProviderAPIKey, now: datetime, reason: str) -> None:
"""打开熔断器"""
was_open = key.circuit_breaker_open
key.circuit_breaker_open = True # type: ignore[assignment]
key.circuit_breaker_open_at = now # type: ignore[assignment]
key.half_open_until = None # type: ignore[assignment]
key.half_open_successes = 0 # type: ignore[assignment]
key.half_open_failures = 0 # type: ignore[assignment]
# 计算下次探测时间(进入半开状态的时间)
consecutive = int(key.consecutive_failures or 0)
recovery_seconds = cls._calculate_recovery_seconds(consecutive)
key.next_probe_at = now + timedelta(seconds=recovery_seconds) # type: ignore[assignment]
if not was_open:
cls._open_circuit_keys += 1
health_open_circuits.set(cls._open_circuit_keys)
logger.warning(
f"[OPEN] Key 熔断器打开: {key.id[:8]}... | 原因: {reason} | "
f"{recovery_seconds}秒后进入半开状态"
)
cls._push_circuit_event(
{
"event": "opened",
"key_id": key.id,
"reason": reason,
"recovery_seconds": recovery_seconds,
"timestamp": now.isoformat(),
}
)
@classmethod
def _enter_half_open(cls, key: ProviderAPIKey, now: datetime) -> None:
"""进入半开状态"""
key.half_open_until = now + timedelta(seconds=cls.HALF_OPEN_DURATION) # type: ignore[assignment]
key.half_open_successes = 0 # type: ignore[assignment]
key.half_open_failures = 0 # type: ignore[assignment]
logger.info(
f"[HALF-OPEN] Key 进入半开状态: {key.id[:8]}... | "
f"需要 {cls.HALF_OPEN_SUCCESS_THRESHOLD} 次成功关闭熔断器"
)
cls._push_circuit_event(
{
"event": "half_open",
"key_id": key.id,
"timestamp": now.isoformat(),
}
)
@classmethod
def _close_circuit(cls, key: ProviderAPIKey, now: datetime, reason: str) -> None:
"""关闭熔断器"""
key.circuit_breaker_open = False # type: ignore[assignment]
key.circuit_breaker_open_at = None # type: ignore[assignment]
key.next_probe_at = None # type: ignore[assignment]
key.half_open_until = None # type: ignore[assignment]
key.half_open_successes = 0 # type: ignore[assignment]
key.half_open_failures = 0 # type: ignore[assignment]
# 快速恢复健康度
key.health_score = max(float(key.health_score or 0), cls.PROBE_RECOVERY_SCORE) # type: ignore[assignment]
cls._open_circuit_keys = max(0, cls._open_circuit_keys - 1)
health_open_circuits.set(cls._open_circuit_keys)
logger.info(f"[CLOSED] Key 熔断器关闭: {key.id[:8]}... | 原因: {reason}")
cls._push_circuit_event(
{
"event": "closed",
"key_id": key.id,
"reason": reason,
"timestamp": now.isoformat(),
}
)
@classmethod
def _calculate_recovery_seconds(cls, consecutive_failures: int) -> int:
"""计算恢复等待时间(指数退避)"""
# 指数退避30s -> 60s -> 120s -> 240s -> 300s上限
exponent = min(consecutive_failures // 5, 4) # 每5次失败增加一级
seconds = cls.INITIAL_RECOVERY_SECONDS * (cls.RECOVERY_BACKOFF**exponent)
return min(int(seconds), cls.MAX_RECOVERY_SECONDS)
# ==================== 状态查询方法 ====================
@classmethod
def is_circuit_breaker_closed(cls, resource: ProviderAPIKey) -> bool:
"""检查熔断器是否允许请求通过"""
if not resource.circuit_breaker_open:
return True
now = datetime.now(timezone.utc)
state = cls._get_circuit_state(resource, now)
# 半开状态允许请求通过
if state == CircuitState.HALF_OPEN:
return True
# 检查是否到了探测时间
if resource.next_probe_at and now >= resource.next_probe_at:
# 自动进入半开状态
cls._enter_half_open(resource, now)
return True
return False
@classmethod
def get_circuit_breaker_status(
cls, resource: ProviderAPIKey
) -> Tuple[bool, Optional[str]]:
"""获取熔断器详细状态"""
if not resource.circuit_breaker_open:
return True, None
now = datetime.now(timezone.utc)
state = cls._get_circuit_state(resource, now)
if state == CircuitState.HALF_OPEN:
successes = int(resource.half_open_successes or 0)
return True, f"半开状态({successes}/{cls.HALF_OPEN_SUCCESS_THRESHOLD}成功)"
if resource.next_probe_at:
if now >= resource.next_probe_at:
return True, None
remaining = resource.next_probe_at - now
remaining_seconds = int(remaining.total_seconds())
if remaining_seconds >= 60:
time_str = f"{remaining_seconds // 60}min{remaining_seconds % 60}s"
else:
time_str = f"{remaining_seconds}s"
return False, f"熔断中({time_str}后半开)"
return False, "熔断中"
@classmethod
def get_key_health(cls, db: Session, key_id: str) -> Optional[Dict[str, Any]]:
"""获取 Key 健康状态"""
try:
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
if not key:
return None
now = datetime.now(timezone.utc)
now_ts = now.timestamp()
# 计算当前错误率
error_rate = cls._calculate_error_rate(key, now_ts)
window = key.request_results_window or []
valid_window = [r for r in window if r["ts"] > now_ts - cls.WINDOW_SECONDS]
avg_response_time_ms = (
int(key.total_response_time_ms or 0) / int(key.success_count or 1)
if key.success_count
else 0
)
return {
"key_id": key.id,
"health_score": float(key.health_score or 1.0),
"error_rate": error_rate,
"window_size": len(valid_window),
"consecutive_failures": int(key.consecutive_failures or 0),
"last_failure_at": key.last_failure_at.isoformat() if key.last_failure_at else None,
"is_active": key.is_active,
"statistics": {
"request_count": int(key.request_count or 0),
"success_count": int(key.success_count or 0),
"error_count": int(key.error_count or 0),
"success_rate": (
int(key.success_count or 0) / int(key.request_count or 1)
if key.request_count
else 0.0
),
"avg_response_time_ms": round(avg_response_time_ms, 2),
},
"circuit_breaker": {
"state": cls._get_circuit_state(key, now),
"open": key.circuit_breaker_open,
"open_at": (
key.circuit_breaker_open_at.isoformat()
if key.circuit_breaker_open_at
else None
),
"next_probe_at": (
key.next_probe_at.isoformat() if key.next_probe_at else None
),
"half_open_until": (
key.half_open_until.isoformat() if key.half_open_until else None
),
"half_open_successes": int(key.half_open_successes or 0),
"half_open_failures": int(key.half_open_failures or 0),
},
}
except Exception as e:
logger.error(f"获取 Key 健康状态失败: {e}")
return None
@classmethod
def get_endpoint_health(cls, db: Session, endpoint_id: str) -> Optional[Dict[str, Any]]:
"""获取 Endpoint 健康状态"""
try:
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == endpoint_id).first()
)
if not endpoint:
return None
return {
"endpoint_id": endpoint.id,
"health_score": float(endpoint.health_score or 1.0),
"consecutive_failures": int(endpoint.consecutive_failures or 0),
"last_failure_at": (
endpoint.last_failure_at.isoformat() if endpoint.last_failure_at else None
),
"is_active": endpoint.is_active,
}
except Exception as e:
logger.error(f"获取 Endpoint 健康状态失败: {e}")
return None
# ==================== 管理方法 ====================
@classmethod
def reset_health(cls, db: Session, key_id: Optional[str] = None) -> bool:
"""重置健康度"""
try:
if key_id:
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
if key:
key.health_score = 1.0 # type: ignore[assignment]
key.consecutive_failures = 0 # type: ignore[assignment]
key.last_failure_at = None # type: ignore[assignment]
key.request_results_window = [] # type: ignore[assignment]
key.circuit_breaker_open = False # type: ignore[assignment]
key.circuit_breaker_open_at = None # type: ignore[assignment]
key.next_probe_at = None # type: ignore[assignment]
key.half_open_until = None # type: ignore[assignment]
key.half_open_successes = 0 # type: ignore[assignment]
key.half_open_failures = 0 # type: ignore[assignment]
logger.info(f"[RESET] 重置 Key 健康度: {key_id}")
db.flush()
get_batch_committer().mark_dirty(db)
return True
except Exception as e:
logger.error(f"重置健康度失败: {e}")
db.rollback()
return False
@classmethod
def manually_enable(cls, db: Session, key_id: Optional[str] = None) -> bool:
"""手动启用 Key"""
try:
if key_id:
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
if key and not key.is_active:
key.is_active = True # type: ignore[assignment]
key.consecutive_failures = 0 # type: ignore[assignment]
logger.info(f"[OK] 手动启用 Key: {key_id}")
db.flush()
get_batch_committer().mark_dirty(db)
return True
except Exception as e:
logger.error(f"手动启用失败: {e}")
db.rollback()
return False
@classmethod
def get_all_health_status(cls, db: Session) -> Dict[str, Any]:
"""获取所有健康状态摘要"""
try:
endpoint_stats = db.query(
func.count(ProviderEndpoint.id).label("total"),
func.sum(case((ProviderEndpoint.is_active == True, 1), else_=0)).label("active"),
func.sum(case((ProviderEndpoint.health_score < 0.5, 1), else_=0)).label(
"unhealthy"
),
).first()
key_stats = db.query(
func.count(ProviderAPIKey.id).label("total"),
func.sum(case((ProviderAPIKey.is_active == True, 1), else_=0)).label("active"),
func.sum(case((ProviderAPIKey.health_score < 0.5, 1), else_=0)).label("unhealthy"),
func.sum(case((ProviderAPIKey.circuit_breaker_open == True, 1), else_=0)).label(
"circuit_open"
),
).first()
return {
"endpoints": {
"total": endpoint_stats.total or 0 if endpoint_stats else 0,
"active": int(endpoint_stats.active or 0) if endpoint_stats else 0,
"unhealthy": int(endpoint_stats.unhealthy or 0) if endpoint_stats else 0,
},
"keys": {
"total": key_stats.total or 0 if key_stats else 0,
"active": int(key_stats.active or 0) if key_stats else 0,
"unhealthy": int(key_stats.unhealthy or 0) if key_stats else 0,
"circuit_open": int(key_stats.circuit_open or 0) if key_stats else 0,
},
}
except Exception as e:
logger.error(f"获取健康状态摘要失败: {e}")
return {
"endpoints": {"total": 0, "active": 0, "unhealthy": 0},
"keys": {"total": 0, "active": 0, "unhealthy": 0, "circuit_open": 0},
}
# ==================== 历史记录方法 ====================
@classmethod
def _push_circuit_event(cls, event: Dict[str, Any]) -> None:
cls._circuit_history.append(event)
if len(cls._circuit_history) > cls.CIRCUIT_HISTORY_LIMIT:
cls._circuit_history.pop(0)
@classmethod
def get_circuit_history(cls, limit: int = 50) -> List[Dict[str, Any]]:
if limit <= 0:
return []
return cls._circuit_history[-limit:]
# ==================== 兼容旧方法 ====================
@classmethod
def is_eligible_for_probe(
cls,
db: Session,
endpoint_id: Optional[str] = None,
key_id: Optional[str] = None,
) -> bool:
"""检查是否有资格进行探测(兼容旧接口)"""
if not cls.ALLOW_AUTO_RECOVER:
return False
if endpoint_id:
return False # Endpoint 不支持探测
if key_id:
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
if key and key.circuit_breaker_open:
now = datetime.now(timezone.utc)
state = cls._get_circuit_state(key, now)
return state == CircuitState.HALF_OPEN
return False
# 全局健康监控器实例
health_monitor = HealthMonitor()
health_open_circuits.set(0)