feat: 引入统一的端点检查器以重构适配器并改进错误处理和用量统计。

This commit is contained in:
hoping
2025-12-25 00:02:56 +08:00
parent 9dad194130
commit 26b4a37323
23 changed files with 2282 additions and 119 deletions

View File

@@ -32,6 +32,17 @@ class ModelsQueryRequest(BaseModel):
api_key_id: Optional[str] = None
class TestModelRequest(BaseModel):
"""模型测试请求"""
provider_id: str
model_name: str
api_key_id: Optional[str] = None
stream: bool = False
message: Optional[str] = "你好"
api_format: Optional[str] = None # 指定使用的API格式如果不指定则使用端点的默认格式
# ============ API Endpoints ============
@@ -206,3 +217,228 @@ async def query_available_models(
"display_name": provider.display_name,
},
}
@router.post("/test-model")
async def test_model(
request: TestModelRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
测试模型连接性
向指定提供商的指定模型发送测试请求,验证模型是否可用
Args:
request: 测试请求
Returns:
测试结果
"""
# 获取提供商及其端点
provider = (
db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
.filter(Provider.id == request.provider_id)
.first()
)
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 找到合适的端点和API Key
endpoint_config = None
endpoint = None
api_key = None
if request.api_key_id:
# 使用指定的API Key
for ep in provider.endpoints:
for key in ep.api_keys:
if key.id == request.api_key_id and key.is_active and ep.is_active:
endpoint = ep
api_key = key
break
if endpoint:
break
else:
# 使用第一个可用的端点和密钥
for ep in provider.endpoints:
if not ep.is_active or not ep.api_keys:
continue
for key in ep.api_keys:
if key.is_active:
endpoint = ep
api_key = key
break
if endpoint:
break
if not endpoint or not api_key:
raise HTTPException(status_code=404, detail="No active endpoint or API key found")
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"[test-model] Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
# 构建请求配置
endpoint_config = {
"api_key": api_key_value,
"api_key_id": api_key.id, # 添加API Key ID用于用量记录
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
"timeout": endpoint.timeout or 30.0,
}
try:
# 获取对应的 Adapter 类
adapter_class = _get_adapter_for_format(endpoint.api_format)
if not adapter_class:
return {
"success": False,
"error": f"Unknown API format: {endpoint.api_format}",
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
}
logger.debug(f"[test-model] 使用 Adapter: {adapter_class.__name__}")
logger.debug(f"[test-model] 端点 API Format: {endpoint.api_format}")
# 如果请求指定了 api_format优先使用它
target_api_format = request.api_format or endpoint.api_format
if request.api_format and request.api_format != endpoint.api_format:
logger.debug(f"[test-model] 请求指定 API Format: {request.api_format}")
# 重新获取适配器
adapter_class = _get_adapter_for_format(request.api_format)
if not adapter_class:
return {
"success": False,
"error": f"Unknown API format: {request.api_format}",
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
}
logger.debug(f"[test-model] 重新选择 Adapter: {adapter_class.__name__}")
# 准备测试请求数据
check_request = {
"model": request.model_name,
"messages": [
{"role": "user", "content": request.message or "Hello! This is a test message."}
],
"max_tokens": 30,
"temperature": 0.7,
}
# 发送测试请求
async with httpx.AsyncClient(timeout=endpoint_config["timeout"]) as client:
# 非流式测试
logger.debug(f"[test-model] 开始非流式测试...")
response = await adapter_class.check_endpoint(
client,
endpoint_config["base_url"],
endpoint_config["api_key"],
check_request,
endpoint_config.get("extra_headers"),
# 用量计算参数(现在强制记录)
db=db,
user=current_user,
provider_name=provider.name,
provider_id=provider.id,
api_key_id=endpoint_config.get("api_key_id"),
model_name=request.model_name,
)
# 记录提供商返回信息
logger.debug(f"[test-model] 非流式测试结果:")
logger.debug(f"[test-model] Status Code: {response.get('status_code')}")
logger.debug(f"[test-model] Response Headers: {response.get('headers', {})}")
response_data = response.get('response', {})
response_body = response_data.get('response_body', {})
logger.debug(f"[test-model] Response Data: {response_data}")
logger.debug(f"[test-model] Response Body: {response_body}")
# 尝试解析 response_body (通常是 JSON 字符串)
parsed_body = response_body
import json
if isinstance(response_body, str):
try:
parsed_body = json.loads(response_body)
except json.JSONDecodeError:
pass
if isinstance(parsed_body, dict) and 'error' in parsed_body:
error_obj = parsed_body['error']
# 兼容 error 可能是字典或字符串的情况
if isinstance(error_obj, dict):
logger.debug(f"[test-model] Error Message: {error_obj.get('message')}")
raise HTTPException(status_code=500, detail=error_obj.get('message'))
else:
logger.debug(f"[test-model] Error: {error_obj}")
raise HTTPException(status_code=500, detail=error_obj)
elif 'error' in response:
logger.debug(f"[test-model] Error: {response['error']}")
raise HTTPException(status_code=500, detail=response['error'])
else:
# 如果有选择或消息,记录内容预览
if isinstance(response_data, dict):
if 'choices' in response_data and response_data['choices']:
choice = response_data['choices'][0]
if 'message' in choice:
content = choice['message'].get('content', '')
logger.debug(f"[test-model] Content Preview: {content[:200]}...")
elif 'content' in response_data and response_data['content']:
content = str(response_data['content'])
logger.debug(f"[test-model] Content Preview: {content[:200]}...")
# 检查测试是否成功基于HTTP状态码
status_code = response.get('status_code', 0)
is_success = status_code == 200 and 'error' not in response
return {
"success": is_success,
"data": {
"stream": False,
"response": response,
},
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
"endpoint": {
"id": endpoint.id,
"api_format": endpoint.api_format,
"base_url": endpoint.base_url,
},
}
except Exception as e:
logger.error(f"[test-model] Error testing model {request.model_name}: {e}")
return {
"success": False,
"error": str(e),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
"endpoint": {
"id": endpoint.id,
"api_format": endpoint.api_format,
"base_url": endpoint.base_url,
} if endpoint else None,
}