refactor: migrate Pydantic Config to v2 ConfigDict

This commit is contained in:
fawney19
2025-12-18 02:20:53 +08:00
parent b2a857c164
commit 3d0ab353d3
10 changed files with 970 additions and 25 deletions

View File

@@ -7,7 +7,7 @@ from datetime import datetime
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter from src.api.base.admin_adapter import AdminApiAdapter
@@ -52,8 +52,7 @@ class CandidateResponse(BaseModel):
started_at: Optional[datetime] = None started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None finished_at: Optional[datetime] = None
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True
class RequestTraceResponse(BaseModel): class RequestTraceResponse(BaseModel):

View File

@@ -6,7 +6,7 @@ import re
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_validator
from ..core.enums import UserRole from ..core.enums import UserRole
@@ -336,8 +336,7 @@ class ProviderResponse(BaseModel):
active_models_count: int = 0 active_models_count: int = 0
api_keys_count: int = 0 api_keys_count: int = 0
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True
# ========== 模型管理 ========== # ========== 模型管理 ==========
@@ -442,8 +441,7 @@ class ModelResponse(BaseModel):
global_model_name: Optional[str] = None global_model_name: Optional[str] = None
global_model_display_name: Optional[str] = None global_model_display_name: Optional[str] = None
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True
class ModelDetailResponse(BaseModel): class ModelDetailResponse(BaseModel):
@@ -469,8 +467,7 @@ class ModelDetailResponse(BaseModel):
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True
# ========== 系统设置 ========== # ========== 系统设置 ==========

View File

@@ -5,7 +5,7 @@ Provider API Key相关的API模型
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class ProviderAPIKeyBase(BaseModel): class ProviderAPIKeyBase(BaseModel):
@@ -53,8 +53,7 @@ class ProviderAPIKeyResponse(ProviderAPIKeyBase):
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True
class ProviderAPIKeyStats(BaseModel): class ProviderAPIKeyStats(BaseModel):

View File

@@ -27,8 +27,7 @@ from sqlalchemy import (
UniqueConstraint, UniqueConstraint,
) )
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.orm import relationship
from ..config import config from ..config import config
from ..core.enums import ProviderBillingType, UserRole from ..core.enums import ProviderBillingType, UserRole

View File

@@ -6,7 +6,7 @@ import re
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_validator
# ========== ProviderEndpoint CRUD ========== # ========== ProviderEndpoint CRUD ==========
@@ -141,8 +141,7 @@ class ProviderEndpointResponse(BaseModel):
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True
# ========== ProviderAPIKey 相关(新架构) ========== # ========== ProviderAPIKey 相关(新架构) ==========
@@ -384,8 +383,7 @@ class EndpointAPIKeyResponse(BaseModel):
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True
# ========== 健康监控相关 ========== # ========== 健康监控相关 ==========
@@ -535,8 +533,7 @@ class ProviderWithEndpointsSummary(BaseModel):
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True
# ========== 健康监控可视化模型 ========== # ========== 健康监控可视化模型 ==========

View File

@@ -5,7 +5,7 @@ Pydantic 数据模型(阶段一统一模型管理)
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
# ========== 阶梯计费相关模型 ========== # ========== 阶梯计费相关模型 ==========
@@ -256,8 +256,7 @@ class GlobalModelResponse(BaseModel):
created_at: datetime created_at: datetime
updated_at: Optional[datetime] updated_at: Optional[datetime]
class Config: model_config = ConfigDict(from_attributes=True)
from_attributes = True
class GlobalModelWithStats(GlobalModelResponse): class GlobalModelWithStats(GlobalModelResponse):

363
tests/api/test_pipeline.py Normal file
View File

