Files
Aether/src/api/handlers/gemini/adapter.py

219 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Gemini Chat Adapter
处理 Gemini API 格式的请求适配
"""
from typing import Any, Dict, Optional, Tuple, Type
import httpx
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.core.logger import logger
from src.models.gemini import GeminiRequest
@register_adapter
class GeminiChatAdapter(ChatAdapterBase):
"""
Gemini Chat API 适配器
处理 Gemini Chat 格式的请求
端点: /v1beta/models/{model}:generateContent
"""
FORMAT_ID = "GEMINI"
name = "gemini.chat"
@property
def HANDLER_CLASS(self) -> Type[ChatHandlerBase]:
"""延迟导入 Handler 类避免循环依赖"""
from src.api.handlers.gemini.handler import GeminiChatHandler
return GeminiChatHandler
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
super().__init__(allowed_api_formats or ["GEMINI"])
logger.info(f"[{self.name}] 初始化 Gemini Chat 适配器 | API格式: {self.allowed_api_formats}")
def extract_api_key(self, request: Request) -> Optional[str]:
"""从请求中提取 API 密钥 (x-goog-api-key)"""
return request.headers.get("x-goog-api-key")
def _merge_path_params(
self, original_request_body: Dict[str, Any], path_params: Dict[str, Any] # noqa: ARG002
) -> Dict[str, Any]:
"""
合并 URL 路径参数到请求体 - Gemini 特化版本
Gemini API 特点:
- model 不合并到请求体(通过 extract_model_from_request 从 path_params 获取)
- stream 不合并到请求体Gemini API 通过 URL 端点区分流式/非流式)
Handler 层的 extract_model_from_request 会从 path_params 获取 model
prepare_provider_request_body 会确保发送给 Gemini API 的请求体不含 model。
Args:
original_request_body: 原始请求体字典
path_params: URL 路径参数字典(不使用)
Returns:
原始请求体(不合并任何 path_params
"""
return original_request_body.copy()
def _validate_request_body(self, original_request_body: dict, path_params: dict = None):
"""验证请求体"""
path_params = path_params or {}
is_stream = path_params.get("stream", False)
model = path_params.get("model", "unknown")
try:
if not isinstance(original_request_body, dict):
raise ValueError("Request body must be a JSON object")
# Gemini 必需字段: contents
if "contents" not in original_request_body:
raise ValueError("Missing required field: contents")
request = GeminiRequest.model_validate(
original_request_body,
strict=False,
)
except ValueError as e:
logger.error(f"请求体基本验证失败: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.warning(f"Pydantic验证警告(将继续处理): {str(e)}")
request = GeminiRequest.model_construct(
contents=original_request_body.get("contents", []),
)
# 设置 model从 path_params 获取,用于日志和审计)
request.model = model
# 设置 stream 属性(用于 ChatAdapterBase 判断流式模式)
request.stream = is_stream
return request
def _extract_message_count(self, payload: Dict[str, Any], request_obj) -> int:
"""提取消息数量"""
contents = payload.get("contents", [])
if hasattr(request_obj, "contents"):
contents = request_obj.contents
return len(contents) if isinstance(contents, list) else 0
def _build_audit_metadata(self, payload: Dict[str, Any], request_obj) -> Dict[str, Any]:
"""构建 Gemini Chat 特定的审计元数据"""
role_counts: dict[str, int] = {}
contents = getattr(request_obj, "contents", []) or []
for content in contents:
role = getattr(content, "role", None) or content.get("role", "unknown")
role_counts[role] = role_counts.get(role, 0) + 1
generation_config = getattr(request_obj, "generation_config", None) or {}
if hasattr(generation_config, "dict"):
generation_config = generation_config.dict()
elif not isinstance(generation_config, dict):
generation_config = {}
# 判断流式模式
stream = getattr(request_obj, "stream", False)
return {
"action": "gemini_generate_content",
"model": getattr(request_obj, "model", payload.get("model", "unknown")),
"stream": bool(stream),
"max_output_tokens": generation_config.get("max_output_tokens"),
"temperature": generation_config.get("temperature"),
"top_p": generation_config.get("top_p"),
"top_k": generation_config.get("top_k"),
"contents_count": len(contents),
"content_roles": role_counts,
"tools_count": len(getattr(request_obj, "tools", None) or []),
"system_instruction_present": bool(getattr(request_obj, "system_instruction", None)),
"safety_settings_count": len(getattr(request_obj, "safety_settings", None) or []),
}
def _error_response(self, status_code: int, error_type: str, message: str) -> JSONResponse:
"""生成 Gemini 格式的错误响应"""
# Gemini 错误响应格式
return JSONResponse(
status_code=status_code,
content={
"error": {
"code": status_code,
"message": message,
"status": error_type.upper(),
}
},
)
@classmethod
async def fetch_models(
cls,
client: httpx.AsyncClient,
base_url: str,
api_key: str,
extra_headers: Optional[Dict[str, str]] = None,
) -> Tuple[list, Optional[str]]:
"""查询 Gemini API 支持的模型列表"""
# 兼容 base_url 已包含 /v1beta 的情况
base_url_clean = base_url.rstrip("/")
if base_url_clean.endswith("/v1beta"):
models_url = f"{base_url_clean}/models?key={api_key}"
else:
models_url = f"{base_url_clean}/v1beta/models?key={api_key}"
headers: Dict[str, str] = {}
if extra_headers:
headers.update(extra_headers)
try:
response = await client.get(models_url, headers=headers)
logger.debug(f"Gemini models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
if "models" in data:
# 转换为统一格式
return [
{
"id": m.get("name", "").replace("models/", ""),
"owned_by": "google",
"display_name": m.get("displayName", ""),
"api_format": cls.FORMAT_ID,
}
for m in data["models"]
], None
return [], None
else:
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"Gemini models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
return [], error_msg
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
"""
根据请求头构建适当的 Gemini 适配器
Args:
x_app_header: X-App 请求头值
Returns:
GeminiChatAdapter 实例
"""
# 目前只有一种 Gemini 适配器
# 未来可以根据 x_app_header 返回不同的适配器(如 CLI 模式)
return GeminiChatAdapter()
__all__ = ["GeminiChatAdapter", "build_gemini_adapter"]