mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
- 为 BaseMessageHandler 和 MessageTelemetry 添加完整类型注解 - 新增 _update_usage_to_streaming 方法,异步更新 Usage 状态为 streaming - 优化 chat/cli handler 的类型提示,提升代码可维护性 - 修复类型检查警告,确保 mypy 通过
1631 lines
63 KiB
Python
1631 lines
63 KiB
Python
"""
|
||
CLI Message Handler 通用基类
|
||
|
||
将 CLI 格式处理器的通用逻辑(HTTP 请求、SSE 解析、统计记录)抽取到基类,
|
||
子类只需实现格式特定的事件解析逻辑。
|
||
|
||
设计目标:
|
||
1. 减少代码重复 - 原来每个 CLI Handler 900+ 行,抽取后子类只需 ~100 行
|
||
2. 统一错误处理 - 超时、空流、故障转移等逻辑集中管理
|
||
3. 简化新格式接入 - 只需实现 ResponseParser 和少量钩子方法
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import time
|
||
from abc import abstractmethod
|
||
from dataclasses import dataclass, field
|
||
from typing import (
|
||
Any,
|
||
AsyncGenerator,
|
||
Callable,
|
||
Dict,
|
||
Optional,
|
||
Tuple,
|
||
)
|
||
|
||
import httpx
|
||
from fastapi import BackgroundTasks
|
||
from fastapi.responses import JSONResponse, StreamingResponse
|
||
from sqlalchemy.orm import Session
|
||
|
||
from src.api.handlers.base.base_handler import (
|
||
BaseMessageHandler,
|
||
MessageTelemetry,
|
||
)
|
||
from src.api.handlers.base.parsers import get_parser_for_format
|
||
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
|
||
|
||
# 直接从具体模块导入,避免循环依赖
|
||
from src.api.handlers.base.response_parser import (
|
||
ResponseParser,
|
||
StreamStats,
|
||
)
|
||
from src.core.exceptions import (
|
||
EmbeddedErrorException,
|
||
ProviderAuthException,
|
||
ProviderNotAvailableException,
|
||
ProviderRateLimitException,
|
||
ProviderTimeoutException,
|
||
)
|
||
from src.core.logger import logger
|
||
from src.database import get_db
|
||
from src.models.database import (
|
||
ApiKey,
|
||
Provider,
|
||
ProviderAPIKey,
|
||
ProviderEndpoint,
|
||
User,
|
||
)
|
||
from src.services.provider.transport import build_provider_url
|
||
from src.utils.sse_parser import SSEEventParser
|
||
|
||
|
||
|
||
@dataclass
|
||
class StreamContext:
|
||
"""流式请求的上下文信息"""
|
||
|
||
# 请求信息
|
||
model: str = "unknown" # 用户请求的原始模型名
|
||
mapped_model: Optional[str] = None # 映射后的目标模型名(如果发生了映射)
|
||
api_format: str = ""
|
||
request_id: str = ""
|
||
|
||
# 用户信息(提前提取避免 Session detached)
|
||
user_id: int = 0
|
||
api_key_id: int = 0
|
||
|
||
# 统计信息
|
||
input_tokens: int = 0
|
||
output_tokens: int = 0
|
||
cached_tokens: int = 0 # cache_read_input_tokens
|
||
cache_creation_tokens: int = 0 # cache_creation_input_tokens
|
||
collected_text: str = ""
|
||
response_id: Optional[str] = None
|
||
final_usage: Optional[Dict[str, Any]] = None
|
||
final_response: Optional[Dict[str, Any]] = None
|
||
parsed_chunks: list = field(default_factory=list)
|
||
|
||
# 流状态
|
||
start_time: float = field(default_factory=time.time)
|
||
chunk_count: int = 0
|
||
data_count: int = 0
|
||
has_completion: bool = False
|
||
|
||
# 响应信息
|
||
status_code: int = 200
|
||
response_headers: Dict[str, str] = field(default_factory=dict)
|
||
|
||
# 请求信息(发送给 Provider 的)
|
||
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
||
provider_request_body: Optional[Dict[str, Any]] = None # 实际发送的请求体
|
||
|
||
# Provider 信息
|
||
provider_name: Optional[str] = None
|
||
provider_id: Optional[str] = None # Provider ID(用于记录真实成本)
|
||
endpoint_id: Optional[str] = None
|
||
key_id: Optional[str] = None
|
||
attempt_id: Optional[str] = None
|
||
attempt_synced: bool = False
|
||
error_message: Optional[str] = None
|
||
|
||
# 格式转换信息
|
||
provider_api_format: str = "" # Provider 的 API 格式(用于响应转换)
|
||
client_api_format: str = "" # 客户端请求的 API 格式
|
||
|
||
# Provider 响应元数据(存储 provider 返回的额外信息,如 Gemini 的 modelVersion)
|
||
response_metadata: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
class CliMessageHandlerBase(BaseMessageHandler):
|
||
"""
|
||
CLI 格式消息处理器基类
|
||
|
||
提供 CLI 格式(直接透传请求)的通用处理逻辑:
|
||
- 流式请求的 HTTP 连接管理
|
||
- SSE 事件解析框架
|
||
- 统计信息收集和记录
|
||
- 错误处理和故障转移
|
||
|
||
子类需要实现:
|
||
- get_response_parser(): 返回格式特定的响应解析器
|
||
- 可选覆盖 handle_sse_event() 自定义事件处理
|
||
|
||
"""
|
||
|
||
# 子类可覆盖的配置
|
||
FORMAT_ID: str = "UNKNOWN" # API 格式标识
|
||
DATA_TIMEOUT: int = 30 # 流数据超时时间(秒)
|
||
EMPTY_CHUNK_THRESHOLD: int = 10 # 空流检测的 chunk 阈值
|
||
|
||
def __init__(
|
||
self,
|
||
db: Session,
|
||
user: User,
|
||
api_key: ApiKey,
|
||
request_id: str,
|
||
client_ip: str,
|
||
user_agent: str,
|
||
start_time: float,
|
||
allowed_api_formats: Optional[list] = None,
|
||
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
|
||
):
|
||
allowed = allowed_api_formats or [self.FORMAT_ID]
|
||
super().__init__(
|
||
db=db,
|
||
user=user,
|
||
api_key=api_key,
|
||
request_id=request_id,
|
||
client_ip=client_ip,
|
||
user_agent=user_agent,
|
||
start_time=start_time,
|
||
allowed_api_formats=allowed,
|
||
adapter_detector=adapter_detector,
|
||
)
|
||
self._parser: Optional[ResponseParser] = None
|
||
self._request_builder = PassthroughRequestBuilder()
|
||
|
||
@property
|
||
def parser(self) -> ResponseParser:
|
||
"""获取响应解析器(懒加载)"""
|
||
if self._parser is None:
|
||
self._parser = self.get_response_parser()
|
||
return self._parser
|
||
|
||
def get_response_parser(self) -> ResponseParser:
|
||
"""
|
||
获取格式特定的响应解析器
|
||
|
||
子类可覆盖此方法提供自定义解析器,
|
||
默认从解析器注册表获取
|
||
"""
|
||
return get_parser_for_format(self.FORMAT_ID)
|
||
|
||
async def _get_mapped_model(
|
||
self,
|
||
source_model: str,
|
||
provider_id: str,
|
||
) -> Optional[str]:
|
||
"""
|
||
获取模型映射后的实际模型名
|
||
|
||
按优先级查找:映射 → 别名 → 直接匹配 GlobalModel
|
||
|
||
Args:
|
||
source_model: 用户请求的模型名(可能是别名)
|
||
provider_id: Provider ID
|
||
|
||
Returns:
|
||
映射后的 provider_model_name,如果没有找到映射则返回 None
|
||
"""
|
||
from src.services.model.mapper import ModelMapperMiddleware
|
||
|
||
mapper = ModelMapperMiddleware(self.db)
|
||
mapping = await mapper.get_mapping(source_model, provider_id)
|
||
|
||
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
|
||
|
||
if mapping and mapping.model:
|
||
mapped_name = str(mapping.model.provider_model_name)
|
||
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
||
return mapped_name
|
||
|
||
logger.debug(f"[CLI] 无模型映射,使用原始名称: {source_model}")
|
||
return None
|
||
|
||
def extract_model_from_request(
|
||
self,
|
||
request_body: Dict[str, Any],
|
||
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002 - 子类使用
|
||
) -> str:
|
||
"""
|
||
从请求中提取模型名 - 子类可覆盖
|
||
|
||
不同 API 格式的 model 位置不同:
|
||
- OpenAI/Claude: 在请求体中 request_body["model"]
|
||
- Gemini: 在 URL 路径中 path_params["model"]
|
||
|
||
子类应覆盖此方法实现各自的提取逻辑。
|
||
|
||
Args:
|
||
request_body: 请求体
|
||
path_params: URL 路径参数
|
||
|
||
Returns:
|
||
模型名,如果无法提取则返回 "unknown"
|
||
"""
|
||
# 默认实现:从请求体获取
|
||
model = request_body.get("model")
|
||
return str(model) if model else "unknown"
|
||
|
||
def apply_mapped_model(
|
||
self,
|
||
request_body: Dict[str, Any],
|
||
mapped_model: str, # noqa: ARG002 - 子类使用
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
将映射后的模型名应用到请求体
|
||
|
||
基类默认实现:不修改请求体,保持原样透传。
|
||
子类应覆盖此方法实现各自的模型名替换逻辑。
|
||
|
||
Args:
|
||
request_body: 原始请求体
|
||
mapped_model: 映射后的模型名(子类使用)
|
||
|
||
Returns:
|
||
请求体(默认不修改)
|
||
"""
|
||
# 基类不修改请求体,子类覆盖此方法实现特定格式的处理
|
||
return request_body
|
||
|
||
def prepare_provider_request_body(
|
||
self,
|
||
request_body: Dict[str, Any],
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
准备发送给 Provider 的请求体 - 子类可覆盖
|
||
|
||
在模型映射之后、发送请求之前调用,用于移除不需要发送给上游的字段。
|
||
例如 Gemini API 需要移除请求体中的 model 字段(因为 model 在 URL 路径中)。
|
||
|
||
Args:
|
||
request_body: 经过模型映射处理后的请求体
|
||
|
||
Returns:
|
||
准备好的请求体
|
||
"""
|
||
return request_body
|
||
|
||
def get_model_for_url(
|
||
self,
|
||
request_body: Dict[str, Any],
|
||
mapped_model: Optional[str],
|
||
) -> Optional[str]:
|
||
"""
|
||
获取用于 URL 路径的模型名
|
||
|
||
某些 API 格式(如 Gemini)需要将 model 放入 URL 路径中。
|
||
子类应覆盖此方法返回正确的值。
|
||
|
||
Args:
|
||
request_body: 请求体
|
||
mapped_model: 映射后的模型名(如果有)
|
||
|
||
Returns:
|
||
用于 URL 路径的模型名,默认优先使用映射后的名称
|
||
"""
|
||
return mapped_model or request_body.get("model")
|
||
|
||
def _extract_response_metadata(
|
||
self,
|
||
response: Dict[str, Any],
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
从响应中提取 Provider 特有的元数据 - 子类可覆盖
|
||
|
||
例如 Gemini 返回的 modelVersion 字段。
|
||
这些元数据会存储到 Usage.request_metadata 中。
|
||
|
||
Args:
|
||
response: Provider 返回的响应
|
||
|
||
Returns:
|
||
元数据字典,默认为空
|
||
"""
|
||
return {}
|
||
|
||
async def process_stream(
|
||
self,
|
||
original_request_body: Dict[str, Any],
|
||
original_headers: Dict[str, str],
|
||
query_params: Optional[Dict[str, str]] = None,
|
||
path_params: Optional[Dict[str, Any]] = None,
|
||
) -> StreamingResponse:
|
||
"""
|
||
处理流式请求
|
||
|
||
通用流程:
|
||
1. 创建流上下文
|
||
2. 定义请求函数(供 FallbackOrchestrator 调用)
|
||
3. 执行请求并返回 StreamingResponse
|
||
4. 后台任务记录统计信息
|
||
"""
|
||
logger.debug(f"开始流式响应处理 ({self.FORMAT_ID})")
|
||
|
||
# 使用子类实现的方法提取 model(不同 API 格式的 model 位置不同)
|
||
model = self.extract_model_from_request(original_request_body, path_params)
|
||
|
||
# 创建流上下文
|
||
ctx = StreamContext(
|
||
model=model,
|
||
api_format=self.allowed_api_formats[0],
|
||
request_id=self.request_id,
|
||
user_id=self.user.id,
|
||
api_key_id=self.api_key.id,
|
||
)
|
||
|
||
# 定义请求函数
|
||
async def stream_request_func(
|
||
provider: Provider,
|
||
endpoint: ProviderEndpoint,
|
||
key: ProviderAPIKey,
|
||
) -> AsyncGenerator[bytes, None]:
|
||
return await self._execute_stream_request(
|
||
ctx,
|
||
provider,
|
||
endpoint,
|
||
key,
|
||
original_request_body,
|
||
original_headers,
|
||
query_params,
|
||
)
|
||
|
||
try:
|
||
# 解析能力需求
|
||
capability_requirements = self._resolve_capability_requirements(
|
||
model_name=ctx.model,
|
||
request_headers=original_headers,
|
||
)
|
||
|
||
# 执行请求(通过 FallbackOrchestrator)
|
||
(
|
||
stream_generator,
|
||
provider_name,
|
||
attempt_id,
|
||
_provider_id,
|
||
_endpoint_id,
|
||
_key_id,
|
||
) = await self.orchestrator.execute_with_fallback(
|
||
api_format=ctx.api_format,
|
||
model_name=ctx.model,
|
||
user_api_key=self.api_key,
|
||
request_func=stream_request_func,
|
||
request_id=self.request_id,
|
||
is_stream=True,
|
||
capability_requirements=capability_requirements or None,
|
||
)
|
||
ctx.attempt_id = attempt_id
|
||
|
||
# 创建后台任务记录统计
|
||
background_tasks = BackgroundTasks()
|
||
background_tasks.add_task(
|
||
self._record_stream_stats,
|
||
ctx,
|
||
original_headers,
|
||
original_request_body,
|
||
)
|
||
|
||
# 创建监控流
|
||
monitored_stream = self._create_monitored_stream(ctx, stream_generator)
|
||
|
||
return StreamingResponse(
|
||
monitored_stream,
|
||
media_type="text/event-stream",
|
||
background=background_tasks,
|
||
)
|
||
|
||
except Exception as e:
|
||
# 对于已知的业务异常,只记录简洁的错误信息,不输出完整堆栈
|
||
from src.core.exceptions import (
|
||
ProviderException,
|
||
QuotaExceededException,
|
||
RateLimitException,
|
||
ModelNotSupportedException,
|
||
)
|
||
|
||
if isinstance(e, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
|
||
# 业务异常:简洁日志
|
||
logger.error(f"流式请求失败: [{type(e).__name__}] {e}")
|
||
else:
|
||
# 未知异常:完整堆栈
|
||
logger.exception(f"流式请求失败: {e}")
|
||
await self._record_stream_failure(ctx, e, original_headers, original_request_body)
|
||
raise
|
||
|
||
async def _execute_stream_request(
|
||
self,
|
||
ctx: StreamContext,
|
||
provider: Provider,
|
||
endpoint: ProviderEndpoint,
|
||
key: ProviderAPIKey,
|
||
original_request_body: Dict[str, Any],
|
||
original_headers: Dict[str, str],
|
||
query_params: Optional[Dict[str, str]] = None,
|
||
) -> AsyncGenerator[bytes, None]:
|
||
"""执行流式请求并返回流生成器"""
|
||
# 重置上下文状态(重试时清除之前的数据,避免累积)
|
||
ctx.parsed_chunks = []
|
||
ctx.chunk_count = 0
|
||
ctx.data_count = 0
|
||
ctx.has_completion = False
|
||
ctx.collected_text = ""
|
||
ctx.input_tokens = 0
|
||
ctx.output_tokens = 0
|
||
ctx.cached_tokens = 0
|
||
ctx.cache_creation_tokens = 0
|
||
ctx.final_usage = None
|
||
ctx.final_response = None
|
||
ctx.response_id = None
|
||
ctx.response_metadata = {} # 重置 Provider 响应元数据
|
||
|
||
# 记录 Provider 信息
|
||
ctx.provider_name = str(provider.name)
|
||
ctx.provider_id = str(provider.id)
|
||
ctx.endpoint_id = str(endpoint.id)
|
||
ctx.key_id = str(key.id)
|
||
|
||
# 记录格式转换信息
|
||
ctx.provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
|
||
ctx.client_api_format = ctx.api_format # 已在 process_stream 中设置
|
||
|
||
# 获取模型映射(别名/映射 → 实际模型名)
|
||
mapped_model = await self._get_mapped_model(
|
||
source_model=ctx.model,
|
||
provider_id=str(provider.id),
|
||
)
|
||
|
||
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
|
||
if mapped_model:
|
||
ctx.mapped_model = mapped_model # 保存映射后的模型名,用于 Usage 记录
|
||
request_body = self.apply_mapped_model(original_request_body, mapped_model)
|
||
else:
|
||
request_body = original_request_body
|
||
|
||
# 准备发送给 Provider 的请求体(子类可覆盖以移除不需要的字段)
|
||
request_body = self.prepare_provider_request_body(request_body)
|
||
|
||
# 使用 RequestBuilder 构建请求体和请求头
|
||
# 注意:mapped_model 已经应用到 request_body,这里不再传递
|
||
provider_payload, provider_headers = self._request_builder.build(
|
||
request_body,
|
||
original_headers,
|
||
endpoint,
|
||
key,
|
||
is_stream=True,
|
||
)
|
||
|
||
# 保存发送给 Provider 的请求信息(用于调试和统计)
|
||
ctx.provider_request_headers = provider_headers
|
||
ctx.provider_request_body = provider_payload
|
||
|
||
# 获取用于 URL 的模型名(子类可覆盖此方法,如 Gemini 需要特殊处理)
|
||
# 使用 ctx.model 作为 fallback(它已从 path_params 获取)
|
||
url_model = self.get_model_for_url(request_body, mapped_model) or ctx.model
|
||
|
||
url = build_provider_url(
|
||
endpoint,
|
||
query_params=query_params,
|
||
path_params={"model": url_model},
|
||
is_stream=True, # CLI handler 处理流式请求
|
||
)
|
||
|
||
# 配置超时
|
||
timeout_config = httpx.Timeout(
|
||
connect=10.0,
|
||
read=float(endpoint.timeout),
|
||
write=10.0,
|
||
pool=10.0,
|
||
)
|
||
|
||
logger.debug(f" └─ [{self.request_id}] 发送流式请求: "
|
||
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||
f"Key=***{key.api_key[-4:]}, "
|
||
f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
|
||
|
||
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
|
||
try:
|
||
response_ctx = http_client.stream(
|
||
"POST", url, json=provider_payload, headers=provider_headers
|
||
)
|
||
stream_response = await response_ctx.__aenter__()
|
||
|
||
ctx.status_code = stream_response.status_code
|
||
ctx.response_headers = dict(stream_response.headers)
|
||
|
||
logger.debug(f" └─ 收到响应: status={stream_response.status_code}")
|
||
|
||
stream_response.raise_for_status()
|
||
|
||
# 创建行迭代器(只创建一次,后续会继续使用)
|
||
line_iterator = stream_response.aiter_lines()
|
||
|
||
# 预读第一个数据块,检测嵌套错误(HTTP 200 但响应体包含错误)
|
||
prefetched_lines = await self._prefetch_and_check_embedded_error(
|
||
line_iterator, provider, endpoint, ctx
|
||
)
|
||
|
||
except httpx.HTTPStatusError as e:
|
||
error_text = await self._extract_error_text(e)
|
||
logger.error(f"Provider 返回错误状态: {e.response.status_code}\n Response: {error_text}")
|
||
await http_client.aclose()
|
||
raise
|
||
|
||
except EmbeddedErrorException:
|
||
# 嵌套错误需要触发重试,关闭连接后重新抛出
|
||
try:
|
||
await response_ctx.__aexit__(None, None, None)
|
||
except Exception:
|
||
pass
|
||
await http_client.aclose()
|
||
raise
|
||
|
||
except Exception:
|
||
await http_client.aclose()
|
||
raise
|
||
|
||
# 创建流生成器(带预读数据,使用同一个迭代器)
|
||
return self._create_response_stream_with_prefetch(
|
||
ctx,
|
||
line_iterator,
|
||
response_ctx,
|
||
http_client,
|
||
prefetched_lines,
|
||
)
|
||
|
||
async def _create_response_stream(
|
||
self,
|
||
ctx: StreamContext,
|
||
stream_response: httpx.Response,
|
||
response_ctx: Any,
|
||
http_client: httpx.AsyncClient,
|
||
) -> AsyncGenerator[bytes, None]:
|
||
"""创建响应流生成器"""
|
||
try:
|
||
sse_parser = SSEEventParser()
|
||
last_data_time = time.time()
|
||
streaming_status_updated = False
|
||
|
||
# 检查是否需要格式转换
|
||
needs_conversion = self._needs_format_conversion(ctx)
|
||
|
||
async for line in stream_response.aiter_lines():
|
||
# 在第一次输出数据前更新状态为 streaming
|
||
if not streaming_status_updated:
|
||
self._update_usage_to_streaming(ctx.request_id)
|
||
streaming_status_updated = True
|
||
|
||
normalized_line = line.rstrip("\r")
|
||
events = sse_parser.feed_line(normalized_line)
|
||
|
||
if normalized_line == "":
|
||
for event in events:
|
||
self._handle_sse_event(
|
||
ctx,
|
||
event.get("event"),
|
||
event.get("data") or "",
|
||
)
|
||
yield b"\n"
|
||
continue
|
||
|
||
ctx.chunk_count += 1
|
||
|
||
# 空流检测:超过阈值且无数据,发送错误事件并结束
|
||
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
||
elapsed = time.time() - last_data_time
|
||
if elapsed > self.DATA_TIMEOUT:
|
||
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
||
error_event = {
|
||
"type": "error",
|
||
"error": {
|
||
"type": "empty_stream_timeout",
|
||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||
},
|
||
}
|
||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||
return # 结束生成器
|
||
|
||
# 格式转换或直接透传
|
||
if needs_conversion:
|
||
converted_line = self._convert_sse_line(ctx, line, events)
|
||
if converted_line:
|
||
yield (converted_line + "\n").encode("utf-8")
|
||
else:
|
||
yield (line + "\n").encode("utf-8")
|
||
|
||
for event in events:
|
||
self._handle_sse_event(
|
||
ctx,
|
||
event.get("event"),
|
||
event.get("data") or "",
|
||
)
|
||
|
||
if ctx.data_count > 0:
|
||
last_data_time = time.time()
|
||
|
||
# 处理剩余事件
|
||
for event in sse_parser.flush():
|
||
self._handle_sse_event(
|
||
ctx,
|
||
event.get("event"),
|
||
event.get("data") or "",
|
||
)
|
||
|
||
# 检查是否收到数据
|
||
if ctx.data_count == 0:
|
||
# 流已开始,无法抛出异常进行故障转移
|
||
# 发送错误事件并记录日志
|
||
logger.warning(f"提供商 '{ctx.provider_name}' 返回空流式响应")
|
||
error_event = {
|
||
"type": "error",
|
||
"error": {
|
||
"type": "empty_response",
|
||
"message": f"提供商 '{ctx.provider_name}' 返回了空的流式响应",
|
||
},
|
||
}
|
||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||
else:
|
||
logger.debug("流式数据转发完成")
|
||
|
||
except GeneratorExit:
|
||
raise
|
||
except httpx.StreamClosed:
|
||
if ctx.data_count == 0:
|
||
# 流已开始,发送错误事件而不是抛出异常
|
||
logger.warning(f"提供商 '{ctx.provider_name}' 流连接关闭且无数据")
|
||
error_event = {
|
||
"type": "error",
|
||
"error": {
|
||
"type": "stream_closed",
|
||
"message": f"提供商 '{ctx.provider_name}' 连接关闭且未返回数据",
|
||
},
|
||
}
|
||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||
except httpx.RemoteProtocolError as e:
|
||
if ctx.data_count > 0:
|
||
error_event = {
|
||
"type": "error",
|
||
"error": {
|
||
"type": "connection_error",
|
||
"message": "上游连接意外关闭,部分响应已成功传输",
|
||
},
|
||
}
|
||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||
else:
|
||
raise
|
||
finally:
|
||
try:
|
||
await response_ctx.__aexit__(None, None, None)
|
||
except Exception:
|
||
pass
|
||
try:
|
||
await http_client.aclose()
|
||
except Exception:
|
||
pass
|
||
|
||
async def _prefetch_and_check_embedded_error(
|
||
self,
|
||
line_iterator: Any,
|
||
provider: Provider,
|
||
endpoint: ProviderEndpoint,
|
||
ctx: StreamContext,
|
||
) -> list:
|
||
"""
|
||
预读流的前几行,检测嵌套错误
|
||
|
||
某些 Provider(如 Gemini)可能返回 HTTP 200,但在响应体中包含错误信息。
|
||
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
|
||
|
||
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
|
||
|
||
Args:
|
||
line_iterator: 行迭代器(aiter_lines() 返回的迭代器)
|
||
provider: Provider 对象
|
||
endpoint: Endpoint 对象
|
||
ctx: 流上下文
|
||
|
||
Returns:
|
||
预读的行列表(需要在后续流中先输出)
|
||
|
||
Raises:
|
||
EmbeddedErrorException: 如果检测到嵌套错误
|
||
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
|
||
"""
|
||
prefetched_lines: list = []
|
||
max_prefetch_lines = 5 # 最多预读5行来检测错误
|
||
|
||
try:
|
||
# 获取对应格式的解析器
|
||
provider_format = ctx.provider_api_format
|
||
if provider_format:
|
||
try:
|
||
provider_parser = get_parser_for_format(provider_format)
|
||
except KeyError:
|
||
provider_parser = self.parser
|
||
else:
|
||
provider_parser = self.parser
|
||
|
||
line_count = 0
|
||
async for line in line_iterator:
|
||
prefetched_lines.append(line)
|
||
line_count += 1
|
||
|
||
# 解析数据
|
||
normalized_line = line.rstrip("\r")
|
||
|
||
# 检测 HTML 响应(base_url 配置错误的常见症状)
|
||
lower_line = normalized_line.lower()
|
||
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
|
||
logger.error(
|
||
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
|
||
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||
f"base_url={endpoint.base_url}"
|
||
)
|
||
raise ProviderNotAvailableException(
|
||
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
|
||
)
|
||
|
||
if not normalized_line or normalized_line.startswith(":"):
|
||
# 空行或注释行,继续预读
|
||
if line_count >= max_prefetch_lines:
|
||
break
|
||
continue
|
||
|
||
# 尝试解析 SSE 数据
|
||
data_str = normalized_line
|
||
if normalized_line.startswith("data: "):
|
||
data_str = normalized_line[6:]
|
||
|
||
if data_str == "[DONE]":
|
||
break
|
||
|
||
try:
|
||
data = json.loads(data_str)
|
||
except json.JSONDecodeError:
|
||
# 不是有效 JSON,可能是部分数据,继续
|
||
if line_count >= max_prefetch_lines:
|
||
break
|
||
continue
|
||
|
||
# 使用解析器检查是否为错误响应
|
||
if isinstance(data, dict) and provider_parser.is_error_response(data):
|
||
# 提取错误信息
|
||
parsed = provider_parser.parse_response(data, 200)
|
||
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
|
||
f"Provider={provider.name}, "
|
||
f"error_type={parsed.error_type}, "
|
||
f"message={parsed.error_message}")
|
||
raise EmbeddedErrorException(
|
||
provider_name=str(provider.name),
|
||
error_code=(
|
||
int(parsed.error_type)
|
||
if parsed.error_type and parsed.error_type.isdigit()
|
||
else None
|
||
),
|
||
error_message=parsed.error_message,
|
||
error_status=parsed.error_type,
|
||
)
|
||
|
||
# 预读到有效数据,没有错误,停止预读
|
||
break
|
||
|
||
except EmbeddedErrorException:
|
||
# 重新抛出嵌套错误
|
||
raise
|
||
except Exception as e:
|
||
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
|
||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
||
|
||
return prefetched_lines
|
||
|
||
async def _create_response_stream_with_prefetch(
|
||
self,
|
||
ctx: StreamContext,
|
||
line_iterator: Any,
|
||
response_ctx: Any,
|
||
http_client: httpx.AsyncClient,
|
||
prefetched_lines: list,
|
||
) -> AsyncGenerator[bytes, None]:
|
||
"""创建响应流生成器(带预读数据)"""
|
||
try:
|
||
sse_parser = SSEEventParser()
|
||
last_data_time = time.time()
|
||
|
||
# 检查是否需要格式转换
|
||
needs_conversion = self._needs_format_conversion(ctx)
|
||
|
||
# 在第一次输出数据前更新状态为 streaming
|
||
if prefetched_lines:
|
||
self._update_usage_to_streaming(ctx.request_id)
|
||
|
||
# 先处理预读的数据
|
||
for line in prefetched_lines:
|
||
normalized_line = line.rstrip("\r")
|
||
events = sse_parser.feed_line(normalized_line)
|
||
|
||
if normalized_line == "":
|
||
for event in events:
|
||
self._handle_sse_event(
|
||
ctx,
|
||
event.get("event"),
|
||
event.get("data") or "",
|
||
)
|
||
yield b"\n"
|
||
continue
|
||
|
||
ctx.chunk_count += 1
|
||
|
||
# 格式转换或直接透传
|
||
if needs_conversion:
|
||
converted_line = self._convert_sse_line(ctx, line, events)
|
||
if converted_line:
|
||
yield (converted_line + "\n").encode("utf-8")
|
||
else:
|
||
yield (line + "\n").encode("utf-8")
|
||
|
||
for event in events:
|
||
self._handle_sse_event(
|
||
ctx,
|
||
event.get("event"),
|
||
event.get("data") or "",
|
||
)
|
||
|
||
if ctx.data_count > 0:
|
||
last_data_time = time.time()
|
||
|
||
# 继续处理剩余的流数据(使用同一个迭代器)
|
||
async for line in line_iterator:
|
||
normalized_line = line.rstrip("\r")
|
||
events = sse_parser.feed_line(normalized_line)
|
||
|
||
if normalized_line == "":
|
||
for event in events:
|
||
self._handle_sse_event(
|
||
ctx,
|
||
event.get("event"),
|
||
event.get("data") or "",
|
||
)
|
||
yield b"\n"
|
||
continue
|
||
|
||
ctx.chunk_count += 1
|
||
|
||
# 空流检测:超过阈值且无数据,发送错误事件并结束
|
||
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
||
elapsed = time.time() - last_data_time
|
||
if elapsed > self.DATA_TIMEOUT:
|
||
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
||
error_event = {
|
||
"type": "error",
|
||
"error": {
|
||
"type": "empty_stream_timeout",
|
||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||
},
|
||
}
|
||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||
return
|
||
|
||
# 格式转换或直接透传
|
||
if needs_conversion:
|
||
converted_line = self._convert_sse_line(ctx, line, events)
|
||
if converted_line:
|
||
yield (converted_line + "\n").encode("utf-8")
|
||
else:
|
||
yield (line + "\n").encode("utf-8")
|
||
|
||
for event in events:
|
||
self._handle_sse_event(
|
||
ctx,
|
||
event.get("event"),
|
||
event.get("data") or "",
|
||
)
|
||
|
||
if ctx.data_count > 0:
|
||
last_data_time = time.time()
|
||
|
||
# 处理剩余事件
|
||
flushed_events = sse_parser.flush()
|
||
for event in flushed_events:
|
||
self._handle_sse_event(
|
||
ctx,
|
||
event.get("event"),
|
||
event.get("data") or "",
|
||
)
|
||
|
||
# 检查是否收到数据
|
||
if ctx.data_count == 0:
|
||
# 空流通常意味着配置错误(如 base_url 指向了网页而非 API)
|
||
logger.error(
|
||
f"提供商 '{ctx.provider_name}' 返回空流式响应 (收到 {ctx.chunk_count} 个非数据行), "
|
||
f"可能是 endpoint base_url 配置错误"
|
||
)
|
||
error_event = {
|
||
"type": "error",
|
||
"error": {
|
||
"type": "empty_response",
|
||
"message": f"提供商 '{ctx.provider_name}' 返回了空的流式响应 (收到 {ctx.chunk_count} 行非 SSE 数据),请检查 endpoint 的 base_url 配置是否指向了正确的 API 地址",
|
||
},
|
||
}
|
||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||
else:
|
||
logger.debug("流式数据转发完成")
|
||
|
||
except GeneratorExit:
|
||
raise
|
||
except httpx.StreamClosed:
|
||
if ctx.data_count == 0:
|
||
logger.warning(f"提供商 '{ctx.provider_name}' 流连接关闭且无数据")
|
||
error_event = {
|
||
"type": "error",
|
||
"error": {
|
||
"type": "stream_closed",
|
||
"message": f"提供商 '{ctx.provider_name}' 连接关闭且未返回数据",
|
||
},
|
||
}
|
||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||
except httpx.RemoteProtocolError:
|
||
if ctx.data_count > 0:
|
||
error_event = {
|
||
"type": "error",
|
||
"error": {
|
||
"type": "connection_error",
|
||
"message": "上游连接意外关闭,部分响应已成功传输",
|
||
},
|
||
}
|
||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||
else:
|
||
raise
|
||
finally:
|
||
try:
|
||
await response_ctx.__aexit__(None, None, None)
|
||
except Exception:
|
||
pass
|
||
try:
|
||
await http_client.aclose()
|
||
except Exception:
|
||
pass
|
||
|
||
def _handle_sse_event(
|
||
self,
|
||
ctx: StreamContext,
|
||
event_name: Optional[str],
|
||
data_str: str,
|
||
) -> None:
|
||
"""
|
||
处理 SSE 事件
|
||
|
||
通用框架:解析 JSON、更新计数器
|
||
子类可覆盖 _process_event_data() 实现格式特定逻辑
|
||
"""
|
||
if not data_str:
|
||
return
|
||
|
||
if data_str == "[DONE]":
|
||
ctx.has_completion = True
|
||
return
|
||
|
||
try:
|
||
data = json.loads(data_str)
|
||
except json.JSONDecodeError:
|
||
return
|
||
|
||
ctx.data_count += 1
|
||
ctx.parsed_chunks.append(data)
|
||
|
||
if not isinstance(data, dict):
|
||
return
|
||
|
||
event_type = event_name or data.get("type", "")
|
||
|
||
# 调用格式特定的处理逻辑
|
||
self._process_event_data(ctx, event_type, data)
|
||
|
||
def _process_event_data(
|
||
self,
|
||
ctx: StreamContext,
|
||
event_type: str,
|
||
data: Dict[str, Any],
|
||
) -> None:
|
||
"""
|
||
处理解析后的事件数据 - 子类应覆盖此方法
|
||
|
||
默认实现使用 ResponseParser 提取 usage
|
||
"""
|
||
# 提取 response_id
|
||
if not ctx.response_id:
|
||
response_obj = data.get("response")
|
||
if isinstance(response_obj, dict) and response_obj.get("id"):
|
||
ctx.response_id = response_obj["id"]
|
||
elif "id" in data:
|
||
ctx.response_id = data["id"]
|
||
|
||
# 使用解析器提取 usage
|
||
usage = self.parser.extract_usage_from_response(data)
|
||
if usage and not ctx.final_usage:
|
||
ctx.input_tokens = usage.get("input_tokens", 0)
|
||
ctx.output_tokens = usage.get("output_tokens", 0)
|
||
ctx.cached_tokens = usage.get("cache_read_tokens", 0)
|
||
ctx.final_usage = usage
|
||
|
||
# 提取文本内容
|
||
text = self.parser.extract_text_content(data)
|
||
if text:
|
||
ctx.collected_text += text
|
||
|
||
# 检查完成事件
|
||
if event_type in ("response.completed", "message_stop"):
|
||
ctx.has_completion = True
|
||
response_obj = data.get("response")
|
||
if isinstance(response_obj, dict):
|
||
ctx.final_response = response_obj
|
||
|
||
def _finalize_stream_metadata(self, ctx: StreamContext) -> None:
|
||
"""
|
||
在记录统计前从 parsed_chunks 中提取额外的元数据 - 子类可覆盖
|
||
|
||
这是一个后处理钩子,在流传输完成后、记录 Usage 之前调用。
|
||
子类可以覆盖此方法从 ctx.parsed_chunks 中提取格式特定的元数据,
|
||
如 Gemini 的 modelVersion、token 统计等。
|
||
|
||
Args:
|
||
ctx: 流上下文,包含 parsed_chunks 和 response_metadata
|
||
"""
|
||
pass
|
||
|
||
async def _create_monitored_stream(
|
||
self,
|
||
ctx: StreamContext,
|
||
stream_generator: AsyncGenerator[bytes, None],
|
||
) -> AsyncGenerator[bytes, None]:
|
||
"""创建带监控的流生成器"""
|
||
try:
|
||
async for chunk in stream_generator:
|
||
yield chunk
|
||
except asyncio.CancelledError:
|
||
ctx.status_code = 499
|
||
ctx.error_message = "Client disconnected"
|
||
raise
|
||
except httpx.TimeoutException as e:
|
||
ctx.status_code = 504
|
||
ctx.error_message = str(e)
|
||
raise
|
||
except Exception as e:
|
||
ctx.status_code = 500
|
||
ctx.error_message = str(e)
|
||
raise
|
||
|
||
async def _record_stream_stats(
|
||
self,
|
||
ctx: StreamContext,
|
||
original_headers: Dict[str, str],
|
||
original_request_body: Dict[str, Any],
|
||
) -> None:
|
||
"""在流完成后记录统计信息"""
|
||
try:
|
||
await asyncio.sleep(0.1)
|
||
|
||
response_time_ms = int((time.time() - ctx.start_time) * 1000)
|
||
|
||
if not ctx.provider_name:
|
||
logger.warning(f"[{ctx.request_id}] 流式请求失败,未选中提供商")
|
||
return
|
||
|
||
# Claude API 的 input_tokens 已经是非缓存部分,不需要再减去 cached_tokens
|
||
# 实际计费的输入 tokens = input_tokens + cache_creation_tokens(缓存读取免费或折扣)
|
||
actual_input_tokens = ctx.input_tokens
|
||
|
||
# 获取新的 DB session
|
||
db_gen = get_db()
|
||
bg_db = next(db_gen)
|
||
|
||
try:
|
||
from src.models.database import ApiKey as ApiKeyModel
|
||
|
||
user = bg_db.query(User).filter(User.id == ctx.user_id).first()
|
||
api_key = bg_db.query(ApiKeyModel).filter(ApiKeyModel.id == ctx.api_key_id).first()
|
||
|
||
if not user or not api_key:
|
||
return
|
||
|
||
bg_telemetry = MessageTelemetry(
|
||
bg_db, user, api_key, ctx.request_id, self.client_ip
|
||
)
|
||
|
||
response_body = {
|
||
"chunks": ctx.parsed_chunks,
|
||
"metadata": {
|
||
"stream": True,
|
||
"total_chunks": len(ctx.parsed_chunks),
|
||
"data_count": ctx.data_count,
|
||
"has_completion": ctx.has_completion,
|
||
"response_time_ms": response_time_ms,
|
||
},
|
||
}
|
||
|
||
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
|
||
actual_request_body = ctx.provider_request_body or original_request_body
|
||
|
||
# 根据状态码决定记录成功还是失败
|
||
# 499 = 客户端断开连接,503 = 服务不可用(如流中断)
|
||
if ctx.status_code and ctx.status_code >= 400:
|
||
# 记录失败的 Usage,但使用已收到的预估 token 信息(来自 message_start)
|
||
# 这样即使请求中断,也能记录预估成本
|
||
await bg_telemetry.record_failure(
|
||
provider=ctx.provider_name or "unknown",
|
||
model=ctx.model,
|
||
response_time_ms=response_time_ms,
|
||
status_code=ctx.status_code,
|
||
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
|
||
request_headers=original_headers,
|
||
request_body=actual_request_body,
|
||
is_stream=True,
|
||
api_format=ctx.api_format,
|
||
provider_request_headers=ctx.provider_request_headers,
|
||
# 预估 token 信息(来自 message_start 事件)
|
||
input_tokens=actual_input_tokens,
|
||
output_tokens=ctx.output_tokens,
|
||
cache_creation_tokens=ctx.cache_creation_tokens,
|
||
cache_read_tokens=ctx.cached_tokens,
|
||
response_body=response_body,
|
||
# 模型映射信息
|
||
target_model=ctx.mapped_model,
|
||
)
|
||
logger.debug(f"{self.FORMAT_ID} 流式响应中断")
|
||
# 简洁的请求失败摘要(包含预估 token 信息)
|
||
logger.info(f"[FAIL] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | "
|
||
f"{ctx.status_code} | in:{actual_input_tokens} out:{ctx.output_tokens} cache:{ctx.cached_tokens}")
|
||
else:
|
||
# 在记录统计前,允许子类从 parsed_chunks 中提取额外的元数据
|
||
self._finalize_stream_metadata(ctx)
|
||
|
||
total_cost = await bg_telemetry.record_success(
|
||
provider=ctx.provider_name,
|
||
model=ctx.model,
|
||
input_tokens=actual_input_tokens,
|
||
output_tokens=ctx.output_tokens,
|
||
response_time_ms=response_time_ms,
|
||
status_code=ctx.status_code,
|
||
request_headers=original_headers,
|
||
request_body=actual_request_body,
|
||
response_headers=ctx.response_headers,
|
||
response_body=response_body,
|
||
cache_creation_tokens=ctx.cache_creation_tokens,
|
||
cache_read_tokens=ctx.cached_tokens,
|
||
is_stream=True,
|
||
provider_request_headers=ctx.provider_request_headers,
|
||
api_format=ctx.api_format,
|
||
# Provider 侧追踪信息(用于记录真实成本)
|
||
provider_id=ctx.provider_id,
|
||
provider_endpoint_id=ctx.endpoint_id,
|
||
provider_api_key_id=ctx.key_id,
|
||
# 模型映射信息
|
||
target_model=ctx.mapped_model,
|
||
# Provider 响应元数据(如 Gemini 的 modelVersion)
|
||
response_metadata=ctx.response_metadata if ctx.response_metadata else None,
|
||
)
|
||
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
|
||
# 简洁的请求完成摘要
|
||
logger.info(f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | "
|
||
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}")
|
||
|
||
# 更新候选记录的最终状态和延迟时间
|
||
# 注意:RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
|
||
# 这里用流传输完成后的实际时间覆盖
|
||
if ctx.attempt_id:
|
||
from src.services.request.candidate import RequestCandidateService
|
||
|
||
# 根据状态码决定是成功还是失败
|
||
# 499 = 客户端断开连接,应标记为失败
|
||
# 503 = 服务不可用(如流中断),应标记为失败
|
||
if ctx.status_code and ctx.status_code >= 400:
|
||
RequestCandidateService.mark_candidate_failed(
|
||
db=bg_db,
|
||
candidate_id=ctx.attempt_id,
|
||
error_type="client_disconnected" if ctx.status_code == 499 else "stream_error",
|
||
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
|
||
status_code=ctx.status_code,
|
||
latency_ms=response_time_ms,
|
||
extra_data={
|
||
"stream_completed": False,
|
||
"chunk_count": ctx.chunk_count,
|
||
"data_count": ctx.data_count,
|
||
},
|
||
)
|
||
else:
|
||
RequestCandidateService.mark_candidate_success(
|
||
db=bg_db,
|
||
candidate_id=ctx.attempt_id,
|
||
status_code=ctx.status_code,
|
||
latency_ms=response_time_ms,
|
||
extra_data={
|
||
"stream_completed": True,
|
||
"chunk_count": ctx.chunk_count,
|
||
"data_count": ctx.data_count,
|
||
},
|
||
)
|
||
|
||
finally:
|
||
bg_db.close()
|
||
|
||
except Exception as e:
|
||
logger.exception("记录流式统计信息时出错")
|
||
|
||
async def _record_stream_failure(
|
||
self,
|
||
ctx: StreamContext,
|
||
error: Exception,
|
||
original_headers: Dict[str, str],
|
||
original_request_body: Dict[str, Any],
|
||
) -> None:
|
||
"""记录流式请求失败"""
|
||
response_time_ms = int((time.time() - ctx.start_time) * 1000)
|
||
|
||
status_code = 503
|
||
if isinstance(error, ProviderAuthException):
|
||
status_code = 503
|
||
elif isinstance(error, ProviderRateLimitException):
|
||
status_code = 429
|
||
elif isinstance(error, ProviderTimeoutException):
|
||
status_code = 504
|
||
|
||
ctx.status_code = status_code
|
||
ctx.error_message = str(error)
|
||
|
||
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
|
||
actual_request_body = ctx.provider_request_body or original_request_body
|
||
|
||
await self.telemetry.record_failure(
|
||
provider=ctx.provider_name or "unknown",
|
||
model=ctx.model,
|
||
response_time_ms=response_time_ms,
|
||
status_code=status_code,
|
||
error_message=str(error),
|
||
request_headers=original_headers,
|
||
request_body=actual_request_body,
|
||
is_stream=True,
|
||
api_format=ctx.api_format,
|
||
provider_request_headers=ctx.provider_request_headers,
|
||
# 模型映射信息
|
||
target_model=ctx.mapped_model,
|
||
)
|
||
|
||
# _update_usage_to_streaming 方法已移至基类 BaseMessageHandler
|
||
|
||
async def process_sync(
|
||
self,
|
||
original_request_body: Dict[str, Any],
|
||
original_headers: Dict[str, str],
|
||
query_params: Optional[Dict[str, str]] = None,
|
||
path_params: Optional[Dict[str, Any]] = None,
|
||
) -> JSONResponse:
|
||
"""
|
||
处理非流式请求
|
||
|
||
通用流程:
|
||
1. 构建请求
|
||
2. 通过 FallbackOrchestrator 执行
|
||
3. 解析响应并记录统计
|
||
"""
|
||
logger.debug(f"开始非流式响应处理 ({self.FORMAT_ID})")
|
||
|
||
# 使用子类实现的方法提取 model(不同 API 格式的 model 位置不同)
|
||
model = self.extract_model_from_request(original_request_body, path_params)
|
||
api_format = self.allowed_api_formats[0]
|
||
sync_start_time = time.time()
|
||
|
||
provider_name = None
|
||
response_json = None
|
||
status_code = 200
|
||
response_headers = {}
|
||
provider_api_format = "" # 用于追踪 Provider 的 API 格式
|
||
provider_request_headers = {} # 发送给 Provider 的请求头
|
||
provider_request_body = None # 实际发送给 Provider 的请求体
|
||
provider_id = None # Provider ID(用于失败记录)
|
||
endpoint_id = None # Endpoint ID(用于失败记录)
|
||
key_id = None # Key ID(用于失败记录)
|
||
mapped_model_result = None # 映射后的目标模型名(用于 Usage 记录)
|
||
response_metadata_result: Dict[str, Any] = {} # Provider 响应元数据
|
||
|
||
async def sync_request_func(
|
||
provider: Provider,
|
||
endpoint: ProviderEndpoint,
|
||
key: ProviderAPIKey,
|
||
) -> Dict[str, Any]:
|
||
nonlocal provider_name, response_json, status_code, response_headers, provider_api_format, provider_request_headers, provider_request_body, mapped_model_result, response_metadata_result
|
||
provider_name = str(provider.name)
|
||
provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
|
||
|
||
# 获取模型映射(别名/映射 → 实际模型名)
|
||
mapped_model = await self._get_mapped_model(
|
||
source_model=model,
|
||
provider_id=str(provider.id),
|
||
)
|
||
|
||
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
|
||
if mapped_model:
|
||
mapped_model_result = mapped_model # 保存映射后的模型名,用于 Usage 记录
|
||
request_body = self.apply_mapped_model(original_request_body, mapped_model)
|
||
else:
|
||
request_body = original_request_body
|
||
|
||
# 准备发送给 Provider 的请求体(子类可覆盖以移除不需要的字段)
|
||
request_body = self.prepare_provider_request_body(request_body)
|
||
|
||
# 使用 RequestBuilder 构建请求体和请求头
|
||
# 注意:mapped_model 已经应用到 request_body,这里不再传递
|
||
provider_payload, provider_headers = self._request_builder.build(
|
||
request_body,
|
||
original_headers,
|
||
endpoint,
|
||
key,
|
||
is_stream=False,
|
||
)
|
||
|
||
# 保存发送给 Provider 的请求信息(用于调试和统计)
|
||
provider_request_headers = provider_headers
|
||
provider_request_body = provider_payload
|
||
|
||
# 获取用于 URL 的模型名(子类可覆盖此方法,如 Gemini 需要特殊处理)
|
||
url_model = self.get_model_for_url(request_body, mapped_model)
|
||
|
||
url = build_provider_url(
|
||
endpoint,
|
||
query_params=query_params,
|
||
path_params={"model": url_model},
|
||
is_stream=False, # 非流式请求
|
||
)
|
||
|
||
logger.info(f" └─ [{self.request_id}] 发送非流式请求: "
|
||
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||
f"Key=***{key.api_key[-4:]}, "
|
||
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
|
||
|
||
async with httpx.AsyncClient(
|
||
timeout=float(endpoint.timeout),
|
||
follow_redirects=True,
|
||
) as http_client:
|
||
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
|
||
|
||
status_code = resp.status_code
|
||
response_headers = dict(resp.headers)
|
||
|
||
if resp.status_code == 401:
|
||
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
|
||
elif resp.status_code == 429:
|
||
raise ProviderRateLimitException(
|
||
f"提供商速率限制: {provider.name}",
|
||
provider_name=str(provider.name),
|
||
response_headers=response_headers,
|
||
retry_after=int(resp.headers.get("retry-after", 0)) or None,
|
||
)
|
||
elif resp.status_code >= 500:
|
||
raise ProviderNotAvailableException(
|
||
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}"
|
||
)
|
||
elif 300 <= resp.status_code < 400:
|
||
redirect_url = resp.headers.get("location", "unknown")
|
||
raise ProviderNotAvailableException(
|
||
f"提供商配置错误: {provider.name}, 返回重定向 {resp.status_code} -> {redirect_url}"
|
||
)
|
||
elif resp.status_code != 200:
|
||
error_text = resp.text
|
||
raise ProviderNotAvailableException(
|
||
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}, 错误: {error_text[:200]}"
|
||
)
|
||
|
||
# 安全解析 JSON 响应,处理可能的编码错误
|
||
try:
|
||
response_json = resp.json()
|
||
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
||
# 记录原始响应信息用于调试
|
||
content_type = resp.headers.get("content-type", "unknown")
|
||
content_encoding = resp.headers.get("content-encoding", "none")
|
||
logger.error(f"[{self.request_id}] 无法解析响应 JSON: {e}, "
|
||
f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, "
|
||
f"响应长度: {len(resp.content)} bytes")
|
||
raise ProviderNotAvailableException(
|
||
f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}"
|
||
)
|
||
|
||
# 提取 Provider 响应元数据(子类可覆盖)
|
||
response_metadata_result = self._extract_response_metadata(response_json)
|
||
|
||
return response_json if isinstance(response_json, dict) else {}
|
||
|
||
try:
|
||
# 解析能力需求
|
||
capability_requirements = self._resolve_capability_requirements(
|
||
model_name=model,
|
||
request_headers=original_headers,
|
||
)
|
||
|
||
(
|
||
result,
|
||
actual_provider_name,
|
||
attempt_id,
|
||
provider_id,
|
||
endpoint_id,
|
||
key_id,
|
||
) = await self.orchestrator.execute_with_fallback(
|
||
api_format=api_format,
|
||
model_name=model,
|
||
user_api_key=self.api_key,
|
||
request_func=sync_request_func,
|
||
request_id=self.request_id,
|
||
capability_requirements=capability_requirements or None,
|
||
)
|
||
|
||
provider_name = actual_provider_name
|
||
response_time_ms = int((time.time() - sync_start_time) * 1000)
|
||
|
||
# 确保 response_json 不为 None
|
||
if response_json is None:
|
||
response_json = {}
|
||
|
||
# 检查是否需要格式转换
|
||
if (
|
||
provider_api_format
|
||
and api_format
|
||
and provider_api_format.upper() != api_format.upper()
|
||
):
|
||
from src.api.handlers.base.format_converter_registry import converter_registry
|
||
|
||
try:
|
||
response_json = converter_registry.convert_response(
|
||
response_json,
|
||
provider_api_format,
|
||
api_format,
|
||
)
|
||
logger.debug(f"非流式响应格式转换完成: {provider_api_format} -> {api_format}")
|
||
except Exception as conv_err:
|
||
logger.warning(f"非流式响应格式转换失败,使用原始响应: {conv_err}")
|
||
|
||
# 使用解析器提取 usage
|
||
usage = self.parser.extract_usage_from_response(response_json)
|
||
input_tokens = usage.get("input_tokens", 0)
|
||
output_tokens = usage.get("output_tokens", 0)
|
||
cached_tokens = usage.get("cache_read_tokens", 0)
|
||
cache_creation_tokens = usage.get("cache_creation_tokens", 0)
|
||
# Claude API 的 input_tokens 已经是非缓存部分,不需要再减去 cached_tokens
|
||
actual_input_tokens = input_tokens
|
||
|
||
output_text = self.parser.extract_text_content(response_json)[:200]
|
||
|
||
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
|
||
actual_request_body = provider_request_body or original_request_body
|
||
|
||
total_cost = await self.telemetry.record_success(
|
||
provider=provider_name,
|
||
model=model,
|
||
input_tokens=actual_input_tokens,
|
||
output_tokens=output_tokens,
|
||
response_time_ms=response_time_ms,
|
||
status_code=status_code,
|
||
request_headers=original_headers,
|
||
request_body=actual_request_body,
|
||
response_headers=response_headers,
|
||
response_body=response_json,
|
||
cache_creation_tokens=cache_creation_tokens,
|
||
cache_read_tokens=cached_tokens,
|
||
is_stream=False,
|
||
provider_request_headers=provider_request_headers,
|
||
api_format=api_format,
|
||
# Provider 侧追踪信息(用于记录真实成本)
|
||
provider_id=provider_id,
|
||
provider_endpoint_id=endpoint_id,
|
||
provider_api_key_id=key_id,
|
||
# 模型映射信息
|
||
target_model=mapped_model_result,
|
||
# Provider 响应元数据(如 Gemini 的 modelVersion)
|
||
response_metadata=response_metadata_result if response_metadata_result else None,
|
||
)
|
||
|
||
logger.info(f"{self.FORMAT_ID} 非流式响应处理完成")
|
||
|
||
return JSONResponse(status_code=status_code, content=response_json)
|
||
|
||
except Exception as e:
|
||
response_time_ms = int((time.time() - sync_start_time) * 1000)
|
||
|
||
status_code = 503
|
||
if isinstance(e, ProviderAuthException):
|
||
status_code = 503
|
||
elif isinstance(e, ProviderRateLimitException):
|
||
status_code = 429
|
||
elif isinstance(e, ProviderTimeoutException):
|
||
status_code = 504
|
||
|
||
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
|
||
actual_request_body = provider_request_body or original_request_body
|
||
|
||
await self.telemetry.record_failure(
|
||
provider=provider_name or "unknown",
|
||
model=model,
|
||
response_time_ms=response_time_ms,
|
||
status_code=status_code,
|
||
error_message=str(e),
|
||
request_headers=original_headers,
|
||
request_body=actual_request_body,
|
||
is_stream=False,
|
||
api_format=api_format,
|
||
provider_request_headers=provider_request_headers,
|
||
# 模型映射信息
|
||
target_model=mapped_model_result,
|
||
)
|
||
|
||
raise
|
||
|
||
async def _extract_error_text(self, e: httpx.HTTPStatusError) -> str:
|
||
"""从 HTTP 错误中提取错误文本"""
|
||
try:
|
||
if hasattr(e.response, "is_stream_consumed") and not e.response.is_stream_consumed:
|
||
error_bytes = await e.response.aread()
|
||
|
||
for encoding in ["utf-8", "gbk", "latin1"]:
|
||
try:
|
||
return error_bytes.decode(encoding)[:500]
|
||
except (UnicodeDecodeError, LookupError):
|
||
continue
|
||
|
||
return error_bytes.decode("utf-8", errors="replace")[:500]
|
||
else:
|
||
return (
|
||
e.response.text[:500]
|
||
if hasattr(e.response, "_content")
|
||
else "Unable to read response"
|
||
)
|
||
except Exception as decode_error:
|
||
return f"Unable to read error response: {decode_error}"
|
||
|
||
def _needs_format_conversion(self, ctx: StreamContext) -> bool:
|
||
"""
|
||
检查是否需要进行格式转换
|
||
|
||
当 Provider 的 API 格式与客户端请求的 API 格式不同时,需要转换响应。
|
||
例如:客户端请求 Claude 格式,但 Provider 返回 OpenAI 格式。
|
||
"""
|
||
if not ctx.provider_api_format or not ctx.client_api_format:
|
||
return False
|
||
return ctx.provider_api_format.upper() != ctx.client_api_format.upper()
|
||
|
||
def _convert_sse_line(
|
||
self,
|
||
ctx: StreamContext,
|
||
line: str,
|
||
events: list,
|
||
) -> Optional[str]:
|
||
"""
|
||
将 SSE 行从 Provider 格式转换为客户端格式
|
||
|
||
Args:
|
||
ctx: 流上下文
|
||
line: 原始 SSE 行
|
||
events: 解析后的事件列表
|
||
|
||
Returns:
|
||
转换后的 SSE 行,如果无法转换则返回 None
|
||
"""
|
||
from src.api.handlers.base.format_converter_registry import converter_registry
|
||
|
||
# 如果是空行或特殊控制行,直接返回
|
||
if not line or line.strip() == "" or line == "data: [DONE]":
|
||
return line
|
||
|
||
# 如果不是 data 行,直接透传
|
||
if not line.startswith("data:"):
|
||
return line
|
||
|
||
# 提取 data 内容
|
||
data_content = line[5:].strip() # 去掉 "data:" 前缀
|
||
|
||
# 尝试解析 JSON
|
||
try:
|
||
data_obj = json.loads(data_content)
|
||
except json.JSONDecodeError:
|
||
# 无法解析,直接透传
|
||
return line
|
||
|
||
# 使用注册表进行格式转换
|
||
try:
|
||
converted_obj = converter_registry.convert_stream_chunk(
|
||
data_obj,
|
||
ctx.provider_api_format,
|
||
ctx.client_api_format,
|
||
)
|
||
# 重新构建 SSE 行
|
||
return f"data: {json.dumps(converted_obj, ensure_ascii=False)}"
|
||
except Exception as e:
|
||
logger.warning(f"格式转换失败,透传原始数据: {e}")
|
||
return line
|