refactor: migrate Pydantic Config to v2 ConfigDict

This commit is contained in:
fawney19
2025-12-18 02:20:53 +08:00
parent b2a857c164
commit 3d0ab353d3
10 changed files with 970 additions and 25 deletions

View File

@@ -0,0 +1 @@
"""服务层测试"""

299
tests/services/test_auth.py Normal file
View File

@@ -0,0 +1,299 @@
"""
认证服务测试
测试 AuthService 的核心功能:
- JWT Token 创建和验证
- 用户登录认证
- API Key 认证
"""
import pytest
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import jwt
from src.services.auth.service import (
AuthService,
JWT_SECRET_KEY,
JWT_ALGORITHM,
JWT_EXPIRATION_HOURS,
)
class TestJWTTokenCreation:
"""测试 JWT Token 创建"""
def test_create_access_token_contains_required_fields(self) -> None:
"""测试访问令牌包含必要字段"""
data = {"sub": "user123", "email": "test@example.com"}
token = AuthService.create_access_token(data)
# 解码验证
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
assert payload["sub"] == "user123"
assert payload["email"] == "test@example.com"
assert payload["type"] == "access"
assert "exp" in payload
def test_create_access_token_expiration(self) -> None:
"""测试访问令牌过期时间正确"""
data = {"sub": "user123"}
token = AuthService.create_access_token(data)
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
# 验证过期时间在预期范围内允许1分钟误差
exp_time = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
expected_exp = datetime.now(timezone.utc) + timedelta(hours=JWT_EXPIRATION_HOURS)
assert abs((exp_time - expected_exp).total_seconds()) < 60
def test_create_refresh_token_type(self) -> None:
"""测试刷新令牌类型正确"""
data = {"sub": "user123"}
token = AuthService.create_refresh_token(data)
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
assert payload["type"] == "refresh"
def test_create_refresh_token_longer_expiration(self) -> None:
"""测试刷新令牌过期时间更长"""
data = {"sub": "user123"}
access_token = AuthService.create_access_token(data)
refresh_token = AuthService.create_refresh_token(data)
access_payload = jwt.decode(access_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
refresh_payload = jwt.decode(refresh_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
# 刷新令牌应该比访问令牌过期时间更长
assert refresh_payload["exp"] > access_payload["exp"]
class TestJWTTokenVerification:
"""测试 JWT Token 验证"""
@pytest.mark.asyncio
async def test_verify_valid_access_token(self) -> None:
"""测试验证有效的访问令牌"""
data = {"sub": "user123", "email": "test@example.com"}
token = AuthService.create_access_token(data)
with patch(
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
new_callable=AsyncMock,
return_value=False,
):
payload = await AuthService.verify_token(token, token_type="access")
assert payload["sub"] == "user123"
assert payload["type"] == "access"
@pytest.mark.asyncio
async def test_verify_expired_token_raises_error(self) -> None:
"""测试验证过期令牌抛出异常"""
# 创建一个已过期的 token
data = {"sub": "user123", "type": "access"}
expire = datetime.now(timezone.utc) - timedelta(hours=1)
data["exp"] = expire
expired_token = jwt.encode(data, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
await AuthService.verify_token(expired_token)
assert exc_info.value.status_code == 401
assert "过期" in exc_info.value.detail
@pytest.mark.asyncio
async def test_verify_invalid_token_raises_error(self) -> None:
"""测试验证无效令牌抛出异常"""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
await AuthService.verify_token("invalid.token.here")
assert exc_info.value.status_code == 401
@pytest.mark.asyncio
async def test_verify_wrong_token_type_raises_error(self) -> None:
"""测试令牌类型不匹配抛出异常"""
data = {"sub": "user123"}
refresh_token = AuthService.create_refresh_token(data)
from fastapi import HTTPException
with patch(
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
new_callable=AsyncMock,
return_value=False,
):
with pytest.raises(HTTPException) as exc_info:
await AuthService.verify_token(refresh_token, token_type="access")
assert exc_info.value.status_code == 401
assert "类型错误" in exc_info.value.detail
@pytest.mark.asyncio
async def test_verify_blacklisted_token_raises_error(self) -> None:
"""测试已撤销的令牌抛出异常"""
data = {"sub": "user123"}
token = AuthService.create_access_token(data)
from fastapi import HTTPException
with patch(
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
new_callable=AsyncMock,
return_value=True,
):
with pytest.raises(HTTPException) as exc_info:
await AuthService.verify_token(token)
assert exc_info.value.status_code == 401
assert "撤销" in exc_info.value.detail
class TestUserAuthentication:
"""测试用户登录认证"""
@pytest.mark.asyncio
async def test_authenticate_user_success(self) -> None:
"""测试用户登录成功"""
# Mock 数据库和用户对象
mock_user = MagicMock()
mock_user.id = "user-123"
mock_user.email = "test@example.com"
mock_user.is_active = True
mock_user.verify_password.return_value = True
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
with patch(
"src.services.auth.service.UserCacheService.invalidate_user_cache",
new_callable=AsyncMock,
):
result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123")
assert result == mock_user
mock_user.verify_password.assert_called_once_with("password123")
mock_db.commit.assert_called_once()
@pytest.mark.asyncio
async def test_authenticate_user_not_found(self) -> None:
"""测试用户不存在"""
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = None
result = await AuthService.authenticate_user(mock_db, "nonexistent@example.com", "password")
assert result is None
@pytest.mark.asyncio
async def test_authenticate_user_wrong_password(self) -> None:
"""测试密码错误"""
mock_user = MagicMock()
mock_user.email = "test@example.com"
mock_user.verify_password.return_value = False
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
result = await AuthService.authenticate_user(mock_db, "test@example.com", "wrongpassword")
assert result is None
@pytest.mark.asyncio
async def test_authenticate_user_inactive(self) -> None:
"""测试用户已禁用"""
mock_user = MagicMock()
mock_user.email = "test@example.com"
mock_user.is_active = False
mock_user.verify_password.return_value = True
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123")
assert result is None
class TestAPIKeyAuthentication:
"""测试 API Key 认证"""
def test_authenticate_api_key_success(self) -> None:
"""测试 API Key 认证成功"""
mock_user = MagicMock()
mock_user.id = "user-123"
mock_user.email = "test@example.com"
mock_user.is_active = True
mock_api_key = MagicMock()
mock_api_key.is_active = True
mock_api_key.expires_at = None
mock_api_key.user = mock_user
mock_api_key.balance_used_usd = 0.0
mock_db = MagicMock()
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
mock_api_key
)
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
with patch(
"src.services.auth.service.ApiKeyService.check_balance",
return_value=(True, 100.0),
):
result = AuthService.authenticate_api_key(mock_db, "sk-test-key")
assert result is not None
assert result[0] == mock_user
assert result[1] == mock_api_key
def test_authenticate_api_key_not_found(self) -> None:
"""测试 API Key 不存在"""
mock_db = MagicMock()
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
None
)
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
result = AuthService.authenticate_api_key(mock_db, "sk-invalid-key")
assert result is None
def test_authenticate_api_key_inactive(self) -> None:
"""测试 API Key 已禁用"""
mock_api_key = MagicMock()
mock_api_key.is_active = False
mock_db = MagicMock()
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
mock_api_key
)
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
result = AuthService.authenticate_api_key(mock_db, "sk-inactive-key")
assert result is None
def test_authenticate_api_key_expired(self) -> None:
"""测试 API Key 已过期"""
mock_api_key = MagicMock()
mock_api_key.is_active = True
mock_api_key.expires_at = datetime.now(timezone.utc) - timedelta(days=1)
mock_db = MagicMock()
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
mock_api_key
)
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
result = AuthService.authenticate_api_key(mock_db, "sk-expired-key")
assert result is None

