Files
Aether/src/api/handlers/base/stream_telemetry.py
fawney19 ad1c8c394c refactor(handler): optimize stream processing and telemetry pipeline
- Enhance stream context for better token and latency tracking
- Refactor stream processor for improved performance metrics
- Improve telemetry integration with first_byte_time_ms support
- Add comprehensive stream context unit tests
2025-12-16 02:39:03 +08:00

300 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
流式遥测记录器 - 从 ChatHandlerBase 提取的统计记录逻辑
职责:
1. 记录流式请求的成功/失败统计
2. 更新 Usage 状态
3. 更新候选记录状态
"""
import asyncio
import time
from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
from src.api.handlers.base.base_handler import MessageTelemetry
from src.api.handlers.base.stream_context import StreamContext
from src.config.settings import config
from src.core.logger import logger
from src.database import get_db
from src.models.database import ApiKey, User
class StreamTelemetryRecorder:
"""
流式遥测记录器
负责在流式请求完成后记录统计信息。
从 ChatHandlerBase 中提取的 _record_stream_stats 逻辑。
"""
def __init__(
self,
request_id: str,
user_id: str,
api_key_id: str,
client_ip: str,
format_id: str,
):
"""
初始化遥测记录器
Args:
request_id: 请求 ID
user_id: 用户 ID
api_key_id: API Key ID
client_ip: 客户端 IP
format_id: API 格式标识
"""
self.request_id = request_id
self.user_id = user_id
self.api_key_id = api_key_id
self.client_ip = client_ip
self.format_id = format_id
async def record_stream_stats(
self,
ctx: StreamContext,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
start_time: float,
) -> None:
"""
记录流式统计信息
Args:
ctx: 流式上下文
original_headers: 原始请求头
original_request_body: 原始请求体
start_time: 请求开始时间 (time.time())
"""
bg_db = None
try:
# 在流结束后计算响应时间,与首字时间使用相同的时间基准
# 注意不要把统计延迟stream_stats_delay算进响应时间里
response_time_ms = int((time.time() - start_time) * 1000)
await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭
if not ctx.provider_name:
await self._update_usage_status_on_error(
response_time_ms=response_time_ms,
error_message="Provider name not available",
)
return
db_gen = get_db()
bg_db = next(db_gen)
try:
user = bg_db.query(User).filter(User.id == self.user_id).first()
api_key_obj = bg_db.query(ApiKey).filter(ApiKey.id == self.api_key_id).first()
if not user or not api_key_obj:
logger.warning(
f"[{self.request_id}] User or ApiKey not found, updating status directly"
)
await self._update_usage_status_directly(
bg_db,
status="completed" if ctx.is_success() else "failed",
response_time_ms=response_time_ms,
status_code=ctx.status_code,
)
return
bg_telemetry = MessageTelemetry(
bg_db, user, api_key_obj, self.request_id, self.client_ip
)
actual_request_body = ctx.provider_request_body or original_request_body
response_body = ctx.build_response_body(response_time_ms)
if ctx.is_success():
await self._record_success(
bg_telemetry,
ctx,
original_headers,
actual_request_body,
response_body,
response_time_ms,
)
else:
await self._record_failure(
bg_telemetry,
ctx,
original_headers,
actual_request_body,
response_body,
response_time_ms,
)
# 更新候选记录状态
await self._update_candidate_status(bg_db, ctx, response_time_ms)
finally:
if bg_db:
bg_db.close()
except Exception as e:
logger.exception("记录流式统计信息时出错")
await self._update_usage_status_on_error(
response_time_ms=response_time_ms,
error_message=f"记录统计信息失败: {str(e)[:200]}",
)
async def _record_success(
self,
telemetry: MessageTelemetry,
ctx: StreamContext,
original_headers: Dict[str, str],
actual_request_body: Dict[str, Any],
response_body: Dict[str, Any],
response_time_ms: int,
) -> None:
"""记录成功的请求"""
await telemetry.record_success(
provider=ctx.provider_name or "unknown",
model=ctx.model,
input_tokens=ctx.input_tokens,
output_tokens=ctx.output_tokens,
response_time_ms=response_time_ms,
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
status_code=ctx.status_code,
request_headers=original_headers,
request_body=actual_request_body,
response_headers=ctx.response_headers,
response_body=response_body,
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
is_stream=True,
provider_request_headers=ctx.provider_request_headers,
api_format=ctx.api_format,
provider_id=ctx.provider_id,
provider_endpoint_id=ctx.endpoint_id,
provider_api_key_id=ctx.key_id,
target_model=ctx.mapped_model,
)
logger.debug(f"{self.format_id} 流式响应完成")
logger.info(ctx.get_log_summary(self.request_id, response_time_ms))
async def _record_failure(
self,
telemetry: MessageTelemetry,
ctx: StreamContext,
original_headers: Dict[str, str],
actual_request_body: Dict[str, Any],
response_body: Dict[str, Any],
response_time_ms: int,
) -> None:
"""记录失败的请求"""
await telemetry.record_failure(
provider=ctx.provider_name or "unknown",
model=ctx.model,
response_time_ms=response_time_ms,
status_code=ctx.status_code,
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
api_format=ctx.api_format,
provider_request_headers=ctx.provider_request_headers,
input_tokens=ctx.input_tokens,
output_tokens=ctx.output_tokens,
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
response_body=response_body,
target_model=ctx.mapped_model,
)
logger.debug(f"{self.format_id} 流式响应中断")
log_summary = ctx.get_log_summary(self.request_id, response_time_ms)
# 对于失败日志,添加缓存信息
logger.info(f"{log_summary} cache:{ctx.cached_tokens}")
async def _update_candidate_status(
self,
db: Session,
ctx: StreamContext,
response_time_ms: int,
) -> None:
"""更新候选记录状态"""
if not ctx.attempt_id:
return
from src.services.request.candidate import RequestCandidateService
if ctx.is_success():
RequestCandidateService.mark_candidate_success(
db=db,
candidate_id=ctx.attempt_id,
status_code=ctx.status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": True,
"data_count": ctx.data_count,
},
)
else:
error_type = "client_disconnected" if ctx.status_code == 499 else "stream_error"
RequestCandidateService.mark_candidate_failed(
db=db,
candidate_id=ctx.attempt_id,
error_type=error_type,
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
status_code=ctx.status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": False,
"data_count": ctx.data_count,
},
)
async def _update_usage_status_on_error(
self,
response_time_ms: int,
error_message: str,
) -> None:
"""在记录失败时更新 Usage 状态"""
try:
db_gen = get_db()
error_db = next(db_gen)
try:
await self._update_usage_status_directly(
error_db,
status="failed",
response_time_ms=response_time_ms,
status_code=500,
error_message=error_message,
)
finally:
error_db.close()
except Exception as inner_e:
logger.error(f"[{self.request_id}] 更新 Usage 状态失败: {inner_e}")
async def _update_usage_status_directly(
self,
db: Session,
status: str,
response_time_ms: int,
status_code: int = 200,
error_message: Optional[str] = None,
) -> None:
"""直接更新 Usage 表的状态字段"""
try:
from src.models.database import Usage
usage = db.query(Usage).filter(Usage.request_id == self.request_id).first()
if usage:
setattr(usage, "status", status)
setattr(usage, "status_code", status_code)
setattr(usage, "response_time_ms", response_time_ms)
if error_message:
setattr(usage, "error_message", error_message)
db.commit()
logger.debug(f"[{self.request_id}] Usage 状态已更新: {status}")
except Exception as e:
logger.error(f"[{self.request_id}] 直接更新 Usage 状态失败: {e}")