mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 08:12:26 +08:00
Initial commit
This commit is contained in:
68
src/api/handlers/base/__init__.py
Normal file
68
src/api/handlers/base/__init__.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Handler 基类模块
|
||||
|
||||
提供 Adapter、Handler 的抽象基类,以及请求构建器和响应解析器。
|
||||
|
||||
注意:Handler 基类(ChatHandlerBase, CliMessageHandlerBase 等)不在这里导出,
|
||||
因为它们依赖 services.usage.stream,而后者又需要导入 response_parser,
|
||||
会形成循环导入。请直接从具体模块导入 Handler 基类。
|
||||
"""
|
||||
|
||||
# Chat Adapter 基类(不会引起循环导入)
|
||||
from src.api.handlers.base.chat_adapter_base import (
|
||||
ChatAdapterBase,
|
||||
get_adapter_class,
|
||||
get_adapter_instance,
|
||||
list_registered_formats,
|
||||
register_adapter,
|
||||
)
|
||||
|
||||
# CLI Adapter 基类
|
||||
from src.api.handlers.base.cli_adapter_base import (
|
||||
CliAdapterBase,
|
||||
get_cli_adapter_class,
|
||||
get_cli_adapter_instance,
|
||||
list_registered_cli_formats,
|
||||
register_cli_adapter,
|
||||
)
|
||||
|
||||
# 请求构建器
|
||||
from src.api.handlers.base.request_builder import (
|
||||
SENSITIVE_HEADERS,
|
||||
PassthroughRequestBuilder,
|
||||
RequestBuilder,
|
||||
build_passthrough_request,
|
||||
)
|
||||
|
||||
# 响应解析器
|
||||
from src.api.handlers.base.response_parser import (
|
||||
ParsedChunk,
|
||||
ParsedResponse,
|
||||
ResponseParser,
|
||||
StreamStats,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Chat Adapter
|
||||
"ChatAdapterBase",
|
||||
"register_adapter",
|
||||
"get_adapter_class",
|
||||
"get_adapter_instance",
|
||||
"list_registered_formats",
|
||||
# CLI Adapter
|
||||
"CliAdapterBase",
|
||||
"register_cli_adapter",
|
||||
"get_cli_adapter_class",
|
||||
"get_cli_adapter_instance",
|
||||
"list_registered_cli_formats",
|
||||
# 请求构建器
|
||||
"RequestBuilder",
|
||||
"PassthroughRequestBuilder",
|
||||
"build_passthrough_request",
|
||||
"SENSITIVE_HEADERS",
|
||||
# 响应解析器
|
||||
"ResponseParser",
|
||||
"ParsedChunk",
|
||||
"ParsedResponse",
|
||||
"StreamStats",
|
||||
]
|
||||
363
src/api/handlers/base/base_handler.py
Normal file
363
src/api/handlers/base/base_handler.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
基础消息处理器,封装通用的编排、转换、遥测逻辑。
|
||||
|
||||
接口约定:
|
||||
- process_stream: 处理流式请求,返回 StreamingResponse
|
||||
- process_sync: 处理非流式请求,返回 JSONResponse
|
||||
|
||||
签名规范(推荐):
|
||||
async def process_stream(
|
||||
self,
|
||||
request: Any, # 解析后的请求模型
|
||||
http_request: Request, # FastAPI Request 对象
|
||||
original_headers: Dict[str, str], # 原始请求头
|
||||
original_request_body: Dict[str, Any], # 原始请求体
|
||||
query_params: Optional[Dict[str, str]] = None, # 查询参数
|
||||
) -> StreamingResponse: ...
|
||||
|
||||
async def process_sync(
|
||||
self,
|
||||
request: Any,
|
||||
http_request: Request,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
query_params: Optional[Dict[str, str]] = None,
|
||||
) -> JSONResponse: ...
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.clients.redis_client import get_redis_client_sync
|
||||
from src.core.api_format_metadata import resolve_api_format
|
||||
from src.core.enums import APIFormat
|
||||
from src.core.logger import logger
|
||||
from src.services.orchestration.fallback_orchestrator import FallbackOrchestrator
|
||||
from src.services.provider.format import normalize_api_format
|
||||
from src.services.system.audit import audit_service
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
|
||||
|
||||
class MessageTelemetry:
|
||||
"""
|
||||
负责记录 Usage/Audit,避免处理器里重复代码。
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, user, api_key, request_id: str, client_ip: str):
|
||||
self.db = db
|
||||
self.user = user
|
||||
self.api_key = api_key
|
||||
self.request_id = request_id
|
||||
self.client_ip = client_ip
|
||||
|
||||
async def calculate_cost(
|
||||
self,
|
||||
provider: str,
|
||||
model: str,
|
||||
*,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
) -> float:
|
||||
input_price, output_price = await UsageService.get_model_price_async(
|
||||
self.db, provider, model
|
||||
)
|
||||
_, _, _, _, _, _, total_cost = UsageService.calculate_cost(
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
input_price,
|
||||
output_price,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
*await UsageService.get_cache_prices_async(self.db, provider, model, input_price),
|
||||
)
|
||||
return total_cost
|
||||
|
||||
async def record_success(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
response_time_ms: int,
|
||||
status_code: int,
|
||||
request_body: Dict[str, Any],
|
||||
request_headers: Dict[str, Any],
|
||||
response_body: Any,
|
||||
response_headers: Dict[str, Any],
|
||||
cache_creation_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
is_stream: bool = False,
|
||||
provider_request_headers: Optional[Dict[str, Any]] = None,
|
||||
# Provider 侧追踪信息(用于记录真实成本)
|
||||
provider_id: Optional[str] = None,
|
||||
provider_endpoint_id: Optional[str] = None,
|
||||
provider_api_key_id: Optional[str] = None,
|
||||
api_format: Optional[str] = None,
|
||||
# 模型映射信息
|
||||
target_model: Optional[str] = None,
|
||||
# Provider 响应元数据(如 Gemini 的 modelVersion)
|
||||
response_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> float:
|
||||
total_cost = await self.calculate_cost(
|
||||
provider,
|
||||
model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
)
|
||||
|
||||
await UsageService.record_usage(
|
||||
db=self.db,
|
||||
user=self.user,
|
||||
api_key=self.api_key,
|
||||
provider=provider,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_tokens,
|
||||
cache_read_input_tokens=cache_read_tokens,
|
||||
request_type="chat",
|
||||
api_format=api_format,
|
||||
is_stream=is_stream,
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=status_code,
|
||||
request_headers=request_headers,
|
||||
request_body=request_body,
|
||||
provider_request_headers=provider_request_headers or {},
|
||||
response_headers=response_headers,
|
||||
response_body=response_body,
|
||||
request_id=self.request_id,
|
||||
# Provider 侧追踪信息(用于记录真实成本)
|
||||
provider_id=provider_id,
|
||||
provider_endpoint_id=provider_endpoint_id,
|
||||
provider_api_key_id=provider_api_key_id,
|
||||
# 模型映射信息
|
||||
target_model=target_model,
|
||||
# Provider 响应元数据
|
||||
metadata=response_metadata,
|
||||
)
|
||||
|
||||
if self.user and self.api_key:
|
||||
audit_service.log_api_request(
|
||||
db=self.db,
|
||||
user_id=self.user.id,
|
||||
api_key_id=self.api_key.id,
|
||||
request_id=self.request_id,
|
||||
model=model,
|
||||
provider=provider,
|
||||
success=True,
|
||||
ip_address=self.client_ip,
|
||||
status_code=status_code,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost_usd=total_cost,
|
||||
)
|
||||
|
||||
return total_cost
|
||||
|
||||
async def record_failure(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
response_time_ms: int,
|
||||
status_code: int,
|
||||
error_message: str,
|
||||
request_body: Dict[str, Any],
|
||||
request_headers: Dict[str, Any],
|
||||
is_stream: bool,
|
||||
api_format: Optional[str] = None,
|
||||
provider_request_headers: Optional[Dict[str, Any]] = None,
|
||||
# 预估 token 信息(来自 message_start 事件,用于中断请求的成本估算)
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
response_body: Optional[Dict[str, Any]] = None,
|
||||
# 模型映射信息
|
||||
target_model: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
记录失败请求
|
||||
|
||||
注意:Provider 链路信息(provider_id, endpoint_id, key_id)不在此处记录,
|
||||
因为 RequestCandidate 表已经记录了完整的请求链路追踪信息。
|
||||
|
||||
Args:
|
||||
input_tokens: 预估输入 tokens(来自 message_start,用于中断请求的成本估算)
|
||||
output_tokens: 预估输出 tokens(来自已收到的内容)
|
||||
cache_creation_tokens: 缓存创建 tokens
|
||||
cache_read_tokens: 缓存读取 tokens
|
||||
response_body: 响应体(如果有部分响应)
|
||||
target_model: 映射后的目标模型名(如果发生了映射)
|
||||
"""
|
||||
provider_name = provider or "unknown"
|
||||
if provider_name == "unknown":
|
||||
logger.warning(f"[Telemetry] Recording failure with unknown provider (request_id={self.request_id})")
|
||||
|
||||
await UsageService.record_usage(
|
||||
db=self.db,
|
||||
user=self.user,
|
||||
api_key=self.api_key,
|
||||
provider=provider_name,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_tokens,
|
||||
cache_read_input_tokens=cache_read_tokens,
|
||||
request_type="chat",
|
||||
api_format=api_format,
|
||||
is_stream=is_stream,
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
request_headers=request_headers,
|
||||
request_body=request_body,
|
||||
provider_request_headers=provider_request_headers or {},
|
||||
response_headers={},
|
||||
response_body=response_body or {"error": error_message},
|
||||
request_id=self.request_id,
|
||||
# 模型映射信息
|
||||
target_model=target_model,
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MessageHandlerProtocol(Protocol):
|
||||
"""
|
||||
消息处理器协议 - 定义标准接口
|
||||
|
||||
ChatHandlerBase 使用完整签名(含 request, http_request)。
|
||||
CliMessageHandlerBase 使用简化签名(仅 original_request_body, original_headers)。
|
||||
"""
|
||||
|
||||
async def process_stream(
|
||||
self,
|
||||
request: Any,
|
||||
http_request: Request,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
query_params: Optional[Dict[str, str]] = None,
|
||||
) -> StreamingResponse:
|
||||
"""处理流式请求"""
|
||||
...
|
||||
|
||||
async def process_sync(
|
||||
self,
|
||||
request: Any,
|
||||
http_request: Request,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
query_params: Optional[Dict[str, str]] = None,
|
||||
) -> JSONResponse:
|
||||
"""处理非流式请求"""
|
||||
...
|
||||
|
||||
|
||||
class BaseMessageHandler:
|
||||
"""
|
||||
消息处理器基类,所有具体格式的 handler 可以继承它。
|
||||
|
||||
子类需要实现:
|
||||
- process_stream: 处理流式请求
|
||||
- process_sync: 处理非流式请求
|
||||
|
||||
推荐使用 MessageHandlerProtocol 中定义的签名。
|
||||
"""
|
||||
|
||||
# Adapter 检测器类型
|
||||
AdapterDetectorType = Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
db: Session,
|
||||
user,
|
||||
api_key,
|
||||
request_id: str,
|
||||
client_ip: str,
|
||||
user_agent: str,
|
||||
start_time: float,
|
||||
allowed_api_formats: Optional[list[str]] = None,
|
||||
adapter_detector: Optional[AdapterDetectorType] = None,
|
||||
):
|
||||
self.db = db
|
||||
self.user = user
|
||||
self.api_key = api_key
|
||||
self.request_id = request_id
|
||||
self.client_ip = client_ip
|
||||
self.user_agent = user_agent
|
||||
self.start_time = start_time
|
||||
self.allowed_api_formats = allowed_api_formats or [APIFormat.CLAUDE.value]
|
||||
self.primary_api_format = normalize_api_format(self.allowed_api_formats[0])
|
||||
self.adapter_detector = adapter_detector
|
||||
|
||||
redis_client = get_redis_client_sync()
|
||||
self.orchestrator = FallbackOrchestrator(db, redis_client)
|
||||
self.telemetry = MessageTelemetry(db, user, api_key, request_id, client_ip)
|
||||
|
||||
def elapsed_ms(self) -> int:
|
||||
return int((time.time() - self.start_time) * 1000)
|
||||
|
||||
def _resolve_capability_requirements(
|
||||
self,
|
||||
model_name: str,
|
||||
request_headers: Optional[Dict[str, str]] = None,
|
||||
request_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
解析请求的能力需求
|
||||
|
||||
来源:
|
||||
1. 用户模型级配置 (User.model_capability_settings)
|
||||
2. 用户 API Key 强制配置 (ApiKey.force_capabilities)
|
||||
3. 请求头 X-Require-Capability
|
||||
4. Adapter 的 detect_capability_requirements(如 Claude 的 anthropic-beta)
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
request_headers: 请求头
|
||||
request_body: 请求体(可选)
|
||||
|
||||
Returns:
|
||||
能力需求字典
|
||||
"""
|
||||
from src.services.capability.resolver import CapabilityResolver
|
||||
|
||||
return CapabilityResolver.resolve_requirements(
|
||||
user=self.user,
|
||||
user_api_key=self.api_key,
|
||||
model_name=model_name,
|
||||
request_headers=request_headers,
|
||||
request_body=request_body,
|
||||
adapter_detector=self.adapter_detector,
|
||||
)
|
||||
|
||||
def get_api_format(self, provider_type: Optional[str] = None) -> APIFormat:
|
||||
"""根据 provider_type 解析 API 格式,未知类型默认 OPENAI"""
|
||||
if provider_type:
|
||||
return resolve_api_format(provider_type, default=APIFormat.OPENAI)
|
||||
return self.primary_api_format
|
||||
|
||||
def build_provider_payload(
|
||||
self,
|
||||
original_body: Dict[str, Any],
|
||||
*,
|
||||
mapped_model: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建发送给 Provider 的请求体,替换 model 名称"""
|
||||
payload = dict(original_body)
|
||||
if mapped_model:
|
||||
payload["model"] = mapped_model
|
||||
return payload
|
||||
724
src/api/handlers/base/chat_adapter_base.py
Normal file
724
src/api/handlers/base/chat_adapter_base.py
Normal file
@@ -0,0 +1,724 @@
|
||||
"""
|
||||
Chat Adapter 通用基类
|
||||
|
||||
提供 Chat 格式(进行请求验证和标准化)的通用适配器逻辑:
|
||||
- 请求解析和验证
|
||||
- 审计日志记录
|
||||
- 错误处理和响应格式化
|
||||
- Handler 创建和调用
|
||||
- 计费策略(支持不同 API 格式的差异化计费)
|
||||
|
||||
子类只需提供:
|
||||
- FORMAT_ID: API 格式标识
|
||||
- HANDLER_CLASS: 对应的 ChatHandlerBase 子类
|
||||
- _validate_request_body(): 可选覆盖请求验证逻辑
|
||||
- _build_audit_metadata(): 可选覆盖审计元数据构建
|
||||
- compute_total_input_context(): 可选覆盖总输入上下文计算(用于阶梯计费判定)
|
||||
"""
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.context import ApiRequestContext
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
from src.core.exceptions import (
|
||||
InvalidRequestException,
|
||||
ModelNotSupportedException,
|
||||
ProviderAuthException,
|
||||
ProviderNotAvailableException,
|
||||
ProviderRateLimitException,
|
||||
ProviderTimeoutException,
|
||||
ProxyException,
|
||||
QuotaExceededException,
|
||||
UpstreamClientException,
|
||||
)
|
||||
from src.core.logger import logger
|
||||
from src.services.request.result import RequestResult
|
||||
from src.services.usage.recorder import UsageRecorder
|
||||
|
||||
|
||||
|
||||
class ChatAdapterBase(ApiAdapter):
|
||||
"""
|
||||
Chat Adapter 通用基类
|
||||
|
||||
提供 Chat 格式的通用适配器逻辑,子类只需配置:
|
||||
- FORMAT_ID: API 格式标识
|
||||
- HANDLER_CLASS: ChatHandlerBase 子类
|
||||
- name: 适配器名称
|
||||
"""
|
||||
|
||||
# 子类必须覆盖
|
||||
FORMAT_ID: str = "UNKNOWN"
|
||||
HANDLER_CLASS: Type[ChatHandlerBase]
|
||||
|
||||
# 适配器配置
|
||||
name: str = "chat.base"
|
||||
mode = ApiMode.STANDARD
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||
self.response_normalizer = None
|
||||
# 可选启用响应规范化
|
||||
self._init_response_normalizer()
|
||||
|
||||
def _init_response_normalizer(self):
|
||||
"""初始化响应规范化器 - 子类可覆盖"""
|
||||
try:
|
||||
from src.services.provider.response_normalizer import ResponseNormalizer
|
||||
|
||||
self.response_normalizer = ResponseNormalizer()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
"""处理 Chat API 请求"""
|
||||
http_request = context.request
|
||||
user = context.user
|
||||
api_key = context.api_key
|
||||
db = context.db
|
||||
request_id = context.request_id
|
||||
quota_remaining_value = context.quota_remaining
|
||||
start_time = context.start_time
|
||||
client_ip = context.client_ip
|
||||
user_agent = context.user_agent
|
||||
original_headers = context.original_headers
|
||||
query_params = context.query_params
|
||||
|
||||
original_request_body = context.ensure_json_body()
|
||||
|
||||
# 合并 path_params 到请求体(如 Gemini API 的 model 在 URL 路径中)
|
||||
if context.path_params:
|
||||
original_request_body = self._merge_path_params(
|
||||
original_request_body, context.path_params
|
||||
)
|
||||
|
||||
# 验证和解析请求
|
||||
request_obj = self._validate_request_body(original_request_body, context.path_params)
|
||||
if isinstance(request_obj, JSONResponse):
|
||||
return request_obj
|
||||
|
||||
stream = getattr(request_obj, "stream", False)
|
||||
model = getattr(request_obj, "model", "unknown")
|
||||
|
||||
# 添加审计元数据
|
||||
audit_metadata = self._build_audit_metadata(original_request_body, request_obj)
|
||||
context.add_audit_metadata(**audit_metadata)
|
||||
|
||||
# 格式化额度显示
|
||||
quota_display = (
|
||||
"unlimited" if quota_remaining_value is None else f"${quota_remaining_value:.2f}"
|
||||
)
|
||||
|
||||
# 请求开始日志
|
||||
logger.info(f"[REQ] {request_id[:8]} | {self.FORMAT_ID} | {getattr(api_key, 'name', 'unknown')} | "
|
||||
f"{model} | {'stream' if stream else 'sync'} | quota:{quota_display}")
|
||||
|
||||
try:
|
||||
# 检查客户端连接
|
||||
if await http_request.is_disconnected():
|
||||
logger.warning("客户端连接断开")
|
||||
raise HTTPException(status_code=499, detail="Client disconnected")
|
||||
|
||||
# 创建 Handler
|
||||
handler = self._create_handler(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
request_id=request_id,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
# 处理请求
|
||||
if stream:
|
||||
return await handler.process_stream(
|
||||
request=request_obj,
|
||||
http_request=http_request,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
query_params=query_params,
|
||||
)
|
||||
return await handler.process_sync(
|
||||
request=request_obj,
|
||||
http_request=http_request,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
query_params=query_params,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
except (
|
||||
ModelNotSupportedException,
|
||||
QuotaExceededException,
|
||||
InvalidRequestException,
|
||||
) as e:
|
||||
logger.info(f"客户端请求错误: {e.error_type}")
|
||||
return self._error_response(
|
||||
status_code=e.status_code,
|
||||
error_type=(
|
||||
"invalid_request_error" if e.status_code == 400 else "quota_exceeded"
|
||||
),
|
||||
message=e.message,
|
||||
)
|
||||
|
||||
except (
|
||||
ProviderAuthException,
|
||||
ProviderRateLimitException,
|
||||
ProviderNotAvailableException,
|
||||
ProviderTimeoutException,
|
||||
UpstreamClientException,
|
||||
) as e:
|
||||
return await self._handle_provider_exception(
|
||||
e,
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
start_time=start_time,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return await self._handle_unexpected_exception(
|
||||
e,
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
start_time=start_time,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
def _create_handler(
|
||||
self,
|
||||
*,
|
||||
db,
|
||||
user,
|
||||
api_key,
|
||||
request_id: str,
|
||||
client_ip: str,
|
||||
user_agent: str,
|
||||
start_time: float,
|
||||
):
|
||||
"""创建 Handler 实例 - 子类可覆盖"""
|
||||
return self.HANDLER_CLASS(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
request_id=request_id,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
start_time=start_time,
|
||||
allowed_api_formats=self.allowed_api_formats,
|
||||
response_normalizer=self.response_normalizer,
|
||||
enable_response_normalization=self.response_normalizer is not None,
|
||||
adapter_detector=self.detect_capability_requirements,
|
||||
)
|
||||
|
||||
def _merge_path_params(
|
||||
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
合并 URL 路径参数到请求体 - 子类可覆盖
|
||||
|
||||
默认实现:直接将 path_params 中的字段合并到请求体(不覆盖已有字段)
|
||||
|
||||
Args:
|
||||
original_request_body: 原始请求体字典
|
||||
path_params: URL 路径参数字典
|
||||
|
||||
Returns:
|
||||
合并后的请求体字典
|
||||
"""
|
||||
merged = original_request_body.copy()
|
||||
for key, value in path_params.items():
|
||||
if key not in merged:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
@abstractmethod
|
||||
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
|
||||
"""
|
||||
验证请求体 - 子类必须实现
|
||||
|
||||
Args:
|
||||
original_request_body: 原始请求体字典
|
||||
path_params: URL 路径参数(如 Gemini 的 stream 通过 URL 端点传入)
|
||||
|
||||
Returns:
|
||||
验证后的请求对象,或 JSONResponse 错误响应
|
||||
"""
|
||||
pass
|
||||
|
||||
def _extract_message_count(self, payload: Dict[str, Any], request_obj) -> int:
|
||||
"""
|
||||
提取消息数量 - 子类可覆盖
|
||||
|
||||
默认实现:从 messages 字段提取
|
||||
"""
|
||||
messages = payload.get("messages", [])
|
||||
if hasattr(request_obj, "messages"):
|
||||
messages = request_obj.messages
|
||||
return len(messages) if isinstance(messages, list) else 0
|
||||
|
||||
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
|
||||
"""
|
||||
构建审计日志元数据 - 子类可覆盖
|
||||
"""
|
||||
model = getattr(request_obj, "model", payload.get("model", "unknown"))
|
||||
stream = getattr(request_obj, "stream", payload.get("stream", False))
|
||||
messages_count = self._extract_message_count(payload, request_obj)
|
||||
|
||||
return {
|
||||
"action": f"{self.FORMAT_ID.lower()}_request",
|
||||
"model": model,
|
||||
"stream": bool(stream),
|
||||
"max_tokens": getattr(request_obj, "max_tokens", payload.get("max_tokens")),
|
||||
"messages_count": messages_count,
|
||||
"temperature": getattr(request_obj, "temperature", payload.get("temperature")),
|
||||
"top_p": getattr(request_obj, "top_p", payload.get("top_p")),
|
||||
}
|
||||
|
||||
async def _handle_provider_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
*,
|
||||
db,
|
||||
user,
|
||||
api_key,
|
||||
model: str,
|
||||
stream: bool,
|
||||
start_time: float,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
client_ip: str,
|
||||
request_id: str,
|
||||
) -> JSONResponse:
|
||||
"""处理 Provider 相关异常"""
|
||||
logger.debug(f"Caught provider exception: {type(e).__name__}")
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 使用 RequestResult.from_exception 创建统一的失败结果
|
||||
# 关键:api_format 从 FORMAT_ID 获取,确保始终有值
|
||||
result = RequestResult.from_exception(
|
||||
exception=e,
|
||||
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
|
||||
model=model,
|
||||
response_time_ms=response_time,
|
||||
is_stream=stream,
|
||||
)
|
||||
result.request_headers = original_headers
|
||||
result.request_body = original_request_body
|
||||
|
||||
# 确定错误消息
|
||||
if isinstance(e, ProviderAuthException):
|
||||
error_message = (
|
||||
f"提供商认证失败: {str(e)}"
|
||||
if result.metadata.provider != "unknown"
|
||||
else "服务端错误: 无可用提供商"
|
||||
)
|
||||
result.error_message = error_message
|
||||
|
||||
# 处理上游客户端错误(如图片处理失败)
|
||||
if isinstance(e, UpstreamClientException):
|
||||
# 返回 400 状态码和清晰的错误消息
|
||||
result.status_code = e.status_code
|
||||
result.error_message = e.message
|
||||
|
||||
# 使用 UsageRecorder 记录失败
|
||||
recorder = UsageRecorder(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
await recorder.record_failure(result, original_headers, original_request_body)
|
||||
|
||||
# 根据异常类型确定错误类型
|
||||
if isinstance(e, UpstreamClientException):
|
||||
error_type = "invalid_request_error"
|
||||
elif result.status_code == 503:
|
||||
error_type = "internal_server_error"
|
||||
else:
|
||||
error_type = "rate_limit_exceeded"
|
||||
|
||||
return self._error_response(
|
||||
status_code=result.status_code,
|
||||
error_type=error_type,
|
||||
message=result.error_message or str(e),
|
||||
)
|
||||
|
||||
async def _handle_unexpected_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
*,
|
||||
db,
|
||||
user,
|
||||
api_key,
|
||||
model: str,
|
||||
stream: bool,
|
||||
start_time: float,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
client_ip: str,
|
||||
request_id: str,
|
||||
) -> JSONResponse:
|
||||
"""处理未预期的异常"""
|
||||
if isinstance(e, ProxyException):
|
||||
logger.error(f"{self.FORMAT_ID} 请求处理业务异常: {type(e).__name__}")
|
||||
else:
|
||||
logger.error(f"{self.FORMAT_ID} 请求处理意外异常",
|
||||
exception=e,
|
||||
extra_data={
|
||||
"exception_class": e.__class__.__name__,
|
||||
"processing_stage": "request_processing",
|
||||
"model": model,
|
||||
"stream": stream,
|
||||
"traceback_preview": str(traceback.format_exc())[:500],
|
||||
},
|
||||
)
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 使用 RequestResult.from_exception 创建统一的失败结果
|
||||
# 关键:api_format 从 FORMAT_ID 获取,确保始终有值
|
||||
result = RequestResult.from_exception(
|
||||
exception=e,
|
||||
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
|
||||
model=model,
|
||||
response_time_ms=response_time,
|
||||
is_stream=stream,
|
||||
)
|
||||
# 对于未预期的异常,强制设置状态码为 500
|
||||
result.status_code = 500
|
||||
result.error_type = "internal_error"
|
||||
result.request_headers = original_headers
|
||||
result.request_body = original_request_body
|
||||
|
||||
try:
|
||||
# 使用 UsageRecorder 记录失败
|
||||
recorder = UsageRecorder(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
await recorder.record_failure(result, original_headers, original_request_body)
|
||||
except Exception as record_error:
|
||||
logger.error(f"记录失败请求时出错: {record_error}")
|
||||
|
||||
return self._error_response(
|
||||
status_code=500,
|
||||
error_type="internal_server_error",
|
||||
message="处理请求时发生内部错误")
|
||||
|
||||
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
|
||||
"""生成错误响应 - 子类可覆盖以自定义格式"""
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"error": {
|
||||
"type": error_type,
|
||||
"message": message,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 计费策略相关方法 - 子类可覆盖以实现不同 API 格式的差异化计费
|
||||
# =========================================================================
|
||||
|
||||
def compute_total_input_context(
|
||||
self,
|
||||
input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
) -> int:
|
||||
"""
|
||||
计算总输入上下文(用于阶梯计费判定)
|
||||
|
||||
默认实现:input_tokens + cache_read_input_tokens
|
||||
子类可覆盖此方法实现不同的计算逻辑
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数(部分格式可能需要)
|
||||
|
||||
Returns:
|
||||
总输入上下文 token 数
|
||||
"""
|
||||
return input_tokens + cache_read_input_tokens
|
||||
|
||||
def get_cache_read_price_for_ttl(
|
||||
self,
|
||||
tier: dict,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
根据缓存 TTL 获取缓存读取价格
|
||||
|
||||
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
|
||||
子类可覆盖此方法实现不同的 TTL 定价逻辑
|
||||
|
||||
Args:
|
||||
tier: 当前阶梯配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
缓存读取价格(每 1M tokens)
|
||||
"""
|
||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||
if ttl_pricing and cache_ttl_minutes is not None:
|
||||
matched_price = None
|
||||
for ttl_config in ttl_pricing:
|
||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||
if cache_ttl_minutes <= ttl_limit:
|
||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
||||
break
|
||||
if matched_price is not None:
|
||||
return matched_price
|
||||
# 超过所有配置的 TTL,使用最后一个
|
||||
if ttl_pricing:
|
||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||
|
||||
return tier.get("cache_read_price_per_1m")
|
||||
|
||||
def compute_cost(
|
||||
self,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
input_price_per_1m: float,
|
||||
output_price_per_1m: float,
|
||||
cache_creation_price_per_1m: Optional[float],
|
||||
cache_read_price_per_1m: Optional[float],
|
||||
price_per_request: Optional[float],
|
||||
tiered_pricing: Optional[dict] = None,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
计算请求成本
|
||||
|
||||
默认实现:支持固定价格和阶梯计费
|
||||
子类可覆盖此方法实现完全不同的计费逻辑
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
output_tokens: 输出 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
input_price_per_1m: 输入价格(每 1M tokens)
|
||||
output_price_per_1m: 输出价格(每 1M tokens)
|
||||
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens)
|
||||
cache_read_price_per_1m: 缓存读取价格(每 1M tokens)
|
||||
price_per_request: 按次计费价格
|
||||
tiered_pricing: 阶梯计费配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
包含各项成本的字典:
|
||||
{
|
||||
"input_cost": float,
|
||||
"output_cost": float,
|
||||
"cache_creation_cost": float,
|
||||
"cache_read_cost": float,
|
||||
"cache_cost": float,
|
||||
"request_cost": float,
|
||||
"total_cost": float,
|
||||
"tier_index": Optional[int], # 命中的阶梯索引
|
||||
}
|
||||
"""
|
||||
tier_index = None
|
||||
effective_input_price = input_price_per_1m
|
||||
effective_output_price = output_price_per_1m
|
||||
effective_cache_creation_price = cache_creation_price_per_1m
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 检查阶梯计费
|
||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
|
||||
|
||||
if tier:
|
||||
tier_index = tiered_pricing["tiers"].index(tier)
|
||||
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
|
||||
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
|
||||
effective_cache_creation_price = tier.get(
|
||||
"cache_creation_price_per_1m", cache_creation_price_per_1m
|
||||
)
|
||||
effective_cache_read_price = self.get_cache_read_price_for_ttl(
|
||||
tier, cache_ttl_minutes
|
||||
)
|
||||
if effective_cache_read_price is None:
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 计算各项成本
|
||||
input_cost = (input_tokens / 1_000_000) * effective_input_price
|
||||
output_cost = (output_tokens / 1_000_000) * effective_output_price
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * effective_cache_creation_price
|
||||
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
|
||||
cache_read_cost = (
|
||||
cache_read_input_tokens / 1_000_000
|
||||
) * effective_cache_read_price
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
request_cost = price_per_request if price_per_request else 0.0
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
|
||||
return {
|
||||
"input_cost": input_cost,
|
||||
"output_cost": output_cost,
|
||||
"cache_creation_cost": cache_creation_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_cost": cache_cost,
|
||||
"request_cost": request_cost,
|
||||
"total_cost": total_cost,
|
||||
"tier_index": tier_index,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
|
||||
"""
|
||||
根据总输入 token 数确定价格阶梯
|
||||
|
||||
Args:
|
||||
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
|
||||
total_input_tokens: 总输入 token 数
|
||||
|
||||
Returns:
|
||||
匹配的阶梯配置
|
||||
"""
|
||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
for tier in tiers:
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_tokens <= up_to:
|
||||
return tier
|
||||
|
||||
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
|
||||
return tiers[-1] if tiers else None
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
|
||||
# =========================================================================
|
||||
|
||||
_ADAPTER_REGISTRY: Dict[str, Type["ChatAdapterBase"]] = {}
|
||||
_ADAPTERS_LOADED = False
|
||||
|
||||
|
||||
def register_adapter(adapter_class: Type["ChatAdapterBase"]) -> Type["ChatAdapterBase"]:
|
||||
"""
|
||||
注册 Adapter 类到注册表
|
||||
|
||||
用法:
|
||||
@register_adapter
|
||||
class ClaudeChatAdapter(ChatAdapterBase):
|
||||
FORMAT_ID = "CLAUDE"
|
||||
...
|
||||
|
||||
Args:
|
||||
adapter_class: Adapter 类
|
||||
|
||||
Returns:
|
||||
注册的 Adapter 类(支持作为装饰器使用)
|
||||
"""
|
||||
format_id = adapter_class.FORMAT_ID
|
||||
if format_id and format_id != "UNKNOWN":
|
||||
_ADAPTER_REGISTRY[format_id.upper()] = adapter_class
|
||||
return adapter_class
|
||||
|
||||
|
||||
def _ensure_adapters_loaded():
|
||||
"""确保所有 Adapter 已被加载(触发注册)"""
|
||||
global _ADAPTERS_LOADED
|
||||
if _ADAPTERS_LOADED:
|
||||
return
|
||||
|
||||
# 导入各个 Adapter 模块以触发 @register_adapter 装饰器
|
||||
try:
|
||||
from src.api.handlers.claude import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from src.api.handlers.openai import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from src.api.handlers.gemini import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_ADAPTERS_LOADED = True
|
||||
|
||||
|
||||
def get_adapter_class(api_format: str) -> Optional[Type["ChatAdapterBase"]]:
|
||||
"""
|
||||
根据 API format 获取 Adapter 类
|
||||
|
||||
Args:
|
||||
api_format: API 格式标识(如 "CLAUDE", "OPENAI", "GEMINI")
|
||||
|
||||
Returns:
|
||||
对应的 Adapter 类,如果未找到返回 None
|
||||
"""
|
||||
_ensure_adapters_loaded()
|
||||
return _ADAPTER_REGISTRY.get(api_format.upper()) if api_format else None
|
||||
|
||||
|
||||
def get_adapter_instance(api_format: str) -> Optional["ChatAdapterBase"]:
|
||||
"""
|
||||
根据 API format 获取 Adapter 实例
|
||||
|
||||
Args:
|
||||
api_format: API 格式标识
|
||||
|
||||
Returns:
|
||||
Adapter 实例,如果未找到返回 None
|
||||
"""
|
||||
adapter_class = get_adapter_class(api_format)
|
||||
if adapter_class:
|
||||
return adapter_class()
|
||||
return None
|
||||
|
||||
|
||||
def list_registered_formats() -> list[str]:
|
||||
"""返回所有已注册的 API 格式"""
|
||||
_ensure_adapters_loaded()
|
||||
return list(_ADAPTER_REGISTRY.keys())
|
||||
1257
src/api/handlers/base/chat_handler_base.py
Normal file
1257
src/api/handlers/base/chat_handler_base.py
Normal file
File diff suppressed because it is too large
Load Diff
648
src/api/handlers/base/cli_adapter_base.py
Normal file
648
src/api/handlers/base/cli_adapter_base.py
Normal file
@@ -0,0 +1,648 @@
|
||||
"""
|
||||
CLI Adapter 通用基类
|
||||
|
||||
提供 CLI 格式(直接透传请求)的通用适配器逻辑:
|
||||
- 请求解析和验证
|
||||
- 审计日志记录
|
||||
- 错误处理和响应格式化
|
||||
- Handler 创建和调用
|
||||
- 计费策略(支持不同 API 格式的差异化计费)
|
||||
|
||||
子类只需提供:
|
||||
- FORMAT_ID: API 格式标识
|
||||
- HANDLER_CLASS: 对应的 MessageHandler 类
|
||||
- 可选覆盖 _extract_message_count() 自定义消息计数逻辑
|
||||
- 可选覆盖 compute_total_input_context() 自定义总输入上下文计算
|
||||
"""
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.context import ApiRequestContext
|
||||
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
|
||||
from src.core.exceptions import (
|
||||
InvalidRequestException,
|
||||
ModelNotSupportedException,
|
||||
ProviderAuthException,
|
||||
ProviderNotAvailableException,
|
||||
ProviderRateLimitException,
|
||||
ProviderTimeoutException,
|
||||
ProxyException,
|
||||
QuotaExceededException,
|
||||
UpstreamClientException,
|
||||
)
|
||||
from src.core.logger import logger
|
||||
from src.services.request.result import RequestResult
|
||||
from src.services.usage.recorder import UsageRecorder
|
||||
|
||||
|
||||
|
||||
class CliAdapterBase(ApiAdapter):
|
||||
"""
|
||||
CLI Adapter 通用基类
|
||||
|
||||
提供 CLI 格式的通用适配器逻辑,子类只需配置:
|
||||
- FORMAT_ID: API 格式标识
|
||||
- HANDLER_CLASS: MessageHandler 类
|
||||
- name: 适配器名称
|
||||
"""
|
||||
|
||||
# 子类必须覆盖
|
||||
FORMAT_ID: str = "UNKNOWN"
|
||||
HANDLER_CLASS: Type[CliMessageHandlerBase]
|
||||
|
||||
# 适配器配置
|
||||
name: str = "cli.base"
|
||||
mode = ApiMode.PROXY
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
"""处理 CLI API 请求"""
|
||||
http_request = context.request
|
||||
user = context.user
|
||||
api_key = context.api_key
|
||||
db = context.db
|
||||
request_id = context.request_id
|
||||
quota_remaining_value = context.quota_remaining
|
||||
start_time = context.start_time
|
||||
client_ip = context.client_ip
|
||||
user_agent = context.user_agent
|
||||
original_headers = context.original_headers
|
||||
query_params = context.query_params # 获取查询参数
|
||||
|
||||
original_request_body = context.ensure_json_body()
|
||||
|
||||
# 合并 path_params 到请求体(如 Gemini API 的 model 在 URL 路径中)
|
||||
if context.path_params:
|
||||
original_request_body = self._merge_path_params(
|
||||
original_request_body, context.path_params
|
||||
)
|
||||
|
||||
# 获取 stream:优先从请求体,其次从 path_params(如 Gemini 通过 URL 端点区分)
|
||||
stream = original_request_body.get("stream")
|
||||
if stream is None and context.path_params:
|
||||
stream = context.path_params.get("stream", False)
|
||||
stream = bool(stream)
|
||||
|
||||
# 获取 model:优先从请求体,其次从 path_params(如 Gemini 的 model 在 URL 路径中)
|
||||
model = original_request_body.get("model")
|
||||
if model is None and context.path_params:
|
||||
model = context.path_params.get("model", "unknown")
|
||||
model = model or "unknown"
|
||||
|
||||
# 提取请求元数据
|
||||
audit_metadata = self._build_audit_metadata(original_request_body, context.path_params)
|
||||
context.add_audit_metadata(**audit_metadata)
|
||||
|
||||
# 格式化额度显示
|
||||
quota_display = (
|
||||
"unlimited" if quota_remaining_value is None else f"${quota_remaining_value:.2f}"
|
||||
)
|
||||
|
||||
# 请求开始日志
|
||||
logger.info(f"[REQ] {request_id[:8]} | {self.FORMAT_ID} | {getattr(api_key, 'name', 'unknown')} | "
|
||||
f"{model} | {'stream' if stream else 'sync'} | quota:{quota_display}")
|
||||
|
||||
try:
|
||||
# 检查客户端连接
|
||||
if await http_request.is_disconnected():
|
||||
logger.warning("客户端连接断开")
|
||||
raise HTTPException(status_code=499, detail="Client disconnected")
|
||||
|
||||
# 创建 Handler
|
||||
handler = self.HANDLER_CLASS(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
request_id=request_id,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
start_time=start_time,
|
||||
allowed_api_formats=self.allowed_api_formats,
|
||||
adapter_detector=self.detect_capability_requirements,
|
||||
)
|
||||
|
||||
# 处理请求
|
||||
if stream:
|
||||
return await handler.process_stream(
|
||||
original_request_body=original_request_body,
|
||||
original_headers=original_headers,
|
||||
query_params=query_params,
|
||||
path_params=context.path_params,
|
||||
)
|
||||
return await handler.process_sync(
|
||||
original_request_body=original_request_body,
|
||||
original_headers=original_headers,
|
||||
query_params=query_params,
|
||||
path_params=context.path_params,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
except (
|
||||
ModelNotSupportedException,
|
||||
QuotaExceededException,
|
||||
InvalidRequestException,
|
||||
) as e:
|
||||
logger.debug(f"客户端请求错误: {e.error_type}")
|
||||
return self._error_response(
|
||||
status_code=e.status_code,
|
||||
error_type=(
|
||||
"invalid_request_error" if e.status_code == 400 else "quota_exceeded"
|
||||
),
|
||||
message=e.message,
|
||||
)
|
||||
|
||||
except (
|
||||
ProviderAuthException,
|
||||
ProviderRateLimitException,
|
||||
ProviderNotAvailableException,
|
||||
ProviderTimeoutException,
|
||||
UpstreamClientException,
|
||||
) as e:
|
||||
return await self._handle_provider_exception(
|
||||
e,
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
start_time=start_time,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return await self._handle_unexpected_exception(
|
||||
e,
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
start_time=start_time,
|
||||
original_headers=original_headers,
|
||||
original_request_body=original_request_body,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
def _merge_path_params(
|
||||
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
合并 URL 路径参数到请求体 - 子类可覆盖
|
||||
|
||||
默认实现:直接将 path_params 中的字段合并到请求体(不覆盖已有字段)
|
||||
|
||||
Args:
|
||||
original_request_body: 原始请求体字典
|
||||
path_params: URL 路径参数字典
|
||||
|
||||
Returns:
|
||||
合并后的请求体字典
|
||||
"""
|
||||
merged = original_request_body.copy()
|
||||
for key, value in path_params.items():
|
||||
if key not in merged:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
def _extract_message_count(self, payload: Dict[str, Any]) -> int:
|
||||
"""
|
||||
提取消息数量 - 子类可覆盖
|
||||
|
||||
默认实现:从 input 字段提取
|
||||
"""
|
||||
if "input" not in payload:
|
||||
return 0
|
||||
input_data = payload["input"]
|
||||
if isinstance(input_data, list):
|
||||
return len(input_data)
|
||||
if isinstance(input_data, dict) and "messages" in input_data:
|
||||
return len(input_data.get("messages", []))
|
||||
return 0
|
||||
|
||||
def _build_audit_metadata(
|
||||
self,
|
||||
payload: Dict[str, Any],
|
||||
path_params: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
构建审计日志元数据 - 子类可覆盖
|
||||
|
||||
Args:
|
||||
payload: 请求体
|
||||
path_params: URL 路径参数(用于获取 model 等)
|
||||
"""
|
||||
# 优先从请求体获取 model,其次从 path_params
|
||||
model = payload.get("model")
|
||||
if model is None and path_params:
|
||||
model = path_params.get("model", "unknown")
|
||||
model = model or "unknown"
|
||||
|
||||
stream = payload.get("stream", False)
|
||||
messages_count = self._extract_message_count(payload)
|
||||
|
||||
return {
|
||||
"action": f"{self.FORMAT_ID.lower()}_request",
|
||||
"model": model,
|
||||
"stream": bool(stream),
|
||||
"max_tokens": payload.get("max_tokens"),
|
||||
"messages_count": messages_count,
|
||||
"temperature": payload.get("temperature"),
|
||||
"top_p": payload.get("top_p"),
|
||||
"tool_count": len(payload.get("tools") or []),
|
||||
"instructions_present": bool(payload.get("instructions")),
|
||||
}
|
||||
|
||||
async def _handle_provider_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
*,
|
||||
db,
|
||||
user,
|
||||
api_key,
|
||||
model: str,
|
||||
stream: bool,
|
||||
start_time: float,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
client_ip: str,
|
||||
request_id: str,
|
||||
) -> JSONResponse:
|
||||
"""处理 Provider 相关异常"""
|
||||
logger.debug(f"Caught provider exception: {type(e).__name__}")
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 使用 RequestResult.from_exception 创建统一的失败结果
|
||||
# 关键:api_format 从 FORMAT_ID 获取,确保始终有值
|
||||
result = RequestResult.from_exception(
|
||||
exception=e,
|
||||
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
|
||||
model=model,
|
||||
response_time_ms=response_time,
|
||||
is_stream=stream,
|
||||
)
|
||||
result.request_headers = original_headers
|
||||
result.request_body = original_request_body
|
||||
|
||||
# 确定错误消息
|
||||
if isinstance(e, ProviderAuthException):
|
||||
error_message = (
|
||||
f"提供商认证失败: {str(e)}"
|
||||
if result.metadata.provider != "unknown"
|
||||
else "服务端错误: 无可用提供商"
|
||||
)
|
||||
result.error_message = error_message
|
||||
|
||||
# 处理上游客户端错误(如图片处理失败)
|
||||
if isinstance(e, UpstreamClientException):
|
||||
# 返回 400 状态码和清晰的错误消息
|
||||
result.status_code = e.status_code
|
||||
result.error_message = e.message
|
||||
|
||||
# 使用 UsageRecorder 记录失败
|
||||
recorder = UsageRecorder(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
await recorder.record_failure(result, original_headers, original_request_body)
|
||||
|
||||
# 根据异常类型确定错误类型
|
||||
if isinstance(e, UpstreamClientException):
|
||||
error_type = "invalid_request_error"
|
||||
elif result.status_code == 503:
|
||||
error_type = "internal_server_error"
|
||||
else:
|
||||
error_type = "rate_limit_exceeded"
|
||||
|
||||
return self._error_response(
|
||||
status_code=result.status_code,
|
||||
error_type=error_type,
|
||||
message=result.error_message or str(e),
|
||||
)
|
||||
|
||||
async def _handle_unexpected_exception(
|
||||
self,
|
||||
e: Exception,
|
||||
*,
|
||||
db,
|
||||
user,
|
||||
api_key,
|
||||
model: str,
|
||||
stream: bool,
|
||||
start_time: float,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
client_ip: str,
|
||||
request_id: str,
|
||||
) -> JSONResponse:
|
||||
"""处理未预期的异常"""
|
||||
if isinstance(e, ProxyException):
|
||||
logger.error(f"{self.FORMAT_ID} 请求处理业务异常: {type(e).__name__}")
|
||||
else:
|
||||
logger.error(f"{self.FORMAT_ID} 请求处理意外异常",
|
||||
exception=e,
|
||||
extra_data={
|
||||
"exception_class": e.__class__.__name__,
|
||||
"processing_stage": "request_processing",
|
||||
"model": model,
|
||||
"stream": stream,
|
||||
"traceback_preview": str(traceback.format_exc())[:500],
|
||||
},
|
||||
)
|
||||
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 使用 RequestResult.from_exception 创建统一的失败结果
|
||||
# 关键:api_format 从 FORMAT_ID 获取,确保始终有值
|
||||
result = RequestResult.from_exception(
|
||||
exception=e,
|
||||
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
|
||||
model=model,
|
||||
response_time_ms=response_time,
|
||||
is_stream=stream,
|
||||
)
|
||||
# 对于未预期的异常,强制设置状态码为 500
|
||||
result.status_code = 500
|
||||
result.error_type = "internal_error"
|
||||
result.request_headers = original_headers
|
||||
result.request_body = original_request_body
|
||||
|
||||
# 使用 UsageRecorder 记录失败
|
||||
recorder = UsageRecorder(
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
client_ip=client_ip,
|
||||
request_id=request_id,
|
||||
)
|
||||
await recorder.record_failure(result, original_headers, original_request_body)
|
||||
|
||||
return self._error_response(
|
||||
status_code=500,
|
||||
error_type="internal_server_error",
|
||||
message="处理请求时发生内部错误")
|
||||
|
||||
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
|
||||
"""生成错误响应"""
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"error": {
|
||||
"type": error_type,
|
||||
"message": message,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 计费策略相关方法 - 子类可覆盖以实现不同 API 格式的差异化计费
|
||||
# =========================================================================
|
||||
|
||||
def compute_total_input_context(
|
||||
self,
|
||||
input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
) -> int:
|
||||
"""
|
||||
计算总输入上下文(用于阶梯计费判定)
|
||||
|
||||
默认实现:input_tokens + cache_read_input_tokens
|
||||
子类可覆盖此方法实现不同的计算逻辑
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数(部分格式可能需要)
|
||||
|
||||
Returns:
|
||||
总输入上下文 token 数
|
||||
"""
|
||||
return input_tokens + cache_read_input_tokens
|
||||
|
||||
def get_cache_read_price_for_ttl(
|
||||
self,
|
||||
tier: dict,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
根据缓存 TTL 获取缓存读取价格
|
||||
|
||||
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
|
||||
子类可覆盖此方法实现不同的 TTL 定价逻辑
|
||||
|
||||
Args:
|
||||
tier: 当前阶梯配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
缓存读取价格(每 1M tokens)
|
||||
"""
|
||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||
if ttl_pricing and cache_ttl_minutes is not None:
|
||||
matched_price = None
|
||||
for ttl_config in ttl_pricing:
|
||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||
if cache_ttl_minutes <= ttl_limit:
|
||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
||||
break
|
||||
if matched_price is not None:
|
||||
return matched_price
|
||||
# 超过所有配置的 TTL,使用最后一个
|
||||
if ttl_pricing:
|
||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||
|
||||
return tier.get("cache_read_price_per_1m")
|
||||
|
||||
def compute_cost(
|
||||
self,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_input_tokens: int,
|
||||
cache_read_input_tokens: int,
|
||||
input_price_per_1m: float,
|
||||
output_price_per_1m: float,
|
||||
cache_creation_price_per_1m: Optional[float],
|
||||
cache_read_price_per_1m: Optional[float],
|
||||
price_per_request: Optional[float],
|
||||
tiered_pricing: Optional[dict] = None,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
计算请求成本
|
||||
|
||||
默认实现:支持固定价格和阶梯计费
|
||||
子类可覆盖此方法实现完全不同的计费逻辑
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
output_tokens: 输出 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
input_price_per_1m: 输入价格(每 1M tokens)
|
||||
output_price_per_1m: 输出价格(每 1M tokens)
|
||||
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens)
|
||||
cache_read_price_per_1m: 缓存读取价格(每 1M tokens)
|
||||
price_per_request: 按次计费价格
|
||||
tiered_pricing: 阶梯计费配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
包含各项成本的字典
|
||||
"""
|
||||
tier_index = None
|
||||
effective_input_price = input_price_per_1m
|
||||
effective_output_price = output_price_per_1m
|
||||
effective_cache_creation_price = cache_creation_price_per_1m
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 检查阶梯计费
|
||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
|
||||
|
||||
if tier:
|
||||
tier_index = tiered_pricing["tiers"].index(tier)
|
||||
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
|
||||
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
|
||||
effective_cache_creation_price = tier.get(
|
||||
"cache_creation_price_per_1m", cache_creation_price_per_1m
|
||||
)
|
||||
effective_cache_read_price = self.get_cache_read_price_for_ttl(
|
||||
tier, cache_ttl_minutes
|
||||
)
|
||||
if effective_cache_read_price is None:
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 计算各项成本
|
||||
input_cost = (input_tokens / 1_000_000) * effective_input_price
|
||||
output_cost = (output_tokens / 1_000_000) * effective_output_price
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * effective_cache_creation_price
|
||||
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
|
||||
cache_read_cost = (
|
||||
cache_read_input_tokens / 1_000_000
|
||||
) * effective_cache_read_price
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
request_cost = price_per_request if price_per_request else 0.0
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
|
||||
return {
|
||||
"input_cost": input_cost,
|
||||
"output_cost": output_cost,
|
||||
"cache_creation_cost": cache_creation_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_cost": cache_cost,
|
||||
"request_cost": request_cost,
|
||||
"total_cost": total_cost,
|
||||
"tier_index": tier_index,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
|
||||
"""根据总输入 token 数确定价格阶梯"""
|
||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
for tier in tiers:
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_tokens <= up_to:
|
||||
return tier
|
||||
|
||||
return tiers[-1] if tiers else None
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
|
||||
# =========================================================================
|
||||
|
||||
_CLI_ADAPTER_REGISTRY: Dict[str, Type["CliAdapterBase"]] = {}
|
||||
_CLI_ADAPTERS_LOADED = False
|
||||
|
||||
|
||||
def register_cli_adapter(adapter_class: Type["CliAdapterBase"]) -> Type["CliAdapterBase"]:
|
||||
"""
|
||||
注册 CLI Adapter 类到注册表
|
||||
|
||||
用法:
|
||||
@register_cli_adapter
|
||||
class ClaudeCliAdapter(CliAdapterBase):
|
||||
FORMAT_ID = "CLAUDE_CLI"
|
||||
...
|
||||
"""
|
||||
format_id = adapter_class.FORMAT_ID
|
||||
if format_id and format_id != "UNKNOWN":
|
||||
_CLI_ADAPTER_REGISTRY[format_id.upper()] = adapter_class
|
||||
return adapter_class
|
||||
|
||||
|
||||
def _ensure_cli_adapters_loaded():
|
||||
"""确保所有 CLI Adapter 已被加载(触发注册)"""
|
||||
global _CLI_ADAPTERS_LOADED
|
||||
if _CLI_ADAPTERS_LOADED:
|
||||
return
|
||||
|
||||
# 导入各个 CLI Adapter 模块以触发 @register_cli_adapter 装饰器
|
||||
try:
|
||||
from src.api.handlers.claude_cli import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from src.api.handlers.openai_cli import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from src.api.handlers.gemini_cli import adapter as _ # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_CLI_ADAPTERS_LOADED = True
|
||||
|
||||
|
||||
def get_cli_adapter_class(api_format: str) -> Optional[Type["CliAdapterBase"]]:
|
||||
"""根据 API format 获取 CLI Adapter 类"""
|
||||
_ensure_cli_adapters_loaded()
|
||||
return _CLI_ADAPTER_REGISTRY.get(api_format.upper()) if api_format else None
|
||||
|
||||
|
||||
def get_cli_adapter_instance(api_format: str) -> Optional["CliAdapterBase"]:
|
||||
"""根据 API format 获取 CLI Adapter 实例"""
|
||||
adapter_class = get_cli_adapter_class(api_format)
|
||||
if adapter_class:
|
||||
return adapter_class()
|
||||
return None
|
||||
|
||||
|
||||
def list_registered_cli_formats() -> list[str]:
|
||||
"""返回所有已注册的 CLI API 格式"""
|
||||
_ensure_cli_adapters_loaded()
|
||||
return list(_CLI_ADAPTER_REGISTRY.keys())
|
||||
1614
src/api/handlers/base/cli_handler_base.py
Normal file
1614
src/api/handlers/base/cli_handler_base.py
Normal file
File diff suppressed because it is too large
Load Diff
279
src/api/handlers/base/format_converter_registry.py
Normal file
279
src/api/handlers/base/format_converter_registry.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
格式转换器注册表
|
||||
|
||||
自动管理不同 API 格式之间的转换器,支持:
|
||||
- 请求转换:客户端格式 → Provider 格式
|
||||
- 响应转换:Provider 格式 → 客户端格式
|
||||
|
||||
使用方法:
|
||||
1. 实现 Converter 类(需要有 convert_request 和/或 convert_response 方法)
|
||||
2. 调用 registry.register() 注册转换器
|
||||
3. 在 Handler 中调用 registry.convert_request/convert_response
|
||||
|
||||
示例:
|
||||
from src.api.handlers.base.format_converter_registry import converter_registry
|
||||
|
||||
# 注册转换器
|
||||
converter_registry.register("CLAUDE", "GEMINI", ClaudeToGeminiConverter())
|
||||
converter_registry.register("GEMINI", "CLAUDE", GeminiToClaudeConverter())
|
||||
|
||||
# 使用转换器
|
||||
gemini_request = converter_registry.convert_request(claude_request, "CLAUDE", "GEMINI")
|
||||
claude_response = converter_registry.convert_response(gemini_response, "GEMINI", "CLAUDE")
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Protocol, Tuple
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
|
||||
class RequestConverter(Protocol):
|
||||
"""请求转换器协议"""
|
||||
|
||||
def convert_request(self, request: Dict[str, Any]) -> Dict[str, Any]: ...
|
||||
|
||||
|
||||
class ResponseConverter(Protocol):
|
||||
"""响应转换器协议"""
|
||||
|
||||
def convert_response(self, response: Dict[str, Any]) -> Dict[str, Any]: ...
|
||||
|
||||
|
||||
class StreamChunkConverter(Protocol):
|
||||
"""流式响应块转换器协议"""
|
||||
|
||||
def convert_stream_chunk(self, chunk: Dict[str, Any]) -> Dict[str, Any]: ...
|
||||
|
||||
|
||||
class FormatConverterRegistry:
|
||||
"""
|
||||
格式转换器注册表
|
||||
|
||||
管理不同 API 格式之间的双向转换器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# key: (source_format, target_format), value: converter instance
|
||||
self._converters: Dict[Tuple[str, str], Any] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
converter: Any,
|
||||
) -> None:
|
||||
"""
|
||||
注册格式转换器
|
||||
|
||||
Args:
|
||||
source_format: 源格式(如 "CLAUDE", "OPENAI", "GEMINI")
|
||||
target_format: 目标格式
|
||||
converter: 转换器实例(需要有 convert_request/convert_response 方法)
|
||||
"""
|
||||
key = (source_format.upper(), target_format.upper())
|
||||
self._converters[key] = converter
|
||||
logger.info(f"[ConverterRegistry] 注册转换器: {source_format} -> {target_format}")
|
||||
|
||||
def get_converter(
|
||||
self,
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
获取转换器
|
||||
|
||||
Args:
|
||||
source_format: 源格式
|
||||
target_format: 目标格式
|
||||
|
||||
Returns:
|
||||
转换器实例,如果不存在返回 None
|
||||
"""
|
||||
key = (source_format.upper(), target_format.upper())
|
||||
return self._converters.get(key)
|
||||
|
||||
def has_converter(
|
||||
self,
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
) -> bool:
|
||||
"""检查是否存在转换器"""
|
||||
key = (source_format.upper(), target_format.upper())
|
||||
return key in self._converters
|
||||
|
||||
def convert_request(
|
||||
self,
|
||||
request: Dict[str, Any],
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
转换请求
|
||||
|
||||
Args:
|
||||
request: 原始请求字典
|
||||
source_format: 源格式(客户端格式)
|
||||
target_format: 目标格式(Provider 格式)
|
||||
|
||||
Returns:
|
||||
转换后的请求字典,如果无需转换或没有转换器则返回原始请求
|
||||
"""
|
||||
# 同格式无需转换
|
||||
if source_format.upper() == target_format.upper():
|
||||
return request
|
||||
|
||||
converter = self.get_converter(source_format, target_format)
|
||||
if converter is None:
|
||||
logger.warning(f"[ConverterRegistry] 未找到请求转换器: {source_format} -> {target_format},返回原始请求")
|
||||
return request
|
||||
|
||||
if not hasattr(converter, "convert_request"):
|
||||
logger.warning(f"[ConverterRegistry] 转换器缺少 convert_request 方法: {source_format} -> {target_format}")
|
||||
return request
|
||||
|
||||
try:
|
||||
converted = converter.convert_request(request)
|
||||
logger.debug(f"[ConverterRegistry] 请求转换成功: {source_format} -> {target_format}")
|
||||
return converted
|
||||
except Exception as e:
|
||||
logger.error(f"[ConverterRegistry] 请求转换失败: {source_format} -> {target_format}: {e}")
|
||||
return request
|
||||
|
||||
def convert_response(
|
||||
self,
|
||||
response: Dict[str, Any],
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
转换响应
|
||||
|
||||
Args:
|
||||
response: 原始响应字典
|
||||
source_format: 源格式(Provider 格式)
|
||||
target_format: 目标格式(客户端格式)
|
||||
|
||||
Returns:
|
||||
转换后的响应字典,如果无需转换或没有转换器则返回原始响应
|
||||
"""
|
||||
# 同格式无需转换
|
||||
if source_format.upper() == target_format.upper():
|
||||
return response
|
||||
|
||||
converter = self.get_converter(source_format, target_format)
|
||||
if converter is None:
|
||||
logger.warning(f"[ConverterRegistry] 未找到响应转换器: {source_format} -> {target_format},返回原始响应")
|
||||
return response
|
||||
|
||||
if not hasattr(converter, "convert_response"):
|
||||
logger.warning(f"[ConverterRegistry] 转换器缺少 convert_response 方法: {source_format} -> {target_format}")
|
||||
return response
|
||||
|
||||
try:
|
||||
converted = converter.convert_response(response)
|
||||
logger.debug(f"[ConverterRegistry] 响应转换成功: {source_format} -> {target_format}")
|
||||
return converted
|
||||
except Exception as e:
|
||||
logger.error(f"[ConverterRegistry] 响应转换失败: {source_format} -> {target_format}: {e}")
|
||||
return response
|
||||
|
||||
def convert_stream_chunk(
|
||||
self,
|
||||
chunk: Dict[str, Any],
|
||||
source_format: str,
|
||||
target_format: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
转换流式响应块
|
||||
|
||||
Args:
|
||||
chunk: 原始流式响应块
|
||||
source_format: 源格式(Provider 格式)
|
||||
target_format: 目标格式(客户端格式)
|
||||
|
||||
Returns:
|
||||
转换后的流式响应块
|
||||
"""
|
||||
# 同格式无需转换
|
||||
if source_format.upper() == target_format.upper():
|
||||
return chunk
|
||||
|
||||
converter = self.get_converter(source_format, target_format)
|
||||
if converter is None:
|
||||
return chunk
|
||||
|
||||
# 优先使用专门的流式转换方法
|
||||
if hasattr(converter, "convert_stream_chunk"):
|
||||
try:
|
||||
return converter.convert_stream_chunk(chunk)
|
||||
except Exception as e:
|
||||
logger.error(f"[ConverterRegistry] 流式块转换失败: {source_format} -> {target_format}: {e}")
|
||||
return chunk
|
||||
|
||||
# 降级到普通响应转换
|
||||
if hasattr(converter, "convert_response"):
|
||||
try:
|
||||
return converter.convert_response(chunk)
|
||||
except Exception:
|
||||
return chunk
|
||||
|
||||
return chunk
|
||||
|
||||
def list_converters(self) -> list[Tuple[str, str]]:
|
||||
"""列出所有已注册的转换器"""
|
||||
return list(self._converters.keys())
|
||||
|
||||
|
||||
# 全局单例
|
||||
converter_registry = FormatConverterRegistry()
|
||||
|
||||
|
||||
def register_all_converters():
|
||||
"""
|
||||
注册所有内置的格式转换器
|
||||
|
||||
在应用启动时调用此函数
|
||||
"""
|
||||
# Claude <-> OpenAI
|
||||
try:
|
||||
from src.api.handlers.claude.converter import OpenAIToClaudeConverter
|
||||
from src.api.handlers.openai.converter import ClaudeToOpenAIConverter
|
||||
|
||||
converter_registry.register("OPENAI", "CLAUDE", OpenAIToClaudeConverter())
|
||||
converter_registry.register("CLAUDE", "OPENAI", ClaudeToOpenAIConverter())
|
||||
except ImportError as e:
|
||||
logger.warning(f"[ConverterRegistry] 无法加载 Claude/OpenAI 转换器: {e}")
|
||||
|
||||
# Claude <-> Gemini
|
||||
try:
|
||||
from src.api.handlers.gemini.converter import (
|
||||
ClaudeToGeminiConverter,
|
||||
GeminiToClaudeConverter,
|
||||
)
|
||||
|
||||
converter_registry.register("CLAUDE", "GEMINI", ClaudeToGeminiConverter())
|
||||
converter_registry.register("GEMINI", "CLAUDE", GeminiToClaudeConverter())
|
||||
except ImportError as e:
|
||||
logger.warning(f"[ConverterRegistry] 无法加载 Claude/Gemini 转换器: {e}")
|
||||
|
||||
# OpenAI <-> Gemini
|
||||
try:
|
||||
from src.api.handlers.gemini.converter import (
|
||||
GeminiToOpenAIConverter,
|
||||
OpenAIToGeminiConverter,
|
||||
)
|
||||
|
||||
converter_registry.register("OPENAI", "GEMINI", OpenAIToGeminiConverter())
|
||||
converter_registry.register("GEMINI", "OPENAI", GeminiToOpenAIConverter())
|
||||
except ImportError as e:
|
||||
logger.warning(f"[ConverterRegistry] 无法加载 OpenAI/Gemini 转换器: {e}")
|
||||
|
||||
logger.info(f"[ConverterRegistry] 已注册 {len(converter_registry.list_converters())} 个格式转换器")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FormatConverterRegistry",
|
||||
"converter_registry",
|
||||
"register_all_converters",
|
||||
]
|
||||
465
src/api/handlers/base/parsers.py
Normal file
465
src/api/handlers/base/parsers.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""
|
||||
响应解析器工厂
|
||||
|
||||
直接根据格式 ID 创建对应的 ResponseParser 实现,
|
||||
不再经过 Protocol 抽象层。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.response_parser import (
|
||||
ParsedChunk,
|
||||
ParsedResponse,
|
||||
ResponseParser,
|
||||
StreamStats,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIResponseParser(ResponseParser):
|
||||
"""OpenAI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
from src.api.handlers.openai.stream_parser import OpenAIStreamParser
|
||||
|
||||
self._parser = OpenAIStreamParser()
|
||||
self.name = "OPENAI"
|
||||
self.api_format = "OPENAI"
|
||||
|
||||
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
|
||||
if not line or not line.strip():
|
||||
return None
|
||||
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
else:
|
||||
data_str = line
|
||||
|
||||
parsed = self._parser.parse_line(data_str)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
chunk = ParsedChunk(
|
||||
raw_line=line,
|
||||
event_type=None,
|
||||
data=parsed,
|
||||
)
|
||||
|
||||
# 提取文本增量
|
||||
text_delta = self._parser.extract_text_delta(parsed)
|
||||
if text_delta:
|
||||
chunk.text_delta = text_delta
|
||||
stats.collected_text += text_delta
|
||||
|
||||
# 检查是否结束
|
||||
if self._parser.is_done_chunk(parsed):
|
||||
chunk.is_done = True
|
||||
stats.has_completion = True
|
||||
|
||||
stats.chunk_count += 1
|
||||
stats.data_count += 1
|
||||
|
||||
return chunk
|
||||
|
||||
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
|
||||
result = ParsedResponse(
|
||||
raw_response=response,
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
# 提取文本内容
|
||||
choices = response.get("choices", [])
|
||||
if choices:
|
||||
message = choices[0].get("message", {})
|
||||
content = message.get("content")
|
||||
if content:
|
||||
result.text_content = content
|
||||
|
||||
result.response_id = response.get("id")
|
||||
|
||||
# 提取 usage
|
||||
usage = response.get("usage", {})
|
||||
result.input_tokens = usage.get("prompt_tokens", 0)
|
||||
result.output_tokens = usage.get("completion_tokens", 0)
|
||||
|
||||
# 检查错误
|
||||
if "error" in response:
|
||||
result.is_error = True
|
||||
error = response.get("error", {})
|
||||
if isinstance(error, dict):
|
||||
result.error_type = error.get("type")
|
||||
result.error_message = error.get("message")
|
||||
else:
|
||||
result.error_message = str(error)
|
||||
|
||||
return result
|
||||
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
usage = response.get("usage", {})
|
||||
return {
|
||||
"input_tokens": usage.get("prompt_tokens", 0),
|
||||
"output_tokens": usage.get("completion_tokens", 0),
|
||||
"cache_creation_tokens": 0,
|
||||
"cache_read_tokens": 0,
|
||||
}
|
||||
|
||||
def extract_text_content(self, response: Dict[str, Any]) -> str:
|
||||
choices = response.get("choices", [])
|
||||
if choices:
|
||||
message = choices[0].get("message", {})
|
||||
content = message.get("content")
|
||||
if content:
|
||||
return content
|
||||
return ""
|
||||
|
||||
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
||||
return "error" in response
|
||||
|
||||
|
||||
class OpenAICliResponseParser(OpenAIResponseParser):
|
||||
"""OpenAI CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "OPENAI_CLI"
|
||||
self.api_format = "OPENAI_CLI"
|
||||
|
||||
|
||||
class ClaudeResponseParser(ResponseParser):
|
||||
"""Claude 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
from src.api.handlers.claude.stream_parser import ClaudeStreamParser
|
||||
|
||||
self._parser = ClaudeStreamParser()
|
||||
self.name = "CLAUDE"
|
||||
self.api_format = "CLAUDE"
|
||||
|
||||
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
|
||||
if not line or not line.strip():
|
||||
return None
|
||||
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
else:
|
||||
data_str = line
|
||||
|
||||
parsed = self._parser.parse_line(data_str)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
chunk = ParsedChunk(
|
||||
raw_line=line,
|
||||
event_type=self._parser.get_event_type(parsed),
|
||||
data=parsed,
|
||||
)
|
||||
|
||||
# 提取文本增量
|
||||
text_delta = self._parser.extract_text_delta(parsed)
|
||||
if text_delta:
|
||||
chunk.text_delta = text_delta
|
||||
stats.collected_text += text_delta
|
||||
|
||||
# 检查是否结束
|
||||
if self._parser.is_done_event(parsed):
|
||||
chunk.is_done = True
|
||||
stats.has_completion = True
|
||||
|
||||
# 提取 usage
|
||||
usage = self._parser.extract_usage(parsed)
|
||||
if usage:
|
||||
chunk.input_tokens = usage.get("input_tokens", 0)
|
||||
chunk.output_tokens = usage.get("output_tokens", 0)
|
||||
chunk.cache_creation_tokens = usage.get("cache_creation_tokens", 0)
|
||||
chunk.cache_read_tokens = usage.get("cache_read_tokens", 0)
|
||||
|
||||
stats.input_tokens = chunk.input_tokens
|
||||
stats.output_tokens = chunk.output_tokens
|
||||
stats.cache_creation_tokens = chunk.cache_creation_tokens
|
||||
stats.cache_read_tokens = chunk.cache_read_tokens
|
||||
|
||||
# 检查错误
|
||||
if self._parser.is_error_event(parsed):
|
||||
chunk.is_error = True
|
||||
error = parsed.get("error", {})
|
||||
if isinstance(error, dict):
|
||||
chunk.error_message = error.get("message", str(error))
|
||||
else:
|
||||
chunk.error_message = str(error)
|
||||
|
||||
stats.chunk_count += 1
|
||||
stats.data_count += 1
|
||||
|
||||
return chunk
|
||||
|
||||
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
|
||||
result = ParsedResponse(
|
||||
raw_response=response,
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
# 提取文本内容
|
||||
content = response.get("content", [])
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
result.text_content = "".join(text_parts)
|
||||
|
||||
result.response_id = response.get("id")
|
||||
|
||||
# 提取 usage
|
||||
usage = response.get("usage", {})
|
||||
result.input_tokens = usage.get("input_tokens", 0)
|
||||
result.output_tokens = usage.get("output_tokens", 0)
|
||||
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0)
|
||||
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
|
||||
|
||||
# 检查错误
|
||||
if "error" in response or response.get("type") == "error":
|
||||
result.is_error = True
|
||||
error = response.get("error", {})
|
||||
if isinstance(error, dict):
|
||||
result.error_type = error.get("type")
|
||||
result.error_message = error.get("message")
|
||||
else:
|
||||
result.error_message = str(error)
|
||||
|
||||
return result
|
||||
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
usage = response.get("usage", {})
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
def extract_text_content(self, response: Dict[str, Any]) -> str:
|
||||
content = response.get("content", [])
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
return "".join(text_parts)
|
||||
return ""
|
||||
|
||||
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
||||
return "error" in response or response.get("type") == "error"
|
||||
|
||||
|
||||
class ClaudeCliResponseParser(ClaudeResponseParser):
|
||||
"""Claude CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "CLAUDE_CLI"
|
||||
self.api_format = "CLAUDE_CLI"
|
||||
|
||||
|
||||
class GeminiResponseParser(ResponseParser):
|
||||
"""Gemini 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
|
||||
|
||||
self._parser = GeminiStreamParser()
|
||||
self.name = "GEMINI"
|
||||
self.api_format = "GEMINI"
|
||||
|
||||
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
|
||||
"""
|
||||
解析 Gemini SSE 行
|
||||
|
||||
Gemini 的流式响应使用 SSE 格式 (data: {...})
|
||||
"""
|
||||
if not line or not line.strip():
|
||||
return None
|
||||
|
||||
# Gemini SSE 格式: data: {...}
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
else:
|
||||
data_str = line
|
||||
|
||||
parsed = self._parser.parse_line(data_str)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
chunk = ParsedChunk(
|
||||
raw_line=line,
|
||||
event_type="content",
|
||||
data=parsed,
|
||||
)
|
||||
|
||||
# 提取文本增量
|
||||
text_delta = self._parser.extract_text_delta(parsed)
|
||||
if text_delta:
|
||||
chunk.text_delta = text_delta
|
||||
stats.collected_text += text_delta
|
||||
|
||||
# 检查是否结束
|
||||
if self._parser.is_done_event(parsed):
|
||||
chunk.is_done = True
|
||||
stats.has_completion = True
|
||||
|
||||
# 提取 usage
|
||||
usage = self._parser.extract_usage(parsed)
|
||||
if usage:
|
||||
chunk.input_tokens = usage.get("input_tokens", 0)
|
||||
chunk.output_tokens = usage.get("output_tokens", 0)
|
||||
chunk.cache_read_tokens = usage.get("cached_tokens", 0)
|
||||
|
||||
stats.input_tokens = chunk.input_tokens
|
||||
stats.output_tokens = chunk.output_tokens
|
||||
stats.cache_read_tokens = chunk.cache_read_tokens
|
||||
|
||||
# 检查错误
|
||||
if self._parser.is_error_event(parsed):
|
||||
chunk.is_error = True
|
||||
error = parsed.get("error", {})
|
||||
if isinstance(error, dict):
|
||||
chunk.error_message = error.get("message", str(error))
|
||||
else:
|
||||
chunk.error_message = str(error)
|
||||
|
||||
stats.chunk_count += 1
|
||||
stats.data_count += 1
|
||||
|
||||
return chunk
|
||||
|
||||
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
|
||||
result = ParsedResponse(
|
||||
raw_response=response,
|
||||
status_code=status_code,
|
||||
)
|
||||
|
||||
# 提取文本内容
|
||||
candidates = response.get("candidates", [])
|
||||
if candidates:
|
||||
content = candidates[0].get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
text_parts = []
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text_parts.append(part["text"])
|
||||
result.text_content = "".join(text_parts)
|
||||
|
||||
result.response_id = response.get("modelVersion")
|
||||
|
||||
# 提取 usage(调用 GeminiStreamParser.extract_usage 作为单一实现源)
|
||||
usage = self._parser.extract_usage(response)
|
||||
if usage:
|
||||
result.input_tokens = usage.get("input_tokens", 0)
|
||||
result.output_tokens = usage.get("output_tokens", 0)
|
||||
result.cache_read_tokens = usage.get("cached_tokens", 0)
|
||||
|
||||
# 检查错误(使用增强的错误检测)
|
||||
error_info = self._parser.extract_error_info(response)
|
||||
if error_info:
|
||||
result.is_error = True
|
||||
result.error_type = error_info.get("status")
|
||||
result.error_message = error_info.get("message")
|
||||
|
||||
return result
|
||||
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
"""
|
||||
从 Gemini 响应中提取 token 使用量
|
||||
|
||||
调用 GeminiStreamParser.extract_usage 作为单一实现源
|
||||
"""
|
||||
usage = self._parser.extract_usage(response)
|
||||
if not usage:
|
||||
return {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_tokens": 0,
|
||||
"cache_read_tokens": 0,
|
||||
}
|
||||
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": 0,
|
||||
"cache_read_tokens": usage.get("cached_tokens", 0),
|
||||
}
|
||||
|
||||
def extract_text_content(self, response: Dict[str, Any]) -> str:
|
||||
candidates = response.get("candidates", [])
|
||||
if candidates:
|
||||
content = candidates[0].get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
text_parts = []
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text_parts.append(part["text"])
|
||||
return "".join(text_parts)
|
||||
return ""
|
||||
|
||||
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断响应是否为错误响应
|
||||
|
||||
使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误
|
||||
"""
|
||||
return self._parser.is_error_event(response)
|
||||
|
||||
|
||||
class GeminiCliResponseParser(GeminiResponseParser):
|
||||
"""Gemini CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.name = "GEMINI_CLI"
|
||||
self.api_format = "GEMINI_CLI"
|
||||
|
||||
|
||||
# 解析器注册表
|
||||
_PARSERS = {
|
||||
"CLAUDE": ClaudeResponseParser,
|
||||
"CLAUDE_CLI": ClaudeCliResponseParser,
|
||||
"OPENAI": OpenAIResponseParser,
|
||||
"OPENAI_CLI": OpenAICliResponseParser,
|
||||
"GEMINI": GeminiResponseParser,
|
||||
"GEMINI_CLI": GeminiCliResponseParser,
|
||||
}
|
||||
|
||||
|
||||
def get_parser_for_format(format_id: str) -> ResponseParser:
|
||||
"""
|
||||
根据格式 ID 获取 ResponseParser
|
||||
|
||||
Args:
|
||||
format_id: 格式 ID,如 "CLAUDE", "OPENAI", "CLAUDE_CLI", "OPENAI_CLI"
|
||||
|
||||
Returns:
|
||||
ResponseParser 实例
|
||||
|
||||
Raises:
|
||||
KeyError: 格式不存在
|
||||
"""
|
||||
format_id = format_id.upper()
|
||||
if format_id not in _PARSERS:
|
||||
raise KeyError(f"Unknown format: {format_id}")
|
||||
return _PARSERS[format_id]()
|
||||
|
||||
|
||||
def is_cli_format(format_id: str) -> bool:
|
||||
"""判断是否为 CLI 格式"""
|
||||
return format_id.upper().endswith("_CLI")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OpenAIResponseParser",
|
||||
"OpenAICliResponseParser",
|
||||
"ClaudeResponseParser",
|
||||
"ClaudeCliResponseParser",
|
||||
"GeminiResponseParser",
|
||||
"GeminiCliResponseParser",
|
||||
"get_parser_for_format",
|
||||
"get_parser_from_protocol",
|
||||
"is_cli_format",
|
||||
]
|
||||
207
src/api/handlers/base/request_builder.py
Normal file
207
src/api/handlers/base/request_builder.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
请求构建器 - 透传模式
|
||||
|
||||
透传模式 (Passthrough): CLI 和 Chat 等场景,原样转发请求体和头部
|
||||
- 清理敏感头部:authorization, x-api-key, host, content-length 等
|
||||
- 保留所有其他头部和请求体字段
|
||||
- 适用于:Claude CLI、OpenAI CLI、Chat API 等场景
|
||||
|
||||
使用方式:
|
||||
builder = PassthroughRequestBuilder()
|
||||
payload, headers = builder.build(original_body, original_headers, endpoint, key)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, FrozenSet, Optional, Tuple
|
||||
|
||||
from src.core.crypto import crypto_service
|
||||
|
||||
# ==============================================================================
|
||||
# 统一的头部配置常量
|
||||
# ==============================================================================
|
||||
|
||||
# 敏感头部 - 透传时需要清理(黑名单)
|
||||
# 这些头部要么包含认证信息,要么由代理层重新生成
|
||||
SENSITIVE_HEADERS: FrozenSet[str] = frozenset(
|
||||
{
|
||||
"authorization",
|
||||
"x-api-key",
|
||||
"x-goog-api-key", # Gemini API 认证头
|
||||
"host",
|
||||
"content-length",
|
||||
"transfer-encoding",
|
||||
"connection",
|
||||
# 不透传 accept-encoding,让 httpx 自己协商压缩格式
|
||||
# 避免客户端请求 brotli/zstd 但 httpx 不支持解压的问题
|
||||
"accept-encoding",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 请求构建器
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RequestBuilder(ABC):
|
||||
"""请求构建器抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
def build_payload(
|
||||
self,
|
||||
original_body: Dict[str, Any],
|
||||
*,
|
||||
mapped_model: Optional[str] = None,
|
||||
is_stream: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""构建请求体"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_headers(
|
||||
self,
|
||||
original_headers: Dict[str, str],
|
||||
endpoint: Any,
|
||||
key: Any,
|
||||
*,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
pass
|
||||
|
||||
def build(
|
||||
self,
|
||||
original_body: Dict[str, Any],
|
||||
original_headers: Dict[str, str],
|
||||
endpoint: Any,
|
||||
key: Any,
|
||||
*,
|
||||
mapped_model: Optional[str] = None,
|
||||
is_stream: bool = False,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
构建完整的请求(请求体 + 请求头)
|
||||
|
||||
Returns:
|
||||
Tuple[payload, headers]
|
||||
"""
|
||||
payload = self.build_payload(
|
||||
original_body,
|
||||
mapped_model=mapped_model,
|
||||
is_stream=is_stream,
|
||||
)
|
||||
headers = self.build_headers(
|
||||
original_headers,
|
||||
endpoint,
|
||||
key,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
return payload, headers
|
||||
|
||||
|
||||
class PassthroughRequestBuilder(RequestBuilder):
|
||||
"""
|
||||
透传模式请求构建器
|
||||
|
||||
适用于 CLI 等场景,尽量保持请求原样:
|
||||
- 请求体:直接复制,只修改必要字段(model, stream)
|
||||
- 请求头:清理敏感头部(黑名单),透传其他所有头部
|
||||
"""
|
||||
|
||||
def build_payload(
|
||||
self,
|
||||
original_body: Dict[str, Any],
|
||||
*,
|
||||
mapped_model: Optional[str] = None, # noqa: ARG002 - 由 apply_mapped_model 处理
|
||||
is_stream: bool = False, # noqa: ARG002 - 保留原始值,不自动添加
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
透传请求体 - 原样复制,不做任何修改
|
||||
|
||||
透传模式下:
|
||||
- model: 由各 handler 的 apply_mapped_model 方法处理
|
||||
- stream: 保留客户端原始值(不同 API 处理方式不同)
|
||||
"""
|
||||
return dict(original_body)
|
||||
|
||||
def build_headers(
|
||||
self,
|
||||
original_headers: Dict[str, str],
|
||||
endpoint: Any,
|
||||
key: Any,
|
||||
*,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
透传请求头 - 清理敏感头部(黑名单),透传其他所有头部
|
||||
"""
|
||||
from src.core.api_format_metadata import get_auth_config, resolve_api_format
|
||||
|
||||
headers: Dict[str, str] = {}
|
||||
|
||||
# 1. 根据 API 格式自动设置认证头
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
api_format = getattr(endpoint, "api_format", None)
|
||||
resolved_format = resolve_api_format(api_format)
|
||||
auth_header, auth_type = (
|
||||
get_auth_config(resolved_format) if resolved_format else ("Authorization", "bearer")
|
||||
)
|
||||
|
||||
if auth_type == "bearer":
|
||||
headers[auth_header] = f"Bearer {decrypted_key}"
|
||||
else:
|
||||
headers[auth_header] = decrypted_key
|
||||
|
||||
# 2. 添加 endpoint 配置的额外头部
|
||||
if endpoint.headers:
|
||||
headers.update(endpoint.headers)
|
||||
|
||||
# 3. 透传原始头部(排除敏感头部 - 黑名单模式)
|
||||
if original_headers:
|
||||
for name, value in original_headers.items():
|
||||
lower_name = name.lower()
|
||||
|
||||
# 跳过敏感头部
|
||||
if lower_name in SENSITIVE_HEADERS:
|
||||
continue
|
||||
|
||||
headers[name] = value
|
||||
|
||||
# 4. 添加额外头部
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
# 5. 确保有 Content-Type
|
||||
if "Content-Type" not in headers and "content-type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 便捷函数
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
def build_passthrough_request(
|
||||
original_body: Dict[str, Any],
|
||||
original_headers: Dict[str, str],
|
||||
endpoint: Any,
|
||||
key: Any,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
构建透传模式的请求
|
||||
|
||||
纯透传:原样复制请求体,只处理请求头(认证等)。
|
||||
model mapping 和 stream 由调用方自行处理(不同 API 格式处理方式不同)。
|
||||
"""
|
||||
builder = PassthroughRequestBuilder()
|
||||
return builder.build(
|
||||
original_body,
|
||||
original_headers,
|
||||
endpoint,
|
||||
key,
|
||||
)
|
||||
174
src/api/handlers/base/response_parser.py
Normal file
174
src/api/handlers/base/response_parser.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
响应解析器基类 - 定义统一的响应解析接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedChunk:
|
||||
"""解析后的流式数据块"""
|
||||
|
||||
# 原始数据
|
||||
raw_line: str
|
||||
event_type: Optional[str] = None
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 提取的内容
|
||||
text_delta: str = ""
|
||||
is_done: bool = False
|
||||
is_error: bool = False
|
||||
error_message: Optional[str] = None
|
||||
|
||||
# 使用量信息(通常在最后一个 chunk 中)
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
|
||||
# 响应 ID
|
||||
response_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamStats:
|
||||
"""流式响应统计信息"""
|
||||
|
||||
# 计数
|
||||
chunk_count: int = 0
|
||||
data_count: int = 0
|
||||
|
||||
# Token 使用量
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
|
||||
# 内容
|
||||
collected_text: str = ""
|
||||
response_id: Optional[str] = None
|
||||
|
||||
# 状态
|
||||
has_completion: bool = False
|
||||
status_code: int = 200
|
||||
error_message: Optional[str] = None
|
||||
|
||||
# Provider 信息
|
||||
provider_name: Optional[str] = None
|
||||
endpoint_id: Optional[str] = None
|
||||
key_id: Optional[str] = None
|
||||
|
||||
# 响应头和完整响应
|
||||
response_headers: Dict[str, str] = field(default_factory=dict)
|
||||
final_response: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedResponse:
|
||||
"""解析后的非流式响应"""
|
||||
|
||||
# 原始响应
|
||||
raw_response: Dict[str, Any]
|
||||
status_code: int
|
||||
|
||||
# 提取的内容
|
||||
text_content: str = ""
|
||||
response_id: Optional[str] = None
|
||||
|
||||
# 使用量
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
|
||||
# 错误信息
|
||||
is_error: bool = False
|
||||
error_type: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class ResponseParser(ABC):
|
||||
"""
|
||||
响应解析器基类
|
||||
|
||||
定义统一的接口来解析不同 API 格式的响应。
|
||||
子类需要实现具体的解析逻辑。
|
||||
"""
|
||||
|
||||
# 解析器名称(用于日志)
|
||||
name: str = "base"
|
||||
|
||||
# 支持的 API 格式
|
||||
api_format: str = "UNKNOWN"
|
||||
|
||||
@abstractmethod
|
||||
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
|
||||
"""
|
||||
解析单行 SSE 数据
|
||||
|
||||
Args:
|
||||
line: SSE 行数据
|
||||
stats: 流统计对象(会被更新)
|
||||
|
||||
Returns:
|
||||
解析后的数据块,如果行不包含有效数据则返回 None
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
|
||||
"""
|
||||
解析非流式响应
|
||||
|
||||
Args:
|
||||
response: 响应 JSON
|
||||
status_code: HTTP 状态码
|
||||
|
||||
Returns:
|
||||
解析后的响应对象
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
"""
|
||||
从响应中提取 token 使用量
|
||||
|
||||
Args:
|
||||
response: 响应 JSON
|
||||
|
||||
Returns:
|
||||
包含 input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens 的字典
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def extract_text_content(self, response: Dict[str, Any]) -> str:
|
||||
"""
|
||||
从响应中提取文本内容
|
||||
|
||||
Args:
|
||||
response: 响应 JSON
|
||||
|
||||
Returns:
|
||||
提取的文本内容
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_error_response(self, response: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断响应是否为错误响应
|
||||
|
||||
Args:
|
||||
response: 响应 JSON
|
||||
|
||||
Returns:
|
||||
是否为错误响应
|
||||
"""
|
||||
return "error" in response
|
||||
|
||||
def create_stats(self) -> StreamStats:
|
||||
"""创建新的流统计对象"""
|
||||
return StreamStats()
|
||||
Reference in New Issue
Block a user