From 8f0a0cbdb146401d278471e519c294728a2c88e8 Mon Sep 17 00:00:00 2001 From: fawney19 Date: Mon, 15 Dec 2025 18:13:35 +0800 Subject: [PATCH] refactor(scheduler): integrate model alias resolution - Use ModelCacheService.resolve_global_model_by_name_or_alias() for model lookups - Support both requested model name and resolved GlobalModel name in validation - Track resolved_model_name for proper allow_models checking - Improve model availability checks to handle alias resolution - Fix transient/detached object handling in global_model merge - Add more descriptive debug logs for alias resolution mismatches - Clean up code formatting (line length, imports organization) --- src/services/cache/aware_scheduler.py | 124 ++++++++++++++++++-------- 1 file changed, 85 insertions(+), 39 deletions(-) diff --git a/src/services/cache/aware_scheduler.py b/src/services/cache/aware_scheduler.py index 7c0d11a..4938db6 100644 --- a/src/services/cache/aware_scheduler.py +++ b/src/services/cache/aware_scheduler.py @@ -28,15 +28,17 @@ - 失效缓存亲和性,避免重复选择故障资源 """ +from __future__ import annotations + import time from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union -from src.core.logger import logger from sqlalchemy.orm import Session, selectinload from src.core.enums import APIFormat from src.core.exceptions import ModelNotSupportedException, ProviderNotAvailableException +from src.core.logger import logger from src.models.database import ( ApiKey, Model, @@ -44,10 +46,15 @@ from src.models.database import ( ProviderAPIKey, ProviderEndpoint, ) + +if TYPE_CHECKING: + from src.models.database import GlobalModel + from src.services.cache.affinity_manager import ( CacheAffinityManager, get_affinity_manager, ) +from src.services.cache.model_cache import ModelCacheService from src.services.health.monitor import health_monitor from src.services.provider.format import normalize_api_format from src.services.rate_limit.adaptive_reservation import ( @@ -259,9 +266,11 @@ class CacheAwareScheduler: self._metrics["concurrency_denied"] += 1 continue - logger.debug(f" └─ 选择 Provider={provider.name}, Endpoint={endpoint.id[:8]}..., " + logger.debug( + f" └─ 选择 Provider={provider.name}, Endpoint={endpoint.id[:8]}..., " f"Key=***{key.api_key[-4:]}, 缓存命中={is_cached_user}, " - f"并发状态[{snapshot.describe()}]") + f"并发状态[{snapshot.describe()}]" + ) if key.cache_ttl_minutes > 0 and global_model_id: ttl = key.cache_ttl_minutes * 60 if key.cache_ttl_minutes > 0 else None @@ -349,7 +358,9 @@ class CacheAwareScheduler: logger.debug(f" -> 无并发管理器,直接通过") snapshot = ConcurrencySnapshot( endpoint_current=0, - endpoint_limit=int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None, + endpoint_limit=( + int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None + ), key_current=0, key_limit=effective_key_limit, is_cached_user=is_cached_user, @@ -484,10 +495,12 @@ class CacheAwareScheduler: user = None # 调试日志 - logger.debug(f"[_get_effective_restrictions] ApiKey={user_api_key.id[:8]}..., " + logger.debug( + f"[_get_effective_restrictions] ApiKey={user_api_key.id[:8]}..., " f"User={user.id[:8] if user else 'None'}..., " f"ApiKey.allowed_models={user_api_key.allowed_models}, " - f"User.allowed_models={user.allowed_models if user else 'N/A'}") + f"User.allowed_models={user.allowed_models if user else 'N/A'}" + ) def merge_restrictions(key_restriction, user_restriction): """合并两个限制列表,返回有效的限制集合""" @@ -566,13 +579,8 @@ class CacheAwareScheduler: target_format = normalize_api_format(api_format) - # 0. 解析 model_name 到 GlobalModel(直接查找,用户必须使用标准名称) - from src.models.database import GlobalModel - global_model = ( - db.query(GlobalModel) - .filter(GlobalModel.name == model_name, GlobalModel.is_active == True) - .first() - ) + # 0. 解析 model_name 到 GlobalModel(支持直接匹配和别名匹配,使用 ModelCacheService) + global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name) if not global_model: logger.warning(f"GlobalModel not found: {model_name}") @@ -580,6 +588,8 @@ class CacheAwareScheduler: # 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存 global_model_id: str = str(global_model.id) + requested_model_name = model_name + resolved_model_name = str(global_model.name) # 获取合并后的访问限制(ApiKey + User) restrictions = self._get_effective_restrictions(user_api_key) @@ -591,14 +601,27 @@ class CacheAwareScheduler: # 0.1 检查 API 格式是否被允许 if allowed_api_formats: if target_format.value not in allowed_api_formats: - logger.debug(f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, " - f"允许的格式: {allowed_api_formats}") + logger.debug( + f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, " + f"允许的格式: {allowed_api_formats}" + ) return [], global_model_id # 0.2 检查模型是否被允许 if allowed_models: - if model_name not in allowed_models: - logger.debug(f"用户/API Key 不允许使用模型 {model_name}, " f"允许的模型: {allowed_models}") + if ( + requested_model_name not in allowed_models + and resolved_model_name not in allowed_models + ): + resolved_note = ( + f" (解析为 {resolved_model_name})" + if resolved_model_name != requested_model_name + else "" + ) + logger.debug( + f"用户/API Key 不允许使用模型 {requested_model_name}{resolved_note}, " + f"允许的模型: {allowed_models}" + ) return [], global_model_id # 1. 查询 Providers @@ -629,7 +652,8 @@ class CacheAwareScheduler: db=db, providers=providers, target_format=target_format, - model_name=model_name, + model_name=requested_model_name, + resolved_model_name=resolved_model_name, affinity_key=affinity_key, max_candidates=max_candidates, allowed_endpoints=allowed_endpoints, @@ -644,8 +668,10 @@ class CacheAwareScheduler: self._metrics["total_candidates"] += len(candidates) self._metrics["last_candidate_count"] = len(candidates) - logger.debug(f"预先获取到 {len(candidates)} 个可用组合 " - f"(api_format={target_format.value}, model={model_name})") + logger.debug( + f"预先获取到 {len(candidates)} 个可用组合 " + f"(api_format={target_format.value}, model={model_name})" + ) # 4. 应用缓存亲和性排序(使用 global_model_id 作为模型标识) if affinity_key and candidates: @@ -708,33 +734,31 @@ class CacheAwareScheduler: - 模型支持的能力是全局的,与具体的 Key 无关 - 如果模型不支持某能力,整个 Provider 的所有 Key 都应该被跳过 + 支持两种匹配方式: + 1. 直接匹配 GlobalModel.name + 2. 通过 ModelCacheService 匹配别名(全局查找) + Args: db: 数据库会话 provider: Provider 对象 - model_name: 模型名称(必须是 GlobalModel.name) + model_name: 模型名称(可以是 GlobalModel.name 或别名) is_stream: 是否是流式请求,如果为 True 则同时检查流式支持 capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力 Returns: (is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表 """ - from src.models.database import GlobalModel - - # 直接通过 GlobalModel.name 查找 - global_model = ( - db.query(GlobalModel) - .filter(GlobalModel.name == model_name, GlobalModel.is_active == True) - .first() - ) + # 使用 ModelCacheService 解析模型名称(支持别名) + global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name) if not global_model: - return False, "模型不存在", None + # 完全未找到匹配 + return False, "模型不存在或 Provider 未配置此模型", None - # 检查模型支持 + # 找到 GlobalModel 后,检查当前 Provider 是否支持 is_supported, skip_reason, caps = await self._check_model_support_for_global_model( db, provider, global_model, model_name, is_stream, capability_requirements ) - return is_supported, skip_reason, caps async def _check_model_support_for_global_model( @@ -761,7 +785,16 @@ class CacheAwareScheduler: (is_supported, skip_reason, supported_capabilities) """ # 确保 global_model 附加到当前 Session - global_model = db.merge(global_model, load=False) + # 注意:从缓存重建的对象是 transient 状态,不能使用 load=False + # 使用 load=True(默认)允许 SQLAlchemy 正确处理 transient 对象 + from sqlalchemy import inspect + insp = inspect(global_model) + if insp.transient or insp.detached: + # transient/detached 对象:使用默认 merge(会查询 DB 检查是否存在) + global_model = db.merge(global_model) + else: + # persistent 对象:已经附加到 session,无需 merge + pass # 获取模型支持的能力列表 model_supported_capabilities: List[str] = list(global_model.supported_capabilities or []) @@ -796,6 +829,7 @@ class CacheAwareScheduler: key: ProviderAPIKey, model_name: str, capability_requirements: Optional[Dict[str, bool]] = None, + resolved_model_name: Optional[str] = None, ) -> Tuple[bool, Optional[str]]: """ 检查 API Key 的可用性 @@ -807,6 +841,7 @@ class CacheAwareScheduler: key: API Key 对象 model_name: 模型名称 capability_requirements: 能力需求(可选) + resolved_model_name: 解析后的 GlobalModel.name(可选) Returns: (is_available, skip_reason) @@ -818,7 +853,10 @@ class CacheAwareScheduler: # 模型权限检查:使用 allowed_models 白名单 # None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型 - if key.allowed_models is not None and model_name not in key.allowed_models: + if key.allowed_models is not None and ( + model_name not in key.allowed_models + and (not resolved_model_name or resolved_model_name not in key.allowed_models) + ): allowed_preview = ", ".join(key.allowed_models[:3]) if key.allowed_models else "(无)" suffix = "..." if len(key.allowed_models) > 3 else "" return False, f"模型权限不匹配(允许: {allowed_preview}{suffix})" @@ -843,6 +881,7 @@ class CacheAwareScheduler: target_format: APIFormat, model_name: str, affinity_key: Optional[str], + resolved_model_name: Optional[str] = None, max_candidates: Optional[int] = None, allowed_endpoints: Optional[set] = None, is_stream: bool = False, @@ -855,8 +894,9 @@ class CacheAwareScheduler: db: 数据库会话 providers: Provider 列表 target_format: 目标 API 格式 - model_name: 模型名称 + model_name: 模型名称(用户请求的名称,可能是别名) affinity_key: 亲和性标识符(通常为API Key ID) + resolved_model_name: 解析后的 GlobalModel.name(用于 Key.allowed_models 校验) max_candidates: 最大候选数 allowed_endpoints: 允许的 Endpoint ID 集合(None 表示不限制) is_stream: 是否是流式请求,如果为 True 则过滤不支持流式的 Provider @@ -901,7 +941,10 @@ class CacheAwareScheduler: for key in keys: # Key 级别的能力检查(模型级别的能力检查已在上面完成) is_available, skip_reason = self._check_key_availability( - key, model_name, capability_requirements + key, + model_name, + capability_requirements, + resolved_model_name=resolved_model_name, ) candidate = ProviderCandidate( @@ -970,11 +1013,13 @@ class CacheAwareScheduler: candidate.is_cached = True cached_candidates.append(candidate) matched = True - logger.debug(f"检测到缓存亲和性: affinity_key={affinity_key[:8]}..., " + logger.debug( + f"检测到缓存亲和性: affinity_key={affinity_key[:8]}..., " f"api_format={api_format_str}, global_model_id={global_model_id[:8]}..., " f"provider={provider.name}, endpoint={endpoint.id[:8]}..., " f"provider_key=***{key.api_key[-4:]}, " - f"使用次数={affinity.request_count}") + f"使用次数={affinity.request_count}" + ) else: candidate.is_cached = False other_candidates.append(candidate) @@ -1080,6 +1125,7 @@ class CacheAwareScheduler: c.key.internal_priority if c.key else 999999, c.key.id if c.key else "", ) + result.extend(sorted(group, key=secondary_sort)) return result