From b2a857c16488bb400615c8afa87f91d1a9c3b519 Mon Sep 17 00:00:00 2001 From: fawney19 Date: Thu, 18 Dec 2025 01:59:40 +0800 Subject: [PATCH] refactor: consolidate transaction management and remove legacy modules - Remove unused context.py module (replaced by request.state) - Remove provider_cache.py (no longer needed) - Unify environment loading in config/settings.py instead of __init__.py - Add deprecation warning for get_async_db() (consolidating on sync Session) - Enhance database.py documentation with comprehensive transaction strategy - Simplify audit logging to reuse request-level Session (no separate connections) - Extract UsageService._build_usage_params() helper to reduce code duplication - Update model and user cache implementations with refined transaction handling - Remove unnecessary sessionmaker from pipeline - Clean up audit service exception handling --- src/__init__.py | 6 +- src/api/base/pipeline.py | 21 +- src/core/context.py | 168 ----- src/database/database.py | 45 +- src/services/cache/model_cache.py | 22 +- src/services/cache/provider_cache.py | 254 ------- src/services/cache/user_cache.py | 24 +- src/services/system/audit.py | 77 ++- src/services/usage/service.py | 978 ++++++++++++++------------- 9 files changed, 638 insertions(+), 957 deletions(-) delete mode 100644 src/core/context.py delete mode 100644 src/services/cache/provider_cache.py diff --git a/src/__init__.py b/src/__init__.py index 7abffa6..6377288 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -3,10 +3,8 @@ A proxy server that enables AI models to work with multiple API providers. """ -from dotenv import load_dotenv - -# Load environment variables from .env file -load_dotenv() +# 注意: dotenv 加载已统一移至 src/config/settings.py +# 不要在此处重复加载 try: from src._version import __version__ diff --git a/src/api/base/pipeline.py b/src/api/base/pipeline.py index f49501b..d0fffc5 100644 --- a/src/api/base/pipeline.py +++ b/src/api/base/pipeline.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Any, Optional, Tuple from fastapi import HTTPException, Request -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm import Session from src.core.exceptions import QuotaExceededException from src.core.logger import logger @@ -241,11 +241,15 @@ class ApiRequestPipeline: status_code: Optional[int] = None, error: Optional[str] = None, ) -> None: + """记录审计事件 + + 事务策略:复用请求级 Session,不单独提交。 + 审计记录随主事务一起提交,由中间件统一管理。 + """ if not getattr(adapter, "audit_log_enabled", True): return - bind = context.db.get_bind() - if bind is None: + if context.db is None: return event_type = adapter.audit_success_event if success else adapter.audit_failure_event @@ -265,11 +269,11 @@ class ApiRequestPipeline: error=error, ) - SessionMaker = sessionmaker(bind=bind) - audit_session = SessionMaker() try: + # 复用请求级 Session,不创建新的连接 + # 审计记录随主事务一起提交,由中间件统一管理 self.audit_service.log_event( - db=audit_session, + db=context.db, event_type=event_type, description=f"{context.request.method} {context.request.url.path} via {adapter.name}", user_id=context.user.id if context.user else None, @@ -281,12 +285,9 @@ class ApiRequestPipeline: error_message=error, metadata=metadata, ) - audit_session.commit() except Exception as exc: - audit_session.rollback() + # 审计失败不应影响主请求,仅记录警告 logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}") - finally: - audit_session.close() def _build_audit_metadata( self, diff --git a/src/core/context.py b/src/core/context.py deleted file mode 100644 index 553bf54..0000000 --- a/src/core/context.py +++ /dev/null @@ -1,168 +0,0 @@ -""" -统一的请求上下文 - -RequestContext 贯穿整个请求生命周期,包含所有请求相关信息。 -这确保了数据在各层之间传递时不会丢失。 - -使用方式: -1. Pipeline 层创建 RequestContext -2. 各层通过 context 访问和更新信息 -3. Adapter 层使用 context 记录 Usage -""" - -import time -import uuid -from dataclasses import dataclass, field -from typing import Any, Dict, Optional - - -@dataclass -class RequestContext: - """ - 请求上下文 - 贯穿整个请求生命周期 - - 设计原则: - 1. 在请求开始时创建,包含所有已知信息 - 2. 在请求执行过程中逐步填充 Provider 信息 - 3. 在请求结束时用于记录 Usage - """ - - # ==================== 请求标识 ==================== - request_id: str - - # ==================== 认证信息 ==================== - user: Any # User model - api_key: Any # ApiKey model - db: Any # Database session - - # ==================== 请求信息 ==================== - api_format: str # CLAUDE, OPENAI, GEMINI, etc. - model: str # 用户请求的模型名 - is_stream: bool = False - - # ==================== 原始请求 ==================== - original_headers: Dict[str, str] = field(default_factory=dict) - original_body: Dict[str, Any] = field(default_factory=dict) - - # ==================== 客户端信息 ==================== - client_ip: str = "unknown" - user_agent: str = "" - - # ==================== 计时 ==================== - start_time: float = field(default_factory=time.time) - - # ==================== Provider 信息(请求执行后填充)==================== - provider_name: Optional[str] = None - provider_id: Optional[str] = None - endpoint_id: Optional[str] = None - provider_api_key_id: Optional[str] = None - - # ==================== 模型映射信息 ==================== - resolved_model: Optional[str] = None # 映射后的模型名 - original_model: Optional[str] = None # 原始模型名(用于价格计算) - - # ==================== 请求/响应头 ==================== - provider_request_headers: Dict[str, str] = field(default_factory=dict) - provider_response_headers: Dict[str, str] = field(default_factory=dict) - - # ==================== 追踪信息 ==================== - attempt_id: Optional[str] = None - - # ==================== 能力需求 ==================== - capability_requirements: Dict[str, bool] = field(default_factory=dict) - # 运行时计算的能力需求,来源于: - # 1. 用户 model_capability_settings - # 2. 用户 ApiKey.force_capabilities - # 3. 请求头 X-Require-Capability - # 4. 失败重试时动态添加 - - @classmethod - def create( - cls, - *, - db: Any, - user: Any, - api_key: Any, - api_format: str, - model: str, - is_stream: bool = False, - original_headers: Optional[Dict[str, str]] = None, - original_body: Optional[Dict[str, Any]] = None, - client_ip: str = "unknown", - user_agent: str = "", - request_id: Optional[str] = None, - ) -> "RequestContext": - """创建请求上下文""" - return cls( - request_id=request_id or str(uuid.uuid4()), - db=db, - user=user, - api_key=api_key, - api_format=api_format, - model=model, - is_stream=is_stream, - original_headers=original_headers or {}, - original_body=original_body or {}, - client_ip=client_ip, - user_agent=user_agent, - original_model=model, # 初始时原始模型等于请求模型 - ) - - def update_provider_info( - self, - *, - provider_name: str, - provider_id: str, - endpoint_id: str, - provider_api_key_id: str, - resolved_model: Optional[str] = None, - ) -> None: - """更新 Provider 信息(请求执行后调用)""" - self.provider_name = provider_name - self.provider_id = provider_id - self.endpoint_id = endpoint_id - self.provider_api_key_id = provider_api_key_id - if resolved_model: - self.resolved_model = resolved_model - - def update_headers( - self, - *, - request_headers: Optional[Dict[str, str]] = None, - response_headers: Optional[Dict[str, str]] = None, - ) -> None: - """更新请求/响应头""" - if request_headers: - self.provider_request_headers = request_headers - if response_headers: - self.provider_response_headers = response_headers - - @property - def elapsed_ms(self) -> int: - """计算已经过的时间(毫秒)""" - return int((time.time() - self.start_time) * 1000) - - @property - def effective_model(self) -> str: - """获取有效的模型名(映射后优先)""" - return self.resolved_model or self.model - - @property - def billing_model(self) -> str: - """获取计费模型名(原始模型优先)""" - return self.original_model or self.model - - def to_metadata_dict(self) -> Dict[str, Any]: - """转换为元数据字典(用于 Usage 记录)""" - return { - "api_format": self.api_format, - "provider": self.provider_name or "unknown", - "model": self.effective_model, - "original_model": self.billing_model, - "provider_id": self.provider_id, - "provider_endpoint_id": self.endpoint_id, - "provider_api_key_id": self.provider_api_key_id, - "provider_request_headers": self.provider_request_headers, - "provider_response_headers": self.provider_response_headers, - "attempt_id": self.attempt_id, - } diff --git a/src/database/database.py b/src/database/database.py index e89f2ad..99f5b5e 100644 --- a/src/database/database.py +++ b/src/database/database.py @@ -223,7 +223,18 @@ def _ensure_async_engine() -> AsyncEngine: async def get_async_db() -> AsyncGenerator[AsyncSession, None]: - """获取异步数据库会话""" + """获取异步数据库会话 + + .. deprecated:: + 此方法已废弃,项目统一使用同步 Session。 + 未来版本可能移除此方法。请使用 get_db() 代替。 + """ + import warnings + warnings.warn( + "get_async_db() 已废弃,项目统一使用同步 Session。请使用 get_db() 代替。", + DeprecationWarning, + stacklevel=2, + ) # 确保异步引擎已初始化 _ensure_async_engine() @@ -237,12 +248,34 @@ async def get_async_db() -> AsyncGenerator[AsyncSession, None]: def get_db(request: Request = None) -> Generator[Session, None, None]: # type: ignore[assignment] """获取数据库会话 - 注意:事务管理由业务逻辑层显式控制(手动调用 commit/rollback) - 这里只负责会话的创建和关闭,不自动提交 + 事务策略说明 + ============ + 本项目采用**混合事务管理**策略: - 在 FastAPI 请求上下文中通过 Depends(get_db) 调用时,会自动注入 Request 对象, - 支持中间件管理的 session 复用;在非请求上下文中直接调用 get_db() 时, - request 为 None,退化为独立 session 模式。 + 1. **LLM 请求路径**: + - 由 PluginMiddleware 统一管理事务 + - Service 层使用 db.flush() 使更改可见,但不提交 + - 请求结束时由中间件统一 commit 或 rollback + - 例外:UsageService.record_usage() 会显式 commit,因为使用记录需要立即持久化 + + 2. **管理后台 API**: + - 路由层显式调用 db.commit() + - 每个操作独立提交,不依赖中间件 + + 3. **后台任务/调度器**: + - 使用独立 Session(通过 create_session() 或 next(get_db())) + - 自行管理事务生命周期 + + 使用方式 + ======== + - FastAPI 请求:通过 Depends(get_db) 注入,支持中间件管理的 session 复用 + - 非请求上下文:直接调用 get_db(),退化为独立 session 模式 + + 注意事项 + ======== + - 本函数不自动提交事务 + - 异常时会自动回滚 + - 中间件管理模式下,session 关闭由中间件负责 """ # FastAPI 请求上下文:优先复用中间件绑定的 request.state.db if request is not None: diff --git a/src/services/cache/model_cache.py b/src/services/cache/model_cache.py index 37661a7..9ad3cac 100644 --- a/src/services/cache/model_cache.py +++ b/src/services/cache/model_cache.py @@ -1,5 +1,21 @@ """ Model 映射缓存服务 - 减少模型查询 + +架构说明 +======== +本服务采用混合 async/sync 模式: +- 缓存操作(CacheService):真正的 async,使用 aioredis +- 数据库查询(db.query):同步的 SQLAlchemy Session + +设计决策 +-------- +1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优 +2. 缓存未命中时的同步查询:FastAPI 会在线程池中执行,不会阻塞事件循环 +3. 调用方必须在 async 上下文中使用 await + +使用示例 +-------- + global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, "gpt-4") """ import time @@ -19,7 +35,11 @@ from src.models.database import GlobalModel, Model class ModelCacheService: - """Model 映射缓存服务""" + """Model 映射缓存服务 + + 提供 GlobalModel 和 Model 的缓存查询功能,减少数据库访问。 + 所有公开方法均为 async,需要在 async 上下文中调用。 + """ # 缓存 TTL(秒)- 使用统一常量 CACHE_TTL = CacheTTL.MODEL diff --git a/src/services/cache/provider_cache.py b/src/services/cache/provider_cache.py deleted file mode 100644 index 2c9b25d..0000000 --- a/src/services/cache/provider_cache.py +++ /dev/null @@ -1,254 +0,0 @@ -""" -Provider 配置缓存服务 - 减少 Provider/Endpoint/APIKey 查询 -""" - -from typing import Any, Dict, List, Optional - -from sqlalchemy.orm import Session - -from src.config.constants import CacheTTL -from src.core.cache_service import CacheKeys, CacheService -from src.core.logger import logger -from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint - - - -class ProviderCacheService: - """Provider 配置缓存服务""" - - # 缓存 TTL(秒)- 使用统一常量 - CACHE_TTL = CacheTTL.PROVIDER - - @staticmethod - async def get_provider_by_id(db: Session, provider_id: str) -> Optional[Provider]: - """ - 获取 Provider(带缓存) - - Args: - db: 数据库会话 - provider_id: Provider ID - - Returns: - Provider 对象或 None - """ - cache_key = CacheKeys.provider_by_id(provider_id) - - # 1. 尝试从缓存获取 - cached_data = await CacheService.get(cache_key) - if cached_data: - logger.debug(f"Provider 缓存命中: {provider_id}") - return ProviderCacheService._dict_to_provider(cached_data) - - # 2. 缓存未命中,查询数据库 - provider = db.query(Provider).filter(Provider.id == provider_id).first() - - # 3. 写入缓存 - if provider: - provider_dict = ProviderCacheService._provider_to_dict(provider) - await CacheService.set( - cache_key, provider_dict, ttl_seconds=ProviderCacheService.CACHE_TTL - ) - logger.debug(f"Provider 已缓存: {provider_id}") - - return provider - - @staticmethod - async def get_endpoint_by_id(db: Session, endpoint_id: str) -> Optional[ProviderEndpoint]: - """ - 获取 Endpoint(带缓存) - - Args: - db: 数据库会话 - endpoint_id: Endpoint ID - - Returns: - ProviderEndpoint 对象或 None - """ - cache_key = CacheKeys.endpoint_by_id(endpoint_id) - - # 1. 尝试从缓存获取 - cached_data = await CacheService.get(cache_key) - if cached_data: - logger.debug(f"Endpoint 缓存命中: {endpoint_id}") - return ProviderCacheService._dict_to_endpoint(cached_data) - - # 2. 缓存未命中,查询数据库 - endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == endpoint_id).first() - - # 3. 写入缓存 - if endpoint: - endpoint_dict = ProviderCacheService._endpoint_to_dict(endpoint) - await CacheService.set( - cache_key, endpoint_dict, ttl_seconds=ProviderCacheService.CACHE_TTL - ) - logger.debug(f"Endpoint 已缓存: {endpoint_id}") - - return endpoint - - @staticmethod - async def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[ProviderAPIKey]: - """ - 获取 API Key(带缓存) - - Args: - db: 数据库会话 - api_key_id: API Key ID - - Returns: - ProviderAPIKey 对象或 None - """ - cache_key = CacheKeys.api_key_by_id(api_key_id) - - # 1. 尝试从缓存获取 - cached_data = await CacheService.get(cache_key) - if cached_data: - logger.debug(f"API Key 缓存命中: {api_key_id}") - return ProviderCacheService._dict_to_api_key(cached_data) - - # 2. 缓存未命中,查询数据库 - api_key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == api_key_id).first() - - # 3. 写入缓存 - if api_key: - api_key_dict = ProviderCacheService._api_key_to_dict(api_key) - await CacheService.set( - cache_key, api_key_dict, ttl_seconds=ProviderCacheService.CACHE_TTL - ) - logger.debug(f"API Key 已缓存: {api_key_id}") - - return api_key - - @staticmethod - async def invalidate_provider_cache(provider_id: str): - """ - 清除 Provider 缓存 - - Args: - provider_id: Provider ID - """ - await CacheService.delete(CacheKeys.provider_by_id(provider_id)) - logger.debug(f"Provider 缓存已清除: {provider_id}") - - @staticmethod - async def invalidate_endpoint_cache(endpoint_id: str): - """ - 清除 Endpoint 缓存 - - Args: - endpoint_id: Endpoint ID - """ - await CacheService.delete(CacheKeys.endpoint_by_id(endpoint_id)) - logger.debug(f"Endpoint 缓存已清除: {endpoint_id}") - - @staticmethod - async def invalidate_api_key_cache(api_key_id: str): - """ - 清除 API Key 缓存 - - Args: - api_key_id: API Key ID - """ - await CacheService.delete(CacheKeys.api_key_by_id(api_key_id)) - logger.debug(f"API Key 缓存已清除: {api_key_id}") - - @staticmethod - def _provider_to_dict(provider: Provider) -> dict: - """将 Provider 对象转换为字典(用于缓存)""" - return { - "id": provider.id, - "name": provider.name, - "api_format": provider.api_format, - "base_url": provider.base_url, - "is_active": provider.is_active, - "priority": provider.priority, - "rpm_limit": provider.rpm_limit, - "rpm_used": provider.rpm_used, - "rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None, - "config": provider.config, - "description": provider.description, - } - - @staticmethod - def _dict_to_provider(provider_dict: dict) -> Provider: - """从字典重建 Provider 对象(分离的对象,不在 Session 中)""" - from datetime import datetime - - provider = Provider( - id=provider_dict["id"], - name=provider_dict["name"], - api_format=provider_dict["api_format"], - base_url=provider_dict.get("base_url"), - is_active=provider_dict["is_active"], - priority=provider_dict.get("priority", 0), - rpm_limit=provider_dict.get("rpm_limit"), - rpm_used=provider_dict.get("rpm_used", 0), - config=provider_dict.get("config"), - description=provider_dict.get("description"), - ) - - if provider_dict.get("rpm_reset_at"): - provider.rpm_reset_at = datetime.fromisoformat(provider_dict["rpm_reset_at"]) - - return provider - - @staticmethod - def _endpoint_to_dict(endpoint: ProviderEndpoint) -> dict: - """将 Endpoint 对象转换为字典""" - return { - "id": endpoint.id, - "provider_id": endpoint.provider_id, - "name": endpoint.name, - "base_url": endpoint.base_url, - "is_active": endpoint.is_active, - "priority": endpoint.priority, - "weight": endpoint.weight, - "custom_path": endpoint.custom_path, - "config": endpoint.config, - } - - @staticmethod - def _dict_to_endpoint(endpoint_dict: dict) -> ProviderEndpoint: - """从字典重建 Endpoint 对象""" - endpoint = ProviderEndpoint( - id=endpoint_dict["id"], - provider_id=endpoint_dict["provider_id"], - name=endpoint_dict["name"], - base_url=endpoint_dict["base_url"], - is_active=endpoint_dict["is_active"], - priority=endpoint_dict.get("priority", 0), - weight=endpoint_dict.get("weight", 1.0), - custom_path=endpoint_dict.get("custom_path"), - config=endpoint_dict.get("config"), - ) - return endpoint - - @staticmethod - def _api_key_to_dict(api_key: ProviderAPIKey) -> dict: - """将 API Key 对象转换为字典""" - return { - "id": api_key.id, - "endpoint_id": api_key.endpoint_id, - "key_value": api_key.key_value, - "is_active": api_key.is_active, - "max_rpm": api_key.max_rpm, - "current_rpm": api_key.current_rpm, - "health_score": api_key.health_score, - "circuit_breaker_state": api_key.circuit_breaker_state, - "adaptive_concurrency_limit": api_key.adaptive_concurrency_limit, - } - - @staticmethod - def _dict_to_api_key(api_key_dict: dict) -> ProviderAPIKey: - """从字典重建 API Key 对象""" - api_key = ProviderAPIKey( - id=api_key_dict["id"], - endpoint_id=api_key_dict["endpoint_id"], - key_value=api_key_dict["key_value"], - is_active=api_key_dict["is_active"], - max_rpm=api_key_dict.get("max_rpm"), - current_rpm=api_key_dict.get("current_rpm", 0), - health_score=api_key_dict.get("health_score", 1.0), - circuit_breaker_state=api_key_dict.get("circuit_breaker_state"), - adaptive_concurrency_limit=api_key_dict.get("adaptive_concurrency_limit"), - ) - return api_key diff --git a/src/services/cache/user_cache.py b/src/services/cache/user_cache.py index e3e5d06..30b399f 100644 --- a/src/services/cache/user_cache.py +++ b/src/services/cache/user_cache.py @@ -1,5 +1,22 @@ """ 用户缓存服务 - 减少数据库查询 + +架构说明 +======== +本服务采用混合 async/sync 模式: +- 缓存操作(CacheService):真正的 async,使用 aioredis +- 数据库查询(db.query):同步的 SQLAlchemy Session + +设计决策 +-------- +1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优 +2. 缓存未命中时的同步查询:FastAPI 会在线程池中执行,不会阻塞事件循环 +3. 调用方必须在 async 上下文中使用 await + +使用示例 +-------- + user = await UserCacheService.get_user_by_id(db, user_id) + await UserCacheService.invalidate_user_cache(user_id, email) """ from typing import Optional @@ -12,9 +29,12 @@ from src.core.logger import logger from src.models.database import User - class UserCacheService: - """用户缓存服务""" + """用户缓存服务 + + 提供 User 的缓存查询功能,减少数据库访问。 + 所有公开方法均为 async,需要在 async 上下文中调用。 + """ # 缓存 TTL(秒)- 使用统一常量 CACHE_TTL = CacheTTL.USER diff --git a/src/services/system/audit.py b/src/services/system/audit.py index 77c1832..89bc9e9 100644 --- a/src/services/system/audit.py +++ b/src/services/system/audit.py @@ -11,7 +11,6 @@ from sqlalchemy.orm import Session from src.core.logger import logger from src.database import get_db from src.models.database import AuditEventType, AuditLog -from src.utils.transaction_manager import transactional @@ -19,10 +18,13 @@ from src.utils.transaction_manager import transactional class AuditService: - """审计服务""" + """审计服务 + + 事务策略:本服务不负责事务提交,由中间件统一管理。 + 所有方法只做 db.add/flush,提交由请求结束时的中间件处理。 + """ @staticmethod - @transactional(commit=False) # 不自动提交,让调用方决定 def log_event( db: Session, event_type: AuditEventType, @@ -54,47 +56,44 @@ class AuditService: Returns: 审计日志记录 + + Note: + 不在此方法内提交事务,由调用方或中间件统一管理。 """ - try: - audit_log = AuditLog( - event_type=event_type.value, - description=description, - user_id=user_id, - api_key_id=api_key_id, - ip_address=ip_address, - user_agent=user_agent, - request_id=request_id, - status_code=status_code, - error_message=error_message, - event_metadata=metadata, - ) + audit_log = AuditLog( + event_type=event_type.value, + description=description, + user_id=user_id, + api_key_id=api_key_id, + ip_address=ip_address, + user_agent=user_agent, + request_id=request_id, + status_code=status_code, + error_message=error_message, + event_metadata=metadata, + ) - db.add(audit_log) - db.commit() # 立即提交事务,释放数据库锁 - db.refresh(audit_log) + db.add(audit_log) + # 使用 flush 使记录可见但不提交事务,事务由中间件统一管理 + db.flush() - # 同时记录到系统日志 - log_message = ( - f"AUDIT [{event_type.value}] - {description} | " - f"user_id={user_id}, ip={ip_address}" - ) + # 同时记录到系统日志 + log_message = ( + f"AUDIT [{event_type.value}] - {description} | " + f"user_id={user_id}, ip={ip_address}" + ) - if event_type in [ - AuditEventType.UNAUTHORIZED_ACCESS, - AuditEventType.SUSPICIOUS_ACTIVITY, - ]: - logger.warning(log_message) - elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]: - logger.info(log_message) - else: - logger.debug(log_message) + if event_type in [ + AuditEventType.UNAUTHORIZED_ACCESS, + AuditEventType.SUSPICIOUS_ACTIVITY, + ]: + logger.warning(log_message) + elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]: + logger.info(log_message) + else: + logger.debug(log_message) - return audit_log - - except Exception as e: - logger.error(f"Failed to log audit event: {e}") - db.rollback() - return None + return audit_log @staticmethod def log_login_attempt( diff --git a/src/services/usage/service.py b/src/services/usage/service.py index c65e211..0cbd236 100644 --- a/src/services/usage/service.py +++ b/src/services/usage/service.py @@ -20,6 +20,336 @@ from src.services.system.config import SystemConfigService class UsageService: """用量统计服务""" + # ==================== 内部数据类 ==================== + + @staticmethod + def _build_usage_params( + *, + db: Session, + user: Optional[User], + api_key: Optional[ApiKey], + provider: str, + model: str, + input_tokens: int, + output_tokens: int, + cache_creation_input_tokens: int, + cache_read_input_tokens: int, + request_type: str, + api_format: Optional[str], + is_stream: bool, + response_time_ms: Optional[int], + first_byte_time_ms: Optional[int], + status_code: int, + error_message: Optional[str], + metadata: Optional[Dict[str, Any]], + request_headers: Optional[Dict[str, Any]], + request_body: Optional[Any], + provider_request_headers: Optional[Dict[str, Any]], + response_headers: Optional[Dict[str, Any]], + response_body: Optional[Any], + request_id: str, + provider_id: Optional[str], + provider_endpoint_id: Optional[str], + provider_api_key_id: Optional[str], + status: str, + target_model: Optional[str], + # 成本计算结果 + input_cost: float, + output_cost: float, + cache_creation_cost: float, + cache_read_cost: float, + cache_cost: float, + request_cost: float, + total_cost: float, + # 价格信息 + input_price: float, + output_price: float, + cache_creation_price: Optional[float], + cache_read_price: Optional[float], + request_price: Optional[float], + # 倍率 + actual_rate_multiplier: float, + is_free_tier: bool, + ) -> Dict[str, Any]: + """构建 Usage 记录的参数字典(内部方法,避免代码重复)""" + + # 根据配置决定是否记录请求详情 + should_log_headers = SystemConfigService.should_log_headers(db) + should_log_body = SystemConfigService.should_log_body(db) + + # 处理请求头(可能需要脱敏) + processed_request_headers = None + if should_log_headers and request_headers: + processed_request_headers = SystemConfigService.mask_sensitive_headers( + db, request_headers + ) + + # 处理提供商请求头(可能需要脱敏) + processed_provider_request_headers = None + if should_log_headers and provider_request_headers: + processed_provider_request_headers = SystemConfigService.mask_sensitive_headers( + db, provider_request_headers + ) + + # 处理请求体和响应体(可能需要截断) + processed_request_body = None + processed_response_body = None + if should_log_body: + if request_body: + processed_request_body = SystemConfigService.truncate_body( + db, request_body, is_request=True + ) + if response_body: + processed_response_body = SystemConfigService.truncate_body( + db, response_body, is_request=False + ) + + # 处理响应头 + processed_response_headers = None + if should_log_headers and response_headers: + processed_response_headers = SystemConfigService.mask_sensitive_headers( + db, response_headers + ) + + # 计算真实成本(表面成本 * 倍率),免费套餐实际费用为 0 + if is_free_tier: + actual_input_cost = 0.0 + actual_output_cost = 0.0 + actual_cache_creation_cost = 0.0 + actual_cache_read_cost = 0.0 + actual_request_cost = 0.0 + actual_total_cost = 0.0 + else: + actual_input_cost = input_cost * actual_rate_multiplier + actual_output_cost = output_cost * actual_rate_multiplier + actual_cache_creation_cost = cache_creation_cost * actual_rate_multiplier + actual_cache_read_cost = cache_read_cost * actual_rate_multiplier + actual_request_cost = request_cost * actual_rate_multiplier + actual_total_cost = total_cost * actual_rate_multiplier + + return { + "user_id": user.id if user else None, + "api_key_id": api_key.id if api_key else None, + "request_id": request_id, + "provider": provider, + "model": model, + "target_model": target_model, + "provider_id": provider_id, + "provider_endpoint_id": provider_endpoint_id, + "provider_api_key_id": provider_api_key_id, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "cache_creation_input_tokens": cache_creation_input_tokens, + "cache_read_input_tokens": cache_read_input_tokens, + "input_cost_usd": input_cost, + "output_cost_usd": output_cost, + "cache_cost_usd": cache_cost, + "cache_creation_cost_usd": cache_creation_cost, + "cache_read_cost_usd": cache_read_cost, + "request_cost_usd": request_cost, + "total_cost_usd": total_cost, + "actual_input_cost_usd": actual_input_cost, + "actual_output_cost_usd": actual_output_cost, + "actual_cache_creation_cost_usd": actual_cache_creation_cost, + "actual_cache_read_cost_usd": actual_cache_read_cost, + "actual_request_cost_usd": actual_request_cost, + "actual_total_cost_usd": actual_total_cost, + "rate_multiplier": actual_rate_multiplier, + "input_price_per_1m": input_price, + "output_price_per_1m": output_price, + "cache_creation_price_per_1m": cache_creation_price, + "cache_read_price_per_1m": cache_read_price, + "price_per_request": request_price, + "request_type": request_type, + "api_format": api_format, + "is_stream": is_stream, + "status_code": status_code, + "error_message": error_message, + "response_time_ms": response_time_ms, + "first_byte_time_ms": first_byte_time_ms, + "status": status, + "request_metadata": metadata, + "request_headers": processed_request_headers, + "request_body": processed_request_body, + "provider_request_headers": processed_provider_request_headers, + "response_headers": processed_response_headers, + "response_body": processed_response_body, + } + + @classmethod + async def _get_rate_multiplier_and_free_tier( + cls, + db: Session, + provider_api_key_id: Optional[str], + provider_id: Optional[str], + ) -> Tuple[float, bool]: + """获取费率倍数和是否免费套餐""" + actual_rate_multiplier = 1.0 + if provider_api_key_id: + provider_key = ( + db.query(ProviderAPIKey).filter(ProviderAPIKey.id == provider_api_key_id).first() + ) + if provider_key and provider_key.rate_multiplier: + actual_rate_multiplier = provider_key.rate_multiplier + + is_free_tier = False + if provider_id: + provider_obj = db.query(Provider).filter(Provider.id == provider_id).first() + if provider_obj and provider_obj.billing_type == ProviderBillingType.FREE_TIER: + is_free_tier = True + + return actual_rate_multiplier, is_free_tier + + @classmethod + async def _calculate_costs( + cls, + db: Session, + provider: str, + model: str, + input_tokens: int, + output_tokens: int, + cache_creation_input_tokens: int, + cache_read_input_tokens: int, + api_format: Optional[str], + cache_ttl_minutes: Optional[int], + use_tiered_pricing: bool, + is_failed_request: bool, + ) -> Tuple[float, float, float, float, float, float, float, float, float, + Optional[float], Optional[float], Optional[float], Optional[int]]: + """计算所有成本相关数据 + + Returns: + (input_price, output_price, cache_creation_price, cache_read_price, request_price, + input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, + request_cost, total_cost, tier_index) + """ + # 获取模型价格信息 + input_price, output_price = await cls.get_model_price_async(db, provider, model) + cache_creation_price, cache_read_price = await cls.get_cache_prices_async( + db, provider, model, input_price + ) + request_price = await cls.get_request_price_async(db, provider, model) + effective_request_price = None if is_failed_request else request_price + + # 初始化成本变量 + input_cost = 0.0 + output_cost = 0.0 + cache_creation_cost = 0.0 + cache_read_cost = 0.0 + cache_cost = 0.0 + request_cost = 0.0 + total_cost = 0.0 + tier_index = None + + if use_tiered_pricing: + ( + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + cache_cost, + request_cost, + total_cost, + tier_index, + ) = await cls.calculate_cost_with_strategy_async( + db=db, + provider=provider, + model=model, + input_tokens=input_tokens, + output_tokens=output_tokens, + cache_creation_input_tokens=cache_creation_input_tokens, + cache_read_input_tokens=cache_read_input_tokens, + api_format=api_format, + cache_ttl_minutes=cache_ttl_minutes, + ) + if is_failed_request: + total_cost = total_cost - request_cost + request_cost = 0.0 + else: + ( + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + cache_cost, + request_cost, + total_cost, + ) = cls.calculate_cost( + input_tokens=input_tokens, + output_tokens=output_tokens, + input_price_per_1m=input_price, + output_price_per_1m=output_price, + cache_creation_input_tokens=cache_creation_input_tokens, + cache_read_input_tokens=cache_read_input_tokens, + cache_creation_price_per_1m=cache_creation_price, + cache_read_price_per_1m=cache_read_price, + price_per_request=effective_request_price, + ) + + return ( + input_price, output_price, cache_creation_price, cache_read_price, request_price, + input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, + request_cost, total_cost, tier_index + ) + + @staticmethod + def _update_existing_usage( + existing_usage: Usage, + usage_params: Dict[str, Any], + target_model: Optional[str], + ) -> None: + """更新已存在的 Usage 记录(内部方法)""" + # 更新关键字段 + existing_usage.provider = usage_params["provider"] + existing_usage.status = usage_params["status"] + existing_usage.status_code = usage_params["status_code"] + existing_usage.error_message = usage_params["error_message"] + existing_usage.response_time_ms = usage_params["response_time_ms"] + existing_usage.first_byte_time_ms = usage_params["first_byte_time_ms"] + + # 更新请求头和请求体(如果有新值) + if usage_params["request_headers"] is not None: + existing_usage.request_headers = usage_params["request_headers"] + if usage_params["request_body"] is not None: + existing_usage.request_body = usage_params["request_body"] + if usage_params["provider_request_headers"] is not None: + existing_usage.provider_request_headers = usage_params["provider_request_headers"] + existing_usage.response_body = usage_params["response_body"] + existing_usage.response_headers = usage_params["response_headers"] + + # 更新 token 和费用信息 + existing_usage.input_tokens = usage_params["input_tokens"] + existing_usage.output_tokens = usage_params["output_tokens"] + existing_usage.total_tokens = usage_params["total_tokens"] + existing_usage.cache_creation_input_tokens = usage_params["cache_creation_input_tokens"] + existing_usage.cache_read_input_tokens = usage_params["cache_read_input_tokens"] + existing_usage.input_cost_usd = usage_params["input_cost_usd"] + existing_usage.output_cost_usd = usage_params["output_cost_usd"] + existing_usage.cache_cost_usd = usage_params["cache_cost_usd"] + existing_usage.cache_creation_cost_usd = usage_params["cache_creation_cost_usd"] + existing_usage.cache_read_cost_usd = usage_params["cache_read_cost_usd"] + existing_usage.request_cost_usd = usage_params["request_cost_usd"] + existing_usage.total_cost_usd = usage_params["total_cost_usd"] + existing_usage.actual_input_cost_usd = usage_params["actual_input_cost_usd"] + existing_usage.actual_output_cost_usd = usage_params["actual_output_cost_usd"] + existing_usage.actual_cache_creation_cost_usd = usage_params["actual_cache_creation_cost_usd"] + existing_usage.actual_cache_read_cost_usd = usage_params["actual_cache_read_cost_usd"] + existing_usage.actual_request_cost_usd = usage_params["actual_request_cost_usd"] + existing_usage.actual_total_cost_usd = usage_params["actual_total_cost_usd"] + existing_usage.rate_multiplier = usage_params["rate_multiplier"] + + # 更新 Provider 侧追踪信息 + existing_usage.provider_id = usage_params["provider_id"] + existing_usage.provider_endpoint_id = usage_params["provider_endpoint_id"] + existing_usage.provider_api_key_id = usage_params["provider_api_key_id"] + + # 更新模型映射信息 + if target_model is not None: + existing_usage.target_model = target_model + + # ==================== 公开 API ==================== + @classmethod async def get_model_price_async( cls, db: Session, provider: str, model: str @@ -157,233 +487,112 @@ class UsageService: api_format: Optional[str] = None, is_stream: bool = False, response_time_ms: Optional[int] = None, - first_byte_time_ms: Optional[int] = None, # 首字时间 (TTFB) + first_byte_time_ms: Optional[int] = None, status_code: int = 200, error_message: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, request_headers: Optional[Dict[str, Any]] = None, request_body: Optional[Any] = None, - provider_request_headers: Optional[Dict[str, Any]] = None, # 向提供商发送的请求头 + provider_request_headers: Optional[Dict[str, Any]] = None, response_headers: Optional[Dict[str, Any]] = None, response_body: Optional[Any] = None, - request_id: Optional[str] = None, # 请求ID,如果未提供则自动生成 - # Provider 侧追踪信息(记录最终成功的 Provider/Endpoint/Key) + request_id: Optional[str] = None, provider_id: Optional[str] = None, provider_endpoint_id: Optional[str] = None, provider_api_key_id: Optional[str] = None, - # 请求状态 (pending, streaming, completed, failed) status: str = "completed", - # 阶梯计费相关参数 - cache_ttl_minutes: Optional[int] = None, # 缓存时长(用于 TTL 差异化定价) - use_tiered_pricing: bool = True, # 是否使用阶梯计费(默认启用) - # 模型映射信息 - target_model: Optional[str] = None, # 映射后的目标模型名 + cache_ttl_minutes: Optional[int] = None, + use_tiered_pricing: bool = True, + target_model: Optional[str] = None, ) -> Usage: - """异步记录使用量(支持阶梯计费)""" + """异步记录使用量(简化版,仅插入新记录) - # 使用传入的 request_id 或生成新的 + 此方法用于快速记录使用量,不更新用户/API Key 统计,不支持更新已存在的记录。 + 适用于不需要更新统计信息的场景。 + + 如需完整功能(更新用户统计、支持更新已存在记录),请使用 record_usage()。 + """ + # 生成 request_id if request_id is None: - request_id = str(uuid.uuid4())[:8] # 生成8位短ID以保持一致性 + request_id = str(uuid.uuid4())[:8] - # 如果提供了 provider_api_key_id,从数据库查询 rate_multiplier - actual_rate_multiplier = 1.0 # 默认值 - if provider_api_key_id: - provider_key = ( - db.query(ProviderAPIKey).filter(ProviderAPIKey.id == provider_api_key_id).first() - ) - if provider_key and provider_key.rate_multiplier: - actual_rate_multiplier = provider_key.rate_multiplier - - # 失败的请求不应该计入按次计费费用 - is_failed_request = status_code >= 400 or error_message is not None - - # 获取模型价格信息(用于历史记录) - input_price, output_price = await cls.get_model_price_async(db, provider, model) - cache_creation_price, cache_read_price = await cls.get_cache_prices_async( - db, provider, model, input_price + # 获取费率倍数和是否免费套餐 + actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier( + db, provider_api_key_id, provider_id ) - request_price = await cls.get_request_price_async(db, provider, model) - effective_request_price = None if is_failed_request else request_price - # 初始化成本变量(避免 `in locals()` 反模式) - input_cost = 0.0 - output_cost = 0.0 - cache_creation_cost = 0.0 - cache_read_cost = 0.0 - cache_cost = 0.0 - request_cost = 0.0 - total_cost = 0.0 - tier_index = None - - # 计算成本(支持阶梯计费) - if use_tiered_pricing: - # 使用策略模式计算成本(支持阶梯计费和 TTL 差异化) - ( - input_cost, - output_cost, - cache_creation_cost, - cache_read_cost, - cache_cost, - request_cost, - total_cost, - tier_index, - ) = await cls.calculate_cost_with_strategy_async( - db=db, - provider=provider, - model=model, - input_tokens=input_tokens, - output_tokens=output_tokens, - cache_creation_input_tokens=cache_creation_input_tokens, - cache_read_input_tokens=cache_read_input_tokens, - api_format=api_format, - cache_ttl_minutes=cache_ttl_minutes, - ) - # 如果失败请求,重置按次费用 - if is_failed_request: - total_cost = total_cost - request_cost - request_cost = 0.0 - else: - # 使用固定价格模式 - ( - input_cost, - output_cost, - cache_creation_cost, - cache_read_cost, - cache_cost, - request_cost, - total_cost, - ) = cls.calculate_cost( - input_tokens=input_tokens, - output_tokens=output_tokens, - input_price_per_1m=input_price, - output_price_per_1m=output_price, - cache_creation_input_tokens=cache_creation_input_tokens, - cache_read_input_tokens=cache_read_input_tokens, - cache_creation_price_per_1m=cache_creation_price, - cache_read_price_per_1m=cache_read_price, - price_per_request=effective_request_price, - ) - - # 根据配置决定是否记录请求详情 - should_log_headers = SystemConfigService.should_log_headers(db) - should_log_body = SystemConfigService.should_log_body(db) - - # 处理请求头(可能需要脱敏) - processed_request_headers = None - if should_log_headers and request_headers: - processed_request_headers = SystemConfigService.mask_sensitive_headers( - db, request_headers - ) - - # 处理提供商请求头(可能需要脱敏) - processed_provider_request_headers = None - if should_log_headers and provider_request_headers: - processed_provider_request_headers = SystemConfigService.mask_sensitive_headers( - db, provider_request_headers - ) - - # 处理请求体和响应体(可能需要截断) - processed_request_body = None - processed_response_body = None - if should_log_body: - if request_body: - processed_request_body = SystemConfigService.truncate_body( - db, request_body, is_request=True - ) - if response_body: - processed_response_body = SystemConfigService.truncate_body( - db, response_body, is_request=False - ) - - # 处理响应头 - processed_response_headers = None - if should_log_headers and response_headers: - processed_response_headers = SystemConfigService.mask_sensitive_headers( - db, response_headers - ) - - # 检查 Provider 的计费类型,免费套餐的实际费用为 0 - is_free_tier = False - if provider_id: - provider_obj = db.query(Provider).filter(Provider.id == provider_id).first() - if provider_obj and provider_obj.billing_type == ProviderBillingType.FREE_TIER: - is_free_tier = True - - # 计算真实成本(表面成本 × 倍率),免费套餐实际费用为 0 - if is_free_tier: - actual_input_cost = 0.0 - actual_output_cost = 0.0 - actual_cache_creation_cost = 0.0 - actual_cache_read_cost = 0.0 - actual_request_cost = 0.0 - actual_total_cost = 0.0 - else: - actual_input_cost = input_cost * actual_rate_multiplier - actual_output_cost = output_cost * actual_rate_multiplier - actual_cache_creation_cost = cache_creation_cost * actual_rate_multiplier - actual_cache_read_cost = cache_read_cost * actual_rate_multiplier - actual_request_cost = request_cost * actual_rate_multiplier - actual_total_cost = total_cost * actual_rate_multiplier - - # 记录使用量 - usage = Usage( - user_id=user.id if user else None, - api_key_id=api_key.id if api_key else None, - request_id=request_id, + # 计算成本 + is_failed_request = status_code >= 400 or error_message is not None + ( + input_price, output_price, cache_creation_price, cache_read_price, request_price, + input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, + request_cost, total_cost, tier_index + ) = await cls._calculate_costs( + db=db, + provider=provider, + model=model, + input_tokens=input_tokens, + output_tokens=output_tokens, + cache_creation_input_tokens=cache_creation_input_tokens, + cache_read_input_tokens=cache_read_input_tokens, + api_format=api_format, + cache_ttl_minutes=cache_ttl_minutes, + use_tiered_pricing=use_tiered_pricing, + is_failed_request=is_failed_request, + ) + + # 构建 Usage 参数 + usage_params = cls._build_usage_params( + db=db, + user=user, + api_key=api_key, provider=provider, model=model, - target_model=target_model, # 映射后的目标模型名 - # Provider 侧追踪信息 - provider_id=provider_id, - provider_endpoint_id=provider_endpoint_id, - provider_api_key_id=provider_api_key_id, input_tokens=input_tokens, output_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, cache_creation_input_tokens=cache_creation_input_tokens, cache_read_input_tokens=cache_read_input_tokens, - input_cost_usd=input_cost, - output_cost_usd=output_cost, - cache_cost_usd=cache_cost, - cache_creation_cost_usd=cache_creation_cost, - cache_read_cost_usd=cache_read_cost, - request_cost_usd=request_cost, - total_cost_usd=total_cost, - # 真实成本(考虑倍率) - actual_input_cost_usd=actual_input_cost, - actual_output_cost_usd=actual_output_cost, - actual_cache_creation_cost_usd=actual_cache_creation_cost, - actual_cache_read_cost_usd=actual_cache_read_cost, - actual_request_cost_usd=actual_request_cost, - actual_total_cost_usd=actual_total_cost, - rate_multiplier=actual_rate_multiplier, # 使用实际查询到的 rate_multiplier - # 添加历史价格信息 - input_price_per_1m=input_price, - output_price_per_1m=output_price, - cache_creation_price_per_1m=cache_creation_price, - cache_read_price_per_1m=cache_read_price, - price_per_request=request_price, request_type=request_type, api_format=api_format, is_stream=is_stream, + response_time_ms=response_time_ms, + first_byte_time_ms=first_byte_time_ms, status_code=status_code, error_message=error_message, - response_time_ms=response_time_ms, - first_byte_time_ms=first_byte_time_ms, # 首字时间 (TTFB) - status=status, # 请求状态追踪 - request_metadata=metadata, - request_headers=processed_request_headers, - request_body=processed_request_body, - provider_request_headers=processed_provider_request_headers, - response_headers=processed_response_headers, - response_body=processed_response_body, + metadata=metadata, + request_headers=request_headers, + request_body=request_body, + provider_request_headers=provider_request_headers, + response_headers=response_headers, + response_body=response_body, + request_id=request_id, + provider_id=provider_id, + provider_endpoint_id=provider_endpoint_id, + provider_api_key_id=provider_api_key_id, + status=status, + target_model=target_model, + input_cost=input_cost, + output_cost=output_cost, + cache_creation_cost=cache_creation_cost, + cache_read_cost=cache_read_cost, + cache_cost=cache_cost, + request_cost=request_cost, + total_cost=total_cost, + input_price=input_price, + output_price=output_price, + cache_creation_price=cache_creation_price, + cache_read_price=cache_read_price, + request_price=request_price, + actual_rate_multiplier=actual_rate_multiplier, + is_free_tier=is_free_tier, ) + # 创建 Usage 记录 + usage = Usage(**usage_params) db.add(usage) # 更新 GlobalModel 使用计数(原子操作) from sqlalchemy import update - from src.models.database import GlobalModel db.execute( @@ -392,17 +601,16 @@ class UsageService: .values(usage_count=GlobalModel.usage_count + 1) ) - # 更新 Provider 月度使用量(原子操作)- 使用实际费用(免费套餐为 0) + # 更新 Provider 月度使用量(原子操作) if provider_id: + actual_total_cost = usage_params["actual_total_cost_usd"] db.execute( update(Provider) .where(Provider.id == provider_id) .values(monthly_used_usd=Provider.monthly_used_usd + actual_total_cost) ) - db.commit() # 立即提交事务,释放数据库锁 - # 不需要 refresh,commit 后对象已经有数据库生成的值 - + db.commit() # 立即提交事务,释放数据库锁 return usage @classmethod @@ -421,7 +629,7 @@ class UsageService: api_format: Optional[str] = None, is_stream: bool = False, response_time_ms: Optional[int] = None, - first_byte_time_ms: Optional[int] = None, # 首字时间 (TTFB) + first_byte_time_ms: Optional[int] = None, status_code: int = 200, error_message: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, @@ -430,272 +638,113 @@ class UsageService: provider_request_headers: Optional[Dict[str, Any]] = None, response_headers: Optional[Dict[str, Any]] = None, response_body: Optional[Any] = None, - request_id: Optional[str] = None, # 请求ID,如果未提供则自动生成 - # Provider 侧追踪信息(记录最终成功的 Provider/Endpoint/Key) + request_id: Optional[str] = None, provider_id: Optional[str] = None, provider_endpoint_id: Optional[str] = None, provider_api_key_id: Optional[str] = None, - # 请求状态 (pending, streaming, completed, failed) status: str = "completed", - # 阶梯计费相关参数 - cache_ttl_minutes: Optional[int] = None, # 缓存时长(用于 TTL 差异化定价) - use_tiered_pricing: bool = True, # 是否使用阶梯计费(默认启用) - # 模型映射信息 - target_model: Optional[str] = None, # 映射后的目标模型名 + cache_ttl_minutes: Optional[int] = None, + use_tiered_pricing: bool = True, + target_model: Optional[str] = None, ) -> Usage: - """记录使用量(支持阶梯计费)""" + """记录使用量(完整版,支持更新已存在记录和用户统计) - # 使用传入的 request_id 或生成新的 + 此方法支持: + - 检查是否已存在相同 request_id 的记录(更新 vs 插入) + - 更新用户/API Key 使用统计 + - 阶梯计费 + + 如只需简单插入新记录,可使用 record_usage_async()。 + """ + # 生成 request_id if request_id is None: - request_id = str(uuid.uuid4())[:8] # 生成8位短ID以保持一致性 + request_id = str(uuid.uuid4())[:8] - # 如果提供了 provider_api_key_id,从数据库查询 rate_multiplier - actual_rate_multiplier = 1.0 # 默认值 - if provider_api_key_id: - provider_key = ( - db.query(ProviderAPIKey).filter(ProviderAPIKey.id == provider_api_key_id).first() - ) - if provider_key and provider_key.rate_multiplier: - actual_rate_multiplier = provider_key.rate_multiplier - - # 失败的请求不应该计入按次计费费用 - is_failed_request = status_code >= 400 or error_message is not None - - # 获取模型价格信息(用于历史记录) - input_price, output_price = await cls.get_model_price_async(db, provider, model) - cache_creation_price, cache_read_price = await cls.get_cache_prices_async( - db, provider, model, input_price + # 获取费率倍数和是否免费套餐 + actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier( + db, provider_api_key_id, provider_id ) - request_price = await cls.get_request_price_async(db, provider, model) - effective_request_price = None if is_failed_request else request_price - # 初始化成本变量(避免 `in locals()` 反模式) - input_cost = 0.0 - output_cost = 0.0 - cache_creation_cost = 0.0 - cache_read_cost = 0.0 - cache_cost = 0.0 - request_cost = 0.0 - total_cost = 0.0 - - # 计算成本(支持阶梯计费) - if use_tiered_pricing: - # 使用策略模式计算成本(支持阶梯计费和 TTL 差异化) - ( - input_cost, - output_cost, - cache_creation_cost, - cache_read_cost, - cache_cost, - request_cost, - total_cost, - _tier_index, - ) = await cls.calculate_cost_with_strategy_async( - db=db, - provider=provider, - model=model, - input_tokens=input_tokens, - output_tokens=output_tokens, - cache_creation_input_tokens=cache_creation_input_tokens, - cache_read_input_tokens=cache_read_input_tokens, - api_format=api_format, - cache_ttl_minutes=cache_ttl_minutes, - ) - # 如果失败请求,重置按次费用 - if is_failed_request: - total_cost = total_cost - request_cost - request_cost = 0.0 - else: - # 使用固定价格模式 - ( - input_cost, - output_cost, - cache_creation_cost, - cache_read_cost, - cache_cost, - request_cost, - total_cost, - ) = cls.calculate_cost( - input_tokens, - output_tokens, - input_price, - output_price, - cache_creation_input_tokens, - cache_read_input_tokens, - cache_creation_price, - cache_read_price, - effective_request_price, - ) - - # 根据配置决定是否记录请求详情 - should_log_headers = SystemConfigService.should_log_headers(db) - should_log_body = SystemConfigService.should_log_body(db) - - # 处理请求头(可能需要脱敏) - processed_request_headers = None - if should_log_headers and request_headers: - processed_request_headers = SystemConfigService.mask_sensitive_headers( - db, request_headers - ) - - # 处理提供商请求头(可能需要脱敏) - processed_provider_request_headers = None - if should_log_headers and provider_request_headers: - processed_provider_request_headers = SystemConfigService.mask_sensitive_headers( - db, provider_request_headers - ) - - # 处理请求体和响应体(可能需要截断) - processed_request_body = None - processed_response_body = None - if should_log_body: - if request_body: - processed_request_body = SystemConfigService.truncate_body( - db, request_body, is_request=True - ) - if response_body: - processed_response_body = SystemConfigService.truncate_body( - db, response_body, is_request=False - ) - - # 处理响应头 - processed_response_headers = None - if should_log_headers and response_headers: - processed_response_headers = SystemConfigService.mask_sensitive_headers( - db, response_headers - ) - - # 检查 Provider 的计费类型,免费套餐的实际费用为 0 - is_free_tier = False - if provider_id: - provider_obj = db.query(Provider).filter(Provider.id == provider_id).first() - if provider_obj and provider_obj.billing_type == ProviderBillingType.FREE_TIER: - is_free_tier = True - - # 计算真实成本(表面成本 × 倍率),免费套餐实际费用为 0 - if is_free_tier: - actual_input_cost = 0.0 - actual_output_cost = 0.0 - actual_cache_creation_cost = 0.0 - actual_cache_read_cost = 0.0 - actual_request_cost = 0.0 - actual_total_cost = 0.0 - else: - actual_input_cost = input_cost * actual_rate_multiplier - actual_output_cost = output_cost * actual_rate_multiplier - actual_cache_creation_cost = cache_creation_cost * actual_rate_multiplier - actual_cache_read_cost = cache_read_cost * actual_rate_multiplier - actual_request_cost = request_cost * actual_rate_multiplier - actual_total_cost = total_cost * actual_rate_multiplier - - # 创建使用记录 - usage = Usage( - user_id=user.id if user else None, - api_key_id=api_key.id if api_key else None, - request_id=request_id, + # 计算成本 + is_failed_request = status_code >= 400 or error_message is not None + ( + input_price, output_price, cache_creation_price, cache_read_price, request_price, + input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, + request_cost, total_cost, _tier_index + ) = await cls._calculate_costs( + db=db, + provider=provider, + model=model, + input_tokens=input_tokens, + output_tokens=output_tokens, + cache_creation_input_tokens=cache_creation_input_tokens, + cache_read_input_tokens=cache_read_input_tokens, + api_format=api_format, + cache_ttl_minutes=cache_ttl_minutes, + use_tiered_pricing=use_tiered_pricing, + is_failed_request=is_failed_request, + ) + + # 构建 Usage 参数 + usage_params = cls._build_usage_params( + db=db, + user=user, + api_key=api_key, provider=provider, model=model, - target_model=target_model, # 映射后的目标模型名 - # Provider 侧追踪信息 - provider_id=provider_id, - provider_endpoint_id=provider_endpoint_id, - provider_api_key_id=provider_api_key_id, input_tokens=input_tokens, output_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, cache_creation_input_tokens=cache_creation_input_tokens, cache_read_input_tokens=cache_read_input_tokens, - input_cost_usd=input_cost, - output_cost_usd=output_cost, - cache_cost_usd=cache_cost, - cache_creation_cost_usd=cache_creation_cost, - cache_read_cost_usd=cache_read_cost, - request_cost_usd=request_cost, - total_cost_usd=total_cost, - # 真实成本(考虑倍率) - actual_input_cost_usd=actual_input_cost, - actual_output_cost_usd=actual_output_cost, - actual_cache_creation_cost_usd=actual_cache_creation_cost, - actual_cache_read_cost_usd=actual_cache_read_cost, - actual_request_cost_usd=actual_request_cost, - actual_total_cost_usd=actual_total_cost, - rate_multiplier=actual_rate_multiplier, # 使用实际查询到的 rate_multiplier - # 添加历史价格信息 - input_price_per_1m=input_price, - output_price_per_1m=output_price, - cache_creation_price_per_1m=cache_creation_price, - cache_read_price_per_1m=cache_read_price, - price_per_request=request_price, request_type=request_type, api_format=api_format, is_stream=is_stream, + response_time_ms=response_time_ms, + first_byte_time_ms=first_byte_time_ms, status_code=status_code, error_message=error_message, - response_time_ms=response_time_ms, - first_byte_time_ms=first_byte_time_ms, # 首字时间 (TTFB) - status=status, # 请求状态追踪 - request_metadata=metadata, - request_headers=processed_request_headers, - request_body=processed_request_body, - provider_request_headers=processed_provider_request_headers, - response_headers=processed_response_headers, - response_body=processed_response_body, + metadata=metadata, + request_headers=request_headers, + request_body=request_body, + provider_request_headers=provider_request_headers, + response_headers=response_headers, + response_body=response_body, + request_id=request_id, + provider_id=provider_id, + provider_endpoint_id=provider_endpoint_id, + provider_api_key_id=provider_api_key_id, + status=status, + target_model=target_model, + input_cost=input_cost, + output_cost=output_cost, + cache_creation_cost=cache_creation_cost, + cache_read_cost=cache_read_cost, + cache_cost=cache_cost, + request_cost=request_cost, + total_cost=total_cost, + input_price=input_price, + output_price=output_price, + cache_creation_price=cache_creation_price, + cache_read_price=cache_read_price, + request_price=request_price, + actual_rate_multiplier=actual_rate_multiplier, + is_free_tier=is_free_tier, ) - # 检查是否已存在相同 request_id 的记录(用于更新 pending 记录或防止重试时重复插入) + # 检查是否已存在相同 request_id 的记录 existing_usage = db.query(Usage).filter(Usage.request_id == request_id).first() if existing_usage: - # 已存在记录,更新而非插入 - logger.debug(f"request_id {request_id} 已存在,更新现有记录 (status: {existing_usage.status} -> {status})") - # 更新关键字段 - existing_usage.provider = provider # 更新 provider 名称 - existing_usage.status = status - existing_usage.status_code = status_code - existing_usage.error_message = error_message - existing_usage.response_time_ms = response_time_ms - existing_usage.first_byte_time_ms = first_byte_time_ms # 更新首字时间 - # 更新请求头和请求体(如果有新值) - if processed_request_headers is not None: - existing_usage.request_headers = processed_request_headers - if processed_request_body is not None: - existing_usage.request_body = processed_request_body - if processed_provider_request_headers is not None: - existing_usage.provider_request_headers = processed_provider_request_headers - existing_usage.response_body = processed_response_body - existing_usage.response_headers = processed_response_headers - # 更新 token 和费用信息 - existing_usage.input_tokens = input_tokens - existing_usage.output_tokens = output_tokens - existing_usage.total_tokens = input_tokens + output_tokens - existing_usage.cache_creation_input_tokens = cache_creation_input_tokens - existing_usage.cache_read_input_tokens = cache_read_input_tokens - existing_usage.input_cost_usd = input_cost - existing_usage.output_cost_usd = output_cost - existing_usage.cache_cost_usd = cache_cost - existing_usage.cache_creation_cost_usd = cache_creation_cost - existing_usage.cache_read_cost_usd = cache_read_cost - existing_usage.request_cost_usd = request_cost - existing_usage.total_cost_usd = total_cost - existing_usage.actual_input_cost_usd = actual_input_cost - existing_usage.actual_output_cost_usd = actual_output_cost - existing_usage.actual_cache_creation_cost_usd = actual_cache_creation_cost - existing_usage.actual_cache_read_cost_usd = actual_cache_read_cost - existing_usage.actual_request_cost_usd = actual_request_cost - existing_usage.actual_total_cost_usd = actual_total_cost - existing_usage.rate_multiplier = actual_rate_multiplier - # 更新 Provider 侧追踪信息 - existing_usage.provider_id = provider_id - existing_usage.provider_endpoint_id = provider_endpoint_id - existing_usage.provider_api_key_id = provider_api_key_id - # 更新模型映射信息 - if target_model is not None: - existing_usage.target_model = target_model - # 不需要 db.add,已在会话中 + logger.debug( + f"request_id {request_id} 已存在,更新现有记录 " + f"(status: {existing_usage.status} -> {status})" + ) + cls._update_existing_usage(existing_usage, usage_params, target_model) usage = existing_usage else: + usage = Usage(**usage_params) db.add(usage) - # 确保 user 和 api_key 在会话中(如果存在) + # 确保 user 和 api_key 在会话中 if user and not db.object_session(user): user = db.merge(user) if api_key and not db.object_session(api_key): @@ -703,87 +752,70 @@ class UsageService: # 使用原子更新避免并发竞态条件 from sqlalchemy import func, update + from src.models.database import ApiKey as ApiKeyModel, User as UserModel, GlobalModel - from src.models.database import ApiKey, User - - # 更新用户使用量(原子操作)- 使用标准计费价格 - # 独立Key不计入创建者的使用记录 + # 更新用户使用量(独立 Key 不计入创建者的使用记录) if user and not (api_key and api_key.is_standalone): db.execute( - update(User) - .where(User.id == user.id) + update(UserModel) + .where(UserModel.id == user.id) .values( - used_usd=User.used_usd + total_cost, - total_usd=User.total_usd + total_cost, + used_usd=UserModel.used_usd + total_cost, + total_usd=UserModel.total_usd + total_cost, updated_at=func.now(), ) ) - # 更新API密钥使用量(原子操作)- 使用标准计费价格 + # 更新 API 密钥使用量 if api_key: - # 独立余额Key需要扣除余额 if api_key.is_standalone: db.execute( - update(ApiKey) - .where(ApiKey.id == api_key.id) + update(ApiKeyModel) + .where(ApiKeyModel.id == api_key.id) .values( - total_requests=ApiKey.total_requests + 1, - total_cost_usd=ApiKey.total_cost_usd + total_cost, - balance_used_usd=ApiKey.balance_used_usd + total_cost, + total_requests=ApiKeyModel.total_requests + 1, + total_cost_usd=ApiKeyModel.total_cost_usd + total_cost, + balance_used_usd=ApiKeyModel.balance_used_usd + total_cost, last_used_at=func.now(), updated_at=func.now(), ) ) else: - # 普通Key只更新统计信息,不扣除余额 db.execute( - update(ApiKey) - .where(ApiKey.id == api_key.id) + update(ApiKeyModel) + .where(ApiKeyModel.id == api_key.id) .values( - total_requests=ApiKey.total_requests + 1, - total_cost_usd=ApiKey.total_cost_usd + total_cost, + total_requests=ApiKeyModel.total_requests + 1, + total_cost_usd=ApiKeyModel.total_cost_usd + total_cost, last_used_at=func.now(), updated_at=func.now(), ) ) - # 更新 GlobalModel 使用计数(原子操作) - from src.models.database import GlobalModel - + # 更新 GlobalModel 使用计数 db.execute( update(GlobalModel) .where(GlobalModel.name == model) .values(usage_count=GlobalModel.usage_count + 1) ) - # 更新 Provider 月度使用量(原子操作)- 使用实际费用(免费套餐为 0) + # 更新 Provider 月度使用量 if provider_id: + actual_total_cost = usage_params["actual_total_cost_usd"] db.execute( update(Provider) .where(Provider.id == provider_id) .values(monthly_used_usd=Provider.monthly_used_usd + actual_total_cost) ) - # 提交事务到数据库 + # 提交事务 try: - db.commit() # 立即提交事务,释放数据库锁 - - # 使用 expire 标记对象过期,下次访问时自动重新加载(避免死锁) - # 不在热路径上立即 refresh,避免行锁等待 - # db.expire(usage) - # if user: - # db.expire(user) - # if api_key: - # db.expire(api_key) - + db.commit() except Exception as e: logger.error(f"提交使用记录时出错: {e}") db.rollback() raise - # 不再记录重复的性能日志和访问日志,因为已经在完成日志中输出了 - # 这些信息都包含在 request_orchestrator.py 的完成日志中 - return usage @staticmethod