refactor: 重构流式处理模块,提取 StreamContext/Processor/Telemetry

- 将 chat_handler_base.py 中的流式处理逻辑拆分为三个独立模块:
  - StreamContext: 类型安全的流式上下文数据类,替代原有的 ctx dict
  - StreamProcessor: SSE 解析、预读、嵌套错误检测
  - StreamTelemetryRecorder: 统计记录(Usage/Audit/Candidate)
- 将硬编码配置外置到 settings.py,支持环境变量覆盖:
  - HTTP 超时配置(connect/write/pool)
  - 流式处理配置(预读行数、统计延迟)
  - 并发控制配置(槽位 TTL、缓存预留比例)
This commit is contained in:
fawney19
2025-12-12 15:42:45 +08:00
parent 39defce71c
commit 53bf74429e
6 changed files with 922 additions and 600 deletions

View File

@@ -12,10 +12,13 @@ Chat Handler Base - Chat API 格式的通用基类
- apply_mapped_model(): 模型映射 - apply_mapped_model(): 模型映射
- get_model_for_url(): URL 模型名 - get_model_for_url(): URL 模型名
- _extract_usage(): 使用量提取 - _extract_usage(): 使用量提取
重构说明:
- StreamContext: 类型安全的流式上下文,替代原有的 ctx dict
- StreamProcessor: 流式响应处理SSE 解析、预读、错误检测)
- StreamTelemetryRecorder: 统计记录Usage、Audit、Candidate
""" """
import asyncio
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, Dict, Optional from typing import Any, AsyncGenerator, Callable, Dict, Optional
@@ -24,13 +27,14 @@ from fastapi import BackgroundTasks, Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.api.handlers.base.base_handler import ( from src.api.handlers.base.base_handler import BaseMessageHandler
BaseMessageHandler,
MessageTelemetry,
)
from src.api.handlers.base.parsers import get_parser_for_format from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.response_parser import ResponseParser from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.stream_processor import StreamProcessor
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
from src.config.settings import config
from src.core.exceptions import ( from src.core.exceptions import (
EmbeddedErrorException, EmbeddedErrorException,
ProviderAuthException, ProviderAuthException,
@@ -39,7 +43,6 @@ from src.core.exceptions import (
ProviderTimeoutException, ProviderTimeoutException,
) )
from src.core.logger import logger from src.core.logger import logger
from src.database import get_db
from src.models.database import ( from src.models.database import (
ApiKey, ApiKey,
Provider, Provider,
@@ -48,7 +51,6 @@ from src.models.database import (
User, User,
) )
from src.services.provider.transport import build_provider_url from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser
@@ -285,30 +287,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
model = getattr(converted_request, "model", original_request_body.get("model", "unknown")) model = getattr(converted_request, "model", original_request_body.get("model", "unknown"))
api_format = self.allowed_api_formats[0] api_format = self.allowed_api_formats[0]
# 用于跟踪的上下文 # 创建类型安全的流式上下文
ctx = { ctx = StreamContext(model=model, api_format=api_format)
"model": model,
"api_format": api_format, # 创建流处理器
"provider_name": None, stream_processor = StreamProcessor(
"provider_id": None, request_id=self.request_id,
"endpoint_id": None, default_parser=self.parser,
"key_id": None, on_streaming_start=self._update_usage_to_streaming,
"attempt_id": None, )
"provider_api_format": None, # Provider 的响应格式(用于选择解析器)
"input_tokens": 0,
"output_tokens": 0,
"cached_tokens": 0,
"cache_creation_tokens": 0,
"collected_text": "",
"status_code": 200,
"response_headers": {},
"provider_request_headers": {},
"provider_request_body": None,
"data_count": 0,
"chunk_count": 0,
"has_completion": False,
"parsed_chunks": [], # 收集解析后的 chunks
}
# 定义请求函数 # 定义请求函数
async def stream_request_func( async def stream_request_func(
@@ -318,6 +305,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
return await self._execute_stream_request( return await self._execute_stream_request(
ctx, ctx,
stream_processor,
provider, provider,
endpoint, endpoint,
key, key,
@@ -350,23 +338,39 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
is_stream=True, is_stream=True,
capability_requirements=capability_requirements or None, capability_requirements=capability_requirements or None,
) )
ctx["attempt_id"] = attempt_id
ctx["provider_name"] = provider_name # 更新上下文
ctx["provider_id"] = provider_id ctx.attempt_id = attempt_id
ctx["endpoint_id"] = endpoint_id ctx.provider_name = provider_name
ctx["key_id"] = key_id ctx.provider_id = provider_id
ctx.endpoint_id = endpoint_id
ctx.key_id = key_id
# 创建遥测记录器
telemetry_recorder = StreamTelemetryRecorder(
request_id=self.request_id,
user_id=str(self.user.id),
api_key_id=str(self.api_key.id),
client_ip=self.client_ip,
format_id=self.FORMAT_ID,
)
# 创建后台任务记录统计 # 创建后台任务记录统计
background_tasks = BackgroundTasks() background_tasks = BackgroundTasks()
background_tasks.add_task( background_tasks.add_task(
self._record_stream_stats, telemetry_recorder.record_stream_stats,
ctx, ctx,
original_headers, original_headers,
original_request_body, original_request_body,
self.elapsed_ms(),
) )
# 创建监控流 # 创建监控流
monitored_stream = self._create_monitored_stream(ctx, stream_generator, http_request) monitored_stream = stream_processor.create_monitored_stream(
ctx,
stream_generator,
http_request.is_disconnected,
)
return StreamingResponse( return StreamingResponse(
monitored_stream, monitored_stream,
@@ -381,7 +385,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
async def _execute_stream_request( async def _execute_stream_request(
self, self,
ctx: Dict, ctx: StreamContext,
stream_processor: StreamProcessor,
provider: Provider, provider: Provider,
endpoint: ProviderEndpoint, endpoint: ProviderEndpoint,
key: ProviderAPIKey, key: ProviderAPIKey,
@@ -390,37 +395,32 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
query_params: Optional[Dict[str, str]] = None, query_params: Optional[Dict[str, str]] = None,
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
"""执行流式请求并返回流生成器""" """执行流式请求并返回流生成器"""
# 重置上下文状态(重试时清除之前的数据,避免累积 # 重置上下文状态(重试时清除之前的数据)
ctx["parsed_chunks"] = [] ctx.reset_for_retry()
ctx["chunk_count"] = 0
ctx["data_count"] = 0
ctx["has_completion"] = False
ctx["collected_text"] = ""
ctx["input_tokens"] = 0
ctx["output_tokens"] = 0
ctx["cached_tokens"] = 0
ctx["cache_creation_tokens"] = 0
ctx["provider_name"] = str(provider.name) # 更新 Provider 信息
ctx["provider_id"] = str(provider.id) ctx.update_provider_info(
ctx["endpoint_id"] = str(endpoint.id) provider_name=str(provider.name),
ctx["key_id"] = str(key.id) provider_id=str(provider.id),
ctx["provider_api_format"] = str(endpoint.api_format) if endpoint.api_format else "" endpoint_id=str(endpoint.id),
key_id=str(key.id),
provider_api_format=str(endpoint.api_format) if endpoint.api_format else None,
)
# 获取模型映射 # 获取模型映射
mapped_model = await self._get_mapped_model( mapped_model = await self._get_mapped_model(
source_model=ctx["model"], source_model=ctx.model,
provider_id=str(provider.id), provider_id=str(provider.id),
) )
# 应用模型映射到请求体 # 应用模型映射到请求体
if mapped_model: if mapped_model:
ctx["mapped_model"] = mapped_model # 保存映射后的模型名,用于 Usage 记录 ctx.mapped_model = mapped_model
request_body = self.apply_mapped_model(original_request_body, mapped_model) request_body = self.apply_mapped_model(original_request_body, mapped_model)
else: else:
request_body = dict(original_request_body) request_body = dict(original_request_body)
# 准备发送给 Provider 的请求体(子类可覆盖以移除不需要的字段) # 准备发送给 Provider 的请求体
request_body = self.prepare_provider_request_body(request_body) request_body = self.prepare_provider_request_body(request_body)
# 构建请求 # 构建请求
@@ -432,11 +432,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
is_stream=True, is_stream=True,
) )
ctx["provider_request_headers"] = provider_headers ctx.provider_request_headers = provider_headers
ctx["provider_request_body"] = provider_payload ctx.provider_request_body = provider_payload
# 获取 URL 模型名(兜底使用 ctx 中的 model确保 Gemini 等格式能正确构建 URL # 获取 URL 模型名
url_model = self.get_model_for_url(request_body, mapped_model) or ctx["model"] url_model = self.get_model_for_url(request_body, mapped_model) or ctx.model
url = build_provider_url( url = build_provider_url(
endpoint, endpoint,
@@ -445,15 +445,17 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
is_stream=True, is_stream=True,
) )
logger.debug(f" [{self.request_id}] 发送流式请求: Provider={provider.name}, " logger.debug(
f"模型={ctx['model']} -> {mapped_model or '无映射'}") f" [{self.request_id}] 发送流式请求: Provider={provider.name}, "
f"模型={ctx.model} -> {mapped_model or '无映射'}"
)
# 发送请求 # 发送请求(使用配置中的超时设置)
timeout_config = httpx.Timeout( timeout_config = httpx.Timeout(
connect=10.0, connect=config.http_connect_timeout,
read=float(endpoint.timeout), read=float(endpoint.timeout),
write=60.0, # 写入超时增加到60秒支持大请求体如包含图片的长对话 write=config.http_write_timeout,
pool=10.0, pool=config.http_pool_timeout,
) )
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True) http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
@@ -463,17 +465,21 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
) )
stream_response = await response_ctx.__aenter__() stream_response = await response_ctx.__aenter__()
ctx["status_code"] = stream_response.status_code ctx.status_code = stream_response.status_code
ctx["response_headers"] = dict(stream_response.headers) ctx.response_headers = dict(stream_response.headers)
stream_response.raise_for_status() stream_response.raise_for_status()
# 创建行迭代器(只创建一次,后续会继续使用) # 创建行迭代器
line_iterator = stream_response.aiter_lines() line_iterator = stream_response.aiter_lines()
# 预读第一个数据块检测嵌套错误HTTP 200 但响应体包含错误) # 预读检测嵌套错误
prefetched_lines = await self._prefetch_and_check_embedded_error( prefetched_lines = await stream_processor.prefetch_and_check_error(
line_iterator, provider, endpoint, ctx line_iterator,
provider,
endpoint,
ctx,
max_prefetch_lines=config.stream_prefetch_lines,
) )
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
@@ -483,7 +489,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
raise raise
except EmbeddedErrorException: except EmbeddedErrorException:
# 嵌套错误需要触发重试,关闭连接后重新抛出
try: try:
await response_ctx.__aexit__(None, None, None) await response_ctx.__aexit__(None, None, None)
except Exception: except Exception:
@@ -495,8 +500,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
await http_client.aclose() await http_client.aclose()
raise raise
# 创建流生成器(带预读数据,使用同一个迭代器) # 创建流生成器
return self._create_response_stream_with_prefetch( return stream_processor.create_response_stream(
ctx, ctx,
line_iterator, line_iterator,
response_ctx, response_ctx,
@@ -504,518 +509,9 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
prefetched_lines, prefetched_lines,
) )
async def _create_response_stream(
self,
ctx: Dict,
stream_response: httpx.Response,
response_ctx: Any,
http_client: httpx.AsyncClient,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器"""
try:
sse_parser = SSEEventParser()
streaming_status_updated = False
async for line in stream_response.aiter_lines():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming()
streaming_status_updated = True
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
yield b"\n"
continue
ctx["chunk_count"] += 1
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
# 处理剩余事件
for event in sse_parser.flush():
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
except GeneratorExit:
raise
finally:
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
try:
await http_client.aclose()
except Exception:
pass
async def _prefetch_and_check_embedded_error(
self,
line_iterator: Any,
provider: Provider,
endpoint: ProviderEndpoint,
ctx: Dict,
) -> list:
"""
预读流的前几行,检测嵌套错误
某些 Provider如 Gemini可能返回 HTTP 200但在响应体中包含错误信息。
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
Args:
line_iterator: 行迭代器aiter_lines() 返回的迭代器)
provider: Provider 对象
endpoint: Endpoint 对象
ctx: 上下文字典
Returns:
预读的行列表(需要在后续流中先输出)
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
"""
prefetched_lines: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误
try:
# 获取对应格式的解析器
provider_parser = self._get_provider_parser(ctx)
line_count = 0
async for line in line_iterator:
prefetched_lines.append(line)
line_count += 1
# 解析数据
normalized_line = line.rstrip("\r")
if not normalized_line or normalized_line.startswith(":"):
# 空行或注释行,继续预读
if line_count >= max_prefetch_lines:
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
# 不是有效 JSON可能是部分数据继续
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
break
except EmbeddedErrorException:
# 重新抛出嵌套错误
raise
except Exception as e:
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_lines
async def _create_response_stream_with_prefetch(
self,
ctx: Dict,
line_iterator: Any,
response_ctx: Any,
http_client: httpx.AsyncClient,
prefetched_lines: list,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器(带预读数据)"""
try:
sse_parser = SSEEventParser()
# 在第一次输出数据前更新状态为 streaming
if prefetched_lines:
self._update_usage_to_streaming()
# 先输出预读的数据
for line in prefetched_lines:
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
yield b"\n"
continue
ctx["chunk_count"] += 1
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
# 继续输出剩余的流数据(使用同一个迭代器)
async for line in line_iterator:
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
yield b"\n"
continue
ctx["chunk_count"] += 1
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
# 处理剩余事件
for event in sse_parser.flush():
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
except GeneratorExit:
raise
finally:
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
try:
await http_client.aclose()
except Exception:
pass
def _get_provider_parser(self, ctx: Dict) -> ResponseParser:
"""
获取 Provider 格式的解析器
根据 Provider 的 API 格式选择正确的解析器,
而不是根据请求格式选择。
"""
provider_format = ctx.get("provider_api_format")
if provider_format:
try:
return get_parser_for_format(provider_format)
except KeyError:
pass
# 回退到默认解析器
return self.parser
def _handle_sse_event(
self,
ctx: Dict,
event_name: Optional[str],
data_str: str,
) -> None:
"""处理 SSE 事件"""
if not data_str:
return
if data_str == "[DONE]":
ctx["has_completion"] = True
return
try:
data = json.loads(data_str)
except json.JSONDecodeError:
return
ctx["data_count"] += 1
if not isinstance(data, dict):
return
# 收集原始 chunk 数据
ctx["parsed_chunks"].append(data)
# 根据 Provider 格式选择解析器
provider_parser = self._get_provider_parser(ctx)
# 使用解析器提取 usage
usage = provider_parser.extract_usage_from_response(data)
if usage:
ctx["input_tokens"] = usage.get("input_tokens", ctx["input_tokens"])
ctx["output_tokens"] = usage.get("output_tokens", ctx["output_tokens"])
ctx["cached_tokens"] = usage.get("cache_read_tokens", ctx["cached_tokens"])
ctx["cache_creation_tokens"] = usage.get(
"cache_creation_tokens", ctx["cache_creation_tokens"]
)
# 提取文本
text = provider_parser.extract_text_content(data)
if text:
ctx["collected_text"] += text
# 检查完成
event_type = event_name or data.get("type", "")
if event_type in ("response.completed", "message_stop"):
ctx["has_completion"] = True
async def _create_monitored_stream(
self,
ctx: Dict,
stream_generator: AsyncGenerator[bytes, None],
http_request: Request,
) -> AsyncGenerator[bytes, None]:
"""创建带监控的流生成器"""
try:
async for chunk in stream_generator:
if await http_request.is_disconnected():
logger.warning(f"ID:{self.request_id} | Client disconnected")
# 客户端断开时设置 499 状态码Client Closed Request
# 注意Provider 可能已经成功返回数据,但客户端未完整接收
ctx["status_code"] = 499
break
yield chunk
except asyncio.CancelledError:
ctx["status_code"] = 499
raise
except Exception:
ctx["status_code"] = 500
raise
async def _record_stream_stats(
self,
ctx: Dict,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
) -> None:
"""记录流式统计信息"""
response_time_ms = self.elapsed_ms()
bg_db = None
try:
await asyncio.sleep(0.1)
if not ctx["provider_name"]:
# 即使没有 provider_name也要尝试更新状态为 failed
await self._update_usage_status_on_error(
response_time_ms=response_time_ms,
error_message="Provider name not available",
)
return
db_gen = get_db()
bg_db = next(db_gen)
try:
from src.models.database import ApiKey as ApiKeyModel
user = bg_db.query(User).filter(User.id == self.user.id).first()
api_key_obj = (
bg_db.query(ApiKeyModel).filter(ApiKeyModel.id == self.api_key.id).first()
)
if not user or not api_key_obj:
logger.warning(f"[{self.request_id}] User or ApiKey not found, updating status directly")
await self._update_usage_status_directly(
bg_db,
status="completed" if ctx["status_code"] == 200 else "failed",
response_time_ms=response_time_ms,
status_code=ctx["status_code"],
)
return
bg_telemetry = MessageTelemetry(
bg_db, user, api_key_obj, self.request_id, self.client_ip
)
actual_request_body = ctx["provider_request_body"] or original_request_body
# 构建响应体(与 CLI 模式一致)
response_body = {
"chunks": ctx["parsed_chunks"],
"metadata": {
"stream": True,
"total_chunks": len(ctx["parsed_chunks"]),
"data_count": ctx["data_count"],
"has_completion": ctx["has_completion"],
"response_time_ms": response_time_ms,
},
}
# 根据状态码决定记录成功还是失败
# 499 = 客户端断开连接503 = 服务不可用(如流中断)
status_code: int = ctx.get("status_code") or 200
if status_code >= 400:
# 记录失败的 Usage但使用已收到的预估 token 信息(来自 message_start
# 这样即使请求中断,也能记录预估成本
await bg_telemetry.record_failure(
provider=ctx.get("provider_name") or "unknown",
model=ctx["model"],
response_time_ms=response_time_ms,
status_code=status_code,
error_message=ctx.get("error_message") or f"HTTP {status_code}",
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
api_format=ctx["api_format"],
provider_request_headers=ctx["provider_request_headers"],
# 预估 token 信息(来自 message_start 事件)
input_tokens=ctx.get("input_tokens", 0),
output_tokens=ctx.get("output_tokens", 0),
cache_creation_tokens=ctx.get("cache_creation_tokens", 0),
cache_read_tokens=ctx.get("cached_tokens", 0),
response_body=response_body,
# 模型映射信息
target_model=ctx.get("mapped_model"),
)
logger.debug(f"{self.FORMAT_ID} 流式响应中断")
# 简洁的请求失败摘要(包含预估 token 信息)
logger.info(f"[FAIL] {self.request_id[:8]} | {ctx['model']} | {ctx.get('provider_name', 'unknown')} | {response_time_ms}ms | "
f"{status_code} | in:{ctx.get('input_tokens', 0)} out:{ctx.get('output_tokens', 0)} cache:{ctx.get('cached_tokens', 0)}")
else:
await bg_telemetry.record_success(
provider=ctx["provider_name"],
model=ctx["model"],
input_tokens=ctx["input_tokens"],
output_tokens=ctx["output_tokens"],
response_time_ms=response_time_ms,
status_code=status_code,
request_headers=original_headers,
request_body=actual_request_body,
response_headers=ctx["response_headers"],
response_body=response_body,
cache_creation_tokens=ctx["cache_creation_tokens"],
cache_read_tokens=ctx["cached_tokens"],
is_stream=True,
provider_request_headers=ctx["provider_request_headers"],
api_format=ctx["api_format"],
provider_id=ctx["provider_id"],
provider_endpoint_id=ctx["endpoint_id"],
provider_api_key_id=ctx["key_id"],
# 模型映射信息
target_model=ctx.get("mapped_model"),
)
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
# 简洁的请求完成摘要
logger.info(f"[OK] {self.request_id[:8]} | {ctx['model']} | {ctx.get('provider_name', 'unknown')} | {response_time_ms}ms | "
f"in:{ctx.get('input_tokens', 0) or 0} out:{ctx.get('output_tokens', 0) or 0}")
# 更新候选记录的最终状态和延迟时间
# 注意RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
# 这里用流传输完成后的实际时间覆盖
if ctx.get("attempt_id"):
from src.services.request.candidate import RequestCandidateService
# 根据状态码决定是成功还是失败(复用上面已定义的 status_code
# 499 = 客户端断开连接,应标记为失败
# 503 = 服务不可用(如流中断),应标记为失败
if status_code and status_code >= 400:
RequestCandidateService.mark_candidate_failed(
db=bg_db,
candidate_id=ctx["attempt_id"],
error_type="client_disconnected" if status_code == 499 else "stream_error",
error_message=ctx.get("error_message") or f"HTTP {status_code}",
status_code=status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": False,
"data_count": ctx.get("data_count", 0),
},
)
else:
RequestCandidateService.mark_candidate_success(
db=bg_db,
candidate_id=ctx["attempt_id"],
status_code=status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": True,
"data_count": ctx.get("data_count", 0),
},
)
finally:
if bg_db:
bg_db.close()
except Exception as e:
logger.exception("记录流式统计信息时出错")
# 确保即使出错也要更新状态,避免 pending 状态卡住
await self._update_usage_status_on_error(
response_time_ms=response_time_ms,
error_message=f"记录统计信息失败: {str(e)[:200]}",
)
# _update_usage_to_streaming 方法已移至基类 BaseMessageHandler
async def _update_usage_status_on_error(
self,
response_time_ms: int,
error_message: str,
) -> None:
"""在记录失败时更新 Usage 状态,避免卡在 pending"""
try:
db_gen = get_db()
error_db = next(db_gen)
try:
await self._update_usage_status_directly(
error_db,
status="failed",
response_time_ms=response_time_ms,
status_code=500,
error_message=error_message,
)
finally:
error_db.close()
except Exception as inner_e:
logger.error(f"[{self.request_id}] 更新 Usage 状态失败: {inner_e}")
async def _update_usage_status_directly(
self,
db: Session,
status: str,
response_time_ms: int,
status_code: int = 200,
error_message: Optional[str] = None,
) -> None:
"""直接更新 Usage 表的状态字段"""
try:
from src.models.database import Usage
usage = db.query(Usage).filter(Usage.request_id == self.request_id).first()
if usage:
setattr(usage, "status", status)
setattr(usage, "status_code", status_code)
setattr(usage, "response_time_ms", response_time_ms)
if error_message:
setattr(usage, "error_message", error_message)
db.commit()
logger.debug(f"[{self.request_id}] Usage 状态已更新: {status}")
except Exception as e:
logger.error(f"[{self.request_id}] 直接更新 Usage 状态失败: {e}")
async def _record_stream_failure( async def _record_stream_failure(
self, self,
ctx: Dict, ctx: StreamContext,
error: Exception, error: Exception,
original_headers: Dict[str, str], original_headers: Dict[str, str],
original_request_body: Dict[str, Any], original_request_body: Dict[str, Any],
@@ -1031,21 +527,20 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
elif isinstance(error, ProviderTimeoutException): elif isinstance(error, ProviderTimeoutException):
status_code = 504 status_code = 504
actual_request_body = ctx.get("provider_request_body") or original_request_body actual_request_body = ctx.provider_request_body or original_request_body
await self.telemetry.record_failure( await self.telemetry.record_failure(
provider=ctx.get("provider_name") or "unknown", provider=ctx.provider_name or "unknown",
model=ctx["model"], model=ctx.model,
response_time_ms=response_time_ms, response_time_ms=response_time_ms,
status_code=status_code, status_code=status_code,
error_message=str(error), error_message=str(error),
request_headers=original_headers, request_headers=original_headers,
request_body=actual_request_body, request_body=actual_request_body,
is_stream=True, is_stream=True,
api_format=ctx["api_format"], api_format=ctx.api_format,
provider_request_headers=ctx.get("provider_request_headers") or {}, provider_request_headers=ctx.provider_request_headers,
# 模型映射信息 target_model=ctx.mapped_model,
target_model=ctx.get("mapped_model"),
) )
# ==================== 非流式处理 ==================== # ==================== 非流式处理 ====================

View File

@@ -0,0 +1,154 @@
"""
流式处理上下文 - 类型安全的数据类替代 dict
提供流式请求处理过程中的状态跟踪,包括:
- Provider/Endpoint/Key 信息
- Token 统计
- 响应状态
- 请求/响应数据
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
@dataclass
class StreamContext:
"""
流式处理上下文
用于在流式请求处理过程中跟踪状态,替代原有的 ctx dict。
所有字段都有类型注解,提供更好的 IDE 支持和运行时类型安全。
"""
# 请求基本信息
model: str
api_format: str
# Provider 信息(在请求执行时填充)
provider_name: Optional[str] = None
provider_id: Optional[str] = None
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
attempt_id: Optional[str] = None
provider_api_format: Optional[str] = None # Provider 的响应格式
# 模型映射
mapped_model: Optional[str] = None
# Token 统计
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0
cache_creation_tokens: int = 0
# 响应内容
collected_text: str = ""
# 响应状态
status_code: int = 200
error_message: Optional[str] = None
has_completion: bool = False
# 请求/响应数据
response_headers: Dict[str, str] = field(default_factory=dict)
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None
# 流式处理统计
data_count: int = 0
chunk_count: int = 0
parsed_chunks: List[Dict[str, Any]] = field(default_factory=list)
def reset_for_retry(self) -> None:
"""
重试时重置状态
在故障转移重试时调用,清除之前的数据避免累积。
保留 model 和 api_format重置其他所有状态。
"""
self.parsed_chunks = []
self.chunk_count = 0
self.data_count = 0
self.has_completion = False
self.collected_text = ""
self.input_tokens = 0
self.output_tokens = 0
self.cached_tokens = 0
self.cache_creation_tokens = 0
self.error_message = None
self.status_code = 200
self.response_headers = {}
self.provider_request_headers = {}
self.provider_request_body = None
def update_provider_info(
self,
provider_name: str,
provider_id: str,
endpoint_id: str,
key_id: str,
provider_api_format: Optional[str] = None,
) -> None:
"""更新 Provider 信息"""
self.provider_name = provider_name
self.provider_id = provider_id
self.endpoint_id = endpoint_id
self.key_id = key_id
self.provider_api_format = provider_api_format
def update_usage(
self,
input_tokens: Optional[int] = None,
output_tokens: Optional[int] = None,
cached_tokens: Optional[int] = None,
cache_creation_tokens: Optional[int] = None,
) -> None:
"""更新 Token 使用统计"""
if input_tokens is not None:
self.input_tokens = input_tokens
if output_tokens is not None:
self.output_tokens = output_tokens
if cached_tokens is not None:
self.cached_tokens = cached_tokens
if cache_creation_tokens is not None:
self.cache_creation_tokens = cache_creation_tokens
def mark_failed(self, status_code: int, error_message: str) -> None:
"""标记请求失败"""
self.status_code = status_code
self.error_message = error_message
def is_success(self) -> bool:
"""检查请求是否成功"""
return self.status_code < 400
def build_response_body(self, response_time_ms: int) -> Dict[str, Any]:
"""
构建响应体元数据
用于记录到 Usage 表的 response_body 字段。
"""
return {
"chunks": self.parsed_chunks,
"metadata": {
"stream": True,
"total_chunks": len(self.parsed_chunks),
"data_count": self.data_count,
"has_completion": self.has_completion,
"response_time_ms": response_time_ms,
},
}
def get_log_summary(self, request_id: str, response_time_ms: int) -> str:
"""
获取日志摘要
用于请求完成/失败时的日志输出。
"""
status = "OK" if self.is_success() else "FAIL"
return (
f"[{status}] {request_id[:8]} | {self.model} | "
f"{self.provider_name or 'unknown'} | {response_time_ms}ms | "
f"in:{self.input_tokens} out:{self.output_tokens}"
)

View File

@@ -0,0 +1,349 @@
"""
流式处理器 - 从 ChatHandlerBase 提取的流式响应处理逻辑
职责:
1. SSE 事件解析和处理
2. 响应流生成
3. 预读和嵌套错误检测
4. 客户端断开检测
"""
import asyncio
import json
from typing import Any, AsyncGenerator, Callable, Optional
import httpx
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
from src.core.exceptions import EmbeddedErrorException
from src.core.logger import logger
from src.models.database import Provider, ProviderEndpoint
from src.utils.sse_parser import SSEEventParser
class StreamProcessor:
"""
流式响应处理器
负责处理 SSE 流的解析、错误检测和响应生成。
从 ChatHandlerBase 中提取,使其职责更加单一。
"""
def __init__(
self,
request_id: str,
default_parser: ResponseParser,
on_streaming_start: Optional[Callable[[], None]] = None,
):
"""
初始化流处理器
Args:
request_id: 请求 ID用于日志
default_parser: 默认响应解析器
on_streaming_start: 流开始时的回调(用于更新状态)
"""
self.request_id = request_id
self.default_parser = default_parser
self.on_streaming_start = on_streaming_start
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
"""
获取 Provider 格式的解析器
根据 Provider 的 API 格式选择正确的解析器。
"""
if ctx.provider_api_format:
try:
return get_parser_for_format(ctx.provider_api_format)
except KeyError:
pass
return self.default_parser
def handle_sse_event(
self,
ctx: StreamContext,
event_name: Optional[str],
data_str: str,
) -> None:
"""
处理单个 SSE 事件
解析事件数据,提取 usage 信息和文本内容。
Args:
ctx: 流式上下文
event_name: 事件名称
data_str: 事件数据字符串
"""
if not data_str:
return
if data_str == "[DONE]":
ctx.has_completion = True
return
try:
data = json.loads(data_str)
except json.JSONDecodeError:
return
ctx.data_count += 1
if not isinstance(data, dict):
return
# 收集原始 chunk 数据
ctx.parsed_chunks.append(data)
# 根据 Provider 格式选择解析器
parser = self.get_parser_for_provider(ctx)
# 使用解析器提取 usage
usage = parser.extract_usage_from_response(data)
if usage:
ctx.update_usage(
input_tokens=usage.get("input_tokens"),
output_tokens=usage.get("output_tokens"),
cached_tokens=usage.get("cache_read_tokens"),
cache_creation_tokens=usage.get("cache_creation_tokens"),
)
# 提取文本
text = parser.extract_text_content(data)
if text:
ctx.collected_text += text
# 检查完成
event_type = event_name or data.get("type", "")
if event_type in ("response.completed", "message_stop"):
ctx.has_completion = True
async def prefetch_and_check_error(
self,
line_iterator: Any,
provider: Provider,
endpoint: ProviderEndpoint,
ctx: StreamContext,
max_prefetch_lines: int = 5,
) -> list:
"""
预读流的前几行,检测嵌套错误
某些 Provider如 Gemini可能返回 HTTP 200但在响应体中包含错误信息。
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
Args:
line_iterator: 行迭代器
provider: Provider 对象
endpoint: Endpoint 对象
ctx: 流式上下文
max_prefetch_lines: 最多预读行数
Returns:
预读的行列表
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
"""
prefetched_lines: list = []
parser = self.get_parser_for_provider(ctx)
try:
line_count = 0
async for line in line_iterator:
prefetched_lines.append(line)
line_count += 1
normalized_line = line.rstrip("\r")
if not normalized_line or normalized_line.startswith(":"):
if line_count >= max_prefetch_lines:
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and parser.is_error_response(data):
parsed = parser.parse_response(data, 200)
logger.warning(
f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}"
)
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
break
except EmbeddedErrorException:
raise
except Exception as e:
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_lines
async def create_response_stream(
self,
ctx: StreamContext,
line_iterator: Any,
response_ctx: Any,
http_client: httpx.AsyncClient,
prefetched_lines: Optional[list] = None,
) -> AsyncGenerator[bytes, None]:
"""
创建响应流生成器
统一的流生成器,支持带预读数据和不带预读数据两种情况。
Args:
ctx: 流式上下文
line_iterator: 行迭代器
response_ctx: HTTP 响应上下文管理器
http_client: HTTP 客户端
prefetched_lines: 预读的行列表(可选)
Yields:
编码后的响应数据块
"""
try:
sse_parser = SSEEventParser()
streaming_started = False
# 处理预读数据
if prefetched_lines:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
for line in prefetched_lines:
for chunk in self._process_line(ctx, sse_parser, line):
yield chunk
# 处理剩余的流数据
async for line in line_iterator:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
for chunk in self._process_line(ctx, sse_parser, line):
yield chunk
# 处理剩余事件
for event in sse_parser.flush():
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
except GeneratorExit:
raise
finally:
await self._cleanup(response_ctx, http_client)
def _process_line(
self,
ctx: StreamContext,
sse_parser: SSEEventParser,
line: str,
) -> list[bytes]:
"""
处理单行数据
Args:
ctx: 流式上下文
sse_parser: SSE 解析器
line: 原始行数据
Returns:
要发送的数据块列表
"""
result: list[bytes] = []
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
result.append(b"\n")
else:
ctx.chunk_count += 1
result.append((line + "\n").encode("utf-8"))
for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
return result
async def create_monitored_stream(
self,
ctx: StreamContext,
stream_generator: AsyncGenerator[bytes, None],
is_disconnected: Callable[[], Any],
) -> AsyncGenerator[bytes, None]:
"""
创建带监控的流生成器
检测客户端断开连接并更新状态码。
Args:
ctx: 流式上下文
stream_generator: 原始流生成器
is_disconnected: 检查客户端是否断开的函数
Yields:
响应数据块
"""
try:
async for chunk in stream_generator:
if await is_disconnected():
logger.warning(f"ID:{self.request_id} | Client disconnected")
ctx.status_code = 499 # Client Closed Request
ctx.error_message = "client_disconnected"
break
yield chunk
except asyncio.CancelledError:
ctx.status_code = 499
ctx.error_message = "client_disconnected"
raise
except Exception as e:
ctx.status_code = 500
ctx.error_message = str(e)
raise
async def _cleanup(
self,
response_ctx: Any,
http_client: httpx.AsyncClient,
) -> None:
"""清理资源"""
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
try:
await http_client.aclose()
except Exception:
pass

View File

@@ -0,0 +1,293 @@
"""
流式遥测记录器 - 从 ChatHandlerBase 提取的统计记录逻辑
职责:
1. 记录流式请求的成功/失败统计
2. 更新 Usage 状态
3. 更新候选记录状态
"""
import asyncio
from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
from src.api.handlers.base.base_handler import MessageTelemetry
from src.api.handlers.base.stream_context import StreamContext
from src.config.settings import config
from src.core.logger import logger
from src.database import get_db
from src.models.database import ApiKey, User
class StreamTelemetryRecorder:
"""
流式遥测记录器
负责在流式请求完成后记录统计信息。
从 ChatHandlerBase 中提取的 _record_stream_stats 逻辑。
"""
def __init__(
self,
request_id: str,
user_id: str,
api_key_id: str,
client_ip: str,
format_id: str,
):
"""
初始化遥测记录器
Args:
request_id: 请求 ID
user_id: 用户 ID
api_key_id: API Key ID
client_ip: 客户端 IP
format_id: API 格式标识
"""
self.request_id = request_id
self.user_id = user_id
self.api_key_id = api_key_id
self.client_ip = client_ip
self.format_id = format_id
async def record_stream_stats(
self,
ctx: StreamContext,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
response_time_ms: int,
) -> None:
"""
记录流式统计信息
Args:
ctx: 流式上下文
original_headers: 原始请求头
original_request_body: 原始请求体
response_time_ms: 响应时间(毫秒)
"""
bg_db = None
try:
await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭
if not ctx.provider_name:
await self._update_usage_status_on_error(
response_time_ms=response_time_ms,
error_message="Provider name not available",
)
return
db_gen = get_db()
bg_db = next(db_gen)
try:
user = bg_db.query(User).filter(User.id == self.user_id).first()
api_key_obj = bg_db.query(ApiKey).filter(ApiKey.id == self.api_key_id).first()
if not user or not api_key_obj:
logger.warning(
f"[{self.request_id}] User or ApiKey not found, updating status directly"
)
await self._update_usage_status_directly(
bg_db,
status="completed" if ctx.is_success() else "failed",
response_time_ms=response_time_ms,
status_code=ctx.status_code,
)
return
bg_telemetry = MessageTelemetry(
bg_db, user, api_key_obj, self.request_id, self.client_ip
)
actual_request_body = ctx.provider_request_body or original_request_body
response_body = ctx.build_response_body(response_time_ms)
if ctx.is_success():
await self._record_success(
bg_telemetry,
ctx,
original_headers,
actual_request_body,
response_body,
response_time_ms,
)
else:
await self._record_failure(
bg_telemetry,
ctx,
original_headers,
actual_request_body,
response_body,
response_time_ms,
)
# 更新候选记录状态
await self._update_candidate_status(bg_db, ctx, response_time_ms)
finally:
if bg_db:
bg_db.close()
except Exception as e:
logger.exception("记录流式统计信息时出错")
await self._update_usage_status_on_error(
response_time_ms=response_time_ms,
error_message=f"记录统计信息失败: {str(e)[:200]}",
)
async def _record_success(
self,
telemetry: MessageTelemetry,
ctx: StreamContext,
original_headers: Dict[str, str],
actual_request_body: Dict[str, Any],
response_body: Dict[str, Any],
response_time_ms: int,
) -> None:
"""记录成功的请求"""
await telemetry.record_success(
provider=ctx.provider_name or "unknown",
model=ctx.model,
input_tokens=ctx.input_tokens,
output_tokens=ctx.output_tokens,
response_time_ms=response_time_ms,
status_code=ctx.status_code,
request_headers=original_headers,
request_body=actual_request_body,
response_headers=ctx.response_headers,
response_body=response_body,
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
is_stream=True,
provider_request_headers=ctx.provider_request_headers,
api_format=ctx.api_format,
provider_id=ctx.provider_id,
provider_endpoint_id=ctx.endpoint_id,
provider_api_key_id=ctx.key_id,
target_model=ctx.mapped_model,
)
logger.debug(f"{self.format_id} 流式响应完成")
logger.info(ctx.get_log_summary(self.request_id, response_time_ms))
async def _record_failure(
self,
telemetry: MessageTelemetry,
ctx: StreamContext,
original_headers: Dict[str, str],
actual_request_body: Dict[str, Any],
response_body: Dict[str, Any],
response_time_ms: int,
) -> None:
"""记录失败的请求"""
await telemetry.record_failure(
provider=ctx.provider_name or "unknown",
model=ctx.model,
response_time_ms=response_time_ms,
status_code=ctx.status_code,
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
api_format=ctx.api_format,
provider_request_headers=ctx.provider_request_headers,
input_tokens=ctx.input_tokens,
output_tokens=ctx.output_tokens,
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
response_body=response_body,
target_model=ctx.mapped_model,
)
logger.debug(f"{self.format_id} 流式响应中断")
log_summary = ctx.get_log_summary(self.request_id, response_time_ms)
# 对于失败日志,添加缓存信息
logger.info(f"{log_summary} cache:{ctx.cached_tokens}")
async def _update_candidate_status(
self,
db: Session,
ctx: StreamContext,
response_time_ms: int,
) -> None:
"""更新候选记录状态"""
if not ctx.attempt_id:
return
from src.services.request.candidate import RequestCandidateService
if ctx.is_success():
RequestCandidateService.mark_candidate_success(
db=db,
candidate_id=ctx.attempt_id,
status_code=ctx.status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": True,
"data_count": ctx.data_count,
},
)
else:
error_type = "client_disconnected" if ctx.status_code == 499 else "stream_error"
RequestCandidateService.mark_candidate_failed(
db=db,
candidate_id=ctx.attempt_id,
error_type=error_type,
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
status_code=ctx.status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": False,
"data_count": ctx.data_count,
},
)
async def _update_usage_status_on_error(
self,
response_time_ms: int,
error_message: str,
) -> None:
"""在记录失败时更新 Usage 状态"""
try:
db_gen = get_db()
error_db = next(db_gen)
try:
await self._update_usage_status_directly(
error_db,
status="failed",
response_time_ms=response_time_ms,
status_code=500,
error_message=error_message,
)
finally:
error_db.close()
except Exception as inner_e:
logger.error(f"[{self.request_id}] 更新 Usage 状态失败: {inner_e}")
async def _update_usage_status_directly(
self,
db: Session,
status: str,
response_time_ms: int,
status_code: int = 200,
error_message: Optional[str] = None,
) -> None:
"""直接更新 Usage 表的状态字段"""
try:
from src.models.database import Usage
usage = db.query(Usage).filter(Usage.request_id == self.request_id).first()
if usage:
setattr(usage, "status", status)
setattr(usage, "status_code", status_code)
setattr(usage, "response_time_ms", response_time_ms)
if error_message:
setattr(usage, "error_message", error_message)
db.commit()
logger.debug(f"[{self.request_id}] Usage 状态已更新: {status}")
except Exception as e:
logger.error(f"[{self.request_id}] 直接更新 Usage 状态失败: {e}")

View File

@@ -120,6 +120,23 @@ class Config:
self.db_pool_recycle = int(os.getenv("DB_POOL_RECYCLE", "3600")) self.db_pool_recycle = int(os.getenv("DB_POOL_RECYCLE", "3600"))
self.db_pool_warn_threshold = int(os.getenv("DB_POOL_WARN_THRESHOLD", "70")) self.db_pool_warn_threshold = int(os.getenv("DB_POOL_WARN_THRESHOLD", "70"))
# 并发控制配置
# CONCURRENCY_SLOT_TTL: 并发槽位 TTL防止死锁
# CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 30%
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.3"))
# HTTP 请求超时配置(秒)
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))
# 流式处理配置
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
# 验证连接池配置 # 验证连接池配置
self._validate_pool_config() self._validate_pool_config()

View File

@@ -13,7 +13,7 @@ import asyncio
import math import math
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import timedelta from datetime import timedelta # noqa: F401 - kept for potential future use
from typing import Optional, Tuple from typing import Optional, Tuple
import redis.asyncio as aioredis import redis.asyncio as aioredis
@@ -185,8 +185,8 @@ class ConcurrencyManager:
key_id: str, key_id: str,
key_max_concurrent: Optional[int], key_max_concurrent: Optional[int],
is_cached_user: bool = False, # 新增:是否是缓存用户 is_cached_user: bool = False, # 新增:是否是缓存用户
cache_reservation_ratio: float = 0.3, # 新增:缓存预留比例 cache_reservation_ratio: Optional[float] = None, # 缓存预留比例None 时从配置读取
ttl_seconds: int = 600, # 10分钟 TTL防止死锁 ttl_seconds: Optional[int] = None, # TTL 秒数None 时从配置读取
) -> bool: ) -> bool:
""" """
尝试获取并发槽位(支持缓存用户优先级) 尝试获取并发槽位(支持缓存用户优先级)
@@ -197,8 +197,8 @@ class ConcurrencyManager:
key_id: ProviderAPIKey ID key_id: ProviderAPIKey ID
key_max_concurrent: Key 最大并发数None 表示不限制) key_max_concurrent: Key 最大并发数None 表示不限制)
is_cached_user: 是否是缓存用户(缓存用户可使用全部槽位) is_cached_user: 是否是缓存用户(缓存用户可使用全部槽位)
cache_reservation_ratio: 缓存预留比例默认30%,只对新用户生效) cache_reservation_ratio: 缓存预留比例None 时从配置读取
ttl_seconds: TTL 秒数,防止异常情况下的死锁 ttl_seconds: TTL 秒数,None 时从配置读取
Returns: Returns:
是否成功获取True/False 是否成功获取True/False
@@ -209,6 +209,14 @@ class ConcurrencyManager:
- 缓存用户最多使用: 10个槽位全部 - 缓存用户最多使用: 10个槽位全部
- 预留的3个槽位专门给缓存用户保证他们的请求优先 - 预留的3个槽位专门给缓存用户保证他们的请求优先
""" """
# 从配置读取默认值
from src.config.settings import config
if cache_reservation_ratio is None:
cache_reservation_ratio = config.cache_reservation_ratio
if ttl_seconds is None:
ttl_seconds = config.concurrency_slot_ttl
if self._redis is None: if self._redis is None:
async with self._memory_lock: async with self._memory_lock:
endpoint_count = self._memory_endpoint_counts.get(endpoint_id, 0) endpoint_count = self._memory_endpoint_counts.get(endpoint_id, 0)
@@ -426,7 +434,7 @@ class ConcurrencyManager:
key_id: str, key_id: str,
key_max_concurrent: Optional[int], key_max_concurrent: Optional[int],
is_cached_user: bool = False, # 新增:是否是缓存用户 is_cached_user: bool = False, # 新增:是否是缓存用户
cache_reservation_ratio: float = 0.3, # 新增:缓存预留比例 cache_reservation_ratio: Optional[float] = None, # 缓存预留比例None 时从配置读取
): ):
""" """
并发控制上下文管理器(支持缓存用户优先级) 并发控制上下文管理器(支持缓存用户优先级)
@@ -441,6 +449,12 @@ class ConcurrencyManager:
如果获取失败,会抛出 ConcurrencyLimitError 异常 如果获取失败,会抛出 ConcurrencyLimitError 异常
""" """
# 从配置读取默认值
from src.config.settings import config
if cache_reservation_ratio is None:
cache_reservation_ratio = config.cache_reservation_ratio
# 尝试获取槽位(传递缓存用户参数) # 尝试获取槽位(传递缓存用户参数)
acquired = await self.acquire_slot( acquired = await self.acquire_slot(
endpoint_id, endpoint_id,