diff --git a/README.md b/README.md index 05086bf..80d6933 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ Aether 是一个自托管的 AI API 网关,为团队和个人提供多租户 ```bash # 1. 克隆代码 git clone https://github.com/fawney19/Aether.git -cd aether +cd Aether # 2. 配置环境变量 cp .env.example .env @@ -72,7 +72,7 @@ docker-compose pull && docker-compose up -d && ./migrate.sh ```bash # 1. 克隆代码 git clone https://github.com/fawney19/Aether.git -cd aether +cd Aether # 2. 配置环境变量 cp .env.example .env diff --git a/alembic/versions/20260101_1400_add_ldap_authentication_support.py b/alembic/versions/20260101_1400_add_ldap_authentication_support.py new file mode 100644 index 0000000..76caeaa --- /dev/null +++ b/alembic/versions/20260101_1400_add_ldap_authentication_support.py @@ -0,0 +1,65 @@ +"""add ldap authentication support + +Revision ID: c3d4e5f6g7h8 +Revises: b2c3d4e5f6g7 +Create Date: 2026-01-01 14:00:00.000000+00:00 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy import text + +# revision identifiers, used by Alembic. +revision = 'c3d4e5f6g7h8' +down_revision = 'b2c3d4e5f6g7' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """添加 LDAP 认证支持 + + 1. 创建 authsource 枚举类型 + 2. 在 users 表添加 auth_source 字段 + 3. 创建 ldap_configs 表 + """ + conn = op.get_bind() + + # 1. 创建 authsource 枚举类型 + conn.execute(text("CREATE TYPE authsource AS ENUM ('local', 'ldap')")) + + # 2. 在 users 表添加 auth_source 字段 + op.add_column('users', sa.Column('auth_source', sa.Enum('local', 'ldap', name='authsource', create_type=False), nullable=False, server_default='local')) + + # 3. 创建 ldap_configs 表 + op.create_table( + 'ldap_configs', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('server_url', sa.String(length=255), nullable=False), + sa.Column('bind_dn', sa.String(length=255), nullable=False), + sa.Column('bind_password_encrypted', sa.Text(), nullable=False), + sa.Column('base_dn', sa.String(length=255), nullable=False), + sa.Column('user_search_filter', sa.String(length=500), nullable=False, server_default='(uid={username})'), + sa.Column('username_attr', sa.String(length=50), nullable=False, server_default='uid'), + sa.Column('email_attr', sa.String(length=50), nullable=False, server_default='mail'), + sa.Column('display_name_attr', sa.String(length=50), nullable=False, server_default='cn'), + sa.Column('is_enabled', sa.Boolean(), nullable=False, server_default='false'), + sa.Column('is_exclusive', sa.Boolean(), nullable=False, server_default='false'), + sa.Column('use_starttls', sa.Boolean(), nullable=False, server_default='false'), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')), + sa.PrimaryKeyConstraint('id') + ) + + +def downgrade() -> None: + """回滚 LDAP 认证支持""" + # 1. 删除 ldap_configs 表 + op.drop_table('ldap_configs') + + # 2. 删除 users 表的 auth_source 字段 + op.drop_column('users', 'auth_source') + + # 3. 删除 authsource 枚举类型 + conn = op.get_bind() + conn.execute(text("DROP TYPE authsource")) diff --git a/frontend/src/api/admin.ts b/frontend/src/api/admin.ts index d19a5fb..fd5c1d1 100644 --- a/frontend/src/api/admin.ts +++ b/frontend/src/api/admin.ts @@ -473,5 +473,30 @@ export const adminApi = { `/api/admin/system/email/templates/${templateType}/reset` ) return response.data + }, + + // LDAP 配置相关 + // 获取 LDAP 配置 + async getLdapConfig(): Promise { + const response = await apiClient.get('/api/admin/ldap/config') + return response.data + }, + + // 更新 LDAP 配置 + async updateLdapConfig(config: any): Promise<{ message: string }> { + const response = await apiClient.put<{ message: string }>( + '/api/admin/ldap/config', + config + ) + return response.data + }, + + // 测试 LDAP 连接 + async testLdapConnection(config?: any): Promise<{ success: boolean; message: string }> { + const response = await apiClient.post<{ success: boolean; message: string }>( + '/api/admin/ldap/test', + config || {} + ) + return response.data } } diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 17355f0..43408bb 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -4,6 +4,7 @@ import { log } from '@/utils/logger' export interface LoginRequest { email: string password: string + auth_type?: 'local' | 'ldap' } export interface LoginResponse { @@ -81,6 +82,12 @@ export interface RegistrationSettingsResponse { require_email_verification: boolean } +export interface AuthSettingsResponse { + local_enabled: boolean + ldap_enabled: boolean + ldap_exclusive: boolean +} + export interface User { id: string // UUID username: string @@ -173,5 +180,10 @@ export const authApi = { { email } ) return response.data + }, + + async getAuthSettings(): Promise { + const response = await apiClient.get('/api/auth/settings') + return response.data } } diff --git a/frontend/src/features/auth/components/LoginDialog.vue b/frontend/src/features/auth/components/LoginDialog.vue index a2f6d83..9a0ba2b 100644 --- a/frontend/src/features/auth/components/LoginDialog.vue +++ b/frontend/src/features/auth/components/LoginDialog.vue @@ -66,19 +66,35 @@ + + + + + 本地登录 + + + LDAP 登录 + + + +
- +
@@ -156,6 +172,9 @@ import { Dialog } from '@/components/ui' import Button from '@/components/ui/button.vue' import Input from '@/components/ui/input.vue' import Label from '@/components/ui/label.vue' +import Tabs from '@/components/ui/tabs.vue' +import TabsList from '@/components/ui/tabs-list.vue' +import TabsTrigger from '@/components/ui/tabs-trigger.vue' import { useAuthStore } from '@/stores/auth' import { useToast } from '@/composables/useToast' import { isDemoMode, DEMO_ACCOUNTS } from '@/config/demo' @@ -180,6 +199,20 @@ const showRegisterDialog = ref(false) const requireEmailVerification = ref(false) const allowRegistration = ref(false) // 由系统配置控制,默认关闭 +// LDAP authentication settings +const authType = ref<'local' | 'ldap'>('local') +const localEnabled = ref(true) +const ldapEnabled = ref(false) +const ldapExclusive = ref(false) + +const showAuthTypeTabs = computed(() => { + return localEnabled.value && ldapEnabled.value && !ldapExclusive.value +}) + +const emailLabel = computed(() => { + return authType.value === 'ldap' ? '用户名/邮箱' : '邮箱' +}) + watch(() => props.modelValue, (val) => { isOpen.value = val // 打开对话框时重置表单 @@ -212,7 +245,7 @@ async function handleLogin() { return } - const success = await authStore.login(form.value.email, form.value.password) + const success = await authStore.login(form.value.email, form.value.password, authType.value) if (success) { showSuccess('登录成功,正在跳转...') @@ -246,16 +279,36 @@ function handleSwitchToLogin() { isOpen.value = true } -// Load registration settings on mount +// Load authentication and registration settings on mount onMounted(async () => { try { - const settings = await authApi.getRegistrationSettings() - allowRegistration.value = !!settings.enable_registration - requireEmailVerification.value = !!settings.require_email_verification + // Load registration settings + const regSettings = await authApi.getRegistrationSettings() + allowRegistration.value = !!regSettings.enable_registration + requireEmailVerification.value = !!regSettings.require_email_verification + + // Load authentication settings + const authSettings = await authApi.getAuthSettings() + localEnabled.value = authSettings.local_enabled + ldapEnabled.value = authSettings.ldap_enabled + ldapExclusive.value = authSettings.ldap_exclusive + + // Set default auth type based on settings + if (authSettings.ldap_exclusive) { + authType.value = 'ldap' + } else if (!authSettings.local_enabled && authSettings.ldap_enabled) { + authType.value = 'ldap' + } else { + authType.value = 'local' + } } catch (error) { - // If获取失败,保持默认:关闭注册 & 关闭邮箱验证 + // If获取失败,保持默认:关闭注册 & 关闭邮箱验证 & 使用本地认证 allowRegistration.value = false requireEmailVerification.value = false + localEnabled.value = true + ldapEnabled.value = false + ldapExclusive.value = false + authType.value = 'local' } }) diff --git a/frontend/src/layouts/MainLayout.vue b/frontend/src/layouts/MainLayout.vue index 6229553..5d82508 100644 --- a/frontend/src/layouts/MainLayout.vue +++ b/frontend/src/layouts/MainLayout.vue @@ -423,6 +423,7 @@ const navigation = computed(() => { { name: 'IP 安全', href: '/admin/ip-security', icon: Shield }, { name: '审计日志', href: '/admin/audit-logs', icon: AlertTriangle }, { name: '邮件配置', href: '/admin/email', icon: Mail }, + { name: 'LDAP 配置', href: '/admin/ldap', icon: Shield }, { name: '系统设置', href: '/admin/system', icon: Cog }, ] } diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 3cf085c..e39c88c 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -111,6 +111,11 @@ const routes: RouteRecordRaw[] = [ name: 'EmailSettings', component: () => importWithRetry(() => import('@/views/admin/EmailSettings.vue')) }, + { + path: 'ldap', + name: 'LdapSettings', + component: () => importWithRetry(() => import('@/views/admin/LdapSettings.vue')) + }, { path: 'audit-logs', name: 'AuditLogs', diff --git a/frontend/src/stores/auth.ts b/frontend/src/stores/auth.ts index da0c551..608235d 100644 --- a/frontend/src/stores/auth.ts +++ b/frontend/src/stores/auth.ts @@ -31,12 +31,12 @@ export const useAuthStore = defineStore('auth', () => { } const isAdmin = computed(() => user.value?.role === 'admin') - async function login(email: string, password: string) { + async function login(email: string, password: string, authType: 'local' | 'ldap' = 'local') { loading.value = true error.value = null try { - const response = await authApi.login({ email, password }) + const response = await authApi.login({ email, password, auth_type: authType }) token.value = response.access_token // 获取用户信息 diff --git a/frontend/src/views/admin/LdapSettings.vue b/frontend/src/views/admin/LdapSettings.vue new file mode 100644 index 0000000..5581ff9 --- /dev/null +++ b/frontend/src/views/admin/LdapSettings.vue @@ -0,0 +1,303 @@ + + + diff --git a/pyproject.toml b/pyproject.toml index 3746e20..0977d70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "redis>=5.0.0", "prometheus-client>=0.20.0", "apscheduler>=3.10.0", + "ldap3>=2.9.1", ] [project.optional-dependencies] diff --git a/src/api/admin/__init__.py b/src/api/admin/__init__.py index 835bfda..83233cd 100644 --- a/src/api/admin/__init__.py +++ b/src/api/admin/__init__.py @@ -5,6 +5,7 @@ from fastapi import APIRouter from .adaptive import router as adaptive_router from .api_keys import router as api_keys_router from .endpoints import router as endpoints_router +from .ldap import router as ldap_router from .models import router as models_router from .monitoring import router as monitoring_router from .provider_query import router as provider_query_router @@ -28,5 +29,6 @@ router.include_router(adaptive_router) router.include_router(models_router) router.include_router(security_router) router.include_router(provider_query_router) +router.include_router(ldap_router) __all__ = ["router"] diff --git a/src/api/admin/ldap.py b/src/api/admin/ldap.py new file mode 100644 index 0000000..317d892 --- /dev/null +++ b/src/api/admin/ldap.py @@ -0,0 +1,190 @@ +"""LDAP配置管理API端点。""" + +from typing import Any, Dict, Optional + +from fastapi import APIRouter, Depends, Request +from pydantic import BaseModel, Field, ValidationError +from sqlalchemy.orm import Session + +from src.api.base.admin_adapter import AdminApiAdapter +from src.api.base.pipeline import ApiRequestPipeline +from src.core.crypto import crypto_service +from src.core.exceptions import InvalidRequestException, translate_pydantic_error +from src.database import get_db +from src.models.database import LDAPConfig + +router = APIRouter(prefix="/api/admin/ldap", tags=["Admin - LDAP"]) +pipeline = ApiRequestPipeline() + + +# ========== Request/Response Models ========== + + +class LDAPConfigResponse(BaseModel): + """LDAP配置响应(不返回密码)""" + + server_url: Optional[str] = None + bind_dn: Optional[str] = None + base_dn: Optional[str] = None + user_search_filter: str + username_attr: str + email_attr: str + display_name_attr: str + is_enabled: bool + is_exclusive: bool + use_starttls: bool + + +class LDAPConfigUpdate(BaseModel): + """LDAP配置更新请求""" + + server_url: str = Field(..., min_length=1, max_length=255) + bind_dn: str = Field(..., min_length=1, max_length=255) + bind_password: Optional[str] = Field(None, min_length=1) + base_dn: str = Field(..., min_length=1, max_length=255) + user_search_filter: str = Field(default="(uid={username})", max_length=500) + username_attr: str = Field(default="uid", max_length=50) + email_attr: str = Field(default="mail", max_length=50) + display_name_attr: str = Field(default="cn", max_length=50) + is_enabled: bool = False + is_exclusive: bool = False + use_starttls: bool = False + + +class LDAPTestResponse(BaseModel): + """LDAP连接测试响应""" + + success: bool + message: str + + +# ========== API Endpoints ========== + + +@router.get("/config") +async def get_ldap_config(request: Request, db: Session = Depends(get_db)) -> Any: + """获取LDAP配置(管理员)""" + adapter = AdminGetLDAPConfigAdapter() + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + +@router.put("/config") +async def update_ldap_config(request: Request, db: Session = Depends(get_db)) -> Any: + """更新LDAP配置(管理员)""" + adapter = AdminUpdateLDAPConfigAdapter() + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + +@router.post("/test") +async def test_ldap_connection(request: Request, db: Session = Depends(get_db)) -> Any: + """测试LDAP连接(管理员)""" + adapter = AdminTestLDAPConnectionAdapter() + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + +# ========== Adapters ========== + + +class AdminGetLDAPConfigAdapter(AdminApiAdapter): + async def handle(self, context) -> Dict[str, Any]: # type: ignore[override] + db = context.db + config = db.query(LDAPConfig).first() + + if not config: + return LDAPConfigResponse( + server_url=None, + bind_dn=None, + base_dn=None, + user_search_filter="(uid={username})", + username_attr="uid", + email_attr="mail", + display_name_attr="cn", + is_enabled=False, + is_exclusive=False, + use_starttls=False, + ).model_dump() + + return LDAPConfigResponse( + server_url=config.server_url, + bind_dn=config.bind_dn, + base_dn=config.base_dn, + user_search_filter=config.user_search_filter, + username_attr=config.username_attr, + email_attr=config.email_attr, + display_name_attr=config.display_name_attr, + is_enabled=config.is_enabled, + is_exclusive=config.is_exclusive, + use_starttls=config.use_starttls, + ).model_dump() + + +class AdminUpdateLDAPConfigAdapter(AdminApiAdapter): + async def handle(self, context) -> Dict[str, str]: # type: ignore[override] + db = context.db + payload = context.ensure_json_body() + + try: + config_update = LDAPConfigUpdate.model_validate(payload) + except ValidationError as e: + errors = e.errors() + if errors: + raise InvalidRequestException(translate_pydantic_error(errors[0])) + raise InvalidRequestException("请求数据验证失败") + + config = db.query(LDAPConfig).first() + + if not config: + config = LDAPConfig() + db.add(config) + + config.server_url = config_update.server_url + config.bind_dn = config_update.bind_dn + config.base_dn = config_update.base_dn + config.user_search_filter = config_update.user_search_filter + config.username_attr = config_update.username_attr + config.email_attr = config_update.email_attr + config.display_name_attr = config_update.display_name_attr + config.is_enabled = config_update.is_enabled + config.is_exclusive = config_update.is_exclusive + config.use_starttls = config_update.use_starttls + + if config_update.bind_password: + config.bind_password_encrypted = crypto_service.encrypt(config_update.bind_password) + + db.commit() + + return {"message": "LDAP配置更新成功"} + + +class AdminTestLDAPConnectionAdapter(AdminApiAdapter): + async def handle(self, context) -> Dict[str, Any]: # type: ignore[override] + db = context.db + config = db.query(LDAPConfig).first() + + if not config: + return LDAPTestResponse(success=False, message="LDAP配置不存在").model_dump() + + try: + import ldap3 + + bind_password = crypto_service.decrypt(config.bind_password_encrypted) + + use_ssl = config.server_url.startswith("ldaps://") + server = ldap3.Server(config.server_url, use_ssl=use_ssl, get_info=ldap3.ALL) + conn = ldap3.Connection(server, user=config.bind_dn, password=bind_password) + + if config.use_starttls and not use_ssl: + conn.start_tls() + + if not conn.bind(): + return LDAPTestResponse( + success=False, message=f"绑定失败: {conn.result}" + ).model_dump() + + conn.unbind() + return LDAPTestResponse(success=True, message="LDAP连接测试成功").model_dump() + + except ImportError: + return LDAPTestResponse(success=False, message="ldap3库未安装").model_dump() + except Exception as e: + return LDAPTestResponse(success=False, message=f"连接失败: {str(e)}").model_dump() diff --git a/src/api/auth/routes.py b/src/api/auth/routes.py index 9a614d5..c8b7ddb 100644 --- a/src/api/auth/routes.py +++ b/src/api/auth/routes.py @@ -33,6 +33,7 @@ from src.models.api import ( ) from src.models.database import AuditEventType, User, UserRole from src.services.auth.service import AuthService +from src.services.auth.ldap import LDAPService from src.services.rate_limit.ip_limiter import IPRateLimiter from src.services.system.audit import AuditService from src.services.system.config import SystemConfigService @@ -99,6 +100,13 @@ async def registration_settings(request: Request, db: Session = Depends(get_db)) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) +@router.get("/settings") +async def auth_settings(request: Request, db: Session = Depends(get_db)): + """公开获取认证设置(用于前端判断显示哪些登录选项)""" + adapter = AuthSettingsAdapter() + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + @router.post("/login", response_model=LoginResponse) async def login(request: Request, db: Session = Depends(get_db)): adapter = AuthLoginAdapter() @@ -193,7 +201,9 @@ class AuthLoginAdapter(AuthPublicAdapter): detail=f"登录请求过于频繁,请在 {reset_after} 秒后重试", ) - user = await AuthService.authenticate_user(db, login_request.email, login_request.password) + user = await AuthService.authenticate_user( + db, login_request.email, login_request.password, login_request.auth_type + ) if not user: AuditService.log_login_attempt( db=db, @@ -305,6 +315,21 @@ class AuthRegistrationSettingsAdapter(AuthPublicAdapter): ).model_dump() +class AuthSettingsAdapter(AuthPublicAdapter): + async def handle(self, context): # type: ignore[override] + """公开返回认证设置""" + db = context.db + + ldap_enabled = LDAPService.is_ldap_enabled(db) + ldap_exclusive = LDAPService.is_ldap_exclusive(db) + + return { + "local_enabled": not ldap_exclusive, + "ldap_enabled": ldap_enabled, + "ldap_exclusive": ldap_exclusive, + } + + class AuthRegisterAdapter(AuthPublicAdapter): async def handle(self, context): # type: ignore[override] from src.models.database import SystemConfig diff --git a/src/core/enums.py b/src/core/enums.py index 169e2ec..cdad32b 100644 --- a/src/core/enums.py +++ b/src/core/enums.py @@ -30,3 +30,10 @@ class ProviderBillingType(Enum): MONTHLY_QUOTA = "monthly_quota" # 月卡额度 PAY_AS_YOU_GO = "pay_as_you_go" # 按量付费 FREE_TIER = "free_tier" # 免费额度 + + +class AuthSource(str, Enum): + """认证来源枚举""" + + LOCAL = "local" # 本地认证 + LDAP = "ldap" # LDAP 认证 diff --git a/src/models/api.py b/src/models/api.py index 202d814..58ed6a5 100644 --- a/src/models/api.py +++ b/src/models/api.py @@ -4,7 +4,7 @@ API端点请求/响应模型定义 import re from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -17,6 +17,7 @@ class LoginRequest(BaseModel): email: str = Field(..., min_length=3, max_length=255, description="邮箱地址") password: str = Field(..., min_length=1, max_length=128, description="密码") + auth_type: Literal["local", "ldap"] = Field(default="local", description="认证类型") @classmethod @field_validator("email") diff --git a/src/models/database.py b/src/models/database.py index 85e4a0b..3cae081 100644 --- a/src/models/database.py +++ b/src/models/database.py @@ -30,7 +30,7 @@ from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import declarative_base, relationship from ..config import config -from ..core.enums import ProviderBillingType, UserRole +from ..core.enums import AuthSource, ProviderBillingType, UserRole Base = declarative_base() @@ -54,6 +54,16 @@ class User(Base): default=UserRole.USER, nullable=False, ) + auth_source = Column( + Enum( + AuthSource, + name="authsource", + create_type=False, + values_callable=lambda x: [e.value for e in x], + ), + default=AuthSource.LOCAL, + nullable=False, + ) # 访问限制(NULL 表示不限制,允许访问所有资源) allowed_providers = Column(JSON, nullable=True) # 允许使用的提供商 ID 列表 @@ -428,6 +438,67 @@ class SystemConfig(Base): ) +class LDAPConfig(Base): + """LDAP认证配置表 - 单行配置""" + + __tablename__ = "ldap_configs" + + id = Column(Integer, primary_key=True, autoincrement=True) + server_url = Column(String(255), nullable=False) # ldap://host:389 或 ldaps://host:636 + bind_dn = Column(String(255), nullable=False) # 绑定账号 DN + bind_password_encrypted = Column(Text, nullable=False) # 加密的绑定密码 + base_dn = Column(String(255), nullable=False) # 用户搜索基础 DN + user_search_filter = Column( + String(500), default="(uid={username})", nullable=False + ) # 用户搜索过滤器 + username_attr = Column(String(50), default="uid", nullable=False) # 用户名属性 (uid/sAMAccountName) + email_attr = Column(String(50), default="mail", nullable=False) # 邮箱属性 + display_name_attr = Column(String(50), default="cn", nullable=False) # 显示名称属性 + is_enabled = Column(Boolean, default=False, nullable=False) # 是否启用 LDAP 认证 + is_exclusive = Column( + Boolean, default=False, nullable=False + ) # 是否仅允许 LDAP 登录(禁用本地认证) + use_starttls = Column(Boolean, default=False, nullable=False) # 是否使用 STARTTLS + + # 时间戳 + created_at = Column( + DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False + ) + updated_at = Column( + DateTime(timezone=True), + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + nullable=False, + ) + + def set_bind_password(self, password: str) -> None: + """ + 设置并加密绑定密码 + + Args: + password: 明文密码 + """ + from src.core.crypto import crypto_service + + self.bind_password_encrypted = crypto_service.encrypt(password) + + def get_bind_password(self) -> str: + """ + 获取解密后的绑定密码 + + Returns: + str: 解密后的明文密码 + + Raises: + DecryptionException: 解密失败时抛出异常 + """ + from src.core.crypto import crypto_service + + if not self.bind_password_encrypted: + return "" + return crypto_service.decrypt(self.bind_password_encrypted) + + class Provider(Base): """提供商配置表""" diff --git a/src/services/auth/ldap.py b/src/services/auth/ldap.py new file mode 100644 index 0000000..d5345b4 --- /dev/null +++ b/src/services/auth/ldap.py @@ -0,0 +1,157 @@ +"""LDAP 认证服务""" + +from typing import Optional, Tuple + +from sqlalchemy.orm import Session + +from src.core.logger import logger +from src.models.database import LDAPConfig + + +class LDAPService: + """LDAP 认证服务""" + + @staticmethod + def get_config(db: Session) -> Optional[LDAPConfig]: + """获取 LDAP 配置""" + return db.query(LDAPConfig).first() + + @staticmethod + def is_ldap_enabled(db: Session) -> bool: + """检查 LDAP 是否启用""" + config = LDAPService.get_config(db) + return config.is_enabled if config else False + + @staticmethod + def is_ldap_exclusive(db: Session) -> bool: + """检查是否仅允许 LDAP 登录""" + config = LDAPService.get_config(db) + return config.is_exclusive if config and config.is_enabled else False + + @staticmethod + def authenticate(db: Session, username: str, password: str) -> Optional[dict]: + """ + LDAP bind 验证 + + Args: + db: 数据库会话 + username: 用户名 + password: 密码 + + Returns: + 用户属性 dict {username, email, display_name} 或 None + """ + try: + import ldap3 + from ldap3 import Server, Connection, SUBTREE + from ldap3.core.exceptions import LDAPBindError, LDAPSocketOpenError + except ImportError: + logger.error("ldap3 库未安装") + return None + + config = LDAPService.get_config(db) + if not config or not config.is_enabled: + logger.warning("LDAP 未配置或未启用") + return None + + try: + # 创建服务器连接 + use_ssl = config.server_url.startswith("ldaps://") + server = Server(config.server_url, use_ssl=use_ssl, get_info=ldap3.ALL) + + # 使用管理员账号连接 + bind_password = config.get_bind_password() + admin_conn = Connection(server, user=config.bind_dn, password=bind_password) + + if config.use_starttls and not use_ssl: + admin_conn.start_tls() + + if not admin_conn.bind(): + logger.error(f"LDAP 管理员绑定失败: {admin_conn.result}") + return None + + # 搜索用户 + search_filter = config.user_search_filter.replace("{username}", username) + admin_conn.search( + search_base=config.base_dn, + search_filter=search_filter, + search_scope=SUBTREE, + attributes=[config.username_attr, config.email_attr, config.display_name_attr], + ) + + if not admin_conn.entries: + logger.warning(f"LDAP 用户未找到: {username}") + admin_conn.unbind() + return None + + user_entry = admin_conn.entries[0] + user_dn = user_entry.entry_dn + admin_conn.unbind() + + # 用户密码验证 + user_conn = Connection(server, user=user_dn, password=password) + if config.use_starttls and not use_ssl: + user_conn.start_tls() + + if not user_conn.bind(): + logger.warning(f"LDAP 密码验证失败: {username}") + return None + + user_conn.unbind() + + # 提取用户属性 + email = str(getattr(user_entry, config.email_attr, "")) or f"{username}@ldap.local" + display_name = str(getattr(user_entry, config.display_name_attr, "")) or username + + logger.info(f"LDAP 认证成功: {username}") + return { + "username": username, + "email": email, + "display_name": display_name, + } + + except LDAPSocketOpenError as e: + logger.error(f"LDAP 服务器连接失败: {e}") + return None + except LDAPBindError as e: + logger.error(f"LDAP 绑定失败: {e}") + return None + except Exception as e: + logger.error(f"LDAP 认证异常: {e}") + return None + + @staticmethod + def test_connection(db: Session) -> Tuple[bool, str]: + """ + 测试 LDAP 连接 + + Returns: + (success, message) + """ + try: + import ldap3 + from ldap3 import Server, Connection + except ImportError: + return False, "ldap3 库未安装" + + config = LDAPService.get_config(db) + if not config: + return False, "LDAP 配置不存在" + + try: + use_ssl = config.server_url.startswith("ldaps://") + server = Server(config.server_url, use_ssl=use_ssl, get_info=ldap3.ALL) + bind_password = config.get_bind_password() + conn = Connection(server, user=config.bind_dn, password=bind_password) + + if config.use_starttls and not use_ssl: + conn.start_tls() + + if not conn.bind(): + return False, f"绑定失败: {conn.result}" + + conn.unbind() + return True, "连接成功" + + except Exception as e: + return False, f"连接失败: {str(e)}" diff --git a/src/services/auth/service.py b/src/services/auth/service.py index ecdf57b..b06621e 100644 --- a/src/services/auth/service.py +++ b/src/services/auth/service.py @@ -15,8 +15,10 @@ from sqlalchemy.orm import Session, joinedload from src.config import config from src.core.crypto import crypto_service from src.core.logger import logger +from src.core.enums import AuthSource from src.models.database import ApiKey, User, UserRole from src.services.auth.jwt_blacklist import JWTBlacklistService +from src.services.auth.ldap import LDAPService from src.services.cache.user_cache import UserCacheService from src.services.user.apikey import ApiKeyService @@ -92,8 +94,36 @@ class AuthService: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的Token") @staticmethod - async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]: - """用户登录认证""" + async def authenticate_user( + db: Session, email: str, password: str, auth_type: str = "local" + ) -> Optional[User]: + """用户登录认证 + + Args: + db: 数据库会话 + email: 邮箱/用户名 + password: 密码 + auth_type: 认证类型 ("local" 或 "ldap") + """ + if auth_type == "ldap": + # LDAP 认证 + if not LDAPService.is_ldap_enabled(db): + logger.warning("登录失败 - LDAP 认证未启用") + return None + + ldap_user = LDAPService.authenticate(db, email, password) + if not ldap_user: + return None + + # 获取或创建本地用户 + user = await AuthService._get_or_create_ldap_user(db, ldap_user) + return user + + # 本地认证 + if LDAPService.is_ldap_exclusive(db): + logger.warning("登录失败 - 仅允许 LDAP 登录") + return None + # 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象 user = db.query(User).filter(User.email == email).first() @@ -101,6 +131,11 @@ class AuthService: logger.warning(f"登录失败 - 用户不存在: {email}") return None + # 检查用户认证来源 + if user.auth_source == AuthSource.LDAP: + logger.warning(f"登录失败 - 该用户使用 LDAP 认证: {email}") + return None + if not user.verify_password(password): logger.warning(f"登录失败 - 密码错误: {email}") return None @@ -118,6 +153,42 @@ class AuthService: logger.info(f"用户登录成功: {email} (ID: {user.id})") return user + @staticmethod + async def _get_or_create_ldap_user(db: Session, ldap_user: dict) -> User: + """获取或创建 LDAP 用户 + + Args: + ldap_user: LDAP 用户信息 {username, email, display_name} + """ + # 先按 email 查找 + user = db.query(User).filter(User.email == ldap_user["email"]).first() + + if user: + # 更新 auth_source(如果是首次 LDAP 登录) + if user.auth_source != AuthSource.LDAP: + user.auth_source = AuthSource.LDAP + user.last_login_at = datetime.now(timezone.utc) + db.commit() + await UserCacheService.invalidate_user_cache(user.id, user.email) + logger.info(f"LDAP 用户登录成功: {ldap_user['email']} (ID: {user.id})") + return user + + # 创建新用户 + user = User( + email=ldap_user["email"], + username=ldap_user["username"], + password_hash="", # LDAP 用户无本地密码 + auth_source=AuthSource.LDAP, + role=UserRole.USER, + is_active=True, + last_login_at=datetime.now(timezone.utc), + ) + db.add(user) + db.commit() + db.refresh(user) + logger.info(f"LDAP 用户创建成功: {ldap_user['email']} (ID: {user.id})") + return user + @staticmethod def authenticate_api_key(db: Session, api_key: str) -> Optional[tuple[User, ApiKey]]: """API密钥认证"""