Files
Aether/src/services/orchestration/error_classifier.py

656 lines
24 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
错误分类器
负责错误分类和处理策略决定
"""
import json
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Union
import httpx
from sqlalchemy.orm import Session
from src.core.enums import APIFormat
from src.core.exceptions import (
ConcurrencyLimitError,
ProviderAuthException,
ProviderException,
ProviderNotAvailableException,
ProviderRateLimitException,
UpstreamClientException,
)
from src.core.logger import logger
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
from src.services.cache.aware_scheduler import CacheAwareScheduler
from src.services.health.monitor import health_monitor
from src.services.provider.format import normalize_api_format
from src.services.rate_limit.adaptive_concurrency import get_adaptive_manager
from src.services.rate_limit.detector import RateLimitType, detect_rate_limit_type
class ErrorAction(Enum):
"""错误处理动作"""
CONTINUE = "continue" # 继续重试当前候选
BREAK = "break" # 跳到下一个候选
RAISE = "raise" # 直接抛出异常
class ErrorClassifier:
"""
错误分类器 - 负责错误分类和处理策略
职责
1. 将错误分类为可重试/不可重试
2. 决定错误后的处理动作重试/切换/放弃
3. 处理特定类型的错误 429 限流
4. 更新健康状态和缓存亲和性
"""
# 需要触发故障转移的错误类型
RETRIABLE_ERRORS: Tuple[type, ...] = (
ProviderException, # 包含所有 Provider 异常子类
ConnectionError, # Python 标准连接错误
TimeoutError, # Python 标准超时错误
httpx.TransportError, # HTTPX 传输错误
)
# 不可重试的错误类型(直接抛出)
NON_RETRIABLE_ERRORS: Tuple[type, ...] = (
ValueError, # 参数错误
TypeError, # 类型错误
KeyError, # 键错误
UpstreamClientException, # 上游客户端错误
)
# 表示客户端请求错误的关键词(不区分大小写)
# 这些错误是由用户请求本身导致的,换 Provider 也无济于事
# 注意:标准 API 返回的 error.type 已在 CLIENT_ERROR_TYPES 中处理
# 这里主要用于匹配非标准格式或第三方代理的错误消息
#
# 重要:不要在此列表中包含 Provider Key 配置问题(如 invalid_api_key
# 这类错误应该触发故障转移,而不是直接返回给用户
2025-12-10 20:52:44 +08:00
CLIENT_ERROR_PATTERNS: Tuple[str, ...] = (
"could not process image", # 图片处理失败
"image too large", # 图片过大
"invalid image", # 无效图片
"unsupported image", # 不支持的图片格式
"content_policy_violation", # 内容违规
"context_length_exceeded", # 上下文长度超限
"content_length_limit", # 请求内容长度超限 (Claude API)
"content_length_exceeds", # 内容长度超限变体 (AWS CodeWhisperer)
2025-12-10 20:52:44 +08:00
"max_tokens", # token 数超限
"invalid_prompt", # 无效的提示词
"content too long", # 内容过长
"input is too long", # 输入过长 (AWS)
2025-12-10 20:52:44 +08:00
"message is too long", # 消息过长
"prompt is too long", # Prompt 超长(第三方代理常见格式)
"image exceeds", # 图片超出限制
"pdf too large", # PDF 过大
"file too large", # 文件过大
"tool_use_id", # tool_result 引用了不存在的 tool_use兼容非标准代理
"validationexception", # AWS 验证异常
2025-12-10 20:52:44 +08:00
)
def __init__(
self,
db: Session,
adaptive_manager: Any = None,
cache_scheduler: Optional[CacheAwareScheduler] = None,
) -> None:
"""
初始化错误分类器
Args:
db: 数据库会话
adaptive_manager: 自适应并发管理器
cache_scheduler: 缓存调度器可选
"""
self.db = db
self.adaptive_manager = adaptive_manager or get_adaptive_manager()
self.cache_scheduler = cache_scheduler
# 表示客户端错误的 error type不区分大小写
# 这些 type 表明是请求本身的问题,不应重试
CLIENT_ERROR_TYPES: Tuple[str, ...] = (
# Claude/OpenAI 标准
"invalid_request_error",
# Gemini
"invalid_argument",
"failed_precondition",
# AWS
"validationexception",
# 通用
"validation_error",
"bad_request",
)
# 表示客户端错误的 reason/code 字段值
CLIENT_ERROR_REASONS: Tuple[str, ...] = (
"CONTENT_LENGTH_EXCEEDS_THRESHOLD",
"CONTEXT_LENGTH_EXCEEDED",
"MAX_TOKENS_EXCEEDED",
"INVALID_CONTENT",
"CONTENT_POLICY_VIOLATION",
)
def _parse_error_response(self, error_text: Optional[str]) -> Dict[str, Any]:
"""
解析错误响应为结构化数据
支持多种格式:
- {"error": {"type": "...", "message": "..."}} (Claude/OpenAI)
- {"error": {"message": "...", "__type": "..."}} (AWS)
- {"errorMessage": "..."} (Lambda)
- {"error": "..."}
- {"message": "...", "reason": "..."}
Returns:
结构化的错误信息: {
"type": str, # 错误类型
"message": str, # 错误消息
"reason": str, # 错误原因/代码
"raw": str, # 原始文本
}
"""
result = {"type": "", "message": "", "reason": "", "raw": error_text or ""}
if not error_text:
return result
try:
data = json.loads(error_text)
# 格式 1: {"error": {"type": "...", "message": "..."}}
if isinstance(data.get("error"), dict):
error_obj = data["error"]
result["type"] = str(error_obj.get("type", ""))
result["message"] = str(error_obj.get("message", ""))
# AWS 格式: {"error": {"__type": "...", "message": "...", "reason": "..."}}
# __type 直接在 error 对象中,而不是嵌套在 message 里
if "__type" in error_obj:
result["type"] = result["type"] or str(error_obj.get("__type", ""))
if "reason" in error_obj:
result["reason"] = str(error_obj.get("reason", ""))
if "code" in error_obj:
result["reason"] = result["reason"] or str(error_obj.get("code", ""))
# 嵌套 JSON 格式: message 字段本身是 JSON 字符串
# 支持多种嵌套格式:
# - AWS: {"__type": "...", "message": "...", "reason": "..."}
# - 第三方代理: {"error": {"type": "...", "message": "..."}}
if result["message"].startswith("{"):
try:
nested = json.loads(result["message"])
if isinstance(nested, dict):
# AWS 格式
if "__type" in nested:
result["type"] = result["type"] or str(nested.get("__type", ""))
result["message"] = str(nested.get("message", result["message"]))
result["reason"] = str(nested.get("reason", ""))
# 第三方代理格式: {"error": {"message": "..."}}
elif isinstance(nested.get("error"), dict):
inner_error = nested["error"]
inner_msg = str(inner_error.get("message", ""))
if inner_msg:
result["message"] = inner_msg
# 简单格式: {"message": "..."}
elif "message" in nested:
result["message"] = str(nested["message"])
except json.JSONDecodeError:
pass
# 格式 2: {"error": "..."}
elif isinstance(data.get("error"), str):
result["message"] = str(data["error"])
# 格式 3: {"errorMessage": "..."} (Lambda)
elif "errorMessage" in data:
result["message"] = str(data["errorMessage"])
# 格式 4: {"message": "...", "reason": "..."}
elif "message" in data:
result["message"] = str(data["message"])
result["reason"] = str(data.get("reason", ""))
# 提取顶层的 reason/code
if not result["reason"]:
result["reason"] = str(data.get("reason", data.get("code", "")))
except (json.JSONDecodeError, TypeError, KeyError):
result["message"] = error_text[:500] if len(error_text) > 500 else error_text
return result
2025-12-10 20:52:44 +08:00
def _is_client_error(self, error_text: Optional[str]) -> bool:
"""
检测错误响应是否为客户端错误不应重试
判断逻辑按优先级
1. 检查 error.type 是否为已知的客户端错误类型
2. 检查 reason/code 是否为已知的客户端错误原因
3. 回退到关键词匹配
2025-12-10 20:52:44 +08:00
Args:
error_text: 错误响应文本
Returns:
是否为客户端错误
"""
if not error_text:
return False
parsed = self._parse_error_response(error_text)
# 1. 检查 error type
if parsed["type"]:
error_type_lower = parsed["type"].lower()
if any(t.lower() in error_type_lower for t in self.CLIENT_ERROR_TYPES):
return True
# 2. 检查 reason/code
if parsed["reason"]:
reason_upper = parsed["reason"].upper()
if any(r in reason_upper for r in self.CLIENT_ERROR_REASONS):
return True
# 3. 回退到关键词匹配(合并 message 和 raw
search_text = f"{parsed['message']} {parsed['raw']}".lower()
return any(pattern.lower() in search_text for pattern in self.CLIENT_ERROR_PATTERNS)
2025-12-10 20:52:44 +08:00
def _extract_error_message(self, error_text: Optional[str]) -> Optional[str]:
"""
从错误响应中提取错误消息
Args:
error_text: 错误响应文本
Returns:
提取的错误消息
2025-12-10 20:52:44 +08:00
"""
if not error_text:
return None
parsed = self._parse_error_response(error_text)
2025-12-10 20:52:44 +08:00
# 构建可读的错误消息
parts = []
if parsed["type"]:
parts.append(parsed["type"])
if parsed["reason"]:
parts.append(f"[{parsed['reason']}]")
if parsed["message"]:
parts.append(parsed["message"])
2025-12-10 20:52:44 +08:00
if parts:
return ": ".join(parts) if len(parts) > 1 else parts[0]
2025-12-10 20:52:44 +08:00
# 无法解析,返回原始文本(截断)
return parsed["raw"][:500] if len(parsed["raw"]) > 500 else parsed["raw"]
2025-12-10 20:52:44 +08:00
def classify(
self,
error: Exception,
has_retry_left: bool = False,
) -> ErrorAction:
"""
分类错误返回处理动作
Args:
error: 异常对象
has_retry_left: 当前候选是否还有重试次数
Returns:
ErrorAction: 处理动作
"""
if isinstance(error, ConcurrencyLimitError):
return ErrorAction.BREAK
if isinstance(error, httpx.HTTPStatusError):
# HTTP 错误根据状态码决定
return ErrorAction.CONTINUE if has_retry_left else ErrorAction.BREAK
if isinstance(error, self.RETRIABLE_ERRORS):
return ErrorAction.CONTINUE if has_retry_left else ErrorAction.BREAK
if isinstance(error, self.NON_RETRIABLE_ERRORS):
return ErrorAction.RAISE
# 未知错误,直接抛出
return ErrorAction.RAISE
async def handle_rate_limit(
self,
key: ProviderAPIKey,
provider_name: str,
current_concurrent: Optional[int],
exception: ProviderRateLimitException,
request_id: Optional[str] = None,
) -> str:
"""
处理 429 速率限制错误的自适应调整
Args:
key: API Key 对象
provider_name: 提供商名称
current_concurrent: 当前并发数
exception: 速率限制异常
request_id: 请求 ID用于日志
Returns:
限制类型: "concurrent" "rpm" "unknown"
"""
try:
# 提取响应头(如果有)
response_headers = {}
if hasattr(exception, "response_headers"):
response_headers = exception.response_headers or {}
# 检测速率限制类型
rate_limit_info = detect_rate_limit_type(
headers=response_headers,
provider_name=provider_name,
current_concurrent=current_concurrent,
)
logger.info(f" [{request_id}] 429错误分析: "
f"类型={rate_limit_info.limit_type}, "
f"retry_after={rate_limit_info.retry_after}s, "
f"当前并发={current_concurrent}")
# 调用自适应管理器处理
new_limit = self.adaptive_manager.handle_429_error(
db=self.db,
key=key,
rate_limit_info=rate_limit_info,
current_concurrent=current_concurrent,
)
if rate_limit_info.limit_type == RateLimitType.CONCURRENT:
logger.warning(f" [{request_id}] 自适应调整: " f"Key {key.id[:8]}... 并发限制 -> {new_limit}")
return "concurrent"
elif rate_limit_info.limit_type == RateLimitType.RPM:
logger.info(f" [{request_id}] [RPM] RPM限制需要切换Provider")
return "rpm"
else:
return "unknown"
except Exception as e:
logger.exception(f" [{request_id}] 处理429错误时异常: {e}")
return "unknown"
def convert_http_error(
self,
error: httpx.HTTPStatusError,
provider_name: str,
error_response_text: Optional[str] = None,
) -> Union[ProviderException, UpstreamClientException]:
"""
转换 HTTP 错误为 Provider 异常
Args:
error: HTTP 状态错误
provider_name: Provider 名称
error_response_text: 错误响应文本可选
Returns:
ProviderException UpstreamClientException: 转换后的异常
"""
status = error.response.status_code if error.response else None
# 提取可读的错误消息
extracted_message = self._extract_error_message(error_response_text)
# 构建详细错误信息
if extracted_message:
detailed_message = f"提供商 '{provider_name}' 返回错误 {status}: {extracted_message}"
else:
detailed_message = f"提供商 '{provider_name}' 返回错误: {status}"
if status == 401:
return ProviderAuthException(provider_name=provider_name)
if status == 429:
return ProviderRateLimitException(
message=error_response_text or f"提供商 '{provider_name}' 速率限制",
provider_name=provider_name,
response_headers=dict(error.response.headers) if error.response else None,
retry_after=(
int(error.response.headers.get("retry-after", 0))
if error.response and error.response.headers.get("retry-after")
else None
),
)
# 400 错误:检查是否为客户端请求错误(不应重试)
if status == 400 and self._is_client_error(error_response_text):
logger.info(f"检测到客户端请求错误,不进行重试: {extracted_message}")
return UpstreamClientException(
message=extracted_message or "请求无效",
provider_name=provider_name,
status_code=400,
upstream_error=error_response_text,
)
if status and status >= 500:
return ProviderNotAvailableException(
message=detailed_message,
provider_name=provider_name,
)
return ProviderNotAvailableException(
message=detailed_message,
provider_name=provider_name,
)
async def handle_http_error(
self,
http_error: httpx.HTTPStatusError,
*,
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
affinity_key: str,
api_format: Union[str, APIFormat],
global_model_id: str,
request_id: Optional[str],
captured_key_concurrent: Optional[int],
elapsed_ms: Optional[int],
attempt: int,
max_attempts: int,
) -> Dict[str, Any]:
"""
处理 HTTP 错误返回 extra_data
Args:
http_error: HTTP 状态错误
provider: Provider 对象
endpoint: Endpoint 对象
key: API Key 对象
affinity_key: 亲和性标识符通常为 API Key ID
api_format: API 格式
global_model_id: GlobalModel ID规范化的模型标识
request_id: 请求 ID
captured_key_concurrent: 捕获的并发数
elapsed_ms: 耗时毫秒
attempt: 当前尝试次数
max_attempts: 最大尝试次数
Returns:
Dict[str, Any]: 额外数据包含
- error_response: 错误响应文本如有
- converted_error: 转换后的异常对象用于判断是否应该重试
"""
provider_name = str(provider.name)
# 尝试读取错误响应内容
error_response_text = None
try:
if http_error.response and hasattr(http_error.response, "text"):
error_response_text = http_error.response.text[:1000] # 限制长度
except Exception:
pass
logger.warning(f" [{request_id}] HTTP错误 (attempt={attempt}/{max_attempts}): "
f"{http_error.response.status_code if http_error.response else 'unknown'}")
converted_error = self.convert_http_error(http_error, provider_name, error_response_text)
# 构建 extra_data包含转换后的异常
extra_data: Dict[str, Any] = {
"converted_error": converted_error,
}
if error_response_text:
extra_data["error_response"] = error_response_text
# 转换 api_format 为字符串
api_format_str = (
normalize_api_format(api_format).value
if isinstance(api_format, (str, APIFormat))
else str(api_format)
)
# 处理客户端请求错误(不应重试,不失效缓存,不记录健康失败)
if isinstance(converted_error, UpstreamClientException):
logger.warning(f" [{request_id}] 客户端请求错误,不进行重试: {converted_error.message}")
return extra_data
# 处理认证错误
if isinstance(converted_error, ProviderAuthException):
if endpoint and key and self.cache_scheduler is not None:
await self.cache_scheduler.invalidate_cache(
affinity_key=affinity_key,
api_format=api_format_str,
global_model_id=global_model_id,
endpoint_id=str(endpoint.id),
key_id=str(key.id),
)
if key:
health_monitor.record_failure(
db=self.db,
key_id=str(key.id),
error_type="ProviderAuthException",
)
return extra_data
# 处理限流错误
if isinstance(converted_error, ProviderRateLimitException) and key:
await self.handle_rate_limit(
key=key,
provider_name=provider_name,
current_concurrent=captured_key_concurrent,
exception=converted_error,
request_id=request_id,
)
if endpoint and self.cache_scheduler is not None:
await self.cache_scheduler.invalidate_cache(
affinity_key=affinity_key,
api_format=api_format_str,
global_model_id=global_model_id,
endpoint_id=str(endpoint.id),
key_id=str(key.id),
)
else:
# 其他错误也失效缓存
if endpoint and key and self.cache_scheduler is not None:
await self.cache_scheduler.invalidate_cache(
affinity_key=affinity_key,
api_format=api_format_str,
global_model_id=global_model_id,
endpoint_id=str(endpoint.id),
key_id=str(key.id),
)
# 记录健康失败
if key:
health_monitor.record_failure(
db=self.db,
key_id=str(key.id),
error_type=type(converted_error).__name__,
)
return extra_data
async def handle_retriable_error(
self,
error: Exception,
*,
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
affinity_key: str,
api_format: Union[str, APIFormat],
global_model_id: str,
captured_key_concurrent: Optional[int],
elapsed_ms: Optional[int],
request_id: Optional[str],
attempt: int,
max_attempts: int,
) -> None:
"""
处理可重试错误
Args:
error: 异常对象
provider: Provider 对象
endpoint: Endpoint 对象
key: API Key 对象
affinity_key: 亲和性标识符通常为 API Key ID
api_format: API 格式
global_model_id: GlobalModel ID规范化的模型标识用于缓存亲和性
captured_key_concurrent: 捕获的并发数
elapsed_ms: 耗时毫秒
request_id: 请求 ID
attempt: 当前尝试次数
max_attempts: 最大尝试次数
"""
provider_name = str(provider.name)
logger.warning(f" [{request_id}] 请求失败 (attempt={attempt}/{max_attempts}): "
f"{type(error).__name__}: {str(error)}")
# 转换 api_format 为字符串
api_format_str = (
normalize_api_format(api_format).value
if isinstance(api_format, (str, APIFormat))
else str(api_format)
)
# 处理限流错误
if isinstance(error, ProviderRateLimitException) and key:
await self.handle_rate_limit(
key=key,
provider_name=provider_name,
current_concurrent=captured_key_concurrent,
exception=error,
request_id=request_id,
)
if endpoint and self.cache_scheduler is not None:
await self.cache_scheduler.invalidate_cache(
affinity_key=affinity_key,
api_format=api_format_str,
global_model_id=global_model_id,
endpoint_id=str(endpoint.id),
key_id=str(key.id),
)
elif endpoint and key and self.cache_scheduler is not None:
# 其他错误也失效缓存
await self.cache_scheduler.invalidate_cache(
affinity_key=affinity_key,
api_format=api_format_str,
global_model_id=global_model_id,
endpoint_id=str(endpoint.id),
key_id=str(key.id),
)
# 记录健康失败
if key:
health_monitor.record_failure(
db=self.db,
key_id=str(key.id),
error_type=type(error).__name__,
)