refactor: consolidate transaction management and remove legacy modules

- Remove unused context.py module (replaced by request.state)
- Remove provider_cache.py (no longer needed)
- Unify environment loading in config/settings.py instead of __init__.py
- Add deprecation warning for get_async_db() (consolidating on sync Session)
- Enhance database.py documentation with comprehensive transaction strategy
- Simplify audit logging to reuse request-level Session (no separate connections)
- Extract UsageService._build_usage_params() helper to reduce code duplication
- Update model and user cache implementations with refined transaction handling
- Remove unnecessary sessionmaker from pipeline
- Clean up audit service exception handling
This commit is contained in:
fawney19
2025-12-18 01:59:40 +08:00
parent 4d1d863916
commit b2a857c164
9 changed files with 638 additions and 957 deletions

View File

@@ -3,10 +3,8 @@
A proxy server that enables AI models to work with multiple API providers.
"""
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# 注意: dotenv 加载已统一移至 src/config/settings.py
# 不要在此处重复加载
try:
from src._version import __version__

View File

@@ -5,7 +5,7 @@ from enum import Enum
from typing import Any, Optional, Tuple
from fastapi import HTTPException, Request
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from src.core.exceptions import QuotaExceededException
from src.core.logger import logger
@@ -241,11 +241,15 @@ class ApiRequestPipeline:
status_code: Optional[int] = None,
error: Optional[str] = None,
) -> None:
"""记录审计事件
事务策略:复用请求级 Session不单独提交。
审计记录随主事务一起提交,由中间件统一管理。
"""
if not getattr(adapter, "audit_log_enabled", True):
return
bind = context.db.get_bind()
if bind is None:
if context.db is None:
return
event_type = adapter.audit_success_event if success else adapter.audit_failure_event
@@ -265,11 +269,11 @@ class ApiRequestPipeline:
error=error,
)
SessionMaker = sessionmaker(bind=bind)
audit_session = SessionMaker()
try:
# 复用请求级 Session不创建新的连接
# 审计记录随主事务一起提交,由中间件统一管理
self.audit_service.log_event(
db=audit_session,
db=context.db,
event_type=event_type,
description=f"{context.request.method} {context.request.url.path} via {adapter.name}",
user_id=context.user.id if context.user else None,
@@ -281,12 +285,9 @@ class ApiRequestPipeline:
error_message=error,
metadata=metadata,
)
audit_session.commit()
except Exception as exc:
audit_session.rollback()
# 审计失败不应影响主请求,仅记录警告
logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}")
finally:
audit_session.close()
def _build_audit_metadata(
self,

View File

@@ -1,168 +0,0 @@
"""
统一的请求上下文
RequestContext 贯穿整个请求生命周期,包含所有请求相关信息。
这确保了数据在各层之间传递时不会丢失。
使用方式:
1. Pipeline 层创建 RequestContext
2. 各层通过 context 访问和更新信息
3. Adapter 层使用 context 记录 Usage
"""
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
@dataclass
class RequestContext:
"""
请求上下文 - 贯穿整个请求生命周期
设计原则:
1. 在请求开始时创建,包含所有已知信息
2. 在请求执行过程中逐步填充 Provider 信息
3. 在请求结束时用于记录 Usage
"""
# ==================== 请求标识 ====================
request_id: str
# ==================== 认证信息 ====================
user: Any # User model
api_key: Any # ApiKey model
db: Any # Database session
# ==================== 请求信息 ====================
api_format: str # CLAUDE, OPENAI, GEMINI, etc.
model: str # 用户请求的模型名
is_stream: bool = False
# ==================== 原始请求 ====================
original_headers: Dict[str, str] = field(default_factory=dict)
original_body: Dict[str, Any] = field(default_factory=dict)
# ==================== 客户端信息 ====================
client_ip: str = "unknown"
user_agent: str = ""
# ==================== 计时 ====================
start_time: float = field(default_factory=time.time)
# ==================== Provider 信息(请求执行后填充)====================
provider_name: Optional[str] = None
provider_id: Optional[str] = None
endpoint_id: Optional[str] = None
provider_api_key_id: Optional[str] = None
# ==================== 模型映射信息 ====================
resolved_model: Optional[str] = None # 映射后的模型名
original_model: Optional[str] = None # 原始模型名(用于价格计算)
# ==================== 请求/响应头 ====================
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_response_headers: Dict[str, str] = field(default_factory=dict)
# ==================== 追踪信息 ====================
attempt_id: Optional[str] = None
# ==================== 能力需求 ====================
capability_requirements: Dict[str, bool] = field(default_factory=dict)
# 运行时计算的能力需求,来源于:
# 1. 用户 model_capability_settings
# 2. 用户 ApiKey.force_capabilities
# 3. 请求头 X-Require-Capability
# 4. 失败重试时动态添加
@classmethod
def create(
cls,
*,
db: Any,
user: Any,
api_key: Any,
api_format: str,
model: str,
is_stream: bool = False,
original_headers: Optional[Dict[str, str]] = None,
original_body: Optional[Dict[str, Any]] = None,
client_ip: str = "unknown",
user_agent: str = "",
request_id: Optional[str] = None,
) -> "RequestContext":
"""创建请求上下文"""
return cls(
request_id=request_id or str(uuid.uuid4()),
db=db,
user=user,
api_key=api_key,
api_format=api_format,
model=model,
is_stream=is_stream,
original_headers=original_headers or {},
original_body=original_body or {},
client_ip=client_ip,
user_agent=user_agent,
original_model=model, # 初始时原始模型等于请求模型
)
def update_provider_info(
self,
*,
provider_name: str,
provider_id: str,
endpoint_id: str,
provider_api_key_id: str,
resolved_model: Optional[str] = None,
) -> None:
"""更新 Provider 信息(请求执行后调用)"""
self.provider_name = provider_name
self.provider_id = provider_id
self.endpoint_id = endpoint_id
self.provider_api_key_id = provider_api_key_id
if resolved_model:
self.resolved_model = resolved_model
def update_headers(
self,
*,
request_headers: Optional[Dict[str, str]] = None,
response_headers: Optional[Dict[str, str]] = None,
) -> None:
"""更新请求/响应头"""
if request_headers:
self.provider_request_headers = request_headers
if response_headers:
self.provider_response_headers = response_headers
@property
def elapsed_ms(self) -> int:
"""计算已经过的时间(毫秒)"""
return int((time.time() - self.start_time) * 1000)
@property
def effective_model(self) -> str:
"""获取有效的模型名(映射后优先)"""
return self.resolved_model or self.model
@property
def billing_model(self) -> str:
"""获取计费模型名(原始模型优先)"""
return self.original_model or self.model
def to_metadata_dict(self) -> Dict[str, Any]:
"""转换为元数据字典(用于 Usage 记录)"""
return {
"api_format": self.api_format,
"provider": self.provider_name or "unknown",
"model": self.effective_model,
"original_model": self.billing_model,
"provider_id": self.provider_id,
"provider_endpoint_id": self.endpoint_id,
"provider_api_key_id": self.provider_api_key_id,
"provider_request_headers": self.provider_request_headers,
"provider_response_headers": self.provider_response_headers,
"attempt_id": self.attempt_id,
}

View File

@@ -223,7 +223,18 @@ def _ensure_async_engine() -> AsyncEngine:
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""获取异步数据库会话"""
"""获取异步数据库会话
.. deprecated::
此方法已废弃,项目统一使用同步 Session。
未来版本可能移除此方法。请使用 get_db() 代替。
"""
import warnings
warnings.warn(
"get_async_db() 已废弃,项目统一使用同步 Session。请使用 get_db() 代替。",
DeprecationWarning,
stacklevel=2,
)
# 确保异步引擎已初始化
_ensure_async_engine()
@@ -237,12 +248,34 @@ async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
def get_db(request: Request = None) -> Generator[Session, None, None]: # type: ignore[assignment]
"""获取数据库会话
注意:事务管理由业务逻辑层显式控制(手动调用 commit/rollback
这里只负责会话的创建和关闭,不自动提交
事务策略说明
============
本项目采用**混合事务管理**策略:
在 FastAPI 请求上下文中通过 Depends(get_db) 调用时,会自动注入 Request 对象,
支持中间件管理的 session 复用;在非请求上下文中直接调用 get_db() 时,
request 为 None退化为独立 session 模式。
1. **LLM 请求路径**
- 由 PluginMiddleware 统一管理事务
- Service 层使用 db.flush() 使更改可见,但不提交
- 请求结束时由中间件统一 commit 或 rollback
- 例外UsageService.record_usage() 会显式 commit因为使用记录需要立即持久化
2. **管理后台 API**
- 路由层显式调用 db.commit()
- 每个操作独立提交,不依赖中间件
3. **后台任务/调度器**
- 使用独立 Session通过 create_session() 或 next(get_db())
- 自行管理事务生命周期
使用方式
========
- FastAPI 请求:通过 Depends(get_db) 注入,支持中间件管理的 session 复用
- 非请求上下文:直接调用 get_db(),退化为独立 session 模式
注意事项
========
- 本函数不自动提交事务
- 异常时会自动回滚
- 中间件管理模式下session 关闭由中间件负责
"""
# FastAPI 请求上下文:优先复用中间件绑定的 request.state.db
if request is not None:

