mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
275 lines
7.7 KiB
Python
275 lines
7.7 KiB
Python
|
|
"""
|
|||
|
|
流式内容提取器 - 策略模式实现
|
|||
|
|
|
|||
|
|
为不同 API 格式(OpenAI、Claude、Gemini)提供内容提取和 chunk 构造的抽象。
|
|||
|
|
StreamSmoother 使用这些提取器来处理不同格式的 SSE 事件。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import copy
|
|||
|
|
import json
|
|||
|
|
from abc import ABC, abstractmethod
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ContentExtractor(ABC):
|
|||
|
|
"""
|
|||
|
|
流式内容提取器抽象基类
|
|||
|
|
|
|||
|
|
定义从 SSE 事件中提取文本内容和构造新 chunk 的接口。
|
|||
|
|
每种 API 格式(OpenAI、Claude、Gemini)需要实现自己的提取器。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
def extract_content(self, data: dict) -> Optional[str]:
|
|||
|
|
"""
|
|||
|
|
从 SSE 数据中提取可拆分的文本内容
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
data: 解析后的 JSON 数据
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
提取的文本内容,如果无法提取则返回 None
|
|||
|
|
"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
def create_chunk(
|
|||
|
|
self,
|
|||
|
|
original_data: dict,
|
|||
|
|
new_content: str,
|
|||
|
|
event_type: str = "",
|
|||
|
|
is_first: bool = False,
|
|||
|
|
) -> bytes:
|
|||
|
|
"""
|
|||
|
|
使用新内容构造 SSE chunk
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
original_data: 原始 JSON 数据
|
|||
|
|
new_content: 新的文本内容
|
|||
|
|
event_type: SSE 事件类型(某些格式需要)
|
|||
|
|
is_first: 是否是第一个 chunk(用于保留 role 等字段)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
编码后的 SSE 字节数据
|
|||
|
|
"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
class OpenAIContentExtractor(ContentExtractor):
|
|||
|
|
"""
|
|||
|
|
OpenAI 格式内容提取器
|
|||
|
|
|
|||
|
|
处理 OpenAI Chat Completions API 的流式响应格式:
|
|||
|
|
- 数据结构: choices[0].delta.content
|
|||
|
|
- 只在 delta 仅包含 role/content 时允许拆分,避免破坏 tool_calls 等结构
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def extract_content(self, data: dict) -> Optional[str]:
|
|||
|
|
if not isinstance(data, dict):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
choices = data.get("choices")
|
|||
|
|
if not isinstance(choices, list) or len(choices) != 1:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
first_choice = choices[0]
|
|||
|
|
if not isinstance(first_choice, dict):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
delta = first_choice.get("delta")
|
|||
|
|
if not isinstance(delta, dict):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
content = delta.get("content")
|
|||
|
|
if not isinstance(content, str):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 只有 delta 仅包含 role/content 时才允许拆分
|
|||
|
|
# 避免破坏 tool_calls、function_call 等复杂结构
|
|||
|
|
allowed_keys = {"role", "content"}
|
|||
|
|
if not all(key in allowed_keys for key in delta.keys()):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
return content
|
|||
|
|
|
|||
|
|
def create_chunk(
|
|||
|
|
self,
|
|||
|
|
original_data: dict,
|
|||
|
|
new_content: str,
|
|||
|
|
event_type: str = "",
|
|||
|
|
is_first: bool = False,
|
|||
|
|
) -> bytes:
|
|||
|
|
new_data = original_data.copy()
|
|||
|
|
|
|||
|
|
if "choices" in new_data and new_data["choices"]:
|
|||
|
|
new_choices = []
|
|||
|
|
for choice in new_data["choices"]:
|
|||
|
|
new_choice = choice.copy()
|
|||
|
|
if "delta" in new_choice:
|
|||
|
|
new_delta = {}
|
|||
|
|
# 只有第一个 chunk 保留 role
|
|||
|
|
if is_first and "role" in new_choice["delta"]:
|
|||
|
|
new_delta["role"] = new_choice["delta"]["role"]
|
|||
|
|
new_delta["content"] = new_content
|
|||
|
|
new_choice["delta"] = new_delta
|
|||
|
|
new_choices.append(new_choice)
|
|||
|
|
new_data["choices"] = new_choices
|
|||
|
|
|
|||
|
|
return f"data: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode("utf-8")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ClaudeContentExtractor(ContentExtractor):
|
|||
|
|
"""
|
|||
|
|
Claude 格式内容提取器
|
|||
|
|
|
|||
|
|
处理 Claude Messages API 的流式响应格式:
|
|||
|
|
- 事件类型: content_block_delta
|
|||
|
|
- 数据结构: delta.type=text_delta, delta.text
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def extract_content(self, data: dict) -> Optional[str]:
|
|||
|
|
if not isinstance(data, dict):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 检查事件类型
|
|||
|
|
if data.get("type") != "content_block_delta":
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
delta = data.get("delta", {})
|
|||
|
|
if not isinstance(delta, dict):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# 检查 delta 类型
|
|||
|
|
if delta.get("type") != "text_delta":
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
text = delta.get("text")
|
|||
|
|
if not isinstance(text, str):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
return text
|
|||
|
|
|
|||
|
|
def create_chunk(
|
|||
|
|
self,
|
|||
|
|
original_data: dict,
|
|||
|
|
new_content: str,
|
|||
|
|
event_type: str = "",
|
|||
|
|
is_first: bool = False,
|
|||
|
|
) -> bytes:
|
|||
|
|
new_data = original_data.copy()
|
|||
|
|
|
|||
|
|
if "delta" in new_data:
|
|||
|
|
new_delta = new_data["delta"].copy()
|
|||
|
|
new_delta["text"] = new_content
|
|||
|
|
new_data["delta"] = new_delta
|
|||
|
|
|
|||
|
|
# Claude 格式需要 event: 前缀
|
|||
|
|
event_name = event_type or "content_block_delta"
|
|||
|
|
return f"event: {event_name}\ndata: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode(
|
|||
|
|
"utf-8"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class GeminiContentExtractor(ContentExtractor):
|
|||
|
|
"""
|
|||
|
|
Gemini 格式内容提取器
|
|||
|
|
|
|||
|
|
处理 Gemini API 的流式响应格式:
|
|||
|
|
- 数据结构: candidates[0].content.parts[0].text
|
|||
|
|
- 只有纯文本块才拆分
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def extract_content(self, data: dict) -> Optional[str]:
|
|||
|
|
if not isinstance(data, dict):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
candidates = data.get("candidates")
|
|||
|
|
if not isinstance(candidates, list) or len(candidates) != 1:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
first_candidate = candidates[0]
|
|||
|
|
if not isinstance(first_candidate, dict):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
content = first_candidate.get("content", {})
|
|||
|
|
if not isinstance(content, dict):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
parts = content.get("parts", [])
|
|||
|
|
if not isinstance(parts, list) or len(parts) != 1:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
first_part = parts[0]
|
|||
|
|
if not isinstance(first_part, dict):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
text = first_part.get("text")
|
|||
|
|
# 只有纯文本块(只有 text 字段)才拆分
|
|||
|
|
if not isinstance(text, str) or len(first_part) != 1:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
return text
|
|||
|
|
|
|||
|
|
def create_chunk(
|
|||
|
|
self,
|
|||
|
|
original_data: dict,
|
|||
|
|
new_content: str,
|
|||
|
|
event_type: str = "",
|
|||
|
|
is_first: bool = False,
|
|||
|
|
) -> bytes:
|
|||
|
|
new_data = copy.deepcopy(original_data)
|
|||
|
|
|
|||
|
|
if "candidates" in new_data and new_data["candidates"]:
|
|||
|
|
first_candidate = new_data["candidates"][0]
|
|||
|
|
if "content" in first_candidate:
|
|||
|
|
content = first_candidate["content"]
|
|||
|
|
if "parts" in content and content["parts"]:
|
|||
|
|
content["parts"][0]["text"] = new_content
|
|||
|
|
|
|||
|
|
return f"data: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode("utf-8")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 提取器注册表
|
|||
|
|
_EXTRACTORS: dict[str, type[ContentExtractor]] = {
|
|||
|
|
"openai": OpenAIContentExtractor,
|
|||
|
|
"claude": ClaudeContentExtractor,
|
|||
|
|
"gemini": GeminiContentExtractor,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_extractor(format_name: str) -> Optional[ContentExtractor]:
|
|||
|
|
"""
|
|||
|
|
根据格式名获取对应的内容提取器实例
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
format_name: 格式名称(openai, claude, gemini)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
对应的提取器实例,如果格式不支持则返回 None
|
|||
|
|
"""
|
|||
|
|
extractor_class = _EXTRACTORS.get(format_name.lower())
|
|||
|
|
if extractor_class:
|
|||
|
|
return extractor_class()
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def register_extractor(format_name: str, extractor_class: type[ContentExtractor]) -> None:
|
|||
|
|
"""
|
|||
|
|
注册新的内容提取器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
format_name: 格式名称
|
|||
|
|
extractor_class: 提取器类
|
|||
|
|
"""
|
|||
|
|
_EXTRACTORS[format_name.lower()] = extractor_class
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_extractor_formats() -> list[str]:
|
|||
|
|
"""
|
|||
|
|
获取所有已注册的格式名称列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
格式名称列表
|
|||
|
|
"""
|
|||
|
|
return list(_EXTRACTORS.keys())
|