refactor: 重构流式处理模块,提取 StreamContext/Processor/Telemetry

- 将 chat_handler_base.py 中的流式处理逻辑拆分为三个独立模块:
  - StreamContext: 类型安全的流式上下文数据类,替代原有的 ctx dict
  - StreamProcessor: SSE 解析、预读、嵌套错误检测
  - StreamTelemetryRecorder: 统计记录(Usage/Audit/Candidate)
- 将硬编码配置外置到 settings.py,支持环境变量覆盖:
  - HTTP 超时配置(connect/write/pool)
  - 流式处理配置(预读行数、统计延迟)
  - 并发控制配置(槽位 TTL、缓存预留比例)
This commit is contained in:
fawney19
2025-12-12 15:42:45 +08:00
parent 39defce71c
commit 53bf74429e
6 changed files with 922 additions and 600 deletions

View File

@@ -12,10 +12,13 @@ Chat Handler Base - Chat API 格式的通用基类
- apply_mapped_model(): 模型映射
- get_model_for_url(): URL 模型名
- _extract_usage(): 使用量提取
重构说明:
- StreamContext: 类型安全的流式上下文,替代原有的 ctx dict
- StreamProcessor: 流式响应处理SSE 解析、预读、错误检测)
- StreamTelemetryRecorder: 统计记录Usage、Audit、Candidate
"""
import asyncio
import json
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, Dict, Optional
@@ -24,13 +27,14 @@ from fastapi import BackgroundTasks, Request
from fastapi.responses import JSONResponse, StreamingResponse
from sqlalchemy.orm import Session
from src.api.handlers.base.base_handler import (
BaseMessageHandler,
MessageTelemetry,
)
from src.api.handlers.base.base_handler import BaseMessageHandler
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.stream_processor import StreamProcessor
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
from src.config.settings import config
from src.core.exceptions import (
EmbeddedErrorException,
ProviderAuthException,
@@ -39,7 +43,6 @@ from src.core.exceptions import (
ProviderTimeoutException,
)
from src.core.logger import logger
from src.database import get_db
from src.models.database import (
ApiKey,
Provider,
@@ -48,7 +51,6 @@ from src.models.database import (
User,
)
from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser
@@ -285,30 +287,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
model = getattr(converted_request, "model", original_request_body.get("model", "unknown"))
api_format = self.allowed_api_formats[0]
# 用于跟踪的上下文
ctx = {
"model": model,
"api_format": api_format,
"provider_name": None,
"provider_id": None,
"endpoint_id": None,
"key_id": None,
"attempt_id": None,
"provider_api_format": None, # Provider 的响应格式(用于选择解析器)
"input_tokens": 0,
"output_tokens": 0,
"cached_tokens": 0,
"cache_creation_tokens": 0,
"collected_text": "",
"status_code": 200,
"response_headers": {},
"provider_request_headers": {},
"provider_request_body": None,
"data_count": 0,
"chunk_count": 0,
"has_completion": False,
"parsed_chunks": [], # 收集解析后的 chunks
}
# 创建类型安全的流式上下文
ctx = StreamContext(model=model, api_format=api_format)
# 创建流处理器
stream_processor = StreamProcessor(
request_id=self.request_id,
default_parser=self.parser,
on_streaming_start=self._update_usage_to_streaming,
)
# 定义请求函数
async def stream_request_func(
@@ -318,6 +305,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
) -> AsyncGenerator[bytes, None]:
return await self._execute_stream_request(
ctx,
stream_processor,
provider,
endpoint,
key,
@@ -350,23 +338,39 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
is_stream=True,
capability_requirements=capability_requirements or None,
)
ctx["attempt_id"] = attempt_id
ctx["provider_name"] = provider_name
ctx["provider_id"] = provider_id
ctx["endpoint_id"] = endpoint_id
ctx["key_id"] = key_id
# 更新上下文
ctx.attempt_id = attempt_id
ctx.provider_name = provider_name
ctx.provider_id = provider_id
ctx.endpoint_id = endpoint_id
ctx.key_id = key_id
# 创建遥测记录器
telemetry_recorder = StreamTelemetryRecorder(
request_id=self.request_id,
user_id=str(self.user.id),
api_key_id=str(self.api_key.id),
client_ip=self.client_ip,
format_id=self.FORMAT_ID,
)
# 创建后台任务记录统计
background_tasks = BackgroundTasks()
background_tasks.add_task(
self._record_stream_stats,
telemetry_recorder.record_stream_stats,
ctx,
original_headers,
original_request_body,
self.elapsed_ms(),
)
# 创建监控流
monitored_stream = self._create_monitored_stream(ctx, stream_generator, http_request)
monitored_stream = stream_processor.create_monitored_stream(
ctx,
stream_generator,
http_request.is_disconnected,
)
return StreamingResponse(
monitored_stream,
@@ -381,7 +385,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
async def _execute_stream_request(
self,
ctx: Dict,
ctx: StreamContext,
stream_processor: StreamProcessor,
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
@@ -390,37 +395,32 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
query_params: Optional[Dict[str, str]] = None,
) -> AsyncGenerator[bytes, None]:
"""执行流式请求并返回流生成器"""
# 重置上下文状态(重试时清除之前的数据,避免累积
ctx["parsed_chunks"] = []
ctx["chunk_count"] = 0
ctx["data_count"] = 0
ctx["has_completion"] = False
ctx["collected_text"] = ""
ctx["input_tokens"] = 0
ctx["output_tokens"] = 0
ctx["cached_tokens"] = 0
ctx["cache_creation_tokens"] = 0
# 重置上下文状态(重试时清除之前的数据)
ctx.reset_for_retry()
ctx["provider_name"] = str(provider.name)
ctx["provider_id"] = str(provider.id)
ctx["endpoint_id"] = str(endpoint.id)
ctx["key_id"] = str(key.id)
ctx["provider_api_format"] = str(endpoint.api_format) if endpoint.api_format else ""
# 更新 Provider 信息
ctx.update_provider_info(
provider_name=str(provider.name),
provider_id=str(provider.id),
endpoint_id=str(endpoint.id),
key_id=str(key.id),
provider_api_format=str(endpoint.api_format) if endpoint.api_format else None,
)
# 获取模型映射
mapped_model = await self._get_mapped_model(
source_model=ctx["model"],
source_model=ctx.model,
provider_id=str(provider.id),
)
# 应用模型映射到请求体
if mapped_model:
ctx["mapped_model"] = mapped_model # 保存映射后的模型名,用于 Usage 记录
ctx.mapped_model = mapped_model
request_body = self.apply_mapped_model(original_request_body, mapped_model)
else:
request_body = dict(original_request_body)
# 准备发送给 Provider 的请求体(子类可覆盖以移除不需要的字段)
# 准备发送给 Provider 的请求体
request_body = self.prepare_provider_request_body(request_body)
# 构建请求
@@ -432,11 +432,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
is_stream=True,
)
ctx["provider_request_headers"] = provider_headers
ctx["provider_request_body"] = provider_payload
ctx.provider_request_headers = provider_headers
ctx.provider_request_body = provider_payload
# 获取 URL 模型名(兜底使用 ctx 中的 model确保 Gemini 等格式能正确构建 URL
url_model = self.get_model_for_url(request_body, mapped_model) or ctx["model"]
# 获取 URL 模型名
url_model = self.get_model_for_url(request_body, mapped_model) or ctx.model
url = build_provider_url(
endpoint,
@@ -445,15 +445,17 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
is_stream=True,
)
logger.debug(f" [{self.request_id}] 发送流式请求: Provider={provider.name}, "
f"模型={ctx['model']} -> {mapped_model or '无映射'}")
logger.debug(
f" [{self.request_id}] 发送流式请求: Provider={provider.name}, "
f"模型={ctx.model} -> {mapped_model or '无映射'}"
)
# 发送请求
# 发送请求(使用配置中的超时设置)
timeout_config = httpx.Timeout(
connect=10.0,
connect=config.http_connect_timeout,
read=float(endpoint.timeout),
write=60.0, # 写入超时增加到60秒支持大请求体如包含图片的长对话
pool=10.0,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
)
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
@@ -463,17 +465,21 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
)
stream_response = await response_ctx.__aenter__()
ctx["status_code"] = stream_response.status_code
ctx["response_headers"] = dict(stream_response.headers)
ctx.status_code = stream_response.status_code
ctx.response_headers = dict(stream_response.headers)
stream_response.raise_for_status()
# 创建行迭代器(只创建一次,后续会继续使用)
# 创建行迭代器
line_iterator = stream_response.aiter_lines()
# 预读第一个数据块检测嵌套错误HTTP 200 但响应体包含错误)
prefetched_lines = await self._prefetch_and_check_embedded_error(
line_iterator, provider, endpoint, ctx
# 预读检测嵌套错误
prefetched_lines = await stream_processor.prefetch_and_check_error(
line_iterator,
provider,
endpoint,
ctx,
max_prefetch_lines=config.stream_prefetch_lines,
)
except httpx.HTTPStatusError as e:
@@ -483,7 +489,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
raise
except EmbeddedErrorException:
# 嵌套错误需要触发重试,关闭连接后重新抛出
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
@@ -495,8 +500,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
await http_client.aclose()
raise
# 创建流生成器(带预读数据,使用同一个迭代器)
return self._create_response_stream_with_prefetch(
# 创建流生成器
return stream_processor.create_response_stream(
ctx,
line_iterator,
response_ctx,
@@ -504,518 +509,9 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
prefetched_lines,
)
async def _create_response_stream(
self,
ctx: Dict,
stream_response: httpx.Response,
response_ctx: Any,
http_client: httpx.AsyncClient,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器"""
try:
sse_parser = SSEEventParser()
streaming_status_updated = False
async for line in stream_response.aiter_lines():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming()
streaming_status_updated = True
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
yield b"\n"
continue
ctx["chunk_count"] += 1
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
# 处理剩余事件
for event in sse_parser.flush():
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
except GeneratorExit:
raise
finally:
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
try:
await http_client.aclose()
except Exception:
pass
async def _prefetch_and_check_embedded_error(
self,
line_iterator: Any,
provider: Provider,
endpoint: ProviderEndpoint,
ctx: Dict,
) -> list:
"""
预读流的前几行,检测嵌套错误
某些 Provider如 Gemini可能返回 HTTP 200但在响应体中包含错误信息。
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
Args:
line_iterator: 行迭代器aiter_lines() 返回的迭代器)
provider: Provider 对象
endpoint: Endpoint 对象
ctx: 上下文字典
Returns:
预读的行列表(需要在后续流中先输出)
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
"""
prefetched_lines: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误
try:
# 获取对应格式的解析器
provider_parser = self._get_provider_parser(ctx)
line_count = 0
async for line in line_iterator:
prefetched_lines.append(line)
line_count += 1
# 解析数据
normalized_line = line.rstrip("\r")
if not normalized_line or normalized_line.startswith(":"):
# 空行或注释行,继续预读
if line_count >= max_prefetch_lines:
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
# 不是有效 JSON可能是部分数据继续
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
break
except EmbeddedErrorException:
# 重新抛出嵌套错误
raise
except Exception as e:
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_lines
async def _create_response_stream_with_prefetch(
self,
ctx: Dict,
line_iterator: Any,
response_ctx: Any,
http_client: httpx.AsyncClient,
prefetched_lines: list,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器(带预读数据)"""
try:
sse_parser = SSEEventParser()
# 在第一次输出数据前更新状态为 streaming
if prefetched_lines:
self._update_usage_to_streaming()
# 先输出预读的数据
for line in prefetched_lines:
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
yield b"\n"
continue
ctx["chunk_count"] += 1
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
# 继续输出剩余的流数据(使用同一个迭代器)
async for line in line_iterator:
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
yield b"\n"
continue
ctx["chunk_count"] += 1
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
# 处理剩余事件
for event in sse_parser.flush():
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
except GeneratorExit:
raise
finally:
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
try:
await http_client.aclose()
except Exception:
pass
def _get_provider_parser(self, ctx: Dict) -> ResponseParser:
"""
获取 Provider 格式的解析器
根据 Provider 的 API 格式选择正确的解析器,
而不是根据请求格式选择。
"""
provider_format = ctx.get("provider_api_format")
if provider_format:
try:
return get_parser_for_format(provider_format)
except KeyError:
pass
# 回退到默认解析器
return self.parser
def _handle_sse_event(
self,
ctx: Dict,
event_name: Optional[str],
data_str: str,
) -> None:
"""处理 SSE 事件"""
if not data_str:
return
if data_str == "[DONE]":
ctx["has_completion"] = True
return
try:
data = json.loads(data_str)
except json.JSONDecodeError:
return
ctx["data_count"] += 1
if not isinstance(data, dict):
return
# 收集原始 chunk 数据
ctx["parsed_chunks"].append(data)
# 根据 Provider 格式选择解析器
provider_parser = self._get_provider_parser(ctx)
# 使用解析器提取 usage
usage = provider_parser.extract_usage_from_response(data)
if usage:
ctx["input_tokens"] = usage.get("input_tokens", ctx["input_tokens"])
ctx["output_tokens"] = usage.get("output_tokens", ctx["output_tokens"])
ctx["cached_tokens"] = usage.get("cache_read_tokens", ctx["cached_tokens"])
ctx["cache_creation_tokens"] = usage.get(
"cache_creation_tokens", ctx["cache_creation_tokens"]
)
# 提取文本
text = provider_parser.extract_text_content(data)
if text:
ctx["collected_text"] += text
# 检查完成
event_type = event_name or data.get("type", "")
if event_type in ("response.completed", "message_stop"):
ctx["has_completion"] = True
async def _create_monitored_stream(
self,
ctx: Dict,
stream_generator: AsyncGenerator[bytes, None],
http_request: Request,
) -> AsyncGenerator[bytes, None]:
"""创建带监控的流生成器"""
try:
async for chunk in stream_generator:
if await http_request.is_disconnected():
logger.warning(f"ID:{self.request_id} | Client disconnected")
# 客户端断开时设置 499 状态码Client Closed Request
# 注意Provider 可能已经成功返回数据,但客户端未完整接收
ctx["status_code"] = 499
break
yield chunk
except asyncio.CancelledError:
ctx["status_code"] = 499
raise
except Exception:
ctx["status_code"] = 500
raise
async def _record_stream_stats(
self,
ctx: Dict,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
) -> None:
"""记录流式统计信息"""
response_time_ms = self.elapsed_ms()
bg_db = None
try:
await asyncio.sleep(0.1)
if not ctx["provider_name"]:
# 即使没有 provider_name也要尝试更新状态为 failed
await self._update_usage_status_on_error(
response_time_ms=response_time_ms,
error_message="Provider name not available",
)
return
db_gen = get_db()
bg_db = next(db_gen)
try:
from src.models.database import ApiKey as ApiKeyModel
user = bg_db.query(User).filter(User.id == self.user.id).first()
api_key_obj = (
bg_db.query(ApiKeyModel).filter(ApiKeyModel.id == self.api_key.id).first()
)
if not user or not api_key_obj:
logger.warning(f"[{self.request_id}] User or ApiKey not found, updating status directly")
await self._update_usage_status_directly(
bg_db,
status="completed" if ctx["status_code"] == 200 else "failed",
response_time_ms=response_time_ms,
status_code=ctx["status_code"],
)
return
bg_telemetry = MessageTelemetry(
bg_db, user, api_key_obj, self.request_id, self.client_ip
)
actual_request_body = ctx["provider_request_body"] or original_request_body
# 构建响应体(与 CLI 模式一致)
response_body = {
"chunks": ctx["parsed_chunks"],
"metadata": {
"stream": True,
"total_chunks": len(ctx["parsed_chunks"]),
"data_count": ctx["data_count"],
"has_completion": ctx["has_completion"],
"response_time_ms": response_time_ms,
},
}
# 根据状态码决定记录成功还是失败
# 499 = 客户端断开连接503 = 服务不可用(如流中断)
status_code: int = ctx.get("status_code") or 200
if status_code >= 400:
# 记录失败的 Usage但使用已收到的预估 token 信息(来自 message_start
# 这样即使请求中断,也能记录预估成本
await bg_telemetry.record_failure(
provider=ctx.get("provider_name") or "unknown",
model=ctx["model"],
response_time_ms=response_time_ms,
status_code=status_code,
error_message=ctx.get("error_message") or f"HTTP {status_code}",
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
api_format=ctx["api_format"],
provider_request_headers=ctx["provider_request_headers"],
# 预估 token 信息(来自 message_start 事件)
input_tokens=ctx.get("input_tokens", 0),
output_tokens=ctx.get("output_tokens", 0),
cache_creation_tokens=ctx.get("cache_creation_tokens", 0),
cache_read_tokens=ctx.get("cached_tokens", 0),
response_body=response_body,
# 模型映射信息
target_model=ctx.get("mapped_model"),
)
logger.debug(f"{self.FORMAT_ID} 流式响应中断")
# 简洁的请求失败摘要(包含预估 token 信息)
logger.info(f"[FAIL] {self.request_id[:8]} | {ctx['model']} | {ctx.get('provider_name', 'unknown')} | {response_time_ms}ms | "
f"{status_code} | in:{ctx.get('input_tokens', 0)} out:{ctx.get('output_tokens', 0)} cache:{ctx.get('cached_tokens', 0)}")
else:
await bg_telemetry.record_success(
provider=ctx["provider_name"],
model=ctx["model"],
input_tokens=ctx["input_tokens"],
output_tokens=ctx["output_tokens"],
response_time_ms=response_time_ms,
status_code=status_code,
request_headers=original_headers,
request_body=actual_request_body,
response_headers=ctx["response_headers"],
response_body=response_body,
cache_creation_tokens=ctx["cache_creation_tokens"],
cache_read_tokens=ctx["cached_tokens"],
is_stream=True,
provider_request_headers=ctx["provider_request_headers"],
api_format=ctx["api_format"],
provider_id=ctx["provider_id"],
provider_endpoint_id=ctx["endpoint_id"],
provider_api_key_id=ctx["key_id"],
# 模型映射信息
target_model=ctx.get("mapped_model"),
)
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
# 简洁的请求完成摘要
logger.info(f"[OK] {self.request_id[:8]} | {ctx['model']} | {ctx.get('provider_name', 'unknown')} | {response_time_ms}ms | "
f"in:{ctx.get('input_tokens', 0) or 0} out:{ctx.get('output_tokens', 0) or 0}")
# 更新候选记录的最终状态和延迟时间
# 注意RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
# 这里用流传输完成后的实际时间覆盖
if ctx.get("attempt_id"):
from src.services.request.candidate import RequestCandidateService
# 根据状态码决定是成功还是失败(复用上面已定义的 status_code
# 499 = 客户端断开连接,应标记为失败
# 503 = 服务不可用(如流中断),应标记为失败
if status_code and status_code >= 400:
RequestCandidateService.mark_candidate_failed(
db=bg_db,
candidate_id=ctx["attempt_id"],
error_type="client_disconnected" if status_code == 499 else "stream_error",
error_message=ctx.get("error_message") or f"HTTP {status_code}",
status_code=status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": False,
"data_count": ctx.get("data_count", 0),
},
)
else:
RequestCandidateService.mark_candidate_success(
db=bg_db,
candidate_id=ctx["attempt_id"],
status_code=status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": True,
"data_count": ctx.get("data_count", 0),
},
)
finally:
if bg_db:
bg_db.close()
except Exception as e:
logger.exception("记录流式统计信息时出错")
# 确保即使出错也要更新状态,避免 pending 状态卡住
await self._update_usage_status_on_error(
response_time_ms=response_time_ms,
error_message=f"记录统计信息失败: {str(e)[:200]}",
)
# _update_usage_to_streaming 方法已移至基类 BaseMessageHandler
async def _update_usage_status_on_error(
self,
response_time_ms: int,
error_message: str,
) -> None:
"""在记录失败时更新 Usage 状态,避免卡在 pending"""
try:
db_gen = get_db()
error_db = next(db_gen)
try:
await self._update_usage_status_directly(
error_db,
status="failed",
response_time_ms=response_time_ms,
status_code=500,
error_message=error_message,
)
finally:
error_db.close()
except Exception as inner_e:
logger.error(f"[{self.request_id}] 更新 Usage 状态失败: {inner_e}")
async def _update_usage_status_directly(
self,
db: Session,
status: str,
response_time_ms: int,
status_code: int = 200,
error_message: Optional[str] = None,
) -> None:
"""直接更新 Usage 表的状态字段"""
try:
from src.models.database import Usage
usage = db.query(Usage).filter(Usage.request_id == self.request_id).first()
if usage:
setattr(usage, "status", status)
setattr(usage, "status_code", status_code)
setattr(usage, "response_time_ms", response_time_ms)
if error_message:
setattr(usage, "error_message", error_message)
db.commit()
logger.debug(f"[{self.request_id}] Usage 状态已更新: {status}")
except Exception as e:
logger.error(f"[{self.request_id}] 直接更新 Usage 状态失败: {e}")
async def _record_stream_failure(
self,
ctx: Dict,
ctx: StreamContext,
error: Exception,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
@@ -1031,21 +527,20 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
elif isinstance(error, ProviderTimeoutException):
status_code = 504
actual_request_body = ctx.get("provider_request_body") or original_request_body
actual_request_body = ctx.provider_request_body or original_request_body
await self.telemetry.record_failure(
provider=ctx.get("provider_name") or "unknown",
model=ctx["model"],
provider=ctx.provider_name or "unknown",
model=ctx.model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=str(error),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
api_format=ctx["api_format"],
provider_request_headers=ctx.get("provider_request_headers") or {},
# 模型映射信息
target_model=ctx.get("mapped_model"),
api_format=ctx.api_format,
provider_request_headers=ctx.provider_request_headers,
target_model=ctx.mapped_model,
)
# ==================== 非流式处理 ====================

View File

@@ -0,0 +1,154 @@
"""
流式处理上下文 - 类型安全的数据类替代 dict
提供流式请求处理过程中的状态跟踪,包括:
- Provider/Endpoint/Key 信息
- Token 统计
- 响应状态
- 请求/响应数据
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
@dataclass
class StreamContext:
"""
流式处理上下文
用于在流式请求处理过程中跟踪状态,替代原有的 ctx dict。
所有字段都有类型注解,提供更好的 IDE 支持和运行时类型安全。
"""
# 请求基本信息
model: str
api_format: str
# Provider 信息(在请求执行时填充)
provider_name: Optional[str] = None
provider_id: Optional[str] = None
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
attempt_id: Optional[str] = None
provider_api_format: Optional[str] = None # Provider 的响应格式
# 模型映射
mapped_model: Optional[str] = None
# Token 统计
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0
cache_creation_tokens: int = 0
# 响应内容
collected_text: str = ""
# 响应状态
status_code: int = 200
error_message: Optional[str] = None
has_completion: bool = False
# 请求/响应数据
response_headers: Dict[str, str] = field(default_factory=dict)
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None
# 流式处理统计
data_count: int = 0
chunk_count: int = 0
parsed_chunks: List[Dict[str, Any]] = field(default_factory=list)
def reset_for_retry(self) -> None:
"""
重试时重置状态
在故障转移重试时调用,清除之前的数据避免累积。
保留 model 和 api_format重置其他所有状态。
"""
self.parsed_chunks = []
self.chunk_count = 0
self.data_count = 0
self.has_completion = False
self.collected_text = ""
self.input_tokens = 0
self.output_tokens = 0
self.cached_tokens = 0
self.cache_creation_tokens = 0
self.error_message = None
self.status_code = 200
self.response_headers = {}
self.provider_request_headers = {}
self.provider_request_body = None
def update_provider_info(
self,
provider_name: str,
provider_id: str,
endpoint_id: str,
key_id: str,
provider_api_format: Optional[str] = None,
) -> None:
"""更新 Provider 信息"""
self.provider_name = provider_name
self.provider_id = provider_id
self.endpoint_id = endpoint_id
self.key_id = key_id
self.provider_api_format = provider_api_format
def update_usage(
self,
input_tokens: Optional[int] = None,
output_tokens: Optional[int] = None,
cached_tokens: Optional[int] = None,
cache_creation_tokens: Optional[int] = None,
) -> None:
"""更新 Token 使用统计"""
if input_tokens is not None:
self.input_tokens = input_tokens
if output_tokens is not None:
self.output_tokens = output_tokens
if cached_tokens is not None:
self.cached_tokens = cached_tokens
if cache_creation_tokens is not None:
self.cache_creation_tokens = cache_creation_tokens
def mark_failed(self, status_code: int, error_message: str) -> None:
"""标记请求失败"""
self.status_code = status_code
self.error_message = error_message
def is_success(self) -> bool:
"""检查请求是否成功"""
return self.status_code < 400
def build_response_body(self, response_time_ms: int) -> Dict[str, Any]:
"""
构建响应体元数据
用于记录到 Usage 表的 response_body 字段。
"""
return {
"chunks": self.parsed_chunks,
"metadata": {
"stream": True,
"total_chunks": len(self.parsed_chunks),
"data_count": self.data_count,
"has_completion": self.has_completion,
"response_time_ms": response_time_ms,
},
}
def get_log_summary(self, request_id: str, response_time_ms: int) -> str:
"""
获取日志摘要
用于请求完成/失败时的日志输出。
"""
status = "OK" if self.is_success() else "FAIL"
return (
f"[{status}] {request_id[:8]} | {self.model} | "
f"{self.provider_name or 'unknown'} | {response_time_ms}ms | "
f"in:{self.input_tokens} out:{self.output_tokens}"
)

View File

@@ -0,0 +1,349 @@
"""
流式处理器 - 从 ChatHandlerBase 提取的流式响应处理逻辑
职责:
1. SSE 事件解析和处理
2. 响应流生成
3. 预读和嵌套错误检测
4. 客户端断开检测
"""
import asyncio
import json
from typing import Any, AsyncGenerator, Callable, Optional
import httpx
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
from src.core.exceptions import EmbeddedErrorException
from src.core.logger import logger
from src.models.database import Provider, ProviderEndpoint
from src.utils.sse_parser import SSEEventParser
class StreamProcessor:
"""
流式响应处理器
负责处理 SSE 流的解析、错误检测和响应生成。
从 ChatHandlerBase 中提取,使其职责更加单一。
"""
def __init__(
self,
request_id: str,
default_parser: ResponseParser,
on_streaming_start: Optional[Callable[[], None]] = None,
):
"""
初始化流处理器
Args:
request_id: 请求 ID用于日志
default_parser: 默认响应解析器
on_streaming_start: 流开始时的回调(用于更新状态)
"""
self.request_id = request_id
self.default_parser = default_parser
self.on_streaming_start = on_streaming_start
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
"""
获取 Provider 格式的解析器
根据 Provider 的 API 格式选择正确的解析器。
"""
if ctx.provider_api_format:
try:
return get_parser_for_format(ctx.provider_api_format)
except KeyError:
pass
return self.default_parser
def handle_sse_event(
self,
ctx: StreamContext,
event_name: Optional[str],
data_str: str,
) -> None:
"""
处理单个 SSE 事件
解析事件数据,提取 usage 信息和文本内容。
Args:
ctx: 流式上下文
event_name: 事件名称
data_str: 事件数据字符串
"""
if not data_str:
return
if data_str == "[DONE]":
ctx.has_completion = True
return
try:
data = json.loads(data_str)
except json.JSONDecodeError:
return
ctx.data_count += 1
if not isinstance(data, dict):
return
# 收集原始 chunk 数据
ctx.parsed_chunks.append(data)
# 根据 Provider 格式选择解析器
parser = self.get_parser_for_provider(ctx)
# 使用解析器提取 usage
usage = parser.extract_usage_from_response(data)
if usage:
ctx.update_usage(
input_tokens=usage.get("input_tokens"),
output_tokens=usage.get("output_tokens"),
cached_tokens=usage.get("cache_read_tokens"),
cache_creation_tokens=usage.get("cache_creation_tokens"),
)
# 提取文本
text = parser.extract_text_content(data)
if text:
ctx.collected_text += text
# 检查完成
event_type = event_name or data.get("type", "")
if event_type in ("response.completed", "message_stop"):
ctx.has_completion = True
async def prefetch_and_check_error(
self,
line_iterator: Any,
provider: Provider,
endpoint: ProviderEndpoint,
ctx: StreamContext,
max_prefetch_lines: int = 5,
) -> list:
"""
预读流的前几行,检测嵌套错误
某些 Provider如 Gemini可能返回 HTTP 200但在响应体中包含错误信息。
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
Args:
line_iterator: 行迭代器
provider: Provider 对象
endpoint: Endpoint 对象
ctx: 流式上下文
max_prefetch_lines: 最多预读行数
Returns:
预读的行列表
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
"""
prefetched_lines: list = []
parser = self.get_parser_for_provider(ctx)
try:
line_count = 0
async for line in line_iterator:
prefetched_lines.append(line)
line_count += 1
normalized_line = line.rstrip("\r")
if not normalized_line or normalized_line.startswith(":"):
if line_count >= max_prefetch_lines:
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and parser.is_error_response(data):
parsed = parser.parse_response(data, 200)
logger.warning(
f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}"
)
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
break
except EmbeddedErrorException:
raise
except Exception as e:
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_lines
async def create_response_stream(
self,
ctx: StreamContext,
line_iterator: Any,
response_ctx: Any,
http_client: httpx.AsyncClient,
prefetched_lines: Optional[list] = None,
) -> AsyncGenerator[bytes, None]:
"""
创建响应流生成器
统一的流生成器,支持带预读数据和不带预读数据两种情况。
Args:
ctx: 流式上下文
line_iterator: 行迭代器
response_ctx: HTTP 响应上下文管理器
http_client: HTTP 客户端
prefetched_lines: 预读的行列表(可选)
Yields:
编码后的响应数据块
"""
try:
sse_parser = SSEEventParser()
streaming_started = False
# 处理预读数据
if prefetched_lines:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
for line in prefetched_lines:
for chunk in self._process_line(ctx, sse_parser, line):
yield chunk
# 处理剩余的流数据
async for line in line_iterator:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
for chunk in self._process_line(ctx, sse_parser, line):
yield chunk
# 处理剩余事件
for event in sse_parser.flush():
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
except GeneratorExit:
raise
finally:
await self._cleanup(response_ctx, http_client)
def _process_line(
self,
ctx: StreamContext,
sse_parser: SSEEventParser,
line: str,
) -> list[bytes]:
"""
处理单行数据
Args:
ctx: 流式上下文
sse_parser: SSE 解析器
line: 原始行数据
Returns:
要发送的数据块列表
"""
result: list[bytes] = []
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
result.append(b"\n")
else:
ctx.chunk_count += 1
result.append((line + "\n").encode("utf-8"))
for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
return result
async def create_monitored_stream(
self,
ctx: StreamContext,
stream_generator: AsyncGenerator[bytes, None],
is_disconnected: Callable[[], Any],
) -> AsyncGenerator[bytes, None]:
"""
创建带监控的流生成器
检测客户端断开连接并更新状态码。
Args:
ctx: 流式上下文
stream_generator: 原始流生成器
is_disconnected: 检查客户端是否断开的函数
Yields:
响应数据块
"""
try:
async for chunk in stream_generator:
if await is_disconnected():
logger.warning(f"ID:{self.request_id} | Client disconnected")
ctx.status_code = 499 # Client Closed Request
ctx.error_message = "client_disconnected"
break
yield chunk
except asyncio.CancelledError:
ctx.status_code = 499
ctx.error_message = "client_disconnected"
raise
except Exception as e:
ctx.status_code = 500
ctx.error_message = str(e)
raise
async def _cleanup(
self,
response_ctx: Any,
http_client: httpx.AsyncClient,
) -> None:
"""清理资源"""
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
try:
await http_client.aclose()
except Exception:
pass

