Files
Aether/src/services/usage/stream.py

1078 lines
43 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
流式响应用量统计服务
处理流式响应的token计算和使用量记录
"""
import json
import re
from typing import Any, AsyncIterator, Dict, Optional, Tuple
from sqlalchemy.orm import Session
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.response_parser import StreamStats
from src.core.exceptions import EmptyStreamException
from src.core.logger import logger
from src.database.database import create_session
from src.models.database import ApiKey, User
from src.services.usage.service import UsageService
class StreamUsageTracker:
"""流式响应用量跟踪器"""
def __init__(
self,
db: Session,
user: User,
api_key: ApiKey,
provider: str,
model: str,
request_headers: Optional[Dict[str, Any]] = None,
provider_request_headers: Optional[Dict[str, Any]] = None,
request_id: Optional[str] = None,
start_time: Optional[float] = None,
attempt_id: Optional[str] = None,
# Provider 侧追踪信息(用于记录真实成本)
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
provider_api_key_id: Optional[str] = None,
# API 格式(用于选择正确的响应解析器)
api_format: Optional[str] = None,
):
"""
初始化流式用量跟踪器
Args:
db: 数据库会话
user: 用户对象
api_key: API密钥对象
provider: 提供商名称
model: 模型名称
request_headers: 实际的请求头
provider_request_headers: 向提供商发送的请求头
request_id: 请求ID用于日志关联
start_time: 请求开始时间用于计算总响应时间
attempt_id: RequestTrace 请求尝试ID
provider_id: Provider ID用于记录真实成本
provider_endpoint_id: Endpoint ID用于记录真实成本
provider_api_key_id: API Key ID用于记录真实成本
api_format: API 格式CLAUDE, CLAUDE_CLI, OPENAI, OPENAI_CLI
"""
self.db = db
# 只存储ID避免会话绑定问题
self.user_id = user.id if user else None
self.api_key_id = api_key.id if api_key else None
self.provider = provider
self.model = model
self.request_headers = request_headers or {}
self.provider_request_headers = provider_request_headers or {}
self.request_id = request_id
self.request_start_time = start_time
# Provider 侧追踪信息
self.provider_id = provider_id
self.provider_endpoint_id = provider_endpoint_id
self.provider_api_key_id = provider_api_key_id
# API 格式和响应解析器
self.api_format = api_format or "CLAUDE"
self.response_parser = get_parser_for_format(self.api_format)
self.stream_stats = StreamStats() # 解析器统计信息
# Token计数
self.input_tokens = 0
self.output_tokens = 0
self.cache_creation_input_tokens = 0
self.cache_read_input_tokens = 0
self.accumulated_content = ""
# 完整响应跟踪(仅用于内部统计,不记录到数据库)
self.complete_response = {
"id": None,
"type": "message",
"role": "assistant",
"content": [],
"model": model,
"stop_reason": None,
"stop_sequence": None,
"usage": {},
}
self.response_chunks = [] # 保存所有原始响应块
self.raw_chunks = [] # 保存所有原始字节流(用于错误诊断)
# 时间跟踪
self.start_time = None
self.end_time = None
# 响应头 (将在track_stream中设置)
self.response_headers = {}
# SSE解析缓冲区
self.buffer = b"" # 用于处理不完整的字节流
self.current_line = "" # 用于累积SSE行
self.sse_event_buffer = {
"event": None,
"data": [],
"id": None,
"retry": None,
} # SSE事件缓冲
# 错误状态跟踪
self.status_code = 200 # 默认成功状态码
self.error_message = None # 错误消息(如果有)
self.attempt_id = attempt_id
def set_error_status(self, status_code: int, error_message: str):
"""
设置错误状态
Args:
status_code: HTTP状态码
error_message: 错误消息
"""
self.status_code = status_code
self.error_message = error_message
logger.debug(f"ID:{self.request_id} | 流式响应错误状态已设置 | 状态码:{status_code} | 错误:{error_message[:100]}")
def _update_complete_response(self, chunk: Dict[str, Any]):
"""根据响应块更新完整响应结构"""
try:
# 更新响应ID
if chunk.get("id"):
self.complete_response["id"] = chunk["id"]
# 更新模型
if chunk.get("model"):
self.complete_response["model"] = chunk["model"]
# 处理不同类型的事件
event_type = chunk.get("type")
if event_type == "message_start":
# 消息开始事件
message = chunk.get("message", {})
if message.get("id"):
self.complete_response["id"] = message["id"]
if message.get("model"):
self.complete_response["model"] = message["model"]
self.complete_response["usage"] = message.get("usage", {})
elif event_type == "content_block_start":
# 内容块开始
content_block = chunk.get("content_block", {})
self.complete_response["content"].append(content_block)
elif event_type == "content_block_delta":
# 内容块增量更新
index = chunk.get("index", 0)
delta = chunk.get("delta", {})
# 确保content列表有足够的元素
while len(self.complete_response["content"]) <= index:
self.complete_response["content"].append({"type": "text", "text": ""})
current_block = self.complete_response["content"][index]
if delta.get("type") == "text_delta":
# 文本增量
if current_block.get("type") == "text":
current_block["text"] = current_block.get("text", "") + delta.get(
"text", ""
)
elif delta.get("type") == "input_json_delta":
# 工具调用输入增量
if current_block.get("type") == "tool_use":
current_input = current_block.get("input", {})
if isinstance(current_input, str):
current_input += delta.get("partial_json", "")
current_block["input"] = current_input
elif event_type == "content_block_stop":
# 内容块结束
pass
elif event_type == "message_delta":
# 消息级别的增量更新
delta = chunk.get("delta", {})
if delta.get("stop_reason"):
self.complete_response["stop_reason"] = delta["stop_reason"]
if delta.get("stop_sequence"):
self.complete_response["stop_sequence"] = delta["stop_sequence"]
elif event_type == "message_stop":
# 消息结束
pass
# 更新usage信息
if chunk.get("usage"):
self.complete_response["usage"].update(chunk["usage"])
except Exception as e:
# 记录错误但不中断流处理
logger.warning(f"Failed to update complete response: {e}")
def estimate_input_tokens(self, messages: list) -> int:
"""
估算输入tokens
Args:
messages: 消息列表
Returns:
估算的token数
"""
total_chars = 0
for msg in messages:
if isinstance(msg, dict):
content = msg.get("content", "")
if isinstance(content, str):
total_chars += len(content)
elif isinstance(content, list):
for block in content:
if hasattr(block, "text"):
total_chars += len(block.text)
# 粗略估算4个字符约等于1个token
return max(1, total_chars // 4)
def _process_sse_event(self) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
"""
处理缓冲区中的完整SSE事件
Returns:
(内容文本, 使用信息)
"""
content = None
usage = None
# 如果没有data直接返回
if not self.sse_event_buffer["data"]:
return None, None
# 合并所有data行根据SSE规范每个data行之间加入换行符
data_str = "\n".join(self.sse_event_buffer["data"])
# 清空缓冲区
self.sse_event_buffer = {"event": None, "data": [], "id": None, "retry": None}
if not data_str or data_str == "[DONE]":
return None, None
try:
data = json.loads(data_str)
if isinstance(data, dict):
self.response_chunks.append(data)
try:
self._update_complete_response(data)
except Exception as update_error:
logger.warning(f"Failed to update complete response from chunk: {update_error}")
# Claude格式
if "type" in data:
if data["type"] == "content_block_delta":
delta = data.get("delta", {})
content = delta.get("text", "")
if content:
logger.debug(f"Extracted content from delta: {len(content)} chars")
elif data["type"] == "message_delta":
usage_data = data.get("usage", {})
if usage_data:
usage = usage_data
logger.debug(f"Extracted usage from message_delta: {usage}")
elif data["type"] == "message_stop":
logger.debug("Received message_stop event")
# OpenAI格式
elif "choices" in data:
choices = data.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if "usage" in data:
usage = data["usage"]
except json.JSONDecodeError as e:
# 更详细的JSON解析错误日志
logger.warning(f"Failed to parse SSE JSON data: {str(e)}")
return None, None
except Exception as e:
logger.error(f"Unexpected error processing SSE event: {type(e).__name__}: {str(e)}")
return content, usage
def parse_sse_line(self, line: str) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
"""
解析单行SSE事件使用统一响应解析器
Args:
line: SSE格式的一行
Returns:
(内容文本, 使用信息) - 当遇到空行时处理完整事件
"""
# 使用统一的响应解析器
chunk = self.response_parser.parse_sse_line(line, self.stream_stats)
if chunk is None:
return None, None
# 从 ParsedChunk 中提取内容和使用信息
content = chunk.text_delta
# 构建 usage 字典(如果有 token 信息)
usage = None
if (
chunk.input_tokens
or chunk.output_tokens
or chunk.cache_creation_tokens
or chunk.cache_read_tokens
):
usage = {
"input_tokens": chunk.input_tokens or self.stream_stats.input_tokens,
"output_tokens": chunk.output_tokens or self.stream_stats.output_tokens,
"cache_creation_input_tokens": chunk.cache_creation_tokens
or self.stream_stats.cache_creation_tokens,
"cache_read_input_tokens": chunk.cache_read_tokens
or self.stream_stats.cache_read_tokens,
}
# 更新响应 ID
if chunk.response_id and not self.complete_response.get("id"):
self.complete_response["id"] = chunk.response_id
# 更新完整响应(如果有数据)
if chunk.data:
self.response_chunks.append(chunk.data)
try:
self._update_complete_response(chunk.data)
except Exception as update_error:
logger.warning(f"Failed to update complete response from chunk: {update_error}")
return content, usage
def parse_stream_chunk(self, chunk: bytes) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
"""
解析流式响应块处理原始字节流
Args:
chunk: 原始字节流
Returns:
(累积的内容文本, 使用信息)
"""
total_content = ""
final_usage = None
# 将新chunk添加到缓冲区
# 确保 chunk 是字节类型
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
self.buffer += chunk
# 尝试解码并处理完整的行
try:
# 尝试解码整个缓冲区
text = self.buffer.decode("utf-8")
self.buffer = b"" # 清空缓冲区
# 将文本添加到当前行
self.current_line += text
# 按换行符分割,处理完整的行
lines = self.current_line.split("\n")
# 最后一个可能是不完整的行,保留它
self.current_line = lines[-1]
# 处理完整的行
for line in lines[:-1]:
line = line.rstrip("\r")
content, usage = self.parse_sse_line(line)
if content:
total_content += content
if usage:
final_usage = usage
except UnicodeDecodeError:
# 如果解码失败说明缓冲区中有不完整的UTF-8序列
# 尝试找到最后一个完整的字符边界
for i in range(len(self.buffer) - 1, max(0, len(self.buffer) - 4), -1):
try:
text = self.buffer[:i].decode("utf-8")
# 成功解码,处理这部分
remaining = self.buffer[i:]
self.buffer = remaining # 保留未解码的部分
# 处理解码的文本
self.current_line += text
lines = self.current_line.split("\n")
self.current_line = lines[-1]
for line in lines[:-1]:
line = line.rstrip("\r")
content, usage = self.parse_sse_line(line)
if content:
total_content += content
if usage:
final_usage = usage
break
except UnicodeDecodeError:
continue
return total_content if total_content else None, final_usage
async def track_stream(
self,
stream: AsyncIterator[str],
request_data: Dict[str, Any],
response_headers: Optional[Dict[str, Any]] = None,
) -> AsyncIterator[str]:
"""
跟踪流式响应并计算用量
Args:
stream: 原始流式响应
request_data: 请求数据
response_headers: 实际的响应头
Yields:
流式响应块
"""
import time
self.start_time = time.time()
self.request_data = request_data # 保存请求数据
# 保存响应头(如果没有提供,使用空字典而不是默认值)
# 这样可以确保记录的是实际的响应头,而不是构造的默认值
self.response_headers = response_headers if response_headers is not None else {}
# 估算输入tokens
messages = request_data.get("messages", [])
self.input_tokens = self.estimate_input_tokens(messages)
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
# 更新状态为 streaming
if self.request_id:
try:
from src.services.usage.service import UsageService
UsageService.update_usage_status(
db=self.db,
request_id=self.request_id,
status="streaming",
)
except Exception as e:
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
chunk_count = 0
try:
async for chunk in stream:
chunk_count += 1
# 保存原始字节流(用于错误诊断)
self.raw_chunks.append(chunk)
# 返回原始块给客户端
yield chunk
# 解析块以提取内容和使用信息chunk是原始字节
content, usage = self.parse_stream_chunk(chunk)
if content:
self.accumulated_content += content
# 实时估算输出tokens
self.output_tokens = max(1, len(self.accumulated_content) // 4)
if usage:
# 如果响应中包含准确的usage信息使用它
self.input_tokens = usage.get("input_tokens", self.input_tokens)
self.output_tokens = usage.get("output_tokens", self.output_tokens)
self.cache_creation_input_tokens = usage.get(
"cache_creation_input_tokens", self.cache_creation_input_tokens
)
self.cache_read_input_tokens = usage.get(
"cache_read_input_tokens", self.cache_read_input_tokens
)
# 处理新的cache_creation格式
if "cache_creation" in usage:
cache_creation_data = usage.get("cache_creation", {})
# 如果没有cache_creation_input_tokens尝试从cache_creation中获取
if not self.cache_creation_input_tokens:
self.cache_creation_input_tokens = cache_creation_data.get(
"ephemeral_5m_input_tokens", 0
) + cache_creation_data.get("ephemeral_1h_input_tokens", 0)
finally:
# 流结束后记录使用量
self.end_time = time.time()
logger.debug(f"ID:{self.request_id} | 流式响应结束 | 共处理{chunk_count}个chunks | "
f"累积内容长度:{len(self.accumulated_content)} | 输出tokens:{self.output_tokens}")
# 检查是否收到了有效数据
# 情况1: 收到了原始数据但无法解析为有效的SSE JSON
if chunk_count > 0 and not self.response_chunks:
error_msg = f"流式响应完成但未解析到任何有效数据(收到 {chunk_count} 个原始chunk但无法解析"
logger.error(f"ID:{self.request_id} | {error_msg}")
# 设置错误状态,避免被记录为成功
self.set_error_status(502, error_msg)
# 抛出异常让 FallbackOrchestrator 捕获并触发故障转移
raise EmptyStreamException(
provider_name=self.provider,
chunk_count=chunk_count,
request_metadata=None,
)
# 情况2: 流式响应完成但没有收到完整的消息(没有 message_stop 事件)
# 这种情况通常发生在服务器重启或连接中断时
if not self.stream_stats.has_completion and not self.response_chunks:
error_msg = "流式响应中断:未收到任何有效数据(可能是连接中断或服务重启)"
logger.warning(f"ID:{self.request_id} | {error_msg}")
self.set_error_status(503, error_msg)
# 确保日志一定会输出即使记录usage失败
try:
await self._record_usage()
except Exception as e:
# 如果记录失败,至少输出基本的汇总日志
logger.exception(f"Failed to record stream usage for request {self.request_id}: {e}")
# 尝试输出基本的汇总日志,使用多层防护
try:
# 计算响应时间,使用多层后备机制
try:
if self.request_start_time and self.end_time:
total_response_time = int(
(self.end_time - self.request_start_time) * 1000
)
elif self.start_time and self.end_time:
total_response_time = int((self.end_time - self.start_time) * 1000)
else:
total_response_time = 0
except Exception:
total_response_time = 0
# 安全地输出汇总日志
logger.info(f"[请求完成] ID:{self.request_id or 'unknown'} | 200 | 耗时:{total_response_time}ms | "
f"Token:输入{self.input_tokens}/输出{self.output_tokens} | 费用:未知(记录失败)")
except Exception as log_error:
# 最后的防线:输出最简单的完成标记
logger.error(f"Failed to output summary log: {log_error}")
try:
logger.info(f"[请求完成] ID:{self.request_id or 'unknown'} | 记录失败但流已完成")
except Exception:
# 如果连最简单的日志都失败了,放弃
pass
async def _record_usage(self):
"""记录最终的使用量"""
try:
if self.request_start_time and self.end_time:
response_time_ms = int((self.end_time - self.request_start_time) * 1000)
elif self.start_time and self.end_time:
response_time_ms = int((self.end_time - self.start_time) * 1000)
else:
response_time_ms = None
# 如果没有准确的token计数使用估算值
if self.output_tokens == 0 and self.accumulated_content:
self.output_tokens = max(1, len(self.accumulated_content) // 4)
# 使用完整的响应体(包含所有信息,包括工具调用)
# 更新最终的usage信息
self.complete_response["usage"].update(
{
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"cache_creation_input_tokens": self.cache_creation_input_tokens,
"cache_read_input_tokens": self.cache_read_input_tokens,
}
)
# 记录响应数据
# 如果成功解析了SSE chunks使用解析后的结构化数据
# 否则使用原始字节流用于错误诊断如403 HTML响应
if self.response_chunks:
# 正常情况成功解析的SSE JSON响应
response_body = {
"chunks": self.response_chunks,
"metadata": {
"stream": True,
"total_chunks": len(self.response_chunks),
"content_length": len(self.accumulated_content),
"response_time_ms": response_time_ms,
},
}
else:
# 错误情况无法解析为JSON如HTML错误页面
# 尝试解码原始字节流为文本
raw_response_text = ""
for chunk in self.raw_chunks:
try:
if isinstance(chunk, bytes):
raw_response_text += chunk.decode("utf-8", errors="replace")
else:
raw_response_text += str(chunk)
except Exception:
pass
response_body = {
"chunks": [],
"raw_response": raw_response_text[:10000], # 限制大小,避免过大
"metadata": {
"stream": True,
"total_chunks": 0,
"raw_chunks_count": len(self.raw_chunks),
"content_length": len(raw_response_text),
"response_time_ms": response_time_ms,
"parse_error": "Failed to parse response as SSE JSON format",
},
}
# 检查会话状态,如果会话处于不可用状态,需要回滚并创建新事务
from sqlalchemy.exc import InvalidRequestError
user = None
api_key = None
db_for_usage = self.db
created_temp_session = False
def _load_user_and_key(db_session: Session) -> Tuple[Optional[User], Optional[ApiKey]]:
local_user = (
db_session.query(User).filter(User.id == self.user_id).first()
if self.user_id
else None
)
local_api_key = (
db_session.query(ApiKey).filter(ApiKey.id == self.api_key_id).first()
if self.api_key_id
else None
)
return local_user, local_api_key
try:
# 检查会话是否可用
db_for_usage.info
# 重新查询用户和API密钥对象确保它们在会话中
user, api_key = _load_user_and_key(db_for_usage)
except InvalidRequestError:
# 会话处于不可用状态,需要回滚并重新开始
logger.warning(f"Session in invalid state for request {self.request_id}, rolling back and retrying")
try:
db_for_usage.rollback()
except Exception:
pass
try:
db_for_usage.close()
except Exception:
pass
# 使用新的会话记录用量,避免 prepared 状态继续影响查询
try:
db_for_usage = create_session()
created_temp_session = True
user, api_key = _load_user_and_key(db_for_usage)
except Exception as session_error:
logger.exception(f"Failed to recover from invalid session for request {self.request_id}: {session_error}")
return
# 根据状态码确定请求状态
final_status = "completed" if self.status_code == 200 else "failed"
usage_record = await UsageService.record_usage_async(
db=db_for_usage,
user=user,
api_key=api_key,
provider=self.provider,
model=self.model,
input_tokens=self.input_tokens,
output_tokens=self.output_tokens,
cache_creation_input_tokens=self.cache_creation_input_tokens,
cache_read_input_tokens=self.cache_read_input_tokens,
request_type="chat",
api_format=self.api_format,
is_stream=True,
response_time_ms=response_time_ms,
status_code=self.status_code, # 使用实际的状态码
error_message=self.error_message, # 使用实际的错误消息
metadata={"stream": True, "content_length": len(self.accumulated_content)},
request_body=self.request_data if hasattr(self, "request_data") else None,
request_headers=self.request_headers,
provider_request_headers=self.provider_request_headers,
response_headers=self.response_headers,
response_body=response_body,
request_id=self.request_id, # 传递 request_id
# Provider 侧追踪信息(用于记录真实成本)
provider_id=self.provider_id,
provider_endpoint_id=self.provider_endpoint_id,
provider_api_key_id=self.provider_api_key_id,
# 请求状态
status=final_status,
)
# 立即获取 total_cost_usd 的值,避免后续访问时对象已脱离会话
total_cost = 0.0
if usage_record:
try:
# 在 usage_record 仍在会话中时,立即获取所需属性
total_cost = usage_record.total_cost_usd or 0.0
except Exception as e:
logger.warning(f"Failed to access total_cost_usd from usage_record: {e}")
total_cost = 0.0
if db_for_usage and self.attempt_id:
# RequestTrace 功能已移除,使用 RequestCandidate 表追踪
# 状态更新已在 RequestCandidateService 中完成
pass
# 计算总响应时间(从请求开始到流结束)
if self.request_start_time:
total_response_time = int((self.end_time - self.request_start_time) * 1000)
else:
total_response_time = response_time_ms
# 输出汇总日志(类似非流式请求的完成日志)
# 根据状态码决定图标和日志级别
status_prefix = "[请求完成]" if self.status_code == 200 else "[请求失败]"
# 根据费用大小选择合适的格式
if total_cost >= 0.01:
cost_str = f"${total_cost:.4f}"
elif total_cost > 0:
cost_str = f"${total_cost:.6f}"
else:
cost_str = "$0"
logger.info(f"{status_prefix} ID:{self.request_id} | {self.status_code} | 耗时:{total_response_time}ms | "
f"Token:输入{self.input_tokens}/输出{self.output_tokens} | 费用:{cost_str}")
# 记录提供商结果用于动态权重调整
# 记录提供商结果的健康监控已由 FallbackOrchestrator 自动处理
# 这里不再需要手动记录
except Exception as e:
logger.exception(f"Failed to record stream usage: {e}")
finally:
if created_temp_session:
try:
db_for_usage.close()
except Exception:
pass
class EnhancedStreamUsageTracker(StreamUsageTracker):
"""
增强的流式用量跟踪器
支持更准确的token计算
"""
def __init__(
self,
db: Session,
user: User,
api_key: ApiKey,
provider: str,
model: str,
request_headers: Optional[Dict[str, Any]] = None,
provider_request_headers: Optional[Dict[str, Any]] = None,
request_id: Optional[str] = None,
start_time: Optional[float] = None,
attempt_id: Optional[str] = None,
# Provider 侧追踪信息(用于记录真实成本)
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
provider_api_key_id: Optional[str] = None,
# API 格式(用于选择正确的响应解析器)
api_format: Optional[str] = None,
):
super().__init__(
db,
user,
api_key,
provider,
model,
request_headers,
provider_request_headers,
request_id,
start_time,
attempt_id,
provider_id,
provider_endpoint_id,
provider_api_key_id,
api_format,
)
# 用于更准确的token计算
self._init_tokenizer()
# 继承父类的SSE解析缓冲区
# 这些已经在父类中初始化了
def _init_tokenizer(self):
"""初始化分词器(如果可用)"""
try:
# 尝试导入tiktoken用于更准确的token计算
import tiktoken
# 根据模型选择合适的编码
if "gpt-4" in self.model:
self.tokenizer = tiktoken.get_encoding("cl100k_base")
elif "gpt-3.5" in self.model:
self.tokenizer = tiktoken.get_encoding("cl100k_base")
else:
# Claude或其他模型使用近似方法
self.tokenizer = None
except ImportError:
logger.debug("tiktoken not available, using estimation")
self.tokenizer = None
def count_tokens(self, text: str) -> int:
"""
更准确地计算tokens
Args:
text: 文本内容
Returns:
token数量
"""
if self.tokenizer:
try:
return len(self.tokenizer.encode(text))
except Exception as e:
logger.warning(f"Token encoding failed: {e}")
# 回退到估算方法
# 中文字符通常是2个token英文单词约1.3个token
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
english_words = len(re.findall(r"\b\w+\b", text))
estimated_tokens = chinese_chars * 2 + english_words * 1.3
return max(1, int(estimated_tokens))
def estimate_input_tokens(self, messages: list) -> int:
"""
更准确地估算输入tokens
Args:
messages: 消息列表
Returns:
估算的token数
"""
total_text = ""
for msg in messages:
if isinstance(msg, dict):
content = msg.get("content", "")
if isinstance(content, str):
total_text += content + " "
elif isinstance(content, list):
for block in content:
if hasattr(block, "text"):
total_text += block.text + " "
return self.count_tokens(total_text)
async def track_stream(
self,
stream: AsyncIterator[str],
request_data: Dict[str, Any],
response_headers: Optional[Dict[str, Any]] = None,
) -> AsyncIterator[str]:
"""
跟踪流式响应并更准确地计算用量
Args:
stream: 原始流式响应
request_data: 请求数据
response_headers: 实际的响应头
Yields:
流式响应块
"""
import time
self.start_time = time.time()
self.request_data = request_data # 保存请求数据
# 保存响应头(如果没有提供,使用空字典而不是默认值)
# 这样可以确保记录的是实际的响应头,而不是构造的默认值
self.response_headers = response_headers if response_headers is not None else {}
# 更准确地估算输入tokens
messages = request_data.get("messages", [])
self.input_tokens = self.estimate_input_tokens(messages)
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应(Enhanced) | 估算输入tokens:{self.input_tokens}")
chunk_count = 0
try:
async for chunk in stream:
chunk_count += 1
# 保存原始字节流(用于错误诊断)
self.raw_chunks.append(chunk)
# 返回原始块给客户端
yield chunk
# 解析块以提取内容和使用信息chunk是原始字节
content, usage = self.parse_stream_chunk(chunk)
if content:
self.accumulated_content += content
# 使用更准确的方法计算输出tokens
self.output_tokens = self.count_tokens(self.accumulated_content)
if usage:
# 如果响应中包含准确的usage信息优先使用
if "input_tokens" in usage:
self.input_tokens = usage["input_tokens"]
if "output_tokens" in usage:
self.output_tokens = usage["output_tokens"]
if "cache_creation_input_tokens" in usage:
self.cache_creation_input_tokens = usage["cache_creation_input_tokens"]
if "cache_read_input_tokens" in usage:
self.cache_read_input_tokens = usage["cache_read_input_tokens"]
finally:
# 流结束后记录使用量
self.end_time = time.time()
logger.debug(f"ID:{self.request_id} | 流式响应结束 | 共处理{chunk_count}个chunks | "
f"累积内容长度:{len(self.accumulated_content)} | 输出tokens:{self.output_tokens}")
# 检查是否收到了有效数据
# 情况1: 收到了原始数据但无法解析为有效的SSE JSON
if chunk_count > 0 and not self.response_chunks:
error_msg = f"流式响应完成但未解析到任何有效数据(收到 {chunk_count} 个原始chunk但无法解析"
logger.error(f"ID:{self.request_id} | {error_msg}")
# 设置错误状态,避免被记录为成功
self.set_error_status(502, error_msg)
# 抛出异常让 FallbackOrchestrator 捕获并触发故障转移
raise EmptyStreamException(
provider_name=self.provider,
chunk_count=chunk_count,
request_metadata=None,
)
# 情况2: 流式响应完成但没有收到完整的消息(没有 message_stop 事件)
# 这种情况通常发生在服务器重启或连接中断时
if not self.stream_stats.has_completion and not self.response_chunks:
error_msg = "流式响应中断:未收到任何有效数据(可能是连接中断或服务重启)"
logger.warning(f"ID:{self.request_id} | {error_msg}")
self.set_error_status(503, error_msg)
# 确保日志一定会输出即使记录usage失败
try:
await self._record_usage()
except Exception as e:
# 如果记录失败,至少输出基本的汇总日志
logger.exception(f"Failed to record stream usage for request {self.request_id}: {e}")
# 尝试输出基本的汇总日志,使用多层防护
try:
# 计算响应时间,使用多层后备机制
try:
if self.request_start_time and self.end_time:
total_response_time = int(
(self.end_time - self.request_start_time) * 1000
)
elif self.start_time and self.end_time:
total_response_time = int((self.end_time - self.start_time) * 1000)
else:
total_response_time = 0
except Exception:
total_response_time = 0
# 安全地输出汇总日志
logger.info(f"[请求完成] ID:{self.request_id or 'unknown'} | 200 | 耗时:{total_response_time}ms | "
f"Token:输入{self.input_tokens}/输出{self.output_tokens} | 费用:未知(记录失败)")
except Exception as log_error:
# 最后的防线:输出最简单的完成标记
logger.error(f"Failed to output summary log: {log_error}")
try:
logger.info(f"[请求完成] ID:{self.request_id or 'unknown'} | 记录失败但流已完成")
except Exception:
# 如果连最简单的日志都失败了,放弃
pass
# 导出便捷函数
def create_stream_tracker(
db: Session,
user: User,
api_key: ApiKey,
provider: str,
model: str,
enhanced: bool = True,
request_headers: Optional[Dict[str, Any]] = None,
provider_request_headers: Optional[Dict[str, Any]] = None,
request_id: Optional[str] = None,
start_time: Optional[float] = None,
attempt_id: Optional[str] = None,
# Provider 侧追踪信息(用于记录真实成本)
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
provider_api_key_id: Optional[str] = None,
# API 格式(用于选择正确的响应解析器)
api_format: Optional[str] = None,
) -> StreamUsageTracker:
"""
创建流式用量跟踪器
Args:
db: 数据库会话
user: 用户对象
api_key: API密钥对象
provider: 提供商名称
model: 模型名称
enhanced: 是否使用增强版跟踪器
request_headers: 实际的请求头
provider_request_headers: 向提供商发送的请求头
request_id: 请求ID用于日志关联
start_time: 请求开始时间用于计算总响应时间
attempt_id: RequestTrace 请求尝试ID可选
provider_id: Provider ID用于记录真实成本
provider_endpoint_id: Endpoint ID用于记录真实成本
provider_api_key_id: API Key ID用于记录真实成本
api_format: API 格式CLAUDE, CLAUDE_CLI, OPENAI, OPENAI_CLI
Returns:
流式用量跟踪器实例
"""
if enhanced:
return EnhancedStreamUsageTracker(
db,
user,
api_key,
provider,
model,
request_headers,
provider_request_headers,
request_id,
start_time,
attempt_id,
provider_id,
provider_endpoint_id,
provider_api_key_id,
api_format,
)
else:
return StreamUsageTracker(
db,
user,
api_key,
provider,
model,
request_headers,
provider_request_headers,
request_id,
start_time,
attempt_id,
provider_id,
provider_endpoint_id,
provider_api_key_id,
api_format,
)