Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

19
src/services/cache/__init__.py vendored Normal file
View File

@@ -0,0 +1,19 @@
"""
缓存服务模块
包含缓存后端、缓存亲和性、缓存同步等功能。
注意:由于循环依赖问题,部分类需要直接从子模块导入:
from src.services.cache.affinity_manager import CacheAffinityManager
from src.services.cache.aware_scheduler import CacheAwareScheduler
"""
# 只导出不会导致循环依赖的基础类
from src.services.cache.backend import BaseCacheBackend, LocalCache, RedisCache, get_cache_backend
__all__ = [
"BaseCacheBackend",
"LocalCache",
"RedisCache",
"get_cache_backend",
]

668
src/services/cache/affinity_manager.py vendored Normal file
View File

@@ -0,0 +1,668 @@
"""
缓存亲和性管理器 (Cache Affinity Manager) - 支持 Redis 或内存存储
职责:
1. 跟踪请求API Key的Provider+Key缓存状态
2. 管理缓存有效期
3. 提供缓存统计和分析
4. 自动失效不支持缓存的Provider
设计原理:
- 每个API Key使用某个Provider的Key后在缓存TTL期内应该继续使用同一个Key
- 这样可以最大化利用提供商的Prompt Caching机制
- 当Key故障时自动失效该Key的缓存亲和性
- 当Provider关闭缓存支持时自动失效所有相关亲和性
注意:
- affinity_key 参数通常为请求使用的 API Key IDapi_key_id
- 这样可以支持"独立余额Key"场景每个Key有自己的缓存亲和性
"""
import asyncio
import json
import os
import time
from contextlib import asynccontextmanager
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
from src.config.constants import CacheTTL
from src.core.logger import logger
class CacheAffinity(NamedTuple):
"""缓存亲和性信息"""
provider_id: str
endpoint_id: str
key_id: str
api_format: str # API格式 (claude/openai)
model_name: str # 模型名称
created_at: float # 创建时间戳
expire_at: float # 过期时间戳
request_count: int # 使用次数
class CacheAffinityManager:
"""
缓存亲和性管理器(支持 Redis 或内存存储)
存储结构:
----------------------
Key格式: cache_affinity:{affinity_key}:{api_format}:{model_name}
- affinity_key: 通常为请求使用的 API Key ID支持独立余额Key场景
- api_format: API格式 (claude/openai)
- model_name: 模型名称(区分不同模型的缓存亲和性)
Value格式: JSON/Dict
{
"provider_id": "xxx",
"endpoint_id": "yyy",
"key_id": "zzz",
"model_name": "claude-3-5-sonnet-20241022",
"created_at": 1234567890.123,
"expire_at": 1234567890.123,
"request_count": 5
}
TTL: 自动过期
设计改进:
- 每个API Key可以对多个API格式和模型分别维护缓存亲和性
- 不同模型请求使用独立的缓存亲和性,避免模型切换导致的缓存失效
- 某个端点故障切换不会影响其他端点的亲和性
- 更精确的缓存命中率统计
- 支持"独立余额Key"场景每个Key有独立的缓存亲和性
"""
# 默认缓存TTL- 使用统一常量
DEFAULT_CACHE_TTL = CacheTTL.CACHE_AFFINITY
def __init__(self, redis_client=None, default_ttl: int = DEFAULT_CACHE_TTL):
"""
初始化缓存亲和性管理器
Args:
redis_client: Redis客户端可选
default_ttl: 默认缓存TTL
"""
self.redis = redis_client
self.default_ttl = default_ttl
self._memory_store: Dict[str, Dict[str, Any]] = {}
self._memory_lock: Optional[asyncio.Lock] = None
# L1 缓存(即使使用 Redis 也启用,减少网络往返)
self._l1_cache_ttl = int(os.getenv("CACHE_AFFINITY_L1_TTL", str(CacheTTL.L1_LOCAL)))
self._l1_cache: Dict[str, Tuple[float, Dict[str, Any]]] = {}
self._l1_lock = asyncio.Lock()
self._l1_max_size = int(os.getenv("CACHE_AFFINITY_L1_MAX_SIZE", "1000")) # 最大缓存条目数
self._l1_last_cleanup = time.time()
# 请求级别锁,避免同一用户+端点同时更新造成抖动
self._request_locks: Dict[str, asyncio.Lock] = {}
# 统计信息
self._stats = {
"total_affinities": 0,
"cache_hits": 0,
"cache_misses": 0,
"cache_invalidations": 0,
"provider_switches": 0,
"key_switches": 0,
}
if self.redis:
logger.debug("CacheAffinityManager: 使用Redis存储")
else:
logger.debug("CacheAffinityManager: Redis不可用回退到内存存储(仅适用于单实例/开发环境)")
def _is_memory_backend(self) -> bool:
"""是否处于内存模式"""
return self.redis is None
def _get_memory_lock(self) -> asyncio.Lock:
"""懒初始化内存锁"""
if self._memory_lock is None:
self._memory_lock = asyncio.Lock()
return self._memory_lock
def _get_cache_key(self, affinity_key: str, api_format: str, model_name: str) -> str:
"""
生成Redis Key
Args:
affinity_key: 亲和性标识符通常为API Key ID
api_format: API格式 (claude/openai)
model_name: 模型名称
Returns:
格式化的缓存键: cache_affinity:{affinity_key}:{api_format}:{model_name}
"""
return f"cache_affinity:{affinity_key}:{api_format}:{model_name}"
async def _get_l1_entry(self, cache_key: str) -> Optional[Dict[str, Any]]:
async with self._l1_lock:
record = self._l1_cache.get(cache_key)
if not record:
return None
expire_at, payload = record
if time.time() > expire_at:
self._l1_cache.pop(cache_key, None)
return None
return dict(payload)
async def _set_l1_entry(self, cache_key: str, payload: Optional[Dict[str, Any]]):
async with self._l1_lock:
if not payload:
self._l1_cache.pop(cache_key, None)
return
expire_at = time.time() + max(1, self._l1_cache_ttl)
self._l1_cache[cache_key] = (expire_at, dict(payload))
# 定期清理过期条目(每 60 秒最多一次)
current_time = time.time()
if current_time - self._l1_last_cleanup > 60:
self._cleanup_l1_cache_unlocked(current_time)
self._l1_last_cleanup = current_time
def _cleanup_l1_cache_unlocked(self, current_time: float) -> int:
"""清理过期的 L1 缓存条目(需要在持有锁的情况下调用)
Returns:
清理的条目数量
"""
expired_keys = [
key for key, (expire_at, _) in self._l1_cache.items()
if current_time > expire_at
]
for key in expired_keys:
self._l1_cache.pop(key, None)
# 如果缓存仍然过大,按过期时间排序移除最旧的条目
if len(self._l1_cache) > self._l1_max_size:
sorted_items = sorted(
self._l1_cache.items(),
key=lambda x: x[1][0] # 按 expire_at 排序
)
# 移除最旧的 20% 条目
remove_count = len(self._l1_cache) - int(self._l1_max_size * 0.8)
for key, _ in sorted_items[:remove_count]:
self._l1_cache.pop(key, None)
expired_keys.extend([k for k, _ in sorted_items[:remove_count]])
if expired_keys:
logger.debug(f"L1 缓存清理: 移除 {len(expired_keys)} 个条目,当前 {len(self._l1_cache)}")
return len(expired_keys)
@asynccontextmanager
async def _acquire_request_lock(self, cache_key: str):
lock = self._request_locks.get(cache_key)
if lock is None:
lock = asyncio.Lock()
self._request_locks[cache_key] = lock
await lock.acquire()
try:
yield
finally:
lock.release()
async def _load_affinity_dict(self, cache_key: str) -> Optional[Dict[str, Any]]:
"""读取缓存亲和性字典"""
# 先尝试L1缓存
l1_value = await self._get_l1_entry(cache_key)
if l1_value is not None:
return l1_value
if not self._is_memory_backend():
data = await self.redis.get(cache_key)
if not data:
return None
value = json.loads(data)
await self._set_l1_entry(cache_key, value)
return value
lock = self._get_memory_lock()
async with lock:
record = self._memory_store.get(cache_key)
if record:
await self._set_l1_entry(cache_key, record)
return dict(record) if record else None
async def _save_affinity_dict(
self, cache_key: str, ttl: int, affinity_dict: Dict[str, Any]
) -> None:
"""存储缓存亲和性字典"""
if not self._is_memory_backend():
await self.redis.setex(cache_key, ttl, json.dumps(affinity_dict))
await self._set_l1_entry(cache_key, affinity_dict)
return
lock = self._get_memory_lock()
async with lock:
self._memory_store[cache_key] = dict(affinity_dict)
await self._set_l1_entry(cache_key, affinity_dict)
async def _delete_affinity_key(self, cache_key: str) -> None:
"""删除缓存亲和性"""
if not self._is_memory_backend():
await self.redis.delete(cache_key)
else:
lock = self._get_memory_lock()
async with lock:
self._memory_store.pop(cache_key, None)
await self._set_l1_entry(cache_key, None)
async def _snapshot_memory_items(self) -> Dict[str, Dict[str, Any]]:
"""复制内存存储内容(仅内存模式使用)"""
lock = self._get_memory_lock()
async with lock:
return {k: dict(v) for k, v in self._memory_store.items()}
async def get_affinity(
self, affinity_key: str, api_format: str, model_name: str
) -> Optional[CacheAffinity]:
"""
获取指定亲和性标识符对特定API格式和模型的缓存亲和性
Args:
affinity_key: 亲和性标识符通常为API Key ID
api_format: API格式 (claude/openai)
model_name: 模型名称
Returns:
CacheAffinity对象如果不存在或已过期则返回None
"""
try:
cache_key = self._get_cache_key(affinity_key, api_format, model_name)
async with self._acquire_request_lock(cache_key):
affinity_dict = await self._load_affinity_dict(cache_key)
if not affinity_dict:
self._stats["cache_misses"] += 1
return None
# 检查是否过期双重检查防止TTL未及时清理
current_time = time.time()
if current_time > affinity_dict["expire_at"]:
await self._delete_affinity_key(cache_key)
self._stats["cache_misses"] += 1
return None
self._stats["cache_hits"] += 1
return CacheAffinity(
provider_id=affinity_dict["provider_id"],
endpoint_id=affinity_dict["endpoint_id"],
key_id=affinity_dict["key_id"],
api_format=affinity_dict.get("api_format", api_format),
model_name=affinity_dict.get("model_name", model_name),
created_at=affinity_dict["created_at"],
expire_at=affinity_dict["expire_at"],
request_count=affinity_dict["request_count"],
)
except Exception as e:
logger.exception(f"获取缓存亲和性失败: {e}")
self._stats["cache_misses"] += 1
return None
async def set_affinity(
self,
affinity_key: str,
provider_id: str,
endpoint_id: str,
key_id: str,
api_format: str,
model_name: str,
supports_caching: bool = True,
ttl: Optional[int] = None,
) -> None:
"""
设置指定亲和性标识符对特定API格式和模型的缓存亲和性
Args:
affinity_key: 亲和性标识符通常为API Key ID
provider_id: Provider ID
endpoint_id: Endpoint ID
key_id: Key ID
api_format: API格式 (claude/openai)
model_name: 模型名称
supports_caching: 该Provider是否支持缓存
ttl: 缓存有效期(秒),如果不提供则使用默认值
注意每次调用都会刷新过期时间滑动窗口机制以保持对同一个Provider/Endpoint/Key的亲和性
"""
if not supports_caching:
# 不支持缓存的Provider不记录亲和性
logger.debug(f"Provider {provider_id[:8]}... 不支持缓存,跳过亲和性记录")
return
ttl = ttl or self.default_ttl
current_time = time.time()
expire_at = current_time + ttl # 每次都刷新过期时间
cache_key = self._get_cache_key(affinity_key, api_format, model_name)
try:
async with self._acquire_request_lock(cache_key):
existing_dict = await self._load_affinity_dict(cache_key)
existing_affinity: Optional[CacheAffinity] = None
if existing_dict and current_time <= existing_dict.get("expire_at", 0):
existing_affinity = CacheAffinity(
provider_id=existing_dict["provider_id"],
endpoint_id=existing_dict["endpoint_id"],
key_id=existing_dict["key_id"],
api_format=existing_dict.get("api_format", api_format),
model_name=existing_dict.get("model_name", model_name),
created_at=existing_dict["created_at"],
expire_at=existing_dict["expire_at"],
request_count=existing_dict.get("request_count", 0),
)
if existing_affinity:
created_at = existing_affinity.created_at
request_count = existing_affinity.request_count + 1
# 检查是否切换了 Provider/Endpoint/Key
if (
existing_affinity.provider_id != provider_id
or existing_affinity.endpoint_id != endpoint_id
or existing_affinity.key_id != key_id
):
self._stats["key_switches"] += 1
logger.debug(f"Key {affinity_key[:8]}... 在 {api_format} 格式下切换后端: "
f"[{existing_affinity.provider_id[:8]}.../{existing_affinity.endpoint_id[:8]}.../"
f"{existing_affinity.key_id[:8]}...] → "
f"[{provider_id[:8]}.../{endpoint_id[:8]}.../{key_id[:8]}...], 重置计数器")
created_at = current_time
request_count = 1
else:
logger.debug(f"刷新缓存亲和性: key={affinity_key[:8]}..., api_format={api_format}, "
f"provider={provider_id[:8]}..., endpoint={endpoint_id[:8]}..., "
f"provider_key={key_id[:8]}..., ttl+={ttl}s")
else:
created_at = current_time
request_count = 1
self._stats["total_affinities"] += 1
affinity_dict = {
"provider_id": provider_id,
"endpoint_id": endpoint_id,
"key_id": key_id,
"api_format": api_format,
"model_name": model_name,
"created_at": created_at,
"expire_at": expire_at,
"request_count": request_count,
}
await self._save_affinity_dict(cache_key, ttl, affinity_dict)
logger.debug(f"设置缓存亲和性: key={affinity_key[:8]}..., api_format={api_format}, "
f"model={model_name}, provider={provider_id[:8]}..., endpoint={endpoint_id[:8]}..., "
f"provider_key={key_id[:8]}..., ttl={ttl}s")
except Exception as e:
logger.exception(f"设置缓存亲和性失败: {e}")
async def invalidate_affinity(
self,
affinity_key: str,
api_format: str,
model_name: str,
key_id: Optional[str] = None,
provider_id: Optional[str] = None,
endpoint_id: Optional[str] = None,
) -> None:
"""
失效指定亲和性标识符对特定API格式和模型的缓存亲和性
Args:
affinity_key: 亲和性标识符通常为API Key ID
api_format: API格式 (claude/openai)
model_name: 模型名称
key_id: Provider Key ID可选如果提供则只在Key匹配时失效
provider_id: Provider ID可选如果提供则只在Provider匹配时失效
endpoint_id: Endpoint ID可选如果提供则只在Endpoint匹配时失效
"""
existing_affinity = await self.get_affinity(affinity_key, api_format, model_name)
if not existing_affinity:
return
# 检查是否匹配过滤条件
should_invalidate = True
if key_id and existing_affinity.key_id != key_id:
should_invalidate = False
if provider_id and existing_affinity.provider_id != provider_id:
should_invalidate = False
if endpoint_id and existing_affinity.endpoint_id != endpoint_id:
should_invalidate = False
if not should_invalidate:
logger.debug(f"跳过失效: affinity_key={affinity_key[:8]}..., api_format={api_format}, "
f"model={model_name}, 过滤条件不匹配 (key={key_id}, provider={provider_id}, endpoint={endpoint_id})")
return
try:
cache_key = self._get_cache_key(affinity_key, api_format, model_name)
async with self._acquire_request_lock(cache_key):
await self._delete_affinity_key(cache_key)
self._stats["cache_invalidations"] += 1
logger.debug(f"失效缓存亲和性: affinity_key={affinity_key[:8]}..., api_format={api_format}, "
f"model={model_name}, provider={existing_affinity.provider_id[:8]}..., "
f"endpoint={existing_affinity.endpoint_id[:8]}..., "
f"provider_key={existing_affinity.key_id[:8]}...")
except Exception as e:
logger.exception(f"删除缓存亲和性失败: {e}")
async def invalidate_all_for_provider(self, provider_id: str) -> int:
"""
失效所有与指定Provider相关的缓存亲和性
用途当Provider关闭缓存支持时调用
Args:
provider_id: Provider ID
Returns:
失效的亲和性数量
"""
try:
invalidated_count = 0
if not self._is_memory_backend():
pattern = "cache_affinity:*"
keys = await self.redis.keys(pattern)
else:
keys = list((await self._snapshot_memory_items()).keys())
for key in keys:
affinity_dict = await self._load_affinity_dict(key)
if not affinity_dict:
continue
if affinity_dict.get("provider_id") == provider_id:
await self._delete_affinity_key(key)
invalidated_count += 1
self._stats["cache_invalidations"] += 1
if invalidated_count > 0:
logger.debug(f"批量失效Provider缓存亲和性: provider={provider_id[:8]}..., "
f"失效数量={invalidated_count}")
return invalidated_count
except Exception as e:
logger.exception(f"批量失效Provider缓存亲和性失败: {e}")
return 0
async def clear_all(self) -> int:
"""
清除所有缓存亲和性(管理功能)
Returns:
清除的数量
"""
try:
if not self._is_memory_backend():
keys = await self.redis.keys("cache_affinity:*")
if keys:
await self.redis.delete(*keys)
logger.debug(f"清除所有Redis缓存亲和性: {len(keys)}")
return len(keys)
return 0
lock = self._get_memory_lock()
async with lock:
count = len(self._memory_store)
self._memory_store.clear()
if count:
logger.debug(f"清除所有内存缓存亲和性: {count}")
return count
except Exception as e:
logger.exception(f"清除缓存亲和性失败: {e}")
return 0
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
cache_hit_rate = 0.0
total_requests = self._stats["cache_hits"] + self._stats["cache_misses"]
if total_requests > 0:
cache_hit_rate = self._stats["cache_hits"] / total_requests
storage_type = "redis" if not self._is_memory_backend() else "memory"
return {
"storage_type": storage_type,
"total_affinities": self._stats["total_affinities"],
"cache_hits": self._stats["cache_hits"],
"cache_misses": self._stats["cache_misses"],
"cache_hit_rate": cache_hit_rate,
"cache_invalidations": self._stats["cache_invalidations"],
"provider_switches": self._stats["provider_switches"],
"key_switches": self._stats["key_switches"],
"config": {
"default_ttl": self.default_ttl,
},
}
async def list_affinities(self) -> List[Dict[str, Any]]:
"""获取所有缓存亲和性列表
返回的每条记录包含:
- affinity_key: 亲和性标识符(通常是 API Key ID
- provider_id, endpoint_id, key_id: Provider 相关信息
- api_format, model_name: API 格式和模型名称
- created_at, expire_at, request_count: 缓存元数据
"""
results: List[Dict[str, Any]] = []
try:
pattern = "cache_affinity:*"
cursor = 0
if not self._is_memory_backend():
while True:
cursor, keys = await self.redis.scan(cursor=cursor, match=pattern, count=200)
if keys:
values = await self.redis.mget(*keys)
for cache_key, data in zip(keys, values):
if not data:
continue
try:
affinity = json.loads(data)
# 解析 cache_affinity:{affinity_key}:{api_format}:{model_name}
parts = cache_key.split(":")
affinity_key_value = parts[1] if len(parts) > 1 else cache_key
api_format = (
parts[2]
if len(parts) > 2
else affinity.get("api_format", "unknown")
)
model_name = (
parts[3]
if len(parts) > 3
else affinity.get("model_name", "unknown")
)
affinity["affinity_key"] = affinity_key_value
if "api_format" not in affinity:
affinity["api_format"] = api_format
if "model_name" not in affinity:
affinity["model_name"] = model_name
results.append(affinity)
except json.JSONDecodeError as e:
logger.exception(f"解析缓存亲和性记录失败: {cache_key} - {e}")
if cursor == 0:
break
else:
snapshot = await self._snapshot_memory_items()
expired_keys: List[str] = []
current_time = time.time()
for cache_key, affinity in snapshot.items():
if current_time > affinity["expire_at"]:
expired_keys.append(cache_key)
continue
# 解析 cache_affinity:{affinity_key}:{api_format}:{model_name}
parts = cache_key.split(":")
affinity_key_value = parts[1] if len(parts) > 1 else cache_key
api_format = (
parts[2] if len(parts) > 2 else affinity.get("api_format", "unknown")
)
model_name = (
parts[3] if len(parts) > 3 else affinity.get("model_name", "unknown")
)
affinity_with_key = dict(affinity)
affinity_with_key["affinity_key"] = affinity_key_value
if "api_format" not in affinity_with_key:
affinity_with_key["api_format"] = api_format
if "model_name" not in affinity_with_key:
affinity_with_key["model_name"] = model_name
results.append(affinity_with_key)
# 清理过期的键
if expired_keys:
async with self._get_memory_lock():
for key in expired_keys:
self._memory_store.pop(key, None)
except Exception as e:
logger.exception(f"获取缓存亲和性列表失败: {e}")
return results
# 全局单例
_affinity_manager: Optional[CacheAffinityManager] = None
async def get_affinity_manager(redis_client=None) -> CacheAffinityManager:
"""
获取全局CacheAffinityManager实例若Redis不可用则降级为内存模式
Args:
redis_client: Redis客户端可选
Returns:
CacheAffinityManager实例
"""
global _affinity_manager
if _affinity_manager is None:
_affinity_manager = CacheAffinityManager(redis_client)
elif redis_client and _affinity_manager.redis is None:
# 当最初使用内存后 Redis 可用时,升级为 Redis 存储
_affinity_manager = CacheAffinityManager(redis_client)
return _affinity_manager

