mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
- 在 aware_scheduler.py 中添加调试日志,用于跟踪模型支持检查过程 - 重构 model_cache.py 的别名解析逻辑:调整优先级为 alias > provider_model_name > direct_match - 优化缓存命中路径,将直接匹配逻辑移到别名匹配失败后执行
1332 lines
52 KiB
Python
1332 lines
52 KiB
Python
"""
|
||
缓存感知调度器 (Cache-Aware Scheduler)
|
||
|
||
职责:
|
||
1. 统一管理Provider/Endpoint/Key的选择逻辑
|
||
2. 集成缓存亲和性管理,优先使用有缓存的Provider+Key
|
||
3. 协调并发控制和缓存优先级
|
||
4. 实现故障转移机制(同Endpoint内优先,跨Provider按优先级)
|
||
|
||
核心设计思想:
|
||
===============
|
||
1. 用户首次请求: 按 provider_priority 选择最优 Provider+Endpoint+Key
|
||
2. 用户后续请求:
|
||
- 优先使用缓存的Endpoint+Key (利用Prompt Caching)
|
||
- 如果缓存的Key并发满,尝试同Endpoint其他Key
|
||
- 如果Endpoint不可用,按 provider_priority 切换到其他Provider
|
||
|
||
3. 并发控制(动态预留机制):
|
||
- 探测阶段:使用低预留(10%),让系统快速学习真实并发限制
|
||
- 稳定阶段:根据置信度和负载动态调整预留比例(10%-35%)
|
||
- 置信度因素:连续成功次数、429冷却时间、调整历史稳定性
|
||
- 缓存用户可使用全部槽位,新用户只能用 (1-预留比例) 的槽位
|
||
|
||
4. 故障转移:
|
||
- Key故障: 同Endpoint内切换其他Key(检查模型支持)
|
||
- Endpoint故障: 按 provider_priority 切换到其他Provider
|
||
- 注意:不同Endpoint的协议完全不兼容,不能在同Provider内切换Endpoint
|
||
- 失效缓存亲和性,避免重复选择故障资源
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import time
|
||
from dataclasses import dataclass
|
||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||
|
||
from sqlalchemy.orm import Session, selectinload
|
||
|
||
from src.core.enums import APIFormat
|
||
from src.core.exceptions import ModelNotSupportedException, ProviderNotAvailableException
|
||
from src.core.logger import logger
|
||
from src.models.database import (
|
||
ApiKey,
|
||
Model,
|
||
Provider,
|
||
ProviderAPIKey,
|
||
ProviderEndpoint,
|
||
)
|
||
|
||
if TYPE_CHECKING:
|
||
from src.models.database import GlobalModel
|
||
|
||
from src.services.cache.affinity_manager import (
|
||
CacheAffinityManager,
|
||
get_affinity_manager,
|
||
)
|
||
from src.services.cache.model_cache import ModelCacheService
|
||
from src.services.health.monitor import health_monitor
|
||
from src.services.provider.format import normalize_api_format
|
||
from src.services.rate_limit.adaptive_reservation import (
|
||
AdaptiveReservationManager,
|
||
ReservationResult,
|
||
get_adaptive_reservation_manager,
|
||
)
|
||
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
|
||
|
||
|
||
@dataclass
|
||
class ProviderCandidate:
|
||
"""候选 provider 组合及是否命中缓存"""
|
||
|
||
provider: Provider
|
||
endpoint: ProviderEndpoint
|
||
key: ProviderAPIKey
|
||
is_cached: bool = False
|
||
is_skipped: bool = False # 是否被跳过
|
||
skip_reason: Optional[str] = None # 跳过原因
|
||
|
||
|
||
@dataclass
|
||
class ConcurrencySnapshot:
|
||
endpoint_current: int
|
||
endpoint_limit: Optional[int]
|
||
key_current: int
|
||
key_limit: Optional[int]
|
||
is_cached_user: bool = False
|
||
# 动态预留信息
|
||
reservation_ratio: float = 0.0
|
||
reservation_phase: str = "unknown"
|
||
reservation_confidence: float = 0.0
|
||
|
||
def describe(self) -> str:
|
||
endpoint_limit_text = str(self.endpoint_limit) if self.endpoint_limit is not None else "inf"
|
||
key_limit_text = str(self.key_limit) if self.key_limit is not None else "inf"
|
||
reservation_text = f"{self.reservation_ratio:.0%}" if self.reservation_ratio > 0 else "N/A"
|
||
return (
|
||
f"endpoint={self.endpoint_current}/{endpoint_limit_text}, "
|
||
f"key={self.key_current}/{key_limit_text}, "
|
||
f"cached={self.is_cached_user}, "
|
||
f"reserve={reservation_text}({self.reservation_phase})"
|
||
)
|
||
|
||
|
||
class CacheAwareScheduler:
|
||
"""
|
||
缓存感知调度器
|
||
|
||
这是Provider选择的核心组件,整合了:
|
||
- Provider/Endpoint/Key三层架构
|
||
- 缓存亲和性管理
|
||
- 并发控制(动态预留机制)
|
||
- 健康度监控
|
||
"""
|
||
|
||
# 静态常量作为默认值(实际由 AdaptiveReservationManager 动态计算)
|
||
CACHE_RESERVATION_RATIO = 0.3
|
||
# 优先级模式常量
|
||
PRIORITY_MODE_PROVIDER = "provider" # 提供商优先模式
|
||
PRIORITY_MODE_GLOBAL_KEY = "global_key" # 全局 Key 优先模式
|
||
ALLOWED_PRIORITY_MODES = {
|
||
PRIORITY_MODE_PROVIDER,
|
||
PRIORITY_MODE_GLOBAL_KEY,
|
||
}
|
||
|
||
def __init__(self, redis_client=None, priority_mode: Optional[str] = None):
|
||
"""
|
||
初始化调度器
|
||
|
||
注意: 不再持久化 db Session,避免跨请求使用已关闭的会话
|
||
每个方法调用时需要传入当前请求的 db Session
|
||
|
||
Args:
|
||
redis_client: Redis客户端(可选)
|
||
priority_mode: 候选排序策略(provider | global_key)
|
||
"""
|
||
self.redis = redis_client
|
||
self.priority_mode = self._normalize_priority_mode(
|
||
priority_mode or self.PRIORITY_MODE_PROVIDER
|
||
)
|
||
logger.debug(f"[CacheAwareScheduler] 初始化优先级模式: {self.priority_mode}")
|
||
|
||
# 初始化子组件(将在第一次使用时异步初始化)
|
||
self._affinity_manager: Optional[CacheAffinityManager] = None
|
||
self._concurrency_manager = None
|
||
# 动态预留管理器(同步初始化)
|
||
self._reservation_manager: AdaptiveReservationManager = get_adaptive_reservation_manager()
|
||
self._metrics = {
|
||
"total_batches": 0,
|
||
"last_batch_size": 0,
|
||
"total_candidates": 0,
|
||
"last_candidate_count": 0,
|
||
"cache_hits": 0,
|
||
"cache_misses": 0,
|
||
"concurrency_denied": 0,
|
||
"last_api_format": None,
|
||
"last_model_name": None,
|
||
"last_updated_at": None,
|
||
# 动态预留相关指标
|
||
"reservation_probe_count": 0,
|
||
"reservation_stable_count": 0,
|
||
"avg_reservation_ratio": 0.0,
|
||
"last_reservation_result": None,
|
||
}
|
||
|
||
async def _ensure_initialized(self):
|
||
"""确保所有异步组件已初始化"""
|
||
if self._affinity_manager is None:
|
||
self._affinity_manager = await get_affinity_manager(self.redis)
|
||
|
||
if self._concurrency_manager is None:
|
||
self._concurrency_manager = await get_concurrency_manager()
|
||
|
||
async def select_with_cache_affinity(
|
||
self,
|
||
db: Session,
|
||
affinity_key: str,
|
||
api_format: Union[str, APIFormat],
|
||
model_name: str,
|
||
excluded_endpoints: Optional[List[str]] = None,
|
||
excluded_keys: Optional[List[str]] = None,
|
||
provider_batch_size: int = 20,
|
||
max_candidates_per_batch: Optional[int] = None,
|
||
) -> Tuple[Provider, ProviderEndpoint, ProviderAPIKey]:
|
||
"""
|
||
缓存感知选择 - 核心方法
|
||
|
||
逻辑:一次性获取所有候选(缓存命中优先),按顺序检查
|
||
排除列表和并发限制,返回首个可用组合,并在需要时刷新缓存亲和性。
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
affinity_key: 亲和性标识符(通常为API Key ID)
|
||
api_format: API格式
|
||
model_name: 模型名称
|
||
excluded_endpoints: 排除的Endpoint ID列表
|
||
excluded_keys: 排除的Provider Key ID列表
|
||
provider_batch_size: Provider批量大小
|
||
max_candidates_per_batch: 每批最大候选数
|
||
"""
|
||
await self._ensure_initialized()
|
||
|
||
excluded_endpoints_set = set(excluded_endpoints or [])
|
||
excluded_keys_set = set(excluded_keys or [])
|
||
|
||
normalized_format = normalize_api_format(api_format)
|
||
|
||
logger.debug(
|
||
f"[CacheAwareScheduler] select_with_cache_affinity: "
|
||
f"affinity_key={affinity_key[:8]}..., api_format={normalized_format.value}, model={model_name}"
|
||
)
|
||
|
||
self._metrics["last_api_format"] = normalized_format.value
|
||
self._metrics["last_model_name"] = model_name
|
||
|
||
provider_offset = 0
|
||
|
||
global_model_id = None # 用于缓存亲和性
|
||
|
||
while True:
|
||
candidates, resolved_global_model_id = await self.list_all_candidates(
|
||
db=db,
|
||
api_format=normalized_format,
|
||
model_name=model_name,
|
||
affinity_key=affinity_key,
|
||
provider_offset=provider_offset,
|
||
provider_limit=provider_batch_size,
|
||
max_candidates=max_candidates_per_batch,
|
||
)
|
||
|
||
if resolved_global_model_id and global_model_id is None:
|
||
global_model_id = resolved_global_model_id
|
||
|
||
if not candidates:
|
||
if provider_offset == 0:
|
||
# 没有找到任何候选,提供友好的错误提示
|
||
error_msg = f"模型 '{model_name}' 不可用"
|
||
raise ProviderNotAvailableException(error_msg)
|
||
break
|
||
|
||
self._metrics["total_batches"] += 1
|
||
self._metrics["last_batch_size"] = len(candidates)
|
||
self._metrics["last_updated_at"] = int(time.time())
|
||
|
||
for candidate in candidates:
|
||
provider = candidate.provider
|
||
endpoint = candidate.endpoint
|
||
key = candidate.key
|
||
|
||
if endpoint.id in excluded_endpoints_set:
|
||
logger.debug(f" └─ Endpoint {endpoint.id[:8]}... 在排除列表,跳过")
|
||
continue
|
||
|
||
if key.id in excluded_keys_set:
|
||
logger.debug(f" └─ Key {key.id[:8]}... 在排除列表,跳过")
|
||
continue
|
||
|
||
is_cached_user = bool(candidate.is_cached)
|
||
can_use, snapshot = await self._check_concurrent_available(
|
||
endpoint,
|
||
key,
|
||
is_cached_user=is_cached_user,
|
||
)
|
||
|
||
if not can_use:
|
||
logger.debug(f" └─ Key {key.id[:8]}... 并发已满 ({snapshot.describe()})")
|
||
self._metrics["concurrency_denied"] += 1
|
||
continue
|
||
|
||
logger.debug(
|
||
f" └─ 选择 Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||
f"Key=***{key.api_key[-4:]}, 缓存命中={is_cached_user}, "
|
||
f"并发状态[{snapshot.describe()}]"
|
||
)
|
||
|
||
if key.cache_ttl_minutes > 0 and global_model_id:
|
||
ttl = key.cache_ttl_minutes * 60 if key.cache_ttl_minutes > 0 else None
|
||
api_format_str = (
|
||
normalized_format.value
|
||
if isinstance(normalized_format, APIFormat)
|
||
else normalized_format
|
||
)
|
||
await self.set_cache_affinity(
|
||
affinity_key=affinity_key,
|
||
provider_id=str(provider.id),
|
||
endpoint_id=str(endpoint.id),
|
||
key_id=str(key.id),
|
||
api_format=api_format_str,
|
||
global_model_id=global_model_id,
|
||
ttl=int(ttl) if ttl is not None else None,
|
||
)
|
||
|
||
if is_cached_user:
|
||
self._metrics["cache_hits"] += 1
|
||
else:
|
||
self._metrics["cache_misses"] += 1
|
||
|
||
return provider, endpoint, key
|
||
|
||
provider_offset += provider_batch_size
|
||
|
||
raise ProviderNotAvailableException(f"所有Provider的资源当前不可用 (model={model_name})")
|
||
|
||
def _get_effective_concurrent_limit(self, key: ProviderAPIKey) -> Optional[int]:
|
||
"""
|
||
获取有效的并发限制
|
||
|
||
新逻辑:
|
||
- max_concurrent=NULL: 启用自适应,使用 learned_max_concurrent(如无学习记录则为 None)
|
||
- max_concurrent=数字: 固定限制,直接使用该值
|
||
|
||
Args:
|
||
key: API Key对象
|
||
|
||
Returns:
|
||
有效的并发限制(None 表示不限制)
|
||
"""
|
||
if key.max_concurrent is None:
|
||
# 自适应模式:使用学习到的值
|
||
learned = key.learned_max_concurrent
|
||
return int(learned) if learned is not None else None
|
||
else:
|
||
# 固定限制模式
|
||
return int(key.max_concurrent)
|
||
|
||
async def _check_concurrent_available(
|
||
self,
|
||
endpoint: ProviderEndpoint,
|
||
key: ProviderAPIKey,
|
||
is_cached_user: bool = False,
|
||
) -> Tuple[bool, ConcurrencySnapshot]:
|
||
"""
|
||
检查并发是否可用(使用动态预留机制)
|
||
|
||
核心逻辑 - 动态缓存预留机制:
|
||
- 总槽位: 有效并发限制(固定值或学习到的值)
|
||
- 预留比例: 由 AdaptiveReservationManager 根据置信度和负载动态计算
|
||
- 缓存用户可用: 全部槽位
|
||
- 新用户可用: 总槽位 × (1 - 动态预留比例)
|
||
|
||
Args:
|
||
endpoint: ProviderEndpoint对象
|
||
key: ProviderAPIKey对象
|
||
is_cached_user: 是否是缓存用户
|
||
|
||
Returns:
|
||
(是否可用, 并发快照)
|
||
"""
|
||
# 获取有效的并发限制
|
||
effective_key_limit = self._get_effective_concurrent_limit(key)
|
||
|
||
logger.debug(
|
||
f" -> 并发检查: _concurrency_manager={self._concurrency_manager is not None}, "
|
||
f"is_cached_user={is_cached_user}, effective_limit={effective_key_limit}"
|
||
)
|
||
|
||
if not self._concurrency_manager:
|
||
# 并发管理器不可用,直接返回True
|
||
logger.debug(f" -> 无并发管理器,直接通过")
|
||
snapshot = ConcurrencySnapshot(
|
||
endpoint_current=0,
|
||
endpoint_limit=(
|
||
int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None
|
||
),
|
||
key_current=0,
|
||
key_limit=effective_key_limit,
|
||
is_cached_user=is_cached_user,
|
||
)
|
||
return True, snapshot
|
||
|
||
# 获取当前并发数
|
||
endpoint_count, key_count = await self._concurrency_manager.get_current_concurrency(
|
||
endpoint_id=str(endpoint.id),
|
||
key_id=str(key.id),
|
||
)
|
||
|
||
can_use = True
|
||
|
||
# 检查Endpoint级别限制
|
||
if endpoint.max_concurrent is not None:
|
||
if endpoint_count >= endpoint.max_concurrent:
|
||
can_use = False
|
||
|
||
# 计算动态预留比例
|
||
reservation_result = self._reservation_manager.calculate_reservation(
|
||
key=key,
|
||
current_concurrent=key_count,
|
||
effective_limit=effective_key_limit,
|
||
)
|
||
|
||
# 更新指标
|
||
if reservation_result.phase == "probe":
|
||
self._metrics["reservation_probe_count"] += 1
|
||
else:
|
||
self._metrics["reservation_stable_count"] += 1
|
||
|
||
# 计算移动平均预留比例
|
||
total_reservations = (
|
||
self._metrics["reservation_probe_count"] + self._metrics["reservation_stable_count"]
|
||
)
|
||
if total_reservations > 0:
|
||
# 指数移动平均
|
||
alpha = 0.1
|
||
self._metrics["avg_reservation_ratio"] = (
|
||
alpha * reservation_result.ratio
|
||
+ (1 - alpha) * self._metrics["avg_reservation_ratio"]
|
||
)
|
||
|
||
self._metrics["last_reservation_result"] = {
|
||
"ratio": reservation_result.ratio,
|
||
"phase": reservation_result.phase,
|
||
"confidence": reservation_result.confidence,
|
||
"load_factor": reservation_result.load_factor,
|
||
}
|
||
|
||
available_for_new = None
|
||
reservation_ratio = reservation_result.ratio
|
||
|
||
# 检查Key级别限制(使用动态预留比例)
|
||
if effective_key_limit is not None:
|
||
if is_cached_user:
|
||
# 缓存用户: 可以使用全部槽位
|
||
if key_count >= effective_key_limit:
|
||
can_use = False
|
||
else:
|
||
# 新用户: 只能使用 (1 - 动态预留比例) 的槽位
|
||
# 使用 max 确保至少有 1 个槽位可用
|
||
import math
|
||
|
||
available_for_new = max(1, math.ceil(effective_key_limit * (1 - reservation_ratio)))
|
||
if key_count >= available_for_new:
|
||
logger.debug(
|
||
f"Key {key.id[:8]}... 新用户配额已满 "
|
||
f"({key_count}/{available_for_new}, 总{effective_key_limit}, "
|
||
f"预留{reservation_ratio:.0%}[{reservation_result.phase}])"
|
||
)
|
||
can_use = False
|
||
|
||
key_limit_for_snapshot: Optional[int]
|
||
if is_cached_user:
|
||
key_limit_for_snapshot = effective_key_limit
|
||
elif effective_key_limit is not None:
|
||
key_limit_for_snapshot = (
|
||
available_for_new if available_for_new is not None else effective_key_limit
|
||
)
|
||
else:
|
||
key_limit_for_snapshot = None
|
||
|
||
snapshot = ConcurrencySnapshot(
|
||
endpoint_current=endpoint_count,
|
||
endpoint_limit=endpoint.max_concurrent,
|
||
key_current=key_count,
|
||
key_limit=key_limit_for_snapshot,
|
||
is_cached_user=is_cached_user,
|
||
reservation_ratio=reservation_ratio,
|
||
reservation_phase=reservation_result.phase,
|
||
reservation_confidence=reservation_result.confidence,
|
||
)
|
||
|
||
return can_use, snapshot
|
||
|
||
def _get_effective_restrictions(
|
||
self,
|
||
user_api_key: Optional[ApiKey],
|
||
) -> Dict[str, Optional[set]]:
|
||
"""
|
||
获取有效的访问限制(合并 ApiKey 和 User 的限制)
|
||
|
||
逻辑:
|
||
- 如果 ApiKey 和 User 都有限制,取交集
|
||
- 如果只有一方有限制,使用该方的限制
|
||
- 如果都没有限制,返回 None(表示不限制)
|
||
|
||
Args:
|
||
user_api_key: 用户 API Key 对象(可能包含 user relationship)
|
||
|
||
Returns:
|
||
包含 allowed_providers, allowed_endpoints, allowed_models 的字典
|
||
"""
|
||
result = {
|
||
"allowed_providers": None,
|
||
"allowed_endpoints": None,
|
||
"allowed_models": None,
|
||
"allowed_api_formats": None,
|
||
}
|
||
|
||
if not user_api_key:
|
||
return result
|
||
|
||
# 获取 User 的限制
|
||
# 注意:这里可能触发 lazy loading,需要确保 session 仍然有效
|
||
try:
|
||
user = user_api_key.user if hasattr(user_api_key, "user") else None
|
||
except Exception as e:
|
||
logger.warning(f"无法加载 ApiKey 关联的 User: {e},仅使用 ApiKey 级别的限制")
|
||
user = None
|
||
|
||
# 调试日志
|
||
logger.debug(
|
||
f"[_get_effective_restrictions] ApiKey={user_api_key.id[:8]}..., "
|
||
f"User={user.id[:8] if user else 'None'}..., "
|
||
f"ApiKey.allowed_models={user_api_key.allowed_models}, "
|
||
f"User.allowed_models={user.allowed_models if user else 'N/A'}"
|
||
)
|
||
|
||
def merge_restrictions(key_restriction, user_restriction):
|
||
"""合并两个限制列表,返回有效的限制集合"""
|
||
key_set = set(key_restriction) if key_restriction else None
|
||
user_set = set(user_restriction) if user_restriction else None
|
||
|
||
if key_set and user_set:
|
||
# 两者都有限制,取交集
|
||
return key_set & user_set
|
||
elif key_set:
|
||
return key_set
|
||
elif user_set:
|
||
return user_set
|
||
else:
|
||
return None
|
||
|
||
# 合并 allowed_providers
|
||
result["allowed_providers"] = merge_restrictions(
|
||
user_api_key.allowed_providers, user.allowed_providers if user else None
|
||
)
|
||
|
||
# 合并 allowed_endpoints
|
||
result["allowed_endpoints"] = merge_restrictions(
|
||
user_api_key.allowed_endpoints if hasattr(user_api_key, "allowed_endpoints") else None,
|
||
user.allowed_endpoints if user else None,
|
||
)
|
||
|
||
# 合并 allowed_models
|
||
result["allowed_models"] = merge_restrictions(
|
||
user_api_key.allowed_models, user.allowed_models if user else None
|
||
)
|
||
|
||
# API 格式仅从 ApiKey 获取(User 不设置此限制)
|
||
if user_api_key.allowed_api_formats:
|
||
result["allowed_api_formats"] = set(user_api_key.allowed_api_formats)
|
||
|
||
return result
|
||
|
||
async def list_all_candidates(
|
||
self,
|
||
db: Session,
|
||
api_format: Union[str, APIFormat],
|
||
model_name: str,
|
||
affinity_key: Optional[str] = None,
|
||
user_api_key: Optional[ApiKey] = None,
|
||
provider_offset: int = 0,
|
||
provider_limit: Optional[int] = None,
|
||
max_candidates: Optional[int] = None,
|
||
is_stream: bool = False,
|
||
capability_requirements: Optional[Dict[str, bool]] = None,
|
||
) -> Tuple[List[ProviderCandidate], str]:
|
||
"""
|
||
预先获取所有可用的 Provider/Endpoint/Key 组合
|
||
|
||
重构后的方法将逻辑拆分为:
|
||
1. _query_providers: 数据库查询逻辑
|
||
2. _build_candidates: 候选构建逻辑
|
||
3. _apply_cache_affinity: 缓存亲和性处理
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
api_format: API 格式
|
||
model_name: 模型名称
|
||
affinity_key: 亲和性标识符(通常为API Key ID,用于缓存亲和性)
|
||
user_api_key: 用户 API Key(用于访问限制过滤,同时考虑 User 级别限制)
|
||
provider_offset: Provider 分页偏移
|
||
provider_limit: Provider 分页限制
|
||
max_candidates: 最大候选数量
|
||
is_stream: 是否是流式请求,如果为 True 则过滤不支持流式的 Provider
|
||
capability_requirements: 能力需求(用于过滤不满足能力要求的 Key)
|
||
|
||
Returns:
|
||
(候选列表, global_model_id) - global_model_id 用于缓存亲和性
|
||
"""
|
||
await self._ensure_initialized()
|
||
|
||
target_format = normalize_api_format(api_format)
|
||
|
||
# 0. 解析 model_name 到 GlobalModel(支持直接匹配和别名匹配,使用 ModelCacheService)
|
||
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
|
||
|
||
if not global_model:
|
||
logger.warning(f"GlobalModel not found: {model_name}")
|
||
raise ModelNotSupportedException(model=model_name)
|
||
|
||
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存
|
||
global_model_id: str = str(global_model.id)
|
||
requested_model_name = model_name
|
||
resolved_model_name = str(global_model.name)
|
||
|
||
# 获取合并后的访问限制(ApiKey + User)
|
||
restrictions = self._get_effective_restrictions(user_api_key)
|
||
allowed_api_formats = restrictions["allowed_api_formats"]
|
||
allowed_providers = restrictions["allowed_providers"]
|
||
allowed_endpoints = restrictions["allowed_endpoints"]
|
||
allowed_models = restrictions["allowed_models"]
|
||
|
||
# 0.1 检查 API 格式是否被允许
|
||
if allowed_api_formats is not None:
|
||
if target_format.value not in allowed_api_formats:
|
||
logger.debug(
|
||
f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, "
|
||
f"允许的格式: {allowed_api_formats}"
|
||
)
|
||
return [], global_model_id
|
||
|
||
# 0.2 检查模型是否被允许
|
||
if allowed_models is not None:
|
||
if (
|
||
requested_model_name not in allowed_models
|
||
and resolved_model_name not in allowed_models
|
||
):
|
||
resolved_note = (
|
||
f" (解析为 {resolved_model_name})"
|
||
if resolved_model_name != requested_model_name
|
||
else ""
|
||
)
|
||
logger.debug(
|
||
f"用户/API Key 不允许使用模型 {requested_model_name}{resolved_note}, "
|
||
f"允许的模型: {allowed_models}"
|
||
)
|
||
return [], global_model_id
|
||
|
||
# 1. 查询 Providers
|
||
providers = self._query_providers(
|
||
db=db,
|
||
provider_offset=provider_offset,
|
||
provider_limit=provider_limit,
|
||
)
|
||
|
||
if not providers:
|
||
return [], global_model_id
|
||
|
||
# 1.5 根据 allowed_providers 过滤(合并 ApiKey 和 User 的限制)
|
||
if allowed_providers is not None:
|
||
original_count = len(providers)
|
||
# 同时支持 provider id 和 name 匹配
|
||
providers = [
|
||
p for p in providers if p.id in allowed_providers or p.name in allowed_providers
|
||
]
|
||
if original_count != len(providers):
|
||
logger.debug(f"用户/API Key 过滤 Provider: {original_count} -> {len(providers)}")
|
||
|
||
if not providers:
|
||
return [], global_model_id
|
||
|
||
# 2. 构建候选列表(传入 allowed_endpoints、is_stream 和 capability_requirements 用于过滤)
|
||
candidates = await self._build_candidates(
|
||
db=db,
|
||
providers=providers,
|
||
target_format=target_format,
|
||
model_name=requested_model_name,
|
||
resolved_model_name=resolved_model_name,
|
||
affinity_key=affinity_key,
|
||
max_candidates=max_candidates,
|
||
allowed_endpoints=allowed_endpoints,
|
||
is_stream=is_stream,
|
||
capability_requirements=capability_requirements,
|
||
)
|
||
|
||
# 3. 应用优先级模式排序
|
||
candidates = self._apply_priority_mode_sort(candidates, affinity_key)
|
||
|
||
# 更新指标
|
||
self._metrics["total_candidates"] += len(candidates)
|
||
self._metrics["last_candidate_count"] = len(candidates)
|
||
|
||
logger.debug(
|
||
f"预先获取到 {len(candidates)} 个可用组合 "
|
||
f"(api_format={target_format.value}, model={model_name})"
|
||
)
|
||
|
||
# 4. 应用缓存亲和性排序(使用 global_model_id 作为模型标识)
|
||
if affinity_key and candidates:
|
||
candidates = await self._apply_cache_affinity(
|
||
candidates=candidates,
|
||
affinity_key=affinity_key,
|
||
api_format=target_format,
|
||
global_model_id=global_model_id,
|
||
)
|
||
|
||
return candidates, global_model_id
|
||
|
||
def _query_providers(
|
||
self,
|
||
db: Session,
|
||
provider_offset: int = 0,
|
||
provider_limit: Optional[int] = None,
|
||
) -> List[Provider]:
|
||
"""
|
||
查询活跃的 Providers(带预加载)
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
provider_offset: 分页偏移
|
||
provider_limit: 分页限制
|
||
|
||
Returns:
|
||
Provider 列表
|
||
"""
|
||
provider_query = (
|
||
db.query(Provider)
|
||
.options(
|
||
selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys),
|
||
# 同时加载 models 和 global_model 关系,以便 get_effective_* 方法能正确继承默认值
|
||
selectinload(Provider.models).selectinload(Model.global_model),
|
||
)
|
||
.filter(Provider.is_active == True)
|
||
.order_by(Provider.provider_priority.asc())
|
||
)
|
||
|
||
if provider_offset:
|
||
provider_query = provider_query.offset(provider_offset)
|
||
if provider_limit:
|
||
provider_query = provider_query.limit(provider_limit)
|
||
|
||
return provider_query.all()
|
||
|
||
async def _check_model_support(
|
||
self,
|
||
db: Session,
|
||
provider: Provider,
|
||
model_name: str,
|
||
is_stream: bool = False,
|
||
capability_requirements: Optional[Dict[str, bool]] = None,
|
||
) -> Tuple[bool, Optional[str], Optional[List[str]]]:
|
||
"""
|
||
检查 Provider 是否支持指定模型(可选检查流式支持和能力需求)
|
||
|
||
模型能力检查在这里进行(而不是在 Key 级别),因为:
|
||
- 模型支持的能力是全局的,与具体的 Key 无关
|
||
- 如果模型不支持某能力,整个 Provider 的所有 Key 都应该被跳过
|
||
|
||
支持两种匹配方式:
|
||
1. 直接匹配 GlobalModel.name
|
||
2. 通过 ModelCacheService 匹配别名(全局查找)
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
provider: Provider 对象
|
||
model_name: 模型名称(可以是 GlobalModel.name 或别名)
|
||
is_stream: 是否是流式请求,如果为 True 则同时检查流式支持
|
||
capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力
|
||
|
||
Returns:
|
||
(is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表
|
||
"""
|
||
# 使用 ModelCacheService 解析模型名称(支持别名)
|
||
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
|
||
|
||
if not global_model:
|
||
# 完全未找到匹配
|
||
return False, "模型不存在或 Provider 未配置此模型", None
|
||
|
||
# 找到 GlobalModel 后,检查当前 Provider 是否支持
|
||
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
|
||
db, provider, global_model, model_name, is_stream, capability_requirements
|
||
)
|
||
return is_supported, skip_reason, caps
|
||
|
||
async def _check_model_support_for_global_model(
|
||
self,
|
||
db: Session,
|
||
provider: Provider,
|
||
global_model: "GlobalModel",
|
||
model_name: str,
|
||
is_stream: bool = False,
|
||
capability_requirements: Optional[Dict[str, bool]] = None,
|
||
) -> Tuple[bool, Optional[str], Optional[List[str]]]:
|
||
"""
|
||
检查 Provider 是否支持指定的 GlobalModel
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
provider: Provider 对象
|
||
global_model: GlobalModel 对象
|
||
model_name: 用户请求的模型名称(用于错误消息)
|
||
is_stream: 是否是流式请求
|
||
capability_requirements: 能力需求
|
||
|
||
Returns:
|
||
(is_supported, skip_reason, supported_capabilities)
|
||
"""
|
||
# 确保 global_model 附加到当前 Session
|
||
# 注意:从缓存重建的对象是 transient 状态,不能使用 load=False
|
||
# 使用 load=True(默认)允许 SQLAlchemy 正确处理 transient 对象
|
||
from sqlalchemy import inspect
|
||
insp = inspect(global_model)
|
||
if insp.transient or insp.detached:
|
||
# transient/detached 对象:使用默认 merge(会查询 DB 检查是否存在)
|
||
global_model = db.merge(global_model)
|
||
else:
|
||
# persistent 对象:已经附加到 session,无需 merge
|
||
pass
|
||
|
||
# 获取模型支持的能力列表
|
||
model_supported_capabilities: List[str] = list(global_model.supported_capabilities or [])
|
||
|
||
# 查询该 Provider 是否有实现这个 GlobalModel
|
||
for model in provider.models:
|
||
if model.global_model_id == global_model.id and model.is_active:
|
||
logger.debug(
|
||
f"[_check_model_support_for_global_model] Provider={provider.name}, "
|
||
f"GlobalModel={global_model.name}, "
|
||
f"provider_model_name={model.provider_model_name}"
|
||
)
|
||
# 检查流式支持
|
||
if is_stream:
|
||
supports_streaming = model.get_effective_supports_streaming()
|
||
if not supports_streaming:
|
||
return False, f"模型 {model_name} 在此 Provider 不支持流式", None
|
||
|
||
# 检查模型是否支持所需的能力(在 Provider 级别检查,而不是 Key 级别)
|
||
# 只有当 model_supported_capabilities 非空时才进行检查
|
||
# 空列表意味着模型没有配置能力限制,默认支持所有能力
|
||
if capability_requirements and model_supported_capabilities:
|
||
for cap_name, is_required in capability_requirements.items():
|
||
if is_required and cap_name not in model_supported_capabilities:
|
||
return (
|
||
False,
|
||
f"模型 {model_name} 不支持能力: {cap_name}",
|
||
list(model_supported_capabilities),
|
||
)
|
||
|
||
return True, None, list(model_supported_capabilities)
|
||
|
||
return False, "Provider 未实现此模型", None
|
||
|
||
def _check_key_availability(
|
||
self,
|
||
key: ProviderAPIKey,
|
||
model_name: str,
|
||
capability_requirements: Optional[Dict[str, bool]] = None,
|
||
resolved_model_name: Optional[str] = None,
|
||
) -> Tuple[bool, Optional[str]]:
|
||
"""
|
||
检查 API Key 的可用性
|
||
|
||
注意:模型能力检查已移到 _check_model_support 中进行(Provider 级别),
|
||
这里只检查 Key 级别的能力匹配。
|
||
|
||
Args:
|
||
key: API Key 对象
|
||
model_name: 模型名称
|
||
capability_requirements: 能力需求(可选)
|
||
resolved_model_name: 解析后的 GlobalModel.name(可选)
|
||
|
||
Returns:
|
||
(is_available, skip_reason)
|
||
"""
|
||
# 检查熔断器状态(使用详细状态方法获取更丰富的跳过原因)
|
||
is_available, circuit_reason = health_monitor.get_circuit_breaker_status(key)
|
||
if not is_available:
|
||
return False, circuit_reason or "熔断器已打开"
|
||
|
||
# 模型权限检查:使用 allowed_models 白名单
|
||
# None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型
|
||
if key.allowed_models is not None and (
|
||
model_name not in key.allowed_models
|
||
and (not resolved_model_name or resolved_model_name not in key.allowed_models)
|
||
):
|
||
allowed_preview = ", ".join(key.allowed_models[:3]) if key.allowed_models else "(无)"
|
||
suffix = "..." if len(key.allowed_models) > 3 else ""
|
||
return False, f"模型权限不匹配(允许: {allowed_preview}{suffix})"
|
||
|
||
# Key 级别的能力匹配检查
|
||
# 注意:模型级别的能力检查已在 _check_model_support 中完成
|
||
# 始终执行检查,即使 capability_requirements 为空
|
||
# 因为 check_capability_match 会检查 Key 的 EXCLUSIVE 能力是否被浪费
|
||
from src.core.key_capabilities import check_capability_match
|
||
|
||
key_caps: Dict[str, bool] = dict(key.capabilities or {})
|
||
is_match, skip_reason = check_capability_match(key_caps, capability_requirements)
|
||
if not is_match:
|
||
return False, skip_reason
|
||
|
||
return True, None
|
||
|
||
async def _build_candidates(
|
||
self,
|
||
db: Session,
|
||
providers: List[Provider],
|
||
target_format: APIFormat,
|
||
model_name: str,
|
||
affinity_key: Optional[str],
|
||
resolved_model_name: Optional[str] = None,
|
||
max_candidates: Optional[int] = None,
|
||
allowed_endpoints: Optional[set] = None,
|
||
is_stream: bool = False,
|
||
capability_requirements: Optional[Dict[str, bool]] = None,
|
||
) -> List[ProviderCandidate]:
|
||
"""
|
||
构建候选列表
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
providers: Provider 列表
|
||
target_format: 目标 API 格式
|
||
model_name: 模型名称(用户请求的名称,可能是别名)
|
||
affinity_key: 亲和性标识符(通常为API Key ID)
|
||
resolved_model_name: 解析后的 GlobalModel.name(用于 Key.allowed_models 校验)
|
||
max_candidates: 最大候选数
|
||
allowed_endpoints: 允许的 Endpoint ID 集合(None 表示不限制)
|
||
is_stream: 是否是流式请求,如果为 True 则过滤不支持流式的 Provider
|
||
capability_requirements: 能力需求(可选)
|
||
|
||
Returns:
|
||
候选列表
|
||
"""
|
||
candidates: List[ProviderCandidate] = []
|
||
|
||
for provider in providers:
|
||
# 检查模型支持(同时检查流式支持和模型能力需求)
|
||
# 模型能力检查在 Provider 级别进行,如果模型不支持所需能力,整个 Provider 被跳过
|
||
supports_model, skip_reason, _model_caps = await self._check_model_support(
|
||
db, provider, model_name, is_stream, capability_requirements
|
||
)
|
||
if not supports_model:
|
||
logger.debug(f"Provider {provider.name} 不支持模型 {model_name}: {skip_reason}")
|
||
continue
|
||
|
||
for endpoint in provider.endpoints:
|
||
# endpoint.api_format 是字符串,target_format 是枚举
|
||
endpoint_format_str = (
|
||
endpoint.api_format
|
||
if isinstance(endpoint.api_format, str)
|
||
else endpoint.api_format.value
|
||
)
|
||
if not endpoint.is_active or endpoint_format_str != target_format.value:
|
||
continue
|
||
|
||
# 检查 Endpoint 是否在允许列表中
|
||
if allowed_endpoints is not None and endpoint.id not in allowed_endpoints:
|
||
logger.debug(
|
||
f"Endpoint {endpoint.id[:8]}... 不在用户/API Key 的允许列表中,跳过"
|
||
)
|
||
continue
|
||
|
||
# 获取活跃的 Key 并按 internal_priority + 负载均衡排序
|
||
active_keys = [key for key in endpoint.api_keys if key.is_active]
|
||
keys = self._shuffle_keys_by_internal_priority(active_keys, affinity_key)
|
||
|
||
for key in keys:
|
||
# Key 级别的能力检查(模型级别的能力检查已在上面完成)
|
||
is_available, skip_reason = self._check_key_availability(
|
||
key,
|
||
model_name,
|
||
capability_requirements,
|
||
resolved_model_name=resolved_model_name,
|
||
)
|
||
|
||
candidate = ProviderCandidate(
|
||
provider=provider,
|
||
endpoint=endpoint,
|
||
key=key,
|
||
is_skipped=not is_available,
|
||
skip_reason=skip_reason,
|
||
)
|
||
candidates.append(candidate)
|
||
|
||
if max_candidates and len(candidates) >= max_candidates:
|
||
return candidates
|
||
|
||
return candidates
|
||
|
||
async def _apply_cache_affinity(
|
||
self,
|
||
candidates: List[ProviderCandidate],
|
||
affinity_key: str,
|
||
api_format: APIFormat,
|
||
global_model_id: str,
|
||
) -> List[ProviderCandidate]:
|
||
"""
|
||
应用缓存亲和性排序
|
||
|
||
缓存命中的候选会被提升到列表前面
|
||
|
||
Args:
|
||
candidates: 候选列表
|
||
affinity_key: 亲和性标识符(通常为API Key ID)
|
||
api_format: API 格式
|
||
global_model_id: GlobalModel ID(规范化的模型标识)
|
||
|
||
Returns:
|
||
重排序后的候选列表
|
||
"""
|
||
try:
|
||
# 查询该亲和性标识符在当前 API 格式和模型下的缓存亲和性
|
||
api_format_str = api_format.value if isinstance(api_format, APIFormat) else api_format
|
||
affinity = await self._affinity_manager.get_affinity(
|
||
affinity_key, api_format_str, global_model_id
|
||
)
|
||
|
||
if not affinity:
|
||
# 没有缓存亲和性,所有候选都标记为非缓存
|
||
for candidate in candidates:
|
||
candidate.is_cached = False
|
||
return candidates
|
||
|
||
# 按是否匹配缓存亲和性分类候选
|
||
cached_candidates: List[ProviderCandidate] = []
|
||
other_candidates: List[ProviderCandidate] = []
|
||
matched = False
|
||
|
||
for candidate in candidates:
|
||
provider = candidate.provider
|
||
endpoint = candidate.endpoint
|
||
key = candidate.key
|
||
|
||
if (
|
||
provider.id == affinity.provider_id
|
||
and endpoint.id == affinity.endpoint_id
|
||
and key.id == affinity.key_id
|
||
):
|
||
candidate.is_cached = True
|
||
cached_candidates.append(candidate)
|
||
matched = True
|
||
logger.debug(
|
||
f"检测到缓存亲和性: affinity_key={affinity_key[:8]}..., "
|
||
f"api_format={api_format_str}, global_model_id={global_model_id[:8]}..., "
|
||
f"provider={provider.name}, endpoint={endpoint.id[:8]}..., "
|
||
f"provider_key=***{key.api_key[-4:]}, "
|
||
f"使用次数={affinity.request_count}"
|
||
)
|
||
else:
|
||
candidate.is_cached = False
|
||
other_candidates.append(candidate)
|
||
|
||
if not matched:
|
||
logger.debug(f"API格式 {api_format_str} 的缓存亲和性存在但组合不可用")
|
||
|
||
# 重新排序:缓存候选优先
|
||
if cached_candidates:
|
||
result = cached_candidates + other_candidates
|
||
logger.debug(f"{len(cached_candidates)} 个缓存组合已提升至优先级")
|
||
return result
|
||
|
||
return candidates
|
||
|
||
except Exception as e:
|
||
logger.warning(f"检查缓存亲和性失败: {e},继续使用默认排序")
|
||
return candidates
|
||
|
||
def _normalize_priority_mode(self, mode: Optional[str]) -> str:
|
||
normalized = (mode or "").strip().lower()
|
||
if normalized not in self.ALLOWED_PRIORITY_MODES:
|
||
if normalized:
|
||
logger.warning(f"[CacheAwareScheduler] 无效的优先级模式 '{mode}',回退为 provider")
|
||
return self.PRIORITY_MODE_PROVIDER
|
||
return normalized
|
||
|
||
def set_priority_mode(self, mode: Optional[str]) -> None:
|
||
"""运行时更新候选排序策略"""
|
||
normalized = self._normalize_priority_mode(mode)
|
||
if normalized == self.priority_mode:
|
||
return
|
||
self.priority_mode = normalized
|
||
logger.debug(f"[CacheAwareScheduler] 切换优先级模式为: {self.priority_mode}")
|
||
|
||
def _apply_priority_mode_sort(
|
||
self, candidates: List[ProviderCandidate], affinity_key: Optional[str] = None
|
||
) -> List[ProviderCandidate]:
|
||
"""
|
||
根据优先级模式对候选列表排序(数字越小越优先)
|
||
|
||
- provider: 提供商优先模式,保持原有顺序(按 Provider.provider_priority -> Key.internal_priority 排序,已由查询保证)
|
||
Key.internal_priority 表示 Endpoint 内部优先级,同优先级内通过哈希分散负载均衡
|
||
- global_key: 全局 Key 优先模式,按 Key.global_priority 升序排序(数字小的优先)
|
||
有 global_priority 的优先,NULL 的排后面
|
||
同 global_priority 内通过哈希分散实现负载均衡
|
||
"""
|
||
if not candidates:
|
||
return candidates
|
||
|
||
if self.priority_mode == self.PRIORITY_MODE_GLOBAL_KEY:
|
||
# 全局 Key 优先模式:按 global_priority 分组,同组内哈希分散负载均衡
|
||
return self._sort_by_global_priority_with_hash(candidates, affinity_key)
|
||
|
||
# 提供商优先模式:保持原有顺序(provider_priority 排序已经由查询保证)
|
||
return candidates
|
||
|
||
def _sort_by_global_priority_with_hash(
|
||
self, candidates: List[ProviderCandidate], affinity_key: Optional[str] = None
|
||
) -> List[ProviderCandidate]:
|
||
"""
|
||
按 global_priority 分组排序,同优先级内通过哈希分散实现负载均衡
|
||
|
||
排序逻辑:
|
||
1. 按 global_priority 分组(数字小的优先,NULL 排后面)
|
||
2. 同 global_priority 组内,使用 affinity_key 哈希分散
|
||
3. 确保同一用户请求稳定选择同一个 Key(缓存亲和性)
|
||
"""
|
||
import hashlib
|
||
from collections import defaultdict
|
||
|
||
# 按 global_priority 分组
|
||
priority_groups: Dict[int, List[ProviderCandidate]] = defaultdict(list)
|
||
for candidate in candidates:
|
||
global_priority = (
|
||
candidate.key.global_priority
|
||
if candidate.key and candidate.key.global_priority is not None
|
||
else 999999 # NULL 排在后面
|
||
)
|
||
priority_groups[global_priority].append(candidate)
|
||
|
||
result = []
|
||
for priority in sorted(priority_groups.keys()): # 数字小的优先级高
|
||
group = priority_groups[priority]
|
||
|
||
if len(group) > 1 and affinity_key:
|
||
# 同优先级内哈希分散负载均衡
|
||
scored_candidates = []
|
||
for candidate in group:
|
||
key_id = candidate.key.id if candidate.key else ""
|
||
hash_input = f"{affinity_key}:{key_id}"
|
||
hash_value = int(hashlib.sha256(hash_input.encode()).hexdigest()[:16], 16)
|
||
scored_candidates.append((hash_value, candidate))
|
||
|
||
# 按哈希值排序
|
||
sorted_group = [c for _, c in sorted(scored_candidates)]
|
||
result.extend(sorted_group)
|
||
else:
|
||
# 单个候选或没有 affinity_key,按次要排序条件排序
|
||
def secondary_sort(c: ProviderCandidate):
|
||
return (
|
||
c.provider.provider_priority,
|
||
c.key.internal_priority if c.key else 999999,
|
||
c.key.id if c.key else "",
|
||
)
|
||
|
||
result.extend(sorted(group, key=secondary_sort))
|
||
|
||
return result
|
||
|
||
def _shuffle_keys_by_internal_priority(
|
||
self,
|
||
keys: List[ProviderAPIKey],
|
||
affinity_key: Optional[str] = None,
|
||
) -> List[ProviderAPIKey]:
|
||
"""
|
||
对 API Key 按 internal_priority 分组,同优先级内部基于 affinity_key 进行确定性打乱
|
||
|
||
目的:
|
||
- 数字越小越优先使用
|
||
- 同优先级 Key 之间实现负载均衡
|
||
- 使用 affinity_key 哈希确保同一请求 Key 的请求稳定(避免破坏缓存亲和性)
|
||
|
||
Args:
|
||
keys: API Key 列表
|
||
affinity_key: 亲和性标识符(通常为 API Key ID,用于确定性打乱)
|
||
|
||
Returns:
|
||
排序后的 Key 列表
|
||
"""
|
||
if not keys:
|
||
return []
|
||
|
||
# 按 internal_priority 分组
|
||
from collections import defaultdict
|
||
|
||
priority_groups: Dict[int, List[ProviderAPIKey]] = defaultdict(list)
|
||
|
||
for key in keys:
|
||
priority = key.internal_priority if key.internal_priority is not None else 999999
|
||
priority_groups[priority].append(key)
|
||
|
||
# 对每个优先级组内的 Key 进行确定性打乱
|
||
result = []
|
||
for priority in sorted(priority_groups.keys()): # 数字小的优先级高,排前面
|
||
group_keys = priority_groups[priority]
|
||
|
||
if len(group_keys) > 1 and affinity_key:
|
||
# 改进的哈希策略:为每个 key 计算独立的哈希值
|
||
import hashlib
|
||
|
||
key_scores = []
|
||
for key in group_keys:
|
||
# 使用 affinity_key + key.id 的组合哈希
|
||
hash_input = f"{affinity_key}:{key.id}"
|
||
hash_value = int(hashlib.sha256(hash_input.encode()).hexdigest()[:16], 16)
|
||
key_scores.append((hash_value, key))
|
||
|
||
# 按哈希值排序
|
||
sorted_group = [key for _, key in sorted(key_scores)]
|
||
result.extend(sorted_group)
|
||
else:
|
||
# 单个 Key 或没有 affinity_key 时保持原顺序
|
||
result.extend(sorted(group_keys, key=lambda k: k.id))
|
||
|
||
return result
|
||
|
||
async def invalidate_cache(
|
||
self,
|
||
affinity_key: str,
|
||
api_format: str,
|
||
global_model_id: str,
|
||
endpoint_id: Optional[str] = None,
|
||
key_id: Optional[str] = None,
|
||
provider_id: Optional[str] = None,
|
||
):
|
||
"""
|
||
失效指定亲和性标识符对特定API格式和模型的缓存亲和性
|
||
|
||
Args:
|
||
affinity_key: 亲和性标识符(通常为API Key ID)
|
||
api_format: API格式 (claude/openai)
|
||
global_model_id: GlobalModel ID(规范化的模型标识)
|
||
endpoint_id: 端点ID(可选,如果提供则只在Endpoint匹配时失效)
|
||
key_id: Provider Key ID(可选)
|
||
provider_id: Provider ID(可选)
|
||
"""
|
||
await self._ensure_initialized()
|
||
await self._affinity_manager.invalidate_affinity(
|
||
affinity_key=affinity_key,
|
||
api_format=api_format,
|
||
model_name=global_model_id,
|
||
endpoint_id=endpoint_id,
|
||
key_id=key_id,
|
||
provider_id=provider_id,
|
||
)
|
||
|
||
async def set_cache_affinity(
|
||
self,
|
||
affinity_key: str,
|
||
provider_id: str,
|
||
endpoint_id: str,
|
||
key_id: str,
|
||
api_format: str,
|
||
global_model_id: str,
|
||
ttl: Optional[int] = None,
|
||
):
|
||
"""
|
||
记录缓存亲和性(供编排器调用)
|
||
|
||
Args:
|
||
affinity_key: 亲和性标识符(通常为API Key ID)
|
||
provider_id: Provider ID
|
||
endpoint_id: Endpoint ID
|
||
key_id: Provider Key ID
|
||
api_format: API格式
|
||
global_model_id: GlobalModel ID(规范化的模型标识)
|
||
ttl: 缓存TTL(秒)
|
||
|
||
注意:每次调用都会刷新过期时间,实现滑动窗口机制
|
||
"""
|
||
await self._ensure_initialized()
|
||
|
||
await self._affinity_manager.set_affinity(
|
||
affinity_key=affinity_key,
|
||
provider_id=provider_id,
|
||
endpoint_id=endpoint_id,
|
||
key_id=key_id,
|
||
api_format=api_format,
|
||
model_name=global_model_id,
|
||
supports_caching=True,
|
||
ttl=ttl,
|
||
)
|
||
|
||
async def get_stats(self) -> dict:
|
||
"""获取调度器统计信息"""
|
||
await self._ensure_initialized()
|
||
|
||
affinity_stats = self._affinity_manager.get_stats()
|
||
metrics = dict(self._metrics)
|
||
|
||
cache_total = metrics["cache_hits"] + metrics["cache_misses"]
|
||
metrics["cache_hit_rate"] = metrics["cache_hits"] / cache_total if cache_total else 0.0
|
||
metrics["avg_candidates_per_batch"] = (
|
||
metrics["total_candidates"] / metrics["total_batches"]
|
||
if metrics["total_batches"]
|
||
else 0.0
|
||
)
|
||
|
||
# 动态预留统计
|
||
reservation_stats = self._reservation_manager.get_stats()
|
||
total_reservation_checks = (
|
||
metrics["reservation_probe_count"] + metrics["reservation_stable_count"]
|
||
)
|
||
if total_reservation_checks > 0:
|
||
probe_ratio = metrics["reservation_probe_count"] / total_reservation_checks
|
||
else:
|
||
probe_ratio = 0.0
|
||
|
||
return {
|
||
"scheduler": "cache_aware",
|
||
"cache_reservation_ratio": self.CACHE_RESERVATION_RATIO,
|
||
"dynamic_reservation": {
|
||
"enabled": True,
|
||
"config": reservation_stats["config"],
|
||
"current_avg_ratio": round(metrics["avg_reservation_ratio"], 3),
|
||
"probe_phase_ratio": round(probe_ratio, 3),
|
||
"total_checks": total_reservation_checks,
|
||
"last_result": metrics["last_reservation_result"],
|
||
},
|
||
"affinity_stats": affinity_stats,
|
||
"scheduler_metrics": metrics,
|
||
}
|
||
|
||
|
||
# 全局单例
|
||
_scheduler: Optional[CacheAwareScheduler] = None
|
||
|
||
|
||
async def get_cache_aware_scheduler(
|
||
redis_client=None,
|
||
priority_mode: Optional[str] = None,
|
||
) -> CacheAwareScheduler:
|
||
"""
|
||
获取全局CacheAwareScheduler实例
|
||
|
||
注意: 不再接受 db 参数,避免持久化请求级别的 Session
|
||
每次调用 scheduler 方法时需要传入当前请求的 db Session
|
||
|
||
Args:
|
||
redis_client: Redis客户端(可选)
|
||
priority_mode: 外部覆盖的优先级模式(provider | global_key)
|
||
|
||
Returns:
|
||
CacheAwareScheduler实例
|
||
"""
|
||
global _scheduler
|
||
|
||
if _scheduler is None:
|
||
_scheduler = CacheAwareScheduler(redis_client, priority_mode=priority_mode)
|
||
elif priority_mode:
|
||
_scheduler.set_priority_mode(priority_mode)
|
||
|
||
return _scheduler
|