Files
Aether/src/services/cache/provider_cache.py
2025-12-10 20:52:44 +08:00

255 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Provider 配置缓存服务 - 减少 Provider/Endpoint/APIKey 查询
"""
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
from src.core.cache_service import CacheKeys, CacheService
from src.core.logger import logger
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
class ProviderCacheService:
"""Provider 配置缓存服务"""
# 缓存 TTL- 使用统一常量
CACHE_TTL = CacheTTL.PROVIDER
@staticmethod
async def get_provider_by_id(db: Session, provider_id: str) -> Optional[Provider]:
"""
获取 Provider带缓存
Args:
db: 数据库会话
provider_id: Provider ID
Returns:
Provider 对象或 None
"""
cache_key = CacheKeys.provider_by_id(provider_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"Provider 缓存命中: {provider_id}")
return ProviderCacheService._dict_to_provider(cached_data)
# 2. 缓存未命中,查询数据库
provider = db.query(Provider).filter(Provider.id == provider_id).first()
# 3. 写入缓存
if provider:
provider_dict = ProviderCacheService._provider_to_dict(provider)
await CacheService.set(
cache_key, provider_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"Provider 已缓存: {provider_id}")
return provider
@staticmethod
async def get_endpoint_by_id(db: Session, endpoint_id: str) -> Optional[ProviderEndpoint]:
"""
获取 Endpoint带缓存
Args:
db: 数据库会话
endpoint_id: Endpoint ID
Returns:
ProviderEndpoint 对象或 None
"""
cache_key = CacheKeys.endpoint_by_id(endpoint_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"Endpoint 缓存命中: {endpoint_id}")
return ProviderCacheService._dict_to_endpoint(cached_data)
# 2. 缓存未命中,查询数据库
endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == endpoint_id).first()
# 3. 写入缓存
if endpoint:
endpoint_dict = ProviderCacheService._endpoint_to_dict(endpoint)
await CacheService.set(
cache_key, endpoint_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"Endpoint 已缓存: {endpoint_id}")
return endpoint
@staticmethod
async def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[ProviderAPIKey]:
"""
获取 API Key带缓存
Args:
db: 数据库会话
api_key_id: API Key ID
Returns:
ProviderAPIKey 对象或 None
"""
cache_key = CacheKeys.api_key_by_id(api_key_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"API Key 缓存命中: {api_key_id}")
return ProviderCacheService._dict_to_api_key(cached_data)
# 2. 缓存未命中,查询数据库
api_key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == api_key_id).first()
# 3. 写入缓存
if api_key:
api_key_dict = ProviderCacheService._api_key_to_dict(api_key)
await CacheService.set(
cache_key, api_key_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"API Key 已缓存: {api_key_id}")
return api_key
@staticmethod
async def invalidate_provider_cache(provider_id: str):
"""
清除 Provider 缓存
Args:
provider_id: Provider ID
"""
await CacheService.delete(CacheKeys.provider_by_id(provider_id))
logger.debug(f"Provider 缓存已清除: {provider_id}")
@staticmethod
async def invalidate_endpoint_cache(endpoint_id: str):
"""
清除 Endpoint 缓存
Args:
endpoint_id: Endpoint ID
"""
await CacheService.delete(CacheKeys.endpoint_by_id(endpoint_id))
logger.debug(f"Endpoint 缓存已清除: {endpoint_id}")
@staticmethod
async def invalidate_api_key_cache(api_key_id: str):
"""
清除 API Key 缓存
Args:
api_key_id: API Key ID
"""
await CacheService.delete(CacheKeys.api_key_by_id(api_key_id))
logger.debug(f"API Key 缓存已清除: {api_key_id}")
@staticmethod
def _provider_to_dict(provider: Provider) -> dict:
"""将 Provider 对象转换为字典(用于缓存)"""
return {
"id": provider.id,
"name": provider.name,
"api_format": provider.api_format,
"base_url": provider.base_url,
"is_active": provider.is_active,
"priority": provider.priority,
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
"config": provider.config,
"description": provider.description,
}
@staticmethod
def _dict_to_provider(provider_dict: dict) -> Provider:
"""从字典重建 Provider 对象(分离的对象,不在 Session 中)"""
from datetime import datetime
provider = Provider(
id=provider_dict["id"],
name=provider_dict["name"],
api_format=provider_dict["api_format"],
base_url=provider_dict.get("base_url"),
is_active=provider_dict["is_active"],
priority=provider_dict.get("priority", 0),
rpm_limit=provider_dict.get("rpm_limit"),
rpm_used=provider_dict.get("rpm_used", 0),
config=provider_dict.get("config"),
description=provider_dict.get("description"),
)
if provider_dict.get("rpm_reset_at"):
provider.rpm_reset_at = datetime.fromisoformat(provider_dict["rpm_reset_at"])
return provider
@staticmethod
def _endpoint_to_dict(endpoint: ProviderEndpoint) -> dict:
"""将 Endpoint 对象转换为字典"""
return {
"id": endpoint.id,
"provider_id": endpoint.provider_id,
"name": endpoint.name,
"base_url": endpoint.base_url,
"is_active": endpoint.is_active,
"priority": endpoint.priority,
"weight": endpoint.weight,
"custom_path": endpoint.custom_path,
"config": endpoint.config,
}
@staticmethod
def _dict_to_endpoint(endpoint_dict: dict) -> ProviderEndpoint:
"""从字典重建 Endpoint 对象"""
endpoint = ProviderEndpoint(
id=endpoint_dict["id"],
provider_id=endpoint_dict["provider_id"],
name=endpoint_dict["name"],
base_url=endpoint_dict["base_url"],
is_active=endpoint_dict["is_active"],
priority=endpoint_dict.get("priority", 0),
weight=endpoint_dict.get("weight", 1.0),
custom_path=endpoint_dict.get("custom_path"),
config=endpoint_dict.get("config"),
)
return endpoint
@staticmethod
def _api_key_to_dict(api_key: ProviderAPIKey) -> dict:
"""将 API Key 对象转换为字典"""
return {
"id": api_key.id,
"endpoint_id": api_key.endpoint_id,
"key_value": api_key.key_value,
"is_active": api_key.is_active,
"max_rpm": api_key.max_rpm,
"current_rpm": api_key.current_rpm,
"health_score": api_key.health_score,
"circuit_breaker_state": api_key.circuit_breaker_state,
"adaptive_concurrency_limit": api_key.adaptive_concurrency_limit,
}
@staticmethod
def _dict_to_api_key(api_key_dict: dict) -> ProviderAPIKey:
"""从字典重建 API Key 对象"""
api_key = ProviderAPIKey(
id=api_key_dict["id"],
endpoint_id=api_key_dict["endpoint_id"],
key_value=api_key_dict["key_value"],
is_active=api_key_dict["is_active"],
max_rpm=api_key_dict.get("max_rpm"),
current_rpm=api_key_dict.get("current_rpm", 0),
health_score=api_key_dict.get("health_score", 1.0),
circuit_breaker_state=api_key_dict.get("circuit_breaker_state"),
adaptive_concurrency_limit=api_key_dict.get("adaptive_concurrency_limit"),
)
return api_key