diff --git a/src/api/base/models_service.py b/src/api/base/models_service.py index d8fcba7..1df1594 100644 --- a/src/api/base/models_service.py +++ b/src/api/base/models_service.py @@ -114,6 +114,7 @@ class AccessRestrictions: allowed_providers: Optional[list[str]] = None # 允许的 Provider ID 列表 allowed_models: Optional[list[str]] = None # 允许的模型名称列表 + allowed_api_formats: Optional[list[str]] = None # 允许的 API 格式列表 @classmethod def from_api_key_and_user( @@ -130,6 +131,7 @@ class AccessRestrictions: """ allowed_providers: Optional[list[str]] = None allowed_models: Optional[list[str]] = None + allowed_api_formats: Optional[list[str]] = None # 优先使用 API Key 的限制 if api_key: @@ -137,15 +139,36 @@ class AccessRestrictions: allowed_providers = api_key.allowed_providers if api_key.allowed_models is not None: allowed_models = api_key.allowed_models + if api_key.allowed_api_formats is not None: + allowed_api_formats = api_key.allowed_api_formats # 如果 API Key 没有限制,检查 User 的限制 + # 注意: User 没有 allowed_api_formats 字段 if user: if allowed_providers is None and user.allowed_providers is not None: allowed_providers = user.allowed_providers if allowed_models is None and user.allowed_models is not None: allowed_models = user.allowed_models - return cls(allowed_providers=allowed_providers, allowed_models=allowed_models) + return cls( + allowed_providers=allowed_providers, + allowed_models=allowed_models, + allowed_api_formats=allowed_api_formats, + ) + + def is_api_format_allowed(self, api_format: str) -> bool: + """ + 检查 API 格式是否被允许 + + Args: + api_format: API 格式 (如 "OPENAI", "CLAUDE", "GEMINI") + + Returns: + True 如果格式被允许,False 否则 + """ + if self.allowed_api_formats is None: + return True + return api_format in self.allowed_api_formats def is_model_allowed(self, model_id: str, provider_id: str) -> bool: """ diff --git a/src/api/public/models.py b/src/api/public/models.py index f92b648..adc488a 100644 --- a/src/api/public/models.py +++ b/src/api/public/models.py @@ -104,6 +104,35 @@ def _get_formats_for_api(api_format: str) -> list[str]: return _OPENAI_FORMATS +def _build_empty_list_response(api_format: str) -> dict: + """根据 API 格式构建空列表响应""" + if api_format == "claude": + return {"data": [], "has_more": False, "first_id": None, "last_id": None} + elif api_format == "gemini": + return {"models": []} + else: + return {"object": "list", "data": []} + + +def _filter_formats_by_restrictions( + formats: list[str], restrictions: AccessRestrictions, api_format: str +) -> Tuple[list[str], Optional[dict]]: + """ + 根据访问限制过滤 API 格式 + + Returns: + (过滤后的格式列表, 空响应或None) + 如果过滤后为空,返回对应格式的空响应 + """ + if restrictions.allowed_api_formats is None: + return formats, None + filtered = [f for f in formats if f in restrictions.allowed_api_formats] + if not filtered: + logger.info(f"[Models] API Key 不允许访问格式 {api_format}") + return [], _build_empty_list_response(api_format) + return filtered, None + + def _authenticate(db: Session, api_key: Optional[str]) -> Tuple[Optional[User], Optional[ApiKey]]: """ 认证 API Key @@ -383,16 +412,15 @@ async def list_models( # 构建访问限制 restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) + # 检查 API 格式限制 formats = _get_formats_for_api(api_format) + formats, empty_response = _filter_formats_by_restrictions(formats, restrictions, api_format) + if empty_response is not None: + return empty_response available_provider_ids = get_available_provider_ids(db, formats) if not available_provider_ids: - if api_format == "claude": - return {"data": [], "has_more": False, "first_id": None, "last_id": None} - elif api_format == "gemini": - return {"models": []} - else: - return {"object": "list", "data": []} + return _build_empty_list_response(api_format) models = await list_available_models(db, available_provider_ids, formats, restrictions) logger.debug(f"[Models] 返回 {len(models)} 个模型") @@ -430,7 +458,11 @@ async def retrieve_model( # 构建访问限制 restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) + # 检查 API 格式限制 formats = _get_formats_for_api(api_format) + formats, _ = _filter_formats_by_restrictions(formats, restrictions, api_format) + if not formats: + return _build_404_response(model_id, api_format) available_provider_ids = get_available_provider_ids(db, formats) model_info = find_model_by_id(db, model_id, available_provider_ids, formats, restrictions) @@ -469,11 +501,18 @@ async def list_models_gemini( # 构建访问限制 restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) - available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) + # 检查 API 格式限制 + formats, empty_response = _filter_formats_by_restrictions( + _GEMINI_FORMATS, restrictions, "gemini" + ) + if empty_response is not None: + return empty_response + + available_provider_ids = get_available_provider_ids(db, formats) if not available_provider_ids: return {"models": []} - models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS, restrictions) + models = await list_available_models(db, available_provider_ids, formats, restrictions) logger.debug(f"[Models] 返回 {len(models)} 个模型") response = _build_gemini_list_response(models, page_size, page_token) logger.debug(f"[Models] Gemini 响应: {response}") @@ -503,9 +542,14 @@ async def get_model_gemini( # 构建访问限制 restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) - available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) + # 检查 API 格式限制 + formats, _ = _filter_formats_by_restrictions(_GEMINI_FORMATS, restrictions, "gemini") + if not formats: + return _build_404_response(model_id, "gemini") + + available_provider_ids = get_available_provider_ids(db, formats) model_info = find_model_by_id( - db, model_id, available_provider_ids, _GEMINI_FORMATS, restrictions + db, model_id, available_provider_ids, formats, restrictions ) if not model_info: