feat: 添加模型列表访问限制功能

实现 API Key 和 User 级别的模型访问权限控制,支持按 Provider 和模型名称限制。

- 新增 AccessRestrictions 类处理访问限制合并逻辑(API Key 优先于 User)
- models_service 支持根据限制过滤模型列表
- models.py 在列表查询时构建并应用访问限制
- 优化缓存策略:仅无限制请求使用缓存,有限制的请求旁路缓存
- 修复 logger 配置:enqueue 改为 False 避免 macOS 信号量泄漏
This commit is contained in:
fawney19
2025-12-30 16:57:59 +08:00
parent b89a4af0cf
commit e20a09f15a
3 changed files with 139 additions and 15 deletions

View File

@@ -18,7 +18,15 @@ from sqlalchemy.orm import Session, joinedload
from src.config.constants import CacheTTL from src.config.constants import CacheTTL
from src.core.cache_service import CacheService from src.core.cache_service import CacheService
from src.core.logger import logger from src.core.logger import logger
from src.models.database import GlobalModel, Model, Provider, ProviderAPIKey, ProviderEndpoint from src.models.database import (
ApiKey,
GlobalModel,
Model,
Provider,
ProviderAPIKey,
ProviderEndpoint,
User,
)
# 缓存 key 前缀 # 缓存 key 前缀
_CACHE_KEY_PREFIX = "models:list" _CACHE_KEY_PREFIX = "models:list"
@@ -82,6 +90,7 @@ class ModelInfo:
created_at: Optional[str] # ISO 格式 created_at: Optional[str] # ISO 格式
created_timestamp: int # Unix 时间戳 created_timestamp: int # Unix 时间戳
provider_name: str provider_name: str
provider_id: str = "" # Provider ID用于权限过滤
# 能力配置 # 能力配置
streaming: bool = True streaming: bool = True
vision: bool = False vision: bool = False
@@ -99,6 +108,69 @@ class ModelInfo:
output_modalities: Optional[list[str]] = None output_modalities: Optional[list[str]] = None
@dataclass
class AccessRestrictions:
"""API Key 或 User 的访问限制"""
allowed_providers: Optional[list[str]] = None # 允许的 Provider ID 列表
allowed_models: Optional[list[str]] = None # 允许的模型名称列表
@classmethod
def from_api_key_and_user(
cls, api_key: Optional[ApiKey], user: Optional[User]
) -> "AccessRestrictions":
"""
从 API Key 和 User 合并访问限制
限制逻辑:
- API Key 的限制优先于 User 的限制
- 如果 API Key 有限制,使用 API Key 的限制
- 如果 API Key 无限制但 User 有限制,使用 User 的限制
- 两者都无限制则返回空限制
"""
allowed_providers: Optional[list[str]] = None
allowed_models: Optional[list[str]] = None
# 优先使用 API Key 的限制
if api_key:
if api_key.allowed_providers is not None:
allowed_providers = api_key.allowed_providers
if api_key.allowed_models is not None:
allowed_models = api_key.allowed_models
# 如果 API Key 没有限制,检查 User 的限制
if user:
if allowed_providers is None and user.allowed_providers is not None:
allowed_providers = user.allowed_providers
if allowed_models is None and user.allowed_models is not None:
allowed_models = user.allowed_models
return cls(allowed_providers=allowed_providers, allowed_models=allowed_models)
def is_model_allowed(self, model_id: str, provider_id: str) -> bool:
"""
检查模型是否被允许访问
Args:
model_id: 模型 ID
provider_id: Provider ID
Returns:
True 如果模型被允许False 否则
"""
# 检查 Provider 限制
if self.allowed_providers is not None:
if provider_id not in self.allowed_providers:
return False
# 检查模型限制
if self.allowed_models is not None:
if model_id not in self.allowed_models:
return False
return True
def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]: def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]:
""" """
返回有可用端点的 Provider IDs 返回有可用端点的 Provider IDs
@@ -218,6 +290,7 @@ def _extract_model_info(model: Any) -> ModelInfo:
) )
created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0 created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0
provider_name: str = model.provider.name if model.provider else "unknown" provider_name: str = model.provider.name if model.provider else "unknown"
provider_id: str = model.provider_id or ""
# 从 GlobalModel.config 提取配置信息 # 从 GlobalModel.config 提取配置信息
config: dict = {} config: dict = {}
@@ -233,6 +306,7 @@ def _extract_model_info(model: Any) -> ModelInfo:
created_at=created_at, created_at=created_at,
created_timestamp=created_timestamp, created_timestamp=created_timestamp,
provider_name=provider_name, provider_name=provider_name,
provider_id=provider_id,
# 能力配置 # 能力配置
streaming=config.get("streaming", True), streaming=config.get("streaming", True),
vision=config.get("vision", False), vision=config.get("vision", False),
@@ -255,6 +329,7 @@ async def list_available_models(
db: Session, db: Session,
available_provider_ids: set[str], available_provider_ids: set[str],
api_formats: Optional[list[str]] = None, api_formats: Optional[list[str]] = None,
restrictions: Optional[AccessRestrictions] = None,
) -> list[ModelInfo]: ) -> list[ModelInfo]:
""" """
获取可用模型列表(已去重,带缓存) 获取可用模型列表(已去重,带缓存)
@@ -263,6 +338,7 @@ async def list_available_models(
db: 数据库会话 db: 数据库会话
available_provider_ids: 有可用端点的 Provider ID 集合 available_provider_ids: 有可用端点的 Provider ID 集合
api_formats: API 格式列表,用于检查 Key 的 allowed_models api_formats: API 格式列表,用于检查 Key 的 allowed_models
restrictions: API Key/User 的访问限制
Returns: Returns:
去重后的 ModelInfo 列表,按创建时间倒序 去重后的 ModelInfo 列表,按创建时间倒序
@@ -270,8 +346,16 @@ async def list_available_models(
if not available_provider_ids: if not available_provider_ids:
return [] return []
# 缓存策略:只有完全无访问限制时才使用缓存
# - restrictions is None: 未传入限制对象
# - restrictions 的两个字段都为 None: 传入了限制对象但无实际限制
# 以上两种情况返回的结果相同,可以共享全局缓存
use_cache = restrictions is None or (
restrictions.allowed_providers is None and restrictions.allowed_models is None
)
# 尝试从缓存获取 # 尝试从缓存获取
if api_formats: if api_formats and use_cache:
cached = await _get_cached_models(api_formats) cached = await _get_cached_models(api_formats)
if cached is not None: if cached is not None:
return cached return cached
@@ -306,14 +390,19 @@ async def list_available_models(
if available_model_ids is not None and info.id not in available_model_ids: if available_model_ids is not None and info.id not in available_model_ids:
continue continue
# 检查 API Key/User 访问限制
if restrictions is not None:
if not restrictions.is_model_allowed(info.id, info.provider_id):
continue
if info.id in seen_model_ids: if info.id in seen_model_ids:
continue continue
seen_model_ids.add(info.id) seen_model_ids.add(info.id)
result.append(info) result.append(info)
# 写入缓存 # 只有无限制的情况才写入缓存
if api_formats: if api_formats and use_cache:
await _set_cached_models(api_formats, result) await _set_cached_models(api_formats, result)
return result return result
@@ -324,6 +413,7 @@ def find_model_by_id(
model_id: str, model_id: str,
available_provider_ids: set[str], available_provider_ids: set[str],
api_formats: Optional[list[str]] = None, api_formats: Optional[list[str]] = None,
restrictions: Optional[AccessRestrictions] = None,
) -> Optional[ModelInfo]: ) -> Optional[ModelInfo]:
""" """
按 ID 查找模型 按 ID 查找模型
@@ -338,6 +428,7 @@ def find_model_by_id(
model_id: 模型 ID model_id: 模型 ID
available_provider_ids: 有可用端点的 Provider ID 集合 available_provider_ids: 有可用端点的 Provider ID 集合
api_formats: API 格式列表,用于检查 Key 的 allowed_models api_formats: API 格式列表,用于检查 Key 的 allowed_models
restrictions: API Key/User 的访问限制
Returns: Returns:
ModelInfo 或 None ModelInfo 或 None
@@ -353,6 +444,11 @@ def find_model_by_id(
if available_model_ids is not None and model_id not in available_model_ids: if available_model_ids is not None and model_id not in available_model_ids:
return None return None
# 快速检查:如果 restrictions 明确限制了模型列表且目标模型不在其中,直接返回 None
if restrictions is not None and restrictions.allowed_models is not None:
if model_id not in restrictions.allowed_models:
return None
# 先按 GlobalModel.name 查找 # 先按 GlobalModel.name 查找
models_by_global = ( models_by_global = (
db.query(Model) db.query(Model)
@@ -368,8 +464,19 @@ def find_model_by_id(
.all() .all()
) )
def is_model_accessible(m: Model) -> bool:
"""检查模型是否可访问"""
if m.provider_id not in available_provider_ids:
return False
# 检查 API Key/User 访问限制
if restrictions is not None:
provider_id = m.provider_id or ""
if not restrictions.is_model_allowed(model_id, provider_id):
return False
return True
model = next( model = next(
(m for m in models_by_global if m.provider_id in available_provider_ids), (m for m in models_by_global if is_model_accessible(m)),
None, None,
) )
@@ -393,7 +500,7 @@ def find_model_by_id(
) )
model = next( model = next(
(m for m in models_by_provider_name if m.provider_id in available_provider_ids), (m for m in models_by_provider_name if is_model_accessible(m)),
None, None,
) )

View File

@@ -14,6 +14,7 @@ from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.api.base.models_service import ( from src.api.base.models_service import (
AccessRestrictions,
ModelInfo, ModelInfo,
find_model_by_id, find_model_by_id,
get_available_provider_ids, get_available_provider_ids,
@@ -375,10 +376,13 @@ async def list_models(
logger.info(f"[Models] GET /v1/models | format={api_format}") logger.info(f"[Models] GET /v1/models | format={api_format}")
# 认证 # 认证
user, _ = _authenticate(db, api_key) user, key_record = _authenticate(db, api_key)
if not user: if not user:
return _build_auth_error_response(api_format) return _build_auth_error_response(api_format)
# 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
formats = _get_formats_for_api(api_format) formats = _get_formats_for_api(api_format)
available_provider_ids = get_available_provider_ids(db, formats) available_provider_ids = get_available_provider_ids(db, formats)
@@ -390,7 +394,7 @@ async def list_models(
else: else:
return {"object": "list", "data": []} return {"object": "list", "data": []}
models = await list_available_models(db, available_provider_ids, formats) models = await list_available_models(db, available_provider_ids, formats, restrictions)
logger.debug(f"[Models] 返回 {len(models)} 个模型") logger.debug(f"[Models] 返回 {len(models)} 个模型")
if api_format == "claude": if api_format == "claude":
@@ -419,14 +423,17 @@ async def retrieve_model(
logger.info(f"[Models] GET /v1/models/{model_id} | format={api_format}") logger.info(f"[Models] GET /v1/models/{model_id} | format={api_format}")
# 认证 # 认证
user, _ = _authenticate(db, api_key) user, key_record = _authenticate(db, api_key)
if not user: if not user:
return _build_auth_error_response(api_format) return _build_auth_error_response(api_format)
# 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
formats = _get_formats_for_api(api_format) formats = _get_formats_for_api(api_format)
available_provider_ids = get_available_provider_ids(db, formats) available_provider_ids = get_available_provider_ids(db, formats)
model_info = find_model_by_id(db, model_id, available_provider_ids, formats) model_info = find_model_by_id(db, model_id, available_provider_ids, formats, restrictions)
if not model_info: if not model_info:
return _build_404_response(model_id, api_format) return _build_404_response(model_id, api_format)
@@ -455,15 +462,18 @@ async def list_models_gemini(
api_key = _extract_api_key_from_request(request, gemini_def) api_key = _extract_api_key_from_request(request, gemini_def)
# 认证 # 认证
user, _ = _authenticate(db, api_key) user, key_record = _authenticate(db, api_key)
if not user: if not user:
return _build_auth_error_response("gemini") return _build_auth_error_response("gemini")
# 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS)
if not available_provider_ids: if not available_provider_ids:
return {"models": []} return {"models": []}
models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS) models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS, restrictions)
logger.debug(f"[Models] 返回 {len(models)} 个模型") logger.debug(f"[Models] 返回 {len(models)} 个模型")
response = _build_gemini_list_response(models, page_size, page_token) response = _build_gemini_list_response(models, page_size, page_token)
logger.debug(f"[Models] Gemini 响应: {response}") logger.debug(f"[Models] Gemini 响应: {response}")
@@ -486,12 +496,17 @@ async def get_model_gemini(
api_key = _extract_api_key_from_request(request, gemini_def) api_key = _extract_api_key_from_request(request, gemini_def)
# 认证 # 认证
user, _ = _authenticate(db, api_key) user, key_record = _authenticate(db, api_key)
if not user: if not user:
return _build_auth_error_response("gemini") return _build_auth_error_response("gemini")
# 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS)
model_info = find_model_by_id(db, model_id, available_provider_ids, _GEMINI_FORMATS) model_info = find_model_by_id(
db, model_id, available_provider_ids, _GEMINI_FORMATS, restrictions
)
if not model_info: if not model_info:
return _build_404_response(model_id, "gemini") return _build_404_response(model_id, "gemini")

View File

@@ -96,13 +96,15 @@ if not DISABLE_FILE_LOG:
log_dir.mkdir(exist_ok=True) log_dir.mkdir(exist_ok=True)
# 文件日志通用配置 # 文件日志通用配置
# 注意: enqueue=False 使用同步模式,避免 multiprocessing 信号量泄漏
# 在 macOS 上,进程异常退出时 POSIX 信号量不会自动释放,导致资源耗尽
file_log_config = { file_log_config = {
"format": FILE_FORMAT, "format": FILE_FORMAT,
"filter": _log_filter, "filter": _log_filter,
"rotation": "100 MB", "rotation": "100 MB",
"retention": "30 days", "retention": "30 days",
"compression": "gz", "compression": "gz",
"enqueue": True, "enqueue": False,
"encoding": "utf-8", "encoding": "utf-8",
"catch": True, "catch": True,
} }