diff --git a/src/api/base/models_service.py b/src/api/base/models_service.py index ee9cdfa..1df1594 100644 --- a/src/api/base/models_service.py +++ b/src/api/base/models_service.py @@ -18,7 +18,15 @@ from sqlalchemy.orm import Session, joinedload from src.config.constants import CacheTTL from src.core.cache_service import CacheService from src.core.logger import logger -from src.models.database import GlobalModel, Model, Provider, ProviderAPIKey, ProviderEndpoint +from src.models.database import ( + ApiKey, + GlobalModel, + Model, + Provider, + ProviderAPIKey, + ProviderEndpoint, + User, +) # 缓存 key 前缀 _CACHE_KEY_PREFIX = "models:list" @@ -82,6 +90,7 @@ class ModelInfo: created_at: Optional[str] # ISO 格式 created_timestamp: int # Unix 时间戳 provider_name: str + provider_id: str = "" # Provider ID,用于权限过滤 # 能力配置 streaming: bool = True vision: bool = False @@ -99,6 +108,92 @@ class ModelInfo: output_modalities: Optional[list[str]] = None +@dataclass +class AccessRestrictions: + """API Key 或 User 的访问限制""" + + allowed_providers: Optional[list[str]] = None # 允许的 Provider ID 列表 + allowed_models: Optional[list[str]] = None # 允许的模型名称列表 + allowed_api_formats: Optional[list[str]] = None # 允许的 API 格式列表 + + @classmethod + def from_api_key_and_user( + cls, api_key: Optional[ApiKey], user: Optional[User] + ) -> "AccessRestrictions": + """ + 从 API Key 和 User 合并访问限制 + + 限制逻辑: + - API Key 的限制优先于 User 的限制 + - 如果 API Key 有限制,使用 API Key 的限制 + - 如果 API Key 无限制但 User 有限制,使用 User 的限制 + - 两者都无限制则返回空限制 + """ + allowed_providers: Optional[list[str]] = None + allowed_models: Optional[list[str]] = None + allowed_api_formats: Optional[list[str]] = None + + # 优先使用 API Key 的限制 + if api_key: + if api_key.allowed_providers is not None: + allowed_providers = api_key.allowed_providers + if api_key.allowed_models is not None: + allowed_models = api_key.allowed_models + if api_key.allowed_api_formats is not None: + allowed_api_formats = api_key.allowed_api_formats + + # 如果 API Key 没有限制,检查 User 的限制 + # 注意: User 没有 allowed_api_formats 字段 + if user: + if allowed_providers is None and user.allowed_providers is not None: + allowed_providers = user.allowed_providers + if allowed_models is None and user.allowed_models is not None: + allowed_models = user.allowed_models + + return cls( + allowed_providers=allowed_providers, + allowed_models=allowed_models, + allowed_api_formats=allowed_api_formats, + ) + + def is_api_format_allowed(self, api_format: str) -> bool: + """ + 检查 API 格式是否被允许 + + Args: + api_format: API 格式 (如 "OPENAI", "CLAUDE", "GEMINI") + + Returns: + True 如果格式被允许,False 否则 + """ + if self.allowed_api_formats is None: + return True + return api_format in self.allowed_api_formats + + def is_model_allowed(self, model_id: str, provider_id: str) -> bool: + """ + 检查模型是否被允许访问 + + Args: + model_id: 模型 ID + provider_id: Provider ID + + Returns: + True 如果模型被允许,False 否则 + """ + # 检查 Provider 限制 + if self.allowed_providers is not None: + if provider_id not in self.allowed_providers: + return False + + # 检查模型限制 + if self.allowed_models is not None: + if model_id not in self.allowed_models: + return False + + return True + + def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]: """ 返回有可用端点的 Provider IDs @@ -218,6 +313,7 @@ def _extract_model_info(model: Any) -> ModelInfo: ) created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0 provider_name: str = model.provider.name if model.provider else "unknown" + provider_id: str = model.provider_id or "" # 从 GlobalModel.config 提取配置信息 config: dict = {} @@ -233,6 +329,7 @@ def _extract_model_info(model: Any) -> ModelInfo: created_at=created_at, created_timestamp=created_timestamp, provider_name=provider_name, + provider_id=provider_id, # 能力配置 streaming=config.get("streaming", True), vision=config.get("vision", False), @@ -255,6 +352,7 @@ async def list_available_models( db: Session, available_provider_ids: set[str], api_formats: Optional[list[str]] = None, + restrictions: Optional[AccessRestrictions] = None, ) -> list[ModelInfo]: """ 获取可用模型列表(已去重,带缓存) @@ -263,6 +361,7 @@ async def list_available_models( db: 数据库会话 available_provider_ids: 有可用端点的 Provider ID 集合 api_formats: API 格式列表,用于检查 Key 的 allowed_models + restrictions: API Key/User 的访问限制 Returns: 去重后的 ModelInfo 列表,按创建时间倒序 @@ -270,8 +369,16 @@ async def list_available_models( if not available_provider_ids: return [] + # 缓存策略:只有完全无访问限制时才使用缓存 + # - restrictions is None: 未传入限制对象 + # - restrictions 的两个字段都为 None: 传入了限制对象但无实际限制 + # 以上两种情况返回的结果相同,可以共享全局缓存 + use_cache = restrictions is None or ( + restrictions.allowed_providers is None and restrictions.allowed_models is None + ) + # 尝试从缓存获取 - if api_formats: + if api_formats and use_cache: cached = await _get_cached_models(api_formats) if cached is not None: return cached @@ -306,14 +413,19 @@ async def list_available_models( if available_model_ids is not None and info.id not in available_model_ids: continue + # 检查 API Key/User 访问限制 + if restrictions is not None: + if not restrictions.is_model_allowed(info.id, info.provider_id): + continue + if info.id in seen_model_ids: continue seen_model_ids.add(info.id) result.append(info) - # 写入缓存 - if api_formats: + # 只有无限制的情况才写入缓存 + if api_formats and use_cache: await _set_cached_models(api_formats, result) return result @@ -324,6 +436,7 @@ def find_model_by_id( model_id: str, available_provider_ids: set[str], api_formats: Optional[list[str]] = None, + restrictions: Optional[AccessRestrictions] = None, ) -> Optional[ModelInfo]: """ 按 ID 查找模型 @@ -338,6 +451,7 @@ def find_model_by_id( model_id: 模型 ID available_provider_ids: 有可用端点的 Provider ID 集合 api_formats: API 格式列表,用于检查 Key 的 allowed_models + restrictions: API Key/User 的访问限制 Returns: ModelInfo 或 None @@ -353,6 +467,11 @@ def find_model_by_id( if available_model_ids is not None and model_id not in available_model_ids: return None + # 快速检查:如果 restrictions 明确限制了模型列表且目标模型不在其中,直接返回 None + if restrictions is not None and restrictions.allowed_models is not None: + if model_id not in restrictions.allowed_models: + return None + # 先按 GlobalModel.name 查找 models_by_global = ( db.query(Model) @@ -368,8 +487,19 @@ def find_model_by_id( .all() ) + def is_model_accessible(m: Model) -> bool: + """检查模型是否可访问""" + if m.provider_id not in available_provider_ids: + return False + # 检查 API Key/User 访问限制 + if restrictions is not None: + provider_id = m.provider_id or "" + if not restrictions.is_model_allowed(model_id, provider_id): + return False + return True + model = next( - (m for m in models_by_global if m.provider_id in available_provider_ids), + (m for m in models_by_global if is_model_accessible(m)), None, ) @@ -393,7 +523,7 @@ def find_model_by_id( ) model = next( - (m for m in models_by_provider_name if m.provider_id in available_provider_ids), + (m for m in models_by_provider_name if is_model_accessible(m)), None, ) diff --git a/src/api/public/models.py b/src/api/public/models.py index 829559d..adc488a 100644 --- a/src/api/public/models.py +++ b/src/api/public/models.py @@ -14,6 +14,7 @@ from fastapi.responses import JSONResponse from sqlalchemy.orm import Session from src.api.base.models_service import ( + AccessRestrictions, ModelInfo, find_model_by_id, get_available_provider_ids, @@ -103,6 +104,35 @@ def _get_formats_for_api(api_format: str) -> list[str]: return _OPENAI_FORMATS +def _build_empty_list_response(api_format: str) -> dict: + """根据 API 格式构建空列表响应""" + if api_format == "claude": + return {"data": [], "has_more": False, "first_id": None, "last_id": None} + elif api_format == "gemini": + return {"models": []} + else: + return {"object": "list", "data": []} + + +def _filter_formats_by_restrictions( + formats: list[str], restrictions: AccessRestrictions, api_format: str +) -> Tuple[list[str], Optional[dict]]: + """ + 根据访问限制过滤 API 格式 + + Returns: + (过滤后的格式列表, 空响应或None) + 如果过滤后为空,返回对应格式的空响应 + """ + if restrictions.allowed_api_formats is None: + return formats, None + filtered = [f for f in formats if f in restrictions.allowed_api_formats] + if not filtered: + logger.info(f"[Models] API Key 不允许访问格式 {api_format}") + return [], _build_empty_list_response(api_format) + return filtered, None + + def _authenticate(db: Session, api_key: Optional[str]) -> Tuple[Optional[User], Optional[ApiKey]]: """ 认证 API Key @@ -375,22 +405,24 @@ async def list_models( logger.info(f"[Models] GET /v1/models | format={api_format}") # 认证 - user, _ = _authenticate(db, api_key) + user, key_record = _authenticate(db, api_key) if not user: return _build_auth_error_response(api_format) + # 构建访问限制 + restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) + + # 检查 API 格式限制 formats = _get_formats_for_api(api_format) + formats, empty_response = _filter_formats_by_restrictions(formats, restrictions, api_format) + if empty_response is not None: + return empty_response available_provider_ids = get_available_provider_ids(db, formats) if not available_provider_ids: - if api_format == "claude": - return {"data": [], "has_more": False, "first_id": None, "last_id": None} - elif api_format == "gemini": - return {"models": []} - else: - return {"object": "list", "data": []} + return _build_empty_list_response(api_format) - models = await list_available_models(db, available_provider_ids, formats) + models = await list_available_models(db, available_provider_ids, formats, restrictions) logger.debug(f"[Models] 返回 {len(models)} 个模型") if api_format == "claude": @@ -419,14 +451,21 @@ async def retrieve_model( logger.info(f"[Models] GET /v1/models/{model_id} | format={api_format}") # 认证 - user, _ = _authenticate(db, api_key) + user, key_record = _authenticate(db, api_key) if not user: return _build_auth_error_response(api_format) + # 构建访问限制 + restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) + + # 检查 API 格式限制 formats = _get_formats_for_api(api_format) + formats, _ = _filter_formats_by_restrictions(formats, restrictions, api_format) + if not formats: + return _build_404_response(model_id, api_format) available_provider_ids = get_available_provider_ids(db, formats) - model_info = find_model_by_id(db, model_id, available_provider_ids, formats) + model_info = find_model_by_id(db, model_id, available_provider_ids, formats, restrictions) if not model_info: return _build_404_response(model_id, api_format) @@ -455,15 +494,25 @@ async def list_models_gemini( api_key = _extract_api_key_from_request(request, gemini_def) # 认证 - user, _ = _authenticate(db, api_key) + user, key_record = _authenticate(db, api_key) if not user: return _build_auth_error_response("gemini") - available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) + # 构建访问限制 + restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) + + # 检查 API 格式限制 + formats, empty_response = _filter_formats_by_restrictions( + _GEMINI_FORMATS, restrictions, "gemini" + ) + if empty_response is not None: + return empty_response + + available_provider_ids = get_available_provider_ids(db, formats) if not available_provider_ids: return {"models": []} - models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS) + models = await list_available_models(db, available_provider_ids, formats, restrictions) logger.debug(f"[Models] 返回 {len(models)} 个模型") response = _build_gemini_list_response(models, page_size, page_token) logger.debug(f"[Models] Gemini 响应: {response}") @@ -486,12 +535,22 @@ async def get_model_gemini( api_key = _extract_api_key_from_request(request, gemini_def) # 认证 - user, _ = _authenticate(db, api_key) + user, key_record = _authenticate(db, api_key) if not user: return _build_auth_error_response("gemini") - available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) - model_info = find_model_by_id(db, model_id, available_provider_ids, _GEMINI_FORMATS) + # 构建访问限制 + restrictions = AccessRestrictions.from_api_key_and_user(key_record, user) + + # 检查 API 格式限制 + formats, _ = _filter_formats_by_restrictions(_GEMINI_FORMATS, restrictions, "gemini") + if not formats: + return _build_404_response(model_id, "gemini") + + available_provider_ids = get_available_provider_ids(db, formats) + model_info = find_model_by_id( + db, model_id, available_provider_ids, formats, restrictions + ) if not model_info: return _build_404_response(model_id, "gemini") diff --git a/src/core/logger.py b/src/core/logger.py index 0d48f8f..82d0acb 100644 --- a/src/core/logger.py +++ b/src/core/logger.py @@ -96,13 +96,15 @@ if not DISABLE_FILE_LOG: log_dir.mkdir(exist_ok=True) # 文件日志通用配置 + # 注意: enqueue=False 使用同步模式,避免 multiprocessing 信号量泄漏 + # 在 macOS 上,进程异常退出时 POSIX 信号量不会自动释放,导致资源耗尽 file_log_config = { "format": FILE_FORMAT, "filter": _log_filter, "rotation": "100 MB", "retention": "30 days", "compression": "gz", - "enqueue": True, + "enqueue": False, "encoding": "utf-8", "catch": True, }