1316
src/services/cache/aware_scheduler.py vendored Normal file

File diff suppressed because it is too large Load Diff

330
src/services/cache/backend.py vendored Normal file
View File

@@ -0,0 +1,330 @@
"""
缓存后端抽象层
提供统一的缓存接口,支持多种后端实现:
1. LocalCache: 内存缓存(单实例,线程安全)
2. RedisCache: Redis 缓存(分布式)
使用场景:
- ModelMappingResolver: 模型映射与别名解析缓存
- ModelMapper: 模型映射缓存
- 其他需要缓存的服务
"""
import asyncio
import json
import time
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Dict, Optional
import redis.asyncio as aioredis
from src.core.logger import logger
from src.clients.redis_client import get_redis_client_sync
from src.core.logger import logger
class BaseCacheBackend(ABC):
"""缓存后端抽象基类"""
@abstractmethod
async def get(self, key: str) -> Optional[Any]:
"""获取缓存值"""
pass
@abstractmethod
async def set(self, key: str, value: Any, ttl: int = 300) -> None:
"""设置缓存值"""
pass
@abstractmethod
async def delete(self, key: str) -> None:
"""删除缓存值"""
pass
@abstractmethod
async def clear(self, pattern: Optional[str] = None) -> None:
"""清空缓存(支持模式匹配)"""
pass
@abstractmethod
async def exists(self, key: str) -> bool:
"""检查键是否存在"""
pass
class LocalCache(BaseCacheBackend):
"""本地内存缓存后端LRU + TTL线程安全"""
def __init__(self, max_size: int = 1000, default_ttl: int = 300):
"""
初始化本地缓存
Args:
max_size: 最大缓存条目数
default_ttl: 默认过期时间(秒)
"""
self._cache: OrderedDict = OrderedDict()
self._expiry: Dict[str, float] = {}
self._max_size = max_size
self._default_ttl = default_ttl
self._lock = asyncio.Lock()
async def get(self, key: str) -> Optional[Any]:
"""获取缓存值(线程安全)"""
async with self._lock:
if key not in self._cache:
return None
# 检查过期
if key in self._expiry and time.time() > self._expiry[key]:
# 过期,删除
del self._cache[key]
del self._expiry[key]
return None
# 更新访问顺序LRU
self._cache.move_to_end(key)
return self._cache[key]
async def set(self, key: str, value: Any, ttl: int = None) -> None:
"""设置缓存值(线程安全)"""
async with self._lock:
if ttl is None:
ttl = self._default_ttl
# 如果键已存在,更新访问顺序
if key in self._cache:
self._cache.move_to_end(key)
self._cache[key] = value
self._expiry[key] = time.time() + ttl
# 检查容量限制,淘汰最旧项
if len(self._cache) > self._max_size:
oldest_key = next(iter(self._cache))
del self._cache[oldest_key]
if oldest_key in self._expiry:
del self._expiry[oldest_key]
async def delete(self, key: str) -> None:
"""删除缓存值(线程安全)"""
async with self._lock:
if key in self._cache:
del self._cache[key]
if key in self._expiry:
del self._expiry[key]
async def clear(self, pattern: Optional[str] = None) -> None:
"""清空缓存(线程安全)"""
async with self._lock:
if pattern is None:
# 清空所有
self._cache.clear()
self._expiry.clear()
else:
# 模式匹配删除(简单实现:支持前缀匹配)
prefix = pattern.rstrip("*")
keys_to_delete = [k for k in self._cache.keys() if k.startswith(prefix)]
for key in keys_to_delete:
del self._cache[key]
if key in self._expiry:
del self._expiry[key]
async def exists(self, key: str) -> bool:
"""检查键是否存在(线程安全)"""
async with self._lock:
if key not in self._cache:
return False
# 检查过期
if key in self._expiry and time.time() > self._expiry[key]:
del self._cache[key]
del self._expiry[key]
return False
return True
def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
return {
"backend": "local",
"size": len(self._cache),
"max_size": self._max_size,
"default_ttl": self._default_ttl,
}
class RedisCache(BaseCacheBackend):
"""Redis 缓存后端(分布式)"""
def __init__(
self, redis_client: aioredis.Redis, key_prefix: str = "cache", default_ttl: int = 300
):
"""
初始化 Redis 缓存
Args:
redis_client: Redis 客户端实例
key_prefix: 缓存键前缀
default_ttl: 默认过期时间(秒)
"""
self._redis = redis_client
self._key_prefix = key_prefix
self._default_ttl = default_ttl
def _make_key(self, key: str) -> str:
"""构造完整的 Redis 键"""
return f"{self._key_prefix}:{key}"
async def get(self, key: str) -> Optional[Any]:
"""获取缓存值"""
try:
redis_key = self._make_key(key)
value = await self._redis.get(redis_key)
if value is None:
return None
# 尝试 JSON 反序列化
try:
return json.loads(value)
except (json.JSONDecodeError, TypeError):
# 如果不是 JSON直接返回字符串
return value
except Exception as e:
logger.error(f"[RedisCache] 获取缓存失败: {key}, 错误: {e}")
return None
async def set(self, key: str, value: Any, ttl: int = None) -> None:
"""设置缓存值"""
if ttl is None:
ttl = self._default_ttl
try:
redis_key = self._make_key(key)
# 序列化值
if isinstance(value, (dict, list, tuple)):
serialized = json.dumps(value)
elif isinstance(value, (int, float, bool)):
serialized = json.dumps(value)
else:
serialized = str(value)
await self._redis.setex(redis_key, ttl, serialized)
except Exception as e:
logger.error(f"[RedisCache] 设置缓存失败: {key}, 错误: {e}")
async def delete(self, key: str) -> None:
"""删除缓存值"""
try:
redis_key = self._make_key(key)
await self._redis.delete(redis_key)
except Exception as e:
logger.error(f"[RedisCache] 删除缓存失败: {key}, 错误: {e}")
async def clear(self, pattern: Optional[str] = None) -> None:
"""清空缓存"""
try:
if pattern is None:
# 清空所有带前缀的键
pattern = "*"
redis_pattern = self._make_key(pattern)
cursor = 0
deleted_count = 0
while True:
cursor, keys = await self._redis.scan(cursor, match=redis_pattern, count=100)
if keys:
await self._redis.delete(*keys)
deleted_count += len(keys)
if cursor == 0:
break
logger.info(f"[RedisCache] 清空缓存: {redis_pattern}, 删除 {deleted_count} 个键")
except Exception as e:
logger.error(f"[RedisCache] 清空缓存失败: {pattern}, 错误: {e}")
async def exists(self, key: str) -> bool:
"""检查键是否存在"""
try:
redis_key = self._make_key(key)
return await self._redis.exists(redis_key) > 0
except Exception as e:
logger.error(f"[RedisCache] 检查键存在失败: {key}, 错误: {e}")
return False
async def publish_invalidation(self, channel: str, key: str) -> None:
"""发布缓存失效消息(用于分布式同步)"""
try:
message = json.dumps({"key": key, "timestamp": time.time()})
await self._redis.publish(channel, message)
logger.debug(f"[RedisCache] 发布缓存失效: {channel} -> {key}")
except Exception as e:
logger.error(f"[RedisCache] 发布缓存失效失败: {channel}, {key}, 错误: {e}")
def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
return {
"backend": "redis",
"key_prefix": self._key_prefix,
"default_ttl": self._default_ttl,
}
# 缓存后端工厂
_cache_backends: Dict[str, BaseCacheBackend] = {}
async def get_cache_backend(
name: str, backend_type: str = "auto", max_size: int = 1000, ttl: int = 300
) -> BaseCacheBackend:
"""
获取缓存后端实例
Args:
name: 缓存名称(用于区分不同的缓存实例)
backend_type: 后端类型 (auto/local/redis)
max_size: LocalCache 的最大容量
ttl: 默认过期时间(秒)
Returns:
BaseCacheBackend 实例
"""
cache_key = f"{name}:{backend_type}"
if cache_key in _cache_backends:
return _cache_backends[cache_key]
# 根据类型创建缓存后端
if backend_type == "redis":
# 尝试使用 Redis
redis_client = get_redis_client_sync()
if redis_client is None:
logger.warning(f"[CacheBackend] Redis 未初始化,{name} 降级为本地缓存")
backend = LocalCache(max_size=max_size, default_ttl=ttl)
else:
backend = RedisCache(redis_client=redis_client, key_prefix=name, default_ttl=ttl)
logger.info(f"[CacheBackend] {name} 使用 Redis 缓存")
elif backend_type == "local":
# 强制使用本地缓存
backend = LocalCache(max_size=max_size, default_ttl=ttl)
logger.info(f"[CacheBackend] {name} 使用本地缓存")
else: # auto
# 自动选择:优先 Redis降级到 Local
redis_client = get_redis_client_sync()
if redis_client is not None:
backend = RedisCache(redis_client=redis_client, key_prefix=name, default_ttl=ttl)
logger.debug(f"[CacheBackend] {name} 自动选择 Redis 缓存")
else:
backend = LocalCache(max_size=max_size, default_ttl=ttl)
logger.debug(f"[CacheBackend] {name} 自动选择本地缓存Redis 不可用)")
_cache_backends[cache_key] = backend
return backend

