feat: 流式预读增强与自适应并发算法优化

流式预读增强:
- 新增预读字节上限(64KB),防止无换行响应导致内存增长
- 预读结束后检测非 SSE 格式的错误响应(HTML 页面、纯 JSON 错误)
- 抽取 check_html_response 和 check_prefetched_response_error 到 utils.py

自适应并发算法优化(边界记忆 + 渐进探测):
- 缩容策略:从乘性减少改为边界 -1,一次 429 即可收敛到真实限制附近
- 扩容策略:普通扩容不超过已知边界,探测性扩容可谨慎突破(每次 +1)
- 仅在并发限制 429 时记录边界,避免 RPM/UNKNOWN 类型覆盖
This commit is contained in:
fawney19
2026-01-05 12:17:45 +08:00
parent 4fa9a1303a
commit e5f12fddd9
5 changed files with 314 additions and 60 deletions

View File

@@ -34,7 +34,11 @@ from src.api.handlers.base.base_handler import (
from src.api.handlers.base.parsers import get_parser_for_format from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.stream_context import StreamContext from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.utils import build_sse_headers from src.api.handlers.base.utils import (
build_sse_headers,
check_html_response,
check_prefetched_response_error,
)
from src.core.error_utils import extract_error_message from src.core.error_utils import extract_error_message
# 直接从具体模块导入,避免循环依赖 # 直接从具体模块导入,避免循环依赖
@@ -58,6 +62,7 @@ from src.models.database import (
ProviderEndpoint, ProviderEndpoint,
User, User,
) )
from src.config.constants import StreamDefaults
from src.config.settings import config from src.config.settings import config
from src.services.provider.transport import build_provider_url from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser from src.utils.sse_parser import SSEEventParser
@@ -703,7 +708,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
ProviderTimeoutException: 如果首字节超时TTFB timeout ProviderTimeoutException: 如果首字节超时TTFB timeout
""" """
prefetched_chunks: list = [] prefetched_chunks: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误 max_prefetch_lines = config.stream_prefetch_lines # 最多预读行来检测错误
max_prefetch_bytes = StreamDefaults.MAX_PREFETCH_BYTES # 避免无换行响应导致 buffer 增长
total_prefetched_bytes = 0
buffer = b"" buffer = b""
line_count = 0 line_count = 0
should_stop = False should_stop = False
@@ -730,14 +737,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
provider_name=str(provider.name), provider_name=str(provider.name),
) )
prefetched_chunks.append(first_chunk) prefetched_chunks.append(first_chunk)
total_prefetched_bytes += len(first_chunk)
buffer += first_chunk buffer += first_chunk
# 继续读取剩余的预读数据 # 继续读取剩余的预读数据
async for chunk in aiter: async for chunk in aiter:
prefetched_chunks.append(chunk) prefetched_chunks.append(chunk)
total_prefetched_bytes += len(chunk)
buffer += chunk buffer += chunk
# 尝试按行解析缓冲区 # 尝试按行解析缓冲区SSE 格式)
while b"\n" in buffer: while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1) line_bytes, buffer = buffer.split(b"\n", 1)
try: try:
@@ -754,15 +763,15 @@ class CliMessageHandlerBase(BaseMessageHandler):
normalized_line = line.rstrip("\r") normalized_line = line.rstrip("\r")
# 检测 HTML 响应base_url 配置错误的常见症状) # 检测 HTML 响应base_url 配置错误的常见症状)
lower_line = normalized_line.lower() if check_html_response(normalized_line):
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
logger.error( logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: " f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., " f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}" f"base_url={endpoint.base_url}"
) )
raise ProviderNotAvailableException( raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确" f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,"
f"请检查 endpoint 的 base_url 配置是否正确"
) )
if not normalized_line or normalized_line.startswith(":"): if not normalized_line or normalized_line.startswith(":"):
@@ -811,9 +820,30 @@ class CliMessageHandlerBase(BaseMessageHandler):
should_stop = True should_stop = True
break break
# 达到预读字节上限,停止继续预读(避免无换行响应导致内存增长)
if not should_stop and total_prefetched_bytes >= max_prefetch_bytes:
logger.debug(
f" [{self.request_id}] 预读达到字节上限,停止继续预读: "
f"Provider={provider.name}, bytes={total_prefetched_bytes}, "
f"max_bytes={max_prefetch_bytes}"
)
break
if should_stop or line_count >= max_prefetch_lines: if should_stop or line_count >= max_prefetch_lines:
break break
# 预读结束后,检查是否为非 SSE 格式的 HTML/JSON 响应
# 处理某些代理返回的纯 JSON 错误(可能无换行/多行 JSON以及 HTML 页面base_url 配置错误)
if not should_stop and prefetched_chunks:
check_prefetched_response_error(
prefetched_chunks=prefetched_chunks,
parser=provider_parser,
request_id=self.request_id,
provider_name=str(provider.name),
endpoint_id=endpoint.id,
base_url=endpoint.base_url,
)
except (EmbeddedErrorException, ProviderTimeoutException, ProviderNotAvailableException): except (EmbeddedErrorException, ProviderTimeoutException, ProviderNotAvailableException):
# 重新抛出可重试的 Provider 异常,触发故障转移 # 重新抛出可重试的 Provider 异常,触发故障转移
raise raise

