mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 08:12:26 +08:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21587449c8 | ||
|
|
3d0ab353d3 | ||
|
|
b2a857c164 | ||
|
|
4d1d863916 | ||
|
|
b579420690 |
@@ -22,7 +22,7 @@
|
||||
/>
|
||||
</Transition>
|
||||
|
||||
<div class="relative flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
|
||||
<div class="relative flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0 pointer-events-none">
|
||||
<!-- 对话框内容 -->
|
||||
<Transition
|
||||
enter-active-class="duration-300 ease-out"
|
||||
|
||||
@@ -3,10 +3,8 @@
|
||||
A proxy server that enables AI models to work with multiple API providers.
|
||||
"""
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
# 注意: dotenv 加载已统一移至 src/config/settings.py
|
||||
# 不要在此处重复加载
|
||||
|
||||
try:
|
||||
from src._version import __version__
|
||||
|
||||
@@ -7,7 +7,7 @@ from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
@@ -52,8 +52,7 @@ class CandidateResponse(BaseModel):
|
||||
started_at: Optional[datetime] = None
|
||||
finished_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class RequestTraceResponse(BaseModel):
|
||||
|
||||
@@ -142,7 +142,7 @@ class AnnouncementOptionalAuthAdapter(ApiAdapter):
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
try:
|
||||
payload = await AuthService.verify_token(token)
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
user_id = payload.get("user_id")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
@@ -211,7 +211,7 @@ class AuthRefreshAdapter(AuthPublicAdapter):
|
||||
|
||||
class AuthRegisterAdapter(AuthPublicAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from ..models.database import SystemConfig
|
||||
from src.models.database import SystemConfig
|
||||
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
@@ -5,13 +5,12 @@ from enum import Enum
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.exceptions import QuotaExceededException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ApiKey, AuditEventType, User, UserRole
|
||||
from src.services.auth.service import AuthService
|
||||
from src.services.cache.user_cache import UserCacheService
|
||||
from src.services.system.audit import AuditService
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
@@ -180,7 +179,7 @@ class ApiRequestPipeline:
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token)
|
||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
@@ -191,8 +190,8 @@ class ApiRequestPipeline:
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="无效的管理员令牌")
|
||||
|
||||
# 使用缓存查询用户
|
||||
user = await UserCacheService.get_user_by_id(db, user_id)
|
||||
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||||
|
||||
@@ -207,7 +206,7 @@ class ApiRequestPipeline:
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token)
|
||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
@@ -218,8 +217,8 @@ class ApiRequestPipeline:
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="无效的用户令牌")
|
||||
|
||||
# 使用缓存查询用户
|
||||
user = await UserCacheService.get_user_by_id(db, user_id)
|
||||
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||||
|
||||
@@ -242,11 +241,15 @@ class ApiRequestPipeline:
|
||||
status_code: Optional[int] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
"""记录审计事件
|
||||
|
||||
事务策略:复用请求级 Session,不单独提交。
|
||||
审计记录随主事务一起提交,由中间件统一管理。
|
||||
"""
|
||||
if not getattr(adapter, "audit_log_enabled", True):
|
||||
return
|
||||
|
||||
bind = context.db.get_bind()
|
||||
if bind is None:
|
||||
if context.db is None:
|
||||
return
|
||||
|
||||
event_type = adapter.audit_success_event if success else adapter.audit_failure_event
|
||||
@@ -266,11 +269,11 @@ class ApiRequestPipeline:
|
||||
error=error,
|
||||
)
|
||||
|
||||
SessionMaker = sessionmaker(bind=bind)
|
||||
audit_session = SessionMaker()
|
||||
try:
|
||||
# 复用请求级 Session,不创建新的连接
|
||||
# 审计记录随主事务一起提交,由中间件统一管理
|
||||
self.audit_service.log_event(
|
||||
db=audit_session,
|
||||
db=context.db,
|
||||
event_type=event_type,
|
||||
description=f"{context.request.method} {context.request.url.path} via {adapter.name}",
|
||||
user_id=context.user.id if context.user else None,
|
||||
@@ -282,12 +285,9 @@ class ApiRequestPipeline:
|
||||
error_message=error,
|
||||
metadata=metadata,
|
||||
)
|
||||
audit_session.commit()
|
||||
except Exception as exc:
|
||||
audit_session.rollback()
|
||||
# 审计失败不应影响主请求,仅记录警告
|
||||
logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}")
|
||||
finally:
|
||||
audit_session.close()
|
||||
|
||||
def _build_audit_metadata(
|
||||
self,
|
||||
|
||||
@@ -411,9 +411,10 @@ class BaseMessageHandler:
|
||||
QuotaExceededException,
|
||||
RateLimitException,
|
||||
ModelNotSupportedException,
|
||||
UpstreamClientException,
|
||||
)
|
||||
|
||||
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
|
||||
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException, UpstreamClientException)):
|
||||
# 业务异常:简洁日志,不打印堆栈
|
||||
logger.error(f"{message}: [{type(error).__name__}] {error}")
|
||||
else:
|
||||
|
||||
@@ -267,6 +267,9 @@ async def get_redis_client(require_redis: bool = False) -> Optional[aioredis.Red
|
||||
|
||||
if _redis_manager is None:
|
||||
_redis_manager = RedisClientManager()
|
||||
# 如果尚未连接(例如启动时降级、或 close() 后),尝试重新初始化。
|
||||
# initialize() 内部包含熔断器逻辑,避免频繁重试导致抖动。
|
||||
if _redis_manager.get_client() is None:
|
||||
await _redis_manager.initialize(require_redis=require_redis)
|
||||
|
||||
return _redis_manager.get_client()
|
||||
|
||||
@@ -41,8 +41,8 @@ class CacheSize:
|
||||
class ConcurrencyDefaults:
|
||||
"""并发控制默认值"""
|
||||
|
||||
# 自适应并发初始限制(保守值)
|
||||
INITIAL_LIMIT = 3
|
||||
# 自适应并发初始限制(宽松起步,遇到 429 再降低)
|
||||
INITIAL_LIMIT = 50
|
||||
|
||||
# 429错误后的冷却时间(分钟)- 在此期间不会增加并发限制
|
||||
COOLDOWN_AFTER_429_MINUTES = 5
|
||||
@@ -67,13 +67,14 @@ class ConcurrencyDefaults:
|
||||
MIN_SAMPLES_FOR_DECISION = 5
|
||||
|
||||
# 扩容步长 - 每次扩容增加的并发数
|
||||
INCREASE_STEP = 1
|
||||
INCREASE_STEP = 2
|
||||
|
||||
# 缩容乘数 - 遇到 429 时的缩容比例
|
||||
DECREASE_MULTIPLIER = 0.7
|
||||
# 缩容乘数 - 遇到 429 时基于当前并发数的缩容比例
|
||||
# 0.85 表示降到触发 429 时并发数的 85%
|
||||
DECREASE_MULTIPLIER = 0.85
|
||||
|
||||
# 最大并发限制上限
|
||||
MAX_CONCURRENT_LIMIT = 100
|
||||
MAX_CONCURRENT_LIMIT = 200
|
||||
|
||||
# 最小并发限制下限
|
||||
MIN_CONCURRENT_LIMIT = 1
|
||||
@@ -85,6 +86,11 @@ class ConcurrencyDefaults:
|
||||
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
|
||||
PROBE_INCREASE_MIN_REQUESTS = 10
|
||||
|
||||
# === 缓存用户预留比例 ===
|
||||
# 缓存用户槽位预留比例(新用户可用 1 - 此值)
|
||||
# 0.1 表示缓存用户预留 10%,新用户可用 90%
|
||||
CACHE_RESERVATION_RATIO = 0.1
|
||||
|
||||
|
||||
class CircuitBreakerDefaults:
|
||||
"""熔断器配置默认值(滑动窗口 + 半开状态模式)
|
||||
|
||||
@@ -122,9 +122,9 @@ class Config:
|
||||
|
||||
# 并发控制配置
|
||||
# CONCURRENCY_SLOT_TTL: 并发槽位 TTL(秒),防止死锁
|
||||
# CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 30%)
|
||||
# CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 10%,新用户可用 90%)
|
||||
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
|
||||
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.3"))
|
||||
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1"))
|
||||
|
||||
# HTTP 请求超时配置(秒)
|
||||
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
|
||||
|
||||
@@ -46,6 +46,11 @@ class BatchCommitter:
|
||||
|
||||
def mark_dirty(self, session: Session):
|
||||
"""标记 Session 有待提交的更改"""
|
||||
# 请求级事务由中间件统一 commit/rollback;避免后台任务在请求中途误提交。
|
||||
if session is None:
|
||||
return
|
||||
if session.info.get("managed_by_middleware"):
|
||||
return
|
||||
self._pending_sessions.add(session)
|
||||
|
||||
async def _batch_commit_loop(self):
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
"""
|
||||
统一的请求上下文
|
||||
|
||||
RequestContext 贯穿整个请求生命周期,包含所有请求相关信息。
|
||||
这确保了数据在各层之间传递时不会丢失。
|
||||
|
||||
使用方式:
|
||||
1. Pipeline 层创建 RequestContext
|
||||
2. 各层通过 context 访问和更新信息
|
||||
3. Adapter 层使用 context 记录 Usage
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
"""
|
||||
请求上下文 - 贯穿整个请求生命周期
|
||||
|
||||
设计原则:
|
||||
1. 在请求开始时创建,包含所有已知信息
|
||||
2. 在请求执行过程中逐步填充 Provider 信息
|
||||
3. 在请求结束时用于记录 Usage
|
||||
"""
|
||||
|
||||
# ==================== 请求标识 ====================
|
||||
request_id: str
|
||||
|
||||
# ==================== 认证信息 ====================
|
||||
user: Any # User model
|
||||
api_key: Any # ApiKey model
|
||||
db: Any # Database session
|
||||
|
||||
# ==================== 请求信息 ====================
|
||||
api_format: str # CLAUDE, OPENAI, GEMINI, etc.
|
||||
model: str # 用户请求的模型名
|
||||
is_stream: bool = False
|
||||
|
||||
# ==================== 原始请求 ====================
|
||||
original_headers: Dict[str, str] = field(default_factory=dict)
|
||||
original_body: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# ==================== 客户端信息 ====================
|
||||
client_ip: str = "unknown"
|
||||
user_agent: str = ""
|
||||
|
||||
# ==================== 计时 ====================
|
||||
start_time: float = field(default_factory=time.time)
|
||||
|
||||
# ==================== Provider 信息(请求执行后填充)====================
|
||||
provider_name: Optional[str] = None
|
||||
provider_id: Optional[str] = None
|
||||
endpoint_id: Optional[str] = None
|
||||
provider_api_key_id: Optional[str] = None
|
||||
|
||||
# ==================== 模型映射信息 ====================
|
||||
resolved_model: Optional[str] = None # 映射后的模型名
|
||||
original_model: Optional[str] = None # 原始模型名(用于价格计算)
|
||||
|
||||
# ==================== 请求/响应头 ====================
|
||||
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
||||
provider_response_headers: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# ==================== 追踪信息 ====================
|
||||
attempt_id: Optional[str] = None
|
||||
|
||||
# ==================== 能力需求 ====================
|
||||
capability_requirements: Dict[str, bool] = field(default_factory=dict)
|
||||
# 运行时计算的能力需求,来源于:
|
||||
# 1. 用户 model_capability_settings
|
||||
# 2. 用户 ApiKey.force_capabilities
|
||||
# 3. 请求头 X-Require-Capability
|
||||
# 4. 失败重试时动态添加
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
*,
|
||||
db: Any,
|
||||
user: Any,
|
||||
api_key: Any,
|
||||
api_format: str,
|
||||
model: str,
|
||||
is_stream: bool = False,
|
||||
original_headers: Optional[Dict[str, str]] = None,
|
||||
original_body: Optional[Dict[str, Any]] = None,
|
||||
client_ip: str = "unknown",
|
||||
user_agent: str = "",
|
||||
request_id: Optional[str] = None,
|
||||
) -> "RequestContext":
|
||||
"""创建请求上下文"""
|
||||
return cls(
|
||||
request_id=request_id or str(uuid.uuid4()),
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
api_format=api_format,
|
||||
model=model,
|
||||
is_stream=is_stream,
|
||||
original_headers=original_headers or {},
|
||||
original_body=original_body or {},
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
original_model=model, # 初始时原始模型等于请求模型
|
||||
)
|
||||
|
||||
def update_provider_info(
|
||||
self,
|
||||
*,
|
||||
provider_name: str,
|
||||
provider_id: str,
|
||||
endpoint_id: str,
|
||||
provider_api_key_id: str,
|
||||
resolved_model: Optional[str] = None,
|
||||
) -> None:
|
||||
"""更新 Provider 信息(请求执行后调用)"""
|
||||
self.provider_name = provider_name
|
||||
self.provider_id = provider_id
|
||||
self.endpoint_id = endpoint_id
|
||||
self.provider_api_key_id = provider_api_key_id
|
||||
if resolved_model:
|
||||
self.resolved_model = resolved_model
|
||||
|
||||
def update_headers(
|
||||
self,
|
||||
*,
|
||||
request_headers: Optional[Dict[str, str]] = None,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
"""更新请求/响应头"""
|
||||
if request_headers:
|
||||
self.provider_request_headers = request_headers
|
||||
if response_headers:
|
||||
self.provider_response_headers = response_headers
|
||||
|
||||
@property
|
||||
def elapsed_ms(self) -> int:
|
||||
"""计算已经过的时间(毫秒)"""
|
||||
return int((time.time() - self.start_time) * 1000)
|
||||
|
||||
@property
|
||||
def effective_model(self) -> str:
|
||||
"""获取有效的模型名(映射后优先)"""
|
||||
return self.resolved_model or self.model
|
||||
|
||||
@property
|
||||
def billing_model(self) -> str:
|
||||
"""获取计费模型名(原始模型优先)"""
|
||||
return self.original_model or self.model
|
||||
|
||||
def to_metadata_dict(self) -> Dict[str, Any]:
|
||||
"""转换为元数据字典(用于 Usage 记录)"""
|
||||
return {
|
||||
"api_format": self.api_format,
|
||||
"provider": self.provider_name or "unknown",
|
||||
"model": self.effective_model,
|
||||
"original_model": self.billing_model,
|
||||
"provider_id": self.provider_id,
|
||||
"provider_endpoint_id": self.endpoint_id,
|
||||
"provider_api_key_id": self.provider_api_key_id,
|
||||
"provider_request_headers": self.provider_request_headers,
|
||||
"provider_response_headers": self.provider_response_headers,
|
||||
"attempt_id": self.attempt_id,
|
||||
}
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
输出策略:
|
||||
- 控制台: 开发环境=DEBUG, 生产环境=INFO (通过 LOG_LEVEL 控制)
|
||||
- 文件: 始终保存 DEBUG 级别,保留30天,每日轮转
|
||||
- 文件: 始终保存 DEBUG 级别,保留30天,按大小轮转 (100MB)
|
||||
|
||||
使用方式:
|
||||
from src.core.logger import logger
|
||||
@@ -72,12 +72,15 @@ def _log_filter(record: dict) -> bool: # type: ignore[type-arg]
|
||||
|
||||
|
||||
if IS_DOCKER:
|
||||
# 生产环境:禁用 backtrace 和 diagnose,减少日志噪音
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format=CONSOLE_FORMAT_PROD,
|
||||
level=LOG_LEVEL,
|
||||
filter=_log_filter, # type: ignore[arg-type]
|
||||
colorize=False,
|
||||
backtrace=False,
|
||||
diagnose=False,
|
||||
)
|
||||
else:
|
||||
logger.add(
|
||||
@@ -92,30 +95,37 @@ if not DISABLE_FILE_LOG:
|
||||
log_dir = PROJECT_ROOT / "logs"
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 文件日志通用配置
|
||||
file_log_config = {
|
||||
"format": FILE_FORMAT,
|
||||
"filter": _log_filter,
|
||||
"rotation": "100 MB",
|
||||
"retention": "30 days",
|
||||
"compression": "gz",
|
||||
"enqueue": True,
|
||||
"encoding": "utf-8",
|
||||
"catch": True,
|
||||
}
|
||||
|
||||
# 生产环境禁用详细堆栈
|
||||
if IS_DOCKER:
|
||||
file_log_config["backtrace"] = False
|
||||
file_log_config["diagnose"] = False
|
||||
|
||||
# 主日志文件 - 所有级别
|
||||
logger.add(
|
||||
log_dir / "app.log",
|
||||
format=FILE_FORMAT,
|
||||
level="DEBUG",
|
||||
filter=_log_filter, # type: ignore[arg-type]
|
||||
rotation="00:00",
|
||||
retention="30 days",
|
||||
compression="gz",
|
||||
enqueue=True,
|
||||
encoding="utf-8",
|
||||
**file_log_config, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# 错误日志文件 - 仅 ERROR 及以上
|
||||
error_log_config = file_log_config.copy()
|
||||
error_log_config["rotation"] = "50 MB"
|
||||
logger.add(
|
||||
log_dir / "error.log",
|
||||
format=FILE_FORMAT,
|
||||
level="ERROR",
|
||||
filter=_log_filter, # type: ignore[arg-type]
|
||||
rotation="00:00",
|
||||
retention="30 days",
|
||||
compression="gz",
|
||||
enqueue=True,
|
||||
encoding="utf-8",
|
||||
**error_log_config, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
import time
|
||||
from typing import AsyncGenerator, Generator, Optional
|
||||
|
||||
from starlette.requests import Request
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
@@ -150,9 +151,22 @@ def _log_pool_capacity():
|
||||
theoretical = config.db_pool_size + config.db_max_overflow
|
||||
workers = max(1, config.worker_processes)
|
||||
total_estimated = theoretical * workers
|
||||
logger.info("数据库连接池配置")
|
||||
if total_estimated > config.db_pool_warn_threshold:
|
||||
logger.warning("数据库连接需求可能超过阈值,请调小池大小或减少 worker 数")
|
||||
safe_limit = config.pg_max_connections - config.pg_reserved_connections
|
||||
logger.info(
|
||||
"数据库连接池配置: pool_size=%s, max_overflow=%s, workers=%s, total_estimated=%s, safe_limit=%s",
|
||||
config.db_pool_size,
|
||||
config.db_max_overflow,
|
||||
workers,
|
||||
total_estimated,
|
||||
safe_limit,
|
||||
)
|
||||
if total_estimated > safe_limit:
|
||||
logger.warning(
|
||||
"数据库连接池总需求可能超过 PostgreSQL 限制: %s > %s (pg_max_connections - reserved),"
|
||||
"建议调整 DB_POOL_SIZE/DB_MAX_OVERFLOW 或减少 worker 数",
|
||||
total_estimated,
|
||||
safe_limit,
|
||||
)
|
||||
|
||||
|
||||
def _ensure_async_engine() -> AsyncEngine:
|
||||
@@ -185,7 +199,7 @@ def _ensure_async_engine() -> AsyncEngine:
|
||||
# 创建异步引擎
|
||||
_async_engine = create_async_engine(
|
||||
ASYNC_DATABASE_URL,
|
||||
poolclass=QueuePool, # 使用队列连接池
|
||||
# AsyncEngine 不能使用 QueuePool;默认使用 AsyncAdaptedQueuePool
|
||||
pool_size=config.db_pool_size,
|
||||
max_overflow=config.db_max_overflow,
|
||||
pool_timeout=config.db_pool_timeout,
|
||||
@@ -209,7 +223,18 @@ def _ensure_async_engine() -> AsyncEngine:
|
||||
|
||||
|
||||
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取异步数据库会话"""
|
||||
"""获取异步数据库会话
|
||||
|
||||
.. deprecated::
|
||||
此方法已废弃,项目统一使用同步 Session。
|
||||
未来版本可能移除此方法。请使用 get_db() 代替。
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"get_async_db() 已废弃,项目统一使用同步 Session。请使用 get_db() 代替。",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# 确保异步引擎已初始化
|
||||
_ensure_async_engine()
|
||||
|
||||
@@ -220,16 +245,61 @@ async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
await session.close()
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
def get_db(request: Request = None) -> Generator[Session, None, None]: # type: ignore[assignment]
|
||||
"""获取数据库会话
|
||||
|
||||
注意:事务管理由业务逻辑层显式控制(手动调用 commit/rollback)
|
||||
这里只负责会话的创建和关闭,不自动提交
|
||||
事务策略说明
|
||||
============
|
||||
本项目采用**混合事务管理**策略:
|
||||
|
||||
1. **LLM 请求路径**:
|
||||
- 由 PluginMiddleware 统一管理事务
|
||||
- Service 层使用 db.flush() 使更改可见,但不提交
|
||||
- 请求结束时由中间件统一 commit 或 rollback
|
||||
- 例外:UsageService.record_usage() 会显式 commit,因为使用记录需要立即持久化
|
||||
|
||||
2. **管理后台 API**:
|
||||
- 路由层显式调用 db.commit()
|
||||
- 每个操作独立提交,不依赖中间件
|
||||
|
||||
3. **后台任务/调度器**:
|
||||
- 使用独立 Session(通过 create_session() 或 next(get_db()))
|
||||
- 自行管理事务生命周期
|
||||
|
||||
使用方式
|
||||
========
|
||||
- FastAPI 请求:通过 Depends(get_db) 注入,支持中间件管理的 session 复用
|
||||
- 非请求上下文:直接调用 get_db(),退化为独立 session 模式
|
||||
|
||||
注意事项
|
||||
========
|
||||
- 本函数不自动提交事务
|
||||
- 异常时会自动回滚
|
||||
- 中间件管理模式下,session 关闭由中间件负责
|
||||
"""
|
||||
# FastAPI 请求上下文:优先复用中间件绑定的 request.state.db
|
||||
if request is not None:
|
||||
existing_db = getattr(getattr(request, "state", None), "db", None)
|
||||
if isinstance(existing_db, Session):
|
||||
yield existing_db
|
||||
return
|
||||
|
||||
# 确保引擎已初始化
|
||||
_ensure_engine()
|
||||
|
||||
db = _SessionLocal()
|
||||
|
||||
# 如果中间件声明会统一管理会话生命周期,则把 session 绑定到 request.state,
|
||||
# 并由中间件负责 commit/rollback/close(这里不关闭,避免流式响应提前释放会话)。
|
||||
managed_by_middleware = bool(
|
||||
request is not None
|
||||
and hasattr(request, "state")
|
||||
and getattr(request.state, "db_managed_by_middleware", False)
|
||||
)
|
||||
if managed_by_middleware:
|
||||
request.state.db = db
|
||||
db.info["managed_by_middleware"] = True
|
||||
|
||||
try:
|
||||
yield db
|
||||
# 不再自动 commit,由业务代码显式管理事务
|
||||
@@ -241,12 +311,13 @@ def get_db() -> Generator[Session, None, None]:
|
||||
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
db.close() # 确保连接返回池
|
||||
except Exception as close_error:
|
||||
# 记录关闭错误(如 IllegalStateChangeError)
|
||||
# 连接池会处理连接的回收
|
||||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||
if not managed_by_middleware:
|
||||
try:
|
||||
db.close() # 确保连接返回池
|
||||
except Exception as close_error:
|
||||
# 记录关闭错误(如 IllegalStateChangeError)
|
||||
# 连接池会处理连接的回收
|
||||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||
|
||||
|
||||
def create_session() -> Session:
|
||||
@@ -336,7 +407,7 @@ def init_admin_user(db: Session):
|
||||
admin.set_password(config.admin_password)
|
||||
|
||||
db.add(admin)
|
||||
db.commit() # 刷新以获取ID,但不提交
|
||||
db.flush() # 分配ID,但不提交事务(由外层 init_db 统一 commit)
|
||||
|
||||
logger.info(f"创建管理员账户成功: {admin.email} ({admin.username})")
|
||||
except Exception as e:
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
采用模块化架构设计
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
@@ -39,14 +38,12 @@ async def initialize_providers():
|
||||
"""从数据库初始化提供商(仅用于日志记录)"""
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.enums import APIFormat
|
||||
from src.database import get_db
|
||||
from src.database.database import create_session
|
||||
from src.models.database import Provider
|
||||
|
||||
try:
|
||||
# 创建数据库会话
|
||||
db_gen = get_db()
|
||||
db: Session = next(db_gen)
|
||||
db: Session = create_session()
|
||||
|
||||
try:
|
||||
# 从数据库加载所有活跃的提供商
|
||||
@@ -75,7 +72,7 @@ async def initialize_providers():
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("从数据库初始化提供商失败")
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from starlette.responses import Response as StarletteResponse
|
||||
|
||||
from src.config import config
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.plugins.manager import get_plugin_manager
|
||||
from src.plugins.rate_limit.base import RateLimitResult
|
||||
|
||||
@@ -71,26 +70,13 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
start_time = time.time()
|
||||
request.state.request_id = request.headers.get("x-request-id", "")
|
||||
request.state.start_time = start_time
|
||||
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
|
||||
request.state.db_managed_by_middleware = True
|
||||
|
||||
# 从 request.app 获取 FastAPI 应用实例(而不是从 __init__ 的 app 参数)
|
||||
# 这样才能访问到真正的 FastAPI 实例和其 dependency_overrides
|
||||
db_func = get_db
|
||||
if hasattr(request, "app") and hasattr(request.app, "dependency_overrides"):
|
||||
if get_db in request.app.dependency_overrides:
|
||||
db_func = request.app.dependency_overrides[get_db]
|
||||
logger.debug("Using overridden get_db from app.dependency_overrides")
|
||||
|
||||
# 创建数据库会话供需要的插件或后续处理使用
|
||||
db_gen = db_func()
|
||||
db = None
|
||||
response = None
|
||||
exception_to_raise = None
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
db = next(db_gen)
|
||||
request.state.db = db
|
||||
|
||||
# 1. 限流插件调用(可选功能)
|
||||
rate_limit_result = await self._call_rate_limit_plugins(request)
|
||||
if rate_limit_result and not rate_limit_result.allowed:
|
||||
@@ -111,10 +97,17 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
# 3. 提交关键数据库事务(在返回响应前)
|
||||
# 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化
|
||||
try:
|
||||
db.commit()
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
db.commit()
|
||||
except Exception as commit_error:
|
||||
logger.error(f"关键事务提交失败: {commit_error}")
|
||||
db.rollback()
|
||||
try:
|
||||
if isinstance(db, Session):
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
await self._call_error_plugins(request, commit_error, start_time)
|
||||
# 返回 500 错误,因为数据可能不一致
|
||||
response = JSONResponse(
|
||||
status_code=500,
|
||||
@@ -139,14 +132,18 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
except RuntimeError as e:
|
||||
if str(e) == "No response returned.":
|
||||
if db:
|
||||
db.rollback()
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.error("Downstream handler completed without returning a response")
|
||||
|
||||
await self._call_error_plugins(request, e, start_time)
|
||||
|
||||
if db:
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.commit()
|
||||
except Exception:
|
||||
@@ -167,14 +164,18 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
except Exception as e:
|
||||
# 回滚数据库事务
|
||||
if db:
|
||||
db.rollback()
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 错误处理插件调用
|
||||
await self._call_error_plugins(request, e, start_time)
|
||||
|
||||
# 尝试提交错误日志
|
||||
if db:
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.commit()
|
||||
except:
|
||||
@@ -183,38 +184,13 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
exception_to_raise = e
|
||||
|
||||
finally:
|
||||
# 确保数据库会话被正确关闭
|
||||
# 注意:需要安全地处理各种状态,避免 IllegalStateChangeError
|
||||
if db is not None:
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
# 检查会话是否可以安全地进行回滚
|
||||
# 只有当没有进行中的事务操作时才尝试回滚
|
||||
if db.is_active and not db.get_transaction().is_active:
|
||||
# 事务不在活跃状态,可以安全回滚
|
||||
pass
|
||||
elif db.is_active:
|
||||
# 事务在活跃状态,尝试回滚
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception as rollback_error:
|
||||
# 回滚失败(可能是 commit 正在进行中),忽略错误
|
||||
logger.debug(f"Rollback skipped: {rollback_error}")
|
||||
except Exception:
|
||||
# 检查状态时出错,忽略
|
||||
pass
|
||||
|
||||
# 通过触发生成器的 finally 块来关闭会话(标准模式)
|
||||
# 这会调用 get_db() 的 finally 块,执行 db.close()
|
||||
try:
|
||||
next(db_gen, None)
|
||||
except StopIteration:
|
||||
# 正常情况:生成器已耗尽
|
||||
pass
|
||||
except Exception as cleanup_error:
|
||||
# 忽略 IllegalStateChangeError 等清理错误
|
||||
# 这些错误通常是由于事务状态不一致导致的,不影响业务逻辑
|
||||
if "IllegalStateChangeError" not in str(type(cleanup_error).__name__):
|
||||
logger.warning(f"Database cleanup warning: {cleanup_error}")
|
||||
db.close()
|
||||
except Exception as close_error:
|
||||
# 连接池会处理连接的回收,这里的异常不应影响响应
|
||||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||
|
||||
# 在 finally 块之后处理异常和响应
|
||||
if exception_to_raise:
|
||||
@@ -250,7 +226,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
return False
|
||||
|
||||
async def _get_rate_limit_key_and_config(
|
||||
self, request: Request, db: Session
|
||||
self, request: Request
|
||||
) -> tuple[Optional[str], Optional[int]]:
|
||||
"""
|
||||
获取速率限制的key和配置
|
||||
@@ -318,14 +294,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
# 如果没有限流插件,允许通过
|
||||
return None
|
||||
|
||||
# 获取数据库会话
|
||||
db = getattr(request.state, "db", None)
|
||||
if not db:
|
||||
logger.warning("速率限制检查:无法获取数据库会话")
|
||||
return None
|
||||
|
||||
# 获取速率限制的key和配置(从数据库)
|
||||
key, rate_limit_value = await self._get_rate_limit_key_and_config(request, db)
|
||||
# 获取速率限制的 key 和配置
|
||||
key, rate_limit_value = await self._get_rate_limit_key_and_config(request)
|
||||
if not key:
|
||||
# 不需要限流的端点(如未分类路径),静默跳过
|
||||
return None
|
||||
@@ -336,7 +306,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
key=key,
|
||||
endpoint=request.url.path,
|
||||
method=request.method,
|
||||
rate_limit=rate_limit_value, # 传入数据库配置的限制值
|
||||
rate_limit=rate_limit_value, # 传入配置的限制值
|
||||
)
|
||||
# 类型检查:确保返回的是RateLimitResult类型
|
||||
if isinstance(result, RateLimitResult):
|
||||
|
||||
@@ -107,20 +107,6 @@ class CreateProviderRequest(BaseModel):
|
||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||
v = f"https://{v}"
|
||||
|
||||
# 防止 SSRF 攻击:禁止内网地址
|
||||
forbidden_patterns = [
|
||||
r"localhost",
|
||||
r"127\.0\.0\.1",
|
||||
r"0\.0\.0\.0",
|
||||
r"192\.168\.",
|
||||
r"10\.",
|
||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
||||
r"169\.254\.",
|
||||
]
|
||||
for pattern in forbidden_patterns:
|
||||
if re.search(pattern, v, re.IGNORECASE):
|
||||
raise ValueError("不允许使用内网地址")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("billing_type")
|
||||
@@ -195,19 +181,6 @@ class CreateEndpointRequest(BaseModel):
|
||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
||||
|
||||
# 防止 SSRF
|
||||
forbidden_patterns = [
|
||||
r"localhost",
|
||||
r"127\.0\.0\.1",
|
||||
r"0\.0\.0\.0",
|
||||
r"192\.168\.",
|
||||
r"10\.",
|
||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
||||
]
|
||||
for pattern in forbidden_patterns:
|
||||
if re.search(pattern, v, re.IGNORECASE):
|
||||
raise ValueError("不允许使用内网地址")
|
||||
|
||||
return v.rstrip("/") # 移除末尾斜杠
|
||||
|
||||
@field_validator("api_format")
|
||||
|
||||
@@ -6,7 +6,7 @@ import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from ..core.enums import UserRole
|
||||
|
||||
@@ -336,8 +336,7 @@ class ProviderResponse(BaseModel):
|
||||
active_models_count: int = 0
|
||||
api_keys_count: int = 0
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ========== 模型管理 ==========
|
||||
@@ -442,8 +441,7 @@ class ModelResponse(BaseModel):
|
||||
global_model_name: Optional[str] = None
|
||||
global_model_display_name: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ModelDetailResponse(BaseModel):
|
||||
@@ -469,8 +467,7 @@ class ModelDetailResponse(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ========== 系统设置 ==========
|
||||
|
||||
@@ -5,7 +5,7 @@ Provider API Key相关的API模型
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ProviderAPIKeyBase(BaseModel):
|
||||
@@ -53,8 +53,7 @@ class ProviderAPIKeyResponse(ProviderAPIKeyBase):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ProviderAPIKeyStats(BaseModel):
|
||||
|
||||
@@ -27,8 +27,7 @@ from sqlalchemy import (
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
from ..config import config
|
||||
from ..core.enums import ProviderBillingType, UserRole
|
||||
|
||||
@@ -6,7 +6,7 @@ import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
# ========== ProviderEndpoint CRUD ==========
|
||||
|
||||
@@ -45,24 +45,9 @@ class ProviderEndpointCreate(BaseModel):
|
||||
@field_validator("base_url")
|
||||
@classmethod
|
||||
def validate_base_url(cls, v: str) -> str:
|
||||
"""验证 API URL(SSRF 防护)"""
|
||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
||||
|
||||
# 防止 SSRF 攻击:禁止内网地址
|
||||
forbidden_patterns = [
|
||||
r"localhost",
|
||||
r"127\.0\.0\.1",
|
||||
r"0\.0\.0\.0",
|
||||
r"192\.168\.",
|
||||
r"10\.",
|
||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
||||
r"169\.254\.",
|
||||
]
|
||||
for pattern in forbidden_patterns:
|
||||
if re.search(pattern, v, re.IGNORECASE):
|
||||
raise ValueError("不允许使用内网地址")
|
||||
|
||||
return v.rstrip("/") # 移除末尾斜杠
|
||||
|
||||
|
||||
@@ -83,27 +68,13 @@ class ProviderEndpointUpdate(BaseModel):
|
||||
@field_validator("base_url")
|
||||
@classmethod
|
||||
def validate_base_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证 API URL(SSRF 防护)"""
|
||||
"""验证 API URL"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
||||
|
||||
# 防止 SSRF 攻击:禁止内网地址
|
||||
forbidden_patterns = [
|
||||
r"localhost",
|
||||
r"127\.0\.0\.1",
|
||||
r"0\.0\.0\.0",
|
||||
r"192\.168\.",
|
||||
r"10\.",
|
||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
||||
r"169\.254\.",
|
||||
]
|
||||
for pattern in forbidden_patterns:
|
||||
if re.search(pattern, v, re.IGNORECASE):
|
||||
raise ValueError("不允许使用内网地址")
|
||||
|
||||
return v.rstrip("/") # 移除末尾斜杠
|
||||
|
||||
|
||||
@@ -141,8 +112,7 @@ class ProviderEndpointResponse(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ========== ProviderAPIKey 相关(新架构) ==========
|
||||
@@ -384,8 +354,7 @@ class EndpointAPIKeyResponse(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ========== 健康监控相关 ==========
|
||||
@@ -535,8 +504,7 @@ class ProviderWithEndpointsSummary(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ========== 健康监控可视化模型 ==========
|
||||
|
||||
@@ -5,7 +5,7 @@ Pydantic 数据模型(阶段一统一模型管理)
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
# ========== 阶梯计费相关模型 ==========
|
||||
@@ -256,8 +256,7 @@ class GlobalModelResponse(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class GlobalModelWithStats(GlobalModelResponse):
|
||||
|
||||
@@ -51,7 +51,7 @@ class JwtAuthPlugin(AuthPlugin):
|
||||
|
||||
try:
|
||||
# 验证JWT token
|
||||
payload = AuthService.verify_token(token)
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
logger.debug(f"JWT token验证成功, payload: {payload}")
|
||||
|
||||
# 从payload中提取用户信息
|
||||
|
||||
@@ -93,8 +93,8 @@ class AuthService:
|
||||
@staticmethod
|
||||
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
"""用户登录认证"""
|
||||
# 使用缓存查询用户
|
||||
user = await UserCacheService.get_user_by_email(db, email)
|
||||
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
|
||||
if not user:
|
||||
logger.warning(f"登录失败 - 用户不存在: {email}")
|
||||
@@ -109,13 +109,10 @@ class AuthService:
|
||||
return None
|
||||
|
||||
# 更新最后登录时间
|
||||
# 需要重新从数据库获取以便更新(缓存的对象是分离的)
|
||||
db_user = db.query(User).filter(User.id == user.id).first()
|
||||
if db_user:
|
||||
db_user.last_login_at = datetime.now(timezone.utc)
|
||||
db.commit() # 立即提交事务,释放数据库锁
|
||||
# 清除缓存,因为用户信息已更新
|
||||
await UserCacheService.invalidate_user_cache(user.id, user.email)
|
||||
user.last_login_at = datetime.now(timezone.utc)
|
||||
db.commit() # 立即提交事务,释放数据库锁
|
||||
# 清除缓存,因为用户信息已更新
|
||||
await UserCacheService.invalidate_user_cache(user.id, user.email)
|
||||
|
||||
logger.info(f"用户登录成功: {email} (ID: {user.id})")
|
||||
return user
|
||||
@@ -198,7 +195,10 @@ class AuthService:
|
||||
if user.role == UserRole.ADMIN:
|
||||
return True
|
||||
|
||||
if user.role.value >= required_role.value:
|
||||
# 避免使用字符串比较导致权限判断错误(例如 'user' >= 'admin')
|
||||
role_rank = {UserRole.USER: 0, UserRole.ADMIN: 1}
|
||||
# 未知用户角色默认 -1(拒绝),未知要求角色默认 999(拒绝)
|
||||
if role_rank.get(user.role, -1) >= role_rank.get(required_role, 999):
|
||||
return True
|
||||
|
||||
logger.warning(f"权限不足: 用户 {user.email} 角色 {user.role.value} < 需要 {required_role.value}")
|
||||
@@ -230,7 +230,7 @@ class AuthService:
|
||||
)
|
||||
|
||||
if success:
|
||||
user_id = payload.get("sub")
|
||||
user_id = payload.get("user_id")
|
||||
logger.info(f"用户登出成功: user_id={user_id}")
|
||||
|
||||
return success
|
||||
|
||||
4
src/services/cache/aware_scheduler.py
vendored
4
src/services/cache/aware_scheduler.py
vendored
@@ -59,7 +59,6 @@ from src.services.health.monitor import health_monitor
|
||||
from src.services.provider.format import normalize_api_format
|
||||
from src.services.rate_limit.adaptive_reservation import (
|
||||
AdaptiveReservationManager,
|
||||
ReservationResult,
|
||||
get_adaptive_reservation_manager,
|
||||
)
|
||||
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
|
||||
@@ -112,8 +111,6 @@ class CacheAwareScheduler:
|
||||
- 健康度监控
|
||||
"""
|
||||
|
||||
# 静态常量作为默认值(实际由 AdaptiveReservationManager 动态计算)
|
||||
CACHE_RESERVATION_RATIO = 0.3
|
||||
# 优先级模式常量
|
||||
PRIORITY_MODE_PROVIDER = "provider" # 提供商优先模式
|
||||
PRIORITY_MODE_GLOBAL_KEY = "global_key" # 全局 Key 优先模式
|
||||
@@ -1320,7 +1317,6 @@ class CacheAwareScheduler:
|
||||
|
||||
return {
|
||||
"scheduler": "cache_aware",
|
||||
"cache_reservation_ratio": self.CACHE_RESERVATION_RATIO,
|
||||
"dynamic_reservation": {
|
||||
"enabled": True,
|
||||
"config": reservation_stats["config"],
|
||||
|
||||
22
src/services/cache/model_cache.py
vendored
22
src/services/cache/model_cache.py
vendored
@@ -1,5 +1,21 @@
|
||||
"""
|
||||
Model 映射缓存服务 - 减少模型查询
|
||||
|
||||
架构说明
|
||||
========
|
||||
本服务采用混合 async/sync 模式:
|
||||
- 缓存操作(CacheService):真正的 async,使用 aioredis
|
||||
- 数据库查询(db.query):同步的 SQLAlchemy Session
|
||||
|
||||
设计决策
|
||||
--------
|
||||
1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优
|
||||
2. 缓存未命中时的同步查询:FastAPI 会在线程池中执行,不会阻塞事件循环
|
||||
3. 调用方必须在 async 上下文中使用 await
|
||||
|
||||
使用示例
|
||||
--------
|
||||
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, "gpt-4")
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -19,7 +35,11 @@ from src.models.database import GlobalModel, Model
|
||||
|
||||
|
||||
class ModelCacheService:
|
||||
"""Model 映射缓存服务"""
|
||||
"""Model 映射缓存服务
|
||||
|
||||
提供 GlobalModel 和 Model 的缓存查询功能,减少数据库访问。
|
||||
所有公开方法均为 async,需要在 async 上下文中调用。
|
||||
"""
|
||||
|
||||
# 缓存 TTL(秒)- 使用统一常量
|
||||
CACHE_TTL = CacheTTL.MODEL
|
||||
|
||||
254
src/services/cache/provider_cache.py
vendored
254
src/services/cache/provider_cache.py
vendored
@@ -1,254 +0,0 @@
|
||||
"""
|
||||
Provider 配置缓存服务 - 减少 Provider/Endpoint/APIKey 查询
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.constants import CacheTTL
|
||||
from src.core.cache_service import CacheKeys, CacheService
|
||||
from src.core.logger import logger
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||||
|
||||
|
||||
|
||||
class ProviderCacheService:
|
||||
"""Provider 配置缓存服务"""
|
||||
|
||||
# 缓存 TTL(秒)- 使用统一常量
|
||||
CACHE_TTL = CacheTTL.PROVIDER
|
||||
|
||||
@staticmethod
|
||||
async def get_provider_by_id(db: Session, provider_id: str) -> Optional[Provider]:
|
||||
"""
|
||||
获取 Provider(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider_id: Provider ID
|
||||
|
||||
Returns:
|
||||
Provider 对象或 None
|
||||
"""
|
||||
cache_key = CacheKeys.provider_by_id(provider_id)
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"Provider 缓存命中: {provider_id}")
|
||||
return ProviderCacheService._dict_to_provider(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if provider:
|
||||
provider_dict = ProviderCacheService._provider_to_dict(provider)
|
||||
await CacheService.set(
|
||||
cache_key, provider_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"Provider 已缓存: {provider_id}")
|
||||
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
async def get_endpoint_by_id(db: Session, endpoint_id: str) -> Optional[ProviderEndpoint]:
|
||||
"""
|
||||
获取 Endpoint(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
endpoint_id: Endpoint ID
|
||||
|
||||
Returns:
|
||||
ProviderEndpoint 对象或 None
|
||||
"""
|
||||
cache_key = CacheKeys.endpoint_by_id(endpoint_id)
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"Endpoint 缓存命中: {endpoint_id}")
|
||||
return ProviderCacheService._dict_to_endpoint(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == endpoint_id).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if endpoint:
|
||||
endpoint_dict = ProviderCacheService._endpoint_to_dict(endpoint)
|
||||
await CacheService.set(
|
||||
cache_key, endpoint_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"Endpoint 已缓存: {endpoint_id}")
|
||||
|
||||
return endpoint
|
||||
|
||||
@staticmethod
|
||||
async def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[ProviderAPIKey]:
|
||||
"""
|
||||
获取 API Key(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
api_key_id: API Key ID
|
||||
|
||||
Returns:
|
||||
ProviderAPIKey 对象或 None
|
||||
"""
|
||||
cache_key = CacheKeys.api_key_by_id(api_key_id)
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"API Key 缓存命中: {api_key_id}")
|
||||
return ProviderCacheService._dict_to_api_key(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
api_key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == api_key_id).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if api_key:
|
||||
api_key_dict = ProviderCacheService._api_key_to_dict(api_key)
|
||||
await CacheService.set(
|
||||
cache_key, api_key_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"API Key 已缓存: {api_key_id}")
|
||||
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_provider_cache(provider_id: str):
|
||||
"""
|
||||
清除 Provider 缓存
|
||||
|
||||
Args:
|
||||
provider_id: Provider ID
|
||||
"""
|
||||
await CacheService.delete(CacheKeys.provider_by_id(provider_id))
|
||||
logger.debug(f"Provider 缓存已清除: {provider_id}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_endpoint_cache(endpoint_id: str):
|
||||
"""
|
||||
清除 Endpoint 缓存
|
||||
|
||||
Args:
|
||||
endpoint_id: Endpoint ID
|
||||
"""
|
||||
await CacheService.delete(CacheKeys.endpoint_by_id(endpoint_id))
|
||||
logger.debug(f"Endpoint 缓存已清除: {endpoint_id}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_api_key_cache(api_key_id: str):
|
||||
"""
|
||||
清除 API Key 缓存
|
||||
|
||||
Args:
|
||||
api_key_id: API Key ID
|
||||
"""
|
||||
await CacheService.delete(CacheKeys.api_key_by_id(api_key_id))
|
||||
logger.debug(f"API Key 缓存已清除: {api_key_id}")
|
||||
|
||||
@staticmethod
|
||||
def _provider_to_dict(provider: Provider) -> dict:
|
||||
"""将 Provider 对象转换为字典(用于缓存)"""
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"api_format": provider.api_format,
|
||||
"base_url": provider.base_url,
|
||||
"is_active": provider.is_active,
|
||||
"priority": provider.priority,
|
||||
"rpm_limit": provider.rpm_limit,
|
||||
"rpm_used": provider.rpm_used,
|
||||
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
|
||||
"config": provider.config,
|
||||
"description": provider.description,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_provider(provider_dict: dict) -> Provider:
|
||||
"""从字典重建 Provider 对象(分离的对象,不在 Session 中)"""
|
||||
from datetime import datetime
|
||||
|
||||
provider = Provider(
|
||||
id=provider_dict["id"],
|
||||
name=provider_dict["name"],
|
||||
api_format=provider_dict["api_format"],
|
||||
base_url=provider_dict.get("base_url"),
|
||||
is_active=provider_dict["is_active"],
|
||||
priority=provider_dict.get("priority", 0),
|
||||
rpm_limit=provider_dict.get("rpm_limit"),
|
||||
rpm_used=provider_dict.get("rpm_used", 0),
|
||||
config=provider_dict.get("config"),
|
||||
description=provider_dict.get("description"),
|
||||
)
|
||||
|
||||
if provider_dict.get("rpm_reset_at"):
|
||||
provider.rpm_reset_at = datetime.fromisoformat(provider_dict["rpm_reset_at"])
|
||||
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
def _endpoint_to_dict(endpoint: ProviderEndpoint) -> dict:
|
||||
"""将 Endpoint 对象转换为字典"""
|
||||
return {
|
||||
"id": endpoint.id,
|
||||
"provider_id": endpoint.provider_id,
|
||||
"name": endpoint.name,
|
||||
"base_url": endpoint.base_url,
|
||||
"is_active": endpoint.is_active,
|
||||
"priority": endpoint.priority,
|
||||
"weight": endpoint.weight,
|
||||
"custom_path": endpoint.custom_path,
|
||||
"config": endpoint.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_endpoint(endpoint_dict: dict) -> ProviderEndpoint:
|
||||
"""从字典重建 Endpoint 对象"""
|
||||
endpoint = ProviderEndpoint(
|
||||
id=endpoint_dict["id"],
|
||||
provider_id=endpoint_dict["provider_id"],
|
||||
name=endpoint_dict["name"],
|
||||
base_url=endpoint_dict["base_url"],
|
||||
is_active=endpoint_dict["is_active"],
|
||||
priority=endpoint_dict.get("priority", 0),
|
||||
weight=endpoint_dict.get("weight", 1.0),
|
||||
custom_path=endpoint_dict.get("custom_path"),
|
||||
config=endpoint_dict.get("config"),
|
||||
)
|
||||
return endpoint
|
||||
|
||||
@staticmethod
|
||||
def _api_key_to_dict(api_key: ProviderAPIKey) -> dict:
|
||||
"""将 API Key 对象转换为字典"""
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"endpoint_id": api_key.endpoint_id,
|
||||
"key_value": api_key.key_value,
|
||||
"is_active": api_key.is_active,
|
||||
"max_rpm": api_key.max_rpm,
|
||||
"current_rpm": api_key.current_rpm,
|
||||
"health_score": api_key.health_score,
|
||||
"circuit_breaker_state": api_key.circuit_breaker_state,
|
||||
"adaptive_concurrency_limit": api_key.adaptive_concurrency_limit,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_api_key(api_key_dict: dict) -> ProviderAPIKey:
|
||||
"""从字典重建 API Key 对象"""
|
||||
api_key = ProviderAPIKey(
|
||||
id=api_key_dict["id"],
|
||||
endpoint_id=api_key_dict["endpoint_id"],
|
||||
key_value=api_key_dict["key_value"],
|
||||
is_active=api_key_dict["is_active"],
|
||||
max_rpm=api_key_dict.get("max_rpm"),
|
||||
current_rpm=api_key_dict.get("current_rpm", 0),
|
||||
health_score=api_key_dict.get("health_score", 1.0),
|
||||
circuit_breaker_state=api_key_dict.get("circuit_breaker_state"),
|
||||
adaptive_concurrency_limit=api_key_dict.get("adaptive_concurrency_limit"),
|
||||
)
|
||||
return api_key
|
||||
24
src/services/cache/user_cache.py
vendored
24
src/services/cache/user_cache.py
vendored
@@ -1,5 +1,22 @@
|
||||
"""
|
||||
用户缓存服务 - 减少数据库查询
|
||||
|
||||
架构说明
|
||||
========
|
||||
本服务采用混合 async/sync 模式:
|
||||
- 缓存操作(CacheService):真正的 async,使用 aioredis
|
||||
- 数据库查询(db.query):同步的 SQLAlchemy Session
|
||||
|
||||
设计决策
|
||||
--------
|
||||
1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优
|
||||
2. 缓存未命中时的同步查询:FastAPI 会在线程池中执行,不会阻塞事件循环
|
||||
3. 调用方必须在 async 上下文中使用 await
|
||||
|
||||
使用示例
|
||||
--------
|
||||
user = await UserCacheService.get_user_by_id(db, user_id)
|
||||
await UserCacheService.invalidate_user_cache(user_id, email)
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
@@ -12,9 +29,12 @@ from src.core.logger import logger
|
||||
from src.models.database import User
|
||||
|
||||
|
||||
|
||||
class UserCacheService:
|
||||
"""用户缓存服务"""
|
||||
"""用户缓存服务
|
||||
|
||||
提供 User 的缓存查询功能,减少数据库访问。
|
||||
所有公开方法均为 async,需要在 async 上下文中调用。
|
||||
"""
|
||||
|
||||
# 缓存 TTL(秒)- 使用统一常量
|
||||
CACHE_TTL = CacheTTL.USER
|
||||
|
||||
@@ -69,24 +69,29 @@ class ErrorClassifier:
|
||||
# 这些错误是由用户请求本身导致的,换 Provider 也无济于事
|
||||
# 注意:标准 API 返回的 error.type 已在 CLIENT_ERROR_TYPES 中处理
|
||||
# 这里主要用于匹配非标准格式或第三方代理的错误消息
|
||||
#
|
||||
# 重要:不要在此列表中包含 Provider Key 配置问题(如 invalid_api_key)
|
||||
# 这类错误应该触发故障转移,而不是直接返回给用户
|
||||
CLIENT_ERROR_PATTERNS: Tuple[str, ...] = (
|
||||
"could not process image", # 图片处理失败
|
||||
"image too large", # 图片过大
|
||||
"invalid image", # 无效图片
|
||||
"unsupported image", # 不支持的图片格式
|
||||
"content_policy_violation", # 内容违规
|
||||
"invalid_api_key", # 无效的 API Key(不同于认证失败)
|
||||
"context_length_exceeded", # 上下文长度超限
|
||||
"content_length_limit", # 请求内容长度超限 (Claude API)
|
||||
"content_length_exceeds", # 内容长度超限变体 (AWS CodeWhisperer)
|
||||
"max_tokens", # token 数超限
|
||||
"invalid_prompt", # 无效的提示词
|
||||
"content too long", # 内容过长
|
||||
"input is too long", # 输入过长 (AWS)
|
||||
"message is too long", # 消息过长
|
||||
"prompt is too long", # Prompt 超长(第三方代理常见格式)
|
||||
"image exceeds", # 图片超出限制
|
||||
"pdf too large", # PDF 过大
|
||||
"file too large", # 文件过大
|
||||
"tool_use_id", # tool_result 引用了不存在的 tool_use(兼容非标准代理)
|
||||
"validationexception", # AWS 验证异常
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@@ -110,18 +115,124 @@ class ErrorClassifier:
|
||||
# 表示客户端错误的 error type(不区分大小写)
|
||||
# 这些 type 表明是请求本身的问题,不应重试
|
||||
CLIENT_ERROR_TYPES: Tuple[str, ...] = (
|
||||
"invalid_request_error", # Claude/OpenAI 标准客户端错误类型
|
||||
"invalid_argument", # Gemini 参数错误
|
||||
"failed_precondition", # Gemini 前置条件错误
|
||||
# Claude/OpenAI 标准
|
||||
"invalid_request_error",
|
||||
# Gemini
|
||||
"invalid_argument",
|
||||
"failed_precondition",
|
||||
# AWS
|
||||
"validationexception",
|
||||
# 通用
|
||||
"validation_error",
|
||||
"bad_request",
|
||||
)
|
||||
|
||||
# 表示客户端错误的 reason/code 字段值
|
||||
CLIENT_ERROR_REASONS: Tuple[str, ...] = (
|
||||
"CONTENT_LENGTH_EXCEEDS_THRESHOLD",
|
||||
"CONTEXT_LENGTH_EXCEEDED",
|
||||
"MAX_TOKENS_EXCEEDED",
|
||||
"INVALID_CONTENT",
|
||||
"CONTENT_POLICY_VIOLATION",
|
||||
)
|
||||
|
||||
def _parse_error_response(self, error_text: Optional[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
解析错误响应为结构化数据
|
||||
|
||||
支持多种格式:
|
||||
- {"error": {"type": "...", "message": "..."}} (Claude/OpenAI)
|
||||
- {"error": {"message": "...", "__type": "..."}} (AWS)
|
||||
- {"errorMessage": "..."} (Lambda)
|
||||
- {"error": "..."}
|
||||
- {"message": "...", "reason": "..."}
|
||||
|
||||
Returns:
|
||||
结构化的错误信息: {
|
||||
"type": str, # 错误类型
|
||||
"message": str, # 错误消息
|
||||
"reason": str, # 错误原因/代码
|
||||
"raw": str, # 原始文本
|
||||
}
|
||||
"""
|
||||
result = {"type": "", "message": "", "reason": "", "raw": error_text or ""}
|
||||
|
||||
if not error_text:
|
||||
return result
|
||||
|
||||
try:
|
||||
data = json.loads(error_text)
|
||||
|
||||
# 格式 1: {"error": {"type": "...", "message": "..."}}
|
||||
if isinstance(data.get("error"), dict):
|
||||
error_obj = data["error"]
|
||||
result["type"] = str(error_obj.get("type", ""))
|
||||
result["message"] = str(error_obj.get("message", ""))
|
||||
|
||||
# AWS 格式: {"error": {"__type": "...", "message": "...", "reason": "..."}}
|
||||
# __type 直接在 error 对象中,而不是嵌套在 message 里
|
||||
if "__type" in error_obj:
|
||||
result["type"] = result["type"] or str(error_obj.get("__type", ""))
|
||||
if "reason" in error_obj:
|
||||
result["reason"] = str(error_obj.get("reason", ""))
|
||||
if "code" in error_obj:
|
||||
result["reason"] = result["reason"] or str(error_obj.get("code", ""))
|
||||
|
||||
# 嵌套 JSON 格式: message 字段本身是 JSON 字符串
|
||||
# 支持多种嵌套格式:
|
||||
# - AWS: {"__type": "...", "message": "...", "reason": "..."}
|
||||
# - 第三方代理: {"error": {"type": "...", "message": "..."}}
|
||||
if result["message"].startswith("{"):
|
||||
try:
|
||||
nested = json.loads(result["message"])
|
||||
if isinstance(nested, dict):
|
||||
# AWS 格式
|
||||
if "__type" in nested:
|
||||
result["type"] = result["type"] or str(nested.get("__type", ""))
|
||||
result["message"] = str(nested.get("message", result["message"]))
|
||||
result["reason"] = str(nested.get("reason", ""))
|
||||
# 第三方代理格式: {"error": {"message": "..."}}
|
||||
elif isinstance(nested.get("error"), dict):
|
||||
inner_error = nested["error"]
|
||||
inner_msg = str(inner_error.get("message", ""))
|
||||
if inner_msg:
|
||||
result["message"] = inner_msg
|
||||
# 简单格式: {"message": "..."}
|
||||
elif "message" in nested:
|
||||
result["message"] = str(nested["message"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 格式 2: {"error": "..."}
|
||||
elif isinstance(data.get("error"), str):
|
||||
result["message"] = str(data["error"])
|
||||
|
||||
# 格式 3: {"errorMessage": "..."} (Lambda)
|
||||
elif "errorMessage" in data:
|
||||
result["message"] = str(data["errorMessage"])
|
||||
|
||||
# 格式 4: {"message": "...", "reason": "..."}
|
||||
elif "message" in data:
|
||||
result["message"] = str(data["message"])
|
||||
result["reason"] = str(data.get("reason", ""))
|
||||
|
||||
# 提取顶层的 reason/code
|
||||
if not result["reason"]:
|
||||
result["reason"] = str(data.get("reason", data.get("code", "")))
|
||||
|
||||
except (json.JSONDecodeError, TypeError, KeyError):
|
||||
result["message"] = error_text[:500] if len(error_text) > 500 else error_text
|
||||
|
||||
return result
|
||||
|
||||
def _is_client_error(self, error_text: Optional[str]) -> bool:
|
||||
"""
|
||||
检测错误响应是否为客户端错误(不应重试)
|
||||
|
||||
判断逻辑:
|
||||
判断逻辑(按优先级):
|
||||
1. 检查 error.type 是否为已知的客户端错误类型
|
||||
2. 检查错误文本是否包含已知的客户端错误模式
|
||||
2. 检查 reason/code 是否为已知的客户端错误原因
|
||||
3. 回退到关键词匹配
|
||||
|
||||
Args:
|
||||
error_text: 错误响应文本
|
||||
@@ -132,67 +243,53 @@ class ErrorClassifier:
|
||||
if not error_text:
|
||||
return False
|
||||
|
||||
# 尝试解析 JSON 并检查 error type
|
||||
try:
|
||||
data = json.loads(error_text)
|
||||
if isinstance(data.get("error"), dict):
|
||||
error_type = data["error"].get("type", "")
|
||||
if error_type and any(
|
||||
t.lower() in error_type.lower() for t in self.CLIENT_ERROR_TYPES
|
||||
):
|
||||
return True
|
||||
except (json.JSONDecodeError, TypeError, KeyError):
|
||||
pass
|
||||
parsed = self._parse_error_response(error_text)
|
||||
|
||||
# 回退到关键词匹配
|
||||
error_lower = error_text.lower()
|
||||
return any(pattern.lower() in error_lower for pattern in self.CLIENT_ERROR_PATTERNS)
|
||||
# 1. 检查 error type
|
||||
if parsed["type"]:
|
||||
error_type_lower = parsed["type"].lower()
|
||||
if any(t.lower() in error_type_lower for t in self.CLIENT_ERROR_TYPES):
|
||||
return True
|
||||
|
||||
# 2. 检查 reason/code
|
||||
if parsed["reason"]:
|
||||
reason_upper = parsed["reason"].upper()
|
||||
if any(r in reason_upper for r in self.CLIENT_ERROR_REASONS):
|
||||
return True
|
||||
|
||||
# 3. 回退到关键词匹配(合并 message 和 raw)
|
||||
search_text = f"{parsed['message']} {parsed['raw']}".lower()
|
||||
return any(pattern.lower() in search_text for pattern in self.CLIENT_ERROR_PATTERNS)
|
||||
|
||||
def _extract_error_message(self, error_text: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
从错误响应中提取错误消息
|
||||
|
||||
支持格式:
|
||||
- {"error": {"message": "..."}} (OpenAI/Claude)
|
||||
- {"error": {"type": "...", "message": "..."}}
|
||||
- {"error": "..."}
|
||||
- {"message": "..."}
|
||||
|
||||
Args:
|
||||
error_text: 错误响应文本
|
||||
|
||||
Returns:
|
||||
提取的错误消息,如果无法解析则返回原始文本
|
||||
提取的错误消息
|
||||
"""
|
||||
if not error_text:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = json.loads(error_text)
|
||||
parsed = self._parse_error_response(error_text)
|
||||
|
||||
# {"error": {"message": "..."}} 或 {"error": {"type": "...", "message": "..."}}
|
||||
if isinstance(data.get("error"), dict):
|
||||
error_obj = data["error"]
|
||||
message = error_obj.get("message", "")
|
||||
error_type = error_obj.get("type", "")
|
||||
if message:
|
||||
if error_type:
|
||||
return f"{error_type}: {message}"
|
||||
return str(message)
|
||||
# 构建可读的错误消息
|
||||
parts = []
|
||||
if parsed["type"]:
|
||||
parts.append(parsed["type"])
|
||||
if parsed["reason"]:
|
||||
parts.append(f"[{parsed['reason']}]")
|
||||
if parsed["message"]:
|
||||
parts.append(parsed["message"])
|
||||
|
||||
# {"error": "..."}
|
||||
if isinstance(data.get("error"), str):
|
||||
return str(data["error"])
|
||||
|
||||
# {"message": "..."}
|
||||
if isinstance(data.get("message"), str):
|
||||
return str(data["message"])
|
||||
|
||||
except (json.JSONDecodeError, TypeError, KeyError):
|
||||
pass
|
||||
if parts:
|
||||
return ": ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
# 无法解析,返回原始文本(截断)
|
||||
return error_text[:500] if len(error_text) > 500 else error_text
|
||||
return parsed["raw"][:500] if len(parsed["raw"]) > 500 else parsed["raw"]
|
||||
|
||||
def classify(
|
||||
self,
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
- 使用滑动窗口采样,容忍并发波动
|
||||
- 基于窗口内高利用率采样比例决策,而非要求连续高利用率
|
||||
- 增加探测性扩容机制,长时间稳定时主动尝试扩容
|
||||
|
||||
AIMD 参数说明:
|
||||
- 扩容:加性增加 (+INCREASE_STEP)
|
||||
- 缩容:乘性减少 (*DECREASE_MULTIPLIER,默认 0.85)
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
@@ -34,7 +38,7 @@ class AdaptiveConcurrencyManager:
|
||||
核心算法:基于滑动窗口利用率的 AIMD
|
||||
- 滑动窗口记录最近 N 次请求的利用率
|
||||
- 当窗口内高利用率采样比例 >= 60% 时触发扩容
|
||||
- 遇到 429 错误时乘性减少 (*0.7)
|
||||
- 遇到 429 错误时乘性减少 (*0.85)
|
||||
- 长时间无 429 且有流量时触发探测性扩容
|
||||
|
||||
扩容条件(满足任一即可):
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta # noqa: F401 - kept for potential future use
|
||||
from typing import Optional, Tuple
|
||||
@@ -40,6 +39,7 @@ class ConcurrencyManager:
|
||||
self._memory_lock: asyncio.Lock = asyncio.Lock()
|
||||
self._memory_endpoint_counts: dict[str, int] = {}
|
||||
self._memory_key_counts: dict[str, int] = {}
|
||||
self._owns_redis: bool = False
|
||||
self._memory_initialized = True
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@@ -47,41 +47,29 @@ class ConcurrencyManager:
|
||||
if self._redis is not None:
|
||||
return
|
||||
|
||||
# 优先使用 REDIS_URL,如果没有则根据密码构建 URL
|
||||
redis_url = os.getenv("REDIS_URL")
|
||||
|
||||
if not redis_url:
|
||||
# 本地开发模式:从 REDIS_PASSWORD 构建 URL
|
||||
redis_password = os.getenv("REDIS_PASSWORD")
|
||||
if redis_password:
|
||||
redis_url = f"redis://:{redis_password}@localhost:6379/0"
|
||||
else:
|
||||
redis_url = "redis://localhost:6379/0"
|
||||
|
||||
try:
|
||||
self._redis = await aioredis.from_url(
|
||||
redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
socket_timeout=5.0,
|
||||
socket_connect_timeout=5.0,
|
||||
)
|
||||
# 测试连接
|
||||
await self._redis.ping()
|
||||
# 脱敏显示(隐藏密码)
|
||||
safe_url = redis_url.split("@")[-1] if "@" in redis_url else redis_url
|
||||
logger.info(f"[OK] Redis 连接成功: {safe_url}")
|
||||
# 复用全局 Redis 客户端(带熔断/降级),避免重复创建连接池
|
||||
from src.clients.redis_client import get_redis_client
|
||||
|
||||
self._redis = await get_redis_client(require_redis=False)
|
||||
self._owns_redis = False
|
||||
if self._redis:
|
||||
logger.info("[OK] ConcurrencyManager 已复用全局 Redis 客户端")
|
||||
else:
|
||||
logger.warning("[WARN] Redis 不可用,并发控制降级为内存模式(仅在单实例环境下安全)")
|
||||
except Exception as e:
|
||||
logger.error(f"[ERROR] Redis 连接失败: {e}")
|
||||
logger.warning("[WARN] 并发控制将被禁用(仅在单实例环境下安全)")
|
||||
logger.error(f"[ERROR] 获取全局 Redis 客户端失败: {e}")
|
||||
logger.warning("[WARN] 并发控制将降级为内存模式(仅在单实例环境下安全)")
|
||||
self._redis = None
|
||||
self._owns_redis = False
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 Redis 连接"""
|
||||
if self._redis:
|
||||
if self._redis and self._owns_redis:
|
||||
await self._redis.close()
|
||||
self._redis = None
|
||||
logger.info("Redis 连接已关闭")
|
||||
logger.info("ConcurrencyManager Redis 连接已关闭")
|
||||
self._redis = None
|
||||
self._owns_redis = False
|
||||
|
||||
def _get_endpoint_key(self, endpoint_id: str) -> str:
|
||||
"""获取 Endpoint 并发计数的 Redis Key"""
|
||||
|
||||
@@ -3,7 +3,7 @@ RPM (Requests Per Minute) 限流服务
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -72,11 +72,7 @@ class RPMLimiter:
|
||||
# 获取当前分钟窗口
|
||||
now = datetime.now(timezone.utc)
|
||||
window_start = now.replace(second=0, microsecond=0)
|
||||
window_end = (
|
||||
window_start.replace(minute=window_start.minute + 1)
|
||||
if window_start.minute < 59
|
||||
else window_start.replace(hour=window_start.hour + 1, minute=0)
|
||||
)
|
||||
window_end = window_start + timedelta(minutes=1)
|
||||
|
||||
# 查找或创建追踪记录
|
||||
tracking = (
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import AuditEventType, AuditLog
|
||||
from src.utils.transaction_manager import transactional
|
||||
|
||||
|
||||
|
||||
@@ -19,10 +18,13 @@ from src.utils.transaction_manager import transactional
|
||||
|
||||
|
||||
class AuditService:
|
||||
"""审计服务"""
|
||||
"""审计服务
|
||||
|
||||
事务策略:本服务不负责事务提交,由中间件统一管理。
|
||||
所有方法只做 db.add/flush,提交由请求结束时的中间件处理。
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@transactional(commit=False) # 不自动提交,让调用方决定
|
||||
def log_event(
|
||||
db: Session,
|
||||
event_type: AuditEventType,
|
||||
@@ -54,47 +56,44 @@ class AuditService:
|
||||
|
||||
Returns:
|
||||
审计日志记录
|
||||
|
||||
Note:
|
||||
不在此方法内提交事务,由调用方或中间件统一管理。
|
||||
"""
|
||||
try:
|
||||
audit_log = AuditLog(
|
||||
event_type=event_type.value,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
api_key_id=api_key_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
event_metadata=metadata,
|
||||
)
|
||||
audit_log = AuditLog(
|
||||
event_type=event_type.value,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
api_key_id=api_key_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
event_metadata=metadata,
|
||||
)
|
||||
|
||||
db.add(audit_log)
|
||||
db.commit() # 立即提交事务,释放数据库锁
|
||||
db.refresh(audit_log)
|
||||
db.add(audit_log)
|
||||
# 使用 flush 使记录可见但不提交事务,事务由中间件统一管理
|
||||
db.flush()
|
||||
|
||||
# 同时记录到系统日志
|
||||
log_message = (
|
||||
f"AUDIT [{event_type.value}] - {description} | "
|
||||
f"user_id={user_id}, ip={ip_address}"
|
||||
)
|
||||
# 同时记录到系统日志
|
||||
log_message = (
|
||||
f"AUDIT [{event_type.value}] - {description} | "
|
||||
f"user_id={user_id}, ip={ip_address}"
|
||||
)
|
||||
|
||||
if event_type in [
|
||||
AuditEventType.UNAUTHORIZED_ACCESS,
|
||||
AuditEventType.SUSPICIOUS_ACTIVITY,
|
||||
]:
|
||||
logger.warning(log_message)
|
||||
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
|
||||
logger.info(log_message)
|
||||
else:
|
||||
logger.debug(log_message)
|
||||
if event_type in [
|
||||
AuditEventType.UNAUTHORIZED_ACCESS,
|
||||
AuditEventType.SUSPICIOUS_ACTIVITY,
|
||||
]:
|
||||
logger.warning(log_message)
|
||||
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
|
||||
logger.info(log_message)
|
||||
else:
|
||||
logger.debug(log_message)
|
||||
|
||||
return audit_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log audit event: {e}")
|
||||
db.rollback()
|
||||
return None
|
||||
return audit_log
|
||||
|
||||
@staticmethod
|
||||
def log_login_attempt(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -41,7 +41,7 @@ async def get_current_user(
|
||||
try:
|
||||
# 验证Token格式和签名
|
||||
try:
|
||||
payload = await AuthService.verify_token(token)
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
except HTTPException as token_error:
|
||||
# 保持原始的HTTP状态码(如401 Unauthorized),不要转换为403
|
||||
logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...")
|
||||
@@ -144,7 +144,7 @@ async def get_current_user_from_header(
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
try:
|
||||
payload = await AuthService.verify_token(token)
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
user_id = payload.get("user_id")
|
||||
|
||||
if not user_id:
|
||||
|
||||
363
tests/api/test_pipeline.py
Normal file
363
tests/api/test_pipeline.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
API Pipeline 测试
|
||||
|
||||
测试 ApiRequestPipeline 的核心功能:
|
||||
- 认证流程(API Key、JWT Token)
|
||||
- 配额计算
|
||||
- 审计日志记录
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
|
||||
|
||||
class TestPipelineQuotaCalculation:
|
||||
"""测试 Pipeline 配额计算"""
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self) -> ApiRequestPipeline:
|
||||
return ApiRequestPipeline()
|
||||
|
||||
def test_calculate_quota_remaining_with_quota(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试有配额限制时计算剩余配额"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 30.0
|
||||
|
||||
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||
|
||||
assert remaining == 70.0
|
||||
|
||||
def test_calculate_quota_remaining_no_quota(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试无配额限制时返回 None"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = None
|
||||
mock_user.used_usd = 30.0
|
||||
|
||||
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||
|
||||
assert remaining is None
|
||||
|
||||
def test_calculate_quota_remaining_negative_quota(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试负配额时返回 None"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = -1
|
||||
mock_user.used_usd = 0.0
|
||||
|
||||
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||
|
||||
assert remaining is None
|
||||
|
||||
def test_calculate_quota_remaining_exceeded(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试配额已超时返回 0"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 150.0
|
||||
|
||||
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||
|
||||
assert remaining == 0.0
|
||||
|
||||
def test_calculate_quota_remaining_none_user(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试用户为 None 时返回 None"""
|
||||
remaining = pipeline._calculate_quota_remaining(None)
|
||||
|
||||
assert remaining is None
|
||||
|
||||
|
||||
class TestPipelineAuditLogging:
|
||||
"""测试 Pipeline 审计日志"""
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self) -> ApiRequestPipeline:
|
||||
return ApiRequestPipeline()
|
||||
|
||||
def test_record_audit_event_success(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试记录成功的审计事件"""
|
||||
mock_context = MagicMock()
|
||||
mock_context.db = MagicMock()
|
||||
mock_context.user = MagicMock()
|
||||
mock_context.user.id = "user-123"
|
||||
mock_context.api_key = MagicMock()
|
||||
mock_context.api_key.id = "key-123"
|
||||
mock_context.request_id = "req-123"
|
||||
mock_context.client_ip = "127.0.0.1"
|
||||
mock_context.user_agent = "test-agent"
|
||||
mock_context.request = MagicMock()
|
||||
mock_context.request.method = "POST"
|
||||
mock_context.request.url.path = "/v1/messages"
|
||||
mock_context.start_time = 1000.0
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test-adapter"
|
||||
mock_adapter.audit_log_enabled = True
|
||||
mock_adapter.audit_success_event = None
|
||||
mock_adapter.audit_failure_event = None
|
||||
|
||||
with patch.object(
|
||||
pipeline.audit_service,
|
||||
"log_event",
|
||||
) as mock_log:
|
||||
with patch("time.time", return_value=1001.0):
|
||||
pipeline._record_audit_event(
|
||||
mock_context, mock_adapter, success=True, status_code=200
|
||||
)
|
||||
|
||||
mock_log.assert_called_once()
|
||||
call_kwargs = mock_log.call_args[1]
|
||||
assert call_kwargs["user_id"] == "user-123"
|
||||
assert call_kwargs["status_code"] == 200
|
||||
|
||||
def test_record_audit_event_failure(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试记录失败的审计事件"""
|
||||
mock_context = MagicMock()
|
||||
mock_context.db = MagicMock()
|
||||
mock_context.user = MagicMock()
|
||||
mock_context.user.id = "user-123"
|
||||
mock_context.api_key = MagicMock()
|
||||
mock_context.api_key.id = "key-123"
|
||||
mock_context.request_id = "req-123"
|
||||
mock_context.client_ip = "127.0.0.1"
|
||||
mock_context.user_agent = "test-agent"
|
||||
mock_context.request = MagicMock()
|
||||
mock_context.request.method = "POST"
|
||||
mock_context.request.url.path = "/v1/messages"
|
||||
mock_context.start_time = 1000.0
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test-adapter"
|
||||
mock_adapter.audit_log_enabled = True
|
||||
mock_adapter.audit_success_event = None
|
||||
mock_adapter.audit_failure_event = None
|
||||
|
||||
with patch.object(
|
||||
pipeline.audit_service,
|
||||
"log_event",
|
||||
) as mock_log:
|
||||
with patch("time.time", return_value=1001.0):
|
||||
pipeline._record_audit_event(
|
||||
mock_context, mock_adapter, success=False, status_code=500, error="Internal error"
|
||||
)
|
||||
|
||||
mock_log.assert_called_once()
|
||||
call_kwargs = mock_log.call_args[1]
|
||||
assert call_kwargs["status_code"] == 500
|
||||
assert call_kwargs["error_message"] == "Internal error"
|
||||
|
||||
def test_record_audit_event_no_db(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试没有数据库会话时跳过审计"""
|
||||
mock_context = MagicMock()
|
||||
mock_context.db = None
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.audit_log_enabled = True
|
||||
|
||||
with patch.object(
|
||||
pipeline.audit_service,
|
||||
"log_event",
|
||||
) as mock_log:
|
||||
# 不应该抛出异常
|
||||
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
|
||||
|
||||
# 不应该调用 log_event
|
||||
mock_log.assert_not_called()
|
||||
|
||||
def test_record_audit_event_disabled(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试审计日志被禁用时跳过"""
|
||||
mock_context = MagicMock()
|
||||
mock_context.db = MagicMock()
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.audit_log_enabled = False
|
||||
|
||||
with patch.object(
|
||||
pipeline.audit_service,
|
||||
"log_event",
|
||||
) as mock_log:
|
||||
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
|
||||
|
||||
mock_log.assert_not_called()
|
||||
|
||||
def test_record_audit_event_exception_handling(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试审计日志异常不影响主流程"""
|
||||
mock_context = MagicMock()
|
||||
mock_context.db = MagicMock()
|
||||
mock_context.user = MagicMock()
|
||||
mock_context.user.id = "user-123"
|
||||
mock_context.api_key = MagicMock()
|
||||
mock_context.api_key.id = "key-123"
|
||||
mock_context.request_id = "req-123"
|
||||
mock_context.client_ip = "127.0.0.1"
|
||||
mock_context.user_agent = "test-agent"
|
||||
mock_context.request = MagicMock()
|
||||
mock_context.request.method = "POST"
|
||||
mock_context.request.url.path = "/v1/messages"
|
||||
mock_context.start_time = 1000.0
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test-adapter"
|
||||
mock_adapter.audit_log_enabled = True
|
||||
mock_adapter.audit_success_event = None
|
||||
|
||||
with patch.object(
|
||||
pipeline.audit_service,
|
||||
"log_event",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
with patch("time.time", return_value=1001.0):
|
||||
# 不应该抛出异常
|
||||
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
|
||||
|
||||
|
||||
class TestPipelineAuthentication:
|
||||
"""测试 Pipeline 认证相关逻辑"""
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self) -> ApiRequestPipeline:
|
||||
return ApiRequestPipeline()
|
||||
|
||||
def test_authenticate_client_missing_key(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试缺少 API Key 时抛出异常"""
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_request.url.path = "/v1/messages"
|
||||
mock_request.state = MagicMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.extract_api_key = MagicMock(return_value=None)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "API密钥" in exc_info.value.detail
|
||||
|
||||
def test_authenticate_client_invalid_key(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试无效的 API Key"""
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {"Authorization": "Bearer sk-invalid"}
|
||||
mock_request.url.path = "/v1/messages"
|
||||
mock_request.state = MagicMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.extract_api_key = MagicMock(return_value="sk-invalid")
|
||||
|
||||
with patch.object(
|
||||
pipeline.auth_service,
|
||||
"authenticate_api_key",
|
||||
return_value=None,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_authenticate_client_quota_exceeded(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试配额超限时抛出异常"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user-123"
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 100.0
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.id = "key-123"
|
||||
mock_api_key.is_standalone = False
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {"Authorization": "Bearer sk-test"}
|
||||
mock_request.url.path = "/v1/messages"
|
||||
mock_request.state = MagicMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.extract_api_key = MagicMock(return_value="sk-test")
|
||||
|
||||
with patch.object(
|
||||
pipeline.auth_service,
|
||||
"authenticate_api_key",
|
||||
return_value=(mock_user, mock_api_key),
|
||||
):
|
||||
with patch.object(
|
||||
pipeline.usage_service,
|
||||
"check_user_quota",
|
||||
return_value=(False, "配额不足"),
|
||||
):
|
||||
from src.core.exceptions import QuotaExceededException
|
||||
|
||||
with pytest.raises(QuotaExceededException):
|
||||
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
|
||||
|
||||
|
||||
class TestPipelineAdminAuth:
|
||||
"""测试管理员认证"""
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self) -> ApiRequestPipeline:
|
||||
return ApiRequestPipeline()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_admin_missing_token(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试缺少管理员令牌"""
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await pipeline._authenticate_admin(mock_request, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "管理员凭证" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_admin_invalid_token(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试无效的管理员令牌"""
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {"authorization": "Bearer invalid-token"}
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch.object(
|
||||
pipeline.auth_service,
|
||||
"verify_token",
|
||||
side_effect=HTTPException(status_code=401, detail="Invalid token"),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await pipeline._authenticate_admin(mock_request, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_admin_success(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试管理员认证成功"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "admin-123"
|
||||
mock_user.is_active = True
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {"authorization": "Bearer valid-token"}
|
||||
mock_request.state = MagicMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
with patch.object(
|
||||
pipeline.auth_service,
|
||||
"verify_token",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"user_id": "admin-123"},
|
||||
):
|
||||
result = await pipeline._authenticate_admin(mock_request, mock_db)
|
||||
|
||||
assert result == mock_user
|
||||
assert mock_request.state.user_id == "admin-123"
|
||||
1
tests/services/__init__.py
Normal file
1
tests/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""服务层测试"""
|
||||
299
tests/services/test_auth.py
Normal file
299
tests/services/test_auth.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
认证服务测试
|
||||
|
||||
测试 AuthService 的核心功能:
|
||||
- JWT Token 创建和验证
|
||||
- 用户登录认证
|
||||
- API Key 认证
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import jwt
|
||||
|
||||
from src.services.auth.service import (
|
||||
AuthService,
|
||||
JWT_SECRET_KEY,
|
||||
JWT_ALGORITHM,
|
||||
JWT_EXPIRATION_HOURS,
|
||||
)
|
||||
|
||||
|
||||
class TestJWTTokenCreation:
|
||||
"""测试 JWT Token 创建"""
|
||||
|
||||
def test_create_access_token_contains_required_fields(self) -> None:
|
||||
"""测试访问令牌包含必要字段"""
|
||||
data = {"sub": "user123", "email": "test@example.com"}
|
||||
token = AuthService.create_access_token(data)
|
||||
|
||||
# 解码验证
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
assert payload["sub"] == "user123"
|
||||
assert payload["email"] == "test@example.com"
|
||||
assert payload["type"] == "access"
|
||||
assert "exp" in payload
|
||||
|
||||
def test_create_access_token_expiration(self) -> None:
|
||||
"""测试访问令牌过期时间正确"""
|
||||
data = {"sub": "user123"}
|
||||
token = AuthService.create_access_token(data)
|
||||
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
# 验证过期时间在预期范围内(允许1分钟误差)
|
||||
exp_time = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||
expected_exp = datetime.now(timezone.utc) + timedelta(hours=JWT_EXPIRATION_HOURS)
|
||||
|
||||
assert abs((exp_time - expected_exp).total_seconds()) < 60
|
||||
|
||||
def test_create_refresh_token_type(self) -> None:
|
||||
"""测试刷新令牌类型正确"""
|
||||
data = {"sub": "user123"}
|
||||
token = AuthService.create_refresh_token(data)
|
||||
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
def test_create_refresh_token_longer_expiration(self) -> None:
|
||||
"""测试刷新令牌过期时间更长"""
|
||||
data = {"sub": "user123"}
|
||||
access_token = AuthService.create_access_token(data)
|
||||
refresh_token = AuthService.create_refresh_token(data)
|
||||
|
||||
access_payload = jwt.decode(access_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
refresh_payload = jwt.decode(refresh_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
# 刷新令牌应该比访问令牌过期时间更长
|
||||
assert refresh_payload["exp"] > access_payload["exp"]
|
||||
|
||||
|
||||
class TestJWTTokenVerification:
|
||||
"""测试 JWT Token 验证"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_valid_access_token(self) -> None:
|
||||
"""测试验证有效的访问令牌"""
|
||||
data = {"sub": "user123", "email": "test@example.com"}
|
||||
token = AuthService.create_access_token(data)
|
||||
|
||||
with patch(
|
||||
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
):
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
|
||||
assert payload["sub"] == "user123"
|
||||
assert payload["type"] == "access"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_expired_token_raises_error(self) -> None:
|
||||
"""测试验证过期令牌抛出异常"""
|
||||
# 创建一个已过期的 token
|
||||
data = {"sub": "user123", "type": "access"}
|
||||
expire = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
data["exp"] = expire
|
||||
expired_token = jwt.encode(data, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await AuthService.verify_token(expired_token)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "过期" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_invalid_token_raises_error(self) -> None:
|
||||
"""测试验证无效令牌抛出异常"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await AuthService.verify_token("invalid.token.here")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_wrong_token_type_raises_error(self) -> None:
|
||||
"""测试令牌类型不匹配抛出异常"""
|
||||
data = {"sub": "user123"}
|
||||
refresh_token = AuthService.create_refresh_token(data)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with patch(
|
||||
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await AuthService.verify_token(refresh_token, token_type="access")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "类型错误" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_blacklisted_token_raises_error(self) -> None:
|
||||
"""测试已撤销的令牌抛出异常"""
|
||||
data = {"sub": "user123"}
|
||||
token = AuthService.create_access_token(data)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with patch(
|
||||
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await AuthService.verify_token(token)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "撤销" in exc_info.value.detail
|
||||
|
||||
|
||||
class TestUserAuthentication:
|
||||
"""测试用户登录认证"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_success(self) -> None:
|
||||
"""测试用户登录成功"""
|
||||
# Mock 数据库和用户对象
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user-123"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.is_active = True
|
||||
mock_user.verify_password.return_value = True
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
with patch(
|
||||
"src.services.auth.service.UserCacheService.invalidate_user_cache",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123")
|
||||
|
||||
assert result == mock_user
|
||||
mock_user.verify_password.assert_called_once_with("password123")
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_not_found(self) -> None:
|
||||
"""测试用户不存在"""
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
result = await AuthService.authenticate_user(mock_db, "nonexistent@example.com", "password")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_wrong_password(self) -> None:
|
||||
"""测试密码错误"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.verify_password.return_value = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
result = await AuthService.authenticate_user(mock_db, "test@example.com", "wrongpassword")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_inactive(self) -> None:
|
||||
"""测试用户已禁用"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.is_active = False
|
||||
mock_user.verify_password.return_value = True
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestAPIKeyAuthentication:
|
||||
"""测试 API Key 认证"""
|
||||
|
||||
def test_authenticate_api_key_success(self) -> None:
|
||||
"""测试 API Key 认证成功"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user-123"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.is_active = True
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_active = True
|
||||
mock_api_key.expires_at = None
|
||||
mock_api_key.user = mock_user
|
||||
mock_api_key.balance_used_usd = 0.0
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||
mock_api_key
|
||||
)
|
||||
|
||||
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||
with patch(
|
||||
"src.services.auth.service.ApiKeyService.check_balance",
|
||||
return_value=(True, 100.0),
|
||||
):
|
||||
result = AuthService.authenticate_api_key(mock_db, "sk-test-key")
|
||||
|
||||
assert result is not None
|
||||
assert result[0] == mock_user
|
||||
assert result[1] == mock_api_key
|
||||
|
||||
def test_authenticate_api_key_not_found(self) -> None:
|
||||
"""测试 API Key 不存在"""
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||
None
|
||||
)
|
||||
|
||||
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||
result = AuthService.authenticate_api_key(mock_db, "sk-invalid-key")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_authenticate_api_key_inactive(self) -> None:
|
||||
"""测试 API Key 已禁用"""
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_active = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||
mock_api_key
|
||||
)
|
||||
|
||||
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||
result = AuthService.authenticate_api_key(mock_db, "sk-inactive-key")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_authenticate_api_key_expired(self) -> None:
|
||||
"""测试 API Key 已过期"""
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_active = True
|
||||
mock_api_key.expires_at = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||
mock_api_key
|
||||
)
|
||||
|
||||
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||
result = AuthService.authenticate_api_key(mock_db, "sk-expired-key")
|
||||
|
||||
assert result is None
|
||||
292
tests/services/test_usage_service.py
Normal file
292
tests/services/test_usage_service.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
UsageService 测试
|
||||
|
||||
测试用量统计服务的核心功能:
|
||||
- 成本计算
|
||||
- 配额检查
|
||||
- 用量统计查询
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
|
||||
class TestCostCalculation:
|
||||
"""测试成本计算"""
|
||||
|
||||
def test_calculate_cost_basic(self) -> None:
|
||||
"""测试基础成本计算"""
|
||||
# 价格:输入 $3/1M, 输出 $15/1M
|
||||
result = UsageService.calculate_cost(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
input_price_per_1m=3.0,
|
||||
output_price_per_1m=15.0,
|
||||
)
|
||||
|
||||
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, request_cost, total_cost = result
|
||||
|
||||
# 1000 tokens * $3 / 1M = $0.003
|
||||
assert abs(input_cost - 0.003) < 0.0001
|
||||
# 500 tokens * $15 / 1M = $0.0075
|
||||
assert abs(output_cost - 0.0075) < 0.0001
|
||||
# Total = $0.003 + $0.0075 = $0.0105
|
||||
assert abs(total_cost - 0.0105) < 0.0001
|
||||
|
||||
def test_calculate_cost_with_cache(self) -> None:
|
||||
"""测试带缓存的成本计算"""
|
||||
result = UsageService.calculate_cost(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
input_price_per_1m=3.0,
|
||||
output_price_per_1m=15.0,
|
||||
cache_creation_input_tokens=200,
|
||||
cache_read_input_tokens=300,
|
||||
cache_creation_price_per_1m=3.75, # 1.25x input price
|
||||
cache_read_price_per_1m=0.3, # 0.1x input price
|
||||
)
|
||||
|
||||
(
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
cache_cost,
|
||||
request_cost,
|
||||
total_cost,
|
||||
) = result
|
||||
|
||||
# 验证缓存成本被计算
|
||||
assert cache_creation_cost > 0
|
||||
assert cache_read_cost > 0
|
||||
assert cache_cost == cache_creation_cost + cache_read_cost
|
||||
|
||||
def test_calculate_cost_with_request_price(self) -> None:
|
||||
"""测试按次计费"""
|
||||
result = UsageService.calculate_cost(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
input_price_per_1m=3.0,
|
||||
output_price_per_1m=15.0,
|
||||
price_per_request=0.01,
|
||||
)
|
||||
|
||||
(
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
cache_cost,
|
||||
request_cost,
|
||||
total_cost,
|
||||
) = result
|
||||
|
||||
assert request_cost == 0.01
|
||||
# Total 包含 request_cost
|
||||
assert total_cost == input_cost + output_cost + request_cost
|
||||
|
||||
def test_calculate_cost_zero_tokens(self) -> None:
|
||||
"""测试零 token 的成本计算"""
|
||||
result = UsageService.calculate_cost(
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
input_price_per_1m=3.0,
|
||||
output_price_per_1m=15.0,
|
||||
)
|
||||
|
||||
(
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
cache_cost,
|
||||
request_cost,
|
||||
total_cost,
|
||||
) = result
|
||||
|
||||
assert input_cost == 0
|
||||
assert output_cost == 0
|
||||
assert total_cost == 0
|
||||
|
||||
|
||||
class TestQuotaCheck:
|
||||
"""测试配额检查"""
|
||||
|
||||
def test_check_user_quota_sufficient(self) -> None:
|
||||
"""测试配额充足"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 30.0
|
||||
mock_user.role = MagicMock()
|
||||
mock_user.role.value = "user"
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||
|
||||
assert is_ok is True
|
||||
|
||||
def test_check_user_quota_exceeded(self) -> None:
|
||||
"""测试配额超限(当有预估成本时)"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 99.0 # 接近配额上限
|
||||
mock_user.role = MagicMock()
|
||||
mock_user.role.value = "user"
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
# 当预估成本超过剩余配额时应该返回 False
|
||||
is_ok, message = UsageService.check_user_quota(
|
||||
mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key
|
||||
)
|
||||
|
||||
assert is_ok is False
|
||||
assert "配额" in message
|
||||
|
||||
def test_check_user_quota_no_limit(self) -> None:
|
||||
"""测试无配额限制(None)"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = None
|
||||
mock_user.used_usd = 1000.0
|
||||
mock_user.role = MagicMock()
|
||||
mock_user.role.value = "user"
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||
|
||||
assert is_ok is True
|
||||
|
||||
def test_check_user_quota_admin_bypass(self) -> None:
|
||||
"""测试管理员绕过配额检查"""
|
||||
from src.models.database import UserRole
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 0.0
|
||||
mock_user.used_usd = 1000.0
|
||||
mock_user.role = UserRole.ADMIN
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||
|
||||
assert is_ok is True
|
||||
|
||||
def test_check_standalone_api_key_balance(self) -> None:
|
||||
"""测试独立 API Key 余额检查"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 0.0
|
||||
mock_user.used_usd = 0.0
|
||||
mock_user.role = MagicMock()
|
||||
mock_user.role.value = "user"
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = True
|
||||
mock_api_key.current_balance_usd = 50.0
|
||||
mock_api_key.balance_used_usd = 10.0
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||
|
||||
assert is_ok is True
|
||||
|
||||
def test_check_standalone_api_key_insufficient_balance(self) -> None:
|
||||
"""测试独立 API Key 余额不足"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 0.0
|
||||
mock_user.role = MagicMock()
|
||||
mock_user.role.value = "user"
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = True
|
||||
mock_api_key.current_balance_usd = 10.0
|
||||
mock_api_key.balance_used_usd = 9.0 # 剩余 $1
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
# 需要 mock ApiKeyService.get_remaining_balance
|
||||
with patch(
|
||||
"src.services.user.apikey.ApiKeyService.get_remaining_balance",
|
||||
return_value=1.0,
|
||||
):
|
||||
# 预估成本 $5 超过剩余余额 $1
|
||||
is_ok, message = UsageService.check_user_quota(
|
||||
mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key
|
||||
)
|
||||
|
||||
assert is_ok is False
|
||||
|
||||
|
||||
class TestUsageStatistics:
|
||||
"""测试用量统计查询
|
||||
|
||||
注意:get_usage_summary 方法内部使用了数据库方言特定的日期函数,
|
||||
需要真实数据库或更复杂的 mock。这里只测试方法存在性。
|
||||
"""
|
||||
|
||||
def test_get_usage_summary_exists(self) -> None:
|
||||
"""测试 get_usage_summary 方法存在"""
|
||||
assert hasattr(UsageService, "get_usage_summary")
|
||||
assert callable(getattr(UsageService, "get_usage_summary"))
|
||||
|
||||
|
||||
class TestHelperMethods:
|
||||
"""测试辅助方法"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rate_multiplier_and_free_tier_default(self) -> None:
|
||||
"""测试默认费率倍数"""
|
||||
mock_db = MagicMock()
|
||||
# 模拟未找到 provider_api_key
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier(
|
||||
mock_db, provider_api_key_id=None, provider_id=None
|
||||
)
|
||||
|
||||
assert rate_multiplier == 1.0
|
||||
assert is_free_tier is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rate_multiplier_from_provider_api_key(self) -> None:
|
||||
"""测试从 ProviderAPIKey 获取费率倍数"""
|
||||
mock_provider_api_key = MagicMock()
|
||||
mock_provider_api_key.rate_multiplier = 0.8
|
||||
|
||||
mock_endpoint = MagicMock()
|
||||
mock_endpoint.provider_id = "provider-123"
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.billing_type = "standard"
|
||||
|
||||
mock_db = MagicMock()
|
||||
# 第一次查询返回 provider_api_key
|
||||
mock_db.query.return_value.filter.return_value.first.side_effect = [
|
||||
mock_provider_api_key,
|
||||
mock_endpoint,
|
||||
mock_provider,
|
||||
]
|
||||
|
||||
rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier(
|
||||
mock_db, provider_api_key_id="pak-123", provider_id=None
|
||||
)
|
||||
|
||||
assert rate_multiplier == 0.8
|
||||
assert is_free_tier is False
|
||||
Reference in New Issue
Block a user