View File

@@ -0,0 +1,292 @@
"""
UsageService 测试
测试用量统计服务的核心功能:
- 成本计算
- 配额检查
- 用量统计查询
"""
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from src.services.usage.service import UsageService
class TestCostCalculation:
"""测试成本计算"""
def test_calculate_cost_basic(self) -> None:
"""测试基础成本计算"""
# 价格:输入 $3/1M, 输出 $15/1M
result = UsageService.calculate_cost(
input_tokens=1000,
output_tokens=500,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
)
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, request_cost, total_cost = result
# 1000 tokens * $3 / 1M = $0.003
assert abs(input_cost - 0.003) < 0.0001
# 500 tokens * $15 / 1M = $0.0075
assert abs(output_cost - 0.0075) < 0.0001
# Total = $0.003 + $0.0075 = $0.0105
assert abs(total_cost - 0.0105) < 0.0001
def test_calculate_cost_with_cache(self) -> None:
"""测试带缓存的成本计算"""
result = UsageService.calculate_cost(
input_tokens=1000,
output_tokens=500,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
cache_creation_input_tokens=200,
cache_read_input_tokens=300,
cache_creation_price_per_1m=3.75, # 1.25x input price
cache_read_price_per_1m=0.3, # 0.1x input price
)
(
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
cache_cost,
request_cost,
total_cost,
) = result
# 验证缓存成本被计算
assert cache_creation_cost > 0
assert cache_read_cost > 0
assert cache_cost == cache_creation_cost + cache_read_cost
def test_calculate_cost_with_request_price(self) -> None:
"""测试按次计费"""
result = UsageService.calculate_cost(
input_tokens=1000,
output_tokens=500,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
price_per_request=0.01,
)
(
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
cache_cost,
request_cost,
total_cost,
) = result
assert request_cost == 0.01
# Total 包含 request_cost
assert total_cost == input_cost + output_cost + request_cost
def test_calculate_cost_zero_tokens(self) -> None:
"""测试零 token 的成本计算"""
result = UsageService.calculate_cost(
input_tokens=0,
output_tokens=0,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
)
(
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
cache_cost,
request_cost,
total_cost,
) = result
assert input_cost == 0
assert output_cost == 0
assert total_cost == 0
class TestQuotaCheck:
"""测试配额检查"""
def test_check_user_quota_sufficient(self) -> None:
"""测试配额充足"""
mock_user = MagicMock()
mock_user.quota_usd = 100.0
mock_user.used_usd = 30.0
mock_user.role = MagicMock()
mock_user.role.value = "user"
mock_api_key = MagicMock()
mock_api_key.is_standalone = False
mock_db = MagicMock()
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
assert is_ok is True
def test_check_user_quota_exceeded(self) -> None:
"""测试配额超限(当有预估成本时)"""
mock_user = MagicMock()
mock_user.quota_usd = 100.0
mock_user.used_usd = 99.0 # 接近配额上限
mock_user.role = MagicMock()
mock_user.role.value = "user"
mock_api_key = MagicMock()
mock_api_key.is_standalone = False
mock_db = MagicMock()
# 当预估成本超过剩余配额时应该返回 False
is_ok, message = UsageService.check_user_quota(
mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key
)
assert is_ok is False
assert "配额" in message
def test_check_user_quota_no_limit(self) -> None:
"""测试无配额限制None"""
mock_user = MagicMock()
mock_user.quota_usd = None
mock_user.used_usd = 1000.0
mock_user.role = MagicMock()
mock_user.role.value = "user"
mock_api_key = MagicMock()
mock_api_key.is_standalone = False
mock_db = MagicMock()
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
assert is_ok is True
def test_check_user_quota_admin_bypass(self) -> None:
"""测试管理员绕过配额检查"""
from src.models.database import UserRole
mock_user = MagicMock()
mock_user.quota_usd = 0.0
mock_user.used_usd = 1000.0
mock_user.role = UserRole.ADMIN
mock_api_key = MagicMock()
mock_api_key.is_standalone = False
mock_db = MagicMock()
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
assert is_ok is True
def test_check_standalone_api_key_balance(self) -> None:
"""测试独立 API Key 余额检查"""
mock_user = MagicMock()
mock_user.quota_usd = 0.0
mock_user.used_usd = 0.0
mock_user.role = MagicMock()
mock_user.role.value = "user"
mock_api_key = MagicMock()
mock_api_key.is_standalone = True
mock_api_key.current_balance_usd = 50.0
mock_api_key.balance_used_usd = 10.0
mock_db = MagicMock()
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
assert is_ok is True
def test_check_standalone_api_key_insufficient_balance(self) -> None:
"""测试独立 API Key 余额不足"""
mock_user = MagicMock()
mock_user.quota_usd = 100.0
mock_user.used_usd = 0.0
mock_user.role = MagicMock()
mock_user.role.value = "user"
mock_api_key = MagicMock()
mock_api_key.is_standalone = True
mock_api_key.current_balance_usd = 10.0
mock_api_key.balance_used_usd = 9.0 # 剩余 $1
mock_db = MagicMock()
# 需要 mock ApiKeyService.get_remaining_balance
with patch(
"src.services.user.apikey.ApiKeyService.get_remaining_balance",
return_value=1.0,
):
# 预估成本 $5 超过剩余余额 $1
is_ok, message = UsageService.check_user_quota(
mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key
)
assert is_ok is False
class TestUsageStatistics:
"""测试用量统计查询
注意get_usage_summary 方法内部使用了数据库方言特定的日期函数,
需要真实数据库或更复杂的 mock。这里只测试方法存在性。
"""
def test_get_usage_summary_exists(self) -> None:
"""测试 get_usage_summary 方法存在"""
assert hasattr(UsageService, "get_usage_summary")
assert callable(getattr(UsageService, "get_usage_summary"))
class TestHelperMethods:
"""测试辅助方法"""
@pytest.mark.asyncio
async def test_get_rate_multiplier_and_free_tier_default(self) -> None:
"""测试默认费率倍数"""
mock_db = MagicMock()
# 模拟未找到 provider_api_key
mock_db.query.return_value.filter.return_value.first.return_value = None
rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier(
mock_db, provider_api_key_id=None, provider_id=None
)
assert rate_multiplier == 1.0
assert is_free_tier is False
@pytest.mark.asyncio
async def test_get_rate_multiplier_from_provider_api_key(self) -> None:
"""测试从 ProviderAPIKey 获取费率倍数"""
mock_provider_api_key = MagicMock()
mock_provider_api_key.rate_multiplier = 0.8
mock_endpoint = MagicMock()
mock_endpoint.provider_id = "provider-123"
mock_provider = MagicMock()
mock_provider.billing_type = "standard"
mock_db = MagicMock()
# 第一次查询返回 provider_api_key
mock_db.query.return_value.filter.return_value.first.side_effect = [
mock_provider_api_key,
mock_endpoint,
mock_provider,
]
rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier(
mock_db, provider_api_key_id="pak-123", provider_id=None
)
assert rate_multiplier == 0.8
assert is_free_tier is False