refactor: improve authentication and user data handling

- Replace user cache queries with direct database queries to ensure data consistency
- Fix token_type parameter in verify_token calls (access token verification)
- Fix role-based permission check using dictionary ranking instead of string comparison
- Fix logout operation to use correct JWT claim name (user_id instead of sub)
- Simplify user authentication flow by removing unnecessary cache layer
- Optimize session initialization in main.py using create_session helper
- Remove unused imports and exception variables
This commit is contained in:
fawney19
2025-12-18 01:09:22 +08:00
parent b579420690
commit 4d1d863916
6 changed files with 24 additions and 28 deletions

View File

@@ -142,7 +142,7 @@ class AnnouncementOptionalAuthAdapter(ApiAdapter):
token = authorization.replace("Bearer ", "").strip() token = authorization.replace("Bearer ", "").strip()
try: try:
payload = await AuthService.verify_token(token) payload = await AuthService.verify_token(token, token_type="access")
user_id = payload.get("user_id") user_id = payload.get("user_id")
if not user_id: if not user_id:
return None return None

View File

@@ -11,7 +11,6 @@ from src.core.exceptions import QuotaExceededException
from src.core.logger import logger from src.core.logger import logger
from src.models.database import ApiKey, AuditEventType, User, UserRole from src.models.database import ApiKey, AuditEventType, User, UserRole
from src.services.auth.service import AuthService from src.services.auth.service import AuthService
from src.services.cache.user_cache import UserCacheService
from src.services.system.audit import AuditService from src.services.system.audit import AuditService
from src.services.usage.service import UsageService from src.services.usage.service import UsageService
@@ -180,7 +179,7 @@ class ApiRequestPipeline:
token = authorization.replace("Bearer ", "").strip() token = authorization.replace("Bearer ", "").strip()
try: try:
payload = await self.auth_service.verify_token(token) payload = await self.auth_service.verify_token(token, token_type="access")
except HTTPException: except HTTPException:
raise raise
except Exception as exc: except Exception as exc:
@@ -191,8 +190,8 @@ class ApiRequestPipeline:
if not user_id: if not user_id:
raise HTTPException(status_code=401, detail="无效的管理员令牌") raise HTTPException(status_code=401, detail="无效的管理员令牌")
# 使用缓存查询用户 # 直接查询数据库,确保返回的是当前 Session 绑定的对象
user = await UserCacheService.get_user_by_id(db, user_id) user = db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active: if not user or not user.is_active:
raise HTTPException(status_code=403, detail="用户不存在或已禁用") raise HTTPException(status_code=403, detail="用户不存在或已禁用")
@@ -207,7 +206,7 @@ class ApiRequestPipeline:
token = authorization.replace("Bearer ", "").strip() token = authorization.replace("Bearer ", "").strip()
try: try:
payload = await self.auth_service.verify_token(token) payload = await self.auth_service.verify_token(token, token_type="access")
except HTTPException: except HTTPException:
raise raise
except Exception as exc: except Exception as exc:
@@ -218,8 +217,8 @@ class ApiRequestPipeline:
if not user_id: if not user_id:
raise HTTPException(status_code=401, detail="无效的用户令牌") raise HTTPException(status_code=401, detail="无效的用户令牌")
# 使用缓存查询用户 # 直接查询数据库,确保返回的是当前 Session 绑定的对象
user = await UserCacheService.get_user_by_id(db, user_id) user = db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active: if not user or not user.is_active:
raise HTTPException(status_code=403, detail="用户不存在或已禁用") raise HTTPException(status_code=403, detail="用户不存在或已禁用")

View File

@@ -3,7 +3,6 @@
采用模块化架构设计 采用模块化架构设计
""" """
import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
@@ -39,14 +38,12 @@ async def initialize_providers():
"""从数据库初始化提供商(仅用于日志记录)""" """从数据库初始化提供商(仅用于日志记录)"""
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.core.enums import APIFormat from src.database.database import create_session
from src.database import get_db
from src.models.database import Provider from src.models.database import Provider
try: try:
# 创建数据库会话 # 创建数据库会话
db_gen = get_db() db: Session = create_session()
db: Session = next(db_gen)
try: try:
# 从数据库加载所有活跃的提供商 # 从数据库加载所有活跃的提供商
@@ -75,7 +72,7 @@ async def initialize_providers():
finally: finally:
db.close() db.close()
except Exception as e: except Exception:
logger.exception("从数据库初始化提供商失败") logger.exception("从数据库初始化提供商失败")

