Files
Aether/src/services/usage/stream.py
2025-12-10 20:52:44 +08:00

1078 lines
43 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
流式响应用量统计服务
处理流式响应的token计算和使用量记录
"""
import json
import re
from typing import Any, AsyncIterator, Dict, Optional, Tuple
from sqlalchemy.orm import Session
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.response_parser import StreamStats
from src.core.exceptions import EmptyStreamException
from src.core.logger import logger
from src.database.database import create_session
from src.models.database import ApiKey, User
from src.services.usage.service import UsageService
class StreamUsageTracker:
"""流式响应用量跟踪器"""
def __init__(
self,
db: Session,
user: User,
api_key: ApiKey,
provider: str,
model: str,
request_headers: Optional[Dict[str, Any]] = None,
provider_request_headers: Optional[Dict[str, Any]] = None,
request_id: Optional[str] = None,
start_time: Optional[float] = None,
attempt_id: Optional[str] = None,
# Provider 侧追踪信息(用于记录真实成本)
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
provider_api_key_id: Optional[str] = None,
# API 格式(用于选择正确的响应解析器)
api_format: Optional[str] = None,
):
"""
初始化流式用量跟踪器
Args:
db: 数据库会话
user: 用户对象
api_key: API密钥对象
provider: 提供商名称
model: 模型名称
request_headers: 实际的请求头
provider_request_headers: 向提供商发送的请求头
request_id: 请求ID用于日志关联
start_time: 请求开始时间(用于计算总响应时间)
attempt_id: RequestTrace 请求尝试ID
provider_id: Provider ID用于记录真实成本
provider_endpoint_id: Endpoint ID用于记录真实成本
provider_api_key_id: API Key ID用于记录真实成本
api_format: API 格式CLAUDE, CLAUDE_CLI, OPENAI, OPENAI_CLI
"""
self.db = db
# 只存储ID避免会话绑定问题
self.user_id = user.id if user else None
self.api_key_id = api_key.id if api_key else None
self.provider = provider
self.model = model
self.request_headers = request_headers or {}
self.provider_request_headers = provider_request_headers or {}
self.request_id = request_id
self.request_start_time = start_time
# Provider 侧追踪信息
self.provider_id = provider_id
self.provider_endpoint_id = provider_endpoint_id
self.provider_api_key_id = provider_api_key_id
# API 格式和响应解析器
self.api_format = api_format or "CLAUDE"
self.response_parser = get_parser_for_format(self.api_format)
self.stream_stats = StreamStats() # 解析器统计信息
# Token计数
self.input_tokens = 0
self.output_tokens = 0
self.cache_creation_input_tokens = 0
self.cache_read_input_tokens = 0
self.accumulated_content = ""
# 完整响应跟踪(仅用于内部统计,不记录到数据库)
self.complete_response = {
"id": None,
"type": "message",
"role": "assistant",
"content": [],
"model": model,
"stop_reason": None,
"stop_sequence": None,
"usage": {},
}
self.response_chunks = [] # 保存所有原始响应块
self.raw_chunks = [] # 保存所有原始字节流(用于错误诊断)
# 时间跟踪
self.start_time = None
self.end_time = None
# 响应头 (将在track_stream中设置)
self.response_headers = {}
# SSE解析缓冲区
self.buffer = b"" # 用于处理不完整的字节流
self.current_line = "" # 用于累积SSE行
self.sse_event_buffer = {
"event": None,
"data": [],
"id": None,
"retry": None,
} # SSE事件缓冲
# 错误状态跟踪
self.status_code = 200 # 默认成功状态码
self.error_message = None # 错误消息(如果有)
self.attempt_id = attempt_id
def set_error_status(self, status_code: int, error_message: str):
"""
设置错误状态
Args:
status_code: HTTP状态码
error_message: 错误消息
"""
self.status_code = status_code
self.error_message = error_message
logger.debug(f"ID:{self.request_id} | 流式响应错误状态已设置 | 状态码:{status_code} | 错误:{error_message[:100]}")
def _update_complete_response(self, chunk: Dict[str, Any]):
"""根据响应块更新完整响应结构"""
try:
# 更新响应ID
if chunk.get("id"):
self.complete_response["id"] = chunk["id"]
# 更新模型
if chunk.get("model"):
self.complete_response["model"] = chunk["model"]
# 处理不同类型的事件
event_type = chunk.get("type")
if event_type == "message_start":
# 消息开始事件
message = chunk.get("message", {})
if message.get("id"):
self.complete_response["id"] = message["id"]
if message.get("model"):
self.complete_response["model"] = message["model"]
self.complete_response["usage"] = message.get("usage", {})
elif event_type == "content_block_start":
# 内容块开始
content_block = chunk.get("content_block", {})
self.complete_response["content"].append(content_block)
elif event_type == "content_block_delta":
# 内容块增量更新
index = chunk.get("index", 0)
delta = chunk.get("delta", {})
# 确保content列表有足够的元素
while len(self.complete_response["content"]) <= index:
self.complete_response["content"].append({"type": "text", "text": ""})
current_block = self.complete_response["content"][index]
if delta.get("type") == "text_delta":
# 文本增量
if current_block.get("type") == "text":
current_block["text"] = current_block.get("text", "") + delta.get(
"text", ""
)
elif delta.get("type") == "input_json_delta":
# 工具调用输入增量
if current_block.get("type") == "tool_use":
current_input = current_block.get("input", {})
if isinstance(current_input, str):
current_input += delta.get("partial_json", "")
current_block["input"] = current_input
elif event_type == "content_block_stop":
# 内容块结束
pass
elif event_type == "message_delta":
# 消息级别的增量更新
delta = chunk.get("delta", {})
if delta.get("stop_reason"):
self.complete_response["stop_reason"] = delta["stop_reason"]
if delta.get("stop_sequence"):
self.complete_response["stop_sequence"] = delta["stop_sequence"]
elif event_type == "message_stop":
# 消息结束
pass
# 更新usage信息
if chunk.get("usage"):
self.complete_response["usage"].update(chunk["usage"])
except Exception as e:
# 记录错误但不中断流处理
logger.warning(f"Failed to update complete response: {e}")
def estimate_input_tokens(self, messages: list) -> int:
"""
估算输入tokens
Args:
messages: 消息列表
Returns:
估算的token数
"""
total_chars = 0
for msg in messages:
if isinstance(msg, dict):
content = msg.get("content", "")
if isinstance(content, str):
total_chars += len(content)
elif isinstance(content, list):
for block in content:
if hasattr(block, "text"):
total_chars += len(block.text)
# 粗略估算4个字符约等于1个token
return max(1, total_chars // 4)
def _process_sse_event(self) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
"""
处理缓冲区中的完整SSE事件
Returns:
(内容文本, 使用信息)
"""
content = None
usage = None
# 如果没有data直接返回
if not self.sse_event_buffer["data"]:
return None, None
# 合并所有data行根据SSE规范每个data行之间加入换行符
data_str = "\n".join(self.sse_event_buffer["data"])
# 清空缓冲区
self.sse_event_buffer = {"event": None, "data": [], "id": None, "retry": None}
if not data_str or data_str == "[DONE]":
return None, None
try:
data = json.loads(data_str)
if isinstance(data, dict):
self.response_chunks.append(data)
try:
self._update_complete_response(data)
except Exception as update_error:
logger.warning(f"Failed to update complete response from chunk: {update_error}")
# Claude格式
if "type" in data:
if data["type"] == "content_block_delta":
delta = data.get("delta", {})
content = delta.get("text", "")
if content:
logger.debug(f"Extracted content from delta: {len(content)} chars")
elif data["type"] == "message_delta":
usage_data = data.get("usage", {})
if usage_data:
usage = usage_data
logger.debug(f"Extracted usage from message_delta: {usage}")
elif data["type"] == "message_stop":
logger.debug("Received message_stop event")
# OpenAI格式
elif "choices" in data:
choices = data.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if "usage" in data:
usage = data["usage"]
except json.JSONDecodeError as e:
# 更详细的JSON解析错误日志
logger.warning(f"Failed to parse SSE JSON data: {str(e)}")
return None, None
except Exception as e:
logger.error(f"Unexpected error processing SSE event: {type(e).__name__}: {str(e)}")
return content, usage
def parse_sse_line(self, line: str) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
"""
解析单行SSE事件使用统一响应解析器
Args:
line: SSE格式的一行
Returns:
(内容文本, 使用信息) - 当遇到空行时处理完整事件
"""
# 使用统一的响应解析器
chunk = self.response_parser.parse_sse_line(line, self.stream_stats)
if chunk is None:
return None, None
# 从 ParsedChunk 中提取内容和使用信息
content = chunk.text_delta
# 构建 usage 字典(如果有 token 信息)
usage = None
if (
chunk.input_tokens
or chunk.output_tokens
or chunk.cache_creation_tokens
or chunk.cache_read_tokens
):
usage = {
"input_tokens": chunk.input_tokens or self.stream_stats.input_tokens,
"output_tokens": chunk.output_tokens or self.stream_stats.output_tokens,
"cache_creation_input_tokens": chunk.cache_creation_tokens
or self.stream_stats.cache_creation_tokens,
"cache_read_input_tokens": chunk.cache_read_tokens
or self.stream_stats.cache_read_tokens,
}
# 更新响应 ID
if chunk.response_id and not self.complete_response.get("id"):
self.complete_response["id"] = chunk.response_id
# 更新完整响应(如果有数据)
if chunk.data:
self.response_chunks.append(chunk.data)
try:
self._update_complete_response(chunk.data)
except Exception as update_error:
logger.warning(f"Failed to update complete response from chunk: {update_error}")
return content, usage
def parse_stream_chunk(self, chunk: bytes) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
"""
解析流式响应块(处理原始字节流)
Args:
chunk: 原始字节流
Returns:
(累积的内容文本, 使用信息)
"""
total_content = ""
final_usage = None
# 将新chunk添加到缓冲区
# 确保 chunk 是字节类型
if isinstance(chunk, str):
chunk = chunk.encode("utf-8")
self.buffer += chunk
# 尝试解码并处理完整的行
try:
# 尝试解码整个缓冲区
text = self.buffer.decode("utf-8")
self.buffer = b"" # 清空缓冲区
# 将文本添加到当前行
self.current_line += text
# 按换行符分割,处理完整的行
lines = self.current_line.split("\n")
# 最后一个可能是不完整的行,保留它
self.current_line = lines[-1]
# 处理完整的行
for line in lines[:-1]:
line = line.rstrip("\r")
content, usage = self.parse_sse_line(line)
if content:
total_content += content
if usage:
final_usage = usage
except UnicodeDecodeError:
# 如果解码失败说明缓冲区中有不完整的UTF-8序列
# 尝试找到最后一个完整的字符边界
for i in range(len(self.buffer) - 1, max(0, len(self.buffer) - 4), -1):
try:
text = self.buffer[:i].decode("utf-8")
# 成功解码,处理这部分
remaining = self.buffer[i:]
self.buffer = remaining # 保留未解码的部分
# 处理解码的文本
self.current_line += text
lines = self.current_line.split("\n")
self.current_line = lines[-1]
for line in lines[:-1]:
line = line.rstrip("\r")
content, usage = self.parse_sse_line(line)
if content:
total_content += content
if usage:
final_usage = usage
break
except UnicodeDecodeError:
continue
return total_content if total_content else None, final_usage
async def track_stream(
self,
stream: AsyncIterator[str],
request_data: Dict[str, Any],
response_headers: Optional[Dict[str, Any]] = None,
) -> AsyncIterator[str]:
"""
跟踪流式响应并计算用量
Args:
stream: 原始流式响应
request_data: 请求数据
response_headers: 实际的响应头
Yields:
流式响应块
"""
import time
self.start_time = time.time()
self.request_data = request_data # 保存请求数据
# 保存响应头(如果没有提供,使用空字典而不是默认值)
# 这样可以确保记录的是实际的响应头,而不是构造的默认值
self.response_headers = response_headers if response_headers is not None else {}
# 估算输入tokens
messages = request_data.get("messages", [])
self.input_tokens = self.estimate_input_tokens(messages)
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
# 更新状态为 streaming
if self.request_id:
try:
from src.services.usage.service import UsageService
UsageService.update_usage_status(
db=self.db,
request_id=self.request_id,
status="streaming",
)
except Exception as e:
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
chunk_count = 0
try:
async for chunk in stream:
chunk_count += 1
# 保存原始字节流(用于错误诊断)
self.raw_chunks.append(chunk)
# 返回原始块给客户端
yield chunk
# 解析块以提取内容和使用信息chunk是原始字节
content, usage = self.parse_stream_chunk(chunk)
if content:
self.accumulated_content += content
# 实时估算输出tokens
self.output_tokens = max(1, len(self.accumulated_content) // 4)
if usage:
# 如果响应中包含准确的usage信息使用它
self.input_tokens = usage.get("input_tokens", self.input_tokens)
self.output_tokens = usage.get("output_tokens", self.output_tokens)
self.cache_creation_input_tokens = usage.get(
"cache_creation_input_tokens", self.cache_creation_input_tokens
)
self.cache_read_input_tokens = usage.get(
"cache_read_input_tokens", self.cache_read_input_tokens
)
# 处理新的cache_creation格式
if "cache_creation" in usage:
cache_creation_data = usage.get("cache_creation", {})
# 如果没有cache_creation_input_tokens尝试从cache_creation中获取
if not self.cache_creation_input_tokens:
self.cache_creation_input_tokens = cache_creation_data.get(
"ephemeral_5m_input_tokens", 0
) + cache_creation_data.get("ephemeral_1h_input_tokens", 0)
finally:
# 流结束后记录使用量
self.end_time = time.time()
logger.debug(f"ID:{self.request_id} | 流式响应结束 | 共处理{chunk_count}个chunks | "
f"累积内容长度:{len(self.accumulated_content)} | 输出tokens:{self.output_tokens}")
# 检查是否收到了有效数据
# 情况1: 收到了原始数据但无法解析为有效的SSE JSON
if chunk_count > 0 and not self.response_chunks:
error_msg = f"流式响应完成但未解析到任何有效数据(收到 {chunk_count} 个原始chunk但无法解析"
logger.error(f"ID:{self.request_id} | {error_msg}")
# 设置错误状态,避免被记录为成功
self.set_error_status(502, error_msg)
# 抛出异常让 FallbackOrchestrator 捕获并触发故障转移
raise EmptyStreamException(
provider_name=self.provider,
chunk_count=chunk_count,
request_metadata=None,
)
# 情况2: 流式响应完成但没有收到完整的消息(没有 message_stop 事件)
# 这种情况通常发生在服务器重启或连接中断时
if not self.stream_stats.has_completion and not self.response_chunks:
error_msg = "流式响应中断:未收到任何有效数据(可能是连接中断或服务重启)"
logger.warning(f"ID:{self.request_id} | {error_msg}")
self.set_error_status(503, error_msg)
# 确保日志一定会输出即使记录usage失败
try:
await self._record_usage()
except Exception as e:
# 如果记录失败,至少输出基本的汇总日志
logger.exception(f"Failed to record stream usage for request {self.request_id}: {e}")
# 尝试输出基本的汇总日志,使用多层防护
try:
# 计算响应时间,使用多层后备机制
try:
if self.request_start_time and self.end_time:
total_response_time = int(
(self.end_time - self.request_start_time) * 1000
)
elif self.start_time and self.end_time:
total_response_time = int((self.end_time - self.start_time) * 1000)
else:
total_response_time = 0
except Exception:
total_response_time = 0
# 安全地输出汇总日志
logger.info(f"[请求完成] ID:{self.request_id or 'unknown'} | 200 | 耗时:{total_response_time}ms | "
f"Token:输入{self.input_tokens}/输出{self.output_tokens} | 费用:未知(记录失败)")
except Exception as log_error:
# 最后的防线:输出最简单的完成标记
logger.error(f"Failed to output summary log: {log_error}")
try:
logger.info(f"[请求完成] ID:{self.request_id or 'unknown'} | 记录失败但流已完成")
except Exception:
# 如果连最简单的日志都失败了,放弃
pass
async def _record_usage(self):
"""记录最终的使用量"""
try:
if self.request_start_time and self.end_time:
response_time_ms = int((self.end_time - self.request_start_time) * 1000)
elif self.start_time and self.end_time:
response_time_ms = int((self.end_time - self.start_time) * 1000)
else:
response_time_ms = None
# 如果没有准确的token计数使用估算值
if self.output_tokens == 0 and self.accumulated_content:
self.output_tokens = max(1, len(self.accumulated_content) // 4)
# 使用完整的响应体(包含所有信息,包括工具调用)
# 更新最终的usage信息
self.complete_response["usage"].update(
{
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"cache_creation_input_tokens": self.cache_creation_input_tokens,
"cache_read_input_tokens": self.cache_read_input_tokens,
}
)
# 记录响应数据
# 如果成功解析了SSE chunks使用解析后的结构化数据
# 否则使用原始字节流用于错误诊断如403 HTML响应
if self.response_chunks:
# 正常情况成功解析的SSE JSON响应
response_body = {
"chunks": self.response_chunks,
"metadata": {
"stream": True,
"total_chunks": len(self.response_chunks),
"content_length": len(self.accumulated_content),
"response_time_ms": response_time_ms,
},
}
else:
# 错误情况无法解析为JSON如HTML错误页面
# 尝试解码原始字节流为文本
raw_response_text = ""
for chunk in self.raw_chunks:
try:
if isinstance(chunk, bytes):
raw_response_text += chunk.decode("utf-8", errors="replace")
else:
raw_response_text += str(chunk)
except Exception:
pass
response_body = {
"chunks": [],
"raw_response": raw_response_text[:10000], # 限制大小,避免过大
"metadata": {
"stream": True,
"total_chunks": 0,
"raw_chunks_count": len(self.raw_chunks),
"content_length": len(raw_response_text),
"response_time_ms": response_time_ms,
"parse_error": "Failed to parse response as SSE JSON format",
},
}
# 检查会话状态,如果会话处于不可用状态,需要回滚并创建新事务
from sqlalchemy.exc import InvalidRequestError
user = None
api_key = None
db_for_usage = self.db
created_temp_session = False
def _load_user_and_key(db_session: Session) -> Tuple[Optional[User], Optional[ApiKey]]:
local_user = (
db_session.query(User).filter(User.id == self.user_id).first()
if self.user_id
else None
)
local_api_key = (
db_session.query(ApiKey).filter(ApiKey.id == self.api_key_id).first()
if self.api_key_id
else None
)
return local_user, local_api_key
try:
# 检查会话是否可用
db_for_usage.info
# 重新查询用户和API密钥对象确保它们在会话中
user, api_key = _load_user_and_key(db_for_usage)
except InvalidRequestError:
# 会话处于不可用状态,需要回滚并重新开始
logger.warning(f"Session in invalid state for request {self.request_id}, rolling back and retrying")
try:
db_for_usage.rollback()
except Exception:
pass
try:
db_for_usage.close()
except Exception:
pass
# 使用新的会话记录用量,避免 prepared 状态继续影响查询
try:
db_for_usage = create_session()
created_temp_session = True
user, api_key = _load_user_and_key(db_for_usage)
except Exception as session_error:
logger.exception(f"Failed to recover from invalid session for request {self.request_id}: {session_error}")
return
# 根据状态码确定请求状态
final_status = "completed" if self.status_code == 200 else "failed"
usage_record = await UsageService.record_usage_async(
db=db_for_usage,
user=user,
api_key=api_key,
provider=self.provider,
model=self.model,
input_tokens=self.input_tokens,
output_tokens=self.output_tokens,
cache_creation_input_tokens=self.cache_creation_input_tokens,
cache_read_input_tokens=self.cache_read_input_tokens,
request_type="chat",
api_format=self.api_format,
is_stream=True,
response_time_ms=response_time_ms,
status_code=self.status_code, # 使用实际的状态码
error_message=self.error_message, # 使用实际的错误消息
metadata={"stream": True, "content_length": len(self.accumulated_content)},
request_body=self.request_data if hasattr(self, "request_data") else None,
request_headers=self.request_headers,
provider_request_headers=self.provider_request_headers,
response_headers=self.response_headers,
response_body=response_body,
request_id=self.request_id, # 传递 request_id
# Provider 侧追踪信息(用于记录真实成本)
provider_id=self.provider_id,
provider_endpoint_id=self.provider_endpoint_id,
provider_api_key_id=self.provider_api_key_id,
# 请求状态
status=final_status,
)
# 立即获取 total_cost_usd 的值,避免后续访问时对象已脱离会话
total_cost = 0.0
if usage_record:
try:
# 在 usage_record 仍在会话中时,立即获取所需属性
total_cost = usage_record.total_cost_usd or 0.0
except Exception as e:
logger.warning(f"Failed to access total_cost_usd from usage_record: {e}")
total_cost = 0.0
if db_for_usage and self.attempt_id:
# RequestTrace 功能已移除,使用 RequestCandidate 表追踪
# 状态更新已在 RequestCandidateService 中完成
pass
# 计算总响应时间(从请求开始到流结束)
if self.request_start_time:
total_response_time = int((self.end_time - self.request_start_time) * 1000)
else:
total_response_time = response_time_ms
# 输出汇总日志(类似非流式请求的完成日志)
# 根据状态码决定图标和日志级别
status_prefix = "[请求完成]" if self.status_code == 200 else "[请求失败]"
# 根据费用大小选择合适的格式
if total_cost >= 0.01:
cost_str = f"${total_cost:.4f}"
elif total_cost > 0:
cost_str = f"${total_cost:.6f}"
else:
cost_str = "$0"
logger.info(f"{status_prefix} ID:{self.request_id} | {self.status_code} | 耗时:{total_response_time}ms | "
f"Token:输入{self.input_tokens}/输出{self.output_tokens} | 费用:{cost_str}")
# 记录提供商结果用于动态权重调整
# 记录提供商结果的健康监控已由 FallbackOrchestrator 自动处理
# 这里不再需要手动记录
except Exception as e:
logger.exception(f"Failed to record stream usage: {e}")
finally:
if created_temp_session:
try:
db_for_usage.close()
except Exception:
pass
class EnhancedStreamUsageTracker(StreamUsageTracker):
"""
增强的流式用量跟踪器
支持更准确的token计算
"""
def __init__(
self,
db: Session,
user: User,
api_key: ApiKey,
provider: str,
model: str,
request_headers: Optional[Dict[str, Any]] = None,
provider_request_headers: Optional[Dict[str, Any]] = None,
request_id: Optional[str] = None,
start_time: Optional[float] = None,
attempt_id: Optional[str] = None,
# Provider 侧追踪信息(用于记录真实成本)
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
provider_api_key_id: Optional[str] = None,
# API 格式(用于选择正确的响应解析器)
api_format: Optional[str] = None,
):
super().__init__(
db,
user,
api_key,
provider,
model,
request_headers,
provider_request_headers,
request_id,
start_time,
attempt_id,
provider_id,
provider_endpoint_id,
provider_api_key_id,
api_format,
)
# 用于更准确的token计算
self._init_tokenizer()
# 继承父类的SSE解析缓冲区
# 这些已经在父类中初始化了
def _init_tokenizer(self):
"""初始化分词器(如果可用)"""
try:
# 尝试导入tiktoken用于更准确的token计算
import tiktoken
# 根据模型选择合适的编码
if "gpt-4" in self.model:
self.tokenizer = tiktoken.get_encoding("cl100k_base")
elif "gpt-3.5" in self.model:
self.tokenizer = tiktoken.get_encoding("cl100k_base")
else:
# Claude或其他模型使用近似方法
self.tokenizer = None
except ImportError:
logger.debug("tiktoken not available, using estimation")
self.tokenizer = None
def count_tokens(self, text: str) -> int:
"""
更准确地计算tokens
Args:
text: 文本内容
Returns:
token数量
"""
if self.tokenizer:
try:
return len(self.tokenizer.encode(text))
except Exception as e:
logger.warning(f"Token encoding failed: {e}")
# 回退到估算方法
# 中文字符通常是2个token英文单词约1.3个token
chinese_chars = len(re.findall(r"[\u4e00-\u9fff]", text))
english_words = len(re.findall(r"\b\w+\b", text))
estimated_tokens = chinese_chars * 2 + english_words * 1.3
return max(1, int(estimated_tokens))
def estimate_input_tokens(self, messages: list) -> int:
"""
更准确地估算输入tokens
Args:
messages: 消息列表
Returns:
估算的token数
"""
total_text = ""
for msg in messages:
if isinstance(msg, dict):
content = msg.get("content", "")
if isinstance(content, str):
total_text += content + " "
elif isinstance(content, list):
for block in content:
if hasattr(block, "text"):
total_text += block.text + " "
return self.count_tokens(total_text)
async def track_stream(
self,
stream: AsyncIterator[str],
request_data: Dict[str, Any],
response_headers: Optional[Dict[str, Any]] = None,
) -> AsyncIterator[str]:
"""
跟踪流式响应并更准确地计算用量
Args:
stream: 原始流式响应
request_data: 请求数据
response_headers: 实际的响应头
Yields:
流式响应块
"""
import time
self.start_time = time.time()
self.request_data = request_data # 保存请求数据
# 保存响应头(如果没有提供,使用空字典而不是默认值)
# 这样可以确保记录的是实际的响应头,而不是构造的默认值
self.response_headers = response_headers if response_headers is not None else {}
# 更准确地估算输入tokens
messages = request_data.get("messages", [])
self.input_tokens = self.estimate_input_tokens(messages)
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应(Enhanced) | 估算输入tokens:{self.input_tokens}")
chunk_count = 0
try:
async for chunk in stream:
chunk_count += 1
# 保存原始字节流(用于错误诊断)
self.raw_chunks.append(chunk)
# 返回原始块给客户端
yield chunk
# 解析块以提取内容和使用信息chunk是原始字节
content, usage = self.parse_stream_chunk(chunk)
if content:
self.accumulated_content += content
# 使用更准确的方法计算输出tokens
self.output_tokens = self.count_tokens(self.accumulated_content)
if usage:
# 如果响应中包含准确的usage信息优先使用
if "input_tokens" in usage:
self.input_tokens = usage["input_tokens"]
if "output_tokens" in usage:
self.output_tokens = usage["output_tokens"]
if "cache_creation_input_tokens" in usage:
self.cache_creation_input_tokens = usage["cache_creation_input_tokens"]
if "cache_read_input_tokens" in usage:
self.cache_read_input_tokens = usage["cache_read_input_tokens"]
finally:
# 流结束后记录使用量
self.end_time = time.time()
logger.debug(f"ID:{self.request_id} | 流式响应结束 | 共处理{chunk_count}个chunks | "
f"累积内容长度:{len(self.accumulated_content)} | 输出tokens:{self.output_tokens}")
# 检查是否收到了有效数据
# 情况1: 收到了原始数据但无法解析为有效的SSE JSON
if chunk_count > 0 and not self.response_chunks:
error_msg = f"流式响应完成但未解析到任何有效数据(收到 {chunk_count} 个原始chunk但无法解析"
logger.error(f"ID:{self.request_id} | {error_msg}")
# 设置错误状态,避免被记录为成功
self.set_error_status(502, error_msg)
# 抛出异常让 FallbackOrchestrator 捕获并触发故障转移
raise EmptyStreamException(
provider_name=self.provider,
chunk_count=chunk_count,
request_metadata=None,
)
# 情况2: 流式响应完成但没有收到完整的消息(没有 message_stop 事件)
# 这种情况通常发生在服务器重启或连接中断时
if not self.stream_stats.has_completion and not self.response_chunks:
error_msg = "流式响应中断:未收到任何有效数据(可能是连接中断或服务重启)"
logger.warning(f"ID:{self.request_id} | {error_msg}")
self.set_error_status(503, error_msg)
# 确保日志一定会输出即使记录usage失败
try:
await self._record_usage()
except Exception as e:
# 如果记录失败,至少输出基本的汇总日志
logger.exception(f"Failed to record stream usage for request {self.request_id}: {e}")
# 尝试输出基本的汇总日志,使用多层防护
try:
# 计算响应时间,使用多层后备机制
try:
if self.request_start_time and self.end_time:
total_response_time = int(
(self.end_time - self.request_start_time) * 1000
)
elif self.start_time and self.end_time:
total_response_time = int((self.end_time - self.start_time) * 1000)
else:
total_response_time = 0
except Exception:
total_response_time = 0
# 安全地输出汇总日志
logger.info(f"[请求完成] ID:{self.request_id or 'unknown'} | 200 | 耗时:{total_response_time}ms | "
f"Token:输入{self.input_tokens}/输出{self.output_tokens} | 费用:未知(记录失败)")
except Exception as log_error:
# 最后的防线:输出最简单的完成标记
logger.error(f"Failed to output summary log: {log_error}")
try:
logger.info(f"[请求完成] ID:{self.request_id or 'unknown'} | 记录失败但流已完成")
except Exception:
# 如果连最简单的日志都失败了,放弃
pass
# 导出便捷函数
def create_stream_tracker(
db: Session,
user: User,
api_key: ApiKey,
provider: str,
model: str,
enhanced: bool = True,
request_headers: Optional[Dict[str, Any]] = None,
provider_request_headers: Optional[Dict[str, Any]] = None,
request_id: Optional[str] = None,
start_time: Optional[float] = None,
attempt_id: Optional[str] = None,
# Provider 侧追踪信息(用于记录真实成本)
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
provider_api_key_id: Optional[str] = None,
# API 格式(用于选择正确的响应解析器)
api_format: Optional[str] = None,
) -> StreamUsageTracker:
"""
创建流式用量跟踪器
Args:
db: 数据库会话
user: 用户对象
api_key: API密钥对象
provider: 提供商名称
model: 模型名称
enhanced: 是否使用增强版跟踪器
request_headers: 实际的请求头
provider_request_headers: 向提供商发送的请求头
request_id: 请求ID用于日志关联
start_time: 请求开始时间(用于计算总响应时间)
attempt_id: RequestTrace 请求尝试ID可选
provider_id: Provider ID用于记录真实成本
provider_endpoint_id: Endpoint ID用于记录真实成本
provider_api_key_id: API Key ID用于记录真实成本
api_format: API 格式CLAUDE, CLAUDE_CLI, OPENAI, OPENAI_CLI
Returns:
流式用量跟踪器实例
"""
if enhanced:
return EnhancedStreamUsageTracker(
db,
user,
api_key,
provider,
model,
request_headers,
provider_request_headers,
request_id,
start_time,
attempt_id,
provider_id,
provider_endpoint_id,
provider_api_key_id,
api_format,
)
else:
return StreamUsageTracker(
db,
user,
api_key,
provider,
model,
request_headers,
provider_request_headers,
request_id,
start_time,
attempt_id,
provider_id,
provider_endpoint_id,
provider_api_key_id,
api_format,
)