refactor(backend): update handlers, utilities and core modules after models restructure

This commit is contained in:
fawney19
2025-12-15 14:30:53 +08:00
parent 03ee6c16d9
commit 88e37594cf
19 changed files with 121 additions and 186 deletions

View File

@@ -263,7 +263,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
mapping = await mapper.get_mapping(source_model, provider_id)
if mapping and mapping.model:
mapped_name = str(mapping.model.provider_model_name)
# 使用 select_provider_model_name 支持别名功能
# 传入 api_key.id 作为 affinity_key实现相同用户稳定选择同一别名
affinity_key = self.api_key.id if self.api_key else None
mapped_name = mapping.model.select_provider_model_name(affinity_key)
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
return mapped_name

View File

@@ -190,14 +190,17 @@ class CliMessageHandlerBase(BaseMessageHandler):
"""
获取模型映射后的实际模型名
按优先级查找:映射 → 别名 → 直接匹配 GlobalModel
查找逻辑:
1. 直接通过 GlobalModel.name 匹配
2. 查找该 Provider 的 Model 实现
3. 使用 provider_model_name / provider_model_aliases 选择最终名称
Args:
source_model: 用户请求的模型名(可能是别名
source_model: 用户请求的模型名(必须是 GlobalModel.name
provider_id: Provider ID
Returns:
映射后的 provider_model_name,如果没有找到映射则返回 None
映射后的 Provider 模型名,如果没有找到映射则返回 None
"""
from src.services.model.mapper import ModelMapperMiddleware
@@ -207,7 +210,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
if mapping and mapping.model:
mapped_name = str(mapping.model.provider_model_name)
# 使用 select_provider_model_name 支持别名功能
# 传入 api_key.id 作为 affinity_key实现相同用户稳定选择同一别名
affinity_key = self.api_key.id if self.api_key else None
mapped_name = mapping.model.select_provider_model_name(affinity_key)
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
return mapped_name

View File

@@ -5,7 +5,7 @@
不再经过 Protocol 抽象层。
"""
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Type
from src.api.handlers.base.response_parser import (
ParsedChunk,
@@ -60,7 +60,7 @@ def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[s
class OpenAIResponseParser(ResponseParser):
"""OpenAI 格式响应解析器"""
def __init__(self):
def __init__(self) -> None:
from src.api.handlers.openai.stream_parser import OpenAIStreamParser
self._parser = OpenAIStreamParser()
@@ -146,7 +146,7 @@ class OpenAIResponseParser(ResponseParser):
if choices:
message = choices[0].get("message", {})
content = message.get("content")
if content:
if isinstance(content, str):
return content
return ""
@@ -158,7 +158,7 @@ class OpenAIResponseParser(ResponseParser):
class OpenAICliResponseParser(OpenAIResponseParser):
"""OpenAI CLI 格式响应解析器"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.name = "OPENAI_CLI"
self.api_format = "OPENAI_CLI"
@@ -167,7 +167,7 @@ class OpenAICliResponseParser(OpenAIResponseParser):
class ClaudeResponseParser(ResponseParser):
"""Claude 格式响应解析器"""
def __init__(self):
def __init__(self) -> None:
from src.api.handlers.claude.stream_parser import ClaudeStreamParser
self._parser = ClaudeStreamParser()
@@ -291,7 +291,7 @@ class ClaudeResponseParser(ResponseParser):
class ClaudeCliResponseParser(ClaudeResponseParser):
"""Claude CLI 格式响应解析器"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.name = "CLAUDE_CLI"
self.api_format = "CLAUDE_CLI"
@@ -300,7 +300,7 @@ class ClaudeCliResponseParser(ClaudeResponseParser):
class GeminiResponseParser(ResponseParser):
"""Gemini 格式响应解析器"""
def __init__(self):
def __init__(self) -> None:
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
self._parser = GeminiStreamParser()
@@ -443,20 +443,20 @@ class GeminiResponseParser(ResponseParser):
使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误
"""
return self._parser.is_error_event(response)
return bool(self._parser.is_error_event(response))
class GeminiCliResponseParser(GeminiResponseParser):
"""Gemini CLI 格式响应解析器"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.name = "GEMINI_CLI"
self.api_format = "GEMINI_CLI"
# 解析器注册表
_PARSERS = {
_PARSERS: Dict[str, Type[ResponseParser]] = {
"CLAUDE": ClaudeResponseParser,
"CLAUDE_CLI": ClaudeCliResponseParser,
"OPENAI": OpenAIResponseParser,
@@ -498,6 +498,5 @@ __all__ = [
"GeminiResponseParser",
"GeminiCliResponseParser",
"get_parser_for_format",
"get_parser_from_protocol",
"is_cli_format",
]

View File

@@ -108,7 +108,10 @@ class ClaudeStreamParser:
return None
try:
return json.loads(line)
result = json.loads(line)
if isinstance(result, dict):
return result
return None
except json.JSONDecodeError:
return None
@@ -147,7 +150,8 @@ class ClaudeStreamParser:
Returns:
事件类型字符串
"""
return event.get("type")
event_type = event.get("type")
return str(event_type) if event_type is not None else None
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
"""
@@ -164,7 +168,8 @@ class ClaudeStreamParser:
delta = event.get("delta", {})
if delta.get("type") == self.DELTA_TEXT:
return delta.get("text")
text = delta.get("text")
return str(text) if text is not None else None
return None
@@ -219,7 +224,8 @@ class ClaudeStreamParser:
return None
message = event.get("message", {})
return message.get("id")
msg_id = message.get("id")
return str(msg_id) if msg_id is not None else None
def extract_stop_reason(self, event: Dict[str, Any]) -> Optional[str]:
"""
@@ -235,7 +241,8 @@ class ClaudeStreamParser:
return None
delta = event.get("delta", {})
return delta.get("stop_reason")
reason = delta.get("stop_reason")
return str(reason) if reason is not None else None
__all__ = ["ClaudeStreamParser"]

