mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-08 10:42:29 +08:00
Initial commit
This commit is contained in:
19
src/services/cache/__init__.py
vendored
Normal file
19
src/services/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
缓存服务模块
|
||||
|
||||
包含缓存后端、缓存亲和性、缓存同步等功能。
|
||||
|
||||
注意:由于循环依赖问题,部分类需要直接从子模块导入:
|
||||
from src.services.cache.affinity_manager import CacheAffinityManager
|
||||
from src.services.cache.aware_scheduler import CacheAwareScheduler
|
||||
"""
|
||||
|
||||
# 只导出不会导致循环依赖的基础类
|
||||
from src.services.cache.backend import BaseCacheBackend, LocalCache, RedisCache, get_cache_backend
|
||||
|
||||
__all__ = [
|
||||
"BaseCacheBackend",
|
||||
"LocalCache",
|
||||
"RedisCache",
|
||||
"get_cache_backend",
|
||||
]
|
||||
668
src/services/cache/affinity_manager.py
vendored
Normal file
668
src/services/cache/affinity_manager.py
vendored
Normal file
@@ -0,0 +1,668 @@
|
||||
"""
|
||||
缓存亲和性管理器 (Cache Affinity Manager) - 支持 Redis 或内存存储
|
||||
|
||||
职责:
|
||||
1. 跟踪请求API Key的Provider+Key缓存状态
|
||||
2. 管理缓存有效期
|
||||
3. 提供缓存统计和分析
|
||||
4. 自动失效不支持缓存的Provider
|
||||
|
||||
设计原理:
|
||||
- 每个API Key使用某个Provider的Key后,在缓存TTL期内,应该继续使用同一个Key
|
||||
- 这样可以最大化利用提供商的Prompt Caching机制
|
||||
- 当Key故障时,自动失效该Key的缓存亲和性
|
||||
- 当Provider关闭缓存支持时,自动失效所有相关亲和性
|
||||
|
||||
注意:
|
||||
- affinity_key 参数通常为请求使用的 API Key ID(api_key_id)
|
||||
- 这样可以支持"独立余额Key"场景,每个Key有自己的缓存亲和性
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
|
||||
|
||||
from src.config.constants import CacheTTL
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
|
||||
class CacheAffinity(NamedTuple):
|
||||
"""缓存亲和性信息"""
|
||||
|
||||
provider_id: str
|
||||
endpoint_id: str
|
||||
key_id: str
|
||||
api_format: str # API格式 (claude/openai)
|
||||
model_name: str # 模型名称
|
||||
created_at: float # 创建时间戳
|
||||
expire_at: float # 过期时间戳
|
||||
request_count: int # 使用次数
|
||||
|
||||
|
||||
class CacheAffinityManager:
|
||||
"""
|
||||
缓存亲和性管理器(支持 Redis 或内存存储)
|
||||
|
||||
存储结构:
|
||||
----------------------
|
||||
Key格式: cache_affinity:{affinity_key}:{api_format}:{model_name}
|
||||
- affinity_key: 通常为请求使用的 API Key ID(支持独立余额Key场景)
|
||||
- api_format: API格式 (claude/openai)
|
||||
- model_name: 模型名称(区分不同模型的缓存亲和性)
|
||||
Value格式: JSON/Dict
|
||||
{
|
||||
"provider_id": "xxx",
|
||||
"endpoint_id": "yyy",
|
||||
"key_id": "zzz",
|
||||
"model_name": "claude-3-5-sonnet-20241022",
|
||||
"created_at": 1234567890.123,
|
||||
"expire_at": 1234567890.123,
|
||||
"request_count": 5
|
||||
}
|
||||
TTL: 自动过期
|
||||
|
||||
设计改进:
|
||||
- 每个API Key可以对多个API格式和模型分别维护缓存亲和性
|
||||
- 不同模型请求使用独立的缓存亲和性,避免模型切换导致的缓存失效
|
||||
- 某个端点故障切换不会影响其他端点的亲和性
|
||||
- 更精确的缓存命中率统计
|
||||
- 支持"独立余额Key"场景,每个Key有独立的缓存亲和性
|
||||
"""
|
||||
|
||||
# 默认缓存TTL(秒)- 使用统一常量
|
||||
DEFAULT_CACHE_TTL = CacheTTL.CACHE_AFFINITY
|
||||
|
||||
def __init__(self, redis_client=None, default_ttl: int = DEFAULT_CACHE_TTL):
|
||||
"""
|
||||
初始化缓存亲和性管理器
|
||||
|
||||
Args:
|
||||
redis_client: Redis客户端(可选)
|
||||
default_ttl: 默认缓存TTL(秒)
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self.default_ttl = default_ttl
|
||||
self._memory_store: Dict[str, Dict[str, Any]] = {}
|
||||
self._memory_lock: Optional[asyncio.Lock] = None
|
||||
|
||||
# L1 缓存(即使使用 Redis 也启用,减少网络往返)
|
||||
self._l1_cache_ttl = int(os.getenv("CACHE_AFFINITY_L1_TTL", str(CacheTTL.L1_LOCAL)))
|
||||
self._l1_cache: Dict[str, Tuple[float, Dict[str, Any]]] = {}
|
||||
self._l1_lock = asyncio.Lock()
|
||||
self._l1_max_size = int(os.getenv("CACHE_AFFINITY_L1_MAX_SIZE", "1000")) # 最大缓存条目数
|
||||
self._l1_last_cleanup = time.time()
|
||||
|
||||
# 请求级别锁,避免同一用户+端点同时更新造成抖动
|
||||
self._request_locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
# 统计信息
|
||||
self._stats = {
|
||||
"total_affinities": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"cache_invalidations": 0,
|
||||
"provider_switches": 0,
|
||||
"key_switches": 0,
|
||||
}
|
||||
|
||||
if self.redis:
|
||||
logger.debug("CacheAffinityManager: 使用Redis存储")
|
||||
else:
|
||||
logger.debug("CacheAffinityManager: Redis不可用,回退到内存存储(仅适用于单实例/开发环境)")
|
||||
|
||||
def _is_memory_backend(self) -> bool:
|
||||
"""是否处于内存模式"""
|
||||
return self.redis is None
|
||||
|
||||
def _get_memory_lock(self) -> asyncio.Lock:
|
||||
"""懒初始化内存锁"""
|
||||
if self._memory_lock is None:
|
||||
self._memory_lock = asyncio.Lock()
|
||||
return self._memory_lock
|
||||
|
||||
def _get_cache_key(self, affinity_key: str, api_format: str, model_name: str) -> str:
|
||||
"""
|
||||
生成Redis Key
|
||||
|
||||
Args:
|
||||
affinity_key: 亲和性标识符(通常为API Key ID)
|
||||
api_format: API格式 (claude/openai)
|
||||
model_name: 模型名称
|
||||
|
||||
Returns:
|
||||
格式化的缓存键: cache_affinity:{affinity_key}:{api_format}:{model_name}
|
||||
"""
|
||||
return f"cache_affinity:{affinity_key}:{api_format}:{model_name}"
|
||||
|
||||
async def _get_l1_entry(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
async with self._l1_lock:
|
||||
record = self._l1_cache.get(cache_key)
|
||||
if not record:
|
||||
return None
|
||||
expire_at, payload = record
|
||||
if time.time() > expire_at:
|
||||
self._l1_cache.pop(cache_key, None)
|
||||
return None
|
||||
return dict(payload)
|
||||
|
||||
async def _set_l1_entry(self, cache_key: str, payload: Optional[Dict[str, Any]]):
|
||||
async with self._l1_lock:
|
||||
if not payload:
|
||||
self._l1_cache.pop(cache_key, None)
|
||||
return
|
||||
expire_at = time.time() + max(1, self._l1_cache_ttl)
|
||||
self._l1_cache[cache_key] = (expire_at, dict(payload))
|
||||
|
||||
# 定期清理过期条目(每 60 秒最多一次)
|
||||
current_time = time.time()
|
||||
if current_time - self._l1_last_cleanup > 60:
|
||||
self._cleanup_l1_cache_unlocked(current_time)
|
||||
self._l1_last_cleanup = current_time
|
||||
|
||||
def _cleanup_l1_cache_unlocked(self, current_time: float) -> int:
|
||||
"""清理过期的 L1 缓存条目(需要在持有锁的情况下调用)
|
||||
|
||||
Returns:
|
||||
清理的条目数量
|
||||
"""
|
||||
expired_keys = [
|
||||
key for key, (expire_at, _) in self._l1_cache.items()
|
||||
if current_time > expire_at
|
||||
]
|
||||
for key in expired_keys:
|
||||
self._l1_cache.pop(key, None)
|
||||
|
||||
# 如果缓存仍然过大,按过期时间排序移除最旧的条目
|
||||
if len(self._l1_cache) > self._l1_max_size:
|
||||
sorted_items = sorted(
|
||||
self._l1_cache.items(),
|
||||
key=lambda x: x[1][0] # 按 expire_at 排序
|
||||
)
|
||||
# 移除最旧的 20% 条目
|
||||
remove_count = len(self._l1_cache) - int(self._l1_max_size * 0.8)
|
||||
for key, _ in sorted_items[:remove_count]:
|
||||
self._l1_cache.pop(key, None)
|
||||
expired_keys.extend([k for k, _ in sorted_items[:remove_count]])
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"L1 缓存清理: 移除 {len(expired_keys)} 个条目,当前 {len(self._l1_cache)} 个")
|
||||
|
||||
return len(expired_keys)
|
||||
|
||||
@asynccontextmanager
|
||||
async def _acquire_request_lock(self, cache_key: str):
|
||||
lock = self._request_locks.get(cache_key)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
self._request_locks[cache_key] = lock
|
||||
await lock.acquire()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
async def _load_affinity_dict(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""读取缓存亲和性字典"""
|
||||
# 先尝试L1缓存
|
||||
l1_value = await self._get_l1_entry(cache_key)
|
||||
if l1_value is not None:
|
||||
return l1_value
|
||||
|
||||
if not self._is_memory_backend():
|
||||
data = await self.redis.get(cache_key)
|
||||
if not data:
|
||||
return None
|
||||
value = json.loads(data)
|
||||
await self._set_l1_entry(cache_key, value)
|
||||
return value
|
||||
|
||||
lock = self._get_memory_lock()
|
||||
async with lock:
|
||||
record = self._memory_store.get(cache_key)
|
||||
if record:
|
||||
await self._set_l1_entry(cache_key, record)
|
||||
return dict(record) if record else None
|
||||
|
||||
async def _save_affinity_dict(
|
||||
self, cache_key: str, ttl: int, affinity_dict: Dict[str, Any]
|
||||
) -> None:
|
||||
"""存储缓存亲和性字典"""
|
||||
if not self._is_memory_backend():
|
||||
await self.redis.setex(cache_key, ttl, json.dumps(affinity_dict))
|
||||
await self._set_l1_entry(cache_key, affinity_dict)
|
||||
return
|
||||
|
||||
lock = self._get_memory_lock()
|
||||
async with lock:
|
||||
self._memory_store[cache_key] = dict(affinity_dict)
|
||||
await self._set_l1_entry(cache_key, affinity_dict)
|
||||
|
||||
async def _delete_affinity_key(self, cache_key: str) -> None:
|
||||
"""删除缓存亲和性"""
|
||||
if not self._is_memory_backend():
|
||||
await self.redis.delete(cache_key)
|
||||
else:
|
||||
lock = self._get_memory_lock()
|
||||
async with lock:
|
||||
self._memory_store.pop(cache_key, None)
|
||||
|
||||
await self._set_l1_entry(cache_key, None)
|
||||
|
||||
async def _snapshot_memory_items(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""复制内存存储内容(仅内存模式使用)"""
|
||||
lock = self._get_memory_lock()
|
||||
async with lock:
|
||||
return {k: dict(v) for k, v in self._memory_store.items()}
|
||||
|
||||
async def get_affinity(
|
||||
self, affinity_key: str, api_format: str, model_name: str
|
||||
) -> Optional[CacheAffinity]:
|
||||
"""
|
||||
获取指定亲和性标识符对特定API格式和模型的缓存亲和性
|
||||
|
||||
Args:
|
||||
affinity_key: 亲和性标识符(通常为API Key ID)
|
||||
api_format: API格式 (claude/openai)
|
||||
model_name: 模型名称
|
||||
|
||||
Returns:
|
||||
CacheAffinity对象,如果不存在或已过期则返回None
|
||||
"""
|
||||
try:
|
||||
cache_key = self._get_cache_key(affinity_key, api_format, model_name)
|
||||
async with self._acquire_request_lock(cache_key):
|
||||
affinity_dict = await self._load_affinity_dict(cache_key)
|
||||
|
||||
if not affinity_dict:
|
||||
self._stats["cache_misses"] += 1
|
||||
return None
|
||||
|
||||
# 检查是否过期(双重检查,防止TTL未及时清理)
|
||||
current_time = time.time()
|
||||
if current_time > affinity_dict["expire_at"]:
|
||||
await self._delete_affinity_key(cache_key)
|
||||
self._stats["cache_misses"] += 1
|
||||
return None
|
||||
|
||||
self._stats["cache_hits"] += 1
|
||||
|
||||
return CacheAffinity(
|
||||
provider_id=affinity_dict["provider_id"],
|
||||
endpoint_id=affinity_dict["endpoint_id"],
|
||||
key_id=affinity_dict["key_id"],
|
||||
api_format=affinity_dict.get("api_format", api_format),
|
||||
model_name=affinity_dict.get("model_name", model_name),
|
||||
created_at=affinity_dict["created_at"],
|
||||
expire_at=affinity_dict["expire_at"],
|
||||
request_count=affinity_dict["request_count"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"获取缓存亲和性失败: {e}")
|
||||
self._stats["cache_misses"] += 1
|
||||
return None
|
||||
|
||||
async def set_affinity(
|
||||
self,
|
||||
affinity_key: str,
|
||||
provider_id: str,
|
||||
endpoint_id: str,
|
||||
key_id: str,
|
||||
api_format: str,
|
||||
model_name: str,
|
||||
supports_caching: bool = True,
|
||||
ttl: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
设置指定亲和性标识符对特定API格式和模型的缓存亲和性
|
||||
|
||||
Args:
|
||||
affinity_key: 亲和性标识符(通常为API Key ID)
|
||||
provider_id: Provider ID
|
||||
endpoint_id: Endpoint ID
|
||||
key_id: Key ID
|
||||
api_format: API格式 (claude/openai)
|
||||
model_name: 模型名称
|
||||
supports_caching: 该Provider是否支持缓存
|
||||
ttl: 缓存有效期(秒),如果不提供则使用默认值
|
||||
|
||||
注意:每次调用都会刷新过期时间(滑动窗口机制),以保持对同一个Provider/Endpoint/Key的亲和性
|
||||
"""
|
||||
if not supports_caching:
|
||||
# 不支持缓存的Provider不记录亲和性
|
||||
logger.debug(f"Provider {provider_id[:8]}... 不支持缓存,跳过亲和性记录")
|
||||
return
|
||||
|
||||
ttl = ttl or self.default_ttl
|
||||
current_time = time.time()
|
||||
expire_at = current_time + ttl # 每次都刷新过期时间
|
||||
cache_key = self._get_cache_key(affinity_key, api_format, model_name)
|
||||
|
||||
try:
|
||||
async with self._acquire_request_lock(cache_key):
|
||||
existing_dict = await self._load_affinity_dict(cache_key)
|
||||
existing_affinity: Optional[CacheAffinity] = None
|
||||
if existing_dict and current_time <= existing_dict.get("expire_at", 0):
|
||||
existing_affinity = CacheAffinity(
|
||||
provider_id=existing_dict["provider_id"],
|
||||
endpoint_id=existing_dict["endpoint_id"],
|
||||
key_id=existing_dict["key_id"],
|
||||
api_format=existing_dict.get("api_format", api_format),
|
||||
model_name=existing_dict.get("model_name", model_name),
|
||||
created_at=existing_dict["created_at"],
|
||||
expire_at=existing_dict["expire_at"],
|
||||
request_count=existing_dict.get("request_count", 0),
|
||||
)
|
||||
|
||||
if existing_affinity:
|
||||
created_at = existing_affinity.created_at
|
||||
request_count = existing_affinity.request_count + 1
|
||||
|
||||
# 检查是否切换了 Provider/Endpoint/Key
|
||||
if (
|
||||
existing_affinity.provider_id != provider_id
|
||||
or existing_affinity.endpoint_id != endpoint_id
|
||||
or existing_affinity.key_id != key_id
|
||||
):
|
||||
self._stats["key_switches"] += 1
|
||||
logger.debug(f"Key {affinity_key[:8]}... 在 {api_format} 格式下切换后端: "
|
||||
f"[{existing_affinity.provider_id[:8]}.../{existing_affinity.endpoint_id[:8]}.../"
|
||||
f"{existing_affinity.key_id[:8]}...] → "
|
||||
f"[{provider_id[:8]}.../{endpoint_id[:8]}.../{key_id[:8]}...], 重置计数器")
|
||||
created_at = current_time
|
||||
request_count = 1
|
||||
else:
|
||||
logger.debug(f"刷新缓存亲和性: key={affinity_key[:8]}..., api_format={api_format}, "
|
||||
f"provider={provider_id[:8]}..., endpoint={endpoint_id[:8]}..., "
|
||||
f"provider_key={key_id[:8]}..., ttl+={ttl}s")
|
||||
else:
|
||||
created_at = current_time
|
||||
request_count = 1
|
||||
self._stats["total_affinities"] += 1
|
||||
|
||||
affinity_dict = {
|
||||
"provider_id": provider_id,
|
||||
"endpoint_id": endpoint_id,
|
||||
"key_id": key_id,
|
||||
"api_format": api_format,
|
||||
"model_name": model_name,
|
||||
"created_at": created_at,
|
||||
"expire_at": expire_at,
|
||||
"request_count": request_count,
|
||||
}
|
||||
|
||||
await self._save_affinity_dict(cache_key, ttl, affinity_dict)
|
||||
|
||||
logger.debug(f"设置缓存亲和性: key={affinity_key[:8]}..., api_format={api_format}, "
|
||||
f"model={model_name}, provider={provider_id[:8]}..., endpoint={endpoint_id[:8]}..., "
|
||||
f"provider_key={key_id[:8]}..., ttl={ttl}s")
|
||||
except Exception as e:
|
||||
logger.exception(f"设置缓存亲和性失败: {e}")
|
||||
|
||||
async def invalidate_affinity(
|
||||
self,
|
||||
affinity_key: str,
|
||||
api_format: str,
|
||||
model_name: str,
|
||||
key_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
endpoint_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
失效指定亲和性标识符对特定API格式和模型的缓存亲和性
|
||||
|
||||
Args:
|
||||
affinity_key: 亲和性标识符(通常为API Key ID)
|
||||
api_format: API格式 (claude/openai)
|
||||
model_name: 模型名称
|
||||
key_id: Provider Key ID(可选,如果提供则只在Key匹配时失效)
|
||||
provider_id: Provider ID(可选,如果提供则只在Provider匹配时失效)
|
||||
endpoint_id: Endpoint ID(可选,如果提供则只在Endpoint匹配时失效)
|
||||
"""
|
||||
existing_affinity = await self.get_affinity(affinity_key, api_format, model_name)
|
||||
|
||||
if not existing_affinity:
|
||||
return
|
||||
|
||||
# 检查是否匹配过滤条件
|
||||
should_invalidate = True
|
||||
|
||||
if key_id and existing_affinity.key_id != key_id:
|
||||
should_invalidate = False
|
||||
|
||||
if provider_id and existing_affinity.provider_id != provider_id:
|
||||
should_invalidate = False
|
||||
|
||||
if endpoint_id and existing_affinity.endpoint_id != endpoint_id:
|
||||
should_invalidate = False
|
||||
|
||||
if not should_invalidate:
|
||||
logger.debug(f"跳过失效: affinity_key={affinity_key[:8]}..., api_format={api_format}, "
|
||||
f"model={model_name}, 过滤条件不匹配 (key={key_id}, provider={provider_id}, endpoint={endpoint_id})")
|
||||
return
|
||||
|
||||
try:
|
||||
cache_key = self._get_cache_key(affinity_key, api_format, model_name)
|
||||
async with self._acquire_request_lock(cache_key):
|
||||
await self._delete_affinity_key(cache_key)
|
||||
|
||||
self._stats["cache_invalidations"] += 1
|
||||
|
||||
logger.debug(f"失效缓存亲和性: affinity_key={affinity_key[:8]}..., api_format={api_format}, "
|
||||
f"model={model_name}, provider={existing_affinity.provider_id[:8]}..., "
|
||||
f"endpoint={existing_affinity.endpoint_id[:8]}..., "
|
||||
f"provider_key={existing_affinity.key_id[:8]}...")
|
||||
except Exception as e:
|
||||
logger.exception(f"删除缓存亲和性失败: {e}")
|
||||
|
||||
async def invalidate_all_for_provider(self, provider_id: str) -> int:
|
||||
"""
|
||||
失效所有与指定Provider相关的缓存亲和性
|
||||
|
||||
用途:当Provider关闭缓存支持时调用
|
||||
|
||||
Args:
|
||||
provider_id: Provider ID
|
||||
|
||||
Returns:
|
||||
失效的亲和性数量
|
||||
"""
|
||||
try:
|
||||
invalidated_count = 0
|
||||
|
||||
if not self._is_memory_backend():
|
||||
pattern = "cache_affinity:*"
|
||||
keys = await self.redis.keys(pattern)
|
||||
else:
|
||||
keys = list((await self._snapshot_memory_items()).keys())
|
||||
|
||||
for key in keys:
|
||||
affinity_dict = await self._load_affinity_dict(key)
|
||||
if not affinity_dict:
|
||||
continue
|
||||
|
||||
if affinity_dict.get("provider_id") == provider_id:
|
||||
await self._delete_affinity_key(key)
|
||||
invalidated_count += 1
|
||||
self._stats["cache_invalidations"] += 1
|
||||
|
||||
if invalidated_count > 0:
|
||||
logger.debug(f"批量失效Provider缓存亲和性: provider={provider_id[:8]}..., "
|
||||
f"失效数量={invalidated_count}")
|
||||
|
||||
return invalidated_count
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"批量失效Provider缓存亲和性失败: {e}")
|
||||
return 0
|
||||
|
||||
async def clear_all(self) -> int:
|
||||
"""
|
||||
清除所有缓存亲和性(管理功能)
|
||||
|
||||
Returns:
|
||||
清除的数量
|
||||
"""
|
||||
try:
|
||||
if not self._is_memory_backend():
|
||||
keys = await self.redis.keys("cache_affinity:*")
|
||||
if keys:
|
||||
await self.redis.delete(*keys)
|
||||
logger.debug(f"清除所有Redis缓存亲和性: {len(keys)} 个")
|
||||
return len(keys)
|
||||
return 0
|
||||
|
||||
lock = self._get_memory_lock()
|
||||
async with lock:
|
||||
count = len(self._memory_store)
|
||||
self._memory_store.clear()
|
||||
if count:
|
||||
logger.debug(f"清除所有内存缓存亲和性: {count} 个")
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.exception(f"清除缓存亲和性失败: {e}")
|
||||
return 0
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
cache_hit_rate = 0.0
|
||||
total_requests = self._stats["cache_hits"] + self._stats["cache_misses"]
|
||||
|
||||
if total_requests > 0:
|
||||
cache_hit_rate = self._stats["cache_hits"] / total_requests
|
||||
|
||||
storage_type = "redis" if not self._is_memory_backend() else "memory"
|
||||
|
||||
return {
|
||||
"storage_type": storage_type,
|
||||
"total_affinities": self._stats["total_affinities"],
|
||||
"cache_hits": self._stats["cache_hits"],
|
||||
"cache_misses": self._stats["cache_misses"],
|
||||
"cache_hit_rate": cache_hit_rate,
|
||||
"cache_invalidations": self._stats["cache_invalidations"],
|
||||
"provider_switches": self._stats["provider_switches"],
|
||||
"key_switches": self._stats["key_switches"],
|
||||
"config": {
|
||||
"default_ttl": self.default_ttl,
|
||||
},
|
||||
}
|
||||
|
||||
async def list_affinities(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有缓存亲和性列表
|
||||
|
||||
返回的每条记录包含:
|
||||
- affinity_key: 亲和性标识符(通常是 API Key ID)
|
||||
- provider_id, endpoint_id, key_id: Provider 相关信息
|
||||
- api_format, model_name: API 格式和模型名称
|
||||
- created_at, expire_at, request_count: 缓存元数据
|
||||
"""
|
||||
results: List[Dict[str, Any]] = []
|
||||
|
||||
try:
|
||||
pattern = "cache_affinity:*"
|
||||
cursor = 0
|
||||
|
||||
if not self._is_memory_backend():
|
||||
while True:
|
||||
cursor, keys = await self.redis.scan(cursor=cursor, match=pattern, count=200)
|
||||
|
||||
if keys:
|
||||
values = await self.redis.mget(*keys)
|
||||
for cache_key, data in zip(keys, values):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
try:
|
||||
affinity = json.loads(data)
|
||||
# 解析 cache_affinity:{affinity_key}:{api_format}:{model_name}
|
||||
parts = cache_key.split(":")
|
||||
affinity_key_value = parts[1] if len(parts) > 1 else cache_key
|
||||
api_format = (
|
||||
parts[2]
|
||||
if len(parts) > 2
|
||||
else affinity.get("api_format", "unknown")
|
||||
)
|
||||
model_name = (
|
||||
parts[3]
|
||||
if len(parts) > 3
|
||||
else affinity.get("model_name", "unknown")
|
||||
)
|
||||
|
||||
affinity["affinity_key"] = affinity_key_value
|
||||
if "api_format" not in affinity:
|
||||
affinity["api_format"] = api_format
|
||||
if "model_name" not in affinity:
|
||||
affinity["model_name"] = model_name
|
||||
results.append(affinity)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.exception(f"解析缓存亲和性记录失败: {cache_key} - {e}")
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
else:
|
||||
snapshot = await self._snapshot_memory_items()
|
||||
expired_keys: List[str] = []
|
||||
current_time = time.time()
|
||||
|
||||
for cache_key, affinity in snapshot.items():
|
||||
if current_time > affinity["expire_at"]:
|
||||
expired_keys.append(cache_key)
|
||||
continue
|
||||
|
||||
# 解析 cache_affinity:{affinity_key}:{api_format}:{model_name}
|
||||
parts = cache_key.split(":")
|
||||
affinity_key_value = parts[1] if len(parts) > 1 else cache_key
|
||||
api_format = (
|
||||
parts[2] if len(parts) > 2 else affinity.get("api_format", "unknown")
|
||||
)
|
||||
model_name = (
|
||||
parts[3] if len(parts) > 3 else affinity.get("model_name", "unknown")
|
||||
)
|
||||
|
||||
affinity_with_key = dict(affinity)
|
||||
affinity_with_key["affinity_key"] = affinity_key_value
|
||||
if "api_format" not in affinity_with_key:
|
||||
affinity_with_key["api_format"] = api_format
|
||||
if "model_name" not in affinity_with_key:
|
||||
affinity_with_key["model_name"] = model_name
|
||||
results.append(affinity_with_key)
|
||||
|
||||
# 清理过期的键
|
||||
if expired_keys:
|
||||
async with self._get_memory_lock():
|
||||
for key in expired_keys:
|
||||
self._memory_store.pop(key, None)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"获取缓存亲和性列表失败: {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# 全局单例
|
||||
_affinity_manager: Optional[CacheAffinityManager] = None
|
||||
|
||||
|
||||
async def get_affinity_manager(redis_client=None) -> CacheAffinityManager:
|
||||
"""
|
||||
获取全局CacheAffinityManager实例(若Redis不可用则降级为内存模式)
|
||||
|
||||
Args:
|
||||
redis_client: Redis客户端(可选)
|
||||
|
||||
Returns:
|
||||
CacheAffinityManager实例
|
||||
"""
|
||||
global _affinity_manager
|
||||
|
||||
if _affinity_manager is None:
|
||||
_affinity_manager = CacheAffinityManager(redis_client)
|
||||
elif redis_client and _affinity_manager.redis is None:
|
||||
# 当最初使用内存后 Redis 可用时,升级为 Redis 存储
|
||||
_affinity_manager = CacheAffinityManager(redis_client)
|
||||
|
||||
return _affinity_manager
|
||||
1316
src/services/cache/aware_scheduler.py
vendored
Normal file
1316
src/services/cache/aware_scheduler.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
330
src/services/cache/backend.py
vendored
Normal file
330
src/services/cache/backend.py
vendored
Normal file
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
缓存后端抽象层
|
||||
|
||||
提供统一的缓存接口,支持多种后端实现:
|
||||
1. LocalCache: 内存缓存(单实例,线程安全)
|
||||
2. RedisCache: Redis 缓存(分布式)
|
||||
|
||||
使用场景:
|
||||
- ModelMappingResolver: 模型映射与别名解析缓存
|
||||
- ModelMapper: 模型映射缓存
|
||||
- 其他需要缓存的服务
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from src.core.logger import logger
|
||||
|
||||
from src.clients.redis_client import get_redis_client_sync
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
class BaseCacheBackend(ABC):
|
||||
"""缓存后端抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""获取缓存值"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set(self, key: str, value: Any, ttl: int = 300) -> None:
|
||||
"""设置缓存值"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> None:
|
||||
"""删除缓存值"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self, pattern: Optional[str] = None) -> None:
|
||||
"""清空缓存(支持模式匹配)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""检查键是否存在"""
|
||||
pass
|
||||
|
||||
|
||||
class LocalCache(BaseCacheBackend):
|
||||
"""本地内存缓存后端(LRU + TTL,线程安全)"""
|
||||
|
||||
def __init__(self, max_size: int = 1000, default_ttl: int = 300):
|
||||
"""
|
||||
初始化本地缓存
|
||||
|
||||
Args:
|
||||
max_size: 最大缓存条目数
|
||||
default_ttl: 默认过期时间(秒)
|
||||
"""
|
||||
self._cache: OrderedDict = OrderedDict()
|
||||
self._expiry: Dict[str, float] = {}
|
||||
self._max_size = max_size
|
||||
self._default_ttl = default_ttl
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""获取缓存值(线程安全)"""
|
||||
async with self._lock:
|
||||
if key not in self._cache:
|
||||
return None
|
||||
|
||||
# 检查过期
|
||||
if key in self._expiry and time.time() > self._expiry[key]:
|
||||
# 过期,删除
|
||||
del self._cache[key]
|
||||
del self._expiry[key]
|
||||
return None
|
||||
|
||||
# 更新访问顺序(LRU)
|
||||
self._cache.move_to_end(key)
|
||||
return self._cache[key]
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: int = None) -> None:
|
||||
"""设置缓存值(线程安全)"""
|
||||
async with self._lock:
|
||||
if ttl is None:
|
||||
ttl = self._default_ttl
|
||||
|
||||
# 如果键已存在,更新访问顺序
|
||||
if key in self._cache:
|
||||
self._cache.move_to_end(key)
|
||||
|
||||
self._cache[key] = value
|
||||
self._expiry[key] = time.time() + ttl
|
||||
|
||||
# 检查容量限制,淘汰最旧项
|
||||
if len(self._cache) > self._max_size:
|
||||
oldest_key = next(iter(self._cache))
|
||||
del self._cache[oldest_key]
|
||||
if oldest_key in self._expiry:
|
||||
del self._expiry[oldest_key]
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""删除缓存值(线程安全)"""
|
||||
async with self._lock:
|
||||
if key in self._cache:
|
||||
del self._cache[key]
|
||||
if key in self._expiry:
|
||||
del self._expiry[key]
|
||||
|
||||
async def clear(self, pattern: Optional[str] = None) -> None:
|
||||
"""清空缓存(线程安全)"""
|
||||
async with self._lock:
|
||||
if pattern is None:
|
||||
# 清空所有
|
||||
self._cache.clear()
|
||||
self._expiry.clear()
|
||||
else:
|
||||
# 模式匹配删除(简单实现:支持前缀匹配)
|
||||
prefix = pattern.rstrip("*")
|
||||
keys_to_delete = [k for k in self._cache.keys() if k.startswith(prefix)]
|
||||
for key in keys_to_delete:
|
||||
del self._cache[key]
|
||||
if key in self._expiry:
|
||||
del self._expiry[key]
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""检查键是否存在(线程安全)"""
|
||||
async with self._lock:
|
||||
if key not in self._cache:
|
||||
return False
|
||||
|
||||
# 检查过期
|
||||
if key in self._expiry and time.time() > self._expiry[key]:
|
||||
del self._cache[key]
|
||||
del self._expiry[key]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
return {
|
||||
"backend": "local",
|
||||
"size": len(self._cache),
|
||||
"max_size": self._max_size,
|
||||
"default_ttl": self._default_ttl,
|
||||
}
|
||||
|
||||
|
||||
class RedisCache(BaseCacheBackend):
|
||||
"""Redis 缓存后端(分布式)"""
|
||||
|
||||
def __init__(
|
||||
self, redis_client: aioredis.Redis, key_prefix: str = "cache", default_ttl: int = 300
|
||||
):
|
||||
"""
|
||||
初始化 Redis 缓存
|
||||
|
||||
Args:
|
||||
redis_client: Redis 客户端实例
|
||||
key_prefix: 缓存键前缀
|
||||
default_ttl: 默认过期时间(秒)
|
||||
"""
|
||||
self._redis = redis_client
|
||||
self._key_prefix = key_prefix
|
||||
self._default_ttl = default_ttl
|
||||
|
||||
def _make_key(self, key: str) -> str:
|
||||
"""构造完整的 Redis 键"""
|
||||
return f"{self._key_prefix}:{key}"
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""获取缓存值"""
|
||||
try:
|
||||
redis_key = self._make_key(key)
|
||||
value = await self._redis.get(redis_key)
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# 尝试 JSON 反序列化
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# 如果不是 JSON,直接返回字符串
|
||||
return value
|
||||
except Exception as e:
|
||||
logger.error(f"[RedisCache] 获取缓存失败: {key}, 错误: {e}")
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: int = None) -> None:
|
||||
"""设置缓存值"""
|
||||
if ttl is None:
|
||||
ttl = self._default_ttl
|
||||
|
||||
try:
|
||||
redis_key = self._make_key(key)
|
||||
|
||||
# 序列化值
|
||||
if isinstance(value, (dict, list, tuple)):
|
||||
serialized = json.dumps(value)
|
||||
elif isinstance(value, (int, float, bool)):
|
||||
serialized = json.dumps(value)
|
||||
else:
|
||||
serialized = str(value)
|
||||
|
||||
await self._redis.setex(redis_key, ttl, serialized)
|
||||
except Exception as e:
|
||||
logger.error(f"[RedisCache] 设置缓存失败: {key}, 错误: {e}")
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""删除缓存值"""
|
||||
try:
|
||||
redis_key = self._make_key(key)
|
||||
await self._redis.delete(redis_key)
|
||||
except Exception as e:
|
||||
logger.error(f"[RedisCache] 删除缓存失败: {key}, 错误: {e}")
|
||||
|
||||
async def clear(self, pattern: Optional[str] = None) -> None:
|
||||
"""清空缓存"""
|
||||
try:
|
||||
if pattern is None:
|
||||
# 清空所有带前缀的键
|
||||
pattern = "*"
|
||||
|
||||
redis_pattern = self._make_key(pattern)
|
||||
cursor = 0
|
||||
deleted_count = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await self._redis.scan(cursor, match=redis_pattern, count=100)
|
||||
if keys:
|
||||
await self._redis.delete(*keys)
|
||||
deleted_count += len(keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
logger.info(f"[RedisCache] 清空缓存: {redis_pattern}, 删除 {deleted_count} 个键")
|
||||
except Exception as e:
|
||||
logger.error(f"[RedisCache] 清空缓存失败: {pattern}, 错误: {e}")
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""检查键是否存在"""
|
||||
try:
|
||||
redis_key = self._make_key(key)
|
||||
return await self._redis.exists(redis_key) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"[RedisCache] 检查键存在失败: {key}, 错误: {e}")
|
||||
return False
|
||||
|
||||
async def publish_invalidation(self, channel: str, key: str) -> None:
|
||||
"""发布缓存失效消息(用于分布式同步)"""
|
||||
try:
|
||||
message = json.dumps({"key": key, "timestamp": time.time()})
|
||||
await self._redis.publish(channel, message)
|
||||
logger.debug(f"[RedisCache] 发布缓存失效: {channel} -> {key}")
|
||||
except Exception as e:
|
||||
logger.error(f"[RedisCache] 发布缓存失效失败: {channel}, {key}, 错误: {e}")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
return {
|
||||
"backend": "redis",
|
||||
"key_prefix": self._key_prefix,
|
||||
"default_ttl": self._default_ttl,
|
||||
}
|
||||
|
||||
|
||||
# 缓存后端工厂
|
||||
_cache_backends: Dict[str, BaseCacheBackend] = {}
|
||||
|
||||
|
||||
async def get_cache_backend(
|
||||
name: str, backend_type: str = "auto", max_size: int = 1000, ttl: int = 300
|
||||
) -> BaseCacheBackend:
|
||||
"""
|
||||
获取缓存后端实例
|
||||
|
||||
Args:
|
||||
name: 缓存名称(用于区分不同的缓存实例)
|
||||
backend_type: 后端类型 (auto/local/redis)
|
||||
max_size: LocalCache 的最大容量
|
||||
ttl: 默认过期时间(秒)
|
||||
|
||||
Returns:
|
||||
BaseCacheBackend 实例
|
||||
"""
|
||||
cache_key = f"{name}:{backend_type}"
|
||||
|
||||
if cache_key in _cache_backends:
|
||||
return _cache_backends[cache_key]
|
||||
|
||||
# 根据类型创建缓存后端
|
||||
if backend_type == "redis":
|
||||
# 尝试使用 Redis
|
||||
redis_client = get_redis_client_sync()
|
||||
|
||||
if redis_client is None:
|
||||
logger.warning(f"[CacheBackend] Redis 未初始化,{name} 降级为本地缓存")
|
||||
backend = LocalCache(max_size=max_size, default_ttl=ttl)
|
||||
else:
|
||||
backend = RedisCache(redis_client=redis_client, key_prefix=name, default_ttl=ttl)
|
||||
logger.info(f"[CacheBackend] {name} 使用 Redis 缓存")
|
||||
|
||||
elif backend_type == "local":
|
||||
# 强制使用本地缓存
|
||||
backend = LocalCache(max_size=max_size, default_ttl=ttl)
|
||||
logger.info(f"[CacheBackend] {name} 使用本地缓存")
|
||||
|
||||
else: # auto
|
||||
# 自动选择:优先 Redis,降级到 Local
|
||||
redis_client = get_redis_client_sync()
|
||||
|
||||
if redis_client is not None:
|
||||
backend = RedisCache(redis_client=redis_client, key_prefix=name, default_ttl=ttl)
|
||||
logger.debug(f"[CacheBackend] {name} 自动选择 Redis 缓存")
|
||||
else:
|
||||
backend = LocalCache(max_size=max_size, default_ttl=ttl)
|
||||
logger.debug(f"[CacheBackend] {name} 自动选择本地缓存(Redis 不可用)")
|
||||
|
||||
_cache_backends[cache_key] = backend
|
||||
return backend
|
||||
125
src/services/cache/invalidation.py
vendored
Normal file
125
src/services/cache/invalidation.py
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
缓存失效服务
|
||||
|
||||
统一管理各种缓存的失效逻辑,支持:
|
||||
1. GlobalModel 变更时失效相关缓存
|
||||
2. ModelMapping 变更时失效别名/降级缓存
|
||||
3. Model 变更时失效模型映射缓存
|
||||
4. 支持同步和异步缓存后端
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
class CacheInvalidationService:
|
||||
"""
|
||||
缓存失效服务
|
||||
|
||||
提供统一的缓存失效接口,当数据库模型变更时自动清理相关缓存
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化缓存失效服务"""
|
||||
self._mapping_resolver = None
|
||||
self._model_mappers = [] # 可能有多个 ModelMapperMiddleware 实例
|
||||
|
||||
def set_mapping_resolver(self, mapping_resolver):
|
||||
"""设置模型映射解析器实例"""
|
||||
self._mapping_resolver = mapping_resolver
|
||||
logger.debug(f"[CacheInvalidation] 模型映射解析器已注册 (实例: {id(mapping_resolver)})")
|
||||
|
||||
def register_model_mapper(self, model_mapper):
|
||||
"""注册 ModelMapper 实例"""
|
||||
if model_mapper not in self._model_mappers:
|
||||
self._model_mappers.append(model_mapper)
|
||||
logger.debug(f"[CacheInvalidation] ModelMapper 已注册 (实例: {id(model_mapper)},总数: {len(self._model_mappers)})")
|
||||
|
||||
def on_global_model_changed(self, model_name: str):
|
||||
"""
|
||||
GlobalModel 变更时的缓存失效
|
||||
|
||||
Args:
|
||||
model_name: 变更的 GlobalModel.name
|
||||
"""
|
||||
logger.info(f"[CacheInvalidation] GlobalModel 变更: {model_name}")
|
||||
|
||||
# 异步失效模型解析器中的缓存
|
||||
if self._mapping_resolver:
|
||||
asyncio.create_task(self._mapping_resolver.invalidate_global_model_cache())
|
||||
|
||||
# 失效所有 ModelMapper 中与此模型相关的缓存
|
||||
for mapper in self._model_mappers:
|
||||
# 清空所有缓存(因为不知道哪些 provider 使用了这个模型)
|
||||
mapper.clear_cache()
|
||||
logger.debug(f"[CacheInvalidation] 已清空 ModelMapper 缓存")
|
||||
|
||||
def on_model_mapping_changed(self, source_model: str, provider_id: Optional[str] = None):
|
||||
"""
|
||||
ModelMapping 变更时的缓存失效
|
||||
|
||||
Args:
|
||||
source_model: 变更的源模型名
|
||||
provider_id: 相关 Provider(None 表示全局)
|
||||
"""
|
||||
logger.info(f"[CacheInvalidation] ModelMapping 变更: {source_model} (provider={provider_id})")
|
||||
|
||||
if self._mapping_resolver:
|
||||
asyncio.create_task(
|
||||
self._mapping_resolver.invalidate_mapping_cache(source_model, provider_id)
|
||||
)
|
||||
|
||||
for mapper in self._model_mappers:
|
||||
if provider_id:
|
||||
mapper.refresh_cache(provider_id)
|
||||
else:
|
||||
mapper.clear_cache()
|
||||
|
||||
def on_model_changed(self, provider_id: str, global_model_id: str):
|
||||
"""
|
||||
Model 变更时的缓存失效
|
||||
|
||||
Args:
|
||||
provider_id: Provider ID
|
||||
global_model_id: GlobalModel ID
|
||||
"""
|
||||
logger.info(f"[CacheInvalidation] Model 变更: provider={provider_id[:8]}..., "
|
||||
f"global_model={global_model_id[:8]}...")
|
||||
|
||||
# 失效 ModelMapper 中特定 Provider 的缓存
|
||||
for mapper in self._model_mappers:
|
||||
mapper.refresh_cache(provider_id)
|
||||
|
||||
def clear_all_caches(self):
|
||||
"""清空所有缓存"""
|
||||
logger.info("[CacheInvalidation] 清空所有缓存")
|
||||
|
||||
if self._mapping_resolver:
|
||||
asyncio.create_task(self._mapping_resolver.clear_cache())
|
||||
|
||||
for mapper in self._model_mappers:
|
||||
mapper.clear_cache()
|
||||
|
||||
|
||||
# 全局单例
|
||||
_cache_invalidation_service: Optional[CacheInvalidationService] = None
|
||||
|
||||
|
||||
def get_cache_invalidation_service() -> CacheInvalidationService:
|
||||
"""
|
||||
获取全局缓存失效服务实例
|
||||
|
||||
Returns:
|
||||
CacheInvalidationService 实例
|
||||
"""
|
||||
global _cache_invalidation_service
|
||||
|
||||
if _cache_invalidation_service is None:
|
||||
_cache_invalidation_service = CacheInvalidationService()
|
||||
logger.debug("[CacheInvalidation] 初始化缓存失效服务")
|
||||
|
||||
return _cache_invalidation_service
|
||||
325
src/services/cache/model_cache.py
vendored
Normal file
325
src/services/cache/model_cache.py
vendored
Normal file
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
Model 映射缓存服务 - 减少模型映射和别名查询
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.constants import CacheTTL
|
||||
from src.core.cache_service import CacheService
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, Model, ModelMapping
|
||||
|
||||
|
||||
|
||||
class ModelCacheService:
|
||||
"""Model 映射缓存服务"""
|
||||
|
||||
# 缓存 TTL(秒)- 使用统一常量
|
||||
CACHE_TTL = CacheTTL.MODEL
|
||||
|
||||
@staticmethod
|
||||
async def get_model_by_id(db: Session, model_id: str) -> Optional[Model]:
|
||||
"""
|
||||
获取 Model(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
model_id: Model ID
|
||||
|
||||
Returns:
|
||||
Model 对象或 None
|
||||
"""
|
||||
cache_key = f"model:id:{model_id}"
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"Model 缓存命中: {model_id}")
|
||||
return ModelCacheService._dict_to_model(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
model = db.query(Model).filter(Model.id == model_id).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if model:
|
||||
model_dict = ModelCacheService._model_to_dict(model)
|
||||
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
|
||||
logger.debug(f"Model 已缓存: {model_id}")
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def get_global_model_by_id(db: Session, global_model_id: str) -> Optional[GlobalModel]:
|
||||
"""
|
||||
获取 GlobalModel(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
global_model_id: GlobalModel ID
|
||||
|
||||
Returns:
|
||||
GlobalModel 对象或 None
|
||||
"""
|
||||
cache_key = f"global_model:id:{global_model_id}"
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"GlobalModel 缓存命中: {global_model_id}")
|
||||
return ModelCacheService._dict_to_global_model(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
global_model = db.query(GlobalModel).filter(GlobalModel.id == global_model_id).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if global_model:
|
||||
global_model_dict = ModelCacheService._global_model_to_dict(global_model)
|
||||
await CacheService.set(
|
||||
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"GlobalModel 已缓存: {global_model_id}")
|
||||
|
||||
return global_model
|
||||
|
||||
@staticmethod
|
||||
async def get_model_by_provider_and_global_model(
|
||||
db: Session, provider_id: str, global_model_id: str
|
||||
) -> Optional[Model]:
|
||||
"""
|
||||
通过 Provider ID 和 GlobalModel ID 获取 Model(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider_id: Provider ID
|
||||
global_model_id: GlobalModel ID
|
||||
|
||||
Returns:
|
||||
Model 对象或 None
|
||||
"""
|
||||
cache_key = f"model:provider_global:{provider_id}:{global_model_id}"
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}...")
|
||||
return ModelCacheService._dict_to_model(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
model = (
|
||||
db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_id,
|
||||
Model.global_model_id == global_model_id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# 3. 写入缓存
|
||||
if model:
|
||||
model_dict = ModelCacheService._model_to_dict(model)
|
||||
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
|
||||
logger.debug(f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}...")
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def get_global_model_by_name(db: Session, name: str) -> Optional[GlobalModel]:
|
||||
"""
|
||||
通过名称获取 GlobalModel(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
name: GlobalModel 名称
|
||||
|
||||
Returns:
|
||||
GlobalModel 对象或 None
|
||||
"""
|
||||
cache_key = f"global_model:name:{name}"
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"GlobalModel 缓存命中(名称): {name}")
|
||||
return ModelCacheService._dict_to_global_model(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
global_model = db.query(GlobalModel).filter(GlobalModel.name == name).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if global_model:
|
||||
global_model_dict = ModelCacheService._global_model_to_dict(global_model)
|
||||
await CacheService.set(
|
||||
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"GlobalModel 已缓存(名称): {name}")
|
||||
|
||||
return global_model
|
||||
|
||||
@staticmethod
|
||||
async def resolve_alias(
|
||||
db: Session, source_model: str, provider_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
解析模型别名(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
source_model: 源模型名称或别名
|
||||
provider_id: Provider ID(可选,用于 Provider 特定别名)
|
||||
|
||||
Returns:
|
||||
目标 GlobalModel ID 或 None
|
||||
"""
|
||||
# 构造缓存键
|
||||
if provider_id:
|
||||
cache_key = f"alias:provider:{provider_id}:{source_model}"
|
||||
else:
|
||||
cache_key = f"alias:global:{source_model}"
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_result = await CacheService.get(cache_key)
|
||||
if cached_result:
|
||||
logger.debug(f"别名缓存命中: {source_model} (provider: {provider_id or 'global'})")
|
||||
return cached_result
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
query = db.query(ModelMapping).filter(ModelMapping.source_model == source_model)
|
||||
|
||||
if provider_id:
|
||||
# Provider 特定别名优先
|
||||
query = query.filter(ModelMapping.provider_id == provider_id)
|
||||
else:
|
||||
# 全局别名
|
||||
query = query.filter(ModelMapping.provider_id.is_(None))
|
||||
|
||||
mapping = query.first()
|
||||
|
||||
# 3. 写入缓存
|
||||
target_global_model_id = mapping.target_global_model_id if mapping else None
|
||||
await CacheService.set(
|
||||
cache_key, target_global_model_id, ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
)
|
||||
|
||||
if mapping:
|
||||
logger.debug(f"别名已缓存: {source_model} → {target_global_model_id}")
|
||||
|
||||
return target_global_model_id
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_model_cache(
|
||||
model_id: str, provider_id: Optional[str] = None, global_model_id: Optional[str] = None
|
||||
):
|
||||
"""清除 Model 缓存
|
||||
|
||||
Args:
|
||||
model_id: Model ID
|
||||
provider_id: Provider ID(用于清除 provider_global 缓存)
|
||||
global_model_id: GlobalModel ID(用于清除 provider_global 缓存)
|
||||
"""
|
||||
# 清除 model:id 缓存
|
||||
await CacheService.delete(f"model:id:{model_id}")
|
||||
|
||||
# 清除 provider_global 缓存(如果提供了必要参数)
|
||||
if provider_id and global_model_id:
|
||||
await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}")
|
||||
logger.debug(f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}...")
|
||||
else:
|
||||
logger.debug(f"Model 缓存已清除: {model_id}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_global_model_cache(global_model_id: str, name: Optional[str] = None):
|
||||
"""清除 GlobalModel 缓存"""
|
||||
await CacheService.delete(f"global_model:id:{global_model_id}")
|
||||
if name:
|
||||
await CacheService.delete(f"global_model:name:{name}")
|
||||
logger.debug(f"GlobalModel 缓存已清除: {global_model_id}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_alias_cache(source_model: str, provider_id: Optional[str] = None):
|
||||
"""清除别名缓存"""
|
||||
if provider_id:
|
||||
cache_key = f"alias:provider:{provider_id}:{source_model}"
|
||||
else:
|
||||
cache_key = f"alias:global:{source_model}"
|
||||
|
||||
await CacheService.delete(cache_key)
|
||||
logger.debug(f"别名缓存已清除: {source_model}")
|
||||
|
||||
@staticmethod
|
||||
def _model_to_dict(model: Model) -> dict:
|
||||
"""将 Model 对象转换为字典"""
|
||||
return {
|
||||
"id": model.id,
|
||||
"provider_id": model.provider_id,
|
||||
"global_model_id": model.global_model_id,
|
||||
"provider_model_name": model.provider_model_name,
|
||||
"is_active": model.is_active,
|
||||
"is_available": model.is_available if hasattr(model, "is_available") else True,
|
||||
"price_per_request": (
|
||||
float(model.price_per_request) if model.price_per_request else None
|
||||
),
|
||||
"tiered_pricing": model.tiered_pricing,
|
||||
"supports_vision": model.supports_vision,
|
||||
"supports_function_calling": model.supports_function_calling,
|
||||
"supports_streaming": model.supports_streaming,
|
||||
"supports_extended_thinking": model.supports_extended_thinking,
|
||||
"config": model.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_model(model_dict: dict) -> Model:
|
||||
"""从字典重建 Model 对象"""
|
||||
model = Model(
|
||||
id=model_dict["id"],
|
||||
provider_id=model_dict["provider_id"],
|
||||
global_model_id=model_dict["global_model_id"],
|
||||
provider_model_name=model_dict["provider_model_name"],
|
||||
is_active=model_dict["is_active"],
|
||||
is_available=model_dict.get("is_available", True),
|
||||
price_per_request=model_dict.get("price_per_request"),
|
||||
tiered_pricing=model_dict.get("tiered_pricing"),
|
||||
supports_vision=model_dict.get("supports_vision"),
|
||||
supports_function_calling=model_dict.get("supports_function_calling"),
|
||||
supports_streaming=model_dict.get("supports_streaming"),
|
||||
supports_extended_thinking=model_dict.get("supports_extended_thinking"),
|
||||
config=model_dict.get("config"),
|
||||
)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _global_model_to_dict(global_model: GlobalModel) -> dict:
|
||||
"""将 GlobalModel 对象转换为字典"""
|
||||
return {
|
||||
"id": global_model.id,
|
||||
"name": global_model.name,
|
||||
"display_name": global_model.display_name,
|
||||
"family": global_model.family,
|
||||
"group_id": global_model.group_id,
|
||||
"supports_vision": global_model.supports_vision,
|
||||
"supports_thinking": global_model.supports_thinking,
|
||||
"context_window": global_model.context_window,
|
||||
"max_output_tokens": global_model.max_output_tokens,
|
||||
"is_active": global_model.is_active,
|
||||
"description": global_model.description,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_global_model(global_model_dict: dict) -> GlobalModel:
|
||||
"""从字典重建 GlobalModel 对象"""
|
||||
global_model = GlobalModel(
|
||||
id=global_model_dict["id"],
|
||||
name=global_model_dict["name"],
|
||||
display_name=global_model_dict.get("display_name"),
|
||||
family=global_model_dict.get("family"),
|
||||
group_id=global_model_dict.get("group_id"),
|
||||
supports_vision=global_model_dict.get("supports_vision", False),
|
||||
supports_thinking=global_model_dict.get("supports_thinking", False),
|
||||
context_window=global_model_dict.get("context_window"),
|
||||
max_output_tokens=global_model_dict.get("max_output_tokens"),
|
||||
is_active=global_model_dict.get("is_active", True),
|
||||
description=global_model_dict.get("description"),
|
||||
)
|
||||
return global_model
|
||||
254
src/services/cache/provider_cache.py
vendored
Normal file
254
src/services/cache/provider_cache.py
vendored
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Provider 配置缓存服务 - 减少 Provider/Endpoint/APIKey 查询
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.constants import CacheTTL
|
||||
from src.core.cache_service import CacheKeys, CacheService
|
||||
from src.core.logger import logger
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||||
|
||||
|
||||
|
||||
class ProviderCacheService:
|
||||
"""Provider 配置缓存服务"""
|
||||
|
||||
# 缓存 TTL(秒)- 使用统一常量
|
||||
CACHE_TTL = CacheTTL.PROVIDER
|
||||
|
||||
@staticmethod
|
||||
async def get_provider_by_id(db: Session, provider_id: str) -> Optional[Provider]:
|
||||
"""
|
||||
获取 Provider(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider_id: Provider ID
|
||||
|
||||
Returns:
|
||||
Provider 对象或 None
|
||||
"""
|
||||
cache_key = CacheKeys.provider_by_id(provider_id)
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"Provider 缓存命中: {provider_id}")
|
||||
return ProviderCacheService._dict_to_provider(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if provider:
|
||||
provider_dict = ProviderCacheService._provider_to_dict(provider)
|
||||
await CacheService.set(
|
||||
cache_key, provider_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"Provider 已缓存: {provider_id}")
|
||||
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
async def get_endpoint_by_id(db: Session, endpoint_id: str) -> Optional[ProviderEndpoint]:
|
||||
"""
|
||||
获取 Endpoint(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
endpoint_id: Endpoint ID
|
||||
|
||||
Returns:
|
||||
ProviderEndpoint 对象或 None
|
||||
"""
|
||||
cache_key = CacheKeys.endpoint_by_id(endpoint_id)
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"Endpoint 缓存命中: {endpoint_id}")
|
||||
return ProviderCacheService._dict_to_endpoint(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == endpoint_id).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if endpoint:
|
||||
endpoint_dict = ProviderCacheService._endpoint_to_dict(endpoint)
|
||||
await CacheService.set(
|
||||
cache_key, endpoint_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"Endpoint 已缓存: {endpoint_id}")
|
||||
|
||||
return endpoint
|
||||
|
||||
@staticmethod
|
||||
async def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[ProviderAPIKey]:
|
||||
"""
|
||||
获取 API Key(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
api_key_id: API Key ID
|
||||
|
||||
Returns:
|
||||
ProviderAPIKey 对象或 None
|
||||
"""
|
||||
cache_key = CacheKeys.api_key_by_id(api_key_id)
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"API Key 缓存命中: {api_key_id}")
|
||||
return ProviderCacheService._dict_to_api_key(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
api_key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == api_key_id).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if api_key:
|
||||
api_key_dict = ProviderCacheService._api_key_to_dict(api_key)
|
||||
await CacheService.set(
|
||||
cache_key, api_key_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"API Key 已缓存: {api_key_id}")
|
||||
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_provider_cache(provider_id: str):
|
||||
"""
|
||||
清除 Provider 缓存
|
||||
|
||||
Args:
|
||||
provider_id: Provider ID
|
||||
"""
|
||||
await CacheService.delete(CacheKeys.provider_by_id(provider_id))
|
||||
logger.debug(f"Provider 缓存已清除: {provider_id}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_endpoint_cache(endpoint_id: str):
|
||||
"""
|
||||
清除 Endpoint 缓存
|
||||
|
||||
Args:
|
||||
endpoint_id: Endpoint ID
|
||||
"""
|
||||
await CacheService.delete(CacheKeys.endpoint_by_id(endpoint_id))
|
||||
logger.debug(f"Endpoint 缓存已清除: {endpoint_id}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_api_key_cache(api_key_id: str):
|
||||
"""
|
||||
清除 API Key 缓存
|
||||
|
||||
Args:
|
||||
api_key_id: API Key ID
|
||||
"""
|
||||
await CacheService.delete(CacheKeys.api_key_by_id(api_key_id))
|
||||
logger.debug(f"API Key 缓存已清除: {api_key_id}")
|
||||
|
||||
@staticmethod
|
||||
def _provider_to_dict(provider: Provider) -> dict:
|
||||
"""将 Provider 对象转换为字典(用于缓存)"""
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"api_format": provider.api_format,
|
||||
"base_url": provider.base_url,
|
||||
"is_active": provider.is_active,
|
||||
"priority": provider.priority,
|
||||
"rpm_limit": provider.rpm_limit,
|
||||
"rpm_used": provider.rpm_used,
|
||||
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
|
||||
"config": provider.config,
|
||||
"description": provider.description,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_provider(provider_dict: dict) -> Provider:
|
||||
"""从字典重建 Provider 对象(分离的对象,不在 Session 中)"""
|
||||
from datetime import datetime
|
||||
|
||||
provider = Provider(
|
||||
id=provider_dict["id"],
|
||||
name=provider_dict["name"],
|
||||
api_format=provider_dict["api_format"],
|
||||
base_url=provider_dict.get("base_url"),
|
||||
is_active=provider_dict["is_active"],
|
||||
priority=provider_dict.get("priority", 0),
|
||||
rpm_limit=provider_dict.get("rpm_limit"),
|
||||
rpm_used=provider_dict.get("rpm_used", 0),
|
||||
config=provider_dict.get("config"),
|
||||
description=provider_dict.get("description"),
|
||||
)
|
||||
|
||||
if provider_dict.get("rpm_reset_at"):
|
||||
provider.rpm_reset_at = datetime.fromisoformat(provider_dict["rpm_reset_at"])
|
||||
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
def _endpoint_to_dict(endpoint: ProviderEndpoint) -> dict:
|
||||
"""将 Endpoint 对象转换为字典"""
|
||||
return {
|
||||
"id": endpoint.id,
|
||||
"provider_id": endpoint.provider_id,
|
||||
"name": endpoint.name,
|
||||
"base_url": endpoint.base_url,
|
||||
"is_active": endpoint.is_active,
|
||||
"priority": endpoint.priority,
|
||||
"weight": endpoint.weight,
|
||||
"custom_path": endpoint.custom_path,
|
||||
"config": endpoint.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_endpoint(endpoint_dict: dict) -> ProviderEndpoint:
|
||||
"""从字典重建 Endpoint 对象"""
|
||||
endpoint = ProviderEndpoint(
|
||||
id=endpoint_dict["id"],
|
||||
provider_id=endpoint_dict["provider_id"],
|
||||
name=endpoint_dict["name"],
|
||||
base_url=endpoint_dict["base_url"],
|
||||
is_active=endpoint_dict["is_active"],
|
||||
priority=endpoint_dict.get("priority", 0),
|
||||
weight=endpoint_dict.get("weight", 1.0),
|
||||
custom_path=endpoint_dict.get("custom_path"),
|
||||
config=endpoint_dict.get("config"),
|
||||
)
|
||||
return endpoint
|
||||
|
||||
@staticmethod
|
||||
def _api_key_to_dict(api_key: ProviderAPIKey) -> dict:
|
||||
"""将 API Key 对象转换为字典"""
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"endpoint_id": api_key.endpoint_id,
|
||||
"key_value": api_key.key_value,
|
||||
"is_active": api_key.is_active,
|
||||
"max_rpm": api_key.max_rpm,
|
||||
"current_rpm": api_key.current_rpm,
|
||||
"health_score": api_key.health_score,
|
||||
"circuit_breaker_state": api_key.circuit_breaker_state,
|
||||
"adaptive_concurrency_limit": api_key.adaptive_concurrency_limit,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_api_key(api_key_dict: dict) -> ProviderAPIKey:
|
||||
"""从字典重建 API Key 对象"""
|
||||
api_key = ProviderAPIKey(
|
||||
id=api_key_dict["id"],
|
||||
endpoint_id=api_key_dict["endpoint_id"],
|
||||
key_value=api_key_dict["key_value"],
|
||||
is_active=api_key_dict["is_active"],
|
||||
max_rpm=api_key_dict.get("max_rpm"),
|
||||
current_rpm=api_key_dict.get("current_rpm", 0),
|
||||
health_score=api_key_dict.get("health_score", 1.0),
|
||||
circuit_breaker_state=api_key_dict.get("circuit_breaker_state"),
|
||||
adaptive_concurrency_limit=api_key_dict.get("adaptive_concurrency_limit"),
|
||||
)
|
||||
return api_key
|
||||
209
src/services/cache/sync.py
vendored
Normal file
209
src/services/cache/sync.py
vendored
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
缓存同步服务(Redis Pub/Sub)
|
||||
|
||||
提供分布式缓存失效同步功能,用于多实例部署场景。
|
||||
当一个实例修改数据并失效本地缓存时,通过 Redis pub/sub 通知其他实例同步失效。
|
||||
|
||||
使用场景:
|
||||
1. 多实例部署时,确保所有实例的缓存一致性
|
||||
2. GlobalModel/ModelMapping 变更时,同步失效所有实例的缓存
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from src.core.logger import logger
|
||||
|
||||
from src.clients.redis_client import get_redis_client_sync
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
class CacheSyncService:
|
||||
"""
|
||||
缓存同步服务
|
||||
|
||||
通过 Redis pub/sub 实现分布式缓存失效同步
|
||||
"""
|
||||
|
||||
# Redis 频道名称
|
||||
CHANNEL_GLOBAL_MODEL = "cache:invalidate:global_model"
|
||||
CHANNEL_MODEL_MAPPING = "cache:invalidate:model_mapping"
|
||||
CHANNEL_MODEL = "cache:invalidate:model"
|
||||
CHANNEL_CLEAR_ALL = "cache:invalidate:clear_all"
|
||||
|
||||
def __init__(self, redis_client: aioredis.Redis):
|
||||
"""
|
||||
初始化缓存同步服务
|
||||
|
||||
Args:
|
||||
redis_client: Redis 客户端实例
|
||||
"""
|
||||
self._redis = redis_client
|
||||
self._pubsub: Optional[aioredis.client.PubSub] = None
|
||||
self._listener_task: Optional[asyncio.Task] = None
|
||||
self._handlers: Dict[str, Callable] = {}
|
||||
self._running = False
|
||||
|
||||
async def start(self):
|
||||
"""启动缓存同步服务(订阅 Redis 频道)"""
|
||||
if self._running:
|
||||
logger.warning("[CacheSync] 服务已在运行")
|
||||
return
|
||||
|
||||
try:
|
||||
self._pubsub = self._redis.pubsub()
|
||||
|
||||
# 订阅所有缓存失效频道
|
||||
await self._pubsub.subscribe(
|
||||
self.CHANNEL_GLOBAL_MODEL,
|
||||
self.CHANNEL_MODEL_MAPPING,
|
||||
self.CHANNEL_MODEL,
|
||||
self.CHANNEL_CLEAR_ALL,
|
||||
)
|
||||
|
||||
# 启动监听任务
|
||||
self._listener_task = asyncio.create_task(self._listen())
|
||||
self._running = True
|
||||
|
||||
logger.info("[CacheSync] 缓存同步服务已启动,订阅频道: "
|
||||
f"{self.CHANNEL_GLOBAL_MODEL}, {self.CHANNEL_MODEL_MAPPING}, "
|
||||
f"{self.CHANNEL_MODEL}, {self.CHANNEL_CLEAR_ALL}")
|
||||
except Exception as e:
|
||||
logger.error(f"[CacheSync] 启动失败: {e}")
|
||||
raise
|
||||
|
||||
async def stop(self):
|
||||
"""停止缓存同步服务"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
# 取消监听任务
|
||||
if self._listener_task:
|
||||
self._listener_task.cancel()
|
||||
try:
|
||||
await self._listener_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 取消订阅
|
||||
if self._pubsub:
|
||||
await self._pubsub.unsubscribe()
|
||||
await self._pubsub.close()
|
||||
|
||||
logger.info("[CacheSync] 缓存同步服务已停止")
|
||||
|
||||
def register_handler(self, channel: str, handler: Callable):
|
||||
"""
|
||||
注册缓存失效处理器
|
||||
|
||||
Args:
|
||||
channel: Redis 频道名称
|
||||
handler: 处理函数(接收消息数据作为参数)
|
||||
"""
|
||||
self._handlers[channel] = handler
|
||||
logger.debug(f"[CacheSync] 注册处理器: {channel}")
|
||||
|
||||
async def _listen(self):
|
||||
"""监听 Redis pub/sub 消息"""
|
||||
logger.info("[CacheSync] 开始监听缓存失效消息")
|
||||
|
||||
try:
|
||||
async for message in self._pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
channel = message["channel"]
|
||||
data = message["data"]
|
||||
|
||||
# 解析消息
|
||||
try:
|
||||
payload = json.loads(data)
|
||||
logger.debug(f"[CacheSync] 收到消息: {channel} -> {payload}")
|
||||
|
||||
# 调用注册的处理器
|
||||
if channel in self._handlers:
|
||||
handler = self._handlers[channel]
|
||||
await handler(payload)
|
||||
else:
|
||||
logger.warning(f"[CacheSync] 未找到处理器: {channel}")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[CacheSync] 消息解析失败: {data}, 错误: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[CacheSync] 处理消息失败: {channel}, 错误: {e}")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[CacheSync] 监听任务已取消")
|
||||
except Exception as e:
|
||||
logger.error(f"[CacheSync] 监听失败: {e}")
|
||||
|
||||
async def publish_global_model_changed(self, model_name: str):
|
||||
"""发布 GlobalModel 变更通知"""
|
||||
await self._publish(self.CHANNEL_GLOBAL_MODEL, {"model_name": model_name})
|
||||
|
||||
async def publish_model_mapping_changed(
|
||||
self, source_model: str, provider_id: Optional[str] = None
|
||||
):
|
||||
"""发布 ModelMapping 变更通知"""
|
||||
await self._publish(
|
||||
self.CHANNEL_MODEL_MAPPING, {"source_model": source_model, "provider_id": provider_id}
|
||||
)
|
||||
|
||||
async def publish_model_changed(self, provider_id: str, global_model_id: str):
|
||||
"""发布 Model 变更通知"""
|
||||
await self._publish(
|
||||
self.CHANNEL_MODEL, {"provider_id": provider_id, "global_model_id": global_model_id}
|
||||
)
|
||||
|
||||
async def publish_clear_all(self):
|
||||
"""发布清空所有缓存通知"""
|
||||
await self._publish(self.CHANNEL_CLEAR_ALL, {})
|
||||
|
||||
async def _publish(self, channel: str, data: dict):
|
||||
"""发布消息到 Redis 频道"""
|
||||
try:
|
||||
message = json.dumps(data)
|
||||
await self._redis.publish(channel, message)
|
||||
logger.debug(f"[CacheSync] 发布消息: {channel} -> {data}")
|
||||
except Exception as e:
|
||||
logger.error(f"[CacheSync] 发布消息失败: {channel}, 错误: {e}")
|
||||
|
||||
|
||||
# 全局单例
|
||||
_cache_sync_service: Optional[CacheSyncService] = None
|
||||
|
||||
|
||||
async def get_cache_sync_service(redis_client: aioredis.Redis = None) -> Optional[CacheSyncService]:
|
||||
"""
|
||||
获取缓存同步服务实例
|
||||
|
||||
Args:
|
||||
redis_client: Redis 客户端实例(首次调用时需要提供)
|
||||
|
||||
Returns:
|
||||
CacheSyncService 实例,如果 Redis 不可用返回 None
|
||||
"""
|
||||
global _cache_sync_service
|
||||
|
||||
if _cache_sync_service is None:
|
||||
if redis_client is None:
|
||||
# 尝试获取全局 Redis 客户端
|
||||
redis_client = get_redis_client_sync()
|
||||
|
||||
if redis_client is None:
|
||||
logger.warning("[CacheSync] Redis 不可用,分布式缓存同步已禁用")
|
||||
return None
|
||||
|
||||
_cache_sync_service = CacheSyncService(redis_client)
|
||||
logger.info("[CacheSync] 缓存同步服务已初始化")
|
||||
|
||||
return _cache_sync_service
|
||||
|
||||
|
||||
async def close_cache_sync_service():
|
||||
"""关闭缓存同步服务"""
|
||||
global _cache_sync_service
|
||||
|
||||
if _cache_sync_service:
|
||||
await _cache_sync_service.stop()
|
||||
_cache_sync_service = None
|
||||
155
src/services/cache/user_cache.py
vendored
Normal file
155
src/services/cache/user_cache.py
vendored
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
用户缓存服务 - 减少数据库查询
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.constants import CacheTTL
|
||||
from src.core.cache_service import CacheKeys, CacheService
|
||||
from src.core.logger import logger
|
||||
from src.models.database import User
|
||||
|
||||
|
||||
|
||||
class UserCacheService:
|
||||
"""用户缓存服务"""
|
||||
|
||||
# 缓存 TTL(秒)- 使用统一常量
|
||||
CACHE_TTL = CacheTTL.USER
|
||||
|
||||
@staticmethod
|
||||
async def get_user_by_id(db: Session, user_id: str) -> Optional[User]:
|
||||
"""
|
||||
获取用户(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
User 对象或 None
|
||||
"""
|
||||
cache_key = CacheKeys.user_by_id(user_id)
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"用户缓存命中: {user_id}")
|
||||
# 从缓存数据重建 User 对象
|
||||
return UserCacheService._dict_to_user(db, cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if user:
|
||||
user_dict = UserCacheService._user_to_dict(user)
|
||||
await CacheService.set(cache_key, user_dict, ttl_seconds=UserCacheService.CACHE_TTL)
|
||||
logger.debug(f"用户已缓存: {user_id}")
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def get_user_by_email(db: Session, email: str) -> Optional[User]:
|
||||
"""
|
||||
通过邮箱获取用户(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
email: 用户邮箱
|
||||
|
||||
Returns:
|
||||
User 对象或 None
|
||||
"""
|
||||
cache_key = CacheKeys.user_by_email(email)
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"用户缓存命中(邮箱): {email}")
|
||||
return UserCacheService._dict_to_user(db, cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if user:
|
||||
user_dict = UserCacheService._user_to_dict(user)
|
||||
await CacheService.set(cache_key, user_dict, ttl_seconds=UserCacheService.CACHE_TTL)
|
||||
logger.debug(f"用户已缓存(邮箱): {email}")
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_user_cache(user_id: str, email: Optional[str] = None):
|
||||
"""
|
||||
清除用户缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
email: 用户邮箱(可选)
|
||||
"""
|
||||
# 删除 ID 缓存
|
||||
await CacheService.delete(CacheKeys.user_by_id(user_id))
|
||||
|
||||
# 删除邮箱缓存
|
||||
if email:
|
||||
await CacheService.delete(CacheKeys.user_by_email(email))
|
||||
|
||||
logger.debug(f"用户缓存已清除: {user_id}")
|
||||
|
||||
@staticmethod
|
||||
def _user_to_dict(user: User) -> dict:
|
||||
"""将 User 对象转换为字典(用于缓存)"""
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"role": user.role.value if user.role else None,
|
||||
"is_active": user.is_active,
|
||||
"quota_usd": float(user.quota_usd) if user.quota_usd is not None else None,
|
||||
"used_usd": float(user.used_usd),
|
||||
"created_at": user.created_at.isoformat() if user.created_at else None,
|
||||
"last_login_at": user.last_login_at.isoformat() if user.last_login_at else None,
|
||||
"model_capability_settings": user.model_capability_settings,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_user(db: Session, user_dict: dict) -> User:
|
||||
"""
|
||||
从字典重建 User 对象
|
||||
|
||||
注意:这是一个"分离"的对象,不在 Session 中
|
||||
如果需要修改,需要使用 db.merge() 或重新查询
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
from src.models.database import UserRole
|
||||
|
||||
user = User(
|
||||
id=user_dict["id"],
|
||||
email=user_dict["email"],
|
||||
username=user_dict["username"],
|
||||
is_active=user_dict["is_active"],
|
||||
used_usd=user_dict["used_usd"],
|
||||
)
|
||||
|
||||
# 设置可选字段
|
||||
if user_dict.get("role"):
|
||||
user.role = UserRole(user_dict["role"])
|
||||
|
||||
if user_dict.get("quota_usd") is not None:
|
||||
user.quota_usd = user_dict["quota_usd"]
|
||||
|
||||
if user_dict.get("created_at"):
|
||||
user.created_at = datetime.fromisoformat(user_dict["created_at"])
|
||||
|
||||
if user_dict.get("last_login_at"):
|
||||
user.last_login_at = datetime.fromisoformat(user_dict["last_login_at"])
|
||||
|
||||
if user_dict.get("model_capability_settings") is not None:
|
||||
user.model_capability_settings = user_dict["model_capability_settings"]
|
||||
|
||||
return user
|
||||
Reference in New Issue
Block a user