feat: 添加 API 格式访问限制

扩展访问限制功能,支持 API Key 级别的 API 格式限制(OPENAI、CLAUDE、GEMINI)。

- AccessRestrictions 新增 allowed_api_formats 字段
- 新增 is_api_format_allowed() 方法检查格式权限
- models.py 添加 _filter_formats_by_restrictions() 函数过滤 API 格式
- 在所有模型列表和查询端点应用格式限制检查
- 添加 _build_empty_list_response() 统一空响应构建逻辑
This commit is contained in:
fawney19
2025-12-30 17:50:39 +08:00
parent e20a09f15a
commit 394cc536a9
2 changed files with 78 additions and 11 deletions

View File

@@ -114,6 +114,7 @@ class AccessRestrictions:
allowed_providers: Optional[list[str]] = None # 允许的 Provider ID 列表 allowed_providers: Optional[list[str]] = None # 允许的 Provider ID 列表
allowed_models: Optional[list[str]] = None # 允许的模型名称列表 allowed_models: Optional[list[str]] = None # 允许的模型名称列表
allowed_api_formats: Optional[list[str]] = None # 允许的 API 格式列表
@classmethod @classmethod
def from_api_key_and_user( def from_api_key_and_user(
@@ -130,6 +131,7 @@ class AccessRestrictions:
""" """
allowed_providers: Optional[list[str]] = None allowed_providers: Optional[list[str]] = None
allowed_models: Optional[list[str]] = None allowed_models: Optional[list[str]] = None
allowed_api_formats: Optional[list[str]] = None
# 优先使用 API Key 的限制 # 优先使用 API Key 的限制
if api_key: if api_key:
@@ -137,15 +139,36 @@ class AccessRestrictions:
allowed_providers = api_key.allowed_providers allowed_providers = api_key.allowed_providers
if api_key.allowed_models is not None: if api_key.allowed_models is not None:
allowed_models = api_key.allowed_models allowed_models = api_key.allowed_models
if api_key.allowed_api_formats is not None:
allowed_api_formats = api_key.allowed_api_formats
# 如果 API Key 没有限制,检查 User 的限制 # 如果 API Key 没有限制,检查 User 的限制
# 注意: User 没有 allowed_api_formats 字段
if user: if user:
if allowed_providers is None and user.allowed_providers is not None: if allowed_providers is None and user.allowed_providers is not None:
allowed_providers = user.allowed_providers allowed_providers = user.allowed_providers
if allowed_models is None and user.allowed_models is not None: if allowed_models is None and user.allowed_models is not None:
allowed_models = user.allowed_models allowed_models = user.allowed_models
return cls(allowed_providers=allowed_providers, allowed_models=allowed_models) return cls(
allowed_providers=allowed_providers,
allowed_models=allowed_models,
allowed_api_formats=allowed_api_formats,
)
def is_api_format_allowed(self, api_format: str) -> bool:
"""
检查 API 格式是否被允许
Args:
api_format: API 格式 (如 "OPENAI", "CLAUDE", "GEMINI")
Returns:
True 如果格式被允许False 否则
"""
if self.allowed_api_formats is None:
return True
return api_format in self.allowed_api_formats
def is_model_allowed(self, model_id: str, provider_id: str) -> bool: def is_model_allowed(self, model_id: str, provider_id: str) -> bool:
""" """

View File

@@ -104,6 +104,35 @@ def _get_formats_for_api(api_format: str) -> list[str]:
return _OPENAI_FORMATS return _OPENAI_FORMATS
def _build_empty_list_response(api_format: str) -> dict:
"""根据 API 格式构建空列表响应"""
if api_format == "claude":
return {"data": [], "has_more": False, "first_id": None, "last_id": None}
elif api_format == "gemini":
return {"models": []}
else:
return {"object": "list", "data": []}
def _filter_formats_by_restrictions(
formats: list[str], restrictions: AccessRestrictions, api_format: str
) -> Tuple[list[str], Optional[dict]]:
"""
根据访问限制过滤 API 格式
Returns:
(过滤后的格式列表, 空响应或None)
如果过滤后为空,返回对应格式的空响应
"""
if restrictions.allowed_api_formats is None:
return formats, None
filtered = [f for f in formats if f in restrictions.allowed_api_formats]
if not filtered:
logger.info(f"[Models] API Key 不允许访问格式 {api_format}")
return [], _build_empty_list_response(api_format)
return filtered, None
def _authenticate(db: Session, api_key: Optional[str]) -> Tuple[Optional[User], Optional[ApiKey]]: def _authenticate(db: Session, api_key: Optional[str]) -> Tuple[Optional[User], Optional[ApiKey]]:
""" """
认证 API Key 认证 API Key
@@ -383,16 +412,15 @@ async def list_models(
# 构建访问限制 # 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
# 检查 API 格式限制
formats = _get_formats_for_api(api_format) formats = _get_formats_for_api(api_format)
formats, empty_response = _filter_formats_by_restrictions(formats, restrictions, api_format)
if empty_response is not None:
return empty_response
available_provider_ids = get_available_provider_ids(db, formats) available_provider_ids = get_available_provider_ids(db, formats)
if not available_provider_ids: if not available_provider_ids:
if api_format == "claude": return _build_empty_list_response(api_format)
return {"data": [], "has_more": False, "first_id": None, "last_id": None}
elif api_format == "gemini":
return {"models": []}
else:
return {"object": "list", "data": []}
models = await list_available_models(db, available_provider_ids, formats, restrictions) models = await list_available_models(db, available_provider_ids, formats, restrictions)
logger.debug(f"[Models] 返回 {len(models)} 个模型") logger.debug(f"[Models] 返回 {len(models)} 个模型")
@@ -430,7 +458,11 @@ async def retrieve_model(
# 构建访问限制 # 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
# 检查 API 格式限制
formats = _get_formats_for_api(api_format) formats = _get_formats_for_api(api_format)
formats, _ = _filter_formats_by_restrictions(formats, restrictions, api_format)
if not formats:
return _build_404_response(model_id, api_format)
available_provider_ids = get_available_provider_ids(db, formats) available_provider_ids = get_available_provider_ids(db, formats)
model_info = find_model_by_id(db, model_id, available_provider_ids, formats, restrictions) model_info = find_model_by_id(db, model_id, available_provider_ids, formats, restrictions)
@@ -469,11 +501,18 @@ async def list_models_gemini(
# 构建访问限制 # 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) # 检查 API 格式限制
formats, empty_response = _filter_formats_by_restrictions(
_GEMINI_FORMATS, restrictions, "gemini"
)
if empty_response is not None:
return empty_response
available_provider_ids = get_available_provider_ids(db, formats)
if not available_provider_ids: if not available_provider_ids:
return {"models": []} return {"models": []}
models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS, restrictions) models = await list_available_models(db, available_provider_ids, formats, restrictions)
logger.debug(f"[Models] 返回 {len(models)} 个模型") logger.debug(f"[Models] 返回 {len(models)} 个模型")
response = _build_gemini_list_response(models, page_size, page_token) response = _build_gemini_list_response(models, page_size, page_token)
logger.debug(f"[Models] Gemini 响应: {response}") logger.debug(f"[Models] Gemini 响应: {response}")
@@ -503,9 +542,14 @@ async def get_model_gemini(
# 构建访问限制 # 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) # 检查 API 格式限制
formats, _ = _filter_formats_by_restrictions(_GEMINI_FORMATS, restrictions, "gemini")
if not formats:
return _build_404_response(model_id, "gemini")
available_provider_ids = get_available_provider_ids(db, formats)
model_info = find_model_by_id( model_info = find_model_by_id(
db, model_id, available_provider_ids, _GEMINI_FORMATS, restrictions db, model_id, available_provider_ids, formats, restrictions
) )
if not model_info: if not model_info: