Files
Aether/src/api/handlers/base/cli_adapter_base.py

823 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
CLI Adapter 通用基类
提供 CLI 格式(直接透传请求)的通用适配器逻辑:
- 请求解析和验证
- 审计日志记录
- 错误处理和响应格式化
- Handler 创建和调用
- 计费策略(支持不同 API 格式的差异化计费)
子类只需提供:
- FORMAT_ID: API 格式标识
- HANDLER_CLASS: 对应的 MessageHandler 类
- 可选覆盖 _extract_message_count() 自定义消息计数逻辑
- 可选覆盖 compute_total_input_context() 自定义总输入上下文计算
"""
import time
import traceback
from typing import Any, Dict, Optional, Tuple, Type
import httpx
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from src.api.base.adapter import ApiAdapter, ApiMode
from src.api.base.context import ApiRequestContext
from src.api.handlers.base.cli_handler_base import CliMessageHandlerBase
from src.core.exceptions import (
InvalidRequestException,
ModelNotSupportedException,
ProviderAuthException,
ProviderNotAvailableException,
ProviderRateLimitException,
ProviderTimeoutException,
ProxyException,
QuotaExceededException,
UpstreamClientException,
)
from src.core.logger import logger
from src.services.request.result import RequestResult
from src.services.usage.recorder import UsageRecorder
class CliAdapterBase(ApiAdapter):
"""
CLI Adapter 通用基类
提供 CLI 格式的通用适配器逻辑,子类只需配置:
- FORMAT_ID: API 格式标识
- HANDLER_CLASS: MessageHandler 类
- name: 适配器名称
"""
# 子类必须覆盖
FORMAT_ID: str = "UNKNOWN"
HANDLER_CLASS: Type[CliMessageHandlerBase]
# 适配器配置
name: str = "cli.base"
mode = ApiMode.PROXY
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
async def handle(self, context: ApiRequestContext):
"""处理 CLI API 请求"""
http_request = context.request
user = context.user
api_key = context.api_key
db = context.db
request_id = context.request_id
quota_remaining_value = context.quota_remaining
start_time = context.start_time
client_ip = context.client_ip
user_agent = context.user_agent
original_headers = context.original_headers
query_params = context.query_params # 获取查询参数
original_request_body = context.ensure_json_body()
# 合并 path_params 到请求体(如 Gemini API 的 model 在 URL 路径中)
if context.path_params:
original_request_body = self._merge_path_params(
original_request_body, context.path_params
)
# 获取 stream优先从请求体其次从 path_params如 Gemini 通过 URL 端点区分)
stream = original_request_body.get("stream")
if stream is None and context.path_params:
stream = context.path_params.get("stream", False)
stream = bool(stream)
# 获取 model优先从请求体其次从 path_params如 Gemini 的 model 在 URL 路径中)
model = original_request_body.get("model")
if model is None and context.path_params:
model = context.path_params.get("model", "unknown")
model = model or "unknown"
# 提取请求元数据
audit_metadata = self._build_audit_metadata(original_request_body, context.path_params)
context.add_audit_metadata(**audit_metadata)
# 格式化额度显示
quota_display = (
"unlimited" if quota_remaining_value is None else f"${quota_remaining_value:.2f}"
)
# 请求开始日志
logger.info(f"[REQ] {request_id[:8]} | {self.FORMAT_ID} | {getattr(api_key, 'name', 'unknown')} | "
f"{model} | {'stream' if stream else 'sync'} | quota:{quota_display}")
try:
# 检查客户端连接
if await http_request.is_disconnected():
logger.warning("客户端连接断开")
raise HTTPException(status_code=499, detail="Client disconnected")
# 创建 Handler
handler = self.HANDLER_CLASS(
db=db,
user=user,
api_key=api_key,
request_id=request_id,
client_ip=client_ip,
user_agent=user_agent,
start_time=start_time,
allowed_api_formats=self.allowed_api_formats,
adapter_detector=self.detect_capability_requirements,
)
# 处理请求
if stream:
return await handler.process_stream(
original_request_body=original_request_body,
original_headers=original_headers,
query_params=query_params,
path_params=context.path_params,
)
return await handler.process_sync(
original_request_body=original_request_body,
original_headers=original_headers,
query_params=query_params,
path_params=context.path_params,
)
except HTTPException:
raise
except (
ModelNotSupportedException,
QuotaExceededException,
InvalidRequestException,
) as e:
logger.debug(f"客户端请求错误: {e.error_type}")
return self._error_response(
status_code=e.status_code,
error_type=(
"invalid_request_error" if e.status_code == 400 else "quota_exceeded"
),
message=e.message,
)
except (
ProviderAuthException,
ProviderRateLimitException,
ProviderNotAvailableException,
ProviderTimeoutException,
UpstreamClientException,
) as e:
return await self._handle_provider_exception(
e,
db=db,
user=user,
api_key=api_key,
model=model,
stream=stream,
start_time=start_time,
original_headers=original_headers,
original_request_body=original_request_body,
client_ip=client_ip,
request_id=request_id,
)
except Exception as e:
return await self._handle_unexpected_exception(
e,
db=db,
user=user,
api_key=api_key,
model=model,
stream=stream,
start_time=start_time,
original_headers=original_headers,
original_request_body=original_request_body,
client_ip=client_ip,
request_id=request_id,
)
def _merge_path_params(
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any]
) -> Dict[str, Any]:
"""
合并 URL 路径参数到请求体 - 子类可覆盖
默认实现:直接将 path_params 中的字段合并到请求体(不覆盖已有字段)
Args:
original_request_body: 原始请求体字典
path_params: URL 路径参数字典
Returns:
合并后的请求体字典
"""
merged = original_request_body.copy()
for key, value in path_params.items():
if key not in merged:
merged[key] = value
return merged
def _extract_message_count(self, payload: Dict[str, Any]) -> int:
"""
提取消息数量 - 子类可覆盖
默认实现:从 input 字段提取
"""
if "input" not in payload:
return 0
input_data = payload["input"]
if isinstance(input_data, list):
return len(input_data)
if isinstance(input_data, dict) and "messages" in input_data:
return len(input_data.get("messages", []))
return 0
def _build_audit_metadata(
self,
payload: Dict[str, Any],
path_params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
构建审计日志元数据 - 子类可覆盖
Args:
payload: 请求体
path_params: URL 路径参数(用于获取 model 等)
"""
# 优先从请求体获取 model其次从 path_params
model = payload.get("model")
if model is None and path_params:
model = path_params.get("model", "unknown")
model = model or "unknown"
stream = payload.get("stream", False)
messages_count = self._extract_message_count(payload)
return {
"action": f"{self.FORMAT_ID.lower()}_request",
"model": model,
"stream": bool(stream),
"max_tokens": payload.get("max_tokens"),
"messages_count": messages_count,
"temperature": payload.get("temperature"),
"top_p": payload.get("top_p"),
"tool_count": len(payload.get("tools") or []),
"instructions_present": bool(payload.get("instructions")),
}
async def _handle_provider_exception(
self,
e: Exception,
*,
db,
user,
api_key,
model: str,
stream: bool,
start_time: float,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
client_ip: str,
request_id: str,
) -> JSONResponse:
"""处理 Provider 相关异常"""
logger.debug(f"Caught provider exception: {type(e).__name__}")
response_time = int((time.time() - start_time) * 1000)
# 使用 RequestResult.from_exception 创建统一的失败结果
# 关键api_format 从 FORMAT_ID 获取,确保始终有值
result = RequestResult.from_exception(
exception=e,
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
model=model,
response_time_ms=response_time,
is_stream=stream,
)
result.request_headers = original_headers
result.request_body = original_request_body
# 确定错误消息
if isinstance(e, ProviderAuthException):
error_message = (
f"提供商认证失败: {str(e)}"
if result.metadata.provider != "unknown"
else "服务端错误: 无可用提供商"
)
result.error_message = error_message
# 处理上游客户端错误(如图片处理失败)
if isinstance(e, UpstreamClientException):
# 返回 400 状态码和清晰的错误消息
result.status_code = e.status_code
result.error_message = e.message
# 使用 UsageRecorder 记录失败
recorder = UsageRecorder(
db=db,
user=user,
api_key=api_key,
client_ip=client_ip,
request_id=request_id,
)
await recorder.record_failure(result, original_headers, original_request_body)
# 根据异常类型确定错误类型
if isinstance(e, UpstreamClientException):
error_type = "invalid_request_error"
elif result.status_code == 503:
error_type = "internal_server_error"
else:
error_type = "rate_limit_exceeded"
return self._error_response(
status_code=result.status_code,
error_type=error_type,
message=result.error_message or str(e),
)
async def _handle_unexpected_exception(
self,
e: Exception,
*,
db,
user,
api_key,
model: str,
stream: bool,
start_time: float,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
client_ip: str,
request_id: str,
) -> JSONResponse:
"""处理未预期的异常"""
if isinstance(e, ProxyException):
logger.error(f"{self.FORMAT_ID} 请求处理业务异常: {type(e).__name__}")
else:
logger.error(f"{self.FORMAT_ID} 请求处理意外异常",
exception=e,
extra_data={
"exception_class": e.__class__.__name__,
"processing_stage": "request_processing",
"model": model,
"stream": stream,
"traceback_preview": str(traceback.format_exc())[:500],
},
)
response_time = int((time.time() - start_time) * 1000)
# 使用 RequestResult.from_exception 创建统一的失败结果
# 关键api_format 从 FORMAT_ID 获取,确保始终有值
result = RequestResult.from_exception(
exception=e,
api_format=self.FORMAT_ID, # 使用 Adapter 的 FORMAT_ID 作为默认值
model=model,
response_time_ms=response_time,
is_stream=stream,
)
# 对于未预期的异常,强制设置状态码为 500
result.status_code = 500
result.error_type = "internal_error"
result.request_headers = original_headers
result.request_body = original_request_body
# 使用 UsageRecorder 记录失败
recorder = UsageRecorder(
db=db,
user=user,
api_key=api_key,
client_ip=client_ip,
request_id=request_id,
)
await recorder.record_failure(result, original_headers, original_request_body)
return self._error_response(
status_code=500,
error_type="internal_server_error",
message="处理请求时发生内部错误")
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
"""生成错误响应"""
return JSONResponse(
status_code=status_code,
content={
"error": {
"type": error_type,
"message": message,
}
},
)
# =========================================================================
# 计费策略相关方法 - 子类可覆盖以实现不同 API 格式的差异化计费
# =========================================================================
def compute_total_input_context(
self,
input_tokens: int,
cache_read_input_tokens: int,
cache_creation_input_tokens: int = 0,
) -> int:
"""
计算总输入上下文(用于阶梯计费判定)
默认实现input_tokens + cache_read_input_tokens
子类可覆盖此方法实现不同的计算逻辑
Args:
input_tokens: 输入 token 数
cache_read_input_tokens: 缓存读取 token 数
cache_creation_input_tokens: 缓存创建 token 数(部分格式可能需要)
Returns:
总输入上下文 token 数
"""
return input_tokens + cache_read_input_tokens
def get_cache_read_price_for_ttl(
self,
tier: dict,
cache_ttl_minutes: Optional[int] = None,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
子类可覆盖此方法实现不同的 TTL 定价逻辑
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格(每 1M tokens
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if ttl_pricing and cache_ttl_minutes is not None:
matched_price = None
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
matched_price = ttl_config.get("cache_read_price_per_1m")
break
if matched_price is not None:
return matched_price
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
return ttl_pricing[-1].get("cache_read_price_per_1m")
return tier.get("cache_read_price_per_1m")
def compute_cost(
self,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int,
cache_read_input_tokens: int,
input_price_per_1m: float,
output_price_per_1m: float,
cache_creation_price_per_1m: Optional[float],
cache_read_price_per_1m: Optional[float],
price_per_request: Optional[float],
tiered_pricing: Optional[dict] = None,
cache_ttl_minutes: Optional[int] = None,
) -> Dict[str, Any]:
"""
计算请求成本
默认实现:支持固定价格和阶梯计费
子类可覆盖此方法实现完全不同的计费逻辑
Args:
input_tokens: 输入 token 数
output_tokens: 输出 token 数
cache_creation_input_tokens: 缓存创建 token 数
cache_read_input_tokens: 缓存读取 token 数
input_price_per_1m: 输入价格(每 1M tokens
output_price_per_1m: 输出价格(每 1M tokens
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens
cache_read_price_per_1m: 缓存读取价格(每 1M tokens
price_per_request: 按次计费价格
tiered_pricing: 阶梯计费配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
包含各项成本的字典
"""
tier_index = None
effective_input_price = input_price_per_1m
effective_output_price = output_price_per_1m
effective_cache_creation_price = cache_creation_price_per_1m
effective_cache_read_price = cache_read_price_per_1m
# 检查阶梯计费
if tiered_pricing and tiered_pricing.get("tiers"):
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
if tier:
tier_index = tiered_pricing["tiers"].index(tier)
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
effective_cache_creation_price = tier.get(
"cache_creation_price_per_1m", cache_creation_price_per_1m
)
effective_cache_read_price = self.get_cache_read_price_for_ttl(
tier, cache_ttl_minutes
)
if effective_cache_read_price is None:
effective_cache_read_price = cache_read_price_per_1m
# 计算各项成本
input_cost = (input_tokens / 1_000_000) * effective_input_price
output_cost = (output_tokens / 1_000_000) * effective_output_price
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * effective_cache_creation_price
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
cache_read_cost = (
cache_read_input_tokens / 1_000_000
) * effective_cache_read_price
cache_cost = cache_creation_cost + cache_read_cost
request_cost = price_per_request if price_per_request else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return {
"input_cost": input_cost,
"output_cost": output_cost,
"cache_creation_cost": cache_creation_cost,
"cache_read_cost": cache_read_cost,
"cache_cost": cache_cost,
"request_cost": request_cost,
"total_cost": total_cost,
"tier_index": tier_index,
}
@staticmethod
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
"""根据总输入 token 数确定价格阶梯"""
if not tiered_pricing or "tiers" not in tiered_pricing:
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
for tier in tiers:
up_to = tier.get("up_to")
if up_to is None or total_input_tokens <= up_to:
return tier
return tiers[-1] if tiers else None
# =========================================================================
# 模型列表查询 - 子类应覆盖此方法
# =========================================================================
@classmethod
async def fetch_models(
cls,
client: httpx.AsyncClient,
base_url: str,
api_key: str,
extra_headers: Optional[Dict[str, str]] = None,
) -> Tuple[list, Optional[str]]:
"""
查询上游 API 支持的模型列表
这是 Aether 内部发起的请求(非用户透传),用于:
- 管理后台查询提供商支持的模型
- 自动发现可用模型
Args:
client: httpx 异步客户端
base_url: API 基础 URL
api_key: API 密钥(已解密)
extra_headers: 端点配置的额外请求头
Returns:
(models, error): 模型列表和错误信息
- models: 模型信息列表,每个模型至少包含 id 字段
- error: 错误信息,成功时为 None
"""
# 默认实现返回空列表,子类应覆盖
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
@classmethod
async def check_endpoint(
cls,
client: httpx.AsyncClient,
base_url: str,
api_key: str,
request_data: Dict[str, Any],
extra_headers: Optional[Dict[str, str]] = None,
# 用量计算参数
db: Optional[Any] = None,
user: Optional[Any] = None,
provider_name: Optional[str] = None,
provider_id: Optional[str] = None,
api_key_id: Optional[str] = None,
model_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
测试模型连接性(非流式)
通用的CLI endpoint测试方法使用配置方法模式
- build_endpoint_url(): 构建请求URL
- build_base_headers(): 构建基础认证头
- get_protected_header_keys(): 获取受保护的头部key
- build_request_body(): 构建请求体
- get_cli_user_agent(): 获取CLI User-Agent子类可覆盖
Args:
client: httpx 异步客户端
base_url: API 基础 URL
api_key: API 密钥(已解密)
request_data: 请求数据
extra_headers: 端点配置的额外请求头
db: 数据库会话
user: 用户对象
provider_name: 提供商名称
provider_id: 提供商ID
api_key_id: API密钥ID
model_name: 模型名称
Returns:
测试响应数据
"""
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
# 构建请求组件
url = cls.build_endpoint_url(base_url, request_data, model_name)
base_headers = cls.build_base_headers(api_key)
protected_keys = cls.get_protected_header_keys()
# 添加CLI User-Agent
cli_user_agent = cls.get_cli_user_agent()
if cli_user_agent:
base_headers["User-Agent"] = cli_user_agent
protected_keys = tuple(list(protected_keys) + ["user-agent"])
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
body = cls.build_request_body(request_data)
# 获取有效的模型名称
effective_model_name = model_name or request_data.get("model")
return await run_endpoint_check(
client=client,
url=url,
headers=headers,
json_body=body,
api_format=cls.name,
# 用量计算参数(现在强制记录)
db=db,
user=user,
provider_name=provider_name,
provider_id=provider_id,
api_key_id=api_key_id,
model_name=effective_model_name,
)
# =========================================================================
# CLI Adapter 配置方法 - 子类应覆盖这些方法而不是整个 check_endpoint
# =========================================================================
@classmethod
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
"""
构建CLI API端点URL - 子类应覆盖
Args:
base_url: API基础URL
request_data: 请求数据
model_name: 模型名称某些API需要如Gemini
Returns:
完整的端点URL
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_endpoint_url")
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""
构建CLI API认证头 - 子类应覆盖
Args:
api_key: API密钥
Returns:
基础认证头部字典
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_base_headers")
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""
返回CLI API的保护头部key - 子类应覆盖
Returns:
保护头部key的元组
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement get_protected_header_keys")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""
构建CLI API请求体 - 子类应覆盖
Args:
request_data: 请求数据
Returns:
请求体字典
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_request_body")
@classmethod
def get_cli_user_agent(cls) -> Optional[str]:
"""
获取CLI User-Agent - 子类可覆盖
Returns:
CLI User-Agent字符串如果不需要则为None
"""
return None
# =========================================================================
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
# =========================================================================
_CLI_ADAPTER_REGISTRY: Dict[str, Type["CliAdapterBase"]] = {}
_CLI_ADAPTERS_LOADED = False
def register_cli_adapter(adapter_class: Type["CliAdapterBase"]) -> Type["CliAdapterBase"]:
"""
注册 CLI Adapter 类到注册表
用法:
@register_cli_adapter
class ClaudeCliAdapter(CliAdapterBase):
FORMAT_ID = "CLAUDE_CLI"
...
"""
format_id = adapter_class.FORMAT_ID
if format_id and format_id != "UNKNOWN":
_CLI_ADAPTER_REGISTRY[format_id.upper()] = adapter_class
return adapter_class
def _ensure_cli_adapters_loaded():
"""确保所有 CLI Adapter 已被加载(触发注册)"""
global _CLI_ADAPTERS_LOADED
if _CLI_ADAPTERS_LOADED:
return
# 导入各个 CLI Adapter 模块以触发 @register_cli_adapter 装饰器
try:
from src.api.handlers.claude_cli import adapter as _ # noqa: F401
except ImportError:
pass
try:
from src.api.handlers.openai_cli import adapter as _ # noqa: F401
except ImportError:
pass
try:
from src.api.handlers.gemini_cli import adapter as _ # noqa: F401
except ImportError:
pass
_CLI_ADAPTERS_LOADED = True
def get_cli_adapter_class(api_format: str) -> Optional[Type["CliAdapterBase"]]:
"""根据 API format 获取 CLI Adapter 类"""
_ensure_cli_adapters_loaded()
return _CLI_ADAPTER_REGISTRY.get(api_format.upper()) if api_format else None
def get_cli_adapter_instance(api_format: str) -> Optional["CliAdapterBase"]:
"""根据 API format 获取 CLI Adapter 实例"""
adapter_class = get_cli_adapter_class(api_format)
if adapter_class:
return adapter_class()
return None
def list_registered_cli_formats() -> list[str]:
"""返回所有已注册的 CLI API 格式"""
_ensure_cli_adapters_loaded()
return list(_CLI_ADAPTER_REGISTRY.keys())