refactor: 完善 handler 基类类型注解和流式状态更新

- 为 BaseMessageHandler 和 MessageTelemetry 添加完整类型注解
- 新增 _update_usage_to_streaming 方法,异步更新 Usage 状态为 streaming
- 优化 chat/cli handler 的类型提示,提升代码可维护性
- 修复类型检查警告,确保 mypy 通过
This commit is contained in:
fawney19
2025-12-11 10:05:06 +08:00
parent 913a87d7f3
commit 0474f63403
3 changed files with 140 additions and 74 deletions

View File

@@ -207,7 +207,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
if mapping and mapping.model:
mapped_name = mapping.model.provider_model_name
mapped_name = str(mapping.model.provider_model_name)
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
return mapped_name
@@ -351,7 +351,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
):
) -> AsyncGenerator[bytes, None]:
return await self._execute_stream_request(
ctx,
provider,
@@ -451,19 +451,19 @@ class CliMessageHandlerBase(BaseMessageHandler):
ctx.response_metadata = {} # 重置 Provider 响应元数据
# 记录 Provider 信息
ctx.provider_name = provider.name
ctx.provider_id = provider.id
ctx.endpoint_id = endpoint.id
ctx.key_id = key.id
ctx.provider_name = str(provider.name)
ctx.provider_id = str(provider.id)
ctx.endpoint_id = str(endpoint.id)
ctx.key_id = str(key.id)
# 记录格式转换信息
ctx.provider_api_format = endpoint.api_format if endpoint.api_format else ""
ctx.provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
ctx.client_api_format = ctx.api_format # 已在 process_stream 中设置
# 获取模型映射(别名/映射 → 实际模型名)
mapped_model = await self._get_mapped_model(
source_model=ctx.model,
provider_id=provider.id,
provider_id=str(provider.id),
)
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
@@ -575,11 +575,17 @@ class CliMessageHandlerBase(BaseMessageHandler):
try:
sse_parser = SSEEventParser()
last_data_time = time.time()
streaming_status_updated = False
# 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx)
async for line in stream_response.aiter_lines():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming(ctx.request_id)
streaming_status_updated = True
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
@@ -588,7 +594,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data", ""),
event.get("data") or "",
)
yield b"\n"
continue
@@ -622,7 +628,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data", ""),
event.get("data") or "",
)
if ctx.data_count > 0:
@@ -633,7 +639,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data", ""),
event.get("data") or "",
)
# 检查是否收到数据
@@ -781,7 +787,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=provider.name,
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
@@ -819,6 +825,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx)
# 在第一次输出数据前更新状态为 streaming
if prefetched_lines:
self._update_usage_to_streaming(ctx.request_id)
# 先处理预读的数据
for line in prefetched_lines:
normalized_line = line.rstrip("\r")
@@ -829,7 +839,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data", ""),
event.get("data") or "",
)
yield b"\n"
continue
@@ -848,7 +858,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data", ""),
event.get("data") or "",
)
if ctx.data_count > 0:
@@ -864,7 +874,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data", ""),
event.get("data") or "",
)
yield b"\n"
continue
@@ -898,7 +908,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data", ""),
event.get("data") or "",
)
if ctx.data_count > 0:
@@ -910,7 +920,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data", ""),
event.get("data") or "",
)
# 检查是否收到数据
@@ -1270,6 +1280,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
target_model=ctx.mapped_model,
)
# _update_usage_to_streaming 方法已移至基类 BaseMessageHandler
async def process_sync(
self,
original_request_body: Dict[str, Any],
@@ -1309,15 +1321,15 @@ class CliMessageHandlerBase(BaseMessageHandler):
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
):
) -> Dict[str, Any]:
nonlocal provider_name, response_json, status_code, response_headers, provider_api_format, provider_request_headers, provider_request_body, mapped_model_result, response_metadata_result
provider_name = provider.name
provider_api_format = endpoint.api_format if endpoint.api_format else ""
provider_name = str(provider.name)
provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
# 获取模型映射(别名/映射 → 实际模型名)
mapped_model = await self._get_mapped_model(
source_model=model,
provider_id=provider.id,
provider_id=str(provider.id),
)
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
@@ -1373,7 +1385,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
elif resp.status_code == 429:
raise ProviderRateLimitException(
f"提供商速率限制: {provider.name}",
provider_name=provider.name,
provider_name=str(provider.name),
response_headers=response_headers,
retry_after=int(resp.headers.get("retry-after", 0)) or None,
)
@@ -1409,7 +1421,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 提取 Provider 响应元数据(子类可覆盖)
response_metadata_result = self._extract_response_metadata(response_json)
return response_json
return response_json if isinstance(response_json, dict) else {}
try:
# 解析能力需求
@@ -1437,6 +1449,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
provider_name = actual_provider_name
response_time_ms = int((time.time() - sync_start_time) * 1000)
# 确保 response_json 不为 None
if response_json is None:
response_json = {}
# 检查是否需要格式转换
if (
provider_api_format