125
src/services/cache/invalidation.py vendored Normal file
View File

@@ -0,0 +1,125 @@
"""
缓存失效服务
统一管理各种缓存的失效逻辑,支持:
1. GlobalModel 变更时失效相关缓存
2. ModelMapping 变更时失效别名/降级缓存
3. Model 变更时失效模型映射缓存
4. 支持同步和异步缓存后端
"""
import asyncio
from typing import Optional
from src.core.logger import logger
from src.core.logger import logger
class CacheInvalidationService:
"""
缓存失效服务
提供统一的缓存失效接口,当数据库模型变更时自动清理相关缓存
"""
def __init__(self):
"""初始化缓存失效服务"""
self._mapping_resolver = None
self._model_mappers = [] # 可能有多个 ModelMapperMiddleware 实例
def set_mapping_resolver(self, mapping_resolver):
"""设置模型映射解析器实例"""
self._mapping_resolver = mapping_resolver
logger.debug(f"[CacheInvalidation] 模型映射解析器已注册 (实例: {id(mapping_resolver)})")
def register_model_mapper(self, model_mapper):
"""注册 ModelMapper 实例"""
if model_mapper not in self._model_mappers:
self._model_mappers.append(model_mapper)
logger.debug(f"[CacheInvalidation] ModelMapper 已注册 (实例: {id(model_mapper)},总数: {len(self._model_mappers)})")
def on_global_model_changed(self, model_name: str):
"""
GlobalModel 变更时的缓存失效
Args:
model_name: 变更的 GlobalModel.name
"""
logger.info(f"[CacheInvalidation] GlobalModel 变更: {model_name}")
# 异步失效模型解析器中的缓存
if self._mapping_resolver:
asyncio.create_task(self._mapping_resolver.invalidate_global_model_cache())
# 失效所有 ModelMapper 中与此模型相关的缓存
for mapper in self._model_mappers:
# 清空所有缓存(因为不知道哪些 provider 使用了这个模型)
mapper.clear_cache()
logger.debug(f"[CacheInvalidation] 已清空 ModelMapper 缓存")
def on_model_mapping_changed(self, source_model: str, provider_id: Optional[str] = None):
"""
ModelMapping 变更时的缓存失效
Args:
source_model: 变更的源模型名
provider_id: 相关 ProviderNone 表示全局)
"""
logger.info(f"[CacheInvalidation] ModelMapping 变更: {source_model} (provider={provider_id})")
if self._mapping_resolver:
asyncio.create_task(
self._mapping_resolver.invalidate_mapping_cache(source_model, provider_id)
)
for mapper in self._model_mappers:
if provider_id:
mapper.refresh_cache(provider_id)
else:
mapper.clear_cache()
def on_model_changed(self, provider_id: str, global_model_id: str):
"""
Model 变更时的缓存失效
Args:
provider_id: Provider ID
global_model_id: GlobalModel ID
"""
logger.info(f"[CacheInvalidation] Model 变更: provider={provider_id[:8]}..., "
f"global_model={global_model_id[:8]}...")
# 失效 ModelMapper 中特定 Provider 的缓存
for mapper in self._model_mappers:
mapper.refresh_cache(provider_id)
def clear_all_caches(self):
"""清空所有缓存"""
logger.info("[CacheInvalidation] 清空所有缓存")
if self._mapping_resolver:
asyncio.create_task(self._mapping_resolver.clear_cache())
for mapper in self._model_mappers:
mapper.clear_cache()
# 全局单例
_cache_invalidation_service: Optional[CacheInvalidationService] = None
def get_cache_invalidation_service() -> CacheInvalidationService:
"""
获取全局缓存失效服务实例
Returns:
CacheInvalidationService 实例
"""
global _cache_invalidation_service
if _cache_invalidation_service is None:
_cache_invalidation_service = CacheInvalidationService()
logger.debug("[CacheInvalidation] 初始化缓存失效服务")
return _cache_invalidation_service

