Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

View File

@@ -0,0 +1,68 @@
"""
Handler 基类模块
提供 Adapter、Handler 的抽象基类,以及请求构建器和响应解析器。
注意Handler 基类ChatHandlerBase, CliMessageHandlerBase 等)不在这里导出,
因为它们依赖 services.usage.stream而后者又需要导入 response_parser
会形成循环导入。请直接从具体模块导入 Handler 基类。
"""
# Chat Adapter 基类(不会引起循环导入)
from src.api.handlers.base.chat_adapter_base import (
ChatAdapterBase,
get_adapter_class,
get_adapter_instance,
list_registered_formats,
register_adapter,
)
# CLI Adapter 基类
from src.api.handlers.base.cli_adapter_base import (
CliAdapterBase,
get_cli_adapter_class,
get_cli_adapter_instance,
list_registered_cli_formats,
register_cli_adapter,
)
# 请求构建器
from src.api.handlers.base.request_builder import (
SENSITIVE_HEADERS,
PassthroughRequestBuilder,
RequestBuilder,
build_passthrough_request,
)
# 响应解析器
from src.api.handlers.base.response_parser import (
ParsedChunk,
ParsedResponse,
ResponseParser,
StreamStats,
)
__all__ = [
# Chat Adapter
"ChatAdapterBase",
"register_adapter",
"get_adapter_class",
"get_adapter_instance",
"list_registered_formats",
# CLI Adapter
"CliAdapterBase",
"register_cli_adapter",
"get_cli_adapter_class",
"get_cli_adapter_instance",
"list_registered_cli_formats",
# 请求构建器
"RequestBuilder",
"PassthroughRequestBuilder",
"build_passthrough_request",
"SENSITIVE_HEADERS",
# 响应解析器
"ResponseParser",
"ParsedChunk",
"ParsedResponse",
"StreamStats",
]

View File

@@ -0,0 +1,363 @@
"""
基础消息处理器,封装通用的编排、转换、遥测逻辑。
接口约定:
- process_stream: 处理流式请求,返回 StreamingResponse
- process_sync: 处理非流式请求,返回 JSONResponse
签名规范(推荐):
async def process_stream(
self,
request: Any, # 解析后的请求模型
http_request: Request, # FastAPI Request 对象
original_headers: Dict[str, str], # 原始请求头
original_request_body: Dict[str, Any], # 原始请求体
query_params: Optional[Dict[str, str]] = None, # 查询参数
) -> StreamingResponse: ...
async def process_sync(
self,
request: Any,
http_request: Request,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
query_params: Optional[Dict[str, str]] = None,
) -> JSONResponse: ...
"""
from __future__ import annotations
import time
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from sqlalchemy.orm import Session
from src.clients.redis_client import get_redis_client_sync
from src.core.api_format_metadata import resolve_api_format
from src.core.enums import APIFormat
from src.core.logger import logger
from src.services.orchestration.fallback_orchestrator import FallbackOrchestrator
from src.services.provider.format import normalize_api_format
from src.services.system.audit import audit_service
from src.services.usage.service import UsageService
class MessageTelemetry:
"""
负责记录 Usage/Audit避免处理器里重复代码。
"""
def __init__(self, db: Session, user, api_key, request_id: str, client_ip: str):
self.db = db
self.user = user
self.api_key = api_key
self.request_id = request_id
self.client_ip = client_ip
async def calculate_cost(
self,
provider: str,
model: str,
*,
input_tokens: int,
output_tokens: int,
cache_creation_tokens: int = 0,
cache_read_tokens: int = 0,
) -> float:
input_price, output_price = await UsageService.get_model_price_async(
self.db, provider, model
)
_, _, _, _, _, _, total_cost = UsageService.calculate_cost(
input_tokens,
output_tokens,
input_price,
output_price,
cache_creation_tokens,
cache_read_tokens,
*await UsageService.get_cache_prices_async(self.db, provider, model, input_price),
)
return total_cost
async def record_success(
self,
*,
provider: str,
model: str,
input_tokens: int,
output_tokens: int,
response_time_ms: int,
status_code: int,
request_body: Dict[str, Any],
request_headers: Dict[str, Any],
response_body: Any,
response_headers: Dict[str, Any],
cache_creation_tokens: int = 0,
cache_read_tokens: int = 0,
is_stream: bool = False,
provider_request_headers: Optional[Dict[str, Any]] = None,
# Provider 侧追踪信息(用于记录真实成本)
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
provider_api_key_id: Optional[str] = None,
api_format: Optional[str] = None,
# 模型映射信息
target_model: Optional[str] = None,
# Provider 响应元数据(如 Gemini 的 modelVersion
response_metadata: Optional[Dict[str, Any]] = None,
) -> float:
total_cost = await self.calculate_cost(
provider,
model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_tokens=cache_creation_tokens,
cache_read_tokens=cache_read_tokens,
)
await UsageService.record_usage(
db=self.db,
user=self.user,
api_key=self.api_key,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_tokens,
cache_read_input_tokens=cache_read_tokens,
request_type="chat",
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
status_code=status_code,
request_headers=request_headers,
request_body=request_body,
provider_request_headers=provider_request_headers or {},
response_headers=response_headers,
response_body=response_body,
request_id=self.request_id,
# Provider 侧追踪信息(用于记录真实成本)
provider_id=provider_id,
provider_endpoint_id=provider_endpoint_id,
provider_api_key_id=provider_api_key_id,
# 模型映射信息
target_model=target_model,
# Provider 响应元数据
metadata=response_metadata,
)
if self.user and self.api_key:
audit_service.log_api_request(
db=self.db,
user_id=self.user.id,
api_key_id=self.api_key.id,
request_id=self.request_id,
model=model,
provider=provider,
success=True,
ip_address=self.client_ip,
status_code=status_code,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost_usd=total_cost,
)
return total_cost
async def record_failure(
self,
*,
provider: str,
model: str,
response_time_ms: int,
status_code: int,
error_message: str,
request_body: Dict[str, Any],
request_headers: Dict[str, Any],
is_stream: bool,
api_format: Optional[str] = None,
provider_request_headers: Optional[Dict[str, Any]] = None,
# 预估 token 信息(来自 message_start 事件,用于中断请求的成本估算)
input_tokens: int = 0,
output_tokens: int = 0,
cache_creation_tokens: int = 0,
cache_read_tokens: int = 0,
response_body: Optional[Dict[str, Any]] = None,
# 模型映射信息
target_model: Optional[str] = None,
):
"""
记录失败请求
注意Provider 链路信息provider_id, endpoint_id, key_id不在此处记录
因为 RequestCandidate 表已经记录了完整的请求链路追踪信息。
Args:
input_tokens: 预估输入 tokens来自 message_start用于中断请求的成本估算
output_tokens: 预估输出 tokens来自已收到的内容
cache_creation_tokens: 缓存创建 tokens
cache_read_tokens: 缓存读取 tokens
response_body: 响应体(如果有部分响应)
target_model: 映射后的目标模型名(如果发生了映射)
"""
provider_name = provider or "unknown"
if provider_name == "unknown":
logger.warning(f"[Telemetry] Recording failure with unknown provider (request_id={self.request_id})")
await UsageService.record_usage(
db=self.db,
user=self.user,
api_key=self.api_key,
provider=provider_name,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_tokens,
cache_read_input_tokens=cache_read_tokens,
request_type="chat",
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=error_message,
request_headers=request_headers,
request_body=request_body,
provider_request_headers=provider_request_headers or {},
response_headers={},
response_body=response_body or {"error": error_message},
request_id=self.request_id,
# 模型映射信息
target_model=target_model,
)
@runtime_checkable
class MessageHandlerProtocol(Protocol):
"""
消息处理器协议 - 定义标准接口
ChatHandlerBase 使用完整签名(含 request, http_request
CliMessageHandlerBase 使用简化签名(仅 original_request_body, original_headers
"""
async def process_stream(
self,
request: Any,
http_request: Request,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
query_params: Optional[Dict[str, str]] = None,
) -> StreamingResponse:
"""处理流式请求"""
...
async def process_sync(
self,
request: Any,
http_request: Request,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
query_params: Optional[Dict[str, str]] = None,
) -> JSONResponse:
"""处理非流式请求"""
...
class BaseMessageHandler:
"""
消息处理器基类,所有具体格式的 handler 可以继承它。
子类需要实现:
- process_stream: 处理流式请求
- process_sync: 处理非流式请求
推荐使用 MessageHandlerProtocol 中定义的签名。
"""
# Adapter 检测器类型
AdapterDetectorType = Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]
def __init__(
self,
*,
db: Session,
user,
api_key,
request_id: str,
client_ip: str,
user_agent: str,
start_time: float,
allowed_api_formats: Optional[list[str]] = None,
adapter_detector: Optional[AdapterDetectorType] = None,
):
self.db = db
self.user = user
self.api_key = api_key
self.request_id = request_id
self.client_ip = client_ip
self.user_agent = user_agent
self.start_time = start_time
self.allowed_api_formats = allowed_api_formats or [APIFormat.CLAUDE.value]
self.primary_api_format = normalize_api_format(self.allowed_api_formats[0])
self.adapter_detector = adapter_detector
redis_client = get_redis_client_sync()
self.orchestrator = FallbackOrchestrator(db, redis_client)
self.telemetry = MessageTelemetry(db, user, api_key, request_id, client_ip)
def elapsed_ms(self) -> int:
return int((time.time() - self.start_time) * 1000)
def _resolve_capability_requirements(
self,
model_name: str,
request_headers: Optional[Dict[str, str]] = None,
request_body: Optional[Dict[str, Any]] = None,
) -> Dict[str, bool]:
"""
解析请求的能力需求
来源:
1. 用户模型级配置 (User.model_capability_settings)
2. 用户 API Key 强制配置 (ApiKey.force_capabilities)
3. 请求头 X-Require-Capability
4. Adapter 的 detect_capability_requirements如 Claude 的 anthropic-beta
Args:
model_name: 模型名称
request_headers: 请求头
request_body: 请求体(可选)
Returns:
能力需求字典
"""
from src.services.capability.resolver import CapabilityResolver
return CapabilityResolver.resolve_requirements(
user=self.user,
user_api_key=self.api_key,
model_name=model_name,
request_headers=request_headers,
request_body=request_body,
adapter_detector=self.adapter_detector,
)
def get_api_format(self, provider_type: Optional[str] = None) -> APIFormat:
"""根据 provider_type 解析 API 格式,未知类型默认 OPENAI"""
if provider_type:
return resolve_api_format(provider_type, default=APIFormat.OPENAI)
return self.primary_api_format
def build_provider_payload(
self,
original_body: Dict[str, Any],
*,
mapped_model: Optional[str] = None,
) -> Dict[str, Any]:
"""构建发送给 Provider 的请求体,替换 model 名称"""
payload = dict(original_body)
if mapped_model:
payload["model"] = mapped_model
return payload

