Enhance LDAP auth config handling

This commit is contained in:
RWDai
2026-01-04 16:27:02 +08:00
parent 414f45aa71
commit 3e4309eba3
9 changed files with 231 additions and 59 deletions

View File

@@ -46,6 +46,7 @@ def upgrade() -> None:
sa.Column('is_enabled', sa.Boolean(), nullable=False, server_default='false'), sa.Column('is_enabled', sa.Boolean(), nullable=False, server_default='false'),
sa.Column('is_exclusive', sa.Boolean(), nullable=False, server_default='false'), sa.Column('is_exclusive', sa.Boolean(), nullable=False, server_default='false'),
sa.Column('use_starttls', sa.Boolean(), nullable=False, server_default='false'), sa.Column('use_starttls', sa.Boolean(), nullable=False, server_default='false'),
sa.Column('connect_timeout', sa.Integer(), nullable=False, server_default='10'),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')), sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')), sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id')

View File

@@ -167,6 +167,7 @@ export interface LdapConfigResponse {
is_enabled: boolean is_enabled: boolean
is_exclusive: boolean is_exclusive: boolean
use_starttls: boolean use_starttls: boolean
connect_timeout: number
} }
// LDAP 配置更新请求 // LDAP 配置更新请求
@@ -182,6 +183,7 @@ export interface LdapConfigUpdateRequest {
is_enabled?: boolean is_enabled?: boolean
is_exclusive?: boolean is_exclusive?: boolean
use_starttls?: boolean use_starttls?: boolean
connect_timeout?: number
} }
// LDAP 连接测试响应 // LDAP 连接测试响应
@@ -527,11 +529,8 @@ export const adminApi = {
}, },
// 测试 LDAP 连接 // 测试 LDAP 连接
async testLdapConnection(): Promise<LdapTestResponse> { async testLdapConnection(config: LdapConfigUpdateRequest): Promise<LdapTestResponse> {
const response = await apiClient.post<LdapTestResponse>( const response = await apiClient.post<LdapTestResponse>('/api/admin/ldap/test', config)
'/api/admin/ldap/test',
{}
)
return response.data return response.data
} }
} }

View File

