mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
Merge branch 'dev'
This commit is contained in:
350
src/api/base/models_service.py
Normal file
350
src/api/base/models_service.py
Normal file
@@ -0,0 +1,350 @@
|
|||||||
|
"""
|
||||||
|
公共模型查询服务
|
||||||
|
|
||||||
|
为 Claude/OpenAI/Gemini 的 /models 端点提供统一的查询逻辑
|
||||||
|
|
||||||
|
查询逻辑:
|
||||||
|
1. 找到指定 api_format 的活跃端点
|
||||||
|
2. 端点下有活跃的 Key
|
||||||
|
3. Provider 关联了该模型(Model 表)
|
||||||
|
4. Key 的 allowed_models 允许该模型(null = 允许所有)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from src.config.constants import CacheTTL
|
||||||
|
from src.core.cache_service import CacheService
|
||||||
|
from src.core.logger import logger
|
||||||
|
from src.models.database import GlobalModel, Model, Provider, ProviderAPIKey, ProviderEndpoint
|
||||||
|
|
||||||
|
# 缓存 key 前缀
|
||||||
|
_CACHE_KEY_PREFIX = "models:list"
|
||||||
|
_CACHE_TTL = CacheTTL.MODEL # 300 秒
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cache_key(api_formats: list[str]) -> str:
|
||||||
|
"""生成缓存 key"""
|
||||||
|
formats_str = ",".join(sorted(api_formats))
|
||||||
|
return f"{_CACHE_KEY_PREFIX}:{formats_str}"
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_cached_models(api_formats: list[str]) -> Optional[list["ModelInfo"]]:
|
||||||
|
"""从缓存获取模型列表"""
|
||||||
|
cache_key = _get_cache_key(api_formats)
|
||||||
|
try:
|
||||||
|
cached = await CacheService.get(cache_key)
|
||||||
|
if cached:
|
||||||
|
logger.debug(f"[ModelsService] 缓存命中: {cache_key}, {len(cached)} 个模型")
|
||||||
|
return [ModelInfo(**item) for item in cached]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[ModelsService] 缓存读取失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _set_cached_models(api_formats: list[str], models: list["ModelInfo"]) -> None:
|
||||||
|
"""将模型列表写入缓存"""
|
||||||
|
cache_key = _get_cache_key(api_formats)
|
||||||
|
try:
|
||||||
|
data = [asdict(m) for m in models]
|
||||||
|
await CacheService.set(cache_key, data, ttl_seconds=_CACHE_TTL)
|
||||||
|
logger.debug(f"[ModelsService] 已缓存: {cache_key}, {len(models)} 个模型, TTL={_CACHE_TTL}s")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[ModelsService] 缓存写入失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelInfo:
|
||||||
|
"""统一的模型信息结构"""
|
||||||
|
|
||||||
|
id: str # 模型 ID (GlobalModel.name 或 provider_model_name)
|
||||||
|
display_name: str
|
||||||
|
description: Optional[str]
|
||||||
|
created_at: Optional[str] # ISO 格式
|
||||||
|
created_timestamp: int # Unix 时间戳
|
||||||
|
provider_name: str
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]:
|
||||||
|
"""
|
||||||
|
返回有可用端点的 Provider IDs
|
||||||
|
|
||||||
|
条件:
|
||||||
|
- 端点 api_format 匹配
|
||||||
|
- 端点是活跃的
|
||||||
|
- 端点下有活跃的 Key
|
||||||
|
"""
|
||||||
|
rows = (
|
||||||
|
db.query(ProviderEndpoint.provider_id)
|
||||||
|
.join(ProviderAPIKey, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
|
||||||
|
.filter(
|
||||||
|
ProviderEndpoint.api_format.in_(api_formats),
|
||||||
|
ProviderEndpoint.is_active.is_(True),
|
||||||
|
ProviderAPIKey.is_active.is_(True),
|
||||||
|
)
|
||||||
|
.distinct()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return {row[0] for row in rows}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_available_model_ids_for_format(db: Session, api_formats: list[str]) -> set[str]:
|
||||||
|
"""
|
||||||
|
获取指定格式下真正可用的模型 ID 集合
|
||||||
|
|
||||||
|
一个模型可用需满足:
|
||||||
|
1. 端点 api_format 匹配且活跃
|
||||||
|
2. 端点下有活跃的 Key
|
||||||
|
3. **该端点的 Provider 关联了该模型**
|
||||||
|
4. Key 的 allowed_models 允许该模型(null = 允许该 Provider 关联的所有模型)
|
||||||
|
"""
|
||||||
|
# 查询所有匹配格式的活跃端点及其活跃 Key,同时获取 endpoint_id
|
||||||
|
endpoint_keys = (
|
||||||
|
db.query(
|
||||||
|
ProviderEndpoint.id.label("endpoint_id"),
|
||||||
|
ProviderEndpoint.provider_id,
|
||||||
|
ProviderAPIKey.allowed_models,
|
||||||
|
)
|
||||||
|
.join(ProviderAPIKey, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
|
||||||
|
.filter(
|
||||||
|
ProviderEndpoint.api_format.in_(api_formats),
|
||||||
|
ProviderEndpoint.is_active.is_(True),
|
||||||
|
ProviderAPIKey.is_active.is_(True),
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not endpoint_keys:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
# 收集每个 (provider_id, endpoint_id) 对应的 allowed_models
|
||||||
|
# 使用 provider_id 作为 key,因为模型是关联到 Provider 的
|
||||||
|
provider_allowed_models: dict[str, list[Optional[list[str]]]] = {}
|
||||||
|
provider_ids_with_format: set[str] = set()
|
||||||
|
|
||||||
|
for endpoint_id, provider_id, allowed_models in endpoint_keys:
|
||||||
|
provider_ids_with_format.add(provider_id)
|
||||||
|
if provider_id not in provider_allowed_models:
|
||||||
|
provider_allowed_models[provider_id] = []
|
||||||
|
provider_allowed_models[provider_id].append(allowed_models)
|
||||||
|
|
||||||
|
# 只查询那些有匹配格式端点的 Provider 下的模型
|
||||||
|
models = (
|
||||||
|
db.query(Model)
|
||||||
|
.options(joinedload(Model.global_model))
|
||||||
|
.join(Provider)
|
||||||
|
.filter(
|
||||||
|
Model.provider_id.in_(provider_ids_with_format),
|
||||||
|
Model.is_active.is_(True),
|
||||||
|
Provider.is_active.is_(True),
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
available_model_ids: set[str] = set()
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
model_provider_id = model.provider_id
|
||||||
|
global_model = model.global_model
|
||||||
|
model_id = global_model.name if global_model else model.provider_model_name # type: ignore
|
||||||
|
|
||||||
|
if not model_provider_id or not model_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 该模型的 Provider 必须有匹配格式的端点
|
||||||
|
if model_provider_id not in provider_ids_with_format:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查该 provider 下是否有 Key 允许这个模型
|
||||||
|
allowed_lists = provider_allowed_models.get(model_provider_id, [])
|
||||||
|
for allowed_models in allowed_lists:
|
||||||
|
if allowed_models is None:
|
||||||
|
# null = 允许该 Provider 关联的所有模型(已通过上面的查询限制)
|
||||||
|
available_model_ids.add(model_id)
|
||||||
|
break
|
||||||
|
elif model_id in allowed_models:
|
||||||
|
# 明确在允许列表中
|
||||||
|
available_model_ids.add(model_id)
|
||||||
|
break
|
||||||
|
elif global_model and model.provider_model_name in allowed_models:
|
||||||
|
# 也检查 provider_model_name
|
||||||
|
available_model_ids.add(model_id)
|
||||||
|
break
|
||||||
|
|
||||||
|
return available_model_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_model_info(model: Any) -> ModelInfo:
|
||||||
|
"""从 Model 对象提取 ModelInfo"""
|
||||||
|
global_model = model.global_model
|
||||||
|
model_id: str = global_model.name if global_model else model.provider_model_name
|
||||||
|
display_name: str = global_model.display_name if global_model else model.provider_model_name
|
||||||
|
description: Optional[str] = global_model.description if global_model else None
|
||||||
|
created_at: Optional[str] = (
|
||||||
|
model.created_at.strftime("%Y-%m-%dT%H:%M:%SZ") if model.created_at else None
|
||||||
|
)
|
||||||
|
created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0
|
||||||
|
provider_name: str = model.provider.name if model.provider else "unknown"
|
||||||
|
|
||||||
|
return ModelInfo(
|
||||||
|
id=model_id,
|
||||||
|
display_name=display_name,
|
||||||
|
description=description,
|
||||||
|
created_at=created_at,
|
||||||
|
created_timestamp=created_timestamp,
|
||||||
|
provider_name=provider_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_available_models(
|
||||||
|
db: Session,
|
||||||
|
available_provider_ids: set[str],
|
||||||
|
api_formats: Optional[list[str]] = None,
|
||||||
|
) -> list[ModelInfo]:
|
||||||
|
"""
|
||||||
|
获取可用模型列表(已去重,带缓存)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
available_provider_ids: 有可用端点的 Provider ID 集合
|
||||||
|
api_formats: API 格式列表,用于检查 Key 的 allowed_models
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
去重后的 ModelInfo 列表,按创建时间倒序
|
||||||
|
"""
|
||||||
|
if not available_provider_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 尝试从缓存获取
|
||||||
|
if api_formats:
|
||||||
|
cached = await _get_cached_models(api_formats)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# 如果提供了 api_formats,获取真正可用的模型 ID
|
||||||
|
available_model_ids: Optional[set[str]] = None
|
||||||
|
if api_formats:
|
||||||
|
available_model_ids = _get_available_model_ids_for_format(db, api_formats)
|
||||||
|
if not available_model_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
query = (
|
||||||
|
db.query(Model)
|
||||||
|
.options(joinedload(Model.global_model), joinedload(Model.provider))
|
||||||
|
.join(Provider)
|
||||||
|
.filter(
|
||||||
|
Model.is_active.is_(True),
|
||||||
|
Provider.is_active.is_(True),
|
||||||
|
Model.provider_id.in_(available_provider_ids),
|
||||||
|
)
|
||||||
|
.order_by(Model.created_at.desc())
|
||||||
|
)
|
||||||
|
all_models = query.all()
|
||||||
|
|
||||||
|
result: list[ModelInfo] = []
|
||||||
|
seen_model_ids: set[str] = set()
|
||||||
|
|
||||||
|
for model in all_models:
|
||||||
|
info = _extract_model_info(model)
|
||||||
|
|
||||||
|
# 如果有 available_model_ids 限制,检查是否在其中
|
||||||
|
if available_model_ids is not None and info.id not in available_model_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if info.id in seen_model_ids:
|
||||||
|
continue
|
||||||
|
seen_model_ids.add(info.id)
|
||||||
|
|
||||||
|
result.append(info)
|
||||||
|
|
||||||
|
# 写入缓存
|
||||||
|
if api_formats:
|
||||||
|
await _set_cached_models(api_formats, result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def find_model_by_id(
|
||||||
|
db: Session,
|
||||||
|
model_id: str,
|
||||||
|
available_provider_ids: set[str],
|
||||||
|
api_formats: Optional[list[str]] = None,
|
||||||
|
) -> Optional[ModelInfo]:
|
||||||
|
"""
|
||||||
|
按 ID 查找模型
|
||||||
|
|
||||||
|
查找顺序:
|
||||||
|
1. 先按 GlobalModel.name 查找
|
||||||
|
2. 如果没找到任何候选,再按 provider_model_name 查找
|
||||||
|
3. 如果有候选但都不可用,返回 None(不回退)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
model_id: 模型 ID
|
||||||
|
available_provider_ids: 有可用端点的 Provider ID 集合
|
||||||
|
api_formats: API 格式列表,用于检查 Key 的 allowed_models
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelInfo 或 None
|
||||||
|
"""
|
||||||
|
if not available_provider_ids:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 如果提供了 api_formats,获取真正可用的模型 ID
|
||||||
|
available_model_ids: Optional[set[str]] = None
|
||||||
|
if api_formats:
|
||||||
|
available_model_ids = _get_available_model_ids_for_format(db, api_formats)
|
||||||
|
# 快速检查:如果目标模型不在可用列表中,直接返回 None
|
||||||
|
if available_model_ids is not None and model_id not in available_model_ids:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 先按 GlobalModel.name 查找
|
||||||
|
models_by_global = (
|
||||||
|
db.query(Model)
|
||||||
|
.options(joinedload(Model.global_model), joinedload(Model.provider))
|
||||||
|
.join(Provider)
|
||||||
|
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||||||
|
.filter(
|
||||||
|
GlobalModel.name == model_id,
|
||||||
|
Model.is_active.is_(True),
|
||||||
|
Provider.is_active.is_(True),
|
||||||
|
)
|
||||||
|
.order_by(Model.created_at.desc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
model = next(
|
||||||
|
(m for m in models_by_global if m.provider_id in available_provider_ids),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果有候选但都不可用,直接返回 None(不回退 provider_model_name)
|
||||||
|
if not model and models_by_global:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 如果找不到任何候选,按 provider_model_name 查找
|
||||||
|
if not model:
|
||||||
|
models_by_provider_name = (
|
||||||
|
db.query(Model)
|
||||||
|
.options(joinedload(Model.global_model), joinedload(Model.provider))
|
||||||
|
.join(Provider)
|
||||||
|
.filter(
|
||||||
|
Model.provider_model_name == model_id,
|
||||||
|
Model.is_active.is_(True),
|
||||||
|
Provider.is_active.is_(True),
|
||||||
|
)
|
||||||
|
.order_by(Model.created_at.desc())
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
model = next(
|
||||||
|
(m for m in models_by_provider_name if m.provider_id in available_provider_ids),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _extract_model_info(model)
|
||||||
@@ -6,10 +6,13 @@ from .capabilities import router as capabilities_router
|
|||||||
from .catalog import router as catalog_router
|
from .catalog import router as catalog_router
|
||||||
from .claude import router as claude_router
|
from .claude import router as claude_router
|
||||||
from .gemini import router as gemini_router
|
from .gemini import router as gemini_router
|
||||||
|
from .models import router as models_router
|
||||||
from .openai import router as openai_router
|
from .openai import router as openai_router
|
||||||
from .system_catalog import router as system_catalog_router
|
from .system_catalog import router as system_catalog_router
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
# Models API 需要在最前面注册,避免被其他路由的 path 参数捕获
|
||||||
|
router.include_router(models_router)
|
||||||
router.include_router(claude_router, tags=["Claude API"])
|
router.include_router(claude_router, tags=["Claude API"])
|
||||||
router.include_router(openai_router)
|
router.include_router(openai_router)
|
||||||
router.include_router(gemini_router, tags=["Gemini API"])
|
router.include_router(gemini_router, tags=["Gemini API"])
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ Claude API 端点
|
|||||||
|
|
||||||
- /v1/messages - Claude Messages API
|
- /v1/messages - Claude Messages API
|
||||||
- /v1/messages/count_tokens - Token Count API
|
- /v1/messages/count_tokens - Token Count API
|
||||||
|
|
||||||
|
注意: /v1/models 端点由 models.py 统一处理,根据请求头返回对应格式
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Request
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
|||||||
@@ -5,11 +5,9 @@ Gemini API 专属端点
|
|||||||
- /v1beta/models/{model}:generateContent
|
- /v1beta/models/{model}:generateContent
|
||||||
- /v1beta/models/{model}:streamGenerateContent
|
- /v1beta/models/{model}:streamGenerateContent
|
||||||
|
|
||||||
注意: Gemini API 的 model 在 URL 路径中,而不是请求体中
|
注意:
|
||||||
|
- Gemini API 的 model 在 URL 路径中,而不是请求体中
|
||||||
路径配置来源: src.core.api_format_metadata.APIFormat.GEMINI
|
- /v1beta/models (列表) 和 /v1beta/models/{model} (详情) 由 models.py 统一处理
|
||||||
- path_prefix: 本站路径前缀(如 /gemini),通过 router prefix 配置
|
|
||||||
- default_path: 标准 API 路径模板
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Request
|
from fastapi import APIRouter, Depends, Request
|
||||||
@@ -109,7 +107,7 @@ async def stream_generate_content(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 兼容 v1 路径(部分 SDK 可能使用)
|
# 兼容 v1 路径(部分 SDK 可能使用 generateContent)
|
||||||
@router.post("/v1/models/{model}:generateContent")
|
@router.post("/v1/models/{model}:generateContent")
|
||||||
async def generate_content_v1(
|
async def generate_content_v1(
|
||||||
model: str,
|
model: str,
|
||||||
|
|||||||
499
src/api/public/models.py
Normal file
499
src/api/public/models.py
Normal file
@@ -0,0 +1,499 @@
|
|||||||
|
"""
|
||||||
|
统一的 Models API 端点
|
||||||
|
|
||||||
|
根据请求头认证方式自动返回对应格式:
|
||||||
|
- x-api-key + anthropic-version -> Claude 格式
|
||||||
|
- x-goog-api-key (header) 或 ?key= 参数 -> Gemini 格式
|
||||||
|
- Authorization: Bearer (bearer) -> OpenAI 格式
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from src.api.base.models_service import (
|
||||||
|
ModelInfo,
|
||||||
|
find_model_by_id,
|
||||||
|
get_available_provider_ids,
|
||||||
|
list_available_models,
|
||||||
|
)
|
||||||
|
from src.core.api_format_metadata import API_FORMAT_DEFINITIONS, ApiFormatDefinition
|
||||||
|
from src.core.enums import APIFormat
|
||||||
|
from src.core.logger import logger
|
||||||
|
from src.database import get_db
|
||||||
|
from src.models.database import ApiKey, User
|
||||||
|
from src.services.auth.service import AuthService
|
||||||
|
|
||||||
|
router = APIRouter(tags=["Models API"])
|
||||||
|
|
||||||
|
# 各格式对应的 API 格式列表
|
||||||
|
# 注意: CLI 格式是透传格式,Models API 只返回非 CLI 格式的端点支持的模型
|
||||||
|
_CLAUDE_FORMATS = [APIFormat.CLAUDE.value]
|
||||||
|
_OPENAI_FORMATS = [APIFormat.OPENAI.value]
|
||||||
|
_GEMINI_FORMATS = [APIFormat.GEMINI.value]
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_api_key_from_request(
|
||||||
|
request: Request, definition: ApiFormatDefinition
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""根据格式定义从请求中提取 API Key"""
|
||||||
|
auth_header = definition.auth_header.lower()
|
||||||
|
auth_type = definition.auth_type
|
||||||
|
|
||||||
|
header_value = request.headers.get(auth_header)
|
||||||
|
if not header_value:
|
||||||
|
# Gemini 还支持 ?key= 参数
|
||||||
|
if definition.api_format in (APIFormat.GEMINI, APIFormat.GEMINI_CLI):
|
||||||
|
return request.query_params.get("key")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if auth_type == "bearer":
|
||||||
|
# Bearer token: "Bearer xxx"
|
||||||
|
if header_value.lower().startswith("bearer "):
|
||||||
|
return header_value[7:].strip()
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# header 类型: 直接使用值
|
||||||
|
return header_value
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_api_format_and_key(request: Request) -> Tuple[str, Optional[str]]:
|
||||||
|
"""
|
||||||
|
根据请求头检测 API 格式并提取 API Key
|
||||||
|
|
||||||
|
检测顺序:
|
||||||
|
1. x-api-key + anthropic-version -> Claude
|
||||||
|
2. x-goog-api-key (header) 或 ?key= -> Gemini
|
||||||
|
3. Authorization: Bearer -> OpenAI (默认)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(api_format, api_key) 元组
|
||||||
|
"""
|
||||||
|
# Claude: x-api-key + anthropic-version (必须同时存在)
|
||||||
|
claude_def = API_FORMAT_DEFINITIONS[APIFormat.CLAUDE]
|
||||||
|
claude_key = _extract_api_key_from_request(request, claude_def)
|
||||||
|
if claude_key and request.headers.get("anthropic-version"):
|
||||||
|
return "claude", claude_key
|
||||||
|
|
||||||
|
# Gemini: x-goog-api-key (header 类型) 或 ?key=
|
||||||
|
gemini_def = API_FORMAT_DEFINITIONS[APIFormat.GEMINI]
|
||||||
|
gemini_key = _extract_api_key_from_request(request, gemini_def)
|
||||||
|
if gemini_key:
|
||||||
|
return "gemini", gemini_key
|
||||||
|
|
||||||
|
# OpenAI: Authorization: Bearer (默认)
|
||||||
|
# 注意: 如果只有 x-api-key 但没有 anthropic-version,也走 OpenAI 格式
|
||||||
|
openai_def = API_FORMAT_DEFINITIONS[APIFormat.OPENAI]
|
||||||
|
openai_key = _extract_api_key_from_request(request, openai_def)
|
||||||
|
# 如果 OpenAI 格式没有 key,但有 x-api-key,也用它(兼容)
|
||||||
|
if not openai_key and claude_key:
|
||||||
|
openai_key = claude_key
|
||||||
|
return "openai", openai_key
|
||||||
|
|
||||||
|
|
||||||
|
def _get_formats_for_api(api_format: str) -> list[str]:
|
||||||
|
"""获取对应 API 格式的端点格式列表"""
|
||||||
|
if api_format == "claude":
|
||||||
|
return _CLAUDE_FORMATS
|
||||||
|
elif api_format == "gemini":
|
||||||
|
return _GEMINI_FORMATS
|
||||||
|
else:
|
||||||
|
return _OPENAI_FORMATS
|
||||||
|
|
||||||
|
|
||||||
|
def _authenticate(db: Session, api_key: Optional[str]) -> Tuple[Optional[User], Optional[ApiKey]]:
|
||||||
|
"""
|
||||||
|
认证 API Key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(user, api_key_record) 元组,认证失败返回 (None, None)
|
||||||
|
"""
|
||||||
|
if not api_key:
|
||||||
|
logger.debug("[Models] 认证失败: 未提供 API Key")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
result = AuthService.authenticate_api_key(db, api_key)
|
||||||
|
if not result:
|
||||||
|
logger.debug("[Models] 认证失败: API Key 无效")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
user, key_record = result
|
||||||
|
logger.debug(f"[Models] 认证成功: {user.email} (Key: {key_record.name})")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _build_auth_error_response(api_format: str) -> JSONResponse:
|
||||||
|
"""根据 API 格式构建认证错误响应"""
|
||||||
|
if api_format == "claude":
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={
|
||||||
|
"type": "error",
|
||||||
|
"error": {
|
||||||
|
"type": "authentication_error",
|
||||||
|
"message": "Invalid API key provided",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif api_format == "gemini":
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"code": 401,
|
||||||
|
"message": "API key not valid. Please pass a valid API key.",
|
||||||
|
"status": "UNAUTHENTICATED",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"message": "Incorrect API key provided. You can find your API key at https://platform.openai.com/account/api-keys.",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": None,
|
||||||
|
"code": "invalid_api_key",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# 响应构建函数
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _build_claude_list_response(
|
||||||
|
models: list[ModelInfo],
|
||||||
|
before_id: Optional[str],
|
||||||
|
after_id: Optional[str],
|
||||||
|
limit: int,
|
||||||
|
) -> dict:
|
||||||
|
"""构建 Claude 格式的列表响应"""
|
||||||
|
model_data_list = [
|
||||||
|
{
|
||||||
|
"id": m.id,
|
||||||
|
"type": "model",
|
||||||
|
"display_name": m.display_name,
|
||||||
|
"created_at": m.created_at,
|
||||||
|
}
|
||||||
|
for m in models
|
||||||
|
]
|
||||||
|
|
||||||
|
# 处理分页
|
||||||
|
start_idx = 0
|
||||||
|
if after_id:
|
||||||
|
for i, m in enumerate(model_data_list):
|
||||||
|
if m["id"] == after_id:
|
||||||
|
start_idx = i + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
end_idx = len(model_data_list)
|
||||||
|
if before_id:
|
||||||
|
for i, m in enumerate(model_data_list):
|
||||||
|
if m["id"] == before_id:
|
||||||
|
end_idx = i
|
||||||
|
break
|
||||||
|
|
||||||
|
paginated = model_data_list[start_idx:end_idx][:limit]
|
||||||
|
|
||||||
|
first_id = paginated[0]["id"] if paginated else None
|
||||||
|
last_id = paginated[-1]["id"] if paginated else None
|
||||||
|
has_more = len(model_data_list[start_idx:end_idx]) > limit
|
||||||
|
|
||||||
|
return {
|
||||||
|
"data": paginated,
|
||||||
|
"has_more": has_more,
|
||||||
|
"first_id": first_id,
|
||||||
|
"last_id": last_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_openai_list_response(models: list[ModelInfo]) -> dict:
|
||||||
|
"""构建 OpenAI 格式的列表响应"""
|
||||||
|
data = [
|
||||||
|
{
|
||||||
|
"id": m.id,
|
||||||
|
"object": "model",
|
||||||
|
"created": m.created_timestamp,
|
||||||
|
"owned_by": m.provider_name,
|
||||||
|
}
|
||||||
|
for m in models
|
||||||
|
]
|
||||||
|
return {"object": "list", "data": data}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gemini_list_response(
|
||||||
|
models: list[ModelInfo],
|
||||||
|
page_size: int,
|
||||||
|
page_token: Optional[str],
|
||||||
|
) -> dict:
|
||||||
|
"""构建 Gemini 格式的列表响应"""
|
||||||
|
# 处理分页
|
||||||
|
start_idx = 0
|
||||||
|
if page_token:
|
||||||
|
try:
|
||||||
|
start_idx = int(page_token)
|
||||||
|
except ValueError:
|
||||||
|
start_idx = 0
|
||||||
|
|
||||||
|
end_idx = start_idx + page_size
|
||||||
|
paginated_models = models[start_idx:end_idx]
|
||||||
|
|
||||||
|
models_data = [
|
||||||
|
{
|
||||||
|
"name": f"models/{m.id}",
|
||||||
|
"baseModelId": m.id,
|
||||||
|
"version": "001",
|
||||||
|
"displayName": m.display_name,
|
||||||
|
"description": m.description or f"Model {m.id}",
|
||||||
|
"inputTokenLimit": 128000,
|
||||||
|
"outputTokenLimit": 8192,
|
||||||
|
"supportedGenerationMethods": ["generateContent", "countTokens"],
|
||||||
|
"temperature": 1.0,
|
||||||
|
"maxTemperature": 2.0,
|
||||||
|
"topP": 0.95,
|
||||||
|
"topK": 64,
|
||||||
|
}
|
||||||
|
for m in paginated_models
|
||||||
|
]
|
||||||
|
|
||||||
|
response: dict = {"models": models_data}
|
||||||
|
if end_idx < len(models):
|
||||||
|
response["nextPageToken"] = str(end_idx)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _build_claude_model_response(model_info: ModelInfo) -> dict:
|
||||||
|
"""构建 Claude 格式的模型详情响应"""
|
||||||
|
return {
|
||||||
|
"id": model_info.id,
|
||||||
|
"type": "model",
|
||||||
|
"display_name": model_info.display_name,
|
||||||
|
"created_at": model_info.created_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_openai_model_response(model_info: ModelInfo) -> dict:
|
||||||
|
"""构建 OpenAI 格式的模型详情响应"""
|
||||||
|
return {
|
||||||
|
"id": model_info.id,
|
||||||
|
"object": "model",
|
||||||
|
"created": model_info.created_timestamp,
|
||||||
|
"owned_by": model_info.provider_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gemini_model_response(model_info: ModelInfo) -> dict:
|
||||||
|
"""构建 Gemini 格式的模型详情响应"""
|
||||||
|
return {
|
||||||
|
"name": f"models/{model_info.id}",
|
||||||
|
"baseModelId": model_info.id,
|
||||||
|
"version": "001",
|
||||||
|
"displayName": model_info.display_name,
|
||||||
|
"description": model_info.description or f"Model {model_info.id}",
|
||||||
|
"inputTokenLimit": 128000,
|
||||||
|
"outputTokenLimit": 8192,
|
||||||
|
"supportedGenerationMethods": ["generateContent", "countTokens"],
|
||||||
|
"temperature": 1.0,
|
||||||
|
"maxTemperature": 2.0,
|
||||||
|
"topP": 0.95,
|
||||||
|
"topK": 64,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# 404 响应
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def _build_404_response(model_id: str, api_format: str) -> JSONResponse:
|
||||||
|
"""根据 API 格式构建 404 响应"""
|
||||||
|
if api_format == "claude":
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=404,
|
||||||
|
content={
|
||||||
|
"type": "error",
|
||||||
|
"error": {"type": "not_found_error", "message": f"Model '{model_id}' not found"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
elif api_format == "gemini":
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=404,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"code": 404,
|
||||||
|
"message": f"models/{model_id} is not found",
|
||||||
|
"status": "NOT_FOUND",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=404,
|
||||||
|
content={
|
||||||
|
"error": {
|
||||||
|
"message": f"The model '{model_id}' does not exist",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": "model",
|
||||||
|
"code": "model_not_found",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# 路由端点
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/models", response_model=None)
|
||||||
|
async def list_models(
|
||||||
|
request: Request,
|
||||||
|
# Claude 分页参数
|
||||||
|
before_id: Optional[str] = Query(None, description="返回此 ID 之前的结果 (Claude)"),
|
||||||
|
after_id: Optional[str] = Query(None, description="返回此 ID 之后的结果 (Claude)"),
|
||||||
|
limit: int = Query(20, ge=1, le=1000, description="返回数量限制 (Claude)"),
|
||||||
|
# Gemini 分页参数
|
||||||
|
page_size: int = Query(50, alias="pageSize", ge=1, le=1000, description="每页数量 (Gemini)"),
|
||||||
|
page_token: Optional[str] = Query(None, alias="pageToken", description="分页 token (Gemini)"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> Union[dict, JSONResponse]:
|
||||||
|
"""
|
||||||
|
List models - 根据请求头认证方式返回对应格式
|
||||||
|
|
||||||
|
- x-api-key -> Claude 格式
|
||||||
|
- x-goog-api-key 或 ?key= -> Gemini 格式
|
||||||
|
- Authorization: Bearer -> OpenAI 格式
|
||||||
|
"""
|
||||||
|
api_format, api_key = _detect_api_format_and_key(request)
|
||||||
|
logger.info(f"[Models] GET /v1/models | format={api_format}")
|
||||||
|
|
||||||
|
# 认证
|
||||||
|
user, _ = _authenticate(db, api_key)
|
||||||
|
if not user:
|
||||||
|
return _build_auth_error_response(api_format)
|
||||||
|
|
||||||
|
formats = _get_formats_for_api(api_format)
|
||||||
|
|
||||||
|
available_provider_ids = get_available_provider_ids(db, formats)
|
||||||
|
if not available_provider_ids:
|
||||||
|
if api_format == "claude":
|
||||||
|
return {"data": [], "has_more": False, "first_id": None, "last_id": None}
|
||||||
|
elif api_format == "gemini":
|
||||||
|
return {"models": []}
|
||||||
|
else:
|
||||||
|
return {"object": "list", "data": []}
|
||||||
|
|
||||||
|
models = await list_available_models(db, available_provider_ids, formats)
|
||||||
|
logger.debug(f"[Models] 返回 {len(models)} 个模型")
|
||||||
|
|
||||||
|
if api_format == "claude":
|
||||||
|
return _build_claude_list_response(models, before_id, after_id, limit)
|
||||||
|
elif api_format == "gemini":
|
||||||
|
return _build_gemini_list_response(models, page_size, page_token)
|
||||||
|
else:
|
||||||
|
return _build_openai_list_response(models)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1/models/{model_id:path}", response_model=None)
|
||||||
|
async def retrieve_model(
|
||||||
|
model_id: str,
|
||||||
|
request: Request,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> Union[dict, JSONResponse]:
|
||||||
|
"""
|
||||||
|
Retrieve model - 根据请求头认证方式返回对应格式
|
||||||
|
"""
|
||||||
|
api_format, api_key = _detect_api_format_and_key(request)
|
||||||
|
|
||||||
|
# Gemini 格式的 name 带 "models/" 前缀,需要移除
|
||||||
|
if api_format == "gemini" and model_id.startswith("models/"):
|
||||||
|
model_id = model_id[7:]
|
||||||
|
|
||||||
|
logger.info(f"[Models] GET /v1/models/{model_id} | format={api_format}")
|
||||||
|
|
||||||
|
# 认证
|
||||||
|
user, _ = _authenticate(db, api_key)
|
||||||
|
if not user:
|
||||||
|
return _build_auth_error_response(api_format)
|
||||||
|
|
||||||
|
formats = _get_formats_for_api(api_format)
|
||||||
|
|
||||||
|
available_provider_ids = get_available_provider_ids(db, formats)
|
||||||
|
model_info = find_model_by_id(db, model_id, available_provider_ids, formats)
|
||||||
|
|
||||||
|
if not model_info:
|
||||||
|
return _build_404_response(model_id, api_format)
|
||||||
|
|
||||||
|
if api_format == "claude":
|
||||||
|
return _build_claude_model_response(model_info)
|
||||||
|
elif api_format == "gemini":
|
||||||
|
return _build_gemini_model_response(model_info)
|
||||||
|
else:
|
||||||
|
return _build_openai_model_response(model_info)
|
||||||
|
|
||||||
|
|
||||||
|
# Gemini 专用路径 /v1beta/models
|
||||||
|
@router.get("/v1beta/models", response_model=None)
|
||||||
|
async def list_models_gemini(
|
||||||
|
request: Request,
|
||||||
|
page_size: int = Query(50, alias="pageSize", ge=1, le=1000),
|
||||||
|
page_token: Optional[str] = Query(None, alias="pageToken"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> Union[dict, JSONResponse]:
|
||||||
|
"""List models (Gemini v1beta 端点)"""
|
||||||
|
logger.info("[Models] GET /v1beta/models | format=gemini")
|
||||||
|
|
||||||
|
# 从 x-goog-api-key 或 ?key= 提取 API Key
|
||||||
|
gemini_def = API_FORMAT_DEFINITIONS[APIFormat.GEMINI]
|
||||||
|
api_key = _extract_api_key_from_request(request, gemini_def)
|
||||||
|
|
||||||
|
# 认证
|
||||||
|
user, _ = _authenticate(db, api_key)
|
||||||
|
if not user:
|
||||||
|
return _build_auth_error_response("gemini")
|
||||||
|
|
||||||
|
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS)
|
||||||
|
if not available_provider_ids:
|
||||||
|
return {"models": []}
|
||||||
|
|
||||||
|
models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS)
|
||||||
|
logger.debug(f"[Models] 返回 {len(models)} 个模型")
|
||||||
|
response = _build_gemini_list_response(models, page_size, page_token)
|
||||||
|
logger.debug(f"[Models] Gemini 响应: {response}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/v1beta/models/{model_name:path}", response_model=None)
|
||||||
|
async def get_model_gemini(
|
||||||
|
request: Request,
|
||||||
|
model_name: str,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> Union[dict, JSONResponse]:
|
||||||
|
"""Get model (Gemini v1beta 端点)"""
|
||||||
|
# 移除 "models/" 前缀(如果有)
|
||||||
|
model_id = model_name[7:] if model_name.startswith("models/") else model_name
|
||||||
|
logger.info(f"[Models] GET /v1beta/models/{model_id} | format=gemini")
|
||||||
|
|
||||||
|
# 从 x-goog-api-key 或 ?key= 提取 API Key
|
||||||
|
gemini_def = API_FORMAT_DEFINITIONS[APIFormat.GEMINI]
|
||||||
|
api_key = _extract_api_key_from_request(request, gemini_def)
|
||||||
|
|
||||||
|
# 认证
|
||||||
|
user, _ = _authenticate(db, api_key)
|
||||||
|
if not user:
|
||||||
|
return _build_auth_error_response("gemini")
|
||||||
|
|
||||||
|
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS)
|
||||||
|
model_info = find_model_by_id(db, model_id, available_provider_ids, _GEMINI_FORMATS)
|
||||||
|
|
||||||
|
if not model_info:
|
||||||
|
return _build_404_response(model_id, "gemini")
|
||||||
|
|
||||||
|
return _build_gemini_model_response(model_info)
|
||||||
@@ -3,6 +3,8 @@ OpenAI API 端点
|
|||||||
|
|
||||||
- /v1/chat/completions - OpenAI Chat API
|
- /v1/chat/completions - OpenAI Chat API
|
||||||
- /v1/responses - OpenAI Responses API (CLI)
|
- /v1/responses - OpenAI Responses API (CLI)
|
||||||
|
|
||||||
|
注意: /v1/models 端点由 models.py 统一处理,根据请求头返回对应格式
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Request
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ _DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = {
|
|||||||
api_format=APIFormat.CLAUDE,
|
api_format=APIFormat.CLAUDE,
|
||||||
aliases=("claude", "anthropic", "claude_compatible"),
|
aliases=("claude", "anthropic", "claude_compatible"),
|
||||||
default_path="/v1/messages",
|
default_path="/v1/messages",
|
||||||
path_prefix="", # 本站路径前缀,可配置如 "/claude"
|
path_prefix="", # 通过请求头区分格式,不使用路径前缀
|
||||||
auth_header="x-api-key",
|
auth_header="x-api-key",
|
||||||
auth_type="header",
|
auth_type="header",
|
||||||
),
|
),
|
||||||
@@ -85,7 +85,7 @@ _DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = {
|
|||||||
"openai_compatible",
|
"openai_compatible",
|
||||||
),
|
),
|
||||||
default_path="/v1/chat/completions",
|
default_path="/v1/chat/completions",
|
||||||
path_prefix="", # 本站路径前缀,可配置如 "/openai"
|
path_prefix="", # 默认格式
|
||||||
auth_header="Authorization",
|
auth_header="Authorization",
|
||||||
auth_type="bearer",
|
auth_type="bearer",
|
||||||
),
|
),
|
||||||
@@ -93,7 +93,7 @@ _DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = {
|
|||||||
api_format=APIFormat.OPENAI_CLI,
|
api_format=APIFormat.OPENAI_CLI,
|
||||||
aliases=("openai_cli", "responses"),
|
aliases=("openai_cli", "responses"),
|
||||||
default_path="/responses",
|
default_path="/responses",
|
||||||
path_prefix="",
|
path_prefix="", # 与 OPENAI 共享入口
|
||||||
auth_header="Authorization",
|
auth_header="Authorization",
|
||||||
auth_type="bearer",
|
auth_type="bearer",
|
||||||
),
|
),
|
||||||
@@ -101,7 +101,7 @@ _DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = {
|
|||||||
api_format=APIFormat.GEMINI,
|
api_format=APIFormat.GEMINI,
|
||||||
aliases=("gemini", "google", "vertex"),
|
aliases=("gemini", "google", "vertex"),
|
||||||
default_path="/v1beta/models/{model}:{action}",
|
default_path="/v1beta/models/{model}:{action}",
|
||||||
path_prefix="", # 本站路径前缀,可配置如 "/gemini"
|
path_prefix="", # 通过请求头区分格式
|
||||||
auth_header="x-goog-api-key",
|
auth_header="x-goog-api-key",
|
||||||
auth_type="header",
|
auth_type="header",
|
||||||
),
|
),
|
||||||
|
|||||||
Reference in New Issue
Block a user