mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 16:22:27 +08:00
Initial commit
This commit is contained in:
254
src/services/cache/provider_cache.py
vendored
Normal file
254
src/services/cache/provider_cache.py
vendored
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Provider 配置缓存服务 - 减少 Provider/Endpoint/APIKey 查询
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.constants import CacheTTL
|
||||
from src.core.cache_service import CacheKeys, CacheService
|
||||
from src.core.logger import logger
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||||
|
||||
|
||||
|
||||
class ProviderCacheService:
|
||||
"""Provider 配置缓存服务"""
|
||||
|
||||
# 缓存 TTL(秒)- 使用统一常量
|
||||
CACHE_TTL = CacheTTL.PROVIDER
|
||||
|
||||
@staticmethod
|
||||
async def get_provider_by_id(db: Session, provider_id: str) -> Optional[Provider]:
|
||||
"""
|
||||
获取 Provider(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider_id: Provider ID
|
||||
|
||||
Returns:
|
||||
Provider 对象或 None
|
||||
"""
|
||||
cache_key = CacheKeys.provider_by_id(provider_id)
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"Provider 缓存命中: {provider_id}")
|
||||
return ProviderCacheService._dict_to_provider(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if provider:
|
||||
provider_dict = ProviderCacheService._provider_to_dict(provider)
|
||||
await CacheService.set(
|
||||
cache_key, provider_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"Provider 已缓存: {provider_id}")
|
||||
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
async def get_endpoint_by_id(db: Session, endpoint_id: str) -> Optional[ProviderEndpoint]:
|
||||
"""
|
||||
获取 Endpoint(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
endpoint_id: Endpoint ID
|
||||
|
||||
Returns:
|
||||
ProviderEndpoint 对象或 None
|
||||
"""
|
||||
cache_key = CacheKeys.endpoint_by_id(endpoint_id)
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"Endpoint 缓存命中: {endpoint_id}")
|
||||
return ProviderCacheService._dict_to_endpoint(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == endpoint_id).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if endpoint:
|
||||
endpoint_dict = ProviderCacheService._endpoint_to_dict(endpoint)
|
||||
await CacheService.set(
|
||||
cache_key, endpoint_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"Endpoint 已缓存: {endpoint_id}")
|
||||
|
||||
return endpoint
|
||||
|
||||
@staticmethod
|
||||
async def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[ProviderAPIKey]:
|
||||
"""
|
||||
获取 API Key(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
api_key_id: API Key ID
|
||||
|
||||
Returns:
|
||||
ProviderAPIKey 对象或 None
|
||||
"""
|
||||
cache_key = CacheKeys.api_key_by_id(api_key_id)
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"API Key 缓存命中: {api_key_id}")
|
||||
return ProviderCacheService._dict_to_api_key(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
api_key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == api_key_id).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if api_key:
|
||||
api_key_dict = ProviderCacheService._api_key_to_dict(api_key)
|
||||
await CacheService.set(
|
||||
cache_key, api_key_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"API Key 已缓存: {api_key_id}")
|
||||
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_provider_cache(provider_id: str):
|
||||
"""
|
||||
清除 Provider 缓存
|
||||
|
||||
Args:
|
||||
provider_id: Provider ID
|
||||
"""
|
||||
await CacheService.delete(CacheKeys.provider_by_id(provider_id))
|
||||
logger.debug(f"Provider 缓存已清除: {provider_id}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_endpoint_cache(endpoint_id: str):
|
||||
"""
|
||||
清除 Endpoint 缓存
|
||||
|
||||
Args:
|
||||
endpoint_id: Endpoint ID
|
||||
"""
|
||||
await CacheService.delete(CacheKeys.endpoint_by_id(endpoint_id))
|
||||
logger.debug(f"Endpoint 缓存已清除: {endpoint_id}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_api_key_cache(api_key_id: str):
|
||||
"""
|
||||
清除 API Key 缓存
|
||||
|
||||
Args:
|
||||
api_key_id: API Key ID
|
||||
"""
|
||||
await CacheService.delete(CacheKeys.api_key_by_id(api_key_id))
|
||||
logger.debug(f"API Key 缓存已清除: {api_key_id}")
|
||||
|
||||
@staticmethod
|
||||
def _provider_to_dict(provider: Provider) -> dict:
|
||||
"""将 Provider 对象转换为字典(用于缓存)"""
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"api_format": provider.api_format,
|
||||
"base_url": provider.base_url,
|
||||
"is_active": provider.is_active,
|
||||
"priority": provider.priority,
|
||||
"rpm_limit": provider.rpm_limit,
|
||||
"rpm_used": provider.rpm_used,
|
||||
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
|
||||
"config": provider.config,
|
||||
"description": provider.description,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_provider(provider_dict: dict) -> Provider:
|
||||
"""从字典重建 Provider 对象(分离的对象,不在 Session 中)"""
|
||||
from datetime import datetime
|
||||
|
||||
provider = Provider(
|
||||
id=provider_dict["id"],
|
||||
name=provider_dict["name"],
|
||||
api_format=provider_dict["api_format"],
|
||||
base_url=provider_dict.get("base_url"),
|
||||
is_active=provider_dict["is_active"],
|
||||
priority=provider_dict.get("priority", 0),
|
||||
rpm_limit=provider_dict.get("rpm_limit"),
|
||||
rpm_used=provider_dict.get("rpm_used", 0),
|
||||
config=provider_dict.get("config"),
|
||||
description=provider_dict.get("description"),
|
||||
)
|
||||
|
||||
if provider_dict.get("rpm_reset_at"):
|
||||
provider.rpm_reset_at = datetime.fromisoformat(provider_dict["rpm_reset_at"])
|
||||
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
def _endpoint_to_dict(endpoint: ProviderEndpoint) -> dict:
|
||||
"""将 Endpoint 对象转换为字典"""
|
||||
return {
|
||||
"id": endpoint.id,
|
||||
"provider_id": endpoint.provider_id,
|
||||
"name": endpoint.name,
|
||||
"base_url": endpoint.base_url,
|
||||
"is_active": endpoint.is_active,
|
||||
"priority": endpoint.priority,
|
||||
"weight": endpoint.weight,
|
||||
"custom_path": endpoint.custom_path,
|
||||
"config": endpoint.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_endpoint(endpoint_dict: dict) -> ProviderEndpoint:
|
||||
"""从字典重建 Endpoint 对象"""
|
||||
endpoint = ProviderEndpoint(
|
||||
id=endpoint_dict["id"],
|
||||
provider_id=endpoint_dict["provider_id"],
|
||||
name=endpoint_dict["name"],
|
||||
base_url=endpoint_dict["base_url"],
|
||||
is_active=endpoint_dict["is_active"],
|
||||
priority=endpoint_dict.get("priority", 0),
|
||||
weight=endpoint_dict.get("weight", 1.0),
|
||||
custom_path=endpoint_dict.get("custom_path"),
|
||||
config=endpoint_dict.get("config"),
|
||||
)
|
||||
return endpoint
|
||||
|
||||
@staticmethod
|
||||
def _api_key_to_dict(api_key: ProviderAPIKey) -> dict:
|
||||
"""将 API Key 对象转换为字典"""
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"endpoint_id": api_key.endpoint_id,
|
||||
"key_value": api_key.key_value,
|
||||
"is_active": api_key.is_active,
|
||||
"max_rpm": api_key.max_rpm,
|
||||
"current_rpm": api_key.current_rpm,
|
||||
"health_score": api_key.health_score,
|
||||
"circuit_breaker_state": api_key.circuit_breaker_state,
|
||||
"adaptive_concurrency_limit": api_key.adaptive_concurrency_limit,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_api_key(api_key_dict: dict) -> ProviderAPIKey:
|
||||
"""从字典重建 API Key 对象"""
|
||||
api_key = ProviderAPIKey(
|
||||
id=api_key_dict["id"],
|
||||
endpoint_id=api_key_dict["endpoint_id"],
|
||||
key_value=api_key_dict["key_value"],
|
||||
is_active=api_key_dict["is_active"],
|
||||
max_rpm=api_key_dict.get("max_rpm"),
|
||||
current_rpm=api_key_dict.get("current_rpm", 0),
|
||||
health_score=api_key_dict.get("health_score", 1.0),
|
||||
circuit_breaker_state=api_key_dict.get("circuit_breaker_state"),
|
||||
adaptive_concurrency_limit=api_key_dict.get("adaptive_concurrency_limit"),
|
||||
)
|
||||
return api_key
|
||||
Reference in New Issue
Block a user