325
src/services/cache/model_cache.py vendored Normal file
View File

@@ -0,0 +1,325 @@
"""
Model 映射缓存服务 - 减少模型映射和别名查询
"""
from typing import Optional
from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
from src.core.cache_service import CacheService
from src.core.logger import logger
from src.models.database import GlobalModel, Model, ModelMapping
class ModelCacheService:
"""Model 映射缓存服务"""
# 缓存 TTL- 使用统一常量
CACHE_TTL = CacheTTL.MODEL
@staticmethod
async def get_model_by_id(db: Session, model_id: str) -> Optional[Model]:
"""
获取 Model带缓存
Args:
db: 数据库会话
model_id: Model ID
Returns:
Model 对象或 None
"""
cache_key = f"model:id:{model_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"Model 缓存命中: {model_id}")
return ModelCacheService._dict_to_model(cached_data)
# 2. 缓存未命中,查询数据库
model = db.query(Model).filter(Model.id == model_id).first()
# 3. 写入缓存
if model:
model_dict = ModelCacheService._model_to_dict(model)
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
logger.debug(f"Model 已缓存: {model_id}")
return model
@staticmethod
async def get_global_model_by_id(db: Session, global_model_id: str) -> Optional[GlobalModel]:
"""
获取 GlobalModel带缓存
Args:
db: 数据库会话
global_model_id: GlobalModel ID
Returns:
GlobalModel 对象或 None
"""
cache_key = f"global_model:id:{global_model_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"GlobalModel 缓存命中: {global_model_id}")
return ModelCacheService._dict_to_global_model(cached_data)
# 2. 缓存未命中,查询数据库
global_model = db.query(GlobalModel).filter(GlobalModel.id == global_model_id).first()
# 3. 写入缓存
if global_model:
global_model_dict = ModelCacheService._global_model_to_dict(global_model)
await CacheService.set(
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
)
logger.debug(f"GlobalModel 已缓存: {global_model_id}")
return global_model
@staticmethod
async def get_model_by_provider_and_global_model(
db: Session, provider_id: str, global_model_id: str
) -> Optional[Model]:
"""
通过 Provider ID 和 GlobalModel ID 获取 Model带缓存
Args:
db: 数据库会话
provider_id: Provider ID
global_model_id: GlobalModel ID
Returns:
Model 对象或 None
"""
cache_key = f"model:provider_global:{provider_id}:{global_model_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}...")
return ModelCacheService._dict_to_model(cached_data)
# 2. 缓存未命中,查询数据库
model = (
db.query(Model)
.filter(
Model.provider_id == provider_id,
Model.global_model_id == global_model_id,
Model.is_active == True,
)
.first()
)
# 3. 写入缓存
if model:
model_dict = ModelCacheService._model_to_dict(model)
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
logger.debug(f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}...")
return model
@staticmethod
async def get_global_model_by_name(db: Session, name: str) -> Optional[GlobalModel]:
"""
通过名称获取 GlobalModel带缓存
Args:
db: 数据库会话
name: GlobalModel 名称
Returns:
GlobalModel 对象或 None
"""
cache_key = f"global_model:name:{name}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"GlobalModel 缓存命中(名称): {name}")
return ModelCacheService._dict_to_global_model(cached_data)
# 2. 缓存未命中,查询数据库
global_model = db.query(GlobalModel).filter(GlobalModel.name == name).first()
# 3. 写入缓存
if global_model:
global_model_dict = ModelCacheService._global_model_to_dict(global_model)
await CacheService.set(
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
)
logger.debug(f"GlobalModel 已缓存(名称): {name}")
return global_model
@staticmethod
async def resolve_alias(
db: Session, source_model: str, provider_id: Optional[str] = None
) -> Optional[str]:
"""
解析模型别名(带缓存)
Args:
db: 数据库会话
source_model: 源模型名称或别名
provider_id: Provider ID可选用于 Provider 特定别名)
Returns:
目标 GlobalModel ID 或 None
"""
# 构造缓存键
if provider_id:
cache_key = f"alias:provider:{provider_id}:{source_model}"
else:
cache_key = f"alias:global:{source_model}"
# 1. 尝试从缓存获取
cached_result = await CacheService.get(cache_key)
if cached_result:
logger.debug(f"别名缓存命中: {source_model} (provider: {provider_id or 'global'})")
return cached_result
# 2. 缓存未命中,查询数据库
query = db.query(ModelMapping).filter(ModelMapping.source_model == source_model)
if provider_id:
# Provider 特定别名优先
query = query.filter(ModelMapping.provider_id == provider_id)
else:
# 全局别名
query = query.filter(ModelMapping.provider_id.is_(None))
mapping = query.first()
# 3. 写入缓存
target_global_model_id = mapping.target_global_model_id if mapping else None
await CacheService.set(
cache_key, target_global_model_id, ttl_seconds=ModelCacheService.CACHE_TTL
)
if mapping:
logger.debug(f"别名已缓存: {source_model}{target_global_model_id}")
return target_global_model_id
@staticmethod
async def invalidate_model_cache(
model_id: str, provider_id: Optional[str] = None, global_model_id: Optional[str] = None
):
"""清除 Model 缓存
Args:
model_id: Model ID
provider_id: Provider ID用于清除 provider_global 缓存)
global_model_id: GlobalModel ID用于清除 provider_global 缓存)
"""
# 清除 model:id 缓存
await CacheService.delete(f"model:id:{model_id}")
# 清除 provider_global 缓存(如果提供了必要参数)
if provider_id and global_model_id:
await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}")
logger.debug(f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}...")
else:
logger.debug(f"Model 缓存已清除: {model_id}")
@staticmethod
async def invalidate_global_model_cache(global_model_id: str, name: Optional[str] = None):
"""清除 GlobalModel 缓存"""
await CacheService.delete(f"global_model:id:{global_model_id}")
if name:
await CacheService.delete(f"global_model:name:{name}")
logger.debug(f"GlobalModel 缓存已清除: {global_model_id}")
@staticmethod
async def invalidate_alias_cache(source_model: str, provider_id: Optional[str] = None):
"""清除别名缓存"""
if provider_id:
cache_key = f"alias:provider:{provider_id}:{source_model}"
else:
cache_key = f"alias:global:{source_model}"
await CacheService.delete(cache_key)
logger.debug(f"别名缓存已清除: {source_model}")
@staticmethod
def _model_to_dict(model: Model) -> dict:
"""将 Model 对象转换为字典"""
return {
"id": model.id,
"provider_id": model.provider_id,
"global_model_id": model.global_model_id,
"provider_model_name": model.provider_model_name,
"is_active": model.is_active,
"is_available": model.is_available if hasattr(model, "is_available") else True,
"price_per_request": (
float(model.price_per_request) if model.price_per_request else None
),
"tiered_pricing": model.tiered_pricing,
"supports_vision": model.supports_vision,
"supports_function_calling": model.supports_function_calling,
"supports_streaming": model.supports_streaming,
"supports_extended_thinking": model.supports_extended_thinking,
"config": model.config,
}
@staticmethod
def _dict_to_model(model_dict: dict) -> Model:
"""从字典重建 Model 对象"""
model = Model(
id=model_dict["id"],
provider_id=model_dict["provider_id"],
global_model_id=model_dict["global_model_id"],
provider_model_name=model_dict["provider_model_name"],
is_active=model_dict["is_active"],
is_available=model_dict.get("is_available", True),
price_per_request=model_dict.get("price_per_request"),
tiered_pricing=model_dict.get("tiered_pricing"),
supports_vision=model_dict.get("supports_vision"),
supports_function_calling=model_dict.get("supports_function_calling"),
supports_streaming=model_dict.get("supports_streaming"),
supports_extended_thinking=model_dict.get("supports_extended_thinking"),
config=model_dict.get("config"),
)
return model
@staticmethod
def _global_model_to_dict(global_model: GlobalModel) -> dict:
"""将 GlobalModel 对象转换为字典"""
return {
"id": global_model.id,
"name": global_model.name,
"display_name": global_model.display_name,
"family": global_model.family,
"group_id": global_model.group_id,
"supports_vision": global_model.supports_vision,
"supports_thinking": global_model.supports_thinking,
"context_window": global_model.context_window,
"max_output_tokens": global_model.max_output_tokens,
"is_active": global_model.is_active,
"description": global_model.description,
}
@staticmethod
def _dict_to_global_model(global_model_dict: dict) -> GlobalModel:
"""从字典重建 GlobalModel 对象"""
global_model = GlobalModel(
id=global_model_dict["id"],
name=global_model_dict["name"],
display_name=global_model_dict.get("display_name"),
family=global_model_dict.get("family"),
group_id=global_model_dict.get("group_id"),
supports_vision=global_model_dict.get("supports_vision", False),
supports_thinking=global_model_dict.get("supports_thinking", False),
context_window=global_model_dict.get("context_window"),
max_output_tokens=global_model_dict.get("max_output_tokens"),
is_active=global_model_dict.get("is_active", True),
description=global_model_dict.get("description"),
)
return global_model

