2025-12-10 20:52:44 +08:00
|
|
|
|
"""
|
|
|
|
|
|
认证服务
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
import secrets
|
|
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
|
|
from typing import Any, Dict, Optional
|
|
|
|
|
|
|
|
|
|
|
|
import jwt
|
|
|
|
|
|
from fastapi import HTTPException, status
|
|
|
|
|
|
from sqlalchemy.orm import Session, joinedload
|
|
|
|
|
|
|
|
|
|
|
|
from src.config import config
|
|
|
|
|
|
from src.core.crypto import crypto_service
|
|
|
|
|
|
from src.core.logger import logger
|
|
|
|
|
|
from src.models.database import ApiKey, User, UserRole
|
|
|
|
|
|
from src.services.auth.jwt_blacklist import JWTBlacklistService
|
|
|
|
|
|
from src.services.cache.user_cache import UserCacheService
|
|
|
|
|
|
from src.services.user.apikey import ApiKeyService
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# JWT配置从config读取
|
|
|
|
|
|
if not config.jwt_secret_key:
|
|
|
|
|
|
# 如果没有配置,生成一个随机密钥并警告
|
|
|
|
|
|
if config.environment == "production":
|
|
|
|
|
|
raise ValueError("JWT_SECRET_KEY must be set in production environment!")
|
|
|
|
|
|
config.jwt_secret_key = secrets.token_urlsafe(32)
|
|
|
|
|
|
logger.warning(f"JWT_SECRET_KEY未在环境变量中找到,已生成随机密钥用于开发: {config.jwt_secret_key[:10]}...")
|
|
|
|
|
|
logger.warning("生产环境请设置JWT_SECRET_KEY环境变量!")
|
|
|
|
|
|
|
|
|
|
|
|
JWT_SECRET_KEY = config.jwt_secret_key
|
|
|
|
|
|
JWT_ALGORITHM = config.jwt_algorithm
|
|
|
|
|
|
JWT_EXPIRATION_HOURS = config.jwt_expiration_hours
|
|
|
|
|
|
# Refresh token 有效期设为7天
|
|
|
|
|
|
REFRESH_TOKEN_EXPIRATION_DAYS = 7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AuthService:
|
|
|
|
|
|
"""认证服务"""
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def create_access_token(data: dict) -> str:
|
|
|
|
|
|
"""创建JWT访问令牌"""
|
|
|
|
|
|
to_encode = data.copy()
|
|
|
|
|
|
expire = datetime.now(timezone.utc) + timedelta(hours=JWT_EXPIRATION_HOURS)
|
|
|
|
|
|
to_encode.update({"exp": expire, "type": "access"})
|
|
|
|
|
|
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
|
|
|
|
|
return encoded_jwt
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def create_refresh_token(data: dict) -> str:
|
|
|
|
|
|
"""创建JWT刷新令牌"""
|
|
|
|
|
|
to_encode = data.copy()
|
|
|
|
|
|
expire = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRATION_DAYS)
|
|
|
|
|
|
to_encode.update({"exp": expire, "type": "refresh"})
|
|
|
|
|
|
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
|
|
|
|
|
return encoded_jwt
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
async def verify_token(token: str, token_type: Optional[str] = None) -> Dict[str, Any]:
|
|
|
|
|
|
"""验证JWT令牌
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
token: JWT token字符串
|
|
|
|
|
|
token_type: 期望的token类型 ('access' 或 'refresh'),None表示不验证类型
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
|
|
|
|
|
|
|
|
|
|
|
# 验证token类型(如果指定)
|
|
|
|
|
|
if token_type:
|
|
|
|
|
|
actual_type = payload.get("type")
|
|
|
|
|
|
if actual_type != token_type:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
|
|
detail=f"Token类型错误: 期望 {token_type}, 实际 {actual_type}",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 检查 Token 是否在黑名单中
|
|
|
|
|
|
is_blacklisted = await JWTBlacklistService.is_blacklisted(token)
|
|
|
|
|
|
if is_blacklisted:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token已被撤销"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return payload
|
|
|
|
|
|
except jwt.ExpiredSignatureError:
|
|
|
|
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token已过期")
|
|
|
|
|
|
except jwt.InvalidTokenError:
|
|
|
|
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的Token")
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
|
|
|
|
|
"""用户登录认证"""
|
2025-12-18 01:09:22 +08:00
|
|
|
|
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
|
|
|
|
|
|
user = db.query(User).filter(User.email == email).first()
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
if not user:
|
|
|
|
|
|
logger.warning(f"登录失败 - 用户不存在: {email}")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
if not user.verify_password(password):
|
|
|
|
|
|
logger.warning(f"登录失败 - 密码错误: {email}")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
if not user.is_active:
|
|
|
|
|
|
logger.warning(f"登录失败 - 用户已禁用: {email}")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# 更新最后登录时间
|
2025-12-18 01:09:22 +08:00
|
|
|
|
user.last_login_at = datetime.now(timezone.utc)
|
|
|
|
|
|
db.commit() # 立即提交事务,释放数据库锁
|
|
|
|
|
|
# 清除缓存,因为用户信息已更新
|
|
|
|
|
|
await UserCacheService.invalidate_user_cache(user.id, user.email)
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
logger.info(f"用户登录成功: {email} (ID: {user.id})")
|
|
|
|
|
|
return user
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def authenticate_api_key(db: Session, api_key: str) -> Optional[tuple[User, ApiKey]]:
|
|
|
|
|
|
"""API密钥认证"""
|
|
|
|
|
|
# 对API密钥进行哈希查找,预加载 user 关系以支持后续访问限制检查
|
|
|
|
|
|
key_hash = ApiKey.hash_key(api_key)
|
|
|
|
|
|
key_record = (
|
|
|
|
|
|
db.query(ApiKey)
|
|
|
|
|
|
.options(joinedload(ApiKey.user))
|
|
|
|
|
|
.filter(ApiKey.key_hash == key_hash)
|
|
|
|
|
|
.first()
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if not key_record:
|
|
|
|
|
|
# 只记录认证失败事件,不记录任何 key 信息以防止信息泄露
|
|
|
|
|
|
logger.warning("API认证失败 - 密钥不存在或无效")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
if not key_record.is_active:
|
|
|
|
|
|
logger.warning("API认证失败 - 密钥已禁用")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# 检查过期时间
|
|
|
|
|
|
if key_record.expires_at:
|
|
|
|
|
|
# 确保 expires_at 是 aware datetime
|
|
|
|
|
|
expires_at = key_record.expires_at
|
|
|
|
|
|
if expires_at.tzinfo is None:
|
|
|
|
|
|
# 如果没有时区信息,假定为 UTC
|
|
|
|
|
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
|
|
|
|
|
|
|
|
|
|
|
if expires_at < datetime.now(timezone.utc):
|
|
|
|
|
|
logger.warning("API认证失败 - 密钥已过期")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# 检查余额限制(仅独立Key)
|
|
|
|
|
|
is_balance_ok, remaining = ApiKeyService.check_balance(key_record)
|
|
|
|
|
|
if not is_balance_ok:
|
|
|
|
|
|
# 获取剩余余额用于日志
|
|
|
|
|
|
remaining_balance = ApiKeyService.get_remaining_balance(key_record)
|
|
|
|
|
|
logger.warning(f"API认证失败 - 余额不足 "
|
|
|
|
|
|
f"(已用: ${key_record.balance_used_usd:.4f}, 剩余: ${remaining_balance:.4f})")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# 获取用户
|
|
|
|
|
|
user = key_record.user
|
|
|
|
|
|
if not user.is_active:
|
|
|
|
|
|
logger.warning(f"API认证失败 - 用户已禁用: {user.email}")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# 更新最后使用时间
|
|
|
|
|
|
key_record.last_used_at = datetime.now(timezone.utc)
|
|
|
|
|
|
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug(f"API认证成功: 用户 {user.email} (Key: {api_key[:10]}...)")
|
|
|
|
|
|
return user, key_record
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def check_user_quota(user: User, estimated_cost: float = 0) -> bool:
|
|
|
|
|
|
"""检查用户配额"""
|
|
|
|
|
|
if user.role == UserRole.ADMIN:
|
|
|
|
|
|
return True # 管理员无限制
|
|
|
|
|
|
|
|
|
|
|
|
# NULL 表示无限制
|
|
|
|
|
|
if user.quota_usd is None:
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
# 检查美元配额
|
|
|
|
|
|
if user.used_usd + estimated_cost > user.quota_usd:
|
|
|
|
|
|
logger.warning(f"用户配额不足: {user.email} (已用: ${user.used_usd:.2f}, 配额: ${user.quota_usd:.2f})")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def check_permission(user: User, required_role: UserRole = UserRole.USER) -> bool:
|
|
|
|
|
|
"""检查用户权限"""
|
|
|
|
|
|
if user.role == UserRole.ADMIN:
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
2025-12-18 01:09:22 +08:00
|
|
|
|
# 避免使用字符串比较导致权限判断错误(例如 'user' >= 'admin')
|
|
|
|
|
|
role_rank = {UserRole.USER: 0, UserRole.ADMIN: 1}
|
|
|
|
|
|
# 未知用户角色默认 -1(拒绝),未知要求角色默认 999(拒绝)
|
|
|
|
|
|
if role_rank.get(user.role, -1) >= role_rank.get(required_role, 999):
|
2025-12-10 20:52:44 +08:00
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
logger.warning(f"权限不足: 用户 {user.email} 角色 {user.role.value} < 需要 {required_role.value}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
async def logout(token: str) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
用户登出,将 Token 加入黑名单
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
token: JWT token字符串
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
是否成功登出
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 解码 Token 获取过期时间(不验证黑名单)
|
|
|
|
|
|
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
|
|
|
|
|
exp_timestamp = payload.get("exp")
|
|
|
|
|
|
|
|
|
|
|
|
if not exp_timestamp:
|
|
|
|
|
|
logger.warning("Token 缺少过期时间,无法加入黑名单")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 将 Token 加入黑名单
|
|
|
|
|
|
success = await JWTBlacklistService.add_to_blacklist(
|
|
|
|
|
|
token=token, exp_timestamp=exp_timestamp, reason="logout"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if success:
|
2025-12-18 01:09:22 +08:00
|
|
|
|
user_id = payload.get("user_id")
|
2025-12-10 20:52:44 +08:00
|
|
|
|
logger.info(f"用户登出成功: user_id={user_id}")
|
|
|
|
|
|
|
|
|
|
|
|
return success
|
|
|
|
|
|
|
|
|
|
|
|
except jwt.InvalidTokenError as e:
|
|
|
|
|
|
logger.warning(f"登出失败 - 无效的 Token: {e}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"登出失败: {e}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
async def revoke_token(token: str, reason: str = "revoked") -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
撤销 Token(管理员操作)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
token: JWT token字符串
|
|
|
|
|
|
reason: 撤销原因
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
是否成功撤销
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 解码 Token 获取过期时间
|
|
|
|
|
|
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
|
|
|
|
|
exp_timestamp = payload.get("exp")
|
|
|
|
|
|
|
|
|
|
|
|
if not exp_timestamp:
|
|
|
|
|
|
logger.warning("Token 缺少过期时间,无法撤销")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 将 Token 加入黑名单
|
|
|
|
|
|
success = await JWTBlacklistService.add_to_blacklist(
|
|
|
|
|
|
token=token, exp_timestamp=exp_timestamp, reason=reason
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if success:
|
|
|
|
|
|
user_id = payload.get("sub")
|
|
|
|
|
|
logger.warning(f"Token 已被撤销: user_id={user_id}, reason={reason}")
|
|
|
|
|
|
|
|
|
|
|
|
return success
|
|
|
|
|
|
|
|
|
|
|
|
except jwt.InvalidTokenError as e:
|
|
|
|
|
|
logger.warning(f"撤销失败 - 无效的 Token: {e}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"撤销 Token 失败: {e}")
|
|
|
|
|
|
return False
|