refactor: improve authentication and user data handling

- Replace user cache queries with direct database queries to ensure data consistency
- Fix token_type parameter in verify_token calls (access token verification)
- Fix role-based permission check using dictionary ranking instead of string comparison
- Fix logout operation to use correct JWT claim name (user_id instead of sub)
- Simplify user authentication flow by removing unnecessary cache layer
- Optimize session initialization in main.py using create_session helper
- Remove unused imports and exception variables
This commit is contained in:
fawney19
2025-12-18 01:09:22 +08:00
parent b579420690
commit 4d1d863916
6 changed files with 24 additions and 28 deletions

View File

@@ -142,7 +142,7 @@ class AnnouncementOptionalAuthAdapter(ApiAdapter):
token = authorization.replace("Bearer ", "").strip()
try:
payload = await AuthService.verify_token(token)
payload = await AuthService.verify_token(token, token_type="access")
user_id = payload.get("user_id")
if not user_id:
return None

View File

@@ -11,7 +11,6 @@ from src.core.exceptions import QuotaExceededException
from src.core.logger import logger
from src.models.database import ApiKey, AuditEventType, User, UserRole
from src.services.auth.service import AuthService
from src.services.cache.user_cache import UserCacheService
from src.services.system.audit import AuditService
from src.services.usage.service import UsageService
@@ -180,7 +179,7 @@ class ApiRequestPipeline:
token = authorization.replace("Bearer ", "").strip()
try:
payload = await self.auth_service.verify_token(token)
payload = await self.auth_service.verify_token(token, token_type="access")
except HTTPException:
raise
except Exception as exc:
@@ -191,8 +190,8 @@ class ApiRequestPipeline:
if not user_id:
raise HTTPException(status_code=401, detail="无效的管理员令牌")
# 使用缓存查询用户
user = await UserCacheService.get_user_by_id(db, user_id)
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
user = db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active:
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
@@ -207,7 +206,7 @@ class ApiRequestPipeline:
token = authorization.replace("Bearer ", "").strip()
try:
payload = await self.auth_service.verify_token(token)
payload = await self.auth_service.verify_token(token, token_type="access")
except HTTPException:
raise
except Exception as exc:
@@ -218,8 +217,8 @@ class ApiRequestPipeline:
if not user_id:
raise HTTPException(status_code=401, detail="无效的用户令牌")
# 使用缓存查询用户
user = await UserCacheService.get_user_by_id(db, user_id)
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
user = db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active:
raise HTTPException(status_code=403, detail="用户不存在或已禁用")

View File

@@ -3,7 +3,6 @@
采用模块化架构设计
"""
import asyncio
from contextlib import asynccontextmanager
from pathlib import Path
@@ -39,14 +38,12 @@ async def initialize_providers():
"""从数据库初始化提供商(仅用于日志记录)"""
from sqlalchemy.orm import Session
from src.core.enums import APIFormat
from src.database import get_db
from src.database.database import create_session
from src.models.database import Provider
try:
# 创建数据库会话
db_gen = get_db()
db: Session = next(db_gen)
db: Session = create_session()
try:
# 从数据库加载所有活跃的提供商
@@ -75,7 +72,7 @@ async def initialize_providers():
finally:
db.close()
except Exception as e:
except Exception:
logger.exception("从数据库初始化提供商失败")

View File

@@ -51,7 +51,7 @@ class JwtAuthPlugin(AuthPlugin):
try:
# 验证JWT token
payload = AuthService.verify_token(token)
payload = await AuthService.verify_token(token, token_type="access")
logger.debug(f"JWT token验证成功, payload: {payload}")
# 从payload中提取用户信息

View File

@@ -93,8 +93,8 @@ class AuthService:
@staticmethod
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
"""用户登录认证"""
# 使用缓存查询用户
user = await UserCacheService.get_user_by_email(db, email)
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
user = db.query(User).filter(User.email == email).first()
if not user:
logger.warning(f"登录失败 - 用户不存在: {email}")
@@ -109,13 +109,10 @@ class AuthService:
return None
# 更新最后登录时间
# 需要重新从数据库获取以便更新(缓存的对象是分离的)
db_user = db.query(User).filter(User.id == user.id).first()
if db_user:
db_user.last_login_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁
# 清除缓存,因为用户信息已更新
await UserCacheService.invalidate_user_cache(user.id, user.email)
user.last_login_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁
# 清除缓存,因为用户信息已更新
await UserCacheService.invalidate_user_cache(user.id, user.email)
logger.info(f"用户登录成功: {email} (ID: {user.id})")
return user
@@ -198,7 +195,10 @@ class AuthService:
if user.role == UserRole.ADMIN:
return True
if user.role.value >= required_role.value:
# 避免使用字符串比较导致权限判断错误(例如 'user' >= 'admin'
role_rank = {UserRole.USER: 0, UserRole.ADMIN: 1}
# 未知用户角色默认 -1拒绝未知要求角色默认 999拒绝
if role_rank.get(user.role, -1) >= role_rank.get(required_role, 999):
return True
logger.warning(f"权限不足: 用户 {user.email} 角色 {user.role.value} < 需要 {required_role.value}")
@@ -230,7 +230,7 @@ class AuthService:
)
if success:
user_id = payload.get("sub")
user_id = payload.get("user_id")
logger.info(f"用户登出成功: user_id={user_id}")
return success

View File

@@ -41,7 +41,7 @@ async def get_current_user(
try:
# 验证Token格式和签名
try:
payload = await AuthService.verify_token(token)
payload = await AuthService.verify_token(token, token_type="access")
except HTTPException as token_error:
# 保持原始的HTTP状态码如401 Unauthorized不要转换为403
logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...")
@@ -144,7 +144,7 @@ async def get_current_user_from_header(
token = authorization.replace("Bearer ", "")
try:
payload = await AuthService.verify_token(token)
payload = await AuthService.verify_token(token, token_type="access")
user_id = payload.get("user_id")
if not user_id: