mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
refactor: migrate Pydantic Config to v2 ConfigDict
This commit is contained in:
363
tests/api/test_pipeline.py
Normal file
363
tests/api/test_pipeline.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
API Pipeline 测试
|
||||
|
||||
测试 ApiRequestPipeline 的核心功能:
|
||||
- 认证流程(API Key、JWT Token)
|
||||
- 配额计算
|
||||
- 审计日志记录
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
|
||||
|
||||
class TestPipelineQuotaCalculation:
|
||||
"""测试 Pipeline 配额计算"""
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self) -> ApiRequestPipeline:
|
||||
return ApiRequestPipeline()
|
||||
|
||||
def test_calculate_quota_remaining_with_quota(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试有配额限制时计算剩余配额"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 30.0
|
||||
|
||||
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||
|
||||
assert remaining == 70.0
|
||||
|
||||
def test_calculate_quota_remaining_no_quota(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试无配额限制时返回 None"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = None
|
||||
mock_user.used_usd = 30.0
|
||||
|
||||
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||
|
||||
assert remaining is None
|
||||
|
||||
def test_calculate_quota_remaining_negative_quota(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试负配额时返回 None"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = -1
|
||||
mock_user.used_usd = 0.0
|
||||
|
||||
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||
|
||||
assert remaining is None
|
||||
|
||||
def test_calculate_quota_remaining_exceeded(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试配额已超时返回 0"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 150.0
|
||||
|
||||
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||
|
||||
assert remaining == 0.0
|
||||
|
||||
def test_calculate_quota_remaining_none_user(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试用户为 None 时返回 None"""
|
||||
remaining = pipeline._calculate_quota_remaining(None)
|
||||
|
||||
assert remaining is None
|
||||
|
||||
|
||||
class TestPipelineAuditLogging:
|
||||
"""测试 Pipeline 审计日志"""
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self) -> ApiRequestPipeline:
|
||||
return ApiRequestPipeline()
|
||||
|
||||
def test_record_audit_event_success(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试记录成功的审计事件"""
|
||||
mock_context = MagicMock()
|
||||
mock_context.db = MagicMock()
|
||||
mock_context.user = MagicMock()
|
||||
mock_context.user.id = "user-123"
|
||||
mock_context.api_key = MagicMock()
|
||||
mock_context.api_key.id = "key-123"
|
||||
mock_context.request_id = "req-123"
|
||||
mock_context.client_ip = "127.0.0.1"
|
||||
mock_context.user_agent = "test-agent"
|
||||
mock_context.request = MagicMock()
|
||||
mock_context.request.method = "POST"
|
||||
mock_context.request.url.path = "/v1/messages"
|
||||
mock_context.start_time = 1000.0
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test-adapter"
|
||||
mock_adapter.audit_log_enabled = True
|
||||
mock_adapter.audit_success_event = None
|
||||
mock_adapter.audit_failure_event = None
|
||||
|
||||
with patch.object(
|
||||
pipeline.audit_service,
|
||||
"log_event",
|
||||
) as mock_log:
|
||||
with patch("time.time", return_value=1001.0):
|
||||
pipeline._record_audit_event(
|
||||
mock_context, mock_adapter, success=True, status_code=200
|
||||
)
|
||||
|
||||
mock_log.assert_called_once()
|
||||
call_kwargs = mock_log.call_args[1]
|
||||
assert call_kwargs["user_id"] == "user-123"
|
||||
assert call_kwargs["status_code"] == 200
|
||||
|
||||
def test_record_audit_event_failure(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试记录失败的审计事件"""
|
||||
mock_context = MagicMock()
|
||||
mock_context.db = MagicMock()
|
||||
mock_context.user = MagicMock()
|
||||
mock_context.user.id = "user-123"
|
||||
mock_context.api_key = MagicMock()
|
||||
mock_context.api_key.id = "key-123"
|
||||
mock_context.request_id = "req-123"
|
||||
mock_context.client_ip = "127.0.0.1"
|
||||
mock_context.user_agent = "test-agent"
|
||||
mock_context.request = MagicMock()
|
||||
mock_context.request.method = "POST"
|
||||
mock_context.request.url.path = "/v1/messages"
|
||||
mock_context.start_time = 1000.0
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test-adapter"
|
||||
mock_adapter.audit_log_enabled = True
|
||||
mock_adapter.audit_success_event = None
|
||||
mock_adapter.audit_failure_event = None
|
||||
|
||||
with patch.object(
|
||||
pipeline.audit_service,
|
||||
"log_event",
|
||||
) as mock_log:
|
||||
with patch("time.time", return_value=1001.0):
|
||||
pipeline._record_audit_event(
|
||||
mock_context, mock_adapter, success=False, status_code=500, error="Internal error"
|
||||
)
|
||||
|
||||
mock_log.assert_called_once()
|
||||
call_kwargs = mock_log.call_args[1]
|
||||
assert call_kwargs["status_code"] == 500
|
||||
assert call_kwargs["error_message"] == "Internal error"
|
||||
|
||||
def test_record_audit_event_no_db(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试没有数据库会话时跳过审计"""
|
||||
mock_context = MagicMock()
|
||||
mock_context.db = None
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.audit_log_enabled = True
|
||||
|
||||
with patch.object(
|
||||
pipeline.audit_service,
|
||||
"log_event",
|
||||
) as mock_log:
|
||||
# 不应该抛出异常
|
||||
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
|
||||
|
||||
# 不应该调用 log_event
|
||||
mock_log.assert_not_called()
|
||||
|
||||
def test_record_audit_event_disabled(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试审计日志被禁用时跳过"""
|
||||
mock_context = MagicMock()
|
||||
mock_context.db = MagicMock()
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.audit_log_enabled = False
|
||||
|
||||
with patch.object(
|
||||
pipeline.audit_service,
|
||||
"log_event",
|
||||
) as mock_log:
|
||||
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
|
||||
|
||||
mock_log.assert_not_called()
|
||||
|
||||
def test_record_audit_event_exception_handling(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试审计日志异常不影响主流程"""
|
||||
mock_context = MagicMock()
|
||||
mock_context.db = MagicMock()
|
||||
mock_context.user = MagicMock()
|
||||
mock_context.user.id = "user-123"
|
||||
mock_context.api_key = MagicMock()
|
||||
mock_context.api_key.id = "key-123"
|
||||
mock_context.request_id = "req-123"
|
||||
mock_context.client_ip = "127.0.0.1"
|
||||
mock_context.user_agent = "test-agent"
|
||||
mock_context.request = MagicMock()
|
||||
mock_context.request.method = "POST"
|
||||
mock_context.request.url.path = "/v1/messages"
|
||||
mock_context.start_time = 1000.0
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test-adapter"
|
||||
mock_adapter.audit_log_enabled = True
|
||||
mock_adapter.audit_success_event = None
|
||||
|
||||
with patch.object(
|
||||
pipeline.audit_service,
|
||||
"log_event",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
with patch("time.time", return_value=1001.0):
|
||||
# 不应该抛出异常
|
||||
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
|
||||
|
||||
|
||||
class TestPipelineAuthentication:
|
||||
"""测试 Pipeline 认证相关逻辑"""
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self) -> ApiRequestPipeline:
|
||||
return ApiRequestPipeline()
|
||||
|
||||
def test_authenticate_client_missing_key(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试缺少 API Key 时抛出异常"""
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_request.url.path = "/v1/messages"
|
||||
mock_request.state = MagicMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.extract_api_key = MagicMock(return_value=None)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "API密钥" in exc_info.value.detail
|
||||
|
||||
def test_authenticate_client_invalid_key(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试无效的 API Key"""
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {"Authorization": "Bearer sk-invalid"}
|
||||
mock_request.url.path = "/v1/messages"
|
||||
mock_request.state = MagicMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.extract_api_key = MagicMock(return_value="sk-invalid")
|
||||
|
||||
with patch.object(
|
||||
pipeline.auth_service,
|
||||
"authenticate_api_key",
|
||||
return_value=None,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_authenticate_client_quota_exceeded(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试配额超限时抛出异常"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user-123"
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 100.0
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.id = "key-123"
|
||||
mock_api_key.is_standalone = False
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {"Authorization": "Bearer sk-test"}
|
||||
mock_request.url.path = "/v1/messages"
|
||||
mock_request.state = MagicMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.extract_api_key = MagicMock(return_value="sk-test")
|
||||
|
||||
with patch.object(
|
||||
pipeline.auth_service,
|
||||
"authenticate_api_key",
|
||||
return_value=(mock_user, mock_api_key),
|
||||
):
|
||||
with patch.object(
|
||||
pipeline.usage_service,
|
||||
"check_user_quota",
|
||||
return_value=(False, "配额不足"),
|
||||
):
|
||||
from src.core.exceptions import QuotaExceededException
|
||||
|
||||
with pytest.raises(QuotaExceededException):
|
||||
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
|
||||
|
||||
|
||||
class TestPipelineAdminAuth:
|
||||
"""测试管理员认证"""
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self) -> ApiRequestPipeline:
|
||||
return ApiRequestPipeline()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_admin_missing_token(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试缺少管理员令牌"""
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await pipeline._authenticate_admin(mock_request, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "管理员凭证" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_admin_invalid_token(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试无效的管理员令牌"""
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {"authorization": "Bearer invalid-token"}
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch.object(
|
||||
pipeline.auth_service,
|
||||
"verify_token",
|
||||
side_effect=HTTPException(status_code=401, detail="Invalid token"),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await pipeline._authenticate_admin(mock_request, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_admin_success(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试管理员认证成功"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "admin-123"
|
||||
mock_user.is_active = True
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {"authorization": "Bearer valid-token"}
|
||||
mock_request.state = MagicMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
with patch.object(
|
||||
pipeline.auth_service,
|
||||
"verify_token",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"user_id": "admin-123"},
|
||||
):
|
||||
result = await pipeline._authenticate_admin(mock_request, mock_db)
|
||||
|
||||
assert result == mock_user
|
||||
assert mock_request.state.user_id == "admin-123"
|
||||
1
tests/services/__init__.py
Normal file
1
tests/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""服务层测试"""
|
||||
299
tests/services/test_auth.py
Normal file
299
tests/services/test_auth.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
认证服务测试
|
||||
|
||||
测试 AuthService 的核心功能:
|
||||
- JWT Token 创建和验证
|
||||
- 用户登录认证
|
||||
- API Key 认证
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import jwt
|
||||
|
||||
from src.services.auth.service import (
|
||||
AuthService,
|
||||
JWT_SECRET_KEY,
|
||||
JWT_ALGORITHM,
|
||||
JWT_EXPIRATION_HOURS,
|
||||
)
|
||||
|
||||
|
||||
class TestJWTTokenCreation:
|
||||
"""测试 JWT Token 创建"""
|
||||
|
||||
def test_create_access_token_contains_required_fields(self) -> None:
|
||||
"""测试访问令牌包含必要字段"""
|
||||
data = {"sub": "user123", "email": "test@example.com"}
|
||||
token = AuthService.create_access_token(data)
|
||||
|
||||
# 解码验证
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
assert payload["sub"] == "user123"
|
||||
assert payload["email"] == "test@example.com"
|
||||
assert payload["type"] == "access"
|
||||
assert "exp" in payload
|
||||
|
||||
def test_create_access_token_expiration(self) -> None:
|
||||
"""测试访问令牌过期时间正确"""
|
||||
data = {"sub": "user123"}
|
||||
token = AuthService.create_access_token(data)
|
||||
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
# 验证过期时间在预期范围内(允许1分钟误差)
|
||||
exp_time = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||
expected_exp = datetime.now(timezone.utc) + timedelta(hours=JWT_EXPIRATION_HOURS)
|
||||
|
||||
assert abs((exp_time - expected_exp).total_seconds()) < 60
|
||||
|
||||
def test_create_refresh_token_type(self) -> None:
|
||||
"""测试刷新令牌类型正确"""
|
||||
data = {"sub": "user123"}
|
||||
token = AuthService.create_refresh_token(data)
|
||||
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
def test_create_refresh_token_longer_expiration(self) -> None:
|
||||
"""测试刷新令牌过期时间更长"""
|
||||
data = {"sub": "user123"}
|
||||
access_token = AuthService.create_access_token(data)
|
||||
refresh_token = AuthService.create_refresh_token(data)
|
||||
|
||||
access_payload = jwt.decode(access_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
refresh_payload = jwt.decode(refresh_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
# 刷新令牌应该比访问令牌过期时间更长
|
||||
assert refresh_payload["exp"] > access_payload["exp"]
|
||||
|
||||
|
||||
class TestJWTTokenVerification:
|
||||
"""测试 JWT Token 验证"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_valid_access_token(self) -> None:
|
||||
"""测试验证有效的访问令牌"""
|
||||
data = {"sub": "user123", "email": "test@example.com"}
|
||||
token = AuthService.create_access_token(data)
|
||||
|
||||
with patch(
|
||||
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
):
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
|
||||
assert payload["sub"] == "user123"
|
||||
assert payload["type"] == "access"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_expired_token_raises_error(self) -> None:
|
||||
"""测试验证过期令牌抛出异常"""
|
||||
# 创建一个已过期的 token
|
||||
data = {"sub": "user123", "type": "access"}
|
||||
expire = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
data["exp"] = expire
|
||||
expired_token = jwt.encode(data, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await AuthService.verify_token(expired_token)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "过期" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_invalid_token_raises_error(self) -> None:
|
||||
"""测试验证无效令牌抛出异常"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await AuthService.verify_token("invalid.token.here")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_wrong_token_type_raises_error(self) -> None:
|
||||
"""测试令牌类型不匹配抛出异常"""
|
||||
data = {"sub": "user123"}
|
||||
refresh_token = AuthService.create_refresh_token(data)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with patch(
|
||||
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await AuthService.verify_token(refresh_token, token_type="access")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "类型错误" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_blacklisted_token_raises_error(self) -> None:
|
||||
"""测试已撤销的令牌抛出异常"""
|
||||
data = {"sub": "user123"}
|
||||
token = AuthService.create_access_token(data)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with patch(
|
||||
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await AuthService.verify_token(token)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "撤销" in exc_info.value.detail
|
||||
|
||||
|
||||
class TestUserAuthentication:
|
||||
"""测试用户登录认证"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_success(self) -> None:
|
||||
"""测试用户登录成功"""
|
||||
# Mock 数据库和用户对象
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user-123"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.is_active = True
|
||||
mock_user.verify_password.return_value = True
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
with patch(
|
||||
"src.services.auth.service.UserCacheService.invalidate_user_cache",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123")
|
||||
|
||||
assert result == mock_user
|
||||
mock_user.verify_password.assert_called_once_with("password123")
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_not_found(self) -> None:
|
||||
"""测试用户不存在"""
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
result = await AuthService.authenticate_user(mock_db, "nonexistent@example.com", "password")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_wrong_password(self) -> None:
|
||||
"""测试密码错误"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.verify_password.return_value = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
result = await AuthService.authenticate_user(mock_db, "test@example.com", "wrongpassword")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_inactive(self) -> None:
|
||||
"""测试用户已禁用"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.is_active = False
|
||||
mock_user.verify_password.return_value = True
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestAPIKeyAuthentication:
|
||||
"""测试 API Key 认证"""
|
||||
|
||||
def test_authenticate_api_key_success(self) -> None:
|
||||
"""测试 API Key 认证成功"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user-123"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.is_active = True
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_active = True
|
||||
mock_api_key.expires_at = None
|
||||
mock_api_key.user = mock_user
|
||||
mock_api_key.balance_used_usd = 0.0
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||
mock_api_key
|
||||
)
|
||||
|
||||
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||
with patch(
|
||||
"src.services.auth.service.ApiKeyService.check_balance",
|
||||
return_value=(True, 100.0),
|
||||
):
|
||||
result = AuthService.authenticate_api_key(mock_db, "sk-test-key")
|
||||
|
||||
assert result is not None
|
||||
assert result[0] == mock_user
|
||||
assert result[1] == mock_api_key
|
||||
|
||||
def test_authenticate_api_key_not_found(self) -> None:
|
||||
"""测试 API Key 不存在"""
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||
None
|
||||
)
|
||||
|
||||
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||
result = AuthService.authenticate_api_key(mock_db, "sk-invalid-key")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_authenticate_api_key_inactive(self) -> None:
|
||||
"""测试 API Key 已禁用"""
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_active = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||
mock_api_key
|
||||
)
|
||||
|
||||
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||
result = AuthService.authenticate_api_key(mock_db, "sk-inactive-key")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_authenticate_api_key_expired(self) -> None:
|
||||
"""测试 API Key 已过期"""
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_active = True
|
||||
mock_api_key.expires_at = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||
mock_api_key
|
||||
)
|
||||
|
||||
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||
result = AuthService.authenticate_api_key(mock_db, "sk-expired-key")
|
||||
|
||||
assert result is None
|
||||
292
tests/services/test_usage_service.py
Normal file
292
tests/services/test_usage_service.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
UsageService 测试
|
||||
|
||||
测试用量统计服务的核心功能:
|
||||
- 成本计算
|
||||
- 配额检查
|
||||
- 用量统计查询
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
|
||||
class TestCostCalculation:
|
||||
"""测试成本计算"""
|
||||
|
||||
def test_calculate_cost_basic(self) -> None:
|
||||
"""测试基础成本计算"""
|
||||
# 价格:输入 $3/1M, 输出 $15/1M
|
||||
result = UsageService.calculate_cost(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
input_price_per_1m=3.0,
|
||||
output_price_per_1m=15.0,
|
||||
)
|
||||
|
||||
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, request_cost, total_cost = result
|
||||
|
||||
# 1000 tokens * $3 / 1M = $0.003
|
||||
assert abs(input_cost - 0.003) < 0.0001
|
||||
# 500 tokens * $15 / 1M = $0.0075
|
||||
assert abs(output_cost - 0.0075) < 0.0001
|
||||
# Total = $0.003 + $0.0075 = $0.0105
|
||||
assert abs(total_cost - 0.0105) < 0.0001
|
||||
|
||||
def test_calculate_cost_with_cache(self) -> None:
|
||||
"""测试带缓存的成本计算"""
|
||||
result = UsageService.calculate_cost(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
input_price_per_1m=3.0,
|
||||
output_price_per_1m=15.0,
|
||||
cache_creation_input_tokens=200,
|
||||
cache_read_input_tokens=300,
|
||||
cache_creation_price_per_1m=3.75, # 1.25x input price
|
||||
cache_read_price_per_1m=0.3, # 0.1x input price
|
||||
)
|
||||
|
||||
(
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
cache_cost,
|
||||
request_cost,
|
||||
total_cost,
|
||||
) = result
|
||||
|
||||
# 验证缓存成本被计算
|
||||
assert cache_creation_cost > 0
|
||||
assert cache_read_cost > 0
|
||||
assert cache_cost == cache_creation_cost + cache_read_cost
|
||||
|
||||
def test_calculate_cost_with_request_price(self) -> None:
|
||||
"""测试按次计费"""
|
||||
result = UsageService.calculate_cost(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
input_price_per_1m=3.0,
|
||||
output_price_per_1m=15.0,
|
||||
price_per_request=0.01,
|
||||
)
|
||||
|
||||
(
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
cache_cost,
|
||||
request_cost,
|
||||
total_cost,
|
||||
) = result
|
||||
|
||||
assert request_cost == 0.01
|
||||
# Total 包含 request_cost
|
||||
assert total_cost == input_cost + output_cost + request_cost
|
||||
|
||||
def test_calculate_cost_zero_tokens(self) -> None:
|
||||
"""测试零 token 的成本计算"""
|
||||
result = UsageService.calculate_cost(
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
input_price_per_1m=3.0,
|
||||
output_price_per_1m=15.0,
|
||||
)
|
||||
|
||||
(
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
cache_cost,
|
||||
request_cost,
|
||||
total_cost,
|
||||
) = result
|
||||
|
||||
assert input_cost == 0
|
||||
assert output_cost == 0
|
||||
assert total_cost == 0
|
||||
|
||||
|
||||
class TestQuotaCheck:
|
||||
"""测试配额检查"""
|
||||
|
||||
def test_check_user_quota_sufficient(self) -> None:
|
||||
"""测试配额充足"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 30.0
|
||||
mock_user.role = MagicMock()
|
||||
mock_user.role.value = "user"
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||
|
||||
assert is_ok is True
|
||||
|
||||
def test_check_user_quota_exceeded(self) -> None:
|
||||
"""测试配额超限(当有预估成本时)"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 99.0 # 接近配额上限
|
||||
mock_user.role = MagicMock()
|
||||
mock_user.role.value = "user"
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
# 当预估成本超过剩余配额时应该返回 False
|
||||
is_ok, message = UsageService.check_user_quota(
|
||||
mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key
|
||||
)
|
||||
|
||||
assert is_ok is False
|
||||
assert "配额" in message
|
||||
|
||||
def test_check_user_quota_no_limit(self) -> None:
|
||||
"""测试无配额限制(None)"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = None
|
||||
mock_user.used_usd = 1000.0
|
||||
mock_user.role = MagicMock()
|
||||
mock_user.role.value = "user"
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||
|
||||
assert is_ok is True
|
||||
|
||||
def test_check_user_quota_admin_bypass(self) -> None:
|
||||
"""测试管理员绕过配额检查"""
|
||||
from src.models.database import UserRole
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 0.0
|
||||
mock_user.used_usd = 1000.0
|
||||
mock_user.role = UserRole.ADMIN
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||
|
||||
assert is_ok is True
|
||||
|
||||
def test_check_standalone_api_key_balance(self) -> None:
|
||||
"""测试独立 API Key 余额检查"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 0.0
|
||||
mock_user.used_usd = 0.0
|
||||
mock_user.role = MagicMock()
|
||||
mock_user.role.value = "user"
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = True
|
||||
mock_api_key.current_balance_usd = 50.0
|
||||
mock_api_key.balance_used_usd = 10.0
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||
|
||||
assert is_ok is True
|
||||
|
||||
def test_check_standalone_api_key_insufficient_balance(self) -> None:
|
||||
"""测试独立 API Key 余额不足"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 0.0
|
||||
mock_user.role = MagicMock()
|
||||
mock_user.role.value = "user"
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = True
|
||||
mock_api_key.current_balance_usd = 10.0
|
||||
mock_api_key.balance_used_usd = 9.0 # 剩余 $1
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
# 需要 mock ApiKeyService.get_remaining_balance
|
||||
with patch(
|
||||
"src.services.user.apikey.ApiKeyService.get_remaining_balance",
|
||||
return_value=1.0,
|
||||
):
|
||||
# 预估成本 $5 超过剩余余额 $1
|
||||
is_ok, message = UsageService.check_user_quota(
|
||||
mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key
|
||||
)
|
||||
|
||||
assert is_ok is False
|
||||
|
||||
|
||||
class TestUsageStatistics:
|
||||
"""测试用量统计查询
|
||||
|
||||
注意:get_usage_summary 方法内部使用了数据库方言特定的日期函数,
|
||||
需要真实数据库或更复杂的 mock。这里只测试方法存在性。
|
||||
"""
|
||||
|
||||
def test_get_usage_summary_exists(self) -> None:
|
||||
"""测试 get_usage_summary 方法存在"""
|
||||
assert hasattr(UsageService, "get_usage_summary")
|
||||
assert callable(getattr(UsageService, "get_usage_summary"))
|
||||
|
||||
|
||||
class TestHelperMethods:
|
||||
"""测试辅助方法"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rate_multiplier_and_free_tier_default(self) -> None:
|
||||
"""测试默认费率倍数"""
|
||||
mock_db = MagicMock()
|
||||
# 模拟未找到 provider_api_key
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier(
|
||||
mock_db, provider_api_key_id=None, provider_id=None
|
||||
)
|
||||
|
||||
assert rate_multiplier == 1.0
|
||||
assert is_free_tier is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rate_multiplier_from_provider_api_key(self) -> None:
|
||||
"""测试从 ProviderAPIKey 获取费率倍数"""
|
||||
mock_provider_api_key = MagicMock()
|
||||
mock_provider_api_key.rate_multiplier = 0.8
|
||||
|
||||
mock_endpoint = MagicMock()
|
||||
mock_endpoint.provider_id = "provider-123"
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.billing_type = "standard"
|
||||
|
||||
mock_db = MagicMock()
|
||||
# 第一次查询返回 provider_api_key
|
||||
mock_db.query.return_value.filter.return_value.first.side_effect = [
|
||||
mock_provider_api_key,
|
||||
mock_endpoint,
|
||||
mock_provider,
|
||||
]
|
||||
|
||||
rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier(
|
||||
mock_db, provider_api_key_id="pak-123", provider_id=None
|
||||
)
|
||||
|
||||
assert rate_multiplier == 0.8
|
||||
assert is_free_tier is False
|
||||
Reference in New Issue
Block a user