View File

@@ -1,5 +1,21 @@
"""
Model 映射缓存服务 - 减少模型查询
架构说明
========
本服务采用混合 async/sync 模式:
- 缓存操作CacheService真正的 async使用 aioredis
- 数据库查询db.query同步的 SQLAlchemy Session
设计决策
--------
1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优
2. 缓存未命中时的同步查询FastAPI 会在线程池中执行,不会阻塞事件循环
3. 调用方必须在 async 上下文中使用 await
使用示例
--------
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, "gpt-4")
"""
import time
@@ -19,7 +35,11 @@ from src.models.database import GlobalModel, Model
class ModelCacheService:
"""Model 映射缓存服务"""
"""Model 映射缓存服务
提供 GlobalModel 和 Model 的缓存查询功能,减少数据库访问。
所有公开方法均为 async需要在 async 上下文中调用。
"""
# 缓存 TTL- 使用统一常量
CACHE_TTL = CacheTTL.MODEL

View File

@@ -1,254 +0,0 @@
"""
Provider 配置缓存服务 - 减少 Provider/Endpoint/APIKey 查询
"""
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
from src.core.cache_service import CacheKeys, CacheService
from src.core.logger import logger
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
class ProviderCacheService:
"""Provider 配置缓存服务"""
# 缓存 TTL- 使用统一常量
CACHE_TTL = CacheTTL.PROVIDER
@staticmethod
async def get_provider_by_id(db: Session, provider_id: str) -> Optional[Provider]:
"""
获取 Provider带缓存
Args:
db: 数据库会话
provider_id: Provider ID
Returns:
Provider 对象或 None
"""
cache_key = CacheKeys.provider_by_id(provider_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"Provider 缓存命中: {provider_id}")
return ProviderCacheService._dict_to_provider(cached_data)
# 2. 缓存未命中,查询数据库
provider = db.query(Provider).filter(Provider.id == provider_id).first()
# 3. 写入缓存
if provider:
provider_dict = ProviderCacheService._provider_to_dict(provider)
await CacheService.set(
cache_key, provider_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"Provider 已缓存: {provider_id}")
return provider
@staticmethod
async def get_endpoint_by_id(db: Session, endpoint_id: str) -> Optional[ProviderEndpoint]:
"""
获取 Endpoint带缓存
Args:
db: 数据库会话
endpoint_id: Endpoint ID
Returns:
ProviderEndpoint 对象或 None
"""
cache_key = CacheKeys.endpoint_by_id(endpoint_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"Endpoint 缓存命中: {endpoint_id}")
return ProviderCacheService._dict_to_endpoint(cached_data)
# 2. 缓存未命中,查询数据库
endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == endpoint_id).first()
# 3. 写入缓存
if endpoint:
endpoint_dict = ProviderCacheService._endpoint_to_dict(endpoint)
await CacheService.set(
cache_key, endpoint_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"Endpoint 已缓存: {endpoint_id}")
return endpoint
@staticmethod
async def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[ProviderAPIKey]:
"""
获取 API Key带缓存
Args:
db: 数据库会话
api_key_id: API Key ID
Returns:
ProviderAPIKey 对象或 None
"""
cache_key = CacheKeys.api_key_by_id(api_key_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"API Key 缓存命中: {api_key_id}")
return ProviderCacheService._dict_to_api_key(cached_data)
# 2. 缓存未命中,查询数据库
api_key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == api_key_id).first()
# 3. 写入缓存
if api_key:
api_key_dict = ProviderCacheService._api_key_to_dict(api_key)
await CacheService.set(
cache_key, api_key_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"API Key 已缓存: {api_key_id}")
return api_key
@staticmethod
async def invalidate_provider_cache(provider_id: str):
"""
清除 Provider 缓存
Args:
provider_id: Provider ID
"""
await CacheService.delete(CacheKeys.provider_by_id(provider_id))
logger.debug(f"Provider 缓存已清除: {provider_id}")
@staticmethod
async def invalidate_endpoint_cache(endpoint_id: str):
"""
清除 Endpoint 缓存
Args:
endpoint_id: Endpoint ID
"""
await CacheService.delete(CacheKeys.endpoint_by_id(endpoint_id))
logger.debug(f"Endpoint 缓存已清除: {endpoint_id}")
@staticmethod
async def invalidate_api_key_cache(api_key_id: str):
"""
清除 API Key 缓存
Args:
api_key_id: API Key ID
"""
await CacheService.delete(CacheKeys.api_key_by_id(api_key_id))
logger.debug(f"API Key 缓存已清除: {api_key_id}")
@staticmethod
def _provider_to_dict(provider: Provider) -> dict:
"""将 Provider 对象转换为字典(用于缓存)"""
return {
"id": provider.id,
"name": provider.name,
"api_format": provider.api_format,
"base_url": provider.base_url,
"is_active": provider.is_active,
"priority": provider.priority,
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
"config": provider.config,
"description": provider.description,
}
@staticmethod
def _dict_to_provider(provider_dict: dict) -> Provider:
"""从字典重建 Provider 对象(分离的对象,不在 Session 中)"""
from datetime import datetime
provider = Provider(
id=provider_dict["id"],
name=provider_dict["name"],
api_format=provider_dict["api_format"],
base_url=provider_dict.get("base_url"),
is_active=provider_dict["is_active"],
priority=provider_dict.get("priority", 0),
rpm_limit=provider_dict.get("rpm_limit"),
rpm_used=provider_dict.get("rpm_used", 0),
config=provider_dict.get("config"),
description=provider_dict.get("description"),
)
if provider_dict.get("rpm_reset_at"):
provider.rpm_reset_at = datetime.fromisoformat(provider_dict["rpm_reset_at"])
return provider
@staticmethod
def _endpoint_to_dict(endpoint: ProviderEndpoint) -> dict:
"""将 Endpoint 对象转换为字典"""
return {
"id": endpoint.id,
"provider_id": endpoint.provider_id,
"name": endpoint.name,
"base_url": endpoint.base_url,
"is_active": endpoint.is_active,
"priority": endpoint.priority,
"weight": endpoint.weight,
"custom_path": endpoint.custom_path,
"config": endpoint.config,
}
@staticmethod
def _dict_to_endpoint(endpoint_dict: dict) -> ProviderEndpoint:
"""从字典重建 Endpoint 对象"""
endpoint = ProviderEndpoint(
id=endpoint_dict["id"],
provider_id=endpoint_dict["provider_id"],
name=endpoint_dict["name"],
base_url=endpoint_dict["base_url"],
is_active=endpoint_dict["is_active"],
priority=endpoint_dict.get("priority", 0),
weight=endpoint_dict.get("weight", 1.0),
custom_path=endpoint_dict.get("custom_path"),
config=endpoint_dict.get("config"),
)
return endpoint
@staticmethod
def _api_key_to_dict(api_key: ProviderAPIKey) -> dict:
"""将 API Key 对象转换为字典"""
return {
"id": api_key.id,
"endpoint_id": api_key.endpoint_id,
"key_value": api_key.key_value,
"is_active": api_key.is_active,
"max_rpm": api_key.max_rpm,
"current_rpm": api_key.current_rpm,
"health_score": api_key.health_score,
"circuit_breaker_state": api_key.circuit_breaker_state,
"adaptive_concurrency_limit": api_key.adaptive_concurrency_limit,
}
@staticmethod
def _dict_to_api_key(api_key_dict: dict) -> ProviderAPIKey:
"""从字典重建 API Key 对象"""
api_key = ProviderAPIKey(
id=api_key_dict["id"],
endpoint_id=api_key_dict["endpoint_id"],
key_value=api_key_dict["key_value"],
is_active=api_key_dict["is_active"],
max_rpm=api_key_dict.get("max_rpm"),
current_rpm=api_key_dict.get("current_rpm", 0),
health_score=api_key_dict.get("health_score", 1.0),
circuit_breaker_state=api_key_dict.get("circuit_breaker_state"),
adaptive_concurrency_limit=api_key_dict.get("adaptive_concurrency_limit"),
)
return api_key

