From 3d0ab353d3a0080172e62a7e83790f12ce703688 Mon Sep 17 00:00:00 2001 From: fawney19 Date: Thu, 18 Dec 2025 02:20:53 +0800 Subject: [PATCH] refactor: migrate Pydantic Config to v2 ConfigDict --- src/api/admin/monitoring/trace.py | 5 +- src/models/api.py | 11 +- src/models/api_key.py | 5 +- src/models/database.py | 3 +- src/models/endpoint_models.py | 11 +- src/models/pydantic_models.py | 5 +- tests/api/test_pipeline.py | 363 +++++++++++++++++++++++++++ tests/services/__init__.py | 1 + tests/services/test_auth.py | 299 ++++++++++++++++++++++ tests/services/test_usage_service.py | 292 +++++++++++++++++++++ 10 files changed, 970 insertions(+), 25 deletions(-) create mode 100644 tests/api/test_pipeline.py create mode 100644 tests/services/__init__.py create mode 100644 tests/services/test_auth.py create mode 100644 tests/services/test_usage_service.py diff --git a/src/api/admin/monitoring/trace.py b/src/api/admin/monitoring/trace.py index de89a26..2c47383 100644 --- a/src/api/admin/monitoring/trace.py +++ b/src/api/admin/monitoring/trace.py @@ -7,7 +7,7 @@ from datetime import datetime from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Query, Request -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from sqlalchemy.orm import Session from src.api.base.admin_adapter import AdminApiAdapter @@ -52,8 +52,7 @@ class CandidateResponse(BaseModel): started_at: Optional[datetime] = None finished_at: Optional[datetime] = None - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class RequestTraceResponse(BaseModel): diff --git a/src/models/api.py b/src/models/api.py index 0b7968b..70c357b 100644 --- a/src/models/api.py +++ b/src/models/api.py @@ -6,7 +6,7 @@ import re from datetime import datetime from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from ..core.enums import UserRole @@ -336,8 +336,7 @@ class ProviderResponse(BaseModel): active_models_count: int = 0 api_keys_count: int = 0 - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) # ========== 模型管理 ========== @@ -442,8 +441,7 @@ class ModelResponse(BaseModel): global_model_name: Optional[str] = None global_model_display_name: Optional[str] = None - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class ModelDetailResponse(BaseModel): @@ -469,8 +467,7 @@ class ModelDetailResponse(BaseModel): created_at: datetime updated_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) # ========== 系统设置 ========== diff --git a/src/models/api_key.py b/src/models/api_key.py index 7191dc7..7d50a6f 100644 --- a/src/models/api_key.py +++ b/src/models/api_key.py @@ -5,7 +5,7 @@ Provider API Key相关的API模型 from datetime import datetime from typing import Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field class ProviderAPIKeyBase(BaseModel): @@ -53,8 +53,7 @@ class ProviderAPIKeyResponse(ProviderAPIKeyBase): created_at: datetime updated_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class ProviderAPIKeyStats(BaseModel): diff --git a/src/models/database.py b/src/models/database.py index b8fd53f..e0027d9 100644 --- a/src/models/database.py +++ b/src/models/database.py @@ -27,8 +27,7 @@ from sqlalchemy import ( UniqueConstraint, ) from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import relationship +from sqlalchemy.orm import declarative_base, relationship from ..config import config from ..core.enums import ProviderBillingType, UserRole diff --git a/src/models/endpoint_models.py b/src/models/endpoint_models.py index 1a38ff2..c9bd0f2 100644 --- a/src/models/endpoint_models.py +++ b/src/models/endpoint_models.py @@ -6,7 +6,7 @@ import re from datetime import datetime from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator # ========== ProviderEndpoint CRUD ========== @@ -141,8 +141,7 @@ class ProviderEndpointResponse(BaseModel): created_at: datetime updated_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) # ========== ProviderAPIKey 相关(新架构) ========== @@ -384,8 +383,7 @@ class EndpointAPIKeyResponse(BaseModel): created_at: datetime updated_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) # ========== 健康监控相关 ========== @@ -535,8 +533,7 @@ class ProviderWithEndpointsSummary(BaseModel): created_at: datetime updated_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) # ========== 健康监控可视化模型 ========== diff --git a/src/models/pydantic_models.py b/src/models/pydantic_models.py index c632715..3625a0a 100644 --- a/src/models/pydantic_models.py +++ b/src/models/pydantic_models.py @@ -5,7 +5,7 @@ Pydantic 数据模型(阶段一统一模型管理) from datetime import datetime from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator # ========== 阶梯计费相关模型 ========== @@ -256,8 +256,7 @@ class GlobalModelResponse(BaseModel): created_at: datetime updated_at: Optional[datetime] - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class GlobalModelWithStats(GlobalModelResponse): diff --git a/tests/api/test_pipeline.py b/tests/api/test_pipeline.py new file mode 100644 index 0000000..8e9afb8 --- /dev/null +++ b/tests/api/test_pipeline.py @@ -0,0 +1,363 @@ +""" +API Pipeline 测试 + +测试 ApiRequestPipeline 的核心功能: +- 认证流程(API Key、JWT Token) +- 配额计算 +- 审计日志记录 +""" + +import pytest +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi import HTTPException + +from src.api.base.pipeline import ApiRequestPipeline + + +class TestPipelineQuotaCalculation: + """测试 Pipeline 配额计算""" + + @pytest.fixture + def pipeline(self) -> ApiRequestPipeline: + return ApiRequestPipeline() + + def test_calculate_quota_remaining_with_quota(self, pipeline: ApiRequestPipeline) -> None: + """测试有配额限制时计算剩余配额""" + mock_user = MagicMock() + mock_user.quota_usd = 100.0 + mock_user.used_usd = 30.0 + + remaining = pipeline._calculate_quota_remaining(mock_user) + + assert remaining == 70.0 + + def test_calculate_quota_remaining_no_quota(self, pipeline: ApiRequestPipeline) -> None: + """测试无配额限制时返回 None""" + mock_user = MagicMock() + mock_user.quota_usd = None + mock_user.used_usd = 30.0 + + remaining = pipeline._calculate_quota_remaining(mock_user) + + assert remaining is None + + def test_calculate_quota_remaining_negative_quota(self, pipeline: ApiRequestPipeline) -> None: + """测试负配额时返回 None""" + mock_user = MagicMock() + mock_user.quota_usd = -1 + mock_user.used_usd = 0.0 + + remaining = pipeline._calculate_quota_remaining(mock_user) + + assert remaining is None + + def test_calculate_quota_remaining_exceeded(self, pipeline: ApiRequestPipeline) -> None: + """测试配额已超时返回 0""" + mock_user = MagicMock() + mock_user.quota_usd = 100.0 + mock_user.used_usd = 150.0 + + remaining = pipeline._calculate_quota_remaining(mock_user) + + assert remaining == 0.0 + + def test_calculate_quota_remaining_none_user(self, pipeline: ApiRequestPipeline) -> None: + """测试用户为 None 时返回 None""" + remaining = pipeline._calculate_quota_remaining(None) + + assert remaining is None + + +class TestPipelineAuditLogging: + """测试 Pipeline 审计日志""" + + @pytest.fixture + def pipeline(self) -> ApiRequestPipeline: + return ApiRequestPipeline() + + def test_record_audit_event_success(self, pipeline: ApiRequestPipeline) -> None: + """测试记录成功的审计事件""" + mock_context = MagicMock() + mock_context.db = MagicMock() + mock_context.user = MagicMock() + mock_context.user.id = "user-123" + mock_context.api_key = MagicMock() + mock_context.api_key.id = "key-123" + mock_context.request_id = "req-123" + mock_context.client_ip = "127.0.0.1" + mock_context.user_agent = "test-agent" + mock_context.request = MagicMock() + mock_context.request.method = "POST" + mock_context.request.url.path = "/v1/messages" + mock_context.start_time = 1000.0 + + mock_adapter = MagicMock() + mock_adapter.name = "test-adapter" + mock_adapter.audit_log_enabled = True + mock_adapter.audit_success_event = None + mock_adapter.audit_failure_event = None + + with patch.object( + pipeline.audit_service, + "log_event", + ) as mock_log: + with patch("time.time", return_value=1001.0): + pipeline._record_audit_event( + mock_context, mock_adapter, success=True, status_code=200 + ) + + mock_log.assert_called_once() + call_kwargs = mock_log.call_args[1] + assert call_kwargs["user_id"] == "user-123" + assert call_kwargs["status_code"] == 200 + + def test_record_audit_event_failure(self, pipeline: ApiRequestPipeline) -> None: + """测试记录失败的审计事件""" + mock_context = MagicMock() + mock_context.db = MagicMock() + mock_context.user = MagicMock() + mock_context.user.id = "user-123" + mock_context.api_key = MagicMock() + mock_context.api_key.id = "key-123" + mock_context.request_id = "req-123" + mock_context.client_ip = "127.0.0.1" + mock_context.user_agent = "test-agent" + mock_context.request = MagicMock() + mock_context.request.method = "POST" + mock_context.request.url.path = "/v1/messages" + mock_context.start_time = 1000.0 + + mock_adapter = MagicMock() + mock_adapter.name = "test-adapter" + mock_adapter.audit_log_enabled = True + mock_adapter.audit_success_event = None + mock_adapter.audit_failure_event = None + + with patch.object( + pipeline.audit_service, + "log_event", + ) as mock_log: + with patch("time.time", return_value=1001.0): + pipeline._record_audit_event( + mock_context, mock_adapter, success=False, status_code=500, error="Internal error" + ) + + mock_log.assert_called_once() + call_kwargs = mock_log.call_args[1] + assert call_kwargs["status_code"] == 500 + assert call_kwargs["error_message"] == "Internal error" + + def test_record_audit_event_no_db(self, pipeline: ApiRequestPipeline) -> None: + """测试没有数据库会话时跳过审计""" + mock_context = MagicMock() + mock_context.db = None + + mock_adapter = MagicMock() + mock_adapter.audit_log_enabled = True + + with patch.object( + pipeline.audit_service, + "log_event", + ) as mock_log: + # 不应该抛出异常 + pipeline._record_audit_event(mock_context, mock_adapter, success=True) + + # 不应该调用 log_event + mock_log.assert_not_called() + + def test_record_audit_event_disabled(self, pipeline: ApiRequestPipeline) -> None: + """测试审计日志被禁用时跳过""" + mock_context = MagicMock() + mock_context.db = MagicMock() + + mock_adapter = MagicMock() + mock_adapter.audit_log_enabled = False + + with patch.object( + pipeline.audit_service, + "log_event", + ) as mock_log: + pipeline._record_audit_event(mock_context, mock_adapter, success=True) + + mock_log.assert_not_called() + + def test_record_audit_event_exception_handling(self, pipeline: ApiRequestPipeline) -> None: + """测试审计日志异常不影响主流程""" + mock_context = MagicMock() + mock_context.db = MagicMock() + mock_context.user = MagicMock() + mock_context.user.id = "user-123" + mock_context.api_key = MagicMock() + mock_context.api_key.id = "key-123" + mock_context.request_id = "req-123" + mock_context.client_ip = "127.0.0.1" + mock_context.user_agent = "test-agent" + mock_context.request = MagicMock() + mock_context.request.method = "POST" + mock_context.request.url.path = "/v1/messages" + mock_context.start_time = 1000.0 + + mock_adapter = MagicMock() + mock_adapter.name = "test-adapter" + mock_adapter.audit_log_enabled = True + mock_adapter.audit_success_event = None + + with patch.object( + pipeline.audit_service, + "log_event", + side_effect=Exception("DB error"), + ): + with patch("time.time", return_value=1001.0): + # 不应该抛出异常 + pipeline._record_audit_event(mock_context, mock_adapter, success=True) + + +class TestPipelineAuthentication: + """测试 Pipeline 认证相关逻辑""" + + @pytest.fixture + def pipeline(self) -> ApiRequestPipeline: + return ApiRequestPipeline() + + def test_authenticate_client_missing_key(self, pipeline: ApiRequestPipeline) -> None: + """测试缺少 API Key 时抛出异常""" + mock_request = MagicMock() + mock_request.headers = {} + mock_request.url.path = "/v1/messages" + mock_request.state = MagicMock() + + mock_db = MagicMock() + + mock_adapter = MagicMock() + mock_adapter.extract_api_key = MagicMock(return_value=None) + + with pytest.raises(HTTPException) as exc_info: + pipeline._authenticate_client(mock_request, mock_db, mock_adapter) + + assert exc_info.value.status_code == 401 + assert "API密钥" in exc_info.value.detail + + def test_authenticate_client_invalid_key(self, pipeline: ApiRequestPipeline) -> None: + """测试无效的 API Key""" + mock_request = MagicMock() + mock_request.headers = {"Authorization": "Bearer sk-invalid"} + mock_request.url.path = "/v1/messages" + mock_request.state = MagicMock() + + mock_db = MagicMock() + + mock_adapter = MagicMock() + mock_adapter.extract_api_key = MagicMock(return_value="sk-invalid") + + with patch.object( + pipeline.auth_service, + "authenticate_api_key", + return_value=None, + ): + with pytest.raises(HTTPException) as exc_info: + pipeline._authenticate_client(mock_request, mock_db, mock_adapter) + + assert exc_info.value.status_code == 401 + + def test_authenticate_client_quota_exceeded(self, pipeline: ApiRequestPipeline) -> None: + """测试配额超限时抛出异常""" + mock_user = MagicMock() + mock_user.id = "user-123" + mock_user.quota_usd = 100.0 + mock_user.used_usd = 100.0 + + mock_api_key = MagicMock() + mock_api_key.id = "key-123" + mock_api_key.is_standalone = False + + mock_request = MagicMock() + mock_request.headers = {"Authorization": "Bearer sk-test"} + mock_request.url.path = "/v1/messages" + mock_request.state = MagicMock() + + mock_db = MagicMock() + + mock_adapter = MagicMock() + mock_adapter.extract_api_key = MagicMock(return_value="sk-test") + + with patch.object( + pipeline.auth_service, + "authenticate_api_key", + return_value=(mock_user, mock_api_key), + ): + with patch.object( + pipeline.usage_service, + "check_user_quota", + return_value=(False, "配额不足"), + ): + from src.core.exceptions import QuotaExceededException + + with pytest.raises(QuotaExceededException): + pipeline._authenticate_client(mock_request, mock_db, mock_adapter) + + +class TestPipelineAdminAuth: + """测试管理员认证""" + + @pytest.fixture + def pipeline(self) -> ApiRequestPipeline: + return ApiRequestPipeline() + + @pytest.mark.asyncio + async def test_authenticate_admin_missing_token(self, pipeline: ApiRequestPipeline) -> None: + """测试缺少管理员令牌""" + mock_request = MagicMock() + mock_request.headers = {} + + mock_db = MagicMock() + + with pytest.raises(HTTPException) as exc_info: + await pipeline._authenticate_admin(mock_request, mock_db) + + assert exc_info.value.status_code == 401 + assert "管理员凭证" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_authenticate_admin_invalid_token(self, pipeline: ApiRequestPipeline) -> None: + """测试无效的管理员令牌""" + mock_request = MagicMock() + mock_request.headers = {"authorization": "Bearer invalid-token"} + + mock_db = MagicMock() + + with patch.object( + pipeline.auth_service, + "verify_token", + side_effect=HTTPException(status_code=401, detail="Invalid token"), + ): + with pytest.raises(HTTPException) as exc_info: + await pipeline._authenticate_admin(mock_request, mock_db) + + assert exc_info.value.status_code == 401 + + @pytest.mark.asyncio + async def test_authenticate_admin_success(self, pipeline: ApiRequestPipeline) -> None: + """测试管理员认证成功""" + mock_user = MagicMock() + mock_user.id = "admin-123" + mock_user.is_active = True + + mock_request = MagicMock() + mock_request.headers = {"authorization": "Bearer valid-token"} + mock_request.state = MagicMock() + + mock_db = MagicMock() + mock_db.query.return_value.filter.return_value.first.return_value = mock_user + + with patch.object( + pipeline.auth_service, + "verify_token", + new_callable=AsyncMock, + return_value={"user_id": "admin-123"}, + ): + result = await pipeline._authenticate_admin(mock_request, mock_db) + + assert result == mock_user + assert mock_request.state.user_id == "admin-123" diff --git a/tests/services/__init__.py b/tests/services/__init__.py new file mode 100644 index 0000000..70bb6e5 --- /dev/null +++ b/tests/services/__init__.py @@ -0,0 +1 @@ +"""服务层测试""" diff --git a/tests/services/test_auth.py b/tests/services/test_auth.py new file mode 100644 index 0000000..6086726 --- /dev/null +++ b/tests/services/test_auth.py @@ -0,0 +1,299 @@ +""" +认证服务测试 + +测试 AuthService 的核心功能: +- JWT Token 创建和验证 +- 用户登录认证 +- API Key 认证 +""" + +import pytest +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import jwt + +from src.services.auth.service import ( + AuthService, + JWT_SECRET_KEY, + JWT_ALGORITHM, + JWT_EXPIRATION_HOURS, +) + + +class TestJWTTokenCreation: + """测试 JWT Token 创建""" + + def test_create_access_token_contains_required_fields(self) -> None: + """测试访问令牌包含必要字段""" + data = {"sub": "user123", "email": "test@example.com"} + token = AuthService.create_access_token(data) + + # 解码验证 + payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) + + assert payload["sub"] == "user123" + assert payload["email"] == "test@example.com" + assert payload["type"] == "access" + assert "exp" in payload + + def test_create_access_token_expiration(self) -> None: + """测试访问令牌过期时间正确""" + data = {"sub": "user123"} + token = AuthService.create_access_token(data) + + payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) + + # 验证过期时间在预期范围内(允许1分钟误差) + exp_time = datetime.fromtimestamp(payload["exp"], tz=timezone.utc) + expected_exp = datetime.now(timezone.utc) + timedelta(hours=JWT_EXPIRATION_HOURS) + + assert abs((exp_time - expected_exp).total_seconds()) < 60 + + def test_create_refresh_token_type(self) -> None: + """测试刷新令牌类型正确""" + data = {"sub": "user123"} + token = AuthService.create_refresh_token(data) + + payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) + + assert payload["type"] == "refresh" + + def test_create_refresh_token_longer_expiration(self) -> None: + """测试刷新令牌过期时间更长""" + data = {"sub": "user123"} + access_token = AuthService.create_access_token(data) + refresh_token = AuthService.create_refresh_token(data) + + access_payload = jwt.decode(access_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) + refresh_payload = jwt.decode(refresh_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) + + # 刷新令牌应该比访问令牌过期时间更长 + assert refresh_payload["exp"] > access_payload["exp"] + + +class TestJWTTokenVerification: + """测试 JWT Token 验证""" + + @pytest.mark.asyncio + async def test_verify_valid_access_token(self) -> None: + """测试验证有效的访问令牌""" + data = {"sub": "user123", "email": "test@example.com"} + token = AuthService.create_access_token(data) + + with patch( + "src.services.auth.service.JWTBlacklistService.is_blacklisted", + new_callable=AsyncMock, + return_value=False, + ): + payload = await AuthService.verify_token(token, token_type="access") + + assert payload["sub"] == "user123" + assert payload["type"] == "access" + + @pytest.mark.asyncio + async def test_verify_expired_token_raises_error(self) -> None: + """测试验证过期令牌抛出异常""" + # 创建一个已过期的 token + data = {"sub": "user123", "type": "access"} + expire = datetime.now(timezone.utc) - timedelta(hours=1) + data["exp"] = expire + expired_token = jwt.encode(data, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) + + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + await AuthService.verify_token(expired_token) + + assert exc_info.value.status_code == 401 + assert "过期" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_verify_invalid_token_raises_error(self) -> None: + """测试验证无效令牌抛出异常""" + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + await AuthService.verify_token("invalid.token.here") + + assert exc_info.value.status_code == 401 + + @pytest.mark.asyncio + async def test_verify_wrong_token_type_raises_error(self) -> None: + """测试令牌类型不匹配抛出异常""" + data = {"sub": "user123"} + refresh_token = AuthService.create_refresh_token(data) + + from fastapi import HTTPException + + with patch( + "src.services.auth.service.JWTBlacklistService.is_blacklisted", + new_callable=AsyncMock, + return_value=False, + ): + with pytest.raises(HTTPException) as exc_info: + await AuthService.verify_token(refresh_token, token_type="access") + + assert exc_info.value.status_code == 401 + assert "类型错误" in exc_info.value.detail + + @pytest.mark.asyncio + async def test_verify_blacklisted_token_raises_error(self) -> None: + """测试已撤销的令牌抛出异常""" + data = {"sub": "user123"} + token = AuthService.create_access_token(data) + + from fastapi import HTTPException + + with patch( + "src.services.auth.service.JWTBlacklistService.is_blacklisted", + new_callable=AsyncMock, + return_value=True, + ): + with pytest.raises(HTTPException) as exc_info: + await AuthService.verify_token(token) + + assert exc_info.value.status_code == 401 + assert "撤销" in exc_info.value.detail + + +class TestUserAuthentication: + """测试用户登录认证""" + + @pytest.mark.asyncio + async def test_authenticate_user_success(self) -> None: + """测试用户登录成功""" + # Mock 数据库和用户对象 + mock_user = MagicMock() + mock_user.id = "user-123" + mock_user.email = "test@example.com" + mock_user.is_active = True + mock_user.verify_password.return_value = True + + mock_db = MagicMock() + mock_db.query.return_value.filter.return_value.first.return_value = mock_user + + with patch( + "src.services.auth.service.UserCacheService.invalidate_user_cache", + new_callable=AsyncMock, + ): + result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123") + + assert result == mock_user + mock_user.verify_password.assert_called_once_with("password123") + mock_db.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_authenticate_user_not_found(self) -> None: + """测试用户不存在""" + mock_db = MagicMock() + mock_db.query.return_value.filter.return_value.first.return_value = None + + result = await AuthService.authenticate_user(mock_db, "nonexistent@example.com", "password") + + assert result is None + + @pytest.mark.asyncio + async def test_authenticate_user_wrong_password(self) -> None: + """测试密码错误""" + mock_user = MagicMock() + mock_user.email = "test@example.com" + mock_user.verify_password.return_value = False + + mock_db = MagicMock() + mock_db.query.return_value.filter.return_value.first.return_value = mock_user + + result = await AuthService.authenticate_user(mock_db, "test@example.com", "wrongpassword") + + assert result is None + + @pytest.mark.asyncio + async def test_authenticate_user_inactive(self) -> None: + """测试用户已禁用""" + mock_user = MagicMock() + mock_user.email = "test@example.com" + mock_user.is_active = False + mock_user.verify_password.return_value = True + + mock_db = MagicMock() + mock_db.query.return_value.filter.return_value.first.return_value = mock_user + + result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123") + + assert result is None + + +class TestAPIKeyAuthentication: + """测试 API Key 认证""" + + def test_authenticate_api_key_success(self) -> None: + """测试 API Key 认证成功""" + mock_user = MagicMock() + mock_user.id = "user-123" + mock_user.email = "test@example.com" + mock_user.is_active = True + + mock_api_key = MagicMock() + mock_api_key.is_active = True + mock_api_key.expires_at = None + mock_api_key.user = mock_user + mock_api_key.balance_used_usd = 0.0 + + mock_db = MagicMock() + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = ( + mock_api_key + ) + + with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"): + with patch( + "src.services.auth.service.ApiKeyService.check_balance", + return_value=(True, 100.0), + ): + result = AuthService.authenticate_api_key(mock_db, "sk-test-key") + + assert result is not None + assert result[0] == mock_user + assert result[1] == mock_api_key + + def test_authenticate_api_key_not_found(self) -> None: + """测试 API Key 不存在""" + mock_db = MagicMock() + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = ( + None + ) + + with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"): + result = AuthService.authenticate_api_key(mock_db, "sk-invalid-key") + + assert result is None + + def test_authenticate_api_key_inactive(self) -> None: + """测试 API Key 已禁用""" + mock_api_key = MagicMock() + mock_api_key.is_active = False + + mock_db = MagicMock() + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = ( + mock_api_key + ) + + with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"): + result = AuthService.authenticate_api_key(mock_db, "sk-inactive-key") + + assert result is None + + def test_authenticate_api_key_expired(self) -> None: + """测试 API Key 已过期""" + mock_api_key = MagicMock() + mock_api_key.is_active = True + mock_api_key.expires_at = datetime.now(timezone.utc) - timedelta(days=1) + + mock_db = MagicMock() + mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = ( + mock_api_key + ) + + with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"): + result = AuthService.authenticate_api_key(mock_db, "sk-expired-key") + + assert result is None diff --git a/tests/services/test_usage_service.py b/tests/services/test_usage_service.py new file mode 100644 index 0000000..2075e1c --- /dev/null +++ b/tests/services/test_usage_service.py @@ -0,0 +1,292 @@ +""" +UsageService 测试 + +测试用量统计服务的核心功能: +- 成本计算 +- 配额检查 +- 用量统计查询 +""" + +import pytest +from unittest.mock import MagicMock, patch, AsyncMock + +from src.services.usage.service import UsageService + + +class TestCostCalculation: + """测试成本计算""" + + def test_calculate_cost_basic(self) -> None: + """测试基础成本计算""" + # 价格:输入 $3/1M, 输出 $15/1M + result = UsageService.calculate_cost( + input_tokens=1000, + output_tokens=500, + input_price_per_1m=3.0, + output_price_per_1m=15.0, + ) + + input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, request_cost, total_cost = result + + # 1000 tokens * $3 / 1M = $0.003 + assert abs(input_cost - 0.003) < 0.0001 + # 500 tokens * $15 / 1M = $0.0075 + assert abs(output_cost - 0.0075) < 0.0001 + # Total = $0.003 + $0.0075 = $0.0105 + assert abs(total_cost - 0.0105) < 0.0001 + + def test_calculate_cost_with_cache(self) -> None: + """测试带缓存的成本计算""" + result = UsageService.calculate_cost( + input_tokens=1000, + output_tokens=500, + input_price_per_1m=3.0, + output_price_per_1m=15.0, + cache_creation_input_tokens=200, + cache_read_input_tokens=300, + cache_creation_price_per_1m=3.75, # 1.25x input price + cache_read_price_per_1m=0.3, # 0.1x input price + ) + + ( + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + cache_cost, + request_cost, + total_cost, + ) = result + + # 验证缓存成本被计算 + assert cache_creation_cost > 0 + assert cache_read_cost > 0 + assert cache_cost == cache_creation_cost + cache_read_cost + + def test_calculate_cost_with_request_price(self) -> None: + """测试按次计费""" + result = UsageService.calculate_cost( + input_tokens=1000, + output_tokens=500, + input_price_per_1m=3.0, + output_price_per_1m=15.0, + price_per_request=0.01, + ) + + ( + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + cache_cost, + request_cost, + total_cost, + ) = result + + assert request_cost == 0.01 + # Total 包含 request_cost + assert total_cost == input_cost + output_cost + request_cost + + def test_calculate_cost_zero_tokens(self) -> None: + """测试零 token 的成本计算""" + result = UsageService.calculate_cost( + input_tokens=0, + output_tokens=0, + input_price_per_1m=3.0, + output_price_per_1m=15.0, + ) + + ( + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + cache_cost, + request_cost, + total_cost, + ) = result + + assert input_cost == 0 + assert output_cost == 0 + assert total_cost == 0 + + +class TestQuotaCheck: + """测试配额检查""" + + def test_check_user_quota_sufficient(self) -> None: + """测试配额充足""" + mock_user = MagicMock() + mock_user.quota_usd = 100.0 + mock_user.used_usd = 30.0 + mock_user.role = MagicMock() + mock_user.role.value = "user" + + mock_api_key = MagicMock() + mock_api_key.is_standalone = False + + mock_db = MagicMock() + + is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key) + + assert is_ok is True + + def test_check_user_quota_exceeded(self) -> None: + """测试配额超限(当有预估成本时)""" + mock_user = MagicMock() + mock_user.quota_usd = 100.0 + mock_user.used_usd = 99.0 # 接近配额上限 + mock_user.role = MagicMock() + mock_user.role.value = "user" + + mock_api_key = MagicMock() + mock_api_key.is_standalone = False + + mock_db = MagicMock() + + # 当预估成本超过剩余配额时应该返回 False + is_ok, message = UsageService.check_user_quota( + mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key + ) + + assert is_ok is False + assert "配额" in message + + def test_check_user_quota_no_limit(self) -> None: + """测试无配额限制(None)""" + mock_user = MagicMock() + mock_user.quota_usd = None + mock_user.used_usd = 1000.0 + mock_user.role = MagicMock() + mock_user.role.value = "user" + + mock_api_key = MagicMock() + mock_api_key.is_standalone = False + + mock_db = MagicMock() + + is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key) + + assert is_ok is True + + def test_check_user_quota_admin_bypass(self) -> None: + """测试管理员绕过配额检查""" + from src.models.database import UserRole + + mock_user = MagicMock() + mock_user.quota_usd = 0.0 + mock_user.used_usd = 1000.0 + mock_user.role = UserRole.ADMIN + + mock_api_key = MagicMock() + mock_api_key.is_standalone = False + + mock_db = MagicMock() + + is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key) + + assert is_ok is True + + def test_check_standalone_api_key_balance(self) -> None: + """测试独立 API Key 余额检查""" + mock_user = MagicMock() + mock_user.quota_usd = 0.0 + mock_user.used_usd = 0.0 + mock_user.role = MagicMock() + mock_user.role.value = "user" + + mock_api_key = MagicMock() + mock_api_key.is_standalone = True + mock_api_key.current_balance_usd = 50.0 + mock_api_key.balance_used_usd = 10.0 + + mock_db = MagicMock() + + is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key) + + assert is_ok is True + + def test_check_standalone_api_key_insufficient_balance(self) -> None: + """测试独立 API Key 余额不足""" + mock_user = MagicMock() + mock_user.quota_usd = 100.0 + mock_user.used_usd = 0.0 + mock_user.role = MagicMock() + mock_user.role.value = "user" + + mock_api_key = MagicMock() + mock_api_key.is_standalone = True + mock_api_key.current_balance_usd = 10.0 + mock_api_key.balance_used_usd = 9.0 # 剩余 $1 + + mock_db = MagicMock() + + # 需要 mock ApiKeyService.get_remaining_balance + with patch( + "src.services.user.apikey.ApiKeyService.get_remaining_balance", + return_value=1.0, + ): + # 预估成本 $5 超过剩余余额 $1 + is_ok, message = UsageService.check_user_quota( + mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key + ) + + assert is_ok is False + + +class TestUsageStatistics: + """测试用量统计查询 + + 注意:get_usage_summary 方法内部使用了数据库方言特定的日期函数, + 需要真实数据库或更复杂的 mock。这里只测试方法存在性。 + """ + + def test_get_usage_summary_exists(self) -> None: + """测试 get_usage_summary 方法存在""" + assert hasattr(UsageService, "get_usage_summary") + assert callable(getattr(UsageService, "get_usage_summary")) + + +class TestHelperMethods: + """测试辅助方法""" + + @pytest.mark.asyncio + async def test_get_rate_multiplier_and_free_tier_default(self) -> None: + """测试默认费率倍数""" + mock_db = MagicMock() + # 模拟未找到 provider_api_key + mock_db.query.return_value.filter.return_value.first.return_value = None + + rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier( + mock_db, provider_api_key_id=None, provider_id=None + ) + + assert rate_multiplier == 1.0 + assert is_free_tier is False + + @pytest.mark.asyncio + async def test_get_rate_multiplier_from_provider_api_key(self) -> None: + """测试从 ProviderAPIKey 获取费率倍数""" + mock_provider_api_key = MagicMock() + mock_provider_api_key.rate_multiplier = 0.8 + + mock_endpoint = MagicMock() + mock_endpoint.provider_id = "provider-123" + + mock_provider = MagicMock() + mock_provider.billing_type = "standard" + + mock_db = MagicMock() + # 第一次查询返回 provider_api_key + mock_db.query.return_value.filter.return_value.first.side_effect = [ + mock_provider_api_key, + mock_endpoint, + mock_provider, + ] + + rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier( + mock_db, provider_api_key_id="pak-123", provider_id=None + ) + + assert rate_multiplier == 0.8 + assert is_free_tier is False