""" 流式响应用量统计服务 处理流式响应的token计算和使用量记录 """ import json import re from typing import Any, AsyncIterator, Dict, Optional, Tuple from sqlalchemy.orm import Session from src.api.handlers.base.parsers import get_parser_for_format from src.api.handlers.base.response_parser import StreamStats from src.core.exceptions import EmptyStreamException from src.core.logger import logger from src.database.database import create_session from src.models.database import ApiKey, User from src.services.usage.service import UsageService class StreamUsageTracker: """流式响应用量跟踪器""" def __init__( self, db: Session, user: User, api_key: ApiKey, provider: str, model: str, request_headers: Optional[Dict[str, Any]] = None, provider_request_headers: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, start_time: Optional[float] = None, attempt_id: Optional[str] = None, # Provider 侧追踪信息(用于记录真实成本) provider_id: Optional[str] = None, provider_endpoint_id: Optional[str] = None, provider_api_key_id: Optional[str] = None, # API 格式(用于选择正确的响应解析器) api_format: Optional[str] = None, ): """ 初始化流式用量跟踪器 Args: db: 数据库会话 user: 用户对象 api_key: API密钥对象 provider: 提供商名称 model: 模型名称 request_headers: 实际的请求头 provider_request_headers: 向提供商发送的请求头 request_id: 请求ID(用于日志关联) start_time: 请求开始时间(用于计算总响应时间) attempt_id: RequestTrace 请求尝试ID provider_id: Provider ID(用于记录真实成本) provider_endpoint_id: Endpoint ID(用于记录真实成本) provider_api_key_id: API Key ID(用于记录真实成本) api_format: API 格式(CLAUDE, CLAUDE_CLI, OPENAI, OPENAI_CLI) """ self.db = db # 只存储ID,避免会话绑定问题 self.user_id = user.id if user else None self.api_key_id = api_key.id if api_key else None self.provider = provider self.model = model self.request_headers = request_headers or {} self.provider_request_headers = provider_request_headers or {} self.request_id = request_id self.request_start_time = start_time # Provider 侧追踪信息 self.provider_id = provider_id self.provider_endpoint_id = provider_endpoint_id self.provider_api_key_id = provider_api_key_id # API 格式和响应解析器 self.api_format = api_format or "CLAUDE" self.response_parser = get_parser_for_format(self.api_format) self.stream_stats = StreamStats() # 解析器统计信息 # Token计数 self.input_tokens = 0 self.output_tokens = 0 self.cache_creation_input_tokens = 0 self.cache_read_input_tokens = 0 self.accumulated_content = "" # 完整响应跟踪(仅用于内部统计,不记录到数据库) self.complete_response = { "id": None, "type": "message", "role": "assistant", "content": [], "model": model, "stop_reason": None, "stop_sequence": None, "usage": {}, } self.response_chunks = [] # 保存所有原始响应块 self.raw_chunks = [] # 保存所有原始字节流(用于错误诊断) # 时间跟踪 self.start_time = None self.end_time = None # 响应头 (将在track_stream中设置) self.response_headers = {} # SSE解析缓冲区 self.buffer = b"" # 用于处理不完整的字节流 self.current_line = "" # 用于累积SSE行 self.sse_event_buffer = { "event": None, "data": [], "id": None, "retry": None, } # SSE事件缓冲 # 错误状态跟踪 self.status_code = 200 # 默认成功状态码 self.error_message = None # 错误消息(如果有) self.attempt_id = attempt_id def set_error_status(self, status_code: int, error_message: str): """ 设置错误状态 Args: status_code: HTTP状态码 error_message: 错误消息 """ self.status_code = status_code self.error_message = error_message logger.debug(f"ID:{self.request_id} | 流式响应错误状态已设置 | 状态码:{status_code} | 错误:{error_message[:100]}") def _update_complete_response(self, chunk: Dict[str, Any]): """根据响应块更新完整响应结构""" try: # 更新响应ID if chunk.get("id"): self.complete_response["id"] = chunk["id"] # 更新模型 if chunk.get("model"): self.complete_response["model"] = chunk["model"] # 处理不同类型的事件 event_type = chunk.get("type") if event_type == "message_start": # 消息开始事件 message = chunk.get("message", {}) if message.get("id"): self.complete_response["id"] = message["id"] if message.get("model"): self.complete_response["model"] = message["model"] self.complete_response["usage"] = message.get("usage", {}) elif event_type == "content_block_start": # 内容块开始 content_block = chunk.get("content_block", {}) self.complete_response["content"].append(content_block) elif event_type == "content_block_delta": # 内容块增量更新 index = chunk.get("index", 0) delta = chunk.get("delta", {}) # 确保content列表有足够的元素 while len(self.complete_response["content"]) <= index: self.complete_response["content"].append({"type": "text", "text": ""}) current_block = self.complete_response["content"][index] if delta.get("type") == "text_delta": # 文本增量 if current_block.get("type") == "text": current_block["text"] = current_block.get("text", "") + delta.get( "text", "" ) elif delta.get("type") == "input_json_delta": # 工具调用输入增量 if current_block.get("type") == "tool_use": current_input = current_block.get("input", {}) if isinstance(current_input, str): current_input += delta.get("partial_json", "") current_block["input"] = current_input elif event_type == "content_block_stop": # 内容块结束 pass elif event_type == "message_delta": # 消息级别的增量更新 delta = chunk.get("delta", {}) if delta.get("stop_reason"): self.complete_response["stop_reason"] = delta["stop_reason"] if delta.get("stop_sequence"): self.complete_response["stop_sequence"] = delta["stop_sequence"] elif event_type == "message_stop": # 消息结束 pass # 更新usage信息 if chunk.get("usage"): self.complete_response["usage"].update(chunk["usage"]) except Exception as e: # 记录错误但不中断流处理 logger.warning(f"Failed to update complete response: {e}") def estimate_input_tokens(self, messages: list) -> int: """ 估算输入tokens Args: messages: 消息列表 Returns: 估算的token数 """ total_chars = 0 for msg in messages: if isinstance(msg, dict): content = msg.get("content", "") if isinstance(content, str): total_chars += len(content) elif isinstance(content, list): for block in content: if hasattr(block, "text"): total_chars += len(block.text) # 粗略估算:4个字符约等于1个token return max(1, total_chars // 4) def _process_sse_event(self) -> Tuple[Optional[str], Optional[Dict[str, Any]]]: """ 处理缓冲区中的完整SSE事件 Returns: (内容文本, 使用信息) """ content = None usage = None # 如果没有data,直接返回 if not self.sse_event_buffer["data"]: return None, None # 合并所有data行(根据SSE规范,每个data行之间加入换行符) data_str = "\n".join(self.sse_event_buffer["data"]) # 清空缓冲区 self.sse_event_buffer = {"event": None, "data": [], "id": None, "retry": None} if not data_str or data_str == "[DONE]": return None, None try: data = json.loads(data_str) if isinstance(data, dict): self.response_chunks.append(data) try: self._update_complete_response(data) except Exception as update_error: logger.warning(f"Failed to update complete response from chunk: {update_error}") # Claude格式 if "type" in data: if data["type"] == "content_block_delta": delta = data.get("delta", {}) content = delta.get("text", "") if content: logger.debug(f"Extracted content from delta: {len(content)} chars") elif data["type"] == "message_delta": usage_data = data.get("usage", {}) if usage_data: usage = usage_data logger.debug(f"Extracted usage from message_delta: {usage}") elif data["type"] == "message_stop": logger.debug("Received message_stop event") # OpenAI格式 elif "choices" in data: choices = data.get("choices", []) if choices: delta = choices[0].get("delta", {}) content = delta.get("content", "") if "usage" in data: usage = data["usage"] except json.JSONDecodeError as e: # 更详细的JSON解析错误日志 logger.warning(f"Failed to parse SSE JSON data: {str(e)}") return None, None except Exception as e: logger.error(f"Unexpected error processing SSE event: {type(e).__name__}: {str(e)}") return content, usage def parse_sse_line(self, line: str) -> Tuple[Optional[str], Optional[Dict[str, Any]]]: """ 解析单行SSE事件(使用统一响应解析器) Args: line: SSE格式的一行 Returns: (内容文本, 使用信息) - 当遇到空行时处理完整事件 """ # 使用统一的响应解析器 chunk = self.response_parser.parse_sse_line(line, self.stream_stats) if chunk is None: return None, None # 从 ParsedChunk 中提取内容和使用信息 content = chunk.text_delta # 构建 usage 字典(如果有 token 信息) usage = None if ( chunk.input_tokens or chunk.output_tokens or chunk.cache_creation_tokens or chunk.cache_read_tokens ): usage = { "input_tokens": chunk.input_tokens or self.stream_stats.input_tokens, "output_tokens": chunk.output_tokens or self.stream_stats.output_tokens, "cache_creation_input_tokens": chunk.cache_creation_tokens or self.stream_stats.cache_creation_tokens, "cache_read_input_tokens": chunk.cache_read_tokens or self.stream_stats.cache_read_tokens, } # 更新响应 ID if chunk.response_id and not self.complete_response.get("id"): self.complete_response["id"] = chunk.response_id # 更新完整响应(如果有数据) if chunk.data: self.response_chunks.append(chunk.data) try: self._update_complete_response(chunk.data) except Exception as update_error: logger.warning(f"Failed to update complete response from chunk: {update_error}") return content, usage def parse_stream_chunk(self, chunk: bytes) -> Tuple[Optional[str], Optional[Dict[str, Any]]]: """ 解析流式响应块(处理原始字节流) Args: chunk: 原始字节流 Returns: (累积的内容文本, 使用信息) """ total_content = "" final_usage = None # 将新chunk添加到缓冲区 # 确保 chunk 是字节类型 if isinstance(chunk, str): chunk = chunk.encode("utf-8") self.buffer += chunk # 尝试解码并处理完整的行 try: # 尝试解码整个缓冲区 text = self.buffer.decode("utf-8") self.buffer = b"" # 清空缓冲区 # 将文本添加到当前行 self.current_line += text # 按换行符分割,处理完整的行 lines = self.current_line.split("\n") # 最后一个可能是不完整的行,保留它 self.current_line = lines[-1] # 处理完整的行 for line in lines[:-1]: line = line.rstrip("\r") content, usage = self.parse_sse_line(line) if content: total_content += content if usage: final_usage = usage except UnicodeDecodeError: # 如果解码失败,说明缓冲区中有不完整的UTF-8序列 # 尝试找到最后一个完整的字符边界 for i in range(len(self.buffer) - 1, max(0, len(self.buffer) - 4), -1): try: text = self.buffer[:i].decode("utf-8") # 成功解码,处理这部分 remaining = self.buffer[i:] self.buffer = remaining # 保留未解码的部分 # 处理解码的文本 self.current_line += text lines = self.current_line.split("\n") self.current_line = lines[-1] for line in lines[:-1]: line = line.rstrip("\r") content, usage = self.parse_sse_line(line) if content: total_content += content if usage: final_usage = usage break except UnicodeDecodeError: continue return total_content if total_content else None, final_usage async def track_stream( self, stream: AsyncIterator[str], request_data: Dict[str, Any], response_headers: Optional[Dict[str, Any]] = None, ) -> AsyncIterator[str]: """ 跟踪流式响应并计算用量 Args: stream: 原始流式响应 request_data: 请求数据 response_headers: 实际的响应头 Yields: 流式响应块 """ import time self.start_time = time.time() self.request_data = request_data # 保存请求数据 # 保存响应头(如果没有提供,使用空字典而不是默认值) # 这样可以确保记录的是实际的响应头,而不是构造的默认值 self.response_headers = response_headers if response_headers is not None else {} # 估算输入tokens messages = request_data.get("messages", []) self.input_tokens = self.estimate_input_tokens(messages) logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}") # 更新状态为 streaming,同时更新 provider if self.request_id: try: from src.services.usage.service import UsageService UsageService.update_usage_status( db=self.db, request_id=self.request_id, status="streaming", provider=self.provider, ) except Exception as e: logger.warning(f"更新使用记录状态为 streaming 失败: {e}") chunk_count = 0 try: async for chunk in stream: chunk_count += 1 # 保存原始字节流(用于错误诊断) self.raw_chunks.append(chunk) # 返回原始块给客户端 yield chunk # 解析块以提取内容和使用信息(chunk是原始字节) content, usage = self.parse_stream_chunk(chunk) if content: self.accumulated_content += content # 实时估算输出tokens self.output_tokens = max(1, len(self.accumulated_content) // 4) if usage: # 如果响应中包含准确的usage信息,使用它 self.input_tokens = usage.get("input_tokens", self.input_tokens) self.output_tokens = usage.get("output_tokens", self.output_tokens) self.cache_creation_input_tokens = usage.get( "cache_creation_input_tokens", self.cache_creation_input_tokens ) self.cache_read_input_tokens = usage.get( "cache_read_input_tokens", self.cache_read_input_tokens ) # 处理新的cache_creation格式 if "cache_creation" in usage: cache_creation_data = usage.get("cache_creation", {}) # 如果没有cache_creation_input_tokens,尝试从cache_creation中获取 if not self.cache_creation_input_tokens: self.cache_creation_input_tokens = cache_creation_data.get( "ephemeral_5m_input_tokens", 0 ) + cache_creation_data.get("ephemeral_1h_input_tokens", 0) finally: # 流结束后记录使用量 self.end_time = time.time() logger.debug(f"ID:{self.request_id} | 流式响应结束 | 共处理{chunk_count}个chunks | " f"累积内容长度:{len(self.accumulated_content)} | 输出tokens:{self.output_tokens}") # 检查是否收到了有效数据 # 情况1: 收到了原始数据但无法解析为有效的SSE JSON if chunk_count > 0 and not self.response_chunks: error_msg = f"流式响应完成但未解析到任何有效数据(收到 {chunk_count} 个原始chunk但无法解析)" logger.error(f"ID:{self.request_id} | {error_msg}") # 设置错误状态,避免被记录为成功 self.set_error_status(502, error_msg) # 抛出异常让 FallbackOrchestrator 捕获并触发故障转移 raise EmptyStreamException( provider_name=self.provider, chunk_count=chunk_count, request_metadata=None, ) # 情况2: 流式响应完成但没有收到完整的消息(没有 message_stop 事件) # 这种情况通常发生在服务器重启或连接中断时 if not self.stream_stats.has_completion and not self.response_chunks: error_msg = "流式响应中断:未收到任何有效数据(可能是连接中断或服务重启)" logger.warning(f"ID:{self.request_id} | {error_msg}") self.set_error_status(503, error_msg) # 确保日志一定会输出,即使记录usage失败 try: await self._record_usage() except Exception as e: # 如果记录失败,至少输出基本的汇总日志 logger.exception(f"Failed to record stream usage for request {self.request_id}: {e}") # 尝试输出基本的汇总日志,使用多层防护 try: # 计算响应时间,使用多层后备机制 try: if self.request_start_time and self.end_time: total_response_time = int( (self.end_time - self.request_start_time) * 1000 ) elif self.start_time and self.end_time: total_response_time = int((self.end_time - self.start_time) * 1000) else: total_response_time = 0 except Exception: total_response_time = 0 # 安全地输出汇总日志 logger.info(f"[请求完成] ID:{self.request_id or 'unknown'} | 200 | 耗时:{total_response_time}ms | " f"Token:输入{self.input_tokens}/输出{self.output_tokens} | 费用:未知(记录失败)") except Exception as log_error: # 最后的防线:输出最简单的完成标记 logger.error(f"Failed to output summary log: {log_error}") try: logger.info(f"[请求完成] ID:{self.request_id or 'unknown'} | 记录失败但流已完成") except Exception: # 如果连最简单的日志都失败了,放弃 pass async def _record_usage(self): """记录最终的使用量""" try: if self.request_start_time and self.end_time: response_time_ms = int((self.end_time - self.request_start_time) * 1000) elif self.start_time and self.end_time: response_time_ms = int((self.end_time - self.start_time) * 1000) else: response_time_ms = None # 如果没有准确的token计数,使用估算值 if self.output_tokens == 0 and self.accumulated_content: self.output_tokens = max(1, len(self.accumulated_content) // 4) # 使用完整的响应体(包含所有信息,包括工具调用) # 更新最终的usage信息 self.complete_response["usage"].update( { "input_tokens": self.input_tokens, "output_tokens": self.output_tokens, "cache_creation_input_tokens": self.cache_creation_input_tokens, "cache_read_input_tokens": self.cache_read_input_tokens, } ) # 记录响应数据 # 如果成功解析了SSE chunks,使用解析后的结构化数据 # 否则使用原始字节流(用于错误诊断,如403 HTML响应) if self.response_chunks: # 正常情况:成功解析的SSE JSON响应 response_body = { "chunks": self.response_chunks, "metadata": { "stream": True, "total_chunks": len(self.response_chunks), "content_length": len(self.accumulated_content), "response_time_ms": response_time_ms, }, } else: # 错误情况:无法解析为JSON(如HTML错误页面) # 尝试解码原始字节流为文本 raw_response_text = "" for chunk in self.raw_chunks: try: if isinstance(chunk, bytes): raw_response_text += chunk.decode("utf-8", errors="replace") else: raw_response_text += str(chunk) except Exception: pass response_body = { "chunks": [], "raw_response": raw_response_text[:10000], # 限制大小,避免过大 "metadata": { "stream": True, "total_chunks": 0, "raw_chunks_count": len(self.raw_chunks), "content_length": len(raw_response_text), "response_time_ms": response_time_ms, "parse_error": "Failed to parse response as SSE JSON format", }, } # 检查会话状态,如果会话处于不可用状态,需要回滚并创建新事务 from sqlalchemy.exc import InvalidRequestError user = None api_key = None db_for_usage = self.db created_temp_session = False def _load_user_and_key(db_session: Session) -> Tuple[Optional[User], Optional[ApiKey]]: local_user = ( db_session.query(User).filter(User.id == self.user_id).first() if self.user_id else None ) local_api_key = ( db_session.query(ApiKey).filter(ApiKey.id == self.api_key_id).first() if self.api_key_id else None ) return local_user, local_api_key try: # 检查会话是否可用 db_for_usage.info # 重新查询用户和API密钥对象,确保它们在会话中 user, api_key = _load_user_and_key(db_for_usage) except InvalidRequestError: # 会话处于不可用状态,需要回滚并重新开始 logger.warning(f"Session in invalid state for request {self.request_id}, rolling back and retrying") try: db_for_usage.rollback() except Exception: pass try: db_for_usage.close() except Exception: pass # 使用新的会话记录用量,避免 prepared 状态继续影响查询 try: db_for_usage = create_session() created_temp_session = True user, api_key = _load_user_and_key(db_for_usage) except Exception as session_error: logger.exception(f"Failed to recover from invalid session for request {self.request_id}: {session_error}") return # 根据状态码确定请求状态 final_status = "completed" if self.status_code == 200 else "failed" usage_record = await UsageService.record_usage_async( db=db_for_usage, user=user, api_key=api_key, provider=self.provider, model=self.model, input_tokens=self.input_tokens, output_tokens=self.output_tokens, cache_creation_input_tokens=self.cache_creation_input_tokens, cache_read_input_tokens=self.cache_read_input_tokens, request_type="chat", api_format=self.api_format, is_stream=True, response_time_ms=response_time_ms, status_code=self.status_code, # 使用实际的状态码 error_message=self.error_message, # 使用实际的错误消息 metadata={"stream": True, "content_length": len(self.accumulated_content)}, request_body=self.request_data if hasattr(self, "request_data") else None, request_headers=self.request_headers, provider_request_headers=self.provider_request_headers, response_headers=self.response_headers, response_body=response_body, request_id=self.request_id, # 传递 request_id # Provider 侧追踪信息(用于记录真实成本) provider_id=self.provider_id, provider_endpoint_id=self.provider_endpoint_id, provider_api_key_id=self.provider_api_key_id, # 请求状态 status=final_status, ) # 立即获取 total_cost_usd 的值,避免后续访问时对象已脱离会话 total_cost = 0.0 if usage_record: try: # 在 usage_record 仍在会话中时,立即获取所需属性 total_cost = usage_record.total_cost_usd or 0.0 except Exception as e: logger.warning(f"Failed to access total_cost_usd from usage_record: {e}") total_cost = 0.0 if db_for_usage and self.attempt_id: # RequestTrace 功能已移除,使用 RequestCandidate 表追踪 # 状态更新已在 RequestCandidateService 中完成 pass # 计算总响应时间(从请求开始到流结束) if self.request_start_time: total_response_time = int((self.end_time - self.request_start_time) * 1000) else: total_response_time = response_time_ms # 输出汇总日志(类似非流式请求的完成日志) # 根据状态码决定图标和日志级别 status_prefix = "[请求完成]" if self.status_code == 200 else "[请求失败]" # 根据费用大小选择合适的格式 if total_cost >= 0.01: cost_str = f"${total_cost:.4f}" elif total_cost > 0: cost_str = f"${total_cost:.6f}" else: cost_str = "$0" logger.info(f"{status_prefix} ID:{self.request_id} | {self.status_code} | 耗时:{total_response_time}ms | " f"Token:输入{self.input_tokens}/输出{self.output_tokens} | 费用:{cost_str}") # 记录提供商结果用于动态权重调整 # 记录提供商结果的健康监控已由 FallbackOrchestrator 自动处理 # 这里不再需要手动记录 except Exception as e: logger.exception(f"Failed to record stream usage: {e}") finally: if created_temp_session: try: db_for_usage.close() except Exception: pass class EnhancedStreamUsageTracker(StreamUsageTracker): """ 增强的流式用量跟踪器 支持更准确的token计算 """ def __init__( self, db: Session, user: User, api_key: ApiKey, provider: str, model: str, request_headers: Optional[Dict[str, Any]] = None, provider_request_headers: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, start_time: Optional[float] = None, attempt_id: Optional[str] = None, # Provider 侧追踪信息(用于记录真实成本) provider_id: Optional[str] = None, provider_endpoint_id: Optional[str] = None, provider_api_key_id: Optional[str] = None, # API 格式(用于选择正确的响应解析器) api_format: Optional[str] = None, ): super().__init__( db, user, api_key, provider, model, request_headers, provider_request_headers, request_id, start_time, attempt_id, provider_id, provider_endpoint_id, provider_api_key_id, api_format, ) # 用于更准确的token计算 self._init_tokenizer() # 继承父类的SSE解析缓冲区 # 这些已经在父类中初始化了 def _init_tokenizer(self): """初始化分词器(如果可用)""" try: # 尝试导入tiktoken用于更准确的token计算 import tiktoken # 根据模型选择合适的编码 if "gpt-4" in self.model: self.tokenizer = tiktoken.get_encoding("cl100k_base") elif "gpt-3.5" in self.model: self.tokenizer = tiktoken.get_encoding("cl100k_base") else: # Claude或其他模型,使用近似方法 self.tokenizer = None except ImportError: logger.debug("tiktoken not available, using estimation") self.tokenizer = None def count_tokens(self, text: str) -> int: """ 更准确地计算tokens Args: text: 文本内容 Returns: token数量 """ if self.tokenizer: try: return len(self.tokenizer.encode(text)) except Exception as e: logger.warning(f"Token encoding failed: {e}") # 回退到估算方法 # 中文字符通常是2个token,英文单词约1.3个token chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text)) english_words = len(re.findall(r"\b\w+\b", text)) estimated_tokens = chinese_chars * 2 + english_words * 1.3 return max(1, int(estimated_tokens)) def estimate_input_tokens(self, messages: list) -> int: """ 更准确地估算输入tokens Args: messages: 消息列表 Returns: 估算的token数 """ total_text = "" for msg in messages: if isinstance(msg, dict): content = msg.get("content", "") if isinstance(content, str): total_text += content + " " elif isinstance(content, list): for block in content: if hasattr(block, "text"): total_text += block.text + " " return self.count_tokens(total_text) async def track_stream( self, stream: AsyncIterator[str], request_data: Dict[str, Any], response_headers: Optional[Dict[str, Any]] = None, ) -> AsyncIterator[str]: """ 跟踪流式响应并更准确地计算用量 Args: stream: 原始流式响应 request_data: 请求数据 response_headers: 实际的响应头 Yields: 流式响应块 """ import time self.start_time = time.time() self.request_data = request_data # 保存请求数据 # 保存响应头(如果没有提供,使用空字典而不是默认值) # 这样可以确保记录的是实际的响应头,而不是构造的默认值 self.response_headers = response_headers if response_headers is not None else {} # 更准确地估算输入tokens messages = request_data.get("messages", []) self.input_tokens = self.estimate_input_tokens(messages) logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应(Enhanced) | 估算输入tokens:{self.input_tokens}") chunk_count = 0 try: async for chunk in stream: chunk_count += 1 # 保存原始字节流(用于错误诊断) self.raw_chunks.append(chunk) # 返回原始块给客户端 yield chunk # 解析块以提取内容和使用信息(chunk是原始字节) content, usage = self.parse_stream_chunk(chunk) if content: self.accumulated_content += content # 使用更准确的方法计算输出tokens self.output_tokens = self.count_tokens(self.accumulated_content) if usage: # 如果响应中包含准确的usage信息,优先使用 if "input_tokens" in usage: self.input_tokens = usage["input_tokens"] if "output_tokens" in usage: self.output_tokens = usage["output_tokens"] if "cache_creation_input_tokens" in usage: self.cache_creation_input_tokens = usage["cache_creation_input_tokens"] if "cache_read_input_tokens" in usage: self.cache_read_input_tokens = usage["cache_read_input_tokens"] finally: # 流结束后记录使用量 self.end_time = time.time() logger.debug(f"ID:{self.request_id} | 流式响应结束 | 共处理{chunk_count}个chunks | " f"累积内容长度:{len(self.accumulated_content)} | 输出tokens:{self.output_tokens}") # 检查是否收到了有效数据 # 情况1: 收到了原始数据但无法解析为有效的SSE JSON if chunk_count > 0 and not self.response_chunks: error_msg = f"流式响应完成但未解析到任何有效数据(收到 {chunk_count} 个原始chunk但无法解析)" logger.error(f"ID:{self.request_id} | {error_msg}") # 设置错误状态,避免被记录为成功 self.set_error_status(502, error_msg) # 抛出异常让 FallbackOrchestrator 捕获并触发故障转移 raise EmptyStreamException( provider_name=self.provider, chunk_count=chunk_count, request_metadata=None, ) # 情况2: 流式响应完成但没有收到完整的消息(没有 message_stop 事件) # 这种情况通常发生在服务器重启或连接中断时 if not self.stream_stats.has_completion and not self.response_chunks: error_msg = "流式响应中断:未收到任何有效数据(可能是连接中断或服务重启)" logger.warning(f"ID:{self.request_id} | {error_msg}") self.set_error_status(503, error_msg) # 确保日志一定会输出,即使记录usage失败 try: await self._record_usage() except Exception as e: # 如果记录失败,至少输出基本的汇总日志 logger.exception(f"Failed to record stream usage for request {self.request_id}: {e}") # 尝试输出基本的汇总日志,使用多层防护 try: # 计算响应时间,使用多层后备机制 try: if self.request_start_time and self.end_time: total_response_time = int( (self.end_time - self.request_start_time) * 1000 ) elif self.start_time and self.end_time: total_response_time = int((self.end_time - self.start_time) * 1000) else: total_response_time = 0 except Exception: total_response_time = 0 # 安全地输出汇总日志 logger.info(f"[请求完成] ID:{self.request_id or 'unknown'} | 200 | 耗时:{total_response_time}ms | " f"Token:输入{self.input_tokens}/输出{self.output_tokens} | 费用:未知(记录失败)") except Exception as log_error: # 最后的防线:输出最简单的完成标记 logger.error(f"Failed to output summary log: {log_error}") try: logger.info(f"[请求完成] ID:{self.request_id or 'unknown'} | 记录失败但流已完成") except Exception: # 如果连最简单的日志都失败了,放弃 pass # 导出便捷函数 def create_stream_tracker( db: Session, user: User, api_key: ApiKey, provider: str, model: str, enhanced: bool = True, request_headers: Optional[Dict[str, Any]] = None, provider_request_headers: Optional[Dict[str, Any]] = None, request_id: Optional[str] = None, start_time: Optional[float] = None, attempt_id: Optional[str] = None, # Provider 侧追踪信息(用于记录真实成本) provider_id: Optional[str] = None, provider_endpoint_id: Optional[str] = None, provider_api_key_id: Optional[str] = None, # API 格式(用于选择正确的响应解析器) api_format: Optional[str] = None, ) -> StreamUsageTracker: """ 创建流式用量跟踪器 Args: db: 数据库会话 user: 用户对象 api_key: API密钥对象 provider: 提供商名称 model: 模型名称 enhanced: 是否使用增强版跟踪器 request_headers: 实际的请求头 provider_request_headers: 向提供商发送的请求头 request_id: 请求ID(用于日志关联) start_time: 请求开始时间(用于计算总响应时间) attempt_id: RequestTrace 请求尝试ID(可选) provider_id: Provider ID(用于记录真实成本) provider_endpoint_id: Endpoint ID(用于记录真实成本) provider_api_key_id: API Key ID(用于记录真实成本) api_format: API 格式(CLAUDE, CLAUDE_CLI, OPENAI, OPENAI_CLI) Returns: 流式用量跟踪器实例 """ if enhanced: return EnhancedStreamUsageTracker( db, user, api_key, provider, model, request_headers, provider_request_headers, request_id, start_time, attempt_id, provider_id, provider_endpoint_id, provider_api_key_id, api_format, ) else: return StreamUsageTracker( db, user, api_key, provider, model, request_headers, provider_request_headers, request_id, start_time, attempt_id, provider_id, provider_endpoint_id, provider_api_key_id, api_format, )