Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

View File

@@ -0,0 +1,99 @@
"""
API Handlers - 请求处理器
按 API 格式组织的 Adapter 和 Handler
- Adapter: 请求验证、格式转换、错误处理
- Handler: 业务逻辑、调用 Provider、记录用量
支持的格式:
- claude: Claude Chat API (/v1/messages)
- claude_cli: Claude CLI 透传模式
- openai: OpenAI Chat API (/v1/chat/completions)
- openai_cli: OpenAI CLI 透传模式
注意Handler 基类和具体 Handler 使用延迟导入以避免循环依赖。
"""
# Adapter 基类(不会引起循环导入,可以直接导入)
from src.api.handlers.base import (
ChatAdapterBase,
CliAdapterBase,
)
__all__ = [
# Adapter 基类
"ChatAdapterBase",
"CliAdapterBase",
# Handler 基类(延迟导入)
"ChatHandlerBase",
"CliMessageHandlerBase",
"BaseMessageHandler",
"MessageHandlerProtocol",
"MessageTelemetry",
"StreamContext",
# Claude
"ClaudeChatAdapter",
"ClaudeTokenCountAdapter",
"build_claude_adapter",
"ClaudeChatHandler",
# Claude CLI
"ClaudeCliAdapter",
"ClaudeCliMessageHandler",
# OpenAI
"OpenAIChatAdapter",
"OpenAIChatHandler",
# OpenAI CLI
"OpenAICliAdapter",
"OpenAICliMessageHandler",
]
# 延迟导入映射表
_LAZY_IMPORTS = {
# Handler 基类
"ChatHandlerBase": ("src.api.handlers.base.chat_handler_base", "ChatHandlerBase"),
"CliMessageHandlerBase": (
"src.api.handlers.base.cli_handler_base",
"CliMessageHandlerBase",
),
"StreamContext": ("src.api.handlers.base.cli_handler_base", "StreamContext"),
"BaseMessageHandler": ("src.api.handlers.base.base_handler", "BaseMessageHandler"),
"MessageHandlerProtocol": (
"src.api.handlers.base.base_handler",
"MessageHandlerProtocol",
),
"MessageTelemetry": ("src.api.handlers.base.base_handler", "MessageTelemetry"),
# Claude
"ClaudeChatAdapter": ("src.api.handlers.claude.adapter", "ClaudeChatAdapter"),
"ClaudeTokenCountAdapter": (
"src.api.handlers.claude.adapter",
"ClaudeTokenCountAdapter",
),
"build_claude_adapter": ("src.api.handlers.claude.adapter", "build_claude_adapter"),
"ClaudeChatHandler": ("src.api.handlers.claude.handler", "ClaudeChatHandler"),
# Claude CLI
"ClaudeCliAdapter": ("src.api.handlers.claude_cli.adapter", "ClaudeCliAdapter"),
"ClaudeCliMessageHandler": (
"src.api.handlers.claude_cli.handler",
"ClaudeCliMessageHandler",
),
# OpenAI
"OpenAIChatAdapter": ("src.api.handlers.openai.adapter", "OpenAIChatAdapter"),
"OpenAIChatHandler": ("src.api.handlers.openai.handler", "OpenAIChatHandler"),
# OpenAI CLI
"OpenAICliAdapter": ("src.api.handlers.openai_cli.adapter", "OpenAICliAdapter"),
"OpenAICliMessageHandler": (
"src.api.handlers.openai_cli.handler",
"OpenAICliMessageHandler",
),
}
def __getattr__(name: str):
"""延迟导入以避免循环依赖"""
if name in _LAZY_IMPORTS:
module_path, attr_name = _LAZY_IMPORTS[name]
import importlib
module = importlib.import_module(module_path)
return getattr(module, attr_name)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -0,0 +1,68 @@
"""
Handler 基类模块
提供 Adapter、Handler 的抽象基类,以及请求构建器和响应解析器。
注意Handler 基类ChatHandlerBase, CliMessageHandlerBase 等)不在这里导出,
因为它们依赖 services.usage.stream而后者又需要导入 response_parser
会形成循环导入。请直接从具体模块导入 Handler 基类。
"""
# Chat Adapter 基类(不会引起循环导入)
from src.api.handlers.base.chat_adapter_base import (
ChatAdapterBase,
get_adapter_class,
get_adapter_instance,
list_registered_formats,
register_adapter,
)
# CLI Adapter 基类
from src.api.handlers.base.cli_adapter_base import (
CliAdapterBase,
get_cli_adapter_class,
get_cli_adapter_instance,
list_registered_cli_formats,
register_cli_adapter,
)
# 请求构建器
from src.api.handlers.base.request_builder import (
SENSITIVE_HEADERS,
PassthroughRequestBuilder,
RequestBuilder,
build_passthrough_request,
)
# 响应解析器
from src.api.handlers.base.response_parser import (
ParsedChunk,
ParsedResponse,
ResponseParser,
StreamStats,
)
__all__ = [
# Chat Adapter
"ChatAdapterBase",
"register_adapter",
"get_adapter_class",
"get_adapter_instance",
"list_registered_formats",
# CLI Adapter
"CliAdapterBase",
"register_cli_adapter",
"get_cli_adapter_class",
"get_cli_adapter_instance",
"list_registered_cli_formats",
# 请求构建器
"RequestBuilder",
"PassthroughRequestBuilder",
"build_passthrough_request",
"SENSITIVE_HEADERS",
# 响应解析器
"ResponseParser",
"ParsedChunk",
"ParsedResponse",
"StreamStats",
]

View File

@@ -0,0 +1,363 @@
"""
基础消息处理器,封装通用的编排、转换、遥测逻辑。
接口约定:
- process_stream: 处理流式请求,返回 StreamingResponse
- process_sync: 处理非流式请求,返回 JSONResponse
签名规范(推荐):
async def process_stream(
self,
request: Any, # 解析后的请求模型
http_request: Request, # FastAPI Request 对象
original_headers: Dict[str, str], # 原始请求头
original_request_body: Dict[str, Any], # 原始请求体
query_params: Optional[Dict[str, str]] = None, # 查询参数
) -> StreamingResponse: ...
async def process_sync(
self,
request: Any,
http_request: Request,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
query_params: Optional[Dict[str, str]] = None,
) -> JSONResponse: ...
"""
from __future__ import annotations
import time
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from sqlalchemy.orm import Session
from src.clients.redis_client import get_redis_client_sync
from src.core.api_format_metadata import resolve_api_format
from src.core.enums import APIFormat
from src.core.logger import logger
from src.services.orchestration.fallback_orchestrator import FallbackOrchestrator
from src.services.provider.format import normalize_api_format
from src.services.system.audit import audit_service
from src.services.usage.service import UsageService
class MessageTelemetry:
"""
负责记录 Usage/Audit避免处理器里重复代码。
"""
def __init__(self, db: Session, user, api_key, request_id: str, client_ip: str):
self.db = db
self.user = user
self.api_key = api_key
self.request_id = request_id
self.client_ip = client_ip
async def calculate_cost(
self,
provider: str,
model: str,
*,
input_tokens: int,
output_tokens: int,
cache_creation_tokens: int = 0,
cache_read_tokens: int = 0,
) -> float:
input_price, output_price = await UsageService.get_model_price_async(
self.db, provider, model
)
_, _, _, _, _, _, total_cost = UsageService.calculate_cost(
input_tokens,
output_tokens,
input_price,
output_price,
cache_creation_tokens,
cache_read_tokens,
*await UsageService.get_cache_prices_async(self.db, provider, model, input_price),
)
return total_cost
async def record_success(
self,
*,
provider: str,
model: str,
input_tokens: int,
output_tokens: int,
response_time_ms: int,
status_code: int,
request_body: Dict[str, Any],
request_headers: Dict[str, Any],
response_body: Any,
response_headers: Dict[str, Any],
cache_creation_tokens: int = 0,
cache_read_tokens: int = 0,
is_stream: bool = False,
provider_request_headers: Optional[Dict[str, Any]] = None,
# Provider 侧追踪信息(用于记录真实成本)
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
provider_api_key_id: Optional[str] = None,
api_format: Optional[str] = None,
# 模型映射信息
target_model: Optional[str] = None,
# Provider 响应元数据(如 Gemini 的 modelVersion
response_metadata: Optional[Dict[str, Any]] = None,
) -> float:
total_cost = await self.calculate_cost(
provider,
model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_tokens=cache_creation_tokens,
cache_read_tokens=cache_read_tokens,
)
await UsageService.record_usage(
db=self.db,
user=self.user,
api_key=self.api_key,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_tokens,
cache_read_input_tokens=cache_read_tokens,
request_type="chat",
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
status_code=status_code,
request_headers=request_headers,
request_body=request_body,
provider_request_headers=provider_request_headers or {},
response_headers=response_headers,
response_body=response_body,
request_id=self.request_id,
# Provider 侧追踪信息(用于记录真实成本)
provider_id=provider_id,
provider_endpoint_id=provider_endpoint_id,
provider_api_key_id=provider_api_key_id,
# 模型映射信息
target_model=target_model,
# Provider 响应元数据
metadata=response_metadata,
)
if self.user and self.api_key:
audit_service.log_api_request(
db=self.db,
user_id=self.user.id,
api_key_id=self.api_key.id,
request_id=self.request_id,
model=model,
provider=provider,
success=True,
ip_address=self.client_ip,
status_code=status_code,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost_usd=total_cost,
)
return total_cost
async def record_failure(
self,
*,
provider: str,
model: str,
response_time_ms: int,
status_code: int,
error_message: str,
request_body: Dict[str, Any],
request_headers: Dict[str, Any],
is_stream: bool,
api_format: Optional[str] = None,
provider_request_headers: Optional[Dict[str, Any]] = None,
# 预估 token 信息(来自 message_start 事件,用于中断请求的成本估算)
input_tokens: int = 0,
output_tokens: int = 0,
cache_creation_tokens: int = 0,
cache_read_tokens: int = 0,
response_body: Optional[Dict[str, Any]] = None,
# 模型映射信息
target_model: Optional[str] = None,
):
"""
记录失败请求
注意Provider 链路信息provider_id, endpoint_id, key_id不在此处记录
因为 RequestCandidate 表已经记录了完整的请求链路追踪信息。
Args:
input_tokens: 预估输入 tokens来自 message_start用于中断请求的成本估算
output_tokens: 预估输出 tokens来自已收到的内容
cache_creation_tokens: 缓存创建 tokens
cache_read_tokens: 缓存读取 tokens
response_body: 响应体(如果有部分响应)
target_model: 映射后的目标模型名(如果发生了映射)
"""
provider_name = provider or "unknown"
if provider_name == "unknown":
logger.warning(f"[Telemetry] Recording failure with unknown provider (request_id={self.request_id})")
await UsageService.record_usage(
db=self.db,
user=self.user,
api_key=self.api_key,
provider=provider_name,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_tokens,
cache_read_input_tokens=cache_read_tokens,
request_type="chat",
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=error_message,
request_headers=request_headers,
request_body=request_body,
provider_request_headers=provider_request_headers or {},
response_headers={},
response_body=response_body or {"error": error_message},
request_id=self.request_id,
# 模型映射信息
target_model=target_model,
)
@runtime_checkable
class MessageHandlerProtocol(Protocol):
"""
消息处理器协议 - 定义标准接口
ChatHandlerBase 使用完整签名(含 request, http_request
CliMessageHandlerBase 使用简化签名(仅 original_request_body, original_headers
"""
async def process_stream(
self,
request: Any,
http_request: Request,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
query_params: Optional[Dict[str, str]] = None,
) -> StreamingResponse:
"""处理流式请求"""
...
async def process_sync(
self,
request: Any,
http_request: Request,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
query_params: Optional[Dict[str, str]] = None,
) -> JSONResponse:
"""处理非流式请求"""
...
class BaseMessageHandler:
"""
消息处理器基类,所有具体格式的 handler 可以继承它。
子类需要实现:
- process_stream: 处理流式请求
- process_sync: 处理非流式请求
推荐使用 MessageHandlerProtocol 中定义的签名。
"""
# Adapter 检测器类型
AdapterDetectorType = Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]
def __init__(
self,
*,
db: Session,
user,
api_key,
request_id: str,
client_ip: str,
user_agent: str,
start_time: float,
allowed_api_formats: Optional[list[str]] = None,
adapter_detector: Optional[AdapterDetectorType] = None,
):
self.db = db
self.user = user
self.api_key = api_key
self.request_id = request_id
self.client_ip = client_ip
self.user_agent = user_agent
self.start_time = start_time
self.allowed_api_formats = allowed_api_formats or [APIFormat.CLAUDE.value]
self.primary_api_format = normalize_api_format(self.allowed_api_formats[0])
self.adapter_detector = adapter_detector
redis_client = get_redis_client_sync()
self.orchestrator = FallbackOrchestrator(db, redis_client)
self.telemetry = MessageTelemetry(db, user, api_key, request_id, client_ip)
def elapsed_ms(self) -> int:
return int((time.time() - self.start_time) * 1000)
def _resolve_capability_requirements(
self,
model_name: str,
request_headers: Optional[Dict[str, str]] = None,
request_body: Optional[Dict[str, Any]] = None,
) -> Dict[str, bool]:
"""
解析请求的能力需求
来源:
1. 用户模型级配置 (User.model_capability_settings)
2. 用户 API Key 强制配置 (ApiKey.force_capabilities)
3. 请求头 X-Require-Capability
4. Adapter 的 detect_capability_requirements如 Claude 的 anthropic-beta
Args:
model_name: 模型名称
request_headers: 请求头
request_body: 请求体(可选)
Returns:
能力需求字典
"""
from src.services.capability.resolver import CapabilityResolver
return CapabilityResolver.resolve_requirements(
user=self.user,
user_api_key=self.api_key,
model_name=model_name,
request_headers=request_headers,
request_body=request_body,
adapter_detector=self.adapter_detector,
)
def get_api_format(self, provider_type: Optional[str] = None) -> APIFormat:
"""根据 provider_type 解析 API 格式,未知类型默认 OPENAI"""
if provider_type:
return resolve_api_format(provider_type, default=APIFormat.OPENAI)
return self.primary_api_format
def build_provider_payload(
self,
original_body: Dict[str, Any],
*,
mapped_model: Optional[str] = None,
) -> Dict[str, Any]:
"""构建发送给 Provider 的请求体,替换 model 名称"""
payload = dict(original_body)
if mapped_model:
payload["model"] = mapped_model
return payload

View File

