mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
refactor(backend): update handlers, utilities and core modules after models restructure
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import List, Sequence, Tuple, TypeVar
|
from typing import Any, List, Sequence, Tuple, TypeVar
|
||||||
|
|
||||||
from sqlalchemy.orm import Query
|
from sqlalchemy.orm import Query
|
||||||
|
|
||||||
@@ -40,10 +40,10 @@ def paginate_sequence(
|
|||||||
return sliced, meta
|
return sliced, meta
|
||||||
|
|
||||||
|
|
||||||
def build_pagination_payload(items: List[dict], meta: PaginationMeta, **extra) -> dict:
|
def build_pagination_payload(items: List[dict], meta: PaginationMeta, **extra: Any) -> dict:
|
||||||
"""
|
"""
|
||||||
构建标准分页响应 payload。
|
构建标准分页响应 payload。
|
||||||
"""
|
"""
|
||||||
payload = {"items": items, "meta": meta.to_dict()}
|
payload: dict = {"items": items, "meta": meta.to_dict()}
|
||||||
payload.update(extra)
|
payload.update(extra)
|
||||||
return payload
|
return payload
|
||||||
|
|||||||
@@ -263,7 +263,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
mapping = await mapper.get_mapping(source_model, provider_id)
|
mapping = await mapper.get_mapping(source_model, provider_id)
|
||||||
|
|
||||||
if mapping and mapping.model:
|
if mapping and mapping.model:
|
||||||
mapped_name = str(mapping.model.provider_model_name)
|
# 使用 select_provider_model_name 支持别名功能
|
||||||
|
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||||
|
affinity_key = self.api_key.id if self.api_key else None
|
||||||
|
mapped_name = mapping.model.select_provider_model_name(affinity_key)
|
||||||
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
|
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
|
||||||
return mapped_name
|
return mapped_name
|
||||||
|
|
||||||
|
|||||||
@@ -190,14 +190,17 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
"""
|
"""
|
||||||
获取模型映射后的实际模型名
|
获取模型映射后的实际模型名
|
||||||
|
|
||||||
按优先级查找:映射 → 别名 → 直接匹配 GlobalModel
|
查找逻辑:
|
||||||
|
1. 直接通过 GlobalModel.name 匹配
|
||||||
|
2. 查找该 Provider 的 Model 实现
|
||||||
|
3. 使用 provider_model_name / provider_model_aliases 选择最终名称
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_model: 用户请求的模型名(可能是别名)
|
source_model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||||
provider_id: Provider ID
|
provider_id: Provider ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
映射后的 provider_model_name,如果没有找到映射则返回 None
|
映射后的 Provider 模型名,如果没有找到映射则返回 None
|
||||||
"""
|
"""
|
||||||
from src.services.model.mapper import ModelMapperMiddleware
|
from src.services.model.mapper import ModelMapperMiddleware
|
||||||
|
|
||||||
@@ -207,7 +210,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
|
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
|
||||||
|
|
||||||
if mapping and mapping.model:
|
if mapping and mapping.model:
|
||||||
mapped_name = str(mapping.model.provider_model_name)
|
# 使用 select_provider_model_name 支持别名功能
|
||||||
|
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||||
|
affinity_key = self.api_key.id if self.api_key else None
|
||||||
|
mapped_name = mapping.model.select_provider_model_name(affinity_key)
|
||||||
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
||||||
return mapped_name
|
return mapped_name
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
不再经过 Protocol 抽象层。
|
不再经过 Protocol 抽象层。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
from src.api.handlers.base.response_parser import (
|
from src.api.handlers.base.response_parser import (
|
||||||
ParsedChunk,
|
ParsedChunk,
|
||||||
@@ -60,7 +60,7 @@ def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[s
|
|||||||
class OpenAIResponseParser(ResponseParser):
|
class OpenAIResponseParser(ResponseParser):
|
||||||
"""OpenAI 格式响应解析器"""
|
"""OpenAI 格式响应解析器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
from src.api.handlers.openai.stream_parser import OpenAIStreamParser
|
from src.api.handlers.openai.stream_parser import OpenAIStreamParser
|
||||||
|
|
||||||
self._parser = OpenAIStreamParser()
|
self._parser = OpenAIStreamParser()
|
||||||
@@ -146,7 +146,7 @@ class OpenAIResponseParser(ResponseParser):
|
|||||||
if choices:
|
if choices:
|
||||||
message = choices[0].get("message", {})
|
message = choices[0].get("message", {})
|
||||||
content = message.get("content")
|
content = message.get("content")
|
||||||
if content:
|
if isinstance(content, str):
|
||||||
return content
|
return content
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@@ -158,7 +158,7 @@ class OpenAIResponseParser(ResponseParser):
|
|||||||
class OpenAICliResponseParser(OpenAIResponseParser):
|
class OpenAICliResponseParser(OpenAIResponseParser):
|
||||||
"""OpenAI CLI 格式响应解析器"""
|
"""OpenAI CLI 格式响应解析器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.name = "OPENAI_CLI"
|
self.name = "OPENAI_CLI"
|
||||||
self.api_format = "OPENAI_CLI"
|
self.api_format = "OPENAI_CLI"
|
||||||
@@ -167,7 +167,7 @@ class OpenAICliResponseParser(OpenAIResponseParser):
|
|||||||
class ClaudeResponseParser(ResponseParser):
|
class ClaudeResponseParser(ResponseParser):
|
||||||
"""Claude 格式响应解析器"""
|
"""Claude 格式响应解析器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
from src.api.handlers.claude.stream_parser import ClaudeStreamParser
|
from src.api.handlers.claude.stream_parser import ClaudeStreamParser
|
||||||
|
|
||||||
self._parser = ClaudeStreamParser()
|
self._parser = ClaudeStreamParser()
|
||||||
@@ -291,7 +291,7 @@ class ClaudeResponseParser(ResponseParser):
|
|||||||
class ClaudeCliResponseParser(ClaudeResponseParser):
|
class ClaudeCliResponseParser(ClaudeResponseParser):
|
||||||
"""Claude CLI 格式响应解析器"""
|
"""Claude CLI 格式响应解析器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.name = "CLAUDE_CLI"
|
self.name = "CLAUDE_CLI"
|
||||||
self.api_format = "CLAUDE_CLI"
|
self.api_format = "CLAUDE_CLI"
|
||||||
@@ -300,7 +300,7 @@ class ClaudeCliResponseParser(ClaudeResponseParser):
|
|||||||
class GeminiResponseParser(ResponseParser):
|
class GeminiResponseParser(ResponseParser):
|
||||||
"""Gemini 格式响应解析器"""
|
"""Gemini 格式响应解析器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
|
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
|
||||||
|
|
||||||
self._parser = GeminiStreamParser()
|
self._parser = GeminiStreamParser()
|
||||||
@@ -443,20 +443,20 @@ class GeminiResponseParser(ResponseParser):
|
|||||||
|
|
||||||
使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误
|
使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误
|
||||||
"""
|
"""
|
||||||
return self._parser.is_error_event(response)
|
return bool(self._parser.is_error_event(response))
|
||||||
|
|
||||||
|
|
||||||
class GeminiCliResponseParser(GeminiResponseParser):
|
class GeminiCliResponseParser(GeminiResponseParser):
|
||||||
"""Gemini CLI 格式响应解析器"""
|
"""Gemini CLI 格式响应解析器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.name = "GEMINI_CLI"
|
self.name = "GEMINI_CLI"
|
||||||
self.api_format = "GEMINI_CLI"
|
self.api_format = "GEMINI_CLI"
|
||||||
|
|
||||||
|
|
||||||
# 解析器注册表
|
# 解析器注册表
|
||||||
_PARSERS = {
|
_PARSERS: Dict[str, Type[ResponseParser]] = {
|
||||||
"CLAUDE": ClaudeResponseParser,
|
"CLAUDE": ClaudeResponseParser,
|
||||||
"CLAUDE_CLI": ClaudeCliResponseParser,
|
"CLAUDE_CLI": ClaudeCliResponseParser,
|
||||||
"OPENAI": OpenAIResponseParser,
|
"OPENAI": OpenAIResponseParser,
|
||||||
@@ -498,6 +498,5 @@ __all__ = [
|
|||||||
"GeminiResponseParser",
|
"GeminiResponseParser",
|
||||||
"GeminiCliResponseParser",
|
"GeminiCliResponseParser",
|
||||||
"get_parser_for_format",
|
"get_parser_for_format",
|
||||||
"get_parser_from_protocol",
|
|
||||||
"is_cli_format",
|
"is_cli_format",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -108,7 +108,10 @@ class ClaudeStreamParser:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return json.loads(line)
|
result = json.loads(line)
|
||||||
|
if isinstance(result, dict):
|
||||||
|
return result
|
||||||
|
return None
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -147,7 +150,8 @@ class ClaudeStreamParser:
|
|||||||
Returns:
|
Returns:
|
||||||
事件类型字符串
|
事件类型字符串
|
||||||
"""
|
"""
|
||||||
return event.get("type")
|
event_type = event.get("type")
|
||||||
|
return str(event_type) if event_type is not None else None
|
||||||
|
|
||||||
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
|
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
@@ -164,7 +168,8 @@ class ClaudeStreamParser:
|
|||||||
|
|
||||||
delta = event.get("delta", {})
|
delta = event.get("delta", {})
|
||||||
if delta.get("type") == self.DELTA_TEXT:
|
if delta.get("type") == self.DELTA_TEXT:
|
||||||
return delta.get("text")
|
text = delta.get("text")
|
||||||
|
return str(text) if text is not None else None
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -219,7 +224,8 @@ class ClaudeStreamParser:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
message = event.get("message", {})
|
message = event.get("message", {})
|
||||||
return message.get("id")
|
msg_id = message.get("id")
|
||||||
|
return str(msg_id) if msg_id is not None else None
|
||||||
|
|
||||||
def extract_stop_reason(self, event: Dict[str, Any]) -> Optional[str]:
|
def extract_stop_reason(self, event: Dict[str, Any]) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
@@ -235,7 +241,8 @@ class ClaudeStreamParser:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
delta = event.get("delta", {})
|
delta = event.get("delta", {})
|
||||||
return delta.get("stop_reason")
|
reason = delta.get("stop_reason")
|
||||||
|
return str(reason) if reason is not None else None
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["ClaudeStreamParser"]
|
__all__ = ["ClaudeStreamParser"]
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class ClaudeToGeminiConverter:
|
|||||||
return [{"text": content}]
|
return [{"text": content}]
|
||||||
|
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
parts = []
|
parts: List[Dict[str, Any]] = []
|
||||||
for block in content:
|
for block in content:
|
||||||
if isinstance(block, str):
|
if isinstance(block, str):
|
||||||
parts.append({"text": block})
|
parts.append({"text": block})
|
||||||
@@ -249,6 +249,8 @@ class GeminiToClaudeConverter:
|
|||||||
"RECITATION": "content_filtered",
|
"RECITATION": "content_filtered",
|
||||||
"OTHER": "stop_sequence",
|
"OTHER": "stop_sequence",
|
||||||
}
|
}
|
||||||
|
if finish_reason is None:
|
||||||
|
return "end_turn"
|
||||||
return mapping.get(finish_reason, "end_turn")
|
return mapping.get(finish_reason, "end_turn")
|
||||||
|
|
||||||
def _create_empty_response(self) -> Dict[str, Any]:
|
def _create_empty_response(self) -> Dict[str, Any]:
|
||||||
@@ -365,7 +367,7 @@ class OpenAIToGeminiConverter:
|
|||||||
return [{"text": content}]
|
return [{"text": content}]
|
||||||
|
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
parts = []
|
parts: List[Dict[str, Any]] = []
|
||||||
for item in content:
|
for item in content:
|
||||||
if isinstance(item, str):
|
if isinstance(item, str):
|
||||||
parts.append({"text": item})
|
parts.append({"text": item})
|
||||||
@@ -524,7 +526,7 @@ class GeminiToOpenAIConverter:
|
|||||||
"total_tokens": prompt_tokens + completion_tokens,
|
"total_tokens": prompt_tokens + completion_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _convert_finish_reason(self, finish_reason: Optional[str]) -> Optional[str]:
|
def _convert_finish_reason(self, finish_reason: Optional[str]) -> str:
|
||||||
"""转换停止原因"""
|
"""转换停止原因"""
|
||||||
mapping = {
|
mapping = {
|
||||||
"STOP": "stop",
|
"STOP": "stop",
|
||||||
@@ -533,6 +535,8 @@ class GeminiToOpenAIConverter:
|
|||||||
"RECITATION": "content_filter",
|
"RECITATION": "content_filter",
|
||||||
"OTHER": "stop",
|
"OTHER": "stop",
|
||||||
}
|
}
|
||||||
|
if finish_reason is None:
|
||||||
|
return "stop"
|
||||||
return mapping.get(finish_reason, "stop")
|
return mapping.get(finish_reason, "stop")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ Gemini API 的流式响应格式与 Claude/OpenAI 不同:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
class GeminiStreamParser:
|
class GeminiStreamParser:
|
||||||
@@ -32,18 +32,18 @@ class GeminiStreamParser:
|
|||||||
FINISH_REASON_RECITATION = "RECITATION"
|
FINISH_REASON_RECITATION = "RECITATION"
|
||||||
FINISH_REASON_OTHER = "OTHER"
|
FINISH_REASON_OTHER = "OTHER"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self._buffer = ""
|
self._buffer = ""
|
||||||
self._in_array = False
|
self._in_array = False
|
||||||
self._brace_depth = 0
|
self._brace_depth = 0
|
||||||
|
|
||||||
def reset(self):
|
def reset(self) -> None:
|
||||||
"""重置解析器状态"""
|
"""重置解析器状态"""
|
||||||
self._buffer = ""
|
self._buffer = ""
|
||||||
self._in_array = False
|
self._in_array = False
|
||||||
self._brace_depth = 0
|
self._brace_depth = 0
|
||||||
|
|
||||||
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
|
def parse_chunk(self, chunk: Union[bytes, str]) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
解析流式数据块
|
解析流式数据块
|
||||||
|
|
||||||
@@ -111,7 +111,10 @@ class GeminiStreamParser:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return json.loads(line.strip().rstrip(","))
|
result = json.loads(line.strip().rstrip(","))
|
||||||
|
if isinstance(result, dict):
|
||||||
|
return result
|
||||||
|
return None
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -216,7 +219,8 @@ class GeminiStreamParser:
|
|||||||
"""
|
"""
|
||||||
candidates = event.get("candidates", [])
|
candidates = event.get("candidates", [])
|
||||||
if candidates:
|
if candidates:
|
||||||
return candidates[0].get("finishReason")
|
reason = candidates[0].get("finishReason")
|
||||||
|
return str(reason) if reason is not None else None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
|
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
|
||||||
@@ -285,7 +289,8 @@ class GeminiStreamParser:
|
|||||||
Returns:
|
Returns:
|
||||||
模型版本,如果没有返回 None
|
模型版本,如果没有返回 None
|
||||||
"""
|
"""
|
||||||
return event.get("modelVersion")
|
version = event.get("modelVersion")
|
||||||
|
return str(version) if version is not None else None
|
||||||
|
|
||||||
def extract_safety_ratings(self, event: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
|
def extract_safety_ratings(self, event: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
@@ -301,7 +306,10 @@ class GeminiStreamParser:
|
|||||||
if not candidates:
|
if not candidates:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return candidates[0].get("safetyRatings")
|
ratings = candidates[0].get("safetyRatings")
|
||||||
|
if isinstance(ratings, list):
|
||||||
|
return ratings
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["GeminiStreamParser"]
|
__all__ = ["GeminiStreamParser"]
|
||||||
|
|||||||
@@ -78,7 +78,10 @@ class OpenAIStreamParser:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return json.loads(line)
|
result = json.loads(line)
|
||||||
|
if isinstance(result, dict):
|
||||||
|
return result
|
||||||
|
return None
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -116,7 +119,8 @@ class OpenAIStreamParser:
|
|||||||
"""
|
"""
|
||||||
choices = chunk.get("choices", [])
|
choices = chunk.get("choices", [])
|
||||||
if choices:
|
if choices:
|
||||||
return choices[0].get("finish_reason")
|
reason = choices[0].get("finish_reason")
|
||||||
|
return str(reason) if reason is not None else None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def extract_text_delta(self, chunk: Dict[str, Any]) -> Optional[str]:
|
def extract_text_delta(self, chunk: Dict[str, Any]) -> Optional[str]:
|
||||||
@@ -156,7 +160,10 @@ class OpenAIStreamParser:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
delta = choices[0].get("delta", {})
|
delta = choices[0].get("delta", {})
|
||||||
return delta.get("tool_calls")
|
tool_calls = delta.get("tool_calls")
|
||||||
|
if isinstance(tool_calls, list):
|
||||||
|
return tool_calls
|
||||||
|
return None
|
||||||
|
|
||||||
def extract_role(self, chunk: Dict[str, Any]) -> Optional[str]:
|
def extract_role(self, chunk: Dict[str, Any]) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
@@ -175,7 +182,8 @@ class OpenAIStreamParser:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
delta = choices[0].get("delta", {})
|
delta = choices[0].get("delta", {})
|
||||||
return delta.get("role")
|
role = delta.get("role")
|
||||||
|
return str(role) if role is not None else None
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["OpenAIStreamParser"]
|
__all__ = ["OpenAIStreamParser"]
|
||||||
|
|||||||
@@ -61,15 +61,18 @@ async def get_model_supported_capabilities(
|
|||||||
获取指定模型支持的能力列表
|
获取指定模型支持的能力列表
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name: 模型名称(如 claude-sonnet-4-20250514)
|
model_name: 模型名称(如 claude-sonnet-4-20250514,必须是 GlobalModel.name)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
模型支持的能力列表,以及每个能力的详细定义
|
模型支持的能力列表,以及每个能力的详细定义
|
||||||
"""
|
"""
|
||||||
from src.services.model.mapping_resolver import get_model_mapping_resolver
|
from src.models.database import GlobalModel
|
||||||
|
|
||||||
mapping_resolver = get_model_mapping_resolver()
|
global_model = (
|
||||||
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None)
|
db.query(GlobalModel)
|
||||||
|
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
if not global_model:
|
if not global_model:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Optional
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from sqlalchemy import and_, func
|
from sqlalchemy import and_, func
|
||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||||
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
|
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
|
||||||
@@ -713,7 +713,7 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
|||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from src.models.database import Model, ModelMapping, ProviderEndpoint
|
from src.models.database import Model, ProviderEndpoint
|
||||||
|
|
||||||
db = context.db
|
db = context.db
|
||||||
|
|
||||||
@@ -765,53 +765,6 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 查询该 Provider 所有 Model 对应的 GlobalModel 的别名/映射
|
|
||||||
provider_model_global_ids = {
|
|
||||||
m.global_model_id for m in provider.models if m.global_model_id
|
|
||||||
}
|
|
||||||
if provider_model_global_ids:
|
|
||||||
# 查询全局别名 + Provider 特定映射
|
|
||||||
alias_mappings = (
|
|
||||||
db.query(ModelMapping)
|
|
||||||
.options(joinedload(ModelMapping.target_global_model))
|
|
||||||
.filter(
|
|
||||||
ModelMapping.target_global_model_id.in_(provider_model_global_ids),
|
|
||||||
ModelMapping.is_active == True,
|
|
||||||
(ModelMapping.provider_id == provider.id)
|
|
||||||
| (ModelMapping.provider_id == None),
|
|
||||||
)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
for alias_obj in alias_mappings:
|
|
||||||
# 为这个别名找到该 Provider 的 Model 实现
|
|
||||||
model = next(
|
|
||||||
(
|
|
||||||
m
|
|
||||||
for m in provider.models
|
|
||||||
if m.global_model_id == alias_obj.target_global_model_id
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if model:
|
|
||||||
models_data.append(
|
|
||||||
{
|
|
||||||
"id": alias_obj.id,
|
|
||||||
"name": alias_obj.source_model,
|
|
||||||
"display_name": (
|
|
||||||
alias_obj.target_global_model.display_name
|
|
||||||
if alias_obj.target_global_model
|
|
||||||
else alias_obj.source_model
|
|
||||||
),
|
|
||||||
"input_price_per_1m": model.input_price_per_1m,
|
|
||||||
"output_price_per_1m": model.output_price_per_1m,
|
|
||||||
"cache_creation_price_per_1m": model.cache_creation_price_per_1m,
|
|
||||||
"cache_read_price_per_1m": model.cache_read_price_per_1m,
|
|
||||||
"supports_vision": model.supports_vision,
|
|
||||||
"supports_function_calling": model.supports_function_calling,
|
|
||||||
"supports_streaming": model.supports_streaming,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
result.append(
|
result.append(
|
||||||
{
|
{
|
||||||
"id": provider.id,
|
"id": provider.id,
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ class CacheTTL:
|
|||||||
# Provider/Model 缓存 - 配置变更不频繁
|
# Provider/Model 缓存 - 配置变更不频繁
|
||||||
PROVIDER = 300 # 5分钟
|
PROVIDER = 300 # 5分钟
|
||||||
MODEL = 300 # 5分钟
|
MODEL = 300 # 5分钟
|
||||||
MODEL_MAPPING = 300 # 5分钟
|
|
||||||
|
|
||||||
# 缓存亲和性 - 对应 provider_api_key.cache_ttl_minutes 默认值
|
# 缓存亲和性 - 对应 provider_api_key.cache_ttl_minutes 默认值
|
||||||
CACHE_AFFINITY = 300 # 5分钟
|
CACHE_AFFINITY = 300 # 5分钟
|
||||||
@@ -33,9 +32,6 @@ class CacheSize:
|
|||||||
# 默认 LRU 缓存大小
|
# 默认 LRU 缓存大小
|
||||||
DEFAULT = 1000
|
DEFAULT = 1000
|
||||||
|
|
||||||
# ModelMapping 缓存(可能有较多别名)
|
|
||||||
MODEL_MAPPING = 2000
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# 并发和限流常量
|
# 并发和限流常量
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ class SyncLRUCache:
|
|||||||
"""删除缓存值(通过索引)"""
|
"""删除缓存值(通过索引)"""
|
||||||
self.delete(key)
|
self.delete(key)
|
||||||
|
|
||||||
def keys(self):
|
def keys(self) -> list:
|
||||||
"""返回所有未过期的 key"""
|
"""返回所有未过期的 key"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ FILE_FORMAT = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:
|
|||||||
logger.remove()
|
logger.remove()
|
||||||
|
|
||||||
|
|
||||||
def _log_filter(record):
|
def _log_filter(record: dict) -> bool: # type: ignore[type-arg]
|
||||||
return "watchfiles" not in record["name"]
|
return "watchfiles" not in record["name"]
|
||||||
|
|
||||||
|
|
||||||
@@ -76,7 +76,7 @@ if IS_DOCKER:
|
|||||||
sys.stdout,
|
sys.stdout,
|
||||||
format=CONSOLE_FORMAT_PROD,
|
format=CONSOLE_FORMAT_PROD,
|
||||||
level=LOG_LEVEL,
|
level=LOG_LEVEL,
|
||||||
filter=_log_filter,
|
filter=_log_filter, # type: ignore[arg-type]
|
||||||
colorize=False,
|
colorize=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -84,7 +84,7 @@ else:
|
|||||||
sys.stdout,
|
sys.stdout,
|
||||||
format=CONSOLE_FORMAT_DEV,
|
format=CONSOLE_FORMAT_DEV,
|
||||||
level=LOG_LEVEL,
|
level=LOG_LEVEL,
|
||||||
filter=_log_filter,
|
filter=_log_filter, # type: ignore[arg-type]
|
||||||
colorize=True,
|
colorize=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -97,7 +97,7 @@ if not DISABLE_FILE_LOG:
|
|||||||
log_dir / "app.log",
|
log_dir / "app.log",
|
||||||
format=FILE_FORMAT,
|
format=FILE_FORMAT,
|
||||||
level="DEBUG",
|
level="DEBUG",
|
||||||
filter=_log_filter,
|
filter=_log_filter, # type: ignore[arg-type]
|
||||||
rotation="00:00",
|
rotation="00:00",
|
||||||
retention="30 days",
|
retention="30 days",
|
||||||
compression="gz",
|
compression="gz",
|
||||||
@@ -110,7 +110,7 @@ if not DISABLE_FILE_LOG:
|
|||||||
log_dir / "error.log",
|
log_dir / "error.log",
|
||||||
format=FILE_FORMAT,
|
format=FILE_FORMAT,
|
||||||
level="ERROR",
|
level="ERROR",
|
||||||
filter=_log_filter,
|
filter=_log_filter, # type: ignore[arg-type]
|
||||||
rotation="00:00",
|
rotation="00:00",
|
||||||
retention="30 days",
|
retention="30 days",
|
||||||
compression="gz",
|
compression="gz",
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
class ProviderHealthTracker:
|
class ProviderHealthTracker:
|
||||||
@@ -32,7 +32,7 @@ class ProviderHealthTracker:
|
|||||||
# 存储优先级调整
|
# 存储优先级调整
|
||||||
self.priority_adjustments: Dict[str, int] = {}
|
self.priority_adjustments: Dict[str, int] = {}
|
||||||
|
|
||||||
def record_success(self, provider_name: str):
|
def record_success(self, provider_name: str) -> None:
|
||||||
"""记录成功的请求"""
|
"""记录成功的请求"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ class ProviderHealthTracker:
|
|||||||
if self.priority_adjustments.get(provider_name, 0) < 0:
|
if self.priority_adjustments.get(provider_name, 0) < 0:
|
||||||
self.priority_adjustments[provider_name] += 1
|
self.priority_adjustments[provider_name] += 1
|
||||||
|
|
||||||
def record_failure(self, provider_name: str):
|
def record_failure(self, provider_name: str) -> None:
|
||||||
"""记录失败的请求"""
|
"""记录失败的请求"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
@@ -93,7 +93,7 @@ class ProviderHealthTracker:
|
|||||||
"status": self._get_status_label(failure_rate, recent_failures),
|
"status": self._get_status_label(failure_rate, recent_failures),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _cleanup_old_records(self, provider_name: str, current_time: float):
|
def _cleanup_old_records(self, provider_name: str, current_time: float) -> None:
|
||||||
"""清理超出时间窗口的记录"""
|
"""清理超出时间窗口的记录"""
|
||||||
# 清理失败记录
|
# 清理失败记录
|
||||||
self.failures[provider_name] = [
|
self.failures[provider_name] = [
|
||||||
@@ -130,7 +130,7 @@ class ProviderHealthTracker:
|
|||||||
adjustment = self.get_priority_adjustment(provider_name)
|
adjustment = self.get_priority_adjustment(provider_name)
|
||||||
return adjustment > -3
|
return adjustment > -3
|
||||||
|
|
||||||
def reset_provider_health(self, provider_name: str):
|
def reset_provider_health(self, provider_name: str) -> None:
|
||||||
"""重置提供商的健康状态(管理员手动操作)"""
|
"""重置提供商的健康状态(管理员手动操作)"""
|
||||||
self.failures[provider_name] = []
|
self.failures[provider_name] = []
|
||||||
self.successes[provider_name] = []
|
self.successes[provider_name] = []
|
||||||
@@ -146,7 +146,7 @@ class SimpleProviderSelector:
|
|||||||
def __init__(self, health_tracker: ProviderHealthTracker):
|
def __init__(self, health_tracker: ProviderHealthTracker):
|
||||||
self.health_tracker = health_tracker
|
self.health_tracker = health_tracker
|
||||||
|
|
||||||
def select_provider(self, providers: list, specified_provider: Optional[str] = None):
|
def select_provider(self, providers: list, specified_provider: Optional[str] = None) -> Any:
|
||||||
"""
|
"""
|
||||||
选择提供商
|
选择提供商
|
||||||
|
|
||||||
|
|||||||
@@ -5,12 +5,12 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Callable, TypeVar
|
from typing import Any, Callable, Coroutine, TypeVar
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
async def run_in_executor(func: Callable[..., T], *args, **kwargs) -> T:
|
async def run_in_executor(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
|
||||||
"""
|
"""
|
||||||
在线程池中运行同步函数,避免阻塞事件循环
|
在线程池中运行同步函数,避免阻塞事件循环
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ async def run_in_executor(func: Callable[..., T], *args, **kwargs) -> T:
|
|||||||
return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
|
return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Any]:
|
def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]:
|
||||||
"""
|
"""
|
||||||
装饰器:包装同步数据库函数为异步函数
|
装饰器:包装同步数据库函数为异步函数
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Any]:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def wrapper(*args, **kwargs):
|
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||||
return await run_in_executor(func, *args, **kwargs)
|
return await run_in_executor(func, *args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|||||||
@@ -347,7 +347,7 @@ def init_default_models(db: Session):
|
|||||||
"""初始化默认模型配置"""
|
"""初始化默认模型配置"""
|
||||||
|
|
||||||
# 注意:作为中转代理服务,不再预设模型配置
|
# 注意:作为中转代理服务,不再预设模型配置
|
||||||
# 模型配置应该通过 Model 和 ModelMapping 表动态管理
|
# 模型配置应该通过 GlobalModel 和 Model 表动态管理
|
||||||
# 这个函数保留用于未来可能的默认模型初始化
|
# 这个函数保留用于未来可能的默认模型初始化
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -2,4 +2,4 @@
|
|||||||
中间件模块
|
中间件模块
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__all__ = []
|
__all__: list[str] = []
|
||||||
|
|||||||
@@ -334,7 +334,6 @@ class ProviderResponse(BaseModel):
|
|||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
models_count: int = 0
|
models_count: int = 0
|
||||||
active_models_count: int = 0
|
active_models_count: int = 0
|
||||||
model_mappings_count: int = 0
|
|
||||||
api_keys_count: int = 0
|
api_keys_count: int = 0
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@@ -346,7 +345,11 @@ class ModelCreate(BaseModel):
|
|||||||
"""创建模型请求 - 价格和能力字段可选,为空时使用 GlobalModel 默认值"""
|
"""创建模型请求 - 价格和能力字段可选,为空时使用 GlobalModel 默认值"""
|
||||||
|
|
||||||
provider_model_name: str = Field(
|
provider_model_name: str = Field(
|
||||||
..., min_length=1, max_length=200, description="Provider 侧的模型名称"
|
..., min_length=1, max_length=200, description="Provider 侧的主模型名称"
|
||||||
|
)
|
||||||
|
provider_model_aliases: Optional[List[dict]] = Field(
|
||||||
|
None,
|
||||||
|
description="模型名称别名列表,格式: [{'name': 'alias1', 'priority': 1}, ...]",
|
||||||
)
|
)
|
||||||
global_model_id: str = Field(..., description="关联的 GlobalModel ID(必填)")
|
global_model_id: str = Field(..., description="关联的 GlobalModel ID(必填)")
|
||||||
# 按次计费配置 - 可选,为空时使用 GlobalModel 默认值
|
# 按次计费配置 - 可选,为空时使用 GlobalModel 默认值
|
||||||
@@ -374,6 +377,10 @@ class ModelUpdate(BaseModel):
|
|||||||
"""更新模型请求"""
|
"""更新模型请求"""
|
||||||
|
|
||||||
provider_model_name: Optional[str] = Field(None, min_length=1, max_length=200)
|
provider_model_name: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||||
|
provider_model_aliases: Optional[List[dict]] = Field(
|
||||||
|
None,
|
||||||
|
description="模型名称别名列表,格式: [{'name': 'alias1', 'priority': 1}, ...]",
|
||||||
|
)
|
||||||
global_model_id: Optional[str] = None
|
global_model_id: Optional[str] = None
|
||||||
# 按次计费配置
|
# 按次计费配置
|
||||||
price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
|
price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
|
||||||
@@ -398,6 +405,7 @@ class ModelResponse(BaseModel):
|
|||||||
provider_id: str
|
provider_id: str
|
||||||
global_model_id: Optional[str]
|
global_model_id: Optional[str]
|
||||||
provider_model_name: str
|
provider_model_name: str
|
||||||
|
provider_model_aliases: Optional[List[dict]] = None
|
||||||
|
|
||||||
# 按次计费配置
|
# 按次计费配置
|
||||||
price_per_request: Optional[float] = None
|
price_per_request: Optional[float] = None
|
||||||
@@ -465,54 +473,6 @@ class ModelDetailResponse(BaseModel):
|
|||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
# ========== 模型映射 ==========
|
|
||||||
class ModelMappingCreate(BaseModel):
|
|
||||||
"""创建模型映射请求(源模型到目标模型的映射)"""
|
|
||||||
|
|
||||||
source_model: str = Field(..., min_length=1, max_length=200, description="源模型名或别名")
|
|
||||||
target_global_model_id: str = Field(..., description="目标 GlobalModel ID")
|
|
||||||
provider_id: Optional[str] = Field(None, description="Provider ID(为空时表示全局别名)")
|
|
||||||
mapping_type: str = Field(
|
|
||||||
"alias",
|
|
||||||
description="映射类型:alias=按目标模型计费(别名),mapping=按源模型计费(降级映射)",
|
|
||||||
)
|
|
||||||
is_active: bool = Field(True, description="是否启用")
|
|
||||||
|
|
||||||
|
|
||||||
class ModelMappingUpdate(BaseModel):
|
|
||||||
"""更新模型映射请求"""
|
|
||||||
|
|
||||||
source_model: Optional[str] = Field(
|
|
||||||
None, min_length=1, max_length=200, description="源模型名或别名"
|
|
||||||
)
|
|
||||||
target_global_model_id: Optional[str] = Field(None, description="目标 GlobalModel ID")
|
|
||||||
provider_id: Optional[str] = Field(None, description="Provider ID(为空时表示全局别名)")
|
|
||||||
mapping_type: Optional[str] = Field(
|
|
||||||
None, description="映射类型:alias=按目标模型计费(别名),mapping=按源模型计费(降级映射)"
|
|
||||||
)
|
|
||||||
is_active: Optional[bool] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ModelMappingResponse(BaseModel):
|
|
||||||
"""模型映射响应"""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
source_model: str
|
|
||||||
target_global_model_id: str
|
|
||||||
target_global_model_name: Optional[str]
|
|
||||||
target_global_model_display_name: Optional[str]
|
|
||||||
provider_id: Optional[str]
|
|
||||||
provider_name: Optional[str]
|
|
||||||
scope: str = Field(..., description="global 或 provider")
|
|
||||||
mapping_type: str = Field(..., description="映射类型:alias 或 mapping")
|
|
||||||
is_active: bool
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 系统设置 ==========
|
# ========== 系统设置 ==========
|
||||||
class SystemSettingsRequest(BaseModel):
|
class SystemSettingsRequest(BaseModel):
|
||||||
"""系统设置请求"""
|
"""系统设置请求"""
|
||||||
@@ -558,7 +518,6 @@ class PublicProviderResponse(BaseModel):
|
|||||||
# 统计信息
|
# 统计信息
|
||||||
models_count: int
|
models_count: int
|
||||||
active_models_count: int
|
active_models_count: int
|
||||||
mappings_count: int
|
|
||||||
endpoints_count: int # 端点总数
|
endpoints_count: int # 端点总数
|
||||||
active_endpoints_count: int # 活跃端点数
|
active_endpoints_count: int # 活跃端点数
|
||||||
|
|
||||||
@@ -587,19 +546,6 @@ class PublicModelResponse(BaseModel):
|
|||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
|
|
||||||
|
|
||||||
class PublicModelMappingResponse(BaseModel):
|
|
||||||
"""公开的模型映射信息响应"""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
source_model: str
|
|
||||||
target_global_model_id: str
|
|
||||||
target_global_model_name: Optional[str]
|
|
||||||
target_global_model_display_name: Optional[str]
|
|
||||||
provider_id: Optional[str] = None
|
|
||||||
scope: str = Field(..., description="global 或 provider")
|
|
||||||
is_active: bool
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderStatsResponse(BaseModel):
|
class ProviderStatsResponse(BaseModel):
|
||||||
"""提供商统计信息响应"""
|
"""提供商统计信息响应"""
|
||||||
|
|
||||||
@@ -607,7 +553,6 @@ class ProviderStatsResponse(BaseModel):
|
|||||||
active_providers: int
|
active_providers: int
|
||||||
total_models: int
|
total_models: int
|
||||||
active_models: int
|
active_models: int
|
||||||
total_mappings: int
|
|
||||||
supported_formats: List[str]
|
supported_formats: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
class SSEEventParser:
|
class SSEEventParser:
|
||||||
@@ -8,7 +8,7 @@ class SSEEventParser:
|
|||||||
self._reset_buffer()
|
self._reset_buffer()
|
||||||
|
|
||||||
def _reset_buffer(self) -> None:
|
def _reset_buffer(self) -> None:
|
||||||
self._buffer: Dict[str, Optional[str] | List[str]] = {
|
self._buffer: Dict[str, Union[Optional[str], List[str]]] = {
|
||||||
"event": None,
|
"event": None,
|
||||||
"data": [],
|
"data": [],
|
||||||
"id": None,
|
"id": None,
|
||||||
@@ -17,16 +17,19 @@ class SSEEventParser:
|
|||||||
|
|
||||||
def _finalize_event(self) -> Optional[Dict[str, Optional[str]]]:
|
def _finalize_event(self) -> Optional[Dict[str, Optional[str]]]:
|
||||||
data_lines = self._buffer.get("data", [])
|
data_lines = self._buffer.get("data", [])
|
||||||
if not data_lines:
|
if not isinstance(data_lines, list) or not data_lines:
|
||||||
self._reset_buffer()
|
self._reset_buffer()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
data_str = "\n".join(data_lines)
|
data_str = "\n".join(data_lines)
|
||||||
event = {
|
event_val = self._buffer.get("event")
|
||||||
"event": self._buffer.get("event"),
|
id_val = self._buffer.get("id")
|
||||||
|
retry_val = self._buffer.get("retry")
|
||||||
|
event: Dict[str, Optional[str]] = {
|
||||||
|
"event": event_val if isinstance(event_val, str) else None,
|
||||||
"data": data_str,
|
"data": data_str,
|
||||||
"id": self._buffer.get("id"),
|
"id": id_val if isinstance(id_val, str) else None,
|
||||||
"retry": self._buffer.get("retry"),
|
"retry": retry_val if isinstance(retry_val, str) else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
self._reset_buffer()
|
self._reset_buffer()
|
||||||
|
|||||||
Reference in New Issue
Block a user