feat: add TTFB timeout detection and improve stream handling

- Add stream first byte timeout (TTFB) detection to trigger failover
  when provider responds too slowly (configurable via STREAM_FIRST_BYTE_TIMEOUT)
- Add rate limit fail-open/fail-close strategy configuration
- Improve exception handling in stream prefetch with proper error classification
- Refactor UsageService with shared _prepare_usage_record method
- Add batch deletion for old usage records to avoid long transaction locks
- Update CLI adapters to use proper User-Agent headers for each CLI client
- Add composite indexes migration for usage table query optimization
- Fix streaming status display in frontend to show TTFB during streaming
- Remove sensitive JWT secret logging in auth service
This commit is contained in:
fawney19
2025-12-22 23:44:42 +08:00
parent 4e1aed9976
commit 1d5c378343
14 changed files with 588 additions and 181 deletions

View File

@@ -0,0 +1,65 @@
"""add usage table composite indexes for query optimization
Revision ID: b2c3d4e5f6g7
Revises: a1b2c3d4e5f6
Create Date: 2025-12-20 15:00:00.000000+00:00
"""
from alembic import op
from sqlalchemy import inspect
# revision identifiers, used by Alembic.
revision = 'b2c3d4e5f6g7'
down_revision = 'a1b2c3d4e5f6'
branch_labels = None
depends_on = None
def index_exists(table_name: str, index_name: str) -> bool:
"""检查索引是否存在"""
bind = op.get_bind()
inspector = inspect(bind)
indexes = [idx['name'] for idx in inspector.get_indexes(table_name)]
return index_name in indexes
def upgrade() -> None:
"""为 usage 表添加复合索引以优化常见查询"""
# 1. user_id + created_at 复合索引 (用户用量查询)
if not index_exists('usage', 'idx_usage_user_created'):
op.create_index(
'idx_usage_user_created',
'usage',
['user_id', 'created_at'],
postgresql_concurrently=True
)
# 2. api_key_id + created_at 复合索引 (API Key 用量查询)
if not index_exists('usage', 'idx_usage_apikey_created'):
op.create_index(
'idx_usage_apikey_created',
'usage',
['api_key_id', 'created_at'],
postgresql_concurrently=True
)
# 3. provider + model + created_at 复合索引 (模型统计查询)
if not index_exists('usage', 'idx_usage_provider_model_created'):
op.create_index(
'idx_usage_provider_model_created',
'usage',
['provider', 'model', 'created_at'],
postgresql_concurrently=True
)
def downgrade() -> None:
"""删除复合索引"""
if index_exists('usage', 'idx_usage_provider_model_created'):
op.drop_index('idx_usage_provider_model_created', table_name='usage')
if index_exists('usage', 'idx_usage_apikey_created'):
op.drop_index('idx_usage_apikey_created', table_name='usage')
if index_exists('usage', 'idx_usage_user_created'):
op.drop_index('idx_usage_user_created', table_name='usage')

View File

@@ -366,14 +366,34 @@
</div>
</TableCell>
<TableCell class="text-right py-4 w-[70px]">
<!-- pending 状态只显示增长的总时间 -->
<div
v-if="record.status === 'pending' || record.status === 'streaming'"
v-if="record.status === 'pending'"
class="flex flex-col items-end text-xs gap-0.5"
>
<span class="text-muted-foreground">-</span>
<span class="text-primary tabular-nums">
{{ getElapsedTime(record) }}
</span>
</div>
<!-- streaming 状态首字固定 + 总时间增长 -->
<div
v-else-if="record.status === 'streaming'"
class="flex flex-col items-end text-xs gap-0.5"
>
<span
v-if="record.first_byte_time_ms != null"
class="tabular-nums"
>{{ (record.first_byte_time_ms / 1000).toFixed(2) }}s</span>
<span
v-else
class="text-muted-foreground"
>-</span>
<span class="text-primary tabular-nums">
{{ getElapsedTime(record) }}
</span>
</div>
<!-- 已完成状态首字 + 总耗时 -->
<div
v-else-if="record.response_time_ms != null"
class="flex flex-col items-end text-xs gap-0.5"

View File

@@ -376,6 +376,9 @@ class BaseMessageHandler:
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
注意TTFB首字节时间由 StreamContext.record_first_byte_time() 记录,
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
Args:
request_id: 请求 ID如果不传则使用 self.request_id
"""
@@ -407,6 +410,9 @@ class BaseMessageHandler:
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
注意TTFB首字节时间由 StreamContext.record_first_byte_time() 记录,
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
Args:
ctx: 流式上下文,包含 provider_name 和 mapped_model
"""

