feat: 引入统一的端点检查器以重构适配器并改进错误处理和用量统计。

This commit is contained in:
hoping
2025-12-25 00:02:56 +08:00
parent 9dad194130
commit 26b4a37323
23 changed files with 2282 additions and 119 deletions

View File

@@ -63,6 +63,34 @@ class ChatAdapterBase(ApiAdapter):
name: str = "chat.base"
mode = ApiMode.STANDARD
# 子类可以配置的特殊方法用于check_endpoint
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
"""构建端点URL子类可以覆盖以自定义URL构建逻辑"""
# 默认实现在base_url后添加特定路径
return base_url
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建基础请求头,子类可以覆盖以自定义认证头"""
# 默认实现Bearer token认证
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回不应被extra_headers覆盖的头部key子类可以覆盖"""
# 默认保护认证相关头部
return ("authorization", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建请求体,子类可以覆盖以自定义请求格式转换"""
# 默认实现:直接使用请求数据
return request_data.copy()
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
@@ -654,6 +682,65 @@ class ChatAdapterBase(ApiAdapter):
# 默认实现返回空列表,子类应覆盖
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
@classmethod
async def check_endpoint(
cls,
client: httpx.AsyncClient,
base_url: str,
api_key: str,
request_data: Dict[str, Any],
extra_headers: Optional[Dict[str, str]] = None,
# 用量计算参数(现在强制记录)
db: Optional[Any] = None,
user: Optional[Any] = None,
provider_name: Optional[str] = None,
provider_id: Optional[str] = None,
api_key_id: Optional[str] = None,
model_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
测试模型连接性(非流式)
Args:
client: httpx 异步客户端
base_url: API 基础 URL
api_key: API 密钥(已解密)
request_data: 请求数据
extra_headers: 端点配置的额外请求头
db: 数据库会话
user: 用户对象
provider_name: 提供商名称
provider_id: 提供商ID
api_key_id: API Key ID
model_name: 模型名称
Returns:
测试响应数据
"""
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
# 使用子类配置方法构建请求组件
url = cls.build_endpoint_url(base_url)
base_headers = cls.build_base_headers(api_key)
protected_keys = cls.get_protected_header_keys()
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
body = cls.build_request_body(request_data)
# 使用通用的endpoint checker执行请求
return await run_endpoint_check(
client=client,
url=url,
headers=headers,
json_body=body,
api_format=cls.name,
# 用量计算参数(现在强制记录)
db=db,
user=user,
provider_name=provider_name,
provider_id=provider_id,
api_key_id=api_key_id,
model_name=model_name or request_data.get("model"),
)
# =========================================================================
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例

View File

@@ -614,6 +614,146 @@ class CliAdapterBase(ApiAdapter):
# 默认实现返回空列表,子类应覆盖
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
@classmethod
async def check_endpoint(
cls,
client: httpx.AsyncClient,
base_url: str,
api_key: str,
request_data: Dict[str, Any],
extra_headers: Optional[Dict[str, str]] = None,
# 用量计算参数
db: Optional[Any] = None,
user: Optional[Any] = None,
provider_name: Optional[str] = None,
provider_id: Optional[str] = None,
api_key_id: Optional[str] = None,
model_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
测试模型连接性(非流式)
通用的CLI endpoint测试方法使用配置方法模式
- build_endpoint_url(): 构建请求URL
- build_base_headers(): 构建基础认证头
- get_protected_header_keys(): 获取受保护的头部key
- build_request_body(): 构建请求体
- get_cli_user_agent(): 获取CLI User-Agent子类可覆盖
Args:
client: httpx 异步客户端
base_url: API 基础 URL
api_key: API 密钥(已解密)
request_data: 请求数据
extra_headers: 端点配置的额外请求头
db: 数据库会话
user: 用户对象
provider_name: 提供商名称
provider_id: 提供商ID
api_key_id: API密钥ID
model_name: 模型名称
Returns:
测试响应数据
"""
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
# 构建请求组件
url = cls.build_endpoint_url(base_url, request_data, model_name)
base_headers = cls.build_base_headers(api_key)
protected_keys = cls.get_protected_header_keys()
# 添加CLI User-Agent
cli_user_agent = cls.get_cli_user_agent()
if cli_user_agent:
base_headers["User-Agent"] = cli_user_agent
protected_keys = tuple(list(protected_keys) + ["user-agent"])
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
body = cls.build_request_body(request_data)
# 获取有效的模型名称
effective_model_name = model_name or request_data.get("model")
return await run_endpoint_check(
client=client,
url=url,
headers=headers,
json_body=body,
api_format=cls.name,
# 用量计算参数(现在强制记录)
db=db,
user=user,
provider_name=provider_name,
provider_id=provider_id,
api_key_id=api_key_id,
model_name=effective_model_name,
)
# =========================================================================
# CLI Adapter 配置方法 - 子类应覆盖这些方法而不是整个 check_endpoint
# =========================================================================
@classmethod
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
"""
构建CLI API端点URL - 子类应覆盖
Args:
base_url: API基础URL
request_data: 请求数据
model_name: 模型名称某些API需要如Gemini
Returns:
完整的端点URL
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_endpoint_url")
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""
构建CLI API认证头 - 子类应覆盖
Args:
api_key: API密钥
Returns:
基础认证头部字典
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_base_headers")
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""
返回CLI API的保护头部key - 子类应覆盖
Returns:
保护头部key的元组
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement get_protected_header_keys")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""
构建CLI API请求体 - 子类应覆盖
Args:
request_data: 请求数据
Returns:
请求体字典
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_request_body")
@classmethod
def get_cli_user_agent(cls) -> Optional[str]:
"""
获取CLI User-Agent - 子类可覆盖
Returns:
CLI User-Agent字符串如果不需要则为None
"""
return None
# =========================================================================
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例

File diff suppressed because it is too large Load Diff