refactor: 完善 handler 基类类型注解和流式状态更新

- 为 BaseMessageHandler 和 MessageTelemetry 添加完整类型注解
- 新增 _update_usage_to_streaming 方法,异步更新 Usage 状态为 streaming
- 优化 chat/cli handler 的类型提示,提升代码可维护性
- 修复类型检查警告,确保 mypy 通过
This commit is contained in:
fawney19
2025-12-11 10:05:06 +08:00
parent 913a87d7f3
commit 0474f63403
3 changed files with 140 additions and 74 deletions

View File

@@ -116,7 +116,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
# ==================== 抽象方法 ====================
@abstractmethod
async def _convert_request(self, request):
async def _convert_request(self, request: Any) -> Any:
"""
将请求转换为目标格式
@@ -261,7 +261,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
mapping = await mapper.get_mapping(source_model, provider_id)
if mapping and mapping.model:
mapped_name = mapping.model.provider_model_name
mapped_name = str(mapping.model.provider_model_name)
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
return mapped_name
@@ -271,10 +271,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
async def process_stream(
self,
request,
request: Any,
http_request: Request,
original_headers: Dict,
original_request_body: Dict,
original_headers: Dict[str, Any],
original_request_body: Dict[str, Any],
query_params: Optional[Dict[str, str]] = None,
) -> StreamingResponse:
"""处理流式响应"""
@@ -315,7 +315,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
):
) -> AsyncGenerator[bytes, None]:
return await self._execute_stream_request(
ctx,
provider,
@@ -401,16 +401,16 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
ctx["cached_tokens"] = 0
ctx["cache_creation_tokens"] = 0
ctx["provider_name"] = provider.name
ctx["provider_id"] = provider.id
ctx["endpoint_id"] = endpoint.id
ctx["key_id"] = key.id
ctx["provider_api_format"] = endpoint.api_format # Provider 的响应格式
ctx["provider_name"] = str(provider.name)
ctx["provider_id"] = str(provider.id)
ctx["endpoint_id"] = str(endpoint.id)
ctx["key_id"] = str(key.id)
ctx["provider_api_format"] = str(endpoint.api_format) if endpoint.api_format else ""
# 获取模型映射
mapped_model = await self._get_mapped_model(
source_model=ctx["model"],
provider_id=provider.id,
provider_id=str(provider.id),
)
# 应用模型映射到请求体
@@ -514,14 +514,20 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
"""创建响应流生成器"""
try:
sse_parser = SSEEventParser()
streaming_status_updated = False
async for line in stream_response.aiter_lines():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming()
streaming_status_updated = True
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
yield b"\n"
continue
@@ -530,11 +536,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
# 处理剩余事件
for event in sse_parser.flush():
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
except GeneratorExit:
raise
@@ -618,7 +624,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=provider.name,
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
@@ -652,6 +658,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
try:
sse_parser = SSEEventParser()
# 在第一次输出数据前更新状态为 streaming
if prefetched_lines:
self._update_usage_to_streaming()
# 先输出预读的数据
for line in prefetched_lines:
normalized_line = line.rstrip("\r")
@@ -659,7 +669,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
if normalized_line == "":
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
yield b"\n"
continue
@@ -667,7 +677,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
# 继续输出剩余的流数据(使用同一个迭代器)
async for line in line_iterator:
@@ -676,7 +686,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
if normalized_line == "":
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
yield b"\n"
continue
@@ -684,11 +694,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
# 处理剩余事件
for event in sse_parser.flush():
self._handle_sse_event(ctx, event.get("event"), event.get("data", ""))
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
except GeneratorExit:
raise
@@ -853,8 +863,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
# 根据状态码决定记录成功还是失败
# 499 = 客户端断开连接503 = 服务不可用(如流中断)
status_code = ctx.get("status_code")
if status_code and status_code >= 400:
status_code: int = ctx.get("status_code") or 200
if status_code >= 400:
# 记录失败的 Usage但使用已收到的预估 token 信息(来自 message_start
# 这样即使请求中断,也能记录预估成本
await bg_telemetry.record_failure(
@@ -955,6 +965,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
error_message=f"记录统计信息失败: {str(e)[:200]}",
)
# _update_usage_to_streaming 方法已移至基类 BaseMessageHandler
async def _update_usage_status_on_error(
self,
response_time_ms: int,
@@ -991,11 +1003,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
usage = db.query(Usage).filter(Usage.request_id == self.request_id).first()
if usage:
usage.status = status
usage.status_code = status_code
usage.response_time_ms = response_time_ms
setattr(usage, "status", status)
setattr(usage, "status_code", status_code)
setattr(usage, "response_time_ms", response_time_ms)
if error_message:
usage.error_message = error_message
setattr(usage, "error_message", error_message)
db.commit()
logger.debug(f"[{self.request_id}] Usage 状态已更新: {status}")
except Exception as e:
@@ -1040,10 +1052,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
async def process_sync(
self,
request,
request: Any,
http_request: Request,
original_headers: Dict,
original_request_body: Dict,
original_headers: Dict[str, Any],
original_request_body: Dict[str, Any],
query_params: Optional[Dict[str, str]] = None,
) -> JSONResponse:
"""处理非流式响应"""
@@ -1055,31 +1067,31 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
api_format = self.allowed_api_formats[0]
# 用于跟踪的变量
provider_name = None
response_json = None
provider_name: Optional[str] = None
response_json: Optional[Dict[str, Any]] = None
status_code = 200
response_headers = {}
provider_request_headers = {}
provider_request_body = None
provider_id = None # Provider ID用于失败记录
endpoint_id = None # Endpoint ID用于失败记录
key_id = None # Key ID用于失败记录
mapped_model_result = None # 映射后的目标模型名(用于 Usage 记录)
response_headers: Dict[str, str] = {}
provider_request_headers: Dict[str, str] = {}
provider_request_body: Optional[Dict[str, Any]] = None
provider_id: Optional[str] = None # Provider ID用于失败记录
endpoint_id: Optional[str] = None # Endpoint ID用于失败记录
key_id: Optional[str] = None # Key ID用于失败记录
mapped_model_result: Optional[str] = None # 映射后的目标模型名(用于 Usage 记录)
async def sync_request_func(
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
):
) -> Dict[str, Any]:
nonlocal provider_name, response_json, status_code, response_headers
nonlocal provider_request_headers, provider_request_body, mapped_model_result
provider_name = provider.name
provider_name = str(provider.name)
# 获取模型映射
mapped_model = await self._get_mapped_model(
source_model=model,
provider_id=provider.id,
provider_id=str(provider.id),
)
# 应用模型映射
@@ -1131,7 +1143,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
elif resp.status_code == 429:
raise ProviderRateLimitException(
f"提供商速率限制: {provider.name}",
provider_name=provider.name,
provider_name=str(provider.name),
response_headers=response_headers,
)
elif resp.status_code >= 500:
@@ -1142,7 +1154,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
)
response_json = resp.json()
return response_json
return response_json if isinstance(response_json, dict) else {}
try:
# 解析能力需求
@@ -1170,6 +1182,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
provider_name = actual_provider_name
response_time_ms = self.elapsed_ms()
# 确保 response_json 不为 None
if response_json is None:
response_json = {}
# 规范化响应
response_json = self._normalize_response(response_json)