View File

@@ -57,8 +57,10 @@ from src.models.database import (
ProviderEndpoint,
User,
)
from src.config.settings import config
from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
class CliMessageHandlerBase(BaseMessageHandler):
@@ -672,6 +674,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
首次读取时会应用 TTFB首字节超时检测超时则触发故障转移。
Args:
byte_iterator: 字节流迭代器
provider: Provider 对象
@@ -684,6 +688,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
ProviderTimeoutException: 如果首字节超时TTFB timeout
"""
prefetched_chunks: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误
@@ -704,7 +709,19 @@ class CliMessageHandlerBase(BaseMessageHandler):
else:
provider_parser = self.parser
async for chunk in byte_iterator:
# 使用共享的 TTFB 超时函数读取首字节
ttfb_timeout = config.stream_first_byte_timeout
first_chunk, aiter = await read_first_chunk_with_ttfb_timeout(
byte_iterator,
timeout=ttfb_timeout,
request_id=self.request_id,
provider_name=str(provider.name),
)
prefetched_chunks.append(first_chunk)
buffer += first_chunk
# 继续读取剩余的预读数据
async for chunk in aiter:
prefetched_chunks.append(chunk)
buffer += chunk
@@ -785,12 +802,21 @@ class CliMessageHandlerBase(BaseMessageHandler):
if should_stop or line_count >= max_prefetch_lines:
break
except EmbeddedErrorException:
# 重新抛出嵌套错误
except (EmbeddedErrorException, ProviderTimeoutException, ProviderNotAvailableException):
# 重新抛出可重试的 Provider 异常,触发故障转移
raise
except (OSError, IOError) as e:
# 网络 I/O 异常:记录警告,可能需要重试
logger.warning(
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
)
except Exception as e:
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
# 未预期的严重异常:记录错误并重新抛出,避免掩盖问题
logger.error(
f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}",
exc_info=True
)
raise
return prefetched_chunks

View File

@@ -25,10 +25,12 @@ from src.api.handlers.base.content_extractors import (
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
from src.core.exceptions import EmbeddedErrorException
from src.config.settings import config
from src.core.exceptions import EmbeddedErrorException, ProviderTimeoutException
from src.core.logger import logger
from src.models.database import Provider, ProviderEndpoint
from src.utils.sse_parser import SSEEventParser
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
@dataclass
@@ -170,6 +172,8 @@ class StreamProcessor:
某些 Provider如 Gemini可能返回 HTTP 200但在响应体中包含错误信息。
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
首次读取时会应用 TTFB首字节超时检测超时则触发故障转移。
Args:
byte_iterator: 字节流迭代器
provider: Provider 对象
@@ -182,6 +186,7 @@ class StreamProcessor:
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
ProviderTimeoutException: 如果首字节超时TTFB timeout
"""
prefetched_chunks: list = []
parser = self.get_parser_for_provider(ctx)
@@ -192,7 +197,19 @@ class StreamProcessor:
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
try:
async for chunk in byte_iterator:
# 使用共享的 TTFB 超时函数读取首字节
ttfb_timeout = config.stream_first_byte_timeout
first_chunk, aiter = await read_first_chunk_with_ttfb_timeout(
byte_iterator,
timeout=ttfb_timeout,
request_id=self.request_id,
provider_name=str(provider.name),
)
prefetched_chunks.append(first_chunk)
buffer += first_chunk
# 继续读取剩余的预读数据
async for chunk in aiter:
prefetched_chunks.append(chunk)
buffer += chunk
@@ -262,10 +279,21 @@ class StreamProcessor:
if should_stop or line_count >= max_prefetch_lines:
break
except EmbeddedErrorException:
except (EmbeddedErrorException, ProviderTimeoutException):
# 重新抛出可重试的 Provider 异常,触发故障转移
raise
except (OSError, IOError) as e:
# 网络 I/O <20><><EFBFBD>记录警告可能需要重试
logger.warning(
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
)
except Exception as e:
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
# 未预期的严重异常:记录错误并重新抛出,避免掩盖问题
logger.error(
f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}",
exc_info=True
)
raise
return prefetched_chunks

View File

