5 Commits

Author SHA1 Message Date
fawney19
21587449c8 fix: improve error classification and logging system
- Enhance error classifier to properly handle API key failures with fallback support
- Add error reason/code parsing for better AWS and multi-provider compatibility
- Improve error message structure detection for non-standard formats
- Refactor file logging with size-based rotation (100MB) instead of daily
- Optimize production logging by disabling backtrace and diagnose
- Clean up model validation and remove redundant configurations
2025-12-18 10:57:31 +08:00
fawney19
3d0ab353d3 refactor: migrate Pydantic Config to v2 ConfigDict 2025-12-18 02:20:53 +08:00
fawney19
b2a857c164 refactor: consolidate transaction management and remove legacy modules
- Remove unused context.py module (replaced by request.state)
- Remove provider_cache.py (no longer needed)
- Unify environment loading in config/settings.py instead of __init__.py
- Add deprecation warning for get_async_db() (consolidating on sync Session)
- Enhance database.py documentation with comprehensive transaction strategy
- Simplify audit logging to reuse request-level Session (no separate connections)
- Extract UsageService._build_usage_params() helper to reduce code duplication
- Update model and user cache implementations with refined transaction handling
- Remove unnecessary sessionmaker from pipeline
- Clean up audit service exception handling
2025-12-18 01:59:40 +08:00
fawney19
4d1d863916 refactor: improve authentication and user data handling
- Replace user cache queries with direct database queries to ensure data consistency
- Fix token_type parameter in verify_token calls (access token verification)
- Fix role-based permission check using dictionary ranking instead of string comparison
- Fix logout operation to use correct JWT claim name (user_id instead of sub)
- Simplify user authentication flow by removing unnecessary cache layer
- Optimize session initialization in main.py using create_session helper
- Remove unused imports and exception variables
2025-12-18 01:09:22 +08:00
fawney19
b579420690 refactor: optimize database session lifecycle and middleware architecture
- Improve database pool capacity logging with detailed configuration parameters
- Optimize database session dependency injection with middleware-managed lifecycle
- Simplify plugin middleware by delegating session creation to FastAPI dependencies
- Fix import path in auth routes (relative to absolute)
- Add safety checks for database session management across middleware exception handlers
- Ensure session cleanup only when not managed by middleware (avoid premature cleanup)
2025-12-18 00:35:46 +08:00
39 changed files with 1936 additions and 1256 deletions

View File

@@ -22,7 +22,7 @@
/>
</Transition>
<div class="relative flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
<div class="relative flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0 pointer-events-none">
<!-- 对话框内容 -->
<Transition
enter-active-class="duration-300 ease-out"

View File

@@ -3,10 +3,8 @@
A proxy server that enables AI models to work with multiple API providers.
"""
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# 注意: dotenv 加载已统一移至 src/config/settings.py
# 不要在此处重复加载
try:
from src._version import __version__

View File

@@ -7,7 +7,7 @@ from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
@@ -52,8 +52,7 @@ class CandidateResponse(BaseModel):
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class RequestTraceResponse(BaseModel):

View File

@@ -142,7 +142,7 @@ class AnnouncementOptionalAuthAdapter(ApiAdapter):
token = authorization.replace("Bearer ", "").strip()
try:
payload = await AuthService.verify_token(token)
payload = await AuthService.verify_token(token, token_type="access")
user_id = payload.get("user_id")
if not user_id:
return None

View File

@@ -211,7 +211,7 @@ class AuthRefreshAdapter(AuthPublicAdapter):
class AuthRegisterAdapter(AuthPublicAdapter):
async def handle(self, context): # type: ignore[override]
from ..models.database import SystemConfig
from src.models.database import SystemConfig
db = context.db
payload = context.ensure_json_body()

View File

@@ -5,13 +5,12 @@ from enum import Enum
from typing import Any, Optional, Tuple
from fastapi import HTTPException, Request
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from src.core.exceptions import QuotaExceededException
from src.core.logger import logger
from src.models.database import ApiKey, AuditEventType, User, UserRole
from src.services.auth.service import AuthService
from src.services.cache.user_cache import UserCacheService
from src.services.system.audit import AuditService
from src.services.usage.service import UsageService
@@ -180,7 +179,7 @@ class ApiRequestPipeline:
token = authorization.replace("Bearer ", "").strip()
try:
payload = await self.auth_service.verify_token(token)
payload = await self.auth_service.verify_token(token, token_type="access")
except HTTPException:
raise
except Exception as exc:
@@ -191,8 +190,8 @@ class ApiRequestPipeline:
if not user_id:
raise HTTPException(status_code=401, detail="无效的管理员令牌")
# 使用缓存查询用户
user = await UserCacheService.get_user_by_id(db, user_id)
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
user = db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active:
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
@@ -207,7 +206,7 @@ class ApiRequestPipeline:
token = authorization.replace("Bearer ", "").strip()
try:
payload = await self.auth_service.verify_token(token)
payload = await self.auth_service.verify_token(token, token_type="access")
except HTTPException:
raise
except Exception as exc:
@@ -218,8 +217,8 @@ class ApiRequestPipeline:
if not user_id:
raise HTTPException(status_code=401, detail="无效的用户令牌")
# 使用缓存查询用户
user = await UserCacheService.get_user_by_id(db, user_id)
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
user = db.query(User).filter(User.id == user_id).first()
if not user or not user.is_active:
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
@@ -242,11 +241,15 @@ class ApiRequestPipeline:
status_code: Optional[int] = None,
error: Optional[str] = None,
) -> None:
"""记录审计事件
事务策略:复用请求级 Session不单独提交。
审计记录随主事务一起提交,由中间件统一管理。
"""
if not getattr(adapter, "audit_log_enabled", True):
return
bind = context.db.get_bind()
if bind is None:
if context.db is None:
return
event_type = adapter.audit_success_event if success else adapter.audit_failure_event
@@ -266,11 +269,11 @@ class ApiRequestPipeline:
error=error,
)
SessionMaker = sessionmaker(bind=bind)
audit_session = SessionMaker()
try:
# 复用请求级 Session不创建新的连接
# 审计记录随主事务一起提交,由中间件统一管理
self.audit_service.log_event(
db=audit_session,
db=context.db,
event_type=event_type,
description=f"{context.request.method} {context.request.url.path} via {adapter.name}",
user_id=context.user.id if context.user else None,
@@ -282,12 +285,9 @@ class ApiRequestPipeline:
error_message=error,
metadata=metadata,
)
audit_session.commit()
except Exception as exc:
audit_session.rollback()
# 审计失败不应影响主请求,仅记录警告
logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}")
finally:
audit_session.close()
def _build_audit_metadata(
self,

View File

@@ -411,9 +411,10 @@ class BaseMessageHandler:
QuotaExceededException,
RateLimitException,
ModelNotSupportedException,
UpstreamClientException,
)
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException, UpstreamClientException)):
# 业务异常:简洁日志,不打印堆栈
logger.error(f"{message}: [{type(error).__name__}] {error}")
else:

View File

@@ -267,6 +267,9 @@ async def get_redis_client(require_redis: bool = False) -> Optional[aioredis.Red
if _redis_manager is None:
_redis_manager = RedisClientManager()
# 如果尚未连接(例如启动时降级、或 close() 后),尝试重新初始化。
# initialize() 内部包含熔断器逻辑,避免频繁重试导致抖动。
if _redis_manager.get_client() is None:
await _redis_manager.initialize(require_redis=require_redis)
return _redis_manager.get_client()

View File

@@ -41,8 +41,8 @@ class CacheSize:
class ConcurrencyDefaults:
"""并发控制默认值"""
# 自适应并发初始限制(保守值
INITIAL_LIMIT = 3
# 自适应并发初始限制(宽松起步,遇到 429 再降低
INITIAL_LIMIT = 50
# 429错误后的冷却时间分钟- 在此期间不会增加并发限制
COOLDOWN_AFTER_429_MINUTES = 5
@@ -67,13 +67,14 @@ class ConcurrencyDefaults:
MIN_SAMPLES_FOR_DECISION = 5
# 扩容步长 - 每次扩容增加的并发数
INCREASE_STEP = 1
INCREASE_STEP = 2
# 缩容乘数 - 遇到 429 时的缩容比例
DECREASE_MULTIPLIER = 0.7
# 缩容乘数 - 遇到 429 时基于当前并发数的缩容比例
# 0.85 表示降到触发 429 时并发数的 85%
DECREASE_MULTIPLIER = 0.85
# 最大并发限制上限
MAX_CONCURRENT_LIMIT = 100
MAX_CONCURRENT_LIMIT = 200
# 最小并发限制下限
MIN_CONCURRENT_LIMIT = 1
@@ -85,6 +86,11 @@ class ConcurrencyDefaults:
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
PROBE_INCREASE_MIN_REQUESTS = 10
# === 缓存用户预留比例 ===
# 缓存用户槽位预留比例(新用户可用 1 - 此值)
# 0.1 表示缓存用户预留 10%,新用户可用 90%
CACHE_RESERVATION_RATIO = 0.1
class CircuitBreakerDefaults:
"""熔断器配置默认值(滑动窗口 + 半开状态模式)

View File

@@ -122,9 +122,9 @@ class Config:
# 并发控制配置
# CONCURRENCY_SLOT_TTL: 并发槽位 TTL防止死锁
# CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 30%
# CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 10%,新用户可用 90%
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.3"))
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1"))
# HTTP 请求超时配置(秒)
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))

View File

@@ -46,6 +46,11 @@ class BatchCommitter:
def mark_dirty(self, session: Session):
"""标记 Session 有待提交的更改"""
# 请求级事务由中间件统一 commit/rollback避免后台任务在请求中途误提交。
if session is None:
return
if session.info.get("managed_by_middleware"):
return
self._pending_sessions.add(session)
async def _batch_commit_loop(self):

View File