View File

@@ -51,7 +51,7 @@ class JwtAuthPlugin(AuthPlugin):
try: try:
# 验证JWT token # 验证JWT token
payload = AuthService.verify_token(token) payload = await AuthService.verify_token(token, token_type="access")
logger.debug(f"JWT token验证成功, payload: {payload}") logger.debug(f"JWT token验证成功, payload: {payload}")
# 从payload中提取用户信息 # 从payload中提取用户信息

View File

@@ -93,8 +93,8 @@ class AuthService:
@staticmethod @staticmethod
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]: async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
"""用户登录认证""" """用户登录认证"""
# 使用缓存查询用户 # 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
user = await UserCacheService.get_user_by_email(db, email) user = db.query(User).filter(User.email == email).first()
if not user: if not user:
logger.warning(f"登录失败 - 用户不存在: {email}") logger.warning(f"登录失败 - 用户不存在: {email}")
@@ -109,13 +109,10 @@ class AuthService:
return None return None
# 更新最后登录时间 # 更新最后登录时间
# 需要重新从数据库获取以便更新(缓存的对象是分离的) user.last_login_at = datetime.now(timezone.utc)
db_user = db.query(User).filter(User.id == user.id).first() db.commit() # 立即提交事务,释放数据库锁
if db_user: # 清除缓存,因为用户信息已更新
db_user.last_login_at = datetime.now(timezone.utc) await UserCacheService.invalidate_user_cache(user.id, user.email)
db.commit() # 立即提交事务,释放数据库锁
# 清除缓存,因为用户信息已更新
await UserCacheService.invalidate_user_cache(user.id, user.email)
logger.info(f"用户登录成功: {email} (ID: {user.id})") logger.info(f"用户登录成功: {email} (ID: {user.id})")
return user return user
@@ -198,7 +195,10 @@ class AuthService:
if user.role == UserRole.ADMIN: if user.role == UserRole.ADMIN:
return True return True
if user.role.value >= required_role.value: # 避免使用字符串比较导致权限判断错误(例如 'user' >= 'admin'
role_rank = {UserRole.USER: 0, UserRole.ADMIN: 1}
# 未知用户角色默认 -1拒绝未知要求角色默认 999拒绝
if role_rank.get(user.role, -1) >= role_rank.get(required_role, 999):
return True return True
logger.warning(f"权限不足: 用户 {user.email} 角色 {user.role.value} < 需要 {required_role.value}") logger.warning(f"权限不足: 用户 {user.email} 角色 {user.role.value} < 需要 {required_role.value}")
@@ -230,7 +230,7 @@ class AuthService:
) )
if success: if success:
user_id = payload.get("sub") user_id = payload.get("user_id")
logger.info(f"用户登出成功: user_id={user_id}") logger.info(f"用户登出成功: user_id={user_id}")
return success return success

View File

@@ -41,7 +41,7 @@ async def get_current_user(
try: try:
# 验证Token格式和签名 # 验证Token格式和签名
try: try:
payload = await AuthService.verify_token(token) payload = await AuthService.verify_token(token, token_type="access")
except HTTPException as token_error: except HTTPException as token_error:
# 保持原始的HTTP状态码如401 Unauthorized不要转换为403 # 保持原始的HTTP状态码如401 Unauthorized不要转换为403
logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...") logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...")
@@ -144,7 +144,7 @@ async def get_current_user_from_header(
token = authorization.replace("Bearer ", "") token = authorization.replace("Bearer ", "")
try: try:
payload = await AuthService.verify_token(token) payload = await AuthService.verify_token(token, token_type="access")
user_id = payload.get("user_id") user_id = payload.get("user_id")
if not user_id: if not user_id: