Files
Aether/src/api/handlers/base/stream_telemetry.py

300 lines
10 KiB
Python
Raw Normal View History

"""
流式遥测记录器 - ChatHandlerBase 提取的统计记录逻辑
职责
1. 记录流式请求的成功/失败统计
2. 更新 Usage 状态
3. 更新候选记录状态
"""
import asyncio
import time
from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
from src.api.handlers.base.base_handler import MessageTelemetry
from src.api.handlers.base.stream_context import StreamContext
from src.config.settings import config
from src.core.logger import logger
from src.database import get_db
from src.models.database import ApiKey, User
class StreamTelemetryRecorder:
"""
流式遥测记录器
负责在流式请求完成后记录统计信息
ChatHandlerBase 中提取的 _record_stream_stats 逻辑
"""
def __init__(
self,
request_id: str,
user_id: str,
api_key_id: str,
client_ip: str,
format_id: str,
):
"""
初始化遥测记录器
Args:
request_id: 请求 ID
user_id: 用户 ID
api_key_id: API Key ID
client_ip: 客户端 IP
format_id: API 格式标识
"""
self.request_id = request_id
self.user_id = user_id
self.api_key_id = api_key_id
self.client_ip = client_ip
self.format_id = format_id
async def record_stream_stats(
self,
ctx: StreamContext,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
start_time: float,
) -> None:
"""
记录流式统计信息
Args:
ctx: 流式上下文
original_headers: 原始请求头
original_request_body: 原始请求体
start_time: 请求开始时间 (time.time())
"""
bg_db = None
try:
# 在流结束后计算响应时间,与首字时间使用相同的时间基准
# 注意不要把统计延迟stream_stats_delay算进响应时间里
response_time_ms = int((time.time() - start_time) * 1000)
await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭
if not ctx.provider_name:
await self._update_usage_status_on_error(
response_time_ms=response_time_ms,
error_message="Provider name not available",
)
return
db_gen = get_db()
bg_db = next(db_gen)
try:
user = bg_db.query(User).filter(User.id == self.user_id).first()
api_key_obj = bg_db.query(ApiKey).filter(ApiKey.id == self.api_key_id).first()
if not user or not api_key_obj:
logger.warning(
f"[{self.request_id}] User or ApiKey not found, updating status directly"
)
await self._update_usage_status_directly(
bg_db,
status="completed" if ctx.is_success() else "failed",
response_time_ms=response_time_ms,
status_code=ctx.status_code,
)
return
bg_telemetry = MessageTelemetry(
bg_db, user, api_key_obj, self.request_id, self.client_ip
)
actual_request_body = ctx.provider_request_body or original_request_body
response_body = ctx.build_response_body(response_time_ms)
if ctx.is_success():
await self._record_success(
bg_telemetry,
ctx,
original_headers,
actual_request_body,
response_body,
response_time_ms,
)
else:
await self._record_failure(
bg_telemetry,
ctx,
original_headers,
actual_request_body,
response_body,
response_time_ms,
)
# 更新候选记录状态
await self._update_candidate_status(bg_db, ctx, response_time_ms)
finally:
if bg_db:
bg_db.close()
except Exception as e:
logger.exception("记录流式统计信息时出错")
await self._update_usage_status_on_error(
response_time_ms=response_time_ms,
error_message=f"记录统计信息失败: {str(e)[:200]}",
)
async def _record_success(
self,
telemetry: MessageTelemetry,
ctx: StreamContext,
original_headers: Dict[str, str],
actual_request_body: Dict[str, Any],
response_body: Dict[str, Any],
response_time_ms: int,
) -> None:
"""记录成功的请求"""
await telemetry.record_success(
provider=ctx.provider_name or "unknown",
model=ctx.model,
input_tokens=ctx.input_tokens,
output_tokens=ctx.output_tokens,
response_time_ms=response_time_ms,
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
status_code=ctx.status_code,
request_headers=original_headers,
request_body=actual_request_body,
response_headers=ctx.response_headers,
response_body=response_body,
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
is_stream=True,
provider_request_headers=ctx.provider_request_headers,
api_format=ctx.api_format,
provider_id=ctx.provider_id,
provider_endpoint_id=ctx.endpoint_id,
provider_api_key_id=ctx.key_id,
target_model=ctx.mapped_model,
)
logger.debug(f"{self.format_id} 流式响应完成")
logger.info(ctx.get_log_summary(self.request_id, response_time_ms))
async def _record_failure(
self,
telemetry: MessageTelemetry,
ctx: StreamContext,
original_headers: Dict[str, str],
actual_request_body: Dict[str, Any],
response_body: Dict[str, Any],
response_time_ms: int,
) -> None:
"""记录失败的请求"""
await telemetry.record_failure(
provider=ctx.provider_name or "unknown",
model=ctx.model,
response_time_ms=response_time_ms,
status_code=ctx.status_code,
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
api_format=ctx.api_format,
provider_request_headers=ctx.provider_request_headers,
input_tokens=ctx.input_tokens,
output_tokens=ctx.output_tokens,
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
response_body=response_body,
target_model=ctx.mapped_model,
)
logger.debug(f"{self.format_id} 流式响应中断")
log_summary = ctx.get_log_summary(self.request_id, response_time_ms)
# 对于失败日志,添加缓存信息
logger.info(f"{log_summary} cache:{ctx.cached_tokens}")
async def _update_candidate_status(
self,
db: Session,
ctx: StreamContext,
response_time_ms: int,
) -> None:
"""更新候选记录状态"""
if not ctx.attempt_id:
return
from src.services.request.candidate import RequestCandidateService
if ctx.is_success():
RequestCandidateService.mark_candidate_success(
db=db,
candidate_id=ctx.attempt_id,
status_code=ctx.status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": True,
"data_count": ctx.data_count,
},
)
else:
error_type = "client_disconnected" if ctx.status_code == 499 else "stream_error"
RequestCandidateService.mark_candidate_failed(
db=db,
candidate_id=ctx.attempt_id,
error_type=error_type,
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
status_code=ctx.status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": False,
"data_count": ctx.data_count,
},
)
async def _update_usage_status_on_error(
self,
response_time_ms: int,
error_message: str,
) -> None:
"""在记录失败时更新 Usage 状态"""
try:
db_gen = get_db()
error_db = next(db_gen)
try:
await self._update_usage_status_directly(
error_db,
status="failed",
response_time_ms=response_time_ms,
status_code=500,
error_message=error_message,
)
finally:
error_db.close()
except Exception as inner_e:
logger.error(f"[{self.request_id}] 更新 Usage 状态失败: {inner_e}")
async def _update_usage_status_directly(
self,
db: Session,
status: str,
response_time_ms: int,
status_code: int = 200,
error_message: Optional[str] = None,
) -> None:
"""直接更新 Usage 表的状态字段"""
try:
from src.models.database import Usage
usage = db.query(Usage).filter(Usage.request_id == self.request_id).first()
if usage:
setattr(usage, "status", status)
setattr(usage, "status_code", status_code)
setattr(usage, "response_time_ms", response_time_ms)
if error_message:
setattr(usage, "error_message", error_message)
db.commit()
logger.debug(f"[{self.request_id}] Usage 状态已更新: {status}")
except Exception as e:
logger.error(f"[{self.request_id}] 直接更新 Usage 状态失败: {e}")