View File

@@ -70,7 +70,7 @@ class ClaudeToGeminiConverter:
return [{"text": content}]
if isinstance(content, list):
parts = []
parts: List[Dict[str, Any]] = []
for block in content:
if isinstance(block, str):
parts.append({"text": block})
@@ -249,6 +249,8 @@ class GeminiToClaudeConverter:
"RECITATION": "content_filtered",
"OTHER": "stop_sequence",
}
if finish_reason is None:
return "end_turn"
return mapping.get(finish_reason, "end_turn")
def _create_empty_response(self) -> Dict[str, Any]:
@@ -365,7 +367,7 @@ class OpenAIToGeminiConverter:
return [{"text": content}]
if isinstance(content, list):
parts = []
parts: List[Dict[str, Any]] = []
for item in content:
if isinstance(item, str):
parts.append({"text": item})
@@ -524,7 +526,7 @@ class GeminiToOpenAIConverter:
"total_tokens": prompt_tokens + completion_tokens,
}
def _convert_finish_reason(self, finish_reason: Optional[str]) -> Optional[str]:
def _convert_finish_reason(self, finish_reason: Optional[str]) -> str:
"""转换停止原因"""
mapping = {
"STOP": "stop",
@@ -533,6 +535,8 @@ class GeminiToOpenAIConverter:
"RECITATION": "content_filter",
"OTHER": "stop",
}
if finish_reason is None:
return "stop"
return mapping.get(finish_reason, "stop")

View File

@@ -10,7 +10,7 @@ Gemini API 的流式响应格式与 Claude/OpenAI 不同:
"""
import json
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
class GeminiStreamParser:
@@ -32,18 +32,18 @@ class GeminiStreamParser:
FINISH_REASON_RECITATION = "RECITATION"
FINISH_REASON_OTHER = "OTHER"
def __init__(self):
def __init__(self) -> None:
self._buffer = ""
self._in_array = False
self._brace_depth = 0
def reset(self):
def reset(self) -> None:
"""重置解析器状态"""
self._buffer = ""
self._in_array = False
self._brace_depth = 0
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
def parse_chunk(self, chunk: Union[bytes, str]) -> List[Dict[str, Any]]:
"""
解析流式数据块
@@ -111,7 +111,10 @@ class GeminiStreamParser:
return None
try:
return json.loads(line.strip().rstrip(","))
result = json.loads(line.strip().rstrip(","))
if isinstance(result, dict):
return result
return None
except json.JSONDecodeError:
return None
@@ -216,7 +219,8 @@ class GeminiStreamParser:
"""
candidates = event.get("candidates", [])
if candidates:
return candidates[0].get("finishReason")
reason = candidates[0].get("finishReason")
return str(reason) if reason is not None else None
return None
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
@@ -285,7 +289,8 @@ class GeminiStreamParser:
Returns:
模型版本,如果没有返回 None
"""
return event.get("modelVersion")
version = event.get("modelVersion")
return str(version) if version is not None else None
def extract_safety_ratings(self, event: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
"""
@@ -301,7 +306,10 @@ class GeminiStreamParser:
if not candidates:
return None
return candidates[0].get("safetyRatings")
ratings = candidates[0].get("safetyRatings")
if isinstance(ratings, list):
return ratings
return None
__all__ = ["GeminiStreamParser"]

View File

@@ -78,7 +78,10 @@ class OpenAIStreamParser:
return None
try:
return json.loads(line)
result = json.loads(line)
if isinstance(result, dict):
return result
return None
except json.JSONDecodeError:
return None
@@ -116,7 +119,8 @@ class OpenAIStreamParser:
"""
choices = chunk.get("choices", [])
if choices:
return choices[0].get("finish_reason")
reason = choices[0].get("finish_reason")
return str(reason) if reason is not None else None
return None
def extract_text_delta(self, chunk: Dict[str, Any]) -> Optional[str]:
@@ -156,7 +160,10 @@ class OpenAIStreamParser:
return None
delta = choices[0].get("delta", {})
return delta.get("tool_calls")
tool_calls = delta.get("tool_calls")
if isinstance(tool_calls, list):
return tool_calls
return None
def extract_role(self, chunk: Dict[str, Any]) -> Optional[str]:
"""
@@ -175,7 +182,8 @@ class OpenAIStreamParser:
return None
delta = choices[0].get("delta", {})
return delta.get("role")
role = delta.get("role")
return str(role) if role is not None else None
__all__ = ["OpenAIStreamParser"]