@@ -1,168 +0,0 @@
"""
统一的请求上下文
RequestContext 贯穿整个请求生命周期,包含所有请求相关信息。
这确保了数据在各层之间传递时不会丢失。
使用方式:
1. Pipeline 层创建 RequestContext
2. 各层通过 context 访问和更新信息
3. Adapter 层使用 context 记录 Usage
"""
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
@dataclass
class RequestContext:
"""
请求上下文 - 贯穿整个请求生命周期
设计原则:
1. 在请求开始时创建,包含所有已知信息
2. 在请求执行过程中逐步填充 Provider 信息
3. 在请求结束时用于记录 Usage
"""
# ==================== 请求标识 ====================
request_id: str
# ==================== 认证信息 ====================
user: Any # User model
api_key: Any # ApiKey model
db: Any # Database session
# ==================== 请求信息 ====================
api_format: str # CLAUDE, OPENAI, GEMINI, etc.
model: str # 用户请求的模型名
is_stream: bool = False
# ==================== 原始请求 ====================
original_headers: Dict[str, str] = field(default_factory=dict)
original_body: Dict[str, Any] = field(default_factory=dict)
# ==================== 客户端信息 ====================
client_ip: str = "unknown"
user_agent: str = ""
# ==================== 计时 ====================
start_time: float = field(default_factory=time.time)
# ==================== Provider 信息(请求执行后填充)====================
provider_name: Optional[str] = None
provider_id: Optional[str] = None
endpoint_id: Optional[str] = None
provider_api_key_id: Optional[str] = None
# ==================== 模型映射信息 ====================
resolved_model: Optional[str] = None # 映射后的模型名
original_model: Optional[str] = None # 原始模型名(用于价格计算)
# ==================== 请求/响应头 ====================
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_response_headers: Dict[str, str] = field(default_factory=dict)
# ==================== 追踪信息 ====================
attempt_id: Optional[str] = None
# ==================== 能力需求 ====================
capability_requirements: Dict[str, bool] = field(default_factory=dict)
# 运行时计算的能力需求,来源于:
# 1. 用户 model_capability_settings
# 2. 用户 ApiKey.force_capabilities
# 3. 请求头 X-Require-Capability
# 4. 失败重试时动态添加
@classmethod
def create(
cls,
*,
db: Any,
user: Any,
api_key: Any,
api_format: str,
model: str,
is_stream: bool = False,
original_headers: Optional[Dict[str, str]] = None,
original_body: Optional[Dict[str, Any]] = None,
client_ip: str = "unknown",
user_agent: str = "",
request_id: Optional[str] = None,
) -> "RequestContext":
"""创建请求上下文"""
return cls(
request_id=request_id or str(uuid.uuid4()),
db=db,
user=user,
api_key=api_key,
api_format=api_format,
model=model,
is_stream=is_stream,
original_headers=original_headers or {},
original_body=original_body or {},
client_ip=client_ip,
user_agent=user_agent,
original_model=model, # 初始时原始模型等于请求模型
)
def update_provider_info(
self,
*,
provider_name: str,
provider_id: str,
endpoint_id: str,
provider_api_key_id: str,
resolved_model: Optional[str] = None,
) -> None:
"""更新 Provider 信息(请求执行后调用)"""
self.provider_name = provider_name
self.provider_id = provider_id
self.endpoint_id = endpoint_id
self.provider_api_key_id = provider_api_key_id
if resolved_model:
self.resolved_model = resolved_model
def update_headers(
self,
*,
request_headers: Optional[Dict[str, str]] = None,
response_headers: Optional[Dict[str, str]] = None,
) -> None:
"""更新请求/响应头"""
if request_headers:
self.provider_request_headers = request_headers
if response_headers:
self.provider_response_headers = response_headers
@property
def elapsed_ms(self) -> int:
"""计算已经过的时间(毫秒)"""
return int((time.time() - self.start_time) * 1000)
@property
def effective_model(self) -> str:
"""获取有效的模型名(映射后优先)"""
return self.resolved_model or self.model
@property
def billing_model(self) -> str:
"""获取计费模型名(原始模型优先)"""
return self.original_model or self.model
def to_metadata_dict(self) -> Dict[str, Any]:
"""转换为元数据字典(用于 Usage 记录)"""
return {
"api_format": self.api_format,
"provider": self.provider_name or "unknown",
"model": self.effective_model,
"original_model": self.billing_model,
"provider_id": self.provider_id,
"provider_endpoint_id": self.endpoint_id,
"provider_api_key_id": self.provider_api_key_id,
"provider_request_headers": self.provider_request_headers,
"provider_response_headers": self.provider_response_headers,
"attempt_id": self.attempt_id,
}

View File

