diff --git a/src/api/handlers/base/chat_handler_base.py b/src/api/handlers/base/chat_handler_base.py index 84e3d72..e599ac9 100644 --- a/src/api/handlers/base/chat_handler_base.py +++ b/src/api/handlers/base/chat_handler_base.py @@ -12,10 +12,13 @@ Chat Handler Base - Chat API 格式的通用基类 - apply_mapped_model(): 模型映射 - get_model_for_url(): URL 模型名 - _extract_usage(): 使用量提取 + +重构说明: +- StreamContext: 类型安全的流式上下文,替代原有的 ctx dict +- StreamProcessor: 流式响应处理(SSE 解析、预读、错误检测) +- StreamTelemetryRecorder: 统计记录(Usage、Audit、Candidate) """ -import asyncio -import json from abc import ABC, abstractmethod from typing import Any, AsyncGenerator, Callable, Dict, Optional @@ -24,13 +27,14 @@ from fastapi import BackgroundTasks, Request from fastapi.responses import JSONResponse, StreamingResponse from sqlalchemy.orm import Session -from src.api.handlers.base.base_handler import ( - BaseMessageHandler, - MessageTelemetry, -) +from src.api.handlers.base.base_handler import BaseMessageHandler from src.api.handlers.base.parsers import get_parser_for_format from src.api.handlers.base.request_builder import PassthroughRequestBuilder from src.api.handlers.base.response_parser import ResponseParser +from src.api.handlers.base.stream_context import StreamContext +from src.api.handlers.base.stream_processor import StreamProcessor +from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder +from src.config.settings import config from src.core.exceptions import ( EmbeddedErrorException, ProviderAuthException, @@ -39,7 +43,6 @@ from src.core.exceptions import ( ProviderTimeoutException, ) from src.core.logger import logger -from src.database import get_db from src.models.database import ( ApiKey, Provider, @@ -48,7 +51,6 @@ from src.models.database import ( User, ) from src.services.provider.transport import build_provider_url -from src.utils.sse_parser import SSEEventParser @@ -285,30 +287,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC): model = getattr(converted_request, "model", original_request_body.get("model", "unknown")) api_format = self.allowed_api_formats[0] - # 用于跟踪的上下文 - ctx = { - "model": model, - "api_format": api_format, - "provider_name": None, - "provider_id": None, - "endpoint_id": None, - "key_id": None, - "attempt_id": None, - "provider_api_format": None, # Provider 的响应格式(用于选择解析器) - "input_tokens": 0, - "output_tokens": 0, - "cached_tokens": 0, - "cache_creation_tokens": 0, - "collected_text": "", - "status_code": 200, - "response_headers": {}, - "provider_request_headers": {}, - "provider_request_body": None, - "data_count": 0, - "chunk_count": 0, - "has_completion": False, - "parsed_chunks": [], # 收集解析后的 chunks - } + # 创建类型安全的流式上下文 + ctx = StreamContext(model=model, api_format=api_format) + + # 创建流处理器 + stream_processor = StreamProcessor( + request_id=self.request_id, + default_parser=self.parser, + on_streaming_start=self._update_usage_to_streaming, + ) # 定义请求函数 async def stream_request_func( @@ -318,6 +305,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC): ) -> AsyncGenerator[bytes, None]: return await self._execute_stream_request( ctx, + stream_processor, provider, endpoint, key, @@ -350,23 +338,39 @@ class ChatHandlerBase(BaseMessageHandler, ABC): is_stream=True, capability_requirements=capability_requirements or None, ) - ctx["attempt_id"] = attempt_id - ctx["provider_name"] = provider_name - ctx["provider_id"] = provider_id - ctx["endpoint_id"] = endpoint_id - ctx["key_id"] = key_id + + # 更新上下文 + ctx.attempt_id = attempt_id + ctx.provider_name = provider_name + ctx.provider_id = provider_id + ctx.endpoint_id = endpoint_id + ctx.key_id = key_id + + # 创建遥测记录器 + telemetry_recorder = StreamTelemetryRecorder( + request_id=self.request_id, + user_id=str(self.user.id), + api_key_id=str(self.api_key.id), + client_ip=self.client_ip, + format_id=self.FORMAT_ID, + ) # 创建后台任务记录统计 background_tasks = BackgroundTasks() background_tasks.add_task( - self._record_stream_stats, + telemetry_recorder.record_stream_stats, ctx, original_headers, original_request_body, + self.elapsed_ms(), ) # 创建监控流 - monitored_stream = self._create_monitored_stream(ctx, stream_generator, http_request) + monitored_stream = stream_processor.create_monitored_stream( + ctx, + stream_generator, + http_request.is_disconnected, + ) return StreamingResponse( monitored_stream, @@ -381,7 +385,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC): async def _execute_stream_request( self, - ctx: Dict, + ctx: StreamContext, + stream_processor: StreamProcessor, provider: Provider, endpoint: ProviderEndpoint, key: ProviderAPIKey, @@ -390,37 +395,32 @@ class ChatHandlerBase(BaseMessageHandler, ABC): query_params: Optional[Dict[str, str]] = None, ) -> AsyncGenerator[bytes, None]: """执行流式请求并返回流生成器""" - # 重置上下文状态(重试时清除之前的数据,避免累积) - ctx["parsed_chunks"] = [] - ctx["chunk_count"] = 0 - ctx["data_count"] = 0 - ctx["has_completion"] = False - ctx["collected_text"] = "" - ctx["input_tokens"] = 0 - ctx["output_tokens"] = 0 - ctx["cached_tokens"] = 0 - ctx["cache_creation_tokens"] = 0 + # 重置上下文状态(重试时清除之前的数据) + ctx.reset_for_retry() - ctx["provider_name"] = str(provider.name) - ctx["provider_id"] = str(provider.id) - ctx["endpoint_id"] = str(endpoint.id) - ctx["key_id"] = str(key.id) - ctx["provider_api_format"] = str(endpoint.api_format) if endpoint.api_format else "" + # 更新 Provider 信息 + ctx.update_provider_info( + provider_name=str(provider.name), + provider_id=str(provider.id), + endpoint_id=str(endpoint.id), + key_id=str(key.id), + provider_api_format=str(endpoint.api_format) if endpoint.api_format else None, + ) # 获取模型映射 mapped_model = await self._get_mapped_model( - source_model=ctx["model"], + source_model=ctx.model, provider_id=str(provider.id), ) # 应用模型映射到请求体 if mapped_model: - ctx["mapped_model"] = mapped_model # 保存映射后的模型名,用于 Usage 记录 + ctx.mapped_model = mapped_model request_body = self.apply_mapped_model(original_request_body, mapped_model) else: request_body = dict(original_request_body) - # 准备发送给 Provider 的请求体(子类可覆盖以移除不需要的字段) + # 准备发送给 Provider 的请求体 request_body = self.prepare_provider_request_body(request_body) # 构建请求 @@ -432,11 +432,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC): is_stream=True, ) - ctx["provider_request_headers"] = provider_headers - ctx["provider_request_body"] = provider_payload + ctx.provider_request_headers = provider_headers + ctx.provider_request_body = provider_payload - # 获取 URL 模型名(兜底使用 ctx 中的 model,确保 Gemini 等格式能正确构建 URL) - url_model = self.get_model_for_url(request_body, mapped_model) or ctx["model"] + # 获取 URL 模型名 + url_model = self.get_model_for_url(request_body, mapped_model) or ctx.model url = build_provider_url( endpoint, @@ -445,15 +445,17 @@ class ChatHandlerBase(BaseMessageHandler, ABC): is_stream=True, ) - logger.debug(f" [{self.request_id}] 发送流式请求: Provider={provider.name}, " - f"模型={ctx['model']} -> {mapped_model or '无映射'}") + logger.debug( + f" [{self.request_id}] 发送流式请求: Provider={provider.name}, " + f"模型={ctx.model} -> {mapped_model or '无映射'}" + ) - # 发送请求 + # 发送请求(使用配置中的超时设置) timeout_config = httpx.Timeout( - connect=10.0, + connect=config.http_connect_timeout, read=float(endpoint.timeout), - write=60.0, # 写入超时增加到60秒,支持大请求体(如包含图片的长对话) - pool=10.0, + write=config.http_write_timeout, + pool=config.http_pool_timeout, ) http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True) @@ -463,17 +465,21 @@ class ChatHandlerBase(BaseMessageHandler, ABC): ) stream_response = await response_ctx.__aenter__() - ctx["status_code"] = stream_response.status_code - ctx["response_headers"] = dict(stream_response.headers) + ctx.status_code = stream_response.status_code + ctx.response_headers = dict(stream_response.headers) stream_response.raise_for_status() - # 创建行迭代器(只创建一次,后续会继续使用) + # 创建行迭代器 line_iterator = stream_response.aiter_lines() - # 预读第一个数据块,检测嵌套错误(HTTP 200 但响应体包含错误) - prefetched_lines = await self._prefetch_and_check_embedded_error( - line_iterator, provider, endpoint, ctx + # 预读检测嵌套错误 + prefetched_lines = await stream_processor.prefetch_and_check_error( + line_iterator, + provider, + endpoint, + ctx, + max_prefetch_lines=config.stream_prefetch_lines, ) except httpx.HTTPStatusError as e: @@ -483,7 +489,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC): raise except EmbeddedErrorException: - # 嵌套错误需要触发重试,关闭连接后重新抛出 try: await response_ctx.__aexit__(None, None, None) except Exception: @@ -495,8 +500,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC): await http_client.aclose() raise - # 创建流生成器(带预读数据,使用同一个迭代器) - return self._create_response_stream_with_prefetch( + # 创建流生成器 + return stream_processor.create_response_stream( ctx, line_iterator, response_ctx, @@ -504,518 +509,9 @@ class ChatHandlerBase(BaseMessageHandler, ABC): prefetched_lines, ) - async def _create_response_stream( - self, - ctx: Dict, - stream_response: httpx.Response, - response_ctx: Any, - http_client: httpx.AsyncClient, - ) -> AsyncGenerator[bytes, None]: - """创建响应流生成器""" - try: - sse_parser = SSEEventParser() - streaming_status_updated = False - - async for line in stream_response.aiter_lines(): - # 在第一次输出数据前更新状态为 streaming - if not streaming_status_updated: - self._update_usage_to_streaming() - streaming_status_updated = True - - normalized_line = line.rstrip("\r") - events = sse_parser.feed_line(normalized_line) - - if normalized_line == "": - for event in events: - self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") - yield b"\n" - continue - - ctx["chunk_count"] += 1 - - yield (line + "\n").encode("utf-8") - - for event in events: - self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") - - # 处理剩余事件 - for event in sse_parser.flush(): - self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") - - except GeneratorExit: - raise - finally: - try: - await response_ctx.__aexit__(None, None, None) - except Exception: - pass - try: - await http_client.aclose() - except Exception: - pass - - async def _prefetch_and_check_embedded_error( - self, - line_iterator: Any, - provider: Provider, - endpoint: ProviderEndpoint, - ctx: Dict, - ) -> list: - """ - 预读流的前几行,检测嵌套错误 - - 某些 Provider(如 Gemini)可能返回 HTTP 200,但在响应体中包含错误信息。 - 这种情况需要在流开始输出之前检测,以便触发重试逻辑。 - - Args: - line_iterator: 行迭代器(aiter_lines() 返回的迭代器) - provider: Provider 对象 - endpoint: Endpoint 对象 - ctx: 上下文字典 - - Returns: - 预读的行列表(需要在后续流中先输出) - - Raises: - EmbeddedErrorException: 如果检测到嵌套错误 - """ - prefetched_lines: list = [] - max_prefetch_lines = 5 # 最多预读5行来检测错误 - - try: - # 获取对应格式的解析器 - provider_parser = self._get_provider_parser(ctx) - - line_count = 0 - async for line in line_iterator: - prefetched_lines.append(line) - line_count += 1 - - # 解析数据 - normalized_line = line.rstrip("\r") - if not normalized_line or normalized_line.startswith(":"): - # 空行或注释行,继续预读 - if line_count >= max_prefetch_lines: - break - continue - - # 尝试解析 SSE 数据 - data_str = normalized_line - if normalized_line.startswith("data: "): - data_str = normalized_line[6:] - - if data_str == "[DONE]": - break - - try: - data = json.loads(data_str) - except json.JSONDecodeError: - # 不是有效 JSON,可能是部分数据,继续 - if line_count >= max_prefetch_lines: - break - continue - - # 使用解析器检查是否为错误响应 - if isinstance(data, dict) and provider_parser.is_error_response(data): - # 提取错误信息 - parsed = provider_parser.parse_response(data, 200) - logger.warning(f" [{self.request_id}] 检测到嵌套错误: " - f"Provider={provider.name}, " - f"error_type={parsed.error_type}, " - f"message={parsed.error_message}") - raise EmbeddedErrorException( - provider_name=str(provider.name), - error_code=( - int(parsed.error_type) - if parsed.error_type and parsed.error_type.isdigit() - else None - ), - error_message=parsed.error_message, - error_status=parsed.error_type, - ) - - # 预读到有效数据,没有错误,停止预读 - break - - except EmbeddedErrorException: - # 重新抛出嵌套错误 - raise - except Exception as e: - # 其他异常(如网络错误)在预读阶段发生,记录日志但不中断 - logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}") - - return prefetched_lines - - async def _create_response_stream_with_prefetch( - self, - ctx: Dict, - line_iterator: Any, - response_ctx: Any, - http_client: httpx.AsyncClient, - prefetched_lines: list, - ) -> AsyncGenerator[bytes, None]: - """创建响应流生成器(带预读数据)""" - try: - sse_parser = SSEEventParser() - - # 在第一次输出数据前更新状态为 streaming - if prefetched_lines: - self._update_usage_to_streaming() - - # 先输出预读的数据 - for line in prefetched_lines: - normalized_line = line.rstrip("\r") - events = sse_parser.feed_line(normalized_line) - - if normalized_line == "": - for event in events: - self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") - yield b"\n" - continue - - ctx["chunk_count"] += 1 - yield (line + "\n").encode("utf-8") - - for event in events: - self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") - - # 继续输出剩余的流数据(使用同一个迭代器) - async for line in line_iterator: - normalized_line = line.rstrip("\r") - events = sse_parser.feed_line(normalized_line) - - if normalized_line == "": - for event in events: - self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") - yield b"\n" - continue - - ctx["chunk_count"] += 1 - yield (line + "\n").encode("utf-8") - - for event in events: - self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") - - # 处理剩余事件 - for event in sse_parser.flush(): - self._handle_sse_event(ctx, event.get("event"), event.get("data") or "") - - except GeneratorExit: - raise - finally: - try: - await response_ctx.__aexit__(None, None, None) - except Exception: - pass - try: - await http_client.aclose() - except Exception: - pass - - def _get_provider_parser(self, ctx: Dict) -> ResponseParser: - """ - 获取 Provider 格式的解析器 - - 根据 Provider 的 API 格式选择正确的解析器, - 而不是根据请求格式选择。 - """ - provider_format = ctx.get("provider_api_format") - if provider_format: - try: - return get_parser_for_format(provider_format) - except KeyError: - pass - # 回退到默认解析器 - return self.parser - - def _handle_sse_event( - self, - ctx: Dict, - event_name: Optional[str], - data_str: str, - ) -> None: - """处理 SSE 事件""" - if not data_str: - return - - if data_str == "[DONE]": - ctx["has_completion"] = True - return - - try: - data = json.loads(data_str) - except json.JSONDecodeError: - return - - ctx["data_count"] += 1 - - if not isinstance(data, dict): - return - - # 收集原始 chunk 数据 - ctx["parsed_chunks"].append(data) - - # 根据 Provider 格式选择解析器 - provider_parser = self._get_provider_parser(ctx) - - # 使用解析器提取 usage - usage = provider_parser.extract_usage_from_response(data) - if usage: - ctx["input_tokens"] = usage.get("input_tokens", ctx["input_tokens"]) - ctx["output_tokens"] = usage.get("output_tokens", ctx["output_tokens"]) - ctx["cached_tokens"] = usage.get("cache_read_tokens", ctx["cached_tokens"]) - ctx["cache_creation_tokens"] = usage.get( - "cache_creation_tokens", ctx["cache_creation_tokens"] - ) - - # 提取文本 - text = provider_parser.extract_text_content(data) - if text: - ctx["collected_text"] += text - - # 检查完成 - event_type = event_name or data.get("type", "") - if event_type in ("response.completed", "message_stop"): - ctx["has_completion"] = True - - async def _create_monitored_stream( - self, - ctx: Dict, - stream_generator: AsyncGenerator[bytes, None], - http_request: Request, - ) -> AsyncGenerator[bytes, None]: - """创建带监控的流生成器""" - try: - async for chunk in stream_generator: - if await http_request.is_disconnected(): - logger.warning(f"ID:{self.request_id} | Client disconnected") - # 客户端断开时设置 499 状态码(Client Closed Request) - # 注意:Provider 可能已经成功返回数据,但客户端未完整接收 - ctx["status_code"] = 499 - break - yield chunk - except asyncio.CancelledError: - ctx["status_code"] = 499 - raise - except Exception: - ctx["status_code"] = 500 - raise - - async def _record_stream_stats( - self, - ctx: Dict, - original_headers: Dict[str, str], - original_request_body: Dict[str, Any], - ) -> None: - """记录流式统计信息""" - response_time_ms = self.elapsed_ms() - bg_db = None - - try: - await asyncio.sleep(0.1) - - if not ctx["provider_name"]: - # 即使没有 provider_name,也要尝试更新状态为 failed - await self._update_usage_status_on_error( - response_time_ms=response_time_ms, - error_message="Provider name not available", - ) - return - - db_gen = get_db() - bg_db = next(db_gen) - - try: - from src.models.database import ApiKey as ApiKeyModel - - user = bg_db.query(User).filter(User.id == self.user.id).first() - api_key_obj = ( - bg_db.query(ApiKeyModel).filter(ApiKeyModel.id == self.api_key.id).first() - ) - - if not user or not api_key_obj: - logger.warning(f"[{self.request_id}] User or ApiKey not found, updating status directly") - await self._update_usage_status_directly( - bg_db, - status="completed" if ctx["status_code"] == 200 else "failed", - response_time_ms=response_time_ms, - status_code=ctx["status_code"], - ) - return - - bg_telemetry = MessageTelemetry( - bg_db, user, api_key_obj, self.request_id, self.client_ip - ) - - actual_request_body = ctx["provider_request_body"] or original_request_body - - # 构建响应体(与 CLI 模式一致) - response_body = { - "chunks": ctx["parsed_chunks"], - "metadata": { - "stream": True, - "total_chunks": len(ctx["parsed_chunks"]), - "data_count": ctx["data_count"], - "has_completion": ctx["has_completion"], - "response_time_ms": response_time_ms, - }, - } - - # 根据状态码决定记录成功还是失败 - # 499 = 客户端断开连接,503 = 服务不可用(如流中断) - status_code: int = ctx.get("status_code") or 200 - if status_code >= 400: - # 记录失败的 Usage,但使用已收到的预估 token 信息(来自 message_start) - # 这样即使请求中断,也能记录预估成本 - await bg_telemetry.record_failure( - provider=ctx.get("provider_name") or "unknown", - model=ctx["model"], - response_time_ms=response_time_ms, - status_code=status_code, - error_message=ctx.get("error_message") or f"HTTP {status_code}", - request_headers=original_headers, - request_body=actual_request_body, - is_stream=True, - api_format=ctx["api_format"], - provider_request_headers=ctx["provider_request_headers"], - # 预估 token 信息(来自 message_start 事件) - input_tokens=ctx.get("input_tokens", 0), - output_tokens=ctx.get("output_tokens", 0), - cache_creation_tokens=ctx.get("cache_creation_tokens", 0), - cache_read_tokens=ctx.get("cached_tokens", 0), - response_body=response_body, - # 模型映射信息 - target_model=ctx.get("mapped_model"), - ) - logger.debug(f"{self.FORMAT_ID} 流式响应中断") - # 简洁的请求失败摘要(包含预估 token 信息) - logger.info(f"[FAIL] {self.request_id[:8]} | {ctx['model']} | {ctx.get('provider_name', 'unknown')} | {response_time_ms}ms | " - f"{status_code} | in:{ctx.get('input_tokens', 0)} out:{ctx.get('output_tokens', 0)} cache:{ctx.get('cached_tokens', 0)}") - else: - await bg_telemetry.record_success( - provider=ctx["provider_name"], - model=ctx["model"], - input_tokens=ctx["input_tokens"], - output_tokens=ctx["output_tokens"], - response_time_ms=response_time_ms, - status_code=status_code, - request_headers=original_headers, - request_body=actual_request_body, - response_headers=ctx["response_headers"], - response_body=response_body, - cache_creation_tokens=ctx["cache_creation_tokens"], - cache_read_tokens=ctx["cached_tokens"], - is_stream=True, - provider_request_headers=ctx["provider_request_headers"], - api_format=ctx["api_format"], - provider_id=ctx["provider_id"], - provider_endpoint_id=ctx["endpoint_id"], - provider_api_key_id=ctx["key_id"], - # 模型映射信息 - target_model=ctx.get("mapped_model"), - ) - logger.debug(f"{self.FORMAT_ID} 流式响应完成") - # 简洁的请求完成摘要 - logger.info(f"[OK] {self.request_id[:8]} | {ctx['model']} | {ctx.get('provider_name', 'unknown')} | {response_time_ms}ms | " - f"in:{ctx.get('input_tokens', 0) or 0} out:{ctx.get('output_tokens', 0) or 0}") - - # 更新候选记录的最终状态和延迟时间 - # 注意:RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间) - # 这里用流传输完成后的实际时间覆盖 - if ctx.get("attempt_id"): - from src.services.request.candidate import RequestCandidateService - - # 根据状态码决定是成功还是失败(复用上面已定义的 status_code) - # 499 = 客户端断开连接,应标记为失败 - # 503 = 服务不可用(如流中断),应标记为失败 - if status_code and status_code >= 400: - RequestCandidateService.mark_candidate_failed( - db=bg_db, - candidate_id=ctx["attempt_id"], - error_type="client_disconnected" if status_code == 499 else "stream_error", - error_message=ctx.get("error_message") or f"HTTP {status_code}", - status_code=status_code, - latency_ms=response_time_ms, - extra_data={ - "stream_completed": False, - "data_count": ctx.get("data_count", 0), - }, - ) - else: - RequestCandidateService.mark_candidate_success( - db=bg_db, - candidate_id=ctx["attempt_id"], - status_code=status_code, - latency_ms=response_time_ms, - extra_data={ - "stream_completed": True, - "data_count": ctx.get("data_count", 0), - }, - ) - - finally: - if bg_db: - bg_db.close() - - except Exception as e: - logger.exception("记录流式统计信息时出错") - # 确保即使出错也要更新状态,避免 pending 状态卡住 - await self._update_usage_status_on_error( - response_time_ms=response_time_ms, - error_message=f"记录统计信息失败: {str(e)[:200]}", - ) - - # _update_usage_to_streaming 方法已移至基类 BaseMessageHandler - - async def _update_usage_status_on_error( - self, - response_time_ms: int, - error_message: str, - ) -> None: - """在记录失败时更新 Usage 状态,避免卡在 pending""" - try: - db_gen = get_db() - error_db = next(db_gen) - try: - await self._update_usage_status_directly( - error_db, - status="failed", - response_time_ms=response_time_ms, - status_code=500, - error_message=error_message, - ) - finally: - error_db.close() - except Exception as inner_e: - logger.error(f"[{self.request_id}] 更新 Usage 状态失败: {inner_e}") - - async def _update_usage_status_directly( - self, - db: Session, - status: str, - response_time_ms: int, - status_code: int = 200, - error_message: Optional[str] = None, - ) -> None: - """直接更新 Usage 表的状态字段""" - try: - from src.models.database import Usage - - usage = db.query(Usage).filter(Usage.request_id == self.request_id).first() - if usage: - setattr(usage, "status", status) - setattr(usage, "status_code", status_code) - setattr(usage, "response_time_ms", response_time_ms) - if error_message: - setattr(usage, "error_message", error_message) - db.commit() - logger.debug(f"[{self.request_id}] Usage 状态已更新: {status}") - except Exception as e: - logger.error(f"[{self.request_id}] 直接更新 Usage 状态失败: {e}") - async def _record_stream_failure( self, - ctx: Dict, + ctx: StreamContext, error: Exception, original_headers: Dict[str, str], original_request_body: Dict[str, Any], @@ -1031,21 +527,20 @@ class ChatHandlerBase(BaseMessageHandler, ABC): elif isinstance(error, ProviderTimeoutException): status_code = 504 - actual_request_body = ctx.get("provider_request_body") or original_request_body + actual_request_body = ctx.provider_request_body or original_request_body await self.telemetry.record_failure( - provider=ctx.get("provider_name") or "unknown", - model=ctx["model"], + provider=ctx.provider_name or "unknown", + model=ctx.model, response_time_ms=response_time_ms, status_code=status_code, error_message=str(error), request_headers=original_headers, request_body=actual_request_body, is_stream=True, - api_format=ctx["api_format"], - provider_request_headers=ctx.get("provider_request_headers") or {}, - # 模型映射信息 - target_model=ctx.get("mapped_model"), + api_format=ctx.api_format, + provider_request_headers=ctx.provider_request_headers, + target_model=ctx.mapped_model, ) # ==================== 非流式处理 ==================== diff --git a/src/api/handlers/base/stream_context.py b/src/api/handlers/base/stream_context.py new file mode 100644 index 0000000..711f6a1 --- /dev/null +++ b/src/api/handlers/base/stream_context.py @@ -0,0 +1,154 @@ +""" +流式处理上下文 - 类型安全的数据类替代 dict + +提供流式请求处理过程中的状态跟踪,包括: +- Provider/Endpoint/Key 信息 +- Token 统计 +- 响应状态 +- 请求/响应数据 +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +@dataclass +class StreamContext: + """ + 流式处理上下文 + + 用于在流式请求处理过程中跟踪状态,替代原有的 ctx dict。 + 所有字段都有类型注解,提供更好的 IDE 支持和运行时类型安全。 + """ + + # 请求基本信息 + model: str + api_format: str + + # Provider 信息(在请求执行时填充) + provider_name: Optional[str] = None + provider_id: Optional[str] = None + endpoint_id: Optional[str] = None + key_id: Optional[str] = None + attempt_id: Optional[str] = None + provider_api_format: Optional[str] = None # Provider 的响应格式 + + # 模型映射 + mapped_model: Optional[str] = None + + # Token 统计 + input_tokens: int = 0 + output_tokens: int = 0 + cached_tokens: int = 0 + cache_creation_tokens: int = 0 + + # 响应内容 + collected_text: str = "" + + # 响应状态 + status_code: int = 200 + error_message: Optional[str] = None + has_completion: bool = False + + # 请求/响应数据 + response_headers: Dict[str, str] = field(default_factory=dict) + provider_request_headers: Dict[str, str] = field(default_factory=dict) + provider_request_body: Optional[Dict[str, Any]] = None + + # 流式处理统计 + data_count: int = 0 + chunk_count: int = 0 + parsed_chunks: List[Dict[str, Any]] = field(default_factory=list) + + def reset_for_retry(self) -> None: + """ + 重试时重置状态 + + 在故障转移重试时调用,清除之前的数据避免累积。 + 保留 model 和 api_format,重置其他所有状态。 + """ + self.parsed_chunks = [] + self.chunk_count = 0 + self.data_count = 0 + self.has_completion = False + self.collected_text = "" + self.input_tokens = 0 + self.output_tokens = 0 + self.cached_tokens = 0 + self.cache_creation_tokens = 0 + self.error_message = None + self.status_code = 200 + self.response_headers = {} + self.provider_request_headers = {} + self.provider_request_body = None + + def update_provider_info( + self, + provider_name: str, + provider_id: str, + endpoint_id: str, + key_id: str, + provider_api_format: Optional[str] = None, + ) -> None: + """更新 Provider 信息""" + self.provider_name = provider_name + self.provider_id = provider_id + self.endpoint_id = endpoint_id + self.key_id = key_id + self.provider_api_format = provider_api_format + + def update_usage( + self, + input_tokens: Optional[int] = None, + output_tokens: Optional[int] = None, + cached_tokens: Optional[int] = None, + cache_creation_tokens: Optional[int] = None, + ) -> None: + """更新 Token 使用统计""" + if input_tokens is not None: + self.input_tokens = input_tokens + if output_tokens is not None: + self.output_tokens = output_tokens + if cached_tokens is not None: + self.cached_tokens = cached_tokens + if cache_creation_tokens is not None: + self.cache_creation_tokens = cache_creation_tokens + + def mark_failed(self, status_code: int, error_message: str) -> None: + """标记请求失败""" + self.status_code = status_code + self.error_message = error_message + + def is_success(self) -> bool: + """检查请求是否成功""" + return self.status_code < 400 + + def build_response_body(self, response_time_ms: int) -> Dict[str, Any]: + """ + 构建响应体元数据 + + 用于记录到 Usage 表的 response_body 字段。 + """ + return { + "chunks": self.parsed_chunks, + "metadata": { + "stream": True, + "total_chunks": len(self.parsed_chunks), + "data_count": self.data_count, + "has_completion": self.has_completion, + "response_time_ms": response_time_ms, + }, + } + + def get_log_summary(self, request_id: str, response_time_ms: int) -> str: + """ + 获取日志摘要 + + 用于请求完成/失败时的日志输出。 + """ + status = "OK" if self.is_success() else "FAIL" + return ( + f"[{status}] {request_id[:8]} | {self.model} | " + f"{self.provider_name or 'unknown'} | {response_time_ms}ms | " + f"in:{self.input_tokens} out:{self.output_tokens}" + ) diff --git a/src/api/handlers/base/stream_processor.py b/src/api/handlers/base/stream_processor.py new file mode 100644 index 0000000..5ff85ee --- /dev/null +++ b/src/api/handlers/base/stream_processor.py @@ -0,0 +1,349 @@ +""" +流式处理器 - 从 ChatHandlerBase 提取的流式响应处理逻辑 + +职责: +1. SSE 事件解析和处理 +2. 响应流生成 +3. 预读和嵌套错误检测 +4. 客户端断开检测 +""" + +import asyncio +import json +from typing import Any, AsyncGenerator, Callable, Optional + +import httpx + +from src.api.handlers.base.parsers import get_parser_for_format +from src.api.handlers.base.response_parser import ResponseParser +from src.api.handlers.base.stream_context import StreamContext +from src.core.exceptions import EmbeddedErrorException +from src.core.logger import logger +from src.models.database import Provider, ProviderEndpoint +from src.utils.sse_parser import SSEEventParser + + +class StreamProcessor: + """ + 流式响应处理器 + + 负责处理 SSE 流的解析、错误检测和响应生成。 + 从 ChatHandlerBase 中提取,使其职责更加单一。 + """ + + def __init__( + self, + request_id: str, + default_parser: ResponseParser, + on_streaming_start: Optional[Callable[[], None]] = None, + ): + """ + 初始化流处理器 + + Args: + request_id: 请求 ID(用于日志) + default_parser: 默认响应解析器 + on_streaming_start: 流开始时的回调(用于更新状态) + """ + self.request_id = request_id + self.default_parser = default_parser + self.on_streaming_start = on_streaming_start + + def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser: + """ + 获取 Provider 格式的解析器 + + 根据 Provider 的 API 格式选择正确的解析器。 + """ + if ctx.provider_api_format: + try: + return get_parser_for_format(ctx.provider_api_format) + except KeyError: + pass + return self.default_parser + + def handle_sse_event( + self, + ctx: StreamContext, + event_name: Optional[str], + data_str: str, + ) -> None: + """ + 处理单个 SSE 事件 + + 解析事件数据,提取 usage 信息和文本内容。 + + Args: + ctx: 流式上下文 + event_name: 事件名称 + data_str: 事件数据字符串 + """ + if not data_str: + return + + if data_str == "[DONE]": + ctx.has_completion = True + return + + try: + data = json.loads(data_str) + except json.JSONDecodeError: + return + + ctx.data_count += 1 + + if not isinstance(data, dict): + return + + # 收集原始 chunk 数据 + ctx.parsed_chunks.append(data) + + # 根据 Provider 格式选择解析器 + parser = self.get_parser_for_provider(ctx) + + # 使用解析器提取 usage + usage = parser.extract_usage_from_response(data) + if usage: + ctx.update_usage( + input_tokens=usage.get("input_tokens"), + output_tokens=usage.get("output_tokens"), + cached_tokens=usage.get("cache_read_tokens"), + cache_creation_tokens=usage.get("cache_creation_tokens"), + ) + + # 提取文本 + text = parser.extract_text_content(data) + if text: + ctx.collected_text += text + + # 检查完成 + event_type = event_name or data.get("type", "") + if event_type in ("response.completed", "message_stop"): + ctx.has_completion = True + + async def prefetch_and_check_error( + self, + line_iterator: Any, + provider: Provider, + endpoint: ProviderEndpoint, + ctx: StreamContext, + max_prefetch_lines: int = 5, + ) -> list: + """ + 预读流的前几行,检测嵌套错误 + + 某些 Provider(如 Gemini)可能返回 HTTP 200,但在响应体中包含错误信息。 + 这种情况需要在流开始输出之前检测,以便触发重试逻辑。 + + Args: + line_iterator: 行迭代器 + provider: Provider 对象 + endpoint: Endpoint 对象 + ctx: 流式上下文 + max_prefetch_lines: 最多预读行数 + + Returns: + 预读的行列表 + + Raises: + EmbeddedErrorException: 如果检测到嵌套错误 + """ + prefetched_lines: list = [] + parser = self.get_parser_for_provider(ctx) + + try: + line_count = 0 + async for line in line_iterator: + prefetched_lines.append(line) + line_count += 1 + + normalized_line = line.rstrip("\r") + if not normalized_line or normalized_line.startswith(":"): + if line_count >= max_prefetch_lines: + break + continue + + # 尝试解析 SSE 数据 + data_str = normalized_line + if normalized_line.startswith("data: "): + data_str = normalized_line[6:] + + if data_str == "[DONE]": + break + + try: + data = json.loads(data_str) + except json.JSONDecodeError: + if line_count >= max_prefetch_lines: + break + continue + + # 使用解析器检查是否为错误响应 + if isinstance(data, dict) and parser.is_error_response(data): + parsed = parser.parse_response(data, 200) + logger.warning( + f" [{self.request_id}] 检测到嵌套错误: " + f"Provider={provider.name}, " + f"error_type={parsed.error_type}, " + f"message={parsed.error_message}" + ) + raise EmbeddedErrorException( + provider_name=str(provider.name), + error_code=( + int(parsed.error_type) + if parsed.error_type and parsed.error_type.isdigit() + else None + ), + error_message=parsed.error_message, + error_status=parsed.error_type, + ) + + # 预读到有效数据,没有错误,停止预读 + break + + except EmbeddedErrorException: + raise + except Exception as e: + logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}") + + return prefetched_lines + + async def create_response_stream( + self, + ctx: StreamContext, + line_iterator: Any, + response_ctx: Any, + http_client: httpx.AsyncClient, + prefetched_lines: Optional[list] = None, + ) -> AsyncGenerator[bytes, None]: + """ + 创建响应流生成器 + + 统一的流生成器,支持带预读数据和不带预读数据两种情况。 + + Args: + ctx: 流式上下文 + line_iterator: 行迭代器 + response_ctx: HTTP 响应上下文管理器 + http_client: HTTP 客户端 + prefetched_lines: 预读的行列表(可选) + + Yields: + 编码后的响应数据块 + """ + try: + sse_parser = SSEEventParser() + streaming_started = False + + # 处理预读数据 + if prefetched_lines: + if not streaming_started and self.on_streaming_start: + self.on_streaming_start() + streaming_started = True + + for line in prefetched_lines: + for chunk in self._process_line(ctx, sse_parser, line): + yield chunk + + # 处理剩余的流数据 + async for line in line_iterator: + if not streaming_started and self.on_streaming_start: + self.on_streaming_start() + streaming_started = True + + for chunk in self._process_line(ctx, sse_parser, line): + yield chunk + + # 处理剩余事件 + for event in sse_parser.flush(): + self.handle_sse_event(ctx, event.get("event"), event.get("data") or "") + + except GeneratorExit: + raise + finally: + await self._cleanup(response_ctx, http_client) + + def _process_line( + self, + ctx: StreamContext, + sse_parser: SSEEventParser, + line: str, + ) -> list[bytes]: + """ + 处理单行数据 + + Args: + ctx: 流式上下文 + sse_parser: SSE 解析器 + line: 原始行数据 + + Returns: + 要发送的数据块列表 + """ + result: list[bytes] = [] + normalized_line = line.rstrip("\r") + events = sse_parser.feed_line(normalized_line) + + if normalized_line == "": + for event in events: + self.handle_sse_event(ctx, event.get("event"), event.get("data") or "") + result.append(b"\n") + else: + ctx.chunk_count += 1 + result.append((line + "\n").encode("utf-8")) + + for event in events: + self.handle_sse_event(ctx, event.get("event"), event.get("data") or "") + + return result + + async def create_monitored_stream( + self, + ctx: StreamContext, + stream_generator: AsyncGenerator[bytes, None], + is_disconnected: Callable[[], Any], + ) -> AsyncGenerator[bytes, None]: + """ + 创建带监控的流生成器 + + 检测客户端断开连接并更新状态码。 + + Args: + ctx: 流式上下文 + stream_generator: 原始流生成器 + is_disconnected: 检查客户端是否断开的函数 + + Yields: + 响应数据块 + """ + try: + async for chunk in stream_generator: + if await is_disconnected(): + logger.warning(f"ID:{self.request_id} | Client disconnected") + ctx.status_code = 499 # Client Closed Request + ctx.error_message = "client_disconnected" + break + yield chunk + except asyncio.CancelledError: + ctx.status_code = 499 + ctx.error_message = "client_disconnected" + raise + except Exception as e: + ctx.status_code = 500 + ctx.error_message = str(e) + raise + + async def _cleanup( + self, + response_ctx: Any, + http_client: httpx.AsyncClient, + ) -> None: + """清理资源""" + try: + await response_ctx.__aexit__(None, None, None) + except Exception: + pass + try: + await http_client.aclose() + except Exception: + pass diff --git a/src/api/handlers/base/stream_telemetry.py b/src/api/handlers/base/stream_telemetry.py new file mode 100644 index 0000000..2d3d7ba --- /dev/null +++ b/src/api/handlers/base/stream_telemetry.py @@ -0,0 +1,293 @@ +""" +流式遥测记录器 - 从 ChatHandlerBase 提取的统计记录逻辑 + +职责: +1. 记录流式请求的成功/失败统计 +2. 更新 Usage 状态 +3. 更新候选记录状态 +""" + +import asyncio +from typing import Any, Dict, Optional + +from sqlalchemy.orm import Session + +from src.api.handlers.base.base_handler import MessageTelemetry +from src.api.handlers.base.stream_context import StreamContext +from src.config.settings import config +from src.core.logger import logger +from src.database import get_db +from src.models.database import ApiKey, User + + +class StreamTelemetryRecorder: + """ + 流式遥测记录器 + + 负责在流式请求完成后记录统计信息。 + 从 ChatHandlerBase 中提取的 _record_stream_stats 逻辑。 + """ + + def __init__( + self, + request_id: str, + user_id: str, + api_key_id: str, + client_ip: str, + format_id: str, + ): + """ + 初始化遥测记录器 + + Args: + request_id: 请求 ID + user_id: 用户 ID + api_key_id: API Key ID + client_ip: 客户端 IP + format_id: API 格式标识 + """ + self.request_id = request_id + self.user_id = user_id + self.api_key_id = api_key_id + self.client_ip = client_ip + self.format_id = format_id + + async def record_stream_stats( + self, + ctx: StreamContext, + original_headers: Dict[str, str], + original_request_body: Dict[str, Any], + response_time_ms: int, + ) -> None: + """ + 记录流式统计信息 + + Args: + ctx: 流式上下文 + original_headers: 原始请求头 + original_request_body: 原始请求体 + response_time_ms: 响应时间(毫秒) + """ + bg_db = None + + try: + await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭 + + if not ctx.provider_name: + await self._update_usage_status_on_error( + response_time_ms=response_time_ms, + error_message="Provider name not available", + ) + return + + db_gen = get_db() + bg_db = next(db_gen) + + try: + user = bg_db.query(User).filter(User.id == self.user_id).first() + api_key_obj = bg_db.query(ApiKey).filter(ApiKey.id == self.api_key_id).first() + + if not user or not api_key_obj: + logger.warning( + f"[{self.request_id}] User or ApiKey not found, updating status directly" + ) + await self._update_usage_status_directly( + bg_db, + status="completed" if ctx.is_success() else "failed", + response_time_ms=response_time_ms, + status_code=ctx.status_code, + ) + return + + bg_telemetry = MessageTelemetry( + bg_db, user, api_key_obj, self.request_id, self.client_ip + ) + + actual_request_body = ctx.provider_request_body or original_request_body + response_body = ctx.build_response_body(response_time_ms) + + if ctx.is_success(): + await self._record_success( + bg_telemetry, + ctx, + original_headers, + actual_request_body, + response_body, + response_time_ms, + ) + else: + await self._record_failure( + bg_telemetry, + ctx, + original_headers, + actual_request_body, + response_body, + response_time_ms, + ) + + # 更新候选记录状态 + await self._update_candidate_status(bg_db, ctx, response_time_ms) + + finally: + if bg_db: + bg_db.close() + + except Exception as e: + logger.exception("记录流式统计信息时出错") + await self._update_usage_status_on_error( + response_time_ms=response_time_ms, + error_message=f"记录统计信息失败: {str(e)[:200]}", + ) + + async def _record_success( + self, + telemetry: MessageTelemetry, + ctx: StreamContext, + original_headers: Dict[str, str], + actual_request_body: Dict[str, Any], + response_body: Dict[str, Any], + response_time_ms: int, + ) -> None: + """记录成功的请求""" + await telemetry.record_success( + provider=ctx.provider_name or "unknown", + model=ctx.model, + input_tokens=ctx.input_tokens, + output_tokens=ctx.output_tokens, + response_time_ms=response_time_ms, + status_code=ctx.status_code, + request_headers=original_headers, + request_body=actual_request_body, + response_headers=ctx.response_headers, + response_body=response_body, + cache_creation_tokens=ctx.cache_creation_tokens, + cache_read_tokens=ctx.cached_tokens, + is_stream=True, + provider_request_headers=ctx.provider_request_headers, + api_format=ctx.api_format, + provider_id=ctx.provider_id, + provider_endpoint_id=ctx.endpoint_id, + provider_api_key_id=ctx.key_id, + target_model=ctx.mapped_model, + ) + + logger.debug(f"{self.format_id} 流式响应完成") + logger.info(ctx.get_log_summary(self.request_id, response_time_ms)) + + async def _record_failure( + self, + telemetry: MessageTelemetry, + ctx: StreamContext, + original_headers: Dict[str, str], + actual_request_body: Dict[str, Any], + response_body: Dict[str, Any], + response_time_ms: int, + ) -> None: + """记录失败的请求""" + await telemetry.record_failure( + provider=ctx.provider_name or "unknown", + model=ctx.model, + response_time_ms=response_time_ms, + status_code=ctx.status_code, + error_message=ctx.error_message or f"HTTP {ctx.status_code}", + request_headers=original_headers, + request_body=actual_request_body, + is_stream=True, + api_format=ctx.api_format, + provider_request_headers=ctx.provider_request_headers, + input_tokens=ctx.input_tokens, + output_tokens=ctx.output_tokens, + cache_creation_tokens=ctx.cache_creation_tokens, + cache_read_tokens=ctx.cached_tokens, + response_body=response_body, + target_model=ctx.mapped_model, + ) + + logger.debug(f"{self.format_id} 流式响应中断") + log_summary = ctx.get_log_summary(self.request_id, response_time_ms) + # 对于失败日志,添加缓存信息 + logger.info(f"{log_summary} cache:{ctx.cached_tokens}") + + async def _update_candidate_status( + self, + db: Session, + ctx: StreamContext, + response_time_ms: int, + ) -> None: + """更新候选记录状态""" + if not ctx.attempt_id: + return + + from src.services.request.candidate import RequestCandidateService + + if ctx.is_success(): + RequestCandidateService.mark_candidate_success( + db=db, + candidate_id=ctx.attempt_id, + status_code=ctx.status_code, + latency_ms=response_time_ms, + extra_data={ + "stream_completed": True, + "data_count": ctx.data_count, + }, + ) + else: + error_type = "client_disconnected" if ctx.status_code == 499 else "stream_error" + RequestCandidateService.mark_candidate_failed( + db=db, + candidate_id=ctx.attempt_id, + error_type=error_type, + error_message=ctx.error_message or f"HTTP {ctx.status_code}", + status_code=ctx.status_code, + latency_ms=response_time_ms, + extra_data={ + "stream_completed": False, + "data_count": ctx.data_count, + }, + ) + + async def _update_usage_status_on_error( + self, + response_time_ms: int, + error_message: str, + ) -> None: + """在记录失败时更新 Usage 状态""" + try: + db_gen = get_db() + error_db = next(db_gen) + try: + await self._update_usage_status_directly( + error_db, + status="failed", + response_time_ms=response_time_ms, + status_code=500, + error_message=error_message, + ) + finally: + error_db.close() + except Exception as inner_e: + logger.error(f"[{self.request_id}] 更新 Usage 状态失败: {inner_e}") + + async def _update_usage_status_directly( + self, + db: Session, + status: str, + response_time_ms: int, + status_code: int = 200, + error_message: Optional[str] = None, + ) -> None: + """直接更新 Usage 表的状态字段""" + try: + from src.models.database import Usage + + usage = db.query(Usage).filter(Usage.request_id == self.request_id).first() + if usage: + setattr(usage, "status", status) + setattr(usage, "status_code", status_code) + setattr(usage, "response_time_ms", response_time_ms) + if error_message: + setattr(usage, "error_message", error_message) + db.commit() + logger.debug(f"[{self.request_id}] Usage 状态已更新: {status}") + except Exception as e: + logger.error(f"[{self.request_id}] 直接更新 Usage 状态失败: {e}") diff --git a/src/config/settings.py b/src/config/settings.py index 41d821d..d24d3b8 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -120,6 +120,23 @@ class Config: self.db_pool_recycle = int(os.getenv("DB_POOL_RECYCLE", "3600")) self.db_pool_warn_threshold = int(os.getenv("DB_POOL_WARN_THRESHOLD", "70")) + # 并发控制配置 + # CONCURRENCY_SLOT_TTL: 并发槽位 TTL(秒),防止死锁 + # CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 30%) + self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600")) + self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.3")) + + # HTTP 请求超时配置(秒) + self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0")) + self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0")) + self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0")) + + # 流式处理配置 + # STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误 + # STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭 + self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5")) + self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1")) + # 验证连接池配置 self._validate_pool_config() diff --git a/src/services/rate_limit/concurrency_manager.py b/src/services/rate_limit/concurrency_manager.py index 106509f..a968155 100644 --- a/src/services/rate_limit/concurrency_manager.py +++ b/src/services/rate_limit/concurrency_manager.py @@ -13,7 +13,7 @@ import asyncio import math import os from contextlib import asynccontextmanager -from datetime import timedelta +from datetime import timedelta # noqa: F401 - kept for potential future use from typing import Optional, Tuple import redis.asyncio as aioredis @@ -185,8 +185,8 @@ class ConcurrencyManager: key_id: str, key_max_concurrent: Optional[int], is_cached_user: bool = False, # 新增:是否是缓存用户 - cache_reservation_ratio: float = 0.3, # 新增:缓存预留比例 - ttl_seconds: int = 600, # 10分钟 TTL,防止死锁 + cache_reservation_ratio: Optional[float] = None, # 缓存预留比例,None 时从配置读取 + ttl_seconds: Optional[int] = None, # TTL 秒数,None 时从配置读取 ) -> bool: """ 尝试获取并发槽位(支持缓存用户优先级) @@ -197,8 +197,8 @@ class ConcurrencyManager: key_id: ProviderAPIKey ID key_max_concurrent: Key 最大并发数(None 表示不限制) is_cached_user: 是否是缓存用户(缓存用户可使用全部槽位) - cache_reservation_ratio: 缓存预留比例(默认30%,只对新用户生效) - ttl_seconds: TTL 秒数,防止异常情况下的死锁 + cache_reservation_ratio: 缓存预留比例,None 时从配置读取 + ttl_seconds: TTL 秒数,None 时从配置读取 Returns: 是否成功获取(True/False) @@ -209,6 +209,14 @@ class ConcurrencyManager: - 缓存用户最多使用: 10个槽位(全部) - 预留的3个槽位专门给缓存用户,保证他们的请求优先 """ + # 从配置读取默认值 + from src.config.settings import config + + if cache_reservation_ratio is None: + cache_reservation_ratio = config.cache_reservation_ratio + if ttl_seconds is None: + ttl_seconds = config.concurrency_slot_ttl + if self._redis is None: async with self._memory_lock: endpoint_count = self._memory_endpoint_counts.get(endpoint_id, 0) @@ -426,7 +434,7 @@ class ConcurrencyManager: key_id: str, key_max_concurrent: Optional[int], is_cached_user: bool = False, # 新增:是否是缓存用户 - cache_reservation_ratio: float = 0.3, # 新增:缓存预留比例 + cache_reservation_ratio: Optional[float] = None, # 缓存预留比例,None 时从配置读取 ): """ 并发控制上下文管理器(支持缓存用户优先级) @@ -441,6 +449,12 @@ class ConcurrencyManager: 如果获取失败,会抛出 ConcurrencyLimitError 异常 """ + # 从配置读取默认值 + from src.config.settings import config + + if cache_reservation_ratio is None: + cache_reservation_ratio = config.cache_reservation_ratio + # 尝试获取槽位(传递缓存用户参数) acquired = await self.acquire_slot( endpoint_id,