Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
"""
认证服务模块
包含认证服务、JWT 黑名单等功能。
"""
from src.services.auth.jwt_blacklist import JWTBlacklistService
from src.services.auth.service import AuthService
__all__ = [
"AuthService",
"JWTBlacklistService",
]

View File

@@ -0,0 +1,191 @@
"""
JWT Token 黑名单服务
使用 Redis 存储被撤销的 JWT Token防止已登出或被撤销的 Token 继续使用
"""
import hashlib
import os
from datetime import datetime, timedelta, timezone
from typing import Optional
from src.clients.redis_client import get_redis_client
from src.core.logger import logger
# 安全策略配置:当 Redis 不可用时的行为
# True = fail-closed安全优先拒绝访问
# False = fail-open可用性优先允许访问
BLACKLIST_FAIL_CLOSED = os.getenv("JWT_BLACKLIST_FAIL_CLOSED", "true").lower() == "true"
class JWTBlacklistService:
"""JWT Token 黑名单服务"""
# Redis key 前缀
BLACKLIST_PREFIX = "jwt:blacklist:"
@staticmethod
def _get_token_hash(token: str) -> str:
"""
获取 Token 的哈希值(用于 Redis key
使用 SHA256 哈希避免直接存储完整 Token
"""
return hashlib.sha256(token.encode()).hexdigest()
@staticmethod
async def add_to_blacklist(token: str, exp_timestamp: int, reason: str = "logout") -> bool:
"""
将 Token 添加到黑名单
Args:
token: JWT token 字符串
exp_timestamp: Token 的过期时间戳Unix timestamp
reason: 添加到黑名单的原因logout, revoked, security
Returns:
是否成功添加到黑名单
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
logger.warning("Redis 不可用,无法将 Token 添加到黑名单(降级模式)")
return False
try:
token_hash = JWTBlacklistService._get_token_hash(token)
redis_key = f"{JWTBlacklistService.BLACKLIST_PREFIX}{token_hash}"
# 计算 TTLToken 过期前的剩余时间)
now = datetime.now(timezone.utc).timestamp()
ttl_seconds = max(int(exp_timestamp - now), 0)
if ttl_seconds <= 0:
# Token 已经过期,不需要加入黑名单
logger.debug(f"Token 已过期,无需加入黑名单: {token[:10]}...")
return True
# 存储到 Redis设置 TTL 为 Token 过期时间
# 值存储为原因字符串
await redis_client.setex(redis_key, ttl_seconds, reason)
logger.info(f"Token 已加入黑名单: {token[:10]}... (原因: {reason}, TTL: {ttl_seconds}s)")
return True
except Exception as e:
logger.error(f"添加 Token 到黑名单失败: {e}")
return False
@staticmethod
async def is_blacklisted(token: str) -> bool:
"""
检查 Token 是否在黑名单中
Args:
token: JWT token 字符串
Returns:
Token 是否在黑名单中
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
# Redis 不可用时,根据安全策略决定行为
if BLACKLIST_FAIL_CLOSED:
logger.warning("Redis 不可用,采用 fail-closed 策略拒绝访问(可通过 JWT_BLACKLIST_FAIL_CLOSED=false 改变)")
return True # 返回 True 表示在黑名单中,拒绝访问
else:
logger.debug("Redis 不可用,采用 fail-open 策略允许访问")
return False
try:
token_hash = JWTBlacklistService._get_token_hash(token)
redis_key = f"{JWTBlacklistService.BLACKLIST_PREFIX}{token_hash}"
# 检查 key 是否存在
exists = await redis_client.exists(redis_key)
if exists:
# 获取黑名单原因(可选)
reason = await redis_client.get(redis_key)
logger.warning(f"检测到黑名单 Token: {token[:10]}... (原因: {reason})")
return True
return False
except Exception as e:
logger.error(f"检查 Token 黑名单状态失败: {e}")
# 发生错误时,根据安全策略决定行为
if BLACKLIST_FAIL_CLOSED:
logger.warning("黑名单检查失败,采用 fail-closed 策略拒绝访问")
return True # 安全优先,拒绝访问
else:
logger.warning("黑名单检查失败,采用 fail-open 策略允许访问")
return False
@staticmethod
async def remove_from_blacklist(token: str) -> bool:
"""
从黑名单中移除 Token用于测试或特殊情况
Args:
token: JWT token 字符串
Returns:
是否成功移除
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
logger.warning("Redis 不可用,无法从黑名单中移除 Token")
return False
try:
token_hash = JWTBlacklistService._get_token_hash(token)
redis_key = f"{JWTBlacklistService.BLACKLIST_PREFIX}{token_hash}"
deleted = await redis_client.delete(redis_key)
if deleted:
logger.info(f"Token 已从黑名单移除: {token[:10]}...")
else:
logger.debug(f"Token 不在黑名单中: {token[:10]}...")
return bool(deleted)
except Exception as e:
logger.error(f"从黑名单移除 Token 失败: {e}")
return False
@staticmethod
async def get_blacklist_stats() -> dict:
"""
获取黑名单统计信息
Returns:
包含统计信息的字典
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
return {"available": False, "total_blacklisted": 0, "error": "Redis 不可用"}
try:
# 扫描黑名单 key
pattern = f"{JWTBlacklistService.BLACKLIST_PREFIX}*"
cursor = 0
total = 0
while True:
cursor, keys = await redis_client.scan(cursor=cursor, match=pattern, count=100)
total += len(keys)
if cursor == 0:
break
return {"available": True, "total_blacklisted": total}
except Exception as e:
logger.error(f"获取黑名单统计失败: {e}")
return {"available": False, "total_blacklisted": 0, "error": str(e)}

