Files
Aether/tests/api/test_pipeline.py
fawney19 7b932d7afb refactor: optimize middleware with pure ASGI implementation and enhance security measures
- Replace BaseHTTPMiddleware with pure ASGI implementation in plugin middleware for better streaming response handling
- Add trusted proxy count configuration for client IP extraction in reverse proxy environments
- Implement audit log cleanup scheduler with configurable retention period
- Replace plaintext token logging with SHA256 hash fingerprints for security
- Fix database session lifecycle management in middleware
- Improve request tracing and error logging throughout the system
- Add comprehensive tests for pipeline architecture
2025-12-18 19:07:20 +08:00

422 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
API Pipeline 测试
测试 ApiRequestPipeline 的核心功能:
- 认证流程API Key、JWT Token
- 配额计算
- 审计日志记录
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import HTTPException
from src.api.base.pipeline import ApiRequestPipeline
class TestPipelineQuotaCalculation:
"""测试 Pipeline 配额计算"""
@pytest.fixture
def pipeline(self) -> ApiRequestPipeline:
return ApiRequestPipeline()
def test_calculate_quota_remaining_with_quota(self, pipeline: ApiRequestPipeline) -> None:
"""测试有配额限制时计算剩余配额"""
mock_user = MagicMock()
mock_user.quota_usd = 100.0
mock_user.used_usd = 30.0
remaining = pipeline._calculate_quota_remaining(mock_user)
assert remaining == 70.0
def test_calculate_quota_remaining_no_quota(self, pipeline: ApiRequestPipeline) -> None:
"""测试无配额限制时返回 None"""
mock_user = MagicMock()
mock_user.quota_usd = None
mock_user.used_usd = 30.0
remaining = pipeline._calculate_quota_remaining(mock_user)
assert remaining is None
def test_calculate_quota_remaining_negative_quota(self, pipeline: ApiRequestPipeline) -> None:
"""测试负配额时返回 None"""
mock_user = MagicMock()
mock_user.quota_usd = -1
mock_user.used_usd = 0.0
remaining = pipeline._calculate_quota_remaining(mock_user)
assert remaining is None
def test_calculate_quota_remaining_exceeded(self, pipeline: ApiRequestPipeline) -> None:
"""测试配额已超时返回 0"""
mock_user = MagicMock()
mock_user.quota_usd = 100.0
mock_user.used_usd = 150.0
remaining = pipeline._calculate_quota_remaining(mock_user)
assert remaining == 0.0
def test_calculate_quota_remaining_none_user(self, pipeline: ApiRequestPipeline) -> None:
"""测试用户为 None 时返回 None"""
remaining = pipeline._calculate_quota_remaining(None)
assert remaining is None
class TestPipelineAuditLogging:
"""测试 Pipeline 审计日志"""
@pytest.fixture
def pipeline(self) -> ApiRequestPipeline:
return ApiRequestPipeline()
def test_record_audit_event_success(self, pipeline: ApiRequestPipeline) -> None:
"""测试记录成功的审计事件"""
mock_context = MagicMock()
mock_context.db = MagicMock()
mock_context.user = MagicMock()
mock_context.user.id = "user-123"
mock_context.api_key = MagicMock()
mock_context.api_key.id = "key-123"
mock_context.request_id = "req-123"
mock_context.client_ip = "127.0.0.1"
mock_context.user_agent = "test-agent"
mock_context.request = MagicMock()
mock_context.request.method = "POST"
mock_context.request.url.path = "/v1/messages"
mock_context.start_time = 1000.0
mock_adapter = MagicMock()
mock_adapter.name = "test-adapter"
mock_adapter.audit_log_enabled = True
mock_adapter.audit_success_event = None
mock_adapter.audit_failure_event = None
with patch.object(
pipeline.audit_service,
"log_event",
) as mock_log:
with patch("time.time", return_value=1001.0):
pipeline._record_audit_event(
mock_context, mock_adapter, success=True, status_code=200
)
mock_log.assert_called_once()
call_kwargs = mock_log.call_args[1]
assert call_kwargs["user_id"] == "user-123"
assert call_kwargs["status_code"] == 200
def test_record_audit_event_failure(self, pipeline: ApiRequestPipeline) -> None:
"""测试记录失败的审计事件"""
mock_context = MagicMock()
mock_context.db = MagicMock()
mock_context.user = MagicMock()
mock_context.user.id = "user-123"
mock_context.api_key = MagicMock()
mock_context.api_key.id = "key-123"
mock_context.request_id = "req-123"
mock_context.client_ip = "127.0.0.1"
mock_context.user_agent = "test-agent"
mock_context.request = MagicMock()
mock_context.request.method = "POST"
mock_context.request.url.path = "/v1/messages"
mock_context.start_time = 1000.0
mock_adapter = MagicMock()
mock_adapter.name = "test-adapter"
mock_adapter.audit_log_enabled = True
mock_adapter.audit_success_event = None
mock_adapter.audit_failure_event = None
with patch.object(
pipeline.audit_service,
"log_event",
) as mock_log:
with patch("time.time", return_value=1001.0):
pipeline._record_audit_event(
mock_context, mock_adapter, success=False, status_code=500, error="Internal error"
)
mock_log.assert_called_once()
call_kwargs = mock_log.call_args[1]
assert call_kwargs["status_code"] == 500
assert call_kwargs["error_message"] == "Internal error"
def test_record_audit_event_no_db(self, pipeline: ApiRequestPipeline) -> None:
"""测试没有数据库会话时跳过审计"""
mock_context = MagicMock()
mock_context.db = None
mock_adapter = MagicMock()
mock_adapter.audit_log_enabled = True
with patch.object(
pipeline.audit_service,
"log_event",
) as mock_log:
# 不应该抛出异常
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
# 不应该调用 log_event
mock_log.assert_not_called()
def test_record_audit_event_disabled(self, pipeline: ApiRequestPipeline) -> None:
"""测试审计日志被禁用时跳过"""
mock_context = MagicMock()
mock_context.db = MagicMock()
mock_adapter = MagicMock()
mock_adapter.audit_log_enabled = False
with patch.object(
pipeline.audit_service,
"log_event",
) as mock_log:
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
mock_log.assert_not_called()
def test_record_audit_event_exception_handling(self, pipeline: ApiRequestPipeline) -> None:
"""测试审计日志异常不影响主流程"""
mock_context = MagicMock()
mock_context.db = MagicMock()
mock_context.user = MagicMock()
mock_context.user.id = "user-123"
mock_context.api_key = MagicMock()
mock_context.api_key.id = "key-123"
mock_context.request_id = "req-123"
mock_context.client_ip = "127.0.0.1"
mock_context.user_agent = "test-agent"
mock_context.request = MagicMock()
mock_context.request.method = "POST"
mock_context.request.url.path = "/v1/messages"
mock_context.start_time = 1000.0
mock_adapter = MagicMock()
mock_adapter.name = "test-adapter"
mock_adapter.audit_log_enabled = True
mock_adapter.audit_success_event = None
with patch.object(
pipeline.audit_service,
"log_event",
side_effect=Exception("DB error"),
):
with patch("time.time", return_value=1001.0):
# 不应该抛出异常
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
class TestPipelineAuthentication:
"""测试 Pipeline 认证相关逻辑"""
@pytest.fixture
def pipeline(self) -> ApiRequestPipeline:
return ApiRequestPipeline()
def test_authenticate_client_missing_key(self, pipeline: ApiRequestPipeline) -> None:
"""测试缺少 API Key 时抛出异常"""
mock_request = MagicMock()
mock_request.headers = {}
mock_request.url.path = "/v1/messages"
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_adapter = MagicMock()
mock_adapter.extract_api_key = MagicMock(return_value=None)
with pytest.raises(HTTPException) as exc_info:
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
assert exc_info.value.status_code == 401
assert "API密钥" in exc_info.value.detail
def test_authenticate_client_invalid_key(self, pipeline: ApiRequestPipeline) -> None:
"""测试无效的 API Key"""
mock_request = MagicMock()
mock_request.headers = {"Authorization": "Bearer sk-invalid"}
mock_request.url.path = "/v1/messages"
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_adapter = MagicMock()
mock_adapter.extract_api_key = MagicMock(return_value="sk-invalid")
with patch.object(
pipeline.auth_service,
"authenticate_api_key",
return_value=None,
):
with pytest.raises(HTTPException) as exc_info:
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
assert exc_info.value.status_code == 401
def test_authenticate_client_quota_exceeded(self, pipeline: ApiRequestPipeline) -> None:
"""测试配额超限时抛出异常"""
mock_user = MagicMock()
mock_user.id = "user-123"
mock_user.quota_usd = 100.0
mock_user.used_usd = 100.0
mock_api_key = MagicMock()
mock_api_key.id = "key-123"
mock_api_key.is_standalone = False
mock_request = MagicMock()
mock_request.headers = {"Authorization": "Bearer sk-test"}
mock_request.url.path = "/v1/messages"
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_adapter = MagicMock()
mock_adapter.extract_api_key = MagicMock(return_value="sk-test")
with patch.object(
pipeline.auth_service,
"authenticate_api_key",
return_value=(mock_user, mock_api_key),
):
with patch.object(
pipeline.usage_service,
"check_user_quota",
return_value=(False, "配额不足"),
):
from src.core.exceptions import QuotaExceededException
with pytest.raises(QuotaExceededException):
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
class TestPipelineAdminAuth:
"""测试管理员认证"""
@pytest.fixture
def pipeline(self) -> ApiRequestPipeline:
return ApiRequestPipeline()
@pytest.mark.asyncio
async def test_authenticate_admin_missing_token(self, pipeline: ApiRequestPipeline) -> None:
"""测试缺少管理员令牌"""
mock_request = MagicMock()
mock_request.headers = {}
mock_db = MagicMock()
with pytest.raises(HTTPException) as exc_info:
await pipeline._authenticate_admin(mock_request, mock_db)
assert exc_info.value.status_code == 401
assert "管理员凭证" in exc_info.value.detail
@pytest.mark.asyncio
async def test_authenticate_admin_invalid_token(self, pipeline: ApiRequestPipeline) -> None:
"""测试无效的管理员令牌"""
mock_request = MagicMock()
mock_request.headers = {"authorization": "Bearer invalid-token"}
mock_db = MagicMock()
with patch.object(
pipeline.auth_service,
"verify_token",
side_effect=HTTPException(status_code=401, detail="Invalid token"),
):
with pytest.raises(HTTPException) as exc_info:
await pipeline._authenticate_admin(mock_request, mock_db)
assert exc_info.value.status_code == 401
@pytest.mark.asyncio
async def test_authenticate_admin_success(self, pipeline: ApiRequestPipeline) -> None:
"""测试管理员认证成功"""
mock_user = MagicMock()
mock_user.id = "admin-123"
mock_user.is_active = True
mock_request = MagicMock()
mock_request.headers = {"authorization": "Bearer valid-token"}
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
with patch.object(
pipeline.auth_service,
"verify_token",
new_callable=AsyncMock,
return_value={"user_id": "admin-123"},
):
result = await pipeline._authenticate_admin(mock_request, mock_db)
assert result == mock_user
assert mock_request.state.user_id == "admin-123"
@pytest.mark.asyncio
async def test_authenticate_admin_lowercase_bearer(self, pipeline: ApiRequestPipeline) -> None:
"""测试 bearer (小写) 前缀也能正确解析"""
mock_user = MagicMock()
mock_user.id = "admin-123"
mock_user.is_active = True
mock_request = MagicMock()
mock_request.headers = {"authorization": "bearer valid-token"}
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
with patch.object(
pipeline.auth_service,
"verify_token",
new_callable=AsyncMock,
return_value={"user_id": "admin-123"},
) as mock_verify:
result = await pipeline._authenticate_admin(mock_request, mock_db)
mock_verify.assert_awaited_once_with("valid-token", token_type="access")
assert result == mock_user
class TestPipelineUserAuth:
"""测试普通用户 JWT 认证"""
@pytest.fixture
def pipeline(self) -> ApiRequestPipeline:
return ApiRequestPipeline()
@pytest.mark.asyncio
async def test_authenticate_user_lowercase_bearer(self, pipeline: ApiRequestPipeline) -> None:
"""测试 bearer (小写) 前缀也能正确解析"""
mock_user = MagicMock()
mock_user.id = "user-123"
mock_user.is_active = True
mock_request = MagicMock()
mock_request.headers = {"authorization": "bearer valid-token"}
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
with patch.object(
pipeline.auth_service,
"verify_token",
new_callable=AsyncMock,
return_value={"user_id": "user-123"},
) as mock_verify:
result = await pipeline._authenticate_user(mock_request, mock_db)
mock_verify.assert_awaited_once_with("valid-token", token_type="access")
assert result == mock_user