@@ -292,6 +292,10 @@ onMounted(async () => {
localEnabled.value = authSettings.local_enabled localEnabled.value = authSettings.local_enabled
ldapEnabled.value = authSettings.ldap_enabled ldapEnabled.value = authSettings.ldap_enabled
ldapExclusive.value = authSettings.ldap_exclusive ldapExclusive.value = authSettings.ldap_exclusive
// 若仅允许 LDAP 登录,则禁用本地注册入口
if (ldapExclusive.value) {
allowRegistration.value = false
}
// Set default auth type based on settings // Set default auth type based on settings
if (authSettings.ldap_exclusive) { if (authSettings.ldap_exclusive) {

View File

@@ -153,6 +153,24 @@
class="mt-1" class="mt-1"
/> />
</div> </div>
<div>
<Label for="connect-timeout" class="block text-sm font-medium">
连接超时 ()
</Label>
<Input
id="connect-timeout"
v-model.number="ldapConfig.connect_timeout"
type="number"
min="1"
max="60"
placeholder="10"
class="mt-1"
/>
<p class="mt-1 text-xs text-muted-foreground">
LDAP 服务器连接超时时间 (1-60)
</p>
</div>
</div> </div>
<div class="mt-6 space-y-4"> <div class="mt-6 space-y-4">
@@ -222,6 +240,7 @@ const ldapConfig = ref({
is_enabled: false, is_enabled: false,
is_exclusive: false, is_exclusive: false,
use_starttls: false, use_starttls: false,
connect_timeout: 10,
}) })
onMounted(async () => { onMounted(async () => {
@@ -244,6 +263,7 @@ async function loadConfig() {
is_enabled: response.is_enabled || false, is_enabled: response.is_enabled || false,
is_exclusive: response.is_exclusive || false, is_exclusive: response.is_exclusive || false,
use_starttls: response.use_starttls || false, use_starttls: response.use_starttls || false,
connect_timeout: response.connect_timeout || 10,
} }
hasPassword.value = !!response.server_url hasPassword.value = !!response.server_url
} catch (err) { } catch (err) {
@@ -268,6 +288,7 @@ async function handleSave() {
is_enabled: ldapConfig.value.is_enabled, is_enabled: ldapConfig.value.is_enabled,
is_exclusive: ldapConfig.value.is_exclusive, is_exclusive: ldapConfig.value.is_exclusive,
use_starttls: ldapConfig.value.use_starttls, use_starttls: ldapConfig.value.use_starttls,
connect_timeout: ldapConfig.value.connect_timeout,
...(ldapConfig.value.bind_password && { bind_password: ldapConfig.value.bind_password }), ...(ldapConfig.value.bind_password && { bind_password: ldapConfig.value.bind_password }),
} }
await adminApi.updateLdapConfig(payload) await adminApi.updateLdapConfig(payload)
@@ -285,7 +306,21 @@ async function handleSave() {
async function handleTestConnection() { async function handleTestConnection() {
testLoading.value = true testLoading.value = true
try { try {
const response = await adminApi.testLdapConnection() const payload: LdapConfigUpdateRequest = {
server_url: ldapConfig.value.server_url,
bind_dn: ldapConfig.value.bind_dn,
base_dn: ldapConfig.value.base_dn,
user_search_filter: ldapConfig.value.user_search_filter,
username_attr: ldapConfig.value.username_attr,
email_attr: ldapConfig.value.email_attr,
display_name_attr: ldapConfig.value.display_name_attr,
is_enabled: ldapConfig.value.is_enabled,
is_exclusive: ldapConfig.value.is_exclusive,
use_starttls: ldapConfig.value.use_starttls,
connect_timeout: ldapConfig.value.connect_timeout,
...(ldapConfig.value.bind_password && { bind_password: ldapConfig.value.bind_password }),
}
const response = await adminApi.testLdapConnection(payload)
if (response.success) { if (response.success) {
success('LDAP 连接测试成功') success('LDAP 连接测试成功')
} else { } else {

View File

@@ -33,6 +33,7 @@ class LDAPConfigResponse(BaseModel):
is_enabled: bool is_enabled: bool
is_exclusive: bool is_exclusive: bool
use_starttls: bool use_starttls: bool
connect_timeout: int
class LDAPConfigUpdate(BaseModel): class LDAPConfigUpdate(BaseModel):
@@ -49,6 +50,7 @@ class LDAPConfigUpdate(BaseModel):
is_enabled: bool = False is_enabled: bool = False
is_exclusive: bool = False is_exclusive: bool = False
use_starttls: bool = False use_starttls: bool = False
connect_timeout: int = Field(default=10, ge=1, le=60)
class LDAPTestResponse(BaseModel): class LDAPTestResponse(BaseModel):
@@ -58,6 +60,23 @@ class LDAPTestResponse(BaseModel):
message: str message: str
class LDAPConfigTest(BaseModel):
"""LDAP配置测试请求全部可选用于临时覆盖"""
server_url: Optional[str] = Field(None, min_length=1, max_length=255)
bind_dn: Optional[str] = Field(None, min_length=1, max_length=255)
bind_password: Optional[str] = Field(None, min_length=1)
base_dn: Optional[str] = Field(None, min_length=1, max_length=255)
user_search_filter: Optional[str] = Field(None, max_length=500)
username_attr: Optional[str] = Field(None, max_length=50)
email_attr: Optional[str] = Field(None, max_length=50)
display_name_attr: Optional[str] = Field(None, max_length=50)
is_enabled: Optional[bool] = None
is_exclusive: Optional[bool] = None
use_starttls: Optional[bool] = None
connect_timeout: Optional[int] = Field(None, ge=1, le=60)
# ========== API Endpoints ========== # ========== API Endpoints ==========
@@ -102,6 +121,7 @@ class AdminGetLDAPConfigAdapter(AdminApiAdapter):
is_enabled=False, is_enabled=False,
is_exclusive=False, is_exclusive=False,
use_starttls=False, use_starttls=False,
connect_timeout=10,
).model_dump() ).model_dump()
return LDAPConfigResponse( return LDAPConfigResponse(
@@ -115,6 +135,7 @@ class AdminGetLDAPConfigAdapter(AdminApiAdapter):
is_enabled=config.is_enabled, is_enabled=config.is_enabled,
is_exclusive=config.is_exclusive, is_exclusive=config.is_exclusive,
use_starttls=config.use_starttls, use_starttls=config.use_starttls,
connect_timeout=config.connect_timeout,
).model_dump() ).model_dump()
@@ -151,6 +172,7 @@ class AdminUpdateLDAPConfigAdapter(AdminApiAdapter):
config.is_enabled = config_update.is_enabled config.is_enabled = config_update.is_enabled
config.is_exclusive = config_update.is_exclusive config.is_exclusive = config_update.is_exclusive
config.use_starttls = config_update.use_starttls config.use_starttls = config_update.use_starttls
config.connect_timeout = config_update.connect_timeout
if config_update.bind_password: if config_update.bind_password:
config.bind_password_encrypted = crypto_service.encrypt(config_update.bind_password) config.bind_password_encrypted = crypto_service.encrypt(config_update.bind_password)
@@ -162,33 +184,77 @@ class AdminUpdateLDAPConfigAdapter(AdminApiAdapter):
class AdminTestLDAPConnectionAdapter(AdminApiAdapter): class AdminTestLDAPConnectionAdapter(AdminApiAdapter):
async def handle(self, context) -> Dict[str, Any]: # type: ignore[override] async def handle(self, context) -> Dict[str, Any]: # type: ignore[override]
db = context.db from src.services.auth.ldap import LDAPService
config = db.query(LDAPConfig).first()
if not config: db = context.db
return LDAPTestResponse(success=False, message="LDAP配置不存在").model_dump() if context.json_body is not None:
payload = context.json_body
elif not context.raw_body:
payload = {}
else:
payload = context.ensure_json_body()
saved_config = db.query(LDAPConfig).first()
try: try:
import ldap3 overrides = LDAPConfigTest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
bind_password = crypto_service.decrypt(config.bind_password_encrypted) config_data: Dict[str, Any] = {}
use_ssl = config.server_url.startswith("ldaps://") if saved_config:
server = ldap3.Server(config.server_url, use_ssl=use_ssl, get_info=ldap3.ALL) config_data = {
conn = ldap3.Connection(server, user=config.bind_dn, password=bind_password) "server_url": saved_config.server_url,
"bind_dn": saved_config.bind_dn,
if config.use_starttls and not use_ssl: "base_dn": saved_config.base_dn,
conn.start_tls() "user_search_filter": saved_config.user_search_filter,
"username_attr": saved_config.username_attr,
if not conn.bind(): "email_attr": saved_config.email_attr,
"display_name_attr": saved_config.display_name_attr,
"use_starttls": saved_config.use_starttls,
"connect_timeout": saved_config.connect_timeout,
}
try:
config_data["bind_password"] = crypto_service.decrypt(
saved_config.bind_password_encrypted
)
except Exception as e:
return LDAPTestResponse( return LDAPTestResponse(
success=False, message=f"绑定失败: {conn.result}" success=False, message=f"绑定密码解密失败: {str(e)}"
).model_dump() ).model_dump()
conn.unbind() # 应用前端传入的覆盖值
return LDAPTestResponse(success=True, message="LDAP连接测试成功").model_dump() for field in [
"server_url",
"bind_dn",
"base_dn",
"user_search_filter",
"username_attr",
"email_attr",
"display_name_attr",
"use_starttls",
"is_enabled",
"is_exclusive",
"connect_timeout",
]:
value = getattr(overrides, field)
if value is not None:
config_data[field] = value
except ImportError: if overrides.bind_password:
return LDAPTestResponse(success=False, message="ldap3库未安装").model_dump() config_data["bind_password"] = overrides.bind_password
except Exception as e:
return LDAPTestResponse(success=False, message=f"连接失败: {str(e)}").model_dump() # 必填字段检查
required_fields = ["server_url", "bind_dn", "base_dn", "bind_password"]
missing = [f for f in required_fields if not config_data.get(f)]
if missing:
return LDAPTestResponse(
success=False, message=f"缺少必要字段: {', '.join(missing)}"
).model_dump()
success, message = LDAPService.test_connection_with_config(config_data)
return LDAPTestResponse(success=success, message=message).model_dump()

View File

@@ -349,6 +349,12 @@ class AuthRegisterAdapter(AuthPublicAdapter):
detail=f"注册请求过于频繁,请在 {reset_after} 秒后重试", detail=f"注册请求过于频繁,请在 {reset_after} 秒后重试",
) )
# 仅允许 LDAP 登录时拒绝本地注册
if LDAPService.is_ldap_exclusive(db):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="系统已启用 LDAP 专属登录,禁止本地注册"
)
allow_registration = db.query(SystemConfig).filter_by(key="enable_registration").first() allow_registration = db.query(SystemConfig).filter_by(key="enable_registration").first()
if allow_registration and not allow_registration.value: if allow_registration and not allow_registration.value:
AuditService.log_event( AuditService.log_event(

View File

@@ -459,6 +459,7 @@ class LDAPConfig(Base):
Boolean, default=False, nullable=False Boolean, default=False, nullable=False
) # 是否仅允许 LDAP 登录(禁用本地认证) ) # 是否仅允许 LDAP 登录(禁用本地认证)
use_starttls = Column(Boolean, default=False, nullable=False) # 是否使用 STARTTLS use_starttls = Column(Boolean, default=False, nullable=False) # 是否使用 STARTTLS
connect_timeout = Column(Integer, default=10, nullable=False) # 连接超时时间(秒)
# 时间戳 # 时间戳
created_at = Column( created_at = Column(

View File

@@ -1,14 +1,14 @@
"""LDAP 认证服务""" """LDAP 认证服务"""
from typing import Optional, Tuple from typing import Dict, Optional, Tuple
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.core.logger import logger from src.core.logger import logger
from src.models.database import LDAPConfig from src.models.database import LDAPConfig
# LDAP 连接超时时间(秒) # LDAP 连接默认超时时间(秒)
LDAP_CONNECT_TIMEOUT = 10 DEFAULT_LDAP_CONNECT_TIMEOUT = 10
def escape_ldap_filter(value: str) -> str: def escape_ldap_filter(value: str) -> str:
@@ -71,12 +71,40 @@ class LDAPService:
return config.is_exclusive if config and config.is_enabled else False return config.is_exclusive if config and config.is_enabled else False
@staticmethod @staticmethod
def authenticate(db: Session, username: str, password: str) -> Optional[dict]: def get_config_data(db: Session) -> Optional[Dict[str, str]]:
"""
提前获取并解密配置,供线程池使用,避免跨线程共享 Session。
"""
config = LDAPService.get_config(db)
if not config or not config.is_enabled:
return None
try:
bind_password = config.get_bind_password()
except Exception as e:
logger.error(f"LDAP 绑定密码解密失败: {e}")
return None
return {
"server_url": config.server_url,
"bind_dn": config.bind_dn,
"bind_password": bind_password,
"base_dn": config.base_dn,
"user_search_filter": config.user_search_filter,
"username_attr": config.username_attr,
"email_attr": config.email_attr,
"display_name_attr": config.display_name_attr,
"use_starttls": config.use_starttls,
"connect_timeout": config.connect_timeout or DEFAULT_LDAP_CONNECT_TIMEOUT,
}
@staticmethod
def authenticate_with_config(config: Dict[str, str], username: str, password: str) -> Optional[dict]:
""" """
LDAP bind 验证 LDAP bind 验证
Args: Args:
db: 数据库会话 config: 已解密的 LDAP 配置
username: 用户名 username: 用户名
password: 密码 password: 密码
@@ -91,8 +119,7 @@ class LDAPService:
logger.error("ldap3 库未安装") logger.error("ldap3 库未安装")
return None return None
config = LDAPService.get_config(db) if not config:
if not config or not config.is_enabled:
logger.warning("LDAP 未配置或未启用") logger.warning("LDAP 未配置或未启用")
return None return None
@@ -101,19 +128,21 @@ class LDAPService:
try: try:
# 创建服务器连接 # 创建服务器连接
use_ssl = config.server_url.startswith("ldaps://") server_url = config["server_url"]
use_ssl = server_url.startswith("ldaps://")
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
server = Server( server = Server(
config.server_url, server_url,
use_ssl=use_ssl, use_ssl=use_ssl,
get_info=ldap3.ALL, get_info=ldap3.ALL,
connect_timeout=LDAP_CONNECT_TIMEOUT, connect_timeout=timeout,
) )
# 使用管理员账号连接 # 使用管理员账号连接
bind_password = config.get_bind_password() bind_password = config["bind_password"]
admin_conn = Connection(server, user=config.bind_dn, password=bind_password) admin_conn = Connection(server, user=config["bind_dn"], password=bind_password)
if config.use_starttls and not use_ssl: if config.get("use_starttls") and not use_ssl:
admin_conn.start_tls() admin_conn.start_tls()
if not admin_conn.bind(): if not admin_conn.bind():
@@ -122,12 +151,16 @@ class LDAPService:
# 搜索用户(转义用户名防止 LDAP 注入) # 搜索用户(转义用户名防止 LDAP 注入)
safe_username = escape_ldap_filter(username) safe_username = escape_ldap_filter(username)
search_filter = config.user_search_filter.replace("{username}", safe_username) search_filter = config["user_search_filter"].replace("{username}", safe_username)
admin_conn.search( admin_conn.search(
search_base=config.base_dn, search_base=config["base_dn"],
search_filter=search_filter, search_filter=search_filter,
search_scope=SUBTREE, search_scope=SUBTREE,
attributes=[config.username_attr, config.email_attr, config.display_name_attr], attributes=[
config["username_attr"],
config["email_attr"],
config["display_name_attr"],
],
) )
if not admin_conn.entries: if not admin_conn.entries:
@@ -139,7 +172,7 @@ class LDAPService:
# 用户密码验证 # 用户密码验证
user_conn = Connection(server, user=user_dn, password=password) user_conn = Connection(server, user=user_dn, password=password)
if config.use_starttls and not use_ssl: if config.get("use_starttls") and not use_ssl:
user_conn.start_tls() user_conn.start_tls()
if not user_conn.bind(): if not user_conn.bind():
@@ -147,8 +180,10 @@ class LDAPService:
return None return None
# 提取用户属性(优先用 LDAP 提供的值,不合法则回退默认) # 提取用户属性(优先用 LDAP 提供的值,不合法则回退默认)
email = _get_attr_value(user_entry, config.email_attr, f"{username}@ldap.local") email = _get_attr_value(
display_name = _get_attr_value(user_entry, config.display_name_attr, username) user_entry, config["email_attr"], f"{username}@ldap.local"
)
display_name = _get_attr_value(user_entry, config["display_name_attr"], username)
logger.info(f"LDAP 认证成功: {username}") logger.info(f"LDAP 认证成功: {username}")
return { return {
@@ -180,7 +215,7 @@ class LDAPService:
pass pass
@staticmethod @staticmethod
def test_connection(db: Session) -> Tuple[bool, str]: def test_connection_with_config(config: Dict[str, str]) -> Tuple[bool, str]:
""" """
测试 LDAP 连接 测试 LDAP 连接
@@ -193,22 +228,23 @@ class LDAPService:
except ImportError: except ImportError:
return False, "ldap3 库未安装" return False, "ldap3 库未安装"
config = LDAPService.get_config(db)
if not config: if not config:
return False, "LDAP 配置不存在" return False, "LDAP 配置不存在"
try: try:
use_ssl = config.server_url.startswith("ldaps://") server_url = config["server_url"]
use_ssl = server_url.startswith("ldaps://")
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
server = Server( server = Server(
config.server_url, server_url,
use_ssl=use_ssl, use_ssl=use_ssl,
get_info=ldap3.ALL, get_info=ldap3.ALL,
connect_timeout=LDAP_CONNECT_TIMEOUT, connect_timeout=timeout,
) )
bind_password = config.get_bind_password() bind_password = config["bind_password"]
conn = Connection(server, user=config.bind_dn, password=bind_password) conn = Connection(server, user=config["bind_dn"], password=bind_password)
if config.use_starttls and not use_ssl: if config.get("use_starttls") and not use_ssl:
conn.start_tls() conn.start_tls()
if not conn.bind(): if not conn.bind():
@@ -219,3 +255,16 @@ class LDAPService:
except Exception as e: except Exception as e:
return False, f"连接失败: {str(e)}" return False, f"连接失败: {str(e)}"
# 兼容旧接口:如果其他代码直接调用
@staticmethod
def authenticate(db: Session, username: str, password: str) -> Optional[dict]:
config = LDAPService.get_config_data(db)
return LDAPService.authenticate_with_config(config, username, password) if config else None
@staticmethod
def test_connection(db: Session) -> Tuple[bool, str]:
config = LDAPService.get_config_data(db)
if not config:
return False, "LDAP 配置不存在或未启用"
return LDAPService.test_connection_with_config(config)

View File

@@ -112,8 +112,16 @@ class AuthService:
logger.warning("登录失败 - LDAP 认证未启用") logger.warning("登录失败 - LDAP 认证未启用")
return None return None
# 预取配置,避免将 Session 传递到线程池
config_data = LDAPService.get_config_data(db)
if not config_data:
logger.warning("登录失败 - 无法获取 LDAP 配置或解密失败")
return None
# 在线程池中执行阻塞的 LDAP 网络请求,避免阻塞事件循环 # 在线程池中执行阻塞的 LDAP 网络请求,避免阻塞事件循环
ldap_user = await run_in_threadpool(LDAPService.authenticate, db, email, password) ldap_user = await run_in_threadpool(
LDAPService.authenticate_with_config, config_data, email, password
)
if not ldap_user: if not ldap_user:
return None return None
@@ -128,10 +136,6 @@ class AuthService:
return user return user
# 本地认证 # 本地认证
if LDAPService.is_ldap_exclusive(db):
logger.warning("登录失败 - 仅允许 LDAP 登录")
return None
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象 # 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
user = db.query(User).filter(User.email == email).first() user = db.query(User).filter(User.email == email).first()
@@ -139,6 +143,13 @@ class AuthService:
logger.warning(f"登录失败 - 用户不存在: {email}") logger.warning(f"登录失败 - 用户不存在: {email}")
return None return None
# 检查 LDAP exclusive 模式:仅允许本地管理员登录(紧急恢复通道)
if LDAPService.is_ldap_exclusive(db):
if user.role != UserRole.ADMIN or user.auth_source != AuthSource.LOCAL:
logger.warning(f"登录失败 - 仅允许 LDAP 登录(管理员除外): {email}")
return None
logger.info(f"LDAP exclusive 模式下允许本地管理员登录: {email}")
# 检查用户认证来源 # 检查用户认证来源
if user.auth_source == AuthSource.LDAP: if user.auth_source == AuthSource.LDAP:
logger.warning(f"登录失败 - 该用户使用 LDAP 认证: {email}") logger.warning(f"登录失败 - 该用户使用 LDAP 认证: {email}")