feat: 添加 LDAP 认证支持

- 新增 LDAP 服务和 API 接口
- 添加 LDAP 配置管理页面
- 登录页面支持 LDAP/本地认证切换
- 数据库迁移支持 LDAP 相关字段
This commit is contained in:
fawney19
2026-01-06 14:38:42 +08:00
21 changed files with 3947 additions and 2037 deletions

363
src/services/auth/ldap.py Normal file
View File

@@ -0,0 +1,363 @@
"""LDAP 认证服务"""
from typing import Any, Dict, Optional, Tuple
from urllib.parse import urlparse
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.models.database import LDAPConfig
# LDAP 连接默认超时时间(秒)
DEFAULT_LDAP_CONNECT_TIMEOUT = 10
def parse_ldap_server_url(server_url: str) -> tuple[str, int, bool]:
"""
解析 LDAP 服务器地址,支持:
- ldap://host:389
- ldaps://host:636
- host:389无 scheme 时默认 ldap
Returns:
(host, port, use_ssl)
"""
raw = (server_url or "").strip()
if not raw:
raise ValueError("LDAP server_url is required")
parsed = urlparse(raw)
if parsed.scheme in {"ldap", "ldaps"}:
host = parsed.hostname
if not host:
raise ValueError("Invalid LDAP server_url")
use_ssl = parsed.scheme == "ldaps"
port = parsed.port or (636 if use_ssl else 389)
return host, port, use_ssl
# 兼容无 scheme按 ldap:// 解析
parsed = urlparse(f"ldap://{raw}")
host = parsed.hostname
if not host:
raise ValueError("Invalid LDAP server_url")
port = parsed.port or 389
return host, port, False
def escape_ldap_filter(value: str, max_length: int = 128) -> str:
"""
转义 LDAP 过滤器中的特殊字符,防止 LDAP 注入攻击RFC 4515
Args:
value: 需要转义的字符串
max_length: 最大允许长度,默认 128 字符(覆盖大多数企业邮箱用户名)
Returns:
转义后的安全字符串
Raises:
ValueError: 输入值过长
"""
import unicodedata
# 先检查原始长度,防止 DoS 攻击
# 128 字符足够覆盖大多数企业用户名和邮箱地址
if len(value) > max_length:
raise ValueError(f"LDAP filter value too long (max {max_length} characters)")
# Unicode 规范化(使用 NFC 而非 NFKC避免兼容性字符转换导致安全问题
value = unicodedata.normalize("NFC", value)
# 再次检查规范化后的长度(防止规范化后长度突增)
if len(value) > max_length:
raise ValueError(f"LDAP filter value too long after normalization (max {max_length})")
# LDAP 过滤器特殊字符RFC 4515 + 扩展)
# 使用显式顺序处理,确保反斜杠首先转义
value = value.replace("\\", r"\5c") # 反斜杠必须首先转义
value = value.replace("*", r"\2a")
value = value.replace("(", r"\28")
value = value.replace(")", r"\29")
value = value.replace("\x00", r"\00") # NUL
value = value.replace("&", r"\26")
value = value.replace("|", r"\7c")
value = value.replace("=", r"\3d")
value = value.replace(">", r"\3e")
value = value.replace("<", r"\3c")
value = value.replace("~", r"\7e")
value = value.replace("!", r"\21")
return value
def _get_attr_value(entry: Any, attr_name: str, default: str = "") -> str:
"""
提取 LDAP 条目属性的首个值,避免返回字符串化的列表表示。
"""
attr = getattr(entry, attr_name, None)
if not attr:
return default
# ldap3 的 EntryAttribute.value 已经是单值或列表,根据类型取首个
val = getattr(attr, "value", None)
if isinstance(val, list):
val = val[0] if val else default
if val is None:
return default
return str(val)
class LDAPService:
"""LDAP 认证服务"""
@staticmethod
def get_config(db: Session) -> Optional[LDAPConfig]:
"""获取 LDAP 配置"""
return db.query(LDAPConfig).first()
@staticmethod
def is_ldap_enabled(db: Session) -> bool:
"""检查 LDAP 是否可用(已启用且绑定密码可解密)"""
return LDAPService.get_config_data(db) is not None
@staticmethod
def is_ldap_exclusive(db: Session) -> bool:
"""检查是否仅允许 LDAP 登录(仅在 LDAP 可用时生效,避免误锁定)"""
config = LDAPService.get_config(db)
if not config or config.is_exclusive is not True:
return False
return LDAPService.get_config_data(db) is not None
@staticmethod
def get_config_data(db: Session) -> Optional[Dict[str, Any]]:
"""
提前获取并解密配置,供线程池使用,避免跨线程共享 Session。
"""
config = LDAPService.get_config(db)
if not config or config.is_enabled is not True:
return None
try:
bind_password = config.get_bind_password()
except Exception as e:
logger.error(f"LDAP 绑定密码解密失败: {e}")
return None
# 绑定密码为空时无法进行 LDAP 认证
if not bind_password:
logger.warning("LDAP 绑定密码未配置,无法进行 LDAP 认证")
return None
return {
"server_url": config.server_url,
"bind_dn": config.bind_dn,
"bind_password": bind_password,
"base_dn": config.base_dn,
"user_search_filter": config.user_search_filter,
"username_attr": config.username_attr,
"email_attr": config.email_attr,
"display_name_attr": config.display_name_attr,
"use_starttls": config.use_starttls,
"connect_timeout": config.connect_timeout or DEFAULT_LDAP_CONNECT_TIMEOUT,
}
@staticmethod
def authenticate_with_config(config: Dict[str, Any], username: str, password: str) -> Optional[dict]:
"""
LDAP bind 验证
Args:
config: 已解密的 LDAP 配置
username: 用户名
password: 密码
Returns:
用户属性 dict {username, email, display_name} 或 None
"""
try:
import ldap3
from ldap3 import Server, Connection, SUBTREE
from ldap3.core.exceptions import LDAPBindError, LDAPSocketOpenError
except ImportError:
logger.error("ldap3 库未安装")
return None
if not config:
logger.warning("LDAP 未配置或未启用")
return None
admin_conn = None
user_conn = None
try:
# 创建服务器连接
server_url = config["server_url"]
server_host, server_port, use_ssl = parse_ldap_server_url(server_url)
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
server = Server(
server_host,
port=server_port,
use_ssl=use_ssl,
get_info=ldap3.ALL,
connect_timeout=timeout,
)
# 使用管理员账号连接
bind_password = config["bind_password"]
admin_conn = Connection(
server,
user=config["bind_dn"],
password=bind_password,
receive_timeout=timeout, # 添加读取超时,避免服务器响应缓慢时阻塞
)
if config.get("use_starttls") and not use_ssl:
admin_conn.start_tls()
if not admin_conn.bind():
logger.error(f"LDAP 管理员绑定失败: {admin_conn.result}")
return None
# 搜索用户(转义用户名防止 LDAP 注入)
safe_username = escape_ldap_filter(username)
search_filter = config["user_search_filter"].replace("{username}", safe_username)
admin_conn.search(
search_base=config["base_dn"],
search_filter=search_filter,
search_scope=SUBTREE,
size_limit=2, # 防止过滤器误配导致匹配多用户
time_limit=timeout, # 添加搜索超时,防止大型目录搜索阻塞
attributes=[
config["username_attr"],
config["email_attr"],
config["display_name_attr"],
],
)
if len(admin_conn.entries) != 1:
# 统一错误信息,避免泄露用户是否存在;日志仅记录结果数量,不泄露敏感信息
logger.warning(
f"LDAP 认证失败(用户查找阶段): 搜索返回 {len(admin_conn.entries)} 条结果"
)
return None
user_entry = admin_conn.entries[0]
user_dn = user_entry.entry_dn
# 用户密码验证
user_conn = Connection(
server,
user=user_dn,
password=password,
receive_timeout=timeout, # 添加读取超时
)
if config.get("use_starttls") and not use_ssl:
user_conn.start_tls()
if not user_conn.bind():
# 统一错误信息,避免泄露密码是否正确;日志仅记录错误码,不泄露用户 DN
bind_result = user_conn.result.get("description", "unknown")
logger.warning(f"LDAP 认证失败(密码验证阶段): {bind_result}")
return None
# 提取用户属性(优先用 LDAP 提供的值,不合法则回退默认)
ldap_username = _get_attr_value(user_entry, config["username_attr"], username)
email = _get_attr_value(
user_entry, config["email_attr"], f"{username}@ldap.local"
)
display_name = _get_attr_value(user_entry, config["display_name_attr"], username)
logger.info(f"LDAP 认证成功: {username}")
return {
"username": ldap_username,
"ldap_username": ldap_username,
"ldap_dn": user_dn,
"email": email,
"display_name": display_name,
}
except LDAPSocketOpenError as e:
logger.error(f"LDAP 服务器连接失败: {e}")
return None
except LDAPBindError as e:
logger.error(f"LDAP 绑定失败: {e}")
return None
except Exception as e:
logger.error(f"LDAP 认证异常: {e}")
return None
finally:
# 确保连接关闭,避免失败路径泄漏
# 使用循环确保即使第一个 unbind 失败,后续连接仍会尝试关闭
for conn, name in [(admin_conn, "admin"), (user_conn, "user")]:
if conn:
try:
conn.unbind()
except Exception as e:
logger.warning(f"LDAP {name} 连接关闭失败: {e}")
@staticmethod
def test_connection_with_config(config: Dict[str, Any]) -> Tuple[bool, str]:
"""
测试 LDAP 连接
Returns:
(success, message)
"""
try:
import ldap3
from ldap3 import Server, Connection
except ImportError:
return False, "ldap3 库未安装"
if not config:
return False, "LDAP 配置不存在"
conn = None
try:
server_url = config["server_url"]
server_host, server_port, use_ssl = parse_ldap_server_url(server_url)
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
server = Server(
server_host,
port=server_port,
use_ssl=use_ssl,
get_info=ldap3.ALL,
connect_timeout=timeout,
)
bind_password = config["bind_password"]
conn = Connection(
server,
user=config["bind_dn"],
password=bind_password,
receive_timeout=timeout, # 添加读取超时
)
if config.get("use_starttls") and not use_ssl:
conn.start_tls()
if not conn.bind():
return False, f"绑定失败: {conn.result}"
return True, "连接成功"
except Exception as e:
# 记录详细错误到日志,但只返回通用信息给前端,避免泄露敏感信息
logger.error(f"LDAP 测试连接失败: {type(e).__name__}: {e}")
return False, "连接失败,请检查服务器地址、端口和凭据"
finally:
if conn:
try:
conn.unbind()
except Exception as e:
logger.warning(f"LDAP 测试连接关闭失败: {e}")
# 兼容旧接口:如果其他代码直接调用
@staticmethod
def authenticate(db: Session, username: str, password: str) -> Optional[dict]:
config = LDAPService.get_config_data(db)
return LDAPService.authenticate_with_config(config, username, password) if config else None
@staticmethod
def test_connection(db: Session) -> Tuple[bool, str]:
config = LDAPService.get_config_data(db)
if not config:
return False, "LDAP 配置不存在或未启用"
return LDAPService.test_connection_with_config(config)