254
src/services/cache/provider_cache.py vendored Normal file
View File

@@ -0,0 +1,254 @@
"""
Provider 配置缓存服务 - 减少 Provider/Endpoint/APIKey 查询
"""
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
from src.core.cache_service import CacheKeys, CacheService
from src.core.logger import logger
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
class ProviderCacheService:
"""Provider 配置缓存服务"""
# 缓存 TTL- 使用统一常量
CACHE_TTL = CacheTTL.PROVIDER
@staticmethod
async def get_provider_by_id(db: Session, provider_id: str) -> Optional[Provider]:
"""
获取 Provider带缓存
Args:
db: 数据库会话
provider_id: Provider ID
Returns:
Provider 对象或 None
"""
cache_key = CacheKeys.provider_by_id(provider_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"Provider 缓存命中: {provider_id}")
return ProviderCacheService._dict_to_provider(cached_data)
# 2. 缓存未命中,查询数据库
provider = db.query(Provider).filter(Provider.id == provider_id).first()
# 3. 写入缓存
if provider:
provider_dict = ProviderCacheService._provider_to_dict(provider)
await CacheService.set(
cache_key, provider_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"Provider 已缓存: {provider_id}")
return provider
@staticmethod
async def get_endpoint_by_id(db: Session, endpoint_id: str) -> Optional[ProviderEndpoint]:
"""
获取 Endpoint带缓存
Args:
db: 数据库会话
endpoint_id: Endpoint ID
Returns:
ProviderEndpoint 对象或 None
"""
cache_key = CacheKeys.endpoint_by_id(endpoint_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"Endpoint 缓存命中: {endpoint_id}")
return ProviderCacheService._dict_to_endpoint(cached_data)
# 2. 缓存未命中,查询数据库
endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == endpoint_id).first()
# 3. 写入缓存
if endpoint:
endpoint_dict = ProviderCacheService._endpoint_to_dict(endpoint)
await CacheService.set(
cache_key, endpoint_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"Endpoint 已缓存: {endpoint_id}")
return endpoint
@staticmethod
async def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[ProviderAPIKey]:
"""
获取 API Key带缓存
Args:
db: 数据库会话
api_key_id: API Key ID
Returns:
ProviderAPIKey 对象或 None
"""
cache_key = CacheKeys.api_key_by_id(api_key_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"API Key 缓存命中: {api_key_id}")
return ProviderCacheService._dict_to_api_key(cached_data)
# 2. 缓存未命中,查询数据库
api_key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == api_key_id).first()
# 3. 写入缓存
if api_key:
api_key_dict = ProviderCacheService._api_key_to_dict(api_key)
await CacheService.set(
cache_key, api_key_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"API Key 已缓存: {api_key_id}")
return api_key
@staticmethod
async def invalidate_provider_cache(provider_id: str):
"""
清除 Provider 缓存
Args:
provider_id: Provider ID
"""
await CacheService.delete(CacheKeys.provider_by_id(provider_id))
logger.debug(f"Provider 缓存已清除: {provider_id}")
@staticmethod
async def invalidate_endpoint_cache(endpoint_id: str):
"""
清除 Endpoint 缓存
Args:
endpoint_id: Endpoint ID
"""
await CacheService.delete(CacheKeys.endpoint_by_id(endpoint_id))
logger.debug(f"Endpoint 缓存已清除: {endpoint_id}")
@staticmethod
async def invalidate_api_key_cache(api_key_id: str):
"""
清除 API Key 缓存
Args:
api_key_id: API Key ID
"""
await CacheService.delete(CacheKeys.api_key_by_id(api_key_id))
logger.debug(f"API Key 缓存已清除: {api_key_id}")
@staticmethod
def _provider_to_dict(provider: Provider) -> dict:
"""将 Provider 对象转换为字典(用于缓存)"""
return {
"id": provider.id,
"name": provider.name,
"api_format": provider.api_format,
"base_url": provider.base_url,
"is_active": provider.is_active,
"priority": provider.priority,
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
"config": provider.config,
"description": provider.description,
}
@staticmethod
def _dict_to_provider(provider_dict: dict) -> Provider:
"""从字典重建 Provider 对象(分离的对象,不在 Session 中)"""
from datetime import datetime
provider = Provider(
id=provider_dict["id"],
name=provider_dict["name"],
api_format=provider_dict["api_format"],
base_url=provider_dict.get("base_url"),
is_active=provider_dict["is_active"],
priority=provider_dict.get("priority", 0),
rpm_limit=provider_dict.get("rpm_limit"),
rpm_used=provider_dict.get("rpm_used", 0),
config=provider_dict.get("config"),
description=provider_dict.get("description"),
)
if provider_dict.get("rpm_reset_at"):
provider.rpm_reset_at = datetime.fromisoformat(provider_dict["rpm_reset_at"])
return provider
@staticmethod
def _endpoint_to_dict(endpoint: ProviderEndpoint) -> dict:
"""将 Endpoint 对象转换为字典"""
return {
"id": endpoint.id,
"provider_id": endpoint.provider_id,
"name": endpoint.name,
"base_url": endpoint.base_url,
"is_active": endpoint.is_active,
"priority": endpoint.priority,
"weight": endpoint.weight,
"custom_path": endpoint.custom_path,
"config": endpoint.config,
}
@staticmethod
def _dict_to_endpoint(endpoint_dict: dict) -> ProviderEndpoint:
"""从字典重建 Endpoint 对象"""
endpoint = ProviderEndpoint(
id=endpoint_dict["id"],
provider_id=endpoint_dict["provider_id"],
name=endpoint_dict["name"],
base_url=endpoint_dict["base_url"],
is_active=endpoint_dict["is_active"],
priority=endpoint_dict.get("priority", 0),
weight=endpoint_dict.get("weight", 1.0),
custom_path=endpoint_dict.get("custom_path"),
config=endpoint_dict.get("config"),
)
return endpoint
@staticmethod
def _api_key_to_dict(api_key: ProviderAPIKey) -> dict:
"""将 API Key 对象转换为字典"""
return {
"id": api_key.id,
"endpoint_id": api_key.endpoint_id,
"key_value": api_key.key_value,
"is_active": api_key.is_active,
"max_rpm": api_key.max_rpm,
"current_rpm": api_key.current_rpm,
"health_score": api_key.health_score,
"circuit_breaker_state": api_key.circuit_breaker_state,
"adaptive_concurrency_limit": api_key.adaptive_concurrency_limit,
}
@staticmethod
def _dict_to_api_key(api_key_dict: dict) -> ProviderAPIKey:
"""从字典重建 API Key 对象"""
api_key = ProviderAPIKey(
id=api_key_dict["id"],
endpoint_id=api_key_dict["endpoint_id"],
key_value=api_key_dict["key_value"],
is_active=api_key_dict["is_active"],
max_rpm=api_key_dict.get("max_rpm"),
current_rpm=api_key_dict.get("current_rpm", 0),
health_score=api_key_dict.get("health_score", 1.0),
circuit_breaker_state=api_key_dict.get("circuit_breaker_state"),
adaptive_concurrency_limit=api_key_dict.get("adaptive_concurrency_limit"),
)
return api_key

