mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
feat: add TTFB timeout detection and improve stream handling
- Add stream first byte timeout (TTFB) detection to trigger failover when provider responds too slowly (configurable via STREAM_FIRST_BYTE_TIMEOUT) - Add rate limit fail-open/fail-close strategy configuration - Improve exception handling in stream prefetch with proper error classification - Refactor UsageService with shared _prepare_usage_record method - Add batch deletion for old usage records to avoid long transaction locks - Update CLI adapters to use proper User-Agent headers for each CLI client - Add composite indexes migration for usage table query optimization - Fix streaming status display in frontend to show TTFB during streaming - Remove sensitive JWT secret logging in auth service
This commit is contained in:
@@ -0,0 +1,65 @@
|
|||||||
|
"""add usage table composite indexes for query optimization
|
||||||
|
|
||||||
|
Revision ID: b2c3d4e5f6g7
|
||||||
|
Revises: a1b2c3d4e5f6
|
||||||
|
Create Date: 2025-12-20 15:00:00.000000+00:00
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'b2c3d4e5f6g7'
|
||||||
|
down_revision = 'a1b2c3d4e5f6'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def index_exists(table_name: str, index_name: str) -> bool:
|
||||||
|
"""检查索引是否存在"""
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = inspect(bind)
|
||||||
|
indexes = [idx['name'] for idx in inspector.get_indexes(table_name)]
|
||||||
|
return index_name in indexes
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""为 usage 表添加复合索引以优化常见查询"""
|
||||||
|
# 1. user_id + created_at 复合索引 (用户用量查询)
|
||||||
|
if not index_exists('usage', 'idx_usage_user_created'):
|
||||||
|
op.create_index(
|
||||||
|
'idx_usage_user_created',
|
||||||
|
'usage',
|
||||||
|
['user_id', 'created_at'],
|
||||||
|
postgresql_concurrently=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. api_key_id + created_at 复合索引 (API Key 用量查询)
|
||||||
|
if not index_exists('usage', 'idx_usage_apikey_created'):
|
||||||
|
op.create_index(
|
||||||
|
'idx_usage_apikey_created',
|
||||||
|
'usage',
|
||||||
|
['api_key_id', 'created_at'],
|
||||||
|
postgresql_concurrently=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. provider + model + created_at 复合索引 (模型统计查询)
|
||||||
|
if not index_exists('usage', 'idx_usage_provider_model_created'):
|
||||||
|
op.create_index(
|
||||||
|
'idx_usage_provider_model_created',
|
||||||
|
'usage',
|
||||||
|
['provider', 'model', 'created_at'],
|
||||||
|
postgresql_concurrently=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""删除复合索引"""
|
||||||
|
if index_exists('usage', 'idx_usage_provider_model_created'):
|
||||||
|
op.drop_index('idx_usage_provider_model_created', table_name='usage')
|
||||||
|
|
||||||
|
if index_exists('usage', 'idx_usage_apikey_created'):
|
||||||
|
op.drop_index('idx_usage_apikey_created', table_name='usage')
|
||||||
|
|
||||||
|
if index_exists('usage', 'idx_usage_user_created'):
|
||||||
|
op.drop_index('idx_usage_user_created', table_name='usage')
|
||||||
@@ -366,14 +366,34 @@
|
|||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell class="text-right py-4 w-[70px]">
|
<TableCell class="text-right py-4 w-[70px]">
|
||||||
|
<!-- pending 状态:只显示增长的总时间 -->
|
||||||
<div
|
<div
|
||||||
v-if="record.status === 'pending' || record.status === 'streaming'"
|
v-if="record.status === 'pending'"
|
||||||
class="flex flex-col items-end text-xs gap-0.5"
|
class="flex flex-col items-end text-xs gap-0.5"
|
||||||
>
|
>
|
||||||
|
<span class="text-muted-foreground">-</span>
|
||||||
<span class="text-primary tabular-nums">
|
<span class="text-primary tabular-nums">
|
||||||
{{ getElapsedTime(record) }}
|
{{ getElapsedTime(record) }}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
|
<!-- streaming 状态:首字固定 + 总时间增长 -->
|
||||||
|
<div
|
||||||
|
v-else-if="record.status === 'streaming'"
|
||||||
|
class="flex flex-col items-end text-xs gap-0.5"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
v-if="record.first_byte_time_ms != null"
|
||||||
|
class="tabular-nums"
|
||||||
|
>{{ (record.first_byte_time_ms / 1000).toFixed(2) }}s</span>
|
||||||
|
<span
|
||||||
|
v-else
|
||||||
|
class="text-muted-foreground"
|
||||||
|
>-</span>
|
||||||
|
<span class="text-primary tabular-nums">
|
||||||
|
{{ getElapsedTime(record) }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<!-- 已完成状态:首字 + 总耗时 -->
|
||||||
<div
|
<div
|
||||||
v-else-if="record.response_time_ms != null"
|
v-else-if="record.response_time_ms != null"
|
||||||
class="flex flex-col items-end text-xs gap-0.5"
|
class="flex flex-col items-end text-xs gap-0.5"
|
||||||
|
|||||||
@@ -376,6 +376,9 @@ class BaseMessageHandler:
|
|||||||
|
|
||||||
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
||||||
|
|
||||||
|
注意:TTFB(首字节时间)由 StreamContext.record_first_byte_time() 记录,
|
||||||
|
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request_id: 请求 ID,如果不传则使用 self.request_id
|
request_id: 请求 ID,如果不传则使用 self.request_id
|
||||||
"""
|
"""
|
||||||
@@ -407,6 +410,9 @@ class BaseMessageHandler:
|
|||||||
|
|
||||||
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
||||||
|
|
||||||
|
注意:TTFB(首字节时间)由 StreamContext.record_first_byte_time() 记录,
|
||||||
|
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ctx: 流式上下文,包含 provider_name 和 mapped_model
|
ctx: 流式上下文,包含 provider_name 和 mapped_model
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -57,8 +57,10 @@ from src.models.database import (
|
|||||||
ProviderEndpoint,
|
ProviderEndpoint,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
|
from src.config.settings import config
|
||||||
from src.services.provider.transport import build_provider_url
|
from src.services.provider.transport import build_provider_url
|
||||||
from src.utils.sse_parser import SSEEventParser
|
from src.utils.sse_parser import SSEEventParser
|
||||||
|
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
|
||||||
|
|
||||||
|
|
||||||
class CliMessageHandlerBase(BaseMessageHandler):
|
class CliMessageHandlerBase(BaseMessageHandler):
|
||||||
@@ -672,6 +674,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
|
|
||||||
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
|
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
|
||||||
|
|
||||||
|
首次读取时会应用 TTFB(首字节超时)检测,超时则触发故障转移。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
byte_iterator: 字节流迭代器
|
byte_iterator: 字节流迭代器
|
||||||
provider: Provider 对象
|
provider: Provider 对象
|
||||||
@@ -684,6 +688,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
Raises:
|
Raises:
|
||||||
EmbeddedErrorException: 如果检测到嵌套错误
|
EmbeddedErrorException: 如果检测到嵌套错误
|
||||||
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
|
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
|
||||||
|
ProviderTimeoutException: 如果首字节超时(TTFB timeout)
|
||||||
"""
|
"""
|
||||||
prefetched_chunks: list = []
|
prefetched_chunks: list = []
|
||||||
max_prefetch_lines = 5 # 最多预读5行来检测错误
|
max_prefetch_lines = 5 # 最多预读5行来检测错误
|
||||||
@@ -704,7 +709,19 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
else:
|
else:
|
||||||
provider_parser = self.parser
|
provider_parser = self.parser
|
||||||
|
|
||||||
async for chunk in byte_iterator:
|
# 使用共享的 TTFB 超时函数读取首字节
|
||||||
|
ttfb_timeout = config.stream_first_byte_timeout
|
||||||
|
first_chunk, aiter = await read_first_chunk_with_ttfb_timeout(
|
||||||
|
byte_iterator,
|
||||||
|
timeout=ttfb_timeout,
|
||||||
|
request_id=self.request_id,
|
||||||
|
provider_name=str(provider.name),
|
||||||
|
)
|
||||||
|
prefetched_chunks.append(first_chunk)
|
||||||
|
buffer += first_chunk
|
||||||
|
|
||||||
|
# 继续读取剩余的预读数据
|
||||||
|
async for chunk in aiter:
|
||||||
prefetched_chunks.append(chunk)
|
prefetched_chunks.append(chunk)
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
|
|
||||||
@@ -785,12 +802,21 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
if should_stop or line_count >= max_prefetch_lines:
|
if should_stop or line_count >= max_prefetch_lines:
|
||||||
break
|
break
|
||||||
|
|
||||||
except EmbeddedErrorException:
|
except (EmbeddedErrorException, ProviderTimeoutException, ProviderNotAvailableException):
|
||||||
# 重新抛出嵌套错误
|
# 重新抛出可重试的 Provider 异常,触发故障转移
|
||||||
raise
|
raise
|
||||||
|
except (OSError, IOError) as e:
|
||||||
|
# 网络 I/O 异常:记录警告,可能需要重试
|
||||||
|
logger.warning(
|
||||||
|
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
|
# 未预期的严重异常:记录错误并重新抛出,避免掩盖问题
|
||||||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
logger.error(
|
||||||
|
f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
return prefetched_chunks
|
return prefetched_chunks
|
||||||
|
|
||||||
|
|||||||
@@ -25,10 +25,12 @@ from src.api.handlers.base.content_extractors import (
|
|||||||
from src.api.handlers.base.parsers import get_parser_for_format
|
from src.api.handlers.base.parsers import get_parser_for_format
|
||||||
from src.api.handlers.base.response_parser import ResponseParser
|
from src.api.handlers.base.response_parser import ResponseParser
|
||||||
from src.api.handlers.base.stream_context import StreamContext
|
from src.api.handlers.base.stream_context import StreamContext
|
||||||
from src.core.exceptions import EmbeddedErrorException
|
from src.config.settings import config
|
||||||
|
from src.core.exceptions import EmbeddedErrorException, ProviderTimeoutException
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.models.database import Provider, ProviderEndpoint
|
from src.models.database import Provider, ProviderEndpoint
|
||||||
from src.utils.sse_parser import SSEEventParser
|
from src.utils.sse_parser import SSEEventParser
|
||||||
|
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -170,6 +172,8 @@ class StreamProcessor:
|
|||||||
某些 Provider(如 Gemini)可能返回 HTTP 200,但在响应体中包含错误信息。
|
某些 Provider(如 Gemini)可能返回 HTTP 200,但在响应体中包含错误信息。
|
||||||
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
|
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
|
||||||
|
|
||||||
|
首次读取时会应用 TTFB(首字节超时)检测,超时则触发故障转移。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
byte_iterator: 字节流迭代器
|
byte_iterator: 字节流迭代器
|
||||||
provider: Provider 对象
|
provider: Provider 对象
|
||||||
@@ -182,6 +186,7 @@ class StreamProcessor:
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
EmbeddedErrorException: 如果检测到嵌套错误
|
EmbeddedErrorException: 如果检测到嵌套错误
|
||||||
|
ProviderTimeoutException: 如果首字节超时(TTFB timeout)
|
||||||
"""
|
"""
|
||||||
prefetched_chunks: list = []
|
prefetched_chunks: list = []
|
||||||
parser = self.get_parser_for_provider(ctx)
|
parser = self.get_parser_for_provider(ctx)
|
||||||
@@ -192,7 +197,19 @@ class StreamProcessor:
|
|||||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in byte_iterator:
|
# 使用共享的 TTFB 超时函数读取首字节
|
||||||
|
ttfb_timeout = config.stream_first_byte_timeout
|
||||||
|
first_chunk, aiter = await read_first_chunk_with_ttfb_timeout(
|
||||||
|
byte_iterator,
|
||||||
|
timeout=ttfb_timeout,
|
||||||
|
request_id=self.request_id,
|
||||||
|
provider_name=str(provider.name),
|
||||||
|
)
|
||||||
|
prefetched_chunks.append(first_chunk)
|
||||||
|
buffer += first_chunk
|
||||||
|
|
||||||
|
# 继续读取剩余的预读数据
|
||||||
|
async for chunk in aiter:
|
||||||
prefetched_chunks.append(chunk)
|
prefetched_chunks.append(chunk)
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
|
|
||||||
@@ -262,10 +279,21 @@ class StreamProcessor:
|
|||||||
if should_stop or line_count >= max_prefetch_lines:
|
if should_stop or line_count >= max_prefetch_lines:
|
||||||
break
|
break
|
||||||
|
|
||||||
except EmbeddedErrorException:
|
except (EmbeddedErrorException, ProviderTimeoutException):
|
||||||
|
# 重新抛出可重试的 Provider 异常,触发故障转移
|
||||||
raise
|
raise
|
||||||
|
except (OSError, IOError) as e:
|
||||||
|
# 网络 I/O <20><><EFBFBD>常:记录警告,可能需要重试
|
||||||
|
logger.warning(
|
||||||
|
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
# 未预期的严重异常:记录错误并重新抛出,避免掩盖问题
|
||||||
|
logger.error(
|
||||||
|
f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
return prefetched_chunks
|
return prefetched_chunks
|
||||||
|
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ class ClaudeCliAdapter(CliAdapterBase):
|
|||||||
) -> Tuple[list, Optional[str]]:
|
) -> Tuple[list, Optional[str]]:
|
||||||
"""查询 Claude API 支持的模型列表(带 CLI User-Agent)"""
|
"""查询 Claude API 支持的模型列表(带 CLI User-Agent)"""
|
||||||
# 复用 ClaudeChatAdapter 的实现,添加 CLI User-Agent
|
# 复用 ClaudeChatAdapter 的实现,添加 CLI User-Agent
|
||||||
cli_headers = {"User-Agent": config.internal_user_agent_claude}
|
cli_headers = {"User-Agent": config.internal_user_agent_claude_cli}
|
||||||
if extra_headers:
|
if extra_headers:
|
||||||
cli_headers.update(extra_headers)
|
cli_headers.update(extra_headers)
|
||||||
models, error = await ClaudeChatAdapter.fetch_models(
|
models, error = await ClaudeChatAdapter.fetch_models(
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ class GeminiCliAdapter(CliAdapterBase):
|
|||||||
) -> Tuple[list, Optional[str]]:
|
) -> Tuple[list, Optional[str]]:
|
||||||
"""查询 Gemini API 支持的模型列表(带 CLI User-Agent)"""
|
"""查询 Gemini API 支持的模型列表(带 CLI User-Agent)"""
|
||||||
# 复用 GeminiChatAdapter 的实现,添加 CLI User-Agent
|
# 复用 GeminiChatAdapter 的实现,添加 CLI User-Agent
|
||||||
cli_headers = {"User-Agent": config.internal_user_agent_gemini}
|
cli_headers = {"User-Agent": config.internal_user_agent_gemini_cli}
|
||||||
if extra_headers:
|
if extra_headers:
|
||||||
cli_headers.update(extra_headers)
|
cli_headers.update(extra_headers)
|
||||||
models, error = await GeminiChatAdapter.fetch_models(
|
models, error = await GeminiChatAdapter.fetch_models(
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class OpenAICliAdapter(CliAdapterBase):
|
|||||||
) -> Tuple[list, Optional[str]]:
|
) -> Tuple[list, Optional[str]]:
|
||||||
"""查询 OpenAI 兼容 API 支持的模型列表(带 CLI User-Agent)"""
|
"""查询 OpenAI 兼容 API 支持的模型列表(带 CLI User-Agent)"""
|
||||||
# 复用 OpenAIChatAdapter 的实现,添加 CLI User-Agent
|
# 复用 OpenAIChatAdapter 的实现,添加 CLI User-Agent
|
||||||
cli_headers = {"User-Agent": config.internal_user_agent_openai}
|
cli_headers = {"User-Agent": config.internal_user_agent_openai_cli}
|
||||||
if extra_headers:
|
if extra_headers:
|
||||||
cli_headers.update(extra_headers)
|
cli_headers.update(extra_headers)
|
||||||
models, error = await OpenAIChatAdapter.fetch_models(
|
models, error = await OpenAIChatAdapter.fetch_models(
|
||||||
|
|||||||
@@ -56,10 +56,11 @@ class Config:
|
|||||||
|
|
||||||
# Redis 依赖策略(生产默认必需,开发默认可选,可通过 REDIS_REQUIRED 覆盖)
|
# Redis 依赖策略(生产默认必需,开发默认可选,可通过 REDIS_REQUIRED 覆盖)
|
||||||
redis_required_env = os.getenv("REDIS_REQUIRED")
|
redis_required_env = os.getenv("REDIS_REQUIRED")
|
||||||
if redis_required_env is None:
|
if redis_required_env is not None:
|
||||||
self.require_redis = self.environment not in {"development", "test", "testing"}
|
|
||||||
else:
|
|
||||||
self.require_redis = redis_required_env.lower() == "true"
|
self.require_redis = redis_required_env.lower() == "true"
|
||||||
|
else:
|
||||||
|
# 保持向后兼容:开发环境可选,生产环境必需
|
||||||
|
self.require_redis = self.environment not in {"development", "test", "testing"}
|
||||||
|
|
||||||
# CORS配置 - 使用环境变量配置允许的源
|
# CORS配置 - 使用环境变量配置允许的源
|
||||||
# 格式: 逗号分隔的域名列表,如 "http://localhost:3000,https://example.com"
|
# 格式: 逗号分隔的域名列表,如 "http://localhost:3000,https://example.com"
|
||||||
@@ -133,6 +134,18 @@ class Config:
|
|||||||
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
|
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
|
||||||
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1"))
|
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1"))
|
||||||
|
|
||||||
|
# 限流降级策略配置
|
||||||
|
# RATE_LIMIT_FAIL_OPEN: 当限流服务(Redis)异常时的行为
|
||||||
|
#
|
||||||
|
# True (默认): fail-open - 放行请求(优先可用性)
|
||||||
|
# 风险:Redis 故障期间无法限流,可能被滥用
|
||||||
|
# 适用:API 网关作为关键基础设施,必须保持高可用
|
||||||
|
#
|
||||||
|
# False: fail-close - 拒绝所有请求(优先安全性)
|
||||||
|
# 风险:Redis 故障会导致 API 网关不可用
|
||||||
|
# 适用:有严格速率限制要求的安全敏感场景
|
||||||
|
self.rate_limit_fail_open = os.getenv("RATE_LIMIT_FAIL_OPEN", "true").lower() == "true"
|
||||||
|
|
||||||
# HTTP 请求超时配置(秒)
|
# HTTP 请求超时配置(秒)
|
||||||
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
|
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
|
||||||
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
|
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
|
||||||
@@ -141,19 +154,22 @@ class Config:
|
|||||||
# 流式处理配置
|
# 流式处理配置
|
||||||
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
|
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
|
||||||
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
|
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
|
||||||
|
# STREAM_FIRST_BYTE_TIMEOUT: 首字节超时(秒),等待首字节超过此时间触发故障转移
|
||||||
|
# 范围: 10-120 秒,默认 30 秒(必须小于 http_write_timeout 避免竞态)
|
||||||
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
|
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
|
||||||
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
|
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
|
||||||
|
self.stream_first_byte_timeout = self._parse_ttfb_timeout()
|
||||||
|
|
||||||
# 内部请求 User-Agent 配置(用于查询上游模型列表等)
|
# 内部请求 User-Agent 配置(用于查询上游模型列表等)
|
||||||
# 可通过环境变量覆盖默认值
|
# 可通过环境变量覆盖默认值,模拟对应 CLI 客户端
|
||||||
self.internal_user_agent_claude = os.getenv(
|
self.internal_user_agent_claude_cli = os.getenv(
|
||||||
"CLAUDE_USER_AGENT", "claude-cli/1.0"
|
"CLAUDE_CLI_USER_AGENT", "claude-code/1.0.1"
|
||||||
)
|
)
|
||||||
self.internal_user_agent_openai = os.getenv(
|
self.internal_user_agent_openai_cli = os.getenv(
|
||||||
"OPENAI_USER_AGENT", "openai-cli/1.0"
|
"OPENAI_CLI_USER_AGENT", "openai-codex/1.0"
|
||||||
)
|
)
|
||||||
self.internal_user_agent_gemini = os.getenv(
|
self.internal_user_agent_gemini_cli = os.getenv(
|
||||||
"GEMINI_USER_AGENT", "gemini-cli/1.0"
|
"GEMINI_CLI_USER_AGENT", "gemini-cli/0.1.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 验证连接池配置
|
# 验证连接池配置
|
||||||
@@ -177,6 +193,39 @@ class Config:
|
|||||||
"""智能计算最大溢出连接数 - 与 pool_size 相同"""
|
"""智能计算最大溢出连接数 - 与 pool_size 相同"""
|
||||||
return self.db_pool_size
|
return self.db_pool_size
|
||||||
|
|
||||||
|
def _parse_ttfb_timeout(self) -> float:
|
||||||
|
"""
|
||||||
|
解析 TTFB 超时配置,带错误处理和范围限制
|
||||||
|
|
||||||
|
TTFB (Time To First Byte) 用于检测慢响应的 Provider,超时触发故障转移。
|
||||||
|
此值必须小于 http_write_timeout,避免竞态条件。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
超时时间(秒),范围 10-120,默认 30
|
||||||
|
"""
|
||||||
|
default_timeout = 30.0
|
||||||
|
min_timeout = 10.0
|
||||||
|
max_timeout = 120.0 # 必须小于 http_write_timeout (默认 60s) 的 2 倍
|
||||||
|
|
||||||
|
raw_value = os.getenv("STREAM_FIRST_BYTE_TIMEOUT", str(default_timeout))
|
||||||
|
try:
|
||||||
|
timeout = float(raw_value)
|
||||||
|
except ValueError:
|
||||||
|
# 延迟导入,避免循环依赖(Config 初始化时 logger 可能未就绪)
|
||||||
|
self._ttfb_config_warning = (
|
||||||
|
f"无效的 STREAM_FIRST_BYTE_TIMEOUT 配置 '{raw_value}',使用默认值 {default_timeout}秒"
|
||||||
|
)
|
||||||
|
return default_timeout
|
||||||
|
|
||||||
|
# 范围限制
|
||||||
|
clamped = max(min_timeout, min(max_timeout, timeout))
|
||||||
|
if clamped != timeout:
|
||||||
|
self._ttfb_config_warning = (
|
||||||
|
f"STREAM_FIRST_BYTE_TIMEOUT={timeout}秒超出范围 [{min_timeout}-{max_timeout}],"
|
||||||
|
f"已调整为 {clamped}秒"
|
||||||
|
)
|
||||||
|
return clamped
|
||||||
|
|
||||||
def _validate_pool_config(self) -> None:
|
def _validate_pool_config(self) -> None:
|
||||||
"""验证连接池配置是否安全"""
|
"""验证连接池配置是否安全"""
|
||||||
total_per_worker = self.db_pool_size + self.db_max_overflow
|
total_per_worker = self.db_pool_size + self.db_max_overflow
|
||||||
@@ -224,6 +273,10 @@ class Config:
|
|||||||
if hasattr(self, "_pool_config_warning") and self._pool_config_warning:
|
if hasattr(self, "_pool_config_warning") and self._pool_config_warning:
|
||||||
logger.warning(self._pool_config_warning)
|
logger.warning(self._pool_config_warning)
|
||||||
|
|
||||||
|
# TTFB 超时配置警告
|
||||||
|
if hasattr(self, "_ttfb_config_warning") and self._ttfb_config_warning:
|
||||||
|
logger.warning(self._ttfb_config_warning)
|
||||||
|
|
||||||
# 管理员密码检查(必须在环境变量中设置)
|
# 管理员密码检查(必须在环境变量中设置)
|
||||||
if hasattr(self, "_missing_admin_password") and self._missing_admin_password:
|
if hasattr(self, "_missing_admin_password") and self._missing_admin_password:
|
||||||
logger.error("必须设置 ADMIN_PASSWORD 环境变量!")
|
logger.error("必须设置 ADMIN_PASSWORD 环境变量!")
|
||||||
|
|||||||
@@ -336,10 +336,44 @@ class PluginMiddleware:
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
return None
|
return None
|
||||||
|
except ConnectionError as e:
|
||||||
|
# Redis 连接错误:根据配置决定
|
||||||
|
logger.warning(f"Rate limit connection error: {e}")
|
||||||
|
if config.rate_limit_fail_open:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return RateLimitResult(
|
||||||
|
allowed=False,
|
||||||
|
remaining=0,
|
||||||
|
retry_after=30,
|
||||||
|
message="Rate limit service unavailable"
|
||||||
|
)
|
||||||
|
except TimeoutError as e:
|
||||||
|
# 超时错误:可能是负载过高,根据配置决定
|
||||||
|
logger.warning(f"Rate limit timeout: {e}")
|
||||||
|
if config.rate_limit_fail_open:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return RateLimitResult(
|
||||||
|
allowed=False,
|
||||||
|
remaining=0,
|
||||||
|
retry_after=30,
|
||||||
|
message="Rate limit service timeout"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Rate limit error: {e}")
|
logger.error(f"Rate limit error: {type(e).__name__}: {e}")
|
||||||
# 发生错误时允许请求通过
|
# 其他异常:根据配置决定
|
||||||
return None
|
if config.rate_limit_fail_open:
|
||||||
|
# fail-open: 异常时放行请求(优先可用性)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# fail-close: 异常时拒绝请求(优先安全性)
|
||||||
|
return RateLimitResult(
|
||||||
|
allowed=False,
|
||||||
|
remaining=0,
|
||||||
|
retry_after=60,
|
||||||
|
message="Rate limit service error"
|
||||||
|
)
|
||||||
|
|
||||||
async def _call_pre_request_plugins(self, request: Request) -> None:
|
async def _call_pre_request_plugins(self, request: Request) -> None:
|
||||||
"""调用请求前的插件(当前保留扩展点)"""
|
"""调用请求前的插件(当前保留扩展点)"""
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ if not config.jwt_secret_key:
|
|||||||
if config.environment == "production":
|
if config.environment == "production":
|
||||||
raise ValueError("JWT_SECRET_KEY must be set in production environment!")
|
raise ValueError("JWT_SECRET_KEY must be set in production environment!")
|
||||||
config.jwt_secret_key = secrets.token_urlsafe(32)
|
config.jwt_secret_key = secrets.token_urlsafe(32)
|
||||||
logger.warning(f"JWT_SECRET_KEY未在环境变量中找到,已生成随机密钥用于开发: {config.jwt_secret_key[:10]}...")
|
logger.warning("JWT_SECRET_KEY未在环境变量中找到,已生成随机密钥用于开发")
|
||||||
logger.warning("生产环境请设置JWT_SECRET_KEY环境变量!")
|
logger.warning("生产环境请设置JWT_SECRET_KEY环境变量!")
|
||||||
|
|
||||||
JWT_SECRET_KEY = config.jwt_secret_key
|
JWT_SECRET_KEY = config.jwt_secret_key
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@@ -16,6 +17,71 @@ from src.services.model.cost import ModelCostService
|
|||||||
from src.services.system.config import SystemConfigService
|
from src.services.system.config import SystemConfigService
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageRecordParams:
|
||||||
|
"""用量记录参数数据类,用于在内部方法间传递数据"""
|
||||||
|
db: Session
|
||||||
|
user: Optional[User]
|
||||||
|
api_key: Optional[ApiKey]
|
||||||
|
provider: str
|
||||||
|
model: str
|
||||||
|
input_tokens: int
|
||||||
|
output_tokens: int
|
||||||
|
cache_creation_input_tokens: int
|
||||||
|
cache_read_input_tokens: int
|
||||||
|
request_type: str
|
||||||
|
api_format: Optional[str]
|
||||||
|
is_stream: bool
|
||||||
|
response_time_ms: Optional[int]
|
||||||
|
first_byte_time_ms: Optional[int]
|
||||||
|
status_code: int
|
||||||
|
error_message: Optional[str]
|
||||||
|
metadata: Optional[Dict[str, Any]]
|
||||||
|
request_headers: Optional[Dict[str, Any]]
|
||||||
|
request_body: Optional[Any]
|
||||||
|
provider_request_headers: Optional[Dict[str, Any]]
|
||||||
|
response_headers: Optional[Dict[str, Any]]
|
||||||
|
response_body: Optional[Any]
|
||||||
|
request_id: str
|
||||||
|
provider_id: Optional[str]
|
||||||
|
provider_endpoint_id: Optional[str]
|
||||||
|
provider_api_key_id: Optional[str]
|
||||||
|
status: str
|
||||||
|
cache_ttl_minutes: Optional[int]
|
||||||
|
use_tiered_pricing: bool
|
||||||
|
target_model: Optional[str]
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""验证关键字段,确保数据完整性"""
|
||||||
|
# Token 数量不能为负数
|
||||||
|
if self.input_tokens < 0:
|
||||||
|
raise ValueError(f"input_tokens 不能为负数: {self.input_tokens}")
|
||||||
|
if self.output_tokens < 0:
|
||||||
|
raise ValueError(f"output_tokens 不能为负数: {self.output_tokens}")
|
||||||
|
if self.cache_creation_input_tokens < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"cache_creation_input_tokens 不能为负数: {self.cache_creation_input_tokens}"
|
||||||
|
)
|
||||||
|
if self.cache_read_input_tokens < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"cache_read_input_tokens 不能为负数: {self.cache_read_input_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 响应时间不能为负数
|
||||||
|
if self.response_time_ms is not None and self.response_time_ms < 0:
|
||||||
|
raise ValueError(f"response_time_ms 不能为负数: {self.response_time_ms}")
|
||||||
|
if self.first_byte_time_ms is not None and self.first_byte_time_ms < 0:
|
||||||
|
raise ValueError(f"first_byte_time_ms 不能为负数: {self.first_byte_time_ms}")
|
||||||
|
|
||||||
|
# HTTP 状态码范围校验
|
||||||
|
if not (100 <= self.status_code <= 599):
|
||||||
|
raise ValueError(f"无效的 HTTP 状态码: {self.status_code}")
|
||||||
|
|
||||||
|
# 状态值校验
|
||||||
|
valid_statuses = {"pending", "streaming", "completed", "failed"}
|
||||||
|
if self.status not in valid_statuses:
|
||||||
|
raise ValueError(f"无效的状态值: {self.status},有效值: {valid_statuses}")
|
||||||
|
|
||||||
|
|
||||||
class UsageService:
|
class UsageService:
|
||||||
"""用量统计服务"""
|
"""用量统计服务"""
|
||||||
@@ -471,6 +537,97 @@ class UsageService:
|
|||||||
cache_ttl_minutes=cache_ttl_minutes,
|
cache_ttl_minutes=cache_ttl_minutes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _prepare_usage_record(
|
||||||
|
cls,
|
||||||
|
params: UsageRecordParams,
|
||||||
|
) -> Tuple[Dict[str, Any], float]:
|
||||||
|
"""准备用量记录的共享逻辑
|
||||||
|
|
||||||
|
此方法提取了 record_usage 和 record_usage_async 的公共处理逻辑:
|
||||||
|
- 获取费率倍数
|
||||||
|
- 计算成本
|
||||||
|
- 构建 Usage 参数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: 用量记录参数数据类
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(usage_params 字典, total_cost 总成本)
|
||||||
|
"""
|
||||||
|
# 获取费率倍数和是否免费套餐
|
||||||
|
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
|
||||||
|
params.db, params.provider_api_key_id, params.provider_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算成本
|
||||||
|
is_failed_request = params.status_code >= 400 or params.error_message is not None
|
||||||
|
(
|
||||||
|
input_price, output_price, cache_creation_price, cache_read_price, request_price,
|
||||||
|
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
|
||||||
|
request_cost, total_cost, _tier_index
|
||||||
|
) = await cls._calculate_costs(
|
||||||
|
db=params.db,
|
||||||
|
provider=params.provider,
|
||||||
|
model=params.model,
|
||||||
|
input_tokens=params.input_tokens,
|
||||||
|
output_tokens=params.output_tokens,
|
||||||
|
cache_creation_input_tokens=params.cache_creation_input_tokens,
|
||||||
|
cache_read_input_tokens=params.cache_read_input_tokens,
|
||||||
|
api_format=params.api_format,
|
||||||
|
cache_ttl_minutes=params.cache_ttl_minutes,
|
||||||
|
use_tiered_pricing=params.use_tiered_pricing,
|
||||||
|
is_failed_request=is_failed_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建 Usage 参数
|
||||||
|
usage_params = cls._build_usage_params(
|
||||||
|
db=params.db,
|
||||||
|
user=params.user,
|
||||||
|
api_key=params.api_key,
|
||||||
|
provider=params.provider,
|
||||||
|
model=params.model,
|
||||||
|
input_tokens=params.input_tokens,
|
||||||
|
output_tokens=params.output_tokens,
|
||||||
|
cache_creation_input_tokens=params.cache_creation_input_tokens,
|
||||||
|
cache_read_input_tokens=params.cache_read_input_tokens,
|
||||||
|
request_type=params.request_type,
|
||||||
|
api_format=params.api_format,
|
||||||
|
is_stream=params.is_stream,
|
||||||
|
response_time_ms=params.response_time_ms,
|
||||||
|
first_byte_time_ms=params.first_byte_time_ms,
|
||||||
|
status_code=params.status_code,
|
||||||
|
error_message=params.error_message,
|
||||||
|
metadata=params.metadata,
|
||||||
|
request_headers=params.request_headers,
|
||||||
|
request_body=params.request_body,
|
||||||
|
provider_request_headers=params.provider_request_headers,
|
||||||
|
response_headers=params.response_headers,
|
||||||
|
response_body=params.response_body,
|
||||||
|
request_id=params.request_id,
|
||||||
|
provider_id=params.provider_id,
|
||||||
|
provider_endpoint_id=params.provider_endpoint_id,
|
||||||
|
provider_api_key_id=params.provider_api_key_id,
|
||||||
|
status=params.status,
|
||||||
|
target_model=params.target_model,
|
||||||
|
input_cost=input_cost,
|
||||||
|
output_cost=output_cost,
|
||||||
|
cache_creation_cost=cache_creation_cost,
|
||||||
|
cache_read_cost=cache_read_cost,
|
||||||
|
cache_cost=cache_cost,
|
||||||
|
request_cost=request_cost,
|
||||||
|
total_cost=total_cost,
|
||||||
|
input_price=input_price,
|
||||||
|
output_price=output_price,
|
||||||
|
cache_creation_price=cache_creation_price,
|
||||||
|
cache_read_price=cache_read_price,
|
||||||
|
request_price=request_price,
|
||||||
|
actual_rate_multiplier=actual_rate_multiplier,
|
||||||
|
is_free_tier=is_free_tier,
|
||||||
|
)
|
||||||
|
|
||||||
|
return usage_params, total_cost
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def record_usage_async(
|
async def record_usage_async(
|
||||||
cls,
|
cls,
|
||||||
@@ -516,76 +673,25 @@ class UsageService:
|
|||||||
if request_id is None:
|
if request_id is None:
|
||||||
request_id = str(uuid.uuid4())[:8]
|
request_id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
# 获取费率倍数和是否免费套餐
|
# 使用共享逻辑准备记录参数
|
||||||
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
|
params = UsageRecordParams(
|
||||||
db, provider_api_key_id, provider_id
|
db=db, user=user, api_key=api_key, provider=provider, model=model,
|
||||||
)
|
input_tokens=input_tokens, output_tokens=output_tokens,
|
||||||
|
|
||||||
# 计算成本
|
|
||||||
is_failed_request = status_code >= 400 or error_message is not None
|
|
||||||
(
|
|
||||||
input_price, output_price, cache_creation_price, cache_read_price, request_price,
|
|
||||||
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
|
|
||||||
request_cost, total_cost, tier_index
|
|
||||||
) = await cls._calculate_costs(
|
|
||||||
db=db,
|
|
||||||
provider=provider,
|
|
||||||
model=model,
|
|
||||||
input_tokens=input_tokens,
|
|
||||||
output_tokens=output_tokens,
|
|
||||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||||
cache_read_input_tokens=cache_read_input_tokens,
|
cache_read_input_tokens=cache_read_input_tokens,
|
||||||
api_format=api_format,
|
request_type=request_type, api_format=api_format, is_stream=is_stream,
|
||||||
cache_ttl_minutes=cache_ttl_minutes,
|
response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms,
|
||||||
use_tiered_pricing=use_tiered_pricing,
|
status_code=status_code, error_message=error_message, metadata=metadata,
|
||||||
is_failed_request=is_failed_request,
|
request_headers=request_headers, request_body=request_body,
|
||||||
)
|
|
||||||
|
|
||||||
# 构建 Usage 参数
|
|
||||||
usage_params = cls._build_usage_params(
|
|
||||||
db=db,
|
|
||||||
user=user,
|
|
||||||
api_key=api_key,
|
|
||||||
provider=provider,
|
|
||||||
model=model,
|
|
||||||
input_tokens=input_tokens,
|
|
||||||
output_tokens=output_tokens,
|
|
||||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
|
||||||
cache_read_input_tokens=cache_read_input_tokens,
|
|
||||||
request_type=request_type,
|
|
||||||
api_format=api_format,
|
|
||||||
is_stream=is_stream,
|
|
||||||
response_time_ms=response_time_ms,
|
|
||||||
first_byte_time_ms=first_byte_time_ms,
|
|
||||||
status_code=status_code,
|
|
||||||
error_message=error_message,
|
|
||||||
metadata=metadata,
|
|
||||||
request_headers=request_headers,
|
|
||||||
request_body=request_body,
|
|
||||||
provider_request_headers=provider_request_headers,
|
provider_request_headers=provider_request_headers,
|
||||||
response_headers=response_headers,
|
response_headers=response_headers, response_body=response_body,
|
||||||
response_body=response_body,
|
request_id=request_id, provider_id=provider_id,
|
||||||
request_id=request_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_endpoint_id=provider_endpoint_id,
|
provider_endpoint_id=provider_endpoint_id,
|
||||||
provider_api_key_id=provider_api_key_id,
|
provider_api_key_id=provider_api_key_id, status=status,
|
||||||
status=status,
|
cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing,
|
||||||
target_model=target_model,
|
target_model=target_model,
|
||||||
input_cost=input_cost,
|
|
||||||
output_cost=output_cost,
|
|
||||||
cache_creation_cost=cache_creation_cost,
|
|
||||||
cache_read_cost=cache_read_cost,
|
|
||||||
cache_cost=cache_cost,
|
|
||||||
request_cost=request_cost,
|
|
||||||
total_cost=total_cost,
|
|
||||||
input_price=input_price,
|
|
||||||
output_price=output_price,
|
|
||||||
cache_creation_price=cache_creation_price,
|
|
||||||
cache_read_price=cache_read_price,
|
|
||||||
request_price=request_price,
|
|
||||||
actual_rate_multiplier=actual_rate_multiplier,
|
|
||||||
is_free_tier=is_free_tier,
|
|
||||||
)
|
)
|
||||||
|
usage_params, _ = await cls._prepare_usage_record(params)
|
||||||
|
|
||||||
# 创建 Usage 记录
|
# 创建 Usage 记录
|
||||||
usage = Usage(**usage_params)
|
usage = Usage(**usage_params)
|
||||||
@@ -660,76 +766,25 @@ class UsageService:
|
|||||||
if request_id is None:
|
if request_id is None:
|
||||||
request_id = str(uuid.uuid4())[:8]
|
request_id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
# 获取费率倍数和是否免费套餐
|
# 使用共享逻辑准备记录参数
|
||||||
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
|
params = UsageRecordParams(
|
||||||
db, provider_api_key_id, provider_id
|
db=db, user=user, api_key=api_key, provider=provider, model=model,
|
||||||
)
|
input_tokens=input_tokens, output_tokens=output_tokens,
|
||||||
|
|
||||||
# 计算成本
|
|
||||||
is_failed_request = status_code >= 400 or error_message is not None
|
|
||||||
(
|
|
||||||
input_price, output_price, cache_creation_price, cache_read_price, request_price,
|
|
||||||
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
|
|
||||||
request_cost, total_cost, _tier_index
|
|
||||||
) = await cls._calculate_costs(
|
|
||||||
db=db,
|
|
||||||
provider=provider,
|
|
||||||
model=model,
|
|
||||||
input_tokens=input_tokens,
|
|
||||||
output_tokens=output_tokens,
|
|
||||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||||
cache_read_input_tokens=cache_read_input_tokens,
|
cache_read_input_tokens=cache_read_input_tokens,
|
||||||
api_format=api_format,
|
request_type=request_type, api_format=api_format, is_stream=is_stream,
|
||||||
cache_ttl_minutes=cache_ttl_minutes,
|
response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms,
|
||||||
use_tiered_pricing=use_tiered_pricing,
|
status_code=status_code, error_message=error_message, metadata=metadata,
|
||||||
is_failed_request=is_failed_request,
|
request_headers=request_headers, request_body=request_body,
|
||||||
)
|
|
||||||
|
|
||||||
# 构建 Usage 参数
|
|
||||||
usage_params = cls._build_usage_params(
|
|
||||||
db=db,
|
|
||||||
user=user,
|
|
||||||
api_key=api_key,
|
|
||||||
provider=provider,
|
|
||||||
model=model,
|
|
||||||
input_tokens=input_tokens,
|
|
||||||
output_tokens=output_tokens,
|
|
||||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
|
||||||
cache_read_input_tokens=cache_read_input_tokens,
|
|
||||||
request_type=request_type,
|
|
||||||
api_format=api_format,
|
|
||||||
is_stream=is_stream,
|
|
||||||
response_time_ms=response_time_ms,
|
|
||||||
first_byte_time_ms=first_byte_time_ms,
|
|
||||||
status_code=status_code,
|
|
||||||
error_message=error_message,
|
|
||||||
metadata=metadata,
|
|
||||||
request_headers=request_headers,
|
|
||||||
request_body=request_body,
|
|
||||||
provider_request_headers=provider_request_headers,
|
provider_request_headers=provider_request_headers,
|
||||||
response_headers=response_headers,
|
response_headers=response_headers, response_body=response_body,
|
||||||
response_body=response_body,
|
request_id=request_id, provider_id=provider_id,
|
||||||
request_id=request_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_endpoint_id=provider_endpoint_id,
|
provider_endpoint_id=provider_endpoint_id,
|
||||||
provider_api_key_id=provider_api_key_id,
|
provider_api_key_id=provider_api_key_id, status=status,
|
||||||
status=status,
|
cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing,
|
||||||
target_model=target_model,
|
target_model=target_model,
|
||||||
input_cost=input_cost,
|
|
||||||
output_cost=output_cost,
|
|
||||||
cache_creation_cost=cache_creation_cost,
|
|
||||||
cache_read_cost=cache_read_cost,
|
|
||||||
cache_cost=cache_cost,
|
|
||||||
request_cost=request_cost,
|
|
||||||
total_cost=total_cost,
|
|
||||||
input_price=input_price,
|
|
||||||
output_price=output_price,
|
|
||||||
cache_creation_price=cache_creation_price,
|
|
||||||
cache_read_price=cache_read_price,
|
|
||||||
request_price=request_price,
|
|
||||||
actual_rate_multiplier=actual_rate_multiplier,
|
|
||||||
is_free_tier=is_free_tier,
|
|
||||||
)
|
)
|
||||||
|
usage_params, total_cost = await cls._prepare_usage_record(params)
|
||||||
|
|
||||||
# 检查是否已存在相同 request_id 的记录
|
# 检查是否已存在相同 request_id 的记录
|
||||||
existing_usage = db.query(Usage).filter(Usage.request_id == request_id).first()
|
existing_usage = db.query(Usage).filter(Usage.request_id == request_id).first()
|
||||||
@@ -751,7 +806,7 @@ class UsageService:
|
|||||||
api_key = db.merge(api_key)
|
api_key = db.merge(api_key)
|
||||||
|
|
||||||
# 使用原子更新避免并发竞态条件
|
# 使用原子更新避免并发竞态条件
|
||||||
from sqlalchemy import func, update
|
from sqlalchemy import func as sql_func, update
|
||||||
from src.models.database import ApiKey as ApiKeyModel, User as UserModel, GlobalModel
|
from src.models.database import ApiKey as ApiKeyModel, User as UserModel, GlobalModel
|
||||||
|
|
||||||
# 更新用户使用量(独立 Key 不计入创建者的使用记录)
|
# 更新用户使用量(独立 Key 不计入创建者的使用记录)
|
||||||
@@ -762,7 +817,7 @@ class UsageService:
|
|||||||
.values(
|
.values(
|
||||||
used_usd=UserModel.used_usd + total_cost,
|
used_usd=UserModel.used_usd + total_cost,
|
||||||
total_usd=UserModel.total_usd + total_cost,
|
total_usd=UserModel.total_usd + total_cost,
|
||||||
updated_at=func.now(),
|
updated_at=sql_func.now(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -776,8 +831,8 @@ class UsageService:
|
|||||||
total_requests=ApiKeyModel.total_requests + 1,
|
total_requests=ApiKeyModel.total_requests + 1,
|
||||||
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
|
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
|
||||||
balance_used_usd=ApiKeyModel.balance_used_usd + total_cost,
|
balance_used_usd=ApiKeyModel.balance_used_usd + total_cost,
|
||||||
last_used_at=func.now(),
|
last_used_at=sql_func.now(),
|
||||||
updated_at=func.now(),
|
updated_at=sql_func.now(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -787,8 +842,8 @@ class UsageService:
|
|||||||
.values(
|
.values(
|
||||||
total_requests=ApiKeyModel.total_requests + 1,
|
total_requests=ApiKeyModel.total_requests + 1,
|
||||||
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
|
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
|
||||||
last_used_at=func.now(),
|
last_used_at=sql_func.now(),
|
||||||
updated_at=func.now(),
|
updated_at=sql_func.now(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1121,19 +1176,48 @@ class UsageService:
|
|||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cleanup_old_usage_records(db: Session, days_to_keep: int = 90) -> int:
|
def cleanup_old_usage_records(
|
||||||
"""清理旧的使用记录"""
|
db: Session, days_to_keep: int = 90, batch_size: int = 1000
|
||||||
|
) -> int:
|
||||||
|
"""清理旧的使用记录(分批删除避免长事务锁定)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
days_to_keep: 保留天数,默认 90 天
|
||||||
|
batch_size: 每批删除数量,默认 1000 条
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
删除的总记录数
|
||||||
|
"""
|
||||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||||
|
total_deleted = 0
|
||||||
|
|
||||||
# 删除旧记录
|
while True:
|
||||||
deleted = db.query(Usage).filter(Usage.created_at < cutoff_date).delete()
|
# 查询待删除的 ID(使用新索引 idx_usage_user_created)
|
||||||
|
batch_ids = (
|
||||||
|
db.query(Usage.id)
|
||||||
|
.filter(Usage.created_at < cutoff_date)
|
||||||
|
.limit(batch_size)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
db.commit()
|
if not batch_ids:
|
||||||
|
break
|
||||||
|
|
||||||
logger.info(f"清理使用记录: 删除 {deleted} 条超过 {days_to_keep} 天的记录")
|
# 批量删除
|
||||||
|
deleted_count = (
|
||||||
|
db.query(Usage)
|
||||||
|
.filter(Usage.id.in_([row.id for row in batch_ids]))
|
||||||
|
.delete(synchronize_session=False)
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
total_deleted += deleted_count
|
||||||
|
|
||||||
return deleted
|
logger.debug(f"清理使用记录: 本批删除 {deleted_count} 条")
|
||||||
|
|
||||||
|
logger.info(f"清理使用记录: 共删除 {total_deleted} 条超过 {days_to_keep} 天的记录")
|
||||||
|
|
||||||
|
return total_deleted
|
||||||
|
|
||||||
# ========== 请求状态追踪方法 ==========
|
# ========== 请求状态追踪方法 ==========
|
||||||
|
|
||||||
@@ -1219,6 +1303,7 @@ class UsageService:
|
|||||||
error_message: Optional[str] = None,
|
error_message: Optional[str] = None,
|
||||||
provider: Optional[str] = None,
|
provider: Optional[str] = None,
|
||||||
target_model: Optional[str] = None,
|
target_model: Optional[str] = None,
|
||||||
|
first_byte_time_ms: Optional[int] = None,
|
||||||
) -> Optional[Usage]:
|
) -> Optional[Usage]:
|
||||||
"""
|
"""
|
||||||
快速更新使用记录状态
|
快速更新使用记录状态
|
||||||
@@ -1230,6 +1315,7 @@ class UsageService:
|
|||||||
error_message: 错误消息(仅在 failed 状态时使用)
|
error_message: 错误消息(仅在 failed 状态时使用)
|
||||||
provider: 提供商名称(可选,streaming 状态时更新)
|
provider: 提供商名称(可选,streaming 状态时更新)
|
||||||
target_model: 映射后的目标模型名(可选)
|
target_model: 映射后的目标模型名(可选)
|
||||||
|
first_byte_time_ms: 首字时间/TTFB(可选,streaming 状态时更新)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
更新后的 Usage 记录,如果未找到则返回 None
|
更新后的 Usage 记录,如果未找到则返回 None
|
||||||
@@ -1247,6 +1333,8 @@ class UsageService:
|
|||||||
usage.provider = provider
|
usage.provider = provider
|
||||||
if target_model:
|
if target_model:
|
||||||
usage.target_model = target_model
|
usage.target_model = target_model
|
||||||
|
if first_byte_time_ms is not None:
|
||||||
|
usage.first_byte_time_ms = first_byte_time_ms
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from typing import Any, AsyncIterator, Dict, Optional, Tuple
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -457,26 +458,32 @@ class StreamUsageTracker:
|
|||||||
|
|
||||||
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
||||||
|
|
||||||
# 更新状态为 streaming,同时更新 provider
|
|
||||||
if self.request_id:
|
|
||||||
try:
|
|
||||||
from src.services.usage.service import UsageService
|
|
||||||
UsageService.update_usage_status(
|
|
||||||
db=self.db,
|
|
||||||
request_id=self.request_id,
|
|
||||||
status="streaming",
|
|
||||||
provider=self.provider,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
|
||||||
|
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
first_chunk_received = False
|
||||||
try:
|
try:
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
# 保存原始字节流(用于错误诊断)
|
# 保存原始字节流(用于错误诊断)
|
||||||
self.raw_chunks.append(chunk)
|
self.raw_chunks.append(chunk)
|
||||||
|
|
||||||
|
# 第一个 chunk 收到时,更新状态为 streaming 并记录 TTFB
|
||||||
|
if not first_chunk_received:
|
||||||
|
first_chunk_received = True
|
||||||
|
if self.request_id:
|
||||||
|
try:
|
||||||
|
# 计算 TTFB(使用请求原始开始时间或 track_stream 开始时间)
|
||||||
|
base_time = self.request_start_time or self.start_time
|
||||||
|
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
|
||||||
|
UsageService.update_usage_status(
|
||||||
|
db=self.db,
|
||||||
|
request_id=self.request_id,
|
||||||
|
status="streaming",
|
||||||
|
provider=self.provider,
|
||||||
|
first_byte_time_ms=first_byte_time_ms,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
||||||
|
|
||||||
# 返回原始块给客户端
|
# 返回原始块给客户端
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|||||||
@@ -139,3 +139,83 @@ async def with_timeout_context(timeout: float, operation_name: str = "operation"
|
|||||||
# Python 3.10 及以下版本的兼容实现
|
# Python 3.10 及以下版本的兼容实现
|
||||||
# 注意:这个简单实现不支持嵌套取消
|
# 注意:这个简单实现不支持嵌套取消
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def read_first_chunk_with_ttfb_timeout(
|
||||||
|
byte_iterator: Any,
|
||||||
|
timeout: float,
|
||||||
|
request_id: str,
|
||||||
|
provider_name: str,
|
||||||
|
) -> tuple[bytes, Any]:
|
||||||
|
"""
|
||||||
|
读取流的首字节并应用 TTFB 超时检测
|
||||||
|
|
||||||
|
首字节超时(Time To First Byte)用于检测慢响应的 Provider,
|
||||||
|
超时时触发故障转移到其他可用的 Provider。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
byte_iterator: 异步字节流迭代器
|
||||||
|
timeout: TTFB 超时时间(秒)
|
||||||
|
request_id: 请求 ID(用于日志)
|
||||||
|
provider_name: Provider 名称(用于日志和异常)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(first_chunk, aiter): 首个字节块和异步迭代器
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ProviderTimeoutException: 如果首字节超时
|
||||||
|
"""
|
||||||
|
from src.core.exceptions import ProviderTimeoutException
|
||||||
|
|
||||||
|
aiter = byte_iterator.__aiter__()
|
||||||
|
|
||||||
|
try:
|
||||||
|
first_chunk = await asyncio.wait_for(aiter.__anext__(), timeout=timeout)
|
||||||
|
return first_chunk, aiter
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# 完整的资源清理:先关闭迭代器,再关闭底层响应
|
||||||
|
await _cleanup_iterator_resources(aiter, request_id)
|
||||||
|
logger.warning(
|
||||||
|
f" [{request_id}] 流首字节超时 (TTFB): "
|
||||||
|
f"Provider={provider_name}, timeout={timeout}s"
|
||||||
|
)
|
||||||
|
raise ProviderTimeoutException(
|
||||||
|
provider_name=provider_name,
|
||||||
|
timeout=int(timeout),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _cleanup_iterator_resources(aiter: Any, request_id: str) -> None:
|
||||||
|
"""
|
||||||
|
清理异步迭代器及其底层资源
|
||||||
|
|
||||||
|
确保在 TTFB 超时后正确释放 HTTP 连接,避免连接泄漏。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
aiter: 异步迭代器
|
||||||
|
request_id: 请求 ID(用于日志)
|
||||||
|
"""
|
||||||
|
# 1. 关闭迭代器本身
|
||||||
|
if hasattr(aiter, "aclose"):
|
||||||
|
try:
|
||||||
|
await aiter.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f" [{request_id}] 关闭迭代器失败: {e}")
|
||||||
|
|
||||||
|
# 2. 关闭底层响应对象(httpx.Response)
|
||||||
|
# 迭代器可能持有 _response 属性指向底层响应
|
||||||
|
response = getattr(aiter, "_response", None)
|
||||||
|
if response is not None and hasattr(response, "aclose"):
|
||||||
|
try:
|
||||||
|
await response.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f" [{request_id}] 关闭底层响应失败: {e}")
|
||||||
|
|
||||||
|
# 3. 尝试关闭 httpx 流(如果迭代器是 httpx 的 aiter_bytes)
|
||||||
|
# httpx 的 Response.aiter_bytes() 返回的生成器可能有 _stream 属性
|
||||||
|
stream = getattr(aiter, "_stream", None)
|
||||||
|
if stream is not None and hasattr(stream, "aclose"):
|
||||||
|
try:
|
||||||
|
await stream.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f" [{request_id}] 关闭流对象失败: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user