View File

@@ -0,0 +1,293 @@
"""
流式遥测记录器 - 从 ChatHandlerBase 提取的统计记录逻辑
职责:
1. 记录流式请求的成功/失败统计
2. 更新 Usage 状态
3. 更新候选记录状态
"""
import asyncio
from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
from src.api.handlers.base.base_handler import MessageTelemetry
from src.api.handlers.base.stream_context import StreamContext
from src.config.settings import config
from src.core.logger import logger
from src.database import get_db
from src.models.database import ApiKey, User
class StreamTelemetryRecorder:
"""
流式遥测记录器
负责在流式请求完成后记录统计信息。
从 ChatHandlerBase 中提取的 _record_stream_stats 逻辑。
"""
def __init__(
self,
request_id: str,
user_id: str,
api_key_id: str,
client_ip: str,
format_id: str,
):
"""
初始化遥测记录器
Args:
request_id: 请求 ID
user_id: 用户 ID
api_key_id: API Key ID
client_ip: 客户端 IP
format_id: API 格式标识
"""
self.request_id = request_id
self.user_id = user_id
self.api_key_id = api_key_id
self.client_ip = client_ip
self.format_id = format_id
async def record_stream_stats(
self,
ctx: StreamContext,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
response_time_ms: int,
) -> None:
"""
记录流式统计信息
Args:
ctx: 流式上下文
original_headers: 原始请求头
original_request_body: 原始请求体
response_time_ms: 响应时间(毫秒)
"""
bg_db = None
try:
await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭
if not ctx.provider_name:
await self._update_usage_status_on_error(
response_time_ms=response_time_ms,
error_message="Provider name not available",
)
return
db_gen = get_db()
bg_db = next(db_gen)
try:
user = bg_db.query(User).filter(User.id == self.user_id).first()
api_key_obj = bg_db.query(ApiKey).filter(ApiKey.id == self.api_key_id).first()
if not user or not api_key_obj:
logger.warning(
f"[{self.request_id}] User or ApiKey not found, updating status directly"
)
await self._update_usage_status_directly(
bg_db,
status="completed" if ctx.is_success() else "failed",
response_time_ms=response_time_ms,
status_code=ctx.status_code,
)
return
bg_telemetry = MessageTelemetry(
bg_db, user, api_key_obj, self.request_id, self.client_ip
)
actual_request_body = ctx.provider_request_body or original_request_body
response_body = ctx.build_response_body(response_time_ms)
if ctx.is_success():
await self._record_success(
bg_telemetry,
ctx,
original_headers,
actual_request_body,
response_body,
response_time_ms,
)
else:
await self._record_failure(
bg_telemetry,
ctx,
original_headers,
actual_request_body,
response_body,
response_time_ms,
)
# 更新候选记录状态
await self._update_candidate_status(bg_db, ctx, response_time_ms)
finally:
if bg_db:
bg_db.close()
except Exception as e:
logger.exception("记录流式统计信息时出错")
await self._update_usage_status_on_error(
response_time_ms=response_time_ms,
error_message=f"记录统计信息失败: {str(e)[:200]}",
)
async def _record_success(
self,
telemetry: MessageTelemetry,
ctx: StreamContext,
original_headers: Dict[str, str],
actual_request_body: Dict[str, Any],
response_body: Dict[str, Any],
response_time_ms: int,
) -> None:
"""记录成功的请求"""
await telemetry.record_success(
provider=ctx.provider_name or "unknown",
model=ctx.model,
input_tokens=ctx.input_tokens,
output_tokens=ctx.output_tokens,
response_time_ms=response_time_ms,
status_code=ctx.status_code,
request_headers=original_headers,
request_body=actual_request_body,
response_headers=ctx.response_headers,
response_body=response_body,
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
is_stream=True,
provider_request_headers=ctx.provider_request_headers,
api_format=ctx.api_format,
provider_id=ctx.provider_id,
provider_endpoint_id=ctx.endpoint_id,
provider_api_key_id=ctx.key_id,
target_model=ctx.mapped_model,
)
logger.debug(f"{self.format_id} 流式响应完成")
logger.info(ctx.get_log_summary(self.request_id, response_time_ms))
async def _record_failure(
self,
telemetry: MessageTelemetry,
ctx: StreamContext,
original_headers: Dict[str, str],
actual_request_body: Dict[str, Any],
response_body: Dict[str, Any],
response_time_ms: int,
) -> None:
"""记录失败的请求"""
await telemetry.record_failure(
provider=ctx.provider_name or "unknown",
model=ctx.model,
response_time_ms=response_time_ms,
status_code=ctx.status_code,
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
api_format=ctx.api_format,
provider_request_headers=ctx.provider_request_headers,
input_tokens=ctx.input_tokens,
output_tokens=ctx.output_tokens,
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
response_body=response_body,
target_model=ctx.mapped_model,
)
logger.debug(f"{self.format_id} 流式响应中断")
log_summary = ctx.get_log_summary(self.request_id, response_time_ms)
# 对于失败日志,添加缓存信息
logger.info(f"{log_summary} cache:{ctx.cached_tokens}")
async def _update_candidate_status(
self,
db: Session,
ctx: StreamContext,
response_time_ms: int,
) -> None:
"""更新候选记录状态"""
if not ctx.attempt_id:
return
from src.services.request.candidate import RequestCandidateService
if ctx.is_success():
RequestCandidateService.mark_candidate_success(
db=db,
candidate_id=ctx.attempt_id,
status_code=ctx.status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": True,
"data_count": ctx.data_count,
},
)
else:
error_type = "client_disconnected" if ctx.status_code == 499 else "stream_error"
RequestCandidateService.mark_candidate_failed(
db=db,
candidate_id=ctx.attempt_id,
error_type=error_type,
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
status_code=ctx.status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": False,
"data_count": ctx.data_count,
},
)
async def _update_usage_status_on_error(
self,
response_time_ms: int,
error_message: str,
) -> None:
"""在记录失败时更新 Usage 状态"""
try:
db_gen = get_db()
error_db = next(db_gen)
try:
await self._update_usage_status_directly(
error_db,
status="failed",
response_time_ms=response_time_ms,
status_code=500,
error_message=error_message,
)
finally:
error_db.close()
except Exception as inner_e:
logger.error(f"[{self.request_id}] 更新 Usage 状态失败: {inner_e}")
async def _update_usage_status_directly(
self,
db: Session,
status: str,
response_time_ms: int,
status_code: int = 200,
error_message: Optional[str] = None,
) -> None:
"""直接更新 Usage 表的状态字段"""
try:
from src.models.database import Usage
usage = db.query(Usage).filter(Usage.request_id == self.request_id).first()
if usage:
setattr(usage, "status", status)
setattr(usage, "status_code", status_code)
setattr(usage, "response_time_ms", response_time_ms)
if error_message:
setattr(usage, "error_message", error_message)
db.commit()
logger.debug(f"[{self.request_id}] Usage 状态已更新: {status}")
except Exception as e:
logger.error(f"[{self.request_id}] 直接更新 Usage 状态失败: {e}")

