""" 认证服务测试 测试 AuthService 的核心功能: - JWT Token 创建和验证 - 用户登录认证 - API Key 认证 """ import pytest from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock, patch import jwt from src.services.auth.service import ( AuthService, JWT_SECRET_KEY, JWT_ALGORITHM, JWT_EXPIRATION_HOURS, ) class TestJWTTokenCreation: """测试 JWT Token 创建""" def test_create_access_token_contains_required_fields(self) -> None: """测试访问令牌包含必要字段""" data = {"sub": "user123", "email": "test@example.com"} token = AuthService.create_access_token(data) # 解码验证 payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) assert payload["sub"] == "user123" assert payload["email"] == "test@example.com" assert payload["type"] == "access" assert "exp" in payload def test_create_access_token_expiration(self) -> None: """测试访问令牌过期时间正确""" data = {"sub": "user123"} token = AuthService.create_access_token(data) payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) # 验证过期时间在预期范围内(允许1分钟误差) exp_time = datetime.fromtimestamp(payload["exp"], tz=timezone.utc) expected_exp = datetime.now(timezone.utc) + timedelta(hours=JWT_EXPIRATION_HOURS) assert abs((exp_time - expected_exp).total_seconds()) < 60 def test_create_refresh_token_type(self) -> None: """测试刷新令牌类型正确""" data = {"sub": "user123"} token = AuthService.create_refresh_token(data) payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) assert payload["type"] == "refresh" def test_create_refresh_token_longer_expiration(self) -> None: """测试刷新令牌过期时间更长""" data = {"sub": "user123"} access_token = AuthService.create_access_token(data) refresh_token = AuthService.create_refresh_token(data) access_payload = jwt.decode(access_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) refresh_payload = jwt.decode(refresh_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) # 刷新令牌应该比访问令牌过期时间更长 assert refresh_payload["exp"] > access_payload["exp"] class TestJWTTokenVerification: """测试 JWT Token 验证""" @pytest.mark.asyncio async def test_verify_valid_access_token(self) -> None: """测试验证有效的访问令牌""" data = {"sub": "user123", "email": "test@example.com"} token = AuthService.create_access_token(data) with patch( "src.services.auth.service.JWTBlacklistService.is_blacklisted", new_callable=AsyncMock, return_value=False, ): payload = await AuthService.verify_token(token, token_type="access") assert payload["sub"] == "user123" assert payload["type"] == "access" @pytest.mark.asyncio async def test_verify_expired_token_raises_error(self) -> None: """测试验证过期令牌抛出异常""" # 创建一个已过期的 token data = {"sub": "user123", "type": "access"} expire = datetime.now(timezone.utc) - timedelta(hours=1) data["exp"] = expire expired_token = jwt.encode(data, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: await AuthService.verify_token(expired_token) assert exc_info.value.status_code == 401 assert "过期" in exc_info.value.detail @pytest.mark.asyncio async def test_verify_invalid_token_raises_error(self) -> None: """测试验证无效令牌抛出异常""" from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: await AuthService.verify_token("invalid.token.here") assert exc_info.value.status_code == 401 @pytest.mark.asyncio async def test_verify_wrong_token_type_raises_error(self) -> None: """测试令牌类型不匹配抛出异常""" data = {"sub": "user123"} refresh_token = AuthService.create_refresh_token(data) from fastapi import HTTPException with patch( "src.services.auth.service.JWTBlacklistService.is_blacklisted", new_callable=AsyncMock, return_value=False, ): with pytest.raises(HTTPException) as exc_info: await AuthService.verify_token(refresh_token, token_type="access") assert exc_info.value.status_code == 401 assert "类型错误" in exc_info.value.detail @pytest.mark.asyncio async def test_verify_blacklisted_token_raises_error(self) -> None: """测试已撤销的令牌抛出异常""" data = {"sub": "user123"} token = AuthService.create_access_token(data) from fastapi import HTTPException with patch( "src.services.auth.service.JWTBlacklistService.is_blacklisted", new_callable=AsyncMock, return_value=True, ): with pytest.raises(HTTPException) as exc_info: await AuthService.verify_token(token) assert exc_info.value.status_code == 401 assert "撤销" in exc_info.value.detail class TestUserAuthentication: """测试用户登录认证""" @pytest.mark.asyncio async def test_authenticate_user_success(self) -> None: """测试用户登录成功""" # Mock 数据库和用户对象 mock_user = MagicMock() mock_user.id = "user-123" mock_user.email = "test@example.com" mock_user.is_active = True mock_user.verify_password.return_value = True mock_db = MagicMock() mock_db.query.return_value.filter.return_value.first.return_value = mock_user with patch( "src.services.auth.service.UserCacheService.invalidate_user_cache", new_callable=AsyncMock, ): result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123") assert result == mock_user mock_user.verify_password.assert_called_once_with("password123") mock_db.commit.assert_called_once() @pytest.mark.asyncio async def test_authenticate_user_not_found(self) -> None: """测试用户不存在""" mock_db = MagicMock() mock_db.query.return_value.filter.return_value.first.return_value = None result = await AuthService.authenticate_user(mock_db, "nonexistent@example.com", "password") assert result is None @pytest.mark.asyncio async def test_authenticate_user_wrong_password(self) -> None: """测试密码错误""" mock_user = MagicMock() mock_user.email = "test@example.com" mock_user.verify_password.return_value = False mock_db = MagicMock() mock_db.query.return_value.filter.return_value.first.return_value = mock_user result = await AuthService.authenticate_user(mock_db, "test@example.com", "wrongpassword") assert result is None @pytest.mark.asyncio async def test_authenticate_user_inactive(self) -> None: """测试用户已禁用""" mock_user = MagicMock() mock_user.email = "test@example.com" mock_user.is_active = False mock_user.verify_password.return_value = True mock_db = MagicMock() mock_db.query.return_value.filter.return_value.first.return_value = mock_user result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123") assert result is None class TestAPIKeyAuthentication: """测试 API Key 认证""" def test_authenticate_api_key_success(self) -> None: """测试 API Key 认证成功""" mock_user = MagicMock() mock_user.id = "user-123" mock_user.email = "test@example.com" mock_user.is_active = True mock_api_key = MagicMock() mock_api_key.is_active = True mock_api_key.expires_at = None mock_api_key.user = mock_user mock_api_key.balance_used_usd = 0.0 mock_db = MagicMock() mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = ( mock_api_key ) with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"): with patch( "src.services.auth.service.ApiKeyService.check_balance", return_value=(True, 100.0), ): result = AuthService.authenticate_api_key(mock_db, "sk-test-key") assert result is not None assert result[0] == mock_user assert result[1] == mock_api_key def test_authenticate_api_key_not_found(self) -> None: """测试 API Key 不存在""" mock_db = MagicMock() mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = ( None ) with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"): result = AuthService.authenticate_api_key(mock_db, "sk-invalid-key") assert result is None def test_authenticate_api_key_inactive(self) -> None: """测试 API Key 已禁用""" mock_api_key = MagicMock() mock_api_key.is_active = False mock_db = MagicMock() mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = ( mock_api_key ) with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"): result = AuthService.authenticate_api_key(mock_db, "sk-inactive-key") assert result is None def test_authenticate_api_key_expired(self) -> None: """测试 API Key 已过期""" mock_api_key = MagicMock() mock_api_key.is_active = True mock_api_key.expires_at = datetime.now(timezone.utc) - timedelta(days=1) mock_db = MagicMock() mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = ( mock_api_key ) with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"): result = AuthService.authenticate_api_key(mock_db, "sk-expired-key") assert result is None