View File

@@ -25,8 +25,17 @@ from src.api.handlers.base.content_extractors import (
from src.api.handlers.base.parsers import get_parser_for_format from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.response_parser import ResponseParser from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.utils import (
check_html_response,
check_prefetched_response_error,
)
from src.config.constants import StreamDefaults
from src.config.settings import config from src.config.settings import config
from src.core.exceptions import EmbeddedErrorException, ProviderTimeoutException from src.core.exceptions import (
EmbeddedErrorException,
ProviderNotAvailableException,
ProviderTimeoutException,
)
from src.core.logger import logger from src.core.logger import logger
from src.models.database import Provider, ProviderEndpoint from src.models.database import Provider, ProviderEndpoint
from src.utils.sse_parser import SSEEventParser from src.utils.sse_parser import SSEEventParser
@@ -165,6 +174,7 @@ class StreamProcessor:
endpoint: ProviderEndpoint, endpoint: ProviderEndpoint,
ctx: StreamContext, ctx: StreamContext,
max_prefetch_lines: int = 5, max_prefetch_lines: int = 5,
max_prefetch_bytes: int = StreamDefaults.MAX_PREFETCH_BYTES,
) -> list: ) -> list:
""" """
预读流的前几行,检测嵌套错误 预读流的前几行,检测嵌套错误
@@ -180,12 +190,14 @@ class StreamProcessor:
endpoint: Endpoint 对象 endpoint: Endpoint 对象
ctx: 流式上下文 ctx: 流式上下文
max_prefetch_lines: 最多预读行数 max_prefetch_lines: 最多预读行数
max_prefetch_bytes: 最多预读字节数(避免无换行响应导致 buffer 增长)
Returns: Returns:
预读的字节块列表 预读的字节块列表
Raises: Raises:
EmbeddedErrorException: 如果检测到嵌套错误 EmbeddedErrorException: 如果检测到嵌套错误
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
ProviderTimeoutException: 如果首字节超时TTFB timeout ProviderTimeoutException: 如果首字节超时TTFB timeout
""" """
prefetched_chunks: list = [] prefetched_chunks: list = []
@@ -193,6 +205,7 @@ class StreamProcessor:
buffer = b"" buffer = b""
line_count = 0 line_count = 0
should_stop = False should_stop = False
total_prefetched_bytes = 0
# 使用增量解码器处理跨 chunk 的 UTF-8 字符 # 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace") decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
@@ -206,11 +219,13 @@ class StreamProcessor:
provider_name=str(provider.name), provider_name=str(provider.name),
) )
prefetched_chunks.append(first_chunk) prefetched_chunks.append(first_chunk)
total_prefetched_bytes += len(first_chunk)
buffer += first_chunk buffer += first_chunk
# 继续读取剩余的预读数据 # 继续读取剩余的预读数据
async for chunk in aiter: async for chunk in aiter:
prefetched_chunks.append(chunk) prefetched_chunks.append(chunk)
total_prefetched_bytes += len(chunk)
buffer += chunk buffer += chunk
# 尝试按行解析缓冲区 # 尝试按行解析缓冲区
@@ -228,10 +243,21 @@ class StreamProcessor:
line_count += 1 line_count += 1
# 检测 HTML 响应base_url 配置错误的常见症状)
if check_html_response(line):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,"
f"请检查 endpoint 的 base_url 配置是否正确"
)
# 跳过空行和注释行 # 跳过空行和注释行
if not line or line.startswith(":"): if not line or line.startswith(":"):
if line_count >= max_prefetch_lines: if line_count >= max_prefetch_lines:
should_stop = True
break break
continue continue
@@ -248,7 +274,6 @@ class StreamProcessor:
data = json.loads(data_str) data = json.loads(data_str)
except json.JSONDecodeError: except json.JSONDecodeError:
if line_count >= max_prefetch_lines: if line_count >= max_prefetch_lines:
should_stop = True
break break
continue continue
@@ -276,14 +301,34 @@ class StreamProcessor:
should_stop = True should_stop = True
break break
# 达到预读字节上限,停止继续预读(避免无换行响应导致内存增长)
if not should_stop and total_prefetched_bytes >= max_prefetch_bytes:
logger.debug(
f" [{self.request_id}] 预读达到字节上限,停止继续预读: "
f"Provider={provider.name}, bytes={total_prefetched_bytes}, "
f"max_bytes={max_prefetch_bytes}"
)
break
if should_stop or line_count >= max_prefetch_lines: if should_stop or line_count >= max_prefetch_lines:
break break
except (EmbeddedErrorException, ProviderTimeoutException): # 预读结束后,检查是否为非 SSE 格式的 HTML/JSON 响应
if not should_stop and prefetched_chunks:
check_prefetched_response_error(
prefetched_chunks=prefetched_chunks,
parser=parser,
request_id=self.request_id,
provider_name=str(provider.name),
endpoint_id=endpoint.id,
base_url=endpoint.base_url,
)
except (EmbeddedErrorException, ProviderNotAvailableException, ProviderTimeoutException):
# 重新抛出可重试的 Provider 异常,触发故障转移 # 重新抛出可重试的 Provider 异常,触发故障转移
raise raise
except (OSError, IOError) as e: except (OSError, IOError) as e:
# 网络 I/O <EFBFBD><EFBFBD><EFBFBD>常:记录警告,可能需要重试 # 网络 I/O 常:记录警告,可能需要重试
logger.warning( logger.warning(
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}" f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
) )

