2025-12-10 20:52:44 +08:00
|
|
|
|
"""
|
|
|
|
|
|
Chat Handler Base - Chat API 格式的通用基类
|
|
|
|
|
|
|
|
|
|
|
|
提供 Chat API 格式(Claude Chat、OpenAI Chat)的通用处理逻辑。
|
|
|
|
|
|
与 CliMessageHandlerBase 的区别:
|
|
|
|
|
|
- CLI 模式:透传请求体,直接转发
|
|
|
|
|
|
- Chat 模式:可能需要格式转换(如 OpenAI -> Claude)
|
|
|
|
|
|
|
|
|
|
|
|
两者共享相同的接口:
|
|
|
|
|
|
- process_stream(): 流式请求
|
|
|
|
|
|
- process_sync(): 非流式请求
|
|
|
|
|
|
- apply_mapped_model(): 模型映射
|
|
|
|
|
|
- get_model_for_url(): URL 模型名
|
|
|
|
|
|
- _extract_usage(): 使用量提取
|
2025-12-12 15:42:45 +08:00
|
|
|
|
|
|
|
|
|
|
重构说明:
|
|
|
|
|
|
- StreamContext: 类型安全的流式上下文,替代原有的 ctx dict
|
|
|
|
|
|
- StreamProcessor: 流式响应处理(SSE 解析、预读、错误检测)
|
|
|
|
|
|
- StreamTelemetryRecorder: 统计记录(Usage、Audit、Candidate)
|
2025-12-10 20:52:44 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
from typing import Any, AsyncGenerator, Callable, Dict, Optional
|
|
|
|
|
|
|
|
|
|
|
|
import httpx
|
|
|
|
|
|
from fastapi import BackgroundTasks, Request
|
|
|
|
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
2025-12-12 15:42:45 +08:00
|
|
|
|
from src.api.handlers.base.base_handler import BaseMessageHandler
|
2025-12-10 20:52:44 +08:00
|
|
|
|
from src.api.handlers.base.parsers import get_parser_for_format
|
|
|
|
|
|
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
|
|
|
|
|
|
from src.api.handlers.base.response_parser import ResponseParser
|
2025-12-12 15:42:45 +08:00
|
|
|
|
from src.api.handlers.base.stream_context import StreamContext
|
|
|
|
|
|
from src.api.handlers.base.stream_processor import StreamProcessor
|
|
|
|
|
|
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
|
|
|
|
|
|
from src.config.settings import config
|
2025-12-10 20:52:44 +08:00
|
|
|
|
from src.core.exceptions import (
|
|
|
|
|
|
EmbeddedErrorException,
|
|
|
|
|
|
ProviderAuthException,
|
|
|
|
|
|
ProviderNotAvailableException,
|
|
|
|
|
|
ProviderRateLimitException,
|
|
|
|
|
|
ProviderTimeoutException,
|
|
|
|
|
|
)
|
|
|
|
|
|
from src.core.logger import logger
|
|
|
|
|
|
from src.models.database import (
|
|
|
|
|
|
ApiKey,
|
|
|
|
|
|
Provider,
|
|
|
|
|
|
ProviderAPIKey,
|
|
|
|
|
|
ProviderEndpoint,
|
|
|
|
|
|
User,
|
|
|
|
|
|
)
|
|
|
|
|
|
from src.services.provider.transport import build_provider_url
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatHandlerBase(BaseMessageHandler, ABC):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Chat Handler 基类
|
|
|
|
|
|
|
|
|
|
|
|
主要职责:
|
|
|
|
|
|
- 通过 FallbackOrchestrator 选择 Provider/Endpoint/Key
|
|
|
|
|
|
- 发送请求并处理响应
|
|
|
|
|
|
- 记录日志、审计、统计
|
|
|
|
|
|
- 错误处理
|
|
|
|
|
|
|
|
|
|
|
|
子类需要实现:
|
|
|
|
|
|
- FORMAT_ID: API 格式标识
|
|
|
|
|
|
- _convert_request(): 请求格式转换
|
|
|
|
|
|
- _extract_usage(): 从响应中提取 token 使用情况
|
|
|
|
|
|
- _normalize_response(): 响应规范化(可选)
|
|
|
|
|
|
|
|
|
|
|
|
与 CliMessageHandlerBase 共享的接口:
|
|
|
|
|
|
- apply_mapped_model(): 模型映射到请求体
|
|
|
|
|
|
- get_model_for_url(): 获取 URL 中的模型名
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
FORMAT_ID: str = "" # 子类覆盖
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
db: Session,
|
|
|
|
|
|
user: User,
|
|
|
|
|
|
api_key: ApiKey,
|
|
|
|
|
|
request_id: str,
|
|
|
|
|
|
client_ip: str,
|
|
|
|
|
|
user_agent: str,
|
|
|
|
|
|
start_time: float,
|
|
|
|
|
|
allowed_api_formats: Optional[list] = None,
|
|
|
|
|
|
response_normalizer: Optional[Any] = None,
|
|
|
|
|
|
enable_response_normalization: bool = False,
|
|
|
|
|
|
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
|
|
|
|
|
|
):
|
|
|
|
|
|
allowed = allowed_api_formats or [self.FORMAT_ID]
|
|
|
|
|
|
super().__init__(
|
|
|
|
|
|
db=db,
|
|
|
|
|
|
user=user,
|
|
|
|
|
|
api_key=api_key,
|
|
|
|
|
|
request_id=request_id,
|
|
|
|
|
|
client_ip=client_ip,
|
|
|
|
|
|
user_agent=user_agent,
|
|
|
|
|
|
start_time=start_time,
|
|
|
|
|
|
allowed_api_formats=allowed,
|
|
|
|
|
|
adapter_detector=adapter_detector,
|
|
|
|
|
|
)
|
|
|
|
|
|
self._parser: Optional[ResponseParser] = None
|
|
|
|
|
|
self._request_builder = PassthroughRequestBuilder()
|
|
|
|
|
|
self.response_normalizer = response_normalizer
|
|
|
|
|
|
self.enable_response_normalization = enable_response_normalization
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
def parser(self) -> ResponseParser:
|
|
|
|
|
|
"""获取响应解析器(懒加载)"""
|
|
|
|
|
|
if self._parser is None:
|
|
|
|
|
|
self._parser = get_parser_for_format(self.FORMAT_ID)
|
|
|
|
|
|
return self._parser
|
|
|
|
|
|
|
|
|
|
|
|
# ==================== 抽象方法 ====================
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
2025-12-11 10:05:06 +08:00
|
|
|
|
async def _convert_request(self, request: Any) -> Any:
|
2025-12-10 20:52:44 +08:00
|
|
|
|
"""
|
|
|
|
|
|
将请求转换为目标格式
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
request: 原始请求对象
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
转换后的请求对象
|
|
|
|
|
|
"""
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
|
def _extract_usage(self, response: Dict) -> Dict[str, int]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
从响应中提取 token 使用情况
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
response: 响应数据
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Dict with keys: input_tokens, output_tokens,
|
|
|
|
|
|
cache_creation_input_tokens, cache_read_input_tokens
|
|
|
|
|
|
"""
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def _normalize_response(self, response: Dict) -> Dict:
|
|
|
|
|
|
"""
|
|
|
|
|
|
规范化响应(可选覆盖)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
response: 原始响应
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
规范化后的响应
|
|
|
|
|
|
"""
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
|
|
# ==================== 统一接口方法 ====================
|
|
|
|
|
|
|
|
|
|
|
|
def extract_model_from_request(
|
|
|
|
|
|
self,
|
|
|
|
|
|
request_body: Dict[str, Any],
|
|
|
|
|
|
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002 - 子类使用
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
从请求中提取模型名 - 子类可覆盖
|
|
|
|
|
|
|
|
|
|
|
|
不同 API 格式的 model 位置不同:
|
|
|
|
|
|
- OpenAI/Claude: 在请求体中 request_body["model"]
|
|
|
|
|
|
- Gemini: 在 URL 路径中 path_params["model"]
|
|
|
|
|
|
|
|
|
|
|
|
子类应覆盖此方法实现各自的提取逻辑。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
request_body: 请求体
|
|
|
|
|
|
path_params: URL 路径参数
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
模型名,如果无法提取则返回 "unknown"
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 默认实现:从请求体获取
|
|
|
|
|
|
model = request_body.get("model")
|
|
|
|
|
|
return str(model) if model else "unknown"
|
|
|
|
|
|
|
|
|
|
|
|
def apply_mapped_model(
|
|
|
|
|
|
self,
|
|
|
|
|
|
request_body: Dict[str, Any],
|
|
|
|
|
|
mapped_model: str, # noqa: ARG002 - 子类使用
|
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
将映射后的模型名应用到请求体
|
|
|
|
|
|
|
|
|
|
|
|
基类默认实现:不修改请求体,保持原样透传。
|
|
|
|
|
|
子类应覆盖此方法实现各自的模型名替换逻辑。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
request_body: 原始请求体
|
|
|
|
|
|
mapped_model: 映射后的模型名(子类使用)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
请求体(默认不修改)
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 基类不修改请求体,子类覆盖此方法实现特定格式的处理
|
|
|
|
|
|
return request_body
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_for_url(
|
|
|
|
|
|
self,
|
|
|
|
|
|
request_body: Dict[str, Any],
|
|
|
|
|
|
mapped_model: Optional[str],
|
|
|
|
|
|
) -> Optional[str]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取用于 URL 路径的模型名
|
|
|
|
|
|
|
|
|
|
|
|
某些 API 格式(如 Gemini)需要将 model 放入 URL 路径中。
|
|
|
|
|
|
子类应覆盖此方法返回正确的值。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
request_body: 请求体
|
|
|
|
|
|
mapped_model: 映射后的模型名(如果有)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
用于 URL 路径的模型名
|
|
|
|
|
|
"""
|
|
|
|
|
|
return mapped_model or request_body.get("model")
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_provider_request_body(
|
|
|
|
|
|
self,
|
|
|
|
|
|
request_body: Dict[str, Any],
|
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
准备发送给 Provider 的请求体 - 子类可覆盖
|
|
|
|
|
|
|
|
|
|
|
|
在模型映射之后、发送请求之前调用,用于移除不需要发送给上游的字段。
|
|
|
|
|
|
例如 Gemini API 需要移除请求体中的 model 字段(因为 model 在 URL 路径中)。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
request_body: 经过模型映射处理后的请求体
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
准备好的请求体
|
|
|
|
|
|
"""
|
|
|
|
|
|
return request_body
|
|
|
|
|
|
|
|
|
|
|
|
async def _get_mapped_model(
|
|
|
|
|
|
self,
|
|
|
|
|
|
source_model: str,
|
|
|
|
|
|
provider_id: str,
|
|
|
|
|
|
) -> Optional[str]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取模型映射后的实际模型名
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
source_model: 用户请求的模型名
|
|
|
|
|
|
provider_id: Provider ID
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
映射后的 provider_model_name,没有映射则返回 None
|
|
|
|
|
|
"""
|
|
|
|
|
|
from src.services.model.mapper import ModelMapperMiddleware
|
|
|
|
|
|
|
|
|
|
|
|
mapper = ModelMapperMiddleware(self.db)
|
|
|
|
|
|
mapping = await mapper.get_mapping(source_model, provider_id)
|
|
|
|
|
|
|
|
|
|
|
|
if mapping and mapping.model:
|
2025-12-15 14:30:53 +08:00
|
|
|
|
# 使用 select_provider_model_name 支持别名功能
|
|
|
|
|
|
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
|
|
|
|
|
affinity_key = self.api_key.id if self.api_key else None
|
|
|
|
|
|
mapped_name = mapping.model.select_provider_model_name(affinity_key)
|
2025-12-10 20:52:44 +08:00
|
|
|
|
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
|
|
|
|
|
|
return mapped_name
|
|
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# ==================== 流式处理 ====================
|
|
|
|
|
|
|
|
|
|
|
|
async def process_stream(
|
|
|
|
|
|
self,
|
2025-12-11 10:05:06 +08:00
|
|
|
|
request: Any,
|
2025-12-10 20:52:44 +08:00
|
|
|
|
http_request: Request,
|
2025-12-11 10:05:06 +08:00
|
|
|
|
original_headers: Dict[str, Any],
|
|
|
|
|
|
original_request_body: Dict[str, Any],
|
2025-12-10 20:52:44 +08:00
|
|
|
|
query_params: Optional[Dict[str, str]] = None,
|
|
|
|
|
|
) -> StreamingResponse:
|
|
|
|
|
|
"""处理流式响应"""
|
|
|
|
|
|
logger.debug(f"开始流式响应处理 ({self.FORMAT_ID})")
|
|
|
|
|
|
|
|
|
|
|
|
# 转换请求格式
|
|
|
|
|
|
converted_request = await self._convert_request(request)
|
|
|
|
|
|
model = getattr(converted_request, "model", original_request_body.get("model", "unknown"))
|
|
|
|
|
|
api_format = self.allowed_api_formats[0]
|
|
|
|
|
|
|
2025-12-12 15:42:45 +08:00
|
|
|
|
# 创建类型安全的流式上下文
|
|
|
|
|
|
ctx = StreamContext(model=model, api_format=api_format)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建流处理器
|
|
|
|
|
|
stream_processor = StreamProcessor(
|
|
|
|
|
|
request_id=self.request_id,
|
|
|
|
|
|
default_parser=self.parser,
|
|
|
|
|
|
on_streaming_start=self._update_usage_to_streaming,
|
|
|
|
|
|
)
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
# 定义请求函数
|
|
|
|
|
|
async def stream_request_func(
|
|
|
|
|
|
provider: Provider,
|
|
|
|
|
|
endpoint: ProviderEndpoint,
|
|
|
|
|
|
key: ProviderAPIKey,
|
2025-12-11 10:05:06 +08:00
|
|
|
|
) -> AsyncGenerator[bytes, None]:
|
2025-12-10 20:52:44 +08:00
|
|
|
|
return await self._execute_stream_request(
|
|
|
|
|
|
ctx,
|
2025-12-12 15:42:45 +08:00
|
|
|
|
stream_processor,
|
2025-12-10 20:52:44 +08:00
|
|
|
|
provider,
|
|
|
|
|
|
endpoint,
|
|
|
|
|
|
key,
|
|
|
|
|
|
original_request_body,
|
|
|
|
|
|
original_headers,
|
|
|
|
|
|
query_params,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 解析能力需求
|
|
|
|
|
|
capability_requirements = self._resolve_capability_requirements(
|
|
|
|
|
|
model_name=model,
|
|
|
|
|
|
request_headers=original_headers,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 执行请求(通过 FallbackOrchestrator)
|
|
|
|
|
|
(
|
|
|
|
|
|
stream_generator,
|
|
|
|
|
|
provider_name,
|
|
|
|
|
|
attempt_id,
|
|
|
|
|
|
provider_id,
|
|
|
|
|
|
endpoint_id,
|
|
|
|
|
|
key_id,
|
|
|
|
|
|
) = await self.orchestrator.execute_with_fallback(
|
|
|
|
|
|
api_format=api_format,
|
|
|
|
|
|
model_name=model,
|
|
|
|
|
|
user_api_key=self.api_key,
|
|
|
|
|
|
request_func=stream_request_func,
|
|
|
|
|
|
request_id=self.request_id,
|
|
|
|
|
|
is_stream=True,
|
|
|
|
|
|
capability_requirements=capability_requirements or None,
|
|
|
|
|
|
)
|
2025-12-12 15:42:45 +08:00
|
|
|
|
|
|
|
|
|
|
# 更新上下文
|
|
|
|
|
|
ctx.attempt_id = attempt_id
|
|
|
|
|
|
ctx.provider_name = provider_name
|
|
|
|
|
|
ctx.provider_id = provider_id
|
|
|
|
|
|
ctx.endpoint_id = endpoint_id
|
|
|
|
|
|
ctx.key_id = key_id
|
|
|
|
|
|
|
|
|
|
|
|
# 创建遥测记录器
|
|
|
|
|
|
telemetry_recorder = StreamTelemetryRecorder(
|
|
|
|
|
|
request_id=self.request_id,
|
|
|
|
|
|
user_id=str(self.user.id),
|
|
|
|
|
|
api_key_id=str(self.api_key.id),
|
|
|
|
|
|
client_ip=self.client_ip,
|
|
|
|
|
|
format_id=self.FORMAT_ID,
|
|
|
|
|
|
)
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
# 创建后台任务记录统计
|
|
|
|
|
|
background_tasks = BackgroundTasks()
|
|
|
|
|
|
background_tasks.add_task(
|
2025-12-12 15:42:45 +08:00
|
|
|
|
telemetry_recorder.record_stream_stats,
|
2025-12-10 20:52:44 +08:00
|
|
|
|
ctx,
|
|
|
|
|
|
original_headers,
|
|
|
|
|
|
original_request_body,
|
2025-12-12 15:42:45 +08:00
|
|
|
|
self.elapsed_ms(),
|
2025-12-10 20:52:44 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建监控流
|
2025-12-12 15:42:45 +08:00
|
|
|
|
monitored_stream = stream_processor.create_monitored_stream(
|
|
|
|
|
|
ctx,
|
|
|
|
|
|
stream_generator,
|
|
|
|
|
|
http_request.is_disconnected,
|
|
|
|
|
|
)
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
monitored_stream,
|
|
|
|
|
|
media_type="text/event-stream",
|
|
|
|
|
|
background=background_tasks,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2025-12-15 23:12:48 +08:00
|
|
|
|
self._log_request_error("流式请求失败", e)
|
2025-12-10 20:52:44 +08:00
|
|
|
|
await self._record_stream_failure(ctx, e, original_headers, original_request_body)
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
async def _execute_stream_request(
|
|
|
|
|
|
self,
|
2025-12-12 15:42:45 +08:00
|
|
|
|
ctx: StreamContext,
|
|
|
|
|
|
stream_processor: StreamProcessor,
|
2025-12-10 20:52:44 +08:00
|
|
|
|
provider: Provider,
|
|
|
|
|
|
endpoint: ProviderEndpoint,
|
|
|
|
|
|
key: ProviderAPIKey,
|
|
|
|
|
|
original_request_body: Dict[str, Any],
|
|
|
|
|
|
original_headers: Dict[str, str],
|
|
|
|
|
|
query_params: Optional[Dict[str, str]] = None,
|
|
|
|
|
|
) -> AsyncGenerator[bytes, None]:
|
|
|
|
|
|
"""执行流式请求并返回流生成器"""
|
2025-12-12 15:42:45 +08:00
|
|
|
|
# 重置上下文状态(重试时清除之前的数据)
|
|
|
|
|
|
ctx.reset_for_retry()
|
|
|
|
|
|
|
|
|
|
|
|
# 更新 Provider 信息
|
|
|
|
|
|
ctx.update_provider_info(
|
|
|
|
|
|
provider_name=str(provider.name),
|
|
|
|
|
|
provider_id=str(provider.id),
|
|
|
|
|
|
endpoint_id=str(endpoint.id),
|
|
|
|
|
|
key_id=str(key.id),
|
|
|
|
|
|
provider_api_format=str(endpoint.api_format) if endpoint.api_format else None,
|
|
|
|
|
|
)
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
# 获取模型映射
|
|
|
|
|
|
mapped_model = await self._get_mapped_model(
|
2025-12-12 15:42:45 +08:00
|
|
|
|
source_model=ctx.model,
|
2025-12-11 10:05:06 +08:00
|
|
|
|
provider_id=str(provider.id),
|
2025-12-10 20:52:44 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 应用模型映射到请求体
|
|
|
|
|
|
if mapped_model:
|
2025-12-12 15:42:45 +08:00
|
|
|
|
ctx.mapped_model = mapped_model
|
2025-12-10 20:52:44 +08:00
|
|
|
|
request_body = self.apply_mapped_model(original_request_body, mapped_model)
|
|
|
|
|
|
else:
|
|
|
|
|
|
request_body = dict(original_request_body)
|
|
|
|
|
|
|
2025-12-12 15:42:45 +08:00
|
|
|
|
# 准备发送给 Provider 的请求体
|
2025-12-10 20:52:44 +08:00
|
|
|
|
request_body = self.prepare_provider_request_body(request_body)
|
|
|
|
|
|
|
|
|
|
|
|
# 构建请求
|
|
|
|
|
|
provider_payload, provider_headers = self._request_builder.build(
|
|
|
|
|
|
request_body,
|
|
|
|
|
|
original_headers,
|
|
|
|
|
|
endpoint,
|
|
|
|
|
|
key,
|
|
|
|
|
|
is_stream=True,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-12 15:42:45 +08:00
|
|
|
|
ctx.provider_request_headers = provider_headers
|
|
|
|
|
|
ctx.provider_request_body = provider_payload
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
2025-12-12 15:42:45 +08:00
|
|
|
|
# 获取 URL 模型名
|
|
|
|
|
|
url_model = self.get_model_for_url(request_body, mapped_model) or ctx.model
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
url = build_provider_url(
|
|
|
|
|
|
endpoint,
|
|
|
|
|
|
query_params=query_params,
|
|
|
|
|
|
path_params={"model": url_model},
|
|
|
|
|
|
is_stream=True,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-12 15:42:45 +08:00
|
|
|
|
logger.debug(
|
|
|
|
|
|
f" [{self.request_id}] 发送流式请求: Provider={provider.name}, "
|
|
|
|
|
|
f"模型={ctx.model} -> {mapped_model or '无映射'}"
|
|
|
|
|
|
)
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
2025-12-12 15:42:45 +08:00
|
|
|
|
# 发送请求(使用配置中的超时设置)
|
2025-12-10 20:52:44 +08:00
|
|
|
|
timeout_config = httpx.Timeout(
|
2025-12-12 15:42:45 +08:00
|
|
|
|
connect=config.http_connect_timeout,
|
2025-12-10 20:52:44 +08:00
|
|
|
|
read=float(endpoint.timeout),
|
2025-12-12 15:42:45 +08:00
|
|
|
|
write=config.http_write_timeout,
|
|
|
|
|
|
pool=config.http_pool_timeout,
|
2025-12-10 20:52:44 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
|
|
|
|
|
|
try:
|
|
|
|
|
|
response_ctx = http_client.stream(
|
|
|
|
|
|
"POST", url, json=provider_payload, headers=provider_headers
|
|
|
|
|
|
)
|
|
|
|
|
|
stream_response = await response_ctx.__aenter__()
|
|
|
|
|
|
|
2025-12-12 15:42:45 +08:00
|
|
|
|
ctx.status_code = stream_response.status_code
|
|
|
|
|
|
ctx.response_headers = dict(stream_response.headers)
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
stream_response.raise_for_status()
|
|
|
|
|
|
|
2025-12-12 15:42:45 +08:00
|
|
|
|
# 创建行迭代器
|
2025-12-10 20:52:44 +08:00
|
|
|
|
line_iterator = stream_response.aiter_lines()
|
|
|
|
|
|
|
2025-12-12 15:42:45 +08:00
|
|
|
|
# 预读检测嵌套错误
|
|
|
|
|
|
prefetched_lines = await stream_processor.prefetch_and_check_error(
|
|
|
|
|
|
line_iterator,
|
|
|
|
|
|
provider,
|
|
|
|
|
|
endpoint,
|
|
|
|
|
|
ctx,
|
|
|
|
|
|
max_prefetch_lines=config.stream_prefetch_lines,
|
2025-12-10 20:52:44 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
|
|
|
|
error_text = await self._extract_error_text(e)
|
|
|
|
|
|
logger.error(f"Provider 返回错误: {e.response.status_code}\n Response: {error_text}")
|
|
|
|
|
|
await http_client.aclose()
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
except EmbeddedErrorException:
|
|
|
|
|
|
try:
|
|
|
|
|
|
await response_ctx.__aexit__(None, None, None)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
pass
|
|
|
|
|
|
await http_client.aclose()
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
await http_client.aclose()
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
2025-12-12 15:42:45 +08:00
|
|
|
|
# 创建流生成器
|
|
|
|
|
|
return stream_processor.create_response_stream(
|
2025-12-10 20:52:44 +08:00
|
|
|
|
ctx,
|
|
|
|
|
|
line_iterator,
|
|
|
|
|
|
response_ctx,
|
|
|
|
|
|
http_client,
|
|
|
|
|
|
prefetched_lines,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def _record_stream_failure(
|
|
|
|
|
|
self,
|
2025-12-12 15:42:45 +08:00
|
|
|
|
ctx: StreamContext,
|
2025-12-10 20:52:44 +08:00
|
|
|
|
error: Exception,
|
|
|
|
|
|
original_headers: Dict[str, str],
|
|
|
|
|
|
original_request_body: Dict[str, Any],
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
"""记录流式请求失败"""
|
|
|
|
|
|
response_time_ms = self.elapsed_ms()
|
|
|
|
|
|
|
|
|
|
|
|
status_code = 503
|
|
|
|
|
|
if isinstance(error, ProviderAuthException):
|
|
|
|
|
|
status_code = 503
|
|
|
|
|
|
elif isinstance(error, ProviderRateLimitException):
|
|
|
|
|
|
status_code = 429
|
|
|
|
|
|
elif isinstance(error, ProviderTimeoutException):
|
|
|
|
|
|
status_code = 504
|
|
|
|
|
|
|
2025-12-12 15:42:45 +08:00
|
|
|
|
actual_request_body = ctx.provider_request_body or original_request_body
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
await self.telemetry.record_failure(
|
2025-12-12 15:42:45 +08:00
|
|
|
|
provider=ctx.provider_name or "unknown",
|
|
|
|
|
|
model=ctx.model,
|
2025-12-10 20:52:44 +08:00
|
|
|
|
response_time_ms=response_time_ms,
|
|
|
|
|
|
status_code=status_code,
|
|
|
|
|
|
error_message=str(error),
|
|
|
|
|
|
request_headers=original_headers,
|
|
|
|
|
|
request_body=actual_request_body,
|
|
|
|
|
|
is_stream=True,
|
2025-12-12 15:42:45 +08:00
|
|
|
|
api_format=ctx.api_format,
|
|
|
|
|
|
provider_request_headers=ctx.provider_request_headers,
|
|
|
|
|
|
target_model=ctx.mapped_model,
|
2025-12-10 20:52:44 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# ==================== 非流式处理 ====================
|
|
|
|
|
|
|
|
|
|
|
|
async def process_sync(
|
|
|
|
|
|
self,
|
2025-12-11 10:05:06 +08:00
|
|
|
|
request: Any,
|
2025-12-10 20:52:44 +08:00
|
|
|
|
http_request: Request,
|
2025-12-11 10:05:06 +08:00
|
|
|
|
original_headers: Dict[str, Any],
|
|
|
|
|
|
original_request_body: Dict[str, Any],
|
2025-12-10 20:52:44 +08:00
|
|
|
|
query_params: Optional[Dict[str, str]] = None,
|
|
|
|
|
|
) -> JSONResponse:
|
|
|
|
|
|
"""处理非流式响应"""
|
|
|
|
|
|
logger.debug(f"开始非流式响应处理 ({self.FORMAT_ID})")
|
|
|
|
|
|
|
|
|
|
|
|
# 转换请求格式
|
|
|
|
|
|
converted_request = await self._convert_request(request)
|
|
|
|
|
|
model = getattr(converted_request, "model", original_request_body.get("model", "unknown"))
|
|
|
|
|
|
api_format = self.allowed_api_formats[0]
|
|
|
|
|
|
|
|
|
|
|
|
# 用于跟踪的变量
|
2025-12-11 10:05:06 +08:00
|
|
|
|
provider_name: Optional[str] = None
|
|
|
|
|
|
response_json: Optional[Dict[str, Any]] = None
|
2025-12-10 20:52:44 +08:00
|
|
|
|
status_code = 200
|
2025-12-11 10:05:06 +08:00
|
|
|
|
response_headers: Dict[str, str] = {}
|
|
|
|
|
|
provider_request_headers: Dict[str, str] = {}
|
|
|
|
|
|
provider_request_body: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
provider_id: Optional[str] = None # Provider ID(用于失败记录)
|
|
|
|
|
|
endpoint_id: Optional[str] = None # Endpoint ID(用于失败记录)
|
|
|
|
|
|
key_id: Optional[str] = None # Key ID(用于失败记录)
|
|
|
|
|
|
mapped_model_result: Optional[str] = None # 映射后的目标模型名(用于 Usage 记录)
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
async def sync_request_func(
|
|
|
|
|
|
provider: Provider,
|
|
|
|
|
|
endpoint: ProviderEndpoint,
|
|
|
|
|
|
key: ProviderAPIKey,
|
2025-12-11 10:05:06 +08:00
|
|
|
|
) -> Dict[str, Any]:
|
2025-12-10 20:52:44 +08:00
|
|
|
|
nonlocal provider_name, response_json, status_code, response_headers
|
|
|
|
|
|
nonlocal provider_request_headers, provider_request_body, mapped_model_result
|
|
|
|
|
|
|
2025-12-11 10:05:06 +08:00
|
|
|
|
provider_name = str(provider.name)
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
# 获取模型映射
|
|
|
|
|
|
mapped_model = await self._get_mapped_model(
|
|
|
|
|
|
source_model=model,
|
2025-12-11 10:05:06 +08:00
|
|
|
|
provider_id=str(provider.id),
|
2025-12-10 20:52:44 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 应用模型映射
|
|
|
|
|
|
if mapped_model:
|
|
|
|
|
|
mapped_model_result = mapped_model # 保存映射后的模型名,用于 Usage 记录
|
|
|
|
|
|
request_body = self.apply_mapped_model(original_request_body, mapped_model)
|
|
|
|
|
|
else:
|
|
|
|
|
|
request_body = dict(original_request_body)
|
|
|
|
|
|
|
|
|
|
|
|
# 准备发送给 Provider 的请求体(子类可覆盖以移除不需要的字段)
|
|
|
|
|
|
request_body = self.prepare_provider_request_body(request_body)
|
|
|
|
|
|
|
|
|
|
|
|
# 构建请求
|
|
|
|
|
|
provider_payload, provider_hdrs = self._request_builder.build(
|
|
|
|
|
|
request_body,
|
|
|
|
|
|
original_headers,
|
|
|
|
|
|
endpoint,
|
|
|
|
|
|
key,
|
|
|
|
|
|
is_stream=False,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
provider_request_headers = provider_hdrs
|
|
|
|
|
|
provider_request_body = provider_payload
|
|
|
|
|
|
|
|
|
|
|
|
# 获取 URL 模型名(兜底使用外层的 model,确保 Gemini 等格式能正确构建 URL)
|
|
|
|
|
|
url_model = self.get_model_for_url(request_body, mapped_model) or model
|
|
|
|
|
|
|
|
|
|
|
|
url = build_provider_url(
|
|
|
|
|
|
endpoint,
|
|
|
|
|
|
query_params=query_params,
|
|
|
|
|
|
path_params={"model": url_model},
|
|
|
|
|
|
is_stream=False,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, "
|
|
|
|
|
|
f"模型={model} -> {mapped_model or '无映射'}")
|
|
|
|
|
|
|
|
|
|
|
|
async with httpx.AsyncClient(
|
|
|
|
|
|
timeout=float(endpoint.timeout),
|
|
|
|
|
|
follow_redirects=True,
|
|
|
|
|
|
) as http_client:
|
|
|
|
|
|
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
|
|
|
|
|
|
|
|
|
|
|
|
status_code = resp.status_code
|
|
|
|
|
|
response_headers = dict(resp.headers)
|
|
|
|
|
|
|
|
|
|
|
|
if resp.status_code == 401:
|
|
|
|
|
|
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
|
|
|
|
|
|
elif resp.status_code == 429:
|
|
|
|
|
|
raise ProviderRateLimitException(
|
|
|
|
|
|
f"提供商速率限制: {provider.name}",
|
2025-12-11 10:05:06 +08:00
|
|
|
|
provider_name=str(provider.name),
|
2025-12-10 20:52:44 +08:00
|
|
|
|
response_headers=response_headers,
|
|
|
|
|
|
)
|
|
|
|
|
|
elif resp.status_code >= 500:
|
|
|
|
|
|
raise ProviderNotAvailableException(f"提供商服务不可用: {provider.name}")
|
|
|
|
|
|
elif resp.status_code != 200:
|
|
|
|
|
|
raise ProviderNotAvailableException(
|
|
|
|
|
|
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
response_json = resp.json()
|
2025-12-11 10:05:06 +08:00
|
|
|
|
return response_json if isinstance(response_json, dict) else {}
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 解析能力需求
|
|
|
|
|
|
capability_requirements = self._resolve_capability_requirements(
|
|
|
|
|
|
model_name=model,
|
|
|
|
|
|
request_headers=original_headers,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
(
|
|
|
|
|
|
result,
|
|
|
|
|
|
actual_provider_name,
|
|
|
|
|
|
attempt_id,
|
|
|
|
|
|
provider_id,
|
|
|
|
|
|
endpoint_id,
|
|
|
|
|
|
key_id,
|
|
|
|
|
|
) = await self.orchestrator.execute_with_fallback(
|
|
|
|
|
|
api_format=api_format,
|
|
|
|
|
|
model_name=model,
|
|
|
|
|
|
user_api_key=self.api_key,
|
|
|
|
|
|
request_func=sync_request_func,
|
|
|
|
|
|
request_id=self.request_id,
|
|
|
|
|
|
capability_requirements=capability_requirements or None,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
provider_name = actual_provider_name
|
|
|
|
|
|
response_time_ms = self.elapsed_ms()
|
|
|
|
|
|
|
2025-12-11 10:05:06 +08:00
|
|
|
|
# 确保 response_json 不为 None
|
|
|
|
|
|
if response_json is None:
|
|
|
|
|
|
response_json = {}
|
|
|
|
|
|
|
2025-12-10 20:52:44 +08:00
|
|
|
|
# 规范化响应
|
|
|
|
|
|
response_json = self._normalize_response(response_json)
|
|
|
|
|
|
|
|
|
|
|
|
# 提取 usage
|
|
|
|
|
|
usage_info = self._extract_usage(response_json)
|
|
|
|
|
|
input_tokens = usage_info.get("input_tokens", 0)
|
|
|
|
|
|
output_tokens = usage_info.get("output_tokens", 0)
|
|
|
|
|
|
cache_creation_tokens = usage_info.get("cache_creation_input_tokens", 0)
|
|
|
|
|
|
cached_tokens = usage_info.get("cache_read_input_tokens", 0)
|
|
|
|
|
|
|
|
|
|
|
|
actual_request_body = provider_request_body or original_request_body
|
|
|
|
|
|
|
|
|
|
|
|
total_cost = await self.telemetry.record_success(
|
|
|
|
|
|
provider=provider_name,
|
|
|
|
|
|
model=model,
|
|
|
|
|
|
input_tokens=input_tokens,
|
|
|
|
|
|
output_tokens=output_tokens,
|
|
|
|
|
|
response_time_ms=response_time_ms,
|
|
|
|
|
|
status_code=status_code,
|
|
|
|
|
|
request_headers=original_headers,
|
|
|
|
|
|
request_body=actual_request_body,
|
|
|
|
|
|
response_headers=response_headers,
|
|
|
|
|
|
response_body=response_json,
|
|
|
|
|
|
cache_creation_tokens=cache_creation_tokens,
|
|
|
|
|
|
cache_read_tokens=cached_tokens,
|
|
|
|
|
|
is_stream=False,
|
|
|
|
|
|
provider_request_headers=provider_request_headers,
|
|
|
|
|
|
api_format=api_format,
|
|
|
|
|
|
provider_id=provider_id,
|
|
|
|
|
|
provider_endpoint_id=endpoint_id,
|
|
|
|
|
|
provider_api_key_id=key_id,
|
|
|
|
|
|
# 模型映射信息
|
|
|
|
|
|
target_model=mapped_model_result,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug(f"{self.FORMAT_ID} 非流式响应完成")
|
|
|
|
|
|
|
|
|
|
|
|
# 简洁的请求完成摘要
|
|
|
|
|
|
logger.info(f"[OK] {self.request_id[:8]} | {model} | {provider_name or 'unknown'} | {response_time_ms}ms | "
|
|
|
|
|
|
f"in:{input_tokens or 0} out:{output_tokens or 0}")
|
|
|
|
|
|
|
|
|
|
|
|
return JSONResponse(status_code=status_code, content=response_json)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
response_time_ms = self.elapsed_ms()
|
|
|
|
|
|
|
|
|
|
|
|
status_code = 503
|
|
|
|
|
|
if isinstance(e, ProviderAuthException):
|
|
|
|
|
|
status_code = 503
|
|
|
|
|
|
elif isinstance(e, ProviderRateLimitException):
|
|
|
|
|
|
status_code = 429
|
|
|
|
|
|
elif isinstance(e, ProviderTimeoutException):
|
|
|
|
|
|
status_code = 504
|
|
|
|
|
|
|
|
|
|
|
|
actual_request_body = provider_request_body or original_request_body
|
|
|
|
|
|
|
|
|
|
|
|
await self.telemetry.record_failure(
|
|
|
|
|
|
provider=provider_name or "unknown",
|
|
|
|
|
|
model=model,
|
|
|
|
|
|
response_time_ms=response_time_ms,
|
|
|
|
|
|
status_code=status_code,
|
|
|
|
|
|
error_message=str(e),
|
|
|
|
|
|
request_headers=original_headers,
|
|
|
|
|
|
request_body=actual_request_body,
|
|
|
|
|
|
is_stream=False,
|
|
|
|
|
|
api_format=api_format,
|
|
|
|
|
|
provider_request_headers=provider_request_headers,
|
|
|
|
|
|
# 模型映射信息
|
|
|
|
|
|
target_model=mapped_model_result,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
async def _extract_error_text(self, e: httpx.HTTPStatusError) -> str:
|
|
|
|
|
|
"""从 HTTP 错误中提取错误文本"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
if hasattr(e.response, "is_stream_consumed") and not e.response.is_stream_consumed:
|
|
|
|
|
|
error_bytes = await e.response.aread()
|
|
|
|
|
|
return error_bytes.decode("utf-8", errors="replace")[:500]
|
|
|
|
|
|
else:
|
|
|
|
|
|
return (
|
|
|
|
|
|
e.response.text[:500] if hasattr(e.response, "_content") else "Unable to read"
|
|
|
|
|
|
)
|
|
|
|
|
|
except Exception as decode_error:
|
|
|
|
|
|
return f"Unable to read error: {decode_error}"
|