diff --git a/src/api/handlers/base/base_handler.py b/src/api/handlers/base/base_handler.py index a7d41c5..3796892 100644 --- a/src/api/handlers/base/base_handler.py +++ b/src/api/handlers/base/base_handler.py @@ -50,7 +50,9 @@ class MessageTelemetry: 负责记录 Usage/Audit,避免处理器里重复代码。 """ - def __init__(self, db: Session, user, api_key, request_id: str, client_ip: str): + def __init__( + self, db: Session, user: Any, api_key: Any, request_id: str, client_ip: str + ) -> None: self.db = db self.user = user self.api_key = api_key @@ -187,7 +189,7 @@ class MessageTelemetry: response_body: Optional[Dict[str, Any]] = None, # 模型映射信息 target_model: Optional[str] = None, - ): + ) -> None: """ 记录失败请求 @@ -283,15 +285,15 @@ class BaseMessageHandler: self, *, db: Session, - user, - api_key, + user: Any, + api_key: Any, request_id: str, client_ip: str, user_agent: str, start_time: float, allowed_api_formats: Optional[list[str]] = None, adapter_detector: Optional[AdapterDetectorType] = None, - ): + ) -> None: self.db = db self.user = user self.api_key = api_key @@ -304,7 +306,7 @@ class BaseMessageHandler: self.adapter_detector = adapter_detector redis_client = get_redis_client_sync() - self.orchestrator = FallbackOrchestrator(db, redis_client) + self.orchestrator = FallbackOrchestrator(db, redis_client) # type: ignore[arg-type] self.telemetry = MessageTelemetry(db, user, api_key, request_id, client_ip) def elapsed_ms(self) -> int: @@ -347,7 +349,8 @@ class BaseMessageHandler: def get_api_format(self, provider_type: Optional[str] = None) -> APIFormat: """根据 provider_type 解析 API 格式,未知类型默认 OPENAI""" if provider_type: - return resolve_api_format(provider_type, default=APIFormat.OPENAI) + result = resolve_api_format(provider_type, default=APIFormat.OPENAI) + return result or APIFormat.OPENAI return self.primary_api_format def build_provider_payload( @@ -361,3 +364,34 @@ class BaseMessageHandler: if mapped_model: payload["model"] = mapped_model return payload + + def _update_usage_to_streaming(self, request_id: Optional[str] = None) -> None: + """更新 Usage 状态为 streaming(流式传输开始时调用) + + 使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输 + + Args: + request_id: 请求 ID,如果不传则使用 self.request_id + """ + import asyncio + from src.database.database import get_db + + target_request_id = request_id or self.request_id + + async def _do_update() -> None: + try: + db_gen = get_db() + db = next(db_gen) + try: + UsageService.update_usage_status( + db=db, + request_id=target_request_id, + status="streaming", + ) + finally: + db.close() + except Exception as e: + logger.warning(f"[{target_request_id}] 更新 Usage 状态为 streaming 失败: {e}") + + # 创建后台任务,不阻塞当前流 + asyncio.create_task(_do_update()) diff --git a/src/api/handlers/base/chat_handler_base.py b/src/api/handlers/base/chat_handler_base.py index 5062964..4490d82 100644 --- a/src/api/handlers/base/chat_handler_base.py +++ b/src/api/handlers/base/chat_handler_base.py @@ -116,7 +116,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC): # ==================== 抽象方法 ==================== @abstractmethod - async def _convert_request(self, request): + async def _convert_request(self, request: Any) -> Any: """ 将请求转换为目标格式 @@ -261,7 +261,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC): mapping = await mapper.get_mapping(source_model, provider_id) if mapping and mapping.model: - mapped_name = mapping.model.provider_model_name + mapped_name = str(mapping.model.provider_model_name) logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}") return mapped_name @@ -271,10 +271,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC): async def process_stream( self, - request, + request: Any, http_request: Request, - original_headers: Dict, - original_request_body: Dict, + original_headers: Dict[str, Any], + original_request_body: Dict[str, Any], query_params: Optional[Dict[str, str]] = None, ) -> StreamingResponse: """处理流式响应""" @@ -315,7 +315,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC): provider: Provider, endpoint: ProviderEndpoint, key: ProviderAPIKey, - ): + ) -> AsyncGenerator[bytes, None]: return await self._execute_stream_request( ctx, provider, @@ -401,16 +401,16 @@ class ChatHandlerBase(BaseMessageHandler, ABC): ctx["cached_tokens"] = 0 ctx["cache_creation_tokens"] = 0 - ctx["provider_name"] = provider.name - ctx["provider_id"] = provider.id - ctx["endpoint_id"] = endpoint.id - ctx["key_id"] = key.id - ctx["provider_api_format"] = endpoint.api_format # Provider 的响应格式 + ctx["provider_name"] = str(provider.name) + ctx["provider_id"] = str(provider.id) + ctx["endpoint_id"] = str(endpoint.id) + ctx["key_id"] = str(key.id) + ctx["provider_api_format"] = str(endpoint.api_format) if endpoint.api_format else "" # 获取模型映射 mapped_model = await self._get_mapped_model( source_model=ctx["model"], - provider_id=provider.id, + provider_id=str(provider.id), ) # 应用模型映射到请求体 @@ -514,14 +514,20 @@ class ChatHandlerBase(BaseMessageHandler, ABC): """创建响应流生成器""" try: sse_parser = SSEEventParser() + streaming_status_updated = False async for line in stream_response.aiter_lines(): + # 在第一次输出数据前更新状态为 streaming + if not streaming_status_updated: + self._update_usage_to_streaming() + streaming_status_updated = True + normalized_line = line.rstrip("\r") events = sse_parser.feed_line(normalized_line) if normalized_line == "": for event in events: - self._handle_sse_event(ctx, event.get("event"), event.get("data", "")) + self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") yield b"\n" continue @@ -530,11 +536,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC): yield (line + "\n").encode("utf-8") for event in events: - self._handle_sse_event(ctx, event.get("event"), event.get("data", "")) + self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") # 处理剩余事件 for event in sse_parser.flush(): - self._handle_sse_event(ctx, event.get("event"), event.get("data", "")) + self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") except GeneratorExit: raise @@ -618,7 +624,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC): f"error_type={parsed.error_type}, " f"message={parsed.error_message}") raise EmbeddedErrorException( - provider_name=provider.name, + provider_name=str(provider.name), error_code=( int(parsed.error_type) if parsed.error_type and parsed.error_type.isdigit() @@ -652,6 +658,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC): try: sse_parser = SSEEventParser() + # 在第一次输出数据前更新状态为 streaming + if prefetched_lines: + self._update_usage_to_streaming() + # 先输出预读的数据 for line in prefetched_lines: normalized_line = line.rstrip("\r") @@ -659,7 +669,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC): if normalized_line == "": for event in events: - self._handle_sse_event(ctx, event.get("event"), event.get("data", "")) + self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") yield b"\n" continue @@ -667,7 +677,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC): yield (line + "\n").encode("utf-8") for event in events: - self._handle_sse_event(ctx, event.get("event"), event.get("data", "")) + self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") # 继续输出剩余的流数据(使用同一个迭代器) async for line in line_iterator: @@ -676,7 +686,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC): if normalized_line == "": for event in events: - self._handle_sse_event(ctx, event.get("event"), event.get("data", "")) + self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") yield b"\n" continue @@ -684,11 +694,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC): yield (line + "\n").encode("utf-8") for event in events: - self._handle_sse_event(ctx, event.get("event"), event.get("data", "")) + self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") # 处理剩余事件 for event in sse_parser.flush(): - self._handle_sse_event(ctx, event.get("event"), event.get("data", "")) + self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") except GeneratorExit: raise @@ -853,8 +863,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC): # 根据状态码决定记录成功还是失败 # 499 = 客户端断开连接,503 = 服务不可用(如流中断) - status_code = ctx.get("status_code") - if status_code and status_code >= 400: + status_code: int = ctx.get("status_code") or 200 + if status_code >= 400: # 记录失败的 Usage,但使用已收到的预估 token 信息(来自 message_start) # 这样即使请求中断,也能记录预估成本 await bg_telemetry.record_failure( @@ -955,6 +965,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC): error_message=f"记录统计信息失败: {str(e)[:200]}", ) + # _update_usage_to_streaming 方法已移至基类 BaseMessageHandler + async def _update_usage_status_on_error( self, response_time_ms: int, @@ -991,11 +1003,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC): usage = db.query(Usage).filter(Usage.request_id == self.request_id).first() if usage: - usage.status = status - usage.status_code = status_code - usage.response_time_ms = response_time_ms + setattr(usage, "status", status) + setattr(usage, "status_code", status_code) + setattr(usage, "response_time_ms", response_time_ms) if error_message: - usage.error_message = error_message + setattr(usage, "error_message", error_message) db.commit() logger.debug(f"[{self.request_id}] Usage 状态已更新: {status}") except Exception as e: @@ -1040,10 +1052,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC): async def process_sync( self, - request, + request: Any, http_request: Request, - original_headers: Dict, - original_request_body: Dict, + original_headers: Dict[str, Any], + original_request_body: Dict[str, Any], query_params: Optional[Dict[str, str]] = None, ) -> JSONResponse: """处理非流式响应""" @@ -1055,31 +1067,31 @@ class ChatHandlerBase(BaseMessageHandler, ABC): api_format = self.allowed_api_formats[0] # 用于跟踪的变量 - provider_name = None - response_json = None + provider_name: Optional[str] = None + response_json: Optional[Dict[str, Any]] = None status_code = 200 - response_headers = {} - provider_request_headers = {} - provider_request_body = None - provider_id = None # Provider ID(用于失败记录) - endpoint_id = None # Endpoint ID(用于失败记录) - key_id = None # Key ID(用于失败记录) - mapped_model_result = None # 映射后的目标模型名(用于 Usage 记录) + response_headers: Dict[str, str] = {} + provider_request_headers: Dict[str, str] = {} + provider_request_body: Optional[Dict[str, Any]] = None + provider_id: Optional[str] = None # Provider ID(用于失败记录) + endpoint_id: Optional[str] = None # Endpoint ID(用于失败记录) + key_id: Optional[str] = None # Key ID(用于失败记录) + mapped_model_result: Optional[str] = None # 映射后的目标模型名(用于 Usage 记录) async def sync_request_func( provider: Provider, endpoint: ProviderEndpoint, key: ProviderAPIKey, - ): + ) -> Dict[str, Any]: nonlocal provider_name, response_json, status_code, response_headers nonlocal provider_request_headers, provider_request_body, mapped_model_result - provider_name = provider.name + provider_name = str(provider.name) # 获取模型映射 mapped_model = await self._get_mapped_model( source_model=model, - provider_id=provider.id, + provider_id=str(provider.id), ) # 应用模型映射 @@ -1131,7 +1143,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC): elif resp.status_code == 429: raise ProviderRateLimitException( f"提供商速率限制: {provider.name}", - provider_name=provider.name, + provider_name=str(provider.name), response_headers=response_headers, ) elif resp.status_code >= 500: @@ -1142,7 +1154,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC): ) response_json = resp.json() - return response_json + return response_json if isinstance(response_json, dict) else {} try: # 解析能力需求 @@ -1170,6 +1182,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC): provider_name = actual_provider_name response_time_ms = self.elapsed_ms() + # 确保 response_json 不为 None + if response_json is None: + response_json = {} + # 规范化响应 response_json = self._normalize_response(response_json) diff --git a/src/api/handlers/base/cli_handler_base.py b/src/api/handlers/base/cli_handler_base.py index b893cd9..6f9947e 100644 --- a/src/api/handlers/base/cli_handler_base.py +++ b/src/api/handlers/base/cli_handler_base.py @@ -207,7 +207,7 @@ class CliMessageHandlerBase(BaseMessageHandler): logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}") if mapping and mapping.model: - mapped_name = mapping.model.provider_model_name + mapped_name = str(mapping.model.provider_model_name) logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)") return mapped_name @@ -351,7 +351,7 @@ class CliMessageHandlerBase(BaseMessageHandler): provider: Provider, endpoint: ProviderEndpoint, key: ProviderAPIKey, - ): + ) -> AsyncGenerator[bytes, None]: return await self._execute_stream_request( ctx, provider, @@ -451,19 +451,19 @@ class CliMessageHandlerBase(BaseMessageHandler): ctx.response_metadata = {} # 重置 Provider 响应元数据 # 记录 Provider 信息 - ctx.provider_name = provider.name - ctx.provider_id = provider.id - ctx.endpoint_id = endpoint.id - ctx.key_id = key.id + ctx.provider_name = str(provider.name) + ctx.provider_id = str(provider.id) + ctx.endpoint_id = str(endpoint.id) + ctx.key_id = str(key.id) # 记录格式转换信息 - ctx.provider_api_format = endpoint.api_format if endpoint.api_format else "" + ctx.provider_api_format = str(endpoint.api_format) if endpoint.api_format else "" ctx.client_api_format = ctx.api_format # 已在 process_stream 中设置 # 获取模型映射(别名/映射 → 实际模型名) mapped_model = await self._get_mapped_model( source_model=ctx.model, - provider_id=provider.id, + provider_id=str(provider.id), ) # 应用模型映射到请求体(子类可覆盖此方法处理不同格式) @@ -575,11 +575,17 @@ class CliMessageHandlerBase(BaseMessageHandler): try: sse_parser = SSEEventParser() last_data_time = time.time() + streaming_status_updated = False # 检查是否需要格式转换 needs_conversion = self._needs_format_conversion(ctx) async for line in stream_response.aiter_lines(): + # 在第一次输出数据前更新状态为 streaming + if not streaming_status_updated: + self._update_usage_to_streaming(ctx.request_id) + streaming_status_updated = True + normalized_line = line.rstrip("\r") events = sse_parser.feed_line(normalized_line) @@ -588,7 +594,7 @@ class CliMessageHandlerBase(BaseMessageHandler): self._handle_sse_event( ctx, event.get("event"), - event.get("data", ""), + event.get("data") or "", ) yield b"\n" continue @@ -622,7 +628,7 @@ class CliMessageHandlerBase(BaseMessageHandler): self._handle_sse_event( ctx, event.get("event"), - event.get("data", ""), + event.get("data") or "", ) if ctx.data_count > 0: @@ -633,7 +639,7 @@ class CliMessageHandlerBase(BaseMessageHandler): self._handle_sse_event( ctx, event.get("event"), - event.get("data", ""), + event.get("data") or "", ) # 检查是否收到数据 @@ -781,7 +787,7 @@ class CliMessageHandlerBase(BaseMessageHandler): f"error_type={parsed.error_type}, " f"message={parsed.error_message}") raise EmbeddedErrorException( - provider_name=provider.name, + provider_name=str(provider.name), error_code=( int(parsed.error_type) if parsed.error_type and parsed.error_type.isdigit() @@ -819,6 +825,10 @@ class CliMessageHandlerBase(BaseMessageHandler): # 检查是否需要格式转换 needs_conversion = self._needs_format_conversion(ctx) + # 在第一次输出数据前更新状态为 streaming + if prefetched_lines: + self._update_usage_to_streaming(ctx.request_id) + # 先处理预读的数据 for line in prefetched_lines: normalized_line = line.rstrip("\r") @@ -829,7 +839,7 @@ class CliMessageHandlerBase(BaseMessageHandler): self._handle_sse_event( ctx, event.get("event"), - event.get("data", ""), + event.get("data") or "", ) yield b"\n" continue @@ -848,7 +858,7 @@ class CliMessageHandlerBase(BaseMessageHandler): self._handle_sse_event( ctx, event.get("event"), - event.get("data", ""), + event.get("data") or "", ) if ctx.data_count > 0: @@ -864,7 +874,7 @@ class CliMessageHandlerBase(BaseMessageHandler): self._handle_sse_event( ctx, event.get("event"), - event.get("data", ""), + event.get("data") or "", ) yield b"\n" continue @@ -898,7 +908,7 @@ class CliMessageHandlerBase(BaseMessageHandler): self._handle_sse_event( ctx, event.get("event"), - event.get("data", ""), + event.get("data") or "", ) if ctx.data_count > 0: @@ -910,7 +920,7 @@ class CliMessageHandlerBase(BaseMessageHandler): self._handle_sse_event( ctx, event.get("event"), - event.get("data", ""), + event.get("data") or "", ) # 检查是否收到数据 @@ -1270,6 +1280,8 @@ class CliMessageHandlerBase(BaseMessageHandler): target_model=ctx.mapped_model, ) + # _update_usage_to_streaming 方法已移至基类 BaseMessageHandler + async def process_sync( self, original_request_body: Dict[str, Any], @@ -1309,15 +1321,15 @@ class CliMessageHandlerBase(BaseMessageHandler): provider: Provider, endpoint: ProviderEndpoint, key: ProviderAPIKey, - ): + ) -> Dict[str, Any]: nonlocal provider_name, response_json, status_code, response_headers, provider_api_format, provider_request_headers, provider_request_body, mapped_model_result, response_metadata_result - provider_name = provider.name - provider_api_format = endpoint.api_format if endpoint.api_format else "" + provider_name = str(provider.name) + provider_api_format = str(endpoint.api_format) if endpoint.api_format else "" # 获取模型映射(别名/映射 → 实际模型名) mapped_model = await self._get_mapped_model( source_model=model, - provider_id=provider.id, + provider_id=str(provider.id), ) # 应用模型映射到请求体(子类可覆盖此方法处理不同格式) @@ -1373,7 +1385,7 @@ class CliMessageHandlerBase(BaseMessageHandler): elif resp.status_code == 429: raise ProviderRateLimitException( f"提供商速率限制: {provider.name}", - provider_name=provider.name, + provider_name=str(provider.name), response_headers=response_headers, retry_after=int(resp.headers.get("retry-after", 0)) or None, ) @@ -1409,7 +1421,7 @@ class CliMessageHandlerBase(BaseMessageHandler): # 提取 Provider 响应元数据(子类可覆盖) response_metadata_result = self._extract_response_metadata(response_json) - return response_json + return response_json if isinstance(response_json, dict) else {} try: # 解析能力需求 @@ -1437,6 +1449,10 @@ class CliMessageHandlerBase(BaseMessageHandler): provider_name = actual_provider_name response_time_ms = int((time.time() - sync_start_time) * 1000) + # 确保 response_json 不为 None + if response_json is None: + response_json = {} + # 检查是否需要格式转换 if ( provider_api_format