mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 08:12:26 +08:00
refactor: 完善 handler 基类类型注解和流式状态更新
- 为 BaseMessageHandler 和 MessageTelemetry 添加完整类型注解 - 新增 _update_usage_to_streaming 方法,异步更新 Usage 状态为 streaming - 优化 chat/cli handler 的类型提示,提升代码可维护性 - 修复类型检查警告,确保 mypy 通过
This commit is contained in:
@@ -207,7 +207,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
|
||||
|
||||
if mapping and mapping.model:
|
||||
mapped_name = mapping.model.provider_model_name
|
||||
mapped_name = str(mapping.model.provider_model_name)
|
||||
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
||||
return mapped_name
|
||||
|
||||
@@ -351,7 +351,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
provider: Provider,
|
||||
endpoint: ProviderEndpoint,
|
||||
key: ProviderAPIKey,
|
||||
):
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
return await self._execute_stream_request(
|
||||
ctx,
|
||||
provider,
|
||||
@@ -451,19 +451,19 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
ctx.response_metadata = {} # 重置 Provider 响应元数据
|
||||
|
||||
# 记录 Provider 信息
|
||||
ctx.provider_name = provider.name
|
||||
ctx.provider_id = provider.id
|
||||
ctx.endpoint_id = endpoint.id
|
||||
ctx.key_id = key.id
|
||||
ctx.provider_name = str(provider.name)
|
||||
ctx.provider_id = str(provider.id)
|
||||
ctx.endpoint_id = str(endpoint.id)
|
||||
ctx.key_id = str(key.id)
|
||||
|
||||
# 记录格式转换信息
|
||||
ctx.provider_api_format = endpoint.api_format if endpoint.api_format else ""
|
||||
ctx.provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
|
||||
ctx.client_api_format = ctx.api_format # 已在 process_stream 中设置
|
||||
|
||||
# 获取模型映射(别名/映射 → 实际模型名)
|
||||
mapped_model = await self._get_mapped_model(
|
||||
source_model=ctx.model,
|
||||
provider_id=provider.id,
|
||||
provider_id=str(provider.id),
|
||||
)
|
||||
|
||||
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
|
||||
@@ -575,11 +575,17 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
try:
|
||||
sse_parser = SSEEventParser()
|
||||
last_data_time = time.time()
|
||||
streaming_status_updated = False
|
||||
|
||||
# 检查是否需要格式转换
|
||||
needs_conversion = self._needs_format_conversion(ctx)
|
||||
|
||||
async for line in stream_response.aiter_lines():
|
||||
# 在第一次输出数据前更新状态为 streaming
|
||||
if not streaming_status_updated:
|
||||
self._update_usage_to_streaming(ctx.request_id)
|
||||
streaming_status_updated = True
|
||||
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
|
||||
@@ -588,7 +594,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data", ""),
|
||||
event.get("data") or "",
|
||||
)
|
||||
yield b"\n"
|
||||
continue
|
||||
@@ -622,7 +628,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data", ""),
|
||||
event.get("data") or "",
|
||||
)
|
||||
|
||||
if ctx.data_count > 0:
|
||||
@@ -633,7 +639,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data", ""),
|
||||
event.get("data") or "",
|
||||
)
|
||||
|
||||
# 检查是否收到数据
|
||||
@@ -781,7 +787,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
f"error_type={parsed.error_type}, "
|
||||
f"message={parsed.error_message}")
|
||||
raise EmbeddedErrorException(
|
||||
provider_name=provider.name,
|
||||
provider_name=str(provider.name),
|
||||
error_code=(
|
||||
int(parsed.error_type)
|
||||
if parsed.error_type and parsed.error_type.isdigit()
|
||||
@@ -819,6 +825,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
# 检查是否需要格式转换
|
||||
needs_conversion = self._needs_format_conversion(ctx)
|
||||
|
||||
# 在第一次输出数据前更新状态为 streaming
|
||||
if prefetched_lines:
|
||||
self._update_usage_to_streaming(ctx.request_id)
|
||||
|
||||
# 先处理预读的数据
|
||||
for line in prefetched_lines:
|
||||
normalized_line = line.rstrip("\r")
|
||||
@@ -829,7 +839,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data", ""),
|
||||
event.get("data") or "",
|
||||
)
|
||||
yield b"\n"
|
||||
continue
|
||||
@@ -848,7 +858,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data", ""),
|
||||
event.get("data") or "",
|
||||
)
|
||||
|
||||
if ctx.data_count > 0:
|
||||
@@ -864,7 +874,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data", ""),
|
||||
event.get("data") or "",
|
||||
)
|
||||
yield b"\n"
|
||||
continue
|
||||
@@ -898,7 +908,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data", ""),
|
||||
event.get("data") or "",
|
||||
)
|
||||
|
||||
if ctx.data_count > 0:
|
||||
@@ -910,7 +920,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data", ""),
|
||||
event.get("data") or "",
|
||||
)
|
||||
|
||||
# 检查是否收到数据
|
||||
@@ -1270,6 +1280,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
target_model=ctx.mapped_model,
|
||||
)
|
||||
|
||||
# _update_usage_to_streaming 方法已移至基类 BaseMessageHandler
|
||||
|
||||
async def process_sync(
|
||||
self,
|
||||
original_request_body: Dict[str, Any],
|
||||
@@ -1309,15 +1321,15 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
provider: Provider,
|
||||
endpoint: ProviderEndpoint,
|
||||
key: ProviderAPIKey,
|
||||
):
|
||||
) -> Dict[str, Any]:
|
||||
nonlocal provider_name, response_json, status_code, response_headers, provider_api_format, provider_request_headers, provider_request_body, mapped_model_result, response_metadata_result
|
||||
provider_name = provider.name
|
||||
provider_api_format = endpoint.api_format if endpoint.api_format else ""
|
||||
provider_name = str(provider.name)
|
||||
provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
|
||||
|
||||
# 获取模型映射(别名/映射 → 实际模型名)
|
||||
mapped_model = await self._get_mapped_model(
|
||||
source_model=model,
|
||||
provider_id=provider.id,
|
||||
provider_id=str(provider.id),
|
||||
)
|
||||
|
||||
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
|
||||
@@ -1373,7 +1385,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
elif resp.status_code == 429:
|
||||
raise ProviderRateLimitException(
|
||||
f"提供商速率限制: {provider.name}",
|
||||
provider_name=provider.name,
|
||||
provider_name=str(provider.name),
|
||||
response_headers=response_headers,
|
||||
retry_after=int(resp.headers.get("retry-after", 0)) or None,
|
||||
)
|
||||
@@ -1409,7 +1421,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
# 提取 Provider 响应元数据(子类可覆盖)
|
||||
response_metadata_result = self._extract_response_metadata(response_json)
|
||||
|
||||
return response_json
|
||||
return response_json if isinstance(response_json, dict) else {}
|
||||
|
||||
try:
|
||||
# 解析能力需求
|
||||
@@ -1437,6 +1449,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
provider_name = actual_provider_name
|
||||
response_time_ms = int((time.time() - sync_start_time) * 1000)
|
||||
|
||||
# 确保 response_json 不为 None
|
||||
if response_json is None:
|
||||
response_json = {}
|
||||
|
||||
# 检查是否需要格式转换
|
||||
if (
|
||||
provider_api_format
|
||||
|
||||
Reference in New Issue
Block a user