mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-09 03:02:26 +08:00
261 lines
9.4 KiB
Python
261 lines
9.4 KiB
Python
"""LDAP配置管理API端点。"""
|
||
|
||
from typing import Any, Dict, Optional
|
||
|
||
from fastapi import APIRouter, Depends, Request
|
||
from pydantic import BaseModel, Field, ValidationError
|
||
from sqlalchemy.orm import Session
|
||
|
||
from src.api.base.admin_adapter import AdminApiAdapter
|
||
from src.api.base.pipeline import ApiRequestPipeline
|
||
from src.core.crypto import crypto_service
|
||
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
||
from src.database import get_db
|
||
from src.models.database import LDAPConfig
|
||
|
||
router = APIRouter(prefix="/api/admin/ldap", tags=["Admin - LDAP"])
|
||
pipeline = ApiRequestPipeline()
|
||
|
||
|
||
# ========== Request/Response Models ==========
|
||
|
||
|
||
class LDAPConfigResponse(BaseModel):
|
||
"""LDAP配置响应(不返回密码)"""
|
||
|
||
server_url: Optional[str] = None
|
||
bind_dn: Optional[str] = None
|
||
base_dn: Optional[str] = None
|
||
user_search_filter: str
|
||
username_attr: str
|
||
email_attr: str
|
||
display_name_attr: str
|
||
is_enabled: bool
|
||
is_exclusive: bool
|
||
use_starttls: bool
|
||
connect_timeout: int
|
||
|
||
|
||
class LDAPConfigUpdate(BaseModel):
|
||
"""LDAP配置更新请求"""
|
||
|
||
server_url: str = Field(..., min_length=1, max_length=255)
|
||
bind_dn: str = Field(..., min_length=1, max_length=255)
|
||
bind_password: Optional[str] = Field(None, min_length=1)
|
||
base_dn: str = Field(..., min_length=1, max_length=255)
|
||
user_search_filter: str = Field(default="(uid={username})", max_length=500)
|
||
username_attr: str = Field(default="uid", max_length=50)
|
||
email_attr: str = Field(default="mail", max_length=50)
|
||
display_name_attr: str = Field(default="cn", max_length=50)
|
||
is_enabled: bool = False
|
||
is_exclusive: bool = False
|
||
use_starttls: bool = False
|
||
connect_timeout: int = Field(default=10, ge=1, le=60)
|
||
|
||
|
||
class LDAPTestResponse(BaseModel):
|
||
"""LDAP连接测试响应"""
|
||
|
||
success: bool
|
||
message: str
|
||
|
||
|
||
class LDAPConfigTest(BaseModel):
|
||
"""LDAP配置测试请求(全部可选,用于临时覆盖)"""
|
||
|
||
server_url: Optional[str] = Field(None, min_length=1, max_length=255)
|
||
bind_dn: Optional[str] = Field(None, min_length=1, max_length=255)
|
||
bind_password: Optional[str] = Field(None, min_length=1)
|
||
base_dn: Optional[str] = Field(None, min_length=1, max_length=255)
|
||
user_search_filter: Optional[str] = Field(None, max_length=500)
|
||
username_attr: Optional[str] = Field(None, max_length=50)
|
||
email_attr: Optional[str] = Field(None, max_length=50)
|
||
display_name_attr: Optional[str] = Field(None, max_length=50)
|
||
is_enabled: Optional[bool] = None
|
||
is_exclusive: Optional[bool] = None
|
||
use_starttls: Optional[bool] = None
|
||
connect_timeout: Optional[int] = Field(None, ge=1, le=60)
|
||
|
||
|
||
# ========== API Endpoints ==========
|
||
|
||
|
||
@router.get("/config")
|
||
async def get_ldap_config(request: Request, db: Session = Depends(get_db)) -> Any:
|
||
"""获取LDAP配置(管理员)"""
|
||
adapter = AdminGetLDAPConfigAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.put("/config")
|
||
async def update_ldap_config(request: Request, db: Session = Depends(get_db)) -> Any:
|
||
"""更新LDAP配置(管理员)"""
|
||
adapter = AdminUpdateLDAPConfigAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
@router.post("/test")
|
||
async def test_ldap_connection(request: Request, db: Session = Depends(get_db)) -> Any:
|
||
"""测试LDAP连接(管理员)"""
|
||
adapter = AdminTestLDAPConnectionAdapter()
|
||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||
|
||
|
||
# ========== Adapters ==========
|
||
|
||
|
||
class AdminGetLDAPConfigAdapter(AdminApiAdapter):
|
||
async def handle(self, context) -> Dict[str, Any]: # type: ignore[override]
|
||
db = context.db
|
||
config = db.query(LDAPConfig).first()
|
||
|
||
if not config:
|
||
return LDAPConfigResponse(
|
||
server_url=None,
|
||
bind_dn=None,
|
||
base_dn=None,
|
||
user_search_filter="(uid={username})",
|
||
username_attr="uid",
|
||
email_attr="mail",
|
||
display_name_attr="cn",
|
||
is_enabled=False,
|
||
is_exclusive=False,
|
||
use_starttls=False,
|
||
connect_timeout=10,
|
||
).model_dump()
|
||
|
||
return LDAPConfigResponse(
|
||
server_url=config.server_url,
|
||
bind_dn=config.bind_dn,
|
||
base_dn=config.base_dn,
|
||
user_search_filter=config.user_search_filter,
|
||
username_attr=config.username_attr,
|
||
email_attr=config.email_attr,
|
||
display_name_attr=config.display_name_attr,
|
||
is_enabled=config.is_enabled,
|
||
is_exclusive=config.is_exclusive,
|
||
use_starttls=config.use_starttls,
|
||
connect_timeout=config.connect_timeout,
|
||
).model_dump()
|
||
|
||
|
||
class AdminUpdateLDAPConfigAdapter(AdminApiAdapter):
|
||
async def handle(self, context) -> Dict[str, str]: # type: ignore[override]
|
||
db = context.db
|
||
payload = context.ensure_json_body()
|
||
|
||
try:
|
||
config_update = LDAPConfigUpdate.model_validate(payload)
|
||
except ValidationError as e:
|
||
errors = e.errors()
|
||
if errors:
|
||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||
raise InvalidRequestException("请求数据验证失败")
|
||
|
||
config = db.query(LDAPConfig).first()
|
||
is_new_config = config is None
|
||
|
||
if is_new_config:
|
||
# 首次创建配置时必须提供密码
|
||
if not config_update.bind_password:
|
||
raise InvalidRequestException("首次配置 LDAP 时必须设置绑定密码")
|
||
config = LDAPConfig()
|
||
db.add(config)
|
||
|
||
config.server_url = config_update.server_url
|
||
config.bind_dn = config_update.bind_dn
|
||
config.base_dn = config_update.base_dn
|
||
config.user_search_filter = config_update.user_search_filter
|
||
config.username_attr = config_update.username_attr
|
||
config.email_attr = config_update.email_attr
|
||
config.display_name_attr = config_update.display_name_attr
|
||
config.is_enabled = config_update.is_enabled
|
||
config.is_exclusive = config_update.is_exclusive
|
||
config.use_starttls = config_update.use_starttls
|
||
config.connect_timeout = config_update.connect_timeout
|
||
|
||
if config_update.bind_password:
|
||
config.bind_password_encrypted = crypto_service.encrypt(config_update.bind_password)
|
||
|
||
db.commit()
|
||
|
||
return {"message": "LDAP配置更新成功"}
|
||
|
||
|
||
class AdminTestLDAPConnectionAdapter(AdminApiAdapter):
|
||
async def handle(self, context) -> Dict[str, Any]: # type: ignore[override]
|
||
from src.services.auth.ldap import LDAPService
|
||
|
||
db = context.db
|
||
if context.json_body is not None:
|
||
payload = context.json_body
|
||
elif not context.raw_body:
|
||
payload = {}
|
||
else:
|
||
payload = context.ensure_json_body()
|
||
|
||
saved_config = db.query(LDAPConfig).first()
|
||
|
||
try:
|
||
overrides = LDAPConfigTest.model_validate(payload)
|
||
except ValidationError as e:
|
||
errors = e.errors()
|
||
if errors:
|
||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||
raise InvalidRequestException("请求数据验证失败")
|
||
|
||
config_data: Dict[str, Any] = {}
|
||
|
||
if saved_config:
|
||
config_data = {
|
||
"server_url": saved_config.server_url,
|
||
"bind_dn": saved_config.bind_dn,
|
||
"base_dn": saved_config.base_dn,
|
||
"user_search_filter": saved_config.user_search_filter,
|
||
"username_attr": saved_config.username_attr,
|
||
"email_attr": saved_config.email_attr,
|
||
"display_name_attr": saved_config.display_name_attr,
|
||
"use_starttls": saved_config.use_starttls,
|
||
"connect_timeout": saved_config.connect_timeout,
|
||
}
|
||
try:
|
||
config_data["bind_password"] = crypto_service.decrypt(
|
||
saved_config.bind_password_encrypted
|
||
)
|
||
except Exception as e:
|
||
return LDAPTestResponse(
|
||
success=False, message=f"绑定密码解密失败: {str(e)}"
|
||
).model_dump()
|
||
|
||
# 应用前端传入的覆盖值
|
||
for field in [
|
||
"server_url",
|
||
"bind_dn",
|
||
"base_dn",
|
||
"user_search_filter",
|
||
"username_attr",
|
||
"email_attr",
|
||
"display_name_attr",
|
||
"use_starttls",
|
||
"is_enabled",
|
||
"is_exclusive",
|
||
"connect_timeout",
|
||
]:
|
||
value = getattr(overrides, field)
|
||
if value is not None:
|
||
config_data[field] = value
|
||
|
||
if overrides.bind_password:
|
||
config_data["bind_password"] = overrides.bind_password
|
||
|
||
# 必填字段检查
|
||
required_fields = ["server_url", "bind_dn", "base_dn", "bind_password"]
|
||
missing = [f for f in required_fields if not config_data.get(f)]
|
||
if missing:
|
||
return LDAPTestResponse(
|
||
success=False, message=f"缺少必要字段: {', '.join(missing)}"
|
||
).model_dump()
|
||
|
||
success, message = LDAPService.test_connection_with_config(config_data)
|
||
return LDAPTestResponse(success=success, message=message).model_dump()
|