Files
Aether/src/api/handlers/base/parsers.py
fawney19 f3a69a6160 refactor(handler): implement defensive token update strategy and extract cache creation token utility
- Add extract_cache_creation_tokens utility to handle new/old cache creation token formats
- Implement defensive update strategy in StreamContext to prevent zero values overwriting valid data
- Simplify cache creation token parsing in Claude handler using new utility
- Add comprehensive test suite for cache creation token extraction
- Improve type hints in handler classes
2025-12-16 00:02:49 +08:00

509 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
响应解析器工厂
直接根据格式 ID 创建对应的 ResponseParser 实现,
不再经过 Protocol 抽象层。
"""
from typing import Any, Dict, Optional, Tuple, Type
from src.api.handlers.base.response_parser import (
ParsedChunk,
ParsedResponse,
ResponseParser,
StreamStats,
)
from src.api.handlers.base.utils import extract_cache_creation_tokens
def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[str, Any]]]:
"""
检查响应中是否存在嵌套错误(某些代理服务返回 HTTP 200 但在响应体中包含错误)
检测格式:
1. 顶层 error: {"error": {...}}
2. 顶层 type=error: {"type": "error", ...}
3. chunks 内嵌套 error: {"chunks": [{"error": {...}}]}
Args:
response: 响应字典
Returns:
(is_error, error_dict): 是否为错误,以及提取的错误信息
"""
# 顶层 error
if "error" in response:
error = response["error"]
if isinstance(error, dict):
return True, error
return True, {"message": str(error)}
# 顶层 type=error
if response.get("type") == "error":
return True, response
# chunks 内嵌套 error (某些代理返回这种格式)
chunks = response.get("chunks", [])
if chunks and isinstance(chunks, list):
for chunk in chunks:
if isinstance(chunk, dict):
if "error" in chunk:
error = chunk["error"]
if isinstance(error, dict):
return True, error
return True, {"message": str(error)}
if chunk.get("type") == "error":
return True, chunk
return False, None
class OpenAIResponseParser(ResponseParser):
"""OpenAI 格式响应解析器"""
def __init__(self) -> None:
from src.api.handlers.openai.stream_parser import OpenAIStreamParser
self._parser = OpenAIStreamParser()
self.name = "OPENAI"
self.api_format = "OPENAI"
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
if not line or not line.strip():
return None
if line.startswith("data: "):
data_str = line[6:]
else:
data_str = line
parsed = self._parser.parse_line(data_str)
if parsed is None:
return None
chunk = ParsedChunk(
raw_line=line,
event_type=None,
data=parsed,
)
# 提取文本增量
text_delta = self._parser.extract_text_delta(parsed)
if text_delta:
chunk.text_delta = text_delta
stats.collected_text += text_delta
# 检查是否结束
if self._parser.is_done_chunk(parsed):
chunk.is_done = True
stats.has_completion = True
stats.chunk_count += 1
stats.data_count += 1
return chunk
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
result = ParsedResponse(
raw_response=response,
status_code=status_code,
)
# 提取文本内容
choices = response.get("choices", [])
if choices:
message = choices[0].get("message", {})
content = message.get("content")
if content:
result.text_content = content
result.response_id = response.get("id")
# 提取 usage
usage = response.get("usage", {})
result.input_tokens = usage.get("prompt_tokens", 0)
result.output_tokens = usage.get("completion_tokens", 0)
# 检查错误(支持嵌套错误格式)
is_error, error_info = _check_nested_error(response)
if is_error and error_info:
result.is_error = True
result.error_type = error_info.get("type")
result.error_message = error_info.get("message")
return result
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
usage = response.get("usage", {})
return {
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
"cache_creation_tokens": 0,
"cache_read_tokens": 0,
}
def extract_text_content(self, response: Dict[str, Any]) -> str:
choices = response.get("choices", [])
if choices:
message = choices[0].get("message", {})
content = message.get("content")
if isinstance(content, str):
return content
return ""
def is_error_response(self, response: Dict[str, Any]) -> bool:
is_error, _ = _check_nested_error(response)
return is_error
class OpenAICliResponseParser(OpenAIResponseParser):
"""OpenAI CLI 格式响应解析器"""
def __init__(self) -> None:
super().__init__()
self.name = "OPENAI_CLI"
self.api_format = "OPENAI_CLI"
class ClaudeResponseParser(ResponseParser):
"""Claude 格式响应解析器"""
def __init__(self) -> None:
from src.api.handlers.claude.stream_parser import ClaudeStreamParser
self._parser = ClaudeStreamParser()
self.name = "CLAUDE"
self.api_format = "CLAUDE"
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
if not line or not line.strip():
return None
if line.startswith("data: "):
data_str = line[6:]
else:
data_str = line
parsed = self._parser.parse_line(data_str)
if parsed is None:
return None
chunk = ParsedChunk(
raw_line=line,
event_type=self._parser.get_event_type(parsed),
data=parsed,
)
# 提取文本增量
text_delta = self._parser.extract_text_delta(parsed)
if text_delta:
chunk.text_delta = text_delta
stats.collected_text += text_delta
# 检查是否结束
if self._parser.is_done_event(parsed):
chunk.is_done = True
stats.has_completion = True
# 提取 usage
usage = self._parser.extract_usage(parsed)
if usage:
chunk.input_tokens = usage.get("input_tokens", 0)
chunk.output_tokens = usage.get("output_tokens", 0)
chunk.cache_creation_tokens = usage.get("cache_creation_tokens", 0)
chunk.cache_read_tokens = usage.get("cache_read_tokens", 0)
stats.input_tokens = chunk.input_tokens
stats.output_tokens = chunk.output_tokens
stats.cache_creation_tokens = chunk.cache_creation_tokens
stats.cache_read_tokens = chunk.cache_read_tokens
# 检查错误
if self._parser.is_error_event(parsed):
chunk.is_error = True
error = parsed.get("error", {})
if isinstance(error, dict):
chunk.error_message = error.get("message", str(error))
else:
chunk.error_message = str(error)
stats.chunk_count += 1
stats.data_count += 1
return chunk
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
result = ParsedResponse(
raw_response=response,
status_code=status_code,
)
# 提取文本内容
content = response.get("content", [])
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(block.get("text", ""))
result.text_content = "".join(text_parts)
result.response_id = response.get("id")
# 提取 usage
usage = response.get("usage", {})
result.input_tokens = usage.get("input_tokens", 0)
result.output_tokens = usage.get("output_tokens", 0)
result.cache_creation_tokens = extract_cache_creation_tokens(usage)
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
# 检查错误(支持嵌套错误格式)
is_error, error_info = _check_nested_error(response)
if is_error and error_info:
result.is_error = True
result.error_type = error_info.get("type")
result.error_message = error_info.get("message")
return result
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
# 对于 message_start 事件usage 在 message.usage 路径下
# 对于其他响应usage 在顶层
usage = response.get("usage", {})
if not usage and "message" in response:
usage = response.get("message", {}).get("usage", {})
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": extract_cache_creation_tokens(usage),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
}
def extract_text_content(self, response: Dict[str, Any]) -> str:
content = response.get("content", [])
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(block.get("text", ""))
return "".join(text_parts)
return ""
def is_error_response(self, response: Dict[str, Any]) -> bool:
is_error, _ = _check_nested_error(response)
return is_error
class ClaudeCliResponseParser(ClaudeResponseParser):
"""Claude CLI 格式响应解析器"""
def __init__(self) -> None:
super().__init__()
self.name = "CLAUDE_CLI"
self.api_format = "CLAUDE_CLI"
class GeminiResponseParser(ResponseParser):
"""Gemini 格式响应解析器"""
def __init__(self) -> None:
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
self._parser = GeminiStreamParser()
self.name = "GEMINI"
self.api_format = "GEMINI"
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
"""
解析 Gemini SSE 行
Gemini 的流式响应使用 SSE 格式 (data: {...})
"""
if not line or not line.strip():
return None
# Gemini SSE 格式: data: {...}
if line.startswith("data: "):
data_str = line[6:]
else:
data_str = line
parsed = self._parser.parse_line(data_str)
if parsed is None:
return None
chunk = ParsedChunk(
raw_line=line,
event_type="content",
data=parsed,
)
# 提取文本增量
text_delta = self._parser.extract_text_delta(parsed)
if text_delta:
chunk.text_delta = text_delta
stats.collected_text += text_delta
# 检查是否结束
if self._parser.is_done_event(parsed):
chunk.is_done = True
stats.has_completion = True
# 提取 usage
usage = self._parser.extract_usage(parsed)
if usage:
chunk.input_tokens = usage.get("input_tokens", 0)
chunk.output_tokens = usage.get("output_tokens", 0)
chunk.cache_read_tokens = usage.get("cached_tokens", 0)
stats.input_tokens = chunk.input_tokens
stats.output_tokens = chunk.output_tokens
stats.cache_read_tokens = chunk.cache_read_tokens
# 检查错误
if self._parser.is_error_event(parsed):
chunk.is_error = True
error = parsed.get("error", {})
if isinstance(error, dict):
chunk.error_message = error.get("message", str(error))
else:
chunk.error_message = str(error)
stats.chunk_count += 1
stats.data_count += 1
return chunk
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
result = ParsedResponse(
raw_response=response,
status_code=status_code,
)
# 提取文本内容
candidates = response.get("candidates", [])
if candidates:
content = candidates[0].get("content", {})
parts = content.get("parts", [])
text_parts = []
for part in parts:
if "text" in part:
text_parts.append(part["text"])
result.text_content = "".join(text_parts)
result.response_id = response.get("modelVersion")
# 提取 usage调用 GeminiStreamParser.extract_usage 作为单一实现源)
usage = self._parser.extract_usage(response)
if usage:
result.input_tokens = usage.get("input_tokens", 0)
result.output_tokens = usage.get("output_tokens", 0)
result.cache_read_tokens = usage.get("cached_tokens", 0)
# 检查错误(使用增强的错误检测)
error_info = self._parser.extract_error_info(response)
if error_info:
result.is_error = True
result.error_type = error_info.get("status")
result.error_message = error_info.get("message")
return result
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
"""
从 Gemini 响应中提取 token 使用量
调用 GeminiStreamParser.extract_usage 作为单一实现源
"""
usage = self._parser.extract_usage(response)
if not usage:
return {
"input_tokens": 0,
"output_tokens": 0,
"cache_creation_tokens": 0,
"cache_read_tokens": 0,
}
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": 0,
"cache_read_tokens": usage.get("cached_tokens", 0),
}
def extract_text_content(self, response: Dict[str, Any]) -> str:
candidates = response.get("candidates", [])
if candidates:
content = candidates[0].get("content", {})
parts = content.get("parts", [])
text_parts = []
for part in parts:
if "text" in part:
text_parts.append(part["text"])
return "".join(text_parts)
return ""
def is_error_response(self, response: Dict[str, Any]) -> bool:
"""
判断响应是否为错误响应
使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误
"""
return bool(self._parser.is_error_event(response))
class GeminiCliResponseParser(GeminiResponseParser):
"""Gemini CLI 格式响应解析器"""
def __init__(self) -> None:
super().__init__()
self.name = "GEMINI_CLI"
self.api_format = "GEMINI_CLI"
# 解析器注册表
_PARSERS: Dict[str, Type[ResponseParser]] = {
"CLAUDE": ClaudeResponseParser,
"CLAUDE_CLI": ClaudeCliResponseParser,
"OPENAI": OpenAIResponseParser,
"OPENAI_CLI": OpenAICliResponseParser,
"GEMINI": GeminiResponseParser,
"GEMINI_CLI": GeminiCliResponseParser,
}
def get_parser_for_format(format_id: str) -> ResponseParser:
"""
根据格式 ID 获取 ResponseParser
Args:
format_id: 格式 ID"CLAUDE", "OPENAI", "CLAUDE_CLI", "OPENAI_CLI"
Returns:
ResponseParser 实例
Raises:
KeyError: 格式不存在
"""
format_id = format_id.upper()
if format_id not in _PARSERS:
raise KeyError(f"Unknown format: {format_id}")
return _PARSERS[format_id]()
def is_cli_format(format_id: str) -> bool:
"""判断是否为 CLI 格式"""
return format_id.upper().endswith("_CLI")
__all__ = [
"OpenAIResponseParser",
"OpenAICliResponseParser",
"ClaudeResponseParser",
"ClaudeCliResponseParser",
"GeminiResponseParser",
"GeminiCliResponseParser",
"get_parser_for_format",
"is_cli_format",
]