Enhance LDAP auth config handling

This commit is contained in:
RWDai
2026-01-04 16:27:02 +08:00
parent 414f45aa71
commit 3e4309eba3
9 changed files with 231 additions and 59 deletions

View File

@@ -33,6 +33,7 @@ class LDAPConfigResponse(BaseModel):
is_enabled: bool
is_exclusive: bool
use_starttls: bool
connect_timeout: int
class LDAPConfigUpdate(BaseModel):
@@ -49,6 +50,7 @@ class LDAPConfigUpdate(BaseModel):
is_enabled: bool = False
is_exclusive: bool = False
use_starttls: bool = False
connect_timeout: int = Field(default=10, ge=1, le=60)
class LDAPTestResponse(BaseModel):
@@ -58,6 +60,23 @@ class LDAPTestResponse(BaseModel):
message: str
class LDAPConfigTest(BaseModel):
"""LDAP配置测试请求全部可选用于临时覆盖"""
server_url: Optional[str] = Field(None, min_length=1, max_length=255)
bind_dn: Optional[str] = Field(None, min_length=1, max_length=255)
bind_password: Optional[str] = Field(None, min_length=1)
base_dn: Optional[str] = Field(None, min_length=1, max_length=255)
user_search_filter: Optional[str] = Field(None, max_length=500)
username_attr: Optional[str] = Field(None, max_length=50)
email_attr: Optional[str] = Field(None, max_length=50)
display_name_attr: Optional[str] = Field(None, max_length=50)
is_enabled: Optional[bool] = None
is_exclusive: Optional[bool] = None
use_starttls: Optional[bool] = None
connect_timeout: Optional[int] = Field(None, ge=1, le=60)
# ========== API Endpoints ==========
@@ -102,6 +121,7 @@ class AdminGetLDAPConfigAdapter(AdminApiAdapter):
is_enabled=False,
is_exclusive=False,
use_starttls=False,
connect_timeout=10,
).model_dump()
return LDAPConfigResponse(
@@ -115,6 +135,7 @@ class AdminGetLDAPConfigAdapter(AdminApiAdapter):
is_enabled=config.is_enabled,
is_exclusive=config.is_exclusive,
use_starttls=config.use_starttls,
connect_timeout=config.connect_timeout,
).model_dump()
@@ -151,6 +172,7 @@ class AdminUpdateLDAPConfigAdapter(AdminApiAdapter):
config.is_enabled = config_update.is_enabled
config.is_exclusive = config_update.is_exclusive
config.use_starttls = config_update.use_starttls
config.connect_timeout = config_update.connect_timeout
if config_update.bind_password:
config.bind_password_encrypted = crypto_service.encrypt(config_update.bind_password)
@@ -162,33 +184,77 @@ class AdminUpdateLDAPConfigAdapter(AdminApiAdapter):
class AdminTestLDAPConnectionAdapter(AdminApiAdapter):
async def handle(self, context) -> Dict[str, Any]: # type: ignore[override]
db = context.db
config = db.query(LDAPConfig).first()
from src.services.auth.ldap import LDAPService
if not config:
return LDAPTestResponse(success=False, message="LDAP配置不存在").model_dump()
db = context.db
if context.json_body is not None:
payload = context.json_body
elif not context.raw_body:
payload = {}
else:
payload = context.ensure_json_body()
saved_config = db.query(LDAPConfig).first()
try:
import ldap3
overrides = LDAPConfigTest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
bind_password = crypto_service.decrypt(config.bind_password_encrypted)
config_data: Dict[str, Any] = {}
use_ssl = config.server_url.startswith("ldaps://")
server = ldap3.Server(config.server_url, use_ssl=use_ssl, get_info=ldap3.ALL)
conn = ldap3.Connection(server, user=config.bind_dn, password=bind_password)
if config.use_starttls and not use_ssl:
conn.start_tls()
if not conn.bind():
if saved_config:
config_data = {
"server_url": saved_config.server_url,
"bind_dn": saved_config.bind_dn,
"base_dn": saved_config.base_dn,
"user_search_filter": saved_config.user_search_filter,
"username_attr": saved_config.username_attr,
"email_attr": saved_config.email_attr,
"display_name_attr": saved_config.display_name_attr,
"use_starttls": saved_config.use_starttls,
"connect_timeout": saved_config.connect_timeout,
}
try:
config_data["bind_password"] = crypto_service.decrypt(
saved_config.bind_password_encrypted
)
except Exception as e:
return LDAPTestResponse(
success=False, message=f"绑定失败: {conn.result}"
success=False, message=f"绑定密码解密失败: {str(e)}"
).model_dump()
conn.unbind()
return LDAPTestResponse(success=True, message="LDAP连接测试成功").model_dump()
# 应用前端传入的覆盖值
for field in [
"server_url",
"bind_dn",
"base_dn",
"user_search_filter",
"username_attr",
"email_attr",
"display_name_attr",
"use_starttls",
"is_enabled",
"is_exclusive",
"connect_timeout",
]:
value = getattr(overrides, field)
if value is not None:
config_data[field] = value
except ImportError:
return LDAPTestResponse(success=False, message="ldap3库未安装").model_dump()
except Exception as e:
return LDAPTestResponse(success=False, message=f"连接失败: {str(e)}").model_dump()
if overrides.bind_password:
config_data["bind_password"] = overrides.bind_password
# 必填字段检查
required_fields = ["server_url", "bind_dn", "base_dn", "bind_password"]
missing = [f for f in required_fields if not config_data.get(f)]
if missing:
return LDAPTestResponse(
success=False, message=f"缺少必要字段: {', '.join(missing)}"
).model_dump()
success, message = LDAPService.test_connection_with_config(config_data)
return LDAPTestResponse(success=success, message=message).model_dump()