Files
Aether/src/core/api_format_metadata.py
fawney19 beae7a2616 feat(api): add unified Models API endpoint
- Add models_service.py with model query logic and caching
- Add models.py unified endpoint supporting Claude/OpenAI/Gemini formats
- Auto-detect API format based on request headers
- Support /v1/models and /v1beta/models (Gemini) paths
- Update route registration and comments
2025-12-14 20:01:19 +08:00

273 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
集中维护 API 格式的元数据,避免新增格式时到处修改常量。
此模块与 src/formats/ 的 FormatProtocol 系统配合使用:
- api_format_metadata: 定义格式的元数据(别名、默认路径)
- src/formats/: 定义格式的协议实现(解析、转换、验证)
使用方式:
# 解析格式别名
from src.core.api_format_metadata import resolve_api_format
api_format = resolve_api_format("claude") # -> APIFormat.CLAUDE
# 获取格式协议
from src.core.api_format_metadata import get_format_protocol
protocol = get_format_protocol(APIFormat.CLAUDE) # -> ClaudeProtocol
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from functools import lru_cache
from types import MappingProxyType
from typing import Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Union
from .enums import APIFormat
@dataclass(frozen=True)
class ApiFormatDefinition:
"""
描述一个 API 格式的所有通用信息。
- aliases: 用于 detect_api_format 的 provider 别名或快捷名称
- default_path: 上游默认请求路径(如 /v1/messages可通过 Endpoint.custom_path 覆盖
- path_prefix: 本站路径前缀(如 /claude, /openai为空表示无前缀
- auth_header: 认证头名称 (如 "x-api-key", "x-goog-api-key")
- auth_type: 认证类型 ("header" 直接放值, "bearer" 加 Bearer 前缀)
"""
api_format: APIFormat
aliases: Sequence[str] = field(default_factory=tuple)
default_path: str = "/" # 上游默认请求路径
path_prefix: str = "" # 本站路径前缀,为空表示无前缀
auth_header: str = "Authorization"
auth_type: str = "bearer" # "bearer" or "header"
def iter_aliases(self) -> Iterable[str]:
"""返回大小写统一后的别名集合,包含枚举名本身。"""
yield normalize_alias_value(self.api_format.value)
for alias in self.aliases:
normalized = normalize_alias_value(alias)
if normalized:
yield normalized
_DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = {
APIFormat.CLAUDE: ApiFormatDefinition(
api_format=APIFormat.CLAUDE,
aliases=("claude", "anthropic", "claude_compatible"),
default_path="/v1/messages",
path_prefix="", # 通过请求头区分格式,不使用路径前缀
auth_header="x-api-key",
auth_type="header",
),
APIFormat.CLAUDE_CLI: ApiFormatDefinition(
api_format=APIFormat.CLAUDE_CLI,
aliases=("claude_cli", "claude-cli"),
default_path="/v1/messages",
path_prefix="", # 与 CLAUDE 共享入口,通过 header 区分
auth_header="authorization",
auth_type="bearer",
),
APIFormat.OPENAI: ApiFormatDefinition(
api_format=APIFormat.OPENAI,
aliases=(
"openai",
"deepseek",
"grok",
"moonshot",
"zhipu",
"qwen",
"baichuan",
"minimax",
"openai_compatible",
),
default_path="/v1/chat/completions",
path_prefix="", # 默认格式
auth_header="Authorization",
auth_type="bearer",
),
APIFormat.OPENAI_CLI: ApiFormatDefinition(
api_format=APIFormat.OPENAI_CLI,
aliases=("openai_cli", "responses"),
default_path="/responses",
path_prefix="", # 与 OPENAI 共享入口
auth_header="Authorization",
auth_type="bearer",
),
APIFormat.GEMINI: ApiFormatDefinition(
api_format=APIFormat.GEMINI,
aliases=("gemini", "google", "vertex"),
default_path="/v1beta/models/{model}:{action}",
path_prefix="", # 通过请求头区分格式
auth_header="x-goog-api-key",
auth_type="header",
),
APIFormat.GEMINI_CLI: ApiFormatDefinition(
api_format=APIFormat.GEMINI_CLI,
aliases=("gemini_cli", "gemini-cli"),
default_path="/v1beta/models/{model}:{action}",
path_prefix="", # 与 GEMINI 共享入口
auth_header="x-goog-api-key",
auth_type="header",
),
}
# 对外只暴露只读视图,避免被随意修改
API_FORMAT_DEFINITIONS: Mapping[APIFormat, ApiFormatDefinition] = MappingProxyType(_DEFINITIONS)
def get_api_format_definition(api_format: APIFormat) -> ApiFormatDefinition:
"""获取指定格式的定义,不存在时抛出 KeyError。"""
return API_FORMAT_DEFINITIONS[api_format]
def list_api_format_definitions() -> List[ApiFormatDefinition]:
"""返回所有定义的浅拷贝列表,供遍历使用。"""
return list(API_FORMAT_DEFINITIONS.values())
def build_alias_lookup() -> Dict[str, APIFormat]:
"""
构建 alias -> APIFormat 的查找表。
每次调用都会返回新的 dict避免可变全局引发并发问题。
"""
lookup: MutableMapping[str, APIFormat] = {}
for definition in API_FORMAT_DEFINITIONS.values():
for alias in definition.iter_aliases():
lookup.setdefault(alias, definition.api_format)
return dict(lookup)
def get_default_path(api_format: APIFormat) -> str:
"""
获取该格式的上游默认请求路径。
可通过 Endpoint.custom_path 覆盖。
"""
definition = API_FORMAT_DEFINITIONS.get(api_format)
return definition.default_path if definition else "/"
def get_local_path(api_format: APIFormat) -> str:
"""
获取该格式的本站入口路径。
本站入口路径 = path_prefix + default_path
例如path_prefix="/openai" + default_path="/v1/chat/completions" -> "/openai/v1/chat/completions"
"""
definition = API_FORMAT_DEFINITIONS.get(api_format)
if definition:
prefix = definition.path_prefix or ""
return prefix + definition.default_path
return "/"
def get_auth_config(api_format: APIFormat) -> tuple[str, str]:
"""
获取该格式的认证配置。
Returns:
(auth_header, auth_type) 元组
- auth_header: 认证头名称
- auth_type: "bearer""header"
"""
definition = API_FORMAT_DEFINITIONS.get(api_format)
if definition:
return definition.auth_header, definition.auth_type
return "Authorization", "bearer"
@lru_cache(maxsize=1)
def _alias_lookup_cache() -> Dict[str, APIFormat]:
"""缓存 alias -> APIFormat 查找表,减少重复构建。"""
return build_alias_lookup()
def resolve_api_format_alias(value: str) -> Optional[APIFormat]:
"""根据别名查找 APIFormat找不到时返回 None。"""
if not value:
return None
normalized = normalize_alias_value(value)
if not normalized:
return None
return _alias_lookup_cache().get(normalized)
def resolve_api_format(
value: Union[str, APIFormat, None],
default: Optional[APIFormat] = None,
) -> Optional[APIFormat]:
"""
将任意字符串/枚举值解析为 APIFormat。
Args:
value: 可以是 APIFormat 或任意字符串/别名
default: 未解析成功时返回的默认值
"""
if isinstance(value, APIFormat):
return value
if isinstance(value, str):
stripped = value.strip()
if not stripped:
return default
upper = stripped.upper()
if upper in APIFormat.__members__:
return APIFormat[upper]
alias = resolve_api_format_alias(stripped)
if alias:
return alias
return default
def register_api_format_definition(definition: ApiFormatDefinition, *, override: bool = False):
"""
注册或覆盖 API 格式定义,允许运行时扩展。
Args:
definition: 要注册的定义
override: 若目标枚举已存在,是否允许覆盖
"""
existing = _DEFINITIONS.get(definition.api_format)
if existing and not override:
raise ValueError(f"{definition.api_format.value} 已存在,如需覆盖请设置 override=True")
_DEFINITIONS[definition.api_format] = definition
_refresh_metadata_cache()
def _refresh_metadata_cache():
"""更新别名缓存,供注册函数调用。"""
_alias_lookup_cache.cache_clear()
def normalize_alias_value(value: str) -> str:
"""统一别名格式:去空白、转小写,并将非字母数字转为单个下划线。"""
if value is None:
return ""
text = value.strip().lower()
# 将所有非字母数字字符替换为下划线,并折叠连续的下划线
text = re.sub(r"[^a-z0-9]+", "_", text)
return text.strip("_")
# =============================================================================
# 格式判断工具
# =============================================================================
def is_cli_api_format(api_format: APIFormat) -> bool:
"""
判断是否为 CLI 透传格式。
Args:
api_format: APIFormat 枚举值
Returns:
True 如果是 CLI 格式
"""
from src.api.handlers.base.parsers import is_cli_format
return is_cli_format(api_format.value)