View File

@@ -0,0 +1,282 @@
"""
认证服务
"""
import os
import secrets
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional
import jwt
from fastapi import HTTPException, status
from sqlalchemy.orm import Session, joinedload
from src.config import config
from src.core.crypto import crypto_service
from src.core.logger import logger
from src.models.database import ApiKey, User, UserRole
from src.services.auth.jwt_blacklist import JWTBlacklistService
from src.services.cache.user_cache import UserCacheService
from src.services.user.apikey import ApiKeyService
# JWT配置从config读取
if not config.jwt_secret_key:
# 如果没有配置,生成一个随机密钥并警告
if config.environment == "production":
raise ValueError("JWT_SECRET_KEY must be set in production environment!")
config.jwt_secret_key = secrets.token_urlsafe(32)
logger.warning(f"JWT_SECRET_KEY未在环境变量中找到已生成随机密钥用于开发: {config.jwt_secret_key[:10]}...")
logger.warning("生产环境请设置JWT_SECRET_KEY环境变量!")
JWT_SECRET_KEY = config.jwt_secret_key
JWT_ALGORITHM = config.jwt_algorithm
JWT_EXPIRATION_HOURS = config.jwt_expiration_hours
# Refresh token 有效期设为7天
REFRESH_TOKEN_EXPIRATION_DAYS = 7
class AuthService:
"""认证服务"""
@staticmethod
def create_access_token(data: dict) -> str:
"""创建JWT访问令牌"""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(hours=JWT_EXPIRATION_HOURS)
to_encode.update({"exp": expire, "type": "access"})
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
return encoded_jwt
@staticmethod
def create_refresh_token(data: dict) -> str:
"""创建JWT刷新令牌"""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRATION_DAYS)
to_encode.update({"exp": expire, "type": "refresh"})
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
return encoded_jwt
@staticmethod
async def verify_token(token: str, token_type: Optional[str] = None) -> Dict[str, Any]:
"""验证JWT令牌
Args:
token: JWT token字符串
token_type: 期望的token类型 ('access''refresh')None表示不验证类型
"""
try:
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
# 验证token类型如果指定
if token_type:
actual_type = payload.get("type")
if actual_type != token_type:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Token类型错误: 期望 {token_type}, 实际 {actual_type}",
)
# 检查 Token 是否在黑名单中
is_blacklisted = await JWTBlacklistService.is_blacklisted(token)
if is_blacklisted:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token已被撤销"
)
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token已过期")
except jwt.InvalidTokenError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的Token")
@staticmethod
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
"""用户登录认证"""
# 使用缓存查询用户
user = await UserCacheService.get_user_by_email(db, email)
if not user:
logger.warning(f"登录失败 - 用户不存在: {email}")
return None
if not user.verify_password(password):
logger.warning(f"登录失败 - 密码错误: {email}")
return None
if not user.is_active:
logger.warning(f"登录失败 - 用户已禁用: {email}")
return None
# 更新最后登录时间
# 需要重新从数据库获取以便更新(缓存的对象是分离的)
db_user = db.query(User).filter(User.id == user.id).first()
if db_user:
db_user.last_login_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁
# 清除缓存,因为用户信息已更新
await UserCacheService.invalidate_user_cache(user.id, user.email)
logger.info(f"用户登录成功: {email} (ID: {user.id})")
return user
@staticmethod
def authenticate_api_key(db: Session, api_key: str) -> Optional[tuple[User, ApiKey]]:
"""API密钥认证"""
# 对API密钥进行哈希查找预加载 user 关系以支持后续访问限制检查
key_hash = ApiKey.hash_key(api_key)
key_record = (
db.query(ApiKey)
.options(joinedload(ApiKey.user))
.filter(ApiKey.key_hash == key_hash)
.first()
)
if not key_record:
# 只记录认证失败事件,不记录任何 key 信息以防止信息泄露
logger.warning("API认证失败 - 密钥不存在或无效")
return None
if not key_record.is_active:
logger.warning("API认证失败 - 密钥已禁用")
return None
# 检查过期时间
if key_record.expires_at:
# 确保 expires_at 是 aware datetime
expires_at = key_record.expires_at
if expires_at.tzinfo is None:
# 如果没有时区信息,假定为 UTC
expires_at = expires_at.replace(tzinfo=timezone.utc)
if expires_at < datetime.now(timezone.utc):
logger.warning("API认证失败 - 密钥已过期")
return None
# 检查余额限制仅独立Key
is_balance_ok, remaining = ApiKeyService.check_balance(key_record)
if not is_balance_ok:
# 获取剩余余额用于日志
remaining_balance = ApiKeyService.get_remaining_balance(key_record)
logger.warning(f"API认证失败 - 余额不足 "
f"(已用: ${key_record.balance_used_usd:.4f}, 剩余: ${remaining_balance:.4f})")
return None
# 获取用户
user = key_record.user
if not user.is_active:
logger.warning(f"API认证失败 - 用户已禁用: {user.email}")
return None
# 更新最后使用时间
key_record.last_used_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
logger.debug(f"API认证成功: 用户 {user.email} (Key: {api_key[:10]}...)")
return user, key_record
@staticmethod
def check_user_quota(user: User, estimated_cost: float = 0) -> bool:
"""检查用户配额"""
if user.role == UserRole.ADMIN:
return True # 管理员无限制
# NULL 表示无限制
if user.quota_usd is None:
return True
# 检查美元配额
if user.used_usd + estimated_cost > user.quota_usd:
logger.warning(f"用户配额不足: {user.email} (已用: ${user.used_usd:.2f}, 配额: ${user.quota_usd:.2f})")
return False
return True
@staticmethod
def check_permission(user: User, required_role: UserRole = UserRole.USER) -> bool:
"""检查用户权限"""
if user.role == UserRole.ADMIN:
return True
if user.role.value >= required_role.value:
return True
logger.warning(f"权限不足: 用户 {user.email} 角色 {user.role.value} < 需要 {required_role.value}")
return False
@staticmethod
async def logout(token: str) -> bool:
"""
用户登出,将 Token 加入黑名单
Args:
token: JWT token字符串
Returns:
是否成功登出
"""
try:
# 解码 Token 获取过期时间(不验证黑名单)
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
exp_timestamp = payload.get("exp")
if not exp_timestamp:
logger.warning("Token 缺少过期时间,无法加入黑名单")
return False
# 将 Token 加入黑名单
success = await JWTBlacklistService.add_to_blacklist(
token=token, exp_timestamp=exp_timestamp, reason="logout"
)
if success:
user_id = payload.get("sub")
logger.info(f"用户登出成功: user_id={user_id}")
return success
except jwt.InvalidTokenError as e:
logger.warning(f"登出失败 - 无效的 Token: {e}")
return False
except Exception as e:
logger.error(f"登出失败: {e}")
return False
@staticmethod
async def revoke_token(token: str, reason: str = "revoked") -> bool:
"""
撤销 Token管理员操作
Args:
token: JWT token字符串
reason: 撤销原因
Returns:
是否成功撤销
"""
try:
# 解码 Token 获取过期时间
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
exp_timestamp = payload.get("exp")
if not exp_timestamp:
logger.warning("Token 缺少过期时间,无法撤销")
return False
# 将 Token 加入黑名单
success = await JWTBlacklistService.add_to_blacklist(
token=token, exp_timestamp=exp_timestamp, reason=reason
)
if success:
user_id = payload.get("sub")
logger.warning(f"Token 已被撤销: user_id={user_id}, reason={reason}")
return success
except jwt.InvalidTokenError as e:
logger.warning(f"撤销失败 - 无效的 Token: {e}")
return False
except Exception as e:
logger.error(f"撤销 Token 失败: {e}")
return False