mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
feat: 添加 API 格式访问限制
扩展访问限制功能,支持 API Key 级别的 API 格式限制(OPENAI、CLAUDE、GEMINI)。 - AccessRestrictions 新增 allowed_api_formats 字段 - 新增 is_api_format_allowed() 方法检查格式权限 - models.py 添加 _filter_formats_by_restrictions() 函数过滤 API 格式 - 在所有模型列表和查询端点应用格式限制检查 - 添加 _build_empty_list_response() 统一空响应构建逻辑
This commit is contained in:
@@ -114,6 +114,7 @@ class AccessRestrictions:
|
|||||||
|
|
||||||
allowed_providers: Optional[list[str]] = None # 允许的 Provider ID 列表
|
allowed_providers: Optional[list[str]] = None # 允许的 Provider ID 列表
|
||||||
allowed_models: Optional[list[str]] = None # 允许的模型名称列表
|
allowed_models: Optional[list[str]] = None # 允许的模型名称列表
|
||||||
|
allowed_api_formats: Optional[list[str]] = None # 允许的 API 格式列表
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_api_key_and_user(
|
def from_api_key_and_user(
|
||||||
@@ -130,6 +131,7 @@ class AccessRestrictions:
|
|||||||
"""
|
"""
|
||||||
allowed_providers: Optional[list[str]] = None
|
allowed_providers: Optional[list[str]] = None
|
||||||
allowed_models: Optional[list[str]] = None
|
allowed_models: Optional[list[str]] = None
|
||||||
|
allowed_api_formats: Optional[list[str]] = None
|
||||||
|
|
||||||
# 优先使用 API Key 的限制
|
# 优先使用 API Key 的限制
|
||||||
if api_key:
|
if api_key:
|
||||||
@@ -137,15 +139,36 @@ class AccessRestrictions:
|
|||||||
allowed_providers = api_key.allowed_providers
|
allowed_providers = api_key.allowed_providers
|
||||||
if api_key.allowed_models is not None:
|
if api_key.allowed_models is not None:
|
||||||
allowed_models = api_key.allowed_models
|
allowed_models = api_key.allowed_models
|
||||||
|
if api_key.allowed_api_formats is not None:
|
||||||
|
allowed_api_formats = api_key.allowed_api_formats
|
||||||
|
|
||||||
# 如果 API Key 没有限制,检查 User 的限制
|
# 如果 API Key 没有限制,检查 User 的限制
|
||||||
|
# 注意: User 没有 allowed_api_formats 字段
|
||||||
if user:
|
if user:
|
||||||
if allowed_providers is None and user.allowed_providers is not None:
|
if allowed_providers is None and user.allowed_providers is not None:
|
||||||
allowed_providers = user.allowed_providers
|
allowed_providers = user.allowed_providers
|
||||||
if allowed_models is None and user.allowed_models is not None:
|
if allowed_models is None and user.allowed_models is not None:
|
||||||
allowed_models = user.allowed_models
|
allowed_models = user.allowed_models
|
||||||
|
|
||||||
return cls(allowed_providers=allowed_providers, allowed_models=allowed_models)
|
return cls(
|
||||||
|
allowed_providers=allowed_providers,
|
||||||
|
allowed_models=allowed_models,
|
||||||
|
allowed_api_formats=allowed_api_formats,
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_api_format_allowed(self, api_format: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查 API 格式是否被允许
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_format: API 格式 (如 "OPENAI", "CLAUDE", "GEMINI")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True 如果格式被允许,False 否则
|
||||||
|
"""
|
||||||
|
if self.allowed_api_formats is None:
|
||||||
|
return True
|
||||||
|
return api_format in self.allowed_api_formats
|
||||||
|
|
||||||
def is_model_allowed(self, model_id: str, provider_id: str) -> bool:
|
def is_model_allowed(self, model_id: str, provider_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -104,6 +104,35 @@ def _get_formats_for_api(api_format: str) -> list[str]:
|
|||||||
return _OPENAI_FORMATS
|
return _OPENAI_FORMATS
|
||||||
|
|
||||||
|
|
||||||
|
def _build_empty_list_response(api_format: str) -> dict:
|
||||||
|
"""根据 API 格式构建空列表响应"""
|
||||||
|
if api_format == "claude":
|
||||||
|
return {"data": [], "has_more": False, "first_id": None, "last_id": None}
|
||||||
|
elif api_format == "gemini":
|
||||||
|
return {"models": []}
|
||||||
|
else:
|
||||||
|
return {"object": "list", "data": []}
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_formats_by_restrictions(
|
||||||
|
formats: list[str], restrictions: AccessRestrictions, api_format: str
|
||||||
|
) -> Tuple[list[str], Optional[dict]]:
|
||||||
|
"""
|
||||||
|
根据访问限制过滤 API 格式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(过滤后的格式列表, 空响应或None)
|
||||||
|
如果过滤后为空,返回对应格式的空响应
|
||||||
|
"""
|
||||||
|
if restrictions.allowed_api_formats is None:
|
||||||
|
return formats, None
|
||||||
|
filtered = [f for f in formats if f in restrictions.allowed_api_formats]
|
||||||
|
if not filtered:
|
||||||
|
logger.info(f"[Models] API Key 不允许访问格式 {api_format}")
|
||||||
|
return [], _build_empty_list_response(api_format)
|
||||||
|
return filtered, None
|
||||||
|
|
||||||
|
|
||||||
def _authenticate(db: Session, api_key: Optional[str]) -> Tuple[Optional[User], Optional[ApiKey]]:
|
def _authenticate(db: Session, api_key: Optional[str]) -> Tuple[Optional[User], Optional[ApiKey]]:
|
||||||
"""
|
"""
|
||||||
认证 API Key
|
认证 API Key
|
||||||
@@ -383,16 +412,15 @@ async def list_models(
|
|||||||
# 构建访问限制
|
# 构建访问限制
|
||||||
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
|
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
|
||||||
|
|
||||||
|
# 检查 API 格式限制
|
||||||
formats = _get_formats_for_api(api_format)
|
formats = _get_formats_for_api(api_format)
|
||||||
|
formats, empty_response = _filter_formats_by_restrictions(formats, restrictions, api_format)
|
||||||
|
if empty_response is not None:
|
||||||
|
return empty_response
|
||||||
|
|
||||||
available_provider_ids = get_available_provider_ids(db, formats)
|
available_provider_ids = get_available_provider_ids(db, formats)
|
||||||
if not available_provider_ids:
|
if not available_provider_ids:
|
||||||
if api_format == "claude":
|
return _build_empty_list_response(api_format)
|
||||||
return {"data": [], "has_more": False, "first_id": None, "last_id": None}
|
|
||||||
elif api_format == "gemini":
|
|
||||||
return {"models": []}
|
|
||||||
else:
|
|
||||||
return {"object": "list", "data": []}
|
|
||||||
|
|
||||||
models = await list_available_models(db, available_provider_ids, formats, restrictions)
|
models = await list_available_models(db, available_provider_ids, formats, restrictions)
|
||||||
logger.debug(f"[Models] 返回 {len(models)} 个模型")
|
logger.debug(f"[Models] 返回 {len(models)} 个模型")
|
||||||
@@ -430,7 +458,11 @@ async def retrieve_model(
|
|||||||
# 构建访问限制
|
# 构建访问限制
|
||||||
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
|
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
|
||||||
|
|
||||||
|
# 检查 API 格式限制
|
||||||
formats = _get_formats_for_api(api_format)
|
formats = _get_formats_for_api(api_format)
|
||||||
|
formats, _ = _filter_formats_by_restrictions(formats, restrictions, api_format)
|
||||||
|
if not formats:
|
||||||
|
return _build_404_response(model_id, api_format)
|
||||||
|
|
||||||
available_provider_ids = get_available_provider_ids(db, formats)
|
available_provider_ids = get_available_provider_ids(db, formats)
|
||||||
model_info = find_model_by_id(db, model_id, available_provider_ids, formats, restrictions)
|
model_info = find_model_by_id(db, model_id, available_provider_ids, formats, restrictions)
|
||||||
@@ -469,11 +501,18 @@ async def list_models_gemini(
|
|||||||
# 构建访问限制
|
# 构建访问限制
|
||||||
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
|
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
|
||||||
|
|
||||||
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS)
|
# 检查 API 格式限制
|
||||||
|
formats, empty_response = _filter_formats_by_restrictions(
|
||||||
|
_GEMINI_FORMATS, restrictions, "gemini"
|
||||||
|
)
|
||||||
|
if empty_response is not None:
|
||||||
|
return empty_response
|
||||||
|
|
||||||
|
available_provider_ids = get_available_provider_ids(db, formats)
|
||||||
if not available_provider_ids:
|
if not available_provider_ids:
|
||||||
return {"models": []}
|
return {"models": []}
|
||||||
|
|
||||||
models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS, restrictions)
|
models = await list_available_models(db, available_provider_ids, formats, restrictions)
|
||||||
logger.debug(f"[Models] 返回 {len(models)} 个模型")
|
logger.debug(f"[Models] 返回 {len(models)} 个模型")
|
||||||
response = _build_gemini_list_response(models, page_size, page_token)
|
response = _build_gemini_list_response(models, page_size, page_token)
|
||||||
logger.debug(f"[Models] Gemini 响应: {response}")
|
logger.debug(f"[Models] Gemini 响应: {response}")
|
||||||
@@ -503,9 +542,14 @@ async def get_model_gemini(
|
|||||||
# 构建访问限制
|
# 构建访问限制
|
||||||
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
|
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
|
||||||
|
|
||||||
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS)
|
# 检查 API 格式限制
|
||||||
|
formats, _ = _filter_formats_by_restrictions(_GEMINI_FORMATS, restrictions, "gemini")
|
||||||
|
if not formats:
|
||||||
|
return _build_404_response(model_id, "gemini")
|
||||||
|
|
||||||
|
available_provider_ids = get_available_provider_ids(db, formats)
|
||||||
model_info = find_model_by_id(
|
model_info = find_model_by_id(
|
||||||
db, model_id, available_provider_ids, _GEMINI_FORMATS, restrictions
|
db, model_id, available_provider_ids, formats, restrictions
|
||||||
)
|
)
|
||||||
|
|
||||||
if not model_info:
|
if not model_info:
|
||||||
|
|||||||
Reference in New Issue
Block a user