mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-08 02:32:27 +08:00
Fix LDAP authentication stability
This commit is contained in:
@@ -34,6 +34,22 @@ def escape_ldap_filter(value: str) -> str:
|
||||
return value
|
||||
|
||||
|
||||
def _get_attr_value(entry, attr_name: str, default: str = "") -> str:
|
||||
"""
|
||||
提取 LDAP 条目属性的首个值,避免返回字符串化的列表表示。
|
||||
"""
|
||||
attr = getattr(entry, attr_name, None)
|
||||
if not attr:
|
||||
return default
|
||||
# ldap3 的 EntryAttribute.value 已经是单值或列表,根据类型取首个
|
||||
val = getattr(attr, "value", None)
|
||||
if isinstance(val, list):
|
||||
val = val[0] if val else default
|
||||
if val is None:
|
||||
return default
|
||||
return str(val)
|
||||
|
||||
|
||||
class LDAPService:
|
||||
"""LDAP 认证服务"""
|
||||
|
||||
@@ -80,6 +96,9 @@ class LDAPService:
|
||||
logger.warning("LDAP 未配置或未启用")
|
||||
return None
|
||||
|
||||
admin_conn = None
|
||||
user_conn = None
|
||||
|
||||
try:
|
||||
# 创建服务器连接
|
||||
use_ssl = config.server_url.startswith("ldaps://")
|
||||
@@ -113,12 +132,10 @@ class LDAPService:
|
||||
|
||||
if not admin_conn.entries:
|
||||
logger.warning(f"LDAP 用户未找到: {username}")
|
||||
admin_conn.unbind()
|
||||
return None
|
||||
|
||||
user_entry = admin_conn.entries[0]
|
||||
user_dn = user_entry.entry_dn
|
||||
admin_conn.unbind()
|
||||
|
||||
# 用户密码验证
|
||||
user_conn = Connection(server, user=user_dn, password=password)
|
||||
@@ -129,11 +146,9 @@ class LDAPService:
|
||||
logger.warning(f"LDAP 密码验证失败: {username}")
|
||||
return None
|
||||
|
||||
user_conn.unbind()
|
||||
|
||||
# 提取用户属性
|
||||
email = str(getattr(user_entry, config.email_attr, "")) or f"{username}@ldap.local"
|
||||
display_name = str(getattr(user_entry, config.display_name_attr, "")) or username
|
||||
# 提取用户属性(优先用 LDAP 提供的值,不合法则回退默认)
|
||||
email = _get_attr_value(user_entry, config.email_attr, f"{username}@ldap.local")
|
||||
display_name = _get_attr_value(user_entry, config.display_name_attr, username)
|
||||
|
||||
logger.info(f"LDAP 认证成功: {username}")
|
||||
return {
|
||||
@@ -151,6 +166,18 @@ class LDAPService:
|
||||
except Exception as e:
|
||||
logger.error(f"LDAP 认证异常: {e}")
|
||||
return None
|
||||
finally:
|
||||
# 确保连接关闭,避免失败路径泄漏
|
||||
if admin_conn:
|
||||
try:
|
||||
admin_conn.unbind()
|
||||
except Exception:
|
||||
pass
|
||||
if user_conn:
|
||||
try:
|
||||
user_conn.unbind()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def test_connection(db: Session) -> Tuple[bool, str]:
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.config import config
|
||||
@@ -111,7 +112,8 @@ class AuthService:
|
||||
logger.warning("登录失败 - LDAP 认证未启用")
|
||||
return None
|
||||
|
||||
ldap_user = LDAPService.authenticate(db, email, password)
|
||||
# 在线程池中执行阻塞的 LDAP 网络请求,避免阻塞事件循环
|
||||
ldap_user = await run_in_threadpool(LDAPService.authenticate, db, email, password)
|
||||
if not ldap_user:
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user