mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-12 12:38:31 +08:00
Initial commit
This commit is contained in:
272
src/core/api_format_metadata.py
Normal file
272
src/core/api_format_metadata.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
集中维护 API 格式的元数据,避免新增格式时到处修改常量。
|
||||
|
||||
此模块与 src/formats/ 的 FormatProtocol 系统配合使用:
|
||||
- api_format_metadata: 定义格式的元数据(别名、默认路径)
|
||||
- src/formats/: 定义格式的协议实现(解析、转换、验证)
|
||||
|
||||
使用方式:
|
||||
# 解析格式别名
|
||||
from src.core.api_format_metadata import resolve_api_format
|
||||
api_format = resolve_api_format("claude") # -> APIFormat.CLAUDE
|
||||
|
||||
# 获取格式协议
|
||||
from src.core.api_format_metadata import get_format_protocol
|
||||
protocol = get_format_protocol(APIFormat.CLAUDE) # -> ClaudeProtocol
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from types import MappingProxyType
|
||||
from typing import Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Union
|
||||
|
||||
from .enums import APIFormat
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ApiFormatDefinition:
|
||||
"""
|
||||
描述一个 API 格式的所有通用信息。
|
||||
|
||||
- aliases: 用于 detect_api_format 的 provider 别名或快捷名称
|
||||
- default_path: 上游默认请求路径(如 /v1/messages),可通过 Endpoint.custom_path 覆盖
|
||||
- path_prefix: 本站路径前缀(如 /claude, /openai),为空表示无前缀
|
||||
- auth_header: 认证头名称 (如 "x-api-key", "x-goog-api-key")
|
||||
- auth_type: 认证类型 ("header" 直接放值, "bearer" 加 Bearer 前缀)
|
||||
"""
|
||||
|
||||
api_format: APIFormat
|
||||
aliases: Sequence[str] = field(default_factory=tuple)
|
||||
default_path: str = "/" # 上游默认请求路径
|
||||
path_prefix: str = "" # 本站路径前缀,为空表示无前缀
|
||||
auth_header: str = "Authorization"
|
||||
auth_type: str = "bearer" # "bearer" or "header"
|
||||
|
||||
def iter_aliases(self) -> Iterable[str]:
|
||||
"""返回大小写统一后的别名集合,包含枚举名本身。"""
|
||||
yield normalize_alias_value(self.api_format.value)
|
||||
for alias in self.aliases:
|
||||
normalized = normalize_alias_value(alias)
|
||||
if normalized:
|
||||
yield normalized
|
||||
|
||||
|
||||
_DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = {
|
||||
APIFormat.CLAUDE: ApiFormatDefinition(
|
||||
api_format=APIFormat.CLAUDE,
|
||||
aliases=("claude", "anthropic", "claude_compatible"),
|
||||
default_path="/v1/messages",
|
||||
path_prefix="", # 本站路径前缀,可配置如 "/claude"
|
||||
auth_header="x-api-key",
|
||||
auth_type="header",
|
||||
),
|
||||
APIFormat.CLAUDE_CLI: ApiFormatDefinition(
|
||||
api_format=APIFormat.CLAUDE_CLI,
|
||||
aliases=("claude_cli", "claude-cli"),
|
||||
default_path="/v1/messages",
|
||||
path_prefix="", # 与 CLAUDE 共享入口,通过 header 区分
|
||||
auth_header="authorization",
|
||||
auth_type="bearer",
|
||||
),
|
||||
APIFormat.OPENAI: ApiFormatDefinition(
|
||||
api_format=APIFormat.OPENAI,
|
||||
aliases=(
|
||||
"openai",
|
||||
"deepseek",
|
||||
"grok",
|
||||
"moonshot",
|
||||
"zhipu",
|
||||
"qwen",
|
||||
"baichuan",
|
||||
"minimax",
|
||||
"openai_compatible",
|
||||
),
|
||||
default_path="/v1/chat/completions",
|
||||
path_prefix="", # 本站路径前缀,可配置如 "/openai"
|
||||
auth_header="Authorization",
|
||||
auth_type="bearer",
|
||||
),
|
||||
APIFormat.OPENAI_CLI: ApiFormatDefinition(
|
||||
api_format=APIFormat.OPENAI_CLI,
|
||||
aliases=("openai_cli", "responses"),
|
||||
default_path="/responses",
|
||||
path_prefix="",
|
||||
auth_header="Authorization",
|
||||
auth_type="bearer",
|
||||
),
|
||||
APIFormat.GEMINI: ApiFormatDefinition(
|
||||
api_format=APIFormat.GEMINI,
|
||||
aliases=("gemini", "google", "vertex"),
|
||||
default_path="/v1beta/models/{model}:{action}",
|
||||
path_prefix="", # 本站路径前缀,可配置如 "/gemini"
|
||||
auth_header="x-goog-api-key",
|
||||
auth_type="header",
|
||||
),
|
||||
APIFormat.GEMINI_CLI: ApiFormatDefinition(
|
||||
api_format=APIFormat.GEMINI_CLI,
|
||||
aliases=("gemini_cli", "gemini-cli"),
|
||||
default_path="/v1beta/models/{model}:{action}",
|
||||
path_prefix="", # 与 GEMINI 共享入口
|
||||
auth_header="x-goog-api-key",
|
||||
auth_type="header",
|
||||
),
|
||||
}
|
||||
|
||||
# 对外只暴露只读视图,避免被随意修改
|
||||
API_FORMAT_DEFINITIONS: Mapping[APIFormat, ApiFormatDefinition] = MappingProxyType(_DEFINITIONS)
|
||||
|
||||
|
||||
def get_api_format_definition(api_format: APIFormat) -> ApiFormatDefinition:
|
||||
"""获取指定格式的定义,不存在时抛出 KeyError。"""
|
||||
return API_FORMAT_DEFINITIONS[api_format]
|
||||
|
||||
|
||||
def list_api_format_definitions() -> List[ApiFormatDefinition]:
|
||||
"""返回所有定义的浅拷贝列表,供遍历使用。"""
|
||||
return list(API_FORMAT_DEFINITIONS.values())
|
||||
|
||||
|
||||
def build_alias_lookup() -> Dict[str, APIFormat]:
|
||||
"""
|
||||
构建 alias -> APIFormat 的查找表。
|
||||
每次调用都会返回新的 dict,避免可变全局引发并发问题。
|
||||
"""
|
||||
lookup: MutableMapping[str, APIFormat] = {}
|
||||
for definition in API_FORMAT_DEFINITIONS.values():
|
||||
for alias in definition.iter_aliases():
|
||||
lookup.setdefault(alias, definition.api_format)
|
||||
return dict(lookup)
|
||||
|
||||
|
||||
def get_default_path(api_format: APIFormat) -> str:
|
||||
"""
|
||||
获取该格式的上游默认请求路径。
|
||||
|
||||
可通过 Endpoint.custom_path 覆盖。
|
||||
"""
|
||||
definition = API_FORMAT_DEFINITIONS.get(api_format)
|
||||
return definition.default_path if definition else "/"
|
||||
|
||||
|
||||
def get_local_path(api_format: APIFormat) -> str:
|
||||
"""
|
||||
获取该格式的本站入口路径。
|
||||
|
||||
本站入口路径 = path_prefix + default_path
|
||||
例如:path_prefix="/openai" + default_path="/v1/chat/completions" -> "/openai/v1/chat/completions"
|
||||
"""
|
||||
definition = API_FORMAT_DEFINITIONS.get(api_format)
|
||||
if definition:
|
||||
prefix = definition.path_prefix or ""
|
||||
return prefix + definition.default_path
|
||||
return "/"
|
||||
|
||||
|
||||
def get_auth_config(api_format: APIFormat) -> tuple[str, str]:
|
||||
"""
|
||||
获取该格式的认证配置。
|
||||
|
||||
Returns:
|
||||
(auth_header, auth_type) 元组
|
||||
- auth_header: 认证头名称
|
||||
- auth_type: "bearer" 或 "header"
|
||||
"""
|
||||
definition = API_FORMAT_DEFINITIONS.get(api_format)
|
||||
if definition:
|
||||
return definition.auth_header, definition.auth_type
|
||||
return "Authorization", "bearer"
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _alias_lookup_cache() -> Dict[str, APIFormat]:
|
||||
"""缓存 alias -> APIFormat 查找表,减少重复构建。"""
|
||||
return build_alias_lookup()
|
||||
|
||||
|
||||
def resolve_api_format_alias(value: str) -> Optional[APIFormat]:
|
||||
"""根据别名查找 APIFormat,找不到时返回 None。"""
|
||||
if not value:
|
||||
return None
|
||||
normalized = normalize_alias_value(value)
|
||||
if not normalized:
|
||||
return None
|
||||
return _alias_lookup_cache().get(normalized)
|
||||
|
||||
|
||||
def resolve_api_format(
|
||||
value: Union[str, APIFormat, None],
|
||||
default: Optional[APIFormat] = None,
|
||||
) -> Optional[APIFormat]:
|
||||
"""
|
||||
将任意字符串/枚举值解析为 APIFormat。
|
||||
|
||||
Args:
|
||||
value: 可以是 APIFormat 或任意字符串/别名
|
||||
default: 未解析成功时返回的默认值
|
||||
"""
|
||||
if isinstance(value, APIFormat):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
return default
|
||||
upper = stripped.upper()
|
||||
if upper in APIFormat.__members__:
|
||||
return APIFormat[upper]
|
||||
alias = resolve_api_format_alias(stripped)
|
||||
if alias:
|
||||
return alias
|
||||
return default
|
||||
|
||||
|
||||
def register_api_format_definition(definition: ApiFormatDefinition, *, override: bool = False):
|
||||
"""
|
||||
注册或覆盖 API 格式定义,允许运行时扩展。
|
||||
|
||||
Args:
|
||||
definition: 要注册的定义
|
||||
override: 若目标枚举已存在,是否允许覆盖
|
||||
"""
|
||||
existing = _DEFINITIONS.get(definition.api_format)
|
||||
if existing and not override:
|
||||
raise ValueError(f"{definition.api_format.value} 已存在,如需覆盖请设置 override=True")
|
||||
_DEFINITIONS[definition.api_format] = definition
|
||||
_refresh_metadata_cache()
|
||||
|
||||
|
||||
def _refresh_metadata_cache():
|
||||
"""更新别名缓存,供注册函数调用。"""
|
||||
_alias_lookup_cache.cache_clear()
|
||||
|
||||
|
||||
def normalize_alias_value(value: str) -> str:
|
||||
"""统一别名格式:去空白、转小写,并将非字母数字转为单个下划线。"""
|
||||
if value is None:
|
||||
return ""
|
||||
text = value.strip().lower()
|
||||
# 将所有非字母数字字符替换为下划线,并折叠连续的下划线
|
||||
text = re.sub(r"[^a-z0-9]+", "_", text)
|
||||
return text.strip("_")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 格式判断工具
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def is_cli_api_format(api_format: APIFormat) -> bool:
|
||||
"""
|
||||
判断是否为 CLI 透传格式。
|
||||
|
||||
Args:
|
||||
api_format: APIFormat 枚举值
|
||||
|
||||
Returns:
|
||||
True 如果是 CLI 格式
|
||||
"""
|
||||
from src.api.handlers.base.parsers import is_cli_format
|
||||
|
||||
return is_cli_format(api_format.value)
|
||||
Reference in New Issue
Block a user