7 Commits

Author SHA1 Message Date
fawney19
d44cfaddf6 fix: rename variable to avoid shadowing in model mapping cache stats
循环内部变量 provider_model_mappings 与外部列表同名,导致外部列表被覆盖为 None 引发 AttributeError
2025-12-23 00:38:37 +08:00
fawney19
65225710a8 refactor: use ConcurrencyDefaults for CACHE_RESERVATION_RATIO constant 2025-12-23 00:34:18 +08:00
fawney19
d7f5b16359 fix: rebuild app image when migration files change
deploy.sh was only running alembic upgrade on the old container when
migration files changed, but the migration files are baked into the
Docker image. Now it rebuilds the app image when migrations change.
2025-12-23 00:23:22 +08:00
fawney19
7185818724 fix: remove index_exists check to avoid transaction conflict in migration
- Remove index_exists function that used op.get_bind() within transaction
- Use IF NOT EXISTS / IF EXISTS SQL syntax instead
- Fixes CREATE INDEX CONCURRENTLY error in Docker migration
2025-12-23 00:21:03 +08:00
fawney19
868f3349e5 fix: use AUTOCOMMIT mode for CREATE INDEX CONCURRENTLY in migration
PostgreSQL 不允许在事务块内执行 CREATE INDEX CONCURRENTLY,
通过创建独立连接并设置 AUTOCOMMIT 隔离级别来解决此问题。
2025-12-23 00:18:11 +08:00
fawney19
d7384e69d9 fix: improve code quality and add type safety for Key updates
- Replace f-string logging with lazy formatting in keys.py (lines 256, 265)
- Add EndpointAPIKeyUpdate type interface for frontend type safety
- Use typed EndpointAPIKeyUpdate instead of any in KeyFormDialog.vue
2025-12-23 00:11:10 +08:00
fawney19
1d5c378343 feat: add TTFB timeout detection and improve stream handling
- Add stream first byte timeout (TTFB) detection to trigger failover
  when provider responds too slowly (configurable via STREAM_FIRST_BYTE_TIMEOUT)
- Add rate limit fail-open/fail-close strategy configuration
- Improve exception handling in stream prefetch with proper error classification
- Refactor UsageService with shared _prepare_usage_record method
- Add batch deletion for old usage records to avoid long transaction locks
- Update CLI adapters to use proper User-Agent headers for each CLI client
- Add composite indexes migration for usage table query optimization
- Fix streaming status display in frontend to show TTFB during streaming
- Remove sensitive JWT secret logging in auth service
2025-12-22 23:44:42 +08:00
22 changed files with 659 additions and 204 deletions

View File

@@ -0,0 +1,63 @@
"""add usage table composite indexes for query optimization
Revision ID: b2c3d4e5f6g7
Revises: a1b2c3d4e5f6
Create Date: 2025-12-20 15:00:00.000000+00:00
"""
from alembic import op
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision = 'b2c3d4e5f6g7'
down_revision = 'a1b2c3d4e5f6'
branch_labels = None
depends_on = None
def upgrade() -> None:
"""为 usage 表添加复合索引以优化常见查询
使用 CONCURRENTLY 创建索引以避免锁表,
但需要在 AUTOCOMMIT 模式下执行(不能在事务内)
"""
conn = op.get_bind()
engine = conn.engine
# 使用新连接并设置 AUTOCOMMIT 模式以支持 CREATE INDEX CONCURRENTLY
with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as autocommit_conn:
# 使用 IF NOT EXISTS 避免重复创建,无需单独检查索引是否存在
# 1. user_id + created_at 复合索引 (用户用量查询)
autocommit_conn.execute(text(
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_user_created "
"ON usage (user_id, created_at)"
))
# 2. api_key_id + created_at 复合索引 (API Key 用量查询)
autocommit_conn.execute(text(
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_apikey_created "
"ON usage (api_key_id, created_at)"
))
# 3. provider + model + created_at 复合索引 (模型统计查询)
autocommit_conn.execute(text(
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_provider_model_created "
"ON usage (provider, model, created_at)"
))
def downgrade() -> None:
"""删除复合索引"""
conn = op.get_bind()
# 使用 IF EXISTS 避免索引不存在时报错
conn.execute(text(
"DROP INDEX IF EXISTS idx_usage_provider_model_created"
))
conn.execute(text(
"DROP INDEX IF EXISTS idx_usage_apikey_created"
))
conn.execute(text(
"DROP INDEX IF EXISTS idx_usage_user_created"
))

View File

@@ -179,7 +179,13 @@ else
echo ">>> Dependencies unchanged."
fi
# 检查代码是否变化,或者 base 重建了app 依赖 base
# 检查代码或迁移是否变化,或者 base 重建了app 依赖 base
# 注意:迁移文件打包在镜像中,所以迁移变化也需要重建 app 镜像
MIGRATION_CHANGED=false
if check_migration_changed; then
MIGRATION_CHANGED=true
fi
if ! docker image inspect aether-app:latest >/dev/null 2>&1; then
echo ">>> App image not found, building..."
build_app
@@ -192,6 +198,10 @@ elif check_code_changed; then
echo ">>> Code changed, rebuilding app image..."
build_app
NEED_RESTART=true
elif [ "$MIGRATION_CHANGED" = true ]; then
echo ">>> Migration files changed, rebuilding app image..."
build_app
NEED_RESTART=true
else
echo ">>> Code unchanged."
fi
@@ -204,9 +214,9 @@ else
echo ">>> No changes detected, skipping restart."
fi
# 检查迁移变化
if check_migration_changed; then
echo ">>> Migration files changed, running database migration..."
# 检查迁移变化(如果前面已经检测到变化并重建了镜像,这里直接运行迁移)
if [ "$MIGRATION_CHANGED" = true ]; then
echo ">>> Running database migration..."
sleep 3
run_migration
else

