Files
Aether/src/api/handlers/base/cli_handler_base.py
fawney19 97425ac68f refactor: make stream smoothing parameters configurable and add models cache invalidation
- Move stream smoothing parameters (chunk_size, delay_ms) to database config
- Remove hardcoded stream smoothing constants from StreamProcessor
- Simplify dynamic delay calculation by using config values directly
- Add invalidate_models_list_cache() function to clear /v1/models endpoint cache
- Call cache invalidation on model create, update, delete, and bulk operations
- Update admin UI to allow runtime configuration of smoothing parameters
- Improve model listing freshness when models are modified
2025-12-19 11:03:46 +08:00

1691 lines
68 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
CLI Message Handler 通用基类
将 CLI 格式处理器的通用逻辑HTTP 请求、SSE 解析、统计记录)抽取到基类,
子类只需实现格式特定的事件解析逻辑。
设计目标:
1. 减少代码重复 - 原来每个 CLI Handler 900+ 行,抽取后子类只需 ~100 行
2. 统一错误处理 - 超时、空流、故障转移等逻辑集中管理
3. 简化新格式接入 - 只需实现 ResponseParser 和少量钩子方法
"""
import asyncio
import codecs
import json
import time
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Optional,
)
import httpx
from fastapi import BackgroundTasks
from fastapi.responses import JSONResponse, StreamingResponse
from sqlalchemy.orm import Session
from src.api.handlers.base.base_handler import (
BaseMessageHandler,
MessageTelemetry,
)
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.utils import build_sse_headers
# 直接从具体模块导入,避免循环依赖
from src.api.handlers.base.response_parser import (
ResponseParser,
StreamStats,
)
from src.core.exceptions import (
EmbeddedErrorException,
ProviderAuthException,
ProviderNotAvailableException,
ProviderRateLimitException,
ProviderTimeoutException,
)
from src.core.logger import logger
from src.database import get_db
from src.models.database import (
ApiKey,
Provider,
ProviderAPIKey,
ProviderEndpoint,
User,
)
from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser
class CliMessageHandlerBase(BaseMessageHandler):
"""
CLI 格式消息处理器基类
提供 CLI 格式(直接透传请求)的通用处理逻辑:
- 流式请求的 HTTP 连接管理
- SSE 事件解析框架
- 统计信息收集和记录
- 错误处理和故障转移
子类需要实现:
- get_response_parser(): 返回格式特定的响应解析器
- 可选覆盖 handle_sse_event() 自定义事件处理
"""
# 子类可覆盖的配置
FORMAT_ID: str = "UNKNOWN" # API 格式标识
DATA_TIMEOUT: int = 30 # 流数据超时时间(秒)
EMPTY_CHUNK_THRESHOLD: int = 10 # 空流检测的 chunk 阈值
def __init__(
self,
db: Session,
user: User,
api_key: ApiKey,
request_id: str,
client_ip: str,
user_agent: str,
start_time: float,
allowed_api_formats: Optional[list] = None,
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
):
allowed = allowed_api_formats or [self.FORMAT_ID]
super().__init__(
db=db,
user=user,
api_key=api_key,
request_id=request_id,
client_ip=client_ip,
user_agent=user_agent,
start_time=start_time,
allowed_api_formats=allowed,
adapter_detector=adapter_detector,
)
self._parser: Optional[ResponseParser] = None
self._request_builder = PassthroughRequestBuilder()
@property
def parser(self) -> ResponseParser:
"""获取响应解析器(懒加载)"""
if self._parser is None:
self._parser = self.get_response_parser()
return self._parser
def get_response_parser(self) -> ResponseParser:
"""
获取格式特定的响应解析器
子类可覆盖此方法提供自定义解析器,
默认从解析器注册表获取
"""
return get_parser_for_format(self.FORMAT_ID)
async def _get_mapped_model(
self,
source_model: str,
provider_id: str,
) -> Optional[str]:
"""
获取模型映射后的实际模型名
查找逻辑:
1. 直接通过 GlobalModel.name 匹配
2. 查找该 Provider 的 Model 实现
3. 使用 provider_model_name / provider_model_aliases 选择最终名称
Args:
source_model: 用户请求的模型名(必须是 GlobalModel.name
provider_id: Provider ID
Returns:
映射后的 Provider 模型名,如果没有找到映射则返回 None
"""
from src.services.model.mapper import ModelMapperMiddleware
mapper = ModelMapperMiddleware(self.db)
mapping = await mapper.get_mapping(source_model, provider_id)
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
if mapping and mapping.model:
# 使用 select_provider_model_name 支持别名功能
# 传入 api_key.id 作为 affinity_key实现相同用户稳定选择同一别名
# 传入 api_format 用于过滤适用的别名作用域
affinity_key = self.api_key.id if self.api_key else None
mapped_name = mapping.model.select_provider_model_name(
affinity_key, api_format=self.FORMAT_ID
)
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
return mapped_name
logger.debug(f"[CLI] 无模型映射,使用原始名称: {source_model}")
return None
def extract_model_from_request(
self,
request_body: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002 - 子类使用
) -> str:
"""
从请求中提取模型名 - 子类可覆盖
不同 API 格式的 model 位置不同:
- OpenAI/Claude: 在请求体中 request_body["model"]
- Gemini: 在 URL 路径中 path_params["model"]
子类应覆盖此方法实现各自的提取逻辑。
Args:
request_body: 请求体
path_params: URL 路径参数
Returns:
模型名,如果无法提取则返回 "unknown"
"""
# 默认实现:从请求体获取
model = request_body.get("model")
return str(model) if model else "unknown"
def apply_mapped_model(
self,
request_body: Dict[str, Any],
mapped_model: str, # noqa: ARG002 - 子类使用
) -> Dict[str, Any]:
"""
将映射后的模型名应用到请求体
基类默认实现:不修改请求体,保持原样透传。
子类应覆盖此方法实现各自的模型名替换逻辑。
Args:
request_body: 原始请求体
mapped_model: 映射后的模型名(子类使用)
Returns:
请求体(默认不修改)
"""
# 基类不修改请求体,子类覆盖此方法实现特定格式的处理
return request_body
def prepare_provider_request_body(
self,
request_body: Dict[str, Any],
) -> Dict[str, Any]:
"""
准备发送给 Provider 的请求体 - 子类可覆盖
在模型映射之后、发送请求之前调用,用于移除不需要发送给上游的字段。
例如 Gemini API 需要移除请求体中的 model 字段(因为 model 在 URL 路径中)。
Args:
request_body: 经过模型映射处理后的请求体
Returns:
准备好的请求体
"""
return request_body
def get_model_for_url(
self,
request_body: Dict[str, Any],
mapped_model: Optional[str],
) -> Optional[str]:
"""
获取用于 URL 路径的模型名
某些 API 格式(如 Gemini需要将 model 放入 URL 路径中。
子类应覆盖此方法返回正确的值。
Args:
request_body: 请求体
mapped_model: 映射后的模型名(如果有)
Returns:
用于 URL 路径的模型名,默认优先使用映射后的名称
"""
return mapped_model or request_body.get("model")
def _extract_response_metadata(
self,
response: Dict[str, Any],
) -> Dict[str, Any]:
"""
从响应中提取 Provider 特有的元数据 - 子类可覆盖
例如 Gemini 返回的 modelVersion 字段。
这些元数据会存储到 Usage.request_metadata 中。
Args:
response: Provider 返回的响应
Returns:
元数据字典,默认为空
"""
return {}
async def process_stream(
self,
original_request_body: Dict[str, Any],
original_headers: Dict[str, str],
query_params: Optional[Dict[str, str]] = None,
path_params: Optional[Dict[str, Any]] = None,
) -> StreamingResponse:
"""
处理流式请求
通用流程:
1. 创建流上下文
2. 定义请求函数(供 FallbackOrchestrator 调用)
3. 执行请求并返回 StreamingResponse
4. 后台任务记录统计信息
"""
logger.debug(f"开始流式响应处理 ({self.FORMAT_ID})")
# 使用子类实现的方法提取 model不同 API 格式的 model 位置不同)
model = self.extract_model_from_request(original_request_body, path_params)
# 创建流上下文
ctx = StreamContext(
model=model,
api_format=self.allowed_api_formats[0],
request_id=self.request_id,
user_id=self.user.id,
api_key_id=self.api_key.id,
)
# 定义请求函数
async def stream_request_func(
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
) -> AsyncGenerator[bytes, None]:
return await self._execute_stream_request(
ctx,
provider,
endpoint,
key,
original_request_body,
original_headers,
query_params,
)
try:
# 解析能力需求
capability_requirements = self._resolve_capability_requirements(
model_name=ctx.model,
request_headers=original_headers,
)
# 执行请求(通过 FallbackOrchestrator
(
stream_generator,
provider_name,
attempt_id,
_provider_id,
_endpoint_id,
_key_id,
) = await self.orchestrator.execute_with_fallback(
api_format=ctx.api_format,
model_name=ctx.model,
user_api_key=self.api_key,
request_func=stream_request_func,
request_id=self.request_id,
is_stream=True,
capability_requirements=capability_requirements or None,
)
ctx.attempt_id = attempt_id
# 创建后台任务记录统计
background_tasks = BackgroundTasks()
background_tasks.add_task(
self._record_stream_stats,
ctx,
original_headers,
original_request_body,
)
# 创建监控流
monitored_stream = self._create_monitored_stream(ctx, stream_generator)
return StreamingResponse(
monitored_stream,
media_type="text/event-stream",
headers=build_sse_headers(),
background=background_tasks,
)
except Exception as e:
self._log_request_error("流式请求失败", e)
await self._record_stream_failure(ctx, e, original_headers, original_request_body)
raise
async def _execute_stream_request(
self,
ctx: StreamContext,
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
original_request_body: Dict[str, Any],
original_headers: Dict[str, str],
query_params: Optional[Dict[str, str]] = None,
) -> AsyncGenerator[bytes, None]:
"""执行流式请求并返回流生成器"""
# 重置上下文状态(重试时清除之前的数据,避免累积)
ctx.parsed_chunks = []
ctx.chunk_count = 0
ctx.data_count = 0
ctx.has_completion = False
ctx._collected_text_parts = [] # 重置文本收集
ctx.input_tokens = 0
ctx.output_tokens = 0
ctx.cached_tokens = 0
ctx.cache_creation_tokens = 0
ctx.final_usage = None
ctx.final_response = None
ctx.response_id = None
ctx.response_metadata = {} # 重置 Provider 响应元数据
# 记录 Provider 信息
ctx.provider_name = str(provider.name)
ctx.provider_id = str(provider.id)
ctx.endpoint_id = str(endpoint.id)
ctx.key_id = str(key.id)
# 记录格式转换信息
ctx.provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
ctx.client_api_format = ctx.api_format # 已在 process_stream 中设置
# 获取模型映射(别名/映射 → 实际模型名)
mapped_model = await self._get_mapped_model(
source_model=ctx.model,
provider_id=str(provider.id),
)
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
if mapped_model:
ctx.mapped_model = mapped_model # 保存映射后的模型名,用于 Usage 记录
request_body = self.apply_mapped_model(original_request_body, mapped_model)
else:
request_body = original_request_body
# 准备发送给 Provider 的请求体(子类可覆盖以移除不需要的字段)
request_body = self.prepare_provider_request_body(request_body)
# 使用 RequestBuilder 构建请求体和请求头
# 注意mapped_model 已经应用到 request_body这里不再传递
provider_payload, provider_headers = self._request_builder.build(
request_body,
original_headers,
endpoint,
key,
is_stream=True,
)
# 保存发送给 Provider 的请求信息(用于调试和统计)
ctx.provider_request_headers = provider_headers
ctx.provider_request_body = provider_payload
# 获取用于 URL 的模型名(子类可覆盖此方法,如 Gemini 需要特殊处理)
# 使用 ctx.model 作为 fallback它已从 path_params 获取)
url_model = self.get_model_for_url(request_body, mapped_model) or ctx.model
url = build_provider_url(
endpoint,
query_params=query_params,
path_params={"model": url_model},
is_stream=True, # CLI handler 处理流式请求
)
# 配置超时
timeout_config = httpx.Timeout(
connect=10.0,
read=float(endpoint.timeout),
write=60.0, # 写入超时增加到60秒支持大请求体如包含图片的长对话
pool=10.0,
)
logger.debug(f" └─ [{self.request_id}] 发送流式请求: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"Key=***{key.api_key[-4:]}, "
f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=timeout_config,
)
try:
response_ctx = http_client.stream(
"POST", url, json=provider_payload, headers=provider_headers
)
stream_response = await response_ctx.__aenter__()
ctx.status_code = stream_response.status_code
ctx.response_headers = dict(stream_response.headers)
logger.debug(f" └─ 收到响应: status={stream_response.status_code}")
stream_response.raise_for_status()
# 使用字节流迭代器(避免 aiter_lines 的性能问题)
byte_iterator = stream_response.aiter_raw()
# 预读第一个数据块检测嵌套错误HTTP 200 但响应体包含错误)
prefetched_chunks = await self._prefetch_and_check_embedded_error(
byte_iterator, provider, endpoint, ctx
)
except httpx.HTTPStatusError as e:
error_text = await self._extract_error_text(e)
logger.error(f"Provider 返回错误状态: {e.response.status_code}\n Response: {error_text}")
await http_client.aclose()
raise
except EmbeddedErrorException:
# 嵌套错误需要触发重试,关闭连接后重新抛出
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
await http_client.aclose()
raise
except Exception:
await http_client.aclose()
raise
# 创建流生成器(带预读数据,使用同一个迭代器)
return self._create_response_stream_with_prefetch(
ctx,
byte_iterator,
response_ctx,
http_client,
prefetched_chunks,
)
async def _create_response_stream(
self,
ctx: StreamContext,
stream_response: httpx.Response,
response_ctx: Any,
http_client: httpx.AsyncClient,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器(使用字节流)"""
try:
sse_parser = SSEEventParser()
last_data_time = time.time()
streaming_status_updated = False
buffer = b""
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx)
async for chunk in stream_response.aiter_raw():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming_with_ctx(ctx)
streaming_status_updated = True
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
yield b"\n"
continue
ctx.chunk_count += 1
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return # 结束生成器
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
if ctx.data_count > 0:
last_data_time = time.time()
# 处理剩余事件
for event in sse_parser.flush():
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
# 检查是否收到数据
if ctx.data_count == 0:
# 流已开始,无法抛出异常进行故障转移
# 发送错误事件并记录日志
logger.warning(f"提供商 '{ctx.provider_name}' 返回空流式响应")
error_event = {
"type": "error",
"error": {
"type": "empty_response",
"message": f"提供商 '{ctx.provider_name}' 返回了空的流式响应",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
else:
logger.debug("流式数据转发完成")
except GeneratorExit:
raise
except httpx.StreamClosed:
if ctx.data_count == 0:
# 流已开始,发送错误事件而不是抛出异常
logger.warning(f"提供商 '{ctx.provider_name}' 流连接关闭且无数据")
error_event = {
"type": "error",
"error": {
"type": "stream_closed",
"message": f"提供商 '{ctx.provider_name}' 连接关闭且未返回数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
except httpx.RemoteProtocolError as e:
if ctx.data_count > 0:
error_event = {
"type": "error",
"error": {
"type": "connection_error",
"message": "上游连接意外关闭,部分响应已成功传输",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
else:
raise
finally:
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
try:
await http_client.aclose()
except Exception:
pass
async def _prefetch_and_check_embedded_error(
self,
byte_iterator: Any,
provider: Provider,
endpoint: ProviderEndpoint,
ctx: StreamContext,
) -> list:
"""
预读流的前几行,检测嵌套错误
某些 Provider如 Gemini可能返回 HTTP 200但在响应体中包含错误信息。
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
Args:
byte_iterator: 字节流迭代器
provider: Provider 对象
endpoint: Endpoint 对象
ctx: 流上下文
Returns:
预读的字节块列表(需要在后续流中先输出)
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
"""
prefetched_chunks: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误
buffer = b""
line_count = 0
should_stop = False
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
try:
# 获取对应格式的解析器
provider_format = ctx.provider_api_format
if provider_format:
try:
provider_parser = get_parser_for_format(provider_format)
except KeyError:
provider_parser = self.parser
else:
provider_parser = self.parser
async for chunk in byte_iterator:
prefetched_chunks.append(chunk)
buffer += chunk
# 尝试按行解析缓冲区
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
line_count += 1
normalized_line = line.rstrip("\r")
# 检测 HTML 响应base_url 配置错误的常见症状)
lower_line = normalized_line.lower()
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
)
if not normalized_line or normalized_line.startswith(":"):
# 空行或注释行,继续预读
if line_count >= max_prefetch_lines:
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
if data_str == "[DONE]":
should_stop = True
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
# 不是有效 JSON可能是部分数据继续
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
should_stop = True
break
if should_stop or line_count >= max_prefetch_lines:
break
except EmbeddedErrorException:
# 重新抛出嵌套错误
raise
except Exception as e:
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_chunks
async def _create_response_stream_with_prefetch(
self,
ctx: StreamContext,
byte_iterator: Any,
response_ctx: Any,
http_client: httpx.AsyncClient,
prefetched_chunks: list,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器(带预读数据,使用字节流)"""
try:
sse_parser = SSEEventParser()
last_data_time = time.time()
buffer = b""
first_yield = True # 标记是否是第一次 yield
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx)
# 在第一次输出数据前更新状态为 streaming
if prefetched_chunks:
self._update_usage_to_streaming_with_ctx(ctx)
# 先处理预读的字节块
for chunk in prefetched_chunks:
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield b"\n"
continue
ctx.chunk_count += 1
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
if ctx.data_count > 0:
last_data_time = time.time()
# 继续处理剩余的流数据(使用同一个迭代器)
async for chunk in byte_iterator:
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield b"\n"
continue
ctx.chunk_count += 1
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
if ctx.data_count > 0:
last_data_time = time.time()
# 处理剩余事件
flushed_events = sse_parser.flush()
for event in flushed_events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
# 检查是否收到数据
if ctx.data_count == 0:
# 空流通常意味着配置错误(如 base_url 指向了网页而非 API
logger.error(
f"提供商 '{ctx.provider_name}' 返回空流式响应 (收到 {ctx.chunk_count} 个非数据行), "
f"可能是 endpoint base_url 配置错误"
)
error_event = {
"type": "error",
"error": {
"type": "empty_response",
"message": f"提供商 '{ctx.provider_name}' 返回了空的流式响应 (收到 {ctx.chunk_count} 行非 SSE 数据),请检查 endpoint 的 base_url 配置是否指向了正确的 API 地址",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
else:
logger.debug("流式数据转发完成")
except GeneratorExit:
raise
except httpx.StreamClosed:
if ctx.data_count == 0:
logger.warning(f"提供商 '{ctx.provider_name}' 流连接关闭且无数据")
error_event = {
"type": "error",
"error": {
"type": "stream_closed",
"message": f"提供商 '{ctx.provider_name}' 连接关闭且未返回数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
except httpx.RemoteProtocolError:
if ctx.data_count > 0:
error_event = {
"type": "error",
"error": {
"type": "connection_error",
"message": "上游连接意外关闭,部分响应已成功传输",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
else:
raise
finally:
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
try:
await http_client.aclose()
except Exception:
pass
def _handle_sse_event(
self,
ctx: StreamContext,
event_name: Optional[str],
data_str: str,
) -> None:
"""
处理 SSE 事件
通用框架:解析 JSON、更新计数器
子类可覆盖 _process_event_data() 实现格式特定逻辑
"""
if not data_str:
return
if data_str == "[DONE]":
ctx.has_completion = True
return
try:
data = json.loads(data_str)
except json.JSONDecodeError:
return
ctx.data_count += 1
ctx.parsed_chunks.append(data)
if not isinstance(data, dict):
return
event_type = event_name or data.get("type", "")
# 调用格式特定的处理逻辑
self._process_event_data(ctx, event_type, data)
def _process_event_data(
self,
ctx: StreamContext,
event_type: str,
data: Dict[str, Any],
) -> None:
"""
处理解析后的事件数据 - 子类应覆盖此方法
默认实现使用 ResponseParser 提取 usage
"""
# 提取 response_id
if not ctx.response_id:
response_obj = data.get("response")
if isinstance(response_obj, dict) and response_obj.get("id"):
ctx.response_id = response_obj["id"]
elif "id" in data:
ctx.response_id = data["id"]
# 使用解析器提取 usage
usage = self.parser.extract_usage_from_response(data)
if usage and not ctx.final_usage:
ctx.input_tokens = usage.get("input_tokens", 0)
ctx.output_tokens = usage.get("output_tokens", 0)
ctx.cached_tokens = usage.get("cache_read_tokens", 0)
ctx.final_usage = usage
# 提取文本内容
text = self.parser.extract_text_content(data)
if text:
ctx.append_text(text)
# 检查完成事件
if event_type in ("response.completed", "message_stop"):
ctx.has_completion = True
response_obj = data.get("response")
if isinstance(response_obj, dict):
ctx.final_response = response_obj
def _finalize_stream_metadata(self, ctx: StreamContext) -> None:
"""
在记录统计前从 parsed_chunks 中提取额外的元数据 - 子类可覆盖
这是一个后处理钩子,在流传输完成后、记录 Usage 之前调用。
子类可以覆盖此方法从 ctx.parsed_chunks 中提取格式特定的元数据,
如 Gemini 的 modelVersion、token 统计等。
Args:
ctx: 流上下文,包含 parsed_chunks 和 response_metadata
"""
pass
async def _create_monitored_stream(
self,
ctx: StreamContext,
stream_generator: AsyncGenerator[bytes, None],
) -> AsyncGenerator[bytes, None]:
"""创建带监控的流生成器"""
try:
async for chunk in stream_generator:
yield chunk
except asyncio.CancelledError:
# 如果响应已完成,不标记为失败
if not ctx.has_completion:
ctx.status_code = 499
ctx.error_message = "Client disconnected"
raise
except httpx.TimeoutException as e:
ctx.status_code = 504
ctx.error_message = str(e)
raise
except Exception as e:
ctx.status_code = 500
ctx.error_message = str(e)
raise
async def _record_stream_stats(
self,
ctx: StreamContext,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
) -> None:
"""在流完成后记录统计信息"""
try:
# 使用 self.start_time 作为时间基准,与首字时间保持一致
# 注意:不要把统计延迟算进响应时间里
response_time_ms = int((time.time() - self.start_time) * 1000)
await asyncio.sleep(0.1)
if not ctx.provider_name:
logger.warning(f"[{ctx.request_id}] 流式请求失败,未选中提供商")
return
# Claude API 的 input_tokens 已经是非缓存部分,不需要再减去 cached_tokens
# 实际计费的输入 tokens = input_tokens + cache_creation_tokens缓存读取免费或折扣
actual_input_tokens = ctx.input_tokens
# 获取新的 DB session
db_gen = get_db()
bg_db = next(db_gen)
try:
from src.models.database import ApiKey as ApiKeyModel
user = bg_db.query(User).filter(User.id == ctx.user_id).first()
api_key = bg_db.query(ApiKeyModel).filter(ApiKeyModel.id == ctx.api_key_id).first()
if not user or not api_key:
return
bg_telemetry = MessageTelemetry(
bg_db, user, api_key, ctx.request_id, self.client_ip
)
response_body = {
"chunks": ctx.parsed_chunks,
"metadata": {
"stream": True,
"total_chunks": len(ctx.parsed_chunks),
"data_count": ctx.data_count,
"has_completion": ctx.has_completion,
"response_time_ms": response_time_ms,
},
}
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
actual_request_body = ctx.provider_request_body or original_request_body
# 根据状态码决定记录成功还是失败
# 499 = 客户端断开连接503 = 服务不可用(如流中断)
if ctx.status_code and ctx.status_code >= 400:
# 记录失败的 Usage但使用已收到的预估 token 信息(来自 message_start
# 这样即使请求中断,也能记录预估成本
await bg_telemetry.record_failure(
provider=ctx.provider_name or "unknown",
model=ctx.model,
response_time_ms=response_time_ms,
status_code=ctx.status_code,
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
api_format=ctx.api_format,
provider_request_headers=ctx.provider_request_headers,
# 预估 token 信息(来自 message_start 事件)
input_tokens=actual_input_tokens,
output_tokens=ctx.output_tokens,
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
response_body=response_body,
# 模型映射信息
target_model=ctx.mapped_model,
)
logger.debug(f"{self.FORMAT_ID} 流式响应中断")
# 简洁的请求失败摘要(包含预估 token 信息)
logger.info(f"[FAIL] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | "
f"{ctx.status_code} | in:{actual_input_tokens} out:{ctx.output_tokens} cache:{ctx.cached_tokens}")
else:
# 在记录统计前,允许子类从 parsed_chunks 中提取额外的元数据
self._finalize_stream_metadata(ctx)
total_cost = await bg_telemetry.record_success(
provider=ctx.provider_name,
model=ctx.model,
input_tokens=actual_input_tokens,
output_tokens=ctx.output_tokens,
response_time_ms=response_time_ms,
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
status_code=ctx.status_code,
request_headers=original_headers,
request_body=actual_request_body,
response_headers=ctx.response_headers,
response_body=response_body,
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
is_stream=True,
provider_request_headers=ctx.provider_request_headers,
api_format=ctx.api_format,
# Provider 侧追踪信息(用于记录真实成本)
provider_id=ctx.provider_id,
provider_endpoint_id=ctx.endpoint_id,
provider_api_key_id=ctx.key_id,
# 模型映射信息
target_model=ctx.mapped_model,
# Provider 响应元数据(如 Gemini 的 modelVersion
response_metadata=ctx.response_metadata if ctx.response_metadata else None,
)
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
# 简洁的请求完成摘要(两行格式)
line1 = (
f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name}"
)
if ctx.first_byte_time_ms:
line1 += f" | TTFB: {ctx.first_byte_time_ms}ms"
line2 = (
f" Total: {response_time_ms}ms | "
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}"
)
logger.info(f"{line1}\n{line2}")
# 更新候选记录的最终状态和延迟时间
# 注意RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
# 这里用流传输完成后的实际时间覆盖
if ctx.attempt_id:
from src.services.request.candidate import RequestCandidateService
# 根据状态码决定是成功还是失败
# 499 = 客户端断开连接,应标记为失败
# 503 = 服务不可用(如流中断),应标记为失败
if ctx.status_code and ctx.status_code >= 400:
RequestCandidateService.mark_candidate_failed(
db=bg_db,
candidate_id=ctx.attempt_id,
error_type="client_disconnected" if ctx.status_code == 499 else "stream_error",
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
status_code=ctx.status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": False,
"chunk_count": ctx.chunk_count,
"data_count": ctx.data_count,
},
)
else:
RequestCandidateService.mark_candidate_success(
db=bg_db,
candidate_id=ctx.attempt_id,
status_code=ctx.status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": True,
"chunk_count": ctx.chunk_count,
"data_count": ctx.data_count,
},
)
finally:
bg_db.close()
except Exception as e:
logger.exception("记录流式统计信息时出错")
async def _record_stream_failure(
self,
ctx: StreamContext,
error: Exception,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
) -> None:
"""记录流式请求失败"""
# 使用 self.start_time 作为时间基准,与首字时间保持一致
response_time_ms = int((time.time() - self.start_time) * 1000)
status_code = 503
if isinstance(error, ProviderAuthException):
status_code = 503
elif isinstance(error, ProviderRateLimitException):
status_code = 429
elif isinstance(error, ProviderTimeoutException):
status_code = 504
ctx.status_code = status_code
ctx.error_message = str(error)
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
actual_request_body = ctx.provider_request_body or original_request_body
await self.telemetry.record_failure(
provider=ctx.provider_name or "unknown",
model=ctx.model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=str(error),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
api_format=ctx.api_format,
provider_request_headers=ctx.provider_request_headers,
# 模型映射信息
target_model=ctx.mapped_model,
)
# _update_usage_to_streaming 方法已移至基类 BaseMessageHandler
async def process_sync(
self,
original_request_body: Dict[str, Any],
original_headers: Dict[str, str],
query_params: Optional[Dict[str, str]] = None,
path_params: Optional[Dict[str, Any]] = None,
) -> JSONResponse:
"""
处理非流式请求
通用流程:
1. 构建请求
2. 通过 FallbackOrchestrator 执行
3. 解析响应并记录统计
"""
logger.debug(f"开始非流式响应处理 ({self.FORMAT_ID})")
# 使用子类实现的方法提取 model不同 API 格式的 model 位置不同)
model = self.extract_model_from_request(original_request_body, path_params)
api_format = self.allowed_api_formats[0]
sync_start_time = time.time()
provider_name = None
response_json = None
status_code = 200
response_headers = {}
provider_api_format = "" # 用于追踪 Provider 的 API 格式
provider_request_headers = {} # 发送给 Provider 的请求头
provider_request_body = None # 实际发送给 Provider 的请求体
provider_id = None # Provider ID用于失败记录
endpoint_id = None # Endpoint ID用于失败记录
key_id = None # Key ID用于失败记录
mapped_model_result = None # 映射后的目标模型名(用于 Usage 记录)
response_metadata_result: Dict[str, Any] = {} # Provider 响应元数据
async def sync_request_func(
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
) -> Dict[str, Any]:
nonlocal provider_name, response_json, status_code, response_headers, provider_api_format, provider_request_headers, provider_request_body, mapped_model_result, response_metadata_result
provider_name = str(provider.name)
provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
# 获取模型映射(别名/映射 → 实际模型名)
mapped_model = await self._get_mapped_model(
source_model=model,
provider_id=str(provider.id),
)
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
if mapped_model:
mapped_model_result = mapped_model # 保存映射后的模型名,用于 Usage 记录
request_body = self.apply_mapped_model(original_request_body, mapped_model)
else:
request_body = original_request_body
# 准备发送给 Provider 的请求体(子类可覆盖以移除不需要的字段)
request_body = self.prepare_provider_request_body(request_body)
# 使用 RequestBuilder 构建请求体和请求头
# 注意mapped_model 已经应用到 request_body这里不再传递
provider_payload, provider_headers = self._request_builder.build(
request_body,
original_headers,
endpoint,
key,
is_stream=False,
)
# 保存发送给 Provider 的请求信息(用于调试和统计)
provider_request_headers = provider_headers
provider_request_body = provider_payload
# 获取用于 URL 的模型名(子类可覆盖此方法,如 Gemini 需要特殊处理)
url_model = self.get_model_for_url(request_body, mapped_model)
url = build_provider_url(
endpoint,
query_params=query_params,
path_params={"model": url_model},
is_stream=False, # 非流式请求
)
logger.info(f" └─ [{self.request_id}] 发送非流式请求: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"Key=***{key.api_key[-4:]}, "
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=httpx.Timeout(float(endpoint.timeout)),
)
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
status_code = resp.status_code
response_headers = dict(resp.headers)
if resp.status_code == 401:
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
elif resp.status_code == 429:
raise ProviderRateLimitException(
f"提供商速率限制: {provider.name}",
provider_name=str(provider.name),
response_headers=response_headers,
retry_after=int(resp.headers.get("retry-after", 0)) or None,
)
elif resp.status_code >= 500:
raise ProviderNotAvailableException(
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}"
)
elif 300 <= resp.status_code < 400:
redirect_url = resp.headers.get("location", "unknown")
raise ProviderNotAvailableException(
f"提供商配置错误: {provider.name}, 返回重定向 {resp.status_code} -> {redirect_url}"
)
elif resp.status_code != 200:
error_text = resp.text
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}, 错误: {error_text[:200]}"
)
# 安全解析 JSON 响应,处理可能的编码错误
try:
response_json = resp.json()
except (UnicodeDecodeError, json.JSONDecodeError) as e:
# 记录原始响应信息用于调试
content_type = resp.headers.get("content-type", "unknown")
content_encoding = resp.headers.get("content-encoding", "none")
logger.error(f"[{self.request_id}] 无法解析响应 JSON: {e}, "
f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, "
f"响应长度: {len(resp.content)} bytes")
raise ProviderNotAvailableException(
f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}"
)
# 提取 Provider 响应元数据(子类可覆盖)
response_metadata_result = self._extract_response_metadata(response_json)
return response_json if isinstance(response_json, dict) else {}
try:
# 解析能力需求
capability_requirements = self._resolve_capability_requirements(
model_name=model,
request_headers=original_headers,
)
(
result,
actual_provider_name,
attempt_id,
provider_id,
endpoint_id,
key_id,
) = await self.orchestrator.execute_with_fallback(
api_format=api_format,
model_name=model,
user_api_key=self.api_key,
request_func=sync_request_func,
request_id=self.request_id,
capability_requirements=capability_requirements or None,
)
provider_name = actual_provider_name
response_time_ms = int((time.time() - sync_start_time) * 1000)
# 确保 response_json 不为 None
if response_json is None:
response_json = {}
# 检查是否需要格式转换
if (
provider_api_format
and api_format
and provider_api_format.upper() != api_format.upper()
):
from src.api.handlers.base.format_converter_registry import converter_registry
try:
response_json = converter_registry.convert_response(
response_json,
provider_api_format,
api_format,
)
logger.debug(f"非流式响应格式转换完成: {provider_api_format} -> {api_format}")
except Exception as conv_err:
logger.warning(f"非流式响应格式转换失败,使用原始响应: {conv_err}")
# 使用解析器提取 usage
usage = self.parser.extract_usage_from_response(response_json)
input_tokens = usage.get("input_tokens", 0)
output_tokens = usage.get("output_tokens", 0)
cached_tokens = usage.get("cache_read_tokens", 0)
cache_creation_tokens = usage.get("cache_creation_tokens", 0)
# Claude API 的 input_tokens 已经是非缓存部分,不需要再减去 cached_tokens
actual_input_tokens = input_tokens
output_text = self.parser.extract_text_content(response_json)[:200]
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
actual_request_body = provider_request_body or original_request_body
total_cost = await self.telemetry.record_success(
provider=provider_name,
model=model,
input_tokens=actual_input_tokens,
output_tokens=output_tokens,
response_time_ms=response_time_ms,
status_code=status_code,
request_headers=original_headers,
request_body=actual_request_body,
response_headers=response_headers,
response_body=response_json,
cache_creation_tokens=cache_creation_tokens,
cache_read_tokens=cached_tokens,
is_stream=False,
provider_request_headers=provider_request_headers,
api_format=api_format,
# Provider 侧追踪信息(用于记录真实成本)
provider_id=provider_id,
provider_endpoint_id=endpoint_id,
provider_api_key_id=key_id,
# 模型映射信息
target_model=mapped_model_result,
# Provider 响应元数据(如 Gemini 的 modelVersion
response_metadata=response_metadata_result if response_metadata_result else None,
)
logger.info(f"{self.FORMAT_ID} 非流式响应处理完成")
return JSONResponse(status_code=status_code, content=response_json)
except Exception as e:
response_time_ms = int((time.time() - sync_start_time) * 1000)
status_code = 503
if isinstance(e, ProviderAuthException):
status_code = 503
elif isinstance(e, ProviderRateLimitException):
status_code = 429
elif isinstance(e, ProviderTimeoutException):
status_code = 504
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
actual_request_body = provider_request_body or original_request_body
await self.telemetry.record_failure(
provider=provider_name or "unknown",
model=model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=str(e),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=False,
api_format=api_format,
provider_request_headers=provider_request_headers,
# 模型映射信息
target_model=mapped_model_result,
)
raise
async def _extract_error_text(self, e: httpx.HTTPStatusError) -> str:
"""从 HTTP 错误中提取错误文本"""
try:
if hasattr(e.response, "is_stream_consumed") and not e.response.is_stream_consumed:
error_bytes = await e.response.aread()
for encoding in ["utf-8", "gbk", "latin1"]:
try:
return error_bytes.decode(encoding)[:500]
except (UnicodeDecodeError, LookupError):
continue
return error_bytes.decode("utf-8", errors="replace")[:500]
else:
return (
e.response.text[:500]
if hasattr(e.response, "_content")
else "Unable to read response"
)
except Exception as decode_error:
return f"Unable to read error response: {decode_error}"
def _needs_format_conversion(self, ctx: StreamContext) -> bool:
"""
检查是否需要进行格式转换
当 Provider 的 API 格式与客户端请求的 API 格式不同时,需要转换响应。
例如:客户端请求 Claude 格式,但 Provider 返回 OpenAI 格式。
"""
if not ctx.provider_api_format or not ctx.client_api_format:
return False
return ctx.provider_api_format.upper() != ctx.client_api_format.upper()
def _convert_sse_line(
self,
ctx: StreamContext,
line: str,
events: list,
) -> Optional[str]:
"""
将 SSE 行从 Provider 格式转换为客户端格式
Args:
ctx: 流上下文
line: 原始 SSE 行
events: 解析后的事件列表
Returns:
转换后的 SSE 行,如果无法转换则返回 None
"""
from src.api.handlers.base.format_converter_registry import converter_registry
# 如果是空行或特殊控制行,直接返回
if not line or line.strip() == "" or line == "data: [DONE]":
return line
# 如果不是 data 行,直接透传
if not line.startswith("data:"):
return line
# 提取 data 内容
data_content = line[5:].strip() # 去掉 "data:" 前缀
# 尝试解析 JSON
try:
data_obj = json.loads(data_content)
except json.JSONDecodeError:
# 无法解析,直接透传
return line
# 使用注册表进行格式转换
try:
converted_obj = converter_registry.convert_stream_chunk(
data_obj,
ctx.provider_api_format,
ctx.client_api_format,
)
# 重新构建 SSE 行
return f"data: {json.dumps(converted_obj, ensure_ascii=False)}"
except Exception as e:
logger.warning(f"格式转换失败,透传原始数据: {e}")
return line