View File

@@ -2,8 +2,10 @@
Handler 基础工具函数 Handler 基础工具函数
""" """
import json
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from src.core.exceptions import EmbeddedErrorException, ProviderNotAvailableException
from src.core.logger import logger from src.core.logger import logger
@@ -107,3 +109,95 @@ def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[st
if extra_headers: if extra_headers:
headers.update(extra_headers) headers.update(extra_headers)
return headers return headers
def check_html_response(line: str) -> bool:
"""
检查行是否为 HTML 响应base_url 配置错误的常见症状)
Args:
line: 要检查的行内容
Returns:
True 如果检测到 HTML 响应
"""
lower_line = line.lstrip().lower()
return lower_line.startswith("<!doctype") or lower_line.startswith("<html")
def check_prefetched_response_error(
prefetched_chunks: list,
parser: Any,
request_id: str,
provider_name: str,
endpoint_id: Optional[str],
base_url: Optional[str],
) -> None:
"""
检查预读的响应是否为非 SSE 格式的错误响应HTML 或纯 JSON 错误)
某些代理可能返回:
1. HTML 页面base_url 配置错误)
2. 纯 JSON 错误(无换行或多行 JSON
Args:
prefetched_chunks: 预读的字节块列表
parser: 响应解析器(需要有 is_error_response 和 parse_response 方法)
request_id: 请求 ID用于日志
provider_name: Provider 名称
endpoint_id: Endpoint ID
base_url: Endpoint 的 base_url
Raises:
ProviderNotAvailableException: 如果检测到 HTML 响应
EmbeddedErrorException: 如果检测到 JSON 错误响应
"""
if not prefetched_chunks:
return
try:
prefetched_bytes = b"".join(prefetched_chunks)
stripped = prefetched_bytes.lstrip()
# 去除 BOM
if stripped.startswith(b"\xef\xbb\xbf"):
stripped = stripped[3:]
# HTML 响应(通常是 base_url 配置错误导致返回网页)
lower_prefix = stripped[:32].lower()
if lower_prefix.startswith(b"<!doctype") or lower_prefix.startswith(b"<html"):
endpoint_short = endpoint_id[:8] + "..." if endpoint_id else "N/A"
logger.error(
f" [{request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider_name}, Endpoint={endpoint_short}, "
f"base_url={base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider_name}' 返回了 HTML 页面而非 API 响应,"
f"请检查 endpoint 的 base_url 配置是否正确"
)
# 纯 JSON可能无换行/多行 JSON
if stripped.startswith(b"{") or stripped.startswith(b"["):
payload_str = stripped.decode("utf-8", errors="replace").strip()
data = json.loads(payload_str)
if isinstance(data, dict) and parser.is_error_response(data):
parsed = parser.parse_response(data, 200)
logger.warning(
f" [{request_id}] 检测到 JSON 错误响应: "
f"Provider={provider_name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}"
)
raise EmbeddedErrorException(
provider_name=provider_name,
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
except json.JSONDecodeError:
pass

View File

@@ -41,8 +41,25 @@ class CacheSize:
# ============================================================================== # ==============================================================================
class StreamDefaults:
"""流式处理默认值"""
# 预读字节上限(避免无换行响应导致内存增长)
# 64KB 基于:
# 1. SSE 单条消息通常远小于此值
# 2. 足够检测 HTML 和 JSON 错误响应
# 3. 不会占用过多内存
MAX_PREFETCH_BYTES = 64 * 1024 # 64KB
class ConcurrencyDefaults: class ConcurrencyDefaults:
"""并发控制默认值""" """并发控制默认值
算法说明:边界记忆 + 渐进探测
- 触发 429 时记录边界last_concurrent_peak新限制 = 边界 - 1
- 扩容时不超过边界,除非是探测性扩容(长时间无 429
- 这样可以快速收敛到真实限制附近,避免过度保守
"""
# 自适应并发初始限制(宽松起步,遇到 429 再降低) # 自适应并发初始限制(宽松起步,遇到 429 再降低)
INITIAL_LIMIT = 50 INITIAL_LIMIT = 50
@@ -72,10 +89,6 @@ class ConcurrencyDefaults:
# 扩容步长 - 每次扩容增加的并发数 # 扩容步长 - 每次扩容增加的并发数
INCREASE_STEP = 2 INCREASE_STEP = 2
# 缩容乘数 - 遇到 429 时基于当前并发数的缩容比例
# 0.85 表示降到触发 429 时并发数的 85%
DECREASE_MULTIPLIER = 0.85
# 最大并发限制上限 # 最大并发限制上限
MAX_CONCURRENT_LIMIT = 200 MAX_CONCURRENT_LIMIT = 200
@@ -87,6 +100,7 @@ class ConcurrencyDefaults:
# === 探测性扩容参数 === # === 探测性扩容参数 ===
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容 # 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容
# 探测性扩容可以突破已知边界,尝试更高的并发
PROBE_INCREASE_INTERVAL_MINUTES = 30 PROBE_INCREASE_INTERVAL_MINUTES = 30
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求 # 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求

View File

@@ -1,14 +1,16 @@
""" """
自适应并发调整器 - 基于滑动窗口利用率的并发限制调整 自适应并发调整器 - 基于边界记忆的并发限制调整
核心改进(相对于旧版基于"持续高利用率"的方案): 核心算法:边界记忆 + 渐进探测
- 使用滑动窗口采样,容忍并发波动 - 触发 429 时记录边界last_concurrent_peak这就是真实上限
- 基于窗口内高利用率采样比例决策,而非要求连续高利用率 - 缩容策略:新限制 = 边界 - 1而非乘性减少
- 增加探测性扩容机制,长时间稳定时主动尝试扩容 - 扩容策略:不超过已知边界,除非是探测性扩容
- 探测性扩容:长时间无 429 时尝试突破边界
AIMD 参数说明 设计原则
- 扩容:加性增加 (+INCREASE_STEP) 1. 快速收敛:一次 429 就能找到接近真实的限制
- 缩容:乘性减少 (*DECREASE_MULTIPLIER默认 0.85) 2. 避免过度保守:不会因为多次 429 而无限下降
3. 安全探测:允许在稳定后尝试更高并发
""" """
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -35,21 +37,21 @@ class AdaptiveConcurrencyManager:
""" """
自适应并发管理器 自适应并发管理器
核心算法:基于滑动窗口利用率的 AIMD 核心算法:边界记忆 + 渐进探测
- 滑动窗口记录最近 N 次请求的利用率 - 触发 429 时记录边界last_concurrent_peak = 触发时的并发数)
- 当窗口内高利用率采样比例 >= 60% 时触发扩容 - 缩容:新限制 = 边界 - 1快速收敛到真实限制附近
- 遇到 429 错误时乘性减少 (*0.85) - 扩容:不超过边界(即 last_concurrent_peak允许回到边界值尝试
- 长时间无 429 且有流量时触发探测性扩容 - 探测性扩容长时间30分钟无 429 时,可以尝试 +1 突破边界
扩容条件(满足任一即可): 扩容条件(满足任一即可):
1. 滑动窗口扩容:窗口内 >= 60% 的采样利用率 >= 70%,且不在冷却期 1. 利用率扩容:窗口内利用率比例 >= 60%,且当前限制 < 边界
2. 探测性扩容:距上次 429 超过 30 分钟,且期间有足够请求量 2. 探测性扩容:距上次 429 超过 30 分钟,可以尝试突破边界
关键特性: 关键特性:
1. 滑动窗口容忍并发波动,不会因单次低利用率重置 1. 快速收敛:一次 429 就能学到接近真实的限制值
2. 区分并发限制和 RPM 限制 2. 边界保护:普通扩容不会超过已知边界
3. 探测性扩容避免长期卡在低限制 3. 安全探测:长时间稳定后允许尝试更高并发
4. 记录调整历史 4. 区分并发限制和 RPM 限制
""" """
# 默认配置 - 使用统一常量 # 默认配置 - 使用统一常量
@@ -59,7 +61,6 @@ class AdaptiveConcurrencyManager:
# AIMD 参数 # AIMD 参数
INCREASE_STEP = ConcurrencyDefaults.INCREASE_STEP INCREASE_STEP = ConcurrencyDefaults.INCREASE_STEP
DECREASE_MULTIPLIER = ConcurrencyDefaults.DECREASE_MULTIPLIER
# 滑动窗口参数 # 滑动窗口参数
UTILIZATION_WINDOW_SIZE = ConcurrencyDefaults.UTILIZATION_WINDOW_SIZE UTILIZATION_WINDOW_SIZE = ConcurrencyDefaults.UTILIZATION_WINDOW_SIZE
@@ -115,7 +116,13 @@ class AdaptiveConcurrencyManager:
# 更新429统计 # 更新429统计
key.last_429_at = datetime.now(timezone.utc) # type: ignore[assignment] key.last_429_at = datetime.now(timezone.utc) # type: ignore[assignment]
key.last_429_type = rate_limit_info.limit_type # type: ignore[assignment] key.last_429_type = rate_limit_info.limit_type # type: ignore[assignment]
key.last_concurrent_peak = current_concurrent # type: ignore[assignment] # 仅在并发限制且拿到并发数时记录边界RPM/UNKNOWN 不应覆盖并发边界记忆)
if (
rate_limit_info.limit_type == RateLimitType.CONCURRENT
and current_concurrent is not None
and current_concurrent > 0
):
key.last_concurrent_peak = current_concurrent # type: ignore[assignment]
# 遇到 429 错误,清空利用率采样窗口(重新开始收集) # 遇到 429 错误,清空利用率采样窗口(重新开始收集)
key.utilization_samples = [] # type: ignore[assignment] key.utilization_samples = [] # type: ignore[assignment]
@@ -207,6 +214,9 @@ class AdaptiveConcurrencyManager:
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT) current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
# 获取已知边界(上次触发 429 时的并发数)
known_boundary = key.last_concurrent_peak
# 计算当前利用率 # 计算当前利用率
utilization = float(current_concurrent / current_limit) if current_limit > 0 else 0.0 utilization = float(current_concurrent / current_limit) if current_limit > 0 else 0.0
@@ -217,22 +227,29 @@ class AdaptiveConcurrencyManager:
samples = self._update_utilization_window(key, now_ts, utilization) samples = self._update_utilization_window(key, now_ts, utilization)
# 检查是否满足扩容条件 # 检查是否满足扩容条件
increase_reason = self._check_increase_conditions(key, samples, now) increase_reason = self._check_increase_conditions(key, samples, now, known_boundary)
if increase_reason and current_limit < self.MAX_CONCURRENT_LIMIT: if increase_reason and current_limit < self.MAX_CONCURRENT_LIMIT:
old_limit = current_limit old_limit = current_limit
new_limit = self._increase_limit(current_limit) is_probe = increase_reason == "probe_increase"
new_limit = self._increase_limit(current_limit, known_boundary, is_probe)
# 如果没有实际增长(已达边界),跳过
if new_limit <= old_limit:
return None
# 计算窗口统计用于日志 # 计算窗口统计用于日志
avg_util = sum(s["util"] for s in samples) / len(samples) if samples else 0 avg_util = sum(s["util"] for s in samples) / len(samples) if samples else 0
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD) high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
high_util_ratio = high_util_count / len(samples) if samples else 0 high_util_ratio = high_util_count / len(samples) if samples else 0
boundary_info = f"边界: {known_boundary}" if known_boundary else "无边界"
logger.info( logger.info(
f"[INCREASE] {increase_reason}: Key {key.id[:8]}... | " f"[INCREASE] {increase_reason}: Key {key.id[:8]}... | "
f"窗口采样: {len(samples)} | " f"窗口采样: {len(samples)} | "
f"平均利用率: {avg_util:.1%} | " f"平均利用率: {avg_util:.1%} | "
f"高利用率比例: {high_util_ratio:.1%} | " f"高利用率比例: {high_util_ratio:.1%} | "
f"{boundary_info} | "
f"调整: {old_limit} -> {new_limit}" f"调整: {old_limit} -> {new_limit}"
) )
@@ -246,13 +263,14 @@ class AdaptiveConcurrencyManager:
high_util_ratio=round(high_util_ratio, 2), high_util_ratio=round(high_util_ratio, 2),
sample_count=len(samples), sample_count=len(samples),
current_concurrent=current_concurrent, current_concurrent=current_concurrent,
known_boundary=known_boundary,
) )
# 更新限制 # 更新限制
key.learned_max_concurrent = new_limit # type: ignore[assignment] key.learned_max_concurrent = new_limit # type: ignore[assignment]
# 如果是探测性扩容,更新探测时间 # 如果是探测性扩容,更新探测时间
if increase_reason == "probe_increase": if is_probe:
key.last_probe_increase_at = now # type: ignore[assignment] key.last_probe_increase_at = now # type: ignore[assignment]
# 扩容后清空采样窗口,重新开始收集 # 扩容后清空采样窗口,重新开始收集
@@ -303,7 +321,11 @@ class AdaptiveConcurrencyManager:
return samples return samples
def _check_increase_conditions( def _check_increase_conditions(
self, key: ProviderAPIKey, samples: List[Dict[str, Any]], now: datetime self,
key: ProviderAPIKey,
samples: List[Dict[str, Any]],
now: datetime,
known_boundary: Optional[int] = None,
) -> Optional[str]: ) -> Optional[str]:
""" """
检查是否满足扩容条件 检查是否满足扩容条件
@@ -312,6 +334,7 @@ class AdaptiveConcurrencyManager:
key: API Key对象 key: API Key对象
samples: 利用率采样列表 samples: 利用率采样列表
now: 当前时间 now: 当前时间
known_boundary: 已知边界(触发 429 时的并发数)
Returns: Returns:
扩容原因(如果满足条件),否则返回 None 扩容原因(如果满足条件),否则返回 None
@@ -320,15 +343,25 @@ class AdaptiveConcurrencyManager:
if self._is_in_cooldown(key): if self._is_in_cooldown(key):
return None return None
# 条件1滑动窗口扩容 current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
# 条件1滑动窗口扩容不超过边界
if len(samples) >= self.MIN_SAMPLES_FOR_DECISION: if len(samples) >= self.MIN_SAMPLES_FOR_DECISION:
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD) high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
high_util_ratio = high_util_count / len(samples) high_util_ratio = high_util_count / len(samples)
if high_util_ratio >= self.HIGH_UTILIZATION_RATIO: if high_util_ratio >= self.HIGH_UTILIZATION_RATIO:
return "high_utilization" # 检查是否还有扩容空间(边界保护)
if known_boundary:
# 允许扩容到边界值(而非 boundary - 1因为缩容时已经 -1 了
if current_limit < known_boundary:
return "high_utilization"
# 已达边界,不触发普通扩容
else:
# 无边界信息,允许扩容
return "high_utilization"
# 条件2探测性扩容长时间无 429 且有流量) # 条件2探测性扩容长时间无 429 且有流量,可以突破边界
if self._should_probe_increase(key, samples, now): if self._should_probe_increase(key, samples, now):
return "probe_increase" return "probe_increase"
@@ -406,32 +439,65 @@ class AdaptiveConcurrencyManager:
current_concurrent: Optional[int] = None, current_concurrent: Optional[int] = None,
) -> int: ) -> int:
""" """
减少并发限制 减少并发限制(基于边界记忆策略)
策略: 策略:
- 如果知道当前并发数设置为当前并发的70% - 如果知道触发 429 时的并发数,新限制 = 并发数 - 1
- 否则,使用乘性减少 - 这样可以快速收敛到真实限制附近,而不会过度保守
- 例如:真实限制 8触发时并发 8 -> 新限制 7而非 8*0.85=6
""" """
if current_concurrent: if current_concurrent is not None and current_concurrent > 0:
# 基于当前并发数减少 # 边界记忆策略:新限制 = 触发边界 - 1
new_limit = max( candidate = current_concurrent - 1
int(current_concurrent * self.DECREASE_MULTIPLIER), self.MIN_CONCURRENT_LIMIT
)
else: else:
# 乘性减少 # 没有并发信息时,保守减少 1
new_limit = max( candidate = current_limit - 1
int(current_limit * self.DECREASE_MULTIPLIER), self.MIN_CONCURRENT_LIMIT
) # 保证不会“缩容变扩容”(例如 current_concurrent > current_limit 的异常场景)
candidate = min(candidate, current_limit - 1)
new_limit = max(candidate, self.MIN_CONCURRENT_LIMIT)
return new_limit return new_limit
def _increase_limit(self, current_limit: int) -> int: def _increase_limit(
self,
current_limit: int,
known_boundary: Optional[int] = None,
is_probe: bool = False,
) -> int:
""" """
增加并发限制 增加并发限制(考虑边界保护)
策略:加性增加 (+1) 策略:
- 普通扩容:每次 +INCREASE_STEP但不超过 known_boundary
(因为缩容时已经 -1 了,这里允许回到边界值尝试)
- 探测性扩容:每次只 +1可以突破边界但要谨慎
Args:
current_limit: 当前限制
known_boundary: 已知边界last_concurrent_peak即触发 429 时的并发数
is_probe: 是否是探测性扩容(可以突破边界)
""" """
new_limit = min(current_limit + self.INCREASE_STEP, self.MAX_CONCURRENT_LIMIT) if is_probe:
# 探测模式:每次只 +1谨慎突破边界
new_limit = current_limit + 1
else:
# 普通模式:每次 +INCREASE_STEP
new_limit = current_limit + self.INCREASE_STEP
# 边界保护:普通扩容不超过 known_boundary允许回到边界值尝试
if known_boundary:
if new_limit > known_boundary:
new_limit = known_boundary
# 全局上限保护
new_limit = min(new_limit, self.MAX_CONCURRENT_LIMIT)
# 确保有增长(否则返回原值表示不扩容)
if new_limit <= current_limit:
return current_limit
return new_limit return new_limit
def _record_adjustment( def _record_adjustment(
@@ -503,11 +569,16 @@ class AdaptiveConcurrencyManager:
if key.last_probe_increase_at: if key.last_probe_increase_at:
last_probe_at_str = cast(datetime, key.last_probe_increase_at).isoformat() last_probe_at_str = cast(datetime, key.last_probe_increase_at).isoformat()
# 边界信息
known_boundary = key.last_concurrent_peak
return { return {
"adaptive_mode": is_adaptive, "adaptive_mode": is_adaptive,
"max_concurrent": key.max_concurrent, # NULL=自适应,数字=固定限制 "max_concurrent": key.max_concurrent, # NULL=自适应,数字=固定限制
"effective_limit": effective_limit, # 当前有效限制 "effective_limit": effective_limit, # 当前有效限制
"learned_limit": key.learned_max_concurrent, # 学习到的限制 "learned_limit": key.learned_max_concurrent, # 学习到的限制
# 边界记忆相关
"known_boundary": known_boundary, # 触发 429 时的并发数(已知上限)
"concurrent_429_count": int(key.concurrent_429_count or 0), "concurrent_429_count": int(key.concurrent_429_count or 0),
"rpm_429_count": int(key.rpm_429_count or 0), "rpm_429_count": int(key.rpm_429_count or 0),
"last_429_at": last_429_at_str, "last_429_at": last_429_at_str,