mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
refactor: consolidate transaction management and remove legacy modules
- Remove unused context.py module (replaced by request.state) - Remove provider_cache.py (no longer needed) - Unify environment loading in config/settings.py instead of __init__.py - Add deprecation warning for get_async_db() (consolidating on sync Session) - Enhance database.py documentation with comprehensive transaction strategy - Simplify audit logging to reuse request-level Session (no separate connections) - Extract UsageService._build_usage_params() helper to reduce code duplication - Update model and user cache implementations with refined transaction handling - Remove unnecessary sessionmaker from pipeline - Clean up audit service exception handling
This commit is contained in:
@@ -3,10 +3,8 @@
|
|||||||
A proxy server that enables AI models to work with multiple API providers.
|
A proxy server that enables AI models to work with multiple API providers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
# 注意: dotenv 加载已统一移至 src/config/settings.py
|
||||||
|
# 不要在此处重复加载
|
||||||
# Load environment variables from .env file
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src._version import __version__
|
from src._version import __version__
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from enum import Enum
|
|||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from src.core.exceptions import QuotaExceededException
|
from src.core.exceptions import QuotaExceededException
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
@@ -241,11 +241,15 @@ class ApiRequestPipeline:
|
|||||||
status_code: Optional[int] = None,
|
status_code: Optional[int] = None,
|
||||||
error: Optional[str] = None,
|
error: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""记录审计事件
|
||||||
|
|
||||||
|
事务策略:复用请求级 Session,不单独提交。
|
||||||
|
审计记录随主事务一起提交,由中间件统一管理。
|
||||||
|
"""
|
||||||
if not getattr(adapter, "audit_log_enabled", True):
|
if not getattr(adapter, "audit_log_enabled", True):
|
||||||
return
|
return
|
||||||
|
|
||||||
bind = context.db.get_bind()
|
if context.db is None:
|
||||||
if bind is None:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
event_type = adapter.audit_success_event if success else adapter.audit_failure_event
|
event_type = adapter.audit_success_event if success else adapter.audit_failure_event
|
||||||
@@ -265,11 +269,11 @@ class ApiRequestPipeline:
|
|||||||
error=error,
|
error=error,
|
||||||
)
|
)
|
||||||
|
|
||||||
SessionMaker = sessionmaker(bind=bind)
|
|
||||||
audit_session = SessionMaker()
|
|
||||||
try:
|
try:
|
||||||
|
# 复用请求级 Session,不创建新的连接
|
||||||
|
# 审计记录随主事务一起提交,由中间件统一管理
|
||||||
self.audit_service.log_event(
|
self.audit_service.log_event(
|
||||||
db=audit_session,
|
db=context.db,
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
description=f"{context.request.method} {context.request.url.path} via {adapter.name}",
|
description=f"{context.request.method} {context.request.url.path} via {adapter.name}",
|
||||||
user_id=context.user.id if context.user else None,
|
user_id=context.user.id if context.user else None,
|
||||||
@@ -281,12 +285,9 @@ class ApiRequestPipeline:
|
|||||||
error_message=error,
|
error_message=error,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
audit_session.commit()
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
audit_session.rollback()
|
# 审计失败不应影响主请求,仅记录警告
|
||||||
logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}")
|
logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}")
|
||||||
finally:
|
|
||||||
audit_session.close()
|
|
||||||
|
|
||||||
def _build_audit_metadata(
|
def _build_audit_metadata(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,168 +0,0 @@
|
|||||||
"""
|
|
||||||
统一的请求上下文
|
|
||||||
|
|
||||||
RequestContext 贯穿整个请求生命周期,包含所有请求相关信息。
|
|
||||||
这确保了数据在各层之间传递时不会丢失。
|
|
||||||
|
|
||||||
使用方式:
|
|
||||||
1. Pipeline 层创建 RequestContext
|
|
||||||
2. 各层通过 context 访问和更新信息
|
|
||||||
3. Adapter 层使用 context 记录 Usage
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RequestContext:
|
|
||||||
"""
|
|
||||||
请求上下文 - 贯穿整个请求生命周期
|
|
||||||
|
|
||||||
设计原则:
|
|
||||||
1. 在请求开始时创建,包含所有已知信息
|
|
||||||
2. 在请求执行过程中逐步填充 Provider 信息
|
|
||||||
3. 在请求结束时用于记录 Usage
|
|
||||||
"""
|
|
||||||
|
|
||||||
# ==================== 请求标识 ====================
|
|
||||||
request_id: str
|
|
||||||
|
|
||||||
# ==================== 认证信息 ====================
|
|
||||||
user: Any # User model
|
|
||||||
api_key: Any # ApiKey model
|
|
||||||
db: Any # Database session
|
|
||||||
|
|
||||||
# ==================== 请求信息 ====================
|
|
||||||
api_format: str # CLAUDE, OPENAI, GEMINI, etc.
|
|
||||||
model: str # 用户请求的模型名
|
|
||||||
is_stream: bool = False
|
|
||||||
|
|
||||||
# ==================== 原始请求 ====================
|
|
||||||
original_headers: Dict[str, str] = field(default_factory=dict)
|
|
||||||
original_body: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# ==================== 客户端信息 ====================
|
|
||||||
client_ip: str = "unknown"
|
|
||||||
user_agent: str = ""
|
|
||||||
|
|
||||||
# ==================== 计时 ====================
|
|
||||||
start_time: float = field(default_factory=time.time)
|
|
||||||
|
|
||||||
# ==================== Provider 信息(请求执行后填充)====================
|
|
||||||
provider_name: Optional[str] = None
|
|
||||||
provider_id: Optional[str] = None
|
|
||||||
endpoint_id: Optional[str] = None
|
|
||||||
provider_api_key_id: Optional[str] = None
|
|
||||||
|
|
||||||
# ==================== 模型映射信息 ====================
|
|
||||||
resolved_model: Optional[str] = None # 映射后的模型名
|
|
||||||
original_model: Optional[str] = None # 原始模型名(用于价格计算)
|
|
||||||
|
|
||||||
# ==================== 请求/响应头 ====================
|
|
||||||
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
|
||||||
provider_response_headers: Dict[str, str] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# ==================== 追踪信息 ====================
|
|
||||||
attempt_id: Optional[str] = None
|
|
||||||
|
|
||||||
# ==================== 能力需求 ====================
|
|
||||||
capability_requirements: Dict[str, bool] = field(default_factory=dict)
|
|
||||||
# 运行时计算的能力需求,来源于:
|
|
||||||
# 1. 用户 model_capability_settings
|
|
||||||
# 2. 用户 ApiKey.force_capabilities
|
|
||||||
# 3. 请求头 X-Require-Capability
|
|
||||||
# 4. 失败重试时动态添加
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(
|
|
||||||
cls,
|
|
||||||
*,
|
|
||||||
db: Any,
|
|
||||||
user: Any,
|
|
||||||
api_key: Any,
|
|
||||||
api_format: str,
|
|
||||||
model: str,
|
|
||||||
is_stream: bool = False,
|
|
||||||
original_headers: Optional[Dict[str, str]] = None,
|
|
||||||
original_body: Optional[Dict[str, Any]] = None,
|
|
||||||
client_ip: str = "unknown",
|
|
||||||
user_agent: str = "",
|
|
||||||
request_id: Optional[str] = None,
|
|
||||||
) -> "RequestContext":
|
|
||||||
"""创建请求上下文"""
|
|
||||||
return cls(
|
|
||||||
request_id=request_id or str(uuid.uuid4()),
|
|
||||||
db=db,
|
|
||||||
user=user,
|
|
||||||
api_key=api_key,
|
|
||||||
api_format=api_format,
|
|
||||||
model=model,
|
|
||||||
is_stream=is_stream,
|
|
||||||
original_headers=original_headers or {},
|
|
||||||
original_body=original_body or {},
|
|
||||||
client_ip=client_ip,
|
|
||||||
user_agent=user_agent,
|
|
||||||
original_model=model, # 初始时原始模型等于请求模型
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_provider_info(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
provider_name: str,
|
|
||||||
provider_id: str,
|
|
||||||
endpoint_id: str,
|
|
||||||
provider_api_key_id: str,
|
|
||||||
resolved_model: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""更新 Provider 信息(请求执行后调用)"""
|
|
||||||
self.provider_name = provider_name
|
|
||||||
self.provider_id = provider_id
|
|
||||||
self.endpoint_id = endpoint_id
|
|
||||||
self.provider_api_key_id = provider_api_key_id
|
|
||||||
if resolved_model:
|
|
||||||
self.resolved_model = resolved_model
|
|
||||||
|
|
||||||
def update_headers(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
request_headers: Optional[Dict[str, str]] = None,
|
|
||||||
response_headers: Optional[Dict[str, str]] = None,
|
|
||||||
) -> None:
|
|
||||||
"""更新请求/响应头"""
|
|
||||||
if request_headers:
|
|
||||||
self.provider_request_headers = request_headers
|
|
||||||
if response_headers:
|
|
||||||
self.provider_response_headers = response_headers
|
|
||||||
|
|
||||||
@property
|
|
||||||
def elapsed_ms(self) -> int:
|
|
||||||
"""计算已经过的时间(毫秒)"""
|
|
||||||
return int((time.time() - self.start_time) * 1000)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def effective_model(self) -> str:
|
|
||||||
"""获取有效的模型名(映射后优先)"""
|
|
||||||
return self.resolved_model or self.model
|
|
||||||
|
|
||||||
@property
|
|
||||||
def billing_model(self) -> str:
|
|
||||||
"""获取计费模型名(原始模型优先)"""
|
|
||||||
return self.original_model or self.model
|
|
||||||
|
|
||||||
def to_metadata_dict(self) -> Dict[str, Any]:
|
|
||||||
"""转换为元数据字典(用于 Usage 记录)"""
|
|
||||||
return {
|
|
||||||
"api_format": self.api_format,
|
|
||||||
"provider": self.provider_name or "unknown",
|
|
||||||
"model": self.effective_model,
|
|
||||||
"original_model": self.billing_model,
|
|
||||||
"provider_id": self.provider_id,
|
|
||||||
"provider_endpoint_id": self.endpoint_id,
|
|
||||||
"provider_api_key_id": self.provider_api_key_id,
|
|
||||||
"provider_request_headers": self.provider_request_headers,
|
|
||||||
"provider_response_headers": self.provider_response_headers,
|
|
||||||
"attempt_id": self.attempt_id,
|
|
||||||
}
|
|
||||||
@@ -223,7 +223,18 @@ def _ensure_async_engine() -> AsyncEngine:
|
|||||||
|
|
||||||
|
|
||||||
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""获取异步数据库会话"""
|
"""获取异步数据库会话
|
||||||
|
|
||||||
|
.. deprecated::
|
||||||
|
此方法已废弃,项目统一使用同步 Session。
|
||||||
|
未来版本可能移除此方法。请使用 get_db() 代替。
|
||||||
|
"""
|
||||||
|
import warnings
|
||||||
|
warnings.warn(
|
||||||
|
"get_async_db() 已废弃,项目统一使用同步 Session。请使用 get_db() 代替。",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
# 确保异步引擎已初始化
|
# 确保异步引擎已初始化
|
||||||
_ensure_async_engine()
|
_ensure_async_engine()
|
||||||
|
|
||||||
@@ -237,12 +248,34 @@ async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
|||||||
def get_db(request: Request = None) -> Generator[Session, None, None]: # type: ignore[assignment]
|
def get_db(request: Request = None) -> Generator[Session, None, None]: # type: ignore[assignment]
|
||||||
"""获取数据库会话
|
"""获取数据库会话
|
||||||
|
|
||||||
注意:事务管理由业务逻辑层显式控制(手动调用 commit/rollback)
|
事务策略说明
|
||||||
这里只负责会话的创建和关闭,不自动提交
|
============
|
||||||
|
本项目采用**混合事务管理**策略:
|
||||||
|
|
||||||
在 FastAPI 请求上下文中通过 Depends(get_db) 调用时,会自动注入 Request 对象,
|
1. **LLM 请求路径**:
|
||||||
支持中间件管理的 session 复用;在非请求上下文中直接调用 get_db() 时,
|
- 由 PluginMiddleware 统一管理事务
|
||||||
request 为 None,退化为独立 session 模式。
|
- Service 层使用 db.flush() 使更改可见,但不提交
|
||||||
|
- 请求结束时由中间件统一 commit 或 rollback
|
||||||
|
- 例外:UsageService.record_usage() 会显式 commit,因为使用记录需要立即持久化
|
||||||
|
|
||||||
|
2. **管理后台 API**:
|
||||||
|
- 路由层显式调用 db.commit()
|
||||||
|
- 每个操作独立提交,不依赖中间件
|
||||||
|
|
||||||
|
3. **后台任务/调度器**:
|
||||||
|
- 使用独立 Session(通过 create_session() 或 next(get_db()))
|
||||||
|
- 自行管理事务生命周期
|
||||||
|
|
||||||
|
使用方式
|
||||||
|
========
|
||||||
|
- FastAPI 请求:通过 Depends(get_db) 注入,支持中间件管理的 session 复用
|
||||||
|
- 非请求上下文:直接调用 get_db(),退化为独立 session 模式
|
||||||
|
|
||||||
|
注意事项
|
||||||
|
========
|
||||||
|
- 本函数不自动提交事务
|
||||||
|
- 异常时会自动回滚
|
||||||
|
- 中间件管理模式下,session 关闭由中间件负责
|
||||||
"""
|
"""
|
||||||
# FastAPI 请求上下文:优先复用中间件绑定的 request.state.db
|
# FastAPI 请求上下文:优先复用中间件绑定的 request.state.db
|
||||||
if request is not None:
|
if request is not None:
|
||||||
|
|||||||
22
src/services/cache/model_cache.py
vendored
22
src/services/cache/model_cache.py
vendored
@@ -1,5 +1,21 @@
|
|||||||
"""
|
"""
|
||||||
Model 映射缓存服务 - 减少模型查询
|
Model 映射缓存服务 - 减少模型查询
|
||||||
|
|
||||||
|
架构说明
|
||||||
|
========
|
||||||
|
本服务采用混合 async/sync 模式:
|
||||||
|
- 缓存操作(CacheService):真正的 async,使用 aioredis
|
||||||
|
- 数据库查询(db.query):同步的 SQLAlchemy Session
|
||||||
|
|
||||||
|
设计决策
|
||||||
|
--------
|
||||||
|
1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优
|
||||||
|
2. 缓存未命中时的同步查询:FastAPI 会在线程池中执行,不会阻塞事件循环
|
||||||
|
3. 调用方必须在 async 上下文中使用 await
|
||||||
|
|
||||||
|
使用示例
|
||||||
|
--------
|
||||||
|
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, "gpt-4")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
@@ -19,7 +35,11 @@ from src.models.database import GlobalModel, Model
|
|||||||
|
|
||||||
|
|
||||||
class ModelCacheService:
|
class ModelCacheService:
|
||||||
"""Model 映射缓存服务"""
|
"""Model 映射缓存服务
|
||||||
|
|
||||||
|
提供 GlobalModel 和 Model 的缓存查询功能,减少数据库访问。
|
||||||
|
所有公开方法均为 async,需要在 async 上下文中调用。
|
||||||
|
"""
|
||||||
|
|
||||||
# 缓存 TTL(秒)- 使用统一常量
|
# 缓存 TTL(秒)- 使用统一常量
|
||||||
CACHE_TTL = CacheTTL.MODEL
|
CACHE_TTL = CacheTTL.MODEL
|
||||||
|
|||||||
254
src/services/cache/provider_cache.py
vendored
254
src/services/cache/provider_cache.py
vendored
@@ -1,254 +0,0 @@
|
|||||||
"""
|
|
||||||
Provider 配置缓存服务 - 减少 Provider/Endpoint/APIKey 查询
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from src.config.constants import CacheTTL
|
|
||||||
from src.core.cache_service import CacheKeys, CacheService
|
|
||||||
from src.core.logger import logger
|
|
||||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderCacheService:
|
|
||||||
"""Provider 配置缓存服务"""
|
|
||||||
|
|
||||||
# 缓存 TTL(秒)- 使用统一常量
|
|
||||||
CACHE_TTL = CacheTTL.PROVIDER
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_provider_by_id(db: Session, provider_id: str) -> Optional[Provider]:
|
|
||||||
"""
|
|
||||||
获取 Provider(带缓存)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
provider_id: Provider ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Provider 对象或 None
|
|
||||||
"""
|
|
||||||
cache_key = CacheKeys.provider_by_id(provider_id)
|
|
||||||
|
|
||||||
# 1. 尝试从缓存获取
|
|
||||||
cached_data = await CacheService.get(cache_key)
|
|
||||||
if cached_data:
|
|
||||||
logger.debug(f"Provider 缓存命中: {provider_id}")
|
|
||||||
return ProviderCacheService._dict_to_provider(cached_data)
|
|
||||||
|
|
||||||
# 2. 缓存未命中,查询数据库
|
|
||||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
|
||||||
|
|
||||||
# 3. 写入缓存
|
|
||||||
if provider:
|
|
||||||
provider_dict = ProviderCacheService._provider_to_dict(provider)
|
|
||||||
await CacheService.set(
|
|
||||||
cache_key, provider_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
|
||||||
)
|
|
||||||
logger.debug(f"Provider 已缓存: {provider_id}")
|
|
||||||
|
|
||||||
return provider
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_endpoint_by_id(db: Session, endpoint_id: str) -> Optional[ProviderEndpoint]:
|
|
||||||
"""
|
|
||||||
获取 Endpoint(带缓存)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
endpoint_id: Endpoint ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ProviderEndpoint 对象或 None
|
|
||||||
"""
|
|
||||||
cache_key = CacheKeys.endpoint_by_id(endpoint_id)
|
|
||||||
|
|
||||||
# 1. 尝试从缓存获取
|
|
||||||
cached_data = await CacheService.get(cache_key)
|
|
||||||
if cached_data:
|
|
||||||
logger.debug(f"Endpoint 缓存命中: {endpoint_id}")
|
|
||||||
return ProviderCacheService._dict_to_endpoint(cached_data)
|
|
||||||
|
|
||||||
# 2. 缓存未命中,查询数据库
|
|
||||||
endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == endpoint_id).first()
|
|
||||||
|
|
||||||
# 3. 写入缓存
|
|
||||||
if endpoint:
|
|
||||||
endpoint_dict = ProviderCacheService._endpoint_to_dict(endpoint)
|
|
||||||
await CacheService.set(
|
|
||||||
cache_key, endpoint_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
|
||||||
)
|
|
||||||
logger.debug(f"Endpoint 已缓存: {endpoint_id}")
|
|
||||||
|
|
||||||
return endpoint
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[ProviderAPIKey]:
|
|
||||||
"""
|
|
||||||
获取 API Key(带缓存)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
api_key_id: API Key ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ProviderAPIKey 对象或 None
|
|
||||||
"""
|
|
||||||
cache_key = CacheKeys.api_key_by_id(api_key_id)
|
|
||||||
|
|
||||||
# 1. 尝试从缓存获取
|
|
||||||
cached_data = await CacheService.get(cache_key)
|
|
||||||
if cached_data:
|
|
||||||
logger.debug(f"API Key 缓存命中: {api_key_id}")
|
|
||||||
return ProviderCacheService._dict_to_api_key(cached_data)
|
|
||||||
|
|
||||||
# 2. 缓存未命中,查询数据库
|
|
||||||
api_key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == api_key_id).first()
|
|
||||||
|
|
||||||
# 3. 写入缓存
|
|
||||||
if api_key:
|
|
||||||
api_key_dict = ProviderCacheService._api_key_to_dict(api_key)
|
|
||||||
await CacheService.set(
|
|
||||||
cache_key, api_key_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
|
||||||
)
|
|
||||||
logger.debug(f"API Key 已缓存: {api_key_id}")
|
|
||||||
|
|
||||||
return api_key
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def invalidate_provider_cache(provider_id: str):
|
|
||||||
"""
|
|
||||||
清除 Provider 缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
provider_id: Provider ID
|
|
||||||
"""
|
|
||||||
await CacheService.delete(CacheKeys.provider_by_id(provider_id))
|
|
||||||
logger.debug(f"Provider 缓存已清除: {provider_id}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def invalidate_endpoint_cache(endpoint_id: str):
|
|
||||||
"""
|
|
||||||
清除 Endpoint 缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
endpoint_id: Endpoint ID
|
|
||||||
"""
|
|
||||||
await CacheService.delete(CacheKeys.endpoint_by_id(endpoint_id))
|
|
||||||
logger.debug(f"Endpoint 缓存已清除: {endpoint_id}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def invalidate_api_key_cache(api_key_id: str):
|
|
||||||
"""
|
|
||||||
清除 API Key 缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key_id: API Key ID
|
|
||||||
"""
|
|
||||||
await CacheService.delete(CacheKeys.api_key_by_id(api_key_id))
|
|
||||||
logger.debug(f"API Key 缓存已清除: {api_key_id}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _provider_to_dict(provider: Provider) -> dict:
|
|
||||||
"""将 Provider 对象转换为字典(用于缓存)"""
|
|
||||||
return {
|
|
||||||
"id": provider.id,
|
|
||||||
"name": provider.name,
|
|
||||||
"api_format": provider.api_format,
|
|
||||||
"base_url": provider.base_url,
|
|
||||||
"is_active": provider.is_active,
|
|
||||||
"priority": provider.priority,
|
|
||||||
"rpm_limit": provider.rpm_limit,
|
|
||||||
"rpm_used": provider.rpm_used,
|
|
||||||
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
|
|
||||||
"config": provider.config,
|
|
||||||
"description": provider.description,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _dict_to_provider(provider_dict: dict) -> Provider:
|
|
||||||
"""从字典重建 Provider 对象(分离的对象,不在 Session 中)"""
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
provider = Provider(
|
|
||||||
id=provider_dict["id"],
|
|
||||||
name=provider_dict["name"],
|
|
||||||
api_format=provider_dict["api_format"],
|
|
||||||
base_url=provider_dict.get("base_url"),
|
|
||||||
is_active=provider_dict["is_active"],
|
|
||||||
priority=provider_dict.get("priority", 0),
|
|
||||||
rpm_limit=provider_dict.get("rpm_limit"),
|
|
||||||
rpm_used=provider_dict.get("rpm_used", 0),
|
|
||||||
config=provider_dict.get("config"),
|
|
||||||
description=provider_dict.get("description"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if provider_dict.get("rpm_reset_at"):
|
|
||||||
provider.rpm_reset_at = datetime.fromisoformat(provider_dict["rpm_reset_at"])
|
|
||||||
|
|
||||||
return provider
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _endpoint_to_dict(endpoint: ProviderEndpoint) -> dict:
|
|
||||||
"""将 Endpoint 对象转换为字典"""
|
|
||||||
return {
|
|
||||||
"id": endpoint.id,
|
|
||||||
"provider_id": endpoint.provider_id,
|
|
||||||
"name": endpoint.name,
|
|
||||||
"base_url": endpoint.base_url,
|
|
||||||
"is_active": endpoint.is_active,
|
|
||||||
"priority": endpoint.priority,
|
|
||||||
"weight": endpoint.weight,
|
|
||||||
"custom_path": endpoint.custom_path,
|
|
||||||
"config": endpoint.config,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _dict_to_endpoint(endpoint_dict: dict) -> ProviderEndpoint:
|
|
||||||
"""从字典重建 Endpoint 对象"""
|
|
||||||
endpoint = ProviderEndpoint(
|
|
||||||
id=endpoint_dict["id"],
|
|
||||||
provider_id=endpoint_dict["provider_id"],
|
|
||||||
name=endpoint_dict["name"],
|
|
||||||
base_url=endpoint_dict["base_url"],
|
|
||||||
is_active=endpoint_dict["is_active"],
|
|
||||||
priority=endpoint_dict.get("priority", 0),
|
|
||||||
weight=endpoint_dict.get("weight", 1.0),
|
|
||||||
custom_path=endpoint_dict.get("custom_path"),
|
|
||||||
config=endpoint_dict.get("config"),
|
|
||||||
)
|
|
||||||
return endpoint
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _api_key_to_dict(api_key: ProviderAPIKey) -> dict:
|
|
||||||
"""将 API Key 对象转换为字典"""
|
|
||||||
return {
|
|
||||||
"id": api_key.id,
|
|
||||||
"endpoint_id": api_key.endpoint_id,
|
|
||||||
"key_value": api_key.key_value,
|
|
||||||
"is_active": api_key.is_active,
|
|
||||||
"max_rpm": api_key.max_rpm,
|
|
||||||
"current_rpm": api_key.current_rpm,
|
|
||||||
"health_score": api_key.health_score,
|
|
||||||
"circuit_breaker_state": api_key.circuit_breaker_state,
|
|
||||||
"adaptive_concurrency_limit": api_key.adaptive_concurrency_limit,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _dict_to_api_key(api_key_dict: dict) -> ProviderAPIKey:
|
|
||||||
"""从字典重建 API Key 对象"""
|
|
||||||
api_key = ProviderAPIKey(
|
|
||||||
id=api_key_dict["id"],
|
|
||||||
endpoint_id=api_key_dict["endpoint_id"],
|
|
||||||
key_value=api_key_dict["key_value"],
|
|
||||||
is_active=api_key_dict["is_active"],
|
|
||||||
max_rpm=api_key_dict.get("max_rpm"),
|
|
||||||
current_rpm=api_key_dict.get("current_rpm", 0),
|
|
||||||
health_score=api_key_dict.get("health_score", 1.0),
|
|
||||||
circuit_breaker_state=api_key_dict.get("circuit_breaker_state"),
|
|
||||||
adaptive_concurrency_limit=api_key_dict.get("adaptive_concurrency_limit"),
|
|
||||||
)
|
|
||||||
return api_key
|
|
||||||
24
src/services/cache/user_cache.py
vendored
24
src/services/cache/user_cache.py
vendored
@@ -1,5 +1,22 @@
|
|||||||
"""
|
"""
|
||||||
用户缓存服务 - 减少数据库查询
|
用户缓存服务 - 减少数据库查询
|
||||||
|
|
||||||
|
架构说明
|
||||||
|
========
|
||||||
|
本服务采用混合 async/sync 模式:
|
||||||
|
- 缓存操作(CacheService):真正的 async,使用 aioredis
|
||||||
|
- 数据库查询(db.query):同步的 SQLAlchemy Session
|
||||||
|
|
||||||
|
设计决策
|
||||||
|
--------
|
||||||
|
1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优
|
||||||
|
2. 缓存未命中时的同步查询:FastAPI 会在线程池中执行,不会阻塞事件循环
|
||||||
|
3. 调用方必须在 async 上下文中使用 await
|
||||||
|
|
||||||
|
使用示例
|
||||||
|
--------
|
||||||
|
user = await UserCacheService.get_user_by_id(db, user_id)
|
||||||
|
await UserCacheService.invalidate_user_cache(user_id, email)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -12,9 +29,12 @@ from src.core.logger import logger
|
|||||||
from src.models.database import User
|
from src.models.database import User
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class UserCacheService:
|
class UserCacheService:
|
||||||
"""用户缓存服务"""
|
"""用户缓存服务
|
||||||
|
|
||||||
|
提供 User 的缓存查询功能,减少数据库访问。
|
||||||
|
所有公开方法均为 async,需要在 async 上下文中调用。
|
||||||
|
"""
|
||||||
|
|
||||||
# 缓存 TTL(秒)- 使用统一常量
|
# 缓存 TTL(秒)- 使用统一常量
|
||||||
CACHE_TTL = CacheTTL.USER
|
CACHE_TTL = CacheTTL.USER
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
|
|||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.database import get_db
|
from src.database import get_db
|
||||||
from src.models.database import AuditEventType, AuditLog
|
from src.models.database import AuditEventType, AuditLog
|
||||||
from src.utils.transaction_manager import transactional
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -19,10 +18,13 @@ from src.utils.transaction_manager import transactional
|
|||||||
|
|
||||||
|
|
||||||
class AuditService:
|
class AuditService:
|
||||||
"""审计服务"""
|
"""审计服务
|
||||||
|
|
||||||
|
事务策略:本服务不负责事务提交,由中间件统一管理。
|
||||||
|
所有方法只做 db.add/flush,提交由请求结束时的中间件处理。
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@transactional(commit=False) # 不自动提交,让调用方决定
|
|
||||||
def log_event(
|
def log_event(
|
||||||
db: Session,
|
db: Session,
|
||||||
event_type: AuditEventType,
|
event_type: AuditEventType,
|
||||||
@@ -54,47 +56,44 @@ class AuditService:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
审计日志记录
|
审计日志记录
|
||||||
|
|
||||||
|
Note:
|
||||||
|
不在此方法内提交事务,由调用方或中间件统一管理。
|
||||||
"""
|
"""
|
||||||
try:
|
audit_log = AuditLog(
|
||||||
audit_log = AuditLog(
|
event_type=event_type.value,
|
||||||
event_type=event_type.value,
|
description=description,
|
||||||
description=description,
|
user_id=user_id,
|
||||||
user_id=user_id,
|
api_key_id=api_key_id,
|
||||||
api_key_id=api_key_id,
|
ip_address=ip_address,
|
||||||
ip_address=ip_address,
|
user_agent=user_agent,
|
||||||
user_agent=user_agent,
|
request_id=request_id,
|
||||||
request_id=request_id,
|
status_code=status_code,
|
||||||
status_code=status_code,
|
error_message=error_message,
|
||||||
error_message=error_message,
|
event_metadata=metadata,
|
||||||
event_metadata=metadata,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
db.add(audit_log)
|
db.add(audit_log)
|
||||||
db.commit() # 立即提交事务,释放数据库锁
|
# 使用 flush 使记录可见但不提交事务,事务由中间件统一管理
|
||||||
db.refresh(audit_log)
|
db.flush()
|
||||||
|
|
||||||
# 同时记录到系统日志
|
# 同时记录到系统日志
|
||||||
log_message = (
|
log_message = (
|
||||||
f"AUDIT [{event_type.value}] - {description} | "
|
f"AUDIT [{event_type.value}] - {description} | "
|
||||||
f"user_id={user_id}, ip={ip_address}"
|
f"user_id={user_id}, ip={ip_address}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if event_type in [
|
if event_type in [
|
||||||
AuditEventType.UNAUTHORIZED_ACCESS,
|
AuditEventType.UNAUTHORIZED_ACCESS,
|
||||||
AuditEventType.SUSPICIOUS_ACTIVITY,
|
AuditEventType.SUSPICIOUS_ACTIVITY,
|
||||||
]:
|
]:
|
||||||
logger.warning(log_message)
|
logger.warning(log_message)
|
||||||
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
|
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
else:
|
else:
|
||||||
logger.debug(log_message)
|
logger.debug(log_message)
|
||||||
|
|
||||||
return audit_log
|
return audit_log
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to log audit event: {e}")
|
|
||||||
db.rollback()
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def log_login_attempt(
|
def log_login_attempt(
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user