Files
Aether/src/services/cache/aware_scheduler.py

1318 lines
52 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
缓存感知调度器 (Cache-Aware Scheduler)
职责:
1. 统一管理Provider/Endpoint/Key的选择逻辑
2. 集成缓存亲和性管理优先使用有缓存的Provider+Key
3. 协调并发控制和缓存优先级
4. 实现故障转移机制同Endpoint内优先跨Provider按优先级
核心设计思想:
===============
1. 用户首次请求: provider_priority 选择最优 Provider+Endpoint+Key
2. 用户后续请求:
- 优先使用缓存的Endpoint+Key (利用Prompt Caching)
- 如果缓存的Key并发满尝试同Endpoint其他Key
- 如果Endpoint不可用 provider_priority 切换到其他Provider
3. 并发控制动态预留机制:
- 探测阶段使用低预留10%让系统快速学习真实并发限制
- 稳定阶段根据置信度和负载动态调整预留比例10%-35%
- 置信度因素连续成功次数429冷却时间调整历史稳定性
- 缓存用户可使用全部槽位新用户只能用 (1-预留比例) 的槽位
4. 故障转移:
- Key故障: 同Endpoint内切换其他Key检查模型支持
- Endpoint故障: provider_priority 切换到其他Provider
- 注意不同Endpoint的协议完全不兼容不能在同Provider内切换Endpoint
- 失效缓存亲和性避免重复选择故障资源
"""
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from src.core.logger import logger
from sqlalchemy.orm import Session, selectinload
from src.core.enums import APIFormat
from src.core.exceptions import ProviderNotAvailableException
from src.models.database import (
ApiKey,
Model,
Provider,
ProviderAPIKey,
ProviderEndpoint,
)
from src.services.cache.affinity_manager import (
CacheAffinityManager,
get_affinity_manager,
)
from src.services.health.monitor import health_monitor
from src.services.provider.format import normalize_api_format
from src.services.rate_limit.adaptive_reservation import (
AdaptiveReservationManager,
ReservationResult,
get_adaptive_reservation_manager,
)
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
@dataclass
class ProviderCandidate:
"""候选 provider 组合及是否命中缓存"""
provider: Provider
endpoint: ProviderEndpoint
key: ProviderAPIKey
is_cached: bool = False
is_skipped: bool = False # 是否被跳过
skip_reason: Optional[str] = None # 跳过原因
@dataclass
class ConcurrencySnapshot:
endpoint_current: int
endpoint_limit: Optional[int]
key_current: int
key_limit: Optional[int]
is_cached_user: bool = False
# 动态预留信息
reservation_ratio: float = 0.0
reservation_phase: str = "unknown"
reservation_confidence: float = 0.0
def describe(self) -> str:
endpoint_limit_text = str(self.endpoint_limit) if self.endpoint_limit is not None else "inf"
key_limit_text = str(self.key_limit) if self.key_limit is not None else "inf"
reservation_text = f"{self.reservation_ratio:.0%}" if self.reservation_ratio > 0 else "N/A"
return (
f"endpoint={self.endpoint_current}/{endpoint_limit_text}, "
f"key={self.key_current}/{key_limit_text}, "
f"cached={self.is_cached_user}, "
f"reserve={reservation_text}({self.reservation_phase})"
)
class CacheAwareScheduler:
"""
缓存感知调度器
这是Provider选择的核心组件整合了:
- Provider/Endpoint/Key三层架构
- 缓存亲和性管理
- 并发控制动态预留机制
- 健康度监控
"""
# 静态常量作为默认值(实际由 AdaptiveReservationManager 动态计算)
CACHE_RESERVATION_RATIO = 0.3
# 优先级模式常量
PRIORITY_MODE_PROVIDER = "provider" # 提供商优先模式
PRIORITY_MODE_GLOBAL_KEY = "global_key" # 全局 Key 优先模式
ALLOWED_PRIORITY_MODES = {
PRIORITY_MODE_PROVIDER,
PRIORITY_MODE_GLOBAL_KEY,
}
def __init__(self, redis_client=None, priority_mode: Optional[str] = None):
"""
初始化调度器
注意: 不再持久化 db Session,避免跨请求使用已关闭的会话
每个方法调用时需要传入当前请求的 db Session
Args:
redis_client: Redis客户端可选
priority_mode: 候选排序策略provider | global_key
"""
self.redis = redis_client
self.priority_mode = self._normalize_priority_mode(
priority_mode or self.PRIORITY_MODE_PROVIDER
)
logger.debug(f"[CacheAwareScheduler] 初始化优先级模式: {self.priority_mode}")
# 初始化子组件(将在第一次使用时异步初始化)
self._affinity_manager: Optional[CacheAffinityManager] = None
self._concurrency_manager = None
# 动态预留管理器(同步初始化)
self._reservation_manager: AdaptiveReservationManager = get_adaptive_reservation_manager()
self._metrics = {
"total_batches": 0,
"last_batch_size": 0,
"total_candidates": 0,
"last_candidate_count": 0,
"cache_hits": 0,
"cache_misses": 0,
"concurrency_denied": 0,
"last_api_format": None,
"last_model_name": None,
"last_updated_at": None,
# 动态预留相关指标
"reservation_probe_count": 0,
"reservation_stable_count": 0,
"avg_reservation_ratio": 0.0,
"last_reservation_result": None,
}
async def _ensure_initialized(self):
"""确保所有异步组件已初始化"""
if self._affinity_manager is None:
self._affinity_manager = await get_affinity_manager(self.redis)
if self._concurrency_manager is None:
self._concurrency_manager = await get_concurrency_manager()
async def select_with_cache_affinity(
self,
db: Session,
affinity_key: str,
api_format: Union[str, APIFormat],
model_name: str,
excluded_endpoints: Optional[List[str]] = None,
excluded_keys: Optional[List[str]] = None,
provider_batch_size: int = 20,
max_candidates_per_batch: Optional[int] = None,
) -> Tuple[Provider, ProviderEndpoint, ProviderAPIKey]:
"""
缓存感知选择 - 核心方法
逻辑一次性获取所有候选缓存命中优先按顺序检查
排除列表和并发限制返回首个可用组合并在需要时刷新缓存亲和性
Args:
db: 数据库会话
affinity_key: 亲和性标识符通常为API Key ID
api_format: API格式
model_name: 模型名称
excluded_endpoints: 排除的Endpoint ID列表
excluded_keys: 排除的Provider Key ID列表
provider_batch_size: Provider批量大小
max_candidates_per_batch: 每批最大候选数
"""
await self._ensure_initialized()
excluded_endpoints_set = set(excluded_endpoints or [])
excluded_keys_set = set(excluded_keys or [])
normalized_format = normalize_api_format(api_format)
logger.debug(
f"[CacheAwareScheduler] select_with_cache_affinity: "
f"affinity_key={affinity_key[:8]}..., api_format={normalized_format.value}, model={model_name}"
)
self._metrics["last_api_format"] = normalized_format.value
self._metrics["last_model_name"] = model_name
provider_offset = 0
global_model_id = None # 用于缓存亲和性
while True:
candidates, resolved_global_model_id = await self.list_all_candidates(
db=db,
api_format=normalized_format,
model_name=model_name,
affinity_key=affinity_key,
provider_offset=provider_offset,
provider_limit=provider_batch_size,
max_candidates=max_candidates_per_batch,
)
if resolved_global_model_id and global_model_id is None:
global_model_id = resolved_global_model_id
if not candidates:
if provider_offset == 0:
# 没有找到任何候选,提供友好的错误提示
error_msg = f"模型 '{model_name}' 不可用"
# 查找相似模型
from src.services.model.mapping_resolver import get_model_mapping_resolver
resolver = get_model_mapping_resolver()
similar_models = resolver.find_similar_models(db, model_name, limit=3)
if similar_models:
suggestions = [
f"{name} (相似度: {score:.0%})" for name, score in similar_models
]
error_msg += f"\n\n您可能想使用以下模型:\n - " + "\n - ".join(suggestions)
raise ProviderNotAvailableException(error_msg)
break
self._metrics["total_batches"] += 1
self._metrics["last_batch_size"] = len(candidates)
self._metrics["last_updated_at"] = int(time.time())
for candidate in candidates:
provider = candidate.provider
endpoint = candidate.endpoint
key = candidate.key
if endpoint.id in excluded_endpoints_set:
logger.debug(f" └─ Endpoint {endpoint.id[:8]}... 在排除列表,跳过")
continue
if key.id in excluded_keys_set:
logger.debug(f" └─ Key {key.id[:8]}... 在排除列表,跳过")
continue
is_cached_user = bool(candidate.is_cached)
can_use, snapshot = await self._check_concurrent_available(
endpoint,
key,
is_cached_user=is_cached_user,
)
if not can_use:
logger.debug(f" └─ Key {key.id[:8]}... 并发已满 ({snapshot.describe()})")
self._metrics["concurrency_denied"] += 1
continue
logger.debug(f" └─ 选择 Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"Key=***{key.api_key[-4:]}, 缓存命中={is_cached_user}, "
f"并发状态[{snapshot.describe()}]")
if key.cache_ttl_minutes > 0 and global_model_id:
ttl = key.cache_ttl_minutes * 60 if key.cache_ttl_minutes > 0 else None
api_format_str = (
normalized_format.value
if isinstance(normalized_format, APIFormat)
else normalized_format
)
await self.set_cache_affinity(
affinity_key=affinity_key,
provider_id=str(provider.id),
endpoint_id=str(endpoint.id),
key_id=str(key.id),
api_format=api_format_str,
global_model_id=global_model_id,
ttl=int(ttl) if ttl is not None else None,
)
if is_cached_user:
self._metrics["cache_hits"] += 1
else:
self._metrics["cache_misses"] += 1
return provider, endpoint, key
provider_offset += provider_batch_size
raise ProviderNotAvailableException(f"所有Provider的资源当前不可用 (model={model_name})")
def _get_effective_concurrent_limit(self, key: ProviderAPIKey) -> Optional[int]:
"""
获取有效的并发限制
新逻辑
- max_concurrent=NULL: 启用自适应使用 learned_max_concurrent如无学习记录则为 None
- max_concurrent=数字: 固定限制直接使用该值
Args:
key: API Key对象
Returns:
有效的并发限制None 表示不限制
"""
if key.max_concurrent is None:
# 自适应模式:使用学习到的值
learned = key.learned_max_concurrent
return int(learned) if learned is not None else None
else:
# 固定限制模式
return int(key.max_concurrent)
async def _check_concurrent_available(
self,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
is_cached_user: bool = False,
) -> Tuple[bool, ConcurrencySnapshot]:
"""
检查并发是否可用使用动态预留机制
核心逻辑 - 动态缓存预留机制:
- 总槽位: 有效并发限制固定值或学习到的值
- 预留比例: AdaptiveReservationManager 根据置信度和负载动态计算
- 缓存用户可用: 全部槽位
- 新用户可用: 总槽位 × (1 - 动态预留比例)
Args:
endpoint: ProviderEndpoint对象
key: ProviderAPIKey对象
is_cached_user: 是否是缓存用户
Returns:
(是否可用, 并发快照)
"""
# 获取有效的并发限制
effective_key_limit = self._get_effective_concurrent_limit(key)
logger.debug(
f" -> 并发检查: _concurrency_manager={self._concurrency_manager is not None}, "
f"is_cached_user={is_cached_user}, effective_limit={effective_key_limit}"
)
if not self._concurrency_manager:
# 并发管理器不可用直接返回True
logger.debug(f" -> 无并发管理器,直接通过")
snapshot = ConcurrencySnapshot(
endpoint_current=0,
endpoint_limit=int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None,
key_current=0,
key_limit=effective_key_limit,
is_cached_user=is_cached_user,
)
return True, snapshot
# 获取当前并发数
endpoint_count, key_count = await self._concurrency_manager.get_current_concurrency(
endpoint_id=str(endpoint.id),
key_id=str(key.id),
)
can_use = True
# 检查Endpoint级别限制
if endpoint.max_concurrent is not None:
if endpoint_count >= endpoint.max_concurrent:
can_use = False
# 计算动态预留比例
reservation_result = self._reservation_manager.calculate_reservation(
key=key,
current_concurrent=key_count,
effective_limit=effective_key_limit,
)
# 更新指标
if reservation_result.phase == "probe":
self._metrics["reservation_probe_count"] += 1
else:
self._metrics["reservation_stable_count"] += 1
# 计算移动平均预留比例
total_reservations = (
self._metrics["reservation_probe_count"] + self._metrics["reservation_stable_count"]
)
if total_reservations > 0:
# 指数移动平均
alpha = 0.1
self._metrics["avg_reservation_ratio"] = (
alpha * reservation_result.ratio
+ (1 - alpha) * self._metrics["avg_reservation_ratio"]
)
self._metrics["last_reservation_result"] = {
"ratio": reservation_result.ratio,
"phase": reservation_result.phase,
"confidence": reservation_result.confidence,
"load_factor": reservation_result.load_factor,
}
available_for_new = None
reservation_ratio = reservation_result.ratio
# 检查Key级别限制使用动态预留比例
if effective_key_limit is not None:
if is_cached_user:
# 缓存用户: 可以使用全部槽位
if key_count >= effective_key_limit:
can_use = False
else:
# 新用户: 只能使用 (1 - 动态预留比例) 的槽位
# 使用 max 确保至少有 1 个槽位可用
import math
available_for_new = max(1, math.ceil(effective_key_limit * (1 - reservation_ratio)))
if key_count >= available_for_new:
logger.debug(
f"Key {key.id[:8]}... 新用户配额已满 "
f"({key_count}/{available_for_new}, 总{effective_key_limit}, "
f"预留{reservation_ratio:.0%}[{reservation_result.phase}])"
)
can_use = False
key_limit_for_snapshot: Optional[int]
if is_cached_user:
key_limit_for_snapshot = effective_key_limit
elif effective_key_limit is not None:
key_limit_for_snapshot = (
available_for_new if available_for_new is not None else effective_key_limit
)
else:
key_limit_for_snapshot = None
snapshot = ConcurrencySnapshot(
endpoint_current=endpoint_count,
endpoint_limit=endpoint.max_concurrent,
key_current=key_count,
key_limit=key_limit_for_snapshot,
is_cached_user=is_cached_user,
reservation_ratio=reservation_ratio,
reservation_phase=reservation_result.phase,
reservation_confidence=reservation_result.confidence,
)
return can_use, snapshot
def _get_effective_restrictions(
self,
user_api_key: Optional[ApiKey],
) -> Dict[str, Optional[set]]:
"""
获取有效的访问限制合并 ApiKey User 的限制
逻辑
- 如果 ApiKey User 都有限制取交集
- 如果只有一方有限制使用该方的限制
- 如果都没有限制返回 None表示不限制
Args:
user_api_key: 用户 API Key 对象可能包含 user relationship
Returns:
包含 allowed_providers, allowed_endpoints, allowed_models 的字典
"""
result = {
"allowed_providers": None,
"allowed_endpoints": None,
"allowed_models": None,
"allowed_api_formats": None,
}
if not user_api_key:
return result
# 获取 User 的限制
# 注意:这里可能触发 lazy loading需要确保 session 仍然有效
try:
user = user_api_key.user if hasattr(user_api_key, "user") else None
except Exception as e:
logger.warning(f"无法加载 ApiKey 关联的 User: {e},仅使用 ApiKey 级别的限制")
user = None
# 调试日志
logger.debug(f"[_get_effective_restrictions] ApiKey={user_api_key.id[:8]}..., "
f"User={user.id[:8] if user else 'None'}..., "
f"ApiKey.allowed_models={user_api_key.allowed_models}, "
f"User.allowed_models={user.allowed_models if user else 'N/A'}")
def merge_restrictions(key_restriction, user_restriction):
"""合并两个限制列表,返回有效的限制集合"""
key_set = set(key_restriction) if key_restriction else None
user_set = set(user_restriction) if user_restriction else None
if key_set and user_set:
# 两者都有限制,取交集
return key_set & user_set
elif key_set:
return key_set
elif user_set:
return user_set
else:
return None
# 合并 allowed_providers
result["allowed_providers"] = merge_restrictions(
user_api_key.allowed_providers, user.allowed_providers if user else None
)
# 合并 allowed_endpoints
result["allowed_endpoints"] = merge_restrictions(
user_api_key.allowed_endpoints if hasattr(user_api_key, "allowed_endpoints") else None,
user.allowed_endpoints if user else None,
)
# 合并 allowed_models
result["allowed_models"] = merge_restrictions(
user_api_key.allowed_models, user.allowed_models if user else None
)
# API 格式仅从 ApiKey 获取User 不设置此限制)
if user_api_key.allowed_api_formats:
result["allowed_api_formats"] = set(user_api_key.allowed_api_formats)
return result
async def list_all_candidates(
self,
db: Session,
api_format: Union[str, APIFormat],
model_name: str,
affinity_key: Optional[str] = None,
user_api_key: Optional[ApiKey] = None,
provider_offset: int = 0,
provider_limit: Optional[int] = None,
max_candidates: Optional[int] = None,
is_stream: bool = False,
capability_requirements: Optional[Dict[str, bool]] = None,
) -> Tuple[List[ProviderCandidate], str]:
"""
预先获取所有可用的 Provider/Endpoint/Key 组合
重构后的方法将逻辑拆分为
1. _query_providers: 数据库查询逻辑
2. _build_candidates: 候选构建逻辑
3. _apply_cache_affinity: 缓存亲和性处理
Args:
db: 数据库会话
api_format: API 格式
model_name: 模型名称
affinity_key: 亲和性标识符通常为API Key ID用于缓存亲和性
user_api_key: 用户 API Key用于访问限制过滤同时考虑 User 级别限制
provider_offset: Provider 分页偏移
provider_limit: Provider 分页限制
max_candidates: 最大候选数量
is_stream: 是否是流式请求如果为 True 则过滤不支持流式的 Provider
capability_requirements: 能力需求用于过滤不满足能力要求的 Key
Returns:
(候选列表, global_model_id) - global_model_id 用于缓存亲和性
"""
await self._ensure_initialized()
target_format = normalize_api_format(api_format)
# 0. 解析 model_name 到 GlobalModel用于缓存亲和性的规范化标识
from src.services.model.mapping_resolver import get_model_mapping_resolver
mapping_resolver = get_model_mapping_resolver()
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None)
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存
global_model_id: str = str(global_model.id) if global_model else model_name
# 获取合并后的访问限制ApiKey + User
restrictions = self._get_effective_restrictions(user_api_key)
allowed_api_formats = restrictions["allowed_api_formats"]
allowed_providers = restrictions["allowed_providers"]
allowed_endpoints = restrictions["allowed_endpoints"]
allowed_models = restrictions["allowed_models"]
# 0.1 检查 API 格式是否被允许
if allowed_api_formats:
if target_format.value not in allowed_api_formats:
logger.debug(f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, "
f"允许的格式: {allowed_api_formats}")
return [], global_model_id
# 0.2 检查模型是否被允许
if allowed_models:
if model_name not in allowed_models:
logger.debug(f"用户/API Key 不允许使用模型 {model_name}, " f"允许的模型: {allowed_models}")
return [], global_model_id
# 1. 查询 Providers
providers = self._query_providers(
db=db,
provider_offset=provider_offset,
provider_limit=provider_limit,
)
if not providers:
return [], global_model_id
# 1.5 根据 allowed_providers 过滤(合并 ApiKey 和 User 的限制)
if allowed_providers:
original_count = len(providers)
# 同时支持 provider id 和 name 匹配
providers = [
p for p in providers if p.id in allowed_providers or p.name in allowed_providers
]
if original_count != len(providers):
logger.debug(f"用户/API Key 过滤 Provider: {original_count} -> {len(providers)}")
if not providers:
return [], global_model_id
# 2. 构建候选列表(传入 allowed_endpoints、is_stream 和 capability_requirements 用于过滤)
candidates = await self._build_candidates(
db=db,
providers=providers,
target_format=target_format,
model_name=model_name,
affinity_key=affinity_key,
max_candidates=max_candidates,
allowed_endpoints=allowed_endpoints,
is_stream=is_stream,
capability_requirements=capability_requirements,
)
# 3. 应用优先级模式排序
candidates = self._apply_priority_mode_sort(candidates, affinity_key)
# 更新指标
self._metrics["total_candidates"] += len(candidates)
self._metrics["last_candidate_count"] = len(candidates)
logger.debug(f"预先获取到 {len(candidates)} 个可用组合 "
f"(api_format={target_format.value}, model={model_name})")
# 4. 应用缓存亲和性排序(使用 global_model_id 作为模型标识)
if affinity_key and candidates:
candidates = await self._apply_cache_affinity(
candidates=candidates,
affinity_key=affinity_key,
api_format=target_format,
global_model_id=global_model_id,
)
return candidates, global_model_id
def _query_providers(
self,
db: Session,
provider_offset: int = 0,
provider_limit: Optional[int] = None,
) -> List[Provider]:
"""
查询活跃的 Providers带预加载
Args:
db: 数据库会话
provider_offset: 分页偏移
provider_limit: 分页限制
Returns:
Provider 列表
"""
provider_query = (
db.query(Provider)
.options(
selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys),
selectinload(Provider.model_mappings),
# 同时加载 models 和 global_model 关系,以便 get_effective_* 方法能正确继承默认值
selectinload(Provider.models).selectinload(Model.global_model),
)
.filter(Provider.is_active == True)
.order_by(Provider.provider_priority.asc())
)
if provider_offset:
provider_query = provider_query.offset(provider_offset)
if provider_limit:
provider_query = provider_query.limit(provider_limit)
return provider_query.all()
async def _check_model_support(
self,
db: Session,
provider: Provider,
model_name: str,
is_stream: bool = False,
capability_requirements: Optional[Dict[str, bool]] = None,
) -> Tuple[bool, Optional[str], Optional[List[str]]]:
"""
检查 Provider 是否支持指定模型可选检查流式支持和能力需求
模型能力检查在这里进行而不是在 Key 级别因为
- 模型支持的能力是全局的与具体的 Key 无关
- 如果模型不支持某能力整个 Provider 的所有 Key 都应该被跳过
映射回退逻辑
- 如果存在模型映射mapping先尝试映射后的模型
- 如果映射后的模型因能力不满足而失败回退尝试原模型
- 其他失败原因如模型不存在Provider 未实现等不触发回退
Args:
db: 数据库会话
provider: Provider 对象
model_name: 模型名称
is_stream: 是否是流式请求如果为 True 则同时检查流式支持
capability_requirements: 能力需求可选用于检查模型是否支持所需能力
Returns:
(is_supported, skip_reason, supported_capabilities) - 是否支持跳过原因模型支持的能力列表
"""
from src.services.model.mapping_resolver import get_model_mapping_resolver
mapping_resolver = get_model_mapping_resolver()
# 获取映射后的模型,同时检查是否发生了映射
global_model, is_mapped = await mapping_resolver.get_global_model_with_mapping_info(
db, model_name, str(provider.id)
)
if not global_model:
return False, "模型不存在", None
# 尝试检查映射后的模型
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
db, provider, global_model, model_name, is_stream, capability_requirements
)
# 如果映射后的模型因能力不满足而失败,且存在映射,则回退尝试原模型
if not is_supported and is_mapped and skip_reason and "不支持能力" in skip_reason:
logger.debug(
f"Provider {provider.name} 映射模型 {global_model.name} 能力不满足,"
f"回退尝试原模型 {model_name}"
)
# 获取原模型(不应用映射)
original_global_model = await mapping_resolver.get_global_model_direct(db, model_name)
if original_global_model and original_global_model.id != global_model.id:
# 尝试原模型
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
db,
provider,
original_global_model,
model_name,
is_stream,
capability_requirements,
)
if is_supported:
logger.debug(
f"Provider {provider.name} 原模型 {original_global_model.name} 支持所需能力"
)
return is_supported, skip_reason, caps
async def _check_model_support_for_global_model(
self,
db: Session,
provider: Provider,
global_model: "GlobalModel",
model_name: str,
is_stream: bool = False,
capability_requirements: Optional[Dict[str, bool]] = None,
) -> Tuple[bool, Optional[str], Optional[List[str]]]:
"""
检查 Provider 是否支持指定的 GlobalModel
Args:
db: 数据库会话
provider: Provider 对象
global_model: GlobalModel 对象
model_name: 用户请求的模型名称用于错误消息
is_stream: 是否是流式请求
capability_requirements: 能力需求
Returns:
(is_supported, skip_reason, supported_capabilities)
"""
# 确保 global_model 附加到当前 Session
global_model = db.merge(global_model, load=False)
# 获取模型支持的能力列表
model_supported_capabilities: List[str] = list(global_model.supported_capabilities or [])
# 查询该 Provider 是否有实现这个 GlobalModel
for model in provider.models:
if model.global_model_id == global_model.id and model.is_active:
# 检查流式支持
if is_stream:
supports_streaming = model.get_effective_supports_streaming()
if not supports_streaming:
return False, f"模型 {model_name} 在此 Provider 不支持流式", None
# 检查模型是否支持所需的能力(在 Provider 级别检查,而不是 Key 级别)
# 只有当 model_supported_capabilities 非空时才进行检查
# 空列表意味着模型没有配置能力限制,默认支持所有能力
if capability_requirements and model_supported_capabilities:
for cap_name, is_required in capability_requirements.items():
if is_required and cap_name not in model_supported_capabilities:
return (
False,
f"模型 {model_name} 不支持能力: {cap_name}",
list(model_supported_capabilities),
)
return True, None, list(model_supported_capabilities)
return False, "Provider 未实现此模型", None
def _check_key_availability(
self,
key: ProviderAPIKey,
model_name: str,
capability_requirements: Optional[Dict[str, bool]] = None,
) -> Tuple[bool, Optional[str]]:
"""
检查 API Key 的可用性
注意模型能力检查已移到 _check_model_support 中进行Provider 级别
这里只检查 Key 级别的能力匹配
Args:
key: API Key 对象
model_name: 模型名称
capability_requirements: 能力需求可选
Returns:
(is_available, skip_reason)
"""
# 检查熔断器状态(使用详细状态方法获取更丰富的跳过原因)
is_available, circuit_reason = health_monitor.get_circuit_breaker_status(key)
if not is_available:
return False, circuit_reason or "熔断器已打开"
# 模型权限检查:使用 allowed_models 白名单
# None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型
if key.allowed_models is not None and model_name not in key.allowed_models:
allowed_preview = ", ".join(key.allowed_models[:3]) if key.allowed_models else "(无)"
suffix = "..." if len(key.allowed_models) > 3 else ""
return False, f"模型权限不匹配(允许: {allowed_preview}{suffix})"
# Key 级别的能力匹配检查
# 注意:模型级别的能力检查已在 _check_model_support 中完成
# 始终执行检查,即使 capability_requirements 为空
# 因为 check_capability_match 会检查 Key 的 EXCLUSIVE 能力是否被浪费
from src.core.key_capabilities import check_capability_match
key_caps: Dict[str, bool] = dict(key.capabilities or {})
is_match, skip_reason = check_capability_match(key_caps, capability_requirements)
if not is_match:
return False, skip_reason
2025-12-10 20:52:44 +08:00
return True, None
async def _build_candidates(
self,
db: Session,
providers: List[Provider],
target_format: APIFormat,
model_name: str,
affinity_key: Optional[str],
max_candidates: Optional[int] = None,
allowed_endpoints: Optional[set] = None,
is_stream: bool = False,
capability_requirements: Optional[Dict[str, bool]] = None,
) -> List[ProviderCandidate]:
"""
构建候选列表
Args:
db: 数据库会话
providers: Provider 列表
target_format: 目标 API 格式
model_name: 模型名称
affinity_key: 亲和性标识符通常为API Key ID
max_candidates: 最大候选数
allowed_endpoints: 允许的 Endpoint ID 集合None 表示不限制
is_stream: 是否是流式请求如果为 True 则过滤不支持流式的 Provider
capability_requirements: 能力需求可选
Returns:
候选列表
"""
candidates: List[ProviderCandidate] = []
for provider in providers:
# 检查模型支持(同时检查流式支持和模型能力需求)
# 模型能力检查在 Provider 级别进行,如果模型不支持所需能力,整个 Provider 被跳过
supports_model, skip_reason, _model_caps = await self._check_model_support(
db, provider, model_name, is_stream, capability_requirements
)
if not supports_model:
logger.debug(f"Provider {provider.name} 不支持模型 {model_name}: {skip_reason}")
continue
for endpoint in provider.endpoints:
# endpoint.api_format 是字符串target_format 是枚举
endpoint_format_str = (
endpoint.api_format
if isinstance(endpoint.api_format, str)
else endpoint.api_format.value
)
if not endpoint.is_active or endpoint_format_str != target_format.value:
continue
# 检查 Endpoint 是否在允许列表中
if allowed_endpoints and endpoint.id not in allowed_endpoints:
logger.debug(
f"Endpoint {endpoint.id[:8]}... 不在用户/API Key 的允许列表中,跳过"
)
continue
# 获取活跃的 Key 并按 internal_priority + 负载均衡排序
active_keys = [key for key in endpoint.api_keys if key.is_active]
keys = self._shuffle_keys_by_internal_priority(active_keys, affinity_key)
for key in keys:
# Key 级别的能力检查(模型级别的能力检查已在上面完成)
is_available, skip_reason = self._check_key_availability(
key, model_name, capability_requirements
)
candidate = ProviderCandidate(
provider=provider,
endpoint=endpoint,
key=key,
is_skipped=not is_available,
skip_reason=skip_reason,
)
candidates.append(candidate)
if max_candidates and len(candidates) >= max_candidates:
return candidates
return candidates
async def _apply_cache_affinity(
self,
candidates: List[ProviderCandidate],
affinity_key: str,
api_format: APIFormat,
global_model_id: str,
) -> List[ProviderCandidate]:
"""
应用缓存亲和性排序
缓存命中的候选会被提升到列表前面
Args:
candidates: 候选列表
affinity_key: 亲和性标识符通常为API Key ID
api_format: API 格式
global_model_id: GlobalModel ID规范化的模型标识
Returns:
重排序后的候选列表
"""
try:
# 查询该亲和性标识符在当前 API 格式和模型下的缓存亲和性
api_format_str = api_format.value if isinstance(api_format, APIFormat) else api_format
affinity = await self._affinity_manager.get_affinity(
affinity_key, api_format_str, global_model_id
)
if not affinity:
# 没有缓存亲和性,所有候选都标记为非缓存
for candidate in candidates:
candidate.is_cached = False
return candidates
# 按是否匹配缓存亲和性分类候选
cached_candidates: List[ProviderCandidate] = []
other_candidates: List[ProviderCandidate] = []
matched = False
for candidate in candidates:
provider = candidate.provider
endpoint = candidate.endpoint
key = candidate.key
if (
provider.id == affinity.provider_id
and endpoint.id == affinity.endpoint_id
and key.id == affinity.key_id
):
candidate.is_cached = True
cached_candidates.append(candidate)
matched = True
logger.debug(f"检测到缓存亲和性: affinity_key={affinity_key[:8]}..., "
f"api_format={api_format_str}, global_model_id={global_model_id[:8]}..., "
f"provider={provider.name}, endpoint={endpoint.id[:8]}..., "
f"provider_key=***{key.api_key[-4:]}, "
f"使用次数={affinity.request_count}")
else:
candidate.is_cached = False
other_candidates.append(candidate)
if not matched:
logger.debug(f"API格式 {api_format_str} 的缓存亲和性存在但组合不可用")
# 重新排序:缓存候选优先
if cached_candidates:
result = cached_candidates + other_candidates
logger.debug(f"{len(cached_candidates)} 个缓存组合已提升至优先级")
return result
return candidates
except Exception as e:
logger.warning(f"检查缓存亲和性失败: {e},继续使用默认排序")
return candidates
def _normalize_priority_mode(self, mode: Optional[str]) -> str:
normalized = (mode or "").strip().lower()
if normalized not in self.ALLOWED_PRIORITY_MODES:
if normalized:
logger.warning(f"[CacheAwareScheduler] 无效的优先级模式 '{mode}',回退为 provider")
return self.PRIORITY_MODE_PROVIDER
return normalized
def set_priority_mode(self, mode: Optional[str]) -> None:
"""运行时更新候选排序策略"""
normalized = self._normalize_priority_mode(mode)
if normalized == self.priority_mode:
return
self.priority_mode = normalized
logger.debug(f"[CacheAwareScheduler] 切换优先级模式为: {self.priority_mode}")
def _apply_priority_mode_sort(
self, candidates: List[ProviderCandidate], affinity_key: Optional[str] = None
) -> List[ProviderCandidate]:
"""
根据优先级模式对候选列表排序数字越小越优先
- provider: 提供商优先模式保持原有顺序 Provider.provider_priority -> Key.internal_priority 排序已由查询保证
Key.internal_priority 表示 Endpoint 内部优先级同优先级内通过哈希分散负载均衡
- global_key: 全局 Key 优先模式 Key.global_priority 升序排序数字小的优先
global_priority 的优先NULL 的排后面
global_priority 内通过哈希分散实现负载均衡
"""
if not candidates:
return candidates
if self.priority_mode == self.PRIORITY_MODE_GLOBAL_KEY:
# 全局 Key 优先模式:按 global_priority 分组,同组内哈希分散负载均衡
return self._sort_by_global_priority_with_hash(candidates, affinity_key)
# 提供商优先模式保持原有顺序provider_priority 排序已经由查询保证)
return candidates
def _sort_by_global_priority_with_hash(
self, candidates: List[ProviderCandidate], affinity_key: Optional[str] = None
) -> List[ProviderCandidate]:
"""
global_priority 分组排序同优先级内通过哈希分散实现负载均衡
排序逻辑
1. global_priority 分组数字小的优先NULL 排后面
2. global_priority 组内使用 affinity_key 哈希分散
3. 确保同一用户请求稳定选择同一个 Key缓存亲和性
"""
import hashlib
from collections import defaultdict
# 按 global_priority 分组
priority_groups: Dict[int, List[ProviderCandidate]] = defaultdict(list)
for candidate in candidates:
global_priority = (
candidate.key.global_priority
if candidate.key and candidate.key.global_priority is not None
else 999999 # NULL 排在后面
)
priority_groups[global_priority].append(candidate)
result = []
for priority in sorted(priority_groups.keys()): # 数字小的优先级高
group = priority_groups[priority]
if len(group) > 1 and affinity_key:
# 同优先级内哈希分散负载均衡
scored_candidates = []
for candidate in group:
key_id = candidate.key.id if candidate.key else ""
hash_input = f"{affinity_key}:{key_id}"
hash_value = int(hashlib.sha256(hash_input.encode()).hexdigest()[:16], 16)
scored_candidates.append((hash_value, candidate))
# 按哈希值排序
sorted_group = [c for _, c in sorted(scored_candidates)]
result.extend(sorted_group)
else:
# 单个候选或没有 affinity_key按次要排序条件排序
def secondary_sort(c: ProviderCandidate):
return (
c.provider.provider_priority,
c.key.internal_priority if c.key else 999999,
c.key.id if c.key else "",
)
result.extend(sorted(group, key=secondary_sort))
return result
def _shuffle_keys_by_internal_priority(
self,
keys: List[ProviderAPIKey],
affinity_key: Optional[str] = None,
) -> List[ProviderAPIKey]:
"""
API Key internal_priority 分组同优先级内部基于 affinity_key 进行确定性打乱
目的
- 数字越小越优先使用
- 同优先级 Key 之间实现负载均衡
- 使用 affinity_key 哈希确保同一请求 Key 的请求稳定避免破坏缓存亲和性
Args:
keys: API Key 列表
affinity_key: 亲和性标识符通常为 API Key ID用于确定性打乱
Returns:
排序后的 Key 列表
"""
if not keys:
return []
# 按 internal_priority 分组
from collections import defaultdict
priority_groups: Dict[int, List[ProviderAPIKey]] = defaultdict(list)
for key in keys:
priority = key.internal_priority if key.internal_priority is not None else 999999
priority_groups[priority].append(key)
# 对每个优先级组内的 Key 进行确定性打乱
result = []
for priority in sorted(priority_groups.keys()): # 数字小的优先级高,排前面
group_keys = priority_groups[priority]
if len(group_keys) > 1 and affinity_key:
# 改进的哈希策略:为每个 key 计算独立的哈希值
import hashlib
key_scores = []
for key in group_keys:
# 使用 affinity_key + key.id 的组合哈希
hash_input = f"{affinity_key}:{key.id}"
hash_value = int(hashlib.sha256(hash_input.encode()).hexdigest()[:16], 16)
key_scores.append((hash_value, key))
# 按哈希值排序
sorted_group = [key for _, key in sorted(key_scores)]
result.extend(sorted_group)
else:
# 单个 Key 或没有 affinity_key 时保持原顺序
result.extend(sorted(group_keys, key=lambda k: k.id))
return result
async def invalidate_cache(
self,
affinity_key: str,
api_format: str,
global_model_id: str,
endpoint_id: Optional[str] = None,
key_id: Optional[str] = None,
provider_id: Optional[str] = None,
):
"""
失效指定亲和性标识符对特定API格式和模型的缓存亲和性
Args:
affinity_key: 亲和性标识符通常为API Key ID
api_format: API格式 (claude/openai)
global_model_id: GlobalModel ID规范化的模型标识
endpoint_id: 端点ID可选如果提供则只在Endpoint匹配时失效
key_id: Provider Key ID可选
provider_id: Provider ID可选
"""
await self._ensure_initialized()
await self._affinity_manager.invalidate_affinity(
affinity_key=affinity_key,
api_format=api_format,
model_name=global_model_id,
endpoint_id=endpoint_id,
key_id=key_id,
provider_id=provider_id,
)
async def set_cache_affinity(
self,
affinity_key: str,
provider_id: str,
endpoint_id: str,
key_id: str,
api_format: str,
global_model_id: str,
ttl: Optional[int] = None,
):
"""
记录缓存亲和性供编排器调用
Args:
affinity_key: 亲和性标识符通常为API Key ID
provider_id: Provider ID
endpoint_id: Endpoint ID
key_id: Provider Key ID
api_format: API格式
global_model_id: GlobalModel ID规范化的模型标识
ttl: 缓存TTL
注意每次调用都会刷新过期时间实现滑动窗口机制
"""
await self._ensure_initialized()
await self._affinity_manager.set_affinity(
affinity_key=affinity_key,
provider_id=provider_id,
endpoint_id=endpoint_id,
key_id=key_id,
api_format=api_format,
model_name=global_model_id,
supports_caching=True,
ttl=ttl,
)
async def get_stats(self) -> dict:
"""获取调度器统计信息"""
await self._ensure_initialized()
affinity_stats = self._affinity_manager.get_stats()
metrics = dict(self._metrics)
cache_total = metrics["cache_hits"] + metrics["cache_misses"]
metrics["cache_hit_rate"] = metrics["cache_hits"] / cache_total if cache_total else 0.0
metrics["avg_candidates_per_batch"] = (
metrics["total_candidates"] / metrics["total_batches"]
if metrics["total_batches"]
else 0.0
)
# 动态预留统计
reservation_stats = self._reservation_manager.get_stats()
total_reservation_checks = (
metrics["reservation_probe_count"] + metrics["reservation_stable_count"]
)
if total_reservation_checks > 0:
probe_ratio = metrics["reservation_probe_count"] / total_reservation_checks
else:
probe_ratio = 0.0
return {
"scheduler": "cache_aware",
"cache_reservation_ratio": self.CACHE_RESERVATION_RATIO,
"dynamic_reservation": {
"enabled": True,
"config": reservation_stats["config"],
"current_avg_ratio": round(metrics["avg_reservation_ratio"], 3),
"probe_phase_ratio": round(probe_ratio, 3),
"total_checks": total_reservation_checks,
"last_result": metrics["last_reservation_result"],
},
"affinity_stats": affinity_stats,
"scheduler_metrics": metrics,
}
# 全局单例
_scheduler: Optional[CacheAwareScheduler] = None
async def get_cache_aware_scheduler(
redis_client=None,
priority_mode: Optional[str] = None,
) -> CacheAwareScheduler:
"""
获取全局CacheAwareScheduler实例
注意: 不再接受 db 参数,避免持久化请求级别的 Session
每次调用 scheduler 方法时需要传入当前请求的 db Session
Args:
redis_client: Redis客户端可选
priority_mode: 外部覆盖的优先级模式provider | global_key
Returns:
CacheAwareScheduler实例
"""
global _scheduler
if _scheduler is None:
_scheduler = CacheAwareScheduler(redis_client, priority_mode=priority_mode)
elif priority_mode:
_scheduler.set_priority_mode(priority_mode)
return _scheduler