mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-10 11:42:27 +08:00
feat: 添加 LDAP 认证支持
- 新增 LDAP 服务和 API 接口 - 添加 LDAP 配置管理页面 - 登录页面支持 LDAP/本地认证切换 - 数据库迁移支持 LDAP 相关字段
This commit is contained in:
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
||||
commit_id: COMMIT_ID
|
||||
__commit_id__: COMMIT_ID
|
||||
|
||||
__version__ = version = '0.1.1.dev0+g393d4d13f.d20251213'
|
||||
__version_tuple__ = version_tuple = (0, 1, 1, 'dev0', 'g393d4d13f.d20251213')
|
||||
__version__ = version = '0.2.3.dev0+g0f78d5cbf.d20260105'
|
||||
__version_tuple__ = version_tuple = (0, 2, 3, 'dev0', 'g0f78d5cbf.d20260105')
|
||||
|
||||
__commit_id__ = commit_id = None
|
||||
|
||||
@@ -5,6 +5,7 @@ from fastapi import APIRouter
|
||||
from .adaptive import router as adaptive_router
|
||||
from .api_keys import router as api_keys_router
|
||||
from .endpoints import router as endpoints_router
|
||||
from .ldap import router as ldap_router
|
||||
from .models import router as models_router
|
||||
from .monitoring import router as monitoring_router
|
||||
from .provider_query import router as provider_query_router
|
||||
@@ -28,5 +29,6 @@ router.include_router(adaptive_router)
|
||||
router.include_router(models_router)
|
||||
router.include_router(security_router)
|
||||
router.include_router(provider_query_router)
|
||||
router.include_router(ldap_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
|
||||
427
src/api/admin/ldap.py
Normal file
427
src/api/admin/ldap.py
Normal file
@@ -0,0 +1,427 @@
|
||||
"""LDAP配置管理API端点。"""
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.enums import AuthSource
|
||||
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import AuditEventType, LDAPConfig, User, UserRole
|
||||
from src.services.system.audit import AuditService
|
||||
|
||||
router = APIRouter(prefix="/api/admin/ldap", tags=["Admin - LDAP"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
# bcrypt 哈希格式正则:$2a$, $2b$, $2y$ + 2位cost + $ + 53字符(22位salt + 31位hash)
|
||||
BCRYPT_HASH_PATTERN = re.compile(r"^\$2[aby]\$\d{2}\$.{53}$")
|
||||
|
||||
|
||||
# ========== Request/Response Models ==========
|
||||
|
||||
|
||||
class LDAPConfigResponse(BaseModel):
|
||||
"""LDAP配置响应(不返回密码)"""
|
||||
|
||||
server_url: Optional[str] = None
|
||||
bind_dn: Optional[str] = None
|
||||
base_dn: Optional[str] = None
|
||||
has_bind_password: bool = False
|
||||
user_search_filter: str
|
||||
username_attr: str
|
||||
email_attr: str
|
||||
display_name_attr: str
|
||||
is_enabled: bool
|
||||
is_exclusive: bool
|
||||
use_starttls: bool
|
||||
connect_timeout: int
|
||||
|
||||
|
||||
class LDAPConfigUpdate(BaseModel):
|
||||
"""LDAP配置更新请求"""
|
||||
|
||||
server_url: str = Field(..., min_length=1, max_length=255)
|
||||
bind_dn: str = Field(..., min_length=1, max_length=255)
|
||||
# 允许空字符串表示"清除密码";非空时自动 strip 并校验不能为空
|
||||
bind_password: Optional[str] = Field(None, max_length=1024)
|
||||
base_dn: str = Field(..., min_length=1, max_length=255)
|
||||
user_search_filter: str = Field(default="(uid={username})", max_length=500)
|
||||
username_attr: str = Field(default="uid", max_length=50)
|
||||
email_attr: str = Field(default="mail", max_length=50)
|
||||
display_name_attr: str = Field(default="cn", max_length=50)
|
||||
is_enabled: bool = False
|
||||
is_exclusive: bool = False
|
||||
use_starttls: bool = False
|
||||
connect_timeout: int = Field(default=10, ge=1, le=60) # 单次操作超时,跨国网络建议 15-30 秒
|
||||
|
||||
@field_validator("bind_password")
|
||||
@classmethod
|
||||
def validate_bind_password(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None or v == "":
|
||||
return v
|
||||
v = v.strip()
|
||||
if not v:
|
||||
raise ValueError("绑定密码不能为空")
|
||||
return v
|
||||
|
||||
@field_validator("user_search_filter")
|
||||
@classmethod
|
||||
def validate_search_filter(cls, v: str) -> str:
|
||||
if "{username}" not in v:
|
||||
raise ValueError("搜索过滤器必须包含 {username} 占位符")
|
||||
# 验证括号匹配和嵌套正确性
|
||||
depth = 0
|
||||
for char in v:
|
||||
if char == "(":
|
||||
depth += 1
|
||||
elif char == ")":
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
raise ValueError("搜索过滤器括号不匹配")
|
||||
if depth != 0:
|
||||
raise ValueError("搜索过滤器括号不匹配")
|
||||
# 限制过滤器复杂度,防止构造复杂查询
|
||||
# 检查嵌套层数而非括号总数
|
||||
depth = 0
|
||||
max_depth = 0
|
||||
for char in v:
|
||||
if char == "(":
|
||||
depth += 1
|
||||
max_depth = max(max_depth, depth)
|
||||
elif char == ")":
|
||||
depth -= 1
|
||||
if max_depth > 5:
|
||||
raise ValueError("搜索过滤器嵌套层数过深(最多5层)")
|
||||
if len(v) > 200:
|
||||
raise ValueError("搜索过滤器过长(最多200字符)")
|
||||
return v
|
||||
|
||||
|
||||
class LDAPTestResponse(BaseModel):
|
||||
"""LDAP连接测试响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class LDAPConfigTest(BaseModel):
|
||||
"""LDAP配置测试请求(全部可选,用于临时覆盖)"""
|
||||
|
||||
server_url: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
bind_dn: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
bind_password: Optional[str] = Field(None, min_length=1)
|
||||
base_dn: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
user_search_filter: Optional[str] = Field(None, max_length=500)
|
||||
username_attr: Optional[str] = Field(None, max_length=50)
|
||||
email_attr: Optional[str] = Field(None, max_length=50)
|
||||
display_name_attr: Optional[str] = Field(None, max_length=50)
|
||||
is_enabled: Optional[bool] = None
|
||||
is_exclusive: Optional[bool] = None
|
||||
use_starttls: Optional[bool] = None
|
||||
connect_timeout: Optional[int] = Field(None, ge=1, le=60)
|
||||
|
||||
@field_validator("user_search_filter")
|
||||
@classmethod
|
||||
def validate_search_filter(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is None:
|
||||
return v
|
||||
if "{username}" not in v:
|
||||
raise ValueError("搜索过滤器必须包含 {username} 占位符")
|
||||
# 验证括号匹配和嵌套正确性
|
||||
depth = 0
|
||||
for char in v:
|
||||
if char == "(":
|
||||
depth += 1
|
||||
elif char == ")":
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
raise ValueError("搜索过滤器括号不匹配")
|
||||
if depth != 0:
|
||||
raise ValueError("搜索过滤器括号不匹配")
|
||||
# 限制过滤器复杂度(检查嵌套层数而非括号总数)
|
||||
depth = 0
|
||||
max_depth = 0
|
||||
for char in v:
|
||||
if char == "(":
|
||||
depth += 1
|
||||
max_depth = max(max_depth, depth)
|
||||
elif char == ")":
|
||||
depth -= 1
|
||||
if max_depth > 5:
|
||||
raise ValueError("搜索过滤器嵌套层数过深(最多5层)")
|
||||
if len(v) > 200:
|
||||
raise ValueError("搜索过滤器过长(最多200字符)")
|
||||
return v
|
||||
|
||||
|
||||
# ========== API Endpoints ==========
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_ldap_config(request: Request, db: Session = Depends(get_db)) -> Any:
|
||||
"""获取LDAP配置(管理员)"""
|
||||
adapter = AdminGetLDAPConfigAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/config")
|
||||
async def update_ldap_config(request: Request, db: Session = Depends(get_db)) -> Any:
|
||||
"""更新LDAP配置(管理员)"""
|
||||
adapter = AdminUpdateLDAPConfigAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/test")
|
||||
async def test_ldap_connection(request: Request, db: Session = Depends(get_db)) -> Any:
|
||||
"""测试LDAP连接(管理员)"""
|
||||
adapter = AdminTestLDAPConnectionAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ========== Adapters ==========
|
||||
|
||||
|
||||
class AdminGetLDAPConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context) -> Dict[str, Any]: # type: ignore[override]
|
||||
db = context.db
|
||||
config = db.query(LDAPConfig).first()
|
||||
|
||||
if not config:
|
||||
return LDAPConfigResponse(
|
||||
server_url=None,
|
||||
bind_dn=None,
|
||||
base_dn=None,
|
||||
has_bind_password=False,
|
||||
user_search_filter="(uid={username})",
|
||||
username_attr="uid",
|
||||
email_attr="mail",
|
||||
display_name_attr="cn",
|
||||
is_enabled=False,
|
||||
is_exclusive=False,
|
||||
use_starttls=False,
|
||||
connect_timeout=10,
|
||||
).model_dump()
|
||||
|
||||
return LDAPConfigResponse(
|
||||
server_url=config.server_url,
|
||||
bind_dn=config.bind_dn,
|
||||
base_dn=config.base_dn,
|
||||
has_bind_password=bool(config.bind_password_encrypted),
|
||||
user_search_filter=config.user_search_filter,
|
||||
username_attr=config.username_attr,
|
||||
email_attr=config.email_attr,
|
||||
display_name_attr=config.display_name_attr,
|
||||
is_enabled=config.is_enabled,
|
||||
is_exclusive=config.is_exclusive,
|
||||
use_starttls=config.use_starttls,
|
||||
connect_timeout=config.connect_timeout,
|
||||
).model_dump()
|
||||
|
||||
|
||||
class AdminUpdateLDAPConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context) -> Dict[str, str]: # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
try:
|
||||
config_update = LDAPConfigUpdate.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
# 使用行级锁防止并发修改导致的竞态条件
|
||||
config = db.query(LDAPConfig).with_for_update().first()
|
||||
is_new_config = config is None
|
||||
|
||||
if is_new_config:
|
||||
# 首次创建配置时必须提供密码
|
||||
if not config_update.bind_password:
|
||||
raise InvalidRequestException("首次配置 LDAP 时必须设置绑定密码")
|
||||
config = LDAPConfig()
|
||||
db.add(config)
|
||||
|
||||
# 需要启用 LDAP 且未提交新密码时,验证已保存密码可解密(避免开启后不可用)
|
||||
if config_update.is_enabled and config_update.bind_password is None:
|
||||
try:
|
||||
if not config.get_bind_password():
|
||||
raise InvalidRequestException("启用 LDAP 认证 需要先设置绑定密码")
|
||||
except InvalidRequestException:
|
||||
raise
|
||||
except Exception:
|
||||
raise InvalidRequestException("绑定密码解密失败,请重新设置绑定密码")
|
||||
|
||||
# 计算更新后的密码状态(用于校验是否可启用/独占)
|
||||
if config_update.bind_password is None:
|
||||
will_have_password = bool(config.bind_password_encrypted)
|
||||
elif config_update.bind_password == "":
|
||||
will_have_password = False
|
||||
else:
|
||||
will_have_password = True
|
||||
|
||||
# 独占模式必须启用 LDAP 且必须有绑定密码(防止误锁定)
|
||||
if config_update.is_exclusive and not config_update.is_enabled:
|
||||
raise InvalidRequestException("仅允许 LDAP 登录 需要先启用 LDAP 认证")
|
||||
if config_update.is_enabled and not will_have_password:
|
||||
raise InvalidRequestException("启用 LDAP 认证 需要先设置绑定密码")
|
||||
if config_update.is_exclusive and not will_have_password:
|
||||
raise InvalidRequestException("仅允许 LDAP 登录 需要先设置绑定密码")
|
||||
|
||||
config.server_url = config_update.server_url
|
||||
config.bind_dn = config_update.bind_dn
|
||||
config.base_dn = config_update.base_dn
|
||||
config.user_search_filter = config_update.user_search_filter
|
||||
config.username_attr = config_update.username_attr
|
||||
config.email_attr = config_update.email_attr
|
||||
config.display_name_attr = config_update.display_name_attr
|
||||
config.is_enabled = config_update.is_enabled
|
||||
config.is_exclusive = config_update.is_exclusive
|
||||
config.use_starttls = config_update.use_starttls
|
||||
config.connect_timeout = config_update.connect_timeout
|
||||
|
||||
# 启用独占模式前检查是否有足够的本地管理员(防止锁定)
|
||||
# 使用 with_for_update() 阻塞锁防止竞态条件(移除 nowait 确保并发安全)
|
||||
if config_update.is_enabled and config_update.is_exclusive:
|
||||
local_admins = (
|
||||
db.query(User)
|
||||
.filter(
|
||||
User.role == UserRole.ADMIN,
|
||||
User.auth_source == AuthSource.LOCAL,
|
||||
User.is_active.is_(True),
|
||||
User.is_deleted.is_(False),
|
||||
)
|
||||
.with_for_update()
|
||||
.all()
|
||||
)
|
||||
# 验证至少有一个管理员有有效的密码哈希(可以登录)
|
||||
# 使用严格的 bcrypt 格式校验:$2a$/$2b$/$2y$ + 2位cost + $ + 53字符
|
||||
valid_admin_count = sum(
|
||||
1
|
||||
for admin in local_admins
|
||||
if admin.password_hash
|
||||
and isinstance(admin.password_hash, str)
|
||||
and BCRYPT_HASH_PATTERN.match(admin.password_hash)
|
||||
)
|
||||
if valid_admin_count < 1:
|
||||
raise InvalidRequestException(
|
||||
"启用 LDAP 独占模式前,必须至少保留 1 个有效的本地管理员账户(含有效密码)作为紧急恢复通道"
|
||||
)
|
||||
|
||||
if config_update.bind_password is not None:
|
||||
if config_update.bind_password == "":
|
||||
# 显式清除密码(设置为 NULL)
|
||||
config.bind_password_encrypted = None
|
||||
password_changed = "cleared"
|
||||
else:
|
||||
config.bind_password_encrypted = crypto_service.encrypt(config_update.bind_password)
|
||||
password_changed = "updated"
|
||||
else:
|
||||
password_changed = None
|
||||
|
||||
db.commit()
|
||||
|
||||
# 记录审计日志
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type=AuditEventType.CONFIG_CHANGED,
|
||||
description=f"LDAP 配置已更新 (enabled={config_update.is_enabled}, exclusive={config_update.is_exclusive})",
|
||||
user_id=str(context.user.id) if context.user else None,
|
||||
metadata={
|
||||
"server_url": config_update.server_url,
|
||||
"is_enabled": config_update.is_enabled,
|
||||
"is_exclusive": config_update.is_exclusive,
|
||||
"password_changed": password_changed,
|
||||
"is_new_config": is_new_config,
|
||||
},
|
||||
)
|
||||
|
||||
return {"message": "LDAP配置更新成功"}
|
||||
|
||||
|
||||
class AdminTestLDAPConnectionAdapter(AdminApiAdapter):
|
||||
async def handle(self, context) -> Dict[str, Any]: # type: ignore[override]
|
||||
from src.services.auth.ldap import LDAPService
|
||||
|
||||
db = context.db
|
||||
if context.json_body is not None:
|
||||
payload = context.json_body
|
||||
elif not context.raw_body:
|
||||
payload = {}
|
||||
else:
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
saved_config = db.query(LDAPConfig).first()
|
||||
|
||||
try:
|
||||
overrides = LDAPConfigTest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
config_data: Dict[str, Any] = {}
|
||||
|
||||
if saved_config:
|
||||
config_data = {
|
||||
"server_url": saved_config.server_url,
|
||||
"bind_dn": saved_config.bind_dn,
|
||||
"base_dn": saved_config.base_dn,
|
||||
"user_search_filter": saved_config.user_search_filter,
|
||||
"username_attr": saved_config.username_attr,
|
||||
"email_attr": saved_config.email_attr,
|
||||
"display_name_attr": saved_config.display_name_attr,
|
||||
"use_starttls": saved_config.use_starttls,
|
||||
"connect_timeout": saved_config.connect_timeout,
|
||||
}
|
||||
|
||||
# 应用前端传入的覆盖值
|
||||
for field in [
|
||||
"server_url",
|
||||
"bind_dn",
|
||||
"base_dn",
|
||||
"user_search_filter",
|
||||
"username_attr",
|
||||
"email_attr",
|
||||
"display_name_attr",
|
||||
"use_starttls",
|
||||
"is_enabled",
|
||||
"is_exclusive",
|
||||
"connect_timeout",
|
||||
]:
|
||||
value = getattr(overrides, field)
|
||||
if value is not None:
|
||||
config_data[field] = value
|
||||
|
||||
# bind_password 优先使用 overrides;否则使用已保存的密码(允许保存密码无法解密时依然用 overrides 测试)
|
||||
if overrides.bind_password is not None:
|
||||
config_data["bind_password"] = overrides.bind_password
|
||||
elif saved_config and saved_config.bind_password_encrypted:
|
||||
try:
|
||||
config_data["bind_password"] = crypto_service.decrypt(
|
||||
saved_config.bind_password_encrypted
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"绑定密码解密失败: {type(e).__name__}: {e}")
|
||||
return LDAPTestResponse(
|
||||
success=False, message="绑定密码解密失败,请检查配置或重新设置密码"
|
||||
).model_dump()
|
||||
|
||||
# 必填字段检查
|
||||
required_fields = ["server_url", "bind_dn", "base_dn", "bind_password"]
|
||||
missing = [f for f in required_fields if not config_data.get(f)]
|
||||
if missing:
|
||||
return LDAPTestResponse(
|
||||
success=False, message=f"缺少必要字段: {', '.join(missing)}"
|
||||
).model_dump()
|
||||
|
||||
success, message = LDAPService.test_connection_with_config(config_data)
|
||||
return LDAPTestResponse(success=success, message=message).model_dump()
|
||||
@@ -33,6 +33,7 @@ from src.models.api import (
|
||||
)
|
||||
from src.models.database import AuditEventType, User, UserRole
|
||||
from src.services.auth.service import AuthService
|
||||
from src.services.auth.ldap import LDAPService
|
||||
from src.services.rate_limit.ip_limiter import IPRateLimiter
|
||||
from src.services.system.audit import AuditService
|
||||
from src.services.system.config import SystemConfigService
|
||||
@@ -99,6 +100,13 @@ async def registration_settings(request: Request, db: Session = Depends(get_db))
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/settings")
|
||||
async def auth_settings(request: Request, db: Session = Depends(get_db)):
|
||||
"""公开获取认证设置(用于前端判断显示哪些登录选项)"""
|
||||
adapter = AuthSettingsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/login", response_model=LoginResponse)
|
||||
async def login(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AuthLoginAdapter()
|
||||
@@ -193,7 +201,9 @@ class AuthLoginAdapter(AuthPublicAdapter):
|
||||
detail=f"登录请求过于频繁,请在 {reset_after} 秒后重试",
|
||||
)
|
||||
|
||||
user = await AuthService.authenticate_user(db, login_request.email, login_request.password)
|
||||
user = await AuthService.authenticate_user(
|
||||
db, login_request.email, login_request.password, login_request.auth_type
|
||||
)
|
||||
if not user:
|
||||
AuditService.log_login_attempt(
|
||||
db=db,
|
||||
@@ -305,6 +315,21 @@ class AuthRegistrationSettingsAdapter(AuthPublicAdapter):
|
||||
).model_dump()
|
||||
|
||||
|
||||
class AuthSettingsAdapter(AuthPublicAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""公开返回认证设置"""
|
||||
db = context.db
|
||||
|
||||
ldap_enabled = LDAPService.is_ldap_enabled(db)
|
||||
ldap_exclusive = LDAPService.is_ldap_exclusive(db)
|
||||
|
||||
return {
|
||||
"local_enabled": not ldap_exclusive,
|
||||
"ldap_enabled": ldap_enabled,
|
||||
"ldap_exclusive": ldap_exclusive,
|
||||
}
|
||||
|
||||
|
||||
class AuthRegisterAdapter(AuthPublicAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.models.database import SystemConfig
|
||||
@@ -324,6 +349,12 @@ class AuthRegisterAdapter(AuthPublicAdapter):
|
||||
detail=f"注册请求过于频繁,请在 {reset_after} 秒后重试",
|
||||
)
|
||||
|
||||
# 仅允许 LDAP 登录时拒绝本地注册
|
||||
if LDAPService.is_ldap_exclusive(db):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="系统已启用 LDAP 专属登录,禁止本地注册"
|
||||
)
|
||||
|
||||
allow_registration = db.query(SystemConfig).filter_by(key="enable_registration").first()
|
||||
if allow_registration and not allow_registration.value:
|
||||
AuditService.log_event(
|
||||
|
||||
@@ -30,3 +30,10 @@ class ProviderBillingType(Enum):
|
||||
MONTHLY_QUOTA = "monthly_quota" # 月卡额度
|
||||
PAY_AS_YOU_GO = "pay_as_you_go" # 按量付费
|
||||
FREE_TIER = "free_tier" # 免费额度
|
||||
|
||||
|
||||
class AuthSource(str, Enum):
|
||||
"""认证来源枚举"""
|
||||
|
||||
LOCAL = "local" # 本地认证
|
||||
LDAP = "ldap" # LDAP 认证
|
||||
|
||||
@@ -4,9 +4,9 @@ API端点请求/响应模型定义
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from ..core.enums import UserRole
|
||||
|
||||
@@ -15,17 +15,9 @@ from ..core.enums import UserRole
|
||||
class LoginRequest(BaseModel):
|
||||
"""登录请求"""
|
||||
|
||||
email: str = Field(..., min_length=3, max_length=255, description="邮箱地址")
|
||||
email: str = Field(..., min_length=1, max_length=255, description="邮箱/用户名")
|
||||
password: str = Field(..., min_length=1, max_length=128, description="密码")
|
||||
|
||||
@classmethod
|
||||
@field_validator("email")
|
||||
def validate_email(cls, v):
|
||||
"""验证邮箱格式"""
|
||||
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
if not re.match(email_pattern, v):
|
||||
raise ValueError("邮箱格式无效")
|
||||
return v.lower()
|
||||
auth_type: Literal["local", "ldap"] = Field(default="local", description="认证类型")
|
||||
|
||||
@classmethod
|
||||
@field_validator("password")
|
||||
@@ -36,6 +28,24 @@ class LoginRequest(BaseModel):
|
||||
raise ValueError("密码不能为空")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_login(self):
|
||||
"""根据认证类型校验并规范化登录标识"""
|
||||
identifier = self.email.strip()
|
||||
|
||||
if not identifier:
|
||||
raise ValueError("用户名/邮箱不能为空")
|
||||
|
||||
# 本地和 LDAP 登录都支持用户名或邮箱
|
||||
# 如果是邮箱格式,转换为小写
|
||||
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
if re.match(email_pattern, identifier):
|
||||
self.email = identifier.lower()
|
||||
else:
|
||||
self.email = identifier
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""登录响应"""
|
||||
|
||||
@@ -30,7 +30,7 @@ from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
from ..config import config
|
||||
from ..core.enums import ProviderBillingType, UserRole
|
||||
from ..core.enums import AuthSource, ProviderBillingType, UserRole
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
@@ -54,6 +54,20 @@ class User(Base):
|
||||
default=UserRole.USER,
|
||||
nullable=False,
|
||||
)
|
||||
auth_source = Column(
|
||||
Enum(
|
||||
AuthSource,
|
||||
name="authsource",
|
||||
create_type=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
default=AuthSource.LOCAL,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# LDAP 标识(仅 auth_source=ldap 时使用,用于在邮箱变更/用户名冲突时稳定关联本地账户)
|
||||
ldap_dn = Column(String(512), nullable=True, index=True)
|
||||
ldap_username = Column(String(255), nullable=True, index=True)
|
||||
|
||||
# 访问限制(NULL 表示不限制,允许访问所有资源)
|
||||
allowed_providers = Column(JSON, nullable=True) # 允许使用的提供商 ID 列表
|
||||
@@ -428,6 +442,68 @@ class SystemConfig(Base):
|
||||
)
|
||||
|
||||
|
||||
class LDAPConfig(Base):
|
||||
"""LDAP认证配置表 - 单行配置"""
|
||||
|
||||
__tablename__ = "ldap_configs"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
server_url = Column(String(255), nullable=False) # ldap://host:389 或 ldaps://host:636
|
||||
bind_dn = Column(String(255), nullable=False) # 绑定账号 DN
|
||||
bind_password_encrypted = Column(Text, nullable=True) # 加密的绑定密码(允许 NULL 表示已清除)
|
||||
base_dn = Column(String(255), nullable=False) # 用户搜索基础 DN
|
||||
user_search_filter = Column(
|
||||
String(500), default="(uid={username})", nullable=False
|
||||
) # 用户搜索过滤器
|
||||
username_attr = Column(String(50), default="uid", nullable=False) # 用户名属性 (uid/sAMAccountName)
|
||||
email_attr = Column(String(50), default="mail", nullable=False) # 邮箱属性
|
||||
display_name_attr = Column(String(50), default="cn", nullable=False) # 显示名称属性
|
||||
is_enabled = Column(Boolean, default=False, nullable=False) # 是否启用 LDAP 认证
|
||||
is_exclusive = Column(
|
||||
Boolean, default=False, nullable=False
|
||||
) # 是否仅允许 LDAP 登录(禁用本地认证)
|
||||
use_starttls = Column(Boolean, default=False, nullable=False) # 是否使用 STARTTLS
|
||||
connect_timeout = Column(Integer, default=10, nullable=False) # 连接超时时间(秒)
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def set_bind_password(self, password: str) -> None:
|
||||
"""
|
||||
设置并加密绑定密码
|
||||
|
||||
Args:
|
||||
password: 明文密码
|
||||
"""
|
||||
from src.core.crypto import crypto_service
|
||||
|
||||
self.bind_password_encrypted = crypto_service.encrypt(password)
|
||||
|
||||
def get_bind_password(self) -> str:
|
||||
"""
|
||||
获取解密后的绑定密码
|
||||
|
||||
Returns:
|
||||
str: 解密后的明文密码
|
||||
|
||||
Raises:
|
||||
DecryptionException: 解密失败时抛出异常
|
||||
"""
|
||||
from src.core.crypto import crypto_service
|
||||
|
||||
if not self.bind_password_encrypted:
|
||||
return ""
|
||||
return crypto_service.decrypt(self.bind_password_encrypted)
|
||||
|
||||
|
||||
class Provider(Base):
|
||||
"""提供商配置表"""
|
||||
|
||||
|
||||
363
src/services/auth/ldap.py
Normal file
363
src/services/auth/ldap.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""LDAP 认证服务"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import LDAPConfig
|
||||
|
||||
# LDAP 连接默认超时时间(秒)
|
||||
DEFAULT_LDAP_CONNECT_TIMEOUT = 10
|
||||
|
||||
|
||||
def parse_ldap_server_url(server_url: str) -> tuple[str, int, bool]:
|
||||
"""
|
||||
解析 LDAP 服务器地址,支持:
|
||||
- ldap://host:389
|
||||
- ldaps://host:636
|
||||
- host:389(无 scheme 时默认 ldap)
|
||||
|
||||
Returns:
|
||||
(host, port, use_ssl)
|
||||
"""
|
||||
raw = (server_url or "").strip()
|
||||
if not raw:
|
||||
raise ValueError("LDAP server_url is required")
|
||||
|
||||
parsed = urlparse(raw)
|
||||
if parsed.scheme in {"ldap", "ldaps"}:
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise ValueError("Invalid LDAP server_url")
|
||||
use_ssl = parsed.scheme == "ldaps"
|
||||
port = parsed.port or (636 if use_ssl else 389)
|
||||
return host, port, use_ssl
|
||||
|
||||
# 兼容无 scheme:按 ldap:// 解析
|
||||
parsed = urlparse(f"ldap://{raw}")
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise ValueError("Invalid LDAP server_url")
|
||||
port = parsed.port or 389
|
||||
return host, port, False
|
||||
|
||||
|
||||
def escape_ldap_filter(value: str, max_length: int = 128) -> str:
|
||||
"""
|
||||
转义 LDAP 过滤器中的特殊字符,防止 LDAP 注入攻击(RFC 4515)
|
||||
|
||||
Args:
|
||||
value: 需要转义的字符串
|
||||
max_length: 最大允许长度,默认 128 字符(覆盖大多数企业邮箱用户名)
|
||||
|
||||
Returns:
|
||||
转义后的安全字符串
|
||||
|
||||
Raises:
|
||||
ValueError: 输入值过长
|
||||
"""
|
||||
import unicodedata
|
||||
|
||||
# 先检查原始长度,防止 DoS 攻击
|
||||
# 128 字符足够覆盖大多数企业用户名和邮箱地址
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f"LDAP filter value too long (max {max_length} characters)")
|
||||
|
||||
# Unicode 规范化(使用 NFC 而非 NFKC,避免兼容性字符转换导致安全问题)
|
||||
value = unicodedata.normalize("NFC", value)
|
||||
|
||||
# 再次检查规范化后的长度(防止规范化后长度突增)
|
||||
if len(value) > max_length:
|
||||
raise ValueError(f"LDAP filter value too long after normalization (max {max_length})")
|
||||
|
||||
# LDAP 过滤器特殊字符(RFC 4515 + 扩展)
|
||||
# 使用显式顺序处理,确保反斜杠首先转义
|
||||
value = value.replace("\\", r"\5c") # 反斜杠必须首先转义
|
||||
value = value.replace("*", r"\2a")
|
||||
value = value.replace("(", r"\28")
|
||||
value = value.replace(")", r"\29")
|
||||
value = value.replace("\x00", r"\00") # NUL
|
||||
value = value.replace("&", r"\26")
|
||||
value = value.replace("|", r"\7c")
|
||||
value = value.replace("=", r"\3d")
|
||||
value = value.replace(">", r"\3e")
|
||||
value = value.replace("<", r"\3c")
|
||||
value = value.replace("~", r"\7e")
|
||||
value = value.replace("!", r"\21")
|
||||
return value
|
||||
|
||||
|
||||
def _get_attr_value(entry: Any, attr_name: str, default: str = "") -> str:
|
||||
"""
|
||||
提取 LDAP 条目属性的首个值,避免返回字符串化的列表表示。
|
||||
"""
|
||||
attr = getattr(entry, attr_name, None)
|
||||
if not attr:
|
||||
return default
|
||||
# ldap3 的 EntryAttribute.value 已经是单值或列表,根据类型取首个
|
||||
val = getattr(attr, "value", None)
|
||||
if isinstance(val, list):
|
||||
val = val[0] if val else default
|
||||
if val is None:
|
||||
return default
|
||||
return str(val)
|
||||
|
||||
|
||||
class LDAPService:
|
||||
"""LDAP 认证服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_config(db: Session) -> Optional[LDAPConfig]:
|
||||
"""获取 LDAP 配置"""
|
||||
return db.query(LDAPConfig).first()
|
||||
|
||||
@staticmethod
|
||||
def is_ldap_enabled(db: Session) -> bool:
|
||||
"""检查 LDAP 是否可用(已启用且绑定密码可解密)"""
|
||||
return LDAPService.get_config_data(db) is not None
|
||||
|
||||
@staticmethod
|
||||
def is_ldap_exclusive(db: Session) -> bool:
|
||||
"""检查是否仅允许 LDAP 登录(仅在 LDAP 可用时生效,避免误锁定)"""
|
||||
config = LDAPService.get_config(db)
|
||||
if not config or config.is_exclusive is not True:
|
||||
return False
|
||||
return LDAPService.get_config_data(db) is not None
|
||||
|
||||
@staticmethod
|
||||
def get_config_data(db: Session) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
提前获取并解密配置,供线程池使用,避免跨线程共享 Session。
|
||||
"""
|
||||
config = LDAPService.get_config(db)
|
||||
if not config or config.is_enabled is not True:
|
||||
return None
|
||||
|
||||
try:
|
||||
bind_password = config.get_bind_password()
|
||||
except Exception as e:
|
||||
logger.error(f"LDAP 绑定密码解密失败: {e}")
|
||||
return None
|
||||
|
||||
# 绑定密码为空时无法进行 LDAP 认证
|
||||
if not bind_password:
|
||||
logger.warning("LDAP 绑定密码未配置,无法进行 LDAP 认证")
|
||||
return None
|
||||
|
||||
return {
|
||||
"server_url": config.server_url,
|
||||
"bind_dn": config.bind_dn,
|
||||
"bind_password": bind_password,
|
||||
"base_dn": config.base_dn,
|
||||
"user_search_filter": config.user_search_filter,
|
||||
"username_attr": config.username_attr,
|
||||
"email_attr": config.email_attr,
|
||||
"display_name_attr": config.display_name_attr,
|
||||
"use_starttls": config.use_starttls,
|
||||
"connect_timeout": config.connect_timeout or DEFAULT_LDAP_CONNECT_TIMEOUT,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def authenticate_with_config(config: Dict[str, Any], username: str, password: str) -> Optional[dict]:
|
||||
"""
|
||||
LDAP bind 验证
|
||||
|
||||
Args:
|
||||
config: 已解密的 LDAP 配置
|
||||
username: 用户名
|
||||
password: 密码
|
||||
|
||||
Returns:
|
||||
用户属性 dict {username, email, display_name} 或 None
|
||||
"""
|
||||
try:
|
||||
import ldap3
|
||||
from ldap3 import Server, Connection, SUBTREE
|
||||
from ldap3.core.exceptions import LDAPBindError, LDAPSocketOpenError
|
||||
except ImportError:
|
||||
logger.error("ldap3 库未安装")
|
||||
return None
|
||||
|
||||
if not config:
|
||||
logger.warning("LDAP 未配置或未启用")
|
||||
return None
|
||||
|
||||
admin_conn = None
|
||||
user_conn = None
|
||||
|
||||
try:
|
||||
# 创建服务器连接
|
||||
server_url = config["server_url"]
|
||||
server_host, server_port, use_ssl = parse_ldap_server_url(server_url)
|
||||
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
|
||||
server = Server(
|
||||
server_host,
|
||||
port=server_port,
|
||||
use_ssl=use_ssl,
|
||||
get_info=ldap3.ALL,
|
||||
connect_timeout=timeout,
|
||||
)
|
||||
|
||||
# 使用管理员账号连接
|
||||
bind_password = config["bind_password"]
|
||||
admin_conn = Connection(
|
||||
server,
|
||||
user=config["bind_dn"],
|
||||
password=bind_password,
|
||||
receive_timeout=timeout, # 添加读取超时,避免服务器响应缓慢时阻塞
|
||||
)
|
||||
|
||||
if config.get("use_starttls") and not use_ssl:
|
||||
admin_conn.start_tls()
|
||||
|
||||
if not admin_conn.bind():
|
||||
logger.error(f"LDAP 管理员绑定失败: {admin_conn.result}")
|
||||
return None
|
||||
|
||||
# 搜索用户(转义用户名防止 LDAP 注入)
|
||||
safe_username = escape_ldap_filter(username)
|
||||
search_filter = config["user_search_filter"].replace("{username}", safe_username)
|
||||
admin_conn.search(
|
||||
search_base=config["base_dn"],
|
||||
search_filter=search_filter,
|
||||
search_scope=SUBTREE,
|
||||
size_limit=2, # 防止过滤器误配导致匹配多用户
|
||||
time_limit=timeout, # 添加搜索超时,防止大型目录搜索阻塞
|
||||
attributes=[
|
||||
config["username_attr"],
|
||||
config["email_attr"],
|
||||
config["display_name_attr"],
|
||||
],
|
||||
)
|
||||
|
||||
if len(admin_conn.entries) != 1:
|
||||
# 统一错误信息,避免泄露用户是否存在;日志仅记录结果数量,不泄露敏感信息
|
||||
logger.warning(
|
||||
f"LDAP 认证失败(用户查找阶段): 搜索返回 {len(admin_conn.entries)} 条结果"
|
||||
)
|
||||
return None
|
||||
|
||||
user_entry = admin_conn.entries[0]
|
||||
user_dn = user_entry.entry_dn
|
||||
|
||||
# 用户密码验证
|
||||
user_conn = Connection(
|
||||
server,
|
||||
user=user_dn,
|
||||
password=password,
|
||||
receive_timeout=timeout, # 添加读取超时
|
||||
)
|
||||
if config.get("use_starttls") and not use_ssl:
|
||||
user_conn.start_tls()
|
||||
|
||||
if not user_conn.bind():
|
||||
# 统一错误信息,避免泄露密码是否正确;日志仅记录错误码,不泄露用户 DN
|
||||
bind_result = user_conn.result.get("description", "unknown")
|
||||
logger.warning(f"LDAP 认证失败(密码验证阶段): {bind_result}")
|
||||
return None
|
||||
|
||||
# 提取用户属性(优先用 LDAP 提供的值,不合法则回退默认)
|
||||
ldap_username = _get_attr_value(user_entry, config["username_attr"], username)
|
||||
email = _get_attr_value(
|
||||
user_entry, config["email_attr"], f"{username}@ldap.local"
|
||||
)
|
||||
display_name = _get_attr_value(user_entry, config["display_name_attr"], username)
|
||||
|
||||
logger.info(f"LDAP 认证成功: {username}")
|
||||
return {
|
||||
"username": ldap_username,
|
||||
"ldap_username": ldap_username,
|
||||
"ldap_dn": user_dn,
|
||||
"email": email,
|
||||
"display_name": display_name,
|
||||
}
|
||||
|
||||
except LDAPSocketOpenError as e:
|
||||
logger.error(f"LDAP 服务器连接失败: {e}")
|
||||
return None
|
||||
except LDAPBindError as e:
|
||||
logger.error(f"LDAP 绑定失败: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"LDAP 认证异常: {e}")
|
||||
return None
|
||||
finally:
|
||||
# 确保连接关闭,避免失败路径泄漏
|
||||
# 使用循环确保即使第一个 unbind 失败,后续连接仍会尝试关闭
|
||||
for conn, name in [(admin_conn, "admin"), (user_conn, "user")]:
|
||||
if conn:
|
||||
try:
|
||||
conn.unbind()
|
||||
except Exception as e:
|
||||
logger.warning(f"LDAP {name} 连接关闭失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def test_connection_with_config(config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""
|
||||
测试 LDAP 连接
|
||||
|
||||
Returns:
|
||||
(success, message)
|
||||
"""
|
||||
try:
|
||||
import ldap3
|
||||
from ldap3 import Server, Connection
|
||||
except ImportError:
|
||||
return False, "ldap3 库未安装"
|
||||
|
||||
if not config:
|
||||
return False, "LDAP 配置不存在"
|
||||
|
||||
conn = None
|
||||
try:
|
||||
server_url = config["server_url"]
|
||||
server_host, server_port, use_ssl = parse_ldap_server_url(server_url)
|
||||
timeout = config.get("connect_timeout", DEFAULT_LDAP_CONNECT_TIMEOUT)
|
||||
server = Server(
|
||||
server_host,
|
||||
port=server_port,
|
||||
use_ssl=use_ssl,
|
||||
get_info=ldap3.ALL,
|
||||
connect_timeout=timeout,
|
||||
)
|
||||
bind_password = config["bind_password"]
|
||||
conn = Connection(
|
||||
server,
|
||||
user=config["bind_dn"],
|
||||
password=bind_password,
|
||||
receive_timeout=timeout, # 添加读取超时
|
||||
)
|
||||
|
||||
if config.get("use_starttls") and not use_ssl:
|
||||
conn.start_tls()
|
||||
|
||||
if not conn.bind():
|
||||
return False, f"绑定失败: {conn.result}"
|
||||
|
||||
return True, "连接成功"
|
||||
|
||||
except Exception as e:
|
||||
# 记录详细错误到日志,但只返回通用信息给前端,避免泄露敏感信息
|
||||
logger.error(f"LDAP 测试连接失败: {type(e).__name__}: {e}")
|
||||
return False, "连接失败,请检查服务器地址、端口和凭据"
|
||||
finally:
|
||||
if conn:
|
||||
try:
|
||||
conn.unbind()
|
||||
except Exception as e:
|
||||
logger.warning(f"LDAP 测试连接关闭失败: {e}")
|
||||
|
||||
# 兼容旧接口:如果其他代码直接调用
|
||||
@staticmethod
|
||||
def authenticate(db: Session, username: str, password: str) -> Optional[dict]:
|
||||
config = LDAPService.get_config_data(db)
|
||||
return LDAPService.authenticate_with_config(config, username, password) if config else None
|
||||
|
||||
@staticmethod
|
||||
def test_connection(db: Session) -> Tuple[bool, str]:
|
||||
config = LDAPService.get_config_data(db)
|
||||
if not config:
|
||||
return False, "LDAP 配置不存在或未启用"
|
||||
return LDAPService.test_connection_with_config(config)
|
||||
@@ -2,21 +2,25 @@
|
||||
认证服务
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import secrets
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import jwt
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.config import config
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.logger import logger
|
||||
from src.core.enums import AuthSource
|
||||
from src.models.database import ApiKey, User, UserRole
|
||||
from src.services.auth.jwt_blacklist import JWTBlacklistService
|
||||
from src.services.auth.ldap import LDAPService
|
||||
from src.services.cache.user_cache import UserCacheService
|
||||
from src.services.user.apikey import ApiKeyService
|
||||
|
||||
@@ -92,15 +96,86 @@ class AuthService:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的Token")
|
||||
|
||||
@staticmethod
|
||||
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
"""用户登录认证"""
|
||||
async def authenticate_user(
|
||||
db: Session, email: str, password: str, auth_type: str = "local"
|
||||
) -> Optional[User]:
|
||||
"""用户登录认证
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
email: 邮箱/用户名
|
||||
password: 密码
|
||||
auth_type: 认证类型 ("local" 或 "ldap")
|
||||
"""
|
||||
if auth_type == "ldap":
|
||||
# LDAP 认证
|
||||
# 预取配置,避免将 Session 传递到线程池
|
||||
config_data = LDAPService.get_config_data(db)
|
||||
if not config_data:
|
||||
logger.warning("登录失败 - LDAP 未启用或配置无效")
|
||||
return None
|
||||
|
||||
# 计算总体超时:LDAP 认证包含多次网络操作(连接、管理员绑定、搜索、用户绑定)
|
||||
# 超时策略:
|
||||
# - 单次操作超时(connect_timeout):控制每次网络操作的最大等待时间
|
||||
# - 总体超时:防止异常场景(如服务器响应缓慢但未超时)导致请求堆积
|
||||
# - 公式:单次超时 × 4(覆盖 4 次主要网络操作)+ 10% 缓冲
|
||||
# - 最小 20 秒(保证基本操作),最大 60 秒(避免用户等待过长)
|
||||
single_timeout = config_data.get("connect_timeout", 10)
|
||||
total_timeout = max(20, min(int(single_timeout * 4 * 1.1), 60))
|
||||
|
||||
# 在线程池中执行阻塞的 LDAP 网络请求,避免阻塞事件循环
|
||||
# 添加总体超时保护,防止异常场景下请求堆积
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
ldap_user = await asyncio.wait_for(
|
||||
run_in_threadpool(
|
||||
LDAPService.authenticate_with_config, config_data, email, password
|
||||
),
|
||||
timeout=total_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"LDAP 认证总体超时({total_timeout}秒): {email}")
|
||||
return None
|
||||
|
||||
if not ldap_user:
|
||||
return None
|
||||
|
||||
# 获取或创建本地用户
|
||||
user = await AuthService._get_or_create_ldap_user(db, ldap_user)
|
||||
if not user:
|
||||
# 已有本地账号但来源不匹配等情况
|
||||
return None
|
||||
if not user.is_active:
|
||||
logger.warning(f"登录失败 - 用户已禁用: {email}")
|
||||
return None
|
||||
return user
|
||||
|
||||
# 本地认证
|
||||
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
# 支持邮箱或用户名登录
|
||||
from sqlalchemy import or_
|
||||
user = db.query(User).filter(
|
||||
or_(User.email == email, User.username == email)
|
||||
).first()
|
||||
|
||||
if not user:
|
||||
logger.warning(f"登录失败 - 用户不存在: {email}")
|
||||
return None
|
||||
|
||||
# 检查 LDAP exclusive 模式:仅允许本地管理员登录(紧急恢复通道)
|
||||
if LDAPService.is_ldap_exclusive(db):
|
||||
if user.role != UserRole.ADMIN or user.auth_source != AuthSource.LOCAL:
|
||||
logger.warning(f"登录失败 - 仅允许 LDAP 登录(管理员除外): {email}")
|
||||
return None
|
||||
logger.warning(f"[LDAP-EXCLUSIVE] 紧急恢复通道:本地管理员登录: {email}")
|
||||
|
||||
# 检查用户认证来源
|
||||
if user.auth_source == AuthSource.LDAP:
|
||||
logger.warning(f"登录失败 - 该用户使用 LDAP 认证: {email}")
|
||||
return None
|
||||
|
||||
if not user.verify_password(password):
|
||||
logger.warning(f"登录失败 - 密码错误: {email}")
|
||||
return None
|
||||
@@ -118,6 +193,127 @@ class AuthService:
|
||||
logger.info(f"用户登录成功: {email} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def _get_or_create_ldap_user(db: Session, ldap_user: dict) -> Optional[User]:
|
||||
"""获取或创建 LDAP 用户
|
||||
|
||||
Args:
|
||||
ldap_user: LDAP 用户信息 {username, email, display_name, ldap_dn, ldap_username}
|
||||
|
||||
注意:使用 with_for_update() 防止并发首次登录创建重复用户
|
||||
"""
|
||||
ldap_dn = (ldap_user.get("ldap_dn") or "").strip() or None
|
||||
ldap_username = (ldap_user.get("ldap_username") or ldap_user.get("username") or "").strip() or None
|
||||
email = ldap_user["email"]
|
||||
|
||||
# 优先用稳定标识查找,避免邮箱变更/用户名冲突导致重复建号
|
||||
# 使用 with_for_update() 锁定行,防止并发创建
|
||||
user: Optional[User] = None
|
||||
if ldap_dn:
|
||||
user = (
|
||||
db.query(User)
|
||||
.filter(User.auth_source == AuthSource.LDAP, User.ldap_dn == ldap_dn)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not user and ldap_username:
|
||||
user = (
|
||||
db.query(User)
|
||||
.filter(User.auth_source == AuthSource.LDAP, User.ldap_username == ldap_username)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
# 最后回退按 email 查找:如果存在同邮箱的本地账号,需要拒绝以避免接管
|
||||
user = db.query(User).filter(User.email == email).with_for_update().first()
|
||||
|
||||
if user:
|
||||
if user.auth_source != AuthSource.LDAP:
|
||||
# 避免覆盖已有本地账户(不同来源时拒绝登录)
|
||||
logger.warning(
|
||||
f"LDAP 登录拒绝 - 账户来源不匹配(现有:{user.auth_source}, 请求:LDAP): {email}"
|
||||
)
|
||||
return None
|
||||
|
||||
# 同步邮箱(LDAP 侧邮箱变更时更新;若新邮箱已被占用则拒绝)
|
||||
if user.email != email:
|
||||
email_taken = (
|
||||
db.query(User)
|
||||
.filter(User.email == email, User.id != user.id)
|
||||
.first()
|
||||
)
|
||||
if email_taken:
|
||||
logger.warning(f"LDAP 登录拒绝 - 新邮箱已被占用: {email}")
|
||||
return None
|
||||
user.email = email
|
||||
|
||||
# 同步 LDAP 标识(首次填充或 LDAP 侧发生变化)
|
||||
if ldap_dn and user.ldap_dn != ldap_dn:
|
||||
user.ldap_dn = ldap_dn
|
||||
if ldap_username and user.ldap_username != ldap_username:
|
||||
user.ldap_username = ldap_username
|
||||
|
||||
user.last_login_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
await UserCacheService.invalidate_user_cache(user.id, user.email)
|
||||
logger.info(f"LDAP 用户登录成功: {ldap_user['email']} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
# 检查 username 是否已被占用,使用时间戳+随机数确保唯一性
|
||||
base_username = ldap_username or ldap_user["username"]
|
||||
username = base_username
|
||||
max_retries = 3
|
||||
|
||||
for attempt in range(max_retries):
|
||||
# 检查用户名是否已存在
|
||||
existing_user_with_username = db.query(User).filter(User.username == username).first()
|
||||
if existing_user_with_username:
|
||||
# 如果 username 已存在,使用时间戳+随机数确保唯一性
|
||||
username = f"{base_username}_ldap_{int(time.time())}{uuid.uuid4().hex[:4]}"
|
||||
logger.info(f"LDAP 用户名冲突,使用新用户名: {ldap_user['username']} -> {username}")
|
||||
|
||||
# 创建新用户
|
||||
user = User(
|
||||
email=email,
|
||||
username=username,
|
||||
password_hash="", # LDAP 用户无本地密码
|
||||
auth_source=AuthSource.LDAP,
|
||||
ldap_dn=ldap_dn,
|
||||
ldap_username=ldap_username,
|
||||
role=UserRole.USER,
|
||||
is_active=True,
|
||||
last_login_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
try:
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
logger.info(f"LDAP 用户创建成功: {ldap_user['email']} (ID: {user.id})")
|
||||
return user
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
error_str = str(e.orig).lower() if e.orig else str(e).lower()
|
||||
|
||||
# 解析具体冲突类型
|
||||
if "email" in error_str or "ix_users_email" in error_str:
|
||||
# 邮箱冲突不应重试(前面已检查过,说明是并发创建)
|
||||
logger.error(f"LDAP 用户创建失败 - 邮箱并发冲突: {email}")
|
||||
return None
|
||||
elif "username" in error_str or "ix_users_username" in error_str:
|
||||
# 用户名冲突,重试时会生成新用户名
|
||||
if attempt == max_retries - 1:
|
||||
logger.error(f"LDAP 用户创建失败(用户名冲突重试耗尽): {username}")
|
||||
return None
|
||||
username = f"{base_username}_ldap_{int(time.time())}{uuid.uuid4().hex[:4]}"
|
||||
logger.warning(f"LDAP 用户创建用户名冲突,重试 ({attempt + 1}/{max_retries}): {username}")
|
||||
else:
|
||||
# 其他约束冲突,不重试
|
||||
logger.error(f"LDAP 用户创建失败 - 未知数据库约束冲突: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def authenticate_api_key(db: Session, api_key: str) -> Optional[tuple[User, ApiKey]]:
|
||||
"""API密钥认证"""
|
||||
|
||||
Reference in New Issue
Block a user