View File

@@ -110,6 +110,24 @@ export interface EndpointAPIKey {
request_results_window?: Array<{ ts: number; ok: boolean }> // 请求结果滑动窗口
}
export interface EndpointAPIKeyUpdate {
name?: string
api_key?: string // 仅在需要更新时提供
rate_multiplier?: number
internal_priority?: number
global_priority?: number | null
max_concurrent?: number | null // null 表示切换为自适应模式
rate_limit?: number
daily_limit?: number
monthly_limit?: number
allowed_models?: string[] | null
capabilities?: Record<string, boolean> | null
cache_ttl_minutes?: number
max_probe_interval_minutes?: number
note?: string
is_active?: boolean
}
export interface EndpointHealthDetail {
api_format: string
health_score: number

View File

@@ -260,6 +260,7 @@ import {
updateEndpointKey,
getAllCapabilities,
type EndpointAPIKey,
type EndpointAPIKeyUpdate,
type ProviderEndpoint,
type CapabilityDefinition
} from '@/api/endpoints'
@@ -386,10 +387,11 @@ function loadKeyData() {
api_key: '',
rate_multiplier: props.editingKey.rate_multiplier || 1.0,
internal_priority: props.editingKey.internal_priority ?? 50,
max_concurrent: props.editingKey.max_concurrent || undefined,
rate_limit: props.editingKey.rate_limit || undefined,
daily_limit: props.editingKey.daily_limit || undefined,
monthly_limit: props.editingKey.monthly_limit || undefined,
// 保留原始的 null/undefined 状态null 表示自适应模式
max_concurrent: props.editingKey.max_concurrent ?? undefined,
rate_limit: props.editingKey.rate_limit ?? undefined,
daily_limit: props.editingKey.daily_limit ?? undefined,
monthly_limit: props.editingKey.monthly_limit ?? undefined,
cache_ttl_minutes: props.editingKey.cache_ttl_minutes ?? 5,
max_probe_interval_minutes: props.editingKey.max_probe_interval_minutes ?? 32,
note: props.editingKey.note || '',
@@ -439,12 +441,17 @@ async function handleSave() {
saving.value = true
try {
if (props.editingKey) {
// 更新
const updateData: any = {
// 更新模式
// 注意max_concurrent 需要显式发送 null 来切换到自适应模式
// undefined 会在 JSON 中被忽略,所以用 null 表示"清空/自适应"
const updateData: EndpointAPIKeyUpdate = {
name: form.value.name,
rate_multiplier: form.value.rate_multiplier,
internal_priority: form.value.internal_priority,
max_concurrent: form.value.max_concurrent,
// 显式使用 null 表示自适应模式,这样后端能区分"未提供"和"设置为 null"
// 注意:只有 max_concurrent 需要这种处理,因为它有"自适应模式"的概念
// 其他限制字段rate_limit 等)不支持"清空"操作undefined 会被 JSON 忽略即不更新
max_concurrent: form.value.max_concurrent === undefined ? null : form.value.max_concurrent,
rate_limit: form.value.rate_limit,
daily_limit: form.value.daily_limit,
monthly_limit: form.value.monthly_limit,

View File

@@ -483,9 +483,9 @@
<span
v-if="key.max_concurrent || key.is_adaptive"
class="text-muted-foreground"
:title="key.is_adaptive ? `自适应并发限制(学习值: ${key.learned_max_concurrent ?? '未学习'}` : '固定并发限制'"
:title="key.is_adaptive ? `自适应并发限制(学习值: ${key.learned_max_concurrent ?? '未学习'}` : `固定并发限制: ${key.max_concurrent}`"
>
{{ key.is_adaptive ? '自适应' : '固定' }}并发: {{ key.learned_max_concurrent || key.max_concurrent || 3 }}
{{ key.is_adaptive ? '自适应' : '固定' }}并发: {{ key.is_adaptive ? (key.learned_max_concurrent ?? '学习中') : key.max_concurrent }}
</span>
</div>
</div>

View File

@@ -366,14 +366,34 @@
</div>
</TableCell>
<TableCell class="text-right py-4 w-[70px]">
<!-- pending 状态只显示增长的总时间 -->
<div
v-if="record.status === 'pending' || record.status === 'streaming'"
v-if="record.status === 'pending'"
class="flex flex-col items-end text-xs gap-0.5"
>
<span class="text-muted-foreground">-</span>
<span class="text-primary tabular-nums">
{{ getElapsedTime(record) }}
</span>
</div>
<!-- streaming 状态首字固定 + 总时间增长 -->
<div
v-else-if="record.status === 'streaming'"
class="flex flex-col items-end text-xs gap-0.5"
>
<span
v-if="record.first_byte_time_ms != null"
class="tabular-nums"
>{{ (record.first_byte_time_ms / 1000).toFixed(2) }}s</span>
<span
v-else
class="text-muted-foreground"
>-</span>
<span class="text-primary tabular-nums">
{{ getElapsedTime(record) }}
</span>
</div>
<!-- 已完成状态首字 + 总耗时 -->
<div
v-else-if="record.response_time_ms != null"
class="flex flex-col items-end text-xs gap-0.5"

View File

@@ -246,6 +246,15 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
if "api_key" in update_data:
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
# 特殊处理 max_concurrent需要区分"未提供"和"显式设置为 null"
# 当 max_concurrent 被显式设置时(在 model_fields_set 中),即使值为 None 也应该更新
if "max_concurrent" in self.key_data.model_fields_set:
update_data["max_concurrent"] = self.key_data.max_concurrent
# 切换到自适应模式时,清空学习到的并发限制,让系统重新学习
if self.key_data.max_concurrent is None:
update_data["learned_max_concurrent"] = None
logger.info("Key %s 切换为自适应并发模式", self.key_id)
for field, value in update_data.items():
setattr(key, field, value)
key.updated_at = datetime.now(timezone.utc)
@@ -253,7 +262,7 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
db.commit()
db.refresh(key)
logger.info(f"[OK] 更新 Key: ID={self.key_id}, Updates={list(update_data.keys())}")
logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys()))
try:
decrypted_key = crypto_service.decrypt(key.api_key)

View File

@@ -947,7 +947,7 @@ class AdminClearProviderCacheAdapter(AdminApiAdapter):
class AdminCacheConfigAdapter(AdminApiAdapter):
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.services.cache.affinity_manager import CacheAffinityManager
from src.services.cache.aware_scheduler import CacheAwareScheduler
from src.config.constants import ConcurrencyDefaults
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
# 获取动态预留管理器的配置
@@ -958,7 +958,7 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
"status": "ok",
"data": {
"cache_ttl_seconds": CacheAffinityManager.DEFAULT_CACHE_TTL,
"cache_reservation_ratio": CacheAwareScheduler.CACHE_RESERVATION_RATIO,
"cache_reservation_ratio": ConcurrencyDefaults.CACHE_RESERVATION_RATIO,
"dynamic_reservation": {
"enabled": True,
"config": reservation_stats["config"],
@@ -981,7 +981,7 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
context.add_audit_metadata(
action="cache_config",
cache_ttl_seconds=CacheAffinityManager.DEFAULT_CACHE_TTL,
cache_reservation_ratio=CacheAwareScheduler.CACHE_RESERVATION_RATIO,
cache_reservation_ratio=ConcurrencyDefaults.CACHE_RESERVATION_RATIO,
dynamic_reservation_enabled=True,
)
return response
@@ -1236,7 +1236,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
try:
cached_data = json.loads(cached_str)
provider_model_name = cached_data.get("provider_model_name")
provider_model_mappings = cached_data.get("provider_model_mappings", [])
cached_model_mappings = cached_data.get("provider_model_mappings", [])
# 获取 Provider 和 GlobalModel 信息
provider = provider_map.get(provider_id)
@@ -1245,8 +1245,8 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
if provider and global_model:
# 提取映射名称
mapping_names = []
if provider_model_mappings:
for mapping_entry in provider_model_mappings:
if cached_model_mappings:
for mapping_entry in cached_model_mappings:
if isinstance(mapping_entry, dict) and mapping_entry.get("name"):
mapping_names.append(mapping_entry["name"])

View File

@@ -376,6 +376,9 @@ class BaseMessageHandler:
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
注意TTFB首字节时间由 StreamContext.record_first_byte_time() 记录,
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
Args:
request_id: 请求 ID如果不传则使用 self.request_id
"""
@@ -407,6 +410,9 @@ class BaseMessageHandler:
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
注意TTFB首字节时间由 StreamContext.record_first_byte_time() 记录,
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
Args:
ctx: 流式上下文,包含 provider_name 和 mapped_model
"""

View File

@@ -57,8 +57,10 @@ from src.models.database import (
ProviderEndpoint,
User,
)
from src.config.settings import config
from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
class CliMessageHandlerBase(BaseMessageHandler):
@@ -672,6 +674,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
首次读取时会应用 TTFB首字节超时检测超时则触发故障转移。
Args:
byte_iterator: 字节流迭代器
provider: Provider 对象
@@ -684,6 +688,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
ProviderTimeoutException: 如果首字节超时TTFB timeout
"""
prefetched_chunks: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误
@@ -704,7 +709,19 @@ class CliMessageHandlerBase(BaseMessageHandler):
else:
provider_parser = self.parser
async for chunk in byte_iterator:
# 使用共享的 TTFB 超时函数读取首字节
ttfb_timeout = config.stream_first_byte_timeout
first_chunk, aiter = await read_first_chunk_with_ttfb_timeout(
byte_iterator,
timeout=ttfb_timeout,
request_id=self.request_id,
provider_name=str(provider.name),
)
prefetched_chunks.append(first_chunk)
buffer += first_chunk
# 继续读取剩余的预读数据
async for chunk in aiter:
prefetched_chunks.append(chunk)
buffer += chunk
@@ -785,12 +802,21 @@ class CliMessageHandlerBase(BaseMessageHandler):
if should_stop or line_count >= max_prefetch_lines:
break
except EmbeddedErrorException:
# 重新抛出嵌套错误
except (EmbeddedErrorException, ProviderTimeoutException, ProviderNotAvailableException):
# 重新抛出可重试的 Provider 异常,触发故障转移
raise
except (OSError, IOError) as e:
# 网络 I/O 异常:记录警告,可能需要重试
logger.warning(
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
)
except Exception as e:
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
# 未预期的严重异常:记录错误并重新抛出,避免掩盖问题
logger.error(
f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}",
exc_info=True
)
raise
return prefetched_chunks

View File

@@ -25,10 +25,12 @@ from src.api.handlers.base.content_extractors import (
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
from src.core.exceptions import EmbeddedErrorException
from src.config.settings import config
from src.core.exceptions import EmbeddedErrorException, ProviderTimeoutException
from src.core.logger import logger
from src.models.database import Provider, ProviderEndpoint
from src.utils.sse_parser import SSEEventParser
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
@dataclass
@@ -170,6 +172,8 @@ class StreamProcessor:
某些 Provider如 Gemini可能返回 HTTP 200但在响应体中包含错误信息。
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
首次读取时会应用 TTFB首字节超时检测超时则触发故障转移。
Args:
byte_iterator: 字节流迭代器
provider: Provider 对象
@@ -182,6 +186,7 @@ class StreamProcessor:
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
ProviderTimeoutException: 如果首字节超时TTFB timeout
"""
prefetched_chunks: list = []
parser = self.get_parser_for_provider(ctx)
@@ -192,7 +197,19 @@ class StreamProcessor:
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
try:
async for chunk in byte_iterator:
# 使用共享的 TTFB 超时函数读取首字节
ttfb_timeout = config.stream_first_byte_timeout
first_chunk, aiter = await read_first_chunk_with_ttfb_timeout(
byte_iterator,
timeout=ttfb_timeout,
request_id=self.request_id,
provider_name=str(provider.name),
)
prefetched_chunks.append(first_chunk)
buffer += first_chunk
# 继续读取剩余的预读数据
async for chunk in aiter:
prefetched_chunks.append(chunk)
buffer += chunk
@@ -262,10 +279,21 @@ class StreamProcessor:
if should_stop or line_count >= max_prefetch_lines:
break
except EmbeddedErrorException:
except (EmbeddedErrorException, ProviderTimeoutException):
# 重新抛出可重试的 Provider 异常,触发故障转移
raise
except (OSError, IOError) as e:
# 网络 I/O <20><><EFBFBD>记录警告可能需要重试
logger.warning(
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
)
except Exception as e:
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
# 未预期的严重异常:记录错误并重新抛出,避免掩盖问题
logger.error(
f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}",
exc_info=True
)
raise
return prefetched_chunks

View File

@@ -115,7 +115,7 @@ class ClaudeCliAdapter(CliAdapterBase):
) -> Tuple[list, Optional[str]]:
"""查询 Claude API 支持的模型列表(带 CLI User-Agent"""
# 复用 ClaudeChatAdapter 的实现,添加 CLI User-Agent
cli_headers = {"User-Agent": config.internal_user_agent_claude}
cli_headers = {"User-Agent": config.internal_user_agent_claude_cli}
if extra_headers:
cli_headers.update(extra_headers)
models, error = await ClaudeChatAdapter.fetch_models(

View File

@@ -112,7 +112,7 @@ class GeminiCliAdapter(CliAdapterBase):
) -> Tuple[list, Optional[str]]:
"""查询 Gemini API 支持的模型列表(带 CLI User-Agent"""
# 复用 GeminiChatAdapter 的实现,添加 CLI User-Agent
cli_headers = {"User-Agent": config.internal_user_agent_gemini}
cli_headers = {"User-Agent": config.internal_user_agent_gemini_cli}
if extra_headers:
cli_headers.update(extra_headers)
models, error = await GeminiChatAdapter.fetch_models(

View File

@@ -57,7 +57,7 @@ class OpenAICliAdapter(CliAdapterBase):
) -> Tuple[list, Optional[str]]:
"""查询 OpenAI 兼容 API 支持的模型列表(带 CLI User-Agent"""
# 复用 OpenAIChatAdapter 的实现,添加 CLI User-Agent
cli_headers = {"User-Agent": config.internal_user_agent_openai}
cli_headers = {"User-Agent": config.internal_user_agent_openai_cli}
if extra_headers:
cli_headers.update(extra_headers)
models, error = await OpenAIChatAdapter.fetch_models(

View File

@@ -77,7 +77,10 @@ class ConcurrencyDefaults:
MAX_CONCURRENT_LIMIT = 200
# 最小并发限制下限
MIN_CONCURRENT_LIMIT = 1
# 设置为 3 而不是 1因为预留机制10%预留给缓存用户)会导致
# 当 learned_max_concurrent=1 时新用户实际可用槽位为 0永远无法命中
# 注意:当 limit < 10 时,预留机制实际不生效(预留槽位 = 0这是可接受的
MIN_CONCURRENT_LIMIT = 3
# === 探测性扩容参数 ===
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容

View File

@@ -56,10 +56,11 @@ class Config:
# Redis 依赖策略(生产默认必需,开发默认可选,可通过 REDIS_REQUIRED 覆盖)
redis_required_env = os.getenv("REDIS_REQUIRED")
if redis_required_env is None:
self.require_redis = self.environment not in {"development", "test", "testing"}
else:
if redis_required_env is not None:
self.require_redis = redis_required_env.lower() == "true"
else:
# 保持向后兼容:开发环境可选,生产环境必需
self.require_redis = self.environment not in {"development", "test", "testing"}
# CORS配置 - 使用环境变量配置允许的源
# 格式: 逗号分隔的域名列表,如 "http://localhost:3000,https://example.com"
@@ -133,6 +134,18 @@ class Config:
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1"))
# 限流降级策略配置
# RATE_LIMIT_FAIL_OPEN: 当限流服务Redis异常时的行为
#
# True (默认): fail-open - 放行请求(优先可用性)
# 风险Redis 故障期间无法限流,可能被滥用
# 适用API 网关作为关键基础设施,必须保持高可用
#
# False: fail-close - 拒绝所有请求(优先安全性)
# 风险Redis 故障会导致 API 网关不可用
# 适用:有严格速率限制要求的安全敏感场景
self.rate_limit_fail_open = os.getenv("RATE_LIMIT_FAIL_OPEN", "true").lower() == "true"
# HTTP 请求超时配置(秒)
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
@@ -141,19 +154,22 @@ class Config:
# 流式处理配置
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
# STREAM_FIRST_BYTE_TIMEOUT: 首字节超时(秒),等待首字节超过此时间触发故障转移
# 范围: 10-120 秒,默认 30 秒(必须小于 http_write_timeout 避免竞态)
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
self.stream_first_byte_timeout = self._parse_ttfb_timeout()
# 内部请求 User-Agent 配置(用于查询上游模型列表等)
# 可通过环境变量覆盖默认值
self.internal_user_agent_claude = os.getenv(
"CLAUDE_USER_AGENT", "claude-cli/1.0"
# 可通过环境变量覆盖默认值,模拟对应 CLI 客户端
self.internal_user_agent_claude_cli = os.getenv(
"CLAUDE_CLI_USER_AGENT", "claude-code/1.0.1"
)
self.internal_user_agent_openai = os.getenv(
"OPENAI_USER_AGENT", "openai-cli/1.0"
self.internal_user_agent_openai_cli = os.getenv(
"OPENAI_CLI_USER_AGENT", "openai-codex/1.0"
)
self.internal_user_agent_gemini = os.getenv(
"GEMINI_USER_AGENT", "gemini-cli/1.0"
self.internal_user_agent_gemini_cli = os.getenv(
"GEMINI_CLI_USER_AGENT", "gemini-cli/0.1.0"
)
# 验证连接池配置
@@ -177,6 +193,39 @@ class Config:
"""智能计算最大溢出连接数 - 与 pool_size 相同"""
return self.db_pool_size
def _parse_ttfb_timeout(self) -> float:
"""
解析 TTFB 超时配置,带错误处理和范围限制
TTFB (Time To First Byte) 用于检测慢响应的 Provider超时触发故障转移。
此值必须小于 http_write_timeout避免竞态条件。
Returns:
超时时间(秒),范围 10-120默认 30
"""
default_timeout = 30.0
min_timeout = 10.0
max_timeout = 120.0 # 必须小于 http_write_timeout (默认 60s) 的 2 倍
raw_value = os.getenv("STREAM_FIRST_BYTE_TIMEOUT", str(default_timeout))
try:
timeout = float(raw_value)
except ValueError:
# 延迟导入避免循环依赖Config 初始化时 logger 可能未就绪)
self._ttfb_config_warning = (
f"无效的 STREAM_FIRST_BYTE_TIMEOUT 配置 '{raw_value}',使用默认值 {default_timeout}"
)
return default_timeout
# 范围限制
clamped = max(min_timeout, min(max_timeout, timeout))
if clamped != timeout:
self._ttfb_config_warning = (
f"STREAM_FIRST_BYTE_TIMEOUT={timeout}秒超出范围 [{min_timeout}-{max_timeout}]"
f"已调整为 {clamped}"
)
return clamped
def _validate_pool_config(self) -> None:
"""验证连接池配置是否安全"""
total_per_worker = self.db_pool_size + self.db_max_overflow
@@ -224,6 +273,10 @@ class Config:
if hasattr(self, "_pool_config_warning") and self._pool_config_warning:
logger.warning(self._pool_config_warning)
# TTFB 超时配置警告
if hasattr(self, "_ttfb_config_warning") and self._ttfb_config_warning:
logger.warning(self._ttfb_config_warning)
# 管理员密码检查(必须在环境变量中设置)
if hasattr(self, "_missing_admin_password") and self._missing_admin_password:
logger.error("必须设置 ADMIN_PASSWORD 环境变量!")

View File

@@ -336,10 +336,44 @@ class PluginMiddleware:
)
return result
return None
except ConnectionError as e:
# Redis 连接错误:根据配置决定
logger.warning(f"Rate limit connection error: {e}")
if config.rate_limit_fail_open:
return None
else:
return RateLimitResult(
allowed=False,
remaining=0,
retry_after=30,
message="Rate limit service unavailable"
)
except TimeoutError as e:
# 超时错误:可能是负载过高,根据配置决定
logger.warning(f"Rate limit timeout: {e}")
if config.rate_limit_fail_open:
return None
else:
return RateLimitResult(
allowed=False,
remaining=0,
retry_after=30,
message="Rate limit service timeout"
)
except Exception as e:
logger.error(f"Rate limit error: {e}")
# 发生错误时允许请求通过
return None
logger.error(f"Rate limit error: {type(e).__name__}: {e}")
# 其他异常:根据配置决定
if config.rate_limit_fail_open:
# fail-open: 异常时放行请求(优先可用性)
return None
else:
# fail-close: 异常时拒绝请求(优先安全性)
return RateLimitResult(
allowed=False,
remaining=0,
retry_after=60,
message="Rate limit service error"
)
async def _call_pre_request_plugins(self, request: Request) -> None:
"""调用请求前的插件(当前保留扩展点)"""

View File

@@ -226,8 +226,11 @@ class EndpointAPIKeyUpdate(BaseModel):
global_priority: Optional[int] = Field(
default=None, description="全局 Key 优先级(全局 Key 优先模式,数字越小越优先)"
)
# 注意:max_concurrent=None 表示不更新,要切换为自适应模式请使用专用 API
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
# max_concurrent: 使用特殊标记区分"未提供"和"设置为 null自适应模式"
# - 不提供字段:不更新
# - 提供 null切换为自适应模式
# - 提供数字:设置固定并发限制
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数null=自适应模式)")
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制")
monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制")

View File

@@ -27,7 +27,7 @@ if not config.jwt_secret_key:
if config.environment == "production":
raise ValueError("JWT_SECRET_KEY must be set in production environment!")
config.jwt_secret_key = secrets.token_urlsafe(32)
logger.warning(f"JWT_SECRET_KEY未在环境变量中找到已生成随机密钥用于开发: {config.jwt_secret_key[:10]}...")
logger.warning("JWT_SECRET_KEY未在环境变量中找到已生成随机密钥用于开发")
logger.warning("生产环境请设置JWT_SECRET_KEY环境变量!")
JWT_SECRET_KEY = config.jwt_secret_key

View File

@@ -3,6 +3,7 @@
"""
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple
@@ -16,6 +17,71 @@ from src.services.model.cost import ModelCostService
from src.services.system.config import SystemConfigService
@dataclass
class UsageRecordParams:
"""用量记录参数数据类,用于在内部方法间传递数据"""
db: Session
user: Optional[User]
api_key: Optional[ApiKey]
provider: str
model: str
input_tokens: int
output_tokens: int
cache_creation_input_tokens: int
cache_read_input_tokens: int
request_type: str
api_format: Optional[str]
is_stream: bool
response_time_ms: Optional[int]
first_byte_time_ms: Optional[int]
status_code: int
error_message: Optional[str]
metadata: Optional[Dict[str, Any]]
request_headers: Optional[Dict[str, Any]]
request_body: Optional[Any]
provider_request_headers: Optional[Dict[str, Any]]
response_headers: Optional[Dict[str, Any]]
response_body: Optional[Any]
request_id: str
provider_id: Optional[str]
provider_endpoint_id: Optional[str]
provider_api_key_id: Optional[str]
status: str
cache_ttl_minutes: Optional[int]
use_tiered_pricing: bool
target_model: Optional[str]
def __post_init__(self) -> None:
"""验证关键字段,确保数据完整性"""
# Token 数量不能为负数
if self.input_tokens < 0:
raise ValueError(f"input_tokens 不能为负数: {self.input_tokens}")
if self.output_tokens < 0:
raise ValueError(f"output_tokens 不能为负数: {self.output_tokens}")
if self.cache_creation_input_tokens < 0:
raise ValueError(
f"cache_creation_input_tokens 不能为负数: {self.cache_creation_input_tokens}"
)
if self.cache_read_input_tokens < 0:
raise ValueError(
f"cache_read_input_tokens 不能为负数: {self.cache_read_input_tokens}"
)
# 响应时间不能为负数
if self.response_time_ms is not None and self.response_time_ms < 0:
raise ValueError(f"response_time_ms 不能为负数: {self.response_time_ms}")
if self.first_byte_time_ms is not None and self.first_byte_time_ms < 0:
raise ValueError(f"first_byte_time_ms 不能为负数: {self.first_byte_time_ms}")
# HTTP 状态码范围校验
if not (100 <= self.status_code <= 599):
raise ValueError(f"无效的 HTTP 状态码: {self.status_code}")
# 状态值校验
valid_statuses = {"pending", "streaming", "completed", "failed"}
if self.status not in valid_statuses:
raise ValueError(f"无效的状态值: {self.status},有效值: {valid_statuses}")
class UsageService:
"""用量统计服务"""
@@ -471,6 +537,97 @@ class UsageService:
cache_ttl_minutes=cache_ttl_minutes,
)
@classmethod
async def _prepare_usage_record(
cls,
params: UsageRecordParams,
) -> Tuple[Dict[str, Any], float]:
"""准备用量记录的共享逻辑
此方法提取了 record_usage 和 record_usage_async 的公共处理逻辑:
- 获取费率倍数
- 计算成本
- 构建 Usage 参数
Args:
params: 用量记录参数数据类
Returns:
(usage_params 字典, total_cost 总成本)
"""
# 获取费率倍数和是否免费套餐
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
params.db, params.provider_api_key_id, params.provider_id
)
# 计算成本
is_failed_request = params.status_code >= 400 or params.error_message is not None
(
input_price, output_price, cache_creation_price, cache_read_price, request_price,
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
request_cost, total_cost, _tier_index
) = await cls._calculate_costs(
db=params.db,
provider=params.provider,
model=params.model,
input_tokens=params.input_tokens,
output_tokens=params.output_tokens,
cache_creation_input_tokens=params.cache_creation_input_tokens,
cache_read_input_tokens=params.cache_read_input_tokens,
api_format=params.api_format,
cache_ttl_minutes=params.cache_ttl_minutes,
use_tiered_pricing=params.use_tiered_pricing,
is_failed_request=is_failed_request,
)
# 构建 Usage 参数
usage_params = cls._build_usage_params(
db=params.db,
user=params.user,
api_key=params.api_key,
provider=params.provider,
model=params.model,
input_tokens=params.input_tokens,
output_tokens=params.output_tokens,
cache_creation_input_tokens=params.cache_creation_input_tokens,
cache_read_input_tokens=params.cache_read_input_tokens,
request_type=params.request_type,
api_format=params.api_format,
is_stream=params.is_stream,
response_time_ms=params.response_time_ms,
first_byte_time_ms=params.first_byte_time_ms,
status_code=params.status_code,
error_message=params.error_message,
metadata=params.metadata,
request_headers=params.request_headers,
request_body=params.request_body,
provider_request_headers=params.provider_request_headers,
response_headers=params.response_headers,
response_body=params.response_body,
request_id=params.request_id,
provider_id=params.provider_id,
provider_endpoint_id=params.provider_endpoint_id,
provider_api_key_id=params.provider_api_key_id,
status=params.status,
target_model=params.target_model,
input_cost=input_cost,
output_cost=output_cost,
cache_creation_cost=cache_creation_cost,
cache_read_cost=cache_read_cost,
cache_cost=cache_cost,
request_cost=request_cost,
total_cost=total_cost,
input_price=input_price,
output_price=output_price,
cache_creation_price=cache_creation_price,
cache_read_price=cache_read_price,
request_price=request_price,
actual_rate_multiplier=actual_rate_multiplier,
is_free_tier=is_free_tier,
)
return usage_params, total_cost
@classmethod
async def record_usage_async(
cls,
@@ -516,76 +673,25 @@ class UsageService:
if request_id is None:
request_id = str(uuid.uuid4())[:8]
# 获取费率倍数和是否免费套餐
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
db, provider_api_key_id, provider_id
)
# 计算成本
is_failed_request = status_code >= 400 or error_message is not None
(
input_price, output_price, cache_creation_price, cache_read_price, request_price,
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
request_cost, total_cost, tier_index
) = await cls._calculate_costs(
db=db,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
# 使用共享逻辑准备记录参数
params = UsageRecordParams(
db=db, user=user, api_key=api_key, provider=provider, model=model,
input_tokens=input_tokens, output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
api_format=api_format,
cache_ttl_minutes=cache_ttl_minutes,
use_tiered_pricing=use_tiered_pricing,
is_failed_request=is_failed_request,
)
# 构建 Usage 参数
usage_params = cls._build_usage_params(
db=db,
user=user,
api_key=api_key,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
request_type=request_type,
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms,
status_code=status_code,
error_message=error_message,
metadata=metadata,
request_headers=request_headers,
request_body=request_body,
request_type=request_type, api_format=api_format, is_stream=is_stream,
response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms,
status_code=status_code, error_message=error_message, metadata=metadata,
request_headers=request_headers, request_body=request_body,
provider_request_headers=provider_request_headers,
response_headers=response_headers,
response_body=response_body,
request_id=request_id,
provider_id=provider_id,
response_headers=response_headers, response_body=response_body,
request_id=request_id, provider_id=provider_id,
provider_endpoint_id=provider_endpoint_id,
provider_api_key_id=provider_api_key_id,
status=status,
provider_api_key_id=provider_api_key_id, status=status,
cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing,
target_model=target_model,
input_cost=input_cost,
output_cost=output_cost,
cache_creation_cost=cache_creation_cost,
cache_read_cost=cache_read_cost,
cache_cost=cache_cost,
request_cost=request_cost,
total_cost=total_cost,
input_price=input_price,
output_price=output_price,
cache_creation_price=cache_creation_price,
cache_read_price=cache_read_price,
request_price=request_price,
actual_rate_multiplier=actual_rate_multiplier,
is_free_tier=is_free_tier,
)
usage_params, _ = await cls._prepare_usage_record(params)
# 创建 Usage 记录
usage = Usage(**usage_params)
@@ -660,76 +766,25 @@ class UsageService:
if request_id is None:
request_id = str(uuid.uuid4())[:8]
# 获取费率倍数和是否免费套餐
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
db, provider_api_key_id, provider_id
)
# 计算成本
is_failed_request = status_code >= 400 or error_message is not None
(
input_price, output_price, cache_creation_price, cache_read_price, request_price,
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
request_cost, total_cost, _tier_index
) = await cls._calculate_costs(
db=db,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
# 使用共享逻辑准备记录参数
params = UsageRecordParams(
db=db, user=user, api_key=api_key, provider=provider, model=model,
input_tokens=input_tokens, output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
api_format=api_format,
cache_ttl_minutes=cache_ttl_minutes,
use_tiered_pricing=use_tiered_pricing,
is_failed_request=is_failed_request,
)
# 构建 Usage 参数
usage_params = cls._build_usage_params(
db=db,
user=user,
api_key=api_key,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
request_type=request_type,
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms,
status_code=status_code,
error_message=error_message,
metadata=metadata,
request_headers=request_headers,
request_body=request_body,
request_type=request_type, api_format=api_format, is_stream=is_stream,
response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms,
status_code=status_code, error_message=error_message, metadata=metadata,
request_headers=request_headers, request_body=request_body,
provider_request_headers=provider_request_headers,
response_headers=response_headers,
response_body=response_body,
request_id=request_id,
provider_id=provider_id,
response_headers=response_headers, response_body=response_body,
request_id=request_id, provider_id=provider_id,
provider_endpoint_id=provider_endpoint_id,
provider_api_key_id=provider_api_key_id,
status=status,
provider_api_key_id=provider_api_key_id, status=status,
cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing,
target_model=target_model,
input_cost=input_cost,
output_cost=output_cost,
cache_creation_cost=cache_creation_cost,
cache_read_cost=cache_read_cost,
cache_cost=cache_cost,
request_cost=request_cost,
total_cost=total_cost,
input_price=input_price,
output_price=output_price,
cache_creation_price=cache_creation_price,
cache_read_price=cache_read_price,
request_price=request_price,
actual_rate_multiplier=actual_rate_multiplier,
is_free_tier=is_free_tier,
)
usage_params, total_cost = await cls._prepare_usage_record(params)
# 检查是否已存在相同 request_id 的记录
existing_usage = db.query(Usage).filter(Usage.request_id == request_id).first()
@@ -751,7 +806,7 @@ class UsageService:
api_key = db.merge(api_key)
# 使用原子更新避免并发竞态条件
from sqlalchemy import func, update
from sqlalchemy import func as sql_func, update
from src.models.database import ApiKey as ApiKeyModel, User as UserModel, GlobalModel
# 更新用户使用量(独立 Key 不计入创建者的使用记录)
@@ -762,7 +817,7 @@ class UsageService:
.values(
used_usd=UserModel.used_usd + total_cost,
total_usd=UserModel.total_usd + total_cost,
updated_at=func.now(),
updated_at=sql_func.now(),
)
)
@@ -776,8 +831,8 @@ class UsageService:
total_requests=ApiKeyModel.total_requests + 1,
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
balance_used_usd=ApiKeyModel.balance_used_usd + total_cost,
last_used_at=func.now(),
updated_at=func.now(),
last_used_at=sql_func.now(),
updated_at=sql_func.now(),
)
)
else:
@@ -787,8 +842,8 @@ class UsageService:
.values(
total_requests=ApiKeyModel.total_requests + 1,
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
last_used_at=func.now(),
updated_at=func.now(),
last_used_at=sql_func.now(),
updated_at=sql_func.now(),
)
)
@@ -1121,19 +1176,48 @@ class UsageService:
]
@staticmethod
def cleanup_old_usage_records(db: Session, days_to_keep: int = 90) -> int:
"""清理旧的使用记录"""
def cleanup_old_usage_records(
db: Session, days_to_keep: int = 90, batch_size: int = 1000
) -> int:
"""清理旧的使用记录(分批删除避免长事务锁定)
Args:
db: 数据库会话
days_to_keep: 保留天数,默认 90 天
batch_size: 每批删除数量,默认 1000 条
Returns:
删除的总记录数
"""
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
total_deleted = 0
# 删除旧记录
deleted = db.query(Usage).filter(Usage.created_at < cutoff_date).delete()
while True:
# 查询待删除的 ID使用新索引 idx_usage_user_created
batch_ids = (
db.query(Usage.id)
.filter(Usage.created_at < cutoff_date)
.limit(batch_size)
.all()
)
db.commit()
if not batch_ids:
break
logger.info(f"清理使用记录: 删除 {deleted} 条超过 {days_to_keep} 天的记录")
# 批量删除
deleted_count = (
db.query(Usage)
.filter(Usage.id.in_([row.id for row in batch_ids]))
.delete(synchronize_session=False)
)
db.commit()
total_deleted += deleted_count
return deleted
logger.debug(f"清理使用记录: 本批删除 {deleted_count}")
logger.info(f"清理使用记录: 共删除 {total_deleted} 条超过 {days_to_keep} 天的记录")
return total_deleted
# ========== 请求状态追踪方法 ==========
@@ -1219,6 +1303,7 @@ class UsageService:
error_message: Optional[str] = None,
provider: Optional[str] = None,
target_model: Optional[str] = None,
first_byte_time_ms: Optional[int] = None,
) -> Optional[Usage]:
"""
快速更新使用记录状态
@@ -1230,6 +1315,7 @@ class UsageService:
error_message: 错误消息(仅在 failed 状态时使用)
provider: 提供商名称可选streaming 状态时更新)
target_model: 映射后的目标模型名(可选)
first_byte_time_ms: 首字时间/TTFB可选streaming 状态时更新)
Returns:
更新后的 Usage 记录,如果未找到则返回 None
@@ -1247,6 +1333,8 @@ class UsageService:
usage.provider = provider
if target_model:
usage.target_model = target_model
if first_byte_time_ms is not None:
usage.first_byte_time_ms = first_byte_time_ms
db.commit()

View File

@@ -5,6 +5,7 @@
import json
import re
import time
from typing import Any, AsyncIterator, Dict, Optional, Tuple
from sqlalchemy.orm import Session
@@ -457,26 +458,32 @@ class StreamUsageTracker:
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
# 更新状态为 streaming同时更新 provider
if self.request_id:
try:
from src.services.usage.service import UsageService
UsageService.update_usage_status(
db=self.db,
request_id=self.request_id,
status="streaming",
provider=self.provider,
)
except Exception as e:
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
chunk_count = 0
first_chunk_received = False
try:
async for chunk in stream:
chunk_count += 1
# 保存原始字节流(用于错误诊断)
self.raw_chunks.append(chunk)
# 第一个 chunk 收到时,更新状态为 streaming 并记录 TTFB
if not first_chunk_received:
first_chunk_received = True
if self.request_id:
try:
# 计算 TTFB使用请求原始开始时间或 track_stream 开始时间)
base_time = self.request_start_time or self.start_time
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
UsageService.update_usage_status(
db=self.db,
request_id=self.request_id,
status="streaming",
provider=self.provider,
first_byte_time_ms=first_byte_time_ms,
)
except Exception as e:
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
# 返回原始块给客户端
yield chunk

View File

@@ -139,3 +139,83 @@ async def with_timeout_context(timeout: float, operation_name: str = "operation"
# Python 3.10 及以下版本的兼容实现
# 注意:这个简单实现不支持嵌套取消
pass
async def read_first_chunk_with_ttfb_timeout(
byte_iterator: Any,
timeout: float,
request_id: str,
provider_name: str,
) -> tuple[bytes, Any]:
"""
读取流的首字节并应用 TTFB 超时检测
首字节超时Time To First Byte用于检测慢响应的 Provider
超时时触发故障转移到其他可用的 Provider。
Args:
byte_iterator: 异步字节流迭代器
timeout: TTFB 超时时间(秒)
request_id: 请求 ID用于日志
provider_name: Provider 名称(用于日志和异常)
Returns:
(first_chunk, aiter): 首个字节块和异步迭代器
Raises:
ProviderTimeoutException: 如果首字节超时
"""
from src.core.exceptions import ProviderTimeoutException
aiter = byte_iterator.__aiter__()
try:
first_chunk = await asyncio.wait_for(aiter.__anext__(), timeout=timeout)
return first_chunk, aiter
except asyncio.TimeoutError:
# 完整的资源清理:先关闭迭代器,再关闭底层响应
await _cleanup_iterator_resources(aiter, request_id)
logger.warning(
f" [{request_id}] 流首字节超时 (TTFB): "
f"Provider={provider_name}, timeout={timeout}s"
)
raise ProviderTimeoutException(
provider_name=provider_name,
timeout=int(timeout),
)
async def _cleanup_iterator_resources(aiter: Any, request_id: str) -> None:
"""
清理异步迭代器及其底层资源
确保在 TTFB 超时后正确释放 HTTP 连接,避免连接泄漏。
Args:
aiter: 异步迭代器
request_id: 请求 ID用于日志
"""
# 1. 关闭迭代器本身
if hasattr(aiter, "aclose"):
try:
await aiter.aclose()
except Exception as e:
logger.debug(f" [{request_id}] 关闭迭代器失败: {e}")
# 2. 关闭底层响应对象httpx.Response
# 迭代器可能持有 _response 属性指向底层响应
response = getattr(aiter, "_response", None)
if response is not None and hasattr(response, "aclose"):
try:
await response.aclose()
except Exception as e:
logger.debug(f" [{request_id}] 关闭底层响应失败: {e}")
# 3. 尝试关闭 httpx 流(如果迭代器是 httpx 的 aiter_bytes
# httpx 的 Response.aiter_bytes() 返回的生成器可能有 _stream 属性
stream = getattr(aiter, "_stream", None)
if stream is not None and hasattr(stream, "aclose"):
try:
await stream.aclose()
except Exception as e:
logger.debug(f" [{request_id}] 关闭流对象失败: {e}")