From 4d1d8639167b090e3ef07326651a7eb7ae367f0a Mon Sep 17 00:00:00 2001 From: fawney19 Date: Thu, 18 Dec 2025 01:09:22 +0800 Subject: [PATCH] refactor: improve authentication and user data handling - Replace user cache queries with direct database queries to ensure data consistency - Fix token_type parameter in verify_token calls (access token verification) - Fix role-based permission check using dictionary ranking instead of string comparison - Fix logout operation to use correct JWT claim name (user_id instead of sub) - Simplify user authentication flow by removing unnecessary cache layer - Optimize session initialization in main.py using create_session helper - Remove unused imports and exception variables --- src/api/announcements/routes.py | 2 +- src/api/base/pipeline.py | 13 ++++++------- src/main.py | 9 +++------ src/plugins/auth/jwt.py | 2 +- src/services/auth/service.py | 22 +++++++++++----------- src/utils/auth_utils.py | 4 ++-- 6 files changed, 24 insertions(+), 28 deletions(-) diff --git a/src/api/announcements/routes.py b/src/api/announcements/routes.py index a787356..786727d 100644 --- a/src/api/announcements/routes.py +++ b/src/api/announcements/routes.py @@ -142,7 +142,7 @@ class AnnouncementOptionalAuthAdapter(ApiAdapter): token = authorization.replace("Bearer ", "").strip() try: - payload = await AuthService.verify_token(token) + payload = await AuthService.verify_token(token, token_type="access") user_id = payload.get("user_id") if not user_id: return None diff --git a/src/api/base/pipeline.py b/src/api/base/pipeline.py index 779f0fa..f49501b 100644 --- a/src/api/base/pipeline.py +++ b/src/api/base/pipeline.py @@ -11,7 +11,6 @@ from src.core.exceptions import QuotaExceededException from src.core.logger import logger from src.models.database import ApiKey, AuditEventType, User, UserRole from src.services.auth.service import AuthService -from src.services.cache.user_cache import UserCacheService from src.services.system.audit import AuditService from src.services.usage.service import UsageService @@ -180,7 +179,7 @@ class ApiRequestPipeline: token = authorization.replace("Bearer ", "").strip() try: - payload = await self.auth_service.verify_token(token) + payload = await self.auth_service.verify_token(token, token_type="access") except HTTPException: raise except Exception as exc: @@ -191,8 +190,8 @@ class ApiRequestPipeline: if not user_id: raise HTTPException(status_code=401, detail="无效的管理员令牌") - # 使用缓存查询用户 - user = await UserCacheService.get_user_by_id(db, user_id) + # 直接查询数据库,确保返回的是当前 Session 绑定的对象 + user = db.query(User).filter(User.id == user_id).first() if not user or not user.is_active: raise HTTPException(status_code=403, detail="用户不存在或已禁用") @@ -207,7 +206,7 @@ class ApiRequestPipeline: token = authorization.replace("Bearer ", "").strip() try: - payload = await self.auth_service.verify_token(token) + payload = await self.auth_service.verify_token(token, token_type="access") except HTTPException: raise except Exception as exc: @@ -218,8 +217,8 @@ class ApiRequestPipeline: if not user_id: raise HTTPException(status_code=401, detail="无效的用户令牌") - # 使用缓存查询用户 - user = await UserCacheService.get_user_by_id(db, user_id) + # 直接查询数据库,确保返回的是当前 Session 绑定的对象 + user = db.query(User).filter(User.id == user_id).first() if not user or not user.is_active: raise HTTPException(status_code=403, detail="用户不存在或已禁用") diff --git a/src/main.py b/src/main.py index 6cde586..87f956e 100644 --- a/src/main.py +++ b/src/main.py @@ -3,7 +3,6 @@ 采用模块化架构设计 """ -import asyncio from contextlib import asynccontextmanager from pathlib import Path @@ -39,14 +38,12 @@ async def initialize_providers(): """从数据库初始化提供商(仅用于日志记录)""" from sqlalchemy.orm import Session - from src.core.enums import APIFormat - from src.database import get_db + from src.database.database import create_session from src.models.database import Provider try: # 创建数据库会话 - db_gen = get_db() - db: Session = next(db_gen) + db: Session = create_session() try: # 从数据库加载所有活跃的提供商 @@ -75,7 +72,7 @@ async def initialize_providers(): finally: db.close() - except Exception as e: + except Exception: logger.exception("从数据库初始化提供商失败") diff --git a/src/plugins/auth/jwt.py b/src/plugins/auth/jwt.py index 5ec66f9..f04eefc 100644 --- a/src/plugins/auth/jwt.py +++ b/src/plugins/auth/jwt.py @@ -51,7 +51,7 @@ class JwtAuthPlugin(AuthPlugin): try: # 验证JWT token - payload = AuthService.verify_token(token) + payload = await AuthService.verify_token(token, token_type="access") logger.debug(f"JWT token验证成功, payload: {payload}") # 从payload中提取用户信息 diff --git a/src/services/auth/service.py b/src/services/auth/service.py index 81391d0..ca6ea05 100644 --- a/src/services/auth/service.py +++ b/src/services/auth/service.py @@ -93,8 +93,8 @@ class AuthService: @staticmethod async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]: """用户登录认证""" - # 使用缓存查询用户 - user = await UserCacheService.get_user_by_email(db, email) + # 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象 + user = db.query(User).filter(User.email == email).first() if not user: logger.warning(f"登录失败 - 用户不存在: {email}") @@ -109,13 +109,10 @@ class AuthService: return None # 更新最后登录时间 - # 需要重新从数据库获取以便更新(缓存的对象是分离的) - db_user = db.query(User).filter(User.id == user.id).first() - if db_user: - db_user.last_login_at = datetime.now(timezone.utc) - db.commit() # 立即提交事务,释放数据库锁 - # 清除缓存,因为用户信息已更新 - await UserCacheService.invalidate_user_cache(user.id, user.email) + user.last_login_at = datetime.now(timezone.utc) + db.commit() # 立即提交事务,释放数据库锁 + # 清除缓存,因为用户信息已更新 + await UserCacheService.invalidate_user_cache(user.id, user.email) logger.info(f"用户登录成功: {email} (ID: {user.id})") return user @@ -198,7 +195,10 @@ class AuthService: if user.role == UserRole.ADMIN: return True - if user.role.value >= required_role.value: + # 避免使用字符串比较导致权限判断错误(例如 'user' >= 'admin') + role_rank = {UserRole.USER: 0, UserRole.ADMIN: 1} + # 未知用户角色默认 -1(拒绝),未知要求角色默认 999(拒绝) + if role_rank.get(user.role, -1) >= role_rank.get(required_role, 999): return True logger.warning(f"权限不足: 用户 {user.email} 角色 {user.role.value} < 需要 {required_role.value}") @@ -230,7 +230,7 @@ class AuthService: ) if success: - user_id = payload.get("sub") + user_id = payload.get("user_id") logger.info(f"用户登出成功: user_id={user_id}") return success diff --git a/src/utils/auth_utils.py b/src/utils/auth_utils.py index 23670c2..57c07c0 100644 --- a/src/utils/auth_utils.py +++ b/src/utils/auth_utils.py @@ -41,7 +41,7 @@ async def get_current_user( try: # 验证Token格式和签名 try: - payload = await AuthService.verify_token(token) + payload = await AuthService.verify_token(token, token_type="access") except HTTPException as token_error: # 保持原始的HTTP状态码(如401 Unauthorized),不要转换为403 logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...") @@ -144,7 +144,7 @@ async def get_current_user_from_header( token = authorization.replace("Bearer ", "") try: - payload = await AuthService.verify_token(token) + payload = await AuthService.verify_token(token, token_type="access") user_id = payload.get("user_id") if not user_id: