mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-09 19:22:26 +08:00
refactor(cache): optimize cache service architecture and provider transport
This commit is contained in:
@@ -6,7 +6,7 @@
|
||||
- 根据 API 格式或端点配置生成请求 URL
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from src.core.api_format_metadata import get_auth_config, get_default_path, resolve_api_format
|
||||
@@ -14,11 +14,14 @@ from src.core.crypto import crypto_service
|
||||
from src.core.enums import APIFormat
|
||||
from src.core.logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.models.database import ProviderAPIKey, ProviderEndpoint
|
||||
|
||||
|
||||
|
||||
def build_provider_headers(
|
||||
endpoint,
|
||||
key,
|
||||
endpoint: "ProviderEndpoint",
|
||||
key: "ProviderAPIKey",
|
||||
original_headers: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
@@ -28,7 +31,8 @@ def build_provider_headers(
|
||||
"""
|
||||
headers: Dict[str, str] = {}
|
||||
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
# api_key 在数据库中是 NOT NULL,类型标注为 Optional 是 SQLAlchemy 限制
|
||||
decrypted_key = crypto_service.decrypt(key.api_key) # type: ignore[arg-type]
|
||||
|
||||
# 根据 API 格式自动选择认证头
|
||||
api_format = getattr(endpoint, "api_format", None)
|
||||
@@ -68,8 +72,32 @@ def build_provider_headers(
|
||||
return headers
|
||||
|
||||
|
||||
def _normalize_base_url(base_url: str, path: str) -> str:
|
||||
"""
|
||||
规范化 base_url,去除末尾的斜杠和可能与 path 重复的版本前缀。
|
||||
|
||||
只有当 path 以版本前缀开头时,才从 base_url 中移除该前缀,
|
||||
避免拼接出 /v1/v1/messages 这样的重复路径。
|
||||
|
||||
兼容用户填写的各种格式:
|
||||
- https://api.example.com
|
||||
- https://api.example.com/
|
||||
- https://api.example.com/v1
|
||||
- https://api.example.com/v1/
|
||||
"""
|
||||
base = base_url.rstrip("/")
|
||||
# 只在 path 以版本前缀开头时才去除 base_url 中的该前缀
|
||||
# 例如:base="/v1", path="/v1/messages" -> 去除 /v1
|
||||
# 例如:base="/v1", path="/chat/completions" -> 不去除(用户可能期望保留)
|
||||
for suffix in ("/v1beta", "/v1", "/v2", "/v3"):
|
||||
if base.endswith(suffix) and path.startswith(suffix):
|
||||
base = base[: -len(suffix)]
|
||||
break
|
||||
return base
|
||||
|
||||
|
||||
def build_provider_url(
|
||||
endpoint,
|
||||
endpoint: "ProviderEndpoint",
|
||||
*,
|
||||
query_params: Optional[Dict[str, Any]] = None,
|
||||
path_params: Optional[Dict[str, Any]] = None,
|
||||
@@ -88,8 +116,6 @@ def build_provider_url(
|
||||
path_params: 路径模板参数 (如 {model})
|
||||
is_stream: 是否为流式请求,用于 Gemini API 选择正确的操作方法
|
||||
"""
|
||||
base = endpoint.base_url.rstrip("/")
|
||||
|
||||
# 准备路径参数,添加 Gemini API 所需的 action 参数
|
||||
effective_path_params = dict(path_params) if path_params else {}
|
||||
|
||||
@@ -123,6 +149,9 @@ def build_provider_url(
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
|
||||
# 先确定 path,再根据 path 规范化 base_url
|
||||
# base_url 在数据库中是 NOT NULL,类型标注为 Optional 是 SQLAlchemy 限制
|
||||
base = _normalize_base_url(endpoint.base_url, path) # type: ignore[arg-type]
|
||||
url = f"{base}{path}"
|
||||
|
||||
# 添加查询参数
|
||||
@@ -134,7 +163,7 @@ def build_provider_url(
|
||||
return url
|
||||
|
||||
|
||||
def _resolve_default_path(api_format) -> str:
|
||||
def _resolve_default_path(api_format: Optional[str]) -> str:
|
||||
"""
|
||||
根据 API 格式返回默认路径
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user