@@ -0,0 +1,724 @@
"""
Chat Adapter 通用基类
提供 Chat 格式(进行请求验证和标准化)的通用适配器逻辑:
- 请求解析和验证
- 审计日志记录
- 错误处理和响应格式化
- Handler 创建和调用
- 计费策略(支持不同 API 格式的差异化计费)
子类只需提供:
- FORMAT_ID: API 格式标识
- HANDLER_CLASS: 对应的 ChatHandlerBase 子类
- _validate_request_body(): 可选覆盖请求验证逻辑
- _build_audit_metadata(): 可选覆盖审计元数据构建
- compute_total_input_context(): 可选覆盖总输入上下文计算(用于阶梯计费判定)
"""
import time
import traceback
from abc import abstractmethod
from typing import Any, Dict, Optional, Type
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.context import ApiRequestContext
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.core.exceptions import (
InvalidRequestException,
ModelNotSupportedException,
ProviderAuthException,
ProviderNotAvailableException,
ProviderRateLimitException,
ProviderTimeoutException,
ProxyException,
QuotaExceededException,
UpstreamClientException,
)
from src.core.logger import logger
from src.services.request.result import RequestResult
from src.services.usage.recorder import UsageRecorder
class ChatAdapterBase(ApiAdapter):
"""
Chat Adapter 通用基类
提供 Chat 格式的通用适配器逻辑,子类只需配置:
- FORMAT_ID: API 格式标识
- HANDLER_CLASS: ChatHandlerBase 子类
- name: 适配器名称
"""
# 子类必须覆盖
FORMAT_ID: str = "UNKNOWN"
HANDLER_CLASS: Type[ChatHandlerBase]
# 适配器配置
name: str = "chat.base"
mode = ApiMode.STANDARD
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
self.response_normalizer = None
# 可选启用响应规范化
self._init_response_normalizer()
def _init_response_normalizer(self):
"""初始化响应规范化器 - 子类可覆盖"""
try:
from src.services.provider.response_normalizer import ResponseNormalizer
self.response_normalizer = ResponseNormalizer()
except ImportError:
pass
async def handle(self, context: ApiRequestContext):
"""处理 Chat API 请求"""
http_request = context.request
user = context.user
api_key = context.api_key
db = context.db
request_id = context.request_id
quota_remaining_value = context.quota_remaining
start_time = context.start_time
client_ip = context.client_ip
user_agent = context.user_agent
original_headers = context.original_headers
query_params = context.query_params
original_request_body = context.ensure_json_body()
# 合并 path_params 到请求体(如 Gemini API 的 model 在 URL 路径中)
if context.path_params:
original_request_body = self._merge_path_params(
original_request_body, context.path_params
)
# 验证和解析请求
request_obj = self._validate_request_body(original_request_body, context.path_params)
if isinstance(request_obj, JSONResponse):
return request_obj
stream = getattr(request_obj, "stream", False)
model = getattr(request_obj, "model", "unknown")
# 添加审计元数据
audit_metadata = self._build_audit_metadata(original_request_body, request_obj)
context.add_audit_metadata(**audit_metadata)
# 格式化额度显示
quota_display = (
"unlimited" if quota_remaining_value is None else f"${quota_remaining_value:.2f}"
)
# 请求开始日志
logger.info(f"[REQ] {request_id[:8]} | {self.FORMAT_ID} | {getattr(api_key, 'name', 'unknown')} | "
f"{model} | {'stream' if stream else 'sync'} | quota:{quota_display}")
try:
# 检查客户端连接
if await http_request.is_disconnected():
logger.warning("客户端连接断开")
raise HTTPException(status_code=499, detail="Client disconnected")
# 创建 Handler
handler = self._create_handler(
db=db,
user=user,
api_key=api_key,
request_id=request_id,
client_ip=client_ip,
user_agent=user_agent,
start_time=start_time,
)
# 处理请求
if stream:
return await handler.process_stream(
request=request_obj,
http_request=http_request,
original_headers=original_headers,
original_request_body=original_request_body,
query_params=query_params,
)
return await handler.process_sync(
request=request_obj,
http_request=http_request,
original_headers=original_headers,
original_request_body=original_request_body,
query_params=query_params,
)
except HTTPException:
raise
except (
ModelNotSupportedException,
QuotaExceededException,
InvalidRequestException,
) as e:
logger.info(f"客户端请求错误: {e.error_type}")
return self._error_response(
status_code=e.status_code,
error_type=(
"invalid_request_error" if e.status_code == 400 else "quota_exceeded"
),
message=e.message,
)
except (
ProviderAuthException,
ProviderRateLimitException,
ProviderNotAvailableException,
ProviderTimeoutException,
UpstreamClientException,
) as e:
return await self._handle_provider_exception(
e,
db=db,
user=user,
api_key=api_key,
model=model,
stream=stream,
start_time=start_time,
original_headers=original_headers,
original_request_body=original_request_body,
client_ip=client_ip,
request_id=request_id,
)
except Exception as e:
return await self._handle_unexpected_exception(
e,
db=db,
user=user,
api_key=api_key,
model=model,
stream=stream,
start_time=start_time,
original_headers=original_headers,
original_request_body=original_request_body,
client_ip=client_ip,
request_id=request_id,
)
def _create_handler(
self,
*,
db,
user,
api_key,
request_id: str,
client_ip: str,
user_agent: str,
start_time: float,
):
"""创建 Handler 实例 - 子类可覆盖"""
return self.HANDLER_CLASS(
db=db,
user=user,
api_key=api_key,
request_id=request_id,
client_ip=client_ip,
user_agent=user_agent,
start_time=start_time,
allowed_api_formats=self.allowed_api_formats,
response_normalizer=self.response_normalizer,
enable_response_normalization=self.response_normalizer is not None,
adapter_detector=self.detect_capability_requirements,
)
def _merge_path_params(
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any]
) -> Dict[str, Any]:
"""
合并 URL 路径参数到请求体 - 子类可覆盖
默认实现:直接将 path_params 中的字段合并到请求体(不覆盖已有字段)
Args:
original_request_body: 原始请求体字典
path_params: URL 路径参数字典
Returns:
合并后的请求体字典
"""
merged = original_request_body.copy()
for key, value in path_params.items():
if key not in merged:
merged[key] = value
return merged
@abstractmethod
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
"""
验证请求体 - 子类必须实现
Args:
original_request_body: 原始请求体字典
path_params: URL 路径参数(如 Gemini 的 stream 通过 URL 端点传入)
Returns:
验证后的请求对象,或 JSONResponse 错误响应
"""
pass
def _extract_message_count(self, payload: Dict[str, Any], request_obj) -> int:
"""
提取消息数量 - 子类可覆盖
默认实现:从 messages 字段提取
"""
messages = payload.get("messages", [])
if hasattr(request_obj, "messages"):
messages = request_obj.messages
return len(messages) if isinstance(messages, list) else 0
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
"""
构建审计日志元数据 - 子类可覆盖
"""
model = getattr(request_obj, "model", payload.get("model", "unknown"))
stream = getattr(request_obj, "stream", payload.get("stream", False))
messages_count = self._extract_message_count(payload, request_obj)
return {
"action": f"{self.FORMAT_ID.lower()}_request",
"model": model,
"stream": bool(stream),
"max_tokens": getattr(request_obj, "max_tokens", payload.get("max_tokens")),
"messages_count": messages_count,
"temperature": getattr(request_obj, "temperature", payload.get("temperature")),
"top_p": getattr(request_obj, "top_p", payload.get("top_p")),
}
async def _handle_provider_exception(
self,
e: Exception,
*,
db,
user,
api_key,
model: str,
stream: bool,
start_time: float,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
client_ip: str,
request_id: str,
) -> JSONResponse:
"""处理 Provider 相关异常"""
logger.debug(f"Caught provider exception: {type(e).__name__}")
response_time = int((time.time() - start_time) * 1000)
# 使用 RequestResult.from_exception 创建统一的失败结果
# 关键api_format 从 FORMAT_ID 获取,确保始终有值
result = RequestResult.from_exception(
exception=e,
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
model=model,
response_time_ms=response_time,
is_stream=stream,
)
result.request_headers = original_headers
result.request_body = original_request_body
# 确定错误消息
if isinstance(e, ProviderAuthException):
error_message = (
f"提供商认证失败: {str(e)}"
if result.metadata.provider != "unknown"
else "服务端错误: 无可用提供商"
)
result.error_message = error_message
# 处理上游客户端错误(如图片处理失败)
if isinstance(e, UpstreamClientException):
# 返回 400 状态码和清晰的错误消息
result.status_code = e.status_code
result.error_message = e.message
# 使用 UsageRecorder 记录失败
recorder = UsageRecorder(
db=db,
user=user,
api_key=api_key,
client_ip=client_ip,
request_id=request_id,
)
await recorder.record_failure(result, original_headers, original_request_body)
# 根据异常类型确定错误类型
if isinstance(e, UpstreamClientException):
error_type = "invalid_request_error"
elif result.status_code == 503:
error_type = "internal_server_error"
else:
error_type = "rate_limit_exceeded"
return self._error_response(
status_code=result.status_code,
error_type=error_type,
message=result.error_message or str(e),
)
async def _handle_unexpected_exception(
self,
e: Exception,
*,
db,
user,
api_key,
model: str,
stream: bool,
start_time: float,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
client_ip: str,
request_id: str,
) -> JSONResponse:
"""处理未预期的异常"""
if isinstance(e, ProxyException):
logger.error(f"{self.FORMAT_ID} 请求处理业务异常: {type(e).__name__}")
else:
logger.error(f"{self.FORMAT_ID} 请求处理意外异常",
exception=e,
extra_data={
"exception_class": e.__class__.__name__,
"processing_stage": "request_processing",
"model": model,
"stream": stream,
"traceback_preview": str(traceback.format_exc())[:500],
},
)
response_time = int((time.time() - start_time) * 1000)
# 使用 RequestResult.from_exception 创建统一的失败结果
# 关键api_format 从 FORMAT_ID 获取,确保始终有值
result = RequestResult.from_exception(
exception=e,
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
model=model,
response_time_ms=response_time,
is_stream=stream,
)
# 对于未预期的异常,强制设置状态码为 500
result.status_code = 500
result.error_type = "internal_error"
result.request_headers = original_headers
result.request_body = original_request_body
try:
# 使用 UsageRecorder 记录失败
recorder = UsageRecorder(
db=db,
user=user,
api_key=api_key,
client_ip=client_ip,
request_id=request_id,
)
await recorder.record_failure(result, original_headers, original_request_body)
except Exception as record_error:
logger.error(f"记录失败请求时出错: {record_error}")
return self._error_response(
status_code=500,
error_type="internal_server_error",
message="处理请求时发生内部错误")
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
"""生成错误响应 - 子类可覆盖以自定义格式"""
return JSONResponse(
status_code=status_code,
content={
"error": {
"type": error_type,
"message": message,
}
},
)
# =========================================================================
# 计费策略相关方法 - 子类可覆盖以实现不同 API 格式的差异化计费
# =========================================================================
def compute_total_input_context(
self,
input_tokens: int,
cache_read_input_tokens: int,
cache_creation_input_tokens: int = 0,
) -> int:
"""
计算总输入上下文(用于阶梯计费判定)
默认实现input_tokens + cache_read_input_tokens
子类可覆盖此方法实现不同的计算逻辑
Args:
input_tokens: 输入 token 数
cache_read_input_tokens: 缓存读取 token 数
cache_creation_input_tokens: 缓存创建 token 数(部分格式可能需要)
Returns:
总输入上下文 token 数
"""
return input_tokens + cache_read_input_tokens
def get_cache_read_price_for_ttl(
self,
tier: dict,
cache_ttl_minutes: Optional[int] = None,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
子类可覆盖此方法实现不同的 TTL 定价逻辑
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格(每 1M tokens
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if ttl_pricing and cache_ttl_minutes is not None:
matched_price = None
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
matched_price = ttl_config.get("cache_read_price_per_1m")
break
if matched_price is not None:
return matched_price
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
return ttl_pricing[-1].get("cache_read_price_per_1m")
return tier.get("cache_read_price_per_1m")
def compute_cost(
self,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int,
cache_read_input_tokens: int,
input_price_per_1m: float,
output_price_per_1m: float,
cache_creation_price_per_1m: Optional[float],
cache_read_price_per_1m: Optional[float],
price_per_request: Optional[float],
tiered_pricing: Optional[dict] = None,
cache_ttl_minutes: Optional[int] = None,
) -> Dict[str, Any]:
"""
计算请求成本
默认实现:支持固定价格和阶梯计费
子类可覆盖此方法实现完全不同的计费逻辑
Args:
input_tokens: 输入 token 数
output_tokens: 输出 token 数
cache_creation_input_tokens: 缓存创建 token 数
cache_read_input_tokens: 缓存读取 token 数
input_price_per_1m: 输入价格(每 1M tokens
output_price_per_1m: 输出价格(每 1M tokens
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens
cache_read_price_per_1m: 缓存读取价格(每 1M tokens
price_per_request: 按次计费价格
tiered_pricing: 阶梯计费配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
包含各项成本的字典:
{
"input_cost": float,
"output_cost": float,
"cache_creation_cost": float,
"cache_read_cost": float,
"cache_cost": float,
"request_cost": float,
"total_cost": float,
"tier_index": Optional[int], # 命中的阶梯索引
}
"""
tier_index = None
effective_input_price = input_price_per_1m
effective_output_price = output_price_per_1m
effective_cache_creation_price = cache_creation_price_per_1m
effective_cache_read_price = cache_read_price_per_1m
# 检查阶梯计费
if tiered_pricing and tiered_pricing.get("tiers"):
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
if tier:
tier_index = tiered_pricing["tiers"].index(tier)
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
effective_cache_creation_price = tier.get(
"cache_creation_price_per_1m", cache_creation_price_per_1m
)
effective_cache_read_price = self.get_cache_read_price_for_ttl(
tier, cache_ttl_minutes
)
if effective_cache_read_price is None:
effective_cache_read_price = cache_read_price_per_1m
# 计算各项成本
input_cost = (input_tokens / 1_000_000) * effective_input_price
output_cost = (output_tokens / 1_000_000) * effective_output_price
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * effective_cache_creation_price
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
cache_read_cost = (
cache_read_input_tokens / 1_000_000
) * effective_cache_read_price
cache_cost = cache_creation_cost + cache_read_cost
request_cost = price_per_request if price_per_request else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return {
"input_cost": input_cost,
"output_cost": output_cost,
"cache_creation_cost": cache_creation_cost,
"cache_read_cost": cache_read_cost,
"cache_cost": cache_cost,
"request_cost": request_cost,
"total_cost": total_cost,
"tier_index": tier_index,
}
@staticmethod
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
"""
根据总输入 token 数确定价格阶梯
Args:
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
total_input_tokens: 总输入 token 数
Returns:
匹配的阶梯配置
"""
if not tiered_pricing or "tiers" not in tiered_pricing:
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
for tier in tiers:
up_to = tier.get("up_to")
if up_to is None or total_input_tokens <= up_to:
return tier
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
return tiers[-1] if tiers else None
# =========================================================================
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
# =========================================================================
_ADAPTER_REGISTRY: Dict[str, Type["ChatAdapterBase"]] = {}
_ADAPTERS_LOADED = False
def register_adapter(adapter_class: Type["ChatAdapterBase"]) -> Type["ChatAdapterBase"]:
"""
注册 Adapter 类到注册表
用法:
@register_adapter
class ClaudeChatAdapter(ChatAdapterBase):
FORMAT_ID = "CLAUDE"
...
Args:
adapter_class: Adapter 类
Returns:
注册的 Adapter 类(支持作为装饰器使用)
"""
format_id = adapter_class.FORMAT_ID
if format_id and format_id != "UNKNOWN":
_ADAPTER_REGISTRY[format_id.upper()] = adapter_class
return adapter_class
def _ensure_adapters_loaded():
"""确保所有 Adapter 已被加载(触发注册)"""
global _ADAPTERS_LOADED
if _ADAPTERS_LOADED:
return
# 导入各个 Adapter 模块以触发 @register_adapter 装饰器
try:
from src.api.handlers.claude import adapter as _ # noqa: F401
except ImportError:
pass
try:
from src.api.handlers.openai import adapter as _ # noqa: F401
except ImportError:
pass
try:
from src.api.handlers.gemini import adapter as _ # noqa: F401
except ImportError:
pass
_ADAPTERS_LOADED = True
def get_adapter_class(api_format: str) -> Optional[Type["ChatAdapterBase"]]:
"""
根据 API format 获取 Adapter 类
Args:
api_format: API 格式标识(如 "CLAUDE", "OPENAI", "GEMINI"
Returns:
对应的 Adapter 类,如果未找到返回 None
"""
_ensure_adapters_loaded()
return _ADAPTER_REGISTRY.get(api_format.upper()) if api_format else None
def get_adapter_instance(api_format: str) -> Optional["ChatAdapterBase"]:
"""
根据 API format 获取 Adapter 实例
Args:
api_format: API 格式标识
Returns:
Adapter 实例,如果未找到返回 None
"""
adapter_class = get_adapter_class(api_format)
if adapter_class:
return adapter_class()
return None
def list_registered_formats() -> list[str]:
"""返回所有已注册的 API 格式"""
_ensure_adapters_loaded()
return list(_ADAPTER_REGISTRY.keys())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,648 @@
"""
CLI Adapter 通用基类
提供 CLI 格式(直接透传请求)的通用适配器逻辑:
- 请求解析和验证
- 审计日志记录
- 错误处理和响应格式化
- Handler 创建和调用
- 计费策略(支持不同 API 格式的差异化计费)
子类只需提供:
- FORMAT_ID: API 格式标识
- HANDLER_CLASS: 对应的 MessageHandler 类
- 可选覆盖 _extract_message_count() 自定义消息计数逻辑
- 可选覆盖 compute_total_input_context() 自定义总输入上下文计算
"""
import time
import traceback
from typing import Any, Dict, Optional, Type
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.context import ApiRequestContext
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
from src.core.exceptions import (
InvalidRequestException,
ModelNotSupportedException,
ProviderAuthException,
ProviderNotAvailableException,
ProviderRateLimitException,
ProviderTimeoutException,
ProxyException,
QuotaExceededException,
UpstreamClientException,
)
from src.core.logger import logger
from src.services.request.result import RequestResult
from src.services.usage.recorder import UsageRecorder
class CliAdapterBase(ApiAdapter):
"""
CLI Adapter 通用基类
提供 CLI 格式的通用适配器逻辑,子类只需配置:
- FORMAT_ID: API 格式标识
- HANDLER_CLASS: MessageHandler 类
- name: 适配器名称
"""
# 子类必须覆盖
FORMAT_ID: str = "UNKNOWN"
HANDLER_CLASS: Type[CliMessageHandlerBase]
# 适配器配置
name: str = "cli.base"
mode = ApiMode.PROXY
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
async def handle(self, context: ApiRequestContext):
"""处理 CLI API 请求"""
http_request = context.request
user = context.user
api_key = context.api_key
db = context.db
request_id = context.request_id
quota_remaining_value = context.quota_remaining
start_time = context.start_time
client_ip = context.client_ip
user_agent = context.user_agent
original_headers = context.original_headers
query_params = context.query_params # 获取查询参数
original_request_body = context.ensure_json_body()
# 合并 path_params 到请求体(如 Gemini API 的 model 在 URL 路径中)
if context.path_params:
original_request_body = self._merge_path_params(
original_request_body, context.path_params
)
# 获取 stream优先从请求体其次从 path_params如 Gemini 通过 URL 端点区分)
stream = original_request_body.get("stream")
if stream is None and context.path_params:
stream = context.path_params.get("stream", False)
stream = bool(stream)
# 获取 model优先从请求体其次从 path_params如 Gemini 的 model 在 URL 路径中)
model = original_request_body.get("model")
if model is None and context.path_params:
model = context.path_params.get("model", "unknown")
model = model or "unknown"
# 提取请求元数据
audit_metadata = self._build_audit_metadata(original_request_body, context.path_params)
context.add_audit_metadata(**audit_metadata)
# 格式化额度显示
quota_display = (
"unlimited" if quota_remaining_value is None else f"${quota_remaining_value:.2f}"
)
# 请求开始日志
logger.info(f"[REQ] {request_id[:8]} | {self.FORMAT_ID} | {getattr(api_key, 'name', 'unknown')} | "
f"{model} | {'stream' if stream else 'sync'} | quota:{quota_display}")
try:
# 检查客户端连接
if await http_request.is_disconnected():
logger.warning("客户端连接断开")
raise HTTPException(status_code=499, detail="Client disconnected")
# 创建 Handler
handler = self.HANDLER_CLASS(
db=db,
user=user,
api_key=api_key,
request_id=request_id,
client_ip=client_ip,
user_agent=user_agent,
start_time=start_time,
allowed_api_formats=self.allowed_api_formats,
adapter_detector=self.detect_capability_requirements,
)
# 处理请求
if stream:
return await handler.process_stream(
original_request_body=original_request_body,
original_headers=original_headers,
query_params=query_params,
path_params=context.path_params,
)
return await handler.process_sync(
original_request_body=original_request_body,
original_headers=original_headers,
query_params=query_params,
path_params=context.path_params,
)
except HTTPException:
raise
except (
ModelNotSupportedException,
QuotaExceededException,
InvalidRequestException,
) as e:
logger.debug(f"客户端请求错误: {e.error_type}")
return self._error_response(
status_code=e.status_code,
error_type=(
"invalid_request_error" if e.status_code == 400 else "quota_exceeded"
),
message=e.message,
)
except (
ProviderAuthException,
ProviderRateLimitException,
ProviderNotAvailableException,
ProviderTimeoutException,
UpstreamClientException,
) as e:
return await self._handle_provider_exception(
e,
db=db,
user=user,
api_key=api_key,
model=model,
stream=stream,
start_time=start_time,
original_headers=original_headers,
original_request_body=original_request_body,
client_ip=client_ip,
request_id=request_id,
)
except Exception as e:
return await self._handle_unexpected_exception(
e,
db=db,
user=user,
api_key=api_key,
model=model,
stream=stream,
start_time=start_time,
original_headers=original_headers,
original_request_body=original_request_body,
client_ip=client_ip,
request_id=request_id,
)
def _merge_path_params(
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any]
) -> Dict[str, Any]:
"""
合并 URL 路径参数到请求体 - 子类可覆盖
默认实现:直接将 path_params 中的字段合并到请求体(不覆盖已有字段)
Args:
original_request_body: 原始请求体字典
path_params: URL 路径参数字典
Returns:
合并后的请求体字典
"""
merged = original_request_body.copy()
for key, value in path_params.items():
if key not in merged:
merged[key] = value
return merged
def _extract_message_count(self, payload: Dict[str, Any]) -> int:
"""
提取消息数量 - 子类可覆盖
默认实现:从 input 字段提取
"""
if "input" not in payload:
return 0
input_data = payload["input"]
if isinstance(input_data, list):
return len(input_data)
if isinstance(input_data, dict) and "messages" in input_data:
return len(input_data.get("messages", []))
return 0
def _build_audit_metadata(
self,
payload: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
构建审计日志元数据 - 子类可覆盖
Args:
payload: 请求体
path_params: URL 路径参数(用于获取 model 等)
"""
# 优先从请求体获取 model其次从 path_params
model = payload.get("model")
if model is None and path_params:
model = path_params.get("model", "unknown")
model = model or "unknown"
stream = payload.get("stream", False)
messages_count = self._extract_message_count(payload)
return {
"action": f"{self.FORMAT_ID.lower()}_request",
"model": model,
"stream": bool(stream),
"max_tokens": payload.get("max_tokens"),
"messages_count": messages_count,
"temperature": payload.get("temperature"),
"top_p": payload.get("top_p"),
"tool_count": len(payload.get("tools") or []),
"instructions_present": bool(payload.get("instructions")),
}
async def _handle_provider_exception(
self,
e: Exception,
*,
db,
user,
api_key,
model: str,
stream: bool,
start_time: float,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
client_ip: str,
request_id: str,
) -> JSONResponse:
"""处理 Provider 相关异常"""
logger.debug(f"Caught provider exception: {type(e).__name__}")
response_time = int((time.time() - start_time) * 1000)
# 使用 RequestResult.from_exception 创建统一的失败结果
# 关键api_format 从 FORMAT_ID 获取,确保始终有值
result = RequestResult.from_exception(
exception=e,
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
model=model,
response_time_ms=response_time,
is_stream=stream,
)
result.request_headers = original_headers
result.request_body = original_request_body
# 确定错误消息
if isinstance(e, ProviderAuthException):
error_message = (
f"提供商认证失败: {str(e)}"
if result.metadata.provider != "unknown"
else "服务端错误: 无可用提供商"
)
result.error_message = error_message
# 处理上游客户端错误(如图片处理失败)
if isinstance(e, UpstreamClientException):
# 返回 400 状态码和清晰的错误消息
result.status_code = e.status_code
result.error_message = e.message
# 使用 UsageRecorder 记录失败
recorder = UsageRecorder(
db=db,
user=user,
api_key=api_key,
client_ip=client_ip,
request_id=request_id,
)
await recorder.record_failure(result, original_headers, original_request_body)
# 根据异常类型确定错误类型
if isinstance(e, UpstreamClientException):
error_type = "invalid_request_error"
elif result.status_code == 503:
error_type = "internal_server_error"
else:
error_type = "rate_limit_exceeded"
return self._error_response(
status_code=result.status_code,
error_type=error_type,
message=result.error_message or str(e),
)
async def _handle_unexpected_exception(
self,
e: Exception,
*,
db,
user,
api_key,
model: str,
stream: bool,
start_time: float,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
client_ip: str,
request_id: str,
) -> JSONResponse:
"""处理未预期的异常"""
if isinstance(e, ProxyException):
logger.error(f"{self.FORMAT_ID} 请求处理业务异常: {type(e).__name__}")
else:
logger.error(f"{self.FORMAT_ID} 请求处理意外异常",
exception=e,
extra_data={
"exception_class": e.__class__.__name__,
"processing_stage": "request_processing",
"model": model,
"stream": stream,
"traceback_preview": str(traceback.format_exc())[:500],
},
)
response_time = int((time.time() - start_time) * 1000)
# 使用 RequestResult.from_exception 创建统一的失败结果
# 关键api_format 从 FORMAT_ID 获取,确保始终有值
result = RequestResult.from_exception(
exception=e,
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
model=model,
response_time_ms=response_time,
is_stream=stream,
)
# 对于未预期的异常,强制设置状态码为 500
result.status_code = 500
result.error_type = "internal_error"
result.request_headers = original_headers
result.request_body = original_request_body
# 使用 UsageRecorder 记录失败
recorder = UsageRecorder(
db=db,
user=user,
api_key=api_key,
client_ip=client_ip,
request_id=request_id,
)
await recorder.record_failure(result, original_headers, original_request_body)
return self._error_response(
status_code=500,
error_type="internal_server_error",
message="处理请求时发生内部错误")
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
"""生成错误响应"""
return JSONResponse(
status_code=status_code,
content={
"error": {
"type": error_type,
"message": message,
}
},
)
# =========================================================================
# 计费策略相关方法 - 子类可覆盖以实现不同 API 格式的差异化计费
# =========================================================================
def compute_total_input_context(
self,
input_tokens: int,
cache_read_input_tokens: int,
cache_creation_input_tokens: int = 0,
) -> int:
"""
计算总输入上下文(用于阶梯计费判定)
默认实现input_tokens + cache_read_input_tokens
子类可覆盖此方法实现不同的计算逻辑
Args:
input_tokens: 输入 token 数
cache_read_input_tokens: 缓存读取 token 数
cache_creation_input_tokens: 缓存创建 token 数(部分格式可能需要)
Returns:
总输入上下文 token 数
"""
return input_tokens + cache_read_input_tokens
def get_cache_read_price_for_ttl(
self,
tier: dict,
cache_ttl_minutes: Optional[int] = None,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
子类可覆盖此方法实现不同的 TTL 定价逻辑
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格(每 1M tokens
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if ttl_pricing and cache_ttl_minutes is not None:
matched_price = None
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
matched_price = ttl_config.get("cache_read_price_per_1m")
break
if matched_price is not None:
return matched_price
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
return ttl_pricing[-1].get("cache_read_price_per_1m")
return tier.get("cache_read_price_per_1m")
def compute_cost(
self,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int,
cache_read_input_tokens: int,
input_price_per_1m: float,
output_price_per_1m: float,
cache_creation_price_per_1m: Optional[float],
cache_read_price_per_1m: Optional[float],
price_per_request: Optional[float],
tiered_pricing: Optional[dict] = None,
cache_ttl_minutes: Optional[int] = None,
) -> Dict[str, Any]:
"""
计算请求成本
默认实现:支持固定价格和阶梯计费
子类可覆盖此方法实现完全不同的计费逻辑
Args:
input_tokens: 输入 token 数
output_tokens: 输出 token 数
cache_creation_input_tokens: 缓存创建 token 数
cache_read_input_tokens: 缓存读取 token 数
input_price_per_1m: 输入价格(每 1M tokens
output_price_per_1m: 输出价格(每 1M tokens
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens
cache_read_price_per_1m: 缓存读取价格(每 1M tokens
price_per_request: 按次计费价格
tiered_pricing: 阶梯计费配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
包含各项成本的字典
"""
tier_index = None
effective_input_price = input_price_per_1m
effective_output_price = output_price_per_1m
effective_cache_creation_price = cache_creation_price_per_1m
effective_cache_read_price = cache_read_price_per_1m
# 检查阶梯计费
if tiered_pricing and tiered_pricing.get("tiers"):
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
if tier:
tier_index = tiered_pricing["tiers"].index(tier)
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
effective_cache_creation_price = tier.get(
"cache_creation_price_per_1m", cache_creation_price_per_1m
)
effective_cache_read_price = self.get_cache_read_price_for_ttl(
tier, cache_ttl_minutes
)
if effective_cache_read_price is None:
effective_cache_read_price = cache_read_price_per_1m
# 计算各项成本
input_cost = (input_tokens / 1_000_000) * effective_input_price
output_cost = (output_tokens / 1_000_000) * effective_output_price
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * effective_cache_creation_price
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
cache_read_cost = (
cache_read_input_tokens / 1_000_000
) * effective_cache_read_price
cache_cost = cache_creation_cost + cache_read_cost
request_cost = price_per_request if price_per_request else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return {
"input_cost": input_cost,
"output_cost": output_cost,
"cache_creation_cost": cache_creation_cost,
"cache_read_cost": cache_read_cost,
"cache_cost": cache_cost,
"request_cost": request_cost,
"total_cost": total_cost,
"tier_index": tier_index,
}
@staticmethod
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
"""根据总输入 token 数确定价格阶梯"""
if not tiered_pricing or "tiers" not in tiered_pricing:
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
for tier in tiers:
up_to = tier.get("up_to")
if up_to is None or total_input_tokens <= up_to:
return tier
return tiers[-1] if tiers else None
# =========================================================================
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
# =========================================================================
_CLI_ADAPTER_REGISTRY: Dict[str, Type["CliAdapterBase"]] = {}
_CLI_ADAPTERS_LOADED = False
def register_cli_adapter(adapter_class: Type["CliAdapterBase"]) -> Type["CliAdapterBase"]:
"""
注册 CLI Adapter 类到注册表
用法:
@register_cli_adapter
class ClaudeCliAdapter(CliAdapterBase):
FORMAT_ID = "CLAUDE_CLI"
...
"""
format_id = adapter_class.FORMAT_ID
if format_id and format_id != "UNKNOWN":
_CLI_ADAPTER_REGISTRY[format_id.upper()] = adapter_class
return adapter_class
def _ensure_cli_adapters_loaded():
"""确保所有 CLI Adapter 已被加载(触发注册)"""
global _CLI_ADAPTERS_LOADED
if _CLI_ADAPTERS_LOADED:
return
# 导入各个 CLI Adapter 模块以触发 @register_cli_adapter 装饰器
try:
from src.api.handlers.claude_cli import adapter as _ # noqa: F401
except ImportError:
pass
try:
from src.api.handlers.openai_cli import adapter as _ # noqa: F401
except ImportError:
pass
try:
from src.api.handlers.gemini_cli import adapter as _ # noqa: F401
except ImportError:
pass
_CLI_ADAPTERS_LOADED = True
def get_cli_adapter_class(api_format: str) -> Optional[Type["CliAdapterBase"]]:
"""根据 API format 获取 CLI Adapter 类"""
_ensure_cli_adapters_loaded()
return _CLI_ADAPTER_REGISTRY.get(api_format.upper()) if api_format else None
def get_cli_adapter_instance(api_format: str) -> Optional["CliAdapterBase"]:
"""根据 API format 获取 CLI Adapter 实例"""
adapter_class = get_cli_adapter_class(api_format)
if adapter_class:
return adapter_class()
return None
def list_registered_cli_formats() -> list[str]:
"""返回所有已注册的 CLI API 格式"""
_ensure_cli_adapters_loaded()
return list(_CLI_ADAPTER_REGISTRY.keys())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,279 @@
"""
格式转换器注册表
自动管理不同 API 格式之间的转换器,支持:
- 请求转换:客户端格式 → Provider 格式
- 响应转换Provider 格式 → 客户端格式
使用方法:
1. 实现 Converter 类(需要有 convert_request 和/或 convert_response 方法)
2. 调用 registry.register() 注册转换器
3. 在 Handler 中调用 registry.convert_request/convert_response
示例:
from src.api.handlers.base.format_converter_registry import converter_registry
# 注册转换器
converter_registry.register("CLAUDE", "GEMINI", ClaudeToGeminiConverter())
converter_registry.register("GEMINI", "CLAUDE", GeminiToClaudeConverter())
# 使用转换器
gemini_request = converter_registry.convert_request(claude_request, "CLAUDE", "GEMINI")
claude_response = converter_registry.convert_response(gemini_response, "GEMINI", "CLAUDE")
"""
from typing import Any, Dict, Optional, Protocol, Tuple
from src.core.logger import logger
class RequestConverter(Protocol):
"""请求转换器协议"""
def convert_request(self, request: Dict[str, Any]) -> Dict[str, Any]: ...
class ResponseConverter(Protocol):
"""响应转换器协议"""
def convert_response(self, response: Dict[str, Any]) -> Dict[str, Any]: ...
class StreamChunkConverter(Protocol):
"""流式响应块转换器协议"""
def convert_stream_chunk(self, chunk: Dict[str, Any]) -> Dict[str, Any]: ...
class FormatConverterRegistry:
"""
格式转换器注册表
管理不同 API 格式之间的双向转换器
"""
def __init__(self):
# key: (source_format, target_format), value: converter instance
self._converters: Dict[Tuple[str, str], Any] = {}
def register(
self,
source_format: str,
target_format: str,
converter: Any,
) -> None:
"""
注册格式转换器
Args:
source_format: 源格式(如 "CLAUDE", "OPENAI", "GEMINI"
target_format: 目标格式
converter: 转换器实例(需要有 convert_request/convert_response 方法)
"""
key = (source_format.upper(), target_format.upper())
self._converters[key] = converter
logger.info(f"[ConverterRegistry] 注册转换器: {source_format} -> {target_format}")
def get_converter(
self,
source_format: str,
target_format: str,
) -> Optional[Any]:
"""
获取转换器
Args:
source_format: 源格式
target_format: 目标格式
Returns:
转换器实例,如果不存在返回 None
"""
key = (source_format.upper(), target_format.upper())
return self._converters.get(key)
def has_converter(
self,
source_format: str,
target_format: str,
) -> bool:
"""检查是否存在转换器"""
key = (source_format.upper(), target_format.upper())
return key in self._converters
def convert_request(
self,
request: Dict[str, Any],
source_format: str,
target_format: str,
) -> Dict[str, Any]:
"""
转换请求
Args:
request: 原始请求字典
source_format: 源格式(客户端格式)
target_format: 目标格式Provider 格式)
Returns:
转换后的请求字典,如果无需转换或没有转换器则返回原始请求
"""
# 同格式无需转换
if source_format.upper() == target_format.upper():
return request
converter = self.get_converter(source_format, target_format)
if converter is None:
logger.warning(f"[ConverterRegistry] 未找到请求转换器: {source_format} -> {target_format},返回原始请求")
return request
if not hasattr(converter, "convert_request"):
logger.warning(f"[ConverterRegistry] 转换器缺少 convert_request 方法: {source_format} -> {target_format}")
return request
try:
converted = converter.convert_request(request)
logger.debug(f"[ConverterRegistry] 请求转换成功: {source_format} -> {target_format}")
return converted
except Exception as e:
logger.error(f"[ConverterRegistry] 请求转换失败: {source_format} -> {target_format}: {e}")
return request
def convert_response(
self,
response: Dict[str, Any],
source_format: str,
target_format: str,
) -> Dict[str, Any]:
"""
转换响应
Args:
response: 原始响应字典
source_format: 源格式Provider 格式)
target_format: 目标格式(客户端格式)
Returns:
转换后的响应字典,如果无需转换或没有转换器则返回原始响应
"""
# 同格式无需转换
if source_format.upper() == target_format.upper():
return response
converter = self.get_converter(source_format, target_format)
if converter is None:
logger.warning(f"[ConverterRegistry] 未找到响应转换器: {source_format} -> {target_format},返回原始响应")
return response
if not hasattr(converter, "convert_response"):
logger.warning(f"[ConverterRegistry] 转换器缺少 convert_response 方法: {source_format} -> {target_format}")
return response
try:
converted = converter.convert_response(response)
logger.debug(f"[ConverterRegistry] 响应转换成功: {source_format} -> {target_format}")
return converted
except Exception as e:
logger.error(f"[ConverterRegistry] 响应转换失败: {source_format} -> {target_format}: {e}")
return response
def convert_stream_chunk(
self,
chunk: Dict[str, Any],
source_format: str,
target_format: str,
) -> Dict[str, Any]:
"""
转换流式响应块
Args:
chunk: 原始流式响应块
source_format: 源格式Provider 格式)
target_format: 目标格式(客户端格式)
Returns:
转换后的流式响应块
"""
# 同格式无需转换
if source_format.upper() == target_format.upper():
return chunk
converter = self.get_converter(source_format, target_format)
if converter is None:
return chunk
# 优先使用专门的流式转换方法
if hasattr(converter, "convert_stream_chunk"):
try:
return converter.convert_stream_chunk(chunk)
except Exception as e:
logger.error(f"[ConverterRegistry] 流式块转换失败: {source_format} -> {target_format}: {e}")
return chunk
# 降级到普通响应转换
if hasattr(converter, "convert_response"):
try:
return converter.convert_response(chunk)
except Exception:
return chunk
return chunk
def list_converters(self) -> list[Tuple[str, str]]:
"""列出所有已注册的转换器"""
return list(self._converters.keys())
# 全局单例
converter_registry = FormatConverterRegistry()
def register_all_converters():
"""
注册所有内置的格式转换器
在应用启动时调用此函数
"""
# Claude <-> OpenAI
try:
from src.api.handlers.claude.converter import OpenAIToClaudeConverter
from src.api.handlers.openai.converter import ClaudeToOpenAIConverter
converter_registry.register("OPENAI", "CLAUDE", OpenAIToClaudeConverter())
converter_registry.register("CLAUDE", "OPENAI", ClaudeToOpenAIConverter())
except ImportError as e:
logger.warning(f"[ConverterRegistry] 无法加载 Claude/OpenAI 转换器: {e}")
# Claude <-> Gemini
try:
from src.api.handlers.gemini.converter import (
ClaudeToGeminiConverter,
GeminiToClaudeConverter,
)
converter_registry.register("CLAUDE", "GEMINI", ClaudeToGeminiConverter())
converter_registry.register("GEMINI", "CLAUDE", GeminiToClaudeConverter())
except ImportError as e:
logger.warning(f"[ConverterRegistry] 无法加载 Claude/Gemini 转换器: {e}")
# OpenAI <-> Gemini
try:
from src.api.handlers.gemini.converter import (
GeminiToOpenAIConverter,
OpenAIToGeminiConverter,
)
converter_registry.register("OPENAI", "GEMINI", OpenAIToGeminiConverter())
converter_registry.register("GEMINI", "OPENAI", GeminiToOpenAIConverter())
except ImportError as e:
logger.warning(f"[ConverterRegistry] 无法加载 OpenAI/Gemini 转换器: {e}")
logger.info(f"[ConverterRegistry] 已注册 {len(converter_registry.list_converters())} 个格式转换器")
__all__ = [
"FormatConverterRegistry",
"converter_registry",
"register_all_converters",
]

View File

@@ -0,0 +1,465 @@
"""
响应解析器工厂
直接根据格式 ID 创建对应的 ResponseParser 实现,
不再经过 Protocol 抽象层。
"""
from typing import Any, Dict, Optional
from src.api.handlers.base.response_parser import (
ParsedChunk,
ParsedResponse,
ResponseParser,
StreamStats,
)
class OpenAIResponseParser(ResponseParser):
"""OpenAI 格式响应解析器"""
def __init__(self):
from src.api.handlers.openai.stream_parser import OpenAIStreamParser
self._parser = OpenAIStreamParser()
self.name = "OPENAI"
self.api_format = "OPENAI"
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
if not line or not line.strip():
return None
if line.startswith("data: "):
data_str = line[6:]
else:
data_str = line
parsed = self._parser.parse_line(data_str)
if parsed is None:
return None
chunk = ParsedChunk(
raw_line=line,
event_type=None,
data=parsed,
)
# 提取文本增量
text_delta = self._parser.extract_text_delta(parsed)
if text_delta:
chunk.text_delta = text_delta
stats.collected_text += text_delta
# 检查是否结束
if self._parser.is_done_chunk(parsed):
chunk.is_done = True
stats.has_completion = True
stats.chunk_count += 1
stats.data_count += 1
return chunk
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
result = ParsedResponse(
raw_response=response,
status_code=status_code,
)
# 提取文本内容
choices = response.get("choices", [])
if choices:
message = choices[0].get("message", {})
content = message.get("content")
if content:
result.text_content = content
result.response_id = response.get("id")
# 提取 usage
usage = response.get("usage", {})
result.input_tokens = usage.get("prompt_tokens", 0)
result.output_tokens = usage.get("completion_tokens", 0)
# 检查错误
if "error" in response:
result.is_error = True
error = response.get("error", {})
if isinstance(error, dict):
result.error_type = error.get("type")
result.error_message = error.get("message")
else:
result.error_message = str(error)
return result
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
usage = response.get("usage", {})
return {
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
"cache_creation_tokens": 0,
"cache_read_tokens": 0,
}
def extract_text_content(self, response: Dict[str, Any]) -> str:
choices = response.get("choices", [])
if choices:
message = choices[0].get("message", {})
content = message.get("content")
if content:
return content
return ""
def is_error_response(self, response: Dict[str, Any]) -> bool:
return "error" in response
class OpenAICliResponseParser(OpenAIResponseParser):
"""OpenAI CLI 格式响应解析器"""
def __init__(self):
super().__init__()
self.name = "OPENAI_CLI"
self.api_format = "OPENAI_CLI"
class ClaudeResponseParser(ResponseParser):
"""Claude 格式响应解析器"""
def __init__(self):
from src.api.handlers.claude.stream_parser import ClaudeStreamParser
self._parser = ClaudeStreamParser()
self.name = "CLAUDE"
self.api_format = "CLAUDE"
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
if not line or not line.strip():
return None
if line.startswith("data: "):
data_str = line[6:]
else:
data_str = line
parsed = self._parser.parse_line(data_str)
if parsed is None:
return None
chunk = ParsedChunk(
raw_line=line,
event_type=self._parser.get_event_type(parsed),
data=parsed,
)
# 提取文本增量
text_delta = self._parser.extract_text_delta(parsed)
if text_delta:
chunk.text_delta = text_delta
stats.collected_text += text_delta
# 检查是否结束
if self._parser.is_done_event(parsed):
chunk.is_done = True
stats.has_completion = True
# 提取 usage
usage = self._parser.extract_usage(parsed)
if usage:
chunk.input_tokens = usage.get("input_tokens", 0)
chunk.output_tokens = usage.get("output_tokens", 0)
chunk.cache_creation_tokens = usage.get("cache_creation_tokens", 0)
chunk.cache_read_tokens = usage.get("cache_read_tokens", 0)
stats.input_tokens = chunk.input_tokens
stats.output_tokens = chunk.output_tokens
stats.cache_creation_tokens = chunk.cache_creation_tokens
stats.cache_read_tokens = chunk.cache_read_tokens
# 检查错误
if self._parser.is_error_event(parsed):
chunk.is_error = True
error = parsed.get("error", {})
if isinstance(error, dict):
chunk.error_message = error.get("message", str(error))
else:
chunk.error_message = str(error)
stats.chunk_count += 1
stats.data_count += 1
return chunk
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
result = ParsedResponse(
raw_response=response,
status_code=status_code,
)
# 提取文本内容
content = response.get("content", [])
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(block.get("text", ""))
result.text_content = "".join(text_parts)
result.response_id = response.get("id")
# 提取 usage
usage = response.get("usage", {})
result.input_tokens = usage.get("input_tokens", 0)
result.output_tokens = usage.get("output_tokens", 0)
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0)
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
# 检查错误
if "error" in response or response.get("type") == "error":
result.is_error = True
error = response.get("error", {})
if isinstance(error, dict):
result.error_type = error.get("type")
result.error_message = error.get("message")
else:
result.error_message = str(error)
return result
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
usage = response.get("usage", {})
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
}
def extract_text_content(self, response: Dict[str, Any]) -> str:
content = response.get("content", [])
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text_parts.append(block.get("text", ""))
return "".join(text_parts)
return ""
def is_error_response(self, response: Dict[str, Any]) -> bool:
return "error" in response or response.get("type") == "error"
class ClaudeCliResponseParser(ClaudeResponseParser):
"""Claude CLI 格式响应解析器"""
def __init__(self):
super().__init__()
self.name = "CLAUDE_CLI"
self.api_format = "CLAUDE_CLI"
class GeminiResponseParser(ResponseParser):
"""Gemini 格式响应解析器"""
def __init__(self):
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
self._parser = GeminiStreamParser()
self.name = "GEMINI"
self.api_format = "GEMINI"
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
"""
解析 Gemini SSE 行
Gemini 的流式响应使用 SSE 格式 (data: {...})
"""
if not line or not line.strip():
return None
# Gemini SSE 格式: data: {...}
if line.startswith("data: "):
data_str = line[6:]
else:
data_str = line
parsed = self._parser.parse_line(data_str)
if parsed is None:
return None
chunk = ParsedChunk(
raw_line=line,
event_type="content",
data=parsed,
)
# 提取文本增量
text_delta = self._parser.extract_text_delta(parsed)
if text_delta:
chunk.text_delta = text_delta
stats.collected_text += text_delta
# 检查是否结束
if self._parser.is_done_event(parsed):
chunk.is_done = True
stats.has_completion = True
# 提取 usage
usage = self._parser.extract_usage(parsed)
if usage:
chunk.input_tokens = usage.get("input_tokens", 0)
chunk.output_tokens = usage.get("output_tokens", 0)
chunk.cache_read_tokens = usage.get("cached_tokens", 0)
stats.input_tokens = chunk.input_tokens
stats.output_tokens = chunk.output_tokens
stats.cache_read_tokens = chunk.cache_read_tokens
# 检查错误
if self._parser.is_error_event(parsed):
chunk.is_error = True
error = parsed.get("error", {})
if isinstance(error, dict):
chunk.error_message = error.get("message", str(error))
else:
chunk.error_message = str(error)
stats.chunk_count += 1
stats.data_count += 1
return chunk
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
result = ParsedResponse(
raw_response=response,
status_code=status_code,
)
# 提取文本内容
candidates = response.get("candidates", [])
if candidates:
content = candidates[0].get("content", {})
parts = content.get("parts", [])
text_parts = []
for part in parts:
if "text" in part:
text_parts.append(part["text"])
result.text_content = "".join(text_parts)
result.response_id = response.get("modelVersion")
# 提取 usage调用 GeminiStreamParser.extract_usage 作为单一实现源)
usage = self._parser.extract_usage(response)
if usage:
result.input_tokens = usage.get("input_tokens", 0)
result.output_tokens = usage.get("output_tokens", 0)
result.cache_read_tokens = usage.get("cached_tokens", 0)
# 检查错误(使用增强的错误检测)
error_info = self._parser.extract_error_info(response)
if error_info:
result.is_error = True
result.error_type = error_info.get("status")
result.error_message = error_info.get("message")
return result
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
"""
从 Gemini 响应中提取 token 使用量
调用 GeminiStreamParser.extract_usage 作为单一实现源
"""
usage = self._parser.extract_usage(response)
if not usage:
return {
"input_tokens": 0,
"output_tokens": 0,
"cache_creation_tokens": 0,
"cache_read_tokens": 0,
}
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": 0,
"cache_read_tokens": usage.get("cached_tokens", 0),
}
def extract_text_content(self, response: Dict[str, Any]) -> str:
candidates = response.get("candidates", [])
if candidates:
content = candidates[0].get("content", {})
parts = content.get("parts", [])
text_parts = []
for part in parts:
if "text" in part:
text_parts.append(part["text"])
return "".join(text_parts)
return ""
def is_error_response(self, response: Dict[str, Any]) -> bool:
"""
判断响应是否为错误响应
使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误
"""
return self._parser.is_error_event(response)
class GeminiCliResponseParser(GeminiResponseParser):
"""Gemini CLI 格式响应解析器"""
def __init__(self):
super().__init__()
self.name = "GEMINI_CLI"
self.api_format = "GEMINI_CLI"
# 解析器注册表
_PARSERS = {
"CLAUDE": ClaudeResponseParser,
"CLAUDE_CLI": ClaudeCliResponseParser,
"OPENAI": OpenAIResponseParser,
"OPENAI_CLI": OpenAICliResponseParser,
"GEMINI": GeminiResponseParser,
"GEMINI_CLI": GeminiCliResponseParser,
}
def get_parser_for_format(format_id: str) -> ResponseParser:
"""
根据格式 ID 获取 ResponseParser
Args:
format_id: 格式 ID"CLAUDE", "OPENAI", "CLAUDE_CLI", "OPENAI_CLI"
Returns:
ResponseParser 实例
Raises:
KeyError: 格式不存在
"""
format_id = format_id.upper()
if format_id not in _PARSERS:
raise KeyError(f"Unknown format: {format_id}")
return _PARSERS[format_id]()
def is_cli_format(format_id: str) -> bool:
"""判断是否为 CLI 格式"""
return format_id.upper().endswith("_CLI")
__all__ = [
"OpenAIResponseParser",
"OpenAICliResponseParser",
"ClaudeResponseParser",
"ClaudeCliResponseParser",
"GeminiResponseParser",
"GeminiCliResponseParser",
"get_parser_for_format",
"get_parser_from_protocol",
"is_cli_format",
]

View File

@@ -0,0 +1,207 @@
"""
请求构建器 - 透传模式
透传模式 (Passthrough): CLI 和 Chat 等场景,原样转发请求体和头部
- 清理敏感头部authorization, x-api-key, host, content-length 等
- 保留所有其他头部和请求体字段
- 适用于Claude CLI、OpenAI CLI、Chat API 等场景
使用方式:
builder = PassthroughRequestBuilder()
payload, headers = builder.build(original_body, original_headers, endpoint, key)
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, FrozenSet, Optional, Tuple
from src.core.crypto import crypto_service
# ==============================================================================
# 统一的头部配置常量
# ==============================================================================
# 敏感头部 - 透传时需要清理(黑名单)
# 这些头部要么包含认证信息,要么由代理层重新生成
SENSITIVE_HEADERS: FrozenSet[str] = frozenset(
{
"authorization",
"x-api-key",
"x-goog-api-key", # Gemini API 认证头
"host",
"content-length",
"transfer-encoding",
"connection",
# 不透传 accept-encoding让 httpx 自己协商压缩格式
# 避免客户端请求 brotli/zstd 但 httpx 不支持解压的问题
"accept-encoding",
}
)
# ==============================================================================
# 请求构建器
# ==============================================================================
class RequestBuilder(ABC):
"""请求构建器抽象基类"""
@abstractmethod
def build_payload(
self,
original_body: Dict[str, Any],
*,
mapped_model: Optional[str] = None,
is_stream: bool = False,
) -> Dict[str, Any]:
"""构建请求体"""
pass
@abstractmethod
def build_headers(
self,
original_headers: Dict[str, str],
endpoint: Any,
key: Any,
*,
extra_headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
"""构建请求头"""
pass
def build(
self,
original_body: Dict[str, Any],
original_headers: Dict[str, str],
endpoint: Any,
key: Any,
*,
mapped_model: Optional[str] = None,
is_stream: bool = False,
extra_headers: Optional[Dict[str, str]] = None,
) -> Tuple[Dict[str, Any], Dict[str, str]]:
"""
构建完整的请求(请求体 + 请求头)
Returns:
Tuple[payload, headers]
"""
payload = self.build_payload(
original_body,
mapped_model=mapped_model,
is_stream=is_stream,
)
headers = self.build_headers(
original_headers,
endpoint,
key,
extra_headers=extra_headers,
)
return payload, headers
class PassthroughRequestBuilder(RequestBuilder):
"""
透传模式请求构建器
适用于 CLI 等场景,尽量保持请求原样:
- 请求体直接复制只修改必要字段model, stream
- 请求头:清理敏感头部(黑名单),透传其他所有头部
"""
def build_payload(
self,
original_body: Dict[str, Any],
*,
mapped_model: Optional[str] = None, # noqa: ARG002 - 由 apply_mapped_model 处理
is_stream: bool = False, # noqa: ARG002 - 保留原始值,不自动添加
) -> Dict[str, Any]:
"""
透传请求体 - 原样复制,不做任何修改
透传模式下:
- model: 由各 handler 的 apply_mapped_model 方法处理
- stream: 保留客户端原始值(不同 API 处理方式不同)
"""
return dict(original_body)
def build_headers(
self,
original_headers: Dict[str, str],
endpoint: Any,
key: Any,
*,
extra_headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
"""
透传请求头 - 清理敏感头部(黑名单),透传其他所有头部
"""
from src.core.api_format_metadata import get_auth_config, resolve_api_format
headers: Dict[str, str] = {}
# 1. 根据 API 格式自动设置认证头
decrypted_key = crypto_service.decrypt(key.api_key)
api_format = getattr(endpoint, "api_format", None)
resolved_format = resolve_api_format(api_format)
auth_header, auth_type = (
get_auth_config(resolved_format) if resolved_format else ("Authorization", "bearer")
)
if auth_type == "bearer":
headers[auth_header] = f"Bearer {decrypted_key}"
else:
headers[auth_header] = decrypted_key
# 2. 添加 endpoint 配置的额外头部
if endpoint.headers:
headers.update(endpoint.headers)
# 3. 透传原始头部(排除敏感头部 - 黑名单模式)
if original_headers:
for name, value in original_headers.items():
lower_name = name.lower()
# 跳过敏感头部
if lower_name in SENSITIVE_HEADERS:
continue
headers[name] = value
# 4. 添加额外头部
if extra_headers:
headers.update(extra_headers)
# 5. 确保有 Content-Type
if "Content-Type" not in headers and "content-type" not in headers:
headers["Content-Type"] = "application/json"
return headers
# ==============================================================================
# 便捷函数
# ==============================================================================
def build_passthrough_request(
original_body: Dict[str, Any],
original_headers: Dict[str, str],
endpoint: Any,
key: Any,
) -> Tuple[Dict[str, Any], Dict[str, str]]:
"""
构建透传模式的请求
纯透传:原样复制请求体,只处理请求头(认证等)。
model mapping 和 stream 由调用方自行处理(不同 API 格式处理方式不同)。
"""
builder = PassthroughRequestBuilder()
return builder.build(
original_body,
original_headers,
endpoint,
key,
)

View File

@@ -0,0 +1,174 @@
"""
响应解析器基类 - 定义统一的响应解析接口
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
@dataclass
class ParsedChunk:
"""解析后的流式数据块"""
# 原始数据
raw_line: str
event_type: Optional[str] = None
data: Optional[Dict[str, Any]] = None
# 提取的内容
text_delta: str = ""
is_done: bool = False
is_error: bool = False
error_message: Optional[str] = None
# 使用量信息(通常在最后一个 chunk 中)
input_tokens: int = 0
output_tokens: int = 0
cache_creation_tokens: int = 0
cache_read_tokens: int = 0
# 响应 ID
response_id: Optional[str] = None
@dataclass
class StreamStats:
"""流式响应统计信息"""
# 计数
chunk_count: int = 0
data_count: int = 0
# Token 使用量
input_tokens: int = 0
output_tokens: int = 0
cache_creation_tokens: int = 0
cache_read_tokens: int = 0
# 内容
collected_text: str = ""
response_id: Optional[str] = None
# 状态
has_completion: bool = False
status_code: int = 200
error_message: Optional[str] = None
# Provider 信息
provider_name: Optional[str] = None
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
# 响应头和完整响应
response_headers: Dict[str, str] = field(default_factory=dict)
final_response: Optional[Dict[str, Any]] = None
@dataclass
class ParsedResponse:
"""解析后的非流式响应"""
# 原始响应
raw_response: Dict[str, Any]
status_code: int
# 提取的内容
text_content: str = ""
response_id: Optional[str] = None
# 使用量
input_tokens: int = 0
output_tokens: int = 0
cache_creation_tokens: int = 0
cache_read_tokens: int = 0
# 错误信息
is_error: bool = False
error_type: Optional[str] = None
error_message: Optional[str] = None
class ResponseParser(ABC):
"""
响应解析器基类
定义统一的接口来解析不同 API 格式的响应。
子类需要实现具体的解析逻辑。
"""
# 解析器名称(用于日志)
name: str = "base"
# 支持的 API 格式
api_format: str = "UNKNOWN"
@abstractmethod
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
"""
解析单行 SSE 数据
Args:
line: SSE 行数据
stats: 流统计对象(会被更新)
Returns:
解析后的数据块,如果行不包含有效数据则返回 None
"""
pass
@abstractmethod
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
"""
解析非流式响应
Args:
response: 响应 JSON
status_code: HTTP 状态码
Returns:
解析后的响应对象
"""
pass
@abstractmethod
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
"""
从响应中提取 token 使用量
Args:
response: 响应 JSON
Returns:
包含 input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens 的字典
"""
pass
@abstractmethod
def extract_text_content(self, response: Dict[str, Any]) -> str:
"""
从响应中提取文本内容
Args:
response: 响应 JSON
Returns:
提取的文本内容
"""
pass
def is_error_response(self, response: Dict[str, Any]) -> bool:
"""
判断响应是否为错误响应
Args:
response: 响应 JSON
Returns:
是否为错误响应
"""
return "error" in response
def create_stats(self) -> StreamStats:
"""创建新的流统计对象"""
return StreamStats()

View File

@@ -0,0 +1,17 @@
"""
Claude Chat API 处理器
"""
from src.api.handlers.claude.adapter import (
ClaudeChatAdapter,
ClaudeTokenCountAdapter,
build_claude_adapter,
)
from src.api.handlers.claude.handler import ClaudeChatHandler
__all__ = [
"ClaudeChatAdapter",
"ClaudeTokenCountAdapter",
"build_claude_adapter",
"ClaudeChatHandler",
]

View File

@@ -0,0 +1,228 @@
"""
Claude Chat Adapter - 基于 ChatAdapterBase 的 Claude Chat API 适配器
处理 /v1/messages 端点的 Claude Chat 格式请求。
"""
from typing import Any, Dict, Optional, Type
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.context import ApiRequestContext
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.core.logger import logger
from src.core.optimization_utils import TokenCounter
from src.models.claude import ClaudeMessagesRequest, ClaudeTokenCountRequest
class ClaudeCapabilityDetector:
"""Claude API 能力检测器"""
@staticmethod
def detect_from_headers(
headers: Dict[str, str],
request_body: Optional[Dict[str, Any]] = None,
) -> Dict[str, bool]:
"""
从 Claude 请求头检测能力需求
检测规则:
- anthropic-beta: context-1m-xxx -> context_1m: True
Args:
headers: 请求头字典
request_body: 请求体Claude 不使用,保留用于接口统一)
"""
requirements: Dict[str, bool] = {}
# 检查 anthropic-beta 请求头(大小写不敏感)
beta_header = None
for key, value in headers.items():
if key.lower() == "anthropic-beta":
beta_header = value
break
if beta_header:
# 检查是否包含 context-1m 标识
if "context-1m" in beta_header.lower():
requirements["context_1m"] = True
return requirements
@register_adapter
class ClaudeChatAdapter(ChatAdapterBase):
"""
Claude Chat API 适配器
处理 Claude Chat 格式的请求(/v1/messages 端点,进行格式验证)。
"""
FORMAT_ID = "CLAUDE"
name = "claude.chat"
@property
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
"""延迟导入 Handler 类避免循环依赖"""
from src.api.handlers.claude.handler import ClaudeChatHandler
return ClaudeChatHandler
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
super().__init__(allowed_api_formats or ["CLAUDE"])
logger.info(f"[{self.name}] 初始化Chat模式适配器 | API格式: {self.allowed_api_formats}")
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (x-api-key)"""
return request.headers.get("x-api-key")
def detect_capability_requirements(
self,
headers: Dict[str, str],
request_body: Optional[Dict[str, Any]] = None,
) -> Dict[str, bool]:
"""检测 Claude 请求中隐含的能力需求"""
return ClaudeCapabilityDetector.detect_from_headers(headers)
# =========================================================================
# Claude 特定的计费逻辑
# =========================================================================
def compute_total_input_context(
self,
input_tokens: int,
cache_read_input_tokens: int,
cache_creation_input_tokens: int = 0,
) -> int:
"""
计算 Claude 的总输入上下文(用于阶梯计费判定)
Claude 的总输入 = input_tokens + cache_creation_input_tokens + cache_read_input_tokens
"""
return input_tokens + cache_creation_input_tokens + cache_read_input_tokens
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
"""验证请求体"""
try:
if not isinstance(original_request_body, dict):
raise ValueError("Request body must be a JSON object")
required_fields = ["model", "messages", "max_tokens"]
missing_fields = [f for f in required_fields if f not in original_request_body]
if missing_fields:
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
request = ClaudeMessagesRequest.model_validate(
original_request_body,
strict=False,
)
except ValueError as e:
logger.error(f"请求体基本验证失败: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
request = ClaudeMessagesRequest.model_construct(
model=original_request_body.get("model"),
max_tokens=original_request_body.get("max_tokens"),
messages=original_request_body.get("messages", []),
stream=original_request_body.get("stream", False),
)
return request
def _build_audit_metadata(self, _payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
"""构建 Claude Chat 特定的审计元数据"""
role_counts: dict[str, int] = {}
for message in request_obj.messages:
role_counts[message.role] = role_counts.get(message.role, 0) + 1
return {
"action": "claude_messages",
"model": request_obj.model,
"stream": bool(request_obj.stream),
"max_tokens": request_obj.max_tokens,
"temperature": getattr(request_obj, "temperature", None),
"top_p": getattr(request_obj, "top_p", None),
"top_k": getattr(request_obj, "top_k", None),
"messages_count": len(request_obj.messages),
"message_roles": role_counts,
"stop_sequences": len(request_obj.stop_sequences or []),
"tools_count": len(request_obj.tools or []),
"system_present": bool(request_obj.system),
"metadata_present": bool(request_obj.metadata),
"thinking_enabled": bool(request_obj.thinking),
}
def build_claude_adapter(x_app_header: Optional[str]):
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
if x_app_header and x_app_header.lower() == "cli":
from src.api.handlers.claude_cli.adapter import ClaudeCliAdapter
return ClaudeCliAdapter()
return ClaudeChatAdapter()
class ClaudeTokenCountAdapter(ApiAdapter):
"""计算 Claude 请求 Token 数的轻量适配器。"""
name = "claude.token_count"
mode = ApiMode.STANDARD
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (x-api-key 或 Authorization: Bearer)"""
# 优先检查 x-api-key
api_key = request.headers.get("x-api-key")
if api_key:
return api_key
# 降级到 Authorization: Bearer
authorization = request.headers.get("authorization")
if authorization and authorization.startswith("Bearer "):
return authorization.replace("Bearer ", "")
return None
async def handle(self, context: ApiRequestContext):
payload = context.ensure_json_body()
try:
request = ClaudeTokenCountRequest.model_validate(payload, strict=False)
except Exception as e:
logger.error(f"Token count payload invalid: {e}")
raise HTTPException(status_code=400, detail="Invalid token count payload") from e
token_counter = TokenCounter()
total_tokens = 0
if request.system:
if isinstance(request.system, str):
total_tokens += token_counter.count_tokens(request.system, request.model)
elif isinstance(request.system, list):
for block in request.system:
if hasattr(block, "text"):
total_tokens += token_counter.count_tokens(block.text, request.model)
messages_dict = [
msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in request.messages
]
total_tokens += token_counter.count_messages_tokens(messages_dict, request.model)
context.add_audit_metadata(
action="claude_token_count",
model=request.model,
messages_count=len(request.messages),
system_present=bool(request.system),
tools_count=len(request.tools or []),
thinking_enabled=bool(request.thinking),
input_tokens=total_tokens,
)
return JSONResponse({"input_tokens": total_tokens})
__all__ = [
"ClaudeChatAdapter",
"ClaudeTokenCountAdapter",
"build_claude_adapter",
]

View File

@@ -0,0 +1,490 @@
"""
OpenAI -> Claude 格式转换器
将 OpenAI Chat Completions API 格式转换为 Claude Messages API 格式。
"""
from __future__ import annotations
import json
import time
import uuid
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
class OpenAIToClaudeConverter:
"""
OpenAI -> Claude 格式转换器
支持:
- 请求转换OpenAI Chat Request -> Claude Request
- 响应转换OpenAI Chat Response -> Claude Response
- 流式转换OpenAI SSE -> Claude SSE
"""
# 内容类型常量
CONTENT_TYPE_TEXT = "text"
CONTENT_TYPE_IMAGE = "image"
CONTENT_TYPE_TOOL_USE = "tool_use"
CONTENT_TYPE_TOOL_RESULT = "tool_result"
# 停止原因映射OpenAI -> Claude
FINISH_REASON_MAP = {
"stop": "end_turn",
"length": "max_tokens",
"tool_calls": "tool_use",
"function_call": "tool_use",
"content_filter": "end_turn",
}
def __init__(self, model_mapping: Optional[Dict[str, str]] = None):
"""
Args:
model_mapping: OpenAI 模型到 Claude 模型的映射
"""
self._model_mapping = model_mapping or {}
# ==================== 请求转换 ====================
def convert_request(self, request: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]:
"""
将 OpenAI 请求转换为 Claude 格式
Args:
request: OpenAI 请求Dict 或 Pydantic 模型)
Returns:
Claude 格式的请求字典
"""
if hasattr(request, "model_dump"):
data = request.model_dump(exclude_none=True)
else:
data = dict(request)
# 模型映射
model = data.get("model", "")
claude_model = self._model_mapping.get(model, model)
# 处理消息
system_content: Optional[str] = None
claude_messages: List[Dict[str, Any]] = []
for message in data.get("messages", []):
role = message.get("role")
# 提取 system 消息
if role == "system":
system_content = self._collapse_content(message.get("content"))
continue
# 转换其他消息
converted = self._convert_message(message)
if converted:
claude_messages.append(converted)
# 构建 Claude 请求
result: Dict[str, Any] = {
"model": claude_model,
"messages": claude_messages,
"max_tokens": data.get("max_tokens") or 4096,
}
# 可选参数
if data.get("temperature") is not None:
result["temperature"] = data["temperature"]
if data.get("top_p") is not None:
result["top_p"] = data["top_p"]
if data.get("stream"):
result["stream"] = data["stream"]
if data.get("stop"):
result["stop_sequences"] = self._convert_stop(data["stop"])
if system_content:
result["system"] = system_content
# 工具转换
tools = self._convert_tools(data.get("tools"))
if tools:
result["tools"] = tools
tool_choice = self._convert_tool_choice(data.get("tool_choice"))
if tool_choice:
result["tool_choice"] = tool_choice
return result
def _convert_message(self, message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""转换单条消息"""
role = message.get("role")
if role == "user":
return self._convert_user_message(message)
if role == "assistant":
return self._convert_assistant_message(message)
if role == "tool":
return self._convert_tool_message(message)
return None
def _convert_user_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""转换用户消息"""
content = message.get("content")
if isinstance(content, str) or content is None:
return {"role": "user", "content": content or ""}
# 转换内容数组
claude_content: List[Dict[str, Any]] = []
for item in content:
item_type = item.get("type")
if item_type == "text":
claude_content.append(
{"type": self.CONTENT_TYPE_TEXT, "text": item.get("text", "")}
)
elif item_type == "image_url":
image_url = (item.get("image_url") or {}).get("url", "")
claude_content.append(self._convert_image_url(image_url))
return {"role": "user", "content": claude_content}
def _convert_assistant_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""转换助手消息"""
content_blocks: List[Dict[str, Any]] = []
# 处理文本内容
content = message.get("content")
if isinstance(content, str):
content_blocks.append({"type": self.CONTENT_TYPE_TEXT, "text": content})
elif isinstance(content, list):
for part in content:
if part.get("type") == "text":
content_blocks.append(
{"type": self.CONTENT_TYPE_TEXT, "text": part.get("text", "")}
)
# 处理工具调用
for tool_call in message.get("tool_calls") or []:
if tool_call.get("type") == "function":
function = tool_call.get("function", {})
arguments = function.get("arguments", "{}")
try:
input_data = json.loads(arguments)
except json.JSONDecodeError:
input_data = {"raw": arguments}
content_blocks.append(
{
"type": self.CONTENT_TYPE_TOOL_USE,
"id": tool_call.get("id", ""),
"name": function.get("name", ""),
"input": input_data,
}
)
# 简化单文本内容
if not content_blocks:
return {"role": "assistant", "content": ""}
if len(content_blocks) == 1 and content_blocks[0]["type"] == self.CONTENT_TYPE_TEXT:
return {"role": "assistant", "content": content_blocks[0]["text"]}
return {"role": "assistant", "content": content_blocks}
def _convert_tool_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""转换工具结果消息"""
tool_content = message.get("content", "")
# 尝试解析 JSON
parsed_content = tool_content
if isinstance(tool_content, str):
try:
parsed_content = json.loads(tool_content)
except json.JSONDecodeError:
pass
tool_block = {
"type": self.CONTENT_TYPE_TOOL_RESULT,
"tool_use_id": message.get("tool_call_id", ""),
"content": parsed_content,
}
return {"role": "user", "content": [tool_block]}
def _convert_tools(
self, tools: Optional[List[Dict[str, Any]]]
) -> Optional[List[Dict[str, Any]]]:
"""转换工具定义"""
if not tools:
return None
result: List[Dict[str, Any]] = []
for tool in tools:
if tool.get("type") != "function":
continue
function = tool.get("function", {})
result.append(
{
"name": function.get("name", ""),
"description": function.get("description"),
"input_schema": function.get("parameters") or {},
}
)
return result if result else None
def _convert_tool_choice(
self, tool_choice: Optional[Union[str, Dict[str, Any]]]
) -> Optional[Dict[str, Any]]:
"""转换工具选择"""
if tool_choice is None:
return None
if tool_choice == "none":
return {"type": "none"}
if tool_choice == "auto":
return {"type": "auto"}
if tool_choice == "required":
return {"type": "any"}
if isinstance(tool_choice, dict) and tool_choice.get("type") == "function":
function = tool_choice.get("function", {})
return {"type": "tool_use", "name": function.get("name", "")}
return {"type": "auto"}
def _convert_image_url(self, image_url: str) -> Dict[str, Any]:
"""转换图片 URL"""
if image_url.startswith("data:"):
header, _, data = image_url.partition(",")
media_type = "image/jpeg"
if ";" in header:
media_type = header.split(";")[0].split(":")[-1]
return {
"type": self.CONTENT_TYPE_IMAGE,
"source": {
"type": "base64",
"media_type": media_type,
"data": data,
},
}
return {"type": self.CONTENT_TYPE_TEXT, "text": f"[Image: {image_url}]"}
def _convert_stop(self, stop: Optional[Union[str, List[str]]]) -> Optional[List[str]]:
"""转换停止序列"""
if stop is None:
return None
if isinstance(stop, str):
return [stop]
return stop
# ==================== 响应转换 ====================
def convert_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""
将 OpenAI 响应转换为 Claude 格式
Args:
response: OpenAI 响应字典
Returns:
Claude 格式的响应字典
"""
choices = response.get("choices", [])
if not choices:
return self._empty_claude_response(response)
choice = choices[0]
message = choice.get("message", {})
# 构建 content 数组
content: List[Dict[str, Any]] = []
# 处理文本
text_content = message.get("content")
if text_content:
content.append(
{
"type": self.CONTENT_TYPE_TEXT,
"text": text_content,
}
)
# 处理工具调用
for tool_call in message.get("tool_calls") or []:
if tool_call.get("type") == "function":
function = tool_call.get("function", {})
arguments = function.get("arguments", "{}")
try:
input_data = json.loads(arguments)
except json.JSONDecodeError:
input_data = {"raw": arguments}
content.append(
{
"type": self.CONTENT_TYPE_TOOL_USE,
"id": tool_call.get("id", ""),
"name": function.get("name", ""),
"input": input_data,
}
)
# 转换 finish_reason
finish_reason = choice.get("finish_reason")
stop_reason = self.FINISH_REASON_MAP.get(finish_reason, "end_turn")
# 转换 usage
usage = response.get("usage", {})
claude_usage = {
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
}
return {
"id": f"msg_{response.get('id', uuid.uuid4().hex[:8])}",
"type": "message",
"role": "assistant",
"model": response.get("model", ""),
"content": content,
"stop_reason": stop_reason,
"stop_sequence": None,
"usage": claude_usage,
}
def _empty_claude_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""构建空的 Claude 响应"""
return {
"id": f"msg_{response.get('id', uuid.uuid4().hex[:8])}",
"type": "message",
"role": "assistant",
"model": response.get("model", ""),
"content": [],
"stop_reason": "end_turn",
"stop_sequence": None,
"usage": {"input_tokens": 0, "output_tokens": 0},
}
# ==================== 流式转换 ====================
def convert_stream_chunk(
self,
chunk: Dict[str, Any],
model: str = "",
message_id: Optional[str] = None,
message_started: bool = False,
) -> List[Dict[str, Any]]:
"""
将 OpenAI SSE chunk 转换为 Claude SSE 事件
Args:
chunk: OpenAI SSE chunk
model: 模型名称
message_id: 消息 ID
message_started: 是否已发送 message_start
Returns:
Claude SSE 事件列表
"""
events: List[Dict[str, Any]] = []
choices = chunk.get("choices", [])
if not choices:
return events
choice = choices[0]
delta = choice.get("delta", {})
finish_reason = choice.get("finish_reason")
# 处理角色(第一个 chunk
role = delta.get("role")
if role and not message_started:
msg_id = message_id or f"msg_{uuid.uuid4().hex[:8]}"
events.append(
{
"type": "message_start",
"message": {
"id": msg_id,
"type": "message",
"role": role,
"model": model,
"content": [],
"stop_reason": None,
"stop_sequence": None,
},
}
)
# 处理文本内容
content_delta = delta.get("content")
if isinstance(content_delta, str):
events.append(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": content_delta},
}
)
# 处理工具调用
tool_calls = delta.get("tool_calls", [])
for tool_call in tool_calls:
index = tool_call.get("index", 0)
# 工具调用开始
if "id" in tool_call:
function = tool_call.get("function", {})
events.append(
{
"type": "content_block_start",
"index": index,
"content_block": {
"type": self.CONTENT_TYPE_TOOL_USE,
"id": tool_call["id"],
"name": function.get("name", ""),
},
}
)
# 工具调用参数增量
function = tool_call.get("function", {})
if "arguments" in function:
events.append(
{
"type": "content_block_delta",
"index": index,
"delta": {
"type": "input_json_delta",
"partial_json": function.get("arguments", ""),
},
}
)
# 处理结束
if finish_reason:
stop_reason = self.FINISH_REASON_MAP.get(finish_reason, "end_turn")
events.append(
{
"type": "message_delta",
"delta": {"stop_reason": stop_reason},
}
)
return events
# ==================== 工具方法 ====================
def _collapse_content(
self, content: Optional[Union[str, List[Dict[str, Any]]]]
) -> Optional[str]:
"""折叠内容为字符串"""
if isinstance(content, str):
return content
if not content:
return None
text_parts = [part.get("text", "") for part in content if part.get("type") == "text"]
return "\n\n".join(filter(None, text_parts)) or None
__all__ = ["OpenAIToClaudeConverter"]

View File

@@ -0,0 +1,150 @@
"""
Claude Chat Handler - 基于通用 Chat Handler 基类的简化实现
继承 ChatHandlerBase只需覆盖格式特定的方法。
代码量从原来的 ~1470 行减少到 ~120 行。
"""
from typing import Any, Dict, Optional
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
class ClaudeChatHandler(ChatHandlerBase):
"""
Claude Chat Handler - 处理 Claude Chat/CLI API 格式的请求
格式特点:
- 使用 input_tokens/output_tokens
- 支持 cache_creation_input_tokens/cache_read_input_tokens
- 请求格式ClaudeMessagesRequest
"""
FORMAT_ID = "CLAUDE"
def extract_model_from_request(
self,
request_body: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
) -> str:
"""
从请求中提取模型名 - Claude 格式实现
Claude API 的 model 在请求体顶级字段。
Args:
request_body: 请求体
path_params: URL 路径参数Claude 不使用)
Returns:
模型名
"""
model = request_body.get("model")
return str(model) if model else "unknown"
def apply_mapped_model(
self,
request_body: Dict[str, Any],
mapped_model: str,
) -> Dict[str, Any]:
"""
将映射后的模型名应用到请求体
Claude API 的 model 在请求体顶级字段。
Args:
request_body: 原始请求体
mapped_model: 映射后的模型名
Returns:
更新了 model 字段的请求体
"""
result = dict(request_body)
result["model"] = mapped_model
return result
async def _convert_request(self, request):
"""
将请求转换为 Claude 格式
Args:
request: 原始请求对象
Returns:
ClaudeMessagesRequest 对象
"""
from src.api.handlers.claude.converter import OpenAIToClaudeConverter
from src.models.claude import ClaudeMessagesRequest
from src.models.openai import OpenAIRequest
# 如果已经是 Claude 格式,直接返回
if isinstance(request, ClaudeMessagesRequest):
return request
# 如果是 OpenAI 格式,转换为 Claude 格式
if isinstance(request, OpenAIRequest):
converter = OpenAIToClaudeConverter()
claude_dict = converter.convert_request(request.dict())
return ClaudeMessagesRequest(**claude_dict)
# 如果是字典,根据内容判断格式
if isinstance(request, dict):
if "messages" in request and len(request["messages"]) > 0:
first_msg = request["messages"][0]
if "role" in first_msg and "content" in first_msg:
# 可能是 OpenAI 格式
converter = OpenAIToClaudeConverter()
claude_dict = converter.convert_request(request)
return ClaudeMessagesRequest(**claude_dict)
# 否则假设已经是 Claude 格式
return ClaudeMessagesRequest(**request)
return request
def _extract_usage(self, response: Dict) -> Dict[str, int]:
"""
从 Claude 响应中提取 token 使用情况
Claude 格式使用:
- input_tokens / output_tokens
- cache_creation_input_tokens / cache_read_input_tokens
"""
usage = response.get("usage", {})
input_tokens = usage.get("input_tokens", 0)
output_tokens = usage.get("output_tokens", 0)
cache_creation_input_tokens = usage.get("cache_creation_input_tokens", 0)
cache_read_input_tokens = usage.get("cache_read_input_tokens", 0)
# 处理新的 cache_creation 格式
if "cache_creation" in usage:
cache_creation_data = usage.get("cache_creation", {})
if not cache_creation_input_tokens:
cache_creation_input_tokens = cache_creation_data.get(
"ephemeral_5m_input_tokens", 0
) + cache_creation_data.get("ephemeral_1h_input_tokens", 0)
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"cache_creation_input_tokens": cache_creation_input_tokens,
"cache_read_input_tokens": cache_read_input_tokens,
}
def _normalize_response(self, response: Dict) -> Dict:
"""
规范化 Claude 响应
Args:
response: 原始响应
Returns:
规范化后的响应
"""
if self.response_normalizer and self.response_normalizer.should_normalize(response):
return self.response_normalizer.normalize_claude_response(
response_data=response,
request_id=self.request_id,
)
return response

View File

@@ -0,0 +1,241 @@
"""
Claude SSE 流解析器
解析 Claude Messages API 的 Server-Sent Events 流。
"""
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional
class ClaudeStreamParser:
"""
Claude SSE 流解析器
解析 Claude Messages API 的 SSE 事件流。
事件类型:
- message_start: 消息开始,包含初始 message 对象
- content_block_start: 内容块开始
- content_block_delta: 内容块增量(文本、工具输入等)
- content_block_stop: 内容块结束
- message_delta: 消息增量,包含 stop_reason 和最终 usage
- message_stop: 消息结束
- ping: 心跳事件
- error: 错误事件
"""
# Claude SSE 事件类型
EVENT_MESSAGE_START = "message_start"
EVENT_MESSAGE_STOP = "message_stop"
EVENT_MESSAGE_DELTA = "message_delta"
EVENT_CONTENT_BLOCK_START = "content_block_start"
EVENT_CONTENT_BLOCK_STOP = "content_block_stop"
EVENT_CONTENT_BLOCK_DELTA = "content_block_delta"
EVENT_PING = "ping"
EVENT_ERROR = "error"
# Delta 类型
DELTA_TEXT = "text_delta"
DELTA_INPUT_JSON = "input_json_delta"
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
"""
解析 SSE 数据块
Args:
chunk: 原始 SSE 数据bytes 或 str
Returns:
解析后的事件列表
"""
if isinstance(chunk, bytes):
text = chunk.decode("utf-8")
else:
text = chunk
events: List[Dict[str, Any]] = []
lines = text.strip().split("\n")
current_event_type: Optional[str] = None
for line in lines:
line = line.strip()
if not line:
continue
# 解析事件类型行
if line.startswith("event: "):
current_event_type = line[7:]
continue
# 解析数据行
if line.startswith("data: "):
data_str = line[6:]
# 处理 [DONE] 标记
if data_str == "[DONE]":
events.append({"type": "__done__", "raw": "[DONE]"})
continue
try:
data = json.loads(data_str)
# 如果数据中没有 type使用事件行的类型
if "type" not in data and current_event_type:
data["type"] = current_event_type
events.append(data)
except json.JSONDecodeError:
# 无法解析的数据,跳过
pass
current_event_type = None
return events
def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
"""
解析单行 SSE 数据
Args:
line: SSE 数据行(已去除 "data: " 前缀)
Returns:
解析后的事件字典,如果无法解析返回 None
"""
if not line or line == "[DONE]":
return None
try:
return json.loads(line)
except json.JSONDecodeError:
return None
def is_done_event(self, event: Dict[str, Any]) -> bool:
"""
判断是否为结束事件
Args:
event: 事件字典
Returns:
True 如果是结束事件
"""
event_type = event.get("type")
return event_type in (self.EVENT_MESSAGE_STOP, "__done__")
def is_error_event(self, event: Dict[str, Any]) -> bool:
"""
判断是否为错误事件
Args:
event: 事件字典
Returns:
True 如果是错误事件
"""
return event.get("type") == self.EVENT_ERROR
def get_event_type(self, event: Dict[str, Any]) -> Optional[str]:
"""
获取事件类型
Args:
event: 事件字典
Returns:
事件类型字符串
"""
return event.get("type")
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
"""
从 content_block_delta 事件中提取文本增量
Args:
event: 事件字典
Returns:
文本增量,如果不是文本 delta 返回 None
"""
if event.get("type") != self.EVENT_CONTENT_BLOCK_DELTA:
return None
delta = event.get("delta", {})
if delta.get("type") == self.DELTA_TEXT:
return delta.get("text")
return None
def extract_usage(self, event: Dict[str, Any]) -> Optional[Dict[str, int]]:
"""
从事件中提取 token 使用量
Args:
event: 事件字典
Returns:
使用量字典,如果没有使用量信息返回 None
"""
event_type = event.get("type")
# message_start 事件包含初始 usage
if event_type == self.EVENT_MESSAGE_START:
message = event.get("message", {})
usage = message.get("usage", {})
if usage:
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
}
# message_delta 事件包含最终 usage
if event_type == self.EVENT_MESSAGE_DELTA:
usage = event.get("usage", {})
if usage:
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
}
return None
def extract_message_id(self, event: Dict[str, Any]) -> Optional[str]:
"""
从 message_start 事件中提取消息 ID
Args:
event: 事件字典
Returns:
消息 ID如果不是 message_start 返回 None
"""
if event.get("type") != self.EVENT_MESSAGE_START:
return None
message = event.get("message", {})
return message.get("id")
def extract_stop_reason(self, event: Dict[str, Any]) -> Optional[str]:
"""
从 message_delta 事件中提取停止原因
Args:
event: 事件字典
Returns:
停止原因,如果没有返回 None
"""
if event.get("type") != self.EVENT_MESSAGE_DELTA:
return None
delta = event.get("delta", {})
return delta.get("stop_reason")
__all__ = ["ClaudeStreamParser"]

View File

@@ -0,0 +1,11 @@
"""
Claude CLI 透传处理器
"""
from src.api.handlers.claude_cli.adapter import ClaudeCliAdapter
from src.api.handlers.claude_cli.handler import ClaudeCliMessageHandler
__all__ = [
"ClaudeCliAdapter",
"ClaudeCliMessageHandler",
]

View File

@@ -0,0 +1,103 @@
"""
Claude CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。
"""
from typing import Any, Dict, Optional, Type
from fastapi import Request
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
from src.api.handlers.claude.adapter import ClaudeCapabilityDetector
@register_cli_adapter
class ClaudeCliAdapter(CliAdapterBase):
"""
Claude CLI API 适配器
处理 Claude CLI 格式的请求(/v1/messages 端点,使用 Bearer 认证)。
"""
FORMAT_ID = "CLAUDE_CLI"
name = "claude.cli"
@property
def HANDLER_CLASS(self) -> Type[CliMessageHandlerBase]:
"""延迟导入 Handler 类避免循环依赖"""
from src.api.handlers.claude_cli.handler import ClaudeCliMessageHandler
return ClaudeCliMessageHandler
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
super().__init__(allowed_api_formats or ["CLAUDE_CLI"])
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (Authorization: Bearer)"""
authorization = request.headers.get("authorization")
if authorization and authorization.startswith("Bearer "):
return authorization.replace("Bearer ", "")
return None
def detect_capability_requirements(
self,
headers: Dict[str, str],
request_body: Optional[Dict[str, Any]] = None,
) -> Dict[str, bool]:
"""检测 Claude CLI 请求中隐含的能力需求"""
return ClaudeCapabilityDetector.detect_from_headers(headers)
# =========================================================================
# Claude CLI 特定的计费逻辑
# =========================================================================
def compute_total_input_context(
self,
input_tokens: int,
cache_read_input_tokens: int,
cache_creation_input_tokens: int = 0,
) -> int:
"""
计算 Claude CLI 的总输入上下文(用于阶梯计费判定)
Claude 的总输入 = input_tokens + cache_creation_input_tokens + cache_read_input_tokens
"""
return input_tokens + cache_creation_input_tokens + cache_read_input_tokens
def _extract_message_count(self, payload: Dict[str, Any]) -> int:
"""Claude CLI 使用 messages 字段"""
messages = payload.get("messages", [])
return len(messages) if isinstance(messages, list) else 0
def _build_audit_metadata(
self,
payload: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
) -> Dict[str, Any]:
"""Claude CLI 特定的审计元数据"""
model = payload.get("model", "unknown")
stream = payload.get("stream", False)
messages = payload.get("messages", [])
role_counts = {}
for msg in messages:
role = msg.get("role", "unknown")
role_counts[role] = role_counts.get(role, 0) + 1
return {
"action": "claude_cli_request",
"model": model,
"stream": bool(stream),
"max_tokens": payload.get("max_tokens"),
"messages_count": len(messages),
"message_roles": role_counts,
"temperature": payload.get("temperature"),
"top_p": payload.get("top_p"),
"tool_count": len(payload.get("tools") or []),
"system_present": bool(payload.get("system")),
}
__all__ = ["ClaudeCliAdapter"]

View File

@@ -0,0 +1,195 @@
"""
Claude CLI Message Handler - 基于通用 CLI Handler 基类的简化实现
继承 CliMessageHandlerBase只需覆盖格式特定的配置和事件处理逻辑。
验证新架构的有效性:代码量从数百行减少到 ~80 行。
"""
from typing import Any, Dict, Optional
from src.api.handlers.base.cli_handler_base import (
CliMessageHandlerBase,
StreamContext,
)
class ClaudeCliMessageHandler(CliMessageHandlerBase):
"""
Claude CLI Message Handler - 处理 Claude CLI API 格式
使用新三层架构 (Provider -> ProviderEndpoint -> ProviderAPIKey)
通过 FallbackOrchestrator 实现自动故障转移、健康监控和并发控制
响应格式特点:
- 使用 content[] 数组
- 使用 text 类型
- 流式事件message_start, content_block_delta, message_delta, message_stop
- 支持 cache_creation_input_tokens 和 cache_read_input_tokens
模型字段:请求体顶级 model 字段
"""
FORMAT_ID = "CLAUDE_CLI"
def extract_model_from_request(
self,
request_body: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
) -> str:
"""
从请求中提取模型名 - Claude 格式实现
Claude API 的 model 在请求体顶级字段。
Args:
request_body: 请求体
path_params: URL 路径参数Claude 不使用)
Returns:
模型名
"""
model = request_body.get("model")
return str(model) if model else "unknown"
def apply_mapped_model(
self,
request_body: Dict[str, Any],
mapped_model: str,
) -> Dict[str, Any]:
"""
Claude API 的 model 在请求体顶级
Args:
request_body: 原始请求体
mapped_model: 映射后的模型名
Returns:
更新了 model 字段的请求体
"""
result = dict(request_body)
result["model"] = mapped_model
return result
def _process_event_data(
self,
ctx: StreamContext,
event_type: str,
data: Dict[str, Any],
) -> None:
"""
处理 Claude CLI 格式的 SSE 事件
事件类型:
- message_start: 消息开始,包含初始 usage含缓存 tokens
- content_block_delta: 文本增量
- message_delta: 消息增量,包含最终 usage
- message_stop: 消息结束
"""
# 处理 message_start 事件
if event_type == "message_start":
message = data.get("message", {})
if message.get("id"):
ctx.response_id = message["id"]
# 提取初始 usage包含缓存 tokens
usage = message.get("usage", {})
if usage:
ctx.input_tokens = usage.get("input_tokens", 0)
# Claude 的缓存 tokens 使用不同的字段名
cache_read = usage.get("cache_read_input_tokens", 0)
if cache_read:
ctx.cached_tokens = cache_read
cache_creation = usage.get("cache_creation_input_tokens", 0)
if cache_creation:
ctx.cache_creation_tokens = cache_creation
# 处理文本增量
elif event_type == "content_block_delta":
delta = data.get("delta", {})
if delta.get("type") == "text_delta":
text = delta.get("text", "")
if text:
ctx.collected_text += text
# 处理消息增量(包含最终 usage
elif event_type == "message_delta":
usage = data.get("usage", {})
if usage:
if "input_tokens" in usage:
ctx.input_tokens = usage["input_tokens"]
if "output_tokens" in usage:
ctx.output_tokens = usage["output_tokens"]
# 更新缓存 tokens
if "cache_read_input_tokens" in usage:
ctx.cached_tokens = usage["cache_read_input_tokens"]
if "cache_creation_input_tokens" in usage:
ctx.cache_creation_tokens = usage["cache_creation_input_tokens"]
# 检查是否结束
delta = data.get("delta", {})
if delta.get("stop_reason"):
ctx.has_completion = True
ctx.final_response = data
# 处理消息结束
elif event_type == "message_stop":
ctx.has_completion = True
def _extract_response_metadata(
self,
response: Dict[str, Any],
) -> Dict[str, Any]:
"""
从 Claude 响应中提取元数据
提取 model、stop_reason 等字段作为元数据。
Args:
response: Claude API 响应
Returns:
提取的元数据字典
"""
metadata: Dict[str, Any] = {}
# 提取模型名称(实际使用的模型)
if "model" in response:
metadata["model"] = response["model"]
# 提取停止原因
if "stop_reason" in response:
metadata["stop_reason"] = response["stop_reason"]
# 提取消息 ID
if "id" in response:
metadata["message_id"] = response["id"]
# 提取消息类型
if "type" in response:
metadata["type"] = response["type"]
return metadata
def _finalize_stream_metadata(self, ctx: StreamContext) -> None:
"""
从流上下文中提取最终元数据
在流传输完成后调用,从收集的事件中提取元数据。
Args:
ctx: 流上下文
"""
# 从 response_id 提取消息 ID
if ctx.response_id:
ctx.response_metadata["message_id"] = ctx.response_id
# 从 final_response 提取停止原因message_delta 事件中的 delta.stop_reason
if ctx.final_response:
delta = ctx.final_response.get("delta", {})
if "stop_reason" in delta:
ctx.response_metadata["stop_reason"] = delta["stop_reason"]
# 记录模型名称
if ctx.model:
ctx.response_metadata["model"] = ctx.model

View File

@@ -0,0 +1,26 @@
"""
Gemini API Handler 模块
提供 Gemini API 格式的请求处理
"""
from src.api.handlers.gemini.adapter import GeminiChatAdapter, build_gemini_adapter
from src.api.handlers.gemini.converter import (
ClaudeToGeminiConverter,
GeminiToClaudeConverter,
GeminiToOpenAIConverter,
OpenAIToGeminiConverter,
)
from src.api.handlers.gemini.handler import GeminiChatHandler
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
__all__ = [
"GeminiChatAdapter",
"GeminiChatHandler",
"GeminiStreamParser",
"ClaudeToGeminiConverter",
"GeminiToClaudeConverter",
"OpenAIToGeminiConverter",
"GeminiToOpenAIConverter",
"build_gemini_adapter",
]

View File

@@ -0,0 +1,170 @@
"""
Gemini Chat Adapter
处理 Gemini API 格式的请求适配
"""
from typing import Any, Dict, Optional, Type
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.core.logger import logger
from src.models.gemini import GeminiRequest
@register_adapter
class GeminiChatAdapter(ChatAdapterBase):
"""
Gemini Chat API 适配器
处理 Gemini Chat 格式的请求
端点: /v1beta/models/{model}:generateContent
"""
FORMAT_ID = "GEMINI"
name = "gemini.chat"
@property
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
"""延迟导入 Handler 类避免循环依赖"""
from src.api.handlers.gemini.handler import GeminiChatHandler
return GeminiChatHandler
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
super().__init__(allowed_api_formats or ["GEMINI"])
logger.info(f"[{self.name}] 初始化 Gemini Chat 适配器 | API格式: {self.allowed_api_formats}")
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (x-goog-api-key)"""
return request.headers.get("x-goog-api-key")
def _merge_path_params(
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any] # noqa: ARG002
) -> Dict[str, Any]:
"""
合并 URL 路径参数到请求体 - Gemini 特化版本
Gemini API 特点:
- model 不合并到请求体(通过 extract_model_from_request 从 path_params 获取)
- stream 不合并到请求体Gemini API 通过 URL 端点区分流式/非流式)
Handler 层的 extract_model_from_request 会从 path_params 获取 model
prepare_provider_request_body 会确保发送给 Gemini API 的请求体不含 model。
Args:
original_request_body: 原始请求体字典
path_params: URL 路径参数字典(不使用)
Returns:
原始请求体(不合并任何 path_params
"""
return original_request_body.copy()
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
"""验证请求体"""
path_params = path_params or {}
is_stream = path_params.get("stream", False)
model = path_params.get("model", "unknown")
try:
if not isinstance(original_request_body, dict):
raise ValueError("Request body must be a JSON object")
# Gemini 必需字段: contents
if "contents" not in original_request_body:
raise ValueError("Missing required field: contents")
request = GeminiRequest.model_validate(
original_request_body,
strict=False,
)
except ValueError as e:
logger.error(f"请求体基本验证失败: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
request = GeminiRequest.model_construct(
contents=original_request_body.get("contents", []),
)
# 设置 model从 path_params 获取,用于日志和审计)
request.model = model
# 设置 stream 属性(用于 ChatAdapterBase 判断流式模式)
request.stream = is_stream
return request
def _extract_message_count(self, payload: Dict[str, Any], request_obj) -> int:
"""提取消息数量"""
contents = payload.get("contents", [])
if hasattr(request_obj, "contents"):
contents = request_obj.contents
return len(contents) if isinstance(contents, list) else 0
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
"""构建 Gemini Chat 特定的审计元数据"""
role_counts: dict[str, int] = {}
contents = getattr(request_obj, "contents", []) or []
for content in contents:
role = getattr(content, "role", None) or content.get("role", "unknown")
role_counts[role] = role_counts.get(role, 0) + 1
generation_config = getattr(request_obj, "generation_config", None) or {}
if hasattr(generation_config, "dict"):
generation_config = generation_config.dict()
elif not isinstance(generation_config, dict):
generation_config = {}
# 判断流式模式
stream = getattr(request_obj, "stream", False)
return {
"action": "gemini_generate_content",
"model": getattr(request_obj, "model", payload.get("model", "unknown")),
"stream": bool(stream),
"max_output_tokens": generation_config.get("max_output_tokens"),
"temperature": generation_config.get("temperature"),
"top_p": generation_config.get("top_p"),
"top_k": generation_config.get("top_k"),
"contents_count": len(contents),
"content_roles": role_counts,
"tools_count": len(getattr(request_obj, "tools", None) or []),
"system_instruction_present": bool(getattr(request_obj, "system_instruction", None)),
"safety_settings_count": len(getattr(request_obj, "safety_settings", None) or []),
}
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
"""生成 Gemini 格式的错误响应"""
# Gemini 错误响应格式
return JSONResponse(
status_code=status_code,
content={
"error": {
"code": status_code,
"message": message,
"status": error_type.upper(),
}
},
)
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
"""
根据请求头构建适当的 Gemini 适配器
Args:
x_app_header: X-App 请求头值
Returns:
GeminiChatAdapter 实例
"""
# 目前只有一种 Gemini 适配器
# 未来可以根据 x_app_header 返回不同的适配器(如 CLI 模式)
return GeminiChatAdapter()
__all__ = ["GeminiChatAdapter", "build_gemini_adapter"]

View File

@@ -0,0 +1,544 @@
"""
Gemini 格式转换器
提供 Gemini 与其他 API 格式Claude、OpenAI之间的转换
"""
from typing import Any, Dict, List, Optional
class ClaudeToGeminiConverter:
"""
Claude -> Gemini 请求转换器
将 Claude Messages API 格式转换为 Gemini generateContent 格式
"""
def convert_request(self, claude_request: Dict[str, Any]) -> Dict[str, Any]:
"""
将 Claude 请求转换为 Gemini 请求
Args:
claude_request: Claude 格式的请求字典
Returns:
Gemini 格式的请求字典
"""
gemini_request: Dict[str, Any] = {
"contents": self._convert_messages(claude_request.get("messages", [])),
}
# 转换 system prompt
system = claude_request.get("system")
if system:
gemini_request["system_instruction"] = self._convert_system(system)
# 转换生成配置
generation_config = self._build_generation_config(claude_request)
if generation_config:
gemini_request["generation_config"] = generation_config
# 转换工具
tools = claude_request.get("tools")
if tools:
gemini_request["tools"] = self._convert_tools(tools)
return gemini_request
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""转换消息列表"""
contents = []
for msg in messages:
role = msg.get("role", "user")
# Gemini 使用 "model" 而不是 "assistant"
gemini_role = "model" if role == "assistant" else "user"
content = msg.get("content", "")
parts = self._convert_content_to_parts(content)
contents.append(
{
"role": gemini_role,
"parts": parts,
}
)
return contents
def _convert_content_to_parts(self, content: Any) -> List[Dict[str, Any]]:
"""将 Claude 内容转换为 Gemini parts"""
if isinstance(content, str):
return [{"text": content}]
if isinstance(content, list):
parts = []
for block in content:
if isinstance(block, str):
parts.append({"text": block})
elif isinstance(block, dict):
block_type = block.get("type")
if block_type == "text":
parts.append({"text": block.get("text", "")})
elif block_type == "image":
# 转换图片
source = block.get("source", {})
if source.get("type") == "base64":
parts.append(
{
"inline_data": {
"mime_type": source.get("media_type", "image/png"),
"data": source.get("data", ""),
}
}
)
elif block_type == "tool_use":
# 转换工具调用
parts.append(
{
"function_call": {
"name": block.get("name", ""),
"args": block.get("input", {}),
}
}
)
elif block_type == "tool_result":
# 转换工具结果
parts.append(
{
"function_response": {
"name": block.get("tool_use_id", ""),
"response": {"result": block.get("content", "")},
}
}
)
return parts
return [{"text": str(content)}]
def _convert_system(self, system: Any) -> Dict[str, Any]:
"""转换 system prompt"""
if isinstance(system, str):
return {"parts": [{"text": system}]}
if isinstance(system, list):
parts = []
for item in system:
if isinstance(item, str):
parts.append({"text": item})
elif isinstance(item, dict) and item.get("type") == "text":
parts.append({"text": item.get("text", "")})
return {"parts": parts}
return {"parts": [{"text": str(system)}]}
def _build_generation_config(self, claude_request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""构建生成配置"""
config: Dict[str, Any] = {}
if "max_tokens" in claude_request:
config["max_output_tokens"] = claude_request["max_tokens"]
if "temperature" in claude_request:
config["temperature"] = claude_request["temperature"]
if "top_p" in claude_request:
config["top_p"] = claude_request["top_p"]
if "top_k" in claude_request:
config["top_k"] = claude_request["top_k"]
if "stop_sequences" in claude_request:
config["stop_sequences"] = claude_request["stop_sequences"]
return config if config else None
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""转换工具定义"""
function_declarations = []
for tool in tools:
func_decl = {
"name": tool.get("name", ""),
}
if "description" in tool:
func_decl["description"] = tool["description"]
if "input_schema" in tool:
func_decl["parameters"] = tool["input_schema"]
function_declarations.append(func_decl)
return [{"function_declarations": function_declarations}]
class GeminiToClaudeConverter:
"""
Gemini -> Claude 响应转换器
将 Gemini generateContent 响应转换为 Claude Messages API 格式
"""
def convert_response(self, gemini_response: Dict[str, Any]) -> Dict[str, Any]:
"""
将 Gemini 响应转换为 Claude 响应
Args:
gemini_response: Gemini 格式的响应字典
Returns:
Claude 格式的响应字典
"""
candidates = gemini_response.get("candidates", [])
if not candidates:
return self._create_empty_response()
candidate = candidates[0]
content = candidate.get("content", {})
parts = content.get("parts", [])
# 转换内容块
claude_content = self._convert_parts_to_content(parts)
# 转换使用量
usage = self._convert_usage(gemini_response.get("usageMetadata", {}))
# 转换停止原因
stop_reason = self._convert_finish_reason(candidate.get("finishReason"))
return {
"id": f"msg_{gemini_response.get('modelVersion', 'gemini')}",
"type": "message",
"role": "assistant",
"content": claude_content,
"model": gemini_response.get("modelVersion", "gemini"),
"stop_reason": stop_reason,
"stop_sequence": None,
"usage": usage,
}
def _convert_parts_to_content(self, parts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""将 Gemini parts 转换为 Claude content blocks"""
content = []
for part in parts:
if "text" in part:
content.append(
{
"type": "text",
"text": part["text"],
}
)
elif "functionCall" in part:
func_call = part["functionCall"]
content.append(
{
"type": "tool_use",
"id": f"toolu_{func_call.get('name', '')}",
"name": func_call.get("name", ""),
"input": func_call.get("args", {}),
}
)
return content
def _convert_usage(self, usage_metadata: Dict[str, Any]) -> Dict[str, int]:
"""转换使用量信息"""
return {
"input_tokens": usage_metadata.get("promptTokenCount", 0),
"output_tokens": usage_metadata.get("candidatesTokenCount", 0),
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": usage_metadata.get("cachedContentTokenCount", 0),
}
def _convert_finish_reason(self, finish_reason: Optional[str]) -> Optional[str]:
"""转换停止原因"""
mapping = {
"STOP": "end_turn",
"MAX_TOKENS": "max_tokens",
"SAFETY": "content_filtered",
"RECITATION": "content_filtered",
"OTHER": "stop_sequence",
}
return mapping.get(finish_reason, "end_turn")
def _create_empty_response(self) -> Dict[str, Any]:
"""创建空响应"""
return {
"id": "msg_empty",
"type": "message",
"role": "assistant",
"content": [],
"model": "gemini",
"stop_reason": "end_turn",
"stop_sequence": None,
"usage": {
"input_tokens": 0,
"output_tokens": 0,
},
}
class OpenAIToGeminiConverter:
"""
OpenAI -> Gemini 请求转换器
将 OpenAI Chat Completions API 格式转换为 Gemini generateContent 格式
"""
def convert_request(self, openai_request: Dict[str, Any]) -> Dict[str, Any]:
"""
将 OpenAI 请求转换为 Gemini 请求
Args:
openai_request: OpenAI 格式的请求字典
Returns:
Gemini 格式的请求字典
"""
messages = openai_request.get("messages", [])
# 分离 system 消息和其他消息
system_messages = []
other_messages = []
for msg in messages:
if msg.get("role") == "system":
system_messages.append(msg)
else:
other_messages.append(msg)
gemini_request: Dict[str, Any] = {
"contents": self._convert_messages(other_messages),
}
# 转换 system messages
if system_messages:
system_text = "\n".join(msg.get("content", "") for msg in system_messages)
gemini_request["system_instruction"] = {"parts": [{"text": system_text}]}
# 转换生成配置
generation_config = self._build_generation_config(openai_request)
if generation_config:
gemini_request["generation_config"] = generation_config
# 转换工具
tools = openai_request.get("tools")
if tools:
gemini_request["tools"] = self._convert_tools(tools)
return gemini_request
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""转换消息列表"""
contents = []
for msg in messages:
role = msg.get("role", "user")
gemini_role = "model" if role == "assistant" else "user"
content = msg.get("content", "")
parts = self._convert_content_to_parts(content)
# 处理工具调用
tool_calls = msg.get("tool_calls", [])
for tc in tool_calls:
if tc.get("type") == "function":
func = tc.get("function", {})
import json
try:
args = json.loads(func.get("arguments", "{}"))
except json.JSONDecodeError:
args = {}
parts.append(
{
"function_call": {
"name": func.get("name", ""),
"args": args,
}
}
)
if parts:
contents.append(
{
"role": gemini_role,
"parts": parts,
}
)
return contents
def _convert_content_to_parts(self, content: Any) -> List[Dict[str, Any]]:
"""将 OpenAI 内容转换为 Gemini parts"""
if content is None:
return []
if isinstance(content, str):
return [{"text": content}]
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, str):
parts.append({"text": item})
elif isinstance(item, dict):
item_type = item.get("type")
if item_type == "text":
parts.append({"text": item.get("text", "")})
elif item_type == "image_url":
# OpenAI 图片 URL 格式
image_url = item.get("image_url", {})
url = image_url.get("url", "")
if url.startswith("data:"):
# base64 数据 URL
# 格式: 
try:
header, data = url.split(",", 1)
mime_type = header.split(":")[1].split(";")[0]
parts.append(
{
"inline_data": {
"mime_type": mime_type,
"data": data,
}
}
)
except (ValueError, IndexError):
pass
return parts
return [{"text": str(content)}]
def _build_generation_config(self, openai_request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""构建生成配置"""
config: Dict[str, Any] = {}
if "max_tokens" in openai_request:
config["max_output_tokens"] = openai_request["max_tokens"]
if "temperature" in openai_request:
config["temperature"] = openai_request["temperature"]
if "top_p" in openai_request:
config["top_p"] = openai_request["top_p"]
if "stop" in openai_request:
stop = openai_request["stop"]
if isinstance(stop, str):
config["stop_sequences"] = [stop]
elif isinstance(stop, list):
config["stop_sequences"] = stop
if "n" in openai_request:
config["candidate_count"] = openai_request["n"]
return config if config else None
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""转换工具定义"""
function_declarations = []
for tool in tools:
if tool.get("type") == "function":
func = tool.get("function", {})
func_decl = {
"name": func.get("name", ""),
}
if "description" in func:
func_decl["description"] = func["description"]
if "parameters" in func:
func_decl["parameters"] = func["parameters"]
function_declarations.append(func_decl)
return [{"function_declarations": function_declarations}]
class GeminiToOpenAIConverter:
"""
Gemini -> OpenAI 响应转换器
将 Gemini generateContent 响应转换为 OpenAI Chat Completions API 格式
"""
def convert_response(self, gemini_response: Dict[str, Any]) -> Dict[str, Any]:
"""
将 Gemini 响应转换为 OpenAI 响应
Args:
gemini_response: Gemini 格式的响应字典
Returns:
OpenAI 格式的响应字典
"""
import time
candidates = gemini_response.get("candidates", [])
choices = []
for i, candidate in enumerate(candidates):
content = candidate.get("content", {})
parts = content.get("parts", [])
# 提取文本内容
text_parts = []
tool_calls = []
for part in parts:
if "text" in part:
text_parts.append(part["text"])
elif "functionCall" in part:
func_call = part["functionCall"]
import json
tool_calls.append(
{
"id": f"call_{func_call.get('name', '')}_{i}",
"type": "function",
"function": {
"name": func_call.get("name", ""),
"arguments": json.dumps(func_call.get("args", {})),
},
}
)
message: Dict[str, Any] = {
"role": "assistant",
"content": "".join(text_parts) if text_parts else None,
}
if tool_calls:
message["tool_calls"] = tool_calls
finish_reason = self._convert_finish_reason(candidate.get("finishReason"))
choices.append(
{
"index": i,
"message": message,
"finish_reason": finish_reason,
}
)
# 转换使用量
usage = self._convert_usage(gemini_response.get("usageMetadata", {}))
return {
"id": f"chatcmpl-{gemini_response.get('modelVersion', 'gemini')}",
"object": "chat.completion",
"created": int(time.time()),
"model": gemini_response.get("modelVersion", "gemini"),
"choices": choices,
"usage": usage,
}
def _convert_usage(self, usage_metadata: Dict[str, Any]) -> Dict[str, int]:
"""转换使用量信息"""
prompt_tokens = usage_metadata.get("promptTokenCount", 0)
completion_tokens = usage_metadata.get("candidatesTokenCount", 0)
return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
def _convert_finish_reason(self, finish_reason: Optional[str]) -> Optional[str]:
"""转换停止原因"""
mapping = {
"STOP": "stop",
"MAX_TOKENS": "length",
"SAFETY": "content_filter",
"RECITATION": "content_filter",
"OTHER": "stop",
}
return mapping.get(finish_reason, "stop")
__all__ = [
"ClaudeToGeminiConverter",
"GeminiToClaudeConverter",
"OpenAIToGeminiConverter",
"GeminiToOpenAIConverter",
]

View File

@@ -0,0 +1,164 @@
"""
Gemini Chat Handler
处理 Gemini API 格式的请求
"""
from typing import Any, Dict, Optional
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
class GeminiChatHandler(ChatHandlerBase):
"""
Gemini Chat Handler - 处理 Google Gemini API 格式的请求
格式特点:
- 使用 promptTokenCount / candidatesTokenCount
- 支持 cachedContentTokenCount
- 请求格式: GeminiRequest
- 响应格式: JSON 数组流(非 SSE
"""
FORMAT_ID = "GEMINI"
def extract_model_from_request(
self,
request_body: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None,
) -> str:
"""
从请求中提取模型名 - Gemini Chat 格式实现
Gemini Chat 模式下model 在请求体中(经过转换后的 GeminiRequest
与 Gemini CLI 不同CLI 模式的 model 在 URL 路径中。
Args:
request_body: 请求体
path_params: URL 路径参数Chat 模式通常不使用)
Returns:
模型名
"""
# 优先从请求体获取,其次从 path_params
model = request_body.get("model")
if model:
return str(model)
if path_params and "model" in path_params:
return str(path_params["model"])
return "unknown"
async def _convert_request(self, request):
"""
将请求转换为 Gemini 格式
支持自动转换:
- Claude 格式 → Gemini 格式
- OpenAI 格式 → Gemini 格式
Args:
request: 原始请求对象(可能是 Gemini/Claude/OpenAI 格式)
Returns:
GeminiRequest 对象
"""
from src.api.handlers.gemini.converter import (
ClaudeToGeminiConverter,
OpenAIToGeminiConverter,
)
from src.models.claude import ClaudeMessagesRequest
from src.models.gemini import GeminiRequest
from src.models.openai import OpenAIRequest
# 如果已经是 Gemini 格式,直接返回
if isinstance(request, GeminiRequest):
return request
# 如果是 Claude 格式,转换为 Gemini 格式
if isinstance(request, ClaudeMessagesRequest):
converter = ClaudeToGeminiConverter()
gemini_dict = converter.convert_request(request.model_dump())
return GeminiRequest(**gemini_dict)
# 如果是 OpenAI 格式,转换为 Gemini 格式
if isinstance(request, OpenAIRequest):
converter = OpenAIToGeminiConverter()
gemini_dict = converter.convert_request(request.model_dump())
return GeminiRequest(**gemini_dict)
# 如果是字典,根据内容判断格式并转换
if isinstance(request, dict):
# 检测 Gemini 格式特征: contents 字段
if "contents" in request:
return GeminiRequest(**request)
# 检测 Claude 格式特征: messages + 没有 choices
if "messages" in request and "choices" not in request:
# 进一步区分 Claude 和 OpenAI
# Claude 使用 max_tokensOpenAI 也可能有
# Claude 的 messages[].content 可以是数组OpenAI 通常是字符串
messages = request.get("messages", [])
if messages and isinstance(messages[0].get("content"), list):
# 可能是 Claude 格式
converter = ClaudeToGeminiConverter()
gemini_dict = converter.convert_request(request)
return GeminiRequest(**gemini_dict)
else:
# 可能是 OpenAI 格式
converter = OpenAIToGeminiConverter()
gemini_dict = converter.convert_request(request)
return GeminiRequest(**gemini_dict)
# 默认尝试作为 Gemini 格式
return GeminiRequest(**request)
return request
def _extract_usage(self, response: Dict) -> Dict[str, int]:
"""
从 Gemini 响应中提取 token 使用情况
调用 GeminiStreamParser.extract_usage 作为单一实现源
"""
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
usage = GeminiStreamParser().extract_usage(response)
if not usage:
return {
"input_tokens": 0,
"output_tokens": 0,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
}
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cache_creation_input_tokens": 0, # Gemini 不区分缓存创建
"cache_read_input_tokens": usage.get("cached_tokens", 0),
}
def _normalize_response(self, response: Dict) -> Dict:
"""
规范化 Gemini 响应
Args:
response: 原始响应
Returns:
规范化后的响应
TODO: 如果需要,实现响应规范化逻辑
"""
# 可选:使用 response_normalizer 进行规范化
# if (
# self.response_normalizer
# and self.response_normalizer.should_normalize(response)
# ):
# return self.response_normalizer.normalize_gemini_response(
# response_data=response,
# request_id=self.request_id,
# strict=False,
# )
return response

View File

@@ -0,0 +1,307 @@
"""
Gemini SSE/JSON 流解析器
Gemini API 的流式响应格式与 Claude/OpenAI 不同:
- 使用 JSON 数组格式 (不是 SSE)
- 每个块是一个完整的 JSON 对象
- 响应以 [ 开始,以 ] 结束,块之间用 , 分隔
参考: https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
"""
import json
from typing import Any, Dict, List, Optional
class GeminiStreamParser:
"""
Gemini 流解析器
解析 Gemini streamGenerateContent API 的响应流。
Gemini 流式响应特点:
- 返回 JSON 数组格式: [{chunk1}, {chunk2}, ...]
- 每个 chunk 包含 candidates、usageMetadata 等字段
- finish_reason 可能值: STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
"""
# 停止原因
FINISH_REASON_STOP = "STOP"
FINISH_REASON_MAX_TOKENS = "MAX_TOKENS"
FINISH_REASON_SAFETY = "SAFETY"
FINISH_REASON_RECITATION = "RECITATION"
FINISH_REASON_OTHER = "OTHER"
def __init__(self):
self._buffer = ""
self._in_array = False
self._brace_depth = 0
def reset(self):
"""重置解析器状态"""
self._buffer = ""
self._in_array = False
self._brace_depth = 0
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
"""
解析流式数据块
Args:
chunk: 原始数据bytes 或 str
Returns:
解析后的事件列表
"""
if isinstance(chunk, bytes):
text = chunk.decode("utf-8")
else:
text = chunk
events: List[Dict[str, Any]] = []
for char in text:
if char == "[" and not self._in_array:
self._in_array = True
continue
if char == "]" and self._in_array and self._brace_depth == 0:
# 数组结束
self._in_array = False
if self._buffer.strip():
try:
obj = json.loads(self._buffer.strip().rstrip(","))
events.append(obj)
except json.JSONDecodeError:
pass
self._buffer = ""
continue
if self._in_array:
if char == "{":
self._brace_depth += 1
elif char == "}":
self._brace_depth -= 1
self._buffer += char
# 当 brace_depth 回到 0 时,说明一个完整的 JSON 对象结束
if self._brace_depth == 0 and self._buffer.strip():
try:
obj = json.loads(self._buffer.strip().rstrip(","))
events.append(obj)
self._buffer = ""
except json.JSONDecodeError:
# 可能还不完整,继续累积
pass
return events
def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
"""
解析单行 JSON 数据
Args:
line: JSON 数据行
Returns:
解析后的事件字典,如果无法解析返回 None
"""
if not line or line.strip() in ["[", "]", ","]:
return None
try:
return json.loads(line.strip().rstrip(","))
except json.JSONDecodeError:
return None
def is_done_event(self, event: Dict[str, Any]) -> bool:
"""
判断是否为结束事件
Args:
event: 事件字典
Returns:
True 如果是结束事件
"""
candidates = event.get("candidates", [])
if candidates:
for candidate in candidates:
finish_reason = candidate.get("finishReason")
if finish_reason in (
self.FINISH_REASON_STOP,
self.FINISH_REASON_MAX_TOKENS,
self.FINISH_REASON_SAFETY,
self.FINISH_REASON_RECITATION,
self.FINISH_REASON_OTHER,
):
return True
return False
def is_error_event(self, event: Dict[str, Any]) -> bool:
"""
判断是否为错误事件
检测多种 Gemini 错误格式:
1. 顶层 error: {"error": {...}}
2. chunks 内嵌套 error: {"chunks": [{"error": {...}}]}
3. candidates 内的错误状态
Args:
event: 事件字典
Returns:
True 如果是错误事件
"""
# 顶层 error
if "error" in event:
return True
# chunks 内嵌套 error (某些 Gemini 响应格式)
chunks = event.get("chunks", [])
if chunks and isinstance(chunks, list):
for chunk in chunks:
if isinstance(chunk, dict) and "error" in chunk:
return True
return False
def extract_error_info(self, event: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
从事件中提取错误信息
Args:
event: 事件字典
Returns:
错误信息字典 {"code": int, "message": str, "status": str},无错误返回 None
"""
# 顶层 error
if "error" in event:
error = event["error"]
if isinstance(error, dict):
return {
"code": error.get("code"),
"message": error.get("message", str(error)),
"status": error.get("status"),
}
return {"code": None, "message": str(error), "status": None}
# chunks 内嵌套 error
chunks = event.get("chunks", [])
if chunks and isinstance(chunks, list):
for chunk in chunks:
if isinstance(chunk, dict) and "error" in chunk:
error = chunk["error"]
if isinstance(error, dict):
return {
"code": error.get("code"),
"message": error.get("message", str(error)),
"status": error.get("status"),
}
return {"code": None, "message": str(error), "status": None}
return None
def get_finish_reason(self, event: Dict[str, Any]) -> Optional[str]:
"""
获取结束原因
Args:
event: 事件字典
Returns:
结束原因字符串
"""
candidates = event.get("candidates", [])
if candidates:
return candidates[0].get("finishReason")
return None
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
"""
从响应中提取文本内容
Args:
event: 事件字典
Returns:
文本内容,如果没有文本返回 None
"""
candidates = event.get("candidates", [])
if not candidates:
return None
content = candidates[0].get("content", {})
parts = content.get("parts", [])
text_parts = []
for part in parts:
if "text" in part:
text_parts.append(part["text"])
return "".join(text_parts) if text_parts else None
def extract_usage(self, event: Dict[str, Any]) -> Optional[Dict[str, int]]:
"""
从事件中提取 token 使用量
这是 Gemini token 提取的单一实现源,其他地方都应该调用此方法。
Args:
event: 事件字典(包含 usageMetadata
Returns:
使用量字典,如果没有完整的使用量信息返回 None
注意:
- 只有当 totalTokenCount 存在时才提取(确保是完整的 usage 数据)
- 输出 token = thoughtsTokenCount + candidatesTokenCount
"""
usage_metadata = event.get("usageMetadata", {})
if not usage_metadata or "totalTokenCount" not in usage_metadata:
return None
# 输出 token = thoughtsTokenCount + candidatesTokenCount
thoughts_tokens = usage_metadata.get("thoughtsTokenCount", 0)
candidates_tokens = usage_metadata.get("candidatesTokenCount", 0)
output_tokens = thoughts_tokens + candidates_tokens
return {
"input_tokens": usage_metadata.get("promptTokenCount", 0),
"output_tokens": output_tokens,
"total_tokens": usage_metadata.get("totalTokenCount", 0),
"cached_tokens": usage_metadata.get("cachedContentTokenCount", 0),
}
def extract_model_version(self, event: Dict[str, Any]) -> Optional[str]:
"""
从响应中提取模型版本
Args:
event: 事件字典
Returns:
模型版本,如果没有返回 None
"""
return event.get("modelVersion")
def extract_safety_ratings(self, event: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
"""
从响应中提取安全评级
Args:
event: 事件字典
Returns:
安全评级列表,如果没有返回 None
"""
candidates = event.get("candidates", [])
if not candidates:
return None
return candidates[0].get("safetyRatings")
__all__ = ["GeminiStreamParser"]

View File

@@ -0,0 +1,12 @@
"""
Gemini CLI 透传处理器
"""
from src.api.handlers.gemini_cli.adapter import GeminiCliAdapter, build_gemini_cli_adapter
from src.api.handlers.gemini_cli.handler import GeminiCliMessageHandler
__all__ = [
"GeminiCliAdapter",
"GeminiCliMessageHandler",
"build_gemini_cli_adapter",
]

View File

@@ -0,0 +1,112 @@
"""
Gemini CLI Adapter - 基于通用 CLI Adapter 基类的实现
继承 CliAdapterBase处理 Gemini CLI 格式的请求。
"""
from typing import Any, Dict, Optional, Type
from fastapi import Request
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
@register_cli_adapter
class GeminiCliAdapter(CliAdapterBase):
"""
Gemini CLI API 适配器
处理 Gemini CLI 格式的请求(透传模式,最小验证)。
"""
FORMAT_ID = "GEMINI_CLI"
name = "gemini.cli"
@property
def HANDLER_CLASS(self) -> Type[CliMessageHandlerBase]:
"""延迟导入 Handler 类避免循环依赖"""
from src.api.handlers.gemini_cli.handler import GeminiCliMessageHandler
return GeminiCliMessageHandler
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
super().__init__(allowed_api_formats or ["GEMINI_CLI"])
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (x-goog-api-key)"""
return request.headers.get("x-goog-api-key")
def _merge_path_params(
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any] # noqa: ARG002
) -> Dict[str, Any]:
"""
合并 URL 路径参数到请求体 - Gemini CLI 特化版本
Gemini API 特点:
- model 不合并到请求体Gemini 原生请求体不含 model通过 URL 路径传递)
- stream 不合并到请求体Gemini API 通过 URL 端点区分流式/非流式)
基类已经从 path_params 获取 model 和 stream 用于日志和路由判断。
Args:
original_request_body: 原始请求体字典
path_params: URL 路径参数字典(包含 model、stream 等)
Returns:
原始请求体(不合并任何 path_params
"""
# Gemini: 不合并任何 path_params 到请求体
return original_request_body.copy()
def _extract_message_count(self, payload: Dict[str, Any]) -> int:
"""Gemini CLI 使用 contents 字段"""
contents = payload.get("contents", [])
return len(contents) if isinstance(contents, list) else 0
def _build_audit_metadata(
self,
payload: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Gemini CLI 特定的审计元数据"""
# 从 path_params 获取 modelGemini 请求体不含 model
model = path_params.get("model", "unknown") if path_params else "unknown"
contents = payload.get("contents", [])
generation_config = payload.get("generation_config", {}) or {}
role_counts: Dict[str, int] = {}
for content in contents:
role = content.get("role", "unknown") if isinstance(content, dict) else "unknown"
role_counts[role] = role_counts.get(role, 0) + 1
return {
"action": "gemini_cli_request",
"model": model,
"stream": bool(payload.get("stream", False)),
"max_output_tokens": generation_config.get("max_output_tokens"),
"contents_count": len(contents),
"content_roles": role_counts,
"temperature": generation_config.get("temperature"),
"top_p": generation_config.get("top_p"),
"top_k": generation_config.get("top_k"),
"tools_count": len(payload.get("tools") or []),
"system_instruction_present": bool(payload.get("system_instruction")),
"safety_settings_count": len(payload.get("safety_settings") or []),
}
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
"""
构建 Gemini CLI 适配器
Args:
x_app_header: X-App 请求头值(预留扩展)
Returns:
GeminiCliAdapter 实例
"""
return GeminiCliAdapter()
__all__ = ["GeminiCliAdapter", "build_gemini_cli_adapter"]

View File

@@ -0,0 +1,210 @@
"""
Gemini CLI Message Handler - 基于通用 CLI Handler 基类的实现
继承 CliMessageHandlerBase处理 Gemini CLI API 格式的请求。
"""
from typing import Any, Dict, Optional
from src.api.handlers.base.cli_handler_base import (
CliMessageHandlerBase,
StreamContext,
)
class GeminiCliMessageHandler(CliMessageHandlerBase):
"""
Gemini CLI Message Handler - 处理 Gemini CLI API 格式
使用新三层架构 (Provider -> ProviderEndpoint -> ProviderAPIKey)
通过 FallbackOrchestrator 实现自动故障转移、健康监控和并发控制
响应格式特点:
- Gemini 使用 JSON 数组格式流式响应(非 SSE
- 每个 chunk 包含 candidates、usageMetadata 等字段
- finish_reason: STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
- Token 使用: promptTokenCount (输入), thoughtsTokenCount + candidatesTokenCount (输出), cachedContentTokenCount (缓存)
Gemini API 特殊处理:
- model 在 URL 路径中而非请求体,如 /v1beta/models/{model}:generateContent
- 请求体中的 model 字段用于内部路由,不发送给 API
"""
FORMAT_ID = "GEMINI_CLI"
def extract_model_from_request(
self,
request_body: Dict[str, Any], # noqa: ARG002 - 基类签名要求
path_params: Optional[Dict[str, Any]] = None,
) -> str:
"""
从请求中提取模型名 - Gemini 格式实现
Gemini API 的 model 在 URL 路径中而非请求体:
/v1beta/models/{model}:generateContent
Args:
request_body: 请求体Gemini 不包含 model
path_params: URL 路径参数(包含 model
Returns:
模型名,如果无法提取则返回 "unknown"
"""
# Gemini: model 从 URL 路径参数获取
if path_params and "model" in path_params:
return str(path_params["model"])
return "unknown"
def prepare_provider_request_body(
self,
request_body: Dict[str, Any],
) -> Dict[str, Any]:
"""
准备发送给 Gemini API 的请求体 - 移除 model 字段
Gemini API 要求 model 只在 URL 路径中,请求体中的 model 字段
会导致某些代理返回 404 错误。
Args:
request_body: 请求体
Returns:
不含 model 字段的请求体
"""
result = dict(request_body)
result.pop("model", None)
return result
def get_model_for_url(
self,
request_body: Dict[str, Any],
mapped_model: Optional[str],
) -> Optional[str]:
"""
Gemini 需要将 model 放入 URL 路径中
Args:
request_body: 请求体
mapped_model: 映射后的模型名(如果有)
Returns:
用于 URL 路径的模型名
"""
# 优先使用映射后的模型名,否则使用请求体中的
return mapped_model or request_body.get("model")
def _extract_usage_from_event(self, event: Dict[str, Any]) -> Dict[str, int]:
"""
从 Gemini 事件中提取 token 使用情况
调用 GeminiStreamParser.extract_usage 作为单一实现源
Args:
event: Gemini 流式响应事件
Returns:
包含 input_tokens, output_tokens, cached_tokens 的字典
"""
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
usage = GeminiStreamParser().extract_usage(event)
if not usage:
return {
"input_tokens": 0,
"output_tokens": 0,
"cached_tokens": 0,
}
return {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"cached_tokens": usage.get("cached_tokens", 0),
}
def _process_event_data(
self,
ctx: StreamContext,
_event_type: str,
data: Dict[str, Any],
) -> None:
"""
处理 Gemini CLI 格式的流式事件
Gemini 的流式响应是 JSON 数组格式,每个元素结构如下:
{
"candidates": [{
"content": {"parts": [{"text": "..."}], "role": "model"},
"finishReason": "STOP",
"safetyRatings": [...]
}],
"usageMetadata": {
"promptTokenCount": 10,
"candidatesTokenCount": 20,
"totalTokenCount": 30,
"cachedContentTokenCount": 5
},
"modelVersion": "gemini-1.5-pro"
}
注意: Gemini 流解析器会将每个 JSON 对象作为一个"事件"传递
event_type 在这里可能为空或是自定义的标记
"""
# 提取候选响应
candidates = data.get("candidates", [])
if candidates:
candidate = candidates[0]
content = candidate.get("content", {})
# 提取文本内容
parts = content.get("parts", [])
for part in parts:
if "text" in part:
ctx.collected_text += part["text"]
# 检查结束原因
finish_reason = candidate.get("finishReason")
if finish_reason in ("STOP", "MAX_TOKENS", "SAFETY", "RECITATION", "OTHER"):
ctx.has_completion = True
ctx.final_response = data
# 提取使用量信息(复用 GeminiStreamParser.extract_usage
usage = self._extract_usage_from_event(data)
if usage["input_tokens"] > 0 or usage["output_tokens"] > 0:
ctx.input_tokens = usage["input_tokens"]
ctx.output_tokens = usage["output_tokens"]
ctx.cached_tokens = usage["cached_tokens"]
# 提取模型版本作为响应 ID
model_version = data.get("modelVersion")
if model_version:
if not ctx.response_id:
ctx.response_id = f"gemini-{model_version}"
# 存储到 response_metadata 供 Usage 记录使用
ctx.response_metadata["model_version"] = model_version
# 检查错误
if "error" in data:
ctx.has_completion = True
ctx.final_response = data
def _extract_response_metadata(
self,
response: Dict[str, Any],
) -> Dict[str, Any]:
"""
从 Gemini 响应中提取元数据
提取 modelVersion 字段,记录实际使用的模型版本。
Args:
response: Gemini API 响应
Returns:
包含 model_version 的元数据字典
"""
metadata: Dict[str, Any] = {}
model_version = response.get("modelVersion")
if model_version:
metadata["model_version"] = model_version
return metadata

View File

@@ -0,0 +1,11 @@
"""
OpenAI Chat API 处理器
"""
from src.api.handlers.openai.adapter import OpenAIChatAdapter
from src.api.handlers.openai.handler import OpenAIChatHandler
__all__ = [
"OpenAIChatAdapter",
"OpenAIChatHandler",
]

View File

@@ -0,0 +1,109 @@
"""
OpenAI Chat Adapter - 基于 ChatAdapterBase 的 OpenAI Chat API 适配器
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
"""
from typing import Any, Dict, Optional, Type
from fastapi import Request
from fastapi.responses import JSONResponse
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.core.logger import logger
from src.models.openai import OpenAIRequest
@register_adapter
class OpenAIChatAdapter(ChatAdapterBase):
"""
OpenAI Chat Completions API 适配器
处理 OpenAI Chat 格式的请求(/v1/chat/completions 端点)。
"""
FORMAT_ID = "OPENAI"
name = "openai.chat"
@property
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
"""延迟导入 Handler 类避免循环依赖"""
from src.api.handlers.openai.handler import OpenAIChatHandler
return OpenAIChatHandler
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
super().__init__(allowed_api_formats or ["OPENAI"])
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (Authorization: Bearer)"""
authorization = request.headers.get("authorization")
if authorization and authorization.startswith("Bearer "):
return authorization.replace("Bearer ", "")
return None
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
"""验证请求体"""
if not isinstance(original_request_body, dict):
return self._error_response(
400, "Request body must be a JSON object", "invalid_request_error"
)
required_fields = ["model", "messages"]
missing = [f for f in required_fields if f not in original_request_body]
if missing:
return self._error_response(
400,
f"Missing required fields: {', '.join(missing)}",
"invalid_request_error",
)
try:
return OpenAIRequest.model_validate(original_request_body, strict=False)
except ValueError as e:
return self._error_response(400, str(e), "invalid_request_error")
except Exception as e:
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
return OpenAIRequest.model_construct(
model=original_request_body.get("model"),
messages=original_request_body.get("messages", []),
stream=original_request_body.get("stream", False),
max_tokens=original_request_body.get("max_tokens"),
)
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
"""构建 OpenAI Chat 特定的审计元数据"""
role_counts = {}
for message in request_obj.messages:
role_counts[message.role] = role_counts.get(message.role, 0) + 1
return {
"action": "openai_chat_completion",
"model": request_obj.model,
"stream": bool(request_obj.stream),
"max_tokens": request_obj.max_tokens,
"temperature": request_obj.temperature,
"top_p": request_obj.top_p,
"messages_count": len(request_obj.messages),
"message_roles": role_counts,
"tools_count": len(request_obj.tools or []),
"response_format": bool(request_obj.response_format),
"user_identifier": request_obj.user,
}
def _error_response(self, status_code: int, message: str, error_type: str) -> JSONResponse:
"""生成 OpenAI 格式的错误响应"""
return JSONResponse(
status_code=status_code,
content={
"error": {
"type": error_type,
"message": message,
"code": status_code,
}
},
)
__all__ = ["OpenAIChatAdapter"]

View File

@@ -0,0 +1,424 @@
"""
Claude -> OpenAI 格式转换器
将 Claude Messages API 格式转换为 OpenAI Chat Completions API 格式。
"""
from __future__ import annotations
import json
import time
import uuid
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
class ClaudeToOpenAIConverter:
"""
Claude -> OpenAI 格式转换器
支持:
- 请求转换Claude Request -> OpenAI Chat Request
- 响应转换Claude Response -> OpenAI Chat Response
- 流式转换Claude SSE -> OpenAI SSE
"""
# 内容类型常量
CONTENT_TYPE_TEXT = "text"
CONTENT_TYPE_IMAGE = "image"
CONTENT_TYPE_TOOL_USE = "tool_use"
CONTENT_TYPE_TOOL_RESULT = "tool_result"
# 停止原因映射
STOP_REASON_MAP = {
"end_turn": "stop",
"max_tokens": "length",
"stop_sequence": "stop",
"tool_use": "tool_calls",
}
def __init__(self, model_mapping: Optional[Dict[str, str]] = None):
"""
Args:
model_mapping: Claude 模型到 OpenAI 模型的映射
"""
self._model_mapping = model_mapping or {}
# ==================== 请求转换 ====================
def convert_request(self, request: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]:
"""
将 Claude 请求转换为 OpenAI 格式
Args:
request: Claude 请求Dict 或 Pydantic 模型)
Returns:
OpenAI 格式的请求字典
"""
if hasattr(request, "model_dump"):
data = request.model_dump(exclude_none=True)
else:
data = dict(request)
# 模型映射
model = data.get("model", "")
openai_model = self._model_mapping.get(model, model)
# 构建消息列表
messages: List[Dict[str, Any]] = []
# 处理 system 消息
system_content = self._extract_text_content(data.get("system"))
if system_content:
messages.append({"role": "system", "content": system_content})
# 处理对话消息
for message in data.get("messages", []):
converted = self._convert_message(message)
if converted:
messages.append(converted)
# 构建 OpenAI 请求
result: Dict[str, Any] = {
"model": openai_model,
"messages": messages,
}
# 可选参数
if data.get("max_tokens"):
result["max_tokens"] = data["max_tokens"]
if data.get("temperature") is not None:
result["temperature"] = data["temperature"]
if data.get("top_p") is not None:
result["top_p"] = data["top_p"]
if data.get("stream"):
result["stream"] = data["stream"]
if data.get("stop_sequences"):
result["stop"] = data["stop_sequences"]
# 工具转换
tools = self._convert_tools(data.get("tools"))
if tools:
result["tools"] = tools
tool_choice = self._convert_tool_choice(data.get("tool_choice"))
if tool_choice:
result["tool_choice"] = tool_choice
return result
def _convert_message(self, message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""转换单条消息"""
role = message.get("role")
if role == "user":
return self._convert_user_message(message)
if role == "assistant":
return self._convert_assistant_message(message)
return None
def _convert_user_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""转换用户消息"""
content = message.get("content")
if isinstance(content, str):
return {"role": "user", "content": content}
openai_content: List[Dict[str, Any]] = []
for block in content or []:
block_type = block.get("type")
if block_type == self.CONTENT_TYPE_TEXT:
openai_content.append({"type": "text", "text": block.get("text", "")})
elif block_type == self.CONTENT_TYPE_IMAGE:
source = block.get("source", {})
media_type = source.get("media_type", "image/jpeg")
data = source.get("data", "")
openai_content.append(
{"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{data}"}}
)
elif block_type == self.CONTENT_TYPE_TOOL_RESULT:
tool_content = block.get("content", "")
rendered = self._render_tool_content(tool_content)
openai_content.append({"type": "text", "text": f"Tool result: {rendered}"})
# 简化单文本内容
if len(openai_content) == 1 and openai_content[0]["type"] == "text":
return {"role": "user", "content": openai_content[0]["text"]}
return {"role": "user", "content": openai_content or ""}
def _convert_assistant_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""转换助手消息"""
content = message.get("content")
text_parts: List[str] = []
tool_calls: List[Dict[str, Any]] = []
if isinstance(content, str):
text_parts.append(content)
else:
for idx, block in enumerate(content or []):
block_type = block.get("type")
if block_type == self.CONTENT_TYPE_TEXT:
text_parts.append(block.get("text", ""))
elif block_type == self.CONTENT_TYPE_TOOL_USE:
tool_calls.append(
{
"id": block.get("id", f"call_{idx}"),
"type": "function",
"function": {
"name": block.get("name", ""),
"arguments": json.dumps(block.get("input", {}), ensure_ascii=False),
},
}
)
result: Dict[str, Any] = {"role": "assistant"}
message_content = "\n".join([p for p in text_parts if p]) or None
if message_content:
result["content"] = message_content
if tool_calls:
result["tool_calls"] = tool_calls
return result
def _convert_tools(
self, tools: Optional[List[Dict[str, Any]]]
) -> Optional[List[Dict[str, Any]]]:
"""转换工具定义"""
if not tools:
return None
result: List[Dict[str, Any]] = []
for tool in tools:
result.append(
{
"type": "function",
"function": {
"name": tool.get("name", ""),
"description": tool.get("description"),
"parameters": tool.get("input_schema", {}),
},
}
)
return result
def _convert_tool_choice(
self, tool_choice: Optional[Dict[str, Any]]
) -> Optional[Union[str, Dict[str, Any]]]:
"""转换工具选择"""
if tool_choice is None:
return None
choice_type = tool_choice.get("type")
if choice_type in ("tool", "tool_use"):
return {"type": "function", "function": {"name": tool_choice.get("name", "")}}
if choice_type == "any":
return "required"
if choice_type == "auto":
return "auto"
return tool_choice
# ==================== 响应转换 ====================
def convert_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""
将 Claude 响应转换为 OpenAI 格式
Args:
response: Claude 响应字典
Returns:
OpenAI 格式的响应字典
"""
# 提取内容
content_parts: List[str] = []
tool_calls: List[Dict[str, Any]] = []
for idx, block in enumerate(response.get("content", [])):
block_type = block.get("type")
if block_type == self.CONTENT_TYPE_TEXT:
content_parts.append(block.get("text", ""))
elif block_type == self.CONTENT_TYPE_TOOL_USE:
tool_calls.append(
{
"id": block.get("id", f"call_{idx}"),
"type": "function",
"function": {
"name": block.get("name", ""),
"arguments": json.dumps(block.get("input", {}), ensure_ascii=False),
},
}
)
# 构建消息
message: Dict[str, Any] = {"role": "assistant"}
text_content = "\n".join([p for p in content_parts if p]) or None
if text_content:
message["content"] = text_content
if tool_calls:
message["tool_calls"] = tool_calls
# 转换停止原因
stop_reason = response.get("stop_reason")
finish_reason = self.STOP_REASON_MAP.get(stop_reason, stop_reason)
# 转换 usage
usage = response.get("usage", {})
openai_usage = {
"prompt_tokens": usage.get("input_tokens", 0),
"completion_tokens": usage.get("output_tokens", 0),
"total_tokens": (usage.get("input_tokens", 0) + usage.get("output_tokens", 0)),
}
return {
"id": f"chatcmpl-{response.get('id', uuid.uuid4().hex[:8])}",
"object": "chat.completion",
"created": int(time.time()),
"model": response.get("model", ""),
"choices": [
{
"index": 0,
"message": message,
"finish_reason": finish_reason,
}
],
"usage": openai_usage,
}
# ==================== 流式转换 ====================
def convert_stream_event(
self,
event: Dict[str, Any],
model: str = "",
message_id: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
"""
将 Claude SSE 事件转换为 OpenAI 格式
Args:
event: Claude SSE 事件
model: 模型名称
message_id: 消息 ID
Returns:
OpenAI 格式的 SSE chunk如果无法转换返回 None
"""
event_type = event.get("type")
chunk_id = f"chatcmpl-{(message_id or 'stream')[-8:]}"
if event_type == "message_start":
message = event.get("message", {})
return self._base_chunk(
chunk_id,
model or message.get("model", ""),
{"role": "assistant"},
)
if event_type == "content_block_start":
content_block = event.get("content_block", {})
if content_block.get("type") == self.CONTENT_TYPE_TOOL_USE:
delta = {
"tool_calls": [
{
"index": event.get("index", 0),
"id": content_block.get("id", ""),
"type": "function",
"function": {
"name": content_block.get("name", ""),
"arguments": "",
},
}
]
}
return self._base_chunk(chunk_id, model, delta)
return None
if event_type == "content_block_delta":
delta_payload = event.get("delta", {})
delta_type = delta_payload.get("type")
if delta_type == "text_delta":
delta = {"content": delta_payload.get("text", "")}
return self._base_chunk(chunk_id, model, delta)
if delta_type == "input_json_delta":
delta = {
"tool_calls": [
{
"index": event.get("index", 0),
"function": {"arguments": delta_payload.get("partial_json", "")},
}
]
}
return self._base_chunk(chunk_id, model, delta)
return None
if event_type == "message_delta":
delta = event.get("delta", {})
stop_reason = delta.get("stop_reason")
finish_reason = self.STOP_REASON_MAP.get(stop_reason, stop_reason)
return self._base_chunk(chunk_id, model, {}, finish_reason=finish_reason)
if event_type == "message_stop":
return self._base_chunk(chunk_id, model, {}, finish_reason="stop")
return None
def _base_chunk(
self,
chunk_id: str,
model: str,
delta: Dict[str, Any],
finish_reason: Optional[str] = None,
) -> Dict[str, Any]:
"""构建基础 OpenAI chunk"""
return {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"system_fingerprint": None,
"choices": [
{
"index": 0,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
# ==================== 工具方法 ====================
def _extract_text_content(
self, content: Optional[Union[str, List[Dict[str, Any]]]]
) -> Optional[str]:
"""提取文本内容"""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [
block.get("text", "")
for block in content
if block.get("type") == self.CONTENT_TYPE_TEXT
]
return "\n\n".join(filter(None, parts)) or None
return None
def _render_tool_content(self, tool_content: Any) -> str:
"""渲染工具内容"""
if isinstance(tool_content, list):
return json.dumps(tool_content, ensure_ascii=False)
return str(tool_content)
__all__ = ["ClaudeToOpenAIConverter"]

View File

@@ -0,0 +1,137 @@
"""
OpenAI Chat Handler - 基于通用 Chat Handler 基类的简化实现
继承 ChatHandlerBase只需覆盖格式特定的方法。
代码量从原来的 ~1315 行减少到 ~100 行。
"""
from typing import Any, Dict, Optional
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
class OpenAIChatHandler(ChatHandlerBase):
"""
OpenAI Chat Handler - 处理 OpenAI Chat Completions API 格式的请求
格式特点:
- 使用 prompt_tokens/completion_tokens
- 不支持 cache tokens
- 请求格式OpenAIRequest
"""
FORMAT_ID = "OPENAI"
def extract_model_from_request(
self,
request_body: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
) -> str:
"""
从请求中提取模型名 - OpenAI 格式实现
OpenAI API 的 model 在请求体顶级字段。
Args:
request_body: 请求体
path_params: URL 路径参数OpenAI 不使用)
Returns:
模型名
"""
model = request_body.get("model")
return str(model) if model else "unknown"
def apply_mapped_model(
self,
request_body: Dict[str, Any],
mapped_model: str,
) -> Dict[str, Any]:
"""
将映射后的模型名应用到请求体
OpenAI API 的 model 在请求体顶级字段。
Args:
request_body: 原始请求体
mapped_model: 映射后的模型名
Returns:
更新了 model 字段的请求体
"""
result = dict(request_body)
result["model"] = mapped_model
return result
async def _convert_request(self, request):
"""
将请求转换为 OpenAI 格式
Args:
request: 原始请求对象
Returns:
OpenAIRequest 对象
"""
from src.api.handlers.openai.converter import ClaudeToOpenAIConverter
from src.models.claude import ClaudeMessagesRequest
from src.models.openai import OpenAIRequest
# 如果已经是 OpenAI 格式,直接返回
if isinstance(request, OpenAIRequest):
return request
# 如果是 Claude 格式,转换为 OpenAI 格式
if isinstance(request, ClaudeMessagesRequest):
converter = ClaudeToOpenAIConverter()
openai_dict = converter.convert_request(request.dict())
return OpenAIRequest(**openai_dict)
# 如果是字典,尝试判断格式
if isinstance(request, dict):
try:
return OpenAIRequest(**request)
except Exception:
try:
converter = ClaudeToOpenAIConverter()
openai_dict = converter.convert_request(request)
return OpenAIRequest(**openai_dict)
except Exception:
return OpenAIRequest(**request)
return request
def _extract_usage(self, response: Dict) -> Dict[str, int]:
"""
从 OpenAI 响应中提取 token 使用情况
OpenAI 格式使用:
- prompt_tokens / completion_tokens
- 不支持 cache tokens
"""
usage = response.get("usage", {})
return {
"input_tokens": usage.get("prompt_tokens", 0),
"output_tokens": usage.get("completion_tokens", 0),
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
}
def _normalize_response(self, response: Dict) -> Dict:
"""
规范化 OpenAI 响应
Args:
response: 原始响应
Returns:
规范化后的响应
"""
if self.response_normalizer and self.response_normalizer.should_normalize(response):
return self.response_normalizer.normalize_openai_response(
response_data=response,
request_id=self.request_id,
strict=False,
)
return response

View File

@@ -0,0 +1,181 @@
"""
OpenAI SSE 流解析器
解析 OpenAI Chat Completions API 的 Server-Sent Events 流。
"""
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional
class OpenAIStreamParser:
"""
OpenAI SSE 流解析器
解析 OpenAI Chat Completions API 的 SSE 事件流。
OpenAI 流格式:
- 每个 chunk 是一个 JSON 对象,包含 choices 数组
- choices[0].delta 包含增量内容
- choices[0].finish_reason 表示结束原因
- 流结束时发送 data: [DONE]
"""
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
"""
解析 SSE 数据块
Args:
chunk: 原始 SSE 数据bytes 或 str
Returns:
解析后的 chunk 列表
"""
if isinstance(chunk, bytes):
text = chunk.decode("utf-8")
else:
text = chunk
chunks: List[Dict[str, Any]] = []
lines = text.strip().split("\n")
for line in lines:
line = line.strip()
if not line:
continue
# 解析数据行
if line.startswith("data: "):
data_str = line[6:]
# 处理 [DONE] 标记
if data_str == "[DONE]":
chunks.append({"__done__": True})
continue
try:
data = json.loads(data_str)
chunks.append(data)
except json.JSONDecodeError:
# 无法解析的数据,跳过
pass
return chunks
def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
"""
解析单行 SSE 数据
Args:
line: SSE 数据行(已去除 "data: " 前缀)
Returns:
解析后的 chunk 字典,如果无法解析返回 None
"""
if not line or line == "[DONE]":
return None
try:
return json.loads(line)
except json.JSONDecodeError:
return None
def is_done_chunk(self, chunk: Dict[str, Any]) -> bool:
"""
判断是否为结束 chunk
Args:
chunk: chunk 字典
Returns:
True 如果是结束 chunk
"""
# 内部标记
if chunk.get("__done__"):
return True
# 检查 finish_reason
choices = chunk.get("choices", [])
if choices:
finish_reason = choices[0].get("finish_reason")
return finish_reason is not None
return False
def get_finish_reason(self, chunk: Dict[str, Any]) -> Optional[str]:
"""
获取结束原因
Args:
chunk: chunk 字典
Returns:
结束原因字符串
"""
choices = chunk.get("choices", [])
if choices:
return choices[0].get("finish_reason")
return None
def extract_text_delta(self, chunk: Dict[str, Any]) -> Optional[str]:
"""
从 chunk 中提取文本增量
Args:
chunk: chunk 字典
Returns:
文本增量,如果没有返回 None
"""
choices = chunk.get("choices", [])
if not choices:
return None
delta = choices[0].get("delta", {})
content = delta.get("content")
if isinstance(content, str):
return content
return None
def extract_tool_calls_delta(self, chunk: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
"""
从 chunk 中提取工具调用增量
Args:
chunk: chunk 字典
Returns:
工具调用列表,如果没有返回 None
"""
choices = chunk.get("choices", [])
if not choices:
return None
delta = choices[0].get("delta", {})
return delta.get("tool_calls")
def extract_role(self, chunk: Dict[str, Any]) -> Optional[str]:
"""
从 chunk 中提取角色
通常只在第一个 chunk 中出现。
Args:
chunk: chunk 字典
Returns:
角色字符串
"""
choices = chunk.get("choices", [])
if not choices:
return None
delta = choices[0].get("delta", {})
return delta.get("role")
__all__ = ["OpenAIStreamParser"]

View File

@@ -0,0 +1,11 @@
"""
OpenAI CLI 透传处理器
"""
from src.api.handlers.openai_cli.adapter import OpenAICliAdapter
from src.api.handlers.openai_cli.handler import OpenAICliMessageHandler
__all__ = [
"OpenAICliAdapter",
"OpenAICliMessageHandler",
]

View File

@@ -0,0 +1,44 @@
"""
OpenAI CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。
"""
from typing import Optional, Type
from fastapi import Request
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
@register_cli_adapter
class OpenAICliAdapter(CliAdapterBase):
"""
OpenAI CLI API 适配器
处理 /v1/responses 端点的请求。
"""
FORMAT_ID = "OPENAI_CLI"
name = "openai.cli"
@property
def HANDLER_CLASS(self) -> Type[CliMessageHandlerBase]:
"""延迟导入 Handler 类避免循环依赖"""
from src.api.handlers.openai_cli.handler import OpenAICliMessageHandler
return OpenAICliMessageHandler
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
super().__init__(allowed_api_formats or ["OPENAI_CLI"])
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (Authorization: Bearer)"""
authorization = request.headers.get("authorization")
if authorization and authorization.startswith("Bearer "):
return authorization.replace("Bearer ", "")
return None
__all__ = ["OpenAICliAdapter"]

View File

@@ -0,0 +1,211 @@
"""
OpenAI CLI Message Handler - 基于通用 CLI Handler 基类的简化实现
继承 CliMessageHandlerBase只需覆盖格式特定的配置和事件处理逻辑。
代码量从原来的 900+ 行减少到 ~100 行。
"""
from typing import Any, Dict, Optional
from src.api.handlers.base.cli_handler_base import (
CliMessageHandlerBase,
StreamContext,
)
class OpenAICliMessageHandler(CliMessageHandlerBase):
"""
OpenAI CLI Message Handler - 处理 OpenAI CLI Responses API 格式
使用新三层架构 (Provider -> ProviderEndpoint -> ProviderAPIKey)
通过 FallbackOrchestrator 实现自动故障转移、健康监控和并发控制
响应格式特点:
- 使用 output[] 数组而非 content[]
- 使用 output_text 类型而非普通 text
- 流式事件response.output_text.delta, response.output_text.done
模型字段:请求体顶级 model 字段
"""
FORMAT_ID = "OPENAI_CLI"
def extract_model_from_request(
self,
request_body: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
) -> str:
"""
从请求中提取模型名 - OpenAI 格式实现
OpenAI API 的 model 在请求体顶级字段。
Args:
request_body: 请求体
path_params: URL 路径参数OpenAI 不使用)
Returns:
模型名
"""
model = request_body.get("model")
return str(model) if model else "unknown"
def apply_mapped_model(
self,
request_body: Dict[str, Any],
mapped_model: str,
) -> Dict[str, Any]:
"""
OpenAI CLI (Responses API) 的 model 在请求体顶级字段。
Args:
request_body: 原始请求体
mapped_model: 映射后的模型名
Returns:
更新了 model 字段的请求体
"""
result = dict(request_body)
result["model"] = mapped_model
return result
def _process_event_data(
self,
ctx: StreamContext,
event_type: str,
data: Dict[str, Any],
) -> None:
"""
处理 OpenAI CLI 格式的 SSE 事件
事件类型:
- response.output_text.delta: 文本增量
- response.completed: 响应完成(包含 usage
"""
# 提取 response_id
if not ctx.response_id:
response_obj = data.get("response")
if isinstance(response_obj, dict) and response_obj.get("id"):
ctx.response_id = response_obj["id"]
elif "id" in data:
ctx.response_id = data["id"]
# 处理文本增量
if event_type in ["response.output_text.delta", "response.outtext.delta"]:
delta = data.get("delta")
if isinstance(delta, str):
ctx.collected_text += delta
elif isinstance(delta, dict) and "text" in delta:
ctx.collected_text += delta["text"]
# 处理完成事件
elif event_type == "response.completed":
ctx.has_completion = True
response_obj = data.get("response")
if isinstance(response_obj, dict):
ctx.final_response = response_obj
usage_obj = response_obj.get("usage")
if isinstance(usage_obj, dict):
ctx.final_usage = usage_obj
ctx.input_tokens = usage_obj.get("input_tokens", 0)
ctx.output_tokens = usage_obj.get("output_tokens", 0)
details = usage_obj.get("input_tokens_details")
if isinstance(details, dict):
ctx.cached_tokens = details.get("cached_tokens", 0)
# 如果没有收集到文本,从 output 中提取
if not ctx.collected_text and "output" in response_obj:
for output_item in response_obj.get("output", []):
if output_item.get("type") != "message":
continue
for content_item in output_item.get("content", []):
if content_item.get("type") == "output_text":
text = content_item.get("text", "")
if text:
ctx.collected_text += text
# 备用:从顶层 usage 提取
usage_obj = data.get("usage")
if isinstance(usage_obj, dict) and not ctx.final_usage:
ctx.final_usage = usage_obj
ctx.input_tokens = usage_obj.get("input_tokens", 0)
ctx.output_tokens = usage_obj.get("output_tokens", 0)
details = usage_obj.get("input_tokens_details")
if isinstance(details, dict):
ctx.cached_tokens = details.get("cached_tokens", 0)
# 备用:从 response 字段提取
response_obj = data.get("response")
if isinstance(response_obj, dict) and not ctx.final_response:
ctx.final_response = response_obj
def _extract_response_metadata(
self,
response: Dict[str, Any],
) -> Dict[str, Any]:
"""
从 OpenAI 响应中提取元数据
提取 model、status、response_id 等字段作为元数据。
Args:
response: OpenAI API 响应
Returns:
提取的元数据字典
"""
metadata: Dict[str, Any] = {}
# 提取模型名称(实际使用的模型)
if "model" in response:
metadata["model"] = response["model"]
# 提取响应 ID
if "id" in response:
metadata["response_id"] = response["id"]
# 提取状态
if "status" in response:
metadata["status"] = response["status"]
# 提取对象类型
if "object" in response:
metadata["object"] = response["object"]
# 提取系统指纹(如果存在)
if "system_fingerprint" in response:
metadata["system_fingerprint"] = response["system_fingerprint"]
return metadata
def _finalize_stream_metadata(self, ctx: StreamContext) -> None:
"""
从流上下文中提取最终元数据
在流传输完成后调用,从收集的事件中提取元数据。
Args:
ctx: 流上下文
"""
# 从 response_id 提取响应 ID
if ctx.response_id:
ctx.response_metadata["response_id"] = ctx.response_id
# 从 final_response 提取更多元数据
if ctx.final_response and isinstance(ctx.final_response, dict):
if "model" in ctx.final_response:
ctx.response_metadata["model"] = ctx.final_response["model"]
if "status" in ctx.final_response:
ctx.response_metadata["status"] = ctx.final_response["status"]
if "object" in ctx.final_response:
ctx.response_metadata["object"] = ctx.final_response["object"]
if "system_fingerprint" in ctx.final_response:
ctx.response_metadata["system_fingerprint"] = ctx.final_response["system_fingerprint"]
# 如果没有从响应中获取到 model使用上下文中的
if "model" not in ctx.response_metadata and ctx.model:
ctx.response_metadata["model"] = ctx.model