mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-06 17:52:29 +08:00
176 lines
5.6 KiB
Python
176 lines
5.6 KiB
Python
"""
|
||
统一的 Provider 请求构建工具。
|
||
|
||
负责:
|
||
- 根据 endpoint/key 构建标准请求头
|
||
- 根据 API 格式或端点配置生成请求 URL
|
||
"""
|
||
|
||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||
from urllib.parse import urlencode
|
||
|
||
from src.core.api_format_metadata import get_auth_config, get_default_path, resolve_api_format
|
||
from src.core.crypto import crypto_service
|
||
from src.core.enums import APIFormat
|
||
from src.core.logger import logger
|
||
|
||
if TYPE_CHECKING:
|
||
from src.models.database import ProviderAPIKey, ProviderEndpoint
|
||
|
||
|
||
|
||
def build_provider_headers(
|
||
endpoint: "ProviderEndpoint",
|
||
key: "ProviderAPIKey",
|
||
original_headers: Optional[Dict[str, str]] = None,
|
||
*,
|
||
extra_headers: Optional[Dict[str, str]] = None,
|
||
) -> Dict[str, str]:
|
||
"""
|
||
根据 endpoint/key 构建请求头,并透传客户端自定义头。
|
||
"""
|
||
headers: Dict[str, str] = {}
|
||
|
||
# api_key 在数据库中是 NOT NULL,类型标注为 Optional 是 SQLAlchemy 限制
|
||
decrypted_key = crypto_service.decrypt(key.api_key) # type: ignore[arg-type]
|
||
|
||
# 根据 API 格式自动选择认证头
|
||
api_format = getattr(endpoint, "api_format", None)
|
||
resolved_format = resolve_api_format(api_format)
|
||
auth_header, auth_type = (
|
||
get_auth_config(resolved_format) if resolved_format else ("Authorization", "bearer")
|
||
)
|
||
|
||
if auth_type == "bearer":
|
||
headers[auth_header] = f"Bearer {decrypted_key}"
|
||
else:
|
||
headers[auth_header] = decrypted_key
|
||
|
||
if endpoint.headers:
|
||
headers.update(endpoint.headers)
|
||
|
||
excluded_headers = {
|
||
"host",
|
||
"authorization",
|
||
"x-api-key",
|
||
"x-goog-api-key",
|
||
"content-length",
|
||
"transfer-encoding",
|
||
}
|
||
|
||
if original_headers:
|
||
for name, value in original_headers.items():
|
||
if name.lower() not in excluded_headers:
|
||
headers[name] = value
|
||
|
||
if extra_headers:
|
||
headers.update(extra_headers)
|
||
|
||
if "Content-Type" not in headers and "content-type" not in headers:
|
||
headers["Content-Type"] = "application/json"
|
||
|
||
return headers
|
||
|
||
|
||
def _normalize_base_url(base_url: str, path: str) -> str:
|
||
"""
|
||
规范化 base_url,去除末尾的斜杠和可能与 path 重复的版本前缀。
|
||
|
||
只有当 path 以版本前缀开头时,才从 base_url 中移除该前缀,
|
||
避免拼接出 /v1/v1/messages 这样的重复路径。
|
||
|
||
兼容用户填写的各种格式:
|
||
- https://api.example.com
|
||
- https://api.example.com/
|
||
- https://api.example.com/v1
|
||
- https://api.example.com/v1/
|
||
"""
|
||
base = base_url.rstrip("/")
|
||
# 只在 path 以版本前缀开头时才去除 base_url 中的该前缀
|
||
# 例如:base="/v1", path="/v1/messages" -> 去除 /v1
|
||
# 例如:base="/v1", path="/chat/completions" -> 不去除(用户可能期望保留)
|
||
for suffix in ("/v1beta", "/v1", "/v2", "/v3"):
|
||
if base.endswith(suffix) and path.startswith(suffix):
|
||
base = base[: -len(suffix)]
|
||
break
|
||
return base
|
||
|
||
|
||
def build_provider_url(
|
||
endpoint: "ProviderEndpoint",
|
||
*,
|
||
query_params: Optional[Dict[str, Any]] = None,
|
||
path_params: Optional[Dict[str, Any]] = None,
|
||
is_stream: bool = False,
|
||
) -> str:
|
||
"""
|
||
根据 endpoint 配置生成请求 URL
|
||
|
||
优先级:
|
||
1. endpoint.custom_path - 自定义路径(支持模板变量如 {model})
|
||
2. API 格式默认路径 - 根据 api_format 自动选择
|
||
|
||
Args:
|
||
endpoint: 端点配置
|
||
query_params: 查询参数
|
||
path_params: 路径模板参数 (如 {model})
|
||
is_stream: 是否为流式请求,用于 Gemini API 选择正确的操作方法
|
||
"""
|
||
# 准备路径参数,添加 Gemini API 所需的 action 参数
|
||
effective_path_params = dict(path_params) if path_params else {}
|
||
|
||
# 为 Gemini API 格式自动添加 action 参数
|
||
resolved_format = resolve_api_format(endpoint.api_format)
|
||
if resolved_format in (APIFormat.GEMINI, APIFormat.GEMINI_CLI):
|
||
if "action" not in effective_path_params:
|
||
effective_path_params["action"] = (
|
||
"streamGenerateContent" if is_stream else "generateContent"
|
||
)
|
||
|
||
# 优先使用 custom_path 字段
|
||
if endpoint.custom_path:
|
||
path = endpoint.custom_path
|
||
if effective_path_params:
|
||
try:
|
||
path = path.format(**effective_path_params)
|
||
except KeyError:
|
||
# 如果模板变量不匹配,保持原路径
|
||
pass
|
||
else:
|
||
# 使用 API 格式的默认路径
|
||
path = _resolve_default_path(endpoint.api_format)
|
||
if effective_path_params:
|
||
try:
|
||
path = path.format(**effective_path_params)
|
||
except KeyError:
|
||
# 如果模板变量不匹配,保持原路径
|
||
pass
|
||
|
||
if not path.startswith("/"):
|
||
path = f"/{path}"
|
||
|
||
# 先确定 path,再根据 path 规范化 base_url
|
||
# base_url 在数据库中是 NOT NULL,类型标注为 Optional 是 SQLAlchemy 限制
|
||
base = _normalize_base_url(endpoint.base_url, path) # type: ignore[arg-type]
|
||
url = f"{base}{path}"
|
||
|
||
# 添加查询参数
|
||
if query_params:
|
||
query_string = urlencode(query_params, doseq=True)
|
||
if query_string:
|
||
url = f"{url}?{query_string}"
|
||
|
||
return url
|
||
|
||
|
||
def _resolve_default_path(api_format: Optional[str]) -> str:
|
||
"""
|
||
根据 API 格式返回默认路径
|
||
"""
|
||
resolved = resolve_api_format(api_format)
|
||
if resolved:
|
||
return get_default_path(resolved)
|
||
|
||
logger.warning(f"Unknown api_format '{api_format}' for endpoint, fallback to '/'")
|
||
return "/"
|