Files
Aether/src/api/handlers/base/cli_handler_base.py

1631 lines
63 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
CLI Message Handler 通用基类
CLI 格式处理器的通用逻辑HTTP 请求SSE 解析统计记录抽取到基类
子类只需实现格式特定的事件解析逻辑
设计目标
1. 减少代码重复 - 原来每个 CLI Handler 900+ 抽取后子类只需 ~100
2. 统一错误处理 - 超时空流故障转移等逻辑集中管理
3. 简化新格式接入 - 只需实现 ResponseParser 和少量钩子方法
"""
import asyncio
import json
import time
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Optional,
Tuple,
)
import httpx
from fastapi import BackgroundTasks
from fastapi.responses import JSONResponse, StreamingResponse
from sqlalchemy.orm import Session
from src.api.handlers.base.base_handler import (
BaseMessageHandler,
MessageTelemetry,
)
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
# 直接从具体模块导入,避免循环依赖
from src.api.handlers.base.response_parser import (
ResponseParser,
StreamStats,
)
from src.core.exceptions import (
EmbeddedErrorException,
ProviderAuthException,
ProviderNotAvailableException,
ProviderRateLimitException,
ProviderTimeoutException,
)
from src.core.logger import logger
from src.database import get_db
from src.models.database import (
ApiKey,
Provider,
ProviderAPIKey,
ProviderEndpoint,
User,
)
from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser
@dataclass
class StreamContext:
"""流式请求的上下文信息"""
# 请求信息
model: str = "unknown" # 用户请求的原始模型名
mapped_model: Optional[str] = None # 映射后的目标模型名(如果发生了映射)
api_format: str = ""
request_id: str = ""
# 用户信息(提前提取避免 Session detached
user_id: int = 0
api_key_id: int = 0
# 统计信息
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0 # cache_read_input_tokens
cache_creation_tokens: int = 0 # cache_creation_input_tokens
collected_text: str = ""
response_id: Optional[str] = None
final_usage: Optional[Dict[str, Any]] = None
final_response: Optional[Dict[str, Any]] = None
parsed_chunks: list = field(default_factory=list)
# 流状态
start_time: float = field(default_factory=time.time)
chunk_count: int = 0
data_count: int = 0
has_completion: bool = False
# 响应信息
status_code: int = 200
response_headers: Dict[str, str] = field(default_factory=dict)
# 请求信息(发送给 Provider 的)
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None # 实际发送的请求体
# Provider 信息
provider_name: Optional[str] = None
provider_id: Optional[str] = None # Provider ID用于记录真实成本
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
attempt_id: Optional[str] = None
attempt_synced: bool = False
error_message: Optional[str] = None
# 格式转换信息
provider_api_format: str = "" # Provider 的 API 格式(用于响应转换)
client_api_format: str = "" # 客户端请求的 API 格式
# Provider 响应元数据(存储 provider 返回的额外信息,如 Gemini 的 modelVersion
response_metadata: Dict[str, Any] = field(default_factory=dict)
class CliMessageHandlerBase(BaseMessageHandler):
"""
CLI 格式消息处理器基类
提供 CLI 格式直接透传请求的通用处理逻辑
- 流式请求的 HTTP 连接管理
- SSE 事件解析框架
- 统计信息收集和记录
- 错误处理和故障转移
子类需要实现
- get_response_parser(): 返回格式特定的响应解析器
- 可选覆盖 handle_sse_event() 自定义事件处理
"""
# 子类可覆盖的配置
FORMAT_ID: str = "UNKNOWN" # API 格式标识
DATA_TIMEOUT: int = 30 # 流数据超时时间(秒)
EMPTY_CHUNK_THRESHOLD: int = 10 # 空流检测的 chunk 阈值
def __init__(
self,
db: Session,
user: User,
api_key: ApiKey,
request_id: str,
client_ip: str,
user_agent: str,
start_time: float,
allowed_api_formats: Optional[list] = None,
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
):
allowed = allowed_api_formats or [self.FORMAT_ID]
super().__init__(
db=db,
user=user,
api_key=api_key,
request_id=request_id,
client_ip=client_ip,
user_agent=user_agent,
start_time=start_time,
allowed_api_formats=allowed,
adapter_detector=adapter_detector,
)
self._parser: Optional[ResponseParser] = None
self._request_builder = PassthroughRequestBuilder()
@property
def parser(self) -> ResponseParser:
"""获取响应解析器(懒加载)"""
if self._parser is None:
self._parser = self.get_response_parser()
return self._parser
def get_response_parser(self) -> ResponseParser:
"""
获取格式特定的响应解析器
子类可覆盖此方法提供自定义解析器
默认从解析器注册表获取
"""
return get_parser_for_format(self.FORMAT_ID)
async def _get_mapped_model(
self,
source_model: str,
provider_id: str,
) -> Optional[str]:
"""
获取模型映射后的实际模型名
按优先级查找映射 别名 直接匹配 GlobalModel
Args:
source_model: 用户请求的模型名可能是别名
provider_id: Provider ID
Returns:
映射后的 provider_model_name如果没有找到映射则返回 None
"""
from src.services.model.mapper import ModelMapperMiddleware
mapper = ModelMapperMiddleware(self.db)
mapping = await mapper.get_mapping(source_model, provider_id)
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
if mapping and mapping.model:
mapped_name = str(mapping.model.provider_model_name)
2025-12-10 20:52:44 +08:00
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
return mapped_name
logger.debug(f"[CLI] 无模型映射,使用原始名称: {source_model}")
return None
def extract_model_from_request(
self,
request_body: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002 - 子类使用
) -> str:
"""
从请求中提取模型名 - 子类可覆盖
不同 API 格式的 model 位置不同
- OpenAI/Claude: 在请求体中 request_body["model"]
- Gemini: URL 路径中 path_params["model"]
子类应覆盖此方法实现各自的提取逻辑
Args:
request_body: 请求体
path_params: URL 路径参数
Returns:
模型名如果无法提取则返回 "unknown"
"""
# 默认实现:从请求体获取
model = request_body.get("model")
return str(model) if model else "unknown"
def apply_mapped_model(
self,
request_body: Dict[str, Any],
mapped_model: str, # noqa: ARG002 - 子类使用
) -> Dict[str, Any]:
"""
将映射后的模型名应用到请求体
基类默认实现不修改请求体保持原样透传
子类应覆盖此方法实现各自的模型名替换逻辑
Args:
request_body: 原始请求体
mapped_model: 映射后的模型名子类使用
Returns:
请求体默认不修改
"""
# 基类不修改请求体,子类覆盖此方法实现特定格式的处理
return request_body
def prepare_provider_request_body(
self,
request_body: Dict[str, Any],
) -> Dict[str, Any]:
"""
准备发送给 Provider 的请求体 - 子类可覆盖
在模型映射之后发送请求之前调用用于移除不需要发送给上游的字段
例如 Gemini API 需要移除请求体中的 model 字段因为 model URL 路径中
Args:
request_body: 经过模型映射处理后的请求体
Returns:
准备好的请求体
"""
return request_body
def get_model_for_url(
self,
request_body: Dict[str, Any],
mapped_model: Optional[str],
) -> Optional[str]:
"""
获取用于 URL 路径的模型名
某些 API 格式 Gemini需要将 model 放入 URL 路径中
子类应覆盖此方法返回正确的值
Args:
request_body: 请求体
mapped_model: 映射后的模型名如果有
Returns:
用于 URL 路径的模型名默认优先使用映射后的名称
"""
return mapped_model or request_body.get("model")
def _extract_response_metadata(
self,
response: Dict[str, Any],
) -> Dict[str, Any]:
"""
从响应中提取 Provider 特有的元数据 - 子类可覆盖
例如 Gemini 返回的 modelVersion 字段
这些元数据会存储到 Usage.request_metadata
Args:
response: Provider 返回的响应
Returns:
元数据字典默认为空
"""
return {}
async def process_stream(
self,
original_request_body: Dict[str, Any],
original_headers: Dict[str, str],
query_params: Optional[Dict[str, str]] = None,
path_params: Optional[Dict[str, Any]] = None,
) -> StreamingResponse:
"""
处理流式请求
通用流程
1. 创建流上下文
2. 定义请求函数 FallbackOrchestrator 调用
3. 执行请求并返回 StreamingResponse
4. 后台任务记录统计信息
"""
logger.debug(f"开始流式响应处理 ({self.FORMAT_ID})")
# 使用子类实现的方法提取 model不同 API 格式的 model 位置不同)
model = self.extract_model_from_request(original_request_body, path_params)
# 创建流上下文
ctx = StreamContext(
model=model,
api_format=self.allowed_api_formats[0],
request_id=self.request_id,
user_id=self.user.id,
api_key_id=self.api_key.id,
)
# 定义请求函数
async def stream_request_func(
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
) -> AsyncGenerator[bytes, None]:
2025-12-10 20:52:44 +08:00
return await self._execute_stream_request(
ctx,
provider,
endpoint,
key,
original_request_body,
original_headers,
query_params,
)
try:
# 解析能力需求
capability_requirements = self._resolve_capability_requirements(
model_name=ctx.model,
request_headers=original_headers,
)
# 执行请求(通过 FallbackOrchestrator
(
stream_generator,
provider_name,
attempt_id,
_provider_id,
_endpoint_id,
_key_id,
) = await self.orchestrator.execute_with_fallback(
api_format=ctx.api_format,
model_name=ctx.model,
user_api_key=self.api_key,
request_func=stream_request_func,
request_id=self.request_id,
is_stream=True,
capability_requirements=capability_requirements or None,
)
ctx.attempt_id = attempt_id
# 创建后台任务记录统计
background_tasks = BackgroundTasks()
background_tasks.add_task(
self._record_stream_stats,
ctx,
original_headers,
original_request_body,
)
# 创建监控流
monitored_stream = self._create_monitored_stream(ctx, stream_generator)
return StreamingResponse(
monitored_stream,
media_type="text/event-stream",
background=background_tasks,
)
except Exception as e:
# 对于已知的业务异常,只记录简洁的错误信息,不输出完整堆栈
from src.core.exceptions import (
ProviderException,
QuotaExceededException,
RateLimitException,
ModelNotSupportedException,
)
if isinstance(e, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
# 业务异常:简洁日志
logger.error(f"流式请求失败: [{type(e).__name__}] {e}")
else:
# 未知异常:完整堆栈
logger.exception(f"流式请求失败: {e}")
await self._record_stream_failure(ctx, e, original_headers, original_request_body)
raise
async def _execute_stream_request(
self,
ctx: StreamContext,
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
original_request_body: Dict[str, Any],
original_headers: Dict[str, str],
query_params: Optional[Dict[str, str]] = None,
) -> AsyncGenerator[bytes, None]:
"""执行流式请求并返回流生成器"""
# 重置上下文状态(重试时清除之前的数据,避免累积)
ctx.parsed_chunks = []
ctx.chunk_count = 0
ctx.data_count = 0
ctx.has_completion = False
ctx.collected_text = ""
ctx.input_tokens = 0
ctx.output_tokens = 0
ctx.cached_tokens = 0
ctx.cache_creation_tokens = 0
ctx.final_usage = None
ctx.final_response = None
ctx.response_id = None
ctx.response_metadata = {} # 重置 Provider 响应元数据
# 记录 Provider 信息
ctx.provider_name = str(provider.name)
ctx.provider_id = str(provider.id)
ctx.endpoint_id = str(endpoint.id)
ctx.key_id = str(key.id)
2025-12-10 20:52:44 +08:00
# 记录格式转换信息
ctx.provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
2025-12-10 20:52:44 +08:00
ctx.client_api_format = ctx.api_format # 已在 process_stream 中设置
# 获取模型映射(别名/映射 → 实际模型名)
mapped_model = await self._get_mapped_model(
source_model=ctx.model,
provider_id=str(provider.id),
2025-12-10 20:52:44 +08:00
)
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
if mapped_model:
ctx.mapped_model = mapped_model # 保存映射后的模型名,用于 Usage 记录
request_body = self.apply_mapped_model(original_request_body, mapped_model)
else:
request_body = original_request_body
# 准备发送给 Provider 的请求体(子类可覆盖以移除不需要的字段)
request_body = self.prepare_provider_request_body(request_body)
# 使用 RequestBuilder 构建请求体和请求头
# 注意mapped_model 已经应用到 request_body这里不再传递
provider_payload, provider_headers = self._request_builder.build(
request_body,
original_headers,
endpoint,
key,
is_stream=True,
)
# 保存发送给 Provider 的请求信息(用于调试和统计)
ctx.provider_request_headers = provider_headers
ctx.provider_request_body = provider_payload
# 获取用于 URL 的模型名(子类可覆盖此方法,如 Gemini 需要特殊处理)
# 使用 ctx.model 作为 fallback它已从 path_params 获取)
url_model = self.get_model_for_url(request_body, mapped_model) or ctx.model
url = build_provider_url(
endpoint,
query_params=query_params,
path_params={"model": url_model},
is_stream=True, # CLI handler 处理流式请求
)
# 配置超时
timeout_config = httpx.Timeout(
connect=10.0,
read=float(endpoint.timeout),
write=10.0,
pool=10.0,
)
logger.debug(f" └─ [{self.request_id}] 发送流式请求: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"Key=***{key.api_key[-4:]}, "
f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
try:
response_ctx = http_client.stream(
"POST", url, json=provider_payload, headers=provider_headers
)
stream_response = await response_ctx.__aenter__()
ctx.status_code = stream_response.status_code
ctx.response_headers = dict(stream_response.headers)
logger.debug(f" └─ 收到响应: status={stream_response.status_code}")
stream_response.raise_for_status()
# 创建行迭代器(只创建一次,后续会继续使用)
line_iterator = stream_response.aiter_lines()
# 预读第一个数据块检测嵌套错误HTTP 200 但响应体包含错误)
prefetched_lines = await self._prefetch_and_check_embedded_error(
line_iterator, provider, endpoint, ctx
)
except httpx.HTTPStatusError as e:
error_text = await self._extract_error_text(e)
logger.error(f"Provider 返回错误状态: {e.response.status_code}\n Response: {error_text}")
await http_client.aclose()
raise
except EmbeddedErrorException:
# 嵌套错误需要触发重试,关闭连接后重新抛出
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
await http_client.aclose()
raise
except Exception:
await http_client.aclose()
raise
# 创建流生成器(带预读数据,使用同一个迭代器)
return self._create_response_stream_with_prefetch(
ctx,
line_iterator,
response_ctx,
http_client,
prefetched_lines,
)
async def _create_response_stream(
self,
ctx: StreamContext,
stream_response: httpx.Response,
response_ctx: Any,
http_client: httpx.AsyncClient,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器"""
try:
sse_parser = SSEEventParser()
last_data_time = time.time()
streaming_status_updated = False
2025-12-10 20:52:44 +08:00
# 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx)
async for line in stream_response.aiter_lines():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming(ctx.request_id)
streaming_status_updated = True
2025-12-10 20:52:44 +08:00
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
2025-12-10 20:52:44 +08:00
)
yield b"\n"
continue
ctx.chunk_count += 1
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return # 结束生成器
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
2025-12-10 20:52:44 +08:00
)
if ctx.data_count > 0:
last_data_time = time.time()
# 处理剩余事件
for event in sse_parser.flush():
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
2025-12-10 20:52:44 +08:00
)
# 检查是否收到数据
if ctx.data_count == 0:
# 流已开始,无法抛出异常进行故障转移
# 发送错误事件并记录日志
logger.warning(f"提供商 '{ctx.provider_name}' 返回空流式响应")
error_event = {
"type": "error",
"error": {
"type": "empty_response",
"message": f"提供商 '{ctx.provider_name}' 返回了空的流式响应",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
else:
logger.debug("流式数据转发完成")
except GeneratorExit:
raise
except httpx.StreamClosed:
if ctx.data_count == 0:
# 流已开始,发送错误事件而不是抛出异常
logger.warning(f"提供商 '{ctx.provider_name}' 流连接关闭且无数据")
error_event = {
"type": "error",
"error": {
"type": "stream_closed",
"message": f"提供商 '{ctx.provider_name}' 连接关闭且未返回数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
except httpx.RemoteProtocolError as e:
if ctx.data_count > 0:
error_event = {
"type": "error",
"error": {
"type": "connection_error",
"message": "上游连接意外关闭,部分响应已成功传输",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
else:
raise
finally:
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
try:
await http_client.aclose()
except Exception:
pass
async def _prefetch_and_check_embedded_error(
self,
line_iterator: Any,
provider: Provider,
endpoint: ProviderEndpoint,
ctx: StreamContext,
) -> list:
"""
预读流的前几行检测嵌套错误
某些 Provider Gemini可能返回 HTTP 200但在响应体中包含错误信息
这种情况需要在流开始输出之前检测以便触发重试逻辑
同时检测 HTML 响应通常是 base_url 配置错误导致返回网页
Args:
line_iterator: 行迭代器aiter_lines() 返回的迭代器
provider: Provider 对象
endpoint: Endpoint 对象
ctx: 流上下文
Returns:
预读的行列表需要在后续流中先输出
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
ProviderNotAvailableException: 如果检测到 HTML 响应配置错误
"""
prefetched_lines: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误
try:
# 获取对应格式的解析器
provider_format = ctx.provider_api_format
if provider_format:
try:
provider_parser = get_parser_for_format(provider_format)
except KeyError:
provider_parser = self.parser
else:
provider_parser = self.parser
line_count = 0
async for line in line_iterator:
prefetched_lines.append(line)
line_count += 1
# 解析数据
normalized_line = line.rstrip("\r")
# 检测 HTML 响应base_url 配置错误的常见症状)
lower_line = normalized_line.lower()
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
)
if not normalized_line or normalized_line.startswith(":"):
# 空行或注释行,继续预读
if line_count >= max_prefetch_lines:
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
# 不是有效 JSON可能是部分数据继续
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=str(provider.name),
2025-12-10 20:52:44 +08:00
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
break
except EmbeddedErrorException:
# 重新抛出嵌套错误
raise
except Exception as e:
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_lines
async def _create_response_stream_with_prefetch(
self,
ctx: StreamContext,
line_iterator: Any,
response_ctx: Any,
http_client: httpx.AsyncClient,
prefetched_lines: list,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器(带预读数据)"""
try:
sse_parser = SSEEventParser()
last_data_time = time.time()
# 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx)
# 在第一次输出数据前更新状态为 streaming
if prefetched_lines:
self._update_usage_to_streaming(ctx.request_id)
2025-12-10 20:52:44 +08:00
# 先处理预读的数据
for line in prefetched_lines:
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
2025-12-10 20:52:44 +08:00
)
yield b"\n"
continue
ctx.chunk_count += 1
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
2025-12-10 20:52:44 +08:00
)
if ctx.data_count > 0:
last_data_time = time.time()
# 继续处理剩余的流数据(使用同一个迭代器)
async for line in line_iterator:
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
2025-12-10 20:52:44 +08:00
)
yield b"\n"
continue
ctx.chunk_count += 1
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
2025-12-10 20:52:44 +08:00
)
if ctx.data_count > 0:
last_data_time = time.time()
# 处理剩余事件
flushed_events = sse_parser.flush()
for event in flushed_events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
2025-12-10 20:52:44 +08:00
)
# 检查是否收到数据
if ctx.data_count == 0:
# 空流通常意味着配置错误(如 base_url 指向了网页而非 API
logger.error(
f"提供商 '{ctx.provider_name}' 返回空流式响应 (收到 {ctx.chunk_count} 个非数据行), "
f"可能是 endpoint base_url 配置错误"
)
error_event = {
"type": "error",
"error": {
"type": "empty_response",
"message": f"提供商 '{ctx.provider_name}' 返回了空的流式响应 (收到 {ctx.chunk_count} 行非 SSE 数据),请检查 endpoint 的 base_url 配置是否指向了正确的 API 地址",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
else:
logger.debug("流式数据转发完成")
except GeneratorExit:
raise
except httpx.StreamClosed:
if ctx.data_count == 0:
logger.warning(f"提供商 '{ctx.provider_name}' 流连接关闭且无数据")
error_event = {
"type": "error",
"error": {
"type": "stream_closed",
"message": f"提供商 '{ctx.provider_name}' 连接关闭且未返回数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
except httpx.RemoteProtocolError:
if ctx.data_count > 0:
error_event = {
"type": "error",
"error": {
"type": "connection_error",
"message": "上游连接意外关闭,部分响应已成功传输",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
else:
raise
finally:
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
try:
await http_client.aclose()
except Exception:
pass
def _handle_sse_event(
self,
ctx: StreamContext,
event_name: Optional[str],
data_str: str,
) -> None:
"""
处理 SSE 事件
通用框架解析 JSON更新计数器
子类可覆盖 _process_event_data() 实现格式特定逻辑
"""
if not data_str:
return
if data_str == "[DONE]":
ctx.has_completion = True
return
try:
data = json.loads(data_str)
except json.JSONDecodeError:
return
ctx.data_count += 1
ctx.parsed_chunks.append(data)
if not isinstance(data, dict):
return
event_type = event_name or data.get("type", "")
# 调用格式特定的处理逻辑
self._process_event_data(ctx, event_type, data)
def _process_event_data(
self,
ctx: StreamContext,
event_type: str,
data: Dict[str, Any],
) -> None:
"""
处理解析后的事件数据 - 子类应覆盖此方法
默认实现使用 ResponseParser 提取 usage
"""
# 提取 response_id
if not ctx.response_id:
response_obj = data.get("response")
if isinstance(response_obj, dict) and response_obj.get("id"):
ctx.response_id = response_obj["id"]
elif "id" in data:
ctx.response_id = data["id"]
# 使用解析器提取 usage
usage = self.parser.extract_usage_from_response(data)
if usage and not ctx.final_usage:
ctx.input_tokens = usage.get("input_tokens", 0)
ctx.output_tokens = usage.get("output_tokens", 0)
ctx.cached_tokens = usage.get("cache_read_tokens", 0)
ctx.final_usage = usage
# 提取文本内容
text = self.parser.extract_text_content(data)
if text:
ctx.collected_text += text
# 检查完成事件
if event_type in ("response.completed", "message_stop"):
ctx.has_completion = True
response_obj = data.get("response")
if isinstance(response_obj, dict):
ctx.final_response = response_obj
def _finalize_stream_metadata(self, ctx: StreamContext) -> None:
"""
在记录统计前从 parsed_chunks 中提取额外的元数据 - 子类可覆盖
这是一个后处理钩子在流传输完成后记录 Usage 之前调用
子类可以覆盖此方法从 ctx.parsed_chunks 中提取格式特定的元数据
Gemini modelVersiontoken 统计等
Args:
ctx: 流上下文包含 parsed_chunks response_metadata
"""
pass
async def _create_monitored_stream(
self,
ctx: StreamContext,
stream_generator: AsyncGenerator[bytes, None],
) -> AsyncGenerator[bytes, None]:
"""创建带监控的流生成器"""
try:
async for chunk in stream_generator:
yield chunk
except asyncio.CancelledError:
ctx.status_code = 499
ctx.error_message = "Client disconnected"
raise
except httpx.TimeoutException as e:
ctx.status_code = 504
ctx.error_message = str(e)
raise
except Exception as e:
ctx.status_code = 500
ctx.error_message = str(e)
raise
async def _record_stream_stats(
self,
ctx: StreamContext,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
) -> None:
"""在流完成后记录统计信息"""
try:
await asyncio.sleep(0.1)
response_time_ms = int((time.time() - ctx.start_time) * 1000)
if not ctx.provider_name:
logger.warning(f"[{ctx.request_id}] 流式请求失败,未选中提供商")
return
# Claude API 的 input_tokens 已经是非缓存部分,不需要再减去 cached_tokens
# 实际计费的输入 tokens = input_tokens + cache_creation_tokens缓存读取免费或折扣
actual_input_tokens = ctx.input_tokens
# 获取新的 DB session
db_gen = get_db()
bg_db = next(db_gen)
try:
from src.models.database import ApiKey as ApiKeyModel
user = bg_db.query(User).filter(User.id == ctx.user_id).first()
api_key = bg_db.query(ApiKeyModel).filter(ApiKeyModel.id == ctx.api_key_id).first()
if not user or not api_key:
return
bg_telemetry = MessageTelemetry(
bg_db, user, api_key, ctx.request_id, self.client_ip
)
response_body = {
"chunks": ctx.parsed_chunks,
"metadata": {
"stream": True,
"total_chunks": len(ctx.parsed_chunks),
"data_count": ctx.data_count,
"has_completion": ctx.has_completion,
"response_time_ms": response_time_ms,
},
}
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
actual_request_body = ctx.provider_request_body or original_request_body
# 根据状态码决定记录成功还是失败
# 499 = 客户端断开连接503 = 服务不可用(如流中断)
if ctx.status_code and ctx.status_code >= 400:
# 记录失败的 Usage但使用已收到的预估 token 信息(来自 message_start
# 这样即使请求中断,也能记录预估成本
await bg_telemetry.record_failure(
provider=ctx.provider_name or "unknown",
model=ctx.model,
response_time_ms=response_time_ms,
status_code=ctx.status_code,
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
api_format=ctx.api_format,
provider_request_headers=ctx.provider_request_headers,
# 预估 token 信息(来自 message_start 事件)
input_tokens=actual_input_tokens,
output_tokens=ctx.output_tokens,
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
response_body=response_body,
# 模型映射信息
target_model=ctx.mapped_model,
)
logger.debug(f"{self.FORMAT_ID} 流式响应中断")
# 简洁的请求失败摘要(包含预估 token 信息)
logger.info(f"[FAIL] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | "
f"{ctx.status_code} | in:{actual_input_tokens} out:{ctx.output_tokens} cache:{ctx.cached_tokens}")
else:
# 在记录统计前,允许子类从 parsed_chunks 中提取额外的元数据
self._finalize_stream_metadata(ctx)
total_cost = await bg_telemetry.record_success(
provider=ctx.provider_name,
model=ctx.model,
input_tokens=actual_input_tokens,
output_tokens=ctx.output_tokens,
response_time_ms=response_time_ms,
status_code=ctx.status_code,
request_headers=original_headers,
request_body=actual_request_body,
response_headers=ctx.response_headers,
response_body=response_body,
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
is_stream=True,
provider_request_headers=ctx.provider_request_headers,
api_format=ctx.api_format,
# Provider 侧追踪信息(用于记录真实成本)
provider_id=ctx.provider_id,
provider_endpoint_id=ctx.endpoint_id,
provider_api_key_id=ctx.key_id,
# 模型映射信息
target_model=ctx.mapped_model,
# Provider 响应元数据(如 Gemini 的 modelVersion
response_metadata=ctx.response_metadata if ctx.response_metadata else None,
)
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
# 简洁的请求完成摘要
logger.info(f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | "
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}")
# 更新候选记录的最终状态和延迟时间
# 注意RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
# 这里用流传输完成后的实际时间覆盖
if ctx.attempt_id:
from src.services.request.candidate import RequestCandidateService
# 根据状态码决定是成功还是失败
# 499 = 客户端断开连接,应标记为失败
# 503 = 服务不可用(如流中断),应标记为失败
if ctx.status_code and ctx.status_code >= 400:
RequestCandidateService.mark_candidate_failed(
db=bg_db,
candidate_id=ctx.attempt_id,
error_type="client_disconnected" if ctx.status_code == 499 else "stream_error",
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
status_code=ctx.status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": False,
"chunk_count": ctx.chunk_count,
"data_count": ctx.data_count,
},
)
else:
RequestCandidateService.mark_candidate_success(
db=bg_db,
candidate_id=ctx.attempt_id,
status_code=ctx.status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": True,
"chunk_count": ctx.chunk_count,
"data_count": ctx.data_count,
},
)
finally:
bg_db.close()
except Exception as e:
logger.exception("记录流式统计信息时出错")
async def _record_stream_failure(
self,
ctx: StreamContext,
error: Exception,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
) -> None:
"""记录流式请求失败"""
response_time_ms = int((time.time() - ctx.start_time) * 1000)
status_code = 503
if isinstance(error, ProviderAuthException):
status_code = 503
elif isinstance(error, ProviderRateLimitException):
status_code = 429
elif isinstance(error, ProviderTimeoutException):
status_code = 504
ctx.status_code = status_code
ctx.error_message = str(error)
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
actual_request_body = ctx.provider_request_body or original_request_body
await self.telemetry.record_failure(
provider=ctx.provider_name or "unknown",
model=ctx.model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=str(error),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
api_format=ctx.api_format,
provider_request_headers=ctx.provider_request_headers,
# 模型映射信息
target_model=ctx.mapped_model,
)
# _update_usage_to_streaming 方法已移至基类 BaseMessageHandler
2025-12-10 20:52:44 +08:00
async def process_sync(
self,
original_request_body: Dict[str, Any],
original_headers: Dict[str, str],
query_params: Optional[Dict[str, str]] = None,
path_params: Optional[Dict[str, Any]] = None,
) -> JSONResponse:
"""
处理非流式请求
通用流程
1. 构建请求
2. 通过 FallbackOrchestrator 执行
3. 解析响应并记录统计
"""
logger.debug(f"开始非流式响应处理 ({self.FORMAT_ID})")
# 使用子类实现的方法提取 model不同 API 格式的 model 位置不同)
model = self.extract_model_from_request(original_request_body, path_params)
api_format = self.allowed_api_formats[0]
sync_start_time = time.time()
provider_name = None
response_json = None
status_code = 200
response_headers = {}
provider_api_format = "" # 用于追踪 Provider 的 API 格式
provider_request_headers = {} # 发送给 Provider 的请求头
provider_request_body = None # 实际发送给 Provider 的请求体
provider_id = None # Provider ID用于失败记录
endpoint_id = None # Endpoint ID用于失败记录
key_id = None # Key ID用于失败记录
mapped_model_result = None # 映射后的目标模型名(用于 Usage 记录)
response_metadata_result: Dict[str, Any] = {} # Provider 响应元数据
async def sync_request_func(
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
) -> Dict[str, Any]:
2025-12-10 20:52:44 +08:00
nonlocal provider_name, response_json, status_code, response_headers, provider_api_format, provider_request_headers, provider_request_body, mapped_model_result, response_metadata_result
provider_name = str(provider.name)
provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
2025-12-10 20:52:44 +08:00
# 获取模型映射(别名/映射 → 实际模型名)
mapped_model = await self._get_mapped_model(
source_model=model,
provider_id=str(provider.id),
2025-12-10 20:52:44 +08:00
)
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
if mapped_model:
mapped_model_result = mapped_model # 保存映射后的模型名,用于 Usage 记录
request_body = self.apply_mapped_model(original_request_body, mapped_model)
else:
request_body = original_request_body
# 准备发送给 Provider 的请求体(子类可覆盖以移除不需要的字段)
request_body = self.prepare_provider_request_body(request_body)
# 使用 RequestBuilder 构建请求体和请求头
# 注意mapped_model 已经应用到 request_body这里不再传递
provider_payload, provider_headers = self._request_builder.build(
request_body,
original_headers,
endpoint,
key,
is_stream=False,
)
# 保存发送给 Provider 的请求信息(用于调试和统计)
provider_request_headers = provider_headers
provider_request_body = provider_payload
# 获取用于 URL 的模型名(子类可覆盖此方法,如 Gemini 需要特殊处理)
url_model = self.get_model_for_url(request_body, mapped_model)
url = build_provider_url(
endpoint,
query_params=query_params,
path_params={"model": url_model},
is_stream=False, # 非流式请求
)
logger.info(f" └─ [{self.request_id}] 发送非流式请求: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"Key=***{key.api_key[-4:]}, "
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
async with httpx.AsyncClient(
timeout=float(endpoint.timeout),
follow_redirects=True,
) as http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
status_code = resp.status_code
response_headers = dict(resp.headers)
if resp.status_code == 401:
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
elif resp.status_code == 429:
raise ProviderRateLimitException(
f"提供商速率限制: {provider.name}",
provider_name=str(provider.name),
2025-12-10 20:52:44 +08:00
response_headers=response_headers,
retry_after=int(resp.headers.get("retry-after", 0)) or None,
)
elif resp.status_code >= 500:
raise ProviderNotAvailableException(
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}"
)
elif 300 <= resp.status_code < 400:
redirect_url = resp.headers.get("location", "unknown")
raise ProviderNotAvailableException(
f"提供商配置错误: {provider.name}, 返回重定向 {resp.status_code} -> {redirect_url}"
)
elif resp.status_code != 200:
error_text = resp.text
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}, 错误: {error_text[:200]}"
)
# 安全解析 JSON 响应,处理可能的编码错误
try:
response_json = resp.json()
except (UnicodeDecodeError, json.JSONDecodeError) as e:
# 记录原始响应信息用于调试
content_type = resp.headers.get("content-type", "unknown")
content_encoding = resp.headers.get("content-encoding", "none")
logger.error(f"[{self.request_id}] 无法解析响应 JSON: {e}, "
f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, "
f"响应长度: {len(resp.content)} bytes")
raise ProviderNotAvailableException(
f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}"
)
# 提取 Provider 响应元数据(子类可覆盖)
response_metadata_result = self._extract_response_metadata(response_json)
return response_json if isinstance(response_json, dict) else {}
2025-12-10 20:52:44 +08:00
try:
# 解析能力需求
capability_requirements = self._resolve_capability_requirements(
model_name=model,
request_headers=original_headers,
)
(
result,
actual_provider_name,
attempt_id,
provider_id,
endpoint_id,
key_id,
) = await self.orchestrator.execute_with_fallback(
api_format=api_format,
model_name=model,
user_api_key=self.api_key,
request_func=sync_request_func,
request_id=self.request_id,
capability_requirements=capability_requirements or None,
)
provider_name = actual_provider_name
response_time_ms = int((time.time() - sync_start_time) * 1000)
# 确保 response_json 不为 None
if response_json is None:
response_json = {}
2025-12-10 20:52:44 +08:00
# 检查是否需要格式转换
if (
provider_api_format
and api_format
and provider_api_format.upper() != api_format.upper()
):
from src.api.handlers.base.format_converter_registry import converter_registry
try:
response_json = converter_registry.convert_response(
response_json,
provider_api_format,
api_format,
)
logger.debug(f"非流式响应格式转换完成: {provider_api_format} -> {api_format}")
except Exception as conv_err:
logger.warning(f"非流式响应格式转换失败,使用原始响应: {conv_err}")
# 使用解析器提取 usage
usage = self.parser.extract_usage_from_response(response_json)
input_tokens = usage.get("input_tokens", 0)
output_tokens = usage.get("output_tokens", 0)
cached_tokens = usage.get("cache_read_tokens", 0)
cache_creation_tokens = usage.get("cache_creation_tokens", 0)
# Claude API 的 input_tokens 已经是非缓存部分,不需要再减去 cached_tokens
actual_input_tokens = input_tokens
output_text = self.parser.extract_text_content(response_json)[:200]
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
actual_request_body = provider_request_body or original_request_body
total_cost = await self.telemetry.record_success(
provider=provider_name,
model=model,
input_tokens=actual_input_tokens,
output_tokens=output_tokens,
response_time_ms=response_time_ms,
status_code=status_code,
request_headers=original_headers,
request_body=actual_request_body,
response_headers=response_headers,
response_body=response_json,
cache_creation_tokens=cache_creation_tokens,
cache_read_tokens=cached_tokens,
is_stream=False,
provider_request_headers=provider_request_headers,
api_format=api_format,
# Provider 侧追踪信息(用于记录真实成本)
provider_id=provider_id,
provider_endpoint_id=endpoint_id,
provider_api_key_id=key_id,
# 模型映射信息
target_model=mapped_model_result,
# Provider 响应元数据(如 Gemini 的 modelVersion
response_metadata=response_metadata_result if response_metadata_result else None,
)
logger.info(f"{self.FORMAT_ID} 非流式响应处理完成")
return JSONResponse(status_code=status_code, content=response_json)
except Exception as e:
response_time_ms = int((time.time() - sync_start_time) * 1000)
status_code = 503
if isinstance(e, ProviderAuthException):
status_code = 503
elif isinstance(e, ProviderRateLimitException):
status_code = 429
elif isinstance(e, ProviderTimeoutException):
status_code = 504
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
actual_request_body = provider_request_body or original_request_body
await self.telemetry.record_failure(
provider=provider_name or "unknown",
model=model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=str(e),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=False,
api_format=api_format,
provider_request_headers=provider_request_headers,
# 模型映射信息
target_model=mapped_model_result,
)
raise
async def _extract_error_text(self, e: httpx.HTTPStatusError) -> str:
"""从 HTTP 错误中提取错误文本"""
try:
if hasattr(e.response, "is_stream_consumed") and not e.response.is_stream_consumed:
error_bytes = await e.response.aread()
for encoding in ["utf-8", "gbk", "latin1"]:
try:
return error_bytes.decode(encoding)[:500]
except (UnicodeDecodeError, LookupError):
continue
return error_bytes.decode("utf-8", errors="replace")[:500]
else:
return (
e.response.text[:500]
if hasattr(e.response, "_content")
else "Unable to read response"
)
except Exception as decode_error:
return f"Unable to read error response: {decode_error}"
def _needs_format_conversion(self, ctx: StreamContext) -> bool:
"""
检查是否需要进行格式转换
Provider API 格式与客户端请求的 API 格式不同时需要转换响应
例如客户端请求 Claude 格式 Provider 返回 OpenAI 格式
"""
if not ctx.provider_api_format or not ctx.client_api_format:
return False
return ctx.provider_api_format.upper() != ctx.client_api_format.upper()
def _convert_sse_line(
self,
ctx: StreamContext,
line: str,
events: list,
) -> Optional[str]:
"""
SSE 行从 Provider 格式转换为客户端格式
Args:
ctx: 流上下文
line: 原始 SSE
events: 解析后的事件列表
Returns:
转换后的 SSE 如果无法转换则返回 None
"""
from src.api.handlers.base.format_converter_registry import converter_registry
# 如果是空行或特殊控制行,直接返回
if not line or line.strip() == "" or line == "data: [DONE]":
return line
# 如果不是 data 行,直接透传
if not line.startswith("data:"):
return line
# 提取 data 内容
data_content = line[5:].strip() # 去掉 "data:" 前缀
# 尝试解析 JSON
try:
data_obj = json.loads(data_content)
except json.JSONDecodeError:
# 无法解析,直接透传
return line
# 使用注册表进行格式转换
try:
converted_obj = converter_registry.convert_stream_chunk(
data_obj,
ctx.provider_api_format,
ctx.client_api_format,
)
# 重新构建 SSE 行
return f"data: {json.dumps(converted_obj, ensure_ascii=False)}"
except Exception as e:
logger.warning(f"格式转换失败,透传原始数据: {e}")
return line