From 1d5c378343b00e50e5243de82d1ee291c5fa5dc4 Mon Sep 17 00:00:00 2001 From: fawney19 Date: Mon, 22 Dec 2025 23:44:42 +0800 Subject: [PATCH] feat: add TTFB timeout detection and improve stream handling - Add stream first byte timeout (TTFB) detection to trigger failover when provider responds too slowly (configurable via STREAM_FIRST_BYTE_TIMEOUT) - Add rate limit fail-open/fail-close strategy configuration - Improve exception handling in stream prefetch with proper error classification - Refactor UsageService with shared _prepare_usage_record method - Add batch deletion for old usage records to avoid long transaction locks - Update CLI adapters to use proper User-Agent headers for each CLI client - Add composite indexes migration for usage table query optimization - Fix streaming status display in frontend to show TTFB during streaming - Remove sensitive JWT secret logging in auth service --- ...251220_1500_add_usage_composite_indexes.py | 65 +++ .../usage/components/UsageRecordsTable.vue | 22 +- src/api/handlers/base/base_handler.py | 6 + src/api/handlers/base/cli_handler_base.py | 36 +- src/api/handlers/base/stream_processor.py | 36 +- src/api/handlers/claude_cli/adapter.py | 2 +- src/api/handlers/gemini_cli/adapter.py | 2 +- src/api/handlers/openai_cli/adapter.py | 2 +- src/config/settings.py | 73 +++- src/middleware/plugin_middleware.py | 40 +- src/services/auth/service.py | 2 +- src/services/usage/service.py | 370 +++++++++++------- src/services/usage/stream.py | 33 +- src/utils/timeout.py | 80 ++++ 14 files changed, 588 insertions(+), 181 deletions(-) create mode 100644 alembic/versions/20251220_1500_add_usage_composite_indexes.py diff --git a/alembic/versions/20251220_1500_add_usage_composite_indexes.py b/alembic/versions/20251220_1500_add_usage_composite_indexes.py new file mode 100644 index 0000000..122da57 --- /dev/null +++ b/alembic/versions/20251220_1500_add_usage_composite_indexes.py @@ -0,0 +1,65 @@ +"""add usage table composite indexes for query optimization + +Revision ID: b2c3d4e5f6g7 +Revises: a1b2c3d4e5f6 +Create Date: 2025-12-20 15:00:00.000000+00:00 + +""" +from alembic import op +from sqlalchemy import inspect + +# revision identifiers, used by Alembic. +revision = 'b2c3d4e5f6g7' +down_revision = 'a1b2c3d4e5f6' +branch_labels = None +depends_on = None + + +def index_exists(table_name: str, index_name: str) -> bool: + """检查索引是否存在""" + bind = op.get_bind() + inspector = inspect(bind) + indexes = [idx['name'] for idx in inspector.get_indexes(table_name)] + return index_name in indexes + + +def upgrade() -> None: + """为 usage 表添加复合索引以优化常见查询""" + # 1. user_id + created_at 复合索引 (用户用量查询) + if not index_exists('usage', 'idx_usage_user_created'): + op.create_index( + 'idx_usage_user_created', + 'usage', + ['user_id', 'created_at'], + postgresql_concurrently=True + ) + + # 2. api_key_id + created_at 复合索引 (API Key 用量查询) + if not index_exists('usage', 'idx_usage_apikey_created'): + op.create_index( + 'idx_usage_apikey_created', + 'usage', + ['api_key_id', 'created_at'], + postgresql_concurrently=True + ) + + # 3. provider + model + created_at 复合索引 (模型统计查询) + if not index_exists('usage', 'idx_usage_provider_model_created'): + op.create_index( + 'idx_usage_provider_model_created', + 'usage', + ['provider', 'model', 'created_at'], + postgresql_concurrently=True + ) + + +def downgrade() -> None: + """删除复合索引""" + if index_exists('usage', 'idx_usage_provider_model_created'): + op.drop_index('idx_usage_provider_model_created', table_name='usage') + + if index_exists('usage', 'idx_usage_apikey_created'): + op.drop_index('idx_usage_apikey_created', table_name='usage') + + if index_exists('usage', 'idx_usage_user_created'): + op.drop_index('idx_usage_user_created', table_name='usage') diff --git a/frontend/src/features/usage/components/UsageRecordsTable.vue b/frontend/src/features/usage/components/UsageRecordsTable.vue index dc28d17..72395a2 100644 --- a/frontend/src/features/usage/components/UsageRecordsTable.vue +++ b/frontend/src/features/usage/components/UsageRecordsTable.vue @@ -366,14 +366,34 @@ +
+ - {{ getElapsedTime(record) }}
+ +
+ {{ (record.first_byte_time_ms / 1000).toFixed(2) }}s + - + + {{ getElapsedTime(record) }} + +
+
= max_prefetch_lines: break - except EmbeddedErrorException: - # 重新抛出嵌套错误 + except (EmbeddedErrorException, ProviderTimeoutException, ProviderNotAvailableException): + # 重新抛出可重试的 Provider 异常,触发故障转移 raise + except (OSError, IOError) as e: + # 网络 I/O 异常:记录警告,可能需要重试 + logger.warning( + f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}" + ) except Exception as e: - # 其他异常(如网络错误)在预读阶段发生,记录日志但不中断 - logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}") + # 未预期的严重异常:记录错误并重新抛出,避免掩盖问题 + logger.error( + f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}", + exc_info=True + ) + raise return prefetched_chunks diff --git a/src/api/handlers/base/stream_processor.py b/src/api/handlers/base/stream_processor.py index 275de5d..7399dcd 100644 --- a/src/api/handlers/base/stream_processor.py +++ b/src/api/handlers/base/stream_processor.py @@ -25,10 +25,12 @@ from src.api.handlers.base.content_extractors import ( from src.api.handlers.base.parsers import get_parser_for_format from src.api.handlers.base.response_parser import ResponseParser from src.api.handlers.base.stream_context import StreamContext -from src.core.exceptions import EmbeddedErrorException +from src.config.settings import config +from src.core.exceptions import EmbeddedErrorException, ProviderTimeoutException from src.core.logger import logger from src.models.database import Provider, ProviderEndpoint from src.utils.sse_parser import SSEEventParser +from src.utils.timeout import read_first_chunk_with_ttfb_timeout @dataclass @@ -170,6 +172,8 @@ class StreamProcessor: 某些 Provider(如 Gemini)可能返回 HTTP 200,但在响应体中包含错误信息。 这种情况需要在流开始输出之前检测,以便触发重试逻辑。 + 首次读取时会应用 TTFB(首字节超时)检测,超时则触发故障转移。 + Args: byte_iterator: 字节流迭代器 provider: Provider 对象 @@ -182,6 +186,7 @@ class StreamProcessor: Raises: EmbeddedErrorException: 如果检测到嵌套错误 + ProviderTimeoutException: 如果首字节超时(TTFB timeout) """ prefetched_chunks: list = [] parser = self.get_parser_for_provider(ctx) @@ -192,7 +197,19 @@ class StreamProcessor: decoder = codecs.getincrementaldecoder("utf-8")(errors="replace") try: - async for chunk in byte_iterator: + # 使用共享的 TTFB 超时函数读取首字节 + ttfb_timeout = config.stream_first_byte_timeout + first_chunk, aiter = await read_first_chunk_with_ttfb_timeout( + byte_iterator, + timeout=ttfb_timeout, + request_id=self.request_id, + provider_name=str(provider.name), + ) + prefetched_chunks.append(first_chunk) + buffer += first_chunk + + # 继续读取剩余的预读数据 + async for chunk in aiter: prefetched_chunks.append(chunk) buffer += chunk @@ -262,10 +279,21 @@ class StreamProcessor: if should_stop or line_count >= max_prefetch_lines: break - except EmbeddedErrorException: + except (EmbeddedErrorException, ProviderTimeoutException): + # 重新抛出可重试的 Provider 异常,触发故障转移 raise + except (OSError, IOError) as e: + # 网络 I/O ���常:记录警告,可能需要重试 + logger.warning( + f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}" + ) except Exception as e: - logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}") + # 未预期的严重异常:记录错误并重新抛出,避免掩盖问题 + logger.error( + f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}", + exc_info=True + ) + raise return prefetched_chunks diff --git a/src/api/handlers/claude_cli/adapter.py b/src/api/handlers/claude_cli/adapter.py index 386bd6e..ca16eb1 100644 --- a/src/api/handlers/claude_cli/adapter.py +++ b/src/api/handlers/claude_cli/adapter.py @@ -115,7 +115,7 @@ class ClaudeCliAdapter(CliAdapterBase): ) -> Tuple[list, Optional[str]]: """查询 Claude API 支持的模型列表(带 CLI User-Agent)""" # 复用 ClaudeChatAdapter 的实现,添加 CLI User-Agent - cli_headers = {"User-Agent": config.internal_user_agent_claude} + cli_headers = {"User-Agent": config.internal_user_agent_claude_cli} if extra_headers: cli_headers.update(extra_headers) models, error = await ClaudeChatAdapter.fetch_models( diff --git a/src/api/handlers/gemini_cli/adapter.py b/src/api/handlers/gemini_cli/adapter.py index 19ab657..e416680 100644 --- a/src/api/handlers/gemini_cli/adapter.py +++ b/src/api/handlers/gemini_cli/adapter.py @@ -112,7 +112,7 @@ class GeminiCliAdapter(CliAdapterBase): ) -> Tuple[list, Optional[str]]: """查询 Gemini API 支持的模型列表(带 CLI User-Agent)""" # 复用 GeminiChatAdapter 的实现,添加 CLI User-Agent - cli_headers = {"User-Agent": config.internal_user_agent_gemini} + cli_headers = {"User-Agent": config.internal_user_agent_gemini_cli} if extra_headers: cli_headers.update(extra_headers) models, error = await GeminiChatAdapter.fetch_models( diff --git a/src/api/handlers/openai_cli/adapter.py b/src/api/handlers/openai_cli/adapter.py index a72469a..80f458e 100644 --- a/src/api/handlers/openai_cli/adapter.py +++ b/src/api/handlers/openai_cli/adapter.py @@ -57,7 +57,7 @@ class OpenAICliAdapter(CliAdapterBase): ) -> Tuple[list, Optional[str]]: """查询 OpenAI 兼容 API 支持的模型列表(带 CLI User-Agent)""" # 复用 OpenAIChatAdapter 的实现,添加 CLI User-Agent - cli_headers = {"User-Agent": config.internal_user_agent_openai} + cli_headers = {"User-Agent": config.internal_user_agent_openai_cli} if extra_headers: cli_headers.update(extra_headers) models, error = await OpenAIChatAdapter.fetch_models( diff --git a/src/config/settings.py b/src/config/settings.py index bec4dcb..44acf33 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -56,10 +56,11 @@ class Config: # Redis 依赖策略(生产默认必需,开发默认可选,可通过 REDIS_REQUIRED 覆盖) redis_required_env = os.getenv("REDIS_REQUIRED") - if redis_required_env is None: - self.require_redis = self.environment not in {"development", "test", "testing"} - else: + if redis_required_env is not None: self.require_redis = redis_required_env.lower() == "true" + else: + # 保持向后兼容:开发环境可选,生产环境必需 + self.require_redis = self.environment not in {"development", "test", "testing"} # CORS配置 - 使用环境变量配置允许的源 # 格式: 逗号分隔的域名列表,如 "http://localhost:3000,https://example.com" @@ -133,6 +134,18 @@ class Config: self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600")) self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1")) + # 限流降级策略配置 + # RATE_LIMIT_FAIL_OPEN: 当限流服务(Redis)异常时的行为 + # + # True (默认): fail-open - 放行请求(优先可用性) + # 风险:Redis 故障期间无法限流,可能被滥用 + # 适用:API 网关作为关键基础设施,必须保持高可用 + # + # False: fail-close - 拒绝所有请求(优先安全性) + # 风险:Redis 故障会导致 API 网关不可用 + # 适用:有严格速率限制要求的安全敏感场景 + self.rate_limit_fail_open = os.getenv("RATE_LIMIT_FAIL_OPEN", "true").lower() == "true" + # HTTP 请求超时配置(秒) self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0")) self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0")) @@ -141,19 +154,22 @@ class Config: # 流式处理配置 # STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误 # STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭 + # STREAM_FIRST_BYTE_TIMEOUT: 首字节超时(秒),等待首字节超过此时间触发故障转移 + # 范围: 10-120 秒,默认 30 秒(必须小于 http_write_timeout 避免竞态) self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5")) self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1")) + self.stream_first_byte_timeout = self._parse_ttfb_timeout() # 内部请求 User-Agent 配置(用于查询上游模型列表等) - # 可通过环境变量覆盖默认值 - self.internal_user_agent_claude = os.getenv( - "CLAUDE_USER_AGENT", "claude-cli/1.0" + # 可通过环境变量覆盖默认值,模拟对应 CLI 客户端 + self.internal_user_agent_claude_cli = os.getenv( + "CLAUDE_CLI_USER_AGENT", "claude-code/1.0.1" ) - self.internal_user_agent_openai = os.getenv( - "OPENAI_USER_AGENT", "openai-cli/1.0" + self.internal_user_agent_openai_cli = os.getenv( + "OPENAI_CLI_USER_AGENT", "openai-codex/1.0" ) - self.internal_user_agent_gemini = os.getenv( - "GEMINI_USER_AGENT", "gemini-cli/1.0" + self.internal_user_agent_gemini_cli = os.getenv( + "GEMINI_CLI_USER_AGENT", "gemini-cli/0.1.0" ) # 验证连接池配置 @@ -177,6 +193,39 @@ class Config: """智能计算最大溢出连接数 - 与 pool_size 相同""" return self.db_pool_size + def _parse_ttfb_timeout(self) -> float: + """ + 解析 TTFB 超时配置,带错误处理和范围限制 + + TTFB (Time To First Byte) 用于检测慢响应的 Provider,超时触发故障转移。 + 此值必须小于 http_write_timeout,避免竞态条件。 + + Returns: + 超时时间(秒),范围 10-120,默认 30 + """ + default_timeout = 30.0 + min_timeout = 10.0 + max_timeout = 120.0 # 必须小于 http_write_timeout (默认 60s) 的 2 倍 + + raw_value = os.getenv("STREAM_FIRST_BYTE_TIMEOUT", str(default_timeout)) + try: + timeout = float(raw_value) + except ValueError: + # 延迟导入,避免循环依赖(Config 初始化时 logger 可能未就绪) + self._ttfb_config_warning = ( + f"无效的 STREAM_FIRST_BYTE_TIMEOUT 配置 '{raw_value}',使用默认值 {default_timeout}秒" + ) + return default_timeout + + # 范围限制 + clamped = max(min_timeout, min(max_timeout, timeout)) + if clamped != timeout: + self._ttfb_config_warning = ( + f"STREAM_FIRST_BYTE_TIMEOUT={timeout}秒超出范围 [{min_timeout}-{max_timeout}]," + f"已调整为 {clamped}秒" + ) + return clamped + def _validate_pool_config(self) -> None: """验证连接池配置是否安全""" total_per_worker = self.db_pool_size + self.db_max_overflow @@ -224,6 +273,10 @@ class Config: if hasattr(self, "_pool_config_warning") and self._pool_config_warning: logger.warning(self._pool_config_warning) + # TTFB 超时配置警告 + if hasattr(self, "_ttfb_config_warning") and self._ttfb_config_warning: + logger.warning(self._ttfb_config_warning) + # 管理员密码检查(必须在环境变量中设置) if hasattr(self, "_missing_admin_password") and self._missing_admin_password: logger.error("必须设置 ADMIN_PASSWORD 环境变量!") diff --git a/src/middleware/plugin_middleware.py b/src/middleware/plugin_middleware.py index 7fcc33c..6ca59c5 100644 --- a/src/middleware/plugin_middleware.py +++ b/src/middleware/plugin_middleware.py @@ -336,10 +336,44 @@ class PluginMiddleware: ) return result return None + except ConnectionError as e: + # Redis 连接错误:根据配置决定 + logger.warning(f"Rate limit connection error: {e}") + if config.rate_limit_fail_open: + return None + else: + return RateLimitResult( + allowed=False, + remaining=0, + retry_after=30, + message="Rate limit service unavailable" + ) + except TimeoutError as e: + # 超时错误:可能是负载过高,根据配置决定 + logger.warning(f"Rate limit timeout: {e}") + if config.rate_limit_fail_open: + return None + else: + return RateLimitResult( + allowed=False, + remaining=0, + retry_after=30, + message="Rate limit service timeout" + ) except Exception as e: - logger.error(f"Rate limit error: {e}") - # 发生错误时允许请求通过 - return None + logger.error(f"Rate limit error: {type(e).__name__}: {e}") + # 其他异常:根据配置决定 + if config.rate_limit_fail_open: + # fail-open: 异常时放行请求(优先可用性) + return None + else: + # fail-close: 异常时拒绝请求(优先安全性) + return RateLimitResult( + allowed=False, + remaining=0, + retry_after=60, + message="Rate limit service error" + ) async def _call_pre_request_plugins(self, request: Request) -> None: """调用请求前的插件(当前保留扩展点)""" diff --git a/src/services/auth/service.py b/src/services/auth/service.py index 9eb04a0..ecdf57b 100644 --- a/src/services/auth/service.py +++ b/src/services/auth/service.py @@ -27,7 +27,7 @@ if not config.jwt_secret_key: if config.environment == "production": raise ValueError("JWT_SECRET_KEY must be set in production environment!") config.jwt_secret_key = secrets.token_urlsafe(32) - logger.warning(f"JWT_SECRET_KEY未在环境变量中找到,已生成随机密钥用于开发: {config.jwt_secret_key[:10]}...") + logger.warning("JWT_SECRET_KEY未在环境变量中找到,已生成随机密钥用于开发") logger.warning("生产环境请设置JWT_SECRET_KEY环境变量!") JWT_SECRET_KEY = config.jwt_secret_key diff --git a/src/services/usage/service.py b/src/services/usage/service.py index 4f62e5c..9bb01c7 100644 --- a/src/services/usage/service.py +++ b/src/services/usage/service.py @@ -3,6 +3,7 @@ """ import uuid +from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional, Tuple @@ -16,6 +17,71 @@ from src.services.model.cost import ModelCostService from src.services.system.config import SystemConfigService +@dataclass +class UsageRecordParams: + """用量记录参数数据类,用于在内部方法间传递数据""" + db: Session + user: Optional[User] + api_key: Optional[ApiKey] + provider: str + model: str + input_tokens: int + output_tokens: int + cache_creation_input_tokens: int + cache_read_input_tokens: int + request_type: str + api_format: Optional[str] + is_stream: bool + response_time_ms: Optional[int] + first_byte_time_ms: Optional[int] + status_code: int + error_message: Optional[str] + metadata: Optional[Dict[str, Any]] + request_headers: Optional[Dict[str, Any]] + request_body: Optional[Any] + provider_request_headers: Optional[Dict[str, Any]] + response_headers: Optional[Dict[str, Any]] + response_body: Optional[Any] + request_id: str + provider_id: Optional[str] + provider_endpoint_id: Optional[str] + provider_api_key_id: Optional[str] + status: str + cache_ttl_minutes: Optional[int] + use_tiered_pricing: bool + target_model: Optional[str] + + def __post_init__(self) -> None: + """验证关键字段,确保数据完整性""" + # Token 数量不能为负数 + if self.input_tokens < 0: + raise ValueError(f"input_tokens 不能为负数: {self.input_tokens}") + if self.output_tokens < 0: + raise ValueError(f"output_tokens 不能为负数: {self.output_tokens}") + if self.cache_creation_input_tokens < 0: + raise ValueError( + f"cache_creation_input_tokens 不能为负数: {self.cache_creation_input_tokens}" + ) + if self.cache_read_input_tokens < 0: + raise ValueError( + f"cache_read_input_tokens 不能为负数: {self.cache_read_input_tokens}" + ) + + # 响应时间不能为负数 + if self.response_time_ms is not None and self.response_time_ms < 0: + raise ValueError(f"response_time_ms 不能为负数: {self.response_time_ms}") + if self.first_byte_time_ms is not None and self.first_byte_time_ms < 0: + raise ValueError(f"first_byte_time_ms 不能为负数: {self.first_byte_time_ms}") + + # HTTP 状态码范围校验 + if not (100 <= self.status_code <= 599): + raise ValueError(f"无效的 HTTP 状态码: {self.status_code}") + + # 状态值校验 + valid_statuses = {"pending", "streaming", "completed", "failed"} + if self.status not in valid_statuses: + raise ValueError(f"无效的状态值: {self.status},有效值: {valid_statuses}") + class UsageService: """用量统计服务""" @@ -471,6 +537,97 @@ class UsageService: cache_ttl_minutes=cache_ttl_minutes, ) + @classmethod + async def _prepare_usage_record( + cls, + params: UsageRecordParams, + ) -> Tuple[Dict[str, Any], float]: + """准备用量记录的共享逻辑 + + 此方法提取了 record_usage 和 record_usage_async 的公共处理逻辑: + - 获取费率倍数 + - 计算成本 + - 构建 Usage 参数 + + Args: + params: 用量记录参数数据类 + + Returns: + (usage_params 字典, total_cost 总成本) + """ + # 获取费率倍数和是否免费套餐 + actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier( + params.db, params.provider_api_key_id, params.provider_id + ) + + # 计算成本 + is_failed_request = params.status_code >= 400 or params.error_message is not None + ( + input_price, output_price, cache_creation_price, cache_read_price, request_price, + input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, + request_cost, total_cost, _tier_index + ) = await cls._calculate_costs( + db=params.db, + provider=params.provider, + model=params.model, + input_tokens=params.input_tokens, + output_tokens=params.output_tokens, + cache_creation_input_tokens=params.cache_creation_input_tokens, + cache_read_input_tokens=params.cache_read_input_tokens, + api_format=params.api_format, + cache_ttl_minutes=params.cache_ttl_minutes, + use_tiered_pricing=params.use_tiered_pricing, + is_failed_request=is_failed_request, + ) + + # 构建 Usage 参数 + usage_params = cls._build_usage_params( + db=params.db, + user=params.user, + api_key=params.api_key, + provider=params.provider, + model=params.model, + input_tokens=params.input_tokens, + output_tokens=params.output_tokens, + cache_creation_input_tokens=params.cache_creation_input_tokens, + cache_read_input_tokens=params.cache_read_input_tokens, + request_type=params.request_type, + api_format=params.api_format, + is_stream=params.is_stream, + response_time_ms=params.response_time_ms, + first_byte_time_ms=params.first_byte_time_ms, + status_code=params.status_code, + error_message=params.error_message, + metadata=params.metadata, + request_headers=params.request_headers, + request_body=params.request_body, + provider_request_headers=params.provider_request_headers, + response_headers=params.response_headers, + response_body=params.response_body, + request_id=params.request_id, + provider_id=params.provider_id, + provider_endpoint_id=params.provider_endpoint_id, + provider_api_key_id=params.provider_api_key_id, + status=params.status, + target_model=params.target_model, + input_cost=input_cost, + output_cost=output_cost, + cache_creation_cost=cache_creation_cost, + cache_read_cost=cache_read_cost, + cache_cost=cache_cost, + request_cost=request_cost, + total_cost=total_cost, + input_price=input_price, + output_price=output_price, + cache_creation_price=cache_creation_price, + cache_read_price=cache_read_price, + request_price=request_price, + actual_rate_multiplier=actual_rate_multiplier, + is_free_tier=is_free_tier, + ) + + return usage_params, total_cost + @classmethod async def record_usage_async( cls, @@ -516,76 +673,25 @@ class UsageService: if request_id is None: request_id = str(uuid.uuid4())[:8] - # 获取费率倍数和是否免费套餐 - actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier( - db, provider_api_key_id, provider_id - ) - - # 计算成本 - is_failed_request = status_code >= 400 or error_message is not None - ( - input_price, output_price, cache_creation_price, cache_read_price, request_price, - input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, - request_cost, total_cost, tier_index - ) = await cls._calculate_costs( - db=db, - provider=provider, - model=model, - input_tokens=input_tokens, - output_tokens=output_tokens, + # 使用共享逻辑准备记录参数 + params = UsageRecordParams( + db=db, user=user, api_key=api_key, provider=provider, model=model, + input_tokens=input_tokens, output_tokens=output_tokens, cache_creation_input_tokens=cache_creation_input_tokens, cache_read_input_tokens=cache_read_input_tokens, - api_format=api_format, - cache_ttl_minutes=cache_ttl_minutes, - use_tiered_pricing=use_tiered_pricing, - is_failed_request=is_failed_request, - ) - - # 构建 Usage 参数 - usage_params = cls._build_usage_params( - db=db, - user=user, - api_key=api_key, - provider=provider, - model=model, - input_tokens=input_tokens, - output_tokens=output_tokens, - cache_creation_input_tokens=cache_creation_input_tokens, - cache_read_input_tokens=cache_read_input_tokens, - request_type=request_type, - api_format=api_format, - is_stream=is_stream, - response_time_ms=response_time_ms, - first_byte_time_ms=first_byte_time_ms, - status_code=status_code, - error_message=error_message, - metadata=metadata, - request_headers=request_headers, - request_body=request_body, + request_type=request_type, api_format=api_format, is_stream=is_stream, + response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms, + status_code=status_code, error_message=error_message, metadata=metadata, + request_headers=request_headers, request_body=request_body, provider_request_headers=provider_request_headers, - response_headers=response_headers, - response_body=response_body, - request_id=request_id, - provider_id=provider_id, + response_headers=response_headers, response_body=response_body, + request_id=request_id, provider_id=provider_id, provider_endpoint_id=provider_endpoint_id, - provider_api_key_id=provider_api_key_id, - status=status, + provider_api_key_id=provider_api_key_id, status=status, + cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing, target_model=target_model, - input_cost=input_cost, - output_cost=output_cost, - cache_creation_cost=cache_creation_cost, - cache_read_cost=cache_read_cost, - cache_cost=cache_cost, - request_cost=request_cost, - total_cost=total_cost, - input_price=input_price, - output_price=output_price, - cache_creation_price=cache_creation_price, - cache_read_price=cache_read_price, - request_price=request_price, - actual_rate_multiplier=actual_rate_multiplier, - is_free_tier=is_free_tier, ) + usage_params, _ = await cls._prepare_usage_record(params) # 创建 Usage 记录 usage = Usage(**usage_params) @@ -660,76 +766,25 @@ class UsageService: if request_id is None: request_id = str(uuid.uuid4())[:8] - # 获取费率倍数和是否免费套餐 - actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier( - db, provider_api_key_id, provider_id - ) - - # 计算成本 - is_failed_request = status_code >= 400 or error_message is not None - ( - input_price, output_price, cache_creation_price, cache_read_price, request_price, - input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, - request_cost, total_cost, _tier_index - ) = await cls._calculate_costs( - db=db, - provider=provider, - model=model, - input_tokens=input_tokens, - output_tokens=output_tokens, + # 使用共享逻辑准备记录参数 + params = UsageRecordParams( + db=db, user=user, api_key=api_key, provider=provider, model=model, + input_tokens=input_tokens, output_tokens=output_tokens, cache_creation_input_tokens=cache_creation_input_tokens, cache_read_input_tokens=cache_read_input_tokens, - api_format=api_format, - cache_ttl_minutes=cache_ttl_minutes, - use_tiered_pricing=use_tiered_pricing, - is_failed_request=is_failed_request, - ) - - # 构建 Usage 参数 - usage_params = cls._build_usage_params( - db=db, - user=user, - api_key=api_key, - provider=provider, - model=model, - input_tokens=input_tokens, - output_tokens=output_tokens, - cache_creation_input_tokens=cache_creation_input_tokens, - cache_read_input_tokens=cache_read_input_tokens, - request_type=request_type, - api_format=api_format, - is_stream=is_stream, - response_time_ms=response_time_ms, - first_byte_time_ms=first_byte_time_ms, - status_code=status_code, - error_message=error_message, - metadata=metadata, - request_headers=request_headers, - request_body=request_body, + request_type=request_type, api_format=api_format, is_stream=is_stream, + response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms, + status_code=status_code, error_message=error_message, metadata=metadata, + request_headers=request_headers, request_body=request_body, provider_request_headers=provider_request_headers, - response_headers=response_headers, - response_body=response_body, - request_id=request_id, - provider_id=provider_id, + response_headers=response_headers, response_body=response_body, + request_id=request_id, provider_id=provider_id, provider_endpoint_id=provider_endpoint_id, - provider_api_key_id=provider_api_key_id, - status=status, + provider_api_key_id=provider_api_key_id, status=status, + cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing, target_model=target_model, - input_cost=input_cost, - output_cost=output_cost, - cache_creation_cost=cache_creation_cost, - cache_read_cost=cache_read_cost, - cache_cost=cache_cost, - request_cost=request_cost, - total_cost=total_cost, - input_price=input_price, - output_price=output_price, - cache_creation_price=cache_creation_price, - cache_read_price=cache_read_price, - request_price=request_price, - actual_rate_multiplier=actual_rate_multiplier, - is_free_tier=is_free_tier, ) + usage_params, total_cost = await cls._prepare_usage_record(params) # 检查是否已存在相同 request_id 的记录 existing_usage = db.query(Usage).filter(Usage.request_id == request_id).first() @@ -751,7 +806,7 @@ class UsageService: api_key = db.merge(api_key) # 使用原子更新避免并发竞态条件 - from sqlalchemy import func, update + from sqlalchemy import func as sql_func, update from src.models.database import ApiKey as ApiKeyModel, User as UserModel, GlobalModel # 更新用户使用量(独立 Key 不计入创建者的使用记录) @@ -762,7 +817,7 @@ class UsageService: .values( used_usd=UserModel.used_usd + total_cost, total_usd=UserModel.total_usd + total_cost, - updated_at=func.now(), + updated_at=sql_func.now(), ) ) @@ -776,8 +831,8 @@ class UsageService: total_requests=ApiKeyModel.total_requests + 1, total_cost_usd=ApiKeyModel.total_cost_usd + total_cost, balance_used_usd=ApiKeyModel.balance_used_usd + total_cost, - last_used_at=func.now(), - updated_at=func.now(), + last_used_at=sql_func.now(), + updated_at=sql_func.now(), ) ) else: @@ -787,8 +842,8 @@ class UsageService: .values( total_requests=ApiKeyModel.total_requests + 1, total_cost_usd=ApiKeyModel.total_cost_usd + total_cost, - last_used_at=func.now(), - updated_at=func.now(), + last_used_at=sql_func.now(), + updated_at=sql_func.now(), ) ) @@ -1121,19 +1176,48 @@ class UsageService: ] @staticmethod - def cleanup_old_usage_records(db: Session, days_to_keep: int = 90) -> int: - """清理旧的使用记录""" + def cleanup_old_usage_records( + db: Session, days_to_keep: int = 90, batch_size: int = 1000 + ) -> int: + """清理旧的使用记录(分批删除避免长事务锁定) + Args: + db: 数据库会话 + days_to_keep: 保留天数,默认 90 天 + batch_size: 每批删除数量,默认 1000 条 + + Returns: + 删除的总记录数 + """ cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep) + total_deleted = 0 - # 删除旧记录 - deleted = db.query(Usage).filter(Usage.created_at < cutoff_date).delete() + while True: + # 查询待删除的 ID(使用新索引 idx_usage_user_created) + batch_ids = ( + db.query(Usage.id) + .filter(Usage.created_at < cutoff_date) + .limit(batch_size) + .all() + ) - db.commit() + if not batch_ids: + break - logger.info(f"清理使用记录: 删除 {deleted} 条超过 {days_to_keep} 天的记录") + # 批量删除 + deleted_count = ( + db.query(Usage) + .filter(Usage.id.in_([row.id for row in batch_ids])) + .delete(synchronize_session=False) + ) + db.commit() + total_deleted += deleted_count - return deleted + logger.debug(f"清理使用记录: 本批删除 {deleted_count} 条") + + logger.info(f"清理使用记录: 共删除 {total_deleted} 条超过 {days_to_keep} 天的记录") + + return total_deleted # ========== 请求状态追踪方法 ========== @@ -1219,6 +1303,7 @@ class UsageService: error_message: Optional[str] = None, provider: Optional[str] = None, target_model: Optional[str] = None, + first_byte_time_ms: Optional[int] = None, ) -> Optional[Usage]: """ 快速更新使用记录状态 @@ -1230,6 +1315,7 @@ class UsageService: error_message: 错误消息(仅在 failed 状态时使用) provider: 提供商名称(可选,streaming 状态时更新) target_model: 映射后的目标模型名(可选) + first_byte_time_ms: 首字时间/TTFB(可选,streaming 状态时更新) Returns: 更新后的 Usage 记录,如果未找到则返回 None @@ -1247,6 +1333,8 @@ class UsageService: usage.provider = provider if target_model: usage.target_model = target_model + if first_byte_time_ms is not None: + usage.first_byte_time_ms = first_byte_time_ms db.commit() diff --git a/src/services/usage/stream.py b/src/services/usage/stream.py index 68e10ef..753f350 100644 --- a/src/services/usage/stream.py +++ b/src/services/usage/stream.py @@ -5,6 +5,7 @@ import json import re +import time from typing import Any, AsyncIterator, Dict, Optional, Tuple from sqlalchemy.orm import Session @@ -457,26 +458,32 @@ class StreamUsageTracker: logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}") - # 更新状态为 streaming,同时更新 provider - if self.request_id: - try: - from src.services.usage.service import UsageService - UsageService.update_usage_status( - db=self.db, - request_id=self.request_id, - status="streaming", - provider=self.provider, - ) - except Exception as e: - logger.warning(f"更新使用记录状态为 streaming 失败: {e}") - chunk_count = 0 + first_chunk_received = False try: async for chunk in stream: chunk_count += 1 # 保存原始字节流(用于错误诊断) self.raw_chunks.append(chunk) + # 第一个 chunk 收到时,更新状态为 streaming 并记录 TTFB + if not first_chunk_received: + first_chunk_received = True + if self.request_id: + try: + # 计算 TTFB(使用请求原始开始时间或 track_stream 开始时间) + base_time = self.request_start_time or self.start_time + first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None + UsageService.update_usage_status( + db=self.db, + request_id=self.request_id, + status="streaming", + provider=self.provider, + first_byte_time_ms=first_byte_time_ms, + ) + except Exception as e: + logger.warning(f"更新使用记录状态为 streaming 失败: {e}") + # 返回原始块给客户端 yield chunk diff --git a/src/utils/timeout.py b/src/utils/timeout.py index 392e73b..0497a3e 100644 --- a/src/utils/timeout.py +++ b/src/utils/timeout.py @@ -139,3 +139,83 @@ async def with_timeout_context(timeout: float, operation_name: str = "operation" # Python 3.10 及以下版本的兼容实现 # 注意:这个简单实现不支持嵌套取消 pass + + +async def read_first_chunk_with_ttfb_timeout( + byte_iterator: Any, + timeout: float, + request_id: str, + provider_name: str, +) -> tuple[bytes, Any]: + """ + 读取流的首字节并应用 TTFB 超时检测 + + 首字节超时(Time To First Byte)用于检测慢响应的 Provider, + 超时时触发故障转移到其他可用的 Provider。 + + Args: + byte_iterator: 异步字节流迭代器 + timeout: TTFB 超时时间(秒) + request_id: 请求 ID(用于日志) + provider_name: Provider 名称(用于日志和异常) + + Returns: + (first_chunk, aiter): 首个字节块和异步迭代器 + + Raises: + ProviderTimeoutException: 如果首字节超时 + """ + from src.core.exceptions import ProviderTimeoutException + + aiter = byte_iterator.__aiter__() + + try: + first_chunk = await asyncio.wait_for(aiter.__anext__(), timeout=timeout) + return first_chunk, aiter + except asyncio.TimeoutError: + # 完整的资源清理:先关闭迭代器,再关闭底层响应 + await _cleanup_iterator_resources(aiter, request_id) + logger.warning( + f" [{request_id}] 流首字节超时 (TTFB): " + f"Provider={provider_name}, timeout={timeout}s" + ) + raise ProviderTimeoutException( + provider_name=provider_name, + timeout=int(timeout), + ) + + +async def _cleanup_iterator_resources(aiter: Any, request_id: str) -> None: + """ + 清理异步迭代器及其底层资源 + + 确保在 TTFB 超时后正确释放 HTTP 连接,避免连接泄漏。 + + Args: + aiter: 异步迭代器 + request_id: 请求 ID(用于日志) + """ + # 1. 关闭迭代器本身 + if hasattr(aiter, "aclose"): + try: + await aiter.aclose() + except Exception as e: + logger.debug(f" [{request_id}] 关闭迭代器失败: {e}") + + # 2. 关闭底层响应对象(httpx.Response) + # 迭代器可能持有 _response 属性指向底层响应 + response = getattr(aiter, "_response", None) + if response is not None and hasattr(response, "aclose"): + try: + await response.aclose() + except Exception as e: + logger.debug(f" [{request_id}] 关闭底层响应失败: {e}") + + # 3. 尝试关闭 httpx 流(如果迭代器是 httpx 的 aiter_bytes) + # httpx 的 Response.aiter_bytes() 返回的生成器可能有 _stream 属性 + stream = getattr(aiter, "_stream", None) + if stream is not None and hasattr(stream, "aclose"): + try: + await stream.aclose() + except Exception as e: + logger.debug(f" [{request_id}] 关闭流对象失败: {e}")