Files
Aether/src/services/auth/jwt_blacklist.py
fawney19 7b932d7afb refactor: optimize middleware with pure ASGI implementation and enhance security measures
- Replace BaseHTTPMiddleware with pure ASGI implementation in plugin middleware for better streaming response handling
- Add trusted proxy count configuration for client IP extraction in reverse proxy environments
- Implement audit log cleanup scheduler with configurable retention period
- Replace plaintext token logging with SHA256 hash fingerprints for security
- Fix database session lifecycle management in middleware
- Improve request tracing and error logging throughout the system
- Add comprehensive tests for pipeline architecture
2025-12-18 19:07:20 +08:00

197 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
JWT Token 黑名单服务
使用 Redis 存储被撤销的 JWT Token防止已登出或被撤销的 Token 继续使用
"""
import hashlib
import os
from datetime import datetime, timedelta, timezone
from typing import Optional
from src.clients.redis_client import get_redis_client
from src.core.logger import logger
# 安全策略配置:当 Redis 不可用时的行为
# True = fail-closed安全优先拒绝访问
# False = fail-open可用性优先允许访问
BLACKLIST_FAIL_CLOSED = os.getenv("JWT_BLACKLIST_FAIL_CLOSED", "true").lower() == "true"
class JWTBlacklistService:
"""JWT Token 黑名单服务"""
# Redis key 前缀
BLACKLIST_PREFIX = "jwt:blacklist:"
@staticmethod
def _get_token_hash(token: str) -> str:
"""
获取 Token 的哈希值(用于 Redis key
使用 SHA256 哈希避免直接存储完整 Token
"""
return hashlib.sha256(token.encode()).hexdigest()
@staticmethod
async def add_to_blacklist(token: str, exp_timestamp: int, reason: str = "logout") -> bool:
"""
将 Token 添加到黑名单
Args:
token: JWT token 字符串
exp_timestamp: Token 的过期时间戳Unix timestamp
reason: 添加到黑名单的原因logout, revoked, security
Returns:
是否成功添加到黑名单
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
logger.warning("Redis 不可用,无法将 Token 添加到黑名单(降级模式)")
return False
try:
token_hash = JWTBlacklistService._get_token_hash(token)
redis_key = f"{JWTBlacklistService.BLACKLIST_PREFIX}{token_hash}"
# 计算 TTLToken 过期前的剩余时间)
now = datetime.now(timezone.utc).timestamp()
ttl_seconds = max(int(exp_timestamp - now), 0)
if ttl_seconds <= 0:
# Token 已经过期,不需要加入黑名单
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
logger.debug("Token 已过期,无需加入黑名单: token_fp={}", token_fp)
return True
# 存储到 Redis设置 TTL 为 Token 过期时间
# 值存储为原因字符串
await redis_client.setex(redis_key, ttl_seconds, reason)
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
logger.info("Token 已加入黑名单: token_fp={} (原因: {}, TTL: {}s)", token_fp, reason, ttl_seconds)
return True
except Exception as e:
logger.error(f"添加 Token 到黑名单失败: {e}")
return False
@staticmethod
async def is_blacklisted(token: str) -> bool:
"""
检查 Token 是否在黑名单中
Args:
token: JWT token 字符串
Returns:
Token 是否在黑名单中
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
# Redis 不可用时,根据安全策略决定行为
if BLACKLIST_FAIL_CLOSED:
logger.warning("Redis 不可用,采用 fail-closed 策略拒绝访问(可通过 JWT_BLACKLIST_FAIL_CLOSED=false 改变)")
return True # 返回 True 表示在黑名单中,拒绝访问
else:
logger.debug("Redis 不可用,采用 fail-open 策略允许访问")
return False
try:
token_hash = JWTBlacklistService._get_token_hash(token)
redis_key = f"{JWTBlacklistService.BLACKLIST_PREFIX}{token_hash}"
# 检查 key 是否存在
exists = await redis_client.exists(redis_key)
if exists:
# 获取黑名单原因(可选)
reason = await redis_client.get(redis_key)
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
logger.warning("检测到黑名单 Token: token_fp={} (原因: {})", token_fp, reason)
return True
return False
except Exception as e:
logger.error(f"检查 Token 黑名单状态失败: {e}")
# 发生错误时,根据安全策略决定行为
if BLACKLIST_FAIL_CLOSED:
logger.warning("黑名单检查失败,采用 fail-closed 策略拒绝访问")
return True # 安全优先,拒绝访问
else:
logger.warning("黑名单检查失败,采用 fail-open 策略允许访问")
return False
@staticmethod
async def remove_from_blacklist(token: str) -> bool:
"""
从黑名单中移除 Token用于测试或特殊情况
Args:
token: JWT token 字符串
Returns:
是否成功移除
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
logger.warning("Redis 不可用,无法从黑名单中移除 Token")
return False
try:
token_hash = JWTBlacklistService._get_token_hash(token)
redis_key = f"{JWTBlacklistService.BLACKLIST_PREFIX}{token_hash}"
deleted = await redis_client.delete(redis_key)
if deleted:
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
logger.info("Token 已从黑名单移除: token_fp={}", token_fp)
else:
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
logger.debug("Token 不在黑名单中: token_fp={}", token_fp)
return bool(deleted)
except Exception as e:
logger.error(f"从黑名单移除 Token 失败: {e}")
return False
@staticmethod
async def get_blacklist_stats() -> dict:
"""
获取黑名单统计信息
Returns:
包含统计信息的字典
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
return {"available": False, "total_blacklisted": 0, "error": "Redis 不可用"}
try:
# 扫描黑名单 key
pattern = f"{JWTBlacklistService.BLACKLIST_PREFIX}*"
cursor = 0
total = 0
while True:
cursor, keys = await redis_client.scan(cursor=cursor, match=pattern, count=100)
total += len(keys)
if cursor == 0:
break
return {"available": True, "total_blacklisted": total}
except Exception as e:
logger.error(f"获取黑名单统计失败: {e}")
return {"available": False, "total_blacklisted": 0, "error": str(e)}