View File

@@ -2,21 +2,25 @@
认证服务
"""
import os
import hashlib
import secrets
import time
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional
import jwt
from fastapi import HTTPException, status
from fastapi.concurrency import run_in_threadpool
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, joinedload
from src.config import config
from src.core.crypto import crypto_service
from src.core.logger import logger
from src.core.enums import AuthSource
from src.models.database import ApiKey, User, UserRole
from src.services.auth.jwt_blacklist import JWTBlacklistService
from src.services.auth.ldap import LDAPService
from src.services.cache.user_cache import UserCacheService
from src.services.user.apikey import ApiKeyService
@@ -92,15 +96,86 @@ class AuthService:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的Token")
@staticmethod
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
"""用户登录认证"""
async def authenticate_user(
db: Session, email: str, password: str, auth_type: str = "local"
) -> Optional[User]:
"""用户登录认证
Args:
db: 数据库会话
email: 邮箱/用户名
password: 密码
auth_type: 认证类型 ("local""ldap")
"""
if auth_type == "ldap":
# LDAP 认证
# 预取配置,避免将 Session 传递到线程池
config_data = LDAPService.get_config_data(db)
if not config_data:
logger.warning("登录失败 - LDAP 未启用或配置无效")
return None
# 计算总体超时LDAP 认证包含多次网络操作(连接、管理员绑定、搜索、用户绑定)
# 超时策略:
# - 单次操作超时(connect_timeout):控制每次网络操作的最大等待时间
# - 总体超时:防止异常场景(如服务器响应缓慢但未超时)导致请求堆积
# - 公式:单次超时 × 4覆盖 4 次主要网络操作)+ 10% 缓冲
# - 最小 20 秒(保证基本操作),最大 60 秒(避免用户等待过长)
single_timeout = config_data.get("connect_timeout", 10)
total_timeout = max(20, min(int(single_timeout * 4 * 1.1), 60))
# 在线程池中执行阻塞的 LDAP 网络请求,避免阻塞事件循环
# 添加总体超时保护,防止异常场景下请求堆积
import asyncio
try:
ldap_user = await asyncio.wait_for(
run_in_threadpool(
LDAPService.authenticate_with_config, config_data, email, password
),
timeout=total_timeout,
)
except asyncio.TimeoutError:
logger.error(f"LDAP 认证总体超时({total_timeout}秒): {email}")
return None
if not ldap_user:
return None
# 获取或创建本地用户
user = await AuthService._get_or_create_ldap_user(db, ldap_user)
if not user:
# 已有本地账号但来源不匹配等情况
return None
if not user.is_active:
logger.warning(f"登录失败 - 用户已禁用: {email}")
return None
return user
# 本地认证
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
user = db.query(User).filter(User.email == email).first()
# 支持邮箱或用户名登录
from sqlalchemy import or_
user = db.query(User).filter(
or_(User.email == email, User.username == email)
).first()
if not user:
logger.warning(f"登录失败 - 用户不存在: {email}")
return None
# 检查 LDAP exclusive 模式:仅允许本地管理员登录(紧急恢复通道)
if LDAPService.is_ldap_exclusive(db):
if user.role != UserRole.ADMIN or user.auth_source != AuthSource.LOCAL:
logger.warning(f"登录失败 - 仅允许 LDAP 登录(管理员除外): {email}")
return None
logger.warning(f"[LDAP-EXCLUSIVE] 紧急恢复通道:本地管理员登录: {email}")
# 检查用户认证来源
if user.auth_source == AuthSource.LDAP:
logger.warning(f"登录失败 - 该用户使用 LDAP 认证: {email}")
return None
if not user.verify_password(password):
logger.warning(f"登录失败 - 密码错误: {email}")
return None
@@ -118,6 +193,127 @@ class AuthService:
logger.info(f"用户登录成功: {email} (ID: {user.id})")
return user
@staticmethod
async def _get_or_create_ldap_user(db: Session, ldap_user: dict) -> Optional[User]:
"""获取或创建 LDAP 用户
Args:
ldap_user: LDAP 用户信息 {username, email, display_name, ldap_dn, ldap_username}
注意:使用 with_for_update() 防止并发首次登录创建重复用户
"""
ldap_dn = (ldap_user.get("ldap_dn") or "").strip() or None
ldap_username = (ldap_user.get("ldap_username") or ldap_user.get("username") or "").strip() or None
email = ldap_user["email"]
# 优先用稳定标识查找,避免邮箱变更/用户名冲突导致重复建号
# 使用 with_for_update() 锁定行,防止并发创建
user: Optional[User] = None
if ldap_dn:
user = (
db.query(User)
.filter(User.auth_source == AuthSource.LDAP, User.ldap_dn == ldap_dn)
.with_for_update()
.first()
)
if not user and ldap_username:
user = (
db.query(User)
.filter(User.auth_source == AuthSource.LDAP, User.ldap_username == ldap_username)
.with_for_update()
.first()
)
if not user:
# 最后回退按 email 查找:如果存在同邮箱的本地账号,需要拒绝以避免接管
user = db.query(User).filter(User.email == email).with_for_update().first()
if user:
if user.auth_source != AuthSource.LDAP:
# 避免覆盖已有本地账户(不同来源时拒绝登录)
logger.warning(
f"LDAP 登录拒绝 - 账户来源不匹配(现有:{user.auth_source}, 请求:LDAP): {email}"
)
return None
# 同步邮箱LDAP 侧邮箱变更时更新;若新邮箱已被占用则拒绝)
if user.email != email:
email_taken = (
db.query(User)
.filter(User.email == email, User.id != user.id)
.first()
)
if email_taken:
logger.warning(f"LDAP 登录拒绝 - 新邮箱已被占用: {email}")
return None
user.email = email
# 同步 LDAP 标识(首次填充或 LDAP 侧发生变化)
if ldap_dn and user.ldap_dn != ldap_dn:
user.ldap_dn = ldap_dn
if ldap_username and user.ldap_username != ldap_username:
user.ldap_username = ldap_username
user.last_login_at = datetime.now(timezone.utc)
db.commit()
await UserCacheService.invalidate_user_cache(user.id, user.email)
logger.info(f"LDAP 用户登录成功: {ldap_user['email']} (ID: {user.id})")
return user
# 检查 username 是否已被占用,使用时间戳+随机数确保唯一性
base_username = ldap_username or ldap_user["username"]
username = base_username
max_retries = 3
for attempt in range(max_retries):
# 检查用户名是否已存在
existing_user_with_username = db.query(User).filter(User.username == username).first()
if existing_user_with_username:
# 如果 username 已存在,使用时间戳+随机数确保唯一性
username = f"{base_username}_ldap_{int(time.time())}{uuid.uuid4().hex[:4]}"
logger.info(f"LDAP 用户名冲突,使用新用户名: {ldap_user['username']} -> {username}")
# 创建新用户
user = User(
email=email,
username=username,
password_hash="", # LDAP 用户无本地密码
auth_source=AuthSource.LDAP,
ldap_dn=ldap_dn,
ldap_username=ldap_username,
role=UserRole.USER,
is_active=True,
last_login_at=datetime.now(timezone.utc),
)
try:
db.add(user)
db.commit()
db.refresh(user)
logger.info(f"LDAP 用户创建成功: {ldap_user['email']} (ID: {user.id})")
return user
except IntegrityError as e:
db.rollback()
error_str = str(e.orig).lower() if e.orig else str(e).lower()
# 解析具体冲突类型
if "email" in error_str or "ix_users_email" in error_str:
# 邮箱冲突不应重试(前面已检查过,说明是并发创建)
logger.error(f"LDAP 用户创建失败 - 邮箱并发冲突: {email}")
return None
elif "username" in error_str or "ix_users_username" in error_str:
# 用户名冲突,重试时会生成新用户名
if attempt == max_retries - 1:
logger.error(f"LDAP 用户创建失败(用户名冲突重试耗尽): {username}")
return None
username = f"{base_username}_ldap_{int(time.time())}{uuid.uuid4().hex[:4]}"
logger.warning(f"LDAP 用户创建用户名冲突,重试 ({attempt + 1}/{max_retries}): {username}")
else:
# 其他约束冲突,不重试
logger.error(f"LDAP 用户创建失败 - 未知数据库约束冲突: {e}")
return None
return None
@staticmethod
def authenticate_api_key(db: Session, api_key: str) -> Optional[tuple[User, ApiKey]]:
"""API密钥认证"""