Files
Aether/src/services/cache/provider_cache.py

172 lines
5.9 KiB
Python
Raw Normal View History

"""
Provider 缓存服务 - 减少 Provider ProviderAPIKey 查询
用于缓存 Provider billing_type ProviderAPIKey rate_multiplier
这些数据在 UsageService.record_usage() 中被频繁查询但变化不频繁
"""
from typing import Optional, Tuple
from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
from src.core.cache_service import CacheService
from src.core.enums import ProviderBillingType
from src.core.logger import logger
from src.models.database import Provider, ProviderAPIKey
class ProviderCacheService:
"""Provider 缓存服务
提供 Provider ProviderAPIKey 的缓存查询功能减少数据库访问
主要用于 UsageService 中获取费率倍数和计费类型
"""
CACHE_TTL = CacheTTL.PROVIDER # 5 分钟
@staticmethod
async def get_provider_api_key_rate_multiplier(
db: Session, provider_api_key_id: str
) -> Optional[float]:
"""
获取 ProviderAPIKey rate_multiplier带缓存
Args:
db: 数据库会话
provider_api_key_id: ProviderAPIKey ID
Returns:
rate_multiplier None如果找不到
"""
cache_key = f"provider_api_key:rate_multiplier:{provider_api_key_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data is not None:
logger.debug(f"ProviderAPIKey rate_multiplier 缓存命中: {provider_api_key_id[:8]}...")
# 缓存的 "NOT_FOUND" 表示数据库中不存在
if cached_data == "NOT_FOUND":
return None
return float(cached_data)
# 2. 缓存未命中,查询数据库
provider_key = (
db.query(ProviderAPIKey.rate_multiplier)
.filter(ProviderAPIKey.id == provider_api_key_id)
.first()
)
# 3. 写入缓存
if provider_key:
rate_multiplier = provider_key.rate_multiplier or 1.0
await CacheService.set(
cache_key, rate_multiplier, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"ProviderAPIKey rate_multiplier 已缓存: {provider_api_key_id[:8]}...")
return rate_multiplier
else:
# 缓存负结果
await CacheService.set(
cache_key, "NOT_FOUND", ttl_seconds=ProviderCacheService.CACHE_TTL
)
return None
@staticmethod
async def get_provider_billing_type(
db: Session, provider_id: str
) -> Optional[ProviderBillingType]:
"""
获取 Provider billing_type带缓存
Args:
db: 数据库会话
provider_id: Provider ID
Returns:
billing_type None如果找不到
"""
cache_key = f"provider:billing_type:{provider_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data is not None:
logger.debug(f"Provider billing_type 缓存命中: {provider_id[:8]}...")
if cached_data == "NOT_FOUND":
return None
try:
return ProviderBillingType(cached_data)
except ValueError:
# 缓存值无效,删除并重新查询
await CacheService.delete(cache_key)
# 2. 缓存未命中,查询数据库
provider = (
db.query(Provider.billing_type).filter(Provider.id == provider_id).first()
)
# 3. 写入缓存
if provider:
billing_type = provider.billing_type
await CacheService.set(
cache_key, billing_type.value, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"Provider billing_type 已缓存: {provider_id[:8]}...")
return billing_type
else:
# 缓存负结果
await CacheService.set(
cache_key, "NOT_FOUND", ttl_seconds=ProviderCacheService.CACHE_TTL
)
return None
@staticmethod
async def get_rate_multiplier_and_free_tier(
db: Session,
provider_api_key_id: Optional[str],
provider_id: Optional[str],
) -> Tuple[float, bool]:
"""
获取费率倍数和是否免费套餐带缓存
这是 UsageService._get_rate_multiplier_and_free_tier 的缓存版本
Args:
db: 数据库会话
provider_api_key_id: ProviderAPIKey ID可选
provider_id: Provider ID可选
Returns:
(rate_multiplier, is_free_tier) 元组
"""
actual_rate_multiplier = 1.0
is_free_tier = False
# 获取费率倍数
if provider_api_key_id:
rate_multiplier = await ProviderCacheService.get_provider_api_key_rate_multiplier(
db, provider_api_key_id
)
if rate_multiplier is not None:
actual_rate_multiplier = rate_multiplier
# 获取计费类型
if provider_id:
billing_type = await ProviderCacheService.get_provider_billing_type(db, provider_id)
if billing_type == ProviderBillingType.FREE_TIER:
is_free_tier = True
return actual_rate_multiplier, is_free_tier
@staticmethod
async def invalidate_provider_api_key_cache(provider_api_key_id: str) -> None:
"""清除 ProviderAPIKey 缓存"""
await CacheService.delete(f"provider_api_key:rate_multiplier:{provider_api_key_id}")
logger.debug(f"ProviderAPIKey 缓存已清除: {provider_api_key_id[:8]}...")
@staticmethod
async def invalidate_provider_cache(provider_id: str) -> None:
"""清除 Provider 缓存"""
await CacheService.delete(f"provider:billing_type:{provider_id}")
logger.debug(f"Provider 缓存已清除: {provider_id[:8]}...")