mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-11 12:08:30 +08:00
Enhance LDAP auth config handling
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
"""LDAP 认证服务"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import LDAPConfig
|
||||
|
||||
# LDAP 连接超时时间(秒)
|
||||
LDAP_CONNECT_TIMEOUT = 10
|
||||
# LDAP 连接默认超时时间(秒)
|
||||
DEFAULT_LDAP_CONNECT_TIMEOUT = 10
|
||||
|
||||
|
||||
def escape_ldap_filter(value: str) -> str:
|
||||
@@ -71,12 +71,40 @@ class LDAPService:
|
||||
return config.is_exclusive if config and config.is_enabled else False
|
||||
|
||||
@staticmethod
|
||||
def authenticate(db: Session, username: str, password: str) -> Optional[dict]:
|
||||
def get_config_data(db: Session) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
提前获取并解密配置,供线程池使用,避免跨线程共享 Session。
|
||||
"""
|
||||
config = LDAPService.get_config(db)
|
||||
if not config or not config.is_enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
bind_password = config.get_bind_password()
|
||||
except Exception as e:
|
||||
logger.error(f"LDAP 绑定密码解密失败: {e}")
|
||||
return None
|
||||
|
||||
return {
|
||||
"server_url": config.server_url,
|
||||
"bind_dn": config.bind_dn,
|
||||
"bind_password": bind_password,
|
||||
"base_dn": config.base_dn,
|
||||
"user_search_filter": config.user_search_filter,
|
||||
"username_attr": config.username_attr,
|
||||
"email_attr": config.email_attr,
|
||||
"display_name_attr": config.display_name_attr,
|
||||
"use_starttls": config.use_starttls,
|
||||
"connect_timeout": config.connect_timeout or DEFAULT_LDAP_CONNECT_TIMEOUT,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def authenticate_with_config(config: Dict[str, str], username: str, password: str) -> Optional[dict]:
|
||||
"""
|
||||
LDAP bind 验证
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config: 已解密的 LDAP 配置
|
||||
username: 用户名
|
||||
password: 密码
|
||||
|
||||
@@ -91,8 +119,7 @@ class LDAPService:
|
||||
logger.error("ldap3 库未安装")
|
||||
return None
|
||||
|
||||
config = LDAPService.get_config(db)
|
||||
if not config or not config.is_enabled:
|
||||
if not config:
|
||||
logger.warning("LDAP 未配置或未启用")
|
||||
return None
|
||||
|
||||
@@ -101,19 +128,21 @@ class LDAPService:
|
||||
|
||||
try:
|
||||
# 创建服务器连接
|
||||
use_ssl = config.server_url.startswith("ldaps://")
|
||||
server_url = config["server_url"]
|
||||
use_ssl = server_url.startswith("ldaps://")
|
||||
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
|
||||
server = Server(
|
||||
config.server_url,
|
||||
server_url,
|
||||
use_ssl=use_ssl,
|
||||
get_info=ldap3.ALL,
|
||||
connect_timeout=LDAP_CONNECT_TIMEOUT,
|
||||
connect_timeout=timeout,
|
||||
)
|
||||
|
||||
# 使用管理员账号连接
|
||||
bind_password = config.get_bind_password()
|
||||
admin_conn = Connection(server, user=config.bind_dn, password=bind_password)
|
||||
bind_password = config["bind_password"]
|
||||
admin_conn = Connection(server, user=config["bind_dn"], password=bind_password)
|
||||
|
||||
if config.use_starttls and not use_ssl:
|
||||
if config.get("use_starttls") and not use_ssl:
|
||||
admin_conn.start_tls()
|
||||
|
||||
if not admin_conn.bind():
|
||||
@@ -122,12 +151,16 @@ class LDAPService:
|
||||
|
||||
# 搜索用户(转义用户名防止 LDAP 注入)
|
||||
safe_username = escape_ldap_filter(username)
|
||||
search_filter = config.user_search_filter.replace("{username}", safe_username)
|
||||
search_filter = config["user_search_filter"].replace("{username}", safe_username)
|
||||
admin_conn.search(
|
||||
search_base=config.base_dn,
|
||||
search_base=config["base_dn"],
|
||||
search_filter=search_filter,
|
||||
search_scope=SUBTREE,
|
||||
attributes=[config.username_attr, config.email_attr, config.display_name_attr],
|
||||
attributes=[
|
||||
config["username_attr"],
|
||||
config["email_attr"],
|
||||
config["display_name_attr"],
|
||||
],
|
||||
)
|
||||
|
||||
if not admin_conn.entries:
|
||||
@@ -139,7 +172,7 @@ class LDAPService:
|
||||
|
||||
# 用户密码验证
|
||||
user_conn = Connection(server, user=user_dn, password=password)
|
||||
if config.use_starttls and not use_ssl:
|
||||
if config.get("use_starttls") and not use_ssl:
|
||||
user_conn.start_tls()
|
||||
|
||||
if not user_conn.bind():
|
||||
@@ -147,8 +180,10 @@ class LDAPService:
|
||||
return None
|
||||
|
||||
# 提取用户属性(优先用 LDAP 提供的值,不合法则回退默认)
|
||||
email = _get_attr_value(user_entry, config.email_attr, f"{username}@ldap.local")
|
||||
display_name = _get_attr_value(user_entry, config.display_name_attr, username)
|
||||
email = _get_attr_value(
|
||||
user_entry, config["email_attr"], f"{username}@ldap.local"
|
||||
)
|
||||
display_name = _get_attr_value(user_entry, config["display_name_attr"], username)
|
||||
|
||||
logger.info(f"LDAP 认证成功: {username}")
|
||||
return {
|
||||
@@ -180,7 +215,7 @@ class LDAPService:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def test_connection(db: Session) -> Tuple[bool, str]:
|
||||
def test_connection_with_config(config: Dict[str, str]) -> Tuple[bool, str]:
|
||||
"""
|
||||
测试 LDAP 连接
|
||||
|
||||
@@ -193,22 +228,23 @@ class LDAPService:
|
||||
except ImportError:
|
||||
return False, "ldap3 库未安装"
|
||||
|
||||
config = LDAPService.get_config(db)
|
||||
if not config:
|
||||
return False, "LDAP 配置不存在"
|
||||
|
||||
try:
|
||||
use_ssl = config.server_url.startswith("ldaps://")
|
||||
server_url = config["server_url"]
|
||||
use_ssl = server_url.startswith("ldaps://")
|
||||
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
|
||||
server = Server(
|
||||
config.server_url,
|
||||
server_url,
|
||||
use_ssl=use_ssl,
|
||||
get_info=ldap3.ALL,
|
||||
connect_timeout=LDAP_CONNECT_TIMEOUT,
|
||||
connect_timeout=timeout,
|
||||
)
|
||||
bind_password = config.get_bind_password()
|
||||
conn = Connection(server, user=config.bind_dn, password=bind_password)
|
||||
bind_password = config["bind_password"]
|
||||
conn = Connection(server, user=config["bind_dn"], password=bind_password)
|
||||
|
||||
if config.use_starttls and not use_ssl:
|
||||
if config.get("use_starttls") and not use_ssl:
|
||||
conn.start_tls()
|
||||
|
||||
if not conn.bind():
|
||||
@@ -219,3 +255,16 @@ class LDAPService:
|
||||
|
||||
except Exception as e:
|
||||
return False, f"连接失败: {str(e)}"
|
||||
|
||||
# 兼容旧接口:如果其他代码直接调用
|
||||
@staticmethod
|
||||
def authenticate(db: Session, username: str, password: str) -> Optional[dict]:
|
||||
config = LDAPService.get_config_data(db)
|
||||
return LDAPService.authenticate_with_config(config, username, password) if config else None
|
||||
|
||||
@staticmethod
|
||||
def test_connection(db: Session) -> Tuple[bool, str]:
|
||||
config = LDAPService.get_config_data(db)
|
||||
if not config:
|
||||
return False, "LDAP 配置不存在或未启用"
|
||||
return LDAPService.test_connection_with_config(config)
|
||||
|
||||
Reference in New Issue
Block a user