mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 08:12:26 +08:00
feat: 引入统一的端点检查器以重构适配器并改进错误处理和用量统计。
This commit is contained in:
@@ -614,6 +614,146 @@ class CliAdapterBase(ApiAdapter):
|
||||
# 默认实现返回空列表,子类应覆盖
|
||||
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
||||
|
||||
@classmethod
|
||||
async def check_endpoint(
|
||||
cls,
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
request_data: Dict[str, Any],
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
# 用量计算参数
|
||||
db: Optional[Any] = None,
|
||||
user: Optional[Any] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
api_key_id: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
测试模型连接性(非流式)
|
||||
|
||||
通用的CLI endpoint测试方法,使用配置方法模式:
|
||||
- build_endpoint_url(): 构建请求URL
|
||||
- build_base_headers(): 构建基础认证头
|
||||
- get_protected_header_keys(): 获取受保护的头部key
|
||||
- build_request_body(): 构建请求体
|
||||
- get_cli_user_agent(): 获取CLI User-Agent(子类可覆盖)
|
||||
|
||||
Args:
|
||||
client: httpx 异步客户端
|
||||
base_url: API 基础 URL
|
||||
api_key: API 密钥(已解密)
|
||||
request_data: 请求数据
|
||||
extra_headers: 端点配置的额外请求头
|
||||
db: 数据库会话
|
||||
user: 用户对象
|
||||
provider_name: 提供商名称
|
||||
provider_id: 提供商ID
|
||||
api_key_id: API密钥ID
|
||||
model_name: 模型名称
|
||||
|
||||
Returns:
|
||||
测试响应数据
|
||||
"""
|
||||
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
|
||||
|
||||
# 构建请求组件
|
||||
url = cls.build_endpoint_url(base_url, request_data, model_name)
|
||||
base_headers = cls.build_base_headers(api_key)
|
||||
protected_keys = cls.get_protected_header_keys()
|
||||
|
||||
# 添加CLI User-Agent
|
||||
cli_user_agent = cls.get_cli_user_agent()
|
||||
if cli_user_agent:
|
||||
base_headers["User-Agent"] = cli_user_agent
|
||||
protected_keys = tuple(list(protected_keys) + ["user-agent"])
|
||||
|
||||
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
|
||||
body = cls.build_request_body(request_data)
|
||||
|
||||
# 获取有效的模型名称
|
||||
effective_model_name = model_name or request_data.get("model")
|
||||
|
||||
return await run_endpoint_check(
|
||||
client=client,
|
||||
url=url,
|
||||
headers=headers,
|
||||
json_body=body,
|
||||
api_format=cls.name,
|
||||
# 用量计算参数(现在强制记录)
|
||||
db=db,
|
||||
user=user,
|
||||
provider_name=provider_name,
|
||||
provider_id=provider_id,
|
||||
api_key_id=api_key_id,
|
||||
model_name=effective_model_name,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# CLI Adapter 配置方法 - 子类应覆盖这些方法而不是整个 check_endpoint
|
||||
# =========================================================================
|
||||
|
||||
@classmethod
|
||||
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
|
||||
"""
|
||||
构建CLI API端点URL - 子类应覆盖
|
||||
|
||||
Args:
|
||||
base_url: API基础URL
|
||||
request_data: 请求数据
|
||||
model_name: 模型名称(某些API需要,如Gemini)
|
||||
|
||||
Returns:
|
||||
完整的端点URL
|
||||
"""
|
||||
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_endpoint_url")
|
||||
|
||||
@classmethod
|
||||
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||
"""
|
||||
构建CLI API认证头 - 子类应覆盖
|
||||
|
||||
Args:
|
||||
api_key: API密钥
|
||||
|
||||
Returns:
|
||||
基础认证头部字典
|
||||
"""
|
||||
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_base_headers")
|
||||
|
||||
@classmethod
|
||||
def get_protected_header_keys(cls) -> tuple:
|
||||
"""
|
||||
返回CLI API的保护头部key - 子类应覆盖
|
||||
|
||||
Returns:
|
||||
保护头部key的元组
|
||||
"""
|
||||
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement get_protected_header_keys")
|
||||
|
||||
@classmethod
|
||||
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
构建CLI API请求体 - 子类应覆盖
|
||||
|
||||
Args:
|
||||
request_data: 请求数据
|
||||
|
||||
Returns:
|
||||
请求体字典
|
||||
"""
|
||||
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_request_body")
|
||||
|
||||
@classmethod
|
||||
def get_cli_user_agent(cls) -> Optional[str]:
|
||||
"""
|
||||
获取CLI User-Agent - 子类可覆盖
|
||||
|
||||
Returns:
|
||||
CLI User-Agent字符串,如果不需要则为None
|
||||
"""
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
|
||||
|
||||
Reference in New Issue
Block a user