@@ -115,7 +115,7 @@ class ClaudeCliAdapter(CliAdapterBase):
) -> Tuple[list, Optional[str]]:
"""查询 Claude API 支持的模型列表(带 CLI User-Agent"""
# 复用 ClaudeChatAdapter 的实现,添加 CLI User-Agent
cli_headers = {"User-Agent": config.internal_user_agent_claude}
cli_headers = {"User-Agent": config.internal_user_agent_claude_cli}
if extra_headers:
cli_headers.update(extra_headers)
models, error = await ClaudeChatAdapter.fetch_models(

View File

@@ -112,7 +112,7 @@ class GeminiCliAdapter(CliAdapterBase):
) -> Tuple[list, Optional[str]]:
"""查询 Gemini API 支持的模型列表(带 CLI User-Agent"""
# 复用 GeminiChatAdapter 的实现,添加 CLI User-Agent
cli_headers = {"User-Agent": config.internal_user_agent_gemini}
cli_headers = {"User-Agent": config.internal_user_agent_gemini_cli}
if extra_headers:
cli_headers.update(extra_headers)
models, error = await GeminiChatAdapter.fetch_models(

View File

@@ -57,7 +57,7 @@ class OpenAICliAdapter(CliAdapterBase):
) -> Tuple[list, Optional[str]]:
"""查询 OpenAI 兼容 API 支持的模型列表(带 CLI User-Agent"""
# 复用 OpenAIChatAdapter 的实现,添加 CLI User-Agent
cli_headers = {"User-Agent": config.internal_user_agent_openai}
cli_headers = {"User-Agent": config.internal_user_agent_openai_cli}
if extra_headers:
cli_headers.update(extra_headers)
models, error = await OpenAIChatAdapter.fetch_models(

View File

@@ -56,10 +56,11 @@ class Config:
# Redis 依赖策略(生产默认必需,开发默认可选,可通过 REDIS_REQUIRED 覆盖)
redis_required_env = os.getenv("REDIS_REQUIRED")
if redis_required_env is None:
self.require_redis = self.environment not in {"development", "test", "testing"}
else:
if redis_required_env is not None:
self.require_redis = redis_required_env.lower() == "true"
else:
# 保持向后兼容:开发环境可选,生产环境必需
self.require_redis = self.environment not in {"development", "test", "testing"}
# CORS配置 - 使用环境变量配置允许的源
# 格式: 逗号分隔的域名列表,如 "http://localhost:3000,https://example.com"
@@ -133,6 +134,18 @@ class Config:
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1"))
# 限流降级策略配置
# RATE_LIMIT_FAIL_OPEN: 当限流服务Redis异常时的行为
#
# True (默认): fail-open - 放行请求(优先可用性)
# 风险Redis 故障期间无法限流,可能被滥用
# 适用API 网关作为关键基础设施,必须保持高可用
#
# False: fail-close - 拒绝所有请求(优先安全性)
# 风险Redis 故障会导致 API 网关不可用
# 适用:有严格速率限制要求的安全敏感场景
self.rate_limit_fail_open = os.getenv("RATE_LIMIT_FAIL_OPEN", "true").lower() == "true"
# HTTP 请求超时配置(秒)
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
@@ -141,19 +154,22 @@ class Config:
# 流式处理配置
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
# STREAM_FIRST_BYTE_TIMEOUT: 首字节超时(秒),等待首字节超过此时间触发故障转移
# 范围: 10-120 秒,默认 30 秒(必须小于 http_write_timeout 避免竞态)
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
self.stream_first_byte_timeout = self._parse_ttfb_timeout()
# 内部请求 User-Agent 配置(用于查询上游模型列表等)
# 可通过环境变量覆盖默认值
self.internal_user_agent_claude = os.getenv(
"CLAUDE_USER_AGENT", "claude-cli/1.0"
# 可通过环境变量覆盖默认值,模拟对应 CLI 客户端
self.internal_user_agent_claude_cli = os.getenv(
"CLAUDE_CLI_USER_AGENT", "claude-code/1.0.1"
)
self.internal_user_agent_openai = os.getenv(
"OPENAI_USER_AGENT", "openai-cli/1.0"
self.internal_user_agent_openai_cli = os.getenv(
"OPENAI_CLI_USER_AGENT", "openai-codex/1.0"
)
self.internal_user_agent_gemini = os.getenv(
"GEMINI_USER_AGENT", "gemini-cli/1.0"
self.internal_user_agent_gemini_cli = os.getenv(
"GEMINI_CLI_USER_AGENT", "gemini-cli/0.1.0"
)
# 验证连接池配置
@@ -177,6 +193,39 @@ class Config:
"""智能计算最大溢出连接数 - 与 pool_size 相同"""
return self.db_pool_size
def _parse_ttfb_timeout(self) -> float:
"""
解析 TTFB 超时配置,带错误处理和范围限制
TTFB (Time To First Byte) 用于检测慢响应的 Provider超时触发故障转移。
此值必须小于 http_write_timeout避免竞态条件。
Returns:
超时时间(秒),范围 10-120默认 30
"""
default_timeout = 30.0
min_timeout = 10.0
max_timeout = 120.0 # 必须小于 http_write_timeout (默认 60s) 的 2 倍
raw_value = os.getenv("STREAM_FIRST_BYTE_TIMEOUT", str(default_timeout))
try:
timeout = float(raw_value)
except ValueError:
# 延迟导入避免循环依赖Config 初始化时 logger 可能未就绪)
self._ttfb_config_warning = (
f"无效的 STREAM_FIRST_BYTE_TIMEOUT 配置 '{raw_value}',使用默认值 {default_timeout}"
)
return default_timeout
# 范围限制
clamped = max(min_timeout, min(max_timeout, timeout))
if clamped != timeout:
self._ttfb_config_warning = (
f"STREAM_FIRST_BYTE_TIMEOUT={timeout}秒超出范围 [{min_timeout}-{max_timeout}]"
f"已调整为 {clamped}"
)
return clamped
def _validate_pool_config(self) -> None:
"""验证连接池配置是否安全"""
total_per_worker = self.db_pool_size + self.db_max_overflow
@@ -224,6 +273,10 @@ class Config:
if hasattr(self, "_pool_config_warning") and self._pool_config_warning:
logger.warning(self._pool_config_warning)
# TTFB 超时配置警告
if hasattr(self, "_ttfb_config_warning") and self._ttfb_config_warning:
logger.warning(self._ttfb_config_warning)
# 管理员密码检查(必须在环境变量中设置)
if hasattr(self, "_missing_admin_password") and self._missing_admin_password:
logger.error("必须设置 ADMIN_PASSWORD 环境变量!")

View File

@@ -336,10 +336,44 @@ class PluginMiddleware:
)
return result
return None
except Exception as e:
logger.error(f"Rate limit error: {e}")
# 发生错误时允许请求通过
except ConnectionError as e:
# Redis 连接错误:根据配置决定
logger.warning(f"Rate limit connection error: {e}")
if config.rate_limit_fail_open:
return None
else:
return RateLimitResult(
allowed=False,
remaining=0,
retry_after=30,
message="Rate limit service unavailable"
)
except TimeoutError as e:
# 超时错误:可能是负载过高,根据配置决定
logger.warning(f"Rate limit timeout: {e}")
if config.rate_limit_fail_open:
return None
else:
return RateLimitResult(
allowed=False,
remaining=0,
retry_after=30,
message="Rate limit service timeout"
)
except Exception as e:
logger.error(f"Rate limit error: {type(e).__name__}: {e}")
# 其他异常:根据配置决定
if config.rate_limit_fail_open:
# fail-open: 异常时放行请求(优先可用性)
return None
else:
# fail-close: 异常时拒绝请求(优先安全性)
return RateLimitResult(
allowed=False,
remaining=0,
retry_after=60,
message="Rate limit service error"
)
async def _call_pre_request_plugins(self, request: Request) -> None:
"""调用请求前的插件(当前保留扩展点)"""

View File

@@ -27,7 +27,7 @@ if not config.jwt_secret_key:
if config.environment == "production":
raise ValueError("JWT_SECRET_KEY must be set in production environment!")
config.jwt_secret_key = secrets.token_urlsafe(32)
logger.warning(f"JWT_SECRET_KEY未在环境变量中找到已生成随机密钥用于开发: {config.jwt_secret_key[:10]}...")
logger.warning("JWT_SECRET_KEY未在环境变量中找到已生成随机密钥用于开发")
logger.warning("生产环境请设置JWT_SECRET_KEY环境变量!")
JWT_SECRET_KEY = config.jwt_secret_key

View File

@@ -3,6 +3,7 @@
"""
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple
@@ -16,6 +17,71 @@ from src.services.model.cost import ModelCostService
from src.services.system.config import SystemConfigService
@dataclass
class UsageRecordParams:
"""用量记录参数数据类,用于在内部方法间传递数据"""
db: Session
user: Optional[User]
api_key: Optional[ApiKey]
provider: str
model: str
input_tokens: int
output_tokens: int
cache_creation_input_tokens: int
cache_read_input_tokens: int
request_type: str
api_format: Optional[str]
is_stream: bool
response_time_ms: Optional[int]
first_byte_time_ms: Optional[int]
status_code: int
error_message: Optional[str]
metadata: Optional[Dict[str, Any]]
request_headers: Optional[Dict[str, Any]]
request_body: Optional[Any]
provider_request_headers: Optional[Dict[str, Any]]
response_headers: Optional[Dict[str, Any]]
response_body: Optional[Any]
request_id: str
provider_id: Optional[str]
provider_endpoint_id: Optional[str]
provider_api_key_id: Optional[str]
status: str
cache_ttl_minutes: Optional[int]
use_tiered_pricing: bool
target_model: Optional[str]
def __post_init__(self) -> None:
"""验证关键字段,确保数据完整性"""
# Token 数量不能为负数
if self.input_tokens < 0:
raise ValueError(f"input_tokens 不能为负数: {self.input_tokens}")
if self.output_tokens < 0:
raise ValueError(f"output_tokens 不能为负数: {self.output_tokens}")
if self.cache_creation_input_tokens < 0:
raise ValueError(
f"cache_creation_input_tokens 不能为负数: {self.cache_creation_input_tokens}"
)
if self.cache_read_input_tokens < 0:
raise ValueError(
f"cache_read_input_tokens 不能为负数: {self.cache_read_input_tokens}"
)
# 响应时间不能为负数
if self.response_time_ms is not None and self.response_time_ms < 0:
raise ValueError(f"response_time_ms 不能为负数: {self.response_time_ms}")
if self.first_byte_time_ms is not None and self.first_byte_time_ms < 0:
raise ValueError(f"first_byte_time_ms 不能为负数: {self.first_byte_time_ms}")
# HTTP 状态码范围校验
if not (100 <= self.status_code <= 599):
raise ValueError(f"无效的 HTTP 状态码: {self.status_code}")
# 状态值校验
valid_statuses = {"pending", "streaming", "completed", "failed"}
if self.status not in valid_statuses:
raise ValueError(f"无效的状态值: {self.status},有效值: {valid_statuses}")
class UsageService:
"""用量统计服务"""
@@ -471,6 +537,97 @@ class UsageService:
cache_ttl_minutes=cache_ttl_minutes,
)
@classmethod
async def _prepare_usage_record(
cls,
params: UsageRecordParams,
) -> Tuple[Dict[str, Any], float]:
"""准备用量记录的共享逻辑
此方法提取了 record_usage 和 record_usage_async 的公共处理逻辑:
- 获取费率倍数
- 计算成本
- 构建 Usage 参数
Args:
params: 用量记录参数数据类
Returns:
(usage_params 字典, total_cost 总成本)
"""
# 获取费率倍数和是否免费套餐
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
params.db, params.provider_api_key_id, params.provider_id
)
# 计算成本
is_failed_request = params.status_code >= 400 or params.error_message is not None
(
input_price, output_price, cache_creation_price, cache_read_price, request_price,
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
request_cost, total_cost, _tier_index
) = await cls._calculate_costs(
db=params.db,
provider=params.provider,
model=params.model,
input_tokens=params.input_tokens,
output_tokens=params.output_tokens,
cache_creation_input_tokens=params.cache_creation_input_tokens,
cache_read_input_tokens=params.cache_read_input_tokens,
api_format=params.api_format,
cache_ttl_minutes=params.cache_ttl_minutes,
use_tiered_pricing=params.use_tiered_pricing,
is_failed_request=is_failed_request,
)
# 构建 Usage 参数
usage_params = cls._build_usage_params(
db=params.db,
user=params.user,
api_key=params.api_key,
provider=params.provider,
model=params.model,
input_tokens=params.input_tokens,
output_tokens=params.output_tokens,
cache_creation_input_tokens=params.cache_creation_input_tokens,
cache_read_input_tokens=params.cache_read_input_tokens,
request_type=params.request_type,
api_format=params.api_format,
is_stream=params.is_stream,
response_time_ms=params.response_time_ms,
first_byte_time_ms=params.first_byte_time_ms,
status_code=params.status_code,
error_message=params.error_message,
metadata=params.metadata,
request_headers=params.request_headers,
request_body=params.request_body,
provider_request_headers=params.provider_request_headers,
response_headers=params.response_headers,
response_body=params.response_body,
request_id=params.request_id,
provider_id=params.provider_id,
provider_endpoint_id=params.provider_endpoint_id,
provider_api_key_id=params.provider_api_key_id,
status=params.status,
target_model=params.target_model,
input_cost=input_cost,
output_cost=output_cost,
cache_creation_cost=cache_creation_cost,
cache_read_cost=cache_read_cost,
cache_cost=cache_cost,
request_cost=request_cost,
total_cost=total_cost,
input_price=input_price,
output_price=output_price,
cache_creation_price=cache_creation_price,
cache_read_price=cache_read_price,
request_price=request_price,
actual_rate_multiplier=actual_rate_multiplier,
is_free_tier=is_free_tier,
)
return usage_params, total_cost
@classmethod
async def record_usage_async(
cls,
@@ -516,76 +673,25 @@ class UsageService:
if request_id is None:
request_id = str(uuid.uuid4())[:8]
# 获取费率倍数和是否免费套餐
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
db, provider_api_key_id, provider_id
)
# 计算成本
is_failed_request = status_code >= 400 or error_message is not None
(
input_price, output_price, cache_creation_price, cache_read_price, request_price,
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
request_cost, total_cost, tier_index
) = await cls._calculate_costs(
db=db,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
# 使用共享逻辑准备记录参数
params = UsageRecordParams(
db=db, user=user, api_key=api_key, provider=provider, model=model,
input_tokens=input_tokens, output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
api_format=api_format,
cache_ttl_minutes=cache_ttl_minutes,
use_tiered_pricing=use_tiered_pricing,
is_failed_request=is_failed_request,
)
# 构建 Usage 参数
usage_params = cls._build_usage_params(
db=db,
user=user,
api_key=api_key,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
request_type=request_type,
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms,
status_code=status_code,
error_message=error_message,
metadata=metadata,
request_headers=request_headers,
request_body=request_body,
request_type=request_type, api_format=api_format, is_stream=is_stream,
response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms,
status_code=status_code, error_message=error_message, metadata=metadata,
request_headers=request_headers, request_body=request_body,
provider_request_headers=provider_request_headers,
response_headers=response_headers,
response_body=response_body,
request_id=request_id,
provider_id=provider_id,
response_headers=response_headers, response_body=response_body,
request_id=request_id, provider_id=provider_id,
provider_endpoint_id=provider_endpoint_id,
provider_api_key_id=provider_api_key_id,
status=status,
provider_api_key_id=provider_api_key_id, status=status,
cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing,
target_model=target_model,
input_cost=input_cost,
output_cost=output_cost,
cache_creation_cost=cache_creation_cost,
cache_read_cost=cache_read_cost,
cache_cost=cache_cost,
request_cost=request_cost,
total_cost=total_cost,
input_price=input_price,
output_price=output_price,
cache_creation_price=cache_creation_price,
cache_read_price=cache_read_price,
request_price=request_price,
actual_rate_multiplier=actual_rate_multiplier,
is_free_tier=is_free_tier,
)
usage_params, _ = await cls._prepare_usage_record(params)
# 创建 Usage 记录
usage = Usage(**usage_params)
@@ -660,76 +766,25 @@ class UsageService:
if request_id is None:
request_id = str(uuid.uuid4())[:8]
# 获取费率倍数和是否免费套餐
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
db, provider_api_key_id, provider_id
)
# 计算成本
is_failed_request = status_code >= 400 or error_message is not None
(
input_price, output_price, cache_creation_price, cache_read_price, request_price,
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
request_cost, total_cost, _tier_index
) = await cls._calculate_costs(
db=db,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
# 使用共享逻辑准备记录参数
params = UsageRecordParams(
db=db, user=user, api_key=api_key, provider=provider, model=model,
input_tokens=input_tokens, output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
api_format=api_format,
cache_ttl_minutes=cache_ttl_minutes,
use_tiered_pricing=use_tiered_pricing,
is_failed_request=is_failed_request,
)
# 构建 Usage 参数
usage_params = cls._build_usage_params(
db=db,
user=user,
api_key=api_key,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
request_type=request_type,
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms,
status_code=status_code,
error_message=error_message,
metadata=metadata,
request_headers=request_headers,
request_body=request_body,
request_type=request_type, api_format=api_format, is_stream=is_stream,
response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms,
status_code=status_code, error_message=error_message, metadata=metadata,
request_headers=request_headers, request_body=request_body,
provider_request_headers=provider_request_headers,
response_headers=response_headers,
response_body=response_body,
request_id=request_id,
provider_id=provider_id,
response_headers=response_headers, response_body=response_body,
request_id=request_id, provider_id=provider_id,
provider_endpoint_id=provider_endpoint_id,
provider_api_key_id=provider_api_key_id,
status=status,
provider_api_key_id=provider_api_key_id, status=status,
cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing,
target_model=target_model,
input_cost=input_cost,
output_cost=output_cost,
cache_creation_cost=cache_creation_cost,
cache_read_cost=cache_read_cost,
cache_cost=cache_cost,
request_cost=request_cost,
total_cost=total_cost,
input_price=input_price,
output_price=output_price,
cache_creation_price=cache_creation_price,
cache_read_price=cache_read_price,
request_price=request_price,
actual_rate_multiplier=actual_rate_multiplier,
is_free_tier=is_free_tier,
)
usage_params, total_cost = await cls._prepare_usage_record(params)
# 检查是否已存在相同 request_id 的记录
existing_usage = db.query(Usage).filter(Usage.request_id == request_id).first()
@@ -751,7 +806,7 @@ class UsageService:
api_key = db.merge(api_key)
# 使用原子更新避免并发竞态条件
from sqlalchemy import func, update
from sqlalchemy import func as sql_func, update
from src.models.database import ApiKey as ApiKeyModel, User as UserModel, GlobalModel
# 更新用户使用量(独立 Key 不计入创建者的使用记录)
@@ -762,7 +817,7 @@ class UsageService:
.values(
used_usd=UserModel.used_usd + total_cost,
total_usd=UserModel.total_usd + total_cost,
updated_at=func.now(),
updated_at=sql_func.now(),
)
)
@@ -776,8 +831,8 @@ class UsageService:
total_requests=ApiKeyModel.total_requests + 1,
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
balance_used_usd=ApiKeyModel.balance_used_usd + total_cost,
last_used_at=func.now(),
updated_at=func.now(),
last_used_at=sql_func.now(),
updated_at=sql_func.now(),
)
)
else:
@@ -787,8 +842,8 @@ class UsageService:
.values(
total_requests=ApiKeyModel.total_requests + 1,
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
last_used_at=func.now(),
updated_at=func.now(),
last_used_at=sql_func.now(),
updated_at=sql_func.now(),
)
)
@@ -1121,19 +1176,48 @@ class UsageService:
]
@staticmethod
def cleanup_old_usage_records(db: Session, days_to_keep: int = 90) -> int:
"""清理旧的使用记录"""
def cleanup_old_usage_records(
db: Session, days_to_keep: int = 90, batch_size: int = 1000
) -> int:
"""清理旧的使用记录(分批删除避免长事务锁定)
Args:
db: 数据库会话
days_to_keep: 保留天数,默认 90 天
batch_size: 每批删除数量,默认 1000 条
Returns:
删除的总记录数
"""
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
total_deleted = 0
# 删除旧记录
deleted = db.query(Usage).filter(Usage.created_at < cutoff_date).delete()
while True:
# 查询待删除的 ID使用新索引 idx_usage_user_created
batch_ids = (
db.query(Usage.id)
.filter(Usage.created_at < cutoff_date)
.limit(batch_size)
.all()
)
if not batch_ids:
break
# 批量删除
deleted_count = (
db.query(Usage)
.filter(Usage.id.in_([row.id for row in batch_ids]))
.delete(synchronize_session=False)
)
db.commit()
total_deleted += deleted_count
logger.info(f"清理使用记录: 删除 {deleted} 条超过 {days_to_keep} 天的记录")
logger.debug(f"清理使用记录: 本批删除 {deleted_count} ")
return deleted
logger.info(f"清理使用记录: 共删除 {total_deleted} 条超过 {days_to_keep} 天的记录")
return total_deleted
# ========== 请求状态追踪方法 ==========
@@ -1219,6 +1303,7 @@ class UsageService:
error_message: Optional[str] = None,
provider: Optional[str] = None,
target_model: Optional[str] = None,
first_byte_time_ms: Optional[int] = None,
) -> Optional[Usage]:
"""
快速更新使用记录状态
@@ -1230,6 +1315,7 @@ class UsageService:
error_message: 错误消息(仅在 failed 状态时使用)
provider: 提供商名称可选streaming 状态时更新)
target_model: 映射后的目标模型名(可选)
first_byte_time_ms: 首字时间/TTFB可选streaming 状态时更新)
Returns:
更新后的 Usage 记录,如果未找到则返回 None
@@ -1247,6 +1333,8 @@ class UsageService:
usage.provider = provider
if target_model:
usage.target_model = target_model
if first_byte_time_ms is not None:
usage.first_byte_time_ms = first_byte_time_ms
db.commit()

View File

@@ -5,6 +5,7 @@
import json
import re
import time
from typing import Any, AsyncIterator, Dict, Optional, Tuple
from sqlalchemy.orm import Session
@@ -457,26 +458,32 @@ class StreamUsageTracker:
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
# 更新状态为 streaming同时更新 provider
if self.request_id:
try:
from src.services.usage.service import UsageService
UsageService.update_usage_status(
db=self.db,
request_id=self.request_id,
status="streaming",
provider=self.provider,
)
except Exception as e:
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
chunk_count = 0
first_chunk_received = False
try:
async for chunk in stream:
chunk_count += 1
# 保存原始字节流(用于错误诊断)
self.raw_chunks.append(chunk)
# 第一个 chunk 收到时,更新状态为 streaming 并记录 TTFB
if not first_chunk_received:
first_chunk_received = True
if self.request_id:
try:
# 计算 TTFB使用请求原始开始时间或 track_stream 开始时间)
base_time = self.request_start_time or self.start_time
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
UsageService.update_usage_status(
db=self.db,
request_id=self.request_id,
status="streaming",
provider=self.provider,
first_byte_time_ms=first_byte_time_ms,
)
except Exception as e:
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
# 返回原始块给客户端
yield chunk

View File

@@ -139,3 +139,83 @@ async def with_timeout_context(timeout: float, operation_name: str = "operation"
# Python 3.10 及以下版本的兼容实现
# 注意:这个简单实现不支持嵌套取消
pass
async def read_first_chunk_with_ttfb_timeout(
byte_iterator: Any,
timeout: float,
request_id: str,
provider_name: str,
) -> tuple[bytes, Any]:
"""
读取流的首字节并应用 TTFB 超时检测
首字节超时Time To First Byte用于检测慢响应的 Provider
超时时触发故障转移到其他可用的 Provider。
Args:
byte_iterator: 异步字节流迭代器
timeout: TTFB 超时时间(秒)
request_id: 请求 ID用于日志
provider_name: Provider 名称(用于日志和异常)
Returns:
(first_chunk, aiter): 首个字节块和异步迭代器
Raises:
ProviderTimeoutException: 如果首字节超时
"""
from src.core.exceptions import ProviderTimeoutException
aiter = byte_iterator.__aiter__()
try:
first_chunk = await asyncio.wait_for(aiter.__anext__(), timeout=timeout)
return first_chunk, aiter
except asyncio.TimeoutError:
# 完整的资源清理:先关闭迭代器,再关闭底层响应
await _cleanup_iterator_resources(aiter, request_id)
logger.warning(
f" [{request_id}] 流首字节超时 (TTFB): "
f"Provider={provider_name}, timeout={timeout}s"
)
raise ProviderTimeoutException(
provider_name=provider_name,
timeout=int(timeout),
)
async def _cleanup_iterator_resources(aiter: Any, request_id: str) -> None:
"""
清理异步迭代器及其底层资源
确保在 TTFB 超时后正确释放 HTTP 连接,避免连接泄漏。
Args:
aiter: 异步迭代器
request_id: 请求 ID用于日志
"""
# 1. 关闭迭代器本身
if hasattr(aiter, "aclose"):
try:
await aiter.aclose()
except Exception as e:
logger.debug(f" [{request_id}] 关闭迭代器失败: {e}")
# 2. 关闭底层响应对象httpx.Response
# 迭代器可能持有 _response 属性指向底层响应
response = getattr(aiter, "_response", None)
if response is not None and hasattr(response, "aclose"):
try:
await response.aclose()
except Exception as e:
logger.debug(f" [{request_id}] 关闭底层响应失败: {e}")
# 3. 尝试关闭 httpx 流(如果迭代器是 httpx 的 aiter_bytes
# httpx 的 Response.aiter_bytes() 返回的生成器可能有 _stream 属性
stream = getattr(aiter, "_stream", None)
if stream is not None and hasattr(stream, "aclose"):
try:
await stream.aclose()
except Exception as e:
logger.debug(f" [{request_id}] 关闭流对象失败: {e}")