diff --git a/src/api/base/models_service.py b/src/api/base/models_service.py new file mode 100644 index 0000000..93c00c9 --- /dev/null +++ b/src/api/base/models_service.py @@ -0,0 +1,350 @@ +""" +公共模型查询服务 + +为 Claude/OpenAI/Gemini 的 /models 端点提供统一的查询逻辑 + +查询逻辑: +1. 找到指定 api_format 的活跃端点 +2. 端点下有活跃的 Key +3. Provider 关联了该模型(Model 表) +4. Key 的 allowed_models 允许该模型(null = 允许所有) +""" + +from dataclasses import asdict, dataclass +from typing import Any, Optional + +from sqlalchemy.orm import Session, joinedload + +from src.config.constants import CacheTTL +from src.core.cache_service import CacheService +from src.core.logger import logger +from src.models.database import GlobalModel, Model, Provider, ProviderAPIKey, ProviderEndpoint + +# 缓存 key 前缀 +_CACHE_KEY_PREFIX = "models:list" +_CACHE_TTL = CacheTTL.MODEL # 300 秒 + + +def _get_cache_key(api_formats: list[str]) -> str: + """生成缓存 key""" + formats_str = ",".join(sorted(api_formats)) + return f"{_CACHE_KEY_PREFIX}:{formats_str}" + + +async def _get_cached_models(api_formats: list[str]) -> Optional[list["ModelInfo"]]: + """从缓存获取模型列表""" + cache_key = _get_cache_key(api_formats) + try: + cached = await CacheService.get(cache_key) + if cached: + logger.debug(f"[ModelsService] 缓存命中: {cache_key}, {len(cached)} 个模型") + return [ModelInfo(**item) for item in cached] + except Exception as e: + logger.warning(f"[ModelsService] 缓存读取失败: {e}") + return None + + +async def _set_cached_models(api_formats: list[str], models: list["ModelInfo"]) -> None: + """将模型列表写入缓存""" + cache_key = _get_cache_key(api_formats) + try: + data = [asdict(m) for m in models] + await CacheService.set(cache_key, data, ttl_seconds=_CACHE_TTL) + logger.debug(f"[ModelsService] 已缓存: {cache_key}, {len(models)} 个模型, TTL={_CACHE_TTL}s") + except Exception as e: + logger.warning(f"[ModelsService] 缓存写入失败: {e}") + + +@dataclass +class ModelInfo: + """统一的模型信息结构""" + + id: str # 模型 ID (GlobalModel.name 或 provider_model_name) + display_name: str + description: Optional[str] + created_at: Optional[str] # ISO 格式 + created_timestamp: int # Unix 时间戳 + provider_name: str + + +def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]: + """ + 返回有可用端点的 Provider IDs + + 条件: + - 端点 api_format 匹配 + - 端点是活跃的 + - 端点下有活跃的 Key + """ + rows = ( + db.query(ProviderEndpoint.provider_id) + .join(ProviderAPIKey, ProviderAPIKey.endpoint_id == ProviderEndpoint.id) + .filter( + ProviderEndpoint.api_format.in_(api_formats), + ProviderEndpoint.is_active.is_(True), + ProviderAPIKey.is_active.is_(True), + ) + .distinct() + .all() + ) + return {row[0] for row in rows} + + +def _get_available_model_ids_for_format(db: Session, api_formats: list[str]) -> set[str]: + """ + 获取指定格式下真正可用的模型 ID 集合 + + 一个模型可用需满足: + 1. 端点 api_format 匹配且活跃 + 2. 端点下有活跃的 Key + 3. **该端点的 Provider 关联了该模型** + 4. Key 的 allowed_models 允许该模型(null = 允许该 Provider 关联的所有模型) + """ + # 查询所有匹配格式的活跃端点及其活跃 Key,同时获取 endpoint_id + endpoint_keys = ( + db.query( + ProviderEndpoint.id.label("endpoint_id"), + ProviderEndpoint.provider_id, + ProviderAPIKey.allowed_models, + ) + .join(ProviderAPIKey, ProviderAPIKey.endpoint_id == ProviderEndpoint.id) + .filter( + ProviderEndpoint.api_format.in_(api_formats), + ProviderEndpoint.is_active.is_(True), + ProviderAPIKey.is_active.is_(True), + ) + .all() + ) + + if not endpoint_keys: + return set() + + # 收集每个 (provider_id, endpoint_id) 对应的 allowed_models + # 使用 provider_id 作为 key,因为模型是关联到 Provider 的 + provider_allowed_models: dict[str, list[Optional[list[str]]]] = {} + provider_ids_with_format: set[str] = set() + + for endpoint_id, provider_id, allowed_models in endpoint_keys: + provider_ids_with_format.add(provider_id) + if provider_id not in provider_allowed_models: + provider_allowed_models[provider_id] = [] + provider_allowed_models[provider_id].append(allowed_models) + + # 只查询那些有匹配格式端点的 Provider 下的模型 + models = ( + db.query(Model) + .options(joinedload(Model.global_model)) + .join(Provider) + .filter( + Model.provider_id.in_(provider_ids_with_format), + Model.is_active.is_(True), + Provider.is_active.is_(True), + ) + .all() + ) + + available_model_ids: set[str] = set() + + for model in models: + model_provider_id = model.provider_id + global_model = model.global_model + model_id = global_model.name if global_model else model.provider_model_name # type: ignore + + if not model_provider_id or not model_id: + continue + + # 该模型的 Provider 必须有匹配格式的端点 + if model_provider_id not in provider_ids_with_format: + continue + + # 检查该 provider 下是否有 Key 允许这个模型 + allowed_lists = provider_allowed_models.get(model_provider_id, []) + for allowed_models in allowed_lists: + if allowed_models is None: + # null = 允许该 Provider 关联的所有模型(已通过上面的查询限制) + available_model_ids.add(model_id) + break + elif model_id in allowed_models: + # 明确在允许列表中 + available_model_ids.add(model_id) + break + elif global_model and model.provider_model_name in allowed_models: + # 也检查 provider_model_name + available_model_ids.add(model_id) + break + + return available_model_ids + + +def _extract_model_info(model: Any) -> ModelInfo: + """从 Model 对象提取 ModelInfo""" + global_model = model.global_model + model_id: str = global_model.name if global_model else model.provider_model_name + display_name: str = global_model.display_name if global_model else model.provider_model_name + description: Optional[str] = global_model.description if global_model else None + created_at: Optional[str] = ( + model.created_at.strftime("%Y-%m-%dT%H:%M:%SZ") if model.created_at else None + ) + created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0 + provider_name: str = model.provider.name if model.provider else "unknown" + + return ModelInfo( + id=model_id, + display_name=display_name, + description=description, + created_at=created_at, + created_timestamp=created_timestamp, + provider_name=provider_name, + ) + + +async def list_available_models( + db: Session, + available_provider_ids: set[str], + api_formats: Optional[list[str]] = None, +) -> list[ModelInfo]: + """ + 获取可用模型列表(已去重,带缓存) + + Args: + db: 数据库会话 + available_provider_ids: 有可用端点的 Provider ID 集合 + api_formats: API 格式列表,用于检查 Key 的 allowed_models + + Returns: + 去重后的 ModelInfo 列表,按创建时间倒序 + """ + if not available_provider_ids: + return [] + + # 尝试从缓存获取 + if api_formats: + cached = await _get_cached_models(api_formats) + if cached is not None: + return cached + + # 如果提供了 api_formats,获取真正可用的模型 ID + available_model_ids: Optional[set[str]] = None + if api_formats: + available_model_ids = _get_available_model_ids_for_format(db, api_formats) + if not available_model_ids: + return [] + + query = ( + db.query(Model) + .options(joinedload(Model.global_model), joinedload(Model.provider)) + .join(Provider) + .filter( + Model.is_active.is_(True), + Provider.is_active.is_(True), + Model.provider_id.in_(available_provider_ids), + ) + .order_by(Model.created_at.desc()) + ) + all_models = query.all() + + result: list[ModelInfo] = [] + seen_model_ids: set[str] = set() + + for model in all_models: + info = _extract_model_info(model) + + # 如果有 available_model_ids 限制,检查是否在其中 + if available_model_ids is not None and info.id not in available_model_ids: + continue + + if info.id in seen_model_ids: + continue + seen_model_ids.add(info.id) + + result.append(info) + + # 写入缓存 + if api_formats: + await _set_cached_models(api_formats, result) + + return result + + +def find_model_by_id( + db: Session, + model_id: str, + available_provider_ids: set[str], + api_formats: Optional[list[str]] = None, +) -> Optional[ModelInfo]: + """ + 按 ID 查找模型 + + 查找顺序: + 1. 先按 GlobalModel.name 查找 + 2. 如果没找到任何候选,再按 provider_model_name 查找 + 3. 如果有候选但都不可用,返回 None(不回退) + + Args: + db: 数据库会话 + model_id: 模型 ID + available_provider_ids: 有可用端点的 Provider ID 集合 + api_formats: API 格式列表,用于检查 Key 的 allowed_models + + Returns: + ModelInfo 或 None + """ + if not available_provider_ids: + return None + + # 如果提供了 api_formats,获取真正可用的模型 ID + available_model_ids: Optional[set[str]] = None + if api_formats: + available_model_ids = _get_available_model_ids_for_format(db, api_formats) + # 快速检查:如果目标模型不在可用列表中,直接返回 None + if available_model_ids is not None and model_id not in available_model_ids: + return None + + # 先按 GlobalModel.name 查找 + models_by_global = ( + db.query(Model) + .options(joinedload(Model.global_model), joinedload(Model.provider)) + .join(Provider) + .join(GlobalModel, Model.global_model_id == GlobalModel.id) + .filter( + GlobalModel.name == model_id, + Model.is_active.is_(True), + Provider.is_active.is_(True), + ) + .order_by(Model.created_at.desc()) + .all() + ) + + model = next( + (m for m in models_by_global if m.provider_id in available_provider_ids), + None, + ) + + # 如果有候选但都不可用,直接返回 None(不回退 provider_model_name) + if not model and models_by_global: + return None + + # 如果找不到任何候选,按 provider_model_name 查找 + if not model: + models_by_provider_name = ( + db.query(Model) + .options(joinedload(Model.global_model), joinedload(Model.provider)) + .join(Provider) + .filter( + Model.provider_model_name == model_id, + Model.is_active.is_(True), + Provider.is_active.is_(True), + ) + .order_by(Model.created_at.desc()) + .all() + ) + + model = next( + (m for m in models_by_provider_name if m.provider_id in available_provider_ids), + None, + ) + + if not model: + return None + + return _extract_model_info(model) diff --git a/src/api/public/__init__.py b/src/api/public/__init__.py index 1b04da2..44f90bd 100644 --- a/src/api/public/__init__.py +++ b/src/api/public/__init__.py @@ -6,10 +6,13 @@ from .capabilities import router as capabilities_router from .catalog import router as catalog_router from .claude import router as claude_router from .gemini import router as gemini_router +from .models import router as models_router from .openai import router as openai_router from .system_catalog import router as system_catalog_router router = APIRouter() +# Models API 需要在最前面注册,避免被其他路由的 path 参数捕获 +router.include_router(models_router) router.include_router(claude_router, tags=["Claude API"]) router.include_router(openai_router) router.include_router(gemini_router, tags=["Gemini API"]) diff --git a/src/api/public/claude.py b/src/api/public/claude.py index 97c1a43..7fa945c 100644 --- a/src/api/public/claude.py +++ b/src/api/public/claude.py @@ -3,6 +3,8 @@ Claude API 端点 - /v1/messages - Claude Messages API - /v1/messages/count_tokens - Token Count API + +注意: /v1/models 端点由 models.py 统一处理,根据请求头返回对应格式 """ from fastapi import APIRouter, Depends, Request diff --git a/src/api/public/gemini.py b/src/api/public/gemini.py index 6264a47..566a405 100644 --- a/src/api/public/gemini.py +++ b/src/api/public/gemini.py @@ -5,11 +5,9 @@ Gemini API 专属端点 - /v1beta/models/{model}:generateContent - /v1beta/models/{model}:streamGenerateContent -注意: Gemini API 的 model 在 URL 路径中,而不是请求体中 - -路径配置来源: src.core.api_format_metadata.APIFormat.GEMINI -- path_prefix: 本站路径前缀(如 /gemini),通过 router prefix 配置 -- default_path: 标准 API 路径模板 +注意: +- Gemini API 的 model 在 URL 路径中,而不是请求体中 +- /v1beta/models (列表) 和 /v1beta/models/{model} (详情) 由 models.py 统一处理 """ from fastapi import APIRouter, Depends, Request @@ -109,7 +107,7 @@ async def stream_generate_content( ) -# 兼容 v1 路径(部分 SDK 可能使用) +# 兼容 v1 路径(部分 SDK 可能使用 generateContent) @router.post("/v1/models/{model}:generateContent") async def generate_content_v1( model: str, diff --git a/src/api/public/models.py b/src/api/public/models.py new file mode 100644 index 0000000..cfffc47 --- /dev/null +++ b/src/api/public/models.py @@ -0,0 +1,499 @@ +""" +统一的 Models API 端点 + +根据请求头认证方式自动返回对应格式: +- x-api-key + anthropic-version -> Claude 格式 +- x-goog-api-key (header) 或 ?key= 参数 -> Gemini 格式 +- Authorization: Bearer (bearer) -> OpenAI 格式 +""" + +from typing import Optional, Tuple, Union + +from fastapi import APIRouter, Depends, Query, Request +from fastapi.responses import JSONResponse +from sqlalchemy.orm import Session + +from src.api.base.models_service import ( + ModelInfo, + find_model_by_id, + get_available_provider_ids, + list_available_models, +) +from src.core.api_format_metadata import API_FORMAT_DEFINITIONS, ApiFormatDefinition +from src.core.enums import APIFormat +from src.core.logger import logger +from src.database import get_db +from src.models.database import ApiKey, User +from src.services.auth.service import AuthService + +router = APIRouter(tags=["Models API"]) + +# 各格式对应的 API 格式列表 +# 注意: CLI 格式是透传格式,Models API 只返回非 CLI 格式的端点支持的模型 +_CLAUDE_FORMATS = [APIFormat.CLAUDE.value] +_OPENAI_FORMATS = [APIFormat.OPENAI.value] +_GEMINI_FORMATS = [APIFormat.GEMINI.value] + + +def _extract_api_key_from_request( + request: Request, definition: ApiFormatDefinition +) -> Optional[str]: + """根据格式定义从请求中提取 API Key""" + auth_header = definition.auth_header.lower() + auth_type = definition.auth_type + + header_value = request.headers.get(auth_header) + if not header_value: + # Gemini 还支持 ?key= 参数 + if definition.api_format in (APIFormat.GEMINI, APIFormat.GEMINI_CLI): + return request.query_params.get("key") + return None + + if auth_type == "bearer": + # Bearer token: "Bearer xxx" + if header_value.lower().startswith("bearer "): + return header_value[7:].strip() + return None + else: + # header 类型: 直接使用值 + return header_value + + +def _detect_api_format_and_key(request: Request) -> Tuple[str, Optional[str]]: + """ + 根据请求头检测 API 格式并提取 API Key + + 检测顺序: + 1. x-api-key + anthropic-version -> Claude + 2. x-goog-api-key (header) 或 ?key= -> Gemini + 3. Authorization: Bearer -> OpenAI (默认) + + Returns: + (api_format, api_key) 元组 + """ + # Claude: x-api-key + anthropic-version (必须同时存在) + claude_def = API_FORMAT_DEFINITIONS[APIFormat.CLAUDE] + claude_key = _extract_api_key_from_request(request, claude_def) + if claude_key and request.headers.get("anthropic-version"): + return "claude", claude_key + + # Gemini: x-goog-api-key (header 类型) 或 ?key= + gemini_def = API_FORMAT_DEFINITIONS[APIFormat.GEMINI] + gemini_key = _extract_api_key_from_request(request, gemini_def) + if gemini_key: + return "gemini", gemini_key + + # OpenAI: Authorization: Bearer (默认) + # 注意: 如果只有 x-api-key 但没有 anthropic-version,也走 OpenAI 格式 + openai_def = API_FORMAT_DEFINITIONS[APIFormat.OPENAI] + openai_key = _extract_api_key_from_request(request, openai_def) + # 如果 OpenAI 格式没有 key,但有 x-api-key,也用它(兼容) + if not openai_key and claude_key: + openai_key = claude_key + return "openai", openai_key + + +def _get_formats_for_api(api_format: str) -> list[str]: + """获取对应 API 格式的端点格式列表""" + if api_format == "claude": + return _CLAUDE_FORMATS + elif api_format == "gemini": + return _GEMINI_FORMATS + else: + return _OPENAI_FORMATS + + +def _authenticate(db: Session, api_key: Optional[str]) -> Tuple[Optional[User], Optional[ApiKey]]: + """ + 认证 API Key + + Returns: + (user, api_key_record) 元组,认证失败返回 (None, None) + """ + if not api_key: + logger.debug("[Models] 认证失败: 未提供 API Key") + return None, None + + result = AuthService.authenticate_api_key(db, api_key) + if not result: + logger.debug("[Models] 认证失败: API Key 无效") + return None, None + + user, key_record = result + logger.debug(f"[Models] 认证成功: {user.email} (Key: {key_record.name})") + return result + + +def _build_auth_error_response(api_format: str) -> JSONResponse: + """根据 API 格式构建认证错误响应""" + if api_format == "claude": + return JSONResponse( + status_code=401, + content={ + "type": "error", + "error": { + "type": "authentication_error", + "message": "Invalid API key provided", + }, + }, + ) + elif api_format == "gemini": + return JSONResponse( + status_code=401, + content={ + "error": { + "code": 401, + "message": "API key not valid. Please pass a valid API key.", + "status": "UNAUTHENTICATED", + } + }, + ) + else: + return JSONResponse( + status_code=401, + content={ + "error": { + "message": "Incorrect API key provided. You can find your API key at https://platform.openai.com/account/api-keys.", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + + +# ============================================================================ +# 响应构建函数 +# ============================================================================ + + +def _build_claude_list_response( + models: list[ModelInfo], + before_id: Optional[str], + after_id: Optional[str], + limit: int, +) -> dict: + """构建 Claude 格式的列表响应""" + model_data_list = [ + { + "id": m.id, + "type": "model", + "display_name": m.display_name, + "created_at": m.created_at, + } + for m in models + ] + + # 处理分页 + start_idx = 0 + if after_id: + for i, m in enumerate(model_data_list): + if m["id"] == after_id: + start_idx = i + 1 + break + + end_idx = len(model_data_list) + if before_id: + for i, m in enumerate(model_data_list): + if m["id"] == before_id: + end_idx = i + break + + paginated = model_data_list[start_idx:end_idx][:limit] + + first_id = paginated[0]["id"] if paginated else None + last_id = paginated[-1]["id"] if paginated else None + has_more = len(model_data_list[start_idx:end_idx]) > limit + + return { + "data": paginated, + "has_more": has_more, + "first_id": first_id, + "last_id": last_id, + } + + +def _build_openai_list_response(models: list[ModelInfo]) -> dict: + """构建 OpenAI 格式的列表响应""" + data = [ + { + "id": m.id, + "object": "model", + "created": m.created_timestamp, + "owned_by": m.provider_name, + } + for m in models + ] + return {"object": "list", "data": data} + + +def _build_gemini_list_response( + models: list[ModelInfo], + page_size: int, + page_token: Optional[str], +) -> dict: + """构建 Gemini 格式的列表响应""" + # 处理分页 + start_idx = 0 + if page_token: + try: + start_idx = int(page_token) + except ValueError: + start_idx = 0 + + end_idx = start_idx + page_size + paginated_models = models[start_idx:end_idx] + + models_data = [ + { + "name": f"models/{m.id}", + "baseModelId": m.id, + "version": "001", + "displayName": m.display_name, + "description": m.description or f"Model {m.id}", + "inputTokenLimit": 128000, + "outputTokenLimit": 8192, + "supportedGenerationMethods": ["generateContent", "countTokens"], + "temperature": 1.0, + "maxTemperature": 2.0, + "topP": 0.95, + "topK": 64, + } + for m in paginated_models + ] + + response: dict = {"models": models_data} + if end_idx < len(models): + response["nextPageToken"] = str(end_idx) + + return response + + +def _build_claude_model_response(model_info: ModelInfo) -> dict: + """构建 Claude 格式的模型详情响应""" + return { + "id": model_info.id, + "type": "model", + "display_name": model_info.display_name, + "created_at": model_info.created_at, + } + + +def _build_openai_model_response(model_info: ModelInfo) -> dict: + """构建 OpenAI 格式的模型详情响应""" + return { + "id": model_info.id, + "object": "model", + "created": model_info.created_timestamp, + "owned_by": model_info.provider_name, + } + + +def _build_gemini_model_response(model_info: ModelInfo) -> dict: + """构建 Gemini 格式的模型详情响应""" + return { + "name": f"models/{model_info.id}", + "baseModelId": model_info.id, + "version": "001", + "displayName": model_info.display_name, + "description": model_info.description or f"Model {model_info.id}", + "inputTokenLimit": 128000, + "outputTokenLimit": 8192, + "supportedGenerationMethods": ["generateContent", "countTokens"], + "temperature": 1.0, + "maxTemperature": 2.0, + "topP": 0.95, + "topK": 64, + } + + +# ============================================================================ +# 404 响应 +# ============================================================================ + + +def _build_404_response(model_id: str, api_format: str) -> JSONResponse: + """根据 API 格式构建 404 响应""" + if api_format == "claude": + return JSONResponse( + status_code=404, + content={ + "type": "error", + "error": {"type": "not_found_error", "message": f"Model '{model_id}' not found"}, + }, + ) + elif api_format == "gemini": + return JSONResponse( + status_code=404, + content={ + "error": { + "code": 404, + "message": f"models/{model_id} is not found", + "status": "NOT_FOUND", + } + }, + ) + else: + return JSONResponse( + status_code=404, + content={ + "error": { + "message": f"The model '{model_id}' does not exist", + "type": "invalid_request_error", + "param": "model", + "code": "model_not_found", + } + }, + ) + + +# ============================================================================ +# 路由端点 +# ============================================================================ + + +@router.get("/v1/models", response_model=None) +async def list_models( + request: Request, + # Claude 分页参数 + before_id: Optional[str] = Query(None, description="返回此 ID 之前的结果 (Claude)"), + after_id: Optional[str] = Query(None, description="返回此 ID 之后的结果 (Claude)"), + limit: int = Query(20, ge=1, le=1000, description="返回数量限制 (Claude)"), + # Gemini 分页参数 + page_size: int = Query(50, alias="pageSize", ge=1, le=1000, description="每页数量 (Gemini)"), + page_token: Optional[str] = Query(None, alias="pageToken", description="分页 token (Gemini)"), + db: Session = Depends(get_db), +) -> Union[dict, JSONResponse]: + """ + List models - 根据请求头认证方式返回对应格式 + + - x-api-key -> Claude 格式 + - x-goog-api-key 或 ?key= -> Gemini 格式 + - Authorization: Bearer -> OpenAI 格式 + """ + api_format, api_key = _detect_api_format_and_key(request) + logger.info(f"[Models] GET /v1/models | format={api_format}") + + # 认证 + user, _ = _authenticate(db, api_key) + if not user: + return _build_auth_error_response(api_format) + + formats = _get_formats_for_api(api_format) + + available_provider_ids = get_available_provider_ids(db, formats) + if not available_provider_ids: + if api_format == "claude": + return {"data": [], "has_more": False, "first_id": None, "last_id": None} + elif api_format == "gemini": + return {"models": []} + else: + return {"object": "list", "data": []} + + models = await list_available_models(db, available_provider_ids, formats) + logger.debug(f"[Models] 返回 {len(models)} 个模型") + + if api_format == "claude": + return _build_claude_list_response(models, before_id, after_id, limit) + elif api_format == "gemini": + return _build_gemini_list_response(models, page_size, page_token) + else: + return _build_openai_list_response(models) + + +@router.get("/v1/models/{model_id:path}", response_model=None) +async def retrieve_model( + model_id: str, + request: Request, + db: Session = Depends(get_db), +) -> Union[dict, JSONResponse]: + """ + Retrieve model - 根据请求头认证方式返回对应格式 + """ + api_format, api_key = _detect_api_format_and_key(request) + + # Gemini 格式的 name 带 "models/" 前缀,需要移除 + if api_format == "gemini" and model_id.startswith("models/"): + model_id = model_id[7:] + + logger.info(f"[Models] GET /v1/models/{model_id} | format={api_format}") + + # 认证 + user, _ = _authenticate(db, api_key) + if not user: + return _build_auth_error_response(api_format) + + formats = _get_formats_for_api(api_format) + + available_provider_ids = get_available_provider_ids(db, formats) + model_info = find_model_by_id(db, model_id, available_provider_ids, formats) + + if not model_info: + return _build_404_response(model_id, api_format) + + if api_format == "claude": + return _build_claude_model_response(model_info) + elif api_format == "gemini": + return _build_gemini_model_response(model_info) + else: + return _build_openai_model_response(model_info) + + +# Gemini 专用路径 /v1beta/models +@router.get("/v1beta/models", response_model=None) +async def list_models_gemini( + request: Request, + page_size: int = Query(50, alias="pageSize", ge=1, le=1000), + page_token: Optional[str] = Query(None, alias="pageToken"), + db: Session = Depends(get_db), +) -> Union[dict, JSONResponse]: + """List models (Gemini v1beta 端点)""" + logger.info("[Models] GET /v1beta/models | format=gemini") + + # 从 x-goog-api-key 或 ?key= 提取 API Key + gemini_def = API_FORMAT_DEFINITIONS[APIFormat.GEMINI] + api_key = _extract_api_key_from_request(request, gemini_def) + + # 认证 + user, _ = _authenticate(db, api_key) + if not user: + return _build_auth_error_response("gemini") + + available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) + if not available_provider_ids: + return {"models": []} + + models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS) + logger.debug(f"[Models] 返回 {len(models)} 个模型") + response = _build_gemini_list_response(models, page_size, page_token) + logger.debug(f"[Models] Gemini 响应: {response}") + return response + + +@router.get("/v1beta/models/{model_name:path}", response_model=None) +async def get_model_gemini( + request: Request, + model_name: str, + db: Session = Depends(get_db), +) -> Union[dict, JSONResponse]: + """Get model (Gemini v1beta 端点)""" + # 移除 "models/" 前缀(如果有) + model_id = model_name[7:] if model_name.startswith("models/") else model_name + logger.info(f"[Models] GET /v1beta/models/{model_id} | format=gemini") + + # 从 x-goog-api-key 或 ?key= 提取 API Key + gemini_def = API_FORMAT_DEFINITIONS[APIFormat.GEMINI] + api_key = _extract_api_key_from_request(request, gemini_def) + + # 认证 + user, _ = _authenticate(db, api_key) + if not user: + return _build_auth_error_response("gemini") + + available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) + model_info = find_model_by_id(db, model_id, available_provider_ids, _GEMINI_FORMATS) + + if not model_info: + return _build_404_response(model_id, "gemini") + + return _build_gemini_model_response(model_info) diff --git a/src/api/public/openai.py b/src/api/public/openai.py index 0e5e3dc..5681b3d 100644 --- a/src/api/public/openai.py +++ b/src/api/public/openai.py @@ -3,6 +3,8 @@ OpenAI API 端点 - /v1/chat/completions - OpenAI Chat API - /v1/responses - OpenAI Responses API (CLI) + +注意: /v1/models 端点由 models.py 统一处理,根据请求头返回对应格式 """ from fastapi import APIRouter, Depends, Request diff --git a/src/core/api_format_metadata.py b/src/core/api_format_metadata.py index dce2496..9f913d3 100644 --- a/src/core/api_format_metadata.py +++ b/src/core/api_format_metadata.py @@ -59,7 +59,7 @@ _DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = { api_format=APIFormat.CLAUDE, aliases=("claude", "anthropic", "claude_compatible"), default_path="/v1/messages", - path_prefix="", # 本站路径前缀,可配置如 "/claude" + path_prefix="", # 通过请求头区分格式,不使用路径前缀 auth_header="x-api-key", auth_type="header", ), @@ -85,7 +85,7 @@ _DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = { "openai_compatible", ), default_path="/v1/chat/completions", - path_prefix="", # 本站路径前缀,可配置如 "/openai" + path_prefix="", # 默认格式 auth_header="Authorization", auth_type="bearer", ), @@ -93,7 +93,7 @@ _DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = { api_format=APIFormat.OPENAI_CLI, aliases=("openai_cli", "responses"), default_path="/responses", - path_prefix="", + path_prefix="", # 与 OPENAI 共享入口 auth_header="Authorization", auth_type="bearer", ), @@ -101,7 +101,7 @@ _DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = { api_format=APIFormat.GEMINI, aliases=("gemini", "google", "vertex"), default_path="/v1beta/models/{model}:{action}", - path_prefix="", # 本站路径前缀,可配置如 "/gemini" + path_prefix="", # 通过请求头区分格式 auth_header="x-goog-api-key", auth_type="header", ),