209
src/services/cache/sync.py vendored Normal file
View File

@@ -0,0 +1,209 @@
"""
缓存同步服务Redis Pub/Sub
提供分布式缓存失效同步功能,用于多实例部署场景。
当一个实例修改数据并失效本地缓存时,通过 Redis pub/sub 通知其他实例同步失效。
使用场景:
1. 多实例部署时,确保所有实例的缓存一致性
2. GlobalModel/ModelMapping 变更时,同步失效所有实例的缓存
"""
import asyncio
import json
from typing import Callable, Dict, Optional
import redis.asyncio as aioredis
from src.core.logger import logger
from src.clients.redis_client import get_redis_client_sync
from src.core.logger import logger
class CacheSyncService:
"""
缓存同步服务
通过 Redis pub/sub 实现分布式缓存失效同步
"""
# Redis 频道名称
CHANNEL_GLOBAL_MODEL = "cache:invalidate:global_model"
CHANNEL_MODEL_MAPPING = "cache:invalidate:model_mapping"
CHANNEL_MODEL = "cache:invalidate:model"
CHANNEL_CLEAR_ALL = "cache:invalidate:clear_all"
def __init__(self, redis_client: aioredis.Redis):
"""
初始化缓存同步服务
Args:
redis_client: Redis 客户端实例
"""
self._redis = redis_client
self._pubsub: Optional[aioredis.client.PubSub] = None
self._listener_task: Optional[asyncio.Task] = None
self._handlers: Dict[str, Callable] = {}
self._running = False
async def start(self):
"""启动缓存同步服务(订阅 Redis 频道)"""
if self._running:
logger.warning("[CacheSync] 服务已在运行")
return
try:
self._pubsub = self._redis.pubsub()
# 订阅所有缓存失效频道
await self._pubsub.subscribe(
self.CHANNEL_GLOBAL_MODEL,
self.CHANNEL_MODEL_MAPPING,
self.CHANNEL_MODEL,
self.CHANNEL_CLEAR_ALL,
)
# 启动监听任务
self._listener_task = asyncio.create_task(self._listen())
self._running = True
logger.info("[CacheSync] 缓存同步服务已启动,订阅频道: "
f"{self.CHANNEL_GLOBAL_MODEL}, {self.CHANNEL_MODEL_MAPPING}, "
f"{self.CHANNEL_MODEL}, {self.CHANNEL_CLEAR_ALL}")
except Exception as e:
logger.error(f"[CacheSync] 启动失败: {e}")
raise
async def stop(self):
"""停止缓存同步服务"""
if not self._running:
return
self._running = False
# 取消监听任务
if self._listener_task:
self._listener_task.cancel()
try:
await self._listener_task
except asyncio.CancelledError:
pass
# 取消订阅
if self._pubsub:
await self._pubsub.unsubscribe()
await self._pubsub.close()
logger.info("[CacheSync] 缓存同步服务已停止")
def register_handler(self, channel: str, handler: Callable):
"""
注册缓存失效处理器
Args:
channel: Redis 频道名称
handler: 处理函数(接收消息数据作为参数)
"""
self._handlers[channel] = handler
logger.debug(f"[CacheSync] 注册处理器: {channel}")
async def _listen(self):
"""监听 Redis pub/sub 消息"""
logger.info("[CacheSync] 开始监听缓存失效消息")
try:
async for message in self._pubsub.listen():
if message["type"] == "message":
channel = message["channel"]
data = message["data"]
# 解析消息
try:
payload = json.loads(data)
logger.debug(f"[CacheSync] 收到消息: {channel} -> {payload}")
# 调用注册的处理器
if channel in self._handlers:
handler = self._handlers[channel]
await handler(payload)
else:
logger.warning(f"[CacheSync] 未找到处理器: {channel}")
except json.JSONDecodeError as e:
logger.error(f"[CacheSync] 消息解析失败: {data}, 错误: {e}")
except Exception as e:
logger.error(f"[CacheSync] 处理消息失败: {channel}, 错误: {e}")
except asyncio.CancelledError:
logger.info("[CacheSync] 监听任务已取消")
except Exception as e:
logger.error(f"[CacheSync] 监听失败: {e}")
async def publish_global_model_changed(self, model_name: str):
"""发布 GlobalModel 变更通知"""
await self._publish(self.CHANNEL_GLOBAL_MODEL, {"model_name": model_name})
async def publish_model_mapping_changed(
self, source_model: str, provider_id: Optional[str] = None
):
"""发布 ModelMapping 变更通知"""
await self._publish(
self.CHANNEL_MODEL_MAPPING, {"source_model": source_model, "provider_id": provider_id}
)
async def publish_model_changed(self, provider_id: str, global_model_id: str):
"""发布 Model 变更通知"""
await self._publish(
self.CHANNEL_MODEL, {"provider_id": provider_id, "global_model_id": global_model_id}
)
async def publish_clear_all(self):
"""发布清空所有缓存通知"""
await self._publish(self.CHANNEL_CLEAR_ALL, {})
async def _publish(self, channel: str, data: dict):
"""发布消息到 Redis 频道"""
try:
message = json.dumps(data)
await self._redis.publish(channel, message)
logger.debug(f"[CacheSync] 发布消息: {channel} -> {data}")
except Exception as e:
logger.error(f"[CacheSync] 发布消息失败: {channel}, 错误: {e}")
# 全局单例
_cache_sync_service: Optional[CacheSyncService] = None
async def get_cache_sync_service(redis_client: aioredis.Redis = None) -> Optional[CacheSyncService]:
"""
获取缓存同步服务实例
Args:
redis_client: Redis 客户端实例(首次调用时需要提供)
Returns:
CacheSyncService 实例,如果 Redis 不可用返回 None
"""
global _cache_sync_service
if _cache_sync_service is None:
if redis_client is None:
# 尝试获取全局 Redis 客户端
redis_client = get_redis_client_sync()
if redis_client is None:
logger.warning("[CacheSync] Redis 不可用,分布式缓存同步已禁用")
return None
_cache_sync_service = CacheSyncService(redis_client)
logger.info("[CacheSync] 缓存同步服务已初始化")
return _cache_sync_service
async def close_cache_sync_service():
"""关闭缓存同步服务"""
global _cache_sync_service
if _cache_sync_service:
await _cache_sync_service.stop()
_cache_sync_service = None

