mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-05 09:12:27 +08:00
316 lines
9.6 KiB
Python
316 lines
9.6 KiB
Python
"""
|
||
Gemini SSE/JSON 流解析器
|
||
|
||
Gemini API 的流式响应格式与 Claude/OpenAI 不同:
|
||
- 使用 JSON 数组格式 (不是 SSE)
|
||
- 每个块是一个完整的 JSON 对象
|
||
- 响应以 [ 开始,以 ] 结束,块之间用 , 分隔
|
||
|
||
参考: https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
|
||
"""
|
||
|
||
import json
|
||
from typing import Any, Dict, List, Optional, Union
|
||
|
||
|
||
class GeminiStreamParser:
|
||
"""
|
||
Gemini 流解析器
|
||
|
||
解析 Gemini streamGenerateContent API 的响应流。
|
||
|
||
Gemini 流式响应特点:
|
||
- 返回 JSON 数组格式: [{chunk1}, {chunk2}, ...]
|
||
- 每个 chunk 包含 candidates、usageMetadata 等字段
|
||
- finish_reason 可能值: STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
|
||
"""
|
||
|
||
# 停止原因
|
||
FINISH_REASON_STOP = "STOP"
|
||
FINISH_REASON_MAX_TOKENS = "MAX_TOKENS"
|
||
FINISH_REASON_SAFETY = "SAFETY"
|
||
FINISH_REASON_RECITATION = "RECITATION"
|
||
FINISH_REASON_OTHER = "OTHER"
|
||
|
||
def __init__(self) -> None:
|
||
self._buffer = ""
|
||
self._in_array = False
|
||
self._brace_depth = 0
|
||
|
||
def reset(self) -> None:
|
||
"""重置解析器状态"""
|
||
self._buffer = ""
|
||
self._in_array = False
|
||
self._brace_depth = 0
|
||
|
||
def parse_chunk(self, chunk: Union[bytes, str]) -> List[Dict[str, Any]]:
|
||
"""
|
||
解析流式数据块
|
||
|
||
Args:
|
||
chunk: 原始数据(bytes 或 str)
|
||
|
||
Returns:
|
||
解析后的事件列表
|
||
"""
|
||
if isinstance(chunk, bytes):
|
||
text = chunk.decode("utf-8")
|
||
else:
|
||
text = chunk
|
||
|
||
events: List[Dict[str, Any]] = []
|
||
|
||
for char in text:
|
||
if char == "[" and not self._in_array:
|
||
self._in_array = True
|
||
continue
|
||
|
||
if char == "]" and self._in_array and self._brace_depth == 0:
|
||
# 数组结束
|
||
self._in_array = False
|
||
if self._buffer.strip():
|
||
try:
|
||
obj = json.loads(self._buffer.strip().rstrip(","))
|
||
events.append(obj)
|
||
except json.JSONDecodeError:
|
||
pass
|
||
self._buffer = ""
|
||
continue
|
||
|
||
if self._in_array:
|
||
if char == "{":
|
||
self._brace_depth += 1
|
||
elif char == "}":
|
||
self._brace_depth -= 1
|
||
|
||
self._buffer += char
|
||
|
||
# 当 brace_depth 回到 0 时,说明一个完整的 JSON 对象结束
|
||
if self._brace_depth == 0 and self._buffer.strip():
|
||
try:
|
||
obj = json.loads(self._buffer.strip().rstrip(","))
|
||
events.append(obj)
|
||
self._buffer = ""
|
||
except json.JSONDecodeError:
|
||
# 可能还不完整,继续累积
|
||
pass
|
||
|
||
return events
|
||
|
||
def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
解析单行 JSON 数据
|
||
|
||
Args:
|
||
line: JSON 数据行
|
||
|
||
Returns:
|
||
解析后的事件字典,如果无法解析返回 None
|
||
"""
|
||
if not line or line.strip() in ["[", "]", ","]:
|
||
return None
|
||
|
||
try:
|
||
result = json.loads(line.strip().rstrip(","))
|
||
if isinstance(result, dict):
|
||
return result
|
||
return None
|
||
except json.JSONDecodeError:
|
||
return None
|
||
|
||
def is_done_event(self, event: Dict[str, Any]) -> bool:
|
||
"""
|
||
判断是否为结束事件
|
||
|
||
Args:
|
||
event: 事件字典
|
||
|
||
Returns:
|
||
True 如果是结束事件
|
||
"""
|
||
candidates = event.get("candidates", [])
|
||
if candidates:
|
||
for candidate in candidates:
|
||
finish_reason = candidate.get("finishReason")
|
||
if finish_reason in (
|
||
self.FINISH_REASON_STOP,
|
||
self.FINISH_REASON_MAX_TOKENS,
|
||
self.FINISH_REASON_SAFETY,
|
||
self.FINISH_REASON_RECITATION,
|
||
self.FINISH_REASON_OTHER,
|
||
):
|
||
return True
|
||
return False
|
||
|
||
def is_error_event(self, event: Dict[str, Any]) -> bool:
|
||
"""
|
||
判断是否为错误事件
|
||
|
||
检测多种 Gemini 错误格式:
|
||
1. 顶层 error: {"error": {...}}
|
||
2. chunks 内嵌套 error: {"chunks": [{"error": {...}}]}
|
||
3. candidates 内的错误状态
|
||
|
||
Args:
|
||
event: 事件字典
|
||
|
||
Returns:
|
||
True 如果是错误事件
|
||
"""
|
||
# 顶层 error
|
||
if "error" in event:
|
||
return True
|
||
|
||
# chunks 内嵌套 error (某些 Gemini 响应格式)
|
||
chunks = event.get("chunks", [])
|
||
if chunks and isinstance(chunks, list):
|
||
for chunk in chunks:
|
||
if isinstance(chunk, dict) and "error" in chunk:
|
||
return True
|
||
|
||
return False
|
||
|
||
def extract_error_info(self, event: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
从事件中提取错误信息
|
||
|
||
Args:
|
||
event: 事件字典
|
||
|
||
Returns:
|
||
错误信息字典 {"code": int, "message": str, "status": str},无错误返回 None
|
||
"""
|
||
# 顶层 error
|
||
if "error" in event:
|
||
error = event["error"]
|
||
if isinstance(error, dict):
|
||
return {
|
||
"code": error.get("code"),
|
||
"message": error.get("message", str(error)),
|
||
"status": error.get("status"),
|
||
}
|
||
return {"code": None, "message": str(error), "status": None}
|
||
|
||
# chunks 内嵌套 error
|
||
chunks = event.get("chunks", [])
|
||
if chunks and isinstance(chunks, list):
|
||
for chunk in chunks:
|
||
if isinstance(chunk, dict) and "error" in chunk:
|
||
error = chunk["error"]
|
||
if isinstance(error, dict):
|
||
return {
|
||
"code": error.get("code"),
|
||
"message": error.get("message", str(error)),
|
||
"status": error.get("status"),
|
||
}
|
||
return {"code": None, "message": str(error), "status": None}
|
||
|
||
return None
|
||
|
||
def get_finish_reason(self, event: Dict[str, Any]) -> Optional[str]:
|
||
"""
|
||
获取结束原因
|
||
|
||
Args:
|
||
event: 事件字典
|
||
|
||
Returns:
|
||
结束原因字符串
|
||
"""
|
||
candidates = event.get("candidates", [])
|
||
if candidates:
|
||
reason = candidates[0].get("finishReason")
|
||
return str(reason) if reason is not None else None
|
||
return None
|
||
|
||
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
|
||
"""
|
||
从响应中提取文本内容
|
||
|
||
Args:
|
||
event: 事件字典
|
||
|
||
Returns:
|
||
文本内容,如果没有文本返回 None
|
||
"""
|
||
candidates = event.get("candidates", [])
|
||
if not candidates:
|
||
return None
|
||
|
||
content = candidates[0].get("content", {})
|
||
parts = content.get("parts", [])
|
||
|
||
text_parts = []
|
||
for part in parts:
|
||
if "text" in part:
|
||
text_parts.append(part["text"])
|
||
|
||
return "".join(text_parts) if text_parts else None
|
||
|
||
def extract_usage(self, event: Dict[str, Any]) -> Optional[Dict[str, int]]:
|
||
"""
|
||
从事件中提取 token 使用量
|
||
|
||
这是 Gemini token 提取的单一实现源,其他地方都应该调用此方法。
|
||
|
||
Args:
|
||
event: 事件字典(包含 usageMetadata)
|
||
|
||
Returns:
|
||
使用量字典,如果没有完整的使用量信息返回 None
|
||
|
||
注意:
|
||
- 只有当 totalTokenCount 存在时才提取(确保是完整的 usage 数据)
|
||
- 输出 token = thoughtsTokenCount + candidatesTokenCount
|
||
"""
|
||
usage_metadata = event.get("usageMetadata", {})
|
||
if not usage_metadata or "totalTokenCount" not in usage_metadata:
|
||
return None
|
||
|
||
# 输出 token = thoughtsTokenCount + candidatesTokenCount
|
||
thoughts_tokens = usage_metadata.get("thoughtsTokenCount", 0)
|
||
candidates_tokens = usage_metadata.get("candidatesTokenCount", 0)
|
||
output_tokens = thoughts_tokens + candidates_tokens
|
||
|
||
return {
|
||
"input_tokens": usage_metadata.get("promptTokenCount", 0),
|
||
"output_tokens": output_tokens,
|
||
"total_tokens": usage_metadata.get("totalTokenCount", 0),
|
||
"cached_tokens": usage_metadata.get("cachedContentTokenCount", 0),
|
||
}
|
||
|
||
def extract_model_version(self, event: Dict[str, Any]) -> Optional[str]:
|
||
"""
|
||
从响应中提取模型版本
|
||
|
||
Args:
|
||
event: 事件字典
|
||
|
||
Returns:
|
||
模型版本,如果没有返回 None
|
||
"""
|
||
version = event.get("modelVersion")
|
||
return str(version) if version is not None else None
|
||
|
||
def extract_safety_ratings(self, event: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
|
||
"""
|
||
从响应中提取安全评级
|
||
|
||
Args:
|
||
event: 事件字典
|
||
|
||
Returns:
|
||
安全评级列表,如果没有返回 None
|
||
"""
|
||
candidates = event.get("candidates", [])
|
||
if not candidates:
|
||
return None
|
||
|
||
ratings = candidates[0].get("safetyRatings")
|
||
if isinstance(ratings, list):
|
||
return ratings
|
||
return None
|
||
|
||
|
||
__all__ = ["GeminiStreamParser"]
|