mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-11 20:18:30 +08:00
364 lines
13 KiB
Python
364 lines
13 KiB
Python
"""LDAP 认证服务"""
|
||
|
||
from typing import Any, Dict, Optional, Tuple
|
||
from urllib.parse import urlparse
|
||
|
||
from sqlalchemy.orm import Session
|
||
|
||
from src.core.logger import logger
|
||
from src.models.database import LDAPConfig
|
||
|
||
# LDAP 连接默认超时时间(秒)
|
||
DEFAULT_LDAP_CONNECT_TIMEOUT = 10
|
||
|
||
|
||
def parse_ldap_server_url(server_url: str) -> tuple[str, int, bool]:
|
||
"""
|
||
解析 LDAP 服务器地址,支持:
|
||
- ldap://host:389
|
||
- ldaps://host:636
|
||
- host:389(无 scheme 时默认 ldap)
|
||
|
||
Returns:
|
||
(host, port, use_ssl)
|
||
"""
|
||
raw = (server_url or "").strip()
|
||
if not raw:
|
||
raise ValueError("LDAP server_url is required")
|
||
|
||
parsed = urlparse(raw)
|
||
if parsed.scheme in {"ldap", "ldaps"}:
|
||
host = parsed.hostname
|
||
if not host:
|
||
raise ValueError("Invalid LDAP server_url")
|
||
use_ssl = parsed.scheme == "ldaps"
|
||
port = parsed.port or (636 if use_ssl else 389)
|
||
return host, port, use_ssl
|
||
|
||
# 兼容无 scheme:按 ldap:// 解析
|
||
parsed = urlparse(f"ldap://{raw}")
|
||
host = parsed.hostname
|
||
if not host:
|
||
raise ValueError("Invalid LDAP server_url")
|
||
port = parsed.port or 389
|
||
return host, port, False
|
||
|
||
|
||
def escape_ldap_filter(value: str, max_length: int = 128) -> str:
|
||
"""
|
||
转义 LDAP 过滤器中的特殊字符,防止 LDAP 注入攻击(RFC 4515)
|
||
|
||
Args:
|
||
value: 需要转义的字符串
|
||
max_length: 最大允许长度,默认 128 字符(覆盖大多数企业邮箱用户名)
|
||
|
||
Returns:
|
||
转义后的安全字符串
|
||
|
||
Raises:
|
||
ValueError: 输入值过长
|
||
"""
|
||
import unicodedata
|
||
|
||
# 先检查原始长度,防止 DoS 攻击
|
||
# 128 字符足够覆盖大多数企业用户名和邮箱地址
|
||
if len(value) > max_length:
|
||
raise ValueError(f"LDAP filter value too long (max {max_length} characters)")
|
||
|
||
# Unicode 规范化(使用 NFC 而非 NFKC,避免兼容性字符转换导致安全问题)
|
||
value = unicodedata.normalize("NFC", value)
|
||
|
||
# 再次检查规范化后的长度(防止规范化后长度突增)
|
||
if len(value) > max_length:
|
||
raise ValueError(f"LDAP filter value too long after normalization (max {max_length})")
|
||
|
||
# LDAP 过滤器特殊字符(RFC 4515 + 扩展)
|
||
# 使用显式顺序处理,确保反斜杠首先转义
|
||
value = value.replace("\\", r"\5c") # 反斜杠必须首先转义
|
||
value = value.replace("*", r"\2a")
|
||
value = value.replace("(", r"\28")
|
||
value = value.replace(")", r"\29")
|
||
value = value.replace("\x00", r"\00") # NUL
|
||
value = value.replace("&", r"\26")
|
||
value = value.replace("|", r"\7c")
|
||
value = value.replace("=", r"\3d")
|
||
value = value.replace(">", r"\3e")
|
||
value = value.replace("<", r"\3c")
|
||
value = value.replace("~", r"\7e")
|
||
value = value.replace("!", r"\21")
|
||
return value
|
||
|
||
|
||
def _get_attr_value(entry: Any, attr_name: str, default: str = "") -> str:
|
||
"""
|
||
提取 LDAP 条目属性的首个值,避免返回字符串化的列表表示。
|
||
"""
|
||
attr = getattr(entry, attr_name, None)
|
||
if not attr:
|
||
return default
|
||
# ldap3 的 EntryAttribute.value 已经是单值或列表,根据类型取首个
|
||
val = getattr(attr, "value", None)
|
||
if isinstance(val, list):
|
||
val = val[0] if val else default
|
||
if val is None:
|
||
return default
|
||
return str(val)
|
||
|
||
|
||
class LDAPService:
|
||
"""LDAP 认证服务"""
|
||
|
||
@staticmethod
|
||
def get_config(db: Session) -> Optional[LDAPConfig]:
|
||
"""获取 LDAP 配置"""
|
||
return db.query(LDAPConfig).first()
|
||
|
||
@staticmethod
|
||
def is_ldap_enabled(db: Session) -> bool:
|
||
"""检查 LDAP 是否可用(已启用且绑定密码可解密)"""
|
||
return LDAPService.get_config_data(db) is not None
|
||
|
||
@staticmethod
|
||
def is_ldap_exclusive(db: Session) -> bool:
|
||
"""检查是否仅允许 LDAP 登录(仅在 LDAP 可用时生效,避免误锁定)"""
|
||
config = LDAPService.get_config(db)
|
||
if not config or config.is_exclusive is not True:
|
||
return False
|
||
return LDAPService.get_config_data(db) is not None
|
||
|
||
@staticmethod
|
||
def get_config_data(db: Session) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
提前获取并解密配置,供线程池使用,避免跨线程共享 Session。
|
||
"""
|
||
config = LDAPService.get_config(db)
|
||
if not config or config.is_enabled is not True:
|
||
return None
|
||
|
||
try:
|
||
bind_password = config.get_bind_password()
|
||
except Exception as e:
|
||
logger.error(f"LDAP 绑定密码解密失败: {e}")
|
||
return None
|
||
|
||
# 绑定密码为空时无法进行 LDAP 认证
|
||
if not bind_password:
|
||
logger.warning("LDAP 绑定密码未配置,无法进行 LDAP 认证")
|
||
return None
|
||
|
||
return {
|
||
"server_url": config.server_url,
|
||
"bind_dn": config.bind_dn,
|
||
"bind_password": bind_password,
|
||
"base_dn": config.base_dn,
|
||
"user_search_filter": config.user_search_filter,
|
||
"username_attr": config.username_attr,
|
||
"email_attr": config.email_attr,
|
||
"display_name_attr": config.display_name_attr,
|
||
"use_starttls": config.use_starttls,
|
||
"connect_timeout": config.connect_timeout or DEFAULT_LDAP_CONNECT_TIMEOUT,
|
||
}
|
||
|
||
@staticmethod
|
||
def authenticate_with_config(config: Dict[str, Any], username: str, password: str) -> Optional[dict]:
|
||
"""
|
||
LDAP bind 验证
|
||
|
||
Args:
|
||
config: 已解密的 LDAP 配置
|
||
username: 用户名
|
||
password: 密码
|
||
|
||
Returns:
|
||
用户属性 dict {username, email, display_name} 或 None
|
||
"""
|
||
try:
|
||
import ldap3
|
||
from ldap3 import Server, Connection, SUBTREE
|
||
from ldap3.core.exceptions import LDAPBindError, LDAPSocketOpenError
|
||
except ImportError:
|
||
logger.error("ldap3 库未安装")
|
||
return None
|
||
|
||
if not config:
|
||
logger.warning("LDAP 未配置或未启用")
|
||
return None
|
||
|
||
admin_conn = None
|
||
user_conn = None
|
||
|
||
try:
|
||
# 创建服务器连接
|
||
server_url = config["server_url"]
|
||
server_host, server_port, use_ssl = parse_ldap_server_url(server_url)
|
||
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
|
||
server = Server(
|
||
server_host,
|
||
port=server_port,
|
||
use_ssl=use_ssl,
|
||
get_info=ldap3.ALL,
|
||
connect_timeout=timeout,
|
||
)
|
||
|
||
# 使用管理员账号连接
|
||
bind_password = config["bind_password"]
|
||
admin_conn = Connection(
|
||
server,
|
||
user=config["bind_dn"],
|
||
password=bind_password,
|
||
receive_timeout=timeout, # 添加读取超时,避免服务器响应缓慢时阻塞
|
||
)
|
||
|
||
if config.get("use_starttls") and not use_ssl:
|
||
admin_conn.start_tls()
|
||
|
||
if not admin_conn.bind():
|
||
logger.error(f"LDAP 管理员绑定失败: {admin_conn.result}")
|
||
return None
|
||
|
||
# 搜索用户(转义用户名防止 LDAP 注入)
|
||
safe_username = escape_ldap_filter(username)
|
||
search_filter = config["user_search_filter"].replace("{username}", safe_username)
|
||
admin_conn.search(
|
||
search_base=config["base_dn"],
|
||
search_filter=search_filter,
|
||
search_scope=SUBTREE,
|
||
size_limit=2, # 防止过滤器误配导致匹配多用户
|
||
time_limit=timeout, # 添加搜索超时,防止大型目录搜索阻塞
|
||
attributes=[
|
||
config["username_attr"],
|
||
config["email_attr"],
|
||
config["display_name_attr"],
|
||
],
|
||
)
|
||
|
||
if len(admin_conn.entries) != 1:
|
||
# 统一错误信息,避免泄露用户是否存在;日志仅记录结果数量,不泄露敏感信息
|
||
logger.warning(
|
||
f"LDAP 认证失败(用户查找阶段): 搜索返回 {len(admin_conn.entries)} 条结果"
|
||
)
|
||
return None
|
||
|
||
user_entry = admin_conn.entries[0]
|
||
user_dn = user_entry.entry_dn
|
||
|
||
# 用户密码验证
|
||
user_conn = Connection(
|
||
server,
|
||
user=user_dn,
|
||
password=password,
|
||
receive_timeout=timeout, # 添加读取超时
|
||
)
|
||
if config.get("use_starttls") and not use_ssl:
|
||
user_conn.start_tls()
|
||
|
||
if not user_conn.bind():
|
||
# 统一错误信息,避免泄露密码是否正确;日志仅记录错误码,不泄露用户 DN
|
||
bind_result = user_conn.result.get("description", "unknown")
|
||
logger.warning(f"LDAP 认证失败(密码验证阶段): {bind_result}")
|
||
return None
|
||
|
||
# 提取用户属性(优先用 LDAP 提供的值,不合法则回退默认)
|
||
ldap_username = _get_attr_value(user_entry, config["username_attr"], username)
|
||
email = _get_attr_value(
|
||
user_entry, config["email_attr"], f"{username}@ldap.local"
|
||
)
|
||
display_name = _get_attr_value(user_entry, config["display_name_attr"], username)
|
||
|
||
logger.info(f"LDAP 认证成功: {username}")
|
||
return {
|
||
"username": ldap_username,
|
||
"ldap_username": ldap_username,
|
||
"ldap_dn": user_dn,
|
||
"email": email,
|
||
"display_name": display_name,
|
||
}
|
||
|
||
except LDAPSocketOpenError as e:
|
||
logger.error(f"LDAP 服务器连接失败: {e}")
|
||
return None
|
||
except LDAPBindError as e:
|
||
logger.error(f"LDAP 绑定失败: {e}")
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"LDAP 认证异常: {e}")
|
||
return None
|
||
finally:
|
||
# 确保连接关闭,避免失败路径泄漏
|
||
# 使用循环确保即使第一个 unbind 失败,后续连接仍会尝试关闭
|
||
for conn, name in [(admin_conn, "admin"), (user_conn, "user")]:
|
||
if conn:
|
||
try:
|
||
conn.unbind()
|
||
except Exception as e:
|
||
logger.warning(f"LDAP {name} 连接关闭失败: {e}")
|
||
|
||
@staticmethod
|
||
def test_connection_with_config(config: Dict[str, Any]) -> Tuple[bool, str]:
|
||
"""
|
||
测试 LDAP 连接
|
||
|
||
Returns:
|
||
(success, message)
|
||
"""
|
||
try:
|
||
import ldap3
|
||
from ldap3 import Server, Connection
|
||
except ImportError:
|
||
return False, "ldap3 库未安装"
|
||
|
||
if not config:
|
||
return False, "LDAP 配置不存在"
|
||
|
||
conn = None
|
||
try:
|
||
server_url = config["server_url"]
|
||
server_host, server_port, use_ssl = parse_ldap_server_url(server_url)
|
||
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
|
||
server = Server(
|
||
server_host,
|
||
port=server_port,
|
||
use_ssl=use_ssl,
|
||
get_info=ldap3.ALL,
|
||
connect_timeout=timeout,
|
||
)
|
||
bind_password = config["bind_password"]
|
||
conn = Connection(
|
||
server,
|
||
user=config["bind_dn"],
|
||
password=bind_password,
|
||
receive_timeout=timeout, # 添加读取超时
|
||
)
|
||
|
||
if config.get("use_starttls") and not use_ssl:
|
||
conn.start_tls()
|
||
|
||
if not conn.bind():
|
||
return False, f"绑定失败: {conn.result}"
|
||
|
||
return True, "连接成功"
|
||
|
||
except Exception as e:
|
||
# 记录详细错误到日志,但只返回通用信息给前端,避免泄露敏感信息
|
||
logger.error(f"LDAP 测试连接失败: {type(e).__name__}: {e}")
|
||
return False, "连接失败,请检查服务器地址、端口和凭据"
|
||
finally:
|
||
if conn:
|
||
try:
|
||
conn.unbind()
|
||
except Exception as e:
|
||
logger.warning(f"LDAP 测试连接关闭失败: {e}")
|
||
|
||
# 兼容旧接口:如果其他代码直接调用
|
||
@staticmethod
|
||
def authenticate(db: Session, username: str, password: str) -> Optional[dict]:
|
||
config = LDAPService.get_config_data(db)
|
||
return LDAPService.authenticate_with_config(config, username, password) if config else None
|
||
|
||
@staticmethod
|
||
def test_connection(db: Session) -> Tuple[bool, str]:
|
||
config = LDAPService.get_config_data(db)
|
||
if not config:
|
||
return False, "LDAP 配置不存在或未启用"
|
||
return LDAPService.test_connection_with_config(config)
|