Files
Aether/src/services/provider/transport.py

176 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
统一的 Provider 请求构建工具。
负责:
- 根据 endpoint/key 构建标准请求头
- 根据 API 格式或端点配置生成请求 URL
"""
from typing import TYPE_CHECKING, Any, Dict, Optional
from urllib.parse import urlencode
from src.core.api_format_metadata import get_auth_config, get_default_path, resolve_api_format
from src.core.crypto import crypto_service
from src.core.enums import APIFormat
from src.core.logger import logger
if TYPE_CHECKING:
from src.models.database import ProviderAPIKey, ProviderEndpoint
def build_provider_headers(
endpoint: "ProviderEndpoint",
key: "ProviderAPIKey",
original_headers: Optional[Dict[str, str]] = None,
*,
extra_headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
"""
根据 endpoint/key 构建请求头,并透传客户端自定义头。
"""
headers: Dict[str, str] = {}
# api_key 在数据库中是 NOT NULL类型标注为 Optional 是 SQLAlchemy 限制
decrypted_key = crypto_service.decrypt(key.api_key) # type: ignore[arg-type]
# 根据 API 格式自动选择认证头
api_format = getattr(endpoint, "api_format", None)
resolved_format = resolve_api_format(api_format)
auth_header, auth_type = (
get_auth_config(resolved_format) if resolved_format else ("Authorization", "bearer")
)
if auth_type == "bearer":
headers[auth_header] = f"Bearer {decrypted_key}"
else:
headers[auth_header] = decrypted_key
if endpoint.headers:
headers.update(endpoint.headers)
excluded_headers = {
"host",
"authorization",
"x-api-key",
"x-goog-api-key",
"content-length",
"transfer-encoding",
}
if original_headers:
for name, value in original_headers.items():
if name.lower() not in excluded_headers:
headers[name] = value
if extra_headers:
headers.update(extra_headers)
if "Content-Type" not in headers and "content-type" not in headers:
headers["Content-Type"] = "application/json"
return headers
def _normalize_base_url(base_url: str, path: str) -> str:
"""
规范化 base_url去除末尾的斜杠和可能与 path 重复的版本前缀。
只有当 path 以版本前缀开头时,才从 base_url 中移除该前缀,
避免拼接出 /v1/v1/messages 这样的重复路径。
兼容用户填写的各种格式:
- https://api.example.com
- https://api.example.com/
- https://api.example.com/v1
- https://api.example.com/v1/
"""
base = base_url.rstrip("/")
# 只在 path 以版本前缀开头时才去除 base_url 中的该前缀
# 例如base="/v1", path="/v1/messages" -> 去除 /v1
# 例如base="/v1", path="/chat/completions" -> 不去除(用户可能期望保留)
for suffix in ("/v1beta", "/v1", "/v2", "/v3"):
if base.endswith(suffix) and path.startswith(suffix):
base = base[: -len(suffix)]
break
return base
def build_provider_url(
endpoint: "ProviderEndpoint",
*,
query_params: Optional[Dict[str, Any]] = None,
path_params: Optional[Dict[str, Any]] = None,
is_stream: bool = False,
) -> str:
"""
根据 endpoint 配置生成请求 URL
优先级:
1. endpoint.custom_path - 自定义路径(支持模板变量如 {model}
2. API 格式默认路径 - 根据 api_format 自动选择
Args:
endpoint: 端点配置
query_params: 查询参数
path_params: 路径模板参数 (如 {model})
is_stream: 是否为流式请求,用于 Gemini API 选择正确的操作方法
"""
# 准备路径参数,添加 Gemini API 所需的 action 参数
effective_path_params = dict(path_params) if path_params else {}
# 为 Gemini API 格式自动添加 action 参数
resolved_format = resolve_api_format(endpoint.api_format)
if resolved_format in (APIFormat.GEMINI, APIFormat.GEMINI_CLI):
if "action" not in effective_path_params:
effective_path_params["action"] = (
"streamGenerateContent" if is_stream else "generateContent"
)
# 优先使用 custom_path 字段
if endpoint.custom_path:
path = endpoint.custom_path
if effective_path_params:
try:
path = path.format(**effective_path_params)
except KeyError:
# 如果模板变量不匹配,保持原路径
pass
else:
# 使用 API 格式的默认路径
path = _resolve_default_path(endpoint.api_format)
if effective_path_params:
try:
path = path.format(**effective_path_params)
except KeyError:
# 如果模板变量不匹配,保持原路径
pass
if not path.startswith("/"):
path = f"/{path}"
# 先确定 path再根据 path 规范化 base_url
# base_url 在数据库中是 NOT NULL类型标注为 Optional 是 SQLAlchemy 限制
base = _normalize_base_url(endpoint.base_url, path) # type: ignore[arg-type]
url = f"{base}{path}"
# 添加查询参数
if query_params:
query_string = urlencode(query_params, doseq=True)
if query_string:
url = f"{url}?{query_string}"
return url
def _resolve_default_path(api_format: Optional[str]) -> str:
"""
根据 API 格式返回默认路径
"""
resolved = resolve_api_format(api_format)
if resolved:
return get_default_path(resolved)
logger.warning(f"Unknown api_format '{api_format}' for endpoint, fallback to '/'")
return "/"