mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
refactor: improve authentication and user data handling
- Replace user cache queries with direct database queries to ensure data consistency - Fix token_type parameter in verify_token calls (access token verification) - Fix role-based permission check using dictionary ranking instead of string comparison - Fix logout operation to use correct JWT claim name (user_id instead of sub) - Simplify user authentication flow by removing unnecessary cache layer - Optimize session initialization in main.py using create_session helper - Remove unused imports and exception variables
This commit is contained in:
@@ -142,7 +142,7 @@ class AnnouncementOptionalAuthAdapter(ApiAdapter):
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
try:
|
||||
payload = await AuthService.verify_token(token)
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
user_id = payload.get("user_id")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
@@ -11,7 +11,6 @@ from src.core.exceptions import QuotaExceededException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ApiKey, AuditEventType, User, UserRole
|
||||
from src.services.auth.service import AuthService
|
||||
from src.services.cache.user_cache import UserCacheService
|
||||
from src.services.system.audit import AuditService
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
@@ -180,7 +179,7 @@ class ApiRequestPipeline:
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token)
|
||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
@@ -191,8 +190,8 @@ class ApiRequestPipeline:
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="无效的管理员令牌")
|
||||
|
||||
# 使用缓存查询用户
|
||||
user = await UserCacheService.get_user_by_id(db, user_id)
|
||||
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||||
|
||||
@@ -207,7 +206,7 @@ class ApiRequestPipeline:
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token)
|
||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
@@ -218,8 +217,8 @@ class ApiRequestPipeline:
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="无效的用户令牌")
|
||||
|
||||
# 使用缓存查询用户
|
||||
user = await UserCacheService.get_user_by_id(db, user_id)
|
||||
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
采用模块化架构设计
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
@@ -39,14 +38,12 @@ async def initialize_providers():
|
||||
"""从数据库初始化提供商(仅用于日志记录)"""
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.enums import APIFormat
|
||||
from src.database import get_db
|
||||
from src.database.database import create_session
|
||||
from src.models.database import Provider
|
||||
|
||||
try:
|
||||
# 创建数据库会话
|
||||
db_gen = get_db()
|
||||
db: Session = next(db_gen)
|
||||
db: Session = create_session()
|
||||
|
||||
try:
|
||||
# 从数据库加载所有活跃的提供商
|
||||
@@ -75,7 +72,7 @@ async def initialize_providers():
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("从数据库初始化提供商失败")
|
||||
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class JwtAuthPlugin(AuthPlugin):
|
||||
|
||||
try:
|
||||
# 验证JWT token
|
||||
payload = AuthService.verify_token(token)
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
logger.debug(f"JWT token验证成功, payload: {payload}")
|
||||
|
||||
# 从payload中提取用户信息
|
||||
|
||||
@@ -93,8 +93,8 @@ class AuthService:
|
||||
@staticmethod
|
||||
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
"""用户登录认证"""
|
||||
# 使用缓存查询用户
|
||||
user = await UserCacheService.get_user_by_email(db, email)
|
||||
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
|
||||
if not user:
|
||||
logger.warning(f"登录失败 - 用户不存在: {email}")
|
||||
@@ -109,13 +109,10 @@ class AuthService:
|
||||
return None
|
||||
|
||||
# 更新最后登录时间
|
||||
# 需要重新从数据库获取以便更新(缓存的对象是分离的)
|
||||
db_user = db.query(User).filter(User.id == user.id).first()
|
||||
if db_user:
|
||||
db_user.last_login_at = datetime.now(timezone.utc)
|
||||
db.commit() # 立即提交事务,释放数据库锁
|
||||
# 清除缓存,因为用户信息已更新
|
||||
await UserCacheService.invalidate_user_cache(user.id, user.email)
|
||||
user.last_login_at = datetime.now(timezone.utc)
|
||||
db.commit() # 立即提交事务,释放数据库锁
|
||||
# 清除缓存,因为用户信息已更新
|
||||
await UserCacheService.invalidate_user_cache(user.id, user.email)
|
||||
|
||||
logger.info(f"用户登录成功: {email} (ID: {user.id})")
|
||||
return user
|
||||
@@ -198,7 +195,10 @@ class AuthService:
|
||||
if user.role == UserRole.ADMIN:
|
||||
return True
|
||||
|
||||
if user.role.value >= required_role.value:
|
||||
# 避免使用字符串比较导致权限判断错误(例如 'user' >= 'admin')
|
||||
role_rank = {UserRole.USER: 0, UserRole.ADMIN: 1}
|
||||
# 未知用户角色默认 -1(拒绝),未知要求角色默认 999(拒绝)
|
||||
if role_rank.get(user.role, -1) >= role_rank.get(required_role, 999):
|
||||
return True
|
||||
|
||||
logger.warning(f"权限不足: 用户 {user.email} 角色 {user.role.value} < 需要 {required_role.value}")
|
||||
@@ -230,7 +230,7 @@ class AuthService:
|
||||
)
|
||||
|
||||
if success:
|
||||
user_id = payload.get("sub")
|
||||
user_id = payload.get("user_id")
|
||||
logger.info(f"用户登出成功: user_id={user_id}")
|
||||
|
||||
return success
|
||||
|
||||
@@ -41,7 +41,7 @@ async def get_current_user(
|
||||
try:
|
||||
# 验证Token格式和签名
|
||||
try:
|
||||
payload = await AuthService.verify_token(token)
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
except HTTPException as token_error:
|
||||
# 保持原始的HTTP状态码(如401 Unauthorized),不要转换为403
|
||||
logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...")
|
||||
@@ -144,7 +144,7 @@ async def get_current_user_from_header(
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
try:
|
||||
payload = await AuthService.verify_token(token)
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
user_id = payload.get("user_id")
|
||||
|
||||
if not user_id:
|
||||
|
||||
Reference in New Issue
Block a user