diff --git a/alembic/versions/20260101_1400_add_ldap_authentication_support.py b/alembic/versions/20260101_1400_add_ldap_authentication_support.py index 76caeaa..9d93022 100644 --- a/alembic/versions/20260101_1400_add_ldap_authentication_support.py +++ b/alembic/versions/20260101_1400_add_ldap_authentication_support.py @@ -46,6 +46,7 @@ def upgrade() -> None: sa.Column('is_enabled', sa.Boolean(), nullable=False, server_default='false'), sa.Column('is_exclusive', sa.Boolean(), nullable=False, server_default='false'), sa.Column('use_starttls', sa.Boolean(), nullable=False, server_default='false'), + sa.Column('connect_timeout', sa.Integer(), nullable=False, server_default='10'), sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')), sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')), sa.PrimaryKeyConstraint('id') diff --git a/frontend/src/api/admin.ts b/frontend/src/api/admin.ts index bdebd01..d59b404 100644 --- a/frontend/src/api/admin.ts +++ b/frontend/src/api/admin.ts @@ -167,6 +167,7 @@ export interface LdapConfigResponse { is_enabled: boolean is_exclusive: boolean use_starttls: boolean + connect_timeout: number } // LDAP 配置更新请求 @@ -182,6 +183,7 @@ export interface LdapConfigUpdateRequest { is_enabled?: boolean is_exclusive?: boolean use_starttls?: boolean + connect_timeout?: number } // LDAP 连接测试响应 @@ -527,11 +529,8 @@ export const adminApi = { }, // 测试 LDAP 连接 - async testLdapConnection(): Promise { - const response = await apiClient.post( - '/api/admin/ldap/test', - {} - ) + async testLdapConnection(config: LdapConfigUpdateRequest): Promise { + const response = await apiClient.post('/api/admin/ldap/test', config) return response.data } } diff --git a/frontend/src/features/auth/components/LoginDialog.vue b/frontend/src/features/auth/components/LoginDialog.vue index 9a0ba2b..0a2f990 100644 --- a/frontend/src/features/auth/components/LoginDialog.vue +++ b/frontend/src/features/auth/components/LoginDialog.vue @@ -292,6 +292,10 @@ onMounted(async () => { localEnabled.value = authSettings.local_enabled ldapEnabled.value = authSettings.ldap_enabled ldapExclusive.value = authSettings.ldap_exclusive + // 若仅允许 LDAP 登录,则禁用本地注册入口 + if (ldapExclusive.value) { + allowRegistration.value = false + } // Set default auth type based on settings if (authSettings.ldap_exclusive) { diff --git a/frontend/src/views/admin/LdapSettings.vue b/frontend/src/views/admin/LdapSettings.vue index 13108ad..501dd42 100644 --- a/frontend/src/views/admin/LdapSettings.vue +++ b/frontend/src/views/admin/LdapSettings.vue @@ -153,6 +153,24 @@ class="mt-1" /> + +
+ + +

+ LDAP 服务器连接超时时间 (1-60秒) +