155
src/services/cache/user_cache.py vendored Normal file
View File

@@ -0,0 +1,155 @@
"""
用户缓存服务 - 减少数据库查询
"""
from typing import Optional
from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
from src.core.cache_service import CacheKeys, CacheService
from src.core.logger import logger
from src.models.database import User
class UserCacheService:
"""用户缓存服务"""
# 缓存 TTL- 使用统一常量
CACHE_TTL = CacheTTL.USER
@staticmethod
async def get_user_by_id(db: Session, user_id: str) -> Optional[User]:
"""
获取用户(带缓存)
Args:
db: 数据库会话
user_id: 用户ID
Returns:
User 对象或 None
"""
cache_key = CacheKeys.user_by_id(user_id)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"用户缓存命中: {user_id}")
# 从缓存数据重建 User 对象
return UserCacheService._dict_to_user(db, cached_data)
# 2. 缓存未命中,查询数据库
user = db.query(User).filter(User.id == user_id).first()
# 3. 写入缓存
if user:
user_dict = UserCacheService._user_to_dict(user)
await CacheService.set(cache_key, user_dict, ttl_seconds=UserCacheService.CACHE_TTL)
logger.debug(f"用户已缓存: {user_id}")
return user
@staticmethod
async def get_user_by_email(db: Session, email: str) -> Optional[User]:
"""
通过邮箱获取用户(带缓存)
Args:
db: 数据库会话
email: 用户邮箱
Returns:
User 对象或 None
"""
cache_key = CacheKeys.user_by_email(email)
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"用户缓存命中(邮箱): {email}")
return UserCacheService._dict_to_user(db, cached_data)
# 2. 缓存未命中,查询数据库
user = db.query(User).filter(User.email == email).first()
# 3. 写入缓存
if user:
user_dict = UserCacheService._user_to_dict(user)
await CacheService.set(cache_key, user_dict, ttl_seconds=UserCacheService.CACHE_TTL)
logger.debug(f"用户已缓存(邮箱): {email}")
return user
@staticmethod
async def invalidate_user_cache(user_id: str, email: Optional[str] = None):
"""
清除用户缓存
Args:
user_id: 用户ID
email: 用户邮箱(可选)
"""
# 删除 ID 缓存
await CacheService.delete(CacheKeys.user_by_id(user_id))
# 删除邮箱缓存
if email:
await CacheService.delete(CacheKeys.user_by_email(email))
logger.debug(f"用户缓存已清除: {user_id}")
@staticmethod
def _user_to_dict(user: User) -> dict:
"""将 User 对象转换为字典(用于缓存)"""
return {
"id": user.id,
"email": user.email,
"username": user.username,
"role": user.role.value if user.role else None,
"is_active": user.is_active,
"quota_usd": float(user.quota_usd) if user.quota_usd is not None else None,
"used_usd": float(user.used_usd),
"created_at": user.created_at.isoformat() if user.created_at else None,
"last_login_at": user.last_login_at.isoformat() if user.last_login_at else None,
"model_capability_settings": user.model_capability_settings,
}
@staticmethod
def _dict_to_user(db: Session, user_dict: dict) -> User:
"""
从字典重建 User 对象
注意:这是一个"分离"的对象,不在 Session 中
如果需要修改,需要使用 db.merge() 或重新查询
"""
from datetime import datetime
from src.models.database import UserRole
user = User(
id=user_dict["id"],
email=user_dict["email"],
username=user_dict["username"],
is_active=user_dict["is_active"],
used_usd=user_dict["used_usd"],
)
# 设置可选字段
if user_dict.get("role"):
user.role = UserRole(user_dict["role"])
if user_dict.get("quota_usd") is not None:
user.quota_usd = user_dict["quota_usd"]
if user_dict.get("created_at"):
user.created_at = datetime.fromisoformat(user_dict["created_at"])
if user_dict.get("last_login_at"):
user.last_login_at = datetime.fromisoformat(user_dict["last_login_at"])
if user_dict.get("model_capability_settings") is not None:
user.model_capability_settings = user_dict["model_capability_settings"]
return user