View File

@@ -0,0 +1,724 @@
"""
Chat Adapter 通用基类
提供 Chat 格式(进行请求验证和标准化)的通用适配器逻辑:
- 请求解析和验证
- 审计日志记录
- 错误处理和响应格式化
- Handler 创建和调用
- 计费策略(支持不同 API 格式的差异化计费)
子类只需提供:
- FORMAT_ID: API 格式标识
- HANDLER_CLASS: 对应的 ChatHandlerBase 子类
- _validate_request_body(): 可选覆盖请求验证逻辑
- _build_audit_metadata(): 可选覆盖审计元数据构建
- compute_total_input_context(): 可选覆盖总输入上下文计算(用于阶梯计费判定)
"""
import time
import traceback
from abc import abstractmethod
from typing import Any, Dict, Optional, Type
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.context import ApiRequestContext
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.core.exceptions import (
InvalidRequestException,
ModelNotSupportedException,
ProviderAuthException,
ProviderNotAvailableException,
ProviderRateLimitException,
ProviderTimeoutException,
ProxyException,
QuotaExceededException,
UpstreamClientException,
)
from src.core.logger import logger
from src.services.request.result import RequestResult
from src.services.usage.recorder import UsageRecorder
class ChatAdapterBase(ApiAdapter):
"""
Chat Adapter 通用基类
提供 Chat 格式的通用适配器逻辑,子类只需配置:
- FORMAT_ID: API 格式标识
- HANDLER_CLASS: ChatHandlerBase 子类
- name: 适配器名称
"""
# 子类必须覆盖
FORMAT_ID: str = "UNKNOWN"
HANDLER_CLASS: Type[ChatHandlerBase]
# 适配器配置
name: str = "chat.base"
mode = ApiMode.STANDARD
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
self.response_normalizer = None
# 可选启用响应规范化
self._init_response_normalizer()
def _init_response_normalizer(self):
"""初始化响应规范化器 - 子类可覆盖"""
try:
from src.services.provider.response_normalizer import ResponseNormalizer
self.response_normalizer = ResponseNormalizer()
except ImportError:
pass
async def handle(self, context: ApiRequestContext):
"""处理 Chat API 请求"""
http_request = context.request
user = context.user
api_key = context.api_key
db = context.db
request_id = context.request_id
quota_remaining_value = context.quota_remaining
start_time = context.start_time
client_ip = context.client_ip
user_agent = context.user_agent
original_headers = context.original_headers
query_params = context.query_params
original_request_body = context.ensure_json_body()
# 合并 path_params 到请求体(如 Gemini API 的 model 在 URL 路径中)
if context.path_params:
original_request_body = self._merge_path_params(
original_request_body, context.path_params
)
# 验证和解析请求
request_obj = self._validate_request_body(original_request_body, context.path_params)
if isinstance(request_obj, JSONResponse):
return request_obj
stream = getattr(request_obj, "stream", False)
model = getattr(request_obj, "model", "unknown")
# 添加审计元数据
audit_metadata = self._build_audit_metadata(original_request_body, request_obj)
context.add_audit_metadata(**audit_metadata)
# 格式化额度显示
quota_display = (
"unlimited" if quota_remaining_value is None else f"${quota_remaining_value:.2f}"
)
# 请求开始日志
logger.info(f"[REQ] {request_id[:8]} | {self.FORMAT_ID} | {getattr(api_key, 'name', 'unknown')} | "
f"{model} | {'stream' if stream else 'sync'} | quota:{quota_display}")
try:
# 检查客户端连接
if await http_request.is_disconnected():
logger.warning("客户端连接断开")
raise HTTPException(status_code=499, detail="Client disconnected")
# 创建 Handler
handler = self._create_handler(
db=db,
user=user,
api_key=api_key,
request_id=request_id,
client_ip=client_ip,
user_agent=user_agent,
start_time=start_time,
)
# 处理请求
if stream:
return await handler.process_stream(
request=request_obj,
http_request=http_request,
original_headers=original_headers,
original_request_body=original_request_body,
query_params=query_params,
)
return await handler.process_sync(
request=request_obj,
http_request=http_request,
original_headers=original_headers,
original_request_body=original_request_body,
query_params=query_params,
)
except HTTPException:
raise
except (
ModelNotSupportedException,
QuotaExceededException,
InvalidRequestException,
) as e:
logger.info(f"客户端请求错误: {e.error_type}")
return self._error_response(
status_code=e.status_code,
error_type=(
"invalid_request_error" if e.status_code == 400 else "quota_exceeded"
),
message=e.message,
)
except (
ProviderAuthException,
ProviderRateLimitException,
ProviderNotAvailableException,
ProviderTimeoutException,
UpstreamClientException,
) as e:
return await self._handle_provider_exception(
e,
db=db,
user=user,
api_key=api_key,
model=model,
stream=stream,
start_time=start_time,
original_headers=original_headers,
original_request_body=original_request_body,
client_ip=client_ip,
request_id=request_id,
)
except Exception as e:
return await self._handle_unexpected_exception(
e,
db=db,
user=user,
api_key=api_key,
model=model,
stream=stream,
start_time=start_time,
original_headers=original_headers,
original_request_body=original_request_body,
client_ip=client_ip,
request_id=request_id,
)
def _create_handler(
self,
*,
db,
user,
api_key,
request_id: str,
client_ip: str,
user_agent: str,
start_time: float,
):
"""创建 Handler 实例 - 子类可覆盖"""
return self.HANDLER_CLASS(
db=db,
user=user,
api_key=api_key,
request_id=request_id,
client_ip=client_ip,
user_agent=user_agent,
start_time=start_time,
allowed_api_formats=self.allowed_api_formats,
response_normalizer=self.response_normalizer,
enable_response_normalization=self.response_normalizer is not None,
adapter_detector=self.detect_capability_requirements,
)
def _merge_path_params(
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any]
) -> Dict[str, Any]:
"""
合并 URL 路径参数到请求体 - 子类可覆盖
默认实现:直接将 path_params 中的字段合并到请求体(不覆盖已有字段)
Args:
original_request_body: 原始请求体字典
path_params: URL 路径参数字典
Returns:
合并后的请求体字典
"""
merged = original_request_body.copy()
for key, value in path_params.items():
if key not in merged:
merged[key] = value
return merged
@abstractmethod
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
"""
验证请求体 - 子类必须实现
Args:
original_request_body: 原始请求体字典
path_params: URL 路径参数(如 Gemini 的 stream 通过 URL 端点传入)
Returns:
验证后的请求对象,或 JSONResponse 错误响应
"""
pass
def _extract_message_count(self, payload: Dict[str, Any], request_obj) -> int:
"""
提取消息数量 - 子类可覆盖
默认实现:从 messages 字段提取
"""
messages = payload.get("messages", [])
if hasattr(request_obj, "messages"):
messages = request_obj.messages
return len(messages) if isinstance(messages, list) else 0
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
"""
构建审计日志元数据 - 子类可覆盖
"""
model = getattr(request_obj, "model", payload.get("model", "unknown"))
stream = getattr(request_obj, "stream", payload.get("stream", False))
messages_count = self._extract_message_count(payload, request_obj)
return {
"action": f"{self.FORMAT_ID.lower()}_request",
"model": model,
"stream": bool(stream),
"max_tokens": getattr(request_obj, "max_tokens", payload.get("max_tokens")),
"messages_count": messages_count,
"temperature": getattr(request_obj, "temperature", payload.get("temperature")),
"top_p": getattr(request_obj, "top_p", payload.get("top_p")),
}
async def _handle_provider_exception(
self,
e: Exception,
*,
db,
user,
api_key,
model: str,
stream: bool,
start_time: float,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
client_ip: str,
request_id: str,
) -> JSONResponse:
"""处理 Provider 相关异常"""
logger.debug(f"Caught provider exception: {type(e).__name__}")
response_time = int((time.time() - start_time) * 1000)
# 使用 RequestResult.from_exception 创建统一的失败结果
# 关键api_format 从 FORMAT_ID 获取,确保始终有值
result = RequestResult.from_exception(
exception=e,
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
model=model,
response_time_ms=response_time,
is_stream=stream,
)
result.request_headers = original_headers
result.request_body = original_request_body
# 确定错误消息
if isinstance(e, ProviderAuthException):
error_message = (
f"提供商认证失败: {str(e)}"
if result.metadata.provider != "unknown"
else "服务端错误: 无可用提供商"
)
result.error_message = error_message
# 处理上游客户端错误(如图片处理失败)
if isinstance(e, UpstreamClientException):
# 返回 400 状态码和清晰的错误消息
result.status_code = e.status_code
result.error_message = e.message
# 使用 UsageRecorder 记录失败
recorder = UsageRecorder(
db=db,
user=user,
api_key=api_key,
client_ip=client_ip,
request_id=request_id,
)
await recorder.record_failure(result, original_headers, original_request_body)
# 根据异常类型确定错误类型
if isinstance(e, UpstreamClientException):
error_type = "invalid_request_error"
elif result.status_code == 503:
error_type = "internal_server_error"
else:
error_type = "rate_limit_exceeded"
return self._error_response(
status_code=result.status_code,
error_type=error_type,
message=result.error_message or str(e),
)
async def _handle_unexpected_exception(
self,
e: Exception,
*,
db,
user,
api_key,
model: str,
stream: bool,
start_time: float,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
client_ip: str,
request_id: str,
) -> JSONResponse:
"""处理未预期的异常"""
if isinstance(e, ProxyException):
logger.error(f"{self.FORMAT_ID} 请求处理业务异常: {type(e).__name__}")
else:
logger.error(f"{self.FORMAT_ID} 请求处理意外异常",
exception=e,
extra_data={
"exception_class": e.__class__.__name__,
"processing_stage": "request_processing",
"model": model,
"stream": stream,
"traceback_preview": str(traceback.format_exc())[:500],
},
)
response_time = int((time.time() - start_time) * 1000)
# 使用 RequestResult.from_exception 创建统一的失败结果
# 关键api_format 从 FORMAT_ID 获取,确保始终有值
result = RequestResult.from_exception(
exception=e,
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
model=model,
response_time_ms=response_time,
is_stream=stream,
)
# 对于未预期的异常,强制设置状态码为 500
result.status_code = 500
result.error_type = "internal_error"
result.request_headers = original_headers
result.request_body = original_request_body
try:
# 使用 UsageRecorder 记录失败
recorder = UsageRecorder(
db=db,
user=user,
api_key=api_key,
client_ip=client_ip,
request_id=request_id,
)
await recorder.record_failure(result, original_headers, original_request_body)
except Exception as record_error:
logger.error(f"记录失败请求时出错: {record_error}")
return self._error_response(
status_code=500,
error_type="internal_server_error",
message="处理请求时发生内部错误")
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
"""生成错误响应 - 子类可覆盖以自定义格式"""
return JSONResponse(
status_code=status_code,
content={
"error": {
"type": error_type,
"message": message,
}
},
)
# =========================================================================
# 计费策略相关方法 - 子类可覆盖以实现不同 API 格式的差异化计费
# =========================================================================
def compute_total_input_context(
self,
input_tokens: int,
cache_read_input_tokens: int,
cache_creation_input_tokens: int = 0,
) -> int:
"""
计算总输入上下文(用于阶梯计费判定)
默认实现input_tokens + cache_read_input_tokens
子类可覆盖此方法实现不同的计算逻辑
Args:
input_tokens: 输入 token 数
cache_read_input_tokens: 缓存读取 token 数
cache_creation_input_tokens: 缓存创建 token 数(部分格式可能需要)
Returns:
总输入上下文 token 数
"""
return input_tokens + cache_read_input_tokens
def get_cache_read_price_for_ttl(
self,
tier: dict,
cache_ttl_minutes: Optional[int] = None,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
子类可覆盖此方法实现不同的 TTL 定价逻辑
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格(每 1M tokens
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if ttl_pricing and cache_ttl_minutes is not None:
matched_price = None
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
matched_price = ttl_config.get("cache_read_price_per_1m")
break
if matched_price is not None:
return matched_price
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
return ttl_pricing[-1].get("cache_read_price_per_1m")
return tier.get("cache_read_price_per_1m")
def compute_cost(
self,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int,
cache_read_input_tokens: int,
input_price_per_1m: float,
output_price_per_1m: float,
cache_creation_price_per_1m: Optional[float],
cache_read_price_per_1m: Optional[float],
price_per_request: Optional[float],
tiered_pricing: Optional[dict] = None,
cache_ttl_minutes: Optional[int] = None,
) -> Dict[str, Any]:
"""
计算请求成本
默认实现:支持固定价格和阶梯计费
子类可覆盖此方法实现完全不同的计费逻辑
Args:
input_tokens: 输入 token 数
output_tokens: 输出 token 数
cache_creation_input_tokens: 缓存创建 token 数
cache_read_input_tokens: 缓存读取 token 数
input_price_per_1m: 输入价格(每 1M tokens
output_price_per_1m: 输出价格(每 1M tokens
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens
cache_read_price_per_1m: 缓存读取价格(每 1M tokens
price_per_request: 按次计费价格
tiered_pricing: 阶梯计费配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
包含各项成本的字典:
{
"input_cost": float,
"output_cost": float,
"cache_creation_cost": float,
"cache_read_cost": float,
"cache_cost": float,
"request_cost": float,
"total_cost": float,
"tier_index": Optional[int], # 命中的阶梯索引
}
"""
tier_index = None
effective_input_price = input_price_per_1m
effective_output_price = output_price_per_1m
effective_cache_creation_price = cache_creation_price_per_1m
effective_cache_read_price = cache_read_price_per_1m
# 检查阶梯计费
if tiered_pricing and tiered_pricing.get("tiers"):
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
if tier:
tier_index = tiered_pricing["tiers"].index(tier)
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
effective_cache_creation_price = tier.get(
"cache_creation_price_per_1m", cache_creation_price_per_1m
)
effective_cache_read_price = self.get_cache_read_price_for_ttl(
tier, cache_ttl_minutes
)
if effective_cache_read_price is None:
effective_cache_read_price = cache_read_price_per_1m
# 计算各项成本
input_cost = (input_tokens / 1_000_000) * effective_input_price
output_cost = (output_tokens / 1_000_000) * effective_output_price
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * effective_cache_creation_price
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
cache_read_cost = (
cache_read_input_tokens / 1_000_000
) * effective_cache_read_price
cache_cost = cache_creation_cost + cache_read_cost
request_cost = price_per_request if price_per_request else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return {
"input_cost": input_cost,
"output_cost": output_cost,
"cache_creation_cost": cache_creation_cost,
"cache_read_cost": cache_read_cost,
"cache_cost": cache_cost,
"request_cost": request_cost,
"total_cost": total_cost,
"tier_index": tier_index,
}
@staticmethod
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
"""
根据总输入 token 数确定价格阶梯
Args:
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
total_input_tokens: 总输入 token 数
Returns:
匹配的阶梯配置
"""
if not tiered_pricing or "tiers" not in tiered_pricing:
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
for tier in tiers:
up_to = tier.get("up_to")
if up_to is None or total_input_tokens <= up_to:
return tier
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
return tiers[-1] if tiers else None
# =========================================================================
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
# =========================================================================
_ADAPTER_REGISTRY: Dict[str, Type["ChatAdapterBase"]] = {}
_ADAPTERS_LOADED = False
def register_adapter(adapter_class: Type["ChatAdapterBase"]) -> Type["ChatAdapterBase"]:
"""
注册 Adapter 类到注册表
用法:
@register_adapter
class ClaudeChatAdapter(ChatAdapterBase):
FORMAT_ID = "CLAUDE"
...
Args:
adapter_class: Adapter 类
Returns:
注册的 Adapter 类(支持作为装饰器使用)
"""
format_id = adapter_class.FORMAT_ID
if format_id and format_id != "UNKNOWN":
_ADAPTER_REGISTRY[format_id.upper()] = adapter_class
return adapter_class
def _ensure_adapters_loaded():
"""确保所有 Adapter 已被加载(触发注册)"""
global _ADAPTERS_LOADED
if _ADAPTERS_LOADED:
return
# 导入各个 Adapter 模块以触发 @register_adapter 装饰器
try:
from src.api.handlers.claude import adapter as _ # noqa: F401
except ImportError:
pass
try:
from src.api.handlers.openai import adapter as _ # noqa: F401
except ImportError:
pass
try:
from src.api.handlers.gemini import adapter as _ # noqa: F401
except ImportError:
pass
_ADAPTERS_LOADED = True
def get_adapter_class(api_format: str) -> Optional[Type["ChatAdapterBase"]]:
"""
根据 API format 获取 Adapter 类
Args:
api_format: API 格式标识(如 "CLAUDE", "OPENAI", "GEMINI"
Returns:
对应的 Adapter 类,如果未找到返回 None
"""
_ensure_adapters_loaded()
return _ADAPTER_REGISTRY.get(api_format.upper()) if api_format else None
def get_adapter_instance(api_format: str) -> Optional["ChatAdapterBase"]:
"""
根据 API format 获取 Adapter 实例
Args:
api_format: API 格式标识
Returns:
Adapter 实例,如果未找到返回 None
"""
adapter_class = get_adapter_class(api_format)
if adapter_class:
return adapter_class()
return None
def list_registered_formats() -> list[str]:
"""返回所有已注册的 API 格式"""
_ensure_adapters_loaded()
return list(_ADAPTER_REGISTRY.keys())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,648 @@
"""
CLI Adapter 通用基类
提供 CLI 格式(直接透传请求)的通用适配器逻辑:
- 请求解析和验证
- 审计日志记录
- 错误处理和响应格式化
- Handler 创建和调用
- 计费策略(支持不同 API 格式的差异化计费)
子类只需提供:
- FORMAT_ID: API 格式标识
- HANDLER_CLASS: 对应的 MessageHandler 类
- 可选覆盖 _extract_message_count() 自定义消息计数逻辑
- 可选覆盖 compute_total_input_context() 自定义总输入上下文计算
"""
import time
import traceback
from typing import Any, Dict, Optional, Type
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.context import ApiRequestContext
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
from src.core.exceptions import (
InvalidRequestException,
ModelNotSupportedException,
ProviderAuthException,
ProviderNotAvailableException,
ProviderRateLimitException,
ProviderTimeoutException,
ProxyException,
QuotaExceededException,
UpstreamClientException,
)
from src.core.logger import logger
from src.services.request.result import RequestResult
from src.services.usage.recorder import UsageRecorder
class CliAdapterBase(ApiAdapter):
"""
CLI Adapter 通用基类
提供 CLI 格式的通用适配器逻辑,子类只需配置:
- FORMAT_ID: API 格式标识
- HANDLER_CLASS: MessageHandler 类
- name: 适配器名称
"""
# 子类必须覆盖
FORMAT_ID: str = "UNKNOWN"
HANDLER_CLASS: Type[CliMessageHandlerBase]
# 适配器配置
name: str = "cli.base"
mode = ApiMode.PROXY
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
async def handle(self, context: ApiRequestContext):
"""处理 CLI API 请求"""
http_request = context.request
user = context.user
api_key = context.api_key
db = context.db
request_id = context.request_id
quota_remaining_value = context.quota_remaining
start_time = context.start_time
client_ip = context.client_ip
user_agent = context.user_agent
original_headers = context.original_headers
query_params = context.query_params # 获取查询参数
original_request_body = context.ensure_json_body()
# 合并 path_params 到请求体(如 Gemini API 的 model 在 URL 路径中)
if context.path_params:
original_request_body = self._merge_path_params(
original_request_body, context.path_params
)
# 获取 stream优先从请求体其次从 path_params如 Gemini 通过 URL 端点区分)
stream = original_request_body.get("stream")
if stream is None and context.path_params:
stream = context.path_params.get("stream", False)
stream = bool(stream)
# 获取 model优先从请求体其次从 path_params如 Gemini 的 model 在 URL 路径中)
model = original_request_body.get("model")
if model is None and context.path_params:
model = context.path_params.get("model", "unknown")
model = model or "unknown"
# 提取请求元数据
audit_metadata = self._build_audit_metadata(original_request_body, context.path_params)
context.add_audit_metadata(**audit_metadata)
# 格式化额度显示
quota_display = (
"unlimited" if quota_remaining_value is None else f"${quota_remaining_value:.2f}"
)
# 请求开始日志
logger.info(f"[REQ] {request_id[:8]} | {self.FORMAT_ID} | {getattr(api_key, 'name', 'unknown')} | "
f"{model} | {'stream' if stream else 'sync'} | quota:{quota_display}")
try:
# 检查客户端连接
if await http_request.is_disconnected():
logger.warning("客户端连接断开")
raise HTTPException(status_code=499, detail="Client disconnected")
# 创建 Handler
handler = self.HANDLER_CLASS(
db=db,
user=user,
api_key=api_key,
request_id=request_id,
client_ip=client_ip,
user_agent=user_agent,
start_time=start_time,
allowed_api_formats=self.allowed_api_formats,
adapter_detector=self.detect_capability_requirements,
)
# 处理请求
if stream:
return await handler.process_stream(
original_request_body=original_request_body,
original_headers=original_headers,
query_params=query_params,
path_params=context.path_params,
)
return await handler.process_sync(
original_request_body=original_request_body,
original_headers=original_headers,
query_params=query_params,
path_params=context.path_params,
)
except HTTPException:
raise
except (
ModelNotSupportedException,
QuotaExceededException,
InvalidRequestException,
) as e:
logger.debug(f"客户端请求错误: {e.error_type}")
return self._error_response(
status_code=e.status_code,
error_type=(
"invalid_request_error" if e.status_code == 400 else "quota_exceeded"
),
message=e.message,
)
except (
ProviderAuthException,
ProviderRateLimitException,
ProviderNotAvailableException,
ProviderTimeoutException,
UpstreamClientException,
) as e:
return await self._handle_provider_exception(
e,
db=db,
user=user,
api_key=api_key,
model=model,
stream=stream,
start_time=start_time,
original_headers=original_headers,
original_request_body=original_request_body,
client_ip=client_ip,
request_id=request_id,
)
except Exception as e:
return await self._handle_unexpected_exception(
e,
db=db,
user=user,
api_key=api_key,
model=model,
stream=stream,
start_time=start_time,
original_headers=original_headers,
original_request_body=original_request_body,
client_ip=client_ip,
request_id=request_id,
)
def _merge_path_params(
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any]
) -> Dict[str, Any]:
"""
合并 URL 路径参数到请求体 - 子类可覆盖
默认实现:直接将 path_params 中的字段合并到请求体(不覆盖已有字段)
Args:
original_request_body: 原始请求体字典
path_params: URL 路径参数字典
Returns:
合并后的请求体字典
"""
merged = original_request_body.copy()
for key, value in path_params.items():
if key not in merged:
merged[key] = value
return merged
def _extract_message_count(self, payload: Dict[str, Any]) -> int:
"""
提取消息数量 - 子类可覆盖
默认实现:从 input 字段提取
"""
if "input" not in payload:
return 0
input_data = payload["input"]
if isinstance(input_data, list):
return len(input_data)
if isinstance(input_data, dict) and "messages" in input_data:
return len(input_data.get("messages", []))
return 0
def _build_audit_metadata(
self,
payload: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
构建审计日志元数据 - 子类可覆盖
Args:
payload: 请求体
path_params: URL 路径参数(用于获取 model 等)
"""
# 优先从请求体获取 model其次从 path_params
model = payload.get("model")
if model is None and path_params:
model = path_params.get("model", "unknown")
model = model or "unknown"
stream = payload.get("stream", False)
messages_count = self._extract_message_count(payload)
return {
"action": f"{self.FORMAT_ID.lower()}_request",
"model": model,
"stream": bool(stream),
"max_tokens": payload.get("max_tokens"),
"messages_count": messages_count,
"temperature": payload.get("temperature"),
"top_p": payload.get("top_p"),
"tool_count": len(payload.get("tools") or []),
"instructions_present": bool(payload.get("instructions")),
}
async def _handle_provider_exception(
self,
e: Exception,
*,
db,
user,
api_key,
model: str,
stream: bool,
start_time: float,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
client_ip: str,
request_id: str,
) -> JSONResponse:
"""处理 Provider 相关异常"""
logger.debug(f"Caught provider exception: {type(e).__name__}")
response_time = int((time.time() - start_time) * 1000)
# 使用 RequestResult.from_exception 创建统一的失败结果
# 关键api_format 从 FORMAT_ID 获取,确保始终有值
result = RequestResult.from_exception(
exception=e,
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
model=model,
response_time_ms=response_time,
is_stream=stream,
)
result.request_headers = original_headers
result.request_body = original_request_body
# 确定错误消息
if isinstance(e, ProviderAuthException):
error_message = (
f"提供商认证失败: {str(e)}"
if result.metadata.provider != "unknown"
else "服务端错误: 无可用提供商"
)
result.error_message = error_message
# 处理上游客户端错误(如图片处理失败)
if isinstance(e, UpstreamClientException):
# 返回 400 状态码和清晰的错误消息
result.status_code = e.status_code
result.error_message = e.message
# 使用 UsageRecorder 记录失败
recorder = UsageRecorder(
db=db,
user=user,
api_key=api_key,
client_ip=client_ip,
request_id=request_id,
)
await recorder.record_failure(result, original_headers, original_request_body)
# 根据异常类型确定错误类型
if isinstance(e, UpstreamClientException):
error_type = "invalid_request_error"
elif result.status_code == 503:
error_type = "internal_server_error"
else:
error_type = "rate_limit_exceeded"
return self._error_response(
status_code=result.status_code,
error_type=error_type,
message=result.error_message or str(e),
)
async def _handle_unexpected_exception(
self,
e: Exception,
*,
db,
user,
api_key,
model: str,
stream: bool,
start_time: float,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
client_ip: str,
request_id: str,
) -> JSONResponse:
"""处理未预期的异常"""
if isinstance(e, ProxyException):
logger.error(f"{self.FORMAT_ID} 请求处理业务异常: {type(e).__name__}")
else:
logger.error(f"{self.FORMAT_ID} 请求处理意外异常",
exception=e,
extra_data={
"exception_class": e.__class__.__name__,
"processing_stage": "request_processing",
"model": model,
"stream": stream,
"traceback_preview": str(traceback.format_exc())[:500],
},
)
response_time = int((time.time() - start_time) * 1000)
# 使用 RequestResult.from_exception 创建统一的失败结果
# 关键api_format 从 FORMAT_ID 获取,确保始终有值
result = RequestResult.from_exception(
exception=e,
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
model=model,
response_time_ms=response_time,
is_stream=stream,
)
# 对于未预期的异常,强制设置状态码为 500
result.status_code = 500
result.error_type = "internal_error"
result.request_headers = original_headers
result.request_body = original_request_body
# 使用 UsageRecorder 记录失败
recorder = UsageRecorder(
db=db,
user=user,
api_key=api_key,
client_ip=client_ip,
request_id=request_id,
)
await recorder.record_failure(result, original_headers, original_request_body)
return self._error_response(
status_code=500,
error_type="internal_server_error",
message="处理请求时发生内部错误")
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
"""生成错误响应"""
return JSONResponse(
status_code=status_code,
content={
"error": {
"type": error_type,
"message": message,
}
},
)
# =========================================================================
# 计费策略相关方法 - 子类可覆盖以实现不同 API 格式的差异化计费
# =========================================================================
def compute_total_input_context(
self,
input_tokens: int,
cache_read_input_tokens: int,
cache_creation_input_tokens: int = 0,
) -> int:
"""
计算总输入上下文(用于阶梯计费判定)
默认实现input_tokens + cache_read_input_tokens
子类可覆盖此方法实现不同的计算逻辑
Args:
input_tokens: 输入 token 数
cache_read_input_tokens: 缓存读取 token 数
cache_creation_input_tokens: 缓存创建 token 数(部分格式可能需要)
Returns:
总输入上下文 token 数
"""
return input_tokens + cache_read_input_tokens
def get_cache_read_price_for_ttl(
self,
tier: dict,
cache_ttl_minutes: Optional[int] = None,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
子类可覆盖此方法实现不同的 TTL 定价逻辑
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格(每 1M tokens
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if ttl_pricing and cache_ttl_minutes is not None:
matched_price = None
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
matched_price = ttl_config.get("cache_read_price_per_1m")
break
if matched_price is not None:
return matched_price
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
return ttl_pricing[-1].get("cache_read_price_per_1m")
return tier.get("cache_read_price_per_1m")
def compute_cost(
self,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int,
cache_read_input_tokens: int,
input_price_per_1m: float,
output_price_per_1m: float,
cache_creation_price_per_1m: Optional[float],
cache_read_price_per_1m: Optional[float],
price_per_request: Optional[float],
tiered_pricing: Optional[dict] = None,
cache_ttl_minutes: Optional[int] = None,
) -> Dict[str, Any]:
"""
计算请求成本
默认实现:支持固定价格和阶梯计费
子类可覆盖此方法实现完全不同的计费逻辑
Args:
input_tokens: 输入 token 数
output_tokens: 输出 token 数
cache_creation_input_tokens: 缓存创建 token 数
cache_read_input_tokens: 缓存读取 token 数
input_price_per_1m: 输入价格(每 1M tokens
output_price_per_1m: 输出价格(每 1M tokens
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens
cache_read_price_per_1m: 缓存读取价格(每 1M tokens
price_per_request: 按次计费价格
tiered_pricing: 阶梯计费配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
包含各项成本的字典
"""
tier_index = None
effective_input_price = input_price_per_1m
effective_output_price = output_price_per_1m
effective_cache_creation_price = cache_creation_price_per_1m
effective_cache_read_price = cache_read_price_per_1m
# 检查阶梯计费
if tiered_pricing and tiered_pricing.get("tiers"):
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
if tier:
tier_index = tiered_pricing["tiers"].index(tier)
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
effective_cache_creation_price = tier.get(
"cache_creation_price_per_1m", cache_creation_price_per_1m
)
effective_cache_read_price = self.get_cache_read_price_for_ttl(
tier, cache_ttl_minutes
)
if effective_cache_read_price is None:
effective_cache_read_price = cache_read_price_per_1m
# 计算各项成本
input_cost = (input_tokens / 1_000_000) * effective_input_price
output_cost = (output_tokens / 1_000_000) * effective_output_price
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * effective_cache_creation_price
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
cache_read_cost = (
cache_read_input_tokens / 1_000_000
) * effective_cache_read_price
cache_cost = cache_creation_cost + cache_read_cost
request_cost = price_per_request if price_per_request else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return {
"input_cost": input_cost,
"output_cost": output_cost,
"cache_creation_cost": cache_creation_cost,
"cache_read_cost": cache_read_cost,
"cache_cost": cache_cost,
"request_cost": request_cost,
"total_cost": total_cost,
"tier_index": tier_index,
}
@staticmethod
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
"""根据总输入 token 数确定价格阶梯"""
if not tiered_pricing or "tiers" not in tiered_pricing:
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
for tier in tiers:
up_to = tier.get("up_to")
if up_to is None or total_input_tokens <= up_to:
return tier
return tiers[-1] if tiers else None
# =========================================================================
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
# =========================================================================
_CLI_ADAPTER_REGISTRY: Dict[str, Type["CliAdapterBase"]] = {}
_CLI_ADAPTERS_LOADED = False
def register_cli_adapter(adapter_class: Type["CliAdapterBase"]) -> Type["CliAdapterBase"]:
"""
注册 CLI Adapter 类到注册表
用法:
@register_cli_adapter
class ClaudeCliAdapter(CliAdapterBase):
FORMAT_ID = "CLAUDE_CLI"
...
"""
format_id = adapter_class.FORMAT_ID
if format_id and format_id != "UNKNOWN":
_CLI_ADAPTER_REGISTRY[format_id.upper()] = adapter_class
return adapter_class
def _ensure_cli_adapters_loaded():
"""确保所有 CLI Adapter 已被加载(触发注册)"""
global _CLI_ADAPTERS_LOADED
if _CLI_ADAPTERS_LOADED:
return
# 导入各个 CLI Adapter 模块以触发 @register_cli_adapter 装饰器
try:
from src.api.handlers.claude_cli import adapter as _ # noqa: F401
except ImportError:
pass
try:
from src.api.handlers.openai_cli import adapter as _ # noqa: F401
except ImportError:
pass
try:
from src.api.handlers.gemini_cli import adapter as _ # noqa: F401
except ImportError:
pass
_CLI_ADAPTERS_LOADED = True
def get_cli_adapter_class(api_format: str) -> Optional[Type["CliAdapterBase"]]:
"""根据 API format 获取 CLI Adapter 类"""
_ensure_cli_adapters_loaded()
return _CLI_ADAPTER_REGISTRY.get(api_format.upper()) if api_format else None
def get_cli_adapter_instance(api_format: str) -> Optional["CliAdapterBase"]:
"""根据 API format 获取 CLI Adapter 实例"""
adapter_class = get_cli_adapter_class(api_format)
if adapter_class:
return adapter_class()
return None
def list_registered_cli_formats() -> list[str]:
"""返回所有已注册的 CLI API 格式"""
_ensure_cli_adapters_loaded()
return list(_CLI_ADAPTER_REGISTRY.keys())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,279 @@
"""
格式转换器注册表
自动管理不同 API 格式之间的转换器,支持:
- 请求转换:客户端格式 → Provider 格式
- 响应转换Provider 格式 → 客户端格式
使用方法:
1. 实现 Converter 类(需要有 convert_request 和/或 convert_response 方法)
2. 调用 registry.register() 注册转换器
3. 在 Handler 中调用 registry.convert_request/convert_response
示例:
from src.api.handlers.base.format_converter_registry import converter_registry
# 注册转换器
converter_registry.register("CLAUDE", "GEMINI", ClaudeToGeminiConverter())
converter_registry.register("GEMINI", "CLAUDE", GeminiToClaudeConverter())
# 使用转换器
gemini_request = converter_registry.convert_request(claude_request, "CLAUDE", "GEMINI")
claude_response = converter_registry.convert_response(gemini_response, "GEMINI", "CLAUDE")
"""
from typing import Any, Dict, Optional, Protocol, Tuple
from src.core.logger import logger
class RequestConverter(Protocol):
"""请求转换器协议"""
def convert_request(self, request: Dict[str, Any]) -> Dict[str, Any]: ...
class ResponseConverter(Protocol):
"""响应转换器协议"""
def convert_response(self, response: Dict[str, Any]) -> Dict[str, Any]: ...
class StreamChunkConverter(Protocol):
"""流式响应块转换器协议"""
def convert_stream_chunk(self, chunk: Dict[str, Any]) -> Dict[str, Any]: ...
class FormatConverterRegistry:
"""
格式转换器注册表
管理不同 API 格式之间的双向转换器
"""
def __init__(self):
# key: (source_format, target_format), value: converter instance
self._converters: Dict[Tuple[str, str], Any] = {}
def register(
self,
source_format: str,
target_format: str,
converter: Any,
) -> None:
"""
注册格式转换器
Args:
source_format: 源格式(如 "CLAUDE", "OPENAI", "GEMINI"
target_format: 目标格式
converter: 转换器实例(需要有 convert_request/convert_response 方法)
"""
key = (source_format.upper(), target_format.upper())
self._converters[key] = converter
logger.info(f"[ConverterRegistry] 注册转换器: {source_format} -> {target_format}")
def get_converter(
self,
source_format: str,
target_format: str,
) -> Optional[Any]:
"""
获取转换器
Args:
source_format: 源格式
target_format: 目标格式
Returns:
转换器实例,如果不存在返回 None
"""
key = (source_format.upper(), target_format.upper())
return self._converters.get(key)
def has_converter(
self,
source_format: str,
target_format: str,
) -> bool:
"""检查是否存在转换器"""
key = (source_format.upper(), target_format.upper())
return key in self._converters
def convert_request(
self,
request: Dict[str, Any],
source_format: str,
target_format: str,
) -> Dict[str, Any]:
"""
转换请求
Args:
request: 原始请求字典
source_format: 源格式(客户端格式)
target_format: 目标格式Provider 格式)
Returns:
转换后的请求字典,如果无需转换或没有转换器则返回原始请求
"""
# 同格式无需转换
if source_format.upper() == target_format.upper():
return request
converter = self.get_converter(source_format, target_format)
if converter is None:
logger.warning(f"[ConverterRegistry] 未找到请求转换器: {source_format} -> {target_format},返回原始请求")
return request
if not hasattr(converter, "convert_request"):
logger.warning(f"[ConverterRegistry] 转换器缺少 convert_request 方法: {source_format} -> {target_format}")
return request
try:
converted = converter.convert_request(request)
logger.debug(f"[ConverterRegistry] 请求转换成功: {source_format} -> {target_format}")
return converted
except Exception as e:
logger.error(f"[ConverterRegistry] 请求转换失败: {source_format} -> {target_format}: {e}")
return request
def convert_response(
self,
response: Dict[str, Any],
source_format: str,
target_format: str,
) -> Dict[str, Any]:
"""
转换响应
Args:
response: 原始响应字典
source_format: 源格式Provider 格式)
target_format: 目标格式(客户端格式)
Returns:
转换后的响应字典,如果无需转换或没有转换器则返回原始响应
"""
# 同格式无需转换
if source_format.upper() == target_format.upper():
return response
converter = self.get_converter(source_format, target_format)
if converter is None:
logger.warning(f"[ConverterRegistry] 未找到响应转换器: {source_format} -> {target_format},返回原始响应")
return response
if not hasattr(converter, "convert_response"):
logger.warning(f"[ConverterRegistry] 转换器缺少 convert_response 方法: {source_format} -> {target_format}")
return response
try:
converted = converter.convert_response(response)
logger.debug(f"[ConverterRegistry] 响应转换成功: {source_format} -> {target_format}")
return converted
except Exception as e:
logger.error(f"[ConverterRegistry] 响应转换失败: {source_format} -> {target_format}: {e}")
return response
def convert_stream_chunk(
self,
chunk: Dict[str, Any],
source_format: str,
target_format: str,
) -> Dict[str, Any]:
"""
转换流式响应块
Args:
chunk: 原始流式响应块
source_format: 源格式Provider 格式)
target_format: 目标格式(客户端格式)
Returns:
转换后的流式响应块
"""
# 同格式无需转换
if source_format.upper() == target_format.upper():
return chunk
converter = self.get_converter(source_format, target_format)
if converter is None:
return chunk
# 优先使用专门的流式转换方法
if hasattr(converter, "convert_stream_chunk"):
try:
return converter.convert_stream_chunk(chunk)
except Exception as e:
logger.error(f"[ConverterRegistry] 流式块转换失败: {source_format} -> {target_format}: {e}")
return chunk
# 降级到普通响应转换
if hasattr(converter, "convert_response"):
try:
return converter.convert_response(chunk)
except Exception:
return chunk
return chunk
def list_converters(self) -> list[Tuple[str, str]]:
"""列出所有已注册的转换器"""
return list(self._converters.keys())
# 全局单例
converter_registry = FormatConverterRegistry()
def register_all_converters():
"""
注册所有内置的格式转换器
在应用启动时调用此函数
"""
# Claude <-> OpenAI
try:
from src.api.handlers.claude.converter import OpenAIToClaudeConverter
from src.api.handlers.openai.converter import ClaudeToOpenAIConverter
converter_registry.register("OPENAI", "CLAUDE", OpenAIToClaudeConverter())
converter_registry.register("CLAUDE", "OPENAI", ClaudeToOpenAIConverter())
except ImportError as e:
logger.warning(f"[ConverterRegistry] 无法加载 Claude/OpenAI 转换器: {e}")
# Claude <-> Gemini
try:
from src.api.handlers.gemini.converter import (
ClaudeToGeminiConverter,
GeminiToClaudeConverter,
)
converter_registry.register("CLAUDE", "GEMINI", ClaudeToGeminiConverter())
converter_registry.register("GEMINI", "CLAUDE", GeminiToClaudeConverter())
except ImportError as e:
logger.warning(f"[ConverterRegistry] 无法加载 Claude/Gemini 转换器: {e}")
# OpenAI <-> Gemini
try:
from src.api.handlers.gemini.converter import (
GeminiToOpenAIConverter,
OpenAIToGeminiConverter,
)
converter_registry.register("OPENAI", "GEMINI", OpenAIToGeminiConverter())
converter_registry.register("GEMINI", "OPENAI", GeminiToOpenAIConverter())
except ImportError as e:
logger.warning(f"[ConverterRegistry] 无法加载 OpenAI/Gemini 转换器: {e}")
logger.info(f"[ConverterRegistry] 已注册 {len(converter_registry.list_converters())} 个格式转换器")
__all__ = [
"FormatConverterRegistry",
"converter_registry",
"register_all_converters",
]

View File

@@ -0,0 +1,465 @@
"""
响应解析器工厂
直接根据格式 ID 创建对应的 ResponseParser 实现,
不再经过 Protocol 抽象层。
"""
from typing import Any, Dict, Optional
from src.api.handlers.base.response_parser import (
ParsedChunk,
ParsedResponse,
ResponseParser,
StreamStats,
)
class OpenAIResponseParser(ResponseParser):
"""OpenAI 格式响应解析器"""
def __init__(self):
from src.api.handlers.openai.stream_parser import OpenAIStreamParser
self._parser = OpenAIStreamParser()
self.name = "OPENAI"
self.api_format = "OPENAI"
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
if not line or not line.strip():
return None
if line.startswith("data: "):
data_str = line[6:]
else:
data_str = line
parsed = self._parser.parse_line(data_str)
if parsed is None:
return None
chunk = ParsedChunk(
raw_line=line,
event_type=None,
data=parsed,
)
# 提取文本增量
text_delta = self._parser.extract_text_delta(parsed)
if text_delta:
chunk.text_delta = text_delta
stats.collected_text += text_delta
# 检查是否结束
if self._parser.is_done_chunk(parsed):
chunk.is_done = True
stats.has_completion = True
stats.chunk_count += 1
stats.data_count += 1
return chunk
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
result = ParsedResponse(
raw_response=response,
status_code=status_code,
)
# 提取文本内容
choices = response.get("choices", [])
if choices:
message = choices[0].get("message", {})
content = message.get("content")
if content:
result.text_content = content
result.response_id = response.get("id")
# 提取 usage
usage = response.get("usage", {})
result.input_tokens = usage.get("prompt_tokens", 0)
result.output_tokens = usage.get("completion_tokens", 0)
# 检查错误
if "error" in response:
result.is_error = True
error = response.get("error", {})
if isinstance(error, dict):
result.error_type = error.get("type")
result.error_message = error.get("message")
else:
result.error_message = str(error)
return result
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
usage = response.get("usage", {})
return {
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
"cache_creation_tokens": 0,
"cache_read_tokens": 0,
}
def extract_text_content(self, response: Dict[str, Any]) -> str:
choices = response.get("choices", [])
if choices:
message = choices[0].get("message", {})
content = message.get("content")
if content:
return content
return ""
def is_error_response(self, response: Dict[str, Any]) -> bool:
return "error" in response
class OpenAICliResponseParser(OpenAIResponseParser):
"""OpenAI CLI 格式响应解析器"""
def __init__(self):
super().__init__()
self.name = "OPENAI_CLI"
self.api_format = "OPENAI_CLI"
class ClaudeResponseParser(ResponseParser):
"""Claude 格式响应解析器"""
def __init__(self):
from src.api.handlers.claude.stream_parser import ClaudeStreamParser
self._parser = ClaudeStreamParser()
self.name = "CLAUDE"
self.api_format = "CLAUDE"
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
if not line or not line.strip():
return None
if line.startswith("data: "):
data_str = line[6:]
else:
data_str = line
parsed = self._parser.parse_line(data_str)
if parsed is None:
return None
chunk = ParsedChunk(
raw_line=line,
event_type=self._parser.get_event_type(parsed),
data=parsed,
)
# 提取文本增量
text_delta = self._parser.extract_text_delta(parsed)
if text_delta:
chunk.text_delta = text_delta
stats.collected_text += text_delta
# 检查是否结束
if self._parser.is_done_event(parsed):
chunk.is_done = True
stats.has_completion = True
# 提取 usage
usage = self._parser.extract_usage(parsed)
if usage:
chunk.input_tokens = usage.get("input_tokens", 0)
chunk.output_tokens = usage.get("output_tokens", 0)
chunk.cache_creation_tokens = usage.get("cache_creation_tokens", 0)
chunk.cache_read_tokens = usage.get("cache_read_tokens", 0)
stats.input_tokens = chunk.input_tokens
stats.output_tokens = chunk.output_tokens
stats.cache_creation_tokens = chunk.cache_creation_tokens
stats.cache_read_tokens = chunk.cache_read_tokens
# 检查错误
if self._parser.is_error_event(parsed):
chunk.is_error = True
error = parsed.get("error", {})
if isinstance(error, dict):
chunk.error_message = error.get("message", str(error))
else:
chunk.error_message = str(error)
stats.chunk_count += 1
stats.data_count += 1
return chunk
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
result = ParsedResponse(
raw_response=response,
status_code=status_code,
)
# 提取文本内容
content = response.get("content", [])
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(block.get("text", ""))
result.text_content = "".join(text_parts)
result.response_id = response.get("id")
# 提取 usage
usage = response.get("usage", {})
result.input_tokens = usage.get("input_tokens", 0)
result.output_tokens = usage.get("output_tokens", 0)
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0)
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
# 检查错误
if "error" in response or response.get("type") == "error":
result.is_error = True
error = response.get("error", {})
if isinstance(error, dict):
result.error_type = error.get("type")
result.error_message = error.get("message")
else:
result.error_message = str(error)
return result
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
usage = response.get("usage", {})
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
}
def extract_text_content(self, response: Dict[str, Any]) -> str:
content = response.get("content", [])
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(block.get("text", ""))
return "".join(text_parts)
return ""
def is_error_response(self, response: Dict[str, Any]) -> bool:
return "error" in response or response.get("type") == "error"
class ClaudeCliResponseParser(ClaudeResponseParser):
"""Claude CLI 格式响应解析器"""
def __init__(self):
super().__init__()
self.name = "CLAUDE_CLI"
self.api_format = "CLAUDE_CLI"
class GeminiResponseParser(ResponseParser):
"""Gemini 格式响应解析器"""
def __init__(self):
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
self._parser = GeminiStreamParser()
self.name = "GEMINI"
self.api_format = "GEMINI"
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
"""
解析 Gemini SSE 行
Gemini 的流式响应使用 SSE 格式 (data: {...})
"""
if not line or not line.strip():
return None
# Gemini SSE 格式: data: {...}
if line.startswith("data: "):
data_str = line[6:]
else:
data_str = line
parsed = self._parser.parse_line(data_str)
if parsed is None:
return None
chunk = ParsedChunk(
raw_line=line,
event_type="content",
data=parsed,
)
# 提取文本增量
text_delta = self._parser.extract_text_delta(parsed)
if text_delta:
chunk.text_delta = text_delta
stats.collected_text += text_delta
# 检查是否结束
if self._parser.is_done_event(parsed):
chunk.is_done = True
stats.has_completion = True
# 提取 usage
usage = self._parser.extract_usage(parsed)
if usage:
chunk.input_tokens = usage.get("input_tokens", 0)
chunk.output_tokens = usage.get("output_tokens", 0)
chunk.cache_read_tokens = usage.get("cached_tokens", 0)
stats.input_tokens = chunk.input_tokens
stats.output_tokens = chunk.output_tokens
stats.cache_read_tokens = chunk.cache_read_tokens
# 检查错误
if self._parser.is_error_event(parsed):
chunk.is_error = True
error = parsed.get("error", {})
if isinstance(error, dict):
chunk.error_message = error.get("message", str(error))
else:
chunk.error_message = str(error)
stats.chunk_count += 1
stats.data_count += 1
return chunk
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
result = ParsedResponse(
raw_response=response,
status_code=status_code,
)
# 提取文本内容
candidates = response.get("candidates", [])
if candidates:
content = candidates[0].get("content", {})
parts = content.get("parts", [])
text_parts = []
for part in parts:
if "text" in part:
text_parts.append(part["text"])
result.text_content = "".join(text_parts)
result.response_id = response.get("modelVersion")
# 提取 usage调用 GeminiStreamParser.extract_usage 作为单一实现源)
usage = self._parser.extract_usage(response)
if usage:
result.input_tokens = usage.get("input_tokens", 0)
result.output_tokens = usage.get("output_tokens", 0)
result.cache_read_tokens = usage.get("cached_tokens", 0)
# 检查错误(使用增强的错误检测)
error_info = self._parser.extract_error_info(response)
if error_info:
result.is_error = True
result.error_type = error_info.get("status")
result.error_message = error_info.get("message")
return result
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
"""
从 Gemini 响应中提取 token 使用量
调用 GeminiStreamParser.extract_usage 作为单一实现源
"""
usage = self._parser.extract_usage(response)
if not usage:
return {
"input_tokens": 0,
"output_tokens": 0,
"cache_creation_tokens": 0,
"cache_read_tokens": 0,
}
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": 0,
"cache_read_tokens": usage.get("cached_tokens", 0),
}
def extract_text_content(self, response: Dict[str, Any]) -> str:
candidates = response.get("candidates", [])
if candidates:
content = candidates[0].get("content", {})
parts = content.get("parts", [])
text_parts = []
for part in parts:
if "text" in part:
text_parts.append(part["text"])
return "".join(text_parts)
return ""
def is_error_response(self, response: Dict[str, Any]) -> bool:
"""
判断响应是否为错误响应
使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误
"""
return self._parser.is_error_event(response)
class GeminiCliResponseParser(GeminiResponseParser):
"""Gemini CLI 格式响应解析器"""
def __init__(self):
super().__init__()
self.name = "GEMINI_CLI"
self.api_format = "GEMINI_CLI"
# 解析器注册表
_PARSERS = {
"CLAUDE": ClaudeResponseParser,
"CLAUDE_CLI": ClaudeCliResponseParser,
"OPENAI": OpenAIResponseParser,
"OPENAI_CLI": OpenAICliResponseParser,
"GEMINI": GeminiResponseParser,
"GEMINI_CLI": GeminiCliResponseParser,
}
def get_parser_for_format(format_id: str) -> ResponseParser:
"""
根据格式 ID 获取 ResponseParser
Args:
format_id: 格式 ID"CLAUDE", "OPENAI", "CLAUDE_CLI", "OPENAI_CLI"
Returns:
ResponseParser 实例
Raises:
KeyError: 格式不存在
"""
format_id = format_id.upper()
if format_id not in _PARSERS:
raise KeyError(f"Unknown format: {format_id}")
return _PARSERS[format_id]()
def is_cli_format(format_id: str) -> bool:
"""判断是否为 CLI 格式"""
return format_id.upper().endswith("_CLI")
__all__ = [
"OpenAIResponseParser",
"OpenAICliResponseParser",
"ClaudeResponseParser",
"ClaudeCliResponseParser",
"GeminiResponseParser",
"GeminiCliResponseParser",
"get_parser_for_format",
"get_parser_from_protocol",
"is_cli_format",
]

View File

@@ -0,0 +1,207 @@
"""
请求构建器 - 透传模式
透传模式 (Passthrough): CLI 和 Chat 等场景,原样转发请求体和头部
- 清理敏感头部authorization, x-api-key, host, content-length 等
- 保留所有其他头部和请求体字段
- 适用于Claude CLI、OpenAI CLI、Chat API 等场景
使用方式:
builder = PassthroughRequestBuilder()
payload, headers = builder.build(original_body, original_headers, endpoint, key)
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, FrozenSet, Optional, Tuple
from src.core.crypto import crypto_service
# ==============================================================================
# 统一的头部配置常量
# ==============================================================================
# 敏感头部 - 透传时需要清理(黑名单)
# 这些头部要么包含认证信息,要么由代理层重新生成
SENSITIVE_HEADERS: FrozenSet[str] = frozenset(
{
"authorization",
"x-api-key",
"x-goog-api-key", # Gemini API 认证头
"host",
"content-length",
"transfer-encoding",
"connection",
# 不透传 accept-encoding让 httpx 自己协商压缩格式
# 避免客户端请求 brotli/zstd 但 httpx 不支持解压的问题
"accept-encoding",
}
)
# ==============================================================================
# 请求构建器
# ==============================================================================
class RequestBuilder(ABC):
"""请求构建器抽象基类"""
@abstractmethod
def build_payload(
self,
original_body: Dict[str, Any],
*,
mapped_model: Optional[str] = None,
is_stream: bool = False,
) -> Dict[str, Any]:
"""构建请求体"""
pass
@abstractmethod
def build_headers(
self,
original_headers: Dict[str, str],
endpoint: Any,
key: Any,
*,
extra_headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
"""构建请求头"""
pass
def build(
self,
original_body: Dict[str, Any],
original_headers: Dict[str, str],
endpoint: Any,
key: Any,
*,
mapped_model: Optional[str] = None,
is_stream: bool = False,
extra_headers: Optional[Dict[str, str]] = None,
) -> Tuple[Dict[str, Any], Dict[str, str]]:
"""
构建完整的请求(请求体 + 请求头)
Returns:
Tuple[payload, headers]
"""
payload = self.build_payload(
original_body,
mapped_model=mapped_model,
is_stream=is_stream,
)
headers = self.build_headers(
original_headers,
endpoint,
key,
extra_headers=extra_headers,
)
return payload, headers
class PassthroughRequestBuilder(RequestBuilder):
"""
透传模式请求构建器
适用于 CLI 等场景,尽量保持请求原样:
- 请求体直接复制只修改必要字段model, stream
- 请求头:清理敏感头部(黑名单),透传其他所有头部
"""
def build_payload(
self,
original_body: Dict[str, Any],
*,
mapped_model: Optional[str] = None, # noqa: ARG002 - 由 apply_mapped_model 处理
is_stream: bool = False, # noqa: ARG002 - 保留原始值,不自动添加
) -> Dict[str, Any]:
"""
透传请求体 - 原样复制,不做任何修改
透传模式下:
- model: 由各 handler 的 apply_mapped_model 方法处理
- stream: 保留客户端原始值(不同 API 处理方式不同)
"""
return dict(original_body)
def build_headers(
self,
original_headers: Dict[str, str],
endpoint: Any,
key: Any,
*,
extra_headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
"""
透传请求头 - 清理敏感头部(黑名单),透传其他所有头部
"""
from src.core.api_format_metadata import get_auth_config, resolve_api_format
headers: Dict[str, str] = {}
# 1. 根据 API 格式自动设置认证头
decrypted_key = crypto_service.decrypt(key.api_key)
api_format = getattr(endpoint, "api_format", None)
resolved_format = resolve_api_format(api_format)
auth_header, auth_type = (
get_auth_config(resolved_format) if resolved_format else ("Authorization", "bearer")
)
if auth_type == "bearer":
headers[auth_header] = f"Bearer {decrypted_key}"
else:
headers[auth_header] = decrypted_key
# 2. 添加 endpoint 配置的额外头部
if endpoint.headers:
headers.update(endpoint.headers)
# 3. 透传原始头部(排除敏感头部 - 黑名单模式)
if original_headers:
for name, value in original_headers.items():
lower_name = name.lower()
# 跳过敏感头部
if lower_name in SENSITIVE_HEADERS:
continue
headers[name] = value
# 4. 添加额外头部
if extra_headers:
headers.update(extra_headers)
# 5. 确保有 Content-Type
if "Content-Type" not in headers and "content-type" not in headers:
headers["Content-Type"] = "application/json"
return headers
# ==============================================================================
# 便捷函数
# ==============================================================================
def build_passthrough_request(
original_body: Dict[str, Any],
original_headers: Dict[str, str],
endpoint: Any,
key: Any,
) -> Tuple[Dict[str, Any], Dict[str, str]]:
"""
构建透传模式的请求
纯透传:原样复制请求体,只处理请求头(认证等)。
model mapping 和 stream 由调用方自行处理(不同 API 格式处理方式不同)。
"""
builder = PassthroughRequestBuilder()
return builder.build(
original_body,
original_headers,
endpoint,
key,
)

View File

@@ -0,0 +1,174 @@
"""
响应解析器基类 - 定义统一的响应解析接口
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
@dataclass
class ParsedChunk:
"""解析后的流式数据块"""
# 原始数据
raw_line: str
event_type: Optional[str] = None
data: Optional[Dict[str, Any]] = None
# 提取的内容
text_delta: str = ""
is_done: bool = False
is_error: bool = False
error_message: Optional[str] = None
# 使用量信息(通常在最后一个 chunk 中)
input_tokens: int = 0
output_tokens: int = 0
cache_creation_tokens: int = 0
cache_read_tokens: int = 0
# 响应 ID
response_id: Optional[str] = None
@dataclass
class StreamStats:
"""流式响应统计信息"""
# 计数
chunk_count: int = 0
data_count: int = 0
# Token 使用量
input_tokens: int = 0
output_tokens: int = 0
cache_creation_tokens: int = 0
cache_read_tokens: int = 0
# 内容
collected_text: str = ""
response_id: Optional[str] = None
# 状态
has_completion: bool = False
status_code: int = 200
error_message: Optional[str] = None
# Provider 信息
provider_name: Optional[str] = None
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
# 响应头和完整响应
response_headers: Dict[str, str] = field(default_factory=dict)
final_response: Optional[Dict[str, Any]] = None
@dataclass
class ParsedResponse:
"""解析后的非流式响应"""
# 原始响应
raw_response: Dict[str, Any]
status_code: int
# 提取的内容
text_content: str = ""
response_id: Optional[str] = None
# 使用量
input_tokens: int = 0
output_tokens: int = 0
cache_creation_tokens: int = 0
cache_read_tokens: int = 0
# 错误信息
is_error: bool = False
error_type: Optional[str] = None
error_message: Optional[str] = None
class ResponseParser(ABC):
"""
响应解析器基类
定义统一的接口来解析不同 API 格式的响应。
子类需要实现具体的解析逻辑。
"""
# 解析器名称(用于日志)
name: str = "base"
# 支持的 API 格式
api_format: str = "UNKNOWN"
@abstractmethod
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
"""
解析单行 SSE 数据
Args:
line: SSE 行数据
stats: 流统计对象(会被更新)
Returns:
解析后的数据块,如果行不包含有效数据则返回 None
"""
pass
@abstractmethod
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
"""
解析非流式响应
Args:
response: 响应 JSON
status_code: HTTP 状态码
Returns:
解析后的响应对象
"""
pass
@abstractmethod
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
"""
从响应中提取 token 使用量
Args:
response: 响应 JSON
Returns:
包含 input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens 的字典
"""
pass
@abstractmethod
def extract_text_content(self, response: Dict[str, Any]) -> str:
"""
从响应中提取文本内容
Args:
response: 响应 JSON
Returns:
提取的文本内容
"""
pass
def is_error_response(self, response: Dict[str, Any]) -> bool:
"""
判断响应是否为错误响应
Args:
response: 响应 JSON
Returns:
是否为错误响应
"""
return "error" in response
def create_stats(self) -> StreamStats:
"""创建新的流统计对象"""
return StreamStats()