mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-07 10:12:27 +08:00
Initial commit
This commit is contained in:
99
src/api/handlers/__init__.py
Normal file
99
src/api/handlers/__init__.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
API Handlers - 请求处理器
|
||||
|
||||
按 API 格式组织的 Adapter 和 Handler:
|
||||
- Adapter: 请求验证、格式转换、错误处理
|
||||
- Handler: 业务逻辑、调用 Provider、记录用量
|
||||
|
||||
支持的格式:
|
||||
- claude: Claude Chat API (/v1/messages)
|
||||
- claude_cli: Claude CLI 透传模式
|
||||
- openai: OpenAI Chat API (/v1/chat/completions)
|
||||
- openai_cli: OpenAI CLI 透传模式
|
||||
|
||||
注意:Handler 基类和具体 Handler 使用延迟导入以避免循环依赖。
|
||||
"""
|
||||
|
||||
# Adapter 基类(不会引起循环导入,可以直接导入)
|
||||
from src.api.handlers.base import (
|
||||
ChatAdapterBase,
|
||||
CliAdapterBase,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Adapter 基类
|
||||
"ChatAdapterBase",
|
||||
"CliAdapterBase",
|
||||
# Handler 基类(延迟导入)
|
||||
"ChatHandlerBase",
|
||||
"CliMessageHandlerBase",
|
||||
"BaseMessageHandler",
|
||||
"MessageHandlerProtocol",
|
||||
"MessageTelemetry",
|
||||
"StreamContext",
|
||||
# Claude
|
||||
"ClaudeChatAdapter",
|
||||
"ClaudeTokenCountAdapter",
|
||||
"build_claude_adapter",
|
||||
"ClaudeChatHandler",
|
||||
# Claude CLI
|
||||
"ClaudeCliAdapter",
|
||||
"ClaudeCliMessageHandler",
|
||||
# OpenAI
|
||||
"OpenAIChatAdapter",
|
||||
"OpenAIChatHandler",
|
||||
# OpenAI CLI
|
||||
"OpenAICliAdapter",
|
||||
"OpenAICliMessageHandler",
|
||||
]
|
||||
|
||||
# 延迟导入映射表
|
||||
_LAZY_IMPORTS = {
|
||||
# Handler 基类
|
||||
"ChatHandlerBase": ("src.api.handlers.base.chat_handler_base", "ChatHandlerBase"),
|
||||
"CliMessageHandlerBase": (
|
||||
"src.api.handlers.base.cli_handler_base",
|
||||
"CliMessageHandlerBase",
|
||||
),
|
||||
"StreamContext": ("src.api.handlers.base.cli_handler_base", "StreamContext"),
|
||||
"BaseMessageHandler": ("src.api.handlers.base.base_handler", "BaseMessageHandler"),
|
||||
"MessageHandlerProtocol": (
|
||||
"src.api.handlers.base.base_handler",
|
||||
"MessageHandlerProtocol",
|
||||
),
|
||||
"MessageTelemetry": ("src.api.handlers.base.base_handler", "MessageTelemetry"),
|
||||
# Claude
|
||||
"ClaudeChatAdapter": ("src.api.handlers.claude.adapter", "ClaudeChatAdapter"),
|
||||
"ClaudeTokenCountAdapter": (
|
||||
"src.api.handlers.claude.adapter",
|
||||
"ClaudeTokenCountAdapter",
|
||||
),
|
||||
"build_claude_adapter": ("src.api.handlers.claude.adapter", "build_claude_adapter"),
|
||||
"ClaudeChatHandler": ("src.api.handlers.claude.handler", "ClaudeChatHandler"),
|
||||
# Claude CLI
|
||||
"ClaudeCliAdapter": ("src.api.handlers.claude_cli.adapter", "ClaudeCliAdapter"),
|
||||
"ClaudeCliMessageHandler": (
|
||||
"src.api.handlers.claude_cli.handler",
|
||||
"ClaudeCliMessageHandler",
|
||||
),
|
||||
# OpenAI
|
||||
"OpenAIChatAdapter": ("src.api.handlers.openai.adapter", "OpenAIChatAdapter"),
|
||||
"OpenAIChatHandler": ("src.api.handlers.openai.handler", "OpenAIChatHandler"),
|
||||
# OpenAI CLI
|
||||
"OpenAICliAdapter": ("src.api.handlers.openai_cli.adapter", "OpenAICliAdapter"),
|
||||
"OpenAICliMessageHandler": (
|
||||
"src.api.handlers.openai_cli.handler",
|
||||
"OpenAICliMessageHandler",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""延迟导入以避免循环依赖"""
|
||||
if name in _LAZY_IMPORTS:
|
||||
module_path, attr_name = _LAZY_IMPORTS[name]
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, attr_name)
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
68
src/api/handlers/base/__init__.py
Normal file
68
src/api/handlers/base/__init__.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Handler 基类模块
|
||||
|
||||
提供 Adapter、Handler 的抽象基类,以及请求构建器和响应解析器。
|
||||
|
||||
注意:Handler 基类(ChatHandlerBase, CliMessageHandlerBase 等)不在这里导出,
|
||||
因为它们依赖 services.usage.stream,而后者又需要导入 response_parser,
|
||||
会形成循环导入。请直接从具体模块导入 Handler 基类。
|
||||
"""
|
||||
|
||||
# Chat Adapter 基类(不会引起循环导入)
|
||||
from src.api.handlers.base.chat_adapter_base import (
|
||||
ChatAdapterBase,
|
||||
get_adapter_class,
|
||||
get_adapter_instance,
|
||||
list_registered_formats,
|
||||
register_adapter,
|
||||
)
|
||||
|
||||
# CLI Adapter 基类
|
||||
from src.api.handlers.base.cli_adapter_base import (
|
||||
CliAdapterBase,
|
||||
get_cli_adapter_class,
|
||||
get_cli_adapter_instance,
|
||||
list_registered_cli_formats,
|
||||
register_cli_adapter,
|
||||
)
|
||||
|
||||
# 请求构建器
|
||||
from src.api.handlers.base.request_builder import (
|
||||
SENSITIVE_HEADERS,
|
||||
PassthroughRequestBuilder,
|
||||
RequestBuilder,
|
||||
build_passthrough_request,
|
||||
)
|
||||
|
||||
# 响应解析器
|
||||
from src.api.handlers.base.response_parser import (
|
||||
ParsedChunk,
|
||||
ParsedResponse,
|
||||
ResponseParser,
|
||||
StreamStats,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Chat Adapter
|
||||
"ChatAdapterBase",
|
||||
"register_adapter",
|
||||
"get_adapter_class",
|
||||
"get_adapter_instance",
|
||||
"list_registered_formats",
|
||||
# CLI Adapter
|
||||
"CliAdapterBase",
|
||||
"register_cli_adapter",
|
||||
"get_cli_adapter_class",
|
||||
"get_cli_adapter_instance",
|
||||
"list_registered_cli_formats",
|
||||
# 请求构建器
|
||||
"RequestBuilder",
|
||||
"PassthroughRequestBuilder",
|
||||
"build_passthrough_request",
|
||||
"SENSITIVE_HEADERS",
|
||||
# 响应解析器
|
||||
"ResponseParser",
|
||||
"ParsedChunk",
|
||||
"ParsedResponse",
|
||||
"StreamStats",
|
||||
]
|
||||
363
src/api/handlers/base/base_handler.py
Normal file
363
src/api/handlers/base/base_handler.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
基础消息处理器,封装通用的编排、转换、遥测逻辑。
|
||||
|
||||
接口约定:
|
||||
- process_stream: 处理流式请求,返回 StreamingResponse
|
||||
- process_sync: 处理非流式请求,返回 JSONResponse
|
||||
|
||||
签名规范(推荐):
|
||||
async def process_stream(
|
||||
self,
|
||||
request: Any, # 解析后的请求模型
|
||||
http_request: Request, # FastAPI Request 对象
|
||||
original_headers: Dict[str, str], # 原始请求头
|
||||
original_request_body: Dict[str, Any], # 原始请求体
|
||||
query_params: Optional[Dict[str, str]] = None, # 查询参数
|
||||
) -> StreamingResponse: ...
|
||||
|
||||
async def process_sync(
|
||||
self,
|
||||
request: Any,
|
||||
http_request: Request,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
query_params: Optional[Dict[str, str]] = None,
|
||||
) -> JSONResponse: ...
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.clients.redis_client import get_redis_client_sync
|
||||
from src.core.api_format_metadata import resolve_api_format
|
||||
from src.core.enums import APIFormat
|
||||
from src.core.logger import logger
|
||||
from src.services.orchestration.fallback_orchestrator import FallbackOrchestrator
|
||||
from src.services.provider.format import normalize_api_format
|
||||
from src.services.system.audit import audit_service
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
|
||||
|
||||
class MessageTelemetry:
|
||||
"""
|
||||
负责记录 Usage/Audit,避免处理器里重复代码。
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, user, api_key, request_id: str, client_ip: str):
|
||||
self.db = db
|
||||
self.user = user
|
||||
self.api_key = api_key
|
||||
self.request_id = request_id
|
||||
self.client_ip = client_ip
|
||||
|
||||
async def calculate_cost(
|
||||
self,
|
||||
provider: str,
|
||||
model: str,
|
||||
*,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
) -> float:
|
||||
input_price, output_price = await UsageService.get_model_price_async(
|
||||
self.db, provider, model
|
||||
)
|
||||
_, _, _, _, _, _, total_cost = UsageService.calculate_cost(
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
input_price,
|
||||
output_price,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
*await UsageService.get_cache_prices_async(self.db, provider, model, input_price),
|
||||
)
|
||||
return total_cost
|
||||
|
||||
async def record_success(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
response_time_ms: int,
|
||||
status_code: int,
|
||||
request_body: Dict[str, Any],
|
||||
request_headers: Dict[str, Any],
|
||||
response_body: Any,
|
||||
response_headers: Dict[str, Any],
|
||||
cache_creation_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
is_stream: bool = False,
|
||||
provider_request_headers: Optional[Dict[str, Any]] = None,
|
||||
# Provider 侧追踪信息(用于记录真实成本)
|
||||
provider_id: Optional[str] = None,
|
||||
provider_endpoint_id: Optional[str] = None,
|
||||
provider_api_key_id: Optional[str] = None,
|
||||
api_format: Optional[str] = None,
|
||||
# 模型映射信息
|
||||
target_model: Optional[str] = None,
|
||||
# Provider 响应元数据(如 Gemini 的 modelVersion)
|
||||
response_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> float:
|
||||
total_cost = await self.calculate_cost(
|
||||
provider,
|
||||
model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
)
|
||||
|
||||
await UsageService.record_usage(
|
||||
db=self.db,
|
||||
user=self.user,
|
||||
api_key=self.api_key,
|
||||
provider=provider,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_tokens,
|
||||
cache_read_input_tokens=cache_read_tokens,
|
||||
request_type="chat",
|
||||
api_format=api_format,
|
||||
is_stream=is_stream,
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=status_code,
|
||||
request_headers=request_headers,
|
||||
request_body=request_body,
|
||||
provider_request_headers=provider_request_headers or {},
|
||||
response_headers=response_headers,
|
||||
response_body=response_body,
|
||||
request_id=self.request_id,
|
||||
# Provider 侧追踪信息(用于记录真实成本)
|
||||
provider_id=provider_id,
|
||||
provider_endpoint_id=provider_endpoint_id,
|
||||
provider_api_key_id=provider_api_key_id,
|
||||
# 模型映射信息
|
||||
target_model=target_model,
|
||||
# Provider 响应元数据
|
||||
metadata=response_metadata,
|
||||
)
|
||||
|
||||
if self.user and self.api_key:
|
||||
audit_service.log_api_request(
|
||||
db=self.db,
|
||||
user_id=self.user.id,
|
||||
api_key_id=self.api_key.id,
|
||||
request_id=self.request_id,
|
||||
model=model,
|
||||
provider=provider,
|
||||
success=True,
|
||||
ip_address=self.client_ip,
|
||||
status_code=status_code,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost_usd=total_cost,
|
||||
)
|
||||
|
||||
return total_cost
|
||||
|
||||
async def record_failure(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
response_time_ms: int,
|
||||
status_code: int,
|
||||
error_message: str,
|
||||
request_body: Dict[str, Any],
|
||||
request_headers: Dict[str, Any],
|
||||
is_stream: bool,
|
||||
api_format: Optional[str] = None,
|
||||
provider_request_headers: Optional[Dict[str, Any]] = None,
|
||||
# 预估 token 信息(来自 message_start 事件,用于中断请求的成本估算)
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
response_body: Optional[Dict[str, Any]] = None,
|
||||
# 模型映射信息
|
||||
target_model: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
记录失败请求
|
||||
|
||||
注意:Provider 链路信息(provider_id, endpoint_id, key_id)不在此处记录,
|
||||
因为 RequestCandidate 表已经记录了完整的请求链路追踪信息。
|
||||
|
||||
Args:
|
||||
input_tokens: 预估输入 tokens(来自 message_start,用于中断请求的成本估算)
|
||||
output_tokens: 预估输出 tokens(来自已收到的内容)
|
||||
cache_creation_tokens: 缓存创建 tokens
|
||||
cache_read_tokens: 缓存读取 tokens
|
||||
response_body: 响应体(如果有部分响应)
|
||||
target_model: 映射后的目标模型名(如果发生了映射)
|
||||
"""
|
||||
provider_name = provider or "unknown"
|
||||
if provider_name == "unknown":
|
||||
logger.warning(f"[Telemetry] Recording failure with unknown provider (request_id={self.request_id})")
|
||||
|
||||
await UsageService.record_usage(
|
||||
db=self.db,
|
||||
user=self.user,
|
||||
api_key=self.api_key,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_tokens,
|
||||
cache_read_input_tokens=cache_read_tokens,
|
||||
request_type="chat",
|
||||
api_format=api_format,
|
||||
is_stream=is_stream,
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
request_headers=request_headers,
|
||||
request_body=request_body,
|
||||
provider_request_headers=provider_request_headers or {},
|
||||
response_headers={},
|
||||
response_body=response_body or {"error": error_message},
|
||||
request_id=self.request_id,
|
||||
# 模型映射信息
|
||||
target_model=target_model,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MessageHandlerProtocol(Protocol):
|
||||
"""
|
||||
消息处理器协议 - 定义标准接口
|
||||
|
||||
ChatHandlerBase 使用完整签名(含 request, http_request)。
|
||||
CliMessageHandlerBase 使用简化签名(仅 original_request_body, original_headers)。
|
||||
"""
|
||||
|
||||
async def process_stream(
|
||||
self,
|
||||
request: Any,
|
||||
http_request: Request,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
query_params: Optional[Dict[str, str]] = None,
|
||||
) -> StreamingResponse:
|
||||
"""处理流式请求"""
|
||||
...
|
||||
|
||||
async def process_sync(
|
||||
self,
|
||||
request: Any,
|
||||
http_request: Request,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
query_params: Optional[Dict[str, str]] = None,
|
||||
) -> JSONResponse:
|
||||
"""处理非流式请求"""
|
||||
...
|
||||
|
||||
|
||||
class BaseMessageHandler:
|
||||
"""
|
||||
消息处理器基类,所有具体格式的 handler 可以继承它。
|
||||
|
||||
子类需要实现:
|
||||
- process_stream: 处理流式请求
|
||||
- process_sync: 处理非流式请求
|
||||
|
||||
推荐使用 MessageHandlerProtocol 中定义的签名。
|
||||
"""
|
||||
|
||||
# Adapter 检测器类型
|
||||
AdapterDetectorType = Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
db: Session,
|
||||
user,
|
||||
api_key,
|
||||
request_id: str,
|
||||
client_ip: str,
|
||||
user_agent: str,
|
||||
start_time: float,
|
||||
allowed_api_formats: Optional[list[str]] = None,
|
||||
adapter_detector: Optional[AdapterDetectorType] = None,
|
||||
):
|
||||
self.db = db
|
||||
self.user = user
|
||||
self.api_key = api_key
|
||||
self.request_id = request_id
|
||||
self.client_ip = client_ip
|
||||
self.user_agent = user_agent
|
||||
self.start_time = start_time
|
||||
self.allowed_api_formats = allowed_api_formats or [APIFormat.CLAUDE.value]
|
||||
self.primary_api_format = normalize_api_format(self.allowed_api_formats[0])
|
||||
self.adapter_detector = adapter_detector
|
||||
|
||||
redis_client = get_redis_client_sync()
|
||||
self.orchestrator = FallbackOrchestrator(db, redis_client)
|
||||
self.telemetry = MessageTelemetry(db, user, api_key, request_id, client_ip)
|
||||
|
||||
def elapsed_ms(self) -> int:
|
||||
return int((time.time() - self.start_time) * 1000)
|
||||
|
||||
def _resolve_capability_requirements(
|
||||
self,
|
||||
model_name: str,
|
||||
request_headers: Optional[Dict[str, str]] = None,
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
解析请求的能力需求
|
||||
|
||||
来源:
|
||||
1. 用户模型级配置 (User.model_capability_settings)
|
||||
2. 用户 API Key 强制配置 (ApiKey.force_capabilities)
|
||||
3. 请求头 X-Require-Capability
|
||||
4. Adapter 的 detect_capability_requirements(如 Claude 的 anthropic-beta)
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
request_headers: 请求头
|
||||
request_body: 请求体(可选)
|
||||
|
||||
Returns:
|
||||
能力需求字典
|
||||
"""
|
||||
from src.services.capability.resolver import CapabilityResolver
|
||||
|
||||
return CapabilityResolver.resolve_requirements(
|
||||
user=self.user,
|
||||
user_api_key=self.api_key,
|
||||
model_name=model_name,
|
||||
request_headers=request_headers,
|
||||
request_body=request_body,
|
||||
adapter_detector=self.adapter_detector,
|
||||
)
|
||||
|
||||
def get_api_format(self, provider_type: Optional[str] = None) -> APIFormat:
|
||||
"""根据 provider_type 解析 API 格式,未知类型默认 OPENAI"""
|
||||
if provider_type:
|
||||
return resolve_api_format(provider_type, default=APIFormat.OPENAI)
|
||||
return self.primary_api_format
|
||||
|
||||
def build_provider_payload(
|
||||
self,
|
||||
original_body: Dict[str, Any],
|
||||
*,
|
||||
mapped_model: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建发送给 Provider 的请求体,替换 model 名称"""
|
||||
payload = dict(original_body)
|
||||
if mapped_model:
|
||||
payload["model"] = mapped_model
|
||||
return payload
|
||||
724
src/api/handlers/base/chat_adapter_base.py
Normal file
724
src/api/handlers/base/chat_adapter_base.py
Normal file
@@ -0,0 +1,724 @@
|
||||
"""
|
||||
Chat Adapter 通用基类
|
||||
|
||||
提供 Chat 格式(进行请求验证和标准化)的通用适配器逻辑:
|
||||
- 请求解析和验证
|
||||
- 审计日志记录
|
||||
- 错误处理和响应格式化
|
||||
- Handler 创建和调用
|
||||
- 计费策略(支持不同 API 格式的差异化计费)
|
||||
|
||||
子类只需提供:
|
||||
- FORMAT_ID: API 格式标识
|
||||
- HANDLER_CLASS: 对应的 ChatHandlerBase 子类
|
||||
- _validate_request_body(): 可选覆盖请求验证逻辑
|
||||
- _build_audit_metadata(): 可选覆盖审计元数据构建
|
||||
- compute_total_input_context(): 可选覆盖总输入上下文计算(用于阶梯计费判定)
|
||||
"""
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.context import ApiRequestContext
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
from src.core.exceptions import (
|
||||
InvalidRequestException,
|
||||
ModelNotSupportedException,
|
||||
ProviderAuthException,
|
||||
ProviderNotAvailableException,
|
||||
ProviderRateLimitException,
|
||||
ProviderTimeoutException,
|
||||
ProxyException,
|
||||
QuotaExceededException,
|
||||
UpstreamClientException,
|
||||
)
|
||||
from src.core.logger import logger
|
||||
from src.services.request.result import RequestResult
|
||||
from src.services.usage.recorder import UsageRecorder
|
||||
|
||||
|
||||
|
||||
class ChatAdapterBase(ApiAdapter):
|
||||
"""
|
||||
Chat Adapter 通用基类
|
||||
|
||||
提供 Chat 格式的通用适配器逻辑,子类只需配置:
|
||||
- FORMAT_ID: API 格式标识
|
||||
- HANDLER_CLASS: ChatHandlerBase 子类
|
||||
- name: 适配器名称
|
||||
"""
|
||||
|
||||
# 子类必须覆盖
|
||||
FORMAT_ID: str = "UNKNOWN"
|
||||
HANDLER_CLASS: Type[ChatHandlerBase]
|
||||
|
||||
# 适配器配置
|
||||
name: str = "chat.base"
|
||||
mode = ApiMode.STANDARD
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||
self.response_normalizer = None
|
||||
# 可选启用响应规范化
|
||||
self._init_response_normalizer()
|
||||
|
||||
def _init_response_normalizer(self):
|
||||
"""初始化响应规范化器 - 子类可覆盖"""
|
||||
try:
|
||||
from src.services.provider.response_normalizer import ResponseNormalizer
|
||||
|
||||
self.response_normalizer = ResponseNormalizer()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
"""处理 Chat API 请求"""
|
||||
http_request = context.request
|
||||
user = context.user
|
||||
api_key = context.api_key
|
||||
db = context.db
|
||||
request_id = context.request_id
|
||||
quota_remaining_value = context.quota_remaining
|
||||
start_time = context.start_time
|
||||
client_ip = context.client_ip
|
||||
user_agent = context.user_agent
|
||||
original_headers = context.original_headers
|
||||
query_params = context.query_params
|
||||
|
||||
original_request_body = context.ensure_json_body()
|
||||
|
||||
# 合并 path_params 到请求体(如 Gemini API 的 model 在 URL 路径中)
|
||||
if context.path_params:
|
||||
original_request_body = self._merge_path_params(
|
||||
original_request_body, context.path_params
|
||||
)
|
||||
|
||||
# 验证和解析请求
|
||||
request_obj = self._validate_request_body(original_request_body, context.path_params)
|
||||
if isinstance(request_obj, JSONResponse):
|
||||
return request_obj
|
||||
|
||||
stream = getattr(request_obj, "stream", False)
|
||||
model = getattr(request_obj, "model", "unknown")
|
||||
|
||||
# 添加审计元数据
|
||||
audit_metadata = self._build_audit_metadata(original_request_body, request_obj)
|
||||
context.add_audit_metadata(**audit_metadata)
|
||||
|
||||
# 格式化额度显示
|
||||
quota_display = (
|
||||
"unlimited" if quota_remaining_value is None else f"${quota_remaining_value:.2f}"
|
||||
)
|
||||
|
||||
# 请求开始日志
|
||||
logger.info(f"[REQ] {request_id[:8]} | {self.FORMAT_ID} | {getattr(api_key, 'name', 'unknown')} | "
|
||||
f"{model} | {'stream' if stream else 'sync'} | quota:{quota_display}")
|
||||
|
||||
try:
|
||||
# 检查客户端连接
|
||||
if await http_request.is_disconnected():
|
||||
logger.warning("客户端连接断开")
|
||||
raise HTTPException(status_code=499, detail="Client disconnected")
|
||||
|
||||
# 创建 Handler
|
||||
handler = self._create_handler(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
request_id=request_id,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
# 处理请求
|
||||
if stream:
|
||||
return await handler.process_stream(
|
||||
request=request_obj,
|
||||
http_request=http_request,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
query_params=query_params,
|
||||
)
|
||||
return await handler.process_sync(
|
||||
request=request_obj,
|
||||
http_request=http_request,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
query_params=query_params,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
except (
|
||||
ModelNotSupportedException,
|
||||
QuotaExceededException,
|
||||
InvalidRequestException,
|
||||
) as e:
|
||||
logger.info(f"客户端请求错误: {e.error_type}")
|
||||
return self._error_response(
|
||||
status_code=e.status_code,
|
||||
error_type=(
|
||||
"invalid_request_error" if e.status_code == 400 else "quota_exceeded"
|
||||
),
|
||||
message=e.message,
|
||||
)
|
||||
|
||||
except (
|
||||
ProviderAuthException,
|
||||
ProviderRateLimitException,
|
||||
ProviderNotAvailableException,
|
||||
ProviderTimeoutException,
|
||||
UpstreamClientException,
|
||||
) as e:
|
||||
return await self._handle_provider_exception(
|
||||
e,
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
start_time=start_time,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return await self._handle_unexpected_exception(
|
||||
e,
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
start_time=start_time,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
def _create_handler(
|
||||
self,
|
||||
*,
|
||||
db,
|
||||
user,
|
||||
api_key,
|
||||
request_id: str,
|
||||
client_ip: str,
|
||||
user_agent: str,
|
||||
start_time: float,
|
||||
):
|
||||
"""创建 Handler 实例 - 子类可覆盖"""
|
||||
return self.HANDLER_CLASS(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
request_id=request_id,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
start_time=start_time,
|
||||
allowed_api_formats=self.allowed_api_formats,
|
||||
response_normalizer=self.response_normalizer,
|
||||
enable_response_normalization=self.response_normalizer is not None,
|
||||
adapter_detector=self.detect_capability_requirements,
|
||||
)
|
||||
|
||||
def _merge_path_params(
|
||||
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
合并 URL 路径参数到请求体 - 子类可覆盖
|
||||
|
||||
默认实现:直接将 path_params 中的字段合并到请求体(不覆盖已有字段)
|
||||
|
||||
Args:
|
||||
original_request_body: 原始请求体字典
|
||||
path_params: URL 路径参数字典
|
||||
|
||||
Returns:
|
||||
合并后的请求体字典
|
||||
"""
|
||||
merged = original_request_body.copy()
|
||||
for key, value in path_params.items():
|
||||
if key not in merged:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
@abstractmethod
|
||||
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
|
||||
"""
|
||||
验证请求体 - 子类必须实现
|
||||
|
||||
Args:
|
||||
original_request_body: 原始请求体字典
|
||||
path_params: URL 路径参数(如 Gemini 的 stream 通过 URL 端点传入)
|
||||
|
||||
Returns:
|
||||
验证后的请求对象,或 JSONResponse 错误响应
|
||||
"""
|
||||
pass
|
||||
|
||||
def _extract_message_count(self, payload: Dict[str, Any], request_obj) -> int:
|
||||
"""
|
||||
提取消息数量 - 子类可覆盖
|
||||
|
||||
默认实现:从 messages 字段提取
|
||||
"""
|
||||
messages = payload.get("messages", [])
|
||||
if hasattr(request_obj, "messages"):
|
||||
messages = request_obj.messages
|
||||
return len(messages) if isinstance(messages, list) else 0
|
||||
|
||||
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
|
||||
"""
|
||||
构建审计日志元数据 - 子类可覆盖
|
||||
"""
|
||||
model = getattr(request_obj, "model", payload.get("model", "unknown"))
|
||||
stream = getattr(request_obj, "stream", payload.get("stream", False))
|
||||
messages_count = self._extract_message_count(payload, request_obj)
|
||||
|
||||
return {
|
||||
"action": f"{self.FORMAT_ID.lower()}_request",
|
||||
"model": model,
|
||||
"stream": bool(stream),
|
||||
"max_tokens": getattr(request_obj, "max_tokens", payload.get("max_tokens")),
|
||||
"messages_count": messages_count,
|
||||
"temperature": getattr(request_obj, "temperature", payload.get("temperature")),
|
||||
"top_p": getattr(request_obj, "top_p", payload.get("top_p")),
|
||||
}
|
||||
|
||||
async def _handle_provider_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
*,
|
||||
db,
|
||||
user,
|
||||
api_key,
|
||||
model: str,
|
||||
stream: bool,
|
||||
start_time: float,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
client_ip: str,
|
||||
request_id: str,
|
||||
) -> JSONResponse:
|
||||
"""处理 Provider 相关异常"""
|
||||
logger.debug(f"Caught provider exception: {type(e).__name__}")
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 使用 RequestResult.from_exception 创建统一的失败结果
|
||||
# 关键:api_format 从 FORMAT_ID 获取,确保始终有值
|
||||
result = RequestResult.from_exception(
|
||||
exception=e,
|
||||
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
|
||||
model=model,
|
||||
response_time_ms=response_time,
|
||||
is_stream=stream,
|
||||
)
|
||||
result.request_headers = original_headers
|
||||
result.request_body = original_request_body
|
||||
|
||||
# 确定错误消息
|
||||
if isinstance(e, ProviderAuthException):
|
||||
error_message = (
|
||||
f"提供商认证失败: {str(e)}"
|
||||
if result.metadata.provider != "unknown"
|
||||
else "服务端错误: 无可用提供商"
|
||||
)
|
||||
result.error_message = error_message
|
||||
|
||||
# 处理上游客户端错误(如图片处理失败)
|
||||
if isinstance(e, UpstreamClientException):
|
||||
# 返回 400 状态码和清晰的错误消息
|
||||
result.status_code = e.status_code
|
||||
result.error_message = e.message
|
||||
|
||||
# 使用 UsageRecorder 记录失败
|
||||
recorder = UsageRecorder(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
await recorder.record_failure(result, original_headers, original_request_body)
|
||||
|
||||
# 根据异常类型确定错误类型
|
||||
if isinstance(e, UpstreamClientException):
|
||||
error_type = "invalid_request_error"
|
||||
elif result.status_code == 503:
|
||||
error_type = "internal_server_error"
|
||||
else:
|
||||
error_type = "rate_limit_exceeded"
|
||||
|
||||
return self._error_response(
|
||||
status_code=result.status_code,
|
||||
error_type=error_type,
|
||||
message=result.error_message or str(e),
|
||||
)
|
||||
|
||||
async def _handle_unexpected_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
*,
|
||||
db,
|
||||
user,
|
||||
api_key,
|
||||
model: str,
|
||||
stream: bool,
|
||||
start_time: float,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
client_ip: str,
|
||||
request_id: str,
|
||||
) -> JSONResponse:
|
||||
"""处理未预期的异常"""
|
||||
if isinstance(e, ProxyException):
|
||||
logger.error(f"{self.FORMAT_ID} 请求处理业务异常: {type(e).__name__}")
|
||||
else:
|
||||
logger.error(f"{self.FORMAT_ID} 请求处理意外异常",
|
||||
exception=e,
|
||||
extra_data={
|
||||
"exception_class": e.__class__.__name__,
|
||||
"processing_stage": "request_processing",
|
||||
"model": model,
|
||||
"stream": stream,
|
||||
"traceback_preview": str(traceback.format_exc())[:500],
|
||||
},
|
||||
)
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 使用 RequestResult.from_exception 创建统一的失败结果
|
||||
# 关键:api_format 从 FORMAT_ID 获取,确保始终有值
|
||||
result = RequestResult.from_exception(
|
||||
exception=e,
|
||||
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
|
||||
model=model,
|
||||
response_time_ms=response_time,
|
||||
is_stream=stream,
|
||||
)
|
||||
# 对于未预期的异常,强制设置状态码为 500
|
||||
result.status_code = 500
|
||||
result.error_type = "internal_error"
|
||||
result.request_headers = original_headers
|
||||
result.request_body = original_request_body
|
||||
|
||||
try:
|
||||
# 使用 UsageRecorder 记录失败
|
||||
recorder = UsageRecorder(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
await recorder.record_failure(result, original_headers, original_request_body)
|
||||
except Exception as record_error:
|
||||
logger.error(f"记录失败请求时出错: {record_error}")
|
||||
|
||||
return self._error_response(
|
||||
status_code=500,
|
||||
error_type="internal_server_error",
|
||||
message="处理请求时发生内部错误")
|
||||
|
||||
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
|
||||
"""生成错误响应 - 子类可覆盖以自定义格式"""
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"error": {
|
||||
"type": error_type,
|
||||
"message": message,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 计费策略相关方法 - 子类可覆盖以实现不同 API 格式的差异化计费
|
||||
# =========================================================================
|
||||
|
||||
def compute_total_input_context(
|
||||
self,
|
||||
input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
) -> int:
|
||||
"""
|
||||
计算总输入上下文(用于阶梯计费判定)
|
||||
|
||||
默认实现:input_tokens + cache_read_input_tokens
|
||||
子类可覆盖此方法实现不同的计算逻辑
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数(部分格式可能需要)
|
||||
|
||||
Returns:
|
||||
总输入上下文 token 数
|
||||
"""
|
||||
return input_tokens + cache_read_input_tokens
|
||||
|
||||
def get_cache_read_price_for_ttl(
|
||||
self,
|
||||
tier: dict,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
根据缓存 TTL 获取缓存读取价格
|
||||
|
||||
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
|
||||
子类可覆盖此方法实现不同的 TTL 定价逻辑
|
||||
|
||||
Args:
|
||||
tier: 当前阶梯配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
缓存读取价格(每 1M tokens)
|
||||
"""
|
||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||
if ttl_pricing and cache_ttl_minutes is not None:
|
||||
matched_price = None
|
||||
for ttl_config in ttl_pricing:
|
||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||
if cache_ttl_minutes <= ttl_limit:
|
||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
||||
break
|
||||
if matched_price is not None:
|
||||
return matched_price
|
||||
# 超过所有配置的 TTL,使用最后一个
|
||||
if ttl_pricing:
|
||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||
|
||||
return tier.get("cache_read_price_per_1m")
|
||||
|
||||
def compute_cost(
|
||||
self,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
input_price_per_1m: float,
|
||||
output_price_per_1m: float,
|
||||
cache_creation_price_per_1m: Optional[float],
|
||||
cache_read_price_per_1m: Optional[float],
|
||||
price_per_request: Optional[float],
|
||||
tiered_pricing: Optional[dict] = None,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
计算请求成本
|
||||
|
||||
默认实现:支持固定价格和阶梯计费
|
||||
子类可覆盖此方法实现完全不同的计费逻辑
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
output_tokens: 输出 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
input_price_per_1m: 输入价格(每 1M tokens)
|
||||
output_price_per_1m: 输出价格(每 1M tokens)
|
||||
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens)
|
||||
cache_read_price_per_1m: 缓存读取价格(每 1M tokens)
|
||||
price_per_request: 按次计费价格
|
||||
tiered_pricing: 阶梯计费配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
包含各项成本的字典:
|
||||
{
|
||||
"input_cost": float,
|
||||
"output_cost": float,
|
||||
"cache_creation_cost": float,
|
||||
"cache_read_cost": float,
|
||||
"cache_cost": float,
|
||||
"request_cost": float,
|
||||
"total_cost": float,
|
||||
"tier_index": Optional[int], # 命中的阶梯索引
|
||||
}
|
||||
"""
|
||||
tier_index = None
|
||||
effective_input_price = input_price_per_1m
|
||||
effective_output_price = output_price_per_1m
|
||||
effective_cache_creation_price = cache_creation_price_per_1m
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 检查阶梯计费
|
||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
|
||||
|
||||
if tier:
|
||||
tier_index = tiered_pricing["tiers"].index(tier)
|
||||
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
|
||||
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
|
||||
effective_cache_creation_price = tier.get(
|
||||
"cache_creation_price_per_1m", cache_creation_price_per_1m
|
||||
)
|
||||
effective_cache_read_price = self.get_cache_read_price_for_ttl(
|
||||
tier, cache_ttl_minutes
|
||||
)
|
||||
if effective_cache_read_price is None:
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 计算各项成本
|
||||
input_cost = (input_tokens / 1_000_000) * effective_input_price
|
||||
output_cost = (output_tokens / 1_000_000) * effective_output_price
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * effective_cache_creation_price
|
||||
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
|
||||
cache_read_cost = (
|
||||
cache_read_input_tokens / 1_000_000
|
||||
) * effective_cache_read_price
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
request_cost = price_per_request if price_per_request else 0.0
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
|
||||
return {
|
||||
"input_cost": input_cost,
|
||||
"output_cost": output_cost,
|
||||
"cache_creation_cost": cache_creation_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_cost": cache_cost,
|
||||
"request_cost": request_cost,
|
||||
"total_cost": total_cost,
|
||||
"tier_index": tier_index,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
|
||||
"""
|
||||
根据总输入 token 数确定价格阶梯
|
||||
|
||||
Args:
|
||||
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
|
||||
total_input_tokens: 总输入 token 数
|
||||
|
||||
Returns:
|
||||
匹配的阶梯配置
|
||||
"""
|
||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
for tier in tiers:
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_tokens <= up_to:
|
||||
return tier
|
||||
|
||||
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
|
||||
return tiers[-1] if tiers else None
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
|
||||
# =========================================================================
|
||||
|
||||
_ADAPTER_REGISTRY: Dict[str, Type["ChatAdapterBase"]] = {}
|
||||
_ADAPTERS_LOADED = False
|
||||
|
||||
|
||||
def register_adapter(adapter_class: Type["ChatAdapterBase"]) -> Type["ChatAdapterBase"]:
|
||||
"""
|
||||
注册 Adapter 类到注册表
|
||||
|
||||
用法:
|
||||
@register_adapter
|
||||
class ClaudeChatAdapter(ChatAdapterBase):
|
||||
FORMAT_ID = "CLAUDE"
|
||||
...
|
||||
|
||||
Args:
|
||||
adapter_class: Adapter 类
|
||||
|
||||
Returns:
|
||||
注册的 Adapter 类(支持作为装饰器使用)
|
||||
"""
|
||||
format_id = adapter_class.FORMAT_ID
|
||||
if format_id and format_id != "UNKNOWN":
|
||||
_ADAPTER_REGISTRY[format_id.upper()] = adapter_class
|
||||
return adapter_class
|
||||
|
||||
|
||||
def _ensure_adapters_loaded():
|
||||
"""确保所有 Adapter 已被加载(触发注册)"""
|
||||
global _ADAPTERS_LOADED
|
||||
if _ADAPTERS_LOADED:
|
||||
return
|
||||
|
||||
# 导入各个 Adapter 模块以触发 @register_adapter 装饰器
|
||||
try:
|
||||
from src.api.handlers.claude import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from src.api.handlers.openai import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from src.api.handlers.gemini import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_ADAPTERS_LOADED = True
|
||||
|
||||
|
||||
def get_adapter_class(api_format: str) -> Optional[Type["ChatAdapterBase"]]:
|
||||
"""
|
||||
根据 API format 获取 Adapter 类
|
||||
|
||||
Args:
|
||||
api_format: API 格式标识(如 "CLAUDE", "OPENAI", "GEMINI")
|
||||
|
||||
Returns:
|
||||
对应的 Adapter 类,如果未找到返回 None
|
||||
"""
|
||||
_ensure_adapters_loaded()
|
||||
return _ADAPTER_REGISTRY.get(api_format.upper()) if api_format else None
|
||||
|
||||
|
||||
def get_adapter_instance(api_format: str) -> Optional["ChatAdapterBase"]:
|
||||
"""
|
||||
根据 API format 获取 Adapter 实例
|
||||
|
||||
Args:
|
||||
api_format: API 格式标识
|
||||
|
||||
Returns:
|
||||
Adapter 实例,如果未找到返回 None
|
||||
"""
|
||||
adapter_class = get_adapter_class(api_format)
|
||||
if adapter_class:
|
||||
return adapter_class()
|
||||
return None
|
||||
|
||||
|
||||
def list_registered_formats() -> list[str]:
|
||||
"""返回所有已注册的 API 格式"""
|
||||
_ensure_adapters_loaded()
|
||||
return list(_ADAPTER_REGISTRY.keys())
|
||||
1257
src/api/handlers/base/chat_handler_base.py
Normal file
1257
src/api/handlers/base/chat_handler_base.py
Normal file
File diff suppressed because it is too large
Load Diff
648
src/api/handlers/base/cli_adapter_base.py
Normal file
648
src/api/handlers/base/cli_adapter_base.py
Normal file
@@ -0,0 +1,648 @@
|
||||
"""
|
||||
CLI Adapter 通用基类
|
||||
|
||||
提供 CLI 格式(直接透传请求)的通用适配器逻辑:
|
||||
- 请求解析和验证
|
||||
- 审计日志记录
|
||||
- 错误处理和响应格式化
|
||||
- Handler 创建和调用
|
||||
- 计费策略(支持不同 API 格式的差异化计费)
|
||||
|
||||
子类只需提供:
|
||||
- FORMAT_ID: API 格式标识
|
||||
- HANDLER_CLASS: 对应的 MessageHandler 类
|
||||
- 可选覆盖 _extract_message_count() 自定义消息计数逻辑
|
||||
- 可选覆盖 compute_total_input_context() 自定义总输入上下文计算
|
||||
"""
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.context import ApiRequestContext
|
||||
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
||||
from src.core.exceptions import (
|
||||
InvalidRequestException,
|
||||
ModelNotSupportedException,
|
||||
ProviderAuthException,
|
||||
ProviderNotAvailableException,
|
||||
ProviderRateLimitException,
|
||||
ProviderTimeoutException,
|
||||
ProxyException,
|
||||
QuotaExceededException,
|
||||
UpstreamClientException,
|
||||
)
|
||||
from src.core.logger import logger
|
||||
from src.services.request.result import RequestResult
|
||||
from src.services.usage.recorder import UsageRecorder
|
||||
|
||||
|
||||
|
||||
class CliAdapterBase(ApiAdapter):
|
||||
"""
|
||||
CLI Adapter 通用基类
|
||||
|
||||
提供 CLI 格式的通用适配器逻辑,子类只需配置:
|
||||
- FORMAT_ID: API 格式标识
|
||||
- HANDLER_CLASS: MessageHandler 类
|
||||
- name: 适配器名称
|
||||
"""
|
||||
|
||||
# 子类必须覆盖
|
||||
FORMAT_ID: str = "UNKNOWN"
|
||||
HANDLER_CLASS: Type[CliMessageHandlerBase]
|
||||
|
||||
# 适配器配置
|
||||
name: str = "cli.base"
|
||||
mode = ApiMode.PROXY
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
"""处理 CLI API 请求"""
|
||||
http_request = context.request
|
||||
user = context.user
|
||||
api_key = context.api_key
|
||||
db = context.db
|
||||
request_id = context.request_id
|
||||
quota_remaining_value = context.quota_remaining
|
||||
start_time = context.start_time
|
||||
client_ip = context.client_ip
|
||||
user_agent = context.user_agent
|
||||
original_headers = context.original_headers
|
||||
query_params = context.query_params # 获取查询参数
|
||||
|
||||
original_request_body = context.ensure_json_body()
|
||||
|
||||
# 合并 path_params 到请求体(如 Gemini API 的 model 在 URL 路径中)
|
||||
if context.path_params:
|
||||
original_request_body = self._merge_path_params(
|
||||
original_request_body, context.path_params
|
||||
)
|
||||
|
||||
# 获取 stream:优先从请求体,其次从 path_params(如 Gemini 通过 URL 端点区分)
|
||||
stream = original_request_body.get("stream")
|
||||
if stream is None and context.path_params:
|
||||
stream = context.path_params.get("stream", False)
|
||||
stream = bool(stream)
|
||||
|
||||
# 获取 model:优先从请求体,其次从 path_params(如 Gemini 的 model 在 URL 路径中)
|
||||
model = original_request_body.get("model")
|
||||
if model is None and context.path_params:
|
||||
model = context.path_params.get("model", "unknown")
|
||||
model = model or "unknown"
|
||||
|
||||
# 提取请求元数据
|
||||
audit_metadata = self._build_audit_metadata(original_request_body, context.path_params)
|
||||
context.add_audit_metadata(**audit_metadata)
|
||||
|
||||
# 格式化额度显示
|
||||
quota_display = (
|
||||
"unlimited" if quota_remaining_value is None else f"${quota_remaining_value:.2f}"
|
||||
)
|
||||
|
||||
# 请求开始日志
|
||||
logger.info(f"[REQ] {request_id[:8]} | {self.FORMAT_ID} | {getattr(api_key, 'name', 'unknown')} | "
|
||||
f"{model} | {'stream' if stream else 'sync'} | quota:{quota_display}")
|
||||
|
||||
try:
|
||||
# 检查客户端连接
|
||||
if await http_request.is_disconnected():
|
||||
logger.warning("客户端连接断开")
|
||||
raise HTTPException(status_code=499, detail="Client disconnected")
|
||||
|
||||
# 创建 Handler
|
||||
handler = self.HANDLER_CLASS(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
request_id=request_id,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
start_time=start_time,
|
||||
allowed_api_formats=self.allowed_api_formats,
|
||||
adapter_detector=self.detect_capability_requirements,
|
||||
)
|
||||
|
||||
# 处理请求
|
||||
if stream:
|
||||
return await handler.process_stream(
|
||||
original_request_body=original_request_body,
|
||||
original_headers=original_headers,
|
||||
query_params=query_params,
|
||||
path_params=context.path_params,
|
||||
)
|
||||
return await handler.process_sync(
|
||||
original_request_body=original_request_body,
|
||||
original_headers=original_headers,
|
||||
query_params=query_params,
|
||||
path_params=context.path_params,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
except (
|
||||
ModelNotSupportedException,
|
||||
QuotaExceededException,
|
||||
InvalidRequestException,
|
||||
) as e:
|
||||
logger.debug(f"客户端请求错误: {e.error_type}")
|
||||
return self._error_response(
|
||||
status_code=e.status_code,
|
||||
error_type=(
|
||||
"invalid_request_error" if e.status_code == 400 else "quota_exceeded"
|
||||
),
|
||||
message=e.message,
|
||||
)
|
||||
|
||||
except (
|
||||
ProviderAuthException,
|
||||
ProviderRateLimitException,
|
||||
ProviderNotAvailableException,
|
||||
ProviderTimeoutException,
|
||||
UpstreamClientException,
|
||||
) as e:
|
||||
return await self._handle_provider_exception(
|
||||
e,
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
start_time=start_time,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return await self._handle_unexpected_exception(
|
||||
e,
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
start_time=start_time,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
def _merge_path_params(
|
||||
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
合并 URL 路径参数到请求体 - 子类可覆盖
|
||||
|
||||
默认实现:直接将 path_params 中的字段合并到请求体(不覆盖已有字段)
|
||||
|
||||
Args:
|
||||
original_request_body: 原始请求体字典
|
||||
path_params: URL 路径参数字典
|
||||
|
||||
Returns:
|
||||
合并后的请求体字典
|
||||
"""
|
||||
merged = original_request_body.copy()
|
||||
for key, value in path_params.items():
|
||||
if key not in merged:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
def _extract_message_count(self, payload: Dict[str, Any]) -> int:
|
||||
"""
|
||||
提取消息数量 - 子类可覆盖
|
||||
|
||||
默认实现:从 input 字段提取
|
||||
"""
|
||||
if "input" not in payload:
|
||||
return 0
|
||||
input_data = payload["input"]
|
||||
if isinstance(input_data, list):
|
||||
return len(input_data)
|
||||
if isinstance(input_data, dict) and "messages" in input_data:
|
||||
return len(input_data.get("messages", []))
|
||||
return 0
|
||||
|
||||
def _build_audit_metadata(
|
||||
self,
|
||||
payload: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
构建审计日志元数据 - 子类可覆盖
|
||||
|
||||
Args:
|
||||
payload: 请求体
|
||||
path_params: URL 路径参数(用于获取 model 等)
|
||||
"""
|
||||
# 优先从请求体获取 model,其次从 path_params
|
||||
model = payload.get("model")
|
||||
if model is None and path_params:
|
||||
model = path_params.get("model", "unknown")
|
||||
model = model or "unknown"
|
||||
|
||||
stream = payload.get("stream", False)
|
||||
messages_count = self._extract_message_count(payload)
|
||||
|
||||
return {
|
||||
"action": f"{self.FORMAT_ID.lower()}_request",
|
||||
"model": model,
|
||||
"stream": bool(stream),
|
||||
"max_tokens": payload.get("max_tokens"),
|
||||
"messages_count": messages_count,
|
||||
"temperature": payload.get("temperature"),
|
||||
"top_p": payload.get("top_p"),
|
||||
"tool_count": len(payload.get("tools") or []),
|
||||
"instructions_present": bool(payload.get("instructions")),
|
||||
}
|
||||
|
||||
async def _handle_provider_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
*,
|
||||
db,
|
||||
user,
|
||||
api_key,
|
||||
model: str,
|
||||
stream: bool,
|
||||
start_time: float,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
client_ip: str,
|
||||
request_id: str,
|
||||
) -> JSONResponse:
|
||||
"""处理 Provider 相关异常"""
|
||||
logger.debug(f"Caught provider exception: {type(e).__name__}")
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 使用 RequestResult.from_exception 创建统一的失败结果
|
||||
# 关键:api_format 从 FORMAT_ID 获取,确保始终有值
|
||||
result = RequestResult.from_exception(
|
||||
exception=e,
|
||||
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
|
||||
model=model,
|
||||
response_time_ms=response_time,
|
||||
is_stream=stream,
|
||||
)
|
||||
result.request_headers = original_headers
|
||||
result.request_body = original_request_body
|
||||
|
||||
# 确定错误消息
|
||||
if isinstance(e, ProviderAuthException):
|
||||
error_message = (
|
||||
f"提供商认证失败: {str(e)}"
|
||||
if result.metadata.provider != "unknown"
|
||||
else "服务端错误: 无可用提供商"
|
||||
)
|
||||
result.error_message = error_message
|
||||
|
||||
# 处理上游客户端错误(如图片处理失败)
|
||||
if isinstance(e, UpstreamClientException):
|
||||
# 返回 400 状态码和清晰的错误消息
|
||||
result.status_code = e.status_code
|
||||
result.error_message = e.message
|
||||
|
||||
# 使用 UsageRecorder 记录失败
|
||||
recorder = UsageRecorder(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
await recorder.record_failure(result, original_headers, original_request_body)
|
||||
|
||||
# 根据异常类型确定错误类型
|
||||
if isinstance(e, UpstreamClientException):
|
||||
error_type = "invalid_request_error"
|
||||
elif result.status_code == 503:
|
||||
error_type = "internal_server_error"
|
||||
else:
|
||||
error_type = "rate_limit_exceeded"
|
||||
|
||||
return self._error_response(
|
||||
status_code=result.status_code,
|
||||
error_type=error_type,
|
||||
message=result.error_message or str(e),
|
||||
)
|
||||
|
||||
async def _handle_unexpected_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
*,
|
||||
db,
|
||||
user,
|
||||
api_key,
|
||||
model: str,
|
||||
stream: bool,
|
||||
start_time: float,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
client_ip: str,
|
||||
request_id: str,
|
||||
) -> JSONResponse:
|
||||
"""处理未预期的异常"""
|
||||
if isinstance(e, ProxyException):
|
||||
logger.error(f"{self.FORMAT_ID} 请求处理业务异常: {type(e).__name__}")
|
||||
else:
|
||||
logger.error(f"{self.FORMAT_ID} 请求处理意外异常",
|
||||
exception=e,
|
||||
extra_data={
|
||||
"exception_class": e.__class__.__name__,
|
||||
"processing_stage": "request_processing",
|
||||
"model": model,
|
||||
"stream": stream,
|
||||
"traceback_preview": str(traceback.format_exc())[:500],
|
||||
},
|
||||
)
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 使用 RequestResult.from_exception 创建统一的失败结果
|
||||
# 关键:api_format 从 FORMAT_ID 获取,确保始终有值
|
||||
result = RequestResult.from_exception(
|
||||
exception=e,
|
||||
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
|
||||
model=model,
|
||||
response_time_ms=response_time,
|
||||
is_stream=stream,
|
||||
)
|
||||
# 对于未预期的异常,强制设置状态码为 500
|
||||
result.status_code = 500
|
||||
result.error_type = "internal_error"
|
||||
result.request_headers = original_headers
|
||||
result.request_body = original_request_body
|
||||
|
||||
# 使用 UsageRecorder 记录失败
|
||||
recorder = UsageRecorder(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
await recorder.record_failure(result, original_headers, original_request_body)
|
||||
|
||||
return self._error_response(
|
||||
status_code=500,
|
||||
error_type="internal_server_error",
|
||||
message="处理请求时发生内部错误")
|
||||
|
||||
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
|
||||
"""生成错误响应"""
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"error": {
|
||||
"type": error_type,
|
||||
"message": message,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 计费策略相关方法 - 子类可覆盖以实现不同 API 格式的差异化计费
|
||||
# =========================================================================
|
||||
|
||||
def compute_total_input_context(
|
||||
self,
|
||||
input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
) -> int:
|
||||
"""
|
||||
计算总输入上下文(用于阶梯计费判定)
|
||||
|
||||
默认实现:input_tokens + cache_read_input_tokens
|
||||
子类可覆盖此方法实现不同的计算逻辑
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数(部分格式可能需要)
|
||||
|
||||
Returns:
|
||||
总输入上下文 token 数
|
||||
"""
|
||||
return input_tokens + cache_read_input_tokens
|
||||
|
||||
def get_cache_read_price_for_ttl(
|
||||
self,
|
||||
tier: dict,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
根据缓存 TTL 获取缓存读取价格
|
||||
|
||||
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
|
||||
子类可覆盖此方法实现不同的 TTL 定价逻辑
|
||||
|
||||
Args:
|
||||
tier: 当前阶梯配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
缓存读取价格(每 1M tokens)
|
||||
"""
|
||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||
if ttl_pricing and cache_ttl_minutes is not None:
|
||||
matched_price = None
|
||||
for ttl_config in ttl_pricing:
|
||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||
if cache_ttl_minutes <= ttl_limit:
|
||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
||||
break
|
||||
if matched_price is not None:
|
||||
return matched_price
|
||||
# 超过所有配置的 TTL,使用最后一个
|
||||
if ttl_pricing:
|
||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||
|
||||
return tier.get("cache_read_price_per_1m")
|
||||
|
||||
def compute_cost(
|
||||
self,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
input_price_per_1m: float,
|
||||
output_price_per_1m: float,
|
||||
cache_creation_price_per_1m: Optional[float],
|
||||
cache_read_price_per_1m: Optional[float],
|
||||
price_per_request: Optional[float],
|
||||
tiered_pricing: Optional[dict] = None,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
计算请求成本
|
||||
|
||||
默认实现:支持固定价格和阶梯计费
|
||||
子类可覆盖此方法实现完全不同的计费逻辑
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
output_tokens: 输出 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
input_price_per_1m: 输入价格(每 1M tokens)
|
||||
output_price_per_1m: 输出价格(每 1M tokens)
|
||||
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens)
|
||||
cache_read_price_per_1m: 缓存读取价格(每 1M tokens)
|
||||
price_per_request: 按次计费价格
|
||||
tiered_pricing: 阶梯计费配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
包含各项成本的字典
|
||||
"""
|
||||
tier_index = None
|
||||
effective_input_price = input_price_per_1m
|
||||
effective_output_price = output_price_per_1m
|
||||
effective_cache_creation_price = cache_creation_price_per_1m
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 检查阶梯计费
|
||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
|
||||
|
||||
if tier:
|
||||
tier_index = tiered_pricing["tiers"].index(tier)
|
||||
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
|
||||
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
|
||||
effective_cache_creation_price = tier.get(
|
||||
"cache_creation_price_per_1m", cache_creation_price_per_1m
|
||||
)
|
||||
effective_cache_read_price = self.get_cache_read_price_for_ttl(
|
||||
tier, cache_ttl_minutes
|
||||
)
|
||||
if effective_cache_read_price is None:
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 计算各项成本
|
||||
input_cost = (input_tokens / 1_000_000) * effective_input_price
|
||||
output_cost = (output_tokens / 1_000_000) * effective_output_price
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * effective_cache_creation_price
|
||||
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
|
||||
cache_read_cost = (
|
||||
cache_read_input_tokens / 1_000_000
|
||||
) * effective_cache_read_price
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
request_cost = price_per_request if price_per_request else 0.0
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
|
||||
return {
|
||||
"input_cost": input_cost,
|
||||
"output_cost": output_cost,
|
||||
"cache_creation_cost": cache_creation_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_cost": cache_cost,
|
||||
"request_cost": request_cost,
|
||||
"total_cost": total_cost,
|
||||
"tier_index": tier_index,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
|
||||
"""根据总输入 token 数确定价格阶梯"""
|
||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
for tier in tiers:
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_tokens <= up_to:
|
||||
return tier
|
||||
|
||||
return tiers[-1] if tiers else None
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
|
||||
# =========================================================================
|
||||
|
||||
_CLI_ADAPTER_REGISTRY: Dict[str, Type["CliAdapterBase"]] = {}
|
||||
_CLI_ADAPTERS_LOADED = False
|
||||
|
||||
|
||||
def register_cli_adapter(adapter_class: Type["CliAdapterBase"]) -> Type["CliAdapterBase"]:
|
||||
"""
|
||||
注册 CLI Adapter 类到注册表
|
||||
|
||||
用法:
|
||||
@register_cli_adapter
|
||||
class ClaudeCliAdapter(CliAdapterBase):
|
||||
FORMAT_ID = "CLAUDE_CLI"
|
||||
...
|
||||
"""
|
||||
format_id = adapter_class.FORMAT_ID
|
||||
if format_id and format_id != "UNKNOWN":
|
||||
_CLI_ADAPTER_REGISTRY[format_id.upper()] = adapter_class
|
||||
return adapter_class
|
||||
|
||||
|
||||
def _ensure_cli_adapters_loaded():
|
||||
"""确保所有 CLI Adapter 已被加载(触发注册)"""
|
||||
global _CLI_ADAPTERS_LOADED
|
||||
if _CLI_ADAPTERS_LOADED:
|
||||
return
|
||||
|
||||
# 导入各个 CLI Adapter 模块以触发 @register_cli_adapter 装饰器
|
||||
try:
|
||||
from src.api.handlers.claude_cli import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from src.api.handlers.openai_cli import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from src.api.handlers.gemini_cli import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_CLI_ADAPTERS_LOADED = True
|
||||
|
||||
|
||||
def get_cli_adapter_class(api_format: str) -> Optional[Type["CliAdapterBase"]]:
|
||||
"""根据 API format 获取 CLI Adapter 类"""
|
||||
_ensure_cli_adapters_loaded()
|
||||
return _CLI_ADAPTER_REGISTRY.get(api_format.upper()) if api_format else None
|
||||
|
||||
|
||||
def get_cli_adapter_instance(api_format: str) -> Optional["CliAdapterBase"]:
|
||||
"""根据 API format 获取 CLI Adapter 实例"""
|
||||
adapter_class = get_cli_adapter_class(api_format)
|
||||
if adapter_class:
|
||||
return adapter_class()
|
||||
return None
|
||||
|
||||
|
||||
def list_registered_cli_formats() -> list[str]:
|
||||
"""返回所有已注册的 CLI API 格式"""
|
||||
_ensure_cli_adapters_loaded()
|
||||
return list(_CLI_ADAPTER_REGISTRY.keys())
|
||||
1614
src/api/handlers/base/cli_handler_base.py
Normal file
1614
src/api/handlers/base/cli_handler_base.py
Normal file
File diff suppressed because it is too large
Load Diff
279
src/api/handlers/base/format_converter_registry.py
Normal file
279
src/api/handlers/base/format_converter_registry.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
格式转换器注册表
|
||||
|
||||
自动管理不同 API 格式之间的转换器,支持:
|
||||
- 请求转换:客户端格式 → Provider 格式
|
||||
- 响应转换:Provider 格式 → 客户端格式
|
||||
|
||||
使用方法:
|
||||
1. 实现 Converter 类(需要有 convert_request 和/或 convert_response 方法)
|
||||
2. 调用 registry.register() 注册转换器
|
||||
3. 在 Handler 中调用 registry.convert_request/convert_response
|
||||
|
||||
示例:
|
||||
from src.api.handlers.base.format_converter_registry import converter_registry
|
||||
|
||||
# 注册转换器
|
||||
converter_registry.register("CLAUDE", "GEMINI", ClaudeToGeminiConverter())
|
||||
converter_registry.register("GEMINI", "CLAUDE", GeminiToClaudeConverter())
|
||||
|
||||
# 使用转换器
|
||||
gemini_request = converter_registry.convert_request(claude_request, "CLAUDE", "GEMINI")
|
||||
claude_response = converter_registry.convert_response(gemini_response, "GEMINI", "CLAUDE")
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Protocol, Tuple
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
|
||||
class RequestConverter(Protocol):
|
||||
"""请求转换器协议"""
|
||||
|
||||
def convert_request(self, request: Dict[str, Any]) -> Dict[str, Any]: ...
|
||||
|
||||
|
||||
class ResponseConverter(Protocol):
|
||||
"""响应转换器协议"""
|
||||
|
||||
def convert_response(self, response: Dict[str, Any]) -> Dict[str, Any]: ...
|
||||
|
||||
|
||||
class StreamChunkConverter(Protocol):
|
||||
"""流式响应块转换器协议"""
|
||||
|
||||
def convert_stream_chunk(self, chunk: Dict[str, Any]) -> Dict[str, Any]: ...
|
||||
|
||||
|
||||
class FormatConverterRegistry:
|
||||
"""
|
||||
格式转换器注册表
|
||||
|
||||
管理不同 API 格式之间的双向转换器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# key: (source_format, target_format), value: converter instance
|
||||
self._converters: Dict[Tuple[str, str], Any] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
converter: Any,
|
||||
) -> None:
|
||||
"""
|
||||
注册格式转换器
|
||||
|
||||
Args:
|
||||
source_format: 源格式(如 "CLAUDE", "OPENAI", "GEMINI")
|
||||
target_format: 目标格式
|
||||
converter: 转换器实例(需要有 convert_request/convert_response 方法)
|
||||
"""
|
||||
key = (source_format.upper(), target_format.upper())
|
||||
self._converters[key] = converter
|
||||
logger.info(f"[ConverterRegistry] 注册转换器: {source_format} -> {target_format}")
|
||||
|
||||
def get_converter(
|
||||
self,
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
获取转换器
|
||||
|
||||
Args:
|
||||
source_format: 源格式
|
||||
target_format: 目标格式
|
||||
|
||||
Returns:
|
||||
转换器实例,如果不存在返回 None
|
||||
"""
|
||||
key = (source_format.upper(), target_format.upper())
|
||||
return self._converters.get(key)
|
||||
|
||||
def has_converter(
|
||||
self,
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
) -> bool:
|
||||
"""检查是否存在转换器"""
|
||||
key = (source_format.upper(), target_format.upper())
|
||||
return key in self._converters
|
||||
|
||||
def convert_request(
|
||||
self,
|
||||
request: Dict[str, Any],
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
转换请求
|
||||
|
||||
Args:
|
||||
request: 原始请求字典
|
||||
source_format: 源格式(客户端格式)
|
||||
target_format: 目标格式(Provider 格式)
|
||||
|
||||
Returns:
|
||||
转换后的请求字典,如果无需转换或没有转换器则返回原始请求
|
||||
"""
|
||||
# 同格式无需转换
|
||||
if source_format.upper() == target_format.upper():
|
||||
return request
|
||||
|
||||
converter = self.get_converter(source_format, target_format)
|
||||
if converter is None:
|
||||
logger.warning(f"[ConverterRegistry] 未找到请求转换器: {source_format} -> {target_format},返回原始请求")
|
||||
return request
|
||||
|
||||
if not hasattr(converter, "convert_request"):
|
||||
logger.warning(f"[ConverterRegistry] 转换器缺少 convert_request 方法: {source_format} -> {target_format}")
|
||||
return request
|
||||
|
||||
try:
|
||||
converted = converter.convert_request(request)
|
||||
logger.debug(f"[ConverterRegistry] 请求转换成功: {source_format} -> {target_format}")
|
||||
return converted
|
||||
except Exception as e:
|
||||
logger.error(f"[ConverterRegistry] 请求转换失败: {source_format} -> {target_format}: {e}")
|
||||
return request
|
||||
|
||||
def convert_response(
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
转换响应
|
||||
|
||||
Args:
|
||||
response: 原始响应字典
|
||||
source_format: 源格式(Provider 格式)
|
||||
target_format: 目标格式(客户端格式)
|
||||
|
||||
Returns:
|
||||
转换后的响应字典,如果无需转换或没有转换器则返回原始响应
|
||||
"""
|
||||
# 同格式无需转换
|
||||
if source_format.upper() == target_format.upper():
|
||||
return response
|
||||
|
||||
converter = self.get_converter(source_format, target_format)
|
||||
if converter is None:
|
||||
logger.warning(f"[ConverterRegistry] 未找到响应转换器: {source_format} -> {target_format},返回原始响应")
|
||||
return response
|
||||
|
||||
if not hasattr(converter, "convert_response"):
|
||||
logger.warning(f"[ConverterRegistry] 转换器缺少 convert_response 方法: {source_format} -> {target_format}")
|
||||
return response
|
||||
|
||||
try:
|
||||
converted = converter.convert_response(response)
|
||||
logger.debug(f"[ConverterRegistry] 响应转换成功: {source_format} -> {target_format}")
|
||||
return converted
|
||||
except Exception as e:
|
||||
logger.error(f"[ConverterRegistry] 响应转换失败: {source_format} -> {target_format}: {e}")
|
||||
return response
|
||||
|
||||
def convert_stream_chunk(
|
||||
self,
|
||||
chunk: Dict[str, Any],
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
转换流式响应块
|
||||
|
||||
Args:
|
||||
chunk: 原始流式响应块
|
||||
source_format: 源格式(Provider 格式)
|
||||
target_format: 目标格式(客户端格式)
|
||||
|
||||
Returns:
|
||||
转换后的流式响应块
|
||||
"""
|
||||
# 同格式无需转换
|
||||
if source_format.upper() == target_format.upper():
|
||||
return chunk
|
||||
|
||||
converter = self.get_converter(source_format, target_format)
|
||||
if converter is None:
|
||||
return chunk
|
||||
|
||||
# 优先使用专门的流式转换方法
|
||||
if hasattr(converter, "convert_stream_chunk"):
|
||||
try:
|
||||
return converter.convert_stream_chunk(chunk)
|
||||
except Exception as e:
|
||||
logger.error(f"[ConverterRegistry] 流式块转换失败: {source_format} -> {target_format}: {e}")
|
||||
return chunk
|
||||
|
||||
# 降级到普通响应转换
|
||||
if hasattr(converter, "convert_response"):
|
||||
try:
|
||||
return converter.convert_response(chunk)
|
||||
except Exception:
|
||||
return chunk
|
||||
|
||||
return chunk
|
||||
|
||||
def list_converters(self) -> list[Tuple[str, str]]:
|
||||
"""列出所有已注册的转换器"""
|
||||
return list(self._converters.keys())
|
||||
|
||||
|
||||
# 全局单例
|
||||
converter_registry = FormatConverterRegistry()
|
||||
|
||||
|
||||
def register_all_converters():
|
||||
"""
|
||||
注册所有内置的格式转换器
|
||||
|
||||
在应用启动时调用此函数
|
||||
"""
|
||||
# Claude <-> OpenAI
|
||||
try:
|
||||
from src.api.handlers.claude.converter import OpenAIToClaudeConverter
|
||||
from src.api.handlers.openai.converter import ClaudeToOpenAIConverter
|
||||
|
||||
converter_registry.register("OPENAI", "CLAUDE", OpenAIToClaudeConverter())
|
||||
converter_registry.register("CLAUDE", "OPENAI", ClaudeToOpenAIConverter())
|
||||
except ImportError as e:
|
||||
logger.warning(f"[ConverterRegistry] 无法加载 Claude/OpenAI 转换器: {e}")
|
||||
|
||||
# Claude <-> Gemini
|
||||
try:
|
||||
from src.api.handlers.gemini.converter import (
|
||||
ClaudeToGeminiConverter,
|
||||
GeminiToClaudeConverter,
|
||||
)
|
||||
|
||||
converter_registry.register("CLAUDE", "GEMINI", ClaudeToGeminiConverter())
|
||||
converter_registry.register("GEMINI", "CLAUDE", GeminiToClaudeConverter())
|
||||
except ImportError as e:
|
||||
logger.warning(f"[ConverterRegistry] 无法加载 Claude/Gemini 转换器: {e}")
|
||||
|
||||
# OpenAI <-> Gemini
|
||||
try:
|
||||
from src.api.handlers.gemini.converter import (
|
||||
GeminiToOpenAIConverter,
|
||||
OpenAIToGeminiConverter,
|
||||
)
|
||||
|
||||
converter_registry.register("OPENAI", "GEMINI", OpenAIToGeminiConverter())
|
||||
converter_registry.register("GEMINI", "OPENAI", GeminiToOpenAIConverter())
|
||||
except ImportError as e:
|
||||
logger.warning(f"[ConverterRegistry] 无法加载 OpenAI/Gemini 转换器: {e}")
|
||||
|
||||
logger.info(f"[ConverterRegistry] 已注册 {len(converter_registry.list_converters())} 个格式转换器")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FormatConverterRegistry",
|
||||
"converter_registry",
|
||||
"register_all_converters",
|
||||
]
|
||||
465
src/api/handlers/base/parsers.py
Normal file
465
src/api/handlers/base/parsers.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""
|
||||
响应解析器工厂
|
||||
|
||||
直接根据格式 ID 创建对应的 ResponseParser 实现,
|
||||
不再经过 Protocol 抽象层。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.response_parser import (
|
||||
ParsedChunk,
|
||||
ParsedResponse,
|
||||
ResponseParser,
|
||||
StreamStats,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIResponseParser(ResponseParser):
|
||||
"""OpenAI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
from src.api.handlers.openai.stream_parser import OpenAIStreamParser
|
||||
|
||||
self._parser = OpenAIStreamParser()
|
||||
self.name = "OPENAI"
|
||||
self.api_format = "OPENAI"
|
||||
|
||||
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
|
||||
if not line or not line.strip():
|
||||
return None
|
||||
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
else:
|
||||
data_str = line
|
||||
|
||||
parsed = self._parser.parse_line(data_str)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
chunk = ParsedChunk(
|
||||
raw_line=line,
|
||||
event_type=None,
|
||||
data=parsed,
|
||||
)
|
||||
|
||||
# 提取文本增量
|
||||
text_delta = self._parser.extract_text_delta(parsed)
|
||||
if text_delta:
|
||||
chunk.text_delta = text_delta
|
||||
stats.collected_text += text_delta
|
||||
|
||||
# 检查是否结束
|
||||
if self._parser.is_done_chunk(parsed):
|
||||
chunk.is_done = True
|
||||
stats.has_completion = True
|
||||
|
||||
stats.chunk_count += 1
|
||||
stats.data_count += 1
|
||||
|
||||
return chunk
|
||||
|
||||
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
|
||||
result = ParsedResponse(
|
||||
raw_response=response,
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
# 提取文本内容
|
||||
choices = response.get("choices", [])
|
||||
if choices:
|
||||
message = choices[0].get("message", {})
|
||||
content = message.get("content")
|
||||
if content:
|
||||
result.text_content = content
|
||||
|
||||
result.response_id = response.get("id")
|
||||
|
||||
# 提取 usage
|
||||
usage = response.get("usage", {})
|
||||
result.input_tokens = usage.get("prompt_tokens", 0)
|
||||
result.output_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
# 检查错误
|
||||
if "error" in response:
|
||||
result.is_error = True
|
||||
error = response.get("error", {})
|
||||
if isinstance(error, dict):
|
||||
result.error_type = error.get("type")
|
||||
result.error_message = error.get("message")
|
||||
else:
|
||||
result.error_message = str(error)
|
||||
|
||||
return result
|
||||
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
usage = response.get("usage", {})
|
||||
return {
|
||||
"input_tokens": usage.get("prompt_tokens", 0),
|
||||
"output_tokens": usage.get("completion_tokens", 0),
|
||||
"cache_creation_tokens": 0,
|
||||
"cache_read_tokens": 0,
|
||||
}
|
||||
|
||||
def extract_text_content(self, response: Dict[str, Any]) -> str:
|
||||
choices = response.get("choices", [])
|
||||
if choices:
|
||||
message = choices[0].get("message", {})
|
||||
content = message.get("content")
|
||||
if content:
|
||||
return content
|
||||
return ""
|
||||
|
||||
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
||||
return "error" in response
|
||||
|
||||
|
||||
class OpenAICliResponseParser(OpenAIResponseParser):
|
||||
"""OpenAI CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "OPENAI_CLI"
|
||||
self.api_format = "OPENAI_CLI"
|
||||
|
||||
|
||||
class ClaudeResponseParser(ResponseParser):
|
||||
"""Claude 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
from src.api.handlers.claude.stream_parser import ClaudeStreamParser
|
||||
|
||||
self._parser = ClaudeStreamParser()
|
||||
self.name = "CLAUDE"
|
||||
self.api_format = "CLAUDE"
|
||||
|
||||
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
|
||||
if not line or not line.strip():
|
||||
return None
|
||||
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
else:
|
||||
data_str = line
|
||||
|
||||
parsed = self._parser.parse_line(data_str)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
chunk = ParsedChunk(
|
||||
raw_line=line,
|
||||
event_type=self._parser.get_event_type(parsed),
|
||||
data=parsed,
|
||||
)
|
||||
|
||||
# 提取文本增量
|
||||
text_delta = self._parser.extract_text_delta(parsed)
|
||||
if text_delta:
|
||||
chunk.text_delta = text_delta
|
||||
stats.collected_text += text_delta
|
||||
|
||||
# 检查是否结束
|
||||
if self._parser.is_done_event(parsed):
|
||||
chunk.is_done = True
|
||||
stats.has_completion = True
|
||||
|
||||
# 提取 usage
|
||||
usage = self._parser.extract_usage(parsed)
|
||||
if usage:
|
||||
chunk.input_tokens = usage.get("input_tokens", 0)
|
||||
chunk.output_tokens = usage.get("output_tokens", 0)
|
||||
chunk.cache_creation_tokens = usage.get("cache_creation_tokens", 0)
|
||||
chunk.cache_read_tokens = usage.get("cache_read_tokens", 0)
|
||||
|
||||
stats.input_tokens = chunk.input_tokens
|
||||
stats.output_tokens = chunk.output_tokens
|
||||
stats.cache_creation_tokens = chunk.cache_creation_tokens
|
||||
stats.cache_read_tokens = chunk.cache_read_tokens
|
||||
|
||||
# 检查错误
|
||||
if self._parser.is_error_event(parsed):
|
||||
chunk.is_error = True
|
||||
error = parsed.get("error", {})
|
||||
if isinstance(error, dict):
|
||||
chunk.error_message = error.get("message", str(error))
|
||||
else:
|
||||
chunk.error_message = str(error)
|
||||
|
||||
stats.chunk_count += 1
|
||||
stats.data_count += 1
|
||||
|
||||
return chunk
|
||||
|
||||
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
|
||||
result = ParsedResponse(
|
||||
raw_response=response,
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
# 提取文本内容
|
||||
content = response.get("content", [])
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
result.text_content = "".join(text_parts)
|
||||
|
||||
result.response_id = response.get("id")
|
||||
|
||||
# 提取 usage
|
||||
usage = response.get("usage", {})
|
||||
result.input_tokens = usage.get("input_tokens", 0)
|
||||
result.output_tokens = usage.get("output_tokens", 0)
|
||||
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0)
|
||||
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
|
||||
|
||||
# 检查错误
|
||||
if "error" in response or response.get("type") == "error":
|
||||
result.is_error = True
|
||||
error = response.get("error", {})
|
||||
if isinstance(error, dict):
|
||||
result.error_type = error.get("type")
|
||||
result.error_message = error.get("message")
|
||||
else:
|
||||
result.error_message = str(error)
|
||||
|
||||
return result
|
||||
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
usage = response.get("usage", {})
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
def extract_text_content(self, response: Dict[str, Any]) -> str:
|
||||
content = response.get("content", [])
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
return "".join(text_parts)
|
||||
return ""
|
||||
|
||||
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
||||
return "error" in response or response.get("type") == "error"
|
||||
|
||||
|
||||
class ClaudeCliResponseParser(ClaudeResponseParser):
|
||||
"""Claude CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "CLAUDE_CLI"
|
||||
self.api_format = "CLAUDE_CLI"
|
||||
|
||||
|
||||
class GeminiResponseParser(ResponseParser):
|
||||
"""Gemini 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
|
||||
|
||||
self._parser = GeminiStreamParser()
|
||||
self.name = "GEMINI"
|
||||
self.api_format = "GEMINI"
|
||||
|
||||
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
|
||||
"""
|
||||
解析 Gemini SSE 行
|
||||
|
||||
Gemini 的流式响应使用 SSE 格式 (data: {...})
|
||||
"""
|
||||
if not line or not line.strip():
|
||||
return None
|
||||
|
||||
# Gemini SSE 格式: data: {...}
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
else:
|
||||
data_str = line
|
||||
|
||||
parsed = self._parser.parse_line(data_str)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
chunk = ParsedChunk(
|
||||
raw_line=line,
|
||||
event_type="content",
|
||||
data=parsed,
|
||||
)
|
||||
|
||||
# 提取文本增量
|
||||
text_delta = self._parser.extract_text_delta(parsed)
|
||||
if text_delta:
|
||||
chunk.text_delta = text_delta
|
||||
stats.collected_text += text_delta
|
||||
|
||||
# 检查是否结束
|
||||
if self._parser.is_done_event(parsed):
|
||||
chunk.is_done = True
|
||||
stats.has_completion = True
|
||||
|
||||
# 提取 usage
|
||||
usage = self._parser.extract_usage(parsed)
|
||||
if usage:
|
||||
chunk.input_tokens = usage.get("input_tokens", 0)
|
||||
chunk.output_tokens = usage.get("output_tokens", 0)
|
||||
chunk.cache_read_tokens = usage.get("cached_tokens", 0)
|
||||
|
||||
stats.input_tokens = chunk.input_tokens
|
||||
stats.output_tokens = chunk.output_tokens
|
||||
stats.cache_read_tokens = chunk.cache_read_tokens
|
||||
|
||||
# 检查错误
|
||||
if self._parser.is_error_event(parsed):
|
||||
chunk.is_error = True
|
||||
error = parsed.get("error", {})
|
||||
if isinstance(error, dict):
|
||||
chunk.error_message = error.get("message", str(error))
|
||||
else:
|
||||
chunk.error_message = str(error)
|
||||
|
||||
stats.chunk_count += 1
|
||||
stats.data_count += 1
|
||||
|
||||
return chunk
|
||||
|
||||
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
|
||||
result = ParsedResponse(
|
||||
raw_response=response,
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
# 提取文本内容
|
||||
candidates = response.get("candidates", [])
|
||||
if candidates:
|
||||
content = candidates[0].get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
text_parts = []
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text_parts.append(part["text"])
|
||||
result.text_content = "".join(text_parts)
|
||||
|
||||
result.response_id = response.get("modelVersion")
|
||||
|
||||
# 提取 usage(调用 GeminiStreamParser.extract_usage 作为单一实现源)
|
||||
usage = self._parser.extract_usage(response)
|
||||
if usage:
|
||||
result.input_tokens = usage.get("input_tokens", 0)
|
||||
result.output_tokens = usage.get("output_tokens", 0)
|
||||
result.cache_read_tokens = usage.get("cached_tokens", 0)
|
||||
|
||||
# 检查错误(使用增强的错误检测)
|
||||
error_info = self._parser.extract_error_info(response)
|
||||
if error_info:
|
||||
result.is_error = True
|
||||
result.error_type = error_info.get("status")
|
||||
result.error_message = error_info.get("message")
|
||||
|
||||
return result
|
||||
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
"""
|
||||
从 Gemini 响应中提取 token 使用量
|
||||
|
||||
调用 GeminiStreamParser.extract_usage 作为单一实现源
|
||||
"""
|
||||
usage = self._parser.extract_usage(response)
|
||||
if not usage:
|
||||
return {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_tokens": 0,
|
||||
"cache_read_tokens": 0,
|
||||
}
|
||||
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": 0,
|
||||
"cache_read_tokens": usage.get("cached_tokens", 0),
|
||||
}
|
||||
|
||||
def extract_text_content(self, response: Dict[str, Any]) -> str:
|
||||
candidates = response.get("candidates", [])
|
||||
if candidates:
|
||||
content = candidates[0].get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
text_parts = []
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text_parts.append(part["text"])
|
||||
return "".join(text_parts)
|
||||
return ""
|
||||
|
||||
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断响应是否为错误响应
|
||||
|
||||
使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误
|
||||
"""
|
||||
return self._parser.is_error_event(response)
|
||||
|
||||
|
||||
class GeminiCliResponseParser(GeminiResponseParser):
|
||||
"""Gemini CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "GEMINI_CLI"
|
||||
self.api_format = "GEMINI_CLI"
|
||||
|
||||
|
||||
# 解析器注册表
|
||||
_PARSERS = {
|
||||
"CLAUDE": ClaudeResponseParser,
|
||||
"CLAUDE_CLI": ClaudeCliResponseParser,
|
||||
"OPENAI": OpenAIResponseParser,
|
||||
"OPENAI_CLI": OpenAICliResponseParser,
|
||||
"GEMINI": GeminiResponseParser,
|
||||
"GEMINI_CLI": GeminiCliResponseParser,
|
||||
}
|
||||
|
||||
|
||||
def get_parser_for_format(format_id: str) -> ResponseParser:
|
||||
"""
|
||||
根据格式 ID 获取 ResponseParser
|
||||
|
||||
Args:
|
||||
format_id: 格式 ID,如 "CLAUDE", "OPENAI", "CLAUDE_CLI", "OPENAI_CLI"
|
||||
|
||||
Returns:
|
||||
ResponseParser 实例
|
||||
|
||||
Raises:
|
||||
KeyError: 格式不存在
|
||||
"""
|
||||
format_id = format_id.upper()
|
||||
if format_id not in _PARSERS:
|
||||
raise KeyError(f"Unknown format: {format_id}")
|
||||
return _PARSERS[format_id]()
|
||||
|
||||
|
||||
def is_cli_format(format_id: str) -> bool:
|
||||
"""判断是否为 CLI 格式"""
|
||||
return format_id.upper().endswith("_CLI")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OpenAIResponseParser",
|
||||
"OpenAICliResponseParser",
|
||||
"ClaudeResponseParser",
|
||||
"ClaudeCliResponseParser",
|
||||
"GeminiResponseParser",
|
||||
"GeminiCliResponseParser",
|
||||
"get_parser_for_format",
|
||||
"get_parser_from_protocol",
|
||||
"is_cli_format",
|
||||
]
|
||||
207
src/api/handlers/base/request_builder.py
Normal file
207
src/api/handlers/base/request_builder.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
请求构建器 - 透传模式
|
||||
|
||||
透传模式 (Passthrough): CLI 和 Chat 等场景,原样转发请求体和头部
|
||||
- 清理敏感头部:authorization, x-api-key, host, content-length 等
|
||||
- 保留所有其他头部和请求体字段
|
||||
- 适用于:Claude CLI、OpenAI CLI、Chat API 等场景
|
||||
|
||||
使用方式:
|
||||
builder = PassthroughRequestBuilder()
|
||||
payload, headers = builder.build(original_body, original_headers, endpoint, key)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, FrozenSet, Optional, Tuple
|
||||
|
||||
from src.core.crypto import crypto_service
|
||||
|
||||
# ==============================================================================
|
||||
# 统一的头部配置常量
|
||||
# ==============================================================================
|
||||
|
||||
# 敏感头部 - 透传时需要清理(黑名单)
|
||||
# 这些头部要么包含认证信息,要么由代理层重新生成
|
||||
SENSITIVE_HEADERS: FrozenSet[str] = frozenset(
|
||||
{
|
||||
"authorization",
|
||||
"x-api-key",
|
||||
"x-goog-api-key", # Gemini API 认证头
|
||||
"host",
|
||||
"content-length",
|
||||
"transfer-encoding",
|
||||
"connection",
|
||||
# 不透传 accept-encoding,让 httpx 自己协商压缩格式
|
||||
# 避免客户端请求 brotli/zstd 但 httpx 不支持解压的问题
|
||||
"accept-encoding",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 请求构建器
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RequestBuilder(ABC):
|
||||
"""请求构建器抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
def build_payload(
|
||||
self,
|
||||
original_body: Dict[str, Any],
|
||||
*,
|
||||
mapped_model: Optional[str] = None,
|
||||
is_stream: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建请求体"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_headers(
|
||||
self,
|
||||
original_headers: Dict[str, str],
|
||||
endpoint: Any,
|
||||
key: Any,
|
||||
*,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
pass
|
||||
|
||||
def build(
|
||||
self,
|
||||
original_body: Dict[str, Any],
|
||||
original_headers: Dict[str, str],
|
||||
endpoint: Any,
|
||||
key: Any,
|
||||
*,
|
||||
mapped_model: Optional[str] = None,
|
||||
is_stream: bool = False,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
构建完整的请求(请求体 + 请求头)
|
||||
|
||||
Returns:
|
||||
Tuple[payload, headers]
|
||||
"""
|
||||
payload = self.build_payload(
|
||||
original_body,
|
||||
mapped_model=mapped_model,
|
||||
is_stream=is_stream,
|
||||
)
|
||||
headers = self.build_headers(
|
||||
original_headers,
|
||||
endpoint,
|
||||
key,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
return payload, headers
|
||||
|
||||
|
||||
class PassthroughRequestBuilder(RequestBuilder):
|
||||
"""
|
||||
透传模式请求构建器
|
||||
|
||||
适用于 CLI 等场景,尽量保持请求原样:
|
||||
- 请求体:直接复制,只修改必要字段(model, stream)
|
||||
- 请求头:清理敏感头部(黑名单),透传其他所有头部
|
||||
"""
|
||||
|
||||
def build_payload(
|
||||
self,
|
||||
original_body: Dict[str, Any],
|
||||
*,
|
||||
mapped_model: Optional[str] = None, # noqa: ARG002 - 由 apply_mapped_model 处理
|
||||
is_stream: bool = False, # noqa: ARG002 - 保留原始值,不自动添加
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
透传请求体 - 原样复制,不做任何修改
|
||||
|
||||
透传模式下:
|
||||
- model: 由各 handler 的 apply_mapped_model 方法处理
|
||||
- stream: 保留客户端原始值(不同 API 处理方式不同)
|
||||
"""
|
||||
return dict(original_body)
|
||||
|
||||
def build_headers(
|
||||
self,
|
||||
original_headers: Dict[str, str],
|
||||
endpoint: Any,
|
||||
key: Any,
|
||||
*,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
透传请求头 - 清理敏感头部(黑名单),透传其他所有头部
|
||||
"""
|
||||
from src.core.api_format_metadata import get_auth_config, resolve_api_format
|
||||
|
||||
headers: Dict[str, str] = {}
|
||||
|
||||
# 1. 根据 API 格式自动设置认证头
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
api_format = getattr(endpoint, "api_format", None)
|
||||
resolved_format = resolve_api_format(api_format)
|
||||
auth_header, auth_type = (
|
||||
get_auth_config(resolved_format) if resolved_format else ("Authorization", "bearer")
|
||||
)
|
||||
|
||||
if auth_type == "bearer":
|
||||
headers[auth_header] = f"Bearer {decrypted_key}"
|
||||
else:
|
||||
headers[auth_header] = decrypted_key
|
||||
|
||||
# 2. 添加 endpoint 配置的额外头部
|
||||
if endpoint.headers:
|
||||
headers.update(endpoint.headers)
|
||||
|
||||
# 3. 透传原始头部(排除敏感头部 - 黑名单模式)
|
||||
if original_headers:
|
||||
for name, value in original_headers.items():
|
||||
lower_name = name.lower()
|
||||
|
||||
# 跳过敏感头部
|
||||
if lower_name in SENSITIVE_HEADERS:
|
||||
continue
|
||||
|
||||
headers[name] = value
|
||||
|
||||
# 4. 添加额外头部
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
# 5. 确保有 Content-Type
|
||||
if "Content-Type" not in headers and "content-type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 便捷函数
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def build_passthrough_request(
|
||||
original_body: Dict[str, Any],
|
||||
original_headers: Dict[str, str],
|
||||
endpoint: Any,
|
||||
key: Any,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
构建透传模式的请求
|
||||
|
||||
纯透传:原样复制请求体,只处理请求头(认证等)。
|
||||
model mapping 和 stream 由调用方自行处理(不同 API 格式处理方式不同)。
|
||||
"""
|
||||
builder = PassthroughRequestBuilder()
|
||||
return builder.build(
|
||||
original_body,
|
||||
original_headers,
|
||||
endpoint,
|
||||
key,
|
||||
)
|
||||
174
src/api/handlers/base/response_parser.py
Normal file
174
src/api/handlers/base/response_parser.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
响应解析器基类 - 定义统一的响应解析接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedChunk:
|
||||
"""解析后的流式数据块"""
|
||||
|
||||
# 原始数据
|
||||
raw_line: str
|
||||
event_type: Optional[str] = None
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 提取的内容
|
||||
text_delta: str = ""
|
||||
is_done: bool = False
|
||||
is_error: bool = False
|
||||
error_message: Optional[str] = None
|
||||
|
||||
# 使用量信息(通常在最后一个 chunk 中)
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
|
||||
# 响应 ID
|
||||
response_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamStats:
|
||||
"""流式响应统计信息"""
|
||||
|
||||
# 计数
|
||||
chunk_count: int = 0
|
||||
data_count: int = 0
|
||||
|
||||
# Token 使用量
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
|
||||
# 内容
|
||||
collected_text: str = ""
|
||||
response_id: Optional[str] = None
|
||||
|
||||
# 状态
|
||||
has_completion: bool = False
|
||||
status_code: int = 200
|
||||
error_message: Optional[str] = None
|
||||
|
||||
# Provider 信息
|
||||
provider_name: Optional[str] = None
|
||||
endpoint_id: Optional[str] = None
|
||||
key_id: Optional[str] = None
|
||||
|
||||
# 响应头和完整响应
|
||||
response_headers: Dict[str, str] = field(default_factory=dict)
|
||||
final_response: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedResponse:
|
||||
"""解析后的非流式响应"""
|
||||
|
||||
# 原始响应
|
||||
raw_response: Dict[str, Any]
|
||||
status_code: int
|
||||
|
||||
# 提取的内容
|
||||
text_content: str = ""
|
||||
response_id: Optional[str] = None
|
||||
|
||||
# 使用量
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
|
||||
# 错误信息
|
||||
is_error: bool = False
|
||||
error_type: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class ResponseParser(ABC):
|
||||
"""
|
||||
响应解析器基类
|
||||
|
||||
定义统一的接口来解析不同 API 格式的响应。
|
||||
子类需要实现具体的解析逻辑。
|
||||
"""
|
||||
|
||||
# 解析器名称(用于日志)
|
||||
name: str = "base"
|
||||
|
||||
# 支持的 API 格式
|
||||
api_format: str = "UNKNOWN"
|
||||
|
||||
@abstractmethod
|
||||
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
|
||||
"""
|
||||
解析单行 SSE 数据
|
||||
|
||||
Args:
|
||||
line: SSE 行数据
|
||||
stats: 流统计对象(会被更新)
|
||||
|
||||
Returns:
|
||||
解析后的数据块,如果行不包含有效数据则返回 None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
|
||||
"""
|
||||
解析非流式响应
|
||||
|
||||
Args:
|
||||
response: 响应 JSON
|
||||
status_code: HTTP 状态码
|
||||
|
||||
Returns:
|
||||
解析后的响应对象
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
"""
|
||||
从响应中提取 token 使用量
|
||||
|
||||
Args:
|
||||
response: 响应 JSON
|
||||
|
||||
Returns:
|
||||
包含 input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens 的字典
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def extract_text_content(self, response: Dict[str, Any]) -> str:
|
||||
"""
|
||||
从响应中提取文本内容
|
||||
|
||||
Args:
|
||||
response: 响应 JSON
|
||||
|
||||
Returns:
|
||||
提取的文本内容
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断响应是否为错误响应
|
||||
|
||||
Args:
|
||||
response: 响应 JSON
|
||||
|
||||
Returns:
|
||||
是否为错误响应
|
||||
"""
|
||||
return "error" in response
|
||||
|
||||
def create_stats(self) -> StreamStats:
|
||||
"""创建新的流统计对象"""
|
||||
return StreamStats()
|
||||
17
src/api/handlers/claude/__init__.py
Normal file
17
src/api/handlers/claude/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Claude Chat API 处理器
|
||||
"""
|
||||
|
||||
from src.api.handlers.claude.adapter import (
|
||||
ClaudeChatAdapter,
|
||||
ClaudeTokenCountAdapter,
|
||||
build_claude_adapter,
|
||||
)
|
||||
from src.api.handlers.claude.handler import ClaudeChatHandler
|
||||
|
||||
__all__ = [
|
||||
"ClaudeChatAdapter",
|
||||
"ClaudeTokenCountAdapter",
|
||||
"build_claude_adapter",
|
||||
"ClaudeChatHandler",
|
||||
]
|
||||
228
src/api/handlers/claude/adapter.py
Normal file
228
src/api/handlers/claude/adapter.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Claude Chat Adapter - 基于 ChatAdapterBase 的 Claude Chat API 适配器
|
||||
|
||||
处理 /v1/messages 端点的 Claude Chat 格式请求。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.context import ApiRequestContext
|
||||
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
from src.core.logger import logger
|
||||
from src.core.optimization_utils import TokenCounter
|
||||
from src.models.claude import ClaudeMessagesRequest, ClaudeTokenCountRequest
|
||||
|
||||
|
||||
class ClaudeCapabilityDetector:
|
||||
"""Claude API 能力检测器"""
|
||||
|
||||
@staticmethod
|
||||
def detect_from_headers(
|
||||
headers: Dict[str, str],
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
从 Claude 请求头检测能力需求
|
||||
|
||||
检测规则:
|
||||
- anthropic-beta: context-1m-xxx -> context_1m: True
|
||||
|
||||
Args:
|
||||
headers: 请求头字典
|
||||
request_body: 请求体(Claude 不使用,保留用于接口统一)
|
||||
"""
|
||||
requirements: Dict[str, bool] = {}
|
||||
|
||||
# 检查 anthropic-beta 请求头(大小写不敏感)
|
||||
beta_header = None
|
||||
for key, value in headers.items():
|
||||
if key.lower() == "anthropic-beta":
|
||||
beta_header = value
|
||||
break
|
||||
|
||||
if beta_header:
|
||||
# 检查是否包含 context-1m 标识
|
||||
if "context-1m" in beta_header.lower():
|
||||
requirements["context_1m"] = True
|
||||
|
||||
return requirements
|
||||
|
||||
|
||||
@register_adapter
|
||||
class ClaudeChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
Claude Chat API 适配器
|
||||
|
||||
处理 Claude Chat 格式的请求(/v1/messages 端点,进行格式验证)。
|
||||
"""
|
||||
|
||||
FORMAT_ID = "CLAUDE"
|
||||
name = "claude.chat"
|
||||
|
||||
@property
|
||||
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
|
||||
"""延迟导入 Handler 类避免循环依赖"""
|
||||
from src.api.handlers.claude.handler import ClaudeChatHandler
|
||||
|
||||
return ClaudeChatHandler
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
super().__init__(allowed_api_formats or ["CLAUDE"])
|
||||
logger.info(f"[{self.name}] 初始化Chat模式适配器 | API格式: {self.allowed_api_formats}")
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (x-api-key)"""
|
||||
return request.headers.get("x-api-key")
|
||||
|
||||
def detect_capability_requirements(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, bool]:
|
||||
"""检测 Claude 请求中隐含的能力需求"""
|
||||
return ClaudeCapabilityDetector.detect_from_headers(headers)
|
||||
|
||||
# =========================================================================
|
||||
# Claude 特定的计费逻辑
|
||||
# =========================================================================
|
||||
|
||||
def compute_total_input_context(
|
||||
self,
|
||||
input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
) -> int:
|
||||
"""
|
||||
计算 Claude 的总输入上下文(用于阶梯计费判定)
|
||||
|
||||
Claude 的总输入 = input_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
||||
"""
|
||||
return input_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
||||
|
||||
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
|
||||
"""验证请求体"""
|
||||
try:
|
||||
if not isinstance(original_request_body, dict):
|
||||
raise ValueError("Request body must be a JSON object")
|
||||
|
||||
required_fields = ["model", "messages", "max_tokens"]
|
||||
missing_fields = [f for f in required_fields if f not in original_request_body]
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required fields: {', '.join(missing_fields)}")
|
||||
|
||||
request = ClaudeMessagesRequest.model_validate(
|
||||
original_request_body,
|
||||
strict=False,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"请求体基本验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
|
||||
request = ClaudeMessagesRequest.model_construct(
|
||||
model=original_request_body.get("model"),
|
||||
max_tokens=original_request_body.get("max_tokens"),
|
||||
messages=original_request_body.get("messages", []),
|
||||
stream=original_request_body.get("stream", False),
|
||||
)
|
||||
return request
|
||||
|
||||
def _build_audit_metadata(self, _payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
|
||||
"""构建 Claude Chat 特定的审计元数据"""
|
||||
role_counts: dict[str, int] = {}
|
||||
for message in request_obj.messages:
|
||||
role_counts[message.role] = role_counts.get(message.role, 0) + 1
|
||||
|
||||
return {
|
||||
"action": "claude_messages",
|
||||
"model": request_obj.model,
|
||||
"stream": bool(request_obj.stream),
|
||||
"max_tokens": request_obj.max_tokens,
|
||||
"temperature": getattr(request_obj, "temperature", None),
|
||||
"top_p": getattr(request_obj, "top_p", None),
|
||||
"top_k": getattr(request_obj, "top_k", None),
|
||||
"messages_count": len(request_obj.messages),
|
||||
"message_roles": role_counts,
|
||||
"stop_sequences": len(request_obj.stop_sequences or []),
|
||||
"tools_count": len(request_obj.tools or []),
|
||||
"system_present": bool(request_obj.system),
|
||||
"metadata_present": bool(request_obj.metadata),
|
||||
"thinking_enabled": bool(request_obj.thinking),
|
||||
}
|
||||
|
||||
|
||||
def build_claude_adapter(x_app_header: Optional[str]):
|
||||
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
|
||||
if x_app_header and x_app_header.lower() == "cli":
|
||||
from src.api.handlers.claude_cli.adapter import ClaudeCliAdapter
|
||||
|
||||
return ClaudeCliAdapter()
|
||||
return ClaudeChatAdapter()
|
||||
|
||||
|
||||
class ClaudeTokenCountAdapter(ApiAdapter):
|
||||
"""计算 Claude 请求 Token 数的轻量适配器。"""
|
||||
|
||||
name = "claude.token_count"
|
||||
mode = ApiMode.STANDARD
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (x-api-key 或 Authorization: Bearer)"""
|
||||
# 优先检查 x-api-key
|
||||
api_key = request.headers.get("x-api-key")
|
||||
if api_key:
|
||||
return api_key
|
||||
# 降级到 Authorization: Bearer
|
||||
authorization = request.headers.get("authorization")
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
return authorization.replace("Bearer ", "")
|
||||
return None
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
try:
|
||||
request = ClaudeTokenCountRequest.model_validate(payload, strict=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Token count payload invalid: {e}")
|
||||
raise HTTPException(status_code=400, detail="Invalid token count payload") from e
|
||||
|
||||
token_counter = TokenCounter()
|
||||
total_tokens = 0
|
||||
|
||||
if request.system:
|
||||
if isinstance(request.system, str):
|
||||
total_tokens += token_counter.count_tokens(request.system, request.model)
|
||||
elif isinstance(request.system, list):
|
||||
for block in request.system:
|
||||
if hasattr(block, "text"):
|
||||
total_tokens += token_counter.count_tokens(block.text, request.model)
|
||||
|
||||
messages_dict = [
|
||||
msg.model_dump() if hasattr(msg, "model_dump") else msg for msg in request.messages
|
||||
]
|
||||
total_tokens += token_counter.count_messages_tokens(messages_dict, request.model)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="claude_token_count",
|
||||
model=request.model,
|
||||
messages_count=len(request.messages),
|
||||
system_present=bool(request.system),
|
||||
tools_count=len(request.tools or []),
|
||||
thinking_enabled=bool(request.thinking),
|
||||
input_tokens=total_tokens,
|
||||
)
|
||||
|
||||
return JSONResponse({"input_tokens": total_tokens})
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ClaudeChatAdapter",
|
||||
"ClaudeTokenCountAdapter",
|
||||
"build_claude_adapter",
|
||||
]
|
||||
490
src/api/handlers/claude/converter.py
Normal file
490
src/api/handlers/claude/converter.py
Normal file
@@ -0,0 +1,490 @@
|
||||
"""
|
||||
OpenAI -> Claude 格式转换器
|
||||
|
||||
将 OpenAI Chat Completions API 格式转换为 Claude Messages API 格式。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OpenAIToClaudeConverter:
|
||||
"""
|
||||
OpenAI -> Claude 格式转换器
|
||||
|
||||
支持:
|
||||
- 请求转换:OpenAI Chat Request -> Claude Request
|
||||
- 响应转换:OpenAI Chat Response -> Claude Response
|
||||
- 流式转换:OpenAI SSE -> Claude SSE
|
||||
"""
|
||||
|
||||
# 内容类型常量
|
||||
CONTENT_TYPE_TEXT = "text"
|
||||
CONTENT_TYPE_IMAGE = "image"
|
||||
CONTENT_TYPE_TOOL_USE = "tool_use"
|
||||
CONTENT_TYPE_TOOL_RESULT = "tool_result"
|
||||
|
||||
# 停止原因映射(OpenAI -> Claude)
|
||||
FINISH_REASON_MAP = {
|
||||
"stop": "end_turn",
|
||||
"length": "max_tokens",
|
||||
"tool_calls": "tool_use",
|
||||
"function_call": "tool_use",
|
||||
"content_filter": "end_turn",
|
||||
}
|
||||
|
||||
def __init__(self, model_mapping: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
Args:
|
||||
model_mapping: OpenAI 模型到 Claude 模型的映射
|
||||
"""
|
||||
self._model_mapping = model_mapping or {}
|
||||
|
||||
# ==================== 请求转换 ====================
|
||||
|
||||
def convert_request(self, request: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 OpenAI 请求转换为 Claude 格式
|
||||
|
||||
Args:
|
||||
request: OpenAI 请求(Dict 或 Pydantic 模型)
|
||||
|
||||
Returns:
|
||||
Claude 格式的请求字典
|
||||
"""
|
||||
if hasattr(request, "model_dump"):
|
||||
data = request.model_dump(exclude_none=True)
|
||||
else:
|
||||
data = dict(request)
|
||||
|
||||
# 模型映射
|
||||
model = data.get("model", "")
|
||||
claude_model = self._model_mapping.get(model, model)
|
||||
|
||||
# 处理消息
|
||||
system_content: Optional[str] = None
|
||||
claude_messages: List[Dict[str, Any]] = []
|
||||
|
||||
for message in data.get("messages", []):
|
||||
role = message.get("role")
|
||||
|
||||
# 提取 system 消息
|
||||
if role == "system":
|
||||
system_content = self._collapse_content(message.get("content"))
|
||||
continue
|
||||
|
||||
# 转换其他消息
|
||||
converted = self._convert_message(message)
|
||||
if converted:
|
||||
claude_messages.append(converted)
|
||||
|
||||
# 构建 Claude 请求
|
||||
result: Dict[str, Any] = {
|
||||
"model": claude_model,
|
||||
"messages": claude_messages,
|
||||
"max_tokens": data.get("max_tokens") or 4096,
|
||||
}
|
||||
|
||||
# 可选参数
|
||||
if data.get("temperature") is not None:
|
||||
result["temperature"] = data["temperature"]
|
||||
if data.get("top_p") is not None:
|
||||
result["top_p"] = data["top_p"]
|
||||
if data.get("stream"):
|
||||
result["stream"] = data["stream"]
|
||||
if data.get("stop"):
|
||||
result["stop_sequences"] = self._convert_stop(data["stop"])
|
||||
if system_content:
|
||||
result["system"] = system_content
|
||||
|
||||
# 工具转换
|
||||
tools = self._convert_tools(data.get("tools"))
|
||||
if tools:
|
||||
result["tools"] = tools
|
||||
|
||||
tool_choice = self._convert_tool_choice(data.get("tool_choice"))
|
||||
if tool_choice:
|
||||
result["tool_choice"] = tool_choice
|
||||
|
||||
return result
|
||||
|
||||
def _convert_message(self, message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""转换单条消息"""
|
||||
role = message.get("role")
|
||||
|
||||
if role == "user":
|
||||
return self._convert_user_message(message)
|
||||
if role == "assistant":
|
||||
return self._convert_assistant_message(message)
|
||||
if role == "tool":
|
||||
return self._convert_tool_message(message)
|
||||
|
||||
return None
|
||||
|
||||
def _convert_user_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""转换用户消息"""
|
||||
content = message.get("content")
|
||||
|
||||
if isinstance(content, str) or content is None:
|
||||
return {"role": "user", "content": content or ""}
|
||||
|
||||
# 转换内容数组
|
||||
claude_content: List[Dict[str, Any]] = []
|
||||
for item in content:
|
||||
item_type = item.get("type")
|
||||
|
||||
if item_type == "text":
|
||||
claude_content.append(
|
||||
{"type": self.CONTENT_TYPE_TEXT, "text": item.get("text", "")}
|
||||
)
|
||||
elif item_type == "image_url":
|
||||
image_url = (item.get("image_url") or {}).get("url", "")
|
||||
claude_content.append(self._convert_image_url(image_url))
|
||||
|
||||
return {"role": "user", "content": claude_content}
|
||||
|
||||
def _convert_assistant_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""转换助手消息"""
|
||||
content_blocks: List[Dict[str, Any]] = []
|
||||
|
||||
# 处理文本内容
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
content_blocks.append({"type": self.CONTENT_TYPE_TEXT, "text": content})
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if part.get("type") == "text":
|
||||
content_blocks.append(
|
||||
{"type": self.CONTENT_TYPE_TEXT, "text": part.get("text", "")}
|
||||
)
|
||||
|
||||
# 处理工具调用
|
||||
for tool_call in message.get("tool_calls") or []:
|
||||
if tool_call.get("type") == "function":
|
||||
function = tool_call.get("function", {})
|
||||
arguments = function.get("arguments", "{}")
|
||||
try:
|
||||
input_data = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
input_data = {"raw": arguments}
|
||||
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": self.CONTENT_TYPE_TOOL_USE,
|
||||
"id": tool_call.get("id", ""),
|
||||
"name": function.get("name", ""),
|
||||
"input": input_data,
|
||||
}
|
||||
)
|
||||
|
||||
# 简化单文本内容
|
||||
if not content_blocks:
|
||||
return {"role": "assistant", "content": ""}
|
||||
if len(content_blocks) == 1 and content_blocks[0]["type"] == self.CONTENT_TYPE_TEXT:
|
||||
return {"role": "assistant", "content": content_blocks[0]["text"]}
|
||||
|
||||
return {"role": "assistant", "content": content_blocks}
|
||||
|
||||
def _convert_tool_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""转换工具结果消息"""
|
||||
tool_content = message.get("content", "")
|
||||
|
||||
# 尝试解析 JSON
|
||||
parsed_content = tool_content
|
||||
if isinstance(tool_content, str):
|
||||
try:
|
||||
parsed_content = json.loads(tool_content)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
tool_block = {
|
||||
"type": self.CONTENT_TYPE_TOOL_RESULT,
|
||||
"tool_use_id": message.get("tool_call_id", ""),
|
||||
"content": parsed_content,
|
||||
}
|
||||
|
||||
return {"role": "user", "content": [tool_block]}
|
||||
|
||||
def _convert_tools(
|
||||
self, tools: Optional[List[Dict[str, Any]]]
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""转换工具定义"""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
result: List[Dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
if tool.get("type") != "function":
|
||||
continue
|
||||
|
||||
function = tool.get("function", {})
|
||||
result.append(
|
||||
{
|
||||
"name": function.get("name", ""),
|
||||
"description": function.get("description"),
|
||||
"input_schema": function.get("parameters") or {},
|
||||
}
|
||||
)
|
||||
|
||||
return result if result else None
|
||||
|
||||
def _convert_tool_choice(
|
||||
self, tool_choice: Optional[Union[str, Dict[str, Any]]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""转换工具选择"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
if tool_choice == "none":
|
||||
return {"type": "none"}
|
||||
if tool_choice == "auto":
|
||||
return {"type": "auto"}
|
||||
if tool_choice == "required":
|
||||
return {"type": "any"}
|
||||
if isinstance(tool_choice, dict) and tool_choice.get("type") == "function":
|
||||
function = tool_choice.get("function", {})
|
||||
return {"type": "tool_use", "name": function.get("name", "")}
|
||||
|
||||
return {"type": "auto"}
|
||||
|
||||
def _convert_image_url(self, image_url: str) -> Dict[str, Any]:
|
||||
"""转换图片 URL"""
|
||||
if image_url.startswith("data:"):
|
||||
header, _, data = image_url.partition(",")
|
||||
media_type = "image/jpeg"
|
||||
if ";" in header:
|
||||
media_type = header.split(";")[0].split(":")[-1]
|
||||
|
||||
return {
|
||||
"type": self.CONTENT_TYPE_IMAGE,
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": data,
|
||||
},
|
||||
}
|
||||
|
||||
return {"type": self.CONTENT_TYPE_TEXT, "text": f"[Image: {image_url}]"}
|
||||
|
||||
def _convert_stop(self, stop: Optional[Union[str, List[str]]]) -> Optional[List[str]]:
|
||||
"""转换停止序列"""
|
||||
if stop is None:
|
||||
return None
|
||||
if isinstance(stop, str):
|
||||
return [stop]
|
||||
return stop
|
||||
|
||||
# ==================== 响应转换 ====================
|
||||
|
||||
def convert_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 OpenAI 响应转换为 Claude 格式
|
||||
|
||||
Args:
|
||||
response: OpenAI 响应字典
|
||||
|
||||
Returns:
|
||||
Claude 格式的响应字典
|
||||
"""
|
||||
choices = response.get("choices", [])
|
||||
if not choices:
|
||||
return self._empty_claude_response(response)
|
||||
|
||||
choice = choices[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
# 构建 content 数组
|
||||
content: List[Dict[str, Any]] = []
|
||||
|
||||
# 处理文本
|
||||
text_content = message.get("content")
|
||||
if text_content:
|
||||
content.append(
|
||||
{
|
||||
"type": self.CONTENT_TYPE_TEXT,
|
||||
"text": text_content,
|
||||
}
|
||||
)
|
||||
|
||||
# 处理工具调用
|
||||
for tool_call in message.get("tool_calls") or []:
|
||||
if tool_call.get("type") == "function":
|
||||
function = tool_call.get("function", {})
|
||||
arguments = function.get("arguments", "{}")
|
||||
try:
|
||||
input_data = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
input_data = {"raw": arguments}
|
||||
|
||||
content.append(
|
||||
{
|
||||
"type": self.CONTENT_TYPE_TOOL_USE,
|
||||
"id": tool_call.get("id", ""),
|
||||
"name": function.get("name", ""),
|
||||
"input": input_data,
|
||||
}
|
||||
)
|
||||
|
||||
# 转换 finish_reason
|
||||
finish_reason = choice.get("finish_reason")
|
||||
stop_reason = self.FINISH_REASON_MAP.get(finish_reason, "end_turn")
|
||||
|
||||
# 转换 usage
|
||||
usage = response.get("usage", {})
|
||||
claude_usage = {
|
||||
"input_tokens": usage.get("prompt_tokens", 0),
|
||||
"output_tokens": usage.get("completion_tokens", 0),
|
||||
}
|
||||
|
||||
return {
|
||||
"id": f"msg_{response.get('id', uuid.uuid4().hex[:8])}",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": response.get("model", ""),
|
||||
"content": content,
|
||||
"stop_reason": stop_reason,
|
||||
"stop_sequence": None,
|
||||
"usage": claude_usage,
|
||||
}
|
||||
|
||||
def _empty_claude_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""构建空的 Claude 响应"""
|
||||
return {
|
||||
"id": f"msg_{response.get('id', uuid.uuid4().hex[:8])}",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": response.get("model", ""),
|
||||
"content": [],
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
"usage": {"input_tokens": 0, "output_tokens": 0},
|
||||
}
|
||||
|
||||
# ==================== 流式转换 ====================
|
||||
|
||||
def convert_stream_chunk(
|
||||
self,
|
||||
chunk: Dict[str, Any],
|
||||
model: str = "",
|
||||
message_id: Optional[str] = None,
|
||||
message_started: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
将 OpenAI SSE chunk 转换为 Claude SSE 事件
|
||||
|
||||
Args:
|
||||
chunk: OpenAI SSE chunk
|
||||
model: 模型名称
|
||||
message_id: 消息 ID
|
||||
message_started: 是否已发送 message_start
|
||||
|
||||
Returns:
|
||||
Claude SSE 事件列表
|
||||
"""
|
||||
events: List[Dict[str, Any]] = []
|
||||
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
return events
|
||||
|
||||
choice = choices[0]
|
||||
delta = choice.get("delta", {})
|
||||
finish_reason = choice.get("finish_reason")
|
||||
|
||||
# 处理角色(第一个 chunk)
|
||||
role = delta.get("role")
|
||||
if role and not message_started:
|
||||
msg_id = message_id or f"msg_{uuid.uuid4().hex[:8]}"
|
||||
events.append(
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": msg_id,
|
||||
"type": "message",
|
||||
"role": role,
|
||||
"model": model,
|
||||
"content": [],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 处理文本内容
|
||||
content_delta = delta.get("content")
|
||||
if isinstance(content_delta, str):
|
||||
events.append(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": content_delta},
|
||||
}
|
||||
)
|
||||
|
||||
# 处理工具调用
|
||||
tool_calls = delta.get("tool_calls", [])
|
||||
for tool_call in tool_calls:
|
||||
index = tool_call.get("index", 0)
|
||||
|
||||
# 工具调用开始
|
||||
if "id" in tool_call:
|
||||
function = tool_call.get("function", {})
|
||||
events.append(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": index,
|
||||
"content_block": {
|
||||
"type": self.CONTENT_TYPE_TOOL_USE,
|
||||
"id": tool_call["id"],
|
||||
"name": function.get("name", ""),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 工具调用参数增量
|
||||
function = tool_call.get("function", {})
|
||||
if "arguments" in function:
|
||||
events.append(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": {
|
||||
"type": "input_json_delta",
|
||||
"partial_json": function.get("arguments", ""),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 处理结束
|
||||
if finish_reason:
|
||||
stop_reason = self.FINISH_REASON_MAP.get(finish_reason, "end_turn")
|
||||
events.append(
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {"stop_reason": stop_reason},
|
||||
}
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
# ==================== 工具方法 ====================
|
||||
|
||||
def _collapse_content(
|
||||
self, content: Optional[Union[str, List[Dict[str, Any]]]]
|
||||
) -> Optional[str]:
|
||||
"""折叠内容为字符串"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if not content:
|
||||
return None
|
||||
|
||||
text_parts = [part.get("text", "") for part in content if part.get("type") == "text"]
|
||||
return "\n\n".join(filter(None, text_parts)) or None
|
||||
|
||||
|
||||
__all__ = ["OpenAIToClaudeConverter"]
|
||||
150
src/api/handlers/claude/handler.py
Normal file
150
src/api/handlers/claude/handler.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Claude Chat Handler - 基于通用 Chat Handler 基类的简化实现
|
||||
|
||||
继承 ChatHandlerBase,只需覆盖格式特定的方法。
|
||||
代码量从原来的 ~1470 行减少到 ~120 行。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
|
||||
|
||||
class ClaudeChatHandler(ChatHandlerBase):
|
||||
"""
|
||||
Claude Chat Handler - 处理 Claude Chat/CLI API 格式的请求
|
||||
|
||||
格式特点:
|
||||
- 使用 input_tokens/output_tokens
|
||||
- 支持 cache_creation_input_tokens/cache_read_input_tokens
|
||||
- 请求格式:ClaudeMessagesRequest
|
||||
"""
|
||||
|
||||
FORMAT_ID = "CLAUDE"
|
||||
|
||||
def extract_model_from_request(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
|
||||
) -> str:
|
||||
"""
|
||||
从请求中提取模型名 - Claude 格式实现
|
||||
|
||||
Claude API 的 model 在请求体顶级字段。
|
||||
|
||||
Args:
|
||||
request_body: 请求体
|
||||
path_params: URL 路径参数(Claude 不使用)
|
||||
|
||||
Returns:
|
||||
模型名
|
||||
"""
|
||||
model = request_body.get("model")
|
||||
return str(model) if model else "unknown"
|
||||
|
||||
def apply_mapped_model(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
mapped_model: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
将映射后的模型名应用到请求体
|
||||
|
||||
Claude API 的 model 在请求体顶级字段。
|
||||
|
||||
Args:
|
||||
request_body: 原始请求体
|
||||
mapped_model: 映射后的模型名
|
||||
|
||||
Returns:
|
||||
更新了 model 字段的请求体
|
||||
"""
|
||||
result = dict(request_body)
|
||||
result["model"] = mapped_model
|
||||
return result
|
||||
|
||||
async def _convert_request(self, request):
|
||||
"""
|
||||
将请求转换为 Claude 格式
|
||||
|
||||
Args:
|
||||
request: 原始请求对象
|
||||
|
||||
Returns:
|
||||
ClaudeMessagesRequest 对象
|
||||
"""
|
||||
from src.api.handlers.claude.converter import OpenAIToClaudeConverter
|
||||
from src.models.claude import ClaudeMessagesRequest
|
||||
from src.models.openai import OpenAIRequest
|
||||
|
||||
# 如果已经是 Claude 格式,直接返回
|
||||
if isinstance(request, ClaudeMessagesRequest):
|
||||
return request
|
||||
|
||||
# 如果是 OpenAI 格式,转换为 Claude 格式
|
||||
if isinstance(request, OpenAIRequest):
|
||||
converter = OpenAIToClaudeConverter()
|
||||
claude_dict = converter.convert_request(request.dict())
|
||||
return ClaudeMessagesRequest(**claude_dict)
|
||||
|
||||
# 如果是字典,根据内容判断格式
|
||||
if isinstance(request, dict):
|
||||
if "messages" in request and len(request["messages"]) > 0:
|
||||
first_msg = request["messages"][0]
|
||||
if "role" in first_msg and "content" in first_msg:
|
||||
# 可能是 OpenAI 格式
|
||||
converter = OpenAIToClaudeConverter()
|
||||
claude_dict = converter.convert_request(request)
|
||||
return ClaudeMessagesRequest(**claude_dict)
|
||||
|
||||
# 否则假设已经是 Claude 格式
|
||||
return ClaudeMessagesRequest(**request)
|
||||
|
||||
return request
|
||||
|
||||
def _extract_usage(self, response: Dict) -> Dict[str, int]:
|
||||
"""
|
||||
从 Claude 响应中提取 token 使用情况
|
||||
|
||||
Claude 格式使用:
|
||||
- input_tokens / output_tokens
|
||||
- cache_creation_input_tokens / cache_read_input_tokens
|
||||
"""
|
||||
usage = response.get("usage", {})
|
||||
|
||||
input_tokens = usage.get("input_tokens", 0)
|
||||
output_tokens = usage.get("output_tokens", 0)
|
||||
cache_creation_input_tokens = usage.get("cache_creation_input_tokens", 0)
|
||||
cache_read_input_tokens = usage.get("cache_read_input_tokens", 0)
|
||||
|
||||
# 处理新的 cache_creation 格式
|
||||
if "cache_creation" in usage:
|
||||
cache_creation_data = usage.get("cache_creation", {})
|
||||
if not cache_creation_input_tokens:
|
||||
cache_creation_input_tokens = cache_creation_data.get(
|
||||
"ephemeral_5m_input_tokens", 0
|
||||
) + cache_creation_data.get("ephemeral_1h_input_tokens", 0)
|
||||
|
||||
return {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"cache_creation_input_tokens": cache_creation_input_tokens,
|
||||
"cache_read_input_tokens": cache_read_input_tokens,
|
||||
}
|
||||
|
||||
def _normalize_response(self, response: Dict) -> Dict:
|
||||
"""
|
||||
规范化 Claude 响应
|
||||
|
||||
Args:
|
||||
response: 原始响应
|
||||
|
||||
Returns:
|
||||
规范化后的响应
|
||||
"""
|
||||
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
||||
return self.response_normalizer.normalize_claude_response(
|
||||
response_data=response,
|
||||
request_id=self.request_id,
|
||||
)
|
||||
return response
|
||||
241
src/api/handlers/claude/stream_parser.py
Normal file
241
src/api/handlers/claude/stream_parser.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
Claude SSE 流解析器
|
||||
|
||||
解析 Claude Messages API 的 Server-Sent Events 流。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class ClaudeStreamParser:
|
||||
"""
|
||||
Claude SSE 流解析器
|
||||
|
||||
解析 Claude Messages API 的 SSE 事件流。
|
||||
|
||||
事件类型:
|
||||
- message_start: 消息开始,包含初始 message 对象
|
||||
- content_block_start: 内容块开始
|
||||
- content_block_delta: 内容块增量(文本、工具输入等)
|
||||
- content_block_stop: 内容块结束
|
||||
- message_delta: 消息增量,包含 stop_reason 和最终 usage
|
||||
- message_stop: 消息结束
|
||||
- ping: 心跳事件
|
||||
- error: 错误事件
|
||||
"""
|
||||
|
||||
# Claude SSE 事件类型
|
||||
EVENT_MESSAGE_START = "message_start"
|
||||
EVENT_MESSAGE_STOP = "message_stop"
|
||||
EVENT_MESSAGE_DELTA = "message_delta"
|
||||
EVENT_CONTENT_BLOCK_START = "content_block_start"
|
||||
EVENT_CONTENT_BLOCK_STOP = "content_block_stop"
|
||||
EVENT_CONTENT_BLOCK_DELTA = "content_block_delta"
|
||||
EVENT_PING = "ping"
|
||||
EVENT_ERROR = "error"
|
||||
|
||||
# Delta 类型
|
||||
DELTA_TEXT = "text_delta"
|
||||
DELTA_INPUT_JSON = "input_json_delta"
|
||||
|
||||
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
解析 SSE 数据块
|
||||
|
||||
Args:
|
||||
chunk: 原始 SSE 数据(bytes 或 str)
|
||||
|
||||
Returns:
|
||||
解析后的事件列表
|
||||
"""
|
||||
if isinstance(chunk, bytes):
|
||||
text = chunk.decode("utf-8")
|
||||
else:
|
||||
text = chunk
|
||||
|
||||
events: List[Dict[str, Any]] = []
|
||||
lines = text.strip().split("\n")
|
||||
|
||||
current_event_type: Optional[str] = None
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 解析事件类型行
|
||||
if line.startswith("event: "):
|
||||
current_event_type = line[7:]
|
||||
continue
|
||||
|
||||
# 解析数据行
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
|
||||
# 处理 [DONE] 标记
|
||||
if data_str == "[DONE]":
|
||||
events.append({"type": "__done__", "raw": "[DONE]"})
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
# 如果数据中没有 type,使用事件行的类型
|
||||
if "type" not in data and current_event_type:
|
||||
data["type"] = current_event_type
|
||||
events.append(data)
|
||||
except json.JSONDecodeError:
|
||||
# 无法解析的数据,跳过
|
||||
pass
|
||||
|
||||
current_event_type = None
|
||||
|
||||
return events
|
||||
|
||||
def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
解析单行 SSE 数据
|
||||
|
||||
Args:
|
||||
line: SSE 数据行(已去除 "data: " 前缀)
|
||||
|
||||
Returns:
|
||||
解析后的事件字典,如果无法解析返回 None
|
||||
"""
|
||||
if not line or line == "[DONE]":
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def is_done_event(self, event: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断是否为结束事件
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
True 如果是结束事件
|
||||
"""
|
||||
event_type = event.get("type")
|
||||
return event_type in (self.EVENT_MESSAGE_STOP, "__done__")
|
||||
|
||||
def is_error_event(self, event: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断是否为错误事件
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
True 如果是错误事件
|
||||
"""
|
||||
return event.get("type") == self.EVENT_ERROR
|
||||
|
||||
def get_event_type(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
获取事件类型
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
事件类型字符串
|
||||
"""
|
||||
return event.get("type")
|
||||
|
||||
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
从 content_block_delta 事件中提取文本增量
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
文本增量,如果不是文本 delta 返回 None
|
||||
"""
|
||||
if event.get("type") != self.EVENT_CONTENT_BLOCK_DELTA:
|
||||
return None
|
||||
|
||||
delta = event.get("delta", {})
|
||||
if delta.get("type") == self.DELTA_TEXT:
|
||||
return delta.get("text")
|
||||
|
||||
return None
|
||||
|
||||
def extract_usage(self, event: Dict[str, Any]) -> Optional[Dict[str, int]]:
|
||||
"""
|
||||
从事件中提取 token 使用量
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
使用量字典,如果没有使用量信息返回 None
|
||||
"""
|
||||
event_type = event.get("type")
|
||||
|
||||
# message_start 事件包含初始 usage
|
||||
if event_type == self.EVENT_MESSAGE_START:
|
||||
message = event.get("message", {})
|
||||
usage = message.get("usage", {})
|
||||
if usage:
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
# message_delta 事件包含最终 usage
|
||||
if event_type == self.EVENT_MESSAGE_DELTA:
|
||||
usage = event.get("usage", {})
|
||||
if usage:
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def extract_message_id(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
从 message_start 事件中提取消息 ID
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
消息 ID,如果不是 message_start 返回 None
|
||||
"""
|
||||
if event.get("type") != self.EVENT_MESSAGE_START:
|
||||
return None
|
||||
|
||||
message = event.get("message", {})
|
||||
return message.get("id")
|
||||
|
||||
def extract_stop_reason(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
从 message_delta 事件中提取停止原因
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
停止原因,如果没有返回 None
|
||||
"""
|
||||
if event.get("type") != self.EVENT_MESSAGE_DELTA:
|
||||
return None
|
||||
|
||||
delta = event.get("delta", {})
|
||||
return delta.get("stop_reason")
|
||||
|
||||
|
||||
__all__ = ["ClaudeStreamParser"]
|
||||
11
src/api/handlers/claude_cli/__init__.py
Normal file
11
src/api/handlers/claude_cli/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Claude CLI 透传处理器
|
||||
"""
|
||||
|
||||
from src.api.handlers.claude_cli.adapter import ClaudeCliAdapter
|
||||
from src.api.handlers.claude_cli.handler import ClaudeCliMessageHandler
|
||||
|
||||
__all__ = [
|
||||
"ClaudeCliAdapter",
|
||||
"ClaudeCliMessageHandler",
|
||||
]
|
||||
103
src/api/handlers/claude_cli/adapter.py
Normal file
103
src/api/handlers/claude_cli/adapter.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Claude CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
|
||||
|
||||
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
|
||||
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
||||
from src.api.handlers.claude.adapter import ClaudeCapabilityDetector
|
||||
|
||||
|
||||
@register_cli_adapter
|
||||
class ClaudeCliAdapter(CliAdapterBase):
|
||||
"""
|
||||
Claude CLI API 适配器
|
||||
|
||||
处理 Claude CLI 格式的请求(/v1/messages 端点,使用 Bearer 认证)。
|
||||
"""
|
||||
|
||||
FORMAT_ID = "CLAUDE_CLI"
|
||||
name = "claude.cli"
|
||||
|
||||
@property
|
||||
def HANDLER_CLASS(self) -> Type[CliMessageHandlerBase]:
|
||||
"""延迟导入 Handler 类避免循环依赖"""
|
||||
from src.api.handlers.claude_cli.handler import ClaudeCliMessageHandler
|
||||
|
||||
return ClaudeCliMessageHandler
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
super().__init__(allowed_api_formats or ["CLAUDE_CLI"])
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (Authorization: Bearer)"""
|
||||
authorization = request.headers.get("authorization")
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
return authorization.replace("Bearer ", "")
|
||||
return None
|
||||
|
||||
def detect_capability_requirements(
|
||||
self,
|
||||
headers: Dict[str, str],
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, bool]:
|
||||
"""检测 Claude CLI 请求中隐含的能力需求"""
|
||||
return ClaudeCapabilityDetector.detect_from_headers(headers)
|
||||
|
||||
# =========================================================================
|
||||
# Claude CLI 特定的计费逻辑
|
||||
# =========================================================================
|
||||
|
||||
def compute_total_input_context(
|
||||
self,
|
||||
input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
) -> int:
|
||||
"""
|
||||
计算 Claude CLI 的总输入上下文(用于阶梯计费判定)
|
||||
|
||||
Claude 的总输入 = input_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
||||
"""
|
||||
return input_tokens + cache_creation_input_tokens + cache_read_input_tokens
|
||||
|
||||
def _extract_message_count(self, payload: Dict[str, Any]) -> int:
|
||||
"""Claude CLI 使用 messages 字段"""
|
||||
messages = payload.get("messages", [])
|
||||
return len(messages) if isinstance(messages, list) else 0
|
||||
|
||||
def _build_audit_metadata(
|
||||
self,
|
||||
payload: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
|
||||
) -> Dict[str, Any]:
|
||||
"""Claude CLI 特定的审计元数据"""
|
||||
model = payload.get("model", "unknown")
|
||||
stream = payload.get("stream", False)
|
||||
messages = payload.get("messages", [])
|
||||
|
||||
role_counts = {}
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
role_counts[role] = role_counts.get(role, 0) + 1
|
||||
|
||||
return {
|
||||
"action": "claude_cli_request",
|
||||
"model": model,
|
||||
"stream": bool(stream),
|
||||
"max_tokens": payload.get("max_tokens"),
|
||||
"messages_count": len(messages),
|
||||
"message_roles": role_counts,
|
||||
"temperature": payload.get("temperature"),
|
||||
"top_p": payload.get("top_p"),
|
||||
"tool_count": len(payload.get("tools") or []),
|
||||
"system_present": bool(payload.get("system")),
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["ClaudeCliAdapter"]
|
||||
195
src/api/handlers/claude_cli/handler.py
Normal file
195
src/api/handlers/claude_cli/handler.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
Claude CLI Message Handler - 基于通用 CLI Handler 基类的简化实现
|
||||
|
||||
继承 CliMessageHandlerBase,只需覆盖格式特定的配置和事件处理逻辑。
|
||||
验证新架构的有效性:代码量从数百行减少到 ~80 行。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.cli_handler_base import (
|
||||
CliMessageHandlerBase,
|
||||
StreamContext,
|
||||
)
|
||||
|
||||
|
||||
class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
||||
"""
|
||||
Claude CLI Message Handler - 处理 Claude CLI API 格式
|
||||
|
||||
使用新三层架构 (Provider -> ProviderEndpoint -> ProviderAPIKey)
|
||||
通过 FallbackOrchestrator 实现自动故障转移、健康监控和并发控制
|
||||
|
||||
响应格式特点:
|
||||
- 使用 content[] 数组
|
||||
- 使用 text 类型
|
||||
- 流式事件:message_start, content_block_delta, message_delta, message_stop
|
||||
- 支持 cache_creation_input_tokens 和 cache_read_input_tokens
|
||||
|
||||
模型字段:请求体顶级 model 字段
|
||||
"""
|
||||
|
||||
FORMAT_ID = "CLAUDE_CLI"
|
||||
|
||||
def extract_model_from_request(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
|
||||
) -> str:
|
||||
"""
|
||||
从请求中提取模型名 - Claude 格式实现
|
||||
|
||||
Claude API 的 model 在请求体顶级字段。
|
||||
|
||||
Args:
|
||||
request_body: 请求体
|
||||
path_params: URL 路径参数(Claude 不使用)
|
||||
|
||||
Returns:
|
||||
模型名
|
||||
"""
|
||||
model = request_body.get("model")
|
||||
return str(model) if model else "unknown"
|
||||
|
||||
def apply_mapped_model(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
mapped_model: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Claude API 的 model 在请求体顶级
|
||||
|
||||
Args:
|
||||
request_body: 原始请求体
|
||||
mapped_model: 映射后的模型名
|
||||
|
||||
Returns:
|
||||
更新了 model 字段的请求体
|
||||
"""
|
||||
result = dict(request_body)
|
||||
result["model"] = mapped_model
|
||||
return result
|
||||
|
||||
def _process_event_data(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
event_type: str,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
处理 Claude CLI 格式的 SSE 事件
|
||||
|
||||
事件类型:
|
||||
- message_start: 消息开始,包含初始 usage(含缓存 tokens)
|
||||
- content_block_delta: 文本增量
|
||||
- message_delta: 消息增量,包含最终 usage
|
||||
- message_stop: 消息结束
|
||||
"""
|
||||
# 处理 message_start 事件
|
||||
if event_type == "message_start":
|
||||
message = data.get("message", {})
|
||||
if message.get("id"):
|
||||
ctx.response_id = message["id"]
|
||||
|
||||
# 提取初始 usage(包含缓存 tokens)
|
||||
usage = message.get("usage", {})
|
||||
if usage:
|
||||
ctx.input_tokens = usage.get("input_tokens", 0)
|
||||
# Claude 的缓存 tokens 使用不同的字段名
|
||||
cache_read = usage.get("cache_read_input_tokens", 0)
|
||||
if cache_read:
|
||||
ctx.cached_tokens = cache_read
|
||||
cache_creation = usage.get("cache_creation_input_tokens", 0)
|
||||
if cache_creation:
|
||||
ctx.cache_creation_tokens = cache_creation
|
||||
|
||||
# 处理文本增量
|
||||
elif event_type == "content_block_delta":
|
||||
delta = data.get("delta", {})
|
||||
if delta.get("type") == "text_delta":
|
||||
text = delta.get("text", "")
|
||||
if text:
|
||||
ctx.collected_text += text
|
||||
|
||||
# 处理消息增量(包含最终 usage)
|
||||
elif event_type == "message_delta":
|
||||
usage = data.get("usage", {})
|
||||
if usage:
|
||||
if "input_tokens" in usage:
|
||||
ctx.input_tokens = usage["input_tokens"]
|
||||
if "output_tokens" in usage:
|
||||
ctx.output_tokens = usage["output_tokens"]
|
||||
# 更新缓存 tokens
|
||||
if "cache_read_input_tokens" in usage:
|
||||
ctx.cached_tokens = usage["cache_read_input_tokens"]
|
||||
if "cache_creation_input_tokens" in usage:
|
||||
ctx.cache_creation_tokens = usage["cache_creation_input_tokens"]
|
||||
|
||||
# 检查是否结束
|
||||
delta = data.get("delta", {})
|
||||
if delta.get("stop_reason"):
|
||||
ctx.has_completion = True
|
||||
ctx.final_response = data
|
||||
|
||||
# 处理消息结束
|
||||
elif event_type == "message_stop":
|
||||
ctx.has_completion = True
|
||||
|
||||
def _extract_response_metadata(
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从 Claude 响应中提取元数据
|
||||
|
||||
提取 model、stop_reason 等字段作为元数据。
|
||||
|
||||
Args:
|
||||
response: Claude API 响应
|
||||
|
||||
Returns:
|
||||
提取的元数据字典
|
||||
"""
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
# 提取模型名称(实际使用的模型)
|
||||
if "model" in response:
|
||||
metadata["model"] = response["model"]
|
||||
|
||||
# 提取停止原因
|
||||
if "stop_reason" in response:
|
||||
metadata["stop_reason"] = response["stop_reason"]
|
||||
|
||||
# 提取消息 ID
|
||||
if "id" in response:
|
||||
metadata["message_id"] = response["id"]
|
||||
|
||||
# 提取消息类型
|
||||
if "type" in response:
|
||||
metadata["type"] = response["type"]
|
||||
|
||||
return metadata
|
||||
|
||||
def _finalize_stream_metadata(self, ctx: StreamContext) -> None:
|
||||
"""
|
||||
从流上下文中提取最终元数据
|
||||
|
||||
在流传输完成后调用,从收集的事件中提取元数据。
|
||||
|
||||
Args:
|
||||
ctx: 流上下文
|
||||
"""
|
||||
# 从 response_id 提取消息 ID
|
||||
if ctx.response_id:
|
||||
ctx.response_metadata["message_id"] = ctx.response_id
|
||||
|
||||
# 从 final_response 提取停止原因(message_delta 事件中的 delta.stop_reason)
|
||||
if ctx.final_response:
|
||||
delta = ctx.final_response.get("delta", {})
|
||||
if "stop_reason" in delta:
|
||||
ctx.response_metadata["stop_reason"] = delta["stop_reason"]
|
||||
|
||||
# 记录模型名称
|
||||
if ctx.model:
|
||||
ctx.response_metadata["model"] = ctx.model
|
||||
|
||||
26
src/api/handlers/gemini/__init__.py
Normal file
26
src/api/handlers/gemini/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Gemini API Handler 模块
|
||||
|
||||
提供 Gemini API 格式的请求处理
|
||||
"""
|
||||
|
||||
from src.api.handlers.gemini.adapter import GeminiChatAdapter, build_gemini_adapter
|
||||
from src.api.handlers.gemini.converter import (
|
||||
ClaudeToGeminiConverter,
|
||||
GeminiToClaudeConverter,
|
||||
GeminiToOpenAIConverter,
|
||||
OpenAIToGeminiConverter,
|
||||
)
|
||||
from src.api.handlers.gemini.handler import GeminiChatHandler
|
||||
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
|
||||
|
||||
__all__ = [
|
||||
"GeminiChatAdapter",
|
||||
"GeminiChatHandler",
|
||||
"GeminiStreamParser",
|
||||
"ClaudeToGeminiConverter",
|
||||
"GeminiToClaudeConverter",
|
||||
"OpenAIToGeminiConverter",
|
||||
"GeminiToOpenAIConverter",
|
||||
"build_gemini_adapter",
|
||||
]
|
||||
170
src/api/handlers/gemini/adapter.py
Normal file
170
src/api/handlers/gemini/adapter.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Gemini Chat Adapter
|
||||
|
||||
处理 Gemini API 格式的请求适配
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
from src.core.logger import logger
|
||||
from src.models.gemini import GeminiRequest
|
||||
|
||||
|
||||
@register_adapter
|
||||
class GeminiChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
Gemini Chat API 适配器
|
||||
|
||||
处理 Gemini Chat 格式的请求
|
||||
端点: /v1beta/models/{model}:generateContent
|
||||
"""
|
||||
|
||||
FORMAT_ID = "GEMINI"
|
||||
name = "gemini.chat"
|
||||
|
||||
@property
|
||||
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
|
||||
"""延迟导入 Handler 类避免循环依赖"""
|
||||
from src.api.handlers.gemini.handler import GeminiChatHandler
|
||||
|
||||
return GeminiChatHandler
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
super().__init__(allowed_api_formats or ["GEMINI"])
|
||||
logger.info(f"[{self.name}] 初始化 Gemini Chat 适配器 | API格式: {self.allowed_api_formats}")
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (x-goog-api-key)"""
|
||||
return request.headers.get("x-goog-api-key")
|
||||
|
||||
def _merge_path_params(
|
||||
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any] # noqa: ARG002
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
合并 URL 路径参数到请求体 - Gemini 特化版本
|
||||
|
||||
Gemini API 特点:
|
||||
- model 不合并到请求体(通过 extract_model_from_request 从 path_params 获取)
|
||||
- stream 不合并到请求体(Gemini API 通过 URL 端点区分流式/非流式)
|
||||
|
||||
Handler 层的 extract_model_from_request 会从 path_params 获取 model,
|
||||
prepare_provider_request_body 会确保发送给 Gemini API 的请求体不含 model。
|
||||
|
||||
Args:
|
||||
original_request_body: 原始请求体字典
|
||||
path_params: URL 路径参数字典(不使用)
|
||||
|
||||
Returns:
|
||||
原始请求体(不合并任何 path_params)
|
||||
"""
|
||||
return original_request_body.copy()
|
||||
|
||||
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
|
||||
"""验证请求体"""
|
||||
path_params = path_params or {}
|
||||
is_stream = path_params.get("stream", False)
|
||||
model = path_params.get("model", "unknown")
|
||||
|
||||
try:
|
||||
if not isinstance(original_request_body, dict):
|
||||
raise ValueError("Request body must be a JSON object")
|
||||
|
||||
# Gemini 必需字段: contents
|
||||
if "contents" not in original_request_body:
|
||||
raise ValueError("Missing required field: contents")
|
||||
|
||||
request = GeminiRequest.model_validate(
|
||||
original_request_body,
|
||||
strict=False,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"请求体基本验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
|
||||
request = GeminiRequest.model_construct(
|
||||
contents=original_request_body.get("contents", []),
|
||||
)
|
||||
|
||||
# 设置 model(从 path_params 获取,用于日志和审计)
|
||||
request.model = model
|
||||
# 设置 stream 属性(用于 ChatAdapterBase 判断流式模式)
|
||||
request.stream = is_stream
|
||||
return request
|
||||
|
||||
def _extract_message_count(self, payload: Dict[str, Any], request_obj) -> int:
|
||||
"""提取消息数量"""
|
||||
contents = payload.get("contents", [])
|
||||
if hasattr(request_obj, "contents"):
|
||||
contents = request_obj.contents
|
||||
return len(contents) if isinstance(contents, list) else 0
|
||||
|
||||
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
|
||||
"""构建 Gemini Chat 特定的审计元数据"""
|
||||
role_counts: dict[str, int] = {}
|
||||
|
||||
contents = getattr(request_obj, "contents", []) or []
|
||||
for content in contents:
|
||||
role = getattr(content, "role", None) or content.get("role", "unknown")
|
||||
role_counts[role] = role_counts.get(role, 0) + 1
|
||||
|
||||
generation_config = getattr(request_obj, "generation_config", None) or {}
|
||||
if hasattr(generation_config, "dict"):
|
||||
generation_config = generation_config.dict()
|
||||
elif not isinstance(generation_config, dict):
|
||||
generation_config = {}
|
||||
|
||||
# 判断流式模式
|
||||
stream = getattr(request_obj, "stream", False)
|
||||
|
||||
return {
|
||||
"action": "gemini_generate_content",
|
||||
"model": getattr(request_obj, "model", payload.get("model", "unknown")),
|
||||
"stream": bool(stream),
|
||||
"max_output_tokens": generation_config.get("max_output_tokens"),
|
||||
"temperature": generation_config.get("temperature"),
|
||||
"top_p": generation_config.get("top_p"),
|
||||
"top_k": generation_config.get("top_k"),
|
||||
"contents_count": len(contents),
|
||||
"content_roles": role_counts,
|
||||
"tools_count": len(getattr(request_obj, "tools", None) or []),
|
||||
"system_instruction_present": bool(getattr(request_obj, "system_instruction", None)),
|
||||
"safety_settings_count": len(getattr(request_obj, "safety_settings", None) or []),
|
||||
}
|
||||
|
||||
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
|
||||
"""生成 Gemini 格式的错误响应"""
|
||||
# Gemini 错误响应格式
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"error": {
|
||||
"code": status_code,
|
||||
"message": message,
|
||||
"status": error_type.upper(),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
|
||||
"""
|
||||
根据请求头构建适当的 Gemini 适配器
|
||||
|
||||
Args:
|
||||
x_app_header: X-App 请求头值
|
||||
|
||||
Returns:
|
||||
GeminiChatAdapter 实例
|
||||
"""
|
||||
# 目前只有一种 Gemini 适配器
|
||||
# 未来可以根据 x_app_header 返回不同的适配器(如 CLI 模式)
|
||||
return GeminiChatAdapter()
|
||||
|
||||
|
||||
__all__ = ["GeminiChatAdapter", "build_gemini_adapter"]
|
||||
544
src/api/handlers/gemini/converter.py
Normal file
544
src/api/handlers/gemini/converter.py
Normal file
@@ -0,0 +1,544 @@
|
||||
"""
|
||||
Gemini 格式转换器
|
||||
|
||||
提供 Gemini 与其他 API 格式(Claude、OpenAI)之间的转换
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class ClaudeToGeminiConverter:
|
||||
"""
|
||||
Claude -> Gemini 请求转换器
|
||||
|
||||
将 Claude Messages API 格式转换为 Gemini generateContent 格式
|
||||
"""
|
||||
|
||||
def convert_request(self, claude_request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 Claude 请求转换为 Gemini 请求
|
||||
|
||||
Args:
|
||||
claude_request: Claude 格式的请求字典
|
||||
|
||||
Returns:
|
||||
Gemini 格式的请求字典
|
||||
"""
|
||||
gemini_request: Dict[str, Any] = {
|
||||
"contents": self._convert_messages(claude_request.get("messages", [])),
|
||||
}
|
||||
|
||||
# 转换 system prompt
|
||||
system = claude_request.get("system")
|
||||
if system:
|
||||
gemini_request["system_instruction"] = self._convert_system(system)
|
||||
|
||||
# 转换生成配置
|
||||
generation_config = self._build_generation_config(claude_request)
|
||||
if generation_config:
|
||||
gemini_request["generation_config"] = generation_config
|
||||
|
||||
# 转换工具
|
||||
tools = claude_request.get("tools")
|
||||
if tools:
|
||||
gemini_request["tools"] = self._convert_tools(tools)
|
||||
|
||||
return gemini_request
|
||||
|
||||
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""转换消息列表"""
|
||||
contents = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
# Gemini 使用 "model" 而不是 "assistant"
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
|
||||
content = msg.get("content", "")
|
||||
parts = self._convert_content_to_parts(content)
|
||||
|
||||
contents.append(
|
||||
{
|
||||
"role": gemini_role,
|
||||
"parts": parts,
|
||||
}
|
||||
)
|
||||
return contents
|
||||
|
||||
def _convert_content_to_parts(self, content: Any) -> List[Dict[str, Any]]:
|
||||
"""将 Claude 内容转换为 Gemini parts"""
|
||||
if isinstance(content, str):
|
||||
return [{"text": content}]
|
||||
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
parts.append({"text": block})
|
||||
elif isinstance(block, dict):
|
||||
block_type = block.get("type")
|
||||
if block_type == "text":
|
||||
parts.append({"text": block.get("text", "")})
|
||||
elif block_type == "image":
|
||||
# 转换图片
|
||||
source = block.get("source", {})
|
||||
if source.get("type") == "base64":
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": source.get("media_type", "image/png"),
|
||||
"data": source.get("data", ""),
|
||||
}
|
||||
}
|
||||
)
|
||||
elif block_type == "tool_use":
|
||||
# 转换工具调用
|
||||
parts.append(
|
||||
{
|
||||
"function_call": {
|
||||
"name": block.get("name", ""),
|
||||
"args": block.get("input", {}),
|
||||
}
|
||||
}
|
||||
)
|
||||
elif block_type == "tool_result":
|
||||
# 转换工具结果
|
||||
parts.append(
|
||||
{
|
||||
"function_response": {
|
||||
"name": block.get("tool_use_id", ""),
|
||||
"response": {"result": block.get("content", "")},
|
||||
}
|
||||
}
|
||||
)
|
||||
return parts
|
||||
|
||||
return [{"text": str(content)}]
|
||||
|
||||
def _convert_system(self, system: Any) -> Dict[str, Any]:
|
||||
"""转换 system prompt"""
|
||||
if isinstance(system, str):
|
||||
return {"parts": [{"text": system}]}
|
||||
|
||||
if isinstance(system, list):
|
||||
parts = []
|
||||
for item in system:
|
||||
if isinstance(item, str):
|
||||
parts.append({"text": item})
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
parts.append({"text": item.get("text", "")})
|
||||
return {"parts": parts}
|
||||
|
||||
return {"parts": [{"text": str(system)}]}
|
||||
|
||||
def _build_generation_config(self, claude_request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""构建生成配置"""
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
if "max_tokens" in claude_request:
|
||||
config["max_output_tokens"] = claude_request["max_tokens"]
|
||||
if "temperature" in claude_request:
|
||||
config["temperature"] = claude_request["temperature"]
|
||||
if "top_p" in claude_request:
|
||||
config["top_p"] = claude_request["top_p"]
|
||||
if "top_k" in claude_request:
|
||||
config["top_k"] = claude_request["top_k"]
|
||||
if "stop_sequences" in claude_request:
|
||||
config["stop_sequences"] = claude_request["stop_sequences"]
|
||||
|
||||
return config if config else None
|
||||
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""转换工具定义"""
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
func_decl = {
|
||||
"name": tool.get("name", ""),
|
||||
}
|
||||
if "description" in tool:
|
||||
func_decl["description"] = tool["description"]
|
||||
if "input_schema" in tool:
|
||||
func_decl["parameters"] = tool["input_schema"]
|
||||
function_declarations.append(func_decl)
|
||||
|
||||
return [{"function_declarations": function_declarations}]
|
||||
|
||||
|
||||
class GeminiToClaudeConverter:
|
||||
"""
|
||||
Gemini -> Claude 响应转换器
|
||||
|
||||
将 Gemini generateContent 响应转换为 Claude Messages API 格式
|
||||
"""
|
||||
|
||||
def convert_response(self, gemini_response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 Gemini 响应转换为 Claude 响应
|
||||
|
||||
Args:
|
||||
gemini_response: Gemini 格式的响应字典
|
||||
|
||||
Returns:
|
||||
Claude 格式的响应字典
|
||||
"""
|
||||
candidates = gemini_response.get("candidates", [])
|
||||
if not candidates:
|
||||
return self._create_empty_response()
|
||||
|
||||
candidate = candidates[0]
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
|
||||
# 转换内容块
|
||||
claude_content = self._convert_parts_to_content(parts)
|
||||
|
||||
# 转换使用量
|
||||
usage = self._convert_usage(gemini_response.get("usageMetadata", {}))
|
||||
|
||||
# 转换停止原因
|
||||
stop_reason = self._convert_finish_reason(candidate.get("finishReason"))
|
||||
|
||||
return {
|
||||
"id": f"msg_{gemini_response.get('modelVersion', 'gemini')}",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": claude_content,
|
||||
"model": gemini_response.get("modelVersion", "gemini"),
|
||||
"stop_reason": stop_reason,
|
||||
"stop_sequence": None,
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
def _convert_parts_to_content(self, parts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""将 Gemini parts 转换为 Claude content blocks"""
|
||||
content = []
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": part["text"],
|
||||
}
|
||||
)
|
||||
elif "functionCall" in part:
|
||||
func_call = part["functionCall"]
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": f"toolu_{func_call.get('name', '')}",
|
||||
"name": func_call.get("name", ""),
|
||||
"input": func_call.get("args", {}),
|
||||
}
|
||||
)
|
||||
return content
|
||||
|
||||
def _convert_usage(self, usage_metadata: Dict[str, Any]) -> Dict[str, int]:
|
||||
"""转换使用量信息"""
|
||||
return {
|
||||
"input_tokens": usage_metadata.get("promptTokenCount", 0),
|
||||
"output_tokens": usage_metadata.get("candidatesTokenCount", 0),
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": usage_metadata.get("cachedContentTokenCount", 0),
|
||||
}
|
||||
|
||||
def _convert_finish_reason(self, finish_reason: Optional[str]) -> Optional[str]:
|
||||
"""转换停止原因"""
|
||||
mapping = {
|
||||
"STOP": "end_turn",
|
||||
"MAX_TOKENS": "max_tokens",
|
||||
"SAFETY": "content_filtered",
|
||||
"RECITATION": "content_filtered",
|
||||
"OTHER": "stop_sequence",
|
||||
}
|
||||
return mapping.get(finish_reason, "end_turn")
|
||||
|
||||
def _create_empty_response(self) -> Dict[str, Any]:
|
||||
"""创建空响应"""
|
||||
return {
|
||||
"id": "msg_empty",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": "gemini",
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class OpenAIToGeminiConverter:
|
||||
"""
|
||||
OpenAI -> Gemini 请求转换器
|
||||
|
||||
将 OpenAI Chat Completions API 格式转换为 Gemini generateContent 格式
|
||||
"""
|
||||
|
||||
def convert_request(self, openai_request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 OpenAI 请求转换为 Gemini 请求
|
||||
|
||||
Args:
|
||||
openai_request: OpenAI 格式的请求字典
|
||||
|
||||
Returns:
|
||||
Gemini 格式的请求字典
|
||||
"""
|
||||
messages = openai_request.get("messages", [])
|
||||
|
||||
# 分离 system 消息和其他消息
|
||||
system_messages = []
|
||||
other_messages = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
system_messages.append(msg)
|
||||
else:
|
||||
other_messages.append(msg)
|
||||
|
||||
gemini_request: Dict[str, Any] = {
|
||||
"contents": self._convert_messages(other_messages),
|
||||
}
|
||||
|
||||
# 转换 system messages
|
||||
if system_messages:
|
||||
system_text = "\n".join(msg.get("content", "") for msg in system_messages)
|
||||
gemini_request["system_instruction"] = {"parts": [{"text": system_text}]}
|
||||
|
||||
# 转换生成配置
|
||||
generation_config = self._build_generation_config(openai_request)
|
||||
if generation_config:
|
||||
gemini_request["generation_config"] = generation_config
|
||||
|
||||
# 转换工具
|
||||
tools = openai_request.get("tools")
|
||||
if tools:
|
||||
gemini_request["tools"] = self._convert_tools(tools)
|
||||
|
||||
return gemini_request
|
||||
|
||||
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""转换消息列表"""
|
||||
contents = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
|
||||
content = msg.get("content", "")
|
||||
parts = self._convert_content_to_parts(content)
|
||||
|
||||
# 处理工具调用
|
||||
tool_calls = msg.get("tool_calls", [])
|
||||
for tc in tool_calls:
|
||||
if tc.get("type") == "function":
|
||||
func = tc.get("function", {})
|
||||
import json
|
||||
|
||||
try:
|
||||
args = json.loads(func.get("arguments", "{}"))
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
parts.append(
|
||||
{
|
||||
"function_call": {
|
||||
"name": func.get("name", ""),
|
||||
"args": args,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if parts:
|
||||
contents.append(
|
||||
{
|
||||
"role": gemini_role,
|
||||
"parts": parts,
|
||||
}
|
||||
)
|
||||
return contents
|
||||
|
||||
def _convert_content_to_parts(self, content: Any) -> List[Dict[str, Any]]:
|
||||
"""将 OpenAI 内容转换为 Gemini parts"""
|
||||
if content is None:
|
||||
return []
|
||||
|
||||
if isinstance(content, str):
|
||||
return [{"text": content}]
|
||||
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append({"text": item})
|
||||
elif isinstance(item, dict):
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
parts.append({"text": item.get("text", "")})
|
||||
elif item_type == "image_url":
|
||||
# OpenAI 图片 URL 格式
|
||||
image_url = item.get("image_url", {})
|
||||
url = image_url.get("url", "")
|
||||
if url.startswith("data:"):
|
||||
# base64 数据 URL
|
||||
# 格式: 
|
||||
try:
|
||||
header, data = url.split(",", 1)
|
||||
mime_type = header.split(":")[1].split(";")[0]
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": mime_type,
|
||||
"data": data,
|
||||
}
|
||||
}
|
||||
)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
return parts
|
||||
|
||||
return [{"text": str(content)}]
|
||||
|
||||
def _build_generation_config(self, openai_request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""构建生成配置"""
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
if "max_tokens" in openai_request:
|
||||
config["max_output_tokens"] = openai_request["max_tokens"]
|
||||
if "temperature" in openai_request:
|
||||
config["temperature"] = openai_request["temperature"]
|
||||
if "top_p" in openai_request:
|
||||
config["top_p"] = openai_request["top_p"]
|
||||
if "stop" in openai_request:
|
||||
stop = openai_request["stop"]
|
||||
if isinstance(stop, str):
|
||||
config["stop_sequences"] = [stop]
|
||||
elif isinstance(stop, list):
|
||||
config["stop_sequences"] = stop
|
||||
if "n" in openai_request:
|
||||
config["candidate_count"] = openai_request["n"]
|
||||
|
||||
return config if config else None
|
||||
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""转换工具定义"""
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool.get("function", {})
|
||||
func_decl = {
|
||||
"name": func.get("name", ""),
|
||||
}
|
||||
if "description" in func:
|
||||
func_decl["description"] = func["description"]
|
||||
if "parameters" in func:
|
||||
func_decl["parameters"] = func["parameters"]
|
||||
function_declarations.append(func_decl)
|
||||
|
||||
return [{"function_declarations": function_declarations}]
|
||||
|
||||
|
||||
class GeminiToOpenAIConverter:
|
||||
"""
|
||||
Gemini -> OpenAI 响应转换器
|
||||
|
||||
将 Gemini generateContent 响应转换为 OpenAI Chat Completions API 格式
|
||||
"""
|
||||
|
||||
def convert_response(self, gemini_response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 Gemini 响应转换为 OpenAI 响应
|
||||
|
||||
Args:
|
||||
gemini_response: Gemini 格式的响应字典
|
||||
|
||||
Returns:
|
||||
OpenAI 格式的响应字典
|
||||
"""
|
||||
import time
|
||||
|
||||
candidates = gemini_response.get("candidates", [])
|
||||
choices = []
|
||||
|
||||
for i, candidate in enumerate(candidates):
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
|
||||
# 提取文本内容
|
||||
text_parts = []
|
||||
tool_calls = []
|
||||
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text_parts.append(part["text"])
|
||||
elif "functionCall" in part:
|
||||
func_call = part["functionCall"]
|
||||
import json
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"call_{func_call.get('name', '')}_{i}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": func_call.get("name", ""),
|
||||
"arguments": json.dumps(func_call.get("args", {})),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
message: Dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": "".join(text_parts) if text_parts else None,
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
message["tool_calls"] = tool_calls
|
||||
|
||||
finish_reason = self._convert_finish_reason(candidate.get("finishReason"))
|
||||
|
||||
choices.append(
|
||||
{
|
||||
"index": i,
|
||||
"message": message,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
)
|
||||
|
||||
# 转换使用量
|
||||
usage = self._convert_usage(gemini_response.get("usageMetadata", {}))
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{gemini_response.get('modelVersion', 'gemini')}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": gemini_response.get("modelVersion", "gemini"),
|
||||
"choices": choices,
|
||||
"usage": usage,
|
||||
}
|
||||
|
||||
def _convert_usage(self, usage_metadata: Dict[str, Any]) -> Dict[str, int]:
|
||||
"""转换使用量信息"""
|
||||
prompt_tokens = usage_metadata.get("promptTokenCount", 0)
|
||||
completion_tokens = usage_metadata.get("candidatesTokenCount", 0)
|
||||
return {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
|
||||
def _convert_finish_reason(self, finish_reason: Optional[str]) -> Optional[str]:
|
||||
"""转换停止原因"""
|
||||
mapping = {
|
||||
"STOP": "stop",
|
||||
"MAX_TOKENS": "length",
|
||||
"SAFETY": "content_filter",
|
||||
"RECITATION": "content_filter",
|
||||
"OTHER": "stop",
|
||||
}
|
||||
return mapping.get(finish_reason, "stop")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ClaudeToGeminiConverter",
|
||||
"GeminiToClaudeConverter",
|
||||
"OpenAIToGeminiConverter",
|
||||
"GeminiToOpenAIConverter",
|
||||
]
|
||||
164
src/api/handlers/gemini/handler.py
Normal file
164
src/api/handlers/gemini/handler.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Gemini Chat Handler
|
||||
|
||||
处理 Gemini API 格式的请求
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
|
||||
|
||||
class GeminiChatHandler(ChatHandlerBase):
|
||||
"""
|
||||
Gemini Chat Handler - 处理 Google Gemini API 格式的请求
|
||||
|
||||
格式特点:
|
||||
- 使用 promptTokenCount / candidatesTokenCount
|
||||
- 支持 cachedContentTokenCount
|
||||
- 请求格式: GeminiRequest
|
||||
- 响应格式: JSON 数组流(非 SSE)
|
||||
"""
|
||||
|
||||
FORMAT_ID = "GEMINI"
|
||||
|
||||
def extract_model_from_request(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
从请求中提取模型名 - Gemini Chat 格式实现
|
||||
|
||||
Gemini Chat 模式下,model 在请求体中(经过转换后的 GeminiRequest)。
|
||||
与 Gemini CLI 不同,CLI 模式的 model 在 URL 路径中。
|
||||
|
||||
Args:
|
||||
request_body: 请求体
|
||||
path_params: URL 路径参数(Chat 模式通常不使用)
|
||||
|
||||
Returns:
|
||||
模型名
|
||||
"""
|
||||
# 优先从请求体获取,其次从 path_params
|
||||
model = request_body.get("model")
|
||||
if model:
|
||||
return str(model)
|
||||
if path_params and "model" in path_params:
|
||||
return str(path_params["model"])
|
||||
return "unknown"
|
||||
|
||||
async def _convert_request(self, request):
|
||||
"""
|
||||
将请求转换为 Gemini 格式
|
||||
|
||||
支持自动转换:
|
||||
- Claude 格式 → Gemini 格式
|
||||
- OpenAI 格式 → Gemini 格式
|
||||
|
||||
Args:
|
||||
request: 原始请求对象(可能是 Gemini/Claude/OpenAI 格式)
|
||||
|
||||
Returns:
|
||||
GeminiRequest 对象
|
||||
"""
|
||||
from src.api.handlers.gemini.converter import (
|
||||
ClaudeToGeminiConverter,
|
||||
OpenAIToGeminiConverter,
|
||||
)
|
||||
from src.models.claude import ClaudeMessagesRequest
|
||||
from src.models.gemini import GeminiRequest
|
||||
from src.models.openai import OpenAIRequest
|
||||
|
||||
# 如果已经是 Gemini 格式,直接返回
|
||||
if isinstance(request, GeminiRequest):
|
||||
return request
|
||||
|
||||
# 如果是 Claude 格式,转换为 Gemini 格式
|
||||
if isinstance(request, ClaudeMessagesRequest):
|
||||
converter = ClaudeToGeminiConverter()
|
||||
gemini_dict = converter.convert_request(request.model_dump())
|
||||
return GeminiRequest(**gemini_dict)
|
||||
|
||||
# 如果是 OpenAI 格式,转换为 Gemini 格式
|
||||
if isinstance(request, OpenAIRequest):
|
||||
converter = OpenAIToGeminiConverter()
|
||||
gemini_dict = converter.convert_request(request.model_dump())
|
||||
return GeminiRequest(**gemini_dict)
|
||||
|
||||
# 如果是字典,根据内容判断格式并转换
|
||||
if isinstance(request, dict):
|
||||
# 检测 Gemini 格式特征: contents 字段
|
||||
if "contents" in request:
|
||||
return GeminiRequest(**request)
|
||||
|
||||
# 检测 Claude 格式特征: messages + 没有 choices
|
||||
if "messages" in request and "choices" not in request:
|
||||
# 进一步区分 Claude 和 OpenAI
|
||||
# Claude 使用 max_tokens,OpenAI 也可能有
|
||||
# Claude 的 messages[].content 可以是数组,OpenAI 通常是字符串
|
||||
messages = request.get("messages", [])
|
||||
if messages and isinstance(messages[0].get("content"), list):
|
||||
# 可能是 Claude 格式
|
||||
converter = ClaudeToGeminiConverter()
|
||||
gemini_dict = converter.convert_request(request)
|
||||
return GeminiRequest(**gemini_dict)
|
||||
else:
|
||||
# 可能是 OpenAI 格式
|
||||
converter = OpenAIToGeminiConverter()
|
||||
gemini_dict = converter.convert_request(request)
|
||||
return GeminiRequest(**gemini_dict)
|
||||
|
||||
# 默认尝试作为 Gemini 格式
|
||||
return GeminiRequest(**request)
|
||||
|
||||
return request
|
||||
|
||||
def _extract_usage(self, response: Dict) -> Dict[str, int]:
|
||||
"""
|
||||
从 Gemini 响应中提取 token 使用情况
|
||||
|
||||
调用 GeminiStreamParser.extract_usage 作为单一实现源
|
||||
"""
|
||||
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
|
||||
|
||||
usage = GeminiStreamParser().extract_usage(response)
|
||||
|
||||
if not usage:
|
||||
return {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
}
|
||||
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_input_tokens": 0, # Gemini 不区分缓存创建
|
||||
"cache_read_input_tokens": usage.get("cached_tokens", 0),
|
||||
}
|
||||
|
||||
def _normalize_response(self, response: Dict) -> Dict:
|
||||
"""
|
||||
规范化 Gemini 响应
|
||||
|
||||
Args:
|
||||
response: 原始响应
|
||||
|
||||
Returns:
|
||||
规范化后的响应
|
||||
|
||||
TODO: 如果需要,实现响应规范化逻辑
|
||||
"""
|
||||
# 可选:使用 response_normalizer 进行规范化
|
||||
# if (
|
||||
# self.response_normalizer
|
||||
# and self.response_normalizer.should_normalize(response)
|
||||
# ):
|
||||
# return self.response_normalizer.normalize_gemini_response(
|
||||
# response_data=response,
|
||||
# request_id=self.request_id,
|
||||
# strict=False,
|
||||
# )
|
||||
return response
|
||||
307
src/api/handlers/gemini/stream_parser.py
Normal file
307
src/api/handlers/gemini/stream_parser.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
Gemini SSE/JSON 流解析器
|
||||
|
||||
Gemini API 的流式响应格式与 Claude/OpenAI 不同:
|
||||
- 使用 JSON 数组格式 (不是 SSE)
|
||||
- 每个块是一个完整的 JSON 对象
|
||||
- 响应以 [ 开始,以 ] 结束,块之间用 , 分隔
|
||||
|
||||
参考: https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class GeminiStreamParser:
|
||||
"""
|
||||
Gemini 流解析器
|
||||
|
||||
解析 Gemini streamGenerateContent API 的响应流。
|
||||
|
||||
Gemini 流式响应特点:
|
||||
- 返回 JSON 数组格式: [{chunk1}, {chunk2}, ...]
|
||||
- 每个 chunk 包含 candidates、usageMetadata 等字段
|
||||
- finish_reason 可能值: STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
|
||||
"""
|
||||
|
||||
# 停止原因
|
||||
FINISH_REASON_STOP = "STOP"
|
||||
FINISH_REASON_MAX_TOKENS = "MAX_TOKENS"
|
||||
FINISH_REASON_SAFETY = "SAFETY"
|
||||
FINISH_REASON_RECITATION = "RECITATION"
|
||||
FINISH_REASON_OTHER = "OTHER"
|
||||
|
||||
def __init__(self):
|
||||
self._buffer = ""
|
||||
self._in_array = False
|
||||
self._brace_depth = 0
|
||||
|
||||
def reset(self):
|
||||
"""重置解析器状态"""
|
||||
self._buffer = ""
|
||||
self._in_array = False
|
||||
self._brace_depth = 0
|
||||
|
||||
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
解析流式数据块
|
||||
|
||||
Args:
|
||||
chunk: 原始数据(bytes 或 str)
|
||||
|
||||
Returns:
|
||||
解析后的事件列表
|
||||
"""
|
||||
if isinstance(chunk, bytes):
|
||||
text = chunk.decode("utf-8")
|
||||
else:
|
||||
text = chunk
|
||||
|
||||
events: List[Dict[str, Any]] = []
|
||||
|
||||
for char in text:
|
||||
if char == "[" and not self._in_array:
|
||||
self._in_array = True
|
||||
continue
|
||||
|
||||
if char == "]" and self._in_array and self._brace_depth == 0:
|
||||
# 数组结束
|
||||
self._in_array = False
|
||||
if self._buffer.strip():
|
||||
try:
|
||||
obj = json.loads(self._buffer.strip().rstrip(","))
|
||||
events.append(obj)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
self._buffer = ""
|
||||
continue
|
||||
|
||||
if self._in_array:
|
||||
if char == "{":
|
||||
self._brace_depth += 1
|
||||
elif char == "}":
|
||||
self._brace_depth -= 1
|
||||
|
||||
self._buffer += char
|
||||
|
||||
# 当 brace_depth 回到 0 时,说明一个完整的 JSON 对象结束
|
||||
if self._brace_depth == 0 and self._buffer.strip():
|
||||
try:
|
||||
obj = json.loads(self._buffer.strip().rstrip(","))
|
||||
events.append(obj)
|
||||
self._buffer = ""
|
||||
except json.JSONDecodeError:
|
||||
# 可能还不完整,继续累积
|
||||
pass
|
||||
|
||||
return events
|
||||
|
||||
def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
解析单行 JSON 数据
|
||||
|
||||
Args:
|
||||
line: JSON 数据行
|
||||
|
||||
Returns:
|
||||
解析后的事件字典,如果无法解析返回 None
|
||||
"""
|
||||
if not line or line.strip() in ["[", "]", ","]:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(line.strip().rstrip(","))
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def is_done_event(self, event: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断是否为结束事件
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
True 如果是结束事件
|
||||
"""
|
||||
candidates = event.get("candidates", [])
|
||||
if candidates:
|
||||
for candidate in candidates:
|
||||
finish_reason = candidate.get("finishReason")
|
||||
if finish_reason in (
|
||||
self.FINISH_REASON_STOP,
|
||||
self.FINISH_REASON_MAX_TOKENS,
|
||||
self.FINISH_REASON_SAFETY,
|
||||
self.FINISH_REASON_RECITATION,
|
||||
self.FINISH_REASON_OTHER,
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_error_event(self, event: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断是否为错误事件
|
||||
|
||||
检测多种 Gemini 错误格式:
|
||||
1. 顶层 error: {"error": {...}}
|
||||
2. chunks 内嵌套 error: {"chunks": [{"error": {...}}]}
|
||||
3. candidates 内的错误状态
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
True 如果是错误事件
|
||||
"""
|
||||
# 顶层 error
|
||||
if "error" in event:
|
||||
return True
|
||||
|
||||
# chunks 内嵌套 error (某些 Gemini 响应格式)
|
||||
chunks = event.get("chunks", [])
|
||||
if chunks and isinstance(chunks, list):
|
||||
for chunk in chunks:
|
||||
if isinstance(chunk, dict) and "error" in chunk:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def extract_error_info(self, event: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
从事件中提取错误信息
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
错误信息字典 {"code": int, "message": str, "status": str},无错误返回 None
|
||||
"""
|
||||
# 顶层 error
|
||||
if "error" in event:
|
||||
error = event["error"]
|
||||
if isinstance(error, dict):
|
||||
return {
|
||||
"code": error.get("code"),
|
||||
"message": error.get("message", str(error)),
|
||||
"status": error.get("status"),
|
||||
}
|
||||
return {"code": None, "message": str(error), "status": None}
|
||||
|
||||
# chunks 内嵌套 error
|
||||
chunks = event.get("chunks", [])
|
||||
if chunks and isinstance(chunks, list):
|
||||
for chunk in chunks:
|
||||
if isinstance(chunk, dict) and "error" in chunk:
|
||||
error = chunk["error"]
|
||||
if isinstance(error, dict):
|
||||
return {
|
||||
"code": error.get("code"),
|
||||
"message": error.get("message", str(error)),
|
||||
"status": error.get("status"),
|
||||
}
|
||||
return {"code": None, "message": str(error), "status": None}
|
||||
|
||||
return None
|
||||
|
||||
def get_finish_reason(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
获取结束原因
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
结束原因字符串
|
||||
"""
|
||||
candidates = event.get("candidates", [])
|
||||
if candidates:
|
||||
return candidates[0].get("finishReason")
|
||||
return None
|
||||
|
||||
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
从响应中提取文本内容
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
文本内容,如果没有文本返回 None
|
||||
"""
|
||||
candidates = event.get("candidates", [])
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
content = candidates[0].get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
|
||||
text_parts = []
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text_parts.append(part["text"])
|
||||
|
||||
return "".join(text_parts) if text_parts else None
|
||||
|
||||
def extract_usage(self, event: Dict[str, Any]) -> Optional[Dict[str, int]]:
|
||||
"""
|
||||
从事件中提取 token 使用量
|
||||
|
||||
这是 Gemini token 提取的单一实现源,其他地方都应该调用此方法。
|
||||
|
||||
Args:
|
||||
event: 事件字典(包含 usageMetadata)
|
||||
|
||||
Returns:
|
||||
使用量字典,如果没有完整的使用量信息返回 None
|
||||
|
||||
注意:
|
||||
- 只有当 totalTokenCount 存在时才提取(确保是完整的 usage 数据)
|
||||
- 输出 token = thoughtsTokenCount + candidatesTokenCount
|
||||
"""
|
||||
usage_metadata = event.get("usageMetadata", {})
|
||||
if not usage_metadata or "totalTokenCount" not in usage_metadata:
|
||||
return None
|
||||
|
||||
# 输出 token = thoughtsTokenCount + candidatesTokenCount
|
||||
thoughts_tokens = usage_metadata.get("thoughtsTokenCount", 0)
|
||||
candidates_tokens = usage_metadata.get("candidatesTokenCount", 0)
|
||||
output_tokens = thoughts_tokens + candidates_tokens
|
||||
|
||||
return {
|
||||
"input_tokens": usage_metadata.get("promptTokenCount", 0),
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": usage_metadata.get("totalTokenCount", 0),
|
||||
"cached_tokens": usage_metadata.get("cachedContentTokenCount", 0),
|
||||
}
|
||||
|
||||
def extract_model_version(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
从响应中提取模型版本
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
模型版本,如果没有返回 None
|
||||
"""
|
||||
return event.get("modelVersion")
|
||||
|
||||
def extract_safety_ratings(self, event: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
从响应中提取安全评级
|
||||
|
||||
Args:
|
||||
event: 事件字典
|
||||
|
||||
Returns:
|
||||
安全评级列表,如果没有返回 None
|
||||
"""
|
||||
candidates = event.get("candidates", [])
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
return candidates[0].get("safetyRatings")
|
||||
|
||||
|
||||
__all__ = ["GeminiStreamParser"]
|
||||
12
src/api/handlers/gemini_cli/__init__.py
Normal file
12
src/api/handlers/gemini_cli/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Gemini CLI 透传处理器
|
||||
"""
|
||||
|
||||
from src.api.handlers.gemini_cli.adapter import GeminiCliAdapter, build_gemini_cli_adapter
|
||||
from src.api.handlers.gemini_cli.handler import GeminiCliMessageHandler
|
||||
|
||||
__all__ = [
|
||||
"GeminiCliAdapter",
|
||||
"GeminiCliMessageHandler",
|
||||
"build_gemini_cli_adapter",
|
||||
]
|
||||
112
src/api/handlers/gemini_cli/adapter.py
Normal file
112
src/api/handlers/gemini_cli/adapter.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
Gemini CLI Adapter - 基于通用 CLI Adapter 基类的实现
|
||||
|
||||
继承 CliAdapterBase,处理 Gemini CLI 格式的请求。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
|
||||
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
||||
|
||||
|
||||
@register_cli_adapter
|
||||
class GeminiCliAdapter(CliAdapterBase):
|
||||
"""
|
||||
Gemini CLI API 适配器
|
||||
|
||||
处理 Gemini CLI 格式的请求(透传模式,最小验证)。
|
||||
"""
|
||||
|
||||
FORMAT_ID = "GEMINI_CLI"
|
||||
name = "gemini.cli"
|
||||
|
||||
@property
|
||||
def HANDLER_CLASS(self) -> Type[CliMessageHandlerBase]:
|
||||
"""延迟导入 Handler 类避免循环依赖"""
|
||||
from src.api.handlers.gemini_cli.handler import GeminiCliMessageHandler
|
||||
|
||||
return GeminiCliMessageHandler
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
super().__init__(allowed_api_formats or ["GEMINI_CLI"])
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (x-goog-api-key)"""
|
||||
return request.headers.get("x-goog-api-key")
|
||||
|
||||
def _merge_path_params(
|
||||
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any] # noqa: ARG002
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
合并 URL 路径参数到请求体 - Gemini CLI 特化版本
|
||||
|
||||
Gemini API 特点:
|
||||
- model 不合并到请求体(Gemini 原生请求体不含 model,通过 URL 路径传递)
|
||||
- stream 不合并到请求体(Gemini API 通过 URL 端点区分流式/非流式)
|
||||
|
||||
基类已经从 path_params 获取 model 和 stream 用于日志和路由判断。
|
||||
|
||||
Args:
|
||||
original_request_body: 原始请求体字典
|
||||
path_params: URL 路径参数字典(包含 model、stream 等)
|
||||
|
||||
Returns:
|
||||
原始请求体(不合并任何 path_params)
|
||||
"""
|
||||
# Gemini: 不合并任何 path_params 到请求体
|
||||
return original_request_body.copy()
|
||||
|
||||
def _extract_message_count(self, payload: Dict[str, Any]) -> int:
|
||||
"""Gemini CLI 使用 contents 字段"""
|
||||
contents = payload.get("contents", [])
|
||||
return len(contents) if isinstance(contents, list) else 0
|
||||
|
||||
def _build_audit_metadata(
|
||||
self,
|
||||
payload: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Gemini CLI 特定的审计元数据"""
|
||||
# 从 path_params 获取 model(Gemini 请求体不含 model)
|
||||
model = path_params.get("model", "unknown") if path_params else "unknown"
|
||||
contents = payload.get("contents", [])
|
||||
generation_config = payload.get("generation_config", {}) or {}
|
||||
|
||||
role_counts: Dict[str, int] = {}
|
||||
for content in contents:
|
||||
role = content.get("role", "unknown") if isinstance(content, dict) else "unknown"
|
||||
role_counts[role] = role_counts.get(role, 0) + 1
|
||||
|
||||
return {
|
||||
"action": "gemini_cli_request",
|
||||
"model": model,
|
||||
"stream": bool(payload.get("stream", False)),
|
||||
"max_output_tokens": generation_config.get("max_output_tokens"),
|
||||
"contents_count": len(contents),
|
||||
"content_roles": role_counts,
|
||||
"temperature": generation_config.get("temperature"),
|
||||
"top_p": generation_config.get("top_p"),
|
||||
"top_k": generation_config.get("top_k"),
|
||||
"tools_count": len(payload.get("tools") or []),
|
||||
"system_instruction_present": bool(payload.get("system_instruction")),
|
||||
"safety_settings_count": len(payload.get("safety_settings") or []),
|
||||
}
|
||||
|
||||
|
||||
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
|
||||
"""
|
||||
构建 Gemini CLI 适配器
|
||||
|
||||
Args:
|
||||
x_app_header: X-App 请求头值(预留扩展)
|
||||
|
||||
Returns:
|
||||
GeminiCliAdapter 实例
|
||||
"""
|
||||
return GeminiCliAdapter()
|
||||
|
||||
|
||||
__all__ = ["GeminiCliAdapter", "build_gemini_cli_adapter"]
|
||||
210
src/api/handlers/gemini_cli/handler.py
Normal file
210
src/api/handlers/gemini_cli/handler.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
Gemini CLI Message Handler - 基于通用 CLI Handler 基类的实现
|
||||
|
||||
继承 CliMessageHandlerBase,处理 Gemini CLI API 格式的请求。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.cli_handler_base import (
|
||||
CliMessageHandlerBase,
|
||||
StreamContext,
|
||||
)
|
||||
|
||||
|
||||
class GeminiCliMessageHandler(CliMessageHandlerBase):
|
||||
"""
|
||||
Gemini CLI Message Handler - 处理 Gemini CLI API 格式
|
||||
|
||||
使用新三层架构 (Provider -> ProviderEndpoint -> ProviderAPIKey)
|
||||
通过 FallbackOrchestrator 实现自动故障转移、健康监控和并发控制
|
||||
|
||||
响应格式特点:
|
||||
- Gemini 使用 JSON 数组格式流式响应(非 SSE)
|
||||
- 每个 chunk 包含 candidates、usageMetadata 等字段
|
||||
- finish_reason: STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
|
||||
- Token 使用: promptTokenCount (输入), thoughtsTokenCount + candidatesTokenCount (输出), cachedContentTokenCount (缓存)
|
||||
|
||||
Gemini API 特殊处理:
|
||||
- model 在 URL 路径中而非请求体,如 /v1beta/models/{model}:generateContent
|
||||
- 请求体中的 model 字段用于内部路由,不发送给 API
|
||||
"""
|
||||
|
||||
FORMAT_ID = "GEMINI_CLI"
|
||||
|
||||
def extract_model_from_request(
|
||||
self,
|
||||
request_body: Dict[str, Any], # noqa: ARG002 - 基类签名要求
|
||||
path_params: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
从请求中提取模型名 - Gemini 格式实现
|
||||
|
||||
Gemini API 的 model 在 URL 路径中而非请求体:
|
||||
/v1beta/models/{model}:generateContent
|
||||
|
||||
Args:
|
||||
request_body: 请求体(Gemini 不包含 model)
|
||||
path_params: URL 路径参数(包含 model)
|
||||
|
||||
Returns:
|
||||
模型名,如果无法提取则返回 "unknown"
|
||||
"""
|
||||
# Gemini: model 从 URL 路径参数获取
|
||||
if path_params and "model" in path_params:
|
||||
return str(path_params["model"])
|
||||
return "unknown"
|
||||
|
||||
def prepare_provider_request_body(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
准备发送给 Gemini API 的请求体 - 移除 model 字段
|
||||
|
||||
Gemini API 要求 model 只在 URL 路径中,请求体中的 model 字段
|
||||
会导致某些代理返回 404 错误。
|
||||
|
||||
Args:
|
||||
request_body: 请求体
|
||||
|
||||
Returns:
|
||||
不含 model 字段的请求体
|
||||
"""
|
||||
result = dict(request_body)
|
||||
result.pop("model", None)
|
||||
return result
|
||||
|
||||
def get_model_for_url(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
mapped_model: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Gemini 需要将 model 放入 URL 路径中
|
||||
|
||||
Args:
|
||||
request_body: 请求体
|
||||
mapped_model: 映射后的模型名(如果有)
|
||||
|
||||
Returns:
|
||||
用于 URL 路径的模型名
|
||||
"""
|
||||
# 优先使用映射后的模型名,否则使用请求体中的
|
||||
return mapped_model or request_body.get("model")
|
||||
|
||||
def _extract_usage_from_event(self, event: Dict[str, Any]) -> Dict[str, int]:
|
||||
"""
|
||||
从 Gemini 事件中提取 token 使用情况
|
||||
|
||||
调用 GeminiStreamParser.extract_usage 作为单一实现源
|
||||
|
||||
Args:
|
||||
event: Gemini 流式响应事件
|
||||
|
||||
Returns:
|
||||
包含 input_tokens, output_tokens, cached_tokens 的字典
|
||||
"""
|
||||
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
|
||||
|
||||
usage = GeminiStreamParser().extract_usage(event)
|
||||
|
||||
if not usage:
|
||||
return {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cached_tokens": 0,
|
||||
}
|
||||
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cached_tokens": usage.get("cached_tokens", 0),
|
||||
}
|
||||
|
||||
def _process_event_data(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
_event_type: str,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
处理 Gemini CLI 格式的流式事件
|
||||
|
||||
Gemini 的流式响应是 JSON 数组格式,每个元素结构如下:
|
||||
{
|
||||
"candidates": [{
|
||||
"content": {"parts": [{"text": "..."}], "role": "model"},
|
||||
"finishReason": "STOP",
|
||||
"safetyRatings": [...]
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 10,
|
||||
"candidatesTokenCount": 20,
|
||||
"totalTokenCount": 30,
|
||||
"cachedContentTokenCount": 5
|
||||
},
|
||||
"modelVersion": "gemini-1.5-pro"
|
||||
}
|
||||
|
||||
注意: Gemini 流解析器会将每个 JSON 对象作为一个"事件"传递
|
||||
event_type 在这里可能为空或是自定义的标记
|
||||
"""
|
||||
# 提取候选响应
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates:
|
||||
candidate = candidates[0]
|
||||
content = candidate.get("content", {})
|
||||
|
||||
# 提取文本内容
|
||||
parts = content.get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
ctx.collected_text += part["text"]
|
||||
|
||||
# 检查结束原因
|
||||
finish_reason = candidate.get("finishReason")
|
||||
if finish_reason in ("STOP", "MAX_TOKENS", "SAFETY", "RECITATION", "OTHER"):
|
||||
ctx.has_completion = True
|
||||
ctx.final_response = data
|
||||
|
||||
# 提取使用量信息(复用 GeminiStreamParser.extract_usage)
|
||||
usage = self._extract_usage_from_event(data)
|
||||
if usage["input_tokens"] > 0 or usage["output_tokens"] > 0:
|
||||
ctx.input_tokens = usage["input_tokens"]
|
||||
ctx.output_tokens = usage["output_tokens"]
|
||||
ctx.cached_tokens = usage["cached_tokens"]
|
||||
|
||||
# 提取模型版本作为响应 ID
|
||||
model_version = data.get("modelVersion")
|
||||
if model_version:
|
||||
if not ctx.response_id:
|
||||
ctx.response_id = f"gemini-{model_version}"
|
||||
# 存储到 response_metadata 供 Usage 记录使用
|
||||
ctx.response_metadata["model_version"] = model_version
|
||||
|
||||
# 检查错误
|
||||
if "error" in data:
|
||||
ctx.has_completion = True
|
||||
ctx.final_response = data
|
||||
|
||||
def _extract_response_metadata(
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从 Gemini 响应中提取元数据
|
||||
|
||||
提取 modelVersion 字段,记录实际使用的模型版本。
|
||||
|
||||
Args:
|
||||
response: Gemini API 响应
|
||||
|
||||
Returns:
|
||||
包含 model_version 的元数据字典
|
||||
"""
|
||||
metadata: Dict[str, Any] = {}
|
||||
model_version = response.get("modelVersion")
|
||||
if model_version:
|
||||
metadata["model_version"] = model_version
|
||||
return metadata
|
||||
11
src/api/handlers/openai/__init__.py
Normal file
11
src/api/handlers/openai/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
OpenAI Chat API 处理器
|
||||
"""
|
||||
|
||||
from src.api.handlers.openai.adapter import OpenAIChatAdapter
|
||||
from src.api.handlers.openai.handler import OpenAIChatHandler
|
||||
|
||||
__all__ = [
|
||||
"OpenAIChatAdapter",
|
||||
"OpenAIChatHandler",
|
||||
]
|
||||
109
src/api/handlers/openai/adapter.py
Normal file
109
src/api/handlers/openai/adapter.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
OpenAI Chat Adapter - 基于 ChatAdapterBase 的 OpenAI Chat API 适配器
|
||||
|
||||
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
from src.core.logger import logger
|
||||
from src.models.openai import OpenAIRequest
|
||||
|
||||
|
||||
@register_adapter
|
||||
class OpenAIChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
OpenAI Chat Completions API 适配器
|
||||
|
||||
处理 OpenAI Chat 格式的请求(/v1/chat/completions 端点)。
|
||||
"""
|
||||
|
||||
FORMAT_ID = "OPENAI"
|
||||
name = "openai.chat"
|
||||
|
||||
@property
|
||||
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
|
||||
"""延迟导入 Handler 类避免循环依赖"""
|
||||
from src.api.handlers.openai.handler import OpenAIChatHandler
|
||||
|
||||
return OpenAIChatHandler
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
super().__init__(allowed_api_formats or ["OPENAI"])
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (Authorization: Bearer)"""
|
||||
authorization = request.headers.get("authorization")
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
return authorization.replace("Bearer ", "")
|
||||
return None
|
||||
|
||||
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
|
||||
"""验证请求体"""
|
||||
if not isinstance(original_request_body, dict):
|
||||
return self._error_response(
|
||||
400, "Request body must be a JSON object", "invalid_request_error"
|
||||
)
|
||||
|
||||
required_fields = ["model", "messages"]
|
||||
missing = [f for f in required_fields if f not in original_request_body]
|
||||
if missing:
|
||||
return self._error_response(
|
||||
400,
|
||||
f"Missing required fields: {', '.join(missing)}",
|
||||
"invalid_request_error",
|
||||
)
|
||||
|
||||
try:
|
||||
return OpenAIRequest.model_validate(original_request_body, strict=False)
|
||||
except ValueError as e:
|
||||
return self._error_response(400, str(e), "invalid_request_error")
|
||||
except Exception as e:
|
||||
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
|
||||
return OpenAIRequest.model_construct(
|
||||
model=original_request_body.get("model"),
|
||||
messages=original_request_body.get("messages", []),
|
||||
stream=original_request_body.get("stream", False),
|
||||
max_tokens=original_request_body.get("max_tokens"),
|
||||
)
|
||||
|
||||
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
|
||||
"""构建 OpenAI Chat 特定的审计元数据"""
|
||||
role_counts = {}
|
||||
for message in request_obj.messages:
|
||||
role_counts[message.role] = role_counts.get(message.role, 0) + 1
|
||||
|
||||
return {
|
||||
"action": "openai_chat_completion",
|
||||
"model": request_obj.model,
|
||||
"stream": bool(request_obj.stream),
|
||||
"max_tokens": request_obj.max_tokens,
|
||||
"temperature": request_obj.temperature,
|
||||
"top_p": request_obj.top_p,
|
||||
"messages_count": len(request_obj.messages),
|
||||
"message_roles": role_counts,
|
||||
"tools_count": len(request_obj.tools or []),
|
||||
"response_format": bool(request_obj.response_format),
|
||||
"user_identifier": request_obj.user,
|
||||
}
|
||||
|
||||
def _error_response(self, status_code: int, message: str, error_type: str) -> JSONResponse:
|
||||
"""生成 OpenAI 格式的错误响应"""
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"error": {
|
||||
"type": error_type,
|
||||
"message": message,
|
||||
"code": status_code,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["OpenAIChatAdapter"]
|
||||
424
src/api/handlers/openai/converter.py
Normal file
424
src/api/handlers/openai/converter.py
Normal file
@@ -0,0 +1,424 @@
|
||||
"""
|
||||
Claude -> OpenAI 格式转换器
|
||||
|
||||
将 Claude Messages API 格式转换为 OpenAI Chat Completions API 格式。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ClaudeToOpenAIConverter:
|
||||
"""
|
||||
Claude -> OpenAI 格式转换器
|
||||
|
||||
支持:
|
||||
- 请求转换:Claude Request -> OpenAI Chat Request
|
||||
- 响应转换:Claude Response -> OpenAI Chat Response
|
||||
- 流式转换:Claude SSE -> OpenAI SSE
|
||||
"""
|
||||
|
||||
# 内容类型常量
|
||||
CONTENT_TYPE_TEXT = "text"
|
||||
CONTENT_TYPE_IMAGE = "image"
|
||||
CONTENT_TYPE_TOOL_USE = "tool_use"
|
||||
CONTENT_TYPE_TOOL_RESULT = "tool_result"
|
||||
|
||||
# 停止原因映射
|
||||
STOP_REASON_MAP = {
|
||||
"end_turn": "stop",
|
||||
"max_tokens": "length",
|
||||
"stop_sequence": "stop",
|
||||
"tool_use": "tool_calls",
|
||||
}
|
||||
|
||||
def __init__(self, model_mapping: Optional[Dict[str, str]] = None):
|
||||
"""
|
||||
Args:
|
||||
model_mapping: Claude 模型到 OpenAI 模型的映射
|
||||
"""
|
||||
self._model_mapping = model_mapping or {}
|
||||
|
||||
# ==================== 请求转换 ====================
|
||||
|
||||
def convert_request(self, request: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 Claude 请求转换为 OpenAI 格式
|
||||
|
||||
Args:
|
||||
request: Claude 请求(Dict 或 Pydantic 模型)
|
||||
|
||||
Returns:
|
||||
OpenAI 格式的请求字典
|
||||
"""
|
||||
if hasattr(request, "model_dump"):
|
||||
data = request.model_dump(exclude_none=True)
|
||||
else:
|
||||
data = dict(request)
|
||||
|
||||
# 模型映射
|
||||
model = data.get("model", "")
|
||||
openai_model = self._model_mapping.get(model, model)
|
||||
|
||||
# 构建消息列表
|
||||
messages: List[Dict[str, Any]] = []
|
||||
|
||||
# 处理 system 消息
|
||||
system_content = self._extract_text_content(data.get("system"))
|
||||
if system_content:
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
|
||||
# 处理对话消息
|
||||
for message in data.get("messages", []):
|
||||
converted = self._convert_message(message)
|
||||
if converted:
|
||||
messages.append(converted)
|
||||
|
||||
# 构建 OpenAI 请求
|
||||
result: Dict[str, Any] = {
|
||||
"model": openai_model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
# 可选参数
|
||||
if data.get("max_tokens"):
|
||||
result["max_tokens"] = data["max_tokens"]
|
||||
if data.get("temperature") is not None:
|
||||
result["temperature"] = data["temperature"]
|
||||
if data.get("top_p") is not None:
|
||||
result["top_p"] = data["top_p"]
|
||||
if data.get("stream"):
|
||||
result["stream"] = data["stream"]
|
||||
if data.get("stop_sequences"):
|
||||
result["stop"] = data["stop_sequences"]
|
||||
|
||||
# 工具转换
|
||||
tools = self._convert_tools(data.get("tools"))
|
||||
if tools:
|
||||
result["tools"] = tools
|
||||
|
||||
tool_choice = self._convert_tool_choice(data.get("tool_choice"))
|
||||
if tool_choice:
|
||||
result["tool_choice"] = tool_choice
|
||||
|
||||
return result
|
||||
|
||||
def _convert_message(self, message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""转换单条消息"""
|
||||
role = message.get("role")
|
||||
|
||||
if role == "user":
|
||||
return self._convert_user_message(message)
|
||||
if role == "assistant":
|
||||
return self._convert_assistant_message(message)
|
||||
|
||||
return None
|
||||
|
||||
def _convert_user_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""转换用户消息"""
|
||||
content = message.get("content")
|
||||
|
||||
if isinstance(content, str):
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
openai_content: List[Dict[str, Any]] = []
|
||||
for block in content or []:
|
||||
block_type = block.get("type")
|
||||
|
||||
if block_type == self.CONTENT_TYPE_TEXT:
|
||||
openai_content.append({"type": "text", "text": block.get("text", "")})
|
||||
elif block_type == self.CONTENT_TYPE_IMAGE:
|
||||
source = block.get("source", {})
|
||||
media_type = source.get("media_type", "image/jpeg")
|
||||
data = source.get("data", "")
|
||||
openai_content.append(
|
||||
{"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{data}"}}
|
||||
)
|
||||
elif block_type == self.CONTENT_TYPE_TOOL_RESULT:
|
||||
tool_content = block.get("content", "")
|
||||
rendered = self._render_tool_content(tool_content)
|
||||
openai_content.append({"type": "text", "text": f"Tool result: {rendered}"})
|
||||
|
||||
# 简化单文本内容
|
||||
if len(openai_content) == 1 and openai_content[0]["type"] == "text":
|
||||
return {"role": "user", "content": openai_content[0]["text"]}
|
||||
|
||||
return {"role": "user", "content": openai_content or ""}
|
||||
|
||||
def _convert_assistant_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""转换助手消息"""
|
||||
content = message.get("content")
|
||||
text_parts: List[str] = []
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
|
||||
if isinstance(content, str):
|
||||
text_parts.append(content)
|
||||
else:
|
||||
for idx, block in enumerate(content or []):
|
||||
block_type = block.get("type")
|
||||
|
||||
if block_type == self.CONTENT_TYPE_TEXT:
|
||||
text_parts.append(block.get("text", ""))
|
||||
elif block_type == self.CONTENT_TYPE_TOOL_USE:
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id", f"call_{idx}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name", ""),
|
||||
"arguments": json.dumps(block.get("input", {}), ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result: Dict[str, Any] = {"role": "assistant"}
|
||||
|
||||
message_content = "\n".join([p for p in text_parts if p]) or None
|
||||
if message_content:
|
||||
result["content"] = message_content
|
||||
|
||||
if tool_calls:
|
||||
result["tool_calls"] = tool_calls
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tools(
|
||||
self, tools: Optional[List[Dict[str, Any]]]
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""转换工具定义"""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
result: List[Dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
result.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.get("name", ""),
|
||||
"description": tool.get("description"),
|
||||
"parameters": tool.get("input_schema", {}),
|
||||
},
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
def _convert_tool_choice(
|
||||
self, tool_choice: Optional[Dict[str, Any]]
|
||||
) -> Optional[Union[str, Dict[str, Any]]]:
|
||||
"""转换工具选择"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
|
||||
choice_type = tool_choice.get("type")
|
||||
if choice_type in ("tool", "tool_use"):
|
||||
return {"type": "function", "function": {"name": tool_choice.get("name", "")}}
|
||||
if choice_type == "any":
|
||||
return "required"
|
||||
if choice_type == "auto":
|
||||
return "auto"
|
||||
|
||||
return tool_choice
|
||||
|
||||
# ==================== 响应转换 ====================
|
||||
|
||||
def convert_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
将 Claude 响应转换为 OpenAI 格式
|
||||
|
||||
Args:
|
||||
response: Claude 响应字典
|
||||
|
||||
Returns:
|
||||
OpenAI 格式的响应字典
|
||||
"""
|
||||
# 提取内容
|
||||
content_parts: List[str] = []
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
|
||||
for idx, block in enumerate(response.get("content", [])):
|
||||
block_type = block.get("type")
|
||||
|
||||
if block_type == self.CONTENT_TYPE_TEXT:
|
||||
content_parts.append(block.get("text", ""))
|
||||
elif block_type == self.CONTENT_TYPE_TOOL_USE:
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id", f"call_{idx}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name", ""),
|
||||
"arguments": json.dumps(block.get("input", {}), ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 构建消息
|
||||
message: Dict[str, Any] = {"role": "assistant"}
|
||||
text_content = "\n".join([p for p in content_parts if p]) or None
|
||||
if text_content:
|
||||
message["content"] = text_content
|
||||
if tool_calls:
|
||||
message["tool_calls"] = tool_calls
|
||||
|
||||
# 转换停止原因
|
||||
stop_reason = response.get("stop_reason")
|
||||
finish_reason = self.STOP_REASON_MAP.get(stop_reason, stop_reason)
|
||||
|
||||
# 转换 usage
|
||||
usage = response.get("usage", {})
|
||||
openai_usage = {
|
||||
"prompt_tokens": usage.get("input_tokens", 0),
|
||||
"completion_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": (usage.get("input_tokens", 0) + usage.get("output_tokens", 0)),
|
||||
}
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{response.get('id', uuid.uuid4().hex[:8])}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": response.get("model", ""),
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": openai_usage,
|
||||
}
|
||||
|
||||
# ==================== 流式转换 ====================
|
||||
|
||||
def convert_stream_event(
|
||||
self,
|
||||
event: Dict[str, Any],
|
||||
model: str = "",
|
||||
message_id: Optional[str] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
将 Claude SSE 事件转换为 OpenAI 格式
|
||||
|
||||
Args:
|
||||
event: Claude SSE 事件
|
||||
model: 模型名称
|
||||
message_id: 消息 ID
|
||||
|
||||
Returns:
|
||||
OpenAI 格式的 SSE chunk,如果无法转换返回 None
|
||||
"""
|
||||
event_type = event.get("type")
|
||||
chunk_id = f"chatcmpl-{(message_id or 'stream')[-8:]}"
|
||||
|
||||
if event_type == "message_start":
|
||||
message = event.get("message", {})
|
||||
return self._base_chunk(
|
||||
chunk_id,
|
||||
model or message.get("model", ""),
|
||||
{"role": "assistant"},
|
||||
)
|
||||
|
||||
if event_type == "content_block_start":
|
||||
content_block = event.get("content_block", {})
|
||||
if content_block.get("type") == self.CONTENT_TYPE_TOOL_USE:
|
||||
delta = {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": event.get("index", 0),
|
||||
"id": content_block.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": content_block.get("name", ""),
|
||||
"arguments": "",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
return self._base_chunk(chunk_id, model, delta)
|
||||
return None
|
||||
|
||||
if event_type == "content_block_delta":
|
||||
delta_payload = event.get("delta", {})
|
||||
delta_type = delta_payload.get("type")
|
||||
|
||||
if delta_type == "text_delta":
|
||||
delta = {"content": delta_payload.get("text", "")}
|
||||
return self._base_chunk(chunk_id, model, delta)
|
||||
|
||||
if delta_type == "input_json_delta":
|
||||
delta = {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": event.get("index", 0),
|
||||
"function": {"arguments": delta_payload.get("partial_json", "")},
|
||||
}
|
||||
]
|
||||
}
|
||||
return self._base_chunk(chunk_id, model, delta)
|
||||
return None
|
||||
|
||||
if event_type == "message_delta":
|
||||
delta = event.get("delta", {})
|
||||
stop_reason = delta.get("stop_reason")
|
||||
finish_reason = self.STOP_REASON_MAP.get(stop_reason, stop_reason)
|
||||
return self._base_chunk(chunk_id, model, {}, finish_reason=finish_reason)
|
||||
|
||||
if event_type == "message_stop":
|
||||
return self._base_chunk(chunk_id, model, {}, finish_reason="stop")
|
||||
|
||||
return None
|
||||
|
||||
def _base_chunk(
|
||||
self,
|
||||
chunk_id: str,
|
||||
model: str,
|
||||
delta: Dict[str, Any],
|
||||
finish_reason: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建基础 OpenAI chunk"""
|
||||
return {
|
||||
"id": chunk_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"system_fingerprint": None,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": delta,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# ==================== 工具方法 ====================
|
||||
|
||||
def _extract_text_content(
|
||||
self, content: Optional[Union[str, List[Dict[str, Any]]]]
|
||||
) -> Optional[str]:
|
||||
"""提取文本内容"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
block.get("text", "")
|
||||
for block in content
|
||||
if block.get("type") == self.CONTENT_TYPE_TEXT
|
||||
]
|
||||
return "\n\n".join(filter(None, parts)) or None
|
||||
return None
|
||||
|
||||
def _render_tool_content(self, tool_content: Any) -> str:
|
||||
"""渲染工具内容"""
|
||||
if isinstance(tool_content, list):
|
||||
return json.dumps(tool_content, ensure_ascii=False)
|
||||
return str(tool_content)
|
||||
|
||||
|
||||
__all__ = ["ClaudeToOpenAIConverter"]
|
||||
137
src/api/handlers/openai/handler.py
Normal file
137
src/api/handlers/openai/handler.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
OpenAI Chat Handler - 基于通用 Chat Handler 基类的简化实现
|
||||
|
||||
继承 ChatHandlerBase,只需覆盖格式特定的方法。
|
||||
代码量从原来的 ~1315 行减少到 ~100 行。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
|
||||
|
||||
class OpenAIChatHandler(ChatHandlerBase):
|
||||
"""
|
||||
OpenAI Chat Handler - 处理 OpenAI Chat Completions API 格式的请求
|
||||
|
||||
格式特点:
|
||||
- 使用 prompt_tokens/completion_tokens
|
||||
- 不支持 cache tokens
|
||||
- 请求格式:OpenAIRequest
|
||||
"""
|
||||
|
||||
FORMAT_ID = "OPENAI"
|
||||
|
||||
def extract_model_from_request(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
|
||||
) -> str:
|
||||
"""
|
||||
从请求中提取模型名 - OpenAI 格式实现
|
||||
|
||||
OpenAI API 的 model 在请求体顶级字段。
|
||||
|
||||
Args:
|
||||
request_body: 请求体
|
||||
path_params: URL 路径参数(OpenAI 不使用)
|
||||
|
||||
Returns:
|
||||
模型名
|
||||
"""
|
||||
model = request_body.get("model")
|
||||
return str(model) if model else "unknown"
|
||||
|
||||
def apply_mapped_model(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
mapped_model: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
将映射后的模型名应用到请求体
|
||||
|
||||
OpenAI API 的 model 在请求体顶级字段。
|
||||
|
||||
Args:
|
||||
request_body: 原始请求体
|
||||
mapped_model: 映射后的模型名
|
||||
|
||||
Returns:
|
||||
更新了 model 字段的请求体
|
||||
"""
|
||||
result = dict(request_body)
|
||||
result["model"] = mapped_model
|
||||
return result
|
||||
|
||||
async def _convert_request(self, request):
|
||||
"""
|
||||
将请求转换为 OpenAI 格式
|
||||
|
||||
Args:
|
||||
request: 原始请求对象
|
||||
|
||||
Returns:
|
||||
OpenAIRequest 对象
|
||||
"""
|
||||
from src.api.handlers.openai.converter import ClaudeToOpenAIConverter
|
||||
from src.models.claude import ClaudeMessagesRequest
|
||||
from src.models.openai import OpenAIRequest
|
||||
|
||||
# 如果已经是 OpenAI 格式,直接返回
|
||||
if isinstance(request, OpenAIRequest):
|
||||
return request
|
||||
|
||||
# 如果是 Claude 格式,转换为 OpenAI 格式
|
||||
if isinstance(request, ClaudeMessagesRequest):
|
||||
converter = ClaudeToOpenAIConverter()
|
||||
openai_dict = converter.convert_request(request.dict())
|
||||
return OpenAIRequest(**openai_dict)
|
||||
|
||||
# 如果是字典,尝试判断格式
|
||||
if isinstance(request, dict):
|
||||
try:
|
||||
return OpenAIRequest(**request)
|
||||
except Exception:
|
||||
try:
|
||||
converter = ClaudeToOpenAIConverter()
|
||||
openai_dict = converter.convert_request(request)
|
||||
return OpenAIRequest(**openai_dict)
|
||||
except Exception:
|
||||
return OpenAIRequest(**request)
|
||||
|
||||
return request
|
||||
|
||||
def _extract_usage(self, response: Dict) -> Dict[str, int]:
|
||||
"""
|
||||
从 OpenAI 响应中提取 token 使用情况
|
||||
|
||||
OpenAI 格式使用:
|
||||
- prompt_tokens / completion_tokens
|
||||
- 不支持 cache tokens
|
||||
"""
|
||||
usage = response.get("usage", {})
|
||||
|
||||
return {
|
||||
"input_tokens": usage.get("prompt_tokens", 0),
|
||||
"output_tokens": usage.get("completion_tokens", 0),
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
}
|
||||
|
||||
def _normalize_response(self, response: Dict) -> Dict:
|
||||
"""
|
||||
规范化 OpenAI 响应
|
||||
|
||||
Args:
|
||||
response: 原始响应
|
||||
|
||||
Returns:
|
||||
规范化后的响应
|
||||
"""
|
||||
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
||||
return self.response_normalizer.normalize_openai_response(
|
||||
response_data=response,
|
||||
request_id=self.request_id,
|
||||
strict=False,
|
||||
)
|
||||
return response
|
||||
181
src/api/handlers/openai/stream_parser.py
Normal file
181
src/api/handlers/openai/stream_parser.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
OpenAI SSE 流解析器
|
||||
|
||||
解析 OpenAI Chat Completions API 的 Server-Sent Events 流。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class OpenAIStreamParser:
|
||||
"""
|
||||
OpenAI SSE 流解析器
|
||||
|
||||
解析 OpenAI Chat Completions API 的 SSE 事件流。
|
||||
|
||||
OpenAI 流格式:
|
||||
- 每个 chunk 是一个 JSON 对象,包含 choices 数组
|
||||
- choices[0].delta 包含增量内容
|
||||
- choices[0].finish_reason 表示结束原因
|
||||
- 流结束时发送 data: [DONE]
|
||||
"""
|
||||
|
||||
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
解析 SSE 数据块
|
||||
|
||||
Args:
|
||||
chunk: 原始 SSE 数据(bytes 或 str)
|
||||
|
||||
Returns:
|
||||
解析后的 chunk 列表
|
||||
"""
|
||||
if isinstance(chunk, bytes):
|
||||
text = chunk.decode("utf-8")
|
||||
else:
|
||||
text = chunk
|
||||
|
||||
chunks: List[Dict[str, Any]] = []
|
||||
lines = text.strip().split("\n")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 解析数据行
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
|
||||
# 处理 [DONE] 标记
|
||||
if data_str == "[DONE]":
|
||||
chunks.append({"__done__": True})
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
chunks.append(data)
|
||||
except json.JSONDecodeError:
|
||||
# 无法解析的数据,跳过
|
||||
pass
|
||||
|
||||
return chunks
|
||||
|
||||
def parse_line(self, line: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
解析单行 SSE 数据
|
||||
|
||||
Args:
|
||||
line: SSE 数据行(已去除 "data: " 前缀)
|
||||
|
||||
Returns:
|
||||
解析后的 chunk 字典,如果无法解析返回 None
|
||||
"""
|
||||
if not line or line == "[DONE]":
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def is_done_chunk(self, chunk: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断是否为结束 chunk
|
||||
|
||||
Args:
|
||||
chunk: chunk 字典
|
||||
|
||||
Returns:
|
||||
True 如果是结束 chunk
|
||||
"""
|
||||
# 内部标记
|
||||
if chunk.get("__done__"):
|
||||
return True
|
||||
|
||||
# 检查 finish_reason
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
finish_reason = choices[0].get("finish_reason")
|
||||
return finish_reason is not None
|
||||
|
||||
return False
|
||||
|
||||
def get_finish_reason(self, chunk: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
获取结束原因
|
||||
|
||||
Args:
|
||||
chunk: chunk 字典
|
||||
|
||||
Returns:
|
||||
结束原因字符串
|
||||
"""
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
return choices[0].get("finish_reason")
|
||||
return None
|
||||
|
||||
def extract_text_delta(self, chunk: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
从 chunk 中提取文本增量
|
||||
|
||||
Args:
|
||||
chunk: chunk 字典
|
||||
|
||||
Returns:
|
||||
文本增量,如果没有返回 None
|
||||
"""
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
return None
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content")
|
||||
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
return None
|
||||
|
||||
def extract_tool_calls_delta(self, chunk: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
从 chunk 中提取工具调用增量
|
||||
|
||||
Args:
|
||||
chunk: chunk 字典
|
||||
|
||||
Returns:
|
||||
工具调用列表,如果没有返回 None
|
||||
"""
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
return None
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
return delta.get("tool_calls")
|
||||
|
||||
def extract_role(self, chunk: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
从 chunk 中提取角色
|
||||
|
||||
通常只在第一个 chunk 中出现。
|
||||
|
||||
Args:
|
||||
chunk: chunk 字典
|
||||
|
||||
Returns:
|
||||
角色字符串
|
||||
"""
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
return None
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
return delta.get("role")
|
||||
|
||||
|
||||
__all__ = ["OpenAIStreamParser"]
|
||||
11
src/api/handlers/openai_cli/__init__.py
Normal file
11
src/api/handlers/openai_cli/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
OpenAI CLI 透传处理器
|
||||
"""
|
||||
|
||||
from src.api.handlers.openai_cli.adapter import OpenAICliAdapter
|
||||
from src.api.handlers.openai_cli.handler import OpenAICliMessageHandler
|
||||
|
||||
__all__ = [
|
||||
"OpenAICliAdapter",
|
||||
"OpenAICliMessageHandler",
|
||||
]
|
||||
44
src/api/handlers/openai_cli/adapter.py
Normal file
44
src/api/handlers/openai_cli/adapter.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
OpenAI CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
|
||||
|
||||
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
||||
"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from src.api.handlers.base.cli_adapter_base import CliAdapterBase, register_cli_adapter
|
||||
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
||||
|
||||
|
||||
@register_cli_adapter
|
||||
class OpenAICliAdapter(CliAdapterBase):
|
||||
"""
|
||||
OpenAI CLI API 适配器
|
||||
|
||||
处理 /v1/responses 端点的请求。
|
||||
"""
|
||||
|
||||
FORMAT_ID = "OPENAI_CLI"
|
||||
name = "openai.cli"
|
||||
|
||||
@property
|
||||
def HANDLER_CLASS(self) -> Type[CliMessageHandlerBase]:
|
||||
"""延迟导入 Handler 类避免循环依赖"""
|
||||
from src.api.handlers.openai_cli.handler import OpenAICliMessageHandler
|
||||
|
||||
return OpenAICliMessageHandler
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
super().__init__(allowed_api_formats or ["OPENAI_CLI"])
|
||||
|
||||
def extract_api_key(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 API 密钥 (Authorization: Bearer)"""
|
||||
authorization = request.headers.get("authorization")
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
return authorization.replace("Bearer ", "")
|
||||
return None
|
||||
|
||||
|
||||
__all__ = ["OpenAICliAdapter"]
|
||||
211
src/api/handlers/openai_cli/handler.py
Normal file
211
src/api/handlers/openai_cli/handler.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
OpenAI CLI Message Handler - 基于通用 CLI Handler 基类的简化实现
|
||||
|
||||
继承 CliMessageHandlerBase,只需覆盖格式特定的配置和事件处理逻辑。
|
||||
代码量从原来的 900+ 行减少到 ~100 行。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.cli_handler_base import (
|
||||
CliMessageHandlerBase,
|
||||
StreamContext,
|
||||
)
|
||||
|
||||
|
||||
class OpenAICliMessageHandler(CliMessageHandlerBase):
|
||||
"""
|
||||
OpenAI CLI Message Handler - 处理 OpenAI CLI Responses API 格式
|
||||
|
||||
使用新三层架构 (Provider -> ProviderEndpoint -> ProviderAPIKey)
|
||||
通过 FallbackOrchestrator 实现自动故障转移、健康监控和并发控制
|
||||
|
||||
响应格式特点:
|
||||
- 使用 output[] 数组而非 content[]
|
||||
- 使用 output_text 类型而非普通 text
|
||||
- 流式事件:response.output_text.delta, response.output_text.done
|
||||
|
||||
模型字段:请求体顶级 model 字段
|
||||
"""
|
||||
|
||||
FORMAT_ID = "OPENAI_CLI"
|
||||
|
||||
def extract_model_from_request(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None, # noqa: ARG002
|
||||
) -> str:
|
||||
"""
|
||||
从请求中提取模型名 - OpenAI 格式实现
|
||||
|
||||
OpenAI API 的 model 在请求体顶级字段。
|
||||
|
||||
Args:
|
||||
request_body: 请求体
|
||||
path_params: URL 路径参数(OpenAI 不使用)
|
||||
|
||||
Returns:
|
||||
模型名
|
||||
"""
|
||||
model = request_body.get("model")
|
||||
return str(model) if model else "unknown"
|
||||
|
||||
def apply_mapped_model(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
mapped_model: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
OpenAI CLI (Responses API) 的 model 在请求体顶级字段。
|
||||
|
||||
Args:
|
||||
request_body: 原始请求体
|
||||
mapped_model: 映射后的模型名
|
||||
|
||||
Returns:
|
||||
更新了 model 字段的请求体
|
||||
"""
|
||||
result = dict(request_body)
|
||||
result["model"] = mapped_model
|
||||
return result
|
||||
|
||||
def _process_event_data(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
event_type: str,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
处理 OpenAI CLI 格式的 SSE 事件
|
||||
|
||||
事件类型:
|
||||
- response.output_text.delta: 文本增量
|
||||
- response.completed: 响应完成(包含 usage)
|
||||
"""
|
||||
# 提取 response_id
|
||||
if not ctx.response_id:
|
||||
response_obj = data.get("response")
|
||||
if isinstance(response_obj, dict) and response_obj.get("id"):
|
||||
ctx.response_id = response_obj["id"]
|
||||
elif "id" in data:
|
||||
ctx.response_id = data["id"]
|
||||
|
||||
# 处理文本增量
|
||||
if event_type in ["response.output_text.delta", "response.outtext.delta"]:
|
||||
delta = data.get("delta")
|
||||
if isinstance(delta, str):
|
||||
ctx.collected_text += delta
|
||||
elif isinstance(delta, dict) and "text" in delta:
|
||||
ctx.collected_text += delta["text"]
|
||||
|
||||
# 处理完成事件
|
||||
elif event_type == "response.completed":
|
||||
ctx.has_completion = True
|
||||
response_obj = data.get("response")
|
||||
if isinstance(response_obj, dict):
|
||||
ctx.final_response = response_obj
|
||||
|
||||
usage_obj = response_obj.get("usage")
|
||||
if isinstance(usage_obj, dict):
|
||||
ctx.final_usage = usage_obj
|
||||
ctx.input_tokens = usage_obj.get("input_tokens", 0)
|
||||
ctx.output_tokens = usage_obj.get("output_tokens", 0)
|
||||
|
||||
details = usage_obj.get("input_tokens_details")
|
||||
if isinstance(details, dict):
|
||||
ctx.cached_tokens = details.get("cached_tokens", 0)
|
||||
|
||||
# 如果没有收集到文本,从 output 中提取
|
||||
if not ctx.collected_text and "output" in response_obj:
|
||||
for output_item in response_obj.get("output", []):
|
||||
if output_item.get("type") != "message":
|
||||
continue
|
||||
for content_item in output_item.get("content", []):
|
||||
if content_item.get("type") == "output_text":
|
||||
text = content_item.get("text", "")
|
||||
if text:
|
||||
ctx.collected_text += text
|
||||
|
||||
# 备用:从顶层 usage 提取
|
||||
usage_obj = data.get("usage")
|
||||
if isinstance(usage_obj, dict) and not ctx.final_usage:
|
||||
ctx.final_usage = usage_obj
|
||||
ctx.input_tokens = usage_obj.get("input_tokens", 0)
|
||||
ctx.output_tokens = usage_obj.get("output_tokens", 0)
|
||||
|
||||
details = usage_obj.get("input_tokens_details")
|
||||
if isinstance(details, dict):
|
||||
ctx.cached_tokens = details.get("cached_tokens", 0)
|
||||
|
||||
# 备用:从 response 字段提取
|
||||
response_obj = data.get("response")
|
||||
if isinstance(response_obj, dict) and not ctx.final_response:
|
||||
ctx.final_response = response_obj
|
||||
|
||||
def _extract_response_metadata(
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从 OpenAI 响应中提取元数据
|
||||
|
||||
提取 model、status、response_id 等字段作为元数据。
|
||||
|
||||
Args:
|
||||
response: OpenAI API 响应
|
||||
|
||||
Returns:
|
||||
提取的元数据字典
|
||||
"""
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
# 提取模型名称(实际使用的模型)
|
||||
if "model" in response:
|
||||
metadata["model"] = response["model"]
|
||||
|
||||
# 提取响应 ID
|
||||
if "id" in response:
|
||||
metadata["response_id"] = response["id"]
|
||||
|
||||
# 提取状态
|
||||
if "status" in response:
|
||||
metadata["status"] = response["status"]
|
||||
|
||||
# 提取对象类型
|
||||
if "object" in response:
|
||||
metadata["object"] = response["object"]
|
||||
|
||||
# 提取系统指纹(如果存在)
|
||||
if "system_fingerprint" in response:
|
||||
metadata["system_fingerprint"] = response["system_fingerprint"]
|
||||
|
||||
return metadata
|
||||
|
||||
def _finalize_stream_metadata(self, ctx: StreamContext) -> None:
|
||||
"""
|
||||
从流上下文中提取最终元数据
|
||||
|
||||
在流传输完成后调用,从收集的事件中提取元数据。
|
||||
|
||||
Args:
|
||||
ctx: 流上下文
|
||||
"""
|
||||
# 从 response_id 提取响应 ID
|
||||
if ctx.response_id:
|
||||
ctx.response_metadata["response_id"] = ctx.response_id
|
||||
|
||||
# 从 final_response 提取更多元数据
|
||||
if ctx.final_response and isinstance(ctx.final_response, dict):
|
||||
if "model" in ctx.final_response:
|
||||
ctx.response_metadata["model"] = ctx.final_response["model"]
|
||||
if "status" in ctx.final_response:
|
||||
ctx.response_metadata["status"] = ctx.final_response["status"]
|
||||
if "object" in ctx.final_response:
|
||||
ctx.response_metadata["object"] = ctx.final_response["object"]
|
||||
if "system_fingerprint" in ctx.final_response:
|
||||
ctx.response_metadata["system_fingerprint"] = ctx.final_response["system_fingerprint"]
|
||||
|
||||
# 如果没有从响应中获取到 model,使用上下文中的
|
||||
if "model" not in ctx.response_metadata and ctx.model:
|
||||
ctx.response_metadata["model"] = ctx.model
|
||||
|
||||
Reference in New Issue
Block a user