mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-10 11:42:27 +08:00
feat: add ldap login
This commit is contained in:
@@ -5,6 +5,7 @@ from fastapi import APIRouter
|
||||
from .adaptive import router as adaptive_router
|
||||
from .api_keys import router as api_keys_router
|
||||
from .endpoints import router as endpoints_router
|
||||
from .ldap import router as ldap_router
|
||||
from .models import router as models_router
|
||||
from .monitoring import router as monitoring_router
|
||||
from .provider_query import router as provider_query_router
|
||||
@@ -28,5 +29,6 @@ router.include_router(adaptive_router)
|
||||
router.include_router(models_router)
|
||||
router.include_router(security_router)
|
||||
router.include_router(provider_query_router)
|
||||
router.include_router(ldap_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
|
||||
190
src/api/admin/ldap.py
Normal file
190
src/api/admin/ldap.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""LDAP配置管理API端点。"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
||||
from src.database import get_db
|
||||
from src.models.database import LDAPConfig
|
||||
|
||||
router = APIRouter(prefix="/api/admin/ldap", tags=["Admin - LDAP"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
# ========== Request/Response Models ==========
|
||||
|
||||
|
||||
class LDAPConfigResponse(BaseModel):
|
||||
"""LDAP配置响应(不返回密码)"""
|
||||
|
||||
server_url: Optional[str] = None
|
||||
bind_dn: Optional[str] = None
|
||||
base_dn: Optional[str] = None
|
||||
user_search_filter: str
|
||||
username_attr: str
|
||||
email_attr: str
|
||||
display_name_attr: str
|
||||
is_enabled: bool
|
||||
is_exclusive: bool
|
||||
use_starttls: bool
|
||||
|
||||
|
||||
class LDAPConfigUpdate(BaseModel):
|
||||
"""LDAP配置更新请求"""
|
||||
|
||||
server_url: str = Field(..., min_length=1, max_length=255)
|
||||
bind_dn: str = Field(..., min_length=1, max_length=255)
|
||||
bind_password: Optional[str] = Field(None, min_length=1)
|
||||
base_dn: str = Field(..., min_length=1, max_length=255)
|
||||
user_search_filter: str = Field(default="(uid={username})", max_length=500)
|
||||
username_attr: str = Field(default="uid", max_length=50)
|
||||
email_attr: str = Field(default="mail", max_length=50)
|
||||
display_name_attr: str = Field(default="cn", max_length=50)
|
||||
is_enabled: bool = False
|
||||
is_exclusive: bool = False
|
||||
use_starttls: bool = False
|
||||
|
||||
|
||||
class LDAPTestResponse(BaseModel):
|
||||
"""LDAP连接测试响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
# ========== API Endpoints ==========
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_ldap_config(request: Request, db: Session = Depends(get_db)) -> Any:
|
||||
"""获取LDAP配置(管理员)"""
|
||||
adapter = AdminGetLDAPConfigAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/config")
|
||||
async def update_ldap_config(request: Request, db: Session = Depends(get_db)) -> Any:
|
||||
"""更新LDAP配置(管理员)"""
|
||||
adapter = AdminUpdateLDAPConfigAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_ldap_connection(request: Request, db: Session = Depends(get_db)) -> Any:
|
||||
"""测试LDAP连接(管理员)"""
|
||||
adapter = AdminTestLDAPConnectionAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ========== Adapters ==========
|
||||
|
||||
|
||||
class AdminGetLDAPConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context) -> Dict[str, Any]: # type: ignore[override]
|
||||
db = context.db
|
||||
config = db.query(LDAPConfig).first()
|
||||
|
||||
if not config:
|
||||
return LDAPConfigResponse(
|
||||
server_url=None,
|
||||
bind_dn=None,
|
||||
base_dn=None,
|
||||
user_search_filter="(uid={username})",
|
||||
username_attr="uid",
|
||||
email_attr="mail",
|
||||
display_name_attr="cn",
|
||||
is_enabled=False,
|
||||
is_exclusive=False,
|
||||
use_starttls=False,
|
||||
).model_dump()
|
||||
|
||||
return LDAPConfigResponse(
|
||||
server_url=config.server_url,
|
||||
bind_dn=config.bind_dn,
|
||||
base_dn=config.base_dn,
|
||||
user_search_filter=config.user_search_filter,
|
||||
username_attr=config.username_attr,
|
||||
email_attr=config.email_attr,
|
||||
display_name_attr=config.display_name_attr,
|
||||
is_enabled=config.is_enabled,
|
||||
is_exclusive=config.is_exclusive,
|
||||
use_starttls=config.use_starttls,
|
||||
).model_dump()
|
||||
|
||||
|
||||
class AdminUpdateLDAPConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context) -> Dict[str, str]: # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
try:
|
||||
config_update = LDAPConfigUpdate.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
config = db.query(LDAPConfig).first()
|
||||
|
||||
if not config:
|
||||
config = LDAPConfig()
|
||||
db.add(config)
|
||||
|
||||
config.server_url = config_update.server_url
|
||||
config.bind_dn = config_update.bind_dn
|
||||
config.base_dn = config_update.base_dn
|
||||
config.user_search_filter = config_update.user_search_filter
|
||||
config.username_attr = config_update.username_attr
|
||||
config.email_attr = config_update.email_attr
|
||||
config.display_name_attr = config_update.display_name_attr
|
||||
config.is_enabled = config_update.is_enabled
|
||||
config.is_exclusive = config_update.is_exclusive
|
||||
config.use_starttls = config_update.use_starttls
|
||||
|
||||
if config_update.bind_password:
|
||||
config.bind_password_encrypted = crypto_service.encrypt(config_update.bind_password)
|
||||
|
||||
db.commit()
|
||||
|
||||
return {"message": "LDAP配置更新成功"}
|
||||
|
||||
|
||||
class AdminTestLDAPConnectionAdapter(AdminApiAdapter):
|
||||
async def handle(self, context) -> Dict[str, Any]: # type: ignore[override]
|
||||
db = context.db
|
||||
config = db.query(LDAPConfig).first()
|
||||
|
||||
if not config:
|
||||
return LDAPTestResponse(success=False, message="LDAP配置不存在").model_dump()
|
||||
|
||||
try:
|
||||
import ldap3
|
||||
|
||||
bind_password = crypto_service.decrypt(config.bind_password_encrypted)
|
||||
|
||||
use_ssl = config.server_url.startswith("ldaps://")
|
||||
server = ldap3.Server(config.server_url, use_ssl=use_ssl, get_info=ldap3.ALL)
|
||||
conn = ldap3.Connection(server, user=config.bind_dn, password=bind_password)
|
||||
|
||||
if config.use_starttls and not use_ssl:
|
||||
conn.start_tls()
|
||||
|
||||
if not conn.bind():
|
||||
return LDAPTestResponse(
|
||||
success=False, message=f"绑定失败: {conn.result}"
|
||||
).model_dump()
|
||||
|
||||
conn.unbind()
|
||||
return LDAPTestResponse(success=True, message="LDAP连接测试成功").model_dump()
|
||||
|
||||
except ImportError:
|
||||
return LDAPTestResponse(success=False, message="ldap3库未安装").model_dump()
|
||||
except Exception as e:
|
||||
return LDAPTestResponse(success=False, message=f"连接失败: {str(e)}").model_dump()
|
||||
@@ -33,6 +33,7 @@ from src.models.api import (
|
||||
)
|
||||
from src.models.database import AuditEventType, User, UserRole
|
||||
from src.services.auth.service import AuthService
|
||||
from src.services.auth.ldap import LDAPService
|
||||
from src.services.rate_limit.ip_limiter import IPRateLimiter
|
||||
from src.services.system.audit import AuditService
|
||||
from src.services.system.config import SystemConfigService
|
||||
@@ -99,6 +100,13 @@ async def registration_settings(request: Request, db: Session = Depends(get_db))
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/settings")
|
||||
async def auth_settings(request: Request, db: Session = Depends(get_db)):
|
||||
"""公开获取认证设置(用于前端判断显示哪些登录选项)"""
|
||||
adapter = AuthSettingsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
async def login(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AuthLoginAdapter()
|
||||
@@ -193,7 +201,9 @@ class AuthLoginAdapter(AuthPublicAdapter):
|
||||
detail=f"登录请求过于频繁,请在 {reset_after} 秒后重试",
|
||||
)
|
||||
|
||||
user = await AuthService.authenticate_user(db, login_request.email, login_request.password)
|
||||
user = await AuthService.authenticate_user(
|
||||
db, login_request.email, login_request.password, login_request.auth_type
|
||||
)
|
||||
if not user:
|
||||
AuditService.log_login_attempt(
|
||||
db=db,
|
||||
@@ -305,6 +315,21 @@ class AuthRegistrationSettingsAdapter(AuthPublicAdapter):
|
||||
).model_dump()
|
||||
|
||||
|
||||
class AuthSettingsAdapter(AuthPublicAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""公开返回认证设置"""
|
||||
db = context.db
|
||||
|
||||
ldap_enabled = LDAPService.is_ldap_enabled(db)
|
||||
ldap_exclusive = LDAPService.is_ldap_exclusive(db)
|
||||
|
||||
return {
|
||||
"local_enabled": not ldap_exclusive,
|
||||
"ldap_enabled": ldap_enabled,
|
||||
"ldap_exclusive": ldap_exclusive,
|
||||
}
|
||||
|
||||
|
||||
class AuthRegisterAdapter(AuthPublicAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.models.database import SystemConfig
|
||||
|
||||
@@ -30,3 +30,10 @@ class ProviderBillingType(Enum):
|
||||
MONTHLY_QUOTA = "monthly_quota" # 月卡额度
|
||||
PAY_AS_YOU_GO = "pay_as_you_go" # 按量付费
|
||||
FREE_TIER = "free_tier" # 免费额度
|
||||
|
||||
|
||||
class AuthSource(str, Enum):
|
||||
"""认证来源枚举"""
|
||||
|
||||
LOCAL = "local" # 本地认证
|
||||
LDAP = "ldap" # LDAP 认证
|
||||
|
||||
@@ -4,7 +4,7 @@ API端点请求/响应模型定义
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
@@ -17,6 +17,7 @@ class LoginRequest(BaseModel):
|
||||
|
||||
email: str = Field(..., min_length=3, max_length=255, description="邮箱地址")
|
||||
password: str = Field(..., min_length=1, max_length=128, description="密码")
|
||||
auth_type: Literal["local", "ldap"] = Field(default="local", description="认证类型")
|
||||
|
||||
@classmethod
|
||||
@field_validator("email")
|
||||
|
||||
@@ -30,7 +30,7 @@ from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
from ..config import config
|
||||
from ..core.enums import ProviderBillingType, UserRole
|
||||
from ..core.enums import AuthSource, ProviderBillingType, UserRole
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
@@ -54,6 +54,16 @@ class User(Base):
|
||||
default=UserRole.USER,
|
||||
nullable=False,
|
||||
)
|
||||
auth_source = Column(
|
||||
Enum(
|
||||
AuthSource,
|
||||
name="authsource",
|
||||
create_type=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=AuthSource.LOCAL,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# 访问限制(NULL 表示不限制,允许访问所有资源)
|
||||
allowed_providers = Column(JSON, nullable=True) # 允许使用的提供商 ID 列表
|
||||
@@ -428,6 +438,67 @@ class SystemConfig(Base):
|
||||
)
|
||||
|
||||
|
||||
class LDAPConfig(Base):
|
||||
"""LDAP认证配置表 - 单行配置"""
|
||||
|
||||
__tablename__ = "ldap_configs"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
server_url = Column(String(255), nullable=False) # ldap://host:389 或 ldaps://host:636
|
||||
bind_dn = Column(String(255), nullable=False) # 绑定账号 DN
|
||||
bind_password_encrypted = Column(Text, nullable=False) # 加密的绑定密码
|
||||
base_dn = Column(String(255), nullable=False) # 用户搜索基础 DN
|
||||
user_search_filter = Column(
|
||||
String(500), default="(uid={username})", nullable=False
|
||||
) # 用户搜索过滤器
|
||||
username_attr = Column(String(50), default="uid", nullable=False) # 用户名属性 (uid/sAMAccountName)
|
||||
email_attr = Column(String(50), default="mail", nullable=False) # 邮箱属性
|
||||
display_name_attr = Column(String(50), default="cn", nullable=False) # 显示名称属性
|
||||
is_enabled = Column(Boolean, default=False, nullable=False) # 是否启用 LDAP 认证
|
||||
is_exclusive = Column(
|
||||
Boolean, default=False, nullable=False
|
||||
) # 是否仅允许 LDAP 登录(禁用本地认证)
|
||||
use_starttls = Column(Boolean, default=False, nullable=False) # 是否使用 STARTTLS
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def set_bind_password(self, password: str) -> None:
|
||||
"""
|
||||
设置并加密绑定密码
|
||||
|
||||
Args:
|
||||
password: 明文密码
|
||||
"""
|
||||
from src.core.crypto import crypto_service
|
||||
|
||||
self.bind_password_encrypted = crypto_service.encrypt(password)
|
||||
|
||||
def get_bind_password(self) -> str:
|
||||
"""
|
||||
获取解密后的绑定密码
|
||||
|
||||
Returns:
|
||||
str: 解密后的明文密码
|
||||
|
||||
Raises:
|
||||
DecryptionException: 解密失败时抛出异常
|
||||
"""
|
||||
from src.core.crypto import crypto_service
|
||||
|
||||
if not self.bind_password_encrypted:
|
||||
return ""
|
||||
return crypto_service.decrypt(self.bind_password_encrypted)
|
||||
|
||||
|
||||
class Provider(Base):
|
||||
"""提供商配置表"""
|
||||
|
||||
|
||||
157
src/services/auth/ldap.py
Normal file
157
src/services/auth/ldap.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""LDAP 认证服务"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import LDAPConfig
|
||||
|
||||
|
||||
class LDAPService:
|
||||
"""LDAP 认证服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_config(db: Session) -> Optional[LDAPConfig]:
|
||||
"""获取 LDAP 配置"""
|
||||
return db.query(LDAPConfig).first()
|
||||
|
||||
@staticmethod
|
||||
def is_ldap_enabled(db: Session) -> bool:
|
||||
"""检查 LDAP 是否启用"""
|
||||
config = LDAPService.get_config(db)
|
||||
return config.is_enabled if config else False
|
||||
|
||||
@staticmethod
|
||||
def is_ldap_exclusive(db: Session) -> bool:
|
||||
"""检查是否仅允许 LDAP 登录"""
|
||||
config = LDAPService.get_config(db)
|
||||
return config.is_exclusive if config and config.is_enabled else False
|
||||
|
||||
@staticmethod
|
||||
def authenticate(db: Session, username: str, password: str) -> Optional[dict]:
|
||||
"""
|
||||
LDAP bind 验证
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
username: 用户名
|
||||
password: 密码
|
||||
|
||||
Returns:
|
||||
用户属性 dict {username, email, display_name} 或 None
|
||||
"""
|
||||
try:
|
||||
import ldap3
|
||||
from ldap3 import Server, Connection, SUBTREE
|
||||
from ldap3.core.exceptions import LDAPBindError, LDAPSocketOpenError
|
||||
except ImportError:
|
||||
logger.error("ldap3 库未安装")
|
||||
return None
|
||||
|
||||
config = LDAPService.get_config(db)
|
||||
if not config or not config.is_enabled:
|
||||
logger.warning("LDAP 未配置或未启用")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 创建服务器连接
|
||||
use_ssl = config.server_url.startswith("ldaps://")
|
||||
server = Server(config.server_url, use_ssl=use_ssl, get_info=ldap3.ALL)
|
||||
|
||||
# 使用管理员账号连接
|
||||
bind_password = config.get_bind_password()
|
||||
admin_conn = Connection(server, user=config.bind_dn, password=bind_password)
|
||||
|
||||
if config.use_starttls and not use_ssl:
|
||||
admin_conn.start_tls()
|
||||
|
||||
if not admin_conn.bind():
|
||||
logger.error(f"LDAP 管理员绑定失败: {admin_conn.result}")
|
||||
return None
|
||||
|
||||
# 搜索用户
|
||||
search_filter = config.user_search_filter.replace("{username}", username)
|
||||
admin_conn.search(
|
||||
search_base=config.base_dn,
|
||||
search_filter=search_filter,
|
||||
search_scope=SUBTREE,
|
||||
attributes=[config.username_attr, config.email_attr, config.display_name_attr],
|
||||
)
|
||||
|
||||
if not admin_conn.entries:
|
||||
logger.warning(f"LDAP 用户未找到: {username}")
|
||||
admin_conn.unbind()
|
||||
return None
|
||||
|
||||
user_entry = admin_conn.entries[0]
|
||||
user_dn = user_entry.entry_dn
|
||||
admin_conn.unbind()
|
||||
|
||||
# 用户密码验证
|
||||
user_conn = Connection(server, user=user_dn, password=password)
|
||||
if config.use_starttls and not use_ssl:
|
||||
user_conn.start_tls()
|
||||
|
||||
if not user_conn.bind():
|
||||
logger.warning(f"LDAP 密码验证失败: {username}")
|
||||
return None
|
||||
|
||||
user_conn.unbind()
|
||||
|
||||
# 提取用户属性
|
||||
email = str(getattr(user_entry, config.email_attr, "")) or f"{username}@ldap.local"
|
||||
display_name = str(getattr(user_entry, config.display_name_attr, "")) or username
|
||||
|
||||
logger.info(f"LDAP 认证成功: {username}")
|
||||
return {
|
||||
"username": username,
|
||||
"email": email,
|
||||
"display_name": display_name,
|
||||
}
|
||||
|
||||
except LDAPSocketOpenError as e:
|
||||
logger.error(f"LDAP 服务器连接失败: {e}")
|
||||
return None
|
||||
except LDAPBindError as e:
|
||||
logger.error(f"LDAP 绑定失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"LDAP 认证异常: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def test_connection(db: Session) -> Tuple[bool, str]:
|
||||
"""
|
||||
测试 LDAP 连接
|
||||
|
||||
Returns:
|
||||
(success, message)
|
||||
"""
|
||||
try:
|
||||
import ldap3
|
||||
from ldap3 import Server, Connection
|
||||
except ImportError:
|
||||
return False, "ldap3 库未安装"
|
||||
|
||||
config = LDAPService.get_config(db)
|
||||
if not config:
|
||||
return False, "LDAP 配置不存在"
|
||||
|
||||
try:
|
||||
use_ssl = config.server_url.startswith("ldaps://")
|
||||
server = Server(config.server_url, use_ssl=use_ssl, get_info=ldap3.ALL)
|
||||
bind_password = config.get_bind_password()
|
||||
conn = Connection(server, user=config.bind_dn, password=bind_password)
|
||||
|
||||
if config.use_starttls and not use_ssl:
|
||||
conn.start_tls()
|
||||
|
||||
if not conn.bind():
|
||||
return False, f"绑定失败: {conn.result}"
|
||||
|
||||
conn.unbind()
|
||||
return True, "连接成功"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"连接失败: {str(e)}"
|
||||
@@ -15,8 +15,10 @@ from sqlalchemy.orm import Session, joinedload
|
||||
from src.config import config
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.logger import logger
|
||||
from src.core.enums import AuthSource
|
||||
from src.models.database import ApiKey, User, UserRole
|
||||
from src.services.auth.jwt_blacklist import JWTBlacklistService
|
||||
from src.services.auth.ldap import LDAPService
|
||||
from src.services.cache.user_cache import UserCacheService
|
||||
from src.services.user.apikey import ApiKeyService
|
||||
|
||||
@@ -92,8 +94,36 @@ class AuthService:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的Token")
|
||||
|
||||
@staticmethod
|
||||
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
"""用户登录认证"""
|
||||
async def authenticate_user(
|
||||
db: Session, email: str, password: str, auth_type: str = "local"
|
||||
) -> Optional[User]:
|
||||
"""用户登录认证
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
email: 邮箱/用户名
|
||||
password: 密码
|
||||
auth_type: 认证类型 ("local" 或 "ldap")
|
||||
"""
|
||||
if auth_type == "ldap":
|
||||
# LDAP 认证
|
||||
if not LDAPService.is_ldap_enabled(db):
|
||||
logger.warning("登录失败 - LDAP 认证未启用")
|
||||
return None
|
||||
|
||||
ldap_user = LDAPService.authenticate(db, email, password)
|
||||
if not ldap_user:
|
||||
return None
|
||||
|
||||
# 获取或创建本地用户
|
||||
user = await AuthService._get_or_create_ldap_user(db, ldap_user)
|
||||
return user
|
||||
|
||||
# 本地认证
|
||||
if LDAPService.is_ldap_exclusive(db):
|
||||
logger.warning("登录失败 - 仅允许 LDAP 登录")
|
||||
return None
|
||||
|
||||
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
|
||||
@@ -101,6 +131,11 @@ class AuthService:
|
||||
logger.warning(f"登录失败 - 用户不存在: {email}")
|
||||
return None
|
||||
|
||||
# 检查用户认证来源
|
||||
if user.auth_source == AuthSource.LDAP:
|
||||
logger.warning(f"登录失败 - 该用户使用 LDAP 认证: {email}")
|
||||
return None
|
||||
|
||||
if not user.verify_password(password):
|
||||
logger.warning(f"登录失败 - 密码错误: {email}")
|
||||
return None
|
||||
@@ -118,6 +153,42 @@ class AuthService:
|
||||
logger.info(f"用户登录成功: {email} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def _get_or_create_ldap_user(db: Session, ldap_user: dict) -> User:
|
||||
"""获取或创建 LDAP 用户
|
||||
|
||||
Args:
|
||||
ldap_user: LDAP 用户信息 {username, email, display_name}
|
||||
"""
|
||||
# 先按 email 查找
|
||||
user = db.query(User).filter(User.email == ldap_user["email"]).first()
|
||||
|
||||
if user:
|
||||
# 更新 auth_source(如果是首次 LDAP 登录)
|
||||
if user.auth_source != AuthSource.LDAP:
|
||||
user.auth_source = AuthSource.LDAP
|
||||
user.last_login_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
await UserCacheService.invalidate_user_cache(user.id, user.email)
|
||||
logger.info(f"LDAP 用户登录成功: {ldap_user['email']} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
# 创建新用户
|
||||
user = User(
|
||||
email=ldap_user["email"],
|
||||
username=ldap_user["username"],
|
||||
password_hash="", # LDAP 用户无本地密码
|
||||
auth_source=AuthSource.LDAP,
|
||||
role=UserRole.USER,
|
||||
is_active=True,
|
||||
last_login_at=datetime.now(timezone.utc),
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
logger.info(f"LDAP 用户创建成功: {ldap_user['email']} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def authenticate_api_key(db: Session, api_key: str) -> Optional[tuple[User, ApiKey]]:
|
||||
"""API密钥认证"""
|
||||
|
||||
Reference in New Issue
Block a user