mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
refactor: 统一响应解析中的嵌套错误检测逻辑
- 提取 _check_nested_error 函数处理多种错误格式 - 支持检测顶层 error、type=error 以及 chunks 内嵌套的错误 - 简化 OpenAIResponseParser 和 ClaudeResponseParser 中的错误处理代码 - 提高代码复用性和可维护性
This commit is contained in:
@@ -5,7 +5,7 @@
|
|||||||
不再经过 Protocol 抽象层。
|
不再经过 Protocol 抽象层。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
from src.api.handlers.base.response_parser import (
|
from src.api.handlers.base.response_parser import (
|
||||||
ParsedChunk,
|
ParsedChunk,
|
||||||
@@ -15,6 +15,48 @@ from src.api.handlers.base.response_parser import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
检查响应中是否存在嵌套错误(某些代理服务返回 HTTP 200 但在响应体中包含错误)
|
||||||
|
|
||||||
|
检测格式:
|
||||||
|
1. 顶层 error: {"error": {...}}
|
||||||
|
2. 顶层 type=error: {"type": "error", ...}
|
||||||
|
3. chunks 内嵌套 error: {"chunks": [{"error": {...}}]}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: 响应字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(is_error, error_dict): 是否为错误,以及提取的错误信息
|
||||||
|
"""
|
||||||
|
# 顶层 error
|
||||||
|
if "error" in response:
|
||||||
|
error = response["error"]
|
||||||
|
if isinstance(error, dict):
|
||||||
|
return True, error
|
||||||
|
return True, {"message": str(error)}
|
||||||
|
|
||||||
|
# 顶层 type=error
|
||||||
|
if response.get("type") == "error":
|
||||||
|
return True, response
|
||||||
|
|
||||||
|
# chunks 内嵌套 error (某些代理返回这种格式)
|
||||||
|
chunks = response.get("chunks", [])
|
||||||
|
if chunks and isinstance(chunks, list):
|
||||||
|
for chunk in chunks:
|
||||||
|
if isinstance(chunk, dict):
|
||||||
|
if "error" in chunk:
|
||||||
|
error = chunk["error"]
|
||||||
|
if isinstance(error, dict):
|
||||||
|
return True, error
|
||||||
|
return True, {"message": str(error)}
|
||||||
|
if chunk.get("type") == "error":
|
||||||
|
return True, chunk
|
||||||
|
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponseParser(ResponseParser):
|
class OpenAIResponseParser(ResponseParser):
|
||||||
"""OpenAI 格式响应解析器"""
|
"""OpenAI 格式响应解析器"""
|
||||||
|
|
||||||
@@ -81,15 +123,12 @@ class OpenAIResponseParser(ResponseParser):
|
|||||||
result.input_tokens = usage.get("prompt_tokens", 0)
|
result.input_tokens = usage.get("prompt_tokens", 0)
|
||||||
result.output_tokens = usage.get("completion_tokens", 0)
|
result.output_tokens = usage.get("completion_tokens", 0)
|
||||||
|
|
||||||
# 检查错误
|
# 检查错误(支持嵌套错误格式)
|
||||||
if "error" in response:
|
is_error, error_info = _check_nested_error(response)
|
||||||
|
if is_error and error_info:
|
||||||
result.is_error = True
|
result.is_error = True
|
||||||
error = response.get("error", {})
|
result.error_type = error_info.get("type")
|
||||||
if isinstance(error, dict):
|
result.error_message = error_info.get("message")
|
||||||
result.error_type = error.get("type")
|
|
||||||
result.error_message = error.get("message")
|
|
||||||
else:
|
|
||||||
result.error_message = str(error)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -112,7 +151,8 @@ class OpenAIResponseParser(ResponseParser):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
||||||
return "error" in response
|
is_error, _ = _check_nested_error(response)
|
||||||
|
return is_error
|
||||||
|
|
||||||
|
|
||||||
class OpenAICliResponseParser(OpenAIResponseParser):
|
class OpenAICliResponseParser(OpenAIResponseParser):
|
||||||
@@ -215,15 +255,12 @@ class ClaudeResponseParser(ResponseParser):
|
|||||||
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0)
|
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0)
|
||||||
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
|
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
|
||||||
|
|
||||||
# 检查错误
|
# 检查错误(支持嵌套错误格式)
|
||||||
if "error" in response or response.get("type") == "error":
|
is_error, error_info = _check_nested_error(response)
|
||||||
|
if is_error and error_info:
|
||||||
result.is_error = True
|
result.is_error = True
|
||||||
error = response.get("error", {})
|
result.error_type = error_info.get("type")
|
||||||
if isinstance(error, dict):
|
result.error_message = error_info.get("message")
|
||||||
result.error_type = error.get("type")
|
|
||||||
result.error_message = error.get("message")
|
|
||||||
else:
|
|
||||||
result.error_message = str(error)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -247,7 +284,8 @@ class ClaudeResponseParser(ResponseParser):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
||||||
return "error" in response or response.get("type") == "error"
|
is_error, _ = _check_nested_error(response)
|
||||||
|
return is_error
|
||||||
|
|
||||||
|
|
||||||
class ClaudeCliResponseParser(ClaudeResponseParser):
|
class ClaudeCliResponseParser(ClaudeResponseParser):
|
||||||
|
|||||||
Reference in New Issue
Block a user