View File

@@ -1,5 +1,22 @@
"""
用户缓存服务 - 减少数据库查询
架构说明
========
本服务采用混合 async/sync 模式:
- 缓存操作CacheService真正的 async使用 aioredis
- 数据库查询db.query同步的 SQLAlchemy Session
设计决策
--------
1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优
2. 缓存未命中时的同步查询FastAPI 会在线程池中执行,不会阻塞事件循环
3. 调用方必须在 async 上下文中使用 await
使用示例
--------
user = await UserCacheService.get_user_by_id(db, user_id)
await UserCacheService.invalidate_user_cache(user_id, email)
"""
from typing import Optional
@@ -12,9 +29,12 @@ from src.core.logger import logger
from src.models.database import User
class UserCacheService:
"""用户缓存服务"""
"""用户缓存服务
提供 User 的缓存查询功能,减少数据库访问。
所有公开方法均为 async需要在 async 上下文中调用。
"""
# 缓存 TTL- 使用统一常量
CACHE_TTL = CacheTTL.USER

View File

@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
from src.core.logger import logger
from src.database import get_db
from src.models.database import AuditEventType, AuditLog
from src.utils.transaction_manager import transactional
@@ -19,10 +18,13 @@ from src.utils.transaction_manager import transactional
class AuditService:
"""审计服务"""
"""审计服务
事务策略:本服务不负责事务提交,由中间件统一管理。
所有方法只做 db.add/flush提交由请求结束时的中间件处理。
"""
@staticmethod
@transactional(commit=False) # 不自动提交,让调用方决定
def log_event(
db: Session,
event_type: AuditEventType,
@@ -54,47 +56,44 @@ class AuditService:
Returns:
审计日志记录
Note:
不在此方法内提交事务,由调用方或中间件统一管理。
"""
try:
audit_log = AuditLog(
event_type=event_type.value,
description=description,
user_id=user_id,
api_key_id=api_key_id,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
status_code=status_code,
error_message=error_message,
event_metadata=metadata,
)
audit_log = AuditLog(
event_type=event_type.value,
description=description,
user_id=user_id,
api_key_id=api_key_id,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
status_code=status_code,
error_message=error_message,
event_metadata=metadata,
)
db.add(audit_log)
db.commit() # 立即提交事务,释放数据库锁
db.refresh(audit_log)
db.add(audit_log)
# 使用 flush 使记录可见但不提交事务,事务由中间件统一管理
db.flush()
# 同时记录到系统日志
log_message = (
f"AUDIT [{event_type.value}] - {description} | "
f"user_id={user_id}, ip={ip_address}"
)
# 同时记录到系统日志
log_message = (
f"AUDIT [{event_type.value}] - {description} | "
f"user_id={user_id}, ip={ip_address}"
)
if event_type in [
AuditEventType.UNAUTHORIZED_ACCESS,
AuditEventType.SUSPICIOUS_ACTIVITY,
]:
logger.warning(log_message)
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
logger.info(log_message)
else:
logger.debug(log_message)
if event_type in [
AuditEventType.UNAUTHORIZED_ACCESS,
AuditEventType.SUSPICIOUS_ACTIVITY,
]:
logger.warning(log_message)
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
logger.info(log_message)
else:
logger.debug(log_message)
return audit_log
except Exception as e:
logger.error(f"Failed to log audit event: {e}")
db.rollback()
return None
return audit_log
@staticmethod
def log_login_attempt(

File diff suppressed because it is too large Load Diff