feat: 添加模型列表访问限制功能

实现 API Key 和 User 级别的模型访问权限控制,支持按 Provider 和模型名称限制。

- 新增 AccessRestrictions 类处理访问限制合并逻辑(API Key 优先于 User)
- models_service 支持根据限制过滤模型列表
- models.py 在列表查询时构建并应用访问限制
- 优化缓存策略:仅无限制请求使用缓存,有限制的请求旁路缓存
- 修复 logger 配置:enqueue 改为 False 避免 macOS 信号量泄漏
This commit is contained in:
fawney19
2025-12-30 16:57:59 +08:00
parent b89a4af0cf
commit e20a09f15a
3 changed files with 139 additions and 15 deletions

View File

@@ -14,6 +14,7 @@ from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from src.api.base.models_service import (
AccessRestrictions,
ModelInfo,
find_model_by_id,
get_available_provider_ids,
@@ -375,10 +376,13 @@ async def list_models(
logger.info(f"[Models] GET /v1/models | format={api_format}")
# 认证
user, _ = _authenticate(db, api_key)
user, key_record = _authenticate(db, api_key)
if not user:
return _build_auth_error_response(api_format)
# 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
formats = _get_formats_for_api(api_format)
available_provider_ids = get_available_provider_ids(db, formats)
@@ -390,7 +394,7 @@ async def list_models(
else:
return {"object": "list", "data": []}
models = await list_available_models(db, available_provider_ids, formats)
models = await list_available_models(db, available_provider_ids, formats, restrictions)
logger.debug(f"[Models] 返回 {len(models)} 个模型")
if api_format == "claude":
@@ -419,14 +423,17 @@ async def retrieve_model(
logger.info(f"[Models] GET /v1/models/{model_id} | format={api_format}")
# 认证
user, _ = _authenticate(db, api_key)
user, key_record = _authenticate(db, api_key)
if not user:
return _build_auth_error_response(api_format)
# 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
formats = _get_formats_for_api(api_format)
available_provider_ids = get_available_provider_ids(db, formats)
model_info = find_model_by_id(db, model_id, available_provider_ids, formats)
model_info = find_model_by_id(db, model_id, available_provider_ids, formats, restrictions)
if not model_info:
return _build_404_response(model_id, api_format)
@@ -455,15 +462,18 @@ async def list_models_gemini(
api_key = _extract_api_key_from_request(request, gemini_def)
# 认证
user, _ = _authenticate(db, api_key)
user, key_record = _authenticate(db, api_key)
if not user:
return _build_auth_error_response("gemini")
# 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS)
if not available_provider_ids:
return {"models": []}
models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS)
models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS, restrictions)
logger.debug(f"[Models] 返回 {len(models)} 个模型")
response = _build_gemini_list_response(models, page_size, page_token)
logger.debug(f"[Models] Gemini 响应: {response}")
@@ -486,12 +496,17 @@ async def get_model_gemini(
api_key = _extract_api_key_from_request(request, gemini_def)
# 认证
user, _ = _authenticate(db, api_key)
user, key_record = _authenticate(db, api_key)
if not user:
return _build_auth_error_response("gemini")
# 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS)
model_info = find_model_by_id(db, model_id, available_provider_ids, _GEMINI_FORMATS)
model_info = find_model_by_id(
db, model_id, available_provider_ids, _GEMINI_FORMATS, restrictions
)
if not model_info:
return _build_404_response(model_id, "gemini")