mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-07 10:12:27 +08:00
refactor(backend): update handlers, utilities and core modules after models restructure
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import List, Sequence, Tuple, TypeVar
|
||||
from typing import Any, List, Sequence, Tuple, TypeVar
|
||||
|
||||
from sqlalchemy.orm import Query
|
||||
|
||||
@@ -40,10 +40,10 @@ def paginate_sequence(
|
||||
return sliced, meta
|
||||
|
||||
|
||||
def build_pagination_payload(items: List[dict], meta: PaginationMeta, **extra) -> dict:
|
||||
def build_pagination_payload(items: List[dict], meta: PaginationMeta, **extra: Any) -> dict:
|
||||
"""
|
||||
构建标准分页响应 payload。
|
||||
"""
|
||||
payload = {"items": items, "meta": meta.to_dict()}
|
||||
payload: dict = {"items": items, "meta": meta.to_dict()}
|
||||
payload.update(extra)
|
||||
return payload
|
||||
|
||||
@@ -263,7 +263,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
mapping = await mapper.get_mapping(source_model, provider_id)
|
||||
|
||||
if mapping and mapping.model:
|
||||
mapped_name = str(mapping.model.provider_model_name)
|
||||
# 使用 select_provider_model_name 支持别名功能
|
||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||
affinity_key = self.api_key.id if self.api_key else None
|
||||
mapped_name = mapping.model.select_provider_model_name(affinity_key)
|
||||
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
|
||||
return mapped_name
|
||||
|
||||
|
||||
@@ -190,14 +190,17 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
"""
|
||||
获取模型映射后的实际模型名
|
||||
|
||||
按优先级查找:映射 → 别名 → 直接匹配 GlobalModel
|
||||
查找逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现
|
||||
3. 使用 provider_model_name / provider_model_aliases 选择最终名称
|
||||
|
||||
Args:
|
||||
source_model: 用户请求的模型名(可能是别名)
|
||||
source_model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
provider_id: Provider ID
|
||||
|
||||
Returns:
|
||||
映射后的 provider_model_name,如果没有找到映射则返回 None
|
||||
映射后的 Provider 模型名,如果没有找到映射则返回 None
|
||||
"""
|
||||
from src.services.model.mapper import ModelMapperMiddleware
|
||||
|
||||
@@ -207,7 +210,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
|
||||
|
||||
if mapping and mapping.model:
|
||||
mapped_name = str(mapping.model.provider_model_name)
|
||||
# 使用 select_provider_model_name 支持别名功能
|
||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||
affinity_key = self.api_key.id if self.api_key else None
|
||||
mapped_name = mapping.model.select_provider_model_name(affinity_key)
|
||||
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
||||
return mapped_name
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
不再经过 Protocol 抽象层。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
from src.api.handlers.base.response_parser import (
|
||||
ParsedChunk,
|
||||
@@ -60,7 +60,7 @@ def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[s
|
||||
class OpenAIResponseParser(ResponseParser):
|
||||
"""OpenAI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
from src.api.handlers.openai.stream_parser import OpenAIStreamParser
|
||||
|
||||
self._parser = OpenAIStreamParser()
|
||||
@@ -146,7 +146,7 @@ class OpenAIResponseParser(ResponseParser):
|
||||
if choices:
|
||||
message = choices[0].get("message", {})
|
||||
content = message.get("content")
|
||||
if content:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return ""
|
||||
|
||||
@@ -158,7 +158,7 @@ class OpenAIResponseParser(ResponseParser):
|
||||
class OpenAICliResponseParser(OpenAIResponseParser):
|
||||
"""OpenAI CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.name = "OPENAI_CLI"
|
||||
self.api_format = "OPENAI_CLI"
|
||||
@@ -167,7 +167,7 @@ class OpenAICliResponseParser(OpenAIResponseParser):
|
||||
class ClaudeResponseParser(ResponseParser):
|
||||
"""Claude 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
from src.api.handlers.claude.stream_parser import ClaudeStreamParser
|
||||
|
||||
self._parser = ClaudeStreamParser()
|
||||
@@ -291,7 +291,7 @@ class ClaudeResponseParser(ResponseParser):
|
||||
class ClaudeCliResponseParser(ClaudeResponseParser):
|
||||
"""Claude CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.name = "CLAUDE_CLI"
|
||||
self.api_format = "CLAUDE_CLI"
|
||||
@@ -300,7 +300,7 @@ class ClaudeCliResponseParser(ClaudeResponseParser):
|
||||
class GeminiResponseParser(ResponseParser):
|
||||
"""Gemini 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
|
||||
|
||||
self._parser = GeminiStreamParser()
|
||||
@@ -443,20 +443,20 @@ class GeminiResponseParser(ResponseParser):
|
||||
|
||||
使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误
|
||||
"""
|
||||
return self._parser.is_error_event(response)
|
||||
return bool(self._parser.is_error_event(response))
|
||||
|
||||
|
||||
class GeminiCliResponseParser(GeminiResponseParser):
|
||||
"""Gemini CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.name = "GEMINI_CLI"
|
||||
self.api_format = "GEMINI_CLI"
|
||||
|
||||
|
||||
# 解析器注册表
|
||||
_PARSERS = {
|
||||
_PARSERS: Dict[str, Type[ResponseParser]] = {
|
||||
"CLAUDE": ClaudeResponseParser,
|
||||
"CLAUDE_CLI": ClaudeCliResponseParser,
|
||||
"OPENAI": OpenAIResponseParser,
|
||||
@@ -498,6 +498,5 @@ __all__ = [
|
||||
"GeminiResponseParser",
|
||||
"GeminiCliResponseParser",
|
||||
"get_parser_for_format",
|
||||
"get_parser_from_protocol",
|
||||
"is_cli_format",
|
||||
]
|
||||
|
||||
@@ -108,7 +108,10 @@ class ClaudeStreamParser:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(line)
|
||||
result = json.loads(line)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
@@ -147,7 +150,8 @@ class ClaudeStreamParser:
|
||||
Returns:
|
||||
事件类型字符串
|
||||
"""
|
||||
return event.get("type")
|
||||
event_type = event.get("type")
|
||||
return str(event_type) if event_type is not None else None
|
||||
|
||||
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
@@ -164,7 +168,8 @@ class ClaudeStreamParser:
|
||||
|
||||
delta = event.get("delta", {})
|
||||
if delta.get("type") == self.DELTA_TEXT:
|
||||
return delta.get("text")
|
||||
text = delta.get("text")
|
||||
return str(text) if text is not None else None
|
||||
|
||||
return None
|
||||
|
||||
@@ -219,7 +224,8 @@ class ClaudeStreamParser:
|
||||
return None
|
||||
|
||||
message = event.get("message", {})
|
||||
return message.get("id")
|
||||
msg_id = message.get("id")
|
||||
return str(msg_id) if msg_id is not None else None
|
||||
|
||||
def extract_stop_reason(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
@@ -235,7 +241,8 @@ class ClaudeStreamParser:
|
||||
return None
|
||||
|
||||
delta = event.get("delta", {})
|
||||
return delta.get("stop_reason")
|
||||
reason = delta.get("stop_reason")
|
||||
return str(reason) if reason is not None else None
|
||||
|
||||
|
||||
__all__ = ["ClaudeStreamParser"]
|
||||
|
||||
@@ -70,7 +70,7 @@ class ClaudeToGeminiConverter:
|
||||
return [{"text": content}]
|
||||
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
parts: List[Dict[str, Any]] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
parts.append({"text": block})
|
||||
@@ -249,6 +249,8 @@ class GeminiToClaudeConverter:
|
||||
"RECITATION": "content_filtered",
|
||||
"OTHER": "stop_sequence",
|
||||
}
|
||||
if finish_reason is None:
|
||||
return "end_turn"
|
||||
return mapping.get(finish_reason, "end_turn")
|
||||
|
||||
def _create_empty_response(self) -> Dict[str, Any]:
|
||||
@@ -365,7 +367,7 @@ class OpenAIToGeminiConverter:
|
||||
return [{"text": content}]
|
||||
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
parts: List[Dict[str, Any]] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append({"text": item})
|
||||
@@ -524,7 +526,7 @@ class GeminiToOpenAIConverter:
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
|
||||
def _convert_finish_reason(self, finish_reason: Optional[str]) -> Optional[str]:
|
||||
def _convert_finish_reason(self, finish_reason: Optional[str]) -> str:
|
||||
"""转换停止原因"""
|
||||
mapping = {
|
||||
"STOP": "stop",
|
||||
@@ -533,6 +535,8 @@ class GeminiToOpenAIConverter:
|
||||
"RECITATION": "content_filter",
|
||||
"OTHER": "stop",
|
||||
}
|
||||
if finish_reason is None:
|
||||
return "stop"
|
||||
return mapping.get(finish_reason, "stop")
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ Gemini API 的流式响应格式与 Claude/OpenAI 不同:
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class GeminiStreamParser:
|
||||
@@ -32,18 +32,18 @@ class GeminiStreamParser:
|
||||
FINISH_REASON_RECITATION = "RECITATION"
|
||||
FINISH_REASON_OTHER = "OTHER"
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._buffer = ""
|
||||
self._in_array = False
|
||||
self._brace_depth = 0
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""重置解析器状态"""
|
||||
self._buffer = ""
|
||||
self._in_array = False
|
||||
self._brace_depth = 0
|
||||
|
||||
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
|
||||
def parse_chunk(self, chunk: Union[bytes, str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
解析流式数据块
|
||||
|
||||
@@ -111,7 +111,10 @@ class GeminiStreamParser:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(line.strip().rstrip(","))
|
||||
result = json.loads(line.strip().rstrip(","))
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
@@ -216,7 +219,8 @@ class GeminiStreamParser:
|
||||
"""
|
||||
candidates = event.get("candidates", [])
|
||||
if candidates:
|
||||
return candidates[0].get("finishReason")
|
||||
reason = candidates[0].get("finishReason")
|
||||
return str(reason) if reason is not None else None
|
||||
return None
|
||||
|
||||
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
@@ -285,7 +289,8 @@ class GeminiStreamParser:
|
||||
Returns:
|
||||
模型版本,如果没有返回 None
|
||||
"""
|
||||
return event.get("modelVersion")
|
||||
version = event.get("modelVersion")
|
||||
return str(version) if version is not None else None
|
||||
|
||||
def extract_safety_ratings(self, event: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
@@ -301,7 +306,10 @@ class GeminiStreamParser:
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
return candidates[0].get("safetyRatings")
|
||||
ratings = candidates[0].get("safetyRatings")
|
||||
if isinstance(ratings, list):
|
||||
return ratings
|
||||
return None
|
||||
|
||||
|
||||
__all__ = ["GeminiStreamParser"]
|
||||
|
||||
@@ -78,7 +78,10 @@ class OpenAIStreamParser:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(line)
|
||||
result = json.loads(line)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
@@ -116,7 +119,8 @@ class OpenAIStreamParser:
|
||||
"""
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
return choices[0].get("finish_reason")
|
||||
reason = choices[0].get("finish_reason")
|
||||
return str(reason) if reason is not None else None
|
||||
return None
|
||||
|
||||
def extract_text_delta(self, chunk: Dict[str, Any]) -> Optional[str]:
|
||||
@@ -156,7 +160,10 @@ class OpenAIStreamParser:
|
||||
return None
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
return delta.get("tool_calls")
|
||||
tool_calls = delta.get("tool_calls")
|
||||
if isinstance(tool_calls, list):
|
||||
return tool_calls
|
||||
return None
|
||||
|
||||
def extract_role(self, chunk: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
@@ -175,7 +182,8 @@ class OpenAIStreamParser:
|
||||
return None
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
return delta.get("role")
|
||||
role = delta.get("role")
|
||||
return str(role) if role is not None else None
|
||||
|
||||
|
||||
__all__ = ["OpenAIStreamParser"]
|
||||
|
||||
@@ -61,15 +61,18 @@ async def get_model_supported_capabilities(
|
||||
获取指定模型支持的能力列表
|
||||
|
||||
Args:
|
||||
model_name: 模型名称(如 claude-sonnet-4-20250514)
|
||||
model_name: 模型名称(如 claude-sonnet-4-20250514,必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
模型支持的能力列表,以及每个能力的详细定义
|
||||
"""
|
||||
from src.services.model.mapping_resolver import get_model_mapping_resolver
|
||||
from src.models.database import GlobalModel
|
||||
|
||||
mapping_resolver = get_model_mapping_resolver()
|
||||
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None)
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not global_model:
|
||||
return {
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
|
||||
@@ -713,7 +713,7 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from src.models.database import Model, ModelMapping, ProviderEndpoint
|
||||
from src.models.database import Model, ProviderEndpoint
|
||||
|
||||
db = context.db
|
||||
|
||||
@@ -765,53 +765,6 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
||||
}
|
||||
)
|
||||
|
||||
# 查询该 Provider 所有 Model 对应的 GlobalModel 的别名/映射
|
||||
provider_model_global_ids = {
|
||||
m.global_model_id for m in provider.models if m.global_model_id
|
||||
}
|
||||
if provider_model_global_ids:
|
||||
# 查询全局别名 + Provider 特定映射
|
||||
alias_mappings = (
|
||||
db.query(ModelMapping)
|
||||
.options(joinedload(ModelMapping.target_global_model))
|
||||
.filter(
|
||||
ModelMapping.target_global_model_id.in_(provider_model_global_ids),
|
||||
ModelMapping.is_active == True,
|
||||
(ModelMapping.provider_id == provider.id)
|
||||
| (ModelMapping.provider_id == None),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for alias_obj in alias_mappings:
|
||||
# 为这个别名找到该 Provider 的 Model 实现
|
||||
model = next(
|
||||
(
|
||||
m
|
||||
for m in provider.models
|
||||
if m.global_model_id == alias_obj.target_global_model_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if model:
|
||||
models_data.append(
|
||||
{
|
||||
"id": alias_obj.id,
|
||||
"name": alias_obj.source_model,
|
||||
"display_name": (
|
||||
alias_obj.target_global_model.display_name
|
||||
if alias_obj.target_global_model
|
||||
else alias_obj.source_model
|
||||
),
|
||||
"input_price_per_1m": model.input_price_per_1m,
|
||||
"output_price_per_1m": model.output_price_per_1m,
|
||||
"cache_creation_price_per_1m": model.cache_creation_price_per_1m,
|
||||
"cache_read_price_per_1m": model.cache_read_price_per_1m,
|
||||
"supports_vision": model.supports_vision,
|
||||
"supports_function_calling": model.supports_function_calling,
|
||||
"supports_streaming": model.supports_streaming,
|
||||
}
|
||||
)
|
||||
|
||||
result.append(
|
||||
{
|
||||
"id": provider.id,
|
||||
|
||||
Reference in New Issue
Block a user