View File

@@ -120,6 +120,23 @@ class Config:
self.db_pool_recycle = int(os.getenv("DB_POOL_RECYCLE", "3600"))
self.db_pool_warn_threshold = int(os.getenv("DB_POOL_WARN_THRESHOLD", "70"))
# 并发控制配置
# CONCURRENCY_SLOT_TTL: 并发槽位 TTL防止死锁
# CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 30%
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.3"))
# HTTP 请求超时配置(秒)
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))
# 流式处理配置
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
# 验证连接池配置
self._validate_pool_config()

View File

@@ -13,7 +13,7 @@ import asyncio
import math
import os
from contextlib import asynccontextmanager
from datetime import timedelta
from datetime import timedelta # noqa: F401 - kept for potential future use
from typing import Optional, Tuple
import redis.asyncio as aioredis
@@ -185,8 +185,8 @@ class ConcurrencyManager:
key_id: str,
key_max_concurrent: Optional[int],
is_cached_user: bool = False, # 新增:是否是缓存用户
cache_reservation_ratio: float = 0.3, # 新增:缓存预留比例
ttl_seconds: int = 600, # 10分钟 TTL防止死锁
cache_reservation_ratio: Optional[float] = None, # 缓存预留比例None 时从配置读取
ttl_seconds: Optional[int] = None, # TTL 秒数None 时从配置读取
) -> bool:
"""
尝试获取并发槽位(支持缓存用户优先级)
@@ -197,8 +197,8 @@ class ConcurrencyManager:
key_id: ProviderAPIKey ID
key_max_concurrent: Key 最大并发数None 表示不限制)
is_cached_user: 是否是缓存用户(缓存用户可使用全部槽位)
cache_reservation_ratio: 缓存预留比例默认30%,只对新用户生效)
ttl_seconds: TTL 秒数,防止异常情况下的死锁
cache_reservation_ratio: 缓存预留比例None 时从配置读取
ttl_seconds: TTL 秒数,None 时从配置读取
Returns:
是否成功获取True/False
@@ -209,6 +209,14 @@ class ConcurrencyManager:
- 缓存用户最多使用: 10个槽位全部
- 预留的3个槽位专门给缓存用户保证他们的请求优先
"""
# 从配置读取默认值
from src.config.settings import config
if cache_reservation_ratio is None:
cache_reservation_ratio = config.cache_reservation_ratio
if ttl_seconds is None:
ttl_seconds = config.concurrency_slot_ttl
if self._redis is None:
async with self._memory_lock:
endpoint_count = self._memory_endpoint_counts.get(endpoint_id, 0)
@@ -426,7 +434,7 @@ class ConcurrencyManager:
key_id: str,
key_max_concurrent: Optional[int],
is_cached_user: bool = False, # 新增:是否是缓存用户
cache_reservation_ratio: float = 0.3, # 新增:缓存预留比例
cache_reservation_ratio: Optional[float] = None, # 缓存预留比例None 时从配置读取
):
"""
并发控制上下文管理器(支持缓存用户优先级)
@@ -441,6 +449,12 @@ class ConcurrencyManager:
如果获取失败,会抛出 ConcurrencyLimitError 异常
"""
# 从配置读取默认值
from src.config.settings import config
if cache_reservation_ratio is None:
cache_reservation_ratio = config.cache_reservation_ratio
# 尝试获取槽位(传递缓存用户参数)
acquired = await self.acquire_slot(
endpoint_id,