mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-12 12:38:31 +08:00
feat: 添加 LDAP 认证支持
- 新增 LDAP 服务和 API 接口 - 添加 LDAP 配置管理页面 - 登录页面支持 LDAP/本地认证切换 - 数据库迁移支持 LDAP 相关字段
This commit is contained in:
363
src/services/auth/ldap.py
Normal file
363
src/services/auth/ldap.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""LDAP 认证服务"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import LDAPConfig
|
||||
|
||||
# LDAP 连接默认超时时间(秒)
|
||||
DEFAULT_LDAP_CONNECT_TIMEOUT = 10
|
||||
|
||||
|
||||
def parse_ldap_server_url(server_url: str) -> tuple[str, int, bool]:
|
||||
"""
|
||||
解析 LDAP 服务器地址,支持:
|
||||
- ldap://host:389
|
||||
- ldaps://host:636
|
||||
- host:389(无 scheme 时默认 ldap)
|
||||
|
||||
Returns:
|
||||
(host, port, use_ssl)
|
||||
"""
|
||||
raw = (server_url or "").strip()
|
||||
if not raw:
|
||||
raise ValueError("LDAP server_url is required")
|
||||
|
||||
parsed = urlparse(raw)
|
||||
if parsed.scheme in {"ldap", "ldaps"}:
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise ValueError("Invalid LDAP server_url")
|
||||
use_ssl = parsed.scheme == "ldaps"
|
||||
port = parsed.port or (636 if use_ssl else 389)
|
||||
return host, port, use_ssl
|
||||
|
||||
# 兼容无 scheme:按 ldap:// 解析
|
||||
parsed = urlparse(f"ldap://{raw}")
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise ValueError("Invalid LDAP server_url")
|
||||
port = parsed.port or 389
|
||||
return host, port, False
|
||||
|
||||
|
||||
def escape_ldap_filter(value: str, max_length: int = 128) -> str:
|
||||
"""
|
||||
转义 LDAP 过滤器中的特殊字符,防止 LDAP 注入攻击(RFC 4515)
|
||||
|
||||
Args:
|
||||
value: 需要转义的字符串
|
||||
max_length: 最大允许长度,默认 128 字符(覆盖大多数企业邮箱用户名)
|
||||
|
||||
Returns:
|
||||
转义后的安全字符串
|
||||
|
||||
Raises:
|
||||
ValueError: 输入值过长
|
||||
"""
|
||||
import unicodedata
|
||||
|
||||
# 先检查原始长度,防止 DoS 攻击
|
||||
# 128 字符足够覆盖大多数企业用户名和邮箱地址
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f"LDAP filter value too long (max {max_length} characters)")
|
||||
|
||||
# Unicode 规范化(使用 NFC 而非 NFKC,避免兼容性字符转换导致安全问题)
|
||||
value = unicodedata.normalize("NFC", value)
|
||||
|
||||
# 再次检查规范化后的长度(防止规范化后长度突增)
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f"LDAP filter value too long after normalization (max {max_length})")
|
||||
|
||||
# LDAP 过滤器特殊字符(RFC 4515 + 扩展)
|
||||
# 使用显式顺序处理,确保反斜杠首先转义
|
||||
value = value.replace("\\", r"\5c") # 反斜杠必须首先转义
|
||||
value = value.replace("*", r"\2a")
|
||||
value = value.replace("(", r"\28")
|
||||
value = value.replace(")", r"\29")
|
||||
value = value.replace("\x00", r"\00") # NUL
|
||||
value = value.replace("&", r"\26")
|
||||
value = value.replace("|", r"\7c")
|
||||
value = value.replace("=", r"\3d")
|
||||
value = value.replace(">", r"\3e")
|
||||
value = value.replace("<", r"\3c")
|
||||
value = value.replace("~", r"\7e")
|
||||
value = value.replace("!", r"\21")
|
||||
return value
|
||||
|
||||
|
||||
def _get_attr_value(entry: Any, attr_name: str, default: str = "") -> str:
|
||||
"""
|
||||
提取 LDAP 条目属性的首个值,避免返回字符串化的列表表示。
|
||||
"""
|
||||
attr = getattr(entry, attr_name, None)
|
||||
if not attr:
|
||||
return default
|
||||
# ldap3 的 EntryAttribute.value 已经是单值或列表,根据类型取首个
|
||||
val = getattr(attr, "value", None)
|
||||
if isinstance(val, list):
|
||||
val = val[0] if val else default
|
||||
if val is None:
|
||||
return default
|
||||
return str(val)
|
||||
|
||||
|
||||
class LDAPService:
|
||||
"""LDAP 认证服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_config(db: Session) -> Optional[LDAPConfig]:
|
||||
"""获取 LDAP 配置"""
|
||||
return db.query(LDAPConfig).first()
|
||||
|
||||
@staticmethod
|
||||
def is_ldap_enabled(db: Session) -> bool:
|
||||
"""检查 LDAP 是否可用(已启用且绑定密码可解密)"""
|
||||
return LDAPService.get_config_data(db) is not None
|
||||
|
||||
@staticmethod
|
||||
def is_ldap_exclusive(db: Session) -> bool:
|
||||
"""检查是否仅允许 LDAP 登录(仅在 LDAP 可用时生效,避免误锁定)"""
|
||||
config = LDAPService.get_config(db)
|
||||
if not config or config.is_exclusive is not True:
|
||||
return False
|
||||
return LDAPService.get_config_data(db) is not None
|
||||
|
||||
@staticmethod
|
||||
def get_config_data(db: Session) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
提前获取并解密配置,供线程池使用,避免跨线程共享 Session。
|
||||
"""
|
||||
config = LDAPService.get_config(db)
|
||||
if not config or config.is_enabled is not True:
|
||||
return None
|
||||
|
||||
try:
|
||||
bind_password = config.get_bind_password()
|
||||
except Exception as e:
|
||||
logger.error(f"LDAP 绑定密码解密失败: {e}")
|
||||
return None
|
||||
|
||||
# 绑定密码为空时无法进行 LDAP 认证
|
||||
if not bind_password:
|
||||
logger.warning("LDAP 绑定密码未配置,无法进行 LDAP 认证")
|
||||
return None
|
||||
|
||||
return {
|
||||
"server_url": config.server_url,
|
||||
"bind_dn": config.bind_dn,
|
||||
"bind_password": bind_password,
|
||||
"base_dn": config.base_dn,
|
||||
"user_search_filter": config.user_search_filter,
|
||||
"username_attr": config.username_attr,
|
||||
"email_attr": config.email_attr,
|
||||
"display_name_attr": config.display_name_attr,
|
||||
"use_starttls": config.use_starttls,
|
||||
"connect_timeout": config.connect_timeout or DEFAULT_LDAP_CONNECT_TIMEOUT,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def authenticate_with_config(config: Dict[str, Any], username: str, password: str) -> Optional[dict]:
|
||||
"""
|
||||
LDAP bind 验证
|
||||
|
||||
Args:
|
||||
config: 已解密的 LDAP 配置
|
||||
username: 用户名
|
||||
password: 密码
|
||||
|
||||
Returns:
|
||||
用户属性 dict {username, email, display_name} 或 None
|
||||
"""
|
||||
try:
|
||||
import ldap3
|
||||
from ldap3 import Server, Connection, SUBTREE
|
||||
from ldap3.core.exceptions import LDAPBindError, LDAPSocketOpenError
|
||||
except ImportError:
|
||||
logger.error("ldap3 库未安装")
|
||||
return None
|
||||
|
||||
if not config:
|
||||
logger.warning("LDAP 未配置或未启用")
|
||||
return None
|
||||
|
||||
admin_conn = None
|
||||
user_conn = None
|
||||
|
||||
try:
|
||||
# 创建服务器连接
|
||||
server_url = config["server_url"]
|
||||
server_host, server_port, use_ssl = parse_ldap_server_url(server_url)
|
||||
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
|
||||
server = Server(
|
||||
server_host,
|
||||
port=server_port,
|
||||
use_ssl=use_ssl,
|
||||
get_info=ldap3.ALL,
|
||||
connect_timeout=timeout,
|
||||
)
|
||||
|
||||
# 使用管理员账号连接
|
||||
bind_password = config["bind_password"]
|
||||
admin_conn = Connection(
|
||||
server,
|
||||
user=config["bind_dn"],
|
||||
password=bind_password,
|
||||
receive_timeout=timeout, # 添加读取超时,避免服务器响应缓慢时阻塞
|
||||
)
|
||||
|
||||
if config.get("use_starttls") and not use_ssl:
|
||||
admin_conn.start_tls()
|
||||
|
||||
if not admin_conn.bind():
|
||||
logger.error(f"LDAP 管理员绑定失败: {admin_conn.result}")
|
||||
return None
|
||||
|
||||
# 搜索用户(转义用户名防止 LDAP 注入)
|
||||
safe_username = escape_ldap_filter(username)
|
||||
search_filter = config["user_search_filter"].replace("{username}", safe_username)
|
||||
admin_conn.search(
|
||||
search_base=config["base_dn"],
|
||||
search_filter=search_filter,
|
||||
search_scope=SUBTREE,
|
||||
size_limit=2, # 防止过滤器误配导致匹配多用户
|
||||
time_limit=timeout, # 添加搜索超时,防止大型目录搜索阻塞
|
||||
attributes=[
|
||||
config["username_attr"],
|
||||
config["email_attr"],
|
||||
config["display_name_attr"],
|
||||
],
|
||||
)
|
||||
|
||||
if len(admin_conn.entries) != 1:
|
||||
# 统一错误信息,避免泄露用户是否存在;日志仅记录结果数量,不泄露敏感信息
|
||||
logger.warning(
|
||||
f"LDAP 认证失败(用户查找阶段): 搜索返回 {len(admin_conn.entries)} 条结果"
|
||||
)
|
||||
return None
|
||||
|
||||
user_entry = admin_conn.entries[0]
|
||||
user_dn = user_entry.entry_dn
|
||||
|
||||
# 用户密码验证
|
||||
user_conn = Connection(
|
||||
server,
|
||||
user=user_dn,
|
||||
password=password,
|
||||
receive_timeout=timeout, # 添加读取超时
|
||||
)
|
||||
if config.get("use_starttls") and not use_ssl:
|
||||
user_conn.start_tls()
|
||||
|
||||
if not user_conn.bind():
|
||||
# 统一错误信息,避免泄露密码是否正确;日志仅记录错误码,不泄露用户 DN
|
||||
bind_result = user_conn.result.get("description", "unknown")
|
||||
logger.warning(f"LDAP 认证失败(密码验证阶段): {bind_result}")
|
||||
return None
|
||||
|
||||
# 提取用户属性(优先用 LDAP 提供的值,不合法则回退默认)
|
||||
ldap_username = _get_attr_value(user_entry, config["username_attr"], username)
|
||||
email = _get_attr_value(
|
||||
user_entry, config["email_attr"], f"{username}@ldap.local"
|
||||
)
|
||||
display_name = _get_attr_value(user_entry, config["display_name_attr"], username)
|
||||
|
||||
logger.info(f"LDAP 认证成功: {username}")
|
||||
return {
|
||||
"username": ldap_username,
|
||||
"ldap_username": ldap_username,
|
||||
"ldap_dn": user_dn,
|
||||
"email": email,
|
||||
"display_name": display_name,
|
||||
}
|
||||
|
||||
except LDAPSocketOpenError as e:
|
||||
logger.error(f"LDAP 服务器连接失败: {e}")
|
||||
return None
|
||||
except LDAPBindError as e:
|
||||
logger.error(f"LDAP 绑定失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"LDAP 认证异常: {e}")
|
||||
return None
|
||||
finally:
|
||||
# 确保连接关闭,避免失败路径泄漏
|
||||
# 使用循环确保即使第一个 unbind 失败,后续连接仍会尝试关闭
|
||||
for conn, name in [(admin_conn, "admin"), (user_conn, "user")]:
|
||||
if conn:
|
||||
try:
|
||||
conn.unbind()
|
||||
except Exception as e:
|
||||
logger.warning(f"LDAP {name} 连接关闭失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def test_connection_with_config(config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""
|
||||
测试 LDAP 连接
|
||||
|
||||
Returns:
|
||||
(success, message)
|
||||
"""
|
||||
try:
|
||||
import ldap3
|
||||
from ldap3 import Server, Connection
|
||||
except ImportError:
|
||||
return False, "ldap3 库未安装"
|
||||
|
||||
if not config:
|
||||
return False, "LDAP 配置不存在"
|
||||
|
||||
conn = None
|
||||
try:
|
||||
server_url = config["server_url"]
|
||||
server_host, server_port, use_ssl = parse_ldap_server_url(server_url)
|
||||
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
|
||||
server = Server(
|
||||
server_host,
|
||||
port=server_port,
|
||||
use_ssl=use_ssl,
|
||||
get_info=ldap3.ALL,
|
||||
connect_timeout=timeout,
|
||||
)
|
||||
bind_password = config["bind_password"]
|
||||
conn = Connection(
|
||||
server,
|
||||
user=config["bind_dn"],
|
||||
password=bind_password,
|
||||
receive_timeout=timeout, # 添加读取超时
|
||||
)
|
||||
|
||||
if config.get("use_starttls") and not use_ssl:
|
||||
conn.start_tls()
|
||||
|
||||
if not conn.bind():
|
||||
return False, f"绑定失败: {conn.result}"
|
||||
|
||||
return True, "连接成功"
|
||||
|
||||
except Exception as e:
|
||||
# 记录详细错误到日志,但只返回通用信息给前端,避免泄露敏感信息
|
||||
logger.error(f"LDAP 测试连接失败: {type(e).__name__}: {e}")
|
||||
return False, "连接失败,请检查服务器地址、端口和凭据"
|
||||
finally:
|
||||
if conn:
|
||||
try:
|
||||
conn.unbind()
|
||||
except Exception as e:
|
||||
logger.warning(f"LDAP 测试连接关闭失败: {e}")
|
||||
|
||||
# 兼容旧接口:如果其他代码直接调用
|
||||
@staticmethod
|
||||
def authenticate(db: Session, username: str, password: str) -> Optional[dict]:
|
||||
config = LDAPService.get_config_data(db)
|
||||
return LDAPService.authenticate_with_config(config, username, password) if config else None
|
||||
|
||||
@staticmethod
|
||||
def test_connection(db: Session) -> Tuple[bool, str]:
|
||||
config = LDAPService.get_config_data(db)
|
||||
if not config:
|
||||
return False, "LDAP 配置不存在或未启用"
|
||||
return LDAPService.test_connection_with_config(config)
|
||||
@@ -2,21 +2,25 @@
|
||||
认证服务
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import secrets
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.config import config
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.logger import logger
|
||||
from src.core.enums import AuthSource
|
||||
from src.models.database import ApiKey, User, UserRole
|
||||
from src.services.auth.jwt_blacklist import JWTBlacklistService
|
||||
from src.services.auth.ldap import LDAPService
|
||||
from src.services.cache.user_cache import UserCacheService
|
||||
from src.services.user.apikey import ApiKeyService
|
||||
|
||||
@@ -92,15 +96,86 @@ class AuthService:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的Token")
|
||||
|
||||
@staticmethod
|
||||
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
"""用户登录认证"""
|
||||
async def authenticate_user(
|
||||
db: Session, email: str, password: str, auth_type: str = "local"
|
||||
) -> Optional[User]:
|
||||
"""用户登录认证
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
email: 邮箱/用户名
|
||||
password: 密码
|
||||
auth_type: 认证类型 ("local" 或 "ldap")
|
||||
"""
|
||||
if auth_type == "ldap":
|
||||
# LDAP 认证
|
||||
# 预取配置,避免将 Session 传递到线程池
|
||||
config_data = LDAPService.get_config_data(db)
|
||||
if not config_data:
|
||||
logger.warning("登录失败 - LDAP 未启用或配置无效")
|
||||
return None
|
||||
|
||||
# 计算总体超时:LDAP 认证包含多次网络操作(连接、管理员绑定、搜索、用户绑定)
|
||||
# 超时策略:
|
||||
# - 单次操作超时(connect_timeout):控制每次网络操作的最大等待时间
|
||||
# - 总体超时:防止异常场景(如服务器响应缓慢但未超时)导致请求堆积
|
||||
# - 公式:单次超时 × 4(覆盖 4 次主要网络操作)+ 10% 缓冲
|
||||
# - 最小 20 秒(保证基本操作),最大 60 秒(避免用户等待过长)
|
||||
single_timeout = config_data.get("connect_timeout", 10)
|
||||
total_timeout = max(20, min(int(single_timeout * 4 * 1.1), 60))
|
||||
|
||||
# 在线程池中执行阻塞的 LDAP 网络请求,避免阻塞事件循环
|
||||
# 添加总体超时保护,防止异常场景下请求堆积
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
ldap_user = await asyncio.wait_for(
|
||||
run_in_threadpool(
|
||||
LDAPService.authenticate_with_config, config_data, email, password
|
||||
),
|
||||
timeout=total_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"LDAP 认证总体超时({total_timeout}秒): {email}")
|
||||
return None
|
||||
|
||||
if not ldap_user:
|
||||
return None
|
||||
|
||||
# 获取或创建本地用户
|
||||
user = await AuthService._get_or_create_ldap_user(db, ldap_user)
|
||||
if not user:
|
||||
# 已有本地账号但来源不匹配等情况
|
||||
return None
|
||||
if not user.is_active:
|
||||
logger.warning(f"登录失败 - 用户已禁用: {email}")
|
||||
return None
|
||||
return user
|
||||
|
||||
# 本地认证
|
||||
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
# 支持邮箱或用户名登录
|
||||
from sqlalchemy import or_
|
||||
user = db.query(User).filter(
|
||||
or_(User.email == email, User.username == email)
|
||||
).first()
|
||||
|
||||
if not user:
|
||||
logger.warning(f"登录失败 - 用户不存在: {email}")
|
||||
return None
|
||||
|
||||
# 检查 LDAP exclusive 模式:仅允许本地管理员登录(紧急恢复通道)
|
||||
if LDAPService.is_ldap_exclusive(db):
|
||||
if user.role != UserRole.ADMIN or user.auth_source != AuthSource.LOCAL:
|
||||
logger.warning(f"登录失败 - 仅允许 LDAP 登录(管理员除外): {email}")
|
||||
return None
|
||||
logger.warning(f"[LDAP-EXCLUSIVE] 紧急恢复通道:本地管理员登录: {email}")
|
||||
|
||||
# 检查用户认证来源
|
||||
if user.auth_source == AuthSource.LDAP:
|
||||
logger.warning(f"登录失败 - 该用户使用 LDAP 认证: {email}")
|
||||
return None
|
||||
|
||||
if not user.verify_password(password):
|
||||
logger.warning(f"登录失败 - 密码错误: {email}")
|
||||
return None
|
||||
@@ -118,6 +193,127 @@ class AuthService:
|
||||
logger.info(f"用户登录成功: {email} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def _get_or_create_ldap_user(db: Session, ldap_user: dict) -> Optional[User]:
|
||||
"""获取或创建 LDAP 用户
|
||||
|
||||
Args:
|
||||
ldap_user: LDAP 用户信息 {username, email, display_name, ldap_dn, ldap_username}
|
||||
|
||||
注意:使用 with_for_update() 防止并发首次登录创建重复用户
|
||||
"""
|
||||
ldap_dn = (ldap_user.get("ldap_dn") or "").strip() or None
|
||||
ldap_username = (ldap_user.get("ldap_username") or ldap_user.get("username") or "").strip() or None
|
||||
email = ldap_user["email"]
|
||||
|
||||
# 优先用稳定标识查找,避免邮箱变更/用户名冲突导致重复建号
|
||||
# 使用 with_for_update() 锁定行,防止并发创建
|
||||
user: Optional[User] = None
|
||||
if ldap_dn:
|
||||
user = (
|
||||
db.query(User)
|
||||
.filter(User.auth_source == AuthSource.LDAP, User.ldap_dn == ldap_dn)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not user and ldap_username:
|
||||
user = (
|
||||
db.query(User)
|
||||
.filter(User.auth_source == AuthSource.LDAP, User.ldap_username == ldap_username)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
# 最后回退按 email 查找:如果存在同邮箱的本地账号,需要拒绝以避免接管
|
||||
user = db.query(User).filter(User.email == email).with_for_update().first()
|
||||
|
||||
if user:
|
||||
if user.auth_source != AuthSource.LDAP:
|
||||
# 避免覆盖已有本地账户(不同来源时拒绝登录)
|
||||
logger.warning(
|
||||
f"LDAP 登录拒绝 - 账户来源不匹配(现有:{user.auth_source}, 请求:LDAP): {email}"
|
||||
)
|
||||
return None
|
||||
|
||||
# 同步邮箱(LDAP 侧邮箱变更时更新;若新邮箱已被占用则拒绝)
|
||||
if user.email != email:
|
||||
email_taken = (
|
||||
db.query(User)
|
||||
.filter(User.email == email, User.id != user.id)
|
||||
.first()
|
||||
)
|
||||
if email_taken:
|
||||
logger.warning(f"LDAP 登录拒绝 - 新邮箱已被占用: {email}")
|
||||
return None
|
||||
user.email = email
|
||||
|
||||
# 同步 LDAP 标识(首次填充或 LDAP 侧发生变化)
|
||||
if ldap_dn and user.ldap_dn != ldap_dn:
|
||||
user.ldap_dn = ldap_dn
|
||||
if ldap_username and user.ldap_username != ldap_username:
|
||||
user.ldap_username = ldap_username
|
||||
|
||||
user.last_login_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
await UserCacheService.invalidate_user_cache(user.id, user.email)
|
||||
logger.info(f"LDAP 用户登录成功: {ldap_user['email']} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
# 检查 username 是否已被占用,使用时间戳+随机数确保唯一性
|
||||
base_username = ldap_username or ldap_user["username"]
|
||||
username = base_username
|
||||
max_retries = 3
|
||||
|
||||
for attempt in range(max_retries):
|
||||
# 检查用户名是否已存在
|
||||
existing_user_with_username = db.query(User).filter(User.username == username).first()
|
||||
if existing_user_with_username:
|
||||
# 如果 username 已存在,使用时间戳+随机数确保唯一性
|
||||
username = f"{base_username}_ldap_{int(time.time())}{uuid.uuid4().hex[:4]}"
|
||||
logger.info(f"LDAP 用户名冲突,使用新用户名: {ldap_user['username']} -> {username}")
|
||||
|
||||
# 创建新用户
|
||||
user = User(
|
||||
email=email,
|
||||
username=username,
|
||||
password_hash="", # LDAP 用户无本地密码
|
||||
auth_source=AuthSource.LDAP,
|
||||
ldap_dn=ldap_dn,
|
||||
ldap_username=ldap_username,
|
||||
role=UserRole.USER,
|
||||
is_active=True,
|
||||
last_login_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
try:
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
logger.info(f"LDAP 用户创建成功: {ldap_user['email']} (ID: {user.id})")
|
||||
return user
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_str = str(e.orig).lower() if e.orig else str(e).lower()
|
||||
|
||||
# 解析具体冲突类型
|
||||
if "email" in error_str or "ix_users_email" in error_str:
|
||||
# 邮箱冲突不应重试(前面已检查过,说明是并发创建)
|
||||
logger.error(f"LDAP 用户创建失败 - 邮箱并发冲突: {email}")
|
||||
return None
|
||||
elif "username" in error_str or "ix_users_username" in error_str:
|
||||
# 用户名冲突,重试时会生成新用户名
|
||||
if attempt == max_retries - 1:
|
||||
logger.error(f"LDAP 用户创建失败(用户名冲突重试耗尽): {username}")
|
||||
return None
|
||||
username = f"{base_username}_ldap_{int(time.time())}{uuid.uuid4().hex[:4]}"
|
||||
logger.warning(f"LDAP 用户创建用户名冲突,重试 ({attempt + 1}/{max_retries}): {username}")
|
||||
else:
|
||||
# 其他约束冲突,不重试
|
||||
logger.error(f"LDAP 用户创建失败 - 未知数据库约束冲突: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def authenticate_api_key(db: Session, api_key: str) -> Optional[tuple[User, ApiKey]]:
|
||||
"""API密钥认证"""
|
||||
|
||||
Reference in New Issue
Block a user