From e20a09f15a328c3d19e97d3915f4e62478113c45 Mon Sep 17 00:00:00 2001 From: fawney19 Date: Tue, 30 Dec 2025 16:57:59 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E5=88=97=E8=A1=A8=E8=AE=BF=E9=97=AE=E9=99=90=E5=88=B6=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 实现 API Key 和 User 级别的模型访问权限控制,支持按 Provider 和模型名称限制。 - 新增 AccessRestrictions 类处理访问限制合并逻辑(API Key 优先于 User) - models_service 支持根据限制过滤模型列表 - models.py 在列表查询时构建并应用访问限制 - 优化缓存策略:仅无限制请求使用缓存,有限制的请求旁路缓存 - 修复 logger 配置:enqueue 改为 False 避免 macOS 信号量泄漏 --- src/api/base/models_service.py | 119 +++++++++++++++++++++++++++++++-- src/api/public/models.py | 31 ++++++--- src/core/logger.py | 4 +- 3 files changed, 139 insertions(+), 15 deletions(-) diff --git a/src/api/base/models_service.py b/src/api/base/models_service.py index ee9cdfa..d8fcba7 100644 --- a/src/api/base/models_service.py +++ b/src/api/base/models_service.py @@ -18,7 +18,15 @@ from sqlalchemy.orm import Session, joinedload from src.config.constants import CacheTTL from src.core.cache_service import CacheService from src.core.logger import logger -from src.models.database import GlobalModel, Model, Provider, ProviderAPIKey, ProviderEndpoint +from src.models.database import ( + ApiKey, + GlobalModel, + Model, + Provider, + ProviderAPIKey, + ProviderEndpoint, + User, +) # 缓存 key 前缀 _CACHE_KEY_PREFIX = "models:list" @@ -82,6 +90,7 @@ class ModelInfo: created_at: Optional[str] # ISO 格式 created_timestamp: int # Unix 时间戳 provider_name: str + provider_id: str = "" # Provider ID,用于权限过滤 # 能力配置 streaming: bool = True vision: bool = False @@ -99,6 +108,69 @@ class ModelInfo: output_modalities: Optional[list[str]] = None +@dataclass +class AccessRestrictions: + """API Key 或 User 的访问限制""" + + allowed_providers: Optional[list[str]] = None # 允许的 Provider ID 列表 + allowed_models: Optional[list[str]] = None # 允许的模型名称列表 + + @classmethod + def from_api_key_and_user( + cls, api_key: Optional[ApiKey], user: Optional[User] + ) -> "AccessRestrictions": + """ + 从 API Key 和 User 合并访问限制 + + 限制逻辑: + - API Key 的限制优先于 User 的限制 + - 如果 API Key 有限制,使用 API Key 的限制 + - 如果 API Key 无限制但 User 有限制,使用 User 的限制 + - 两者都无限制则返回空限制 + """ + allowed_providers: Optional[list[str]] = None + allowed_models: Optional[list[str]] = None + + # 优先使用 API Key 的限制 + if api_key: + if api_key.allowed_providers is not None: + allowed_providers = api_key.allowed_providers + if api_key.allowed_models is not None: + allowed_models = api_key.allowed_models + + # 如果 API Key 没有限制,检查 User 的限制 + if user: + if allowed_providers is None and user.allowed_providers is not None: + allowed_providers = user.allowed_providers + if allowed_models is None and user.allowed_models is not None: + allowed_models = user.allowed_models + + return cls(allowed_providers=allowed_providers, allowed_models=allowed_models) + + def is_model_allowed(self, model_id: str, provider_id: str) -> bool: + """ + 检查模型是否被允许访问 + + Args: + model_id: 模型 ID + provider_id: Provider ID + + Returns: + True 如果模型被允许,False 否则 + """ + # 检查 Provider 限制 + if self.allowed_providers is not None: + if provider_id not in self.allowed_providers: + return False + + # 检查模型限制 + if self.allowed_models is not None: + if model_id not in self.allowed_models: + return False + + return True + + def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]: """ 返回有可用端点的 Provider IDs @@ -218,6 +290,7 @@ def _extract_model_info(model: Any) -> ModelInfo: ) created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0 provider_name: str = model.provider.name if model.provider else "unknown" + provider_id: str = model.provider_id or "" # 从 GlobalModel.config 提取配置信息 config: dict = {} @@ -233,6 +306,7 @@ def _extract_model_info(model: Any) -> ModelInfo: created_at=created_at, created_timestamp=created_timestamp, provider_name=provider_name, + provider_id=provider_id, # 能力配置 streaming=config.get("streaming", True), vision=config.get("vision", False), @@ -255,6 +329,7 @@ async def list_available_models( db: Session, available_provider_ids: set[str], api_formats: Optional[list[str]] = None, + restrictions: Optional[AccessRestrictions] = None, ) -> list[ModelInfo]: """ 获取可用模型列表(已去重,带缓存) @@ -263,6 +338,7 @@ async def list_available_models( db: 数据库会话 available_provider_ids: 有可用端点的 Provider ID 集合 api_formats: API 格式列表,用于检查 Key 的 allowed_models + restrictions: API Key/User 的访问限制 Returns: 去重后的 ModelInfo 列表,按创建时间倒序 @@ -270,8 +346,16 @@ async def list_available_models( if not available_provider_ids: return [] + # 缓存策略:只有完全无访问限制时才使用缓存 + # - restrictions is None: 未传入限制对象 + # - restrictions 的两个字段都为 None: 传入了限制对象但无实际限制 + # 以上两种情况返回的结果相同,可以共享全局缓存 + use_cache = restrictions is None or ( + restrictions.allowed_providers is None and restrictions.allowed_models is None + ) + # 尝试从缓存获取 - if api_formats: + if api_formats and use_cache: cached = await _get_cached_models(api_formats) if cached is not None: return cached @@ -306,14 +390,19 @@ async def list_available_models( if available_model_ids is not None and info.id not in available_model_ids: continue + # 检查 API Key/User 访问限制 + if restrictions is not None: + if not restrictions.is_model_allowed(info.id, info.provider_id): + continue + if info.id in seen_model_ids: continue seen_model_ids.add(info.id) result.append(info) - # 写入缓存 - if api_formats: + # 只有无限制的情况才写入缓存 + if api_formats and use_cache: await _set_cached_models(api_formats, result) return result @@ -324,6 +413,7 @@ def find_model_by_id( model_id: str, available_provider_ids: set[str], api_formats: Optional[list[str]] = None, + restrictions: Optional[AccessRestrictions] = None, ) -> Optional[ModelInfo]: """ 按 ID 查找模型 @@ -338,6 +428,7 @@ def find_model_by_id( model_id: 模型 ID available_provider_ids: 有可用端点的 Provider ID 集合 api_formats: API 格式列表,用于检查 Key 的 allowed_models + restrictions: API Key/User 的访问限制 Returns: ModelInfo 或 None @@ -353,6 +444,11 @@ def find_model_by_id( if available_model_ids is not None and model_id not in available_model_ids: return None + # 快速检查:如果 restrictions 明确限制了模型列表且目标模型不在其中,直接返回 None + if restrictions is not None and restrictions.allowed_models is not None: + if model_id not in restrictions.allowed_models: + return None + # 先按 GlobalModel.name 查找 models_by_global = ( db.query(Model) @@ -368,8 +464,19 @@ def find_model_by_id( .all() ) + def is_model_accessible(m: Model) -> bool: + """检查模型是否可访问""" + if m.provider_id not in available_provider_ids: + return False + # 检查 API Key/User 访问限制 + if restrictions is not None: + provider_id = m.provider_id or "" + if not restrictions.is_model_allowed(model_id, provider_id): + return False + return True + model = next( - (m for m in models_by_global if m.provider_id in available_provider_ids), + (m for m in models_by_global if is_model_accessible(m)), None, ) @@ -393,7 +500,7 @@ def find_model_by_id( ) model = next( - (m for m in models_by_provider_name if m.provider_id in available_provider_ids), + (m for m in models_by_provider_name if is_model_accessible(m)), None, ) diff --git a/src/api/public/models.py b/src/api/public/models.py index 829559d..f92b648 100644 --- a/src/api/public/models.py +++ b/src/api/public/models.py @@ -14,6 +14,7 @@ from fastapi.responses import JSONResponse from sqlalchemy.orm import Session from src.api.base.models_service import ( + AccessRestrictions, ModelInfo, find_model_by_id, get_available_provider_ids, @@ -375,10 +376,13 @@ async def list_models( logger.info(f"[Models] GET /v1/models | format={api_format}") # 认证 - user, _ = _authenticate(db, api_key) + user, key_record = _authenticate(db, api_key) if not user: return _build_auth_error_response(api_format) + # 构建访问限制 + restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) + formats = _get_formats_for_api(api_format) available_provider_ids = get_available_provider_ids(db, formats) @@ -390,7 +394,7 @@ async def list_models( else: return {"object": "list", "data": []} - models = await list_available_models(db, available_provider_ids, formats) + models = await list_available_models(db, available_provider_ids, formats, restrictions) logger.debug(f"[Models] 返回 {len(models)} 个模型") if api_format == "claude": @@ -419,14 +423,17 @@ async def retrieve_model( logger.info(f"[Models] GET /v1/models/{model_id} | format={api_format}") # 认证 - user, _ = _authenticate(db, api_key) + user, key_record = _authenticate(db, api_key) if not user: return _build_auth_error_response(api_format) + # 构建访问限制 + restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) + formats = _get_formats_for_api(api_format) available_provider_ids = get_available_provider_ids(db, formats) - model_info = find_model_by_id(db, model_id, available_provider_ids, formats) + model_info = find_model_by_id(db, model_id, available_provider_ids, formats, restrictions) if not model_info: return _build_404_response(model_id, api_format) @@ -455,15 +462,18 @@ async def list_models_gemini( api_key = _extract_api_key_from_request(request, gemini_def) # 认证 - user, _ = _authenticate(db, api_key) + user, key_record = _authenticate(db, api_key) if not user: return _build_auth_error_response("gemini") + # 构建访问限制 + restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) + available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) if not available_provider_ids: return {"models": []} - models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS) + models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS, restrictions) logger.debug(f"[Models] 返回 {len(models)} 个模型") response = _build_gemini_list_response(models, page_size, page_token) logger.debug(f"[Models] Gemini 响应: {response}") @@ -486,12 +496,17 @@ async def get_model_gemini( api_key = _extract_api_key_from_request(request, gemini_def) # 认证 - user, _ = _authenticate(db, api_key) + user, key_record = _authenticate(db, api_key) if not user: return _build_auth_error_response("gemini") + # 构建访问限制 + restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) + available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) - model_info = find_model_by_id(db, model_id, available_provider_ids, _GEMINI_FORMATS) + model_info = find_model_by_id( + db, model_id, available_provider_ids, _GEMINI_FORMATS, restrictions + ) if not model_info: return _build_404_response(model_id, "gemini") diff --git a/src/core/logger.py b/src/core/logger.py index 0d48f8f..82d0acb 100644 --- a/src/core/logger.py +++ b/src/core/logger.py @@ -96,13 +96,15 @@ if not DISABLE_FILE_LOG: log_dir.mkdir(exist_ok=True) # 文件日志通用配置 + # 注意: enqueue=False 使用同步模式,避免 multiprocessing 信号量泄漏 + # 在 macOS 上,进程异常退出时 POSIX 信号量不会自动释放,导致资源耗尽 file_log_config = { "format": FILE_FORMAT, "filter": _log_filter, "rotation": "100 MB", "retention": "30 days", "compression": "gz", - "enqueue": True, + "enqueue": False, "encoding": "utf-8", "catch": True, }