Files
Aether/src/api/handlers/gemini/stream_parser.py

316 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Gemini SSE/JSON 流解析器
Gemini API 的流式响应格式与 Claude/OpenAI 不同:
- 使用 JSON 数组格式 (不是 SSE)
- 每个块是一个完整的 JSON 对象
- 响应以 [ 开始,以 ] 结束,块之间用 , 分隔
参考: https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
"""
import json
from typing import Any, Dict, List, Optional, Union
class GeminiStreamParser:
"""
Gemini 流解析器
解析 Gemini streamGenerateContent API 的响应流。
Gemini 流式响应特点:
- 返回 JSON 数组格式: [{chunk1}, {chunk2}, ...]
- 每个 chunk 包含 candidates、usageMetadata 等字段
- finish_reason 可能值: STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
"""
# 停止原因
FINISH_REASON_STOP = "STOP"
FINISH_REASON_MAX_TOKENS = "MAX_TOKENS"
FINISH_REASON_SAFETY = "SAFETY"
FINISH_REASON_RECITATION = "RECITATION"
FINISH_REASON_OTHER = "OTHER"
def __init__(self) -> None:
self._buffer = ""
self._in_array = False
self._brace_depth = 0
def reset(self) -> None:
"""重置解析器状态"""
self._buffer = ""
self._in_array = False
self._brace_depth = 0
def parse_chunk(self, chunk: Union[bytes, str]) -> List[Dict[str, Any]]:
"""
解析流式数据块
Args:
chunk: 原始数据bytes 或 str
Returns:
解析后的事件列表
"""
if isinstance(chunk, bytes):
text = chunk.decode("utf-8")
else:
text = chunk
events: List[Dict[str, Any]] = []
for char in text:
if char == "[" and not self._in_array:
self._in_array = True
continue
if char == "]" and self._in_array and self._brace_depth == 0:
# 数组结束
self._in_array = False
if self._buffer.strip():
try:
obj = json.loads(self._buffer.strip().rstrip(","))
events.append(obj)
except json.JSONDecodeError:
pass
self._buffer = ""
continue
if self._in_array:
if char == "{":
self._brace_depth += 1
elif char == "}":
self._brace_depth -= 1
self._buffer += char
# 当 brace_depth 回到 0 时,说明一个完整的 JSON 对象结束
if self._brace_depth == 0 and self._buffer.strip():
try:
obj = json.loads(self._buffer.strip().rstrip(","))
events.append(obj)
self._buffer = ""
except json.JSONDecodeError:
# 可能还不完整,继续累积
pass
return events
def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
"""
解析单行 JSON 数据
Args:
line: JSON 数据行
Returns:
解析后的事件字典,如果无法解析返回 None
"""
if not line or line.strip() in ["[", "]", ","]:
return None
try:
result = json.loads(line.strip().rstrip(","))
if isinstance(result, dict):
return result
return None
except json.JSONDecodeError:
return None
def is_done_event(self, event: Dict[str, Any]) -> bool:
"""
判断是否为结束事件
Args:
event: 事件字典
Returns:
True 如果是结束事件
"""
candidates = event.get("candidates", [])
if candidates:
for candidate in candidates:
finish_reason = candidate.get("finishReason")
if finish_reason in (
self.FINISH_REASON_STOP,
self.FINISH_REASON_MAX_TOKENS,
self.FINISH_REASON_SAFETY,
self.FINISH_REASON_RECITATION,
self.FINISH_REASON_OTHER,
):
return True
return False
def is_error_event(self, event: Dict[str, Any]) -> bool:
"""
判断是否为错误事件
检测多种 Gemini 错误格式:
1. 顶层 error: {"error": {...}}
2. chunks 内嵌套 error: {"chunks": [{"error": {...}}]}
3. candidates 内的错误状态
Args:
event: 事件字典
Returns:
True 如果是错误事件
"""
# 顶层 error
if "error" in event:
return True
# chunks 内嵌套 error (某些 Gemini 响应格式)
chunks = event.get("chunks", [])
if chunks and isinstance(chunks, list):
for chunk in chunks:
if isinstance(chunk, dict) and "error" in chunk:
return True
return False
def extract_error_info(self, event: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
从事件中提取错误信息
Args:
event: 事件字典
Returns:
错误信息字典 {"code": int, "message": str, "status": str},无错误返回 None
"""
# 顶层 error
if "error" in event:
error = event["error"]
if isinstance(error, dict):
return {
"code": error.get("code"),
"message": error.get("message", str(error)),
"status": error.get("status"),
}
return {"code": None, "message": str(error), "status": None}
# chunks 内嵌套 error
chunks = event.get("chunks", [])
if chunks and isinstance(chunks, list):
for chunk in chunks:
if isinstance(chunk, dict) and "error" in chunk:
error = chunk["error"]
if isinstance(error, dict):
return {
"code": error.get("code"),
"message": error.get("message", str(error)),
"status": error.get("status"),
}
return {"code": None, "message": str(error), "status": None}
return None
def get_finish_reason(self, event: Dict[str, Any]) -> Optional[str]:
"""
获取结束原因
Args:
event: 事件字典
Returns:
结束原因字符串
"""
candidates = event.get("candidates", [])
if candidates:
reason = candidates[0].get("finishReason")
return str(reason) if reason is not None else None
return None
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
"""
从响应中提取文本内容
Args:
event: 事件字典
Returns:
文本内容,如果没有文本返回 None
"""
candidates = event.get("candidates", [])
if not candidates:
return None
content = candidates[0].get("content", {})
parts = content.get("parts", [])
text_parts = []
for part in parts:
if "text" in part:
text_parts.append(part["text"])
return "".join(text_parts) if text_parts else None
def extract_usage(self, event: Dict[str, Any]) -> Optional[Dict[str, int]]:
"""
从事件中提取 token 使用量
这是 Gemini token 提取的单一实现源,其他地方都应该调用此方法。
Args:
event: 事件字典(包含 usageMetadata
Returns:
使用量字典,如果没有完整的使用量信息返回 None
注意:
- 只有当 totalTokenCount 存在时才提取(确保是完整的 usage 数据)
- 输出 token = thoughtsTokenCount + candidatesTokenCount
"""
usage_metadata = event.get("usageMetadata", {})
if not usage_metadata or "totalTokenCount" not in usage_metadata:
return None
# 输出 token = thoughtsTokenCount + candidatesTokenCount
thoughts_tokens = usage_metadata.get("thoughtsTokenCount", 0)
candidates_tokens = usage_metadata.get("candidatesTokenCount", 0)
output_tokens = thoughts_tokens + candidates_tokens
return {
"input_tokens": usage_metadata.get("promptTokenCount", 0),
"output_tokens": output_tokens,
"total_tokens": usage_metadata.get("totalTokenCount", 0),
"cached_tokens": usage_metadata.get("cachedContentTokenCount", 0),
}
def extract_model_version(self, event: Dict[str, Any]) -> Optional[str]:
"""
从响应中提取模型版本
Args:
event: 事件字典
Returns:
模型版本,如果没有返回 None
"""
version = event.get("modelVersion")
return str(version) if version is not None else None
def extract_safety_ratings(self, event: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
"""
从响应中提取安全评级
Args:
event: 事件字典
Returns:
安全评级列表,如果没有返回 None
"""
candidates = event.get("candidates", [])
if not candidates:
return None
ratings = candidates[0].get("safetyRatings")
if isinstance(ratings, list):
return ratings
return None
__all__ = ["GeminiStreamParser"]