mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-08 10:42:29 +08:00
feat: add TTFB timeout detection and improve stream handling
- Add stream first byte timeout (TTFB) detection to trigger failover when provider responds too slowly (configurable via STREAM_FIRST_BYTE_TIMEOUT) - Add rate limit fail-open/fail-close strategy configuration - Improve exception handling in stream prefetch with proper error classification - Refactor UsageService with shared _prepare_usage_record method - Add batch deletion for old usage records to avoid long transaction locks - Update CLI adapters to use proper User-Agent headers for each CLI client - Add composite indexes migration for usage table query optimization - Fix streaming status display in frontend to show TTFB during streaming - Remove sensitive JWT secret logging in auth service
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -16,6 +17,71 @@ from src.services.model.cost import ModelCostService
|
||||
from src.services.system.config import SystemConfigService
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageRecordParams:
|
||||
"""用量记录参数数据类,用于在内部方法间传递数据"""
|
||||
db: Session
|
||||
user: Optional[User]
|
||||
api_key: Optional[ApiKey]
|
||||
provider: str
|
||||
model: str
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
cache_creation_input_tokens: int
|
||||
cache_read_input_tokens: int
|
||||
request_type: str
|
||||
api_format: Optional[str]
|
||||
is_stream: bool
|
||||
response_time_ms: Optional[int]
|
||||
first_byte_time_ms: Optional[int]
|
||||
status_code: int
|
||||
error_message: Optional[str]
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
request_headers: Optional[Dict[str, Any]]
|
||||
request_body: Optional[Any]
|
||||
provider_request_headers: Optional[Dict[str, Any]]
|
||||
response_headers: Optional[Dict[str, Any]]
|
||||
response_body: Optional[Any]
|
||||
request_id: str
|
||||
provider_id: Optional[str]
|
||||
provider_endpoint_id: Optional[str]
|
||||
provider_api_key_id: Optional[str]
|
||||
status: str
|
||||
cache_ttl_minutes: Optional[int]
|
||||
use_tiered_pricing: bool
|
||||
target_model: Optional[str]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""验证关键字段,确保数据完整性"""
|
||||
# Token 数量不能为负数
|
||||
if self.input_tokens < 0:
|
||||
raise ValueError(f"input_tokens 不能为负数: {self.input_tokens}")
|
||||
if self.output_tokens < 0:
|
||||
raise ValueError(f"output_tokens 不能为负数: {self.output_tokens}")
|
||||
if self.cache_creation_input_tokens < 0:
|
||||
raise ValueError(
|
||||
f"cache_creation_input_tokens 不能为负数: {self.cache_creation_input_tokens}"
|
||||
)
|
||||
if self.cache_read_input_tokens < 0:
|
||||
raise ValueError(
|
||||
f"cache_read_input_tokens 不能为负数: {self.cache_read_input_tokens}"
|
||||
)
|
||||
|
||||
# 响应时间不能为负数
|
||||
if self.response_time_ms is not None and self.response_time_ms < 0:
|
||||
raise ValueError(f"response_time_ms 不能为负数: {self.response_time_ms}")
|
||||
if self.first_byte_time_ms is not None and self.first_byte_time_ms < 0:
|
||||
raise ValueError(f"first_byte_time_ms 不能为负数: {self.first_byte_time_ms}")
|
||||
|
||||
# HTTP 状态码范围校验
|
||||
if not (100 <= self.status_code <= 599):
|
||||
raise ValueError(f"无效的 HTTP 状态码: {self.status_code}")
|
||||
|
||||
# 状态值校验
|
||||
valid_statuses = {"pending", "streaming", "completed", "failed"}
|
||||
if self.status not in valid_statuses:
|
||||
raise ValueError(f"无效的状态值: {self.status},有效值: {valid_statuses}")
|
||||
|
||||
|
||||
class UsageService:
|
||||
"""用量统计服务"""
|
||||
@@ -471,6 +537,97 @@ class UsageService:
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _prepare_usage_record(
|
||||
cls,
|
||||
params: UsageRecordParams,
|
||||
) -> Tuple[Dict[str, Any], float]:
|
||||
"""准备用量记录的共享逻辑
|
||||
|
||||
此方法提取了 record_usage 和 record_usage_async 的公共处理逻辑:
|
||||
- 获取费率倍数
|
||||
- 计算成本
|
||||
- 构建 Usage 参数
|
||||
|
||||
Args:
|
||||
params: 用量记录参数数据类
|
||||
|
||||
Returns:
|
||||
(usage_params 字典, total_cost 总成本)
|
||||
"""
|
||||
# 获取费率倍数和是否免费套餐
|
||||
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
|
||||
params.db, params.provider_api_key_id, params.provider_id
|
||||
)
|
||||
|
||||
# 计算成本
|
||||
is_failed_request = params.status_code >= 400 or params.error_message is not None
|
||||
(
|
||||
input_price, output_price, cache_creation_price, cache_read_price, request_price,
|
||||
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
|
||||
request_cost, total_cost, _tier_index
|
||||
) = await cls._calculate_costs(
|
||||
db=params.db,
|
||||
provider=params.provider,
|
||||
model=params.model,
|
||||
input_tokens=params.input_tokens,
|
||||
output_tokens=params.output_tokens,
|
||||
cache_creation_input_tokens=params.cache_creation_input_tokens,
|
||||
cache_read_input_tokens=params.cache_read_input_tokens,
|
||||
api_format=params.api_format,
|
||||
cache_ttl_minutes=params.cache_ttl_minutes,
|
||||
use_tiered_pricing=params.use_tiered_pricing,
|
||||
is_failed_request=is_failed_request,
|
||||
)
|
||||
|
||||
# 构建 Usage 参数
|
||||
usage_params = cls._build_usage_params(
|
||||
db=params.db,
|
||||
user=params.user,
|
||||
api_key=params.api_key,
|
||||
provider=params.provider,
|
||||
model=params.model,
|
||||
input_tokens=params.input_tokens,
|
||||
output_tokens=params.output_tokens,
|
||||
cache_creation_input_tokens=params.cache_creation_input_tokens,
|
||||
cache_read_input_tokens=params.cache_read_input_tokens,
|
||||
request_type=params.request_type,
|
||||
api_format=params.api_format,
|
||||
is_stream=params.is_stream,
|
||||
response_time_ms=params.response_time_ms,
|
||||
first_byte_time_ms=params.first_byte_time_ms,
|
||||
status_code=params.status_code,
|
||||
error_message=params.error_message,
|
||||
metadata=params.metadata,
|
||||
request_headers=params.request_headers,
|
||||
request_body=params.request_body,
|
||||
provider_request_headers=params.provider_request_headers,
|
||||
response_headers=params.response_headers,
|
||||
response_body=params.response_body,
|
||||
request_id=params.request_id,
|
||||
provider_id=params.provider_id,
|
||||
provider_endpoint_id=params.provider_endpoint_id,
|
||||
provider_api_key_id=params.provider_api_key_id,
|
||||
status=params.status,
|
||||
target_model=params.target_model,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
cache_creation_cost=cache_creation_cost,
|
||||
cache_read_cost=cache_read_cost,
|
||||
cache_cost=cache_cost,
|
||||
request_cost=request_cost,
|
||||
total_cost=total_cost,
|
||||
input_price=input_price,
|
||||
output_price=output_price,
|
||||
cache_creation_price=cache_creation_price,
|
||||
cache_read_price=cache_read_price,
|
||||
request_price=request_price,
|
||||
actual_rate_multiplier=actual_rate_multiplier,
|
||||
is_free_tier=is_free_tier,
|
||||
)
|
||||
|
||||
return usage_params, total_cost
|
||||
|
||||
@classmethod
|
||||
async def record_usage_async(
|
||||
cls,
|
||||
@@ -516,76 +673,25 @@ class UsageService:
|
||||
if request_id is None:
|
||||
request_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# 获取费率倍数和是否免费套餐
|
||||
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
|
||||
db, provider_api_key_id, provider_id
|
||||
)
|
||||
|
||||
# 计算成本
|
||||
is_failed_request = status_code >= 400 or error_message is not None
|
||||
(
|
||||
input_price, output_price, cache_creation_price, cache_read_price, request_price,
|
||||
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
|
||||
request_cost, total_cost, tier_index
|
||||
) = await cls._calculate_costs(
|
||||
db=db,
|
||||
provider=provider,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
# 使用共享逻辑准备记录参数
|
||||
params = UsageRecordParams(
|
||||
db=db, user=user, api_key=api_key, provider=provider, model=model,
|
||||
input_tokens=input_tokens, output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
api_format=api_format,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
use_tiered_pricing=use_tiered_pricing,
|
||||
is_failed_request=is_failed_request,
|
||||
)
|
||||
|
||||
# 构建 Usage 参数
|
||||
usage_params = cls._build_usage_params(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
provider=provider,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
request_type=request_type,
|
||||
api_format=api_format,
|
||||
is_stream=is_stream,
|
||||
response_time_ms=response_time_ms,
|
||||
first_byte_time_ms=first_byte_time_ms,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
metadata=metadata,
|
||||
request_headers=request_headers,
|
||||
request_body=request_body,
|
||||
request_type=request_type, api_format=api_format, is_stream=is_stream,
|
||||
response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms,
|
||||
status_code=status_code, error_message=error_message, metadata=metadata,
|
||||
request_headers=request_headers, request_body=request_body,
|
||||
provider_request_headers=provider_request_headers,
|
||||
response_headers=response_headers,
|
||||
response_body=response_body,
|
||||
request_id=request_id,
|
||||
provider_id=provider_id,
|
||||
response_headers=response_headers, response_body=response_body,
|
||||
request_id=request_id, provider_id=provider_id,
|
||||
provider_endpoint_id=provider_endpoint_id,
|
||||
provider_api_key_id=provider_api_key_id,
|
||||
status=status,
|
||||
provider_api_key_id=provider_api_key_id, status=status,
|
||||
cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing,
|
||||
target_model=target_model,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
cache_creation_cost=cache_creation_cost,
|
||||
cache_read_cost=cache_read_cost,
|
||||
cache_cost=cache_cost,
|
||||
request_cost=request_cost,
|
||||
total_cost=total_cost,
|
||||
input_price=input_price,
|
||||
output_price=output_price,
|
||||
cache_creation_price=cache_creation_price,
|
||||
cache_read_price=cache_read_price,
|
||||
request_price=request_price,
|
||||
actual_rate_multiplier=actual_rate_multiplier,
|
||||
is_free_tier=is_free_tier,
|
||||
)
|
||||
usage_params, _ = await cls._prepare_usage_record(params)
|
||||
|
||||
# 创建 Usage 记录
|
||||
usage = Usage(**usage_params)
|
||||
@@ -660,76 +766,25 @@ class UsageService:
|
||||
if request_id is None:
|
||||
request_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# 获取费率倍数和是否免费套餐
|
||||
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
|
||||
db, provider_api_key_id, provider_id
|
||||
)
|
||||
|
||||
# 计算成本
|
||||
is_failed_request = status_code >= 400 or error_message is not None
|
||||
(
|
||||
input_price, output_price, cache_creation_price, cache_read_price, request_price,
|
||||
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
|
||||
request_cost, total_cost, _tier_index
|
||||
) = await cls._calculate_costs(
|
||||
db=db,
|
||||
provider=provider,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
# 使用共享逻辑准备记录参数
|
||||
params = UsageRecordParams(
|
||||
db=db, user=user, api_key=api_key, provider=provider, model=model,
|
||||
input_tokens=input_tokens, output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
api_format=api_format,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
use_tiered_pricing=use_tiered_pricing,
|
||||
is_failed_request=is_failed_request,
|
||||
)
|
||||
|
||||
# 构建 Usage 参数
|
||||
usage_params = cls._build_usage_params(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
provider=provider,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
request_type=request_type,
|
||||
api_format=api_format,
|
||||
is_stream=is_stream,
|
||||
response_time_ms=response_time_ms,
|
||||
first_byte_time_ms=first_byte_time_ms,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
metadata=metadata,
|
||||
request_headers=request_headers,
|
||||
request_body=request_body,
|
||||
request_type=request_type, api_format=api_format, is_stream=is_stream,
|
||||
response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms,
|
||||
status_code=status_code, error_message=error_message, metadata=metadata,
|
||||
request_headers=request_headers, request_body=request_body,
|
||||
provider_request_headers=provider_request_headers,
|
||||
response_headers=response_headers,
|
||||
response_body=response_body,
|
||||
request_id=request_id,
|
||||
provider_id=provider_id,
|
||||
response_headers=response_headers, response_body=response_body,
|
||||
request_id=request_id, provider_id=provider_id,
|
||||
provider_endpoint_id=provider_endpoint_id,
|
||||
provider_api_key_id=provider_api_key_id,
|
||||
status=status,
|
||||
provider_api_key_id=provider_api_key_id, status=status,
|
||||
cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing,
|
||||
target_model=target_model,
|
||||
input_cost=input_cost,
|
||||
output_cost=output_cost,
|
||||
cache_creation_cost=cache_creation_cost,
|
||||
cache_read_cost=cache_read_cost,
|
||||
cache_cost=cache_cost,
|
||||
request_cost=request_cost,
|
||||
total_cost=total_cost,
|
||||
input_price=input_price,
|
||||
output_price=output_price,
|
||||
cache_creation_price=cache_creation_price,
|
||||
cache_read_price=cache_read_price,
|
||||
request_price=request_price,
|
||||
actual_rate_multiplier=actual_rate_multiplier,
|
||||
is_free_tier=is_free_tier,
|
||||
)
|
||||
usage_params, total_cost = await cls._prepare_usage_record(params)
|
||||
|
||||
# 检查是否已存在相同 request_id 的记录
|
||||
existing_usage = db.query(Usage).filter(Usage.request_id == request_id).first()
|
||||
@@ -751,7 +806,7 @@ class UsageService:
|
||||
api_key = db.merge(api_key)
|
||||
|
||||
# 使用原子更新避免并发竞态条件
|
||||
from sqlalchemy import func, update
|
||||
from sqlalchemy import func as sql_func, update
|
||||
from src.models.database import ApiKey as ApiKeyModel, User as UserModel, GlobalModel
|
||||
|
||||
# 更新用户使用量(独立 Key 不计入创建者的使用记录)
|
||||
@@ -762,7 +817,7 @@ class UsageService:
|
||||
.values(
|
||||
used_usd=UserModel.used_usd + total_cost,
|
||||
total_usd=UserModel.total_usd + total_cost,
|
||||
updated_at=func.now(),
|
||||
updated_at=sql_func.now(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -776,8 +831,8 @@ class UsageService:
|
||||
total_requests=ApiKeyModel.total_requests + 1,
|
||||
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
|
||||
balance_used_usd=ApiKeyModel.balance_used_usd + total_cost,
|
||||
last_used_at=func.now(),
|
||||
updated_at=func.now(),
|
||||
last_used_at=sql_func.now(),
|
||||
updated_at=sql_func.now(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -787,8 +842,8 @@ class UsageService:
|
||||
.values(
|
||||
total_requests=ApiKeyModel.total_requests + 1,
|
||||
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
|
||||
last_used_at=func.now(),
|
||||
updated_at=func.now(),
|
||||
last_used_at=sql_func.now(),
|
||||
updated_at=sql_func.now(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1121,19 +1176,48 @@ class UsageService:
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def cleanup_old_usage_records(db: Session, days_to_keep: int = 90) -> int:
|
||||
"""清理旧的使用记录"""
|
||||
def cleanup_old_usage_records(
|
||||
db: Session, days_to_keep: int = 90, batch_size: int = 1000
|
||||
) -> int:
|
||||
"""清理旧的使用记录(分批删除避免长事务锁定)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
days_to_keep: 保留天数,默认 90 天
|
||||
batch_size: 每批删除数量,默认 1000 条
|
||||
|
||||
Returns:
|
||||
删除的总记录数
|
||||
"""
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
total_deleted = 0
|
||||
|
||||
# 删除旧记录
|
||||
deleted = db.query(Usage).filter(Usage.created_at < cutoff_date).delete()
|
||||
while True:
|
||||
# 查询待删除的 ID(使用新索引 idx_usage_user_created)
|
||||
batch_ids = (
|
||||
db.query(Usage.id)
|
||||
.filter(Usage.created_at < cutoff_date)
|
||||
.limit(batch_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
db.commit()
|
||||
if not batch_ids:
|
||||
break
|
||||
|
||||
logger.info(f"清理使用记录: 删除 {deleted} 条超过 {days_to_keep} 天的记录")
|
||||
# 批量删除
|
||||
deleted_count = (
|
||||
db.query(Usage)
|
||||
.filter(Usage.id.in_([row.id for row in batch_ids]))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
db.commit()
|
||||
total_deleted += deleted_count
|
||||
|
||||
return deleted
|
||||
logger.debug(f"清理使用记录: 本批删除 {deleted_count} 条")
|
||||
|
||||
logger.info(f"清理使用记录: 共删除 {total_deleted} 条超过 {days_to_keep} 天的记录")
|
||||
|
||||
return total_deleted
|
||||
|
||||
# ========== 请求状态追踪方法 ==========
|
||||
|
||||
@@ -1219,6 +1303,7 @@ class UsageService:
|
||||
error_message: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
target_model: Optional[str] = None,
|
||||
first_byte_time_ms: Optional[int] = None,
|
||||
) -> Optional[Usage]:
|
||||
"""
|
||||
快速更新使用记录状态
|
||||
@@ -1230,6 +1315,7 @@ class UsageService:
|
||||
error_message: 错误消息(仅在 failed 状态时使用)
|
||||
provider: 提供商名称(可选,streaming 状态时更新)
|
||||
target_model: 映射后的目标模型名(可选)
|
||||
first_byte_time_ms: 首字时间/TTFB(可选,streaming 状态时更新)
|
||||
|
||||
Returns:
|
||||
更新后的 Usage 记录,如果未找到则返回 None
|
||||
@@ -1247,6 +1333,8 @@ class UsageService:
|
||||
usage.provider = provider
|
||||
if target_model:
|
||||
usage.target_model = target_model
|
||||
if first_byte_time_ms is not None:
|
||||
usage.first_byte_time_ms = first_byte_time_ms
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user