+
@@ -222,6 +240,7 @@ const ldapConfig = ref({ is_enabled: false, is_exclusive: false, use_starttls: false, + connect_timeout: 10, }) onMounted(async () => { @@ -244,6 +263,7 @@ async function loadConfig() { is_enabled: response.is_enabled || false, is_exclusive: response.is_exclusive || false, use_starttls: response.use_starttls || false, + connect_timeout: response.connect_timeout || 10, } hasPassword.value = !!response.server_url } catch (err) { @@ -268,6 +288,7 @@ async function handleSave() { is_enabled: ldapConfig.value.is_enabled, is_exclusive: ldapConfig.value.is_exclusive, use_starttls: ldapConfig.value.use_starttls, + connect_timeout: ldapConfig.value.connect_timeout, ...(ldapConfig.value.bind_password && { bind_password: ldapConfig.value.bind_password }), } await adminApi.updateLdapConfig(payload) @@ -285,7 +306,21 @@ async function handleSave() { async function handleTestConnection() { testLoading.value = true try { - const response = await adminApi.testLdapConnection() + const payload: LdapConfigUpdateRequest = { + server_url: ldapConfig.value.server_url, + bind_dn: ldapConfig.value.bind_dn, + base_dn: ldapConfig.value.base_dn, + user_search_filter: ldapConfig.value.user_search_filter, + username_attr: ldapConfig.value.username_attr, + email_attr: ldapConfig.value.email_attr, + display_name_attr: ldapConfig.value.display_name_attr, + is_enabled: ldapConfig.value.is_enabled, + is_exclusive: ldapConfig.value.is_exclusive, + use_starttls: ldapConfig.value.use_starttls, + connect_timeout: ldapConfig.value.connect_timeout, + ...(ldapConfig.value.bind_password && { bind_password: ldapConfig.value.bind_password }), + } + const response = await adminApi.testLdapConnection(payload) if (response.success) { success('LDAP 连接测试成功') } else { diff --git a/src/api/admin/ldap.py b/src/api/admin/ldap.py index 6406ea6..15b1a24 100644 --- a/src/api/admin/ldap.py +++ b/src/api/admin/ldap.py @@ -33,6 +33,7 @@ class LDAPConfigResponse(BaseModel): is_enabled: bool is_exclusive: bool use_starttls: bool + connect_timeout: int class LDAPConfigUpdate(BaseModel): @@ -49,6 +50,7 @@ class LDAPConfigUpdate(BaseModel): is_enabled: bool = False is_exclusive: bool = False use_starttls: bool = False + connect_timeout: int = Field(default=10, ge=1, le=60) class LDAPTestResponse(BaseModel): @@ -58,6 +60,23 @@ class LDAPTestResponse(BaseModel): message: str +class LDAPConfigTest(BaseModel): + """LDAP配置测试请求(全部可选,用于临时覆盖)""" + + server_url: Optional[str] = Field(None, min_length=1, max_length=255) + bind_dn: Optional[str] = Field(None, min_length=1, max_length=255) + bind_password: Optional[str] = Field(None, min_length=1) + base_dn: Optional[str] = Field(None, min_length=1, max_length=255) + user_search_filter: Optional[str] = Field(None, max_length=500) + username_attr: Optional[str] = Field(None, max_length=50) + email_attr: Optional[str] = Field(None, max_length=50) + display_name_attr: Optional[str] = Field(None, max_length=50) + is_enabled: Optional[bool] = None + is_exclusive: Optional[bool] = None + use_starttls: Optional[bool] = None + connect_timeout: Optional[int] = Field(None, ge=1, le=60) + + # ========== API Endpoints ========== @@ -102,6 +121,7 @@ class AdminGetLDAPConfigAdapter(AdminApiAdapter): is_enabled=False, is_exclusive=False, use_starttls=False, + connect_timeout=10, ).model_dump() return LDAPConfigResponse( @@ -115,6 +135,7 @@ class AdminGetLDAPConfigAdapter(AdminApiAdapter): is_enabled=config.is_enabled, is_exclusive=config.is_exclusive, use_starttls=config.use_starttls, + connect_timeout=config.connect_timeout, ).model_dump() @@ -151,6 +172,7 @@ class AdminUpdateLDAPConfigAdapter(AdminApiAdapter): config.is_enabled = config_update.is_enabled config.is_exclusive = config_update.is_exclusive config.use_starttls = config_update.use_starttls + config.connect_timeout = config_update.connect_timeout if config_update.bind_password: config.bind_password_encrypted = crypto_service.encrypt(config_update.bind_password) @@ -162,33 +184,77 @@ class AdminUpdateLDAPConfigAdapter(AdminApiAdapter): class AdminTestLDAPConnectionAdapter(AdminApiAdapter): async def handle(self, context) -> Dict[str, Any]: # type: ignore[override] - db = context.db - config = db.query(LDAPConfig).first() + from src.services.auth.ldap import LDAPService - if not config: - return LDAPTestResponse(success=False, message="LDAP配置不存在").model_dump() + db = context.db + if context.json_body is not None: + payload = context.json_body + elif not context.raw_body: + payload = {} + else: + payload = context.ensure_json_body() + + saved_config = db.query(LDAPConfig).first() try: - import ldap3 + overrides = LDAPConfigTest.model_validate(payload) + except ValidationError as e: + errors = e.errors() + if errors: + raise InvalidRequestException(translate_pydantic_error(errors[0])) + raise InvalidRequestException("请求数据验证失败") - bind_password = crypto_service.decrypt(config.bind_password_encrypted) + config_data: Dict[str, Any] = {} - use_ssl = config.server_url.startswith("ldaps://") - server = ldap3.Server(config.server_url, use_ssl=use_ssl, get_info=ldap3.ALL) - conn = ldap3.Connection(server, user=config.bind_dn, password=bind_password) - - if config.use_starttls and not use_ssl: - conn.start_tls() - - if not conn.bind(): + if saved_config: + config_data = { + "server_url": saved_config.server_url, + "bind_dn": saved_config.bind_dn, + "base_dn": saved_config.base_dn, + "user_search_filter": saved_config.user_search_filter, + "username_attr": saved_config.username_attr, + "email_attr": saved_config.email_attr, + "display_name_attr": saved_config.display_name_attr, + "use_starttls": saved_config.use_starttls, + "connect_timeout": saved_config.connect_timeout, + } + try: + config_data["bind_password"] = crypto_service.decrypt( + saved_config.bind_password_encrypted + ) + except Exception as e: return LDAPTestResponse( - success=False, message=f"绑定失败: {conn.result}" + success=False, message=f"绑定密码解密失败: {str(e)}" ).model_dump() - conn.unbind() - return LDAPTestResponse(success=True, message="LDAP连接测试成功").model_dump() + # 应用前端传入的覆盖值 + for field in [ + "server_url", + "bind_dn", + "base_dn", + "user_search_filter", + "username_attr", + "email_attr", + "display_name_attr", + "use_starttls", + "is_enabled", + "is_exclusive", + "connect_timeout", + ]: + value = getattr(overrides, field) + if value is not None: + config_data[field] = value - except ImportError: - return LDAPTestResponse(success=False, message="ldap3库未安装").model_dump() - except Exception as e: - return LDAPTestResponse(success=False, message=f"连接失败: {str(e)}").model_dump() + if overrides.bind_password: + config_data["bind_password"] = overrides.bind_password + + # 必填字段检查 + required_fields = ["server_url", "bind_dn", "base_dn", "bind_password"] + missing = [f for f in required_fields if not config_data.get(f)] + if missing: + return LDAPTestResponse( + success=False, message=f"缺少必要字段: {', '.join(missing)}" + ).model_dump() + + success, message = LDAPService.test_connection_with_config(config_data) + return LDAPTestResponse(success=success, message=message).model_dump() diff --git a/src/api/auth/routes.py b/src/api/auth/routes.py index c8b7ddb..335d7a3 100644 --- a/src/api/auth/routes.py +++ b/src/api/auth/routes.py @@ -349,6 +349,12 @@ class AuthRegisterAdapter(AuthPublicAdapter): detail=f"注册请求过于频繁,请在 {reset_after} 秒后重试", ) + # 仅允许 LDAP 登录时拒绝本地注册 + if LDAPService.is_ldap_exclusive(db): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="系统已启用 LDAP 专属登录,禁止本地注册" + ) + allow_registration = db.query(SystemConfig).filter_by(key="enable_registration").first() if allow_registration and not allow_registration.value: AuditService.log_event( diff --git a/src/models/database.py b/src/models/database.py index 3cae081..6c6f8da 100644 --- a/src/models/database.py +++ b/src/models/database.py @@ -459,6 +459,7 @@ class LDAPConfig(Base): Boolean, default=False, nullable=False ) # 是否仅允许 LDAP 登录(禁用本地认证) use_starttls = Column(Boolean, default=False, nullable=False) # 是否使用 STARTTLS + connect_timeout = Column(Integer, default=10, nullable=False) # 连接超时时间(秒) # 时间戳 created_at = Column( diff --git a/src/services/auth/ldap.py b/src/services/auth/ldap.py index bb6b0e4..cfb65cc 100644 --- a/src/services/auth/ldap.py +++ b/src/services/auth/ldap.py @@ -1,14 +1,14 @@ """LDAP 认证服务""" -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple from sqlalchemy.orm import Session from src.core.logger import logger from src.models.database import LDAPConfig -# LDAP 连接超时时间(秒) -LDAP_CONNECT_TIMEOUT = 10 +# LDAP 连接默认超时时间(秒) +DEFAULT_LDAP_CONNECT_TIMEOUT = 10 def escape_ldap_filter(value: str) -> str: @@ -71,12 +71,40 @@ class LDAPService: return config.is_exclusive if config and config.is_enabled else False @staticmethod - def authenticate(db: Session, username: str, password: str) -> Optional[dict]: + def get_config_data(db: Session) -> Optional[Dict[str, str]]: + """ + 提前获取并解密配置,供线程池使用,避免跨线程共享 Session。 + """ + config = LDAPService.get_config(db) + if not config or not config.is_enabled: + return None + + try: + bind_password = config.get_bind_password() + except Exception as e: + logger.error(f"LDAP 绑定密码解密失败: {e}") + return None + + return { + "server_url": config.server_url, + "bind_dn": config.bind_dn, + "bind_password": bind_password, + "base_dn": config.base_dn, + "user_search_filter": config.user_search_filter, + "username_attr": config.username_attr, + "email_attr": config.email_attr, + "display_name_attr": config.display_name_attr, + "use_starttls": config.use_starttls, + "connect_timeout": config.connect_timeout or DEFAULT_LDAP_CONNECT_TIMEOUT, + } + + @staticmethod + def authenticate_with_config(config: Dict[str, str], username: str, password: str) -> Optional[dict]: """ LDAP bind 验证 Args: - db: 数据库会话 + config: 已解密的 LDAP 配置 username: 用户名 password: 密码 @@ -91,8 +119,7 @@ class LDAPService: logger.error("ldap3 库未安装") return None - config = LDAPService.get_config(db) - if not config or not config.is_enabled: + if not config: logger.warning("LDAP 未配置或未启用") return None @@ -101,19 +128,21 @@ class LDAPService: try: # 创建服务器连接 - use_ssl = config.server_url.startswith("ldaps://") + server_url = config["server_url"] + use_ssl = server_url.startswith("ldaps://") + timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT) server = Server( - config.server_url, + server_url, use_ssl=use_ssl, get_info=ldap3.ALL, - connect_timeout=LDAP_CONNECT_TIMEOUT, + connect_timeout=timeout, ) # 使用管理员账号连接 - bind_password = config.get_bind_password() - admin_conn = Connection(server, user=config.bind_dn, password=bind_password) + bind_password = config["bind_password"] + admin_conn = Connection(server, user=config["bind_dn"], password=bind_password) - if config.use_starttls and not use_ssl: + if config.get("use_starttls") and not use_ssl: admin_conn.start_tls() if not admin_conn.bind(): @@ -122,12 +151,16 @@ class LDAPService: # 搜索用户(转义用户名防止 LDAP 注入) safe_username = escape_ldap_filter(username) - search_filter = config.user_search_filter.replace("{username}", safe_username) + search_filter = config["user_search_filter"].replace("{username}", safe_username) admin_conn.search( - search_base=config.base_dn, + search_base=config["base_dn"], search_filter=search_filter, search_scope=SUBTREE, - attributes=[config.username_attr, config.email_attr, config.display_name_attr], + attributes=[ + config["username_attr"], + config["email_attr"], + config["display_name_attr"], + ], ) if not admin_conn.entries: @@ -139,7 +172,7 @@ class LDAPService: # 用户密码验证 user_conn = Connection(server, user=user_dn, password=password) - if config.use_starttls and not use_ssl: + if config.get("use_starttls") and not use_ssl: user_conn.start_tls() if not user_conn.bind(): @@ -147,8 +180,10 @@ class LDAPService: return None # 提取用户属性(优先用 LDAP 提供的值,不合法则回退默认) - email = _get_attr_value(user_entry, config.email_attr, f"{username}@ldap.local") - display_name = _get_attr_value(user_entry, config.display_name_attr, username) + email = _get_attr_value( + user_entry, config["email_attr"], f"{username}@ldap.local" + ) + display_name = _get_attr_value(user_entry, config["display_name_attr"], username) logger.info(f"LDAP 认证成功: {username}") return { @@ -180,7 +215,7 @@ class LDAPService: pass @staticmethod - def test_connection(db: Session) -> Tuple[bool, str]: + def test_connection_with_config(config: Dict[str, str]) -> Tuple[bool, str]: """ 测试 LDAP 连接 @@ -193,22 +228,23 @@ class LDAPService: except ImportError: return False, "ldap3 库未安装" - config = LDAPService.get_config(db) if not config: return False, "LDAP 配置不存在" try: - use_ssl = config.server_url.startswith("ldaps://") + server_url = config["server_url"] + use_ssl = server_url.startswith("ldaps://") + timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT) server = Server( - config.server_url, + server_url, use_ssl=use_ssl, get_info=ldap3.ALL, - connect_timeout=LDAP_CONNECT_TIMEOUT, + connect_timeout=timeout, ) - bind_password = config.get_bind_password() - conn = Connection(server, user=config.bind_dn, password=bind_password) + bind_password = config["bind_password"] + conn = Connection(server, user=config["bind_dn"], password=bind_password) - if config.use_starttls and not use_ssl: + if config.get("use_starttls") and not use_ssl: conn.start_tls() if not conn.bind(): @@ -219,3 +255,16 @@ class LDAPService: except Exception as e: return False, f"连接失败: {str(e)}" + + # 兼容旧接口:如果其他代码直接调用 + @staticmethod + def authenticate(db: Session, username: str, password: str) -> Optional[dict]: + config = LDAPService.get_config_data(db) + return LDAPService.authenticate_with_config(config, username, password) if config else None + + @staticmethod + def test_connection(db: Session) -> Tuple[bool, str]: + config = LDAPService.get_config_data(db) + if not config: + return False, "LDAP 配置不存在或未启用" + return LDAPService.test_connection_with_config(config) diff --git a/src/services/auth/service.py b/src/services/auth/service.py index 1ead8ec..26cc2ea 100644 --- a/src/services/auth/service.py +++ b/src/services/auth/service.py @@ -112,8 +112,16 @@ class AuthService: logger.warning("登录失败 - LDAP 认证未启用") return None + # 预取配置,避免将 Session 传递到线程池 + config_data = LDAPService.get_config_data(db) + if not config_data: + logger.warning("登录失败 - 无法获取 LDAP 配置或解密失败") + return None + # 在线程池中执行阻塞的 LDAP 网络请求,避免阻塞事件循环 - ldap_user = await run_in_threadpool(LDAPService.authenticate, db, email, password) + ldap_user = await run_in_threadpool( + LDAPService.authenticate_with_config, config_data, email, password + ) if not ldap_user: return None @@ -128,10 +136,6 @@ class AuthService: return user # 本地认证 - if LDAPService.is_ldap_exclusive(db): - logger.warning("登录失败 - 仅允许 LDAP 登录") - return None - # 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象 user = db.query(User).filter(User.email == email).first() @@ -139,6 +143,13 @@ class AuthService: logger.warning(f"登录失败 - 用户不存在: {email}") return None + # 检查 LDAP exclusive 模式:仅允许本地管理员登录(紧急恢复通道) + if LDAPService.is_ldap_exclusive(db): + if user.role != UserRole.ADMIN or user.auth_source != AuthSource.LOCAL: + logger.warning(f"登录失败 - 仅允许 LDAP 登录(管理员除外): {email}") + return None + logger.info(f"LDAP exclusive 模式下允许本地管理员登录: {email}") + # 检查用户认证来源 if user.auth_source == AuthSource.LDAP: logger.warning(f"登录失败 - 该用户使用 LDAP 认证: {email}")