@@ -9,7 +9,7 @@
输出策略:
- 控制台: 开发环境=DEBUG, 生产环境=INFO (通过 LOG_LEVEL 控制)
- 文件: 始终保存 DEBUG 级别保留30天每日轮转
- 文件: 始终保存 DEBUG 级别保留30天按大小轮转 (100MB)
使用方式:
from src.core.logger import logger
@@ -72,12 +72,15 @@ def _log_filter(record: dict) -> bool: # type: ignore[type-arg]
if IS_DOCKER:
# 生产环境:禁用 backtrace 和 diagnose减少日志噪音
logger.add(
sys.stdout,
format=CONSOLE_FORMAT_PROD,
level=LOG_LEVEL,
filter=_log_filter, # type: ignore[arg-type]
colorize=False,
backtrace=False,
diagnose=False,
)
else:
logger.add(
@@ -92,30 +95,37 @@ if not DISABLE_FILE_LOG:
log_dir = PROJECT_ROOT / "logs"
log_dir.mkdir(exist_ok=True)
# 文件日志通用配置
file_log_config = {
"format": FILE_FORMAT,
"filter": _log_filter,
"rotation": "100 MB",
"retention": "30 days",
"compression": "gz",
"enqueue": True,
"encoding": "utf-8",
"catch": True,
}
# 生产环境禁用详细堆栈
if IS_DOCKER:
file_log_config["backtrace"] = False
file_log_config["diagnose"] = False
# 主日志文件 - 所有级别
logger.add(
log_dir / "app.log",
format=FILE_FORMAT,
level="DEBUG",
filter=_log_filter, # type: ignore[arg-type]
rotation="00:00",
retention="30 days",
compression="gz",
enqueue=True,
encoding="utf-8",
**file_log_config, # type: ignore[arg-type]
)
# 错误日志文件 - 仅 ERROR 及以上
error_log_config = file_log_config.copy()
error_log_config["rotation"] = "50 MB"
logger.add(
log_dir / "error.log",
format=FILE_FORMAT,
level="ERROR",
filter=_log_filter, # type: ignore[arg-type]
rotation="00:00",
retention="30 days",
compression="gz",
enqueue=True,
encoding="utf-8",
**error_log_config, # type: ignore[arg-type]
)
# ============================================================================

View File

@@ -5,6 +5,7 @@
import time
from typing import AsyncGenerator, Generator, Optional
from starlette.requests import Request
from sqlalchemy import create_engine, event
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import (
@@ -150,9 +151,22 @@ def _log_pool_capacity():
theoretical = config.db_pool_size + config.db_max_overflow
workers = max(1, config.worker_processes)
total_estimated = theoretical * workers
logger.info("数据库连接池配置")
if total_estimated > config.db_pool_warn_threshold:
logger.warning("数据库连接需求可能超过阈值,请调小池大小或减少 worker 数")
safe_limit = config.pg_max_connections - config.pg_reserved_connections
logger.info(
"数据库连接池配置: pool_size=%s, max_overflow=%s, workers=%s, total_estimated=%s, safe_limit=%s",
config.db_pool_size,
config.db_max_overflow,
workers,
total_estimated,
safe_limit,
)
if total_estimated > safe_limit:
logger.warning(
"数据库连接池总需求可能超过 PostgreSQL 限制: %s > %s (pg_max_connections - reserved)"
"建议调整 DB_POOL_SIZE/DB_MAX_OVERFLOW 或减少 worker 数",
total_estimated,
safe_limit,
)
def _ensure_async_engine() -> AsyncEngine:
@@ -185,7 +199,7 @@ def _ensure_async_engine() -> AsyncEngine:
# 创建异步引擎
_async_engine = create_async_engine(
ASYNC_DATABASE_URL,
poolclass=QueuePool, # 使用队列连接池
# AsyncEngine 不能使用 QueuePool默认使用 AsyncAdaptedQueuePool
pool_size=config.db_pool_size,
max_overflow=config.db_max_overflow,
pool_timeout=config.db_pool_timeout,
@@ -209,7 +223,18 @@ def _ensure_async_engine() -> AsyncEngine:
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""获取异步数据库会话"""
"""获取异步数据库会话
.. deprecated::
此方法已废弃,项目统一使用同步 Session。
未来版本可能移除此方法。请使用 get_db() 代替。
"""
import warnings
warnings.warn(
"get_async_db() 已废弃,项目统一使用同步 Session。请使用 get_db() 代替。",
DeprecationWarning,
stacklevel=2,
)
# 确保异步引擎已初始化
_ensure_async_engine()
@@ -220,16 +245,61 @@ async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
await session.close()
def get_db() -> Generator[Session, None, None]:
def get_db(request: Request = None) -> Generator[Session, None, None]: # type: ignore[assignment]
"""获取数据库会话
注意:事务管理由业务逻辑层显式控制(手动调用 commit/rollback
这里只负责会话的创建和关闭,不自动提交
事务策略说明
============
本项目采用**混合事务管理**策略:
1. **LLM 请求路径**
- 由 PluginMiddleware 统一管理事务
- Service 层使用 db.flush() 使更改可见,但不提交
- 请求结束时由中间件统一 commit 或 rollback
- 例外UsageService.record_usage() 会显式 commit因为使用记录需要立即持久化
2. **管理后台 API**
- 路由层显式调用 db.commit()
- 每个操作独立提交,不依赖中间件
3. **后台任务/调度器**
- 使用独立 Session通过 create_session() 或 next(get_db())
- 自行管理事务生命周期
使用方式
========
- FastAPI 请求:通过 Depends(get_db) 注入,支持中间件管理的 session 复用
- 非请求上下文:直接调用 get_db(),退化为独立 session 模式
注意事项
========
- 本函数不自动提交事务
- 异常时会自动回滚
- 中间件管理模式下session 关闭由中间件负责
"""
# FastAPI 请求上下文:优先复用中间件绑定的 request.state.db
if request is not None:
existing_db = getattr(getattr(request, "state", None), "db", None)
if isinstance(existing_db, Session):
yield existing_db
return
# 确保引擎已初始化
_ensure_engine()
db = _SessionLocal()
# 如果中间件声明会统一管理会话生命周期,则把 session 绑定到 request.state
# 并由中间件负责 commit/rollback/close这里不关闭避免流式响应提前释放会话
managed_by_middleware = bool(
request is not None
and hasattr(request, "state")
and getattr(request.state, "db_managed_by_middleware", False)
)
if managed_by_middleware:
request.state.db = db
db.info["managed_by_middleware"] = True
try:
yield db
# 不再自动 commit由业务代码显式管理事务
@@ -241,12 +311,13 @@ def get_db() -> Generator[Session, None, None]:
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
raise
finally:
try:
db.close() # 确保连接返回池
except Exception as close_error:
# 记录关闭错误(如 IllegalStateChangeError
# 连接池会处理连接的回收
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
if not managed_by_middleware:
try:
db.close() # 确保连接返回池
except Exception as close_error:
# 记录关闭错误(如 IllegalStateChangeError
# 连接池会处理连接的回收
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
def create_session() -> Session:
@@ -336,7 +407,7 @@ def init_admin_user(db: Session):
admin.set_password(config.admin_password)
db.add(admin)
db.commit() # 刷新以获取ID但不提交
db.flush() # 分配ID但不提交事务(由外层 init_db 统一 commit
logger.info(f"创建管理员账户成功: {admin.email} ({admin.username})")
except Exception as e:

View File

@@ -3,7 +3,6 @@
采用模块化架构设计
"""
import asyncio
from contextlib import asynccontextmanager
from pathlib import Path
@@ -39,14 +38,12 @@ async def initialize_providers():
"""从数据库初始化提供商(仅用于日志记录)"""
from sqlalchemy.orm import Session
from src.core.enums import APIFormat
from src.database import get_db
from src.database.database import create_session
from src.models.database import Provider
try:
# 创建数据库会话
db_gen = get_db()
db: Session = next(db_gen)
db: Session = create_session()
try:
# 从数据库加载所有活跃的提供商
@@ -75,7 +72,7 @@ async def initialize_providers():
finally:
db.close()
except Exception as e:
except Exception:
logger.exception("从数据库初始化提供商失败")

View File

@@ -14,7 +14,6 @@ from starlette.responses import Response as StarletteResponse
from src.config import config
from src.core.logger import logger
from src.database import get_db
from src.plugins.manager import get_plugin_manager
from src.plugins.rate_limit.base import RateLimitResult
@@ -71,26 +70,13 @@ class PluginMiddleware(BaseHTTPMiddleware):
start_time = time.time()
request.state.request_id = request.headers.get("x-request-id", "")
request.state.start_time = start_time
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
request.state.db_managed_by_middleware = True
# 从 request.app 获取 FastAPI 应用实例(而不是从 __init__ 的 app 参数)
# 这样才能访问到真正的 FastAPI 实例和其 dependency_overrides
db_func = get_db
if hasattr(request, "app") and hasattr(request.app, "dependency_overrides"):
if get_db in request.app.dependency_overrides:
db_func = request.app.dependency_overrides[get_db]
logger.debug("Using overridden get_db from app.dependency_overrides")
# 创建数据库会话供需要的插件或后续处理使用
db_gen = db_func()
db = None
response = None
exception_to_raise = None
try:
# 获取数据库会话
db = next(db_gen)
request.state.db = db
# 1. 限流插件调用(可选功能)
rate_limit_result = await self._call_rate_limit_plugins(request)
if rate_limit_result and not rate_limit_result.allowed:
@@ -111,10 +97,17 @@ class PluginMiddleware(BaseHTTPMiddleware):
# 3. 提交关键数据库事务(在返回响应前)
# 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化
try:
db.commit()
db = getattr(request.state, "db", None)
if isinstance(db, Session):
db.commit()
except Exception as commit_error:
logger.error(f"关键事务提交失败: {commit_error}")
db.rollback()
try:
if isinstance(db, Session):
db.rollback()
except Exception:
pass
await self._call_error_plugins(request, commit_error, start_time)
# 返回 500 错误,因为数据可能不一致
response = JSONResponse(
status_code=500,
@@ -139,14 +132,18 @@ class PluginMiddleware(BaseHTTPMiddleware):
except RuntimeError as e:
if str(e) == "No response returned.":
if db:
db.rollback()
db = getattr(request.state, "db", None)
if isinstance(db, Session):
try:
db.rollback()
except Exception:
pass
logger.error("Downstream handler completed without returning a response")
await self._call_error_plugins(request, e, start_time)
if db:
if isinstance(db, Session):
try:
db.commit()
except Exception:
@@ -167,14 +164,18 @@ class PluginMiddleware(BaseHTTPMiddleware):
except Exception as e:
# 回滚数据库事务
if db:
db.rollback()
db = getattr(request.state, "db", None)
if isinstance(db, Session):
try:
db.rollback()
except Exception:
pass
# 错误处理插件调用
await self._call_error_plugins(request, e, start_time)
# 尝试提交错误日志
if db:
if isinstance(db, Session):
try:
db.commit()
except:
@@ -183,38 +184,13 @@ class PluginMiddleware(BaseHTTPMiddleware):
exception_to_raise = e
finally:
# 确保数据库会话被正确关闭
# 注意:需要安全地处理各种状态,避免 IllegalStateChangeError
if db is not None:
db = getattr(request.state, "db", None)
if isinstance(db, Session):
try:
# 检查会话是否可以安全地进行回滚
# 只有当没有进行中的事务操作时才尝试回滚
if db.is_active and not db.get_transaction().is_active:
# 事务不在活跃状态,可以安全回滚
pass
elif db.is_active:
# 事务在活跃状态,尝试回滚
try:
db.rollback()
except Exception as rollback_error:
# 回滚失败(可能是 commit 正在进行中),忽略错误
logger.debug(f"Rollback skipped: {rollback_error}")
except Exception:
# 检查状态时出错,忽略
pass
# 通过触发生成器的 finally 块来关闭会话(标准模式)
# 这会调用 get_db() 的 finally 块,执行 db.close()
try:
next(db_gen, None)
except StopIteration:
# 正常情况:生成器已耗尽
pass
except Exception as cleanup_error:
# 忽略 IllegalStateChangeError 等清理错误
# 这些错误通常是由于事务状态不一致导致的,不影响业务逻辑
if "IllegalStateChangeError" not in str(type(cleanup_error).__name__):
logger.warning(f"Database cleanup warning: {cleanup_error}")
db.close()
except Exception as close_error:
# 连接池会处理连接的回收,这里的异常不应影响响应
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
# 在 finally 块之后处理异常和响应
if exception_to_raise:
@@ -250,7 +226,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
return False
async def _get_rate_limit_key_and_config(
self, request: Request, db: Session
self, request: Request
) -> tuple[Optional[str], Optional[int]]:
"""
获取速率限制的key和配置
@@ -318,14 +294,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
# 如果没有限流插件,允许通过
return None
# 获取数据库会话
db = getattr(request.state, "db", None)
if not db:
logger.warning("速率限制检查:无法获取数据库会话")
return None
# 获取速率限制的key和配置从数据库
key, rate_limit_value = await self._get_rate_limit_key_and_config(request, db)
# 获取速率限制的 key 和配置
key, rate_limit_value = await self._get_rate_limit_key_and_config(request)
if not key:
# 不需要限流的端点(如未分类路径),静默跳过
return None
@@ -336,7 +306,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
key=key,
endpoint=request.url.path,
method=request.method,
rate_limit=rate_limit_value, # 传入数据库配置的限制值
rate_limit=rate_limit_value, # 传入配置的限制值
)
# 类型检查确保返回的是RateLimitResult类型
if isinstance(result, RateLimitResult):

View File

@@ -107,20 +107,6 @@ class CreateProviderRequest(BaseModel):
if not re.match(r"^https?://", v, re.IGNORECASE):
v = f"https://{v}"
# 防止 SSRF 攻击:禁止内网地址
forbidden_patterns = [
r"localhost",
r"127\.0\.0\.1",
r"0\.0\.0\.0",
r"192\.168\.",
r"10\.",
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
r"169\.254\.",
]
for pattern in forbidden_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("不允许使用内网地址")
return v
@field_validator("billing_type")
@@ -195,19 +181,6 @@ class CreateEndpointRequest(BaseModel):
if not re.match(r"^https?://", v, re.IGNORECASE):
raise ValueError("URL 必须以 http:// 或 https:// 开头")
# 防止 SSRF
forbidden_patterns = [
r"localhost",
r"127\.0\.0\.1",
r"0\.0\.0\.0",
r"192\.168\.",
r"10\.",
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
]
for pattern in forbidden_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("不允许使用内网地址")
return v.rstrip("/") # 移除末尾斜杠
@field_validator("api_format")

View File

@@ -6,7 +6,7 @@ import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from ..core.enums import UserRole
@@ -336,8 +336,7 @@ class ProviderResponse(BaseModel):
active_models_count: int = 0
api_keys_count: int = 0
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
# ========== 模型管理 ==========
@@ -442,8 +441,7 @@ class ModelResponse(BaseModel):
global_model_name: Optional[str] = None
global_model_display_name: Optional[str] = None
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class ModelDetailResponse(BaseModel):
@@ -469,8 +467,7 @@ class ModelDetailResponse(BaseModel):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
# ========== 系统设置 ==========

View File

@@ -5,7 +5,7 @@ Provider API Key相关的API模型
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class ProviderAPIKeyBase(BaseModel):
@@ -53,8 +53,7 @@ class ProviderAPIKeyResponse(ProviderAPIKeyBase):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class ProviderAPIKeyStats(BaseModel):

View File

@@ -27,8 +27,7 @@ from sqlalchemy import (
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm import declarative_base, relationship
from ..config import config
from ..core.enums import ProviderBillingType, UserRole

View File

@@ -6,7 +6,7 @@ import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
# ========== ProviderEndpoint CRUD ==========
@@ -45,24 +45,9 @@ class ProviderEndpointCreate(BaseModel):
@field_validator("base_url")
@classmethod
def validate_base_url(cls, v: str) -> str:
"""验证 API URLSSRF 防护)"""
if not re.match(r"^https?://", v, re.IGNORECASE):
raise ValueError("URL 必须以 http:// 或 https:// 开头")
# 防止 SSRF 攻击:禁止内网地址
forbidden_patterns = [
r"localhost",
r"127\.0\.0\.1",
r"0\.0\.0\.0",
r"192\.168\.",
r"10\.",
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
r"169\.254\.",
]
for pattern in forbidden_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("不允许使用内网地址")
return v.rstrip("/") # 移除末尾斜杠
@@ -83,27 +68,13 @@ class ProviderEndpointUpdate(BaseModel):
@field_validator("base_url")
@classmethod
def validate_base_url(cls, v: Optional[str]) -> Optional[str]:
"""验证 API URLSSRF 防护)"""
"""验证 API URL"""
if v is None:
return v
if not re.match(r"^https?://", v, re.IGNORECASE):
raise ValueError("URL 必须以 http:// 或 https:// 开头")
# 防止 SSRF 攻击:禁止内网地址
forbidden_patterns = [
r"localhost",
r"127\.0\.0\.1",
r"0\.0\.0\.0",
r"192\.168\.",
r"10\.",
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
r"169\.254\.",
]
for pattern in forbidden_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("不允许使用内网地址")
return v.rstrip("/") # 移除末尾斜杠
@@ -141,8 +112,7 @@ class ProviderEndpointResponse(BaseModel):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
# ========== ProviderAPIKey 相关(新架构) ==========
@@ -384,8 +354,7 @@ class EndpointAPIKeyResponse(BaseModel):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
# ========== 健康监控相关 ==========
@@ -535,8 +504,7 @@ class ProviderWithEndpointsSummary(BaseModel):
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
# ========== 健康监控可视化模型 ==========

View File

@@ -5,7 +5,7 @@ Pydantic 数据模型(阶段一统一模型管理)
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
# ========== 阶梯计费相关模型 ==========
@@ -256,8 +256,7 @@ class GlobalModelResponse(BaseModel):
created_at: datetime
updated_at: Optional[datetime]
class Config:
from_attributes = True
model_config = ConfigDict(from_attributes=True)
class GlobalModelWithStats(GlobalModelResponse):

View File

@@ -51,7 +51,7 @@ class JwtAuthPlugin(AuthPlugin):
try:
# 验证JWT token
payload = AuthService.verify_token(token)
payload = await AuthService.verify_token(token, token_type="access")
logger.debug(f"JWT token验证成功, payload: {payload}")
# 从payload中提取用户信息

View File

@@ -93,8 +93,8 @@ class AuthService:
@staticmethod
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
"""用户登录认证"""
# 使用缓存查询用户
user = await UserCacheService.get_user_by_email(db, email)
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
user = db.query(User).filter(User.email == email).first()
if not user:
logger.warning(f"登录失败 - 用户不存在: {email}")
@@ -109,13 +109,10 @@ class AuthService:
return None
# 更新最后登录时间
# 需要重新从数据库获取以便更新(缓存的对象是分离的)
db_user = db.query(User).filter(User.id == user.id).first()
if db_user:
db_user.last_login_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁
# 清除缓存,因为用户信息已更新
await UserCacheService.invalidate_user_cache(user.id, user.email)
user.last_login_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁
# 清除缓存,因为用户信息已更新
await UserCacheService.invalidate_user_cache(user.id, user.email)
logger.info(f"用户登录成功: {email} (ID: {user.id})")
return user
@@ -198,7 +195,10 @@ class AuthService:
if user.role == UserRole.ADMIN:
return True
if user.role.value >= required_role.value:
# 避免使用字符串比较导致权限判断错误(例如 'user' >= 'admin'
role_rank = {UserRole.USER: 0, UserRole.ADMIN: 1}
# 未知用户角色默认 -1拒绝未知要求角色默认 999拒绝
if role_rank.get(user.role, -1) >= role_rank.get(required_role, 999):
return True
logger.warning(f"权限不足: 用户 {user.email} 角色 {user.role.value} < 需要 {required_role.value}")
@@ -230,7 +230,7 @@ class AuthService:
)
if success:
user_id = payload.get("sub")
user_id = payload.get("user_id")
logger.info(f"用户登出成功: user_id={user_id}")
return success

View File

@@ -59,7 +59,6 @@ from src.services.health.monitor import health_monitor
from src.services.provider.format import normalize_api_format
from src.services.rate_limit.adaptive_reservation import (
AdaptiveReservationManager,
ReservationResult,
get_adaptive_reservation_manager,
)
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
@@ -112,8 +111,6 @@ class CacheAwareScheduler:
- 健康度监控
"""
# 静态常量作为默认值(实际由 AdaptiveReservationManager 动态计算)
CACHE_RESERVATION_RATIO = 0.3
# 优先级模式常量
PRIORITY_MODE_PROVIDER = "provider" # 提供商优先模式
PRIORITY_MODE_GLOBAL_KEY = "global_key" # 全局 Key 优先模式
@@ -1320,7 +1317,6 @@ class CacheAwareScheduler:
return {
"scheduler": "cache_aware",
"cache_reservation_ratio": self.CACHE_RESERVATION_RATIO,
"dynamic_reservation": {
"enabled": True,
"config": reservation_stats["config"],

View File

@@ -1,5 +1,21 @@
"""
Model 映射缓存服务 - 减少模型查询
架构说明
========
本服务采用混合 async/sync 模式:
- 缓存操作CacheService真正的 async使用 aioredis
- 数据库查询db.query同步的 SQLAlchemy Session
设计决策
--------
1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优
2. 缓存未命中时的同步查询FastAPI 会在线程池中执行,不会阻塞事件循环
3. 调用方必须在 async 上下文中使用 await
使用示例
--------
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, "gpt-4")
"""
import time
@@ -19,7 +35,11 @@ from src.models.database import GlobalModel, Model
class ModelCacheService:
"""Model 映射缓存服务"""
"""Model 映射缓存服务
提供 GlobalModel 和 Model 的缓存查询功能,减少数据库访问。
所有公开方法均为 async需要在 async 上下文中调用。
"""
# 缓存 TTL- 使用统一常量
CACHE_TTL = CacheTTL.MODEL

View File

@@ -1,254 +0,0 @@
"""
Provider 配置缓存服务 - 减少 Provider/Endpoint/APIKey 查询
"""
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
from src.core.cache_service import CacheKeys, CacheService
from src.core.logger import logger
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
class ProviderCacheService:
"""Provider 配置缓存服务"""
# 缓存 TTL- 使用统一常量
CACHE_TTL = CacheTTL.PROVIDER
@staticmethod
async def get_provider_by_id(db: Session, provider_id: str) -> Optional[Provider]:
"""
获取 Provider带缓存
Args:
db: 数据库会话
provider_id: Provider ID
Returns:
Provider 对象或 None
"""
cache_key = CacheKeys.provider_by_id(provider_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"Provider 缓存命中: {provider_id}")
return ProviderCacheService._dict_to_provider(cached_data)
# 2. 缓存未命中,查询数据库
provider = db.query(Provider).filter(Provider.id == provider_id).first()
# 3. 写入缓存
if provider:
provider_dict = ProviderCacheService._provider_to_dict(provider)
await CacheService.set(
cache_key, provider_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"Provider 已缓存: {provider_id}")
return provider
@staticmethod
async def get_endpoint_by_id(db: Session, endpoint_id: str) -> Optional[ProviderEndpoint]:
"""
获取 Endpoint带缓存
Args:
db: 数据库会话
endpoint_id: Endpoint ID
Returns:
ProviderEndpoint 对象或 None
"""
cache_key = CacheKeys.endpoint_by_id(endpoint_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"Endpoint 缓存命中: {endpoint_id}")
return ProviderCacheService._dict_to_endpoint(cached_data)
# 2. 缓存未命中,查询数据库
endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == endpoint_id).first()
# 3. 写入缓存
if endpoint:
endpoint_dict = ProviderCacheService._endpoint_to_dict(endpoint)
await CacheService.set(
cache_key, endpoint_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"Endpoint 已缓存: {endpoint_id}")
return endpoint
@staticmethod
async def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[ProviderAPIKey]:
"""
获取 API Key带缓存
Args:
db: 数据库会话
api_key_id: API Key ID
Returns:
ProviderAPIKey 对象或 None
"""
cache_key = CacheKeys.api_key_by_id(api_key_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"API Key 缓存命中: {api_key_id}")
return ProviderCacheService._dict_to_api_key(cached_data)
# 2. 缓存未命中,查询数据库
api_key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == api_key_id).first()
# 3. 写入缓存
if api_key:
api_key_dict = ProviderCacheService._api_key_to_dict(api_key)
await CacheService.set(
cache_key, api_key_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"API Key 已缓存: {api_key_id}")
return api_key
@staticmethod
async def invalidate_provider_cache(provider_id: str):
"""
清除 Provider 缓存
Args:
provider_id: Provider ID
"""
await CacheService.delete(CacheKeys.provider_by_id(provider_id))
logger.debug(f"Provider 缓存已清除: {provider_id}")
@staticmethod
async def invalidate_endpoint_cache(endpoint_id: str):
"""
清除 Endpoint 缓存
Args:
endpoint_id: Endpoint ID
"""
await CacheService.delete(CacheKeys.endpoint_by_id(endpoint_id))
logger.debug(f"Endpoint 缓存已清除: {endpoint_id}")
@staticmethod
async def invalidate_api_key_cache(api_key_id: str):
"""
清除 API Key 缓存
Args:
api_key_id: API Key ID
"""
await CacheService.delete(CacheKeys.api_key_by_id(api_key_id))
logger.debug(f"API Key 缓存已清除: {api_key_id}")
@staticmethod
def _provider_to_dict(provider: Provider) -> dict:
"""将 Provider 对象转换为字典(用于缓存)"""
return {
"id": provider.id,
"name": provider.name,
"api_format": provider.api_format,
"base_url": provider.base_url,
"is_active": provider.is_active,
"priority": provider.priority,
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
"config": provider.config,
"description": provider.description,
}
@staticmethod
def _dict_to_provider(provider_dict: dict) -> Provider:
"""从字典重建 Provider 对象(分离的对象,不在 Session 中)"""
from datetime import datetime
provider = Provider(
id=provider_dict["id"],
name=provider_dict["name"],
api_format=provider_dict["api_format"],
base_url=provider_dict.get("base_url"),
is_active=provider_dict["is_active"],
priority=provider_dict.get("priority", 0),
rpm_limit=provider_dict.get("rpm_limit"),
rpm_used=provider_dict.get("rpm_used", 0),
config=provider_dict.get("config"),
description=provider_dict.get("description"),
)
if provider_dict.get("rpm_reset_at"):
provider.rpm_reset_at = datetime.fromisoformat(provider_dict["rpm_reset_at"])
return provider
@staticmethod
def _endpoint_to_dict(endpoint: ProviderEndpoint) -> dict:
"""将 Endpoint 对象转换为字典"""
return {
"id": endpoint.id,
"provider_id": endpoint.provider_id,
"name": endpoint.name,
"base_url": endpoint.base_url,
"is_active": endpoint.is_active,
"priority": endpoint.priority,
"weight": endpoint.weight,
"custom_path": endpoint.custom_path,
"config": endpoint.config,
}
@staticmethod
def _dict_to_endpoint(endpoint_dict: dict) -> ProviderEndpoint:
"""从字典重建 Endpoint 对象"""
endpoint = ProviderEndpoint(
id=endpoint_dict["id"],
provider_id=endpoint_dict["provider_id"],
name=endpoint_dict["name"],
base_url=endpoint_dict["base_url"],
is_active=endpoint_dict["is_active"],
priority=endpoint_dict.get("priority", 0),
weight=endpoint_dict.get("weight", 1.0),
custom_path=endpoint_dict.get("custom_path"),
config=endpoint_dict.get("config"),
)
return endpoint
@staticmethod
def _api_key_to_dict(api_key: ProviderAPIKey) -> dict:
"""将 API Key 对象转换为字典"""
return {
"id": api_key.id,
"endpoint_id": api_key.endpoint_id,
"key_value": api_key.key_value,
"is_active": api_key.is_active,
"max_rpm": api_key.max_rpm,
"current_rpm": api_key.current_rpm,
"health_score": api_key.health_score,
"circuit_breaker_state": api_key.circuit_breaker_state,
"adaptive_concurrency_limit": api_key.adaptive_concurrency_limit,
}
@staticmethod
def _dict_to_api_key(api_key_dict: dict) -> ProviderAPIKey:
"""从字典重建 API Key 对象"""
api_key = ProviderAPIKey(
id=api_key_dict["id"],
endpoint_id=api_key_dict["endpoint_id"],
key_value=api_key_dict["key_value"],
is_active=api_key_dict["is_active"],
max_rpm=api_key_dict.get("max_rpm"),
current_rpm=api_key_dict.get("current_rpm", 0),
health_score=api_key_dict.get("health_score", 1.0),
circuit_breaker_state=api_key_dict.get("circuit_breaker_state"),
adaptive_concurrency_limit=api_key_dict.get("adaptive_concurrency_limit"),
)
return api_key

View File

@@ -1,5 +1,22 @@
"""
用户缓存服务 - 减少数据库查询
架构说明
========
本服务采用混合 async/sync 模式:
- 缓存操作CacheService真正的 async使用 aioredis
- 数据库查询db.query同步的 SQLAlchemy Session
设计决策
--------
1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优
2. 缓存未命中时的同步查询FastAPI 会在线程池中执行,不会阻塞事件循环
3. 调用方必须在 async 上下文中使用 await
使用示例
--------
user = await UserCacheService.get_user_by_id(db, user_id)
await UserCacheService.invalidate_user_cache(user_id, email)
"""
from typing import Optional
@@ -12,9 +29,12 @@ from src.core.logger import logger
from src.models.database import User
class UserCacheService:
"""用户缓存服务"""
"""用户缓存服务
提供 User 的缓存查询功能,减少数据库访问。
所有公开方法均为 async需要在 async 上下文中调用。
"""
# 缓存 TTL- 使用统一常量
CACHE_TTL = CacheTTL.USER

View File

@@ -69,24 +69,29 @@ class ErrorClassifier:
# 这些错误是由用户请求本身导致的,换 Provider 也无济于事
# 注意:标准 API 返回的 error.type 已在 CLIENT_ERROR_TYPES 中处理
# 这里主要用于匹配非标准格式或第三方代理的错误消息
#
# 重要:不要在此列表中包含 Provider Key 配置问题(如 invalid_api_key
# 这类错误应该触发故障转移,而不是直接返回给用户
CLIENT_ERROR_PATTERNS: Tuple[str, ...] = (
"could not process image", # 图片处理失败
"image too large", # 图片过大
"invalid image", # 无效图片
"unsupported image", # 不支持的图片格式
"content_policy_violation", # 内容违规
"invalid_api_key", # 无效的 API Key不同于认证失败
"context_length_exceeded", # 上下文长度超限
"content_length_limit", # 请求内容长度超限 (Claude API)
"content_length_exceeds", # 内容长度超限变体 (AWS CodeWhisperer)
"max_tokens", # token 数超限
"invalid_prompt", # 无效的提示词
"content too long", # 内容过长
"input is too long", # 输入过长 (AWS)
"message is too long", # 消息过长
"prompt is too long", # Prompt 超长(第三方代理常见格式)
"image exceeds", # 图片超出限制
"pdf too large", # PDF 过大
"file too large", # 文件过大
"tool_use_id", # tool_result 引用了不存在的 tool_use兼容非标准代理
"validationexception", # AWS 验证异常
)
def __init__(
@@ -110,18 +115,124 @@ class ErrorClassifier:
# 表示客户端错误的 error type不区分大小写
# 这些 type 表明是请求本身的问题,不应重试
CLIENT_ERROR_TYPES: Tuple[str, ...] = (
"invalid_request_error", # Claude/OpenAI 标准客户端错误类型
"invalid_argument", # Gemini 参数错误
"failed_precondition", # Gemini 前置条件错误
# Claude/OpenAI 标准
"invalid_request_error",
# Gemini
"invalid_argument",
"failed_precondition",
# AWS
"validationexception",
# 通用
"validation_error",
"bad_request",
)
# 表示客户端错误的 reason/code 字段值
CLIENT_ERROR_REASONS: Tuple[str, ...] = (
"CONTENT_LENGTH_EXCEEDS_THRESHOLD",
"CONTEXT_LENGTH_EXCEEDED",
"MAX_TOKENS_EXCEEDED",
"INVALID_CONTENT",
"CONTENT_POLICY_VIOLATION",
)
def _parse_error_response(self, error_text: Optional[str]) -> Dict[str, Any]:
"""
解析错误响应为结构化数据
支持多种格式:
- {"error": {"type": "...", "message": "..."}} (Claude/OpenAI)
- {"error": {"message": "...", "__type": "..."}} (AWS)
- {"errorMessage": "..."} (Lambda)
- {"error": "..."}
- {"message": "...", "reason": "..."}
Returns:
结构化的错误信息: {
"type": str, # 错误类型
"message": str, # 错误消息
"reason": str, # 错误原因/代码
"raw": str, # 原始文本
}
"""
result = {"type": "", "message": "", "reason": "", "raw": error_text or ""}
if not error_text:
return result
try:
data = json.loads(error_text)
# 格式 1: {"error": {"type": "...", "message": "..."}}
if isinstance(data.get("error"), dict):
error_obj = data["error"]
result["type"] = str(error_obj.get("type", ""))
result["message"] = str(error_obj.get("message", ""))
# AWS 格式: {"error": {"__type": "...", "message": "...", "reason": "..."}}
# __type 直接在 error 对象中,而不是嵌套在 message 里
if "__type" in error_obj:
result["type"] = result["type"] or str(error_obj.get("__type", ""))
if "reason" in error_obj:
result["reason"] = str(error_obj.get("reason", ""))
if "code" in error_obj:
result["reason"] = result["reason"] or str(error_obj.get("code", ""))
# 嵌套 JSON 格式: message 字段本身是 JSON 字符串
# 支持多种嵌套格式:
# - AWS: {"__type": "...", "message": "...", "reason": "..."}
# - 第三方代理: {"error": {"type": "...", "message": "..."}}
if result["message"].startswith("{"):
try:
nested = json.loads(result["message"])
if isinstance(nested, dict):
# AWS 格式
if "__type" in nested:
result["type"] = result["type"] or str(nested.get("__type", ""))
result["message"] = str(nested.get("message", result["message"]))
result["reason"] = str(nested.get("reason", ""))
# 第三方代理格式: {"error": {"message": "..."}}
elif isinstance(nested.get("error"), dict):
inner_error = nested["error"]
inner_msg = str(inner_error.get("message", ""))
if inner_msg:
result["message"] = inner_msg
# 简单格式: {"message": "..."}
elif "message" in nested:
result["message"] = str(nested["message"])
except json.JSONDecodeError:
pass
# 格式 2: {"error": "..."}
elif isinstance(data.get("error"), str):
result["message"] = str(data["error"])
# 格式 3: {"errorMessage": "..."} (Lambda)
elif "errorMessage" in data:
result["message"] = str(data["errorMessage"])
# 格式 4: {"message": "...", "reason": "..."}
elif "message" in data:
result["message"] = str(data["message"])
result["reason"] = str(data.get("reason", ""))
# 提取顶层的 reason/code
if not result["reason"]:
result["reason"] = str(data.get("reason", data.get("code", "")))
except (json.JSONDecodeError, TypeError, KeyError):
result["message"] = error_text[:500] if len(error_text) > 500 else error_text
return result
def _is_client_error(self, error_text: Optional[str]) -> bool:
"""
检测错误响应是否为客户端错误(不应重试)
判断逻辑:
判断逻辑(按优先级)
1. 检查 error.type 是否为已知的客户端错误类型
2. 检查错误文本是否包含已知的客户端错误模式
2. 检查 reason/code 是否为已知的客户端错误原因
3. 回退到关键词匹配
Args:
error_text: 错误响应文本
@@ -132,67 +243,53 @@ class ErrorClassifier:
if not error_text:
return False
# 尝试解析 JSON 并检查 error type
try:
data = json.loads(error_text)
if isinstance(data.get("error"), dict):
error_type = data["error"].get("type", "")
if error_type and any(
t.lower() in error_type.lower() for t in self.CLIENT_ERROR_TYPES
):
return True
except (json.JSONDecodeError, TypeError, KeyError):
pass
parsed = self._parse_error_response(error_text)
# 回退到关键词匹配
error_lower = error_text.lower()
return any(pattern.lower() in error_lower for pattern in self.CLIENT_ERROR_PATTERNS)
# 1. 检查 error type
if parsed["type"]:
error_type_lower = parsed["type"].lower()
if any(t.lower() in error_type_lower for t in self.CLIENT_ERROR_TYPES):
return True
# 2. 检查 reason/code
if parsed["reason"]:
reason_upper = parsed["reason"].upper()
if any(r in reason_upper for r in self.CLIENT_ERROR_REASONS):
return True
# 3. 回退到关键词匹配(合并 message 和 raw
search_text = f"{parsed['message']} {parsed['raw']}".lower()
return any(pattern.lower() in search_text for pattern in self.CLIENT_ERROR_PATTERNS)
def _extract_error_message(self, error_text: Optional[str]) -> Optional[str]:
"""
从错误响应中提取错误消息
支持格式:
- {"error": {"message": "..."}} (OpenAI/Claude)
- {"error": {"type": "...", "message": "..."}}
- {"error": "..."}
- {"message": "..."}
Args:
error_text: 错误响应文本
Returns:
提取的错误消息,如果无法解析则返回原始文本
提取的错误消息
"""
if not error_text:
return None
try:
data = json.loads(error_text)
parsed = self._parse_error_response(error_text)
# {"error": {"message": "..."}} 或 {"error": {"type": "...", "message": "..."}}
if isinstance(data.get("error"), dict):
error_obj = data["error"]
message = error_obj.get("message", "")
error_type = error_obj.get("type", "")
if message:
if error_type:
return f"{error_type}: {message}"
return str(message)
# 构建可读的错误消息
parts = []
if parsed["type"]:
parts.append(parsed["type"])
if parsed["reason"]:
parts.append(f"[{parsed['reason']}]")
if parsed["message"]:
parts.append(parsed["message"])
# {"error": "..."}
if isinstance(data.get("error"), str):
return str(data["error"])
# {"message": "..."}
if isinstance(data.get("message"), str):
return str(data["message"])
except (json.JSONDecodeError, TypeError, KeyError):
pass
if parts:
return ": ".join(parts) if len(parts) > 1 else parts[0]
# 无法解析,返回原始文本(截断)
return error_text[:500] if len(error_text) > 500 else error_text
return parsed["raw"][:500] if len(parsed["raw"]) > 500 else parsed["raw"]
def classify(
self,

View File

@@ -5,6 +5,10 @@
- 使用滑动窗口采样,容忍并发波动
- 基于窗口内高利用率采样比例决策,而非要求连续高利用率
- 增加探测性扩容机制,长时间稳定时主动尝试扩容
AIMD 参数说明:
- 扩容:加性增加 (+INCREASE_STEP)
- 缩容:乘性减少 (*DECREASE_MULTIPLIER默认 0.85)
"""
from datetime import datetime, timezone
@@ -34,7 +38,7 @@ class AdaptiveConcurrencyManager:
核心算法:基于滑动窗口利用率的 AIMD
- 滑动窗口记录最近 N 次请求的利用率
- 当窗口内高利用率采样比例 >= 60% 时触发扩容
- 遇到 429 错误时乘性减少 (*0.7)
- 遇到 429 错误时乘性减少 (*0.85)
- 长时间无 429 且有流量时触发探测性扩容
扩容条件(满足任一即可):

View File

@@ -11,7 +11,6 @@
import asyncio
import math
import os
from contextlib import asynccontextmanager
from datetime import timedelta # noqa: F401 - kept for potential future use
from typing import Optional, Tuple
@@ -40,6 +39,7 @@ class ConcurrencyManager:
self._memory_lock: asyncio.Lock = asyncio.Lock()
self._memory_endpoint_counts: dict[str, int] = {}
self._memory_key_counts: dict[str, int] = {}
self._owns_redis: bool = False
self._memory_initialized = True
async def initialize(self) -> None:
@@ -47,41 +47,29 @@ class ConcurrencyManager:
if self._redis is not None:
return
# 优先使用 REDIS_URL如果没有则根据密码构建 URL
redis_url = os.getenv("REDIS_URL")
if not redis_url:
# 本地开发模式:从 REDIS_PASSWORD 构建 URL
redis_password = os.getenv("REDIS_PASSWORD")
if redis_password:
redis_url = f"redis://:{redis_password}@localhost:6379/0"
else:
redis_url = "redis://localhost:6379/0"
try:
self._redis = await aioredis.from_url(
redis_url,
encoding="utf-8",
decode_responses=True,
socket_timeout=5.0,
socket_connect_timeout=5.0,
)
# 测试连接
await self._redis.ping()
# 脱敏显示(隐藏密码)
safe_url = redis_url.split("@")[-1] if "@" in redis_url else redis_url
logger.info(f"[OK] Redis 连接成功: {safe_url}")
# 复用全局 Redis 客户端(带熔断/降级),避免重复创建连接池
from src.clients.redis_client import get_redis_client
self._redis = await get_redis_client(require_redis=False)
self._owns_redis = False
if self._redis:
logger.info("[OK] ConcurrencyManager 已复用全局 Redis 客户端")
else:
logger.warning("[WARN] Redis 不可用,并发控制降级为内存模式(仅在单实例环境下安全)")
except Exception as e:
logger.error(f"[ERROR] Redis 连接失败: {e}")
logger.warning("[WARN] 并发控制将被禁用(仅在单实例环境下安全)")
logger.error(f"[ERROR] 获取全局 Redis 客户端失败: {e}")
logger.warning("[WARN] 并发控制将降级为内存模式(仅在单实例环境下安全)")
self._redis = None
self._owns_redis = False
async def close(self) -> None:
"""关闭 Redis 连接"""
if self._redis:
if self._redis and self._owns_redis:
await self._redis.close()
self._redis = None
logger.info("Redis 连接已关闭")
logger.info("ConcurrencyManager Redis 连接已关闭")
self._redis = None
self._owns_redis = False
def _get_endpoint_key(self, endpoint_id: str) -> str:
"""获取 Endpoint 并发计数的 Redis Key"""

View File

@@ -3,7 +3,7 @@ RPM (Requests Per Minute) 限流服务
"""
import time
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from typing import Dict, Tuple
from sqlalchemy.orm import Session
@@ -72,11 +72,7 @@ class RPMLimiter:
# 获取当前分钟窗口
now = datetime.now(timezone.utc)
window_start = now.replace(second=0, microsecond=0)
window_end = (
window_start.replace(minute=window_start.minute + 1)
if window_start.minute < 59
else window_start.replace(hour=window_start.hour + 1, minute=0)
)
window_end = window_start + timedelta(minutes=1)
# 查找或创建追踪记录
tracking = (

View File

@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
from src.core.logger import logger
from src.database import get_db
from src.models.database import AuditEventType, AuditLog
from src.utils.transaction_manager import transactional
@@ -19,10 +18,13 @@ from src.utils.transaction_manager import transactional
class AuditService:
"""审计服务"""
"""审计服务
事务策略:本服务不负责事务提交,由中间件统一管理。
所有方法只做 db.add/flush提交由请求结束时的中间件处理。
"""
@staticmethod
@transactional(commit=False) # 不自动提交,让调用方决定
def log_event(
db: Session,
event_type: AuditEventType,
@@ -54,47 +56,44 @@ class AuditService:
Returns:
审计日志记录
Note:
不在此方法内提交事务,由调用方或中间件统一管理。
"""
try:
audit_log = AuditLog(
event_type=event_type.value,
description=description,
user_id=user_id,
api_key_id=api_key_id,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
status_code=status_code,
error_message=error_message,
event_metadata=metadata,
)
audit_log = AuditLog(
event_type=event_type.value,
description=description,
user_id=user_id,
api_key_id=api_key_id,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
status_code=status_code,
error_message=error_message,
event_metadata=metadata,
)
db.add(audit_log)
db.commit() # 立即提交事务,释放数据库锁
db.refresh(audit_log)
db.add(audit_log)
# 使用 flush 使记录可见但不提交事务,事务由中间件统一管理
db.flush()
# 同时记录到系统日志
log_message = (
f"AUDIT [{event_type.value}] - {description} | "
f"user_id={user_id}, ip={ip_address}"
)
# 同时记录到系统日志
log_message = (
f"AUDIT [{event_type.value}] - {description} | "
f"user_id={user_id}, ip={ip_address}"
)
if event_type in [
AuditEventType.UNAUTHORIZED_ACCESS,
AuditEventType.SUSPICIOUS_ACTIVITY,
]:
logger.warning(log_message)
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
logger.info(log_message)
else:
logger.debug(log_message)
if event_type in [
AuditEventType.UNAUTHORIZED_ACCESS,
AuditEventType.SUSPICIOUS_ACTIVITY,
]:
logger.warning(log_message)
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
logger.info(log_message)
else:
logger.debug(log_message)
return audit_log
except Exception as e:
logger.error(f"Failed to log audit event: {e}")
db.rollback()
return None
return audit_log
@staticmethod
def log_login_attempt(

File diff suppressed because it is too large Load Diff

View File

@@ -41,7 +41,7 @@ async def get_current_user(
try:
# 验证Token格式和签名
try:
payload = await AuthService.verify_token(token)
payload = await AuthService.verify_token(token, token_type="access")
except HTTPException as token_error:
# 保持原始的HTTP状态码如401 Unauthorized不要转换为403
logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...")
@@ -144,7 +144,7 @@ async def get_current_user_from_header(
token = authorization.replace("Bearer ", "")
try:
payload = await AuthService.verify_token(token)
payload = await AuthService.verify_token(token, token_type="access")
user_id = payload.get("user_id")
if not user_id:

363
tests/api/test_pipeline.py Normal file
View File

@@ -0,0 +1,363 @@
"""
API Pipeline 测试
测试 ApiRequestPipeline 的核心功能:
- 认证流程API Key、JWT Token
- 配额计算
- 审计日志记录
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import HTTPException
from src.api.base.pipeline import ApiRequestPipeline
class TestPipelineQuotaCalculation:
"""测试 Pipeline 配额计算"""
@pytest.fixture
def pipeline(self) -> ApiRequestPipeline:
return ApiRequestPipeline()
def test_calculate_quota_remaining_with_quota(self, pipeline: ApiRequestPipeline) -> None:
"""测试有配额限制时计算剩余配额"""
mock_user = MagicMock()
mock_user.quota_usd = 100.0
mock_user.used_usd = 30.0
remaining = pipeline._calculate_quota_remaining(mock_user)
assert remaining == 70.0
def test_calculate_quota_remaining_no_quota(self, pipeline: ApiRequestPipeline) -> None:
"""测试无配额限制时返回 None"""
mock_user = MagicMock()
mock_user.quota_usd = None
mock_user.used_usd = 30.0
remaining = pipeline._calculate_quota_remaining(mock_user)
assert remaining is None
def test_calculate_quota_remaining_negative_quota(self, pipeline: ApiRequestPipeline) -> None:
"""测试负配额时返回 None"""
mock_user = MagicMock()
mock_user.quota_usd = -1
mock_user.used_usd = 0.0
remaining = pipeline._calculate_quota_remaining(mock_user)
assert remaining is None
def test_calculate_quota_remaining_exceeded(self, pipeline: ApiRequestPipeline) -> None:
"""测试配额已超时返回 0"""
mock_user = MagicMock()
mock_user.quota_usd = 100.0
mock_user.used_usd = 150.0
remaining = pipeline._calculate_quota_remaining(mock_user)
assert remaining == 0.0
def test_calculate_quota_remaining_none_user(self, pipeline: ApiRequestPipeline) -> None:
"""测试用户为 None 时返回 None"""
remaining = pipeline._calculate_quota_remaining(None)
assert remaining is None
class TestPipelineAuditLogging:
"""测试 Pipeline 审计日志"""
@pytest.fixture
def pipeline(self) -> ApiRequestPipeline:
return ApiRequestPipeline()
def test_record_audit_event_success(self, pipeline: ApiRequestPipeline) -> None:
"""测试记录成功的审计事件"""
mock_context = MagicMock()
mock_context.db = MagicMock()
mock_context.user = MagicMock()
mock_context.user.id = "user-123"
mock_context.api_key = MagicMock()
mock_context.api_key.id = "key-123"
mock_context.request_id = "req-123"
mock_context.client_ip = "127.0.0.1"
mock_context.user_agent = "test-agent"
mock_context.request = MagicMock()
mock_context.request.method = "POST"
mock_context.request.url.path = "/v1/messages"
mock_context.start_time = 1000.0
mock_adapter = MagicMock()
mock_adapter.name = "test-adapter"
mock_adapter.audit_log_enabled = True
mock_adapter.audit_success_event = None
mock_adapter.audit_failure_event = None
with patch.object(
pipeline.audit_service,
"log_event",
) as mock_log:
with patch("time.time", return_value=1001.0):
pipeline._record_audit_event(
mock_context, mock_adapter, success=True, status_code=200
)
mock_log.assert_called_once()
call_kwargs = mock_log.call_args[1]
assert call_kwargs["user_id"] == "user-123"
assert call_kwargs["status_code"] == 200
def test_record_audit_event_failure(self, pipeline: ApiRequestPipeline) -> None:
"""测试记录失败的审计事件"""
mock_context = MagicMock()
mock_context.db = MagicMock()
mock_context.user = MagicMock()
mock_context.user.id = "user-123"
mock_context.api_key = MagicMock()
mock_context.api_key.id = "key-123"
mock_context.request_id = "req-123"
mock_context.client_ip = "127.0.0.1"
mock_context.user_agent = "test-agent"
mock_context.request = MagicMock()
mock_context.request.method = "POST"
mock_context.request.url.path = "/v1/messages"
mock_context.start_time = 1000.0
mock_adapter = MagicMock()
mock_adapter.name = "test-adapter"
mock_adapter.audit_log_enabled = True
mock_adapter.audit_success_event = None
mock_adapter.audit_failure_event = None
with patch.object(
pipeline.audit_service,
"log_event",
) as mock_log:
with patch("time.time", return_value=1001.0):
pipeline._record_audit_event(
mock_context, mock_adapter, success=False, status_code=500, error="Internal error"
)
mock_log.assert_called_once()
call_kwargs = mock_log.call_args[1]
assert call_kwargs["status_code"] == 500
assert call_kwargs["error_message"] == "Internal error"
def test_record_audit_event_no_db(self, pipeline: ApiRequestPipeline) -> None:
"""测试没有数据库会话时跳过审计"""
mock_context = MagicMock()
mock_context.db = None
mock_adapter = MagicMock()
mock_adapter.audit_log_enabled = True
with patch.object(
pipeline.audit_service,
"log_event",
) as mock_log:
# 不应该抛出异常
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
# 不应该调用 log_event
mock_log.assert_not_called()
def test_record_audit_event_disabled(self, pipeline: ApiRequestPipeline) -> None:
"""测试审计日志被禁用时跳过"""
mock_context = MagicMock()
mock_context.db = MagicMock()
mock_adapter = MagicMock()
mock_adapter.audit_log_enabled = False
with patch.object(
pipeline.audit_service,
"log_event",
) as mock_log:
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
mock_log.assert_not_called()
def test_record_audit_event_exception_handling(self, pipeline: ApiRequestPipeline) -> None:
"""测试审计日志异常不影响主流程"""
mock_context = MagicMock()
mock_context.db = MagicMock()
mock_context.user = MagicMock()
mock_context.user.id = "user-123"
mock_context.api_key = MagicMock()
mock_context.api_key.id = "key-123"
mock_context.request_id = "req-123"
mock_context.client_ip = "127.0.0.1"
mock_context.user_agent = "test-agent"
mock_context.request = MagicMock()
mock_context.request.method = "POST"
mock_context.request.url.path = "/v1/messages"
mock_context.start_time = 1000.0
mock_adapter = MagicMock()
mock_adapter.name = "test-adapter"
mock_adapter.audit_log_enabled = True
mock_adapter.audit_success_event = None
with patch.object(
pipeline.audit_service,
"log_event",
side_effect=Exception("DB error"),
):
with patch("time.time", return_value=1001.0):
# 不应该抛出异常
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
class TestPipelineAuthentication:
"""测试 Pipeline 认证相关逻辑"""
@pytest.fixture
def pipeline(self) -> ApiRequestPipeline:
return ApiRequestPipeline()
def test_authenticate_client_missing_key(self, pipeline: ApiRequestPipeline) -> None:
"""测试缺少 API Key 时抛出异常"""
mock_request = MagicMock()
mock_request.headers = {}
mock_request.url.path = "/v1/messages"
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_adapter = MagicMock()
mock_adapter.extract_api_key = MagicMock(return_value=None)
with pytest.raises(HTTPException) as exc_info:
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
assert exc_info.value.status_code == 401
assert "API密钥" in exc_info.value.detail
def test_authenticate_client_invalid_key(self, pipeline: ApiRequestPipeline) -> None:
"""测试无效的 API Key"""
mock_request = MagicMock()
mock_request.headers = {"Authorization": "Bearer sk-invalid"}
mock_request.url.path = "/v1/messages"
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_adapter = MagicMock()
mock_adapter.extract_api_key = MagicMock(return_value="sk-invalid")
with patch.object(
pipeline.auth_service,
"authenticate_api_key",
return_value=None,
):
with pytest.raises(HTTPException) as exc_info:
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
assert exc_info.value.status_code == 401
def test_authenticate_client_quota_exceeded(self, pipeline: ApiRequestPipeline) -> None:
"""测试配额超限时抛出异常"""
mock_user = MagicMock()
mock_user.id = "user-123"
mock_user.quota_usd = 100.0
mock_user.used_usd = 100.0
mock_api_key = MagicMock()
mock_api_key.id = "key-123"
mock_api_key.is_standalone = False
mock_request = MagicMock()
mock_request.headers = {"Authorization": "Bearer sk-test"}
mock_request.url.path = "/v1/messages"
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_adapter = MagicMock()
mock_adapter.extract_api_key = MagicMock(return_value="sk-test")
with patch.object(
pipeline.auth_service,
"authenticate_api_key",
return_value=(mock_user, mock_api_key),
):
with patch.object(
pipeline.usage_service,
"check_user_quota",
return_value=(False, "配额不足"),
):
from src.core.exceptions import QuotaExceededException
with pytest.raises(QuotaExceededException):
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
class TestPipelineAdminAuth:
"""测试管理员认证"""
@pytest.fixture
def pipeline(self) -> ApiRequestPipeline:
return ApiRequestPipeline()
@pytest.mark.asyncio
async def test_authenticate_admin_missing_token(self, pipeline: ApiRequestPipeline) -> None:
"""测试缺少管理员令牌"""
mock_request = MagicMock()
mock_request.headers = {}
mock_db = MagicMock()
with pytest.raises(HTTPException) as exc_info:
await pipeline._authenticate_admin(mock_request, mock_db)
assert exc_info.value.status_code == 401
assert "管理员凭证" in exc_info.value.detail
@pytest.mark.asyncio
async def test_authenticate_admin_invalid_token(self, pipeline: ApiRequestPipeline) -> None:
"""测试无效的管理员令牌"""
mock_request = MagicMock()
mock_request.headers = {"authorization": "Bearer invalid-token"}
mock_db = MagicMock()
with patch.object(
pipeline.auth_service,
"verify_token",
side_effect=HTTPException(status_code=401, detail="Invalid token"),
):
with pytest.raises(HTTPException) as exc_info:
await pipeline._authenticate_admin(mock_request, mock_db)
assert exc_info.value.status_code == 401
@pytest.mark.asyncio
async def test_authenticate_admin_success(self, pipeline: ApiRequestPipeline) -> None:
"""测试管理员认证成功"""
mock_user = MagicMock()
mock_user.id = "admin-123"
mock_user.is_active = True
mock_request = MagicMock()
mock_request.headers = {"authorization": "Bearer valid-token"}
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
with patch.object(
pipeline.auth_service,
"verify_token",
new_callable=AsyncMock,
return_value={"user_id": "admin-123"},
):
result = await pipeline._authenticate_admin(mock_request, mock_db)
assert result == mock_user
assert mock_request.state.user_id == "admin-123"

View File

@@ -0,0 +1 @@
"""服务层测试"""

299
tests/services/test_auth.py Normal file
View File

@@ -0,0 +1,299 @@
"""
认证服务测试
测试 AuthService 的核心功能:
- JWT Token 创建和验证
- 用户登录认证
- API Key 认证
"""
import pytest
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import jwt
from src.services.auth.service import (
AuthService,
JWT_SECRET_KEY,
JWT_ALGORITHM,
JWT_EXPIRATION_HOURS,
)
class TestJWTTokenCreation:
"""测试 JWT Token 创建"""
def test_create_access_token_contains_required_fields(self) -> None:
"""测试访问令牌包含必要字段"""
data = {"sub": "user123", "email": "test@example.com"}
token = AuthService.create_access_token(data)
# 解码验证
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
assert payload["sub"] == "user123"
assert payload["email"] == "test@example.com"
assert payload["type"] == "access"
assert "exp" in payload
def test_create_access_token_expiration(self) -> None:
"""测试访问令牌过期时间正确"""
data = {"sub": "user123"}
token = AuthService.create_access_token(data)
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
# 验证过期时间在预期范围内允许1分钟误差
exp_time = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
expected_exp = datetime.now(timezone.utc) + timedelta(hours=JWT_EXPIRATION_HOURS)
assert abs((exp_time - expected_exp).total_seconds()) < 60
def test_create_refresh_token_type(self) -> None:
"""测试刷新令牌类型正确"""
data = {"sub": "user123"}
token = AuthService.create_refresh_token(data)
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
assert payload["type"] == "refresh"
def test_create_refresh_token_longer_expiration(self) -> None:
"""测试刷新令牌过期时间更长"""
data = {"sub": "user123"}
access_token = AuthService.create_access_token(data)
refresh_token = AuthService.create_refresh_token(data)
access_payload = jwt.decode(access_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
refresh_payload = jwt.decode(refresh_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
# 刷新令牌应该比访问令牌过期时间更长
assert refresh_payload["exp"] > access_payload["exp"]
class TestJWTTokenVerification:
"""测试 JWT Token 验证"""
@pytest.mark.asyncio
async def test_verify_valid_access_token(self) -> None:
"""测试验证有效的访问令牌"""
data = {"sub": "user123", "email": "test@example.com"}
token = AuthService.create_access_token(data)
with patch(
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
new_callable=AsyncMock,
return_value=False,
):
payload = await AuthService.verify_token(token, token_type="access")
assert payload["sub"] == "user123"
assert payload["type"] == "access"
@pytest.mark.asyncio
async def test_verify_expired_token_raises_error(self) -> None:
"""测试验证过期令牌抛出异常"""
# 创建一个已过期的 token
data = {"sub": "user123", "type": "access"}
expire = datetime.now(timezone.utc) - timedelta(hours=1)
data["exp"] = expire
expired_token = jwt.encode(data, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
await AuthService.verify_token(expired_token)
assert exc_info.value.status_code == 401
assert "过期" in exc_info.value.detail
@pytest.mark.asyncio
async def test_verify_invalid_token_raises_error(self) -> None:
"""测试验证无效令牌抛出异常"""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
await AuthService.verify_token("invalid.token.here")
assert exc_info.value.status_code == 401
@pytest.mark.asyncio
async def test_verify_wrong_token_type_raises_error(self) -> None:
"""测试令牌类型不匹配抛出异常"""
data = {"sub": "user123"}
refresh_token = AuthService.create_refresh_token(data)
from fastapi import HTTPException
with patch(
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
new_callable=AsyncMock,
return_value=False,
):
with pytest.raises(HTTPException) as exc_info:
await AuthService.verify_token(refresh_token, token_type="access")
assert exc_info.value.status_code == 401
assert "类型错误" in exc_info.value.detail
@pytest.mark.asyncio
async def test_verify_blacklisted_token_raises_error(self) -> None:
"""测试已撤销的令牌抛出异常"""
data = {"sub": "user123"}
token = AuthService.create_access_token(data)
from fastapi import HTTPException
with patch(
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
new_callable=AsyncMock,
return_value=True,
):
with pytest.raises(HTTPException) as exc_info:
await AuthService.verify_token(token)
assert exc_info.value.status_code == 401
assert "撤销" in exc_info.value.detail
class TestUserAuthentication:
"""测试用户登录认证"""
@pytest.mark.asyncio
async def test_authenticate_user_success(self) -> None:
"""测试用户登录成功"""
# Mock 数据库和用户对象
mock_user = MagicMock()
mock_user.id = "user-123"
mock_user.email = "test@example.com"
mock_user.is_active = True
mock_user.verify_password.return_value = True
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
with patch(
"src.services.auth.service.UserCacheService.invalidate_user_cache",
new_callable=AsyncMock,
):
result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123")
assert result == mock_user
mock_user.verify_password.assert_called_once_with("password123")
mock_db.commit.assert_called_once()
@pytest.mark.asyncio
async def test_authenticate_user_not_found(self) -> None:
"""测试用户不存在"""
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = None
result = await AuthService.authenticate_user(mock_db, "nonexistent@example.com", "password")
assert result is None
@pytest.mark.asyncio
async def test_authenticate_user_wrong_password(self) -> None:
"""测试密码错误"""
mock_user = MagicMock()
mock_user.email = "test@example.com"
mock_user.verify_password.return_value = False
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
result = await AuthService.authenticate_user(mock_db, "test@example.com", "wrongpassword")
assert result is None
@pytest.mark.asyncio
async def test_authenticate_user_inactive(self) -> None:
"""测试用户已禁用"""
mock_user = MagicMock()
mock_user.email = "test@example.com"
mock_user.is_active = False
mock_user.verify_password.return_value = True
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123")
assert result is None
class TestAPIKeyAuthentication:
"""测试 API Key 认证"""
def test_authenticate_api_key_success(self) -> None:
"""测试 API Key 认证成功"""
mock_user = MagicMock()
mock_user.id = "user-123"
mock_user.email = "test@example.com"
mock_user.is_active = True
mock_api_key = MagicMock()
mock_api_key.is_active = True
mock_api_key.expires_at = None
mock_api_key.user = mock_user
mock_api_key.balance_used_usd = 0.0
mock_db = MagicMock()
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
mock_api_key
)
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
with patch(
"src.services.auth.service.ApiKeyService.check_balance",
return_value=(True, 100.0),
):
result = AuthService.authenticate_api_key(mock_db, "sk-test-key")
assert result is not None
assert result[0] == mock_user
assert result[1] == mock_api_key
def test_authenticate_api_key_not_found(self) -> None:
"""测试 API Key 不存在"""
mock_db = MagicMock()
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
None
)
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
result = AuthService.authenticate_api_key(mock_db, "sk-invalid-key")
assert result is None
def test_authenticate_api_key_inactive(self) -> None:
"""测试 API Key 已禁用"""
mock_api_key = MagicMock()
mock_api_key.is_active = False
mock_db = MagicMock()
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
mock_api_key
)
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
result = AuthService.authenticate_api_key(mock_db, "sk-inactive-key")
assert result is None
def test_authenticate_api_key_expired(self) -> None:
"""测试 API Key 已过期"""
mock_api_key = MagicMock()
mock_api_key.is_active = True
mock_api_key.expires_at = datetime.now(timezone.utc) - timedelta(days=1)
mock_db = MagicMock()
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
mock_api_key
)
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
result = AuthService.authenticate_api_key(mock_db, "sk-expired-key")
assert result is None

View File

@@ -0,0 +1,292 @@
"""
UsageService 测试
测试用量统计服务的核心功能:
- 成本计算
- 配额检查
- 用量统计查询
"""
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from src.services.usage.service import UsageService
class TestCostCalculation:
"""测试成本计算"""
def test_calculate_cost_basic(self) -> None:
"""测试基础成本计算"""
# 价格:输入 $3/1M, 输出 $15/1M
result = UsageService.calculate_cost(
input_tokens=1000,
output_tokens=500,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
)
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, request_cost, total_cost = result
# 1000 tokens * $3 / 1M = $0.003
assert abs(input_cost - 0.003) < 0.0001
# 500 tokens * $15 / 1M = $0.0075
assert abs(output_cost - 0.0075) < 0.0001
# Total = $0.003 + $0.0075 = $0.0105
assert abs(total_cost - 0.0105) < 0.0001
def test_calculate_cost_with_cache(self) -> None:
"""测试带缓存的成本计算"""
result = UsageService.calculate_cost(
input_tokens=1000,
output_tokens=500,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
cache_creation_input_tokens=200,
cache_read_input_tokens=300,
cache_creation_price_per_1m=3.75, # 1.25x input price
cache_read_price_per_1m=0.3, # 0.1x input price
)
(
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
cache_cost,
request_cost,
total_cost,
) = result
# 验证缓存成本被计算
assert cache_creation_cost > 0
assert cache_read_cost > 0
assert cache_cost == cache_creation_cost + cache_read_cost
def test_calculate_cost_with_request_price(self) -> None:
"""测试按次计费"""
result = UsageService.calculate_cost(
input_tokens=1000,
output_tokens=500,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
price_per_request=0.01,
)
(
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
cache_cost,
request_cost,
total_cost,
) = result
assert request_cost == 0.01
# Total 包含 request_cost
assert total_cost == input_cost + output_cost + request_cost
def test_calculate_cost_zero_tokens(self) -> None:
"""测试零 token 的成本计算"""
result = UsageService.calculate_cost(
input_tokens=0,
output_tokens=0,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
)
(
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
cache_cost,
request_cost,
total_cost,
) = result
assert input_cost == 0
assert output_cost == 0
assert total_cost == 0
class TestQuotaCheck:
"""测试配额检查"""
def test_check_user_quota_sufficient(self) -> None:
"""测试配额充足"""
mock_user = MagicMock()
mock_user.quota_usd = 100.0
mock_user.used_usd = 30.0
mock_user.role = MagicMock()
mock_user.role.value = "user"
mock_api_key = MagicMock()
mock_api_key.is_standalone = False
mock_db = MagicMock()
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
assert is_ok is True
def test_check_user_quota_exceeded(self) -> None:
"""测试配额超限(当有预估成本时)"""
mock_user = MagicMock()
mock_user.quota_usd = 100.0
mock_user.used_usd = 99.0 # 接近配额上限
mock_user.role = MagicMock()
mock_user.role.value = "user"
mock_api_key = MagicMock()
mock_api_key.is_standalone = False
mock_db = MagicMock()
# 当预估成本超过剩余配额时应该返回 False
is_ok, message = UsageService.check_user_quota(
mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key
)
assert is_ok is False
assert "配额" in message
def test_check_user_quota_no_limit(self) -> None:
"""测试无配额限制None"""
mock_user = MagicMock()
mock_user.quota_usd = None
mock_user.used_usd = 1000.0
mock_user.role = MagicMock()
mock_user.role.value = "user"
mock_api_key = MagicMock()
mock_api_key.is_standalone = False
mock_db = MagicMock()
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
assert is_ok is True
def test_check_user_quota_admin_bypass(self) -> None:
"""测试管理员绕过配额检查"""
from src.models.database import UserRole
mock_user = MagicMock()
mock_user.quota_usd = 0.0
mock_user.used_usd = 1000.0
mock_user.role = UserRole.ADMIN
mock_api_key = MagicMock()
mock_api_key.is_standalone = False
mock_db = MagicMock()
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
assert is_ok is True
def test_check_standalone_api_key_balance(self) -> None:
"""测试独立 API Key 余额检查"""
mock_user = MagicMock()
mock_user.quota_usd = 0.0
mock_user.used_usd = 0.0
mock_user.role = MagicMock()
mock_user.role.value = "user"
mock_api_key = MagicMock()
mock_api_key.is_standalone = True
mock_api_key.current_balance_usd = 50.0
mock_api_key.balance_used_usd = 10.0
mock_db = MagicMock()
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
assert is_ok is True
def test_check_standalone_api_key_insufficient_balance(self) -> None:
"""测试独立 API Key 余额不足"""
mock_user = MagicMock()
mock_user.quota_usd = 100.0
mock_user.used_usd = 0.0
mock_user.role = MagicMock()
mock_user.role.value = "user"
mock_api_key = MagicMock()
mock_api_key.is_standalone = True
mock_api_key.current_balance_usd = 10.0
mock_api_key.balance_used_usd = 9.0 # 剩余 $1
mock_db = MagicMock()
# 需要 mock ApiKeyService.get_remaining_balance
with patch(
"src.services.user.apikey.ApiKeyService.get_remaining_balance",
return_value=1.0,
):
# 预估成本 $5 超过剩余余额 $1
is_ok, message = UsageService.check_user_quota(
mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key
)
assert is_ok is False
class TestUsageStatistics:
"""测试用量统计查询
注意get_usage_summary 方法内部使用了数据库方言特定的日期函数,
需要真实数据库或更复杂的 mock。这里只测试方法存在性。
"""
def test_get_usage_summary_exists(self) -> None:
"""测试 get_usage_summary 方法存在"""
assert hasattr(UsageService, "get_usage_summary")
assert callable(getattr(UsageService, "get_usage_summary"))
class TestHelperMethods:
"""测试辅助方法"""
@pytest.mark.asyncio
async def test_get_rate_multiplier_and_free_tier_default(self) -> None:
"""测试默认费率倍数"""
mock_db = MagicMock()
# 模拟未找到 provider_api_key
mock_db.query.return_value.filter.return_value.first.return_value = None
rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier(
mock_db, provider_api_key_id=None, provider_id=None
)
assert rate_multiplier == 1.0
assert is_free_tier is False
@pytest.mark.asyncio
async def test_get_rate_multiplier_from_provider_api_key(self) -> None:
"""测试从 ProviderAPIKey 获取费率倍数"""
mock_provider_api_key = MagicMock()
mock_provider_api_key.rate_multiplier = 0.8
mock_endpoint = MagicMock()
mock_endpoint.provider_id = "provider-123"
mock_provider = MagicMock()
mock_provider.billing_type = "standard"
mock_db = MagicMock()
# 第一次查询返回 provider_api_key
mock_db.query.return_value.filter.return_value.first.side_effect = [
mock_provider_api_key,
mock_endpoint,
mock_provider,
]
rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier(
mock_db, provider_api_key_id="pak-123", provider_id=None
)
assert rate_multiplier == 0.8
assert is_free_tier is False