@@ -0,0 +1,363 @@
"""
API Pipeline 测试
测试 ApiRequestPipeline 的核心功能:
- 认证流程API Key、JWT Token
- 配额计算
- 审计日志记录
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import HTTPException
from src.api.base.pipeline import ApiRequestPipeline
class TestPipelineQuotaCalculation:
"""测试 Pipeline 配额计算"""
@pytest.fixture
def pipeline(self) -> ApiRequestPipeline:
return ApiRequestPipeline()
def test_calculate_quota_remaining_with_quota(self, pipeline: ApiRequestPipeline) -> None:
"""测试有配额限制时计算剩余配额"""
mock_user = MagicMock()
mock_user.quota_usd = 100.0
mock_user.used_usd = 30.0
remaining = pipeline._calculate_quota_remaining(mock_user)
assert remaining == 70.0
def test_calculate_quota_remaining_no_quota(self, pipeline: ApiRequestPipeline) -> None:
"""测试无配额限制时返回 None"""
mock_user = MagicMock()
mock_user.quota_usd = None
mock_user.used_usd = 30.0
remaining = pipeline._calculate_quota_remaining(mock_user)
assert remaining is None
def test_calculate_quota_remaining_negative_quota(self, pipeline: ApiRequestPipeline) -> None:
"""测试负配额时返回 None"""
mock_user = MagicMock()
mock_user.quota_usd = -1
mock_user.used_usd = 0.0
remaining = pipeline._calculate_quota_remaining(mock_user)
assert remaining is None
def test_calculate_quota_remaining_exceeded(self, pipeline: ApiRequestPipeline) -> None:
"""测试配额已超时返回 0"""
mock_user = MagicMock()
mock_user.quota_usd = 100.0
mock_user.used_usd = 150.0
remaining = pipeline._calculate_quota_remaining(mock_user)
assert remaining == 0.0
def test_calculate_quota_remaining_none_user(self, pipeline: ApiRequestPipeline) -> None:
"""测试用户为 None 时返回 None"""
remaining = pipeline._calculate_quota_remaining(None)
assert remaining is None
class TestPipelineAuditLogging:
"""测试 Pipeline 审计日志"""
@pytest.fixture
def pipeline(self) -> ApiRequestPipeline:
return ApiRequestPipeline()
def test_record_audit_event_success(self, pipeline: ApiRequestPipeline) -> None:
"""测试记录成功的审计事件"""
mock_context = MagicMock()
mock_context.db = MagicMock()
mock_context.user = MagicMock()
mock_context.user.id = "user-123"
mock_context.api_key = MagicMock()
mock_context.api_key.id = "key-123"
mock_context.request_id = "req-123"
mock_context.client_ip = "127.0.0.1"
mock_context.user_agent = "test-agent"
mock_context.request = MagicMock()
mock_context.request.method = "POST"
mock_context.request.url.path = "/v1/messages"
mock_context.start_time = 1000.0
mock_adapter = MagicMock()
mock_adapter.name = "test-adapter"
mock_adapter.audit_log_enabled = True
mock_adapter.audit_success_event = None
mock_adapter.audit_failure_event = None
with patch.object(
pipeline.audit_service,
"log_event",
) as mock_log:
with patch("time.time", return_value=1001.0):
pipeline._record_audit_event(
mock_context, mock_adapter, success=True, status_code=200
)
mock_log.assert_called_once()
call_kwargs = mock_log.call_args[1]
assert call_kwargs["user_id"] == "user-123"
assert call_kwargs["status_code"] == 200
def test_record_audit_event_failure(self, pipeline: ApiRequestPipeline) -> None:
"""测试记录失败的审计事件"""
mock_context = MagicMock()
mock_context.db = MagicMock()
mock_context.user = MagicMock()
mock_context.user.id = "user-123"
mock_context.api_key = MagicMock()
mock_context.api_key.id = "key-123"
mock_context.request_id = "req-123"
mock_context.client_ip = "127.0.0.1"
mock_context.user_agent = "test-agent"
mock_context.request = MagicMock()
mock_context.request.method = "POST"
mock_context.request.url.path = "/v1/messages"
mock_context.start_time = 1000.0
mock_adapter = MagicMock()
mock_adapter.name = "test-adapter"
mock_adapter.audit_log_enabled = True
mock_adapter.audit_success_event = None
mock_adapter.audit_failure_event = None
with patch.object(
pipeline.audit_service,
"log_event",
) as mock_log:
with patch("time.time", return_value=1001.0):
pipeline._record_audit_event(
mock_context, mock_adapter, success=False, status_code=500, error="Internal error"
)
mock_log.assert_called_once()
call_kwargs = mock_log.call_args[1]
assert call_kwargs["status_code"] == 500
assert call_kwargs["error_message"] == "Internal error"
def test_record_audit_event_no_db(self, pipeline: ApiRequestPipeline) -> None:
"""测试没有数据库会话时跳过审计"""
mock_context = MagicMock()
mock_context.db = None
mock_adapter = MagicMock()
mock_adapter.audit_log_enabled = True
with patch.object(
pipeline.audit_service,
"log_event",
) as mock_log:
# 不应该抛出异常
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
# 不应该调用 log_event
mock_log.assert_not_called()
def test_record_audit_event_disabled(self, pipeline: ApiRequestPipeline) -> None:
"""测试审计日志被禁用时跳过"""
mock_context = MagicMock()
mock_context.db = MagicMock()
mock_adapter = MagicMock()
mock_adapter.audit_log_enabled = False
with patch.object(
pipeline.audit_service,
"log_event",
) as mock_log:
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
mock_log.assert_not_called()
def test_record_audit_event_exception_handling(self, pipeline: ApiRequestPipeline) -> None:
"""测试审计日志异常不影响主流程"""
mock_context = MagicMock()
mock_context.db = MagicMock()
mock_context.user = MagicMock()
mock_context.user.id = "user-123"
mock_context.api_key = MagicMock()
mock_context.api_key.id = "key-123"
mock_context.request_id = "req-123"
mock_context.client_ip = "127.0.0.1"
mock_context.user_agent = "test-agent"
mock_context.request = MagicMock()
mock_context.request.method = "POST"
mock_context.request.url.path = "/v1/messages"
mock_context.start_time = 1000.0
mock_adapter = MagicMock()
mock_adapter.name = "test-adapter"
mock_adapter.audit_log_enabled = True
mock_adapter.audit_success_event = None
with patch.object(
pipeline.audit_service,
"log_event",
side_effect=Exception("DB error"),
):
with patch("time.time", return_value=1001.0):
# 不应该抛出异常
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
class TestPipelineAuthentication:
"""测试 Pipeline 认证相关逻辑"""
@pytest.fixture
def pipeline(self) -> ApiRequestPipeline:
return ApiRequestPipeline()
def test_authenticate_client_missing_key(self, pipeline: ApiRequestPipeline) -> None:
"""测试缺少 API Key 时抛出异常"""
mock_request = MagicMock()
mock_request.headers = {}
mock_request.url.path = "/v1/messages"
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_adapter = MagicMock()
mock_adapter.extract_api_key = MagicMock(return_value=None)
with pytest.raises(HTTPException) as exc_info:
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
assert exc_info.value.status_code == 401
assert "API密钥" in exc_info.value.detail
def test_authenticate_client_invalid_key(self, pipeline: ApiRequestPipeline) -> None:
"""测试无效的 API Key"""
mock_request = MagicMock()
mock_request.headers = {"Authorization": "Bearer sk-invalid"}
mock_request.url.path = "/v1/messages"
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_adapter = MagicMock()
mock_adapter.extract_api_key = MagicMock(return_value="sk-invalid")
with patch.object(
pipeline.auth_service,
"authenticate_api_key",
return_value=None,
):
with pytest.raises(HTTPException) as exc_info:
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
assert exc_info.value.status_code == 401
def test_authenticate_client_quota_exceeded(self, pipeline: ApiRequestPipeline) -> None:
"""测试配额超限时抛出异常"""
mock_user = MagicMock()
mock_user.id = "user-123"
mock_user.quota_usd = 100.0
mock_user.used_usd = 100.0
mock_api_key = MagicMock()
mock_api_key.id = "key-123"
mock_api_key.is_standalone = False
mock_request = MagicMock()
mock_request.headers = {"Authorization": "Bearer sk-test"}
mock_request.url.path = "/v1/messages"
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_adapter = MagicMock()
mock_adapter.extract_api_key = MagicMock(return_value="sk-test")
with patch.object(
pipeline.auth_service,
"authenticate_api_key",
return_value=(mock_user, mock_api_key),
):
with patch.object(
pipeline.usage_service,
"check_user_quota",
return_value=(False, "配额不足"),
):
from src.core.exceptions import QuotaExceededException
with pytest.raises(QuotaExceededException):
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
class TestPipelineAdminAuth:
"""测试管理员认证"""
@pytest.fixture
def pipeline(self) -> ApiRequestPipeline:
return ApiRequestPipeline()
@pytest.mark.asyncio
async def test_authenticate_admin_missing_token(self, pipeline: ApiRequestPipeline) -> None:
"""测试缺少管理员令牌"""
mock_request = MagicMock()
mock_request.headers = {}
mock_db = MagicMock()
with pytest.raises(HTTPException) as exc_info:
await pipeline._authenticate_admin(mock_request, mock_db)
assert exc_info.value.status_code == 401
assert "管理员凭证" in exc_info.value.detail
@pytest.mark.asyncio
async def test_authenticate_admin_invalid_token(self, pipeline: ApiRequestPipeline) -> None:
"""测试无效的管理员令牌"""
mock_request = MagicMock()
mock_request.headers = {"authorization": "Bearer invalid-token"}
mock_db = MagicMock()
with patch.object(
pipeline.auth_service,
"verify_token",
side_effect=HTTPException(status_code=401, detail="Invalid token"),
):
with pytest.raises(HTTPException) as exc_info:
await pipeline._authenticate_admin(mock_request, mock_db)
assert exc_info.value.status_code == 401
@pytest.mark.asyncio
async def test_authenticate_admin_success(self, pipeline: ApiRequestPipeline) -> None:
"""测试管理员认证成功"""
mock_user = MagicMock()
mock_user.id = "admin-123"
mock_user.is_active = True
mock_request = MagicMock()
mock_request.headers = {"authorization": "Bearer valid-token"}
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
with patch.object(
pipeline.auth_service,
"verify_token",
new_callable=AsyncMock,
return_value={"user_id": "admin-123"},
):
result = await pipeline._authenticate_admin(mock_request, mock_db)
assert result == mock_user
assert mock_request.state.user_id == "admin-123"

View File

@@ -0,0 +1 @@
"""服务层测试"""

299
tests/services/test_auth.py Normal file
View File

@@ -0,0 +1,299 @@
"""
认证服务测试
测试 AuthService 的核心功能:
- JWT Token 创建和验证
- 用户登录认证
- API Key 认证
"""
import pytest
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import jwt
from src.services.auth.service import (
AuthService,
JWT_SECRET_KEY,
JWT_ALGORITHM,
JWT_EXPIRATION_HOURS,
)
class TestJWTTokenCreation:
"""测试 JWT Token 创建"""
def test_create_access_token_contains_required_fields(self) -> None:
"""测试访问令牌包含必要字段"""
data = {"sub": "user123", "email": "test@example.com"}
token = AuthService.create_access_token(data)
# 解码验证
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
assert payload["sub"] == "user123"
assert payload["email"] == "test@example.com"
assert payload["type"] == "access"
assert "exp" in payload
def test_create_access_token_expiration(self) -> None:
"""测试访问令牌过期时间正确"""
data = {"sub": "user123"}
token = AuthService.create_access_token(data)
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
# 验证过期时间在预期范围内允许1分钟误差
exp_time = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
expected_exp = datetime.now(timezone.utc) + timedelta(hours=JWT_EXPIRATION_HOURS)
assert abs((exp_time - expected_exp).total_seconds()) < 60
def test_create_refresh_token_type(self) -> None:
"""测试刷新令牌类型正确"""
data = {"sub": "user123"}
token = AuthService.create_refresh_token(data)
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
assert payload["type"] == "refresh"
def test_create_refresh_token_longer_expiration(self) -> None:
"""测试刷新令牌过期时间更长"""
data = {"sub": "user123"}
access_token = AuthService.create_access_token(data)
refresh_token = AuthService.create_refresh_token(data)
access_payload = jwt.decode(access_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
refresh_payload = jwt.decode(refresh_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
# 刷新令牌应该比访问令牌过期时间更长
assert refresh_payload["exp"] > access_payload["exp"]
class TestJWTTokenVerification:
"""测试 JWT Token 验证"""
@pytest.mark.asyncio
async def test_verify_valid_access_token(self) -> None:
"""测试验证有效的访问令牌"""
data = {"sub": "user123", "email": "test@example.com"}
token = AuthService.create_access_token(data)
with patch(
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
new_callable=AsyncMock,
return_value=False,
):
payload = await AuthService.verify_token(token, token_type="access")
assert payload["sub"] == "user123"
assert payload["type"] == "access"
@pytest.mark.asyncio
async def test_verify_expired_token_raises_error(self) -> None:
"""测试验证过期令牌抛出异常"""
# 创建一个已过期的 token
data = {"sub": "user123", "type": "access"}
expire = datetime.now(timezone.utc) - timedelta(hours=1)
data["exp"] = expire
expired_token = jwt.encode(data, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
await AuthService.verify_token(expired_token)
assert exc_info.value.status_code == 401
assert "过期" in exc_info.value.detail
@pytest.mark.asyncio
async def test_verify_invalid_token_raises_error(self) -> None:
"""测试验证无效令牌抛出异常"""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
await AuthService.verify_token("invalid.token.here")
assert exc_info.value.status_code == 401
@pytest.mark.asyncio
async def test_verify_wrong_token_type_raises_error(self) -> None:
"""测试令牌类型不匹配抛出异常"""
data = {"sub": "user123"}
refresh_token = AuthService.create_refresh_token(data)
from fastapi import HTTPException
with patch(
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
new_callable=AsyncMock,
return_value=False,
):
with pytest.raises(HTTPException) as exc_info:
await AuthService.verify_token(refresh_token, token_type="access")
assert exc_info.value.status_code == 401
assert "类型错误" in exc_info.value.detail
@pytest.mark.asyncio
async def test_verify_blacklisted_token_raises_error(self) -> None:
"""测试已撤销的令牌抛出异常"""
data = {"sub": "user123"}
token = AuthService.create_access_token(data)
from fastapi import HTTPException
with patch(
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
new_callable=AsyncMock,
return_value=True,
):
with pytest.raises(HTTPException) as exc_info:
await AuthService.verify_token(token)
assert exc_info.value.status_code == 401
assert "撤销" in exc_info.value.detail
class TestUserAuthentication:
"""测试用户登录认证"""
@pytest.mark.asyncio
async def test_authenticate_user_success(self) -> None:
"""测试用户登录成功"""
# Mock 数据库和用户对象
mock_user = MagicMock()
mock_user.id = "user-123"
mock_user.email = "test@example.com"
mock_user.is_active = True
mock_user.verify_password.return_value = True
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
with patch(
"src.services.auth.service.UserCacheService.invalidate_user_cache",
new_callable=AsyncMock,
):
result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123")
assert result == mock_user
mock_user.verify_password.assert_called_once_with("password123")
mock_db.commit.assert_called_once()
@pytest.mark.asyncio
async def test_authenticate_user_not_found(self) -> None:
"""测试用户不存在"""
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = None
result = await AuthService.authenticate_user(mock_db, "nonexistent@example.com", "password")
assert result is None
@pytest.mark.asyncio
async def test_authenticate_user_wrong_password(self) -> None:
"""测试密码错误"""
mock_user = MagicMock()
mock_user.email = "test@example.com"
mock_user.verify_password.return_value = False
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
result = await AuthService.authenticate_user(mock_db, "test@example.com", "wrongpassword")
assert result is None
@pytest.mark.asyncio
async def test_authenticate_user_inactive(self) -> None:
"""测试用户已禁用"""
mock_user = MagicMock()
mock_user.email = "test@example.com"
mock_user.is_active = False
mock_user.verify_password.return_value = True
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123")
assert result is None
class TestAPIKeyAuthentication:
"""测试 API Key 认证"""
def test_authenticate_api_key_success(self) -> None:
"""测试 API Key 认证成功"""
mock_user = MagicMock()
mock_user.id = "user-123"
mock_user.email = "test@example.com"
mock_user.is_active = True
mock_api_key = MagicMock()
mock_api_key.is_active = True
mock_api_key.expires_at = None
mock_api_key.user = mock_user
mock_api_key.balance_used_usd = 0.0
mock_db = MagicMock()
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
mock_api_key
)
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
with patch(
"src.services.auth.service.ApiKeyService.check_balance",
return_value=(True, 100.0),
):
result = AuthService.authenticate_api_key(mock_db, "sk-test-key")
assert result is not None
assert result[0] == mock_user
assert result[1] == mock_api_key
def test_authenticate_api_key_not_found(self) -> None:
"""测试 API Key 不存在"""
mock_db = MagicMock()
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
None
)
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
result = AuthService.authenticate_api_key(mock_db, "sk-invalid-key")
assert result is None
def test_authenticate_api_key_inactive(self) -> None:
"""测试 API Key 已禁用"""
mock_api_key = MagicMock()
mock_api_key.is_active = False
mock_db = MagicMock()
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
mock_api_key
)
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
result = AuthService.authenticate_api_key(mock_db, "sk-inactive-key")
assert result is None
def test_authenticate_api_key_expired(self) -> None:
"""测试 API Key 已过期"""
mock_api_key = MagicMock()
mock_api_key.is_active = True
mock_api_key.expires_at = datetime.now(timezone.utc) - timedelta(days=1)
mock_db = MagicMock()
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
mock_api_key
)
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
result = AuthService.authenticate_api_key(mock_db, "sk-expired-key")
assert result is None

View File

@@ -0,0 +1,292 @@
"""
UsageService 测试
测试用量统计服务的核心功能:
- 成本计算
- 配额检查
- 用量统计查询
"""
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from src.services.usage.service import UsageService
class TestCostCalculation:
"""测试成本计算"""
def test_calculate_cost_basic(self) -> None:
"""测试基础成本计算"""
# 价格:输入 $3/1M, 输出 $15/1M
result = UsageService.calculate_cost(
input_tokens=1000,
output_tokens=500,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
)
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, request_cost, total_cost = result
# 1000 tokens * $3 / 1M = $0.003
assert abs(input_cost - 0.003) < 0.0001
# 500 tokens * $15 / 1M = $0.0075
assert abs(output_cost - 0.0075) < 0.0001
# Total = $0.003 + $0.0075 = $0.0105
assert abs(total_cost - 0.0105) < 0.0001
def test_calculate_cost_with_cache(self) -> None:
"""测试带缓存的成本计算"""
result = UsageService.calculate_cost(
input_tokens=1000,
output_tokens=500,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
cache_creation_input_tokens=200,
cache_read_input_tokens=300,
cache_creation_price_per_1m=3.75, # 1.25x input price
cache_read_price_per_1m=0.3, # 0.1x input price
)
(
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
cache_cost,
request_cost,
total_cost,
) = result
# 验证缓存成本被计算
assert cache_creation_cost > 0
assert cache_read_cost > 0
assert cache_cost == cache_creation_cost + cache_read_cost
def test_calculate_cost_with_request_price(self) -> None:
"""测试按次计费"""
result = UsageService.calculate_cost(
input_tokens=1000,
output_tokens=500,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
price_per_request=0.01,
)
(
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
cache_cost,
request_cost,
total_cost,
) = result
assert request_cost == 0.01
# Total 包含 request_cost
assert total_cost == input_cost + output_cost + request_cost
def test_calculate_cost_zero_tokens(self) -> None:
"""测试零 token 的成本计算"""
result = UsageService.calculate_cost(
input_tokens=0,
output_tokens=0,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
)
(
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
cache_cost,
request_cost,
total_cost,
) = result
assert input_cost == 0
assert output_cost == 0
assert total_cost == 0
class TestQuotaCheck:
"""测试配额检查"""
def test_check_user_quota_sufficient(self) -> None:
"""测试配额充足"""
mock_user = MagicMock()
mock_user.quota_usd = 100.0
mock_user.used_usd = 30.0
mock_user.role = MagicMock()
mock_user.role.value = "user"
mock_api_key = MagicMock()
mock_api_key.is_standalone = False
mock_db = MagicMock()
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
assert is_ok is True
def test_check_user_quota_exceeded(self) -> None:
"""测试配额超限(当有预估成本时)"""
mock_user = MagicMock()
mock_user.quota_usd = 100.0
mock_user.used_usd = 99.0 # 接近配额上限
mock_user.role = MagicMock()
mock_user.role.value = "user"
mock_api_key = MagicMock()
mock_api_key.is_standalone = False
mock_db = MagicMock()
# 当预估成本超过剩余配额时应该返回 False
is_ok, message = UsageService.check_user_quota(
mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key
)
assert is_ok is False
assert "配额" in message
def test_check_user_quota_no_limit(self) -> None:
"""测试无配额限制None"""
mock_user = MagicMock()
mock_user.quota_usd = None
mock_user.used_usd = 1000.0
mock_user.role = MagicMock()
mock_user.role.value = "user"
mock_api_key = MagicMock()
mock_api_key.is_standalone = False
mock_db = MagicMock()
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
assert is_ok is True
def test_check_user_quota_admin_bypass(self) -> None:
"""测试管理员绕过配额检查"""
from src.models.database import UserRole
mock_user = MagicMock()
mock_user.quota_usd = 0.0
mock_user.used_usd = 1000.0
mock_user.role = UserRole.ADMIN
mock_api_key = MagicMock()
mock_api_key.is_standalone = False
mock_db = MagicMock()
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
assert is_ok is True
def test_check_standalone_api_key_balance(self) -> None:
"""测试独立 API Key 余额检查"""
mock_user = MagicMock()
mock_user.quota_usd = 0.0
mock_user.used_usd = 0.0
mock_user.role = MagicMock()
mock_user.role.value = "user"
mock_api_key = MagicMock()
mock_api_key.is_standalone = True
mock_api_key.current_balance_usd = 50.0
mock_api_key.balance_used_usd = 10.0
mock_db = MagicMock()
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
assert is_ok is True
def test_check_standalone_api_key_insufficient_balance(self) -> None:
"""测试独立 API Key 余额不足"""
mock_user = MagicMock()
mock_user.quota_usd = 100.0
mock_user.used_usd = 0.0
mock_user.role = MagicMock()
mock_user.role.value = "user"
mock_api_key = MagicMock()
mock_api_key.is_standalone = True
mock_api_key.current_balance_usd = 10.0
mock_api_key.balance_used_usd = 9.0 # 剩余 $1
mock_db = MagicMock()
# 需要 mock ApiKeyService.get_remaining_balance
with patch(
"src.services.user.apikey.ApiKeyService.get_remaining_balance",
return_value=1.0,
):
# 预估成本 $5 超过剩余余额 $1
is_ok, message = UsageService.check_user_quota(
mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key
)
assert is_ok is False
class TestUsageStatistics:
"""测试用量统计查询
注意get_usage_summary 方法内部使用了数据库方言特定的日期函数,
需要真实数据库或更复杂的 mock。这里只测试方法存在性。
"""
def test_get_usage_summary_exists(self) -> None:
"""测试 get_usage_summary 方法存在"""
assert hasattr(UsageService, "get_usage_summary")
assert callable(getattr(UsageService, "get_usage_summary"))
class TestHelperMethods:
"""测试辅助方法"""
@pytest.mark.asyncio
async def test_get_rate_multiplier_and_free_tier_default(self) -> None:
"""测试默认费率倍数"""
mock_db = MagicMock()
# 模拟未找到 provider_api_key
mock_db.query.return_value.filter.return_value.first.return_value = None
rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier(
mock_db, provider_api_key_id=None, provider_id=None
)
assert rate_multiplier == 1.0
assert is_free_tier is False
@pytest.mark.asyncio
async def test_get_rate_multiplier_from_provider_api_key(self) -> None:
"""测试从 ProviderAPIKey 获取费率倍数"""
mock_provider_api_key = MagicMock()
mock_provider_api_key.rate_multiplier = 0.8
mock_endpoint = MagicMock()
mock_endpoint.provider_id = "provider-123"
mock_provider = MagicMock()
mock_provider.billing_type = "standard"
mock_db = MagicMock()
# 第一次查询返回 provider_api_key
mock_db.query.return_value.filter.return_value.first.side_effect = [
mock_provider_api_key,
mock_endpoint,
mock_provider,
]
rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier(
mock_db, provider_api_key_id="pak-123", provider_id=None
)
assert rate_multiplier == 0.8
assert is_free_tier is False