refactor(scheduler): integrate model alias resolution

- Use ModelCacheService.resolve_global_model_by_name_or_alias() for model lookups
- Support both requested model name and resolved GlobalModel name in validation
- Track resolved_model_name for proper allow_models checking
- Improve model availability checks to handle alias resolution
- Fix transient/detached object handling in global_model merge
- Add more descriptive debug logs for alias resolution mismatches
- Clean up code formatting (line length, imports organization)
This commit is contained in:
fawney19
2025-12-15 18:13:35 +08:00
parent 51b85915d2
commit 8f0a0cbdb1

View File

@@ -28,15 +28,17 @@
- 失效缓存亲和性,避免重复选择故障资源 - 失效缓存亲和性,避免重复选择故障资源
""" """
from __future__ import annotations
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from src.core.logger import logger
from sqlalchemy.orm import Session, selectinload from sqlalchemy.orm import Session, selectinload
from src.core.enums import APIFormat from src.core.enums import APIFormat
from src.core.exceptions import ModelNotSupportedException, ProviderNotAvailableException from src.core.exceptions import ModelNotSupportedException, ProviderNotAvailableException
from src.core.logger import logger
from src.models.database import ( from src.models.database import (
ApiKey, ApiKey,
Model, Model,
@@ -44,10 +46,15 @@ from src.models.database import (
ProviderAPIKey, ProviderAPIKey,
ProviderEndpoint, ProviderEndpoint,
) )
if TYPE_CHECKING:
from src.models.database import GlobalModel
from src.services.cache.affinity_manager import ( from src.services.cache.affinity_manager import (
CacheAffinityManager, CacheAffinityManager,
get_affinity_manager, get_affinity_manager,
) )
from src.services.cache.model_cache import ModelCacheService
from src.services.health.monitor import health_monitor from src.services.health.monitor import health_monitor
from src.services.provider.format import normalize_api_format from src.services.provider.format import normalize_api_format
from src.services.rate_limit.adaptive_reservation import ( from src.services.rate_limit.adaptive_reservation import (
@@ -259,9 +266,11 @@ class CacheAwareScheduler:
self._metrics["concurrency_denied"] += 1 self._metrics["concurrency_denied"] += 1
continue continue
logger.debug(f" └─ 选择 Provider={provider.name}, Endpoint={endpoint.id[:8]}..., " logger.debug(
f" └─ 选择 Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"Key=***{key.api_key[-4:]}, 缓存命中={is_cached_user}, " f"Key=***{key.api_key[-4:]}, 缓存命中={is_cached_user}, "
f"并发状态[{snapshot.describe()}]") f"并发状态[{snapshot.describe()}]"
)
if key.cache_ttl_minutes > 0 and global_model_id: if key.cache_ttl_minutes > 0 and global_model_id:
ttl = key.cache_ttl_minutes * 60 if key.cache_ttl_minutes > 0 else None ttl = key.cache_ttl_minutes * 60 if key.cache_ttl_minutes > 0 else None
@@ -349,7 +358,9 @@ class CacheAwareScheduler:
logger.debug(f" -> 无并发管理器,直接通过") logger.debug(f" -> 无并发管理器,直接通过")
snapshot = ConcurrencySnapshot( snapshot = ConcurrencySnapshot(
endpoint_current=0, endpoint_current=0,
endpoint_limit=int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None, endpoint_limit=(
int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None
),
key_current=0, key_current=0,
key_limit=effective_key_limit, key_limit=effective_key_limit,
is_cached_user=is_cached_user, is_cached_user=is_cached_user,
@@ -484,10 +495,12 @@ class CacheAwareScheduler:
user = None user = None
# 调试日志 # 调试日志
logger.debug(f"[_get_effective_restrictions] ApiKey={user_api_key.id[:8]}..., " logger.debug(
f"[_get_effective_restrictions] ApiKey={user_api_key.id[:8]}..., "
f"User={user.id[:8] if user else 'None'}..., " f"User={user.id[:8] if user else 'None'}..., "
f"ApiKey.allowed_models={user_api_key.allowed_models}, " f"ApiKey.allowed_models={user_api_key.allowed_models}, "
f"User.allowed_models={user.allowed_models if user else 'N/A'}") f"User.allowed_models={user.allowed_models if user else 'N/A'}"
)
def merge_restrictions(key_restriction, user_restriction): def merge_restrictions(key_restriction, user_restriction):
"""合并两个限制列表,返回有效的限制集合""" """合并两个限制列表,返回有效的限制集合"""
@@ -566,13 +579,8 @@ class CacheAwareScheduler:
target_format = normalize_api_format(api_format) target_format = normalize_api_format(api_format)
# 0. 解析 model_name 到 GlobalModel直接查找,用户必须使用标准名称 # 0. 解析 model_name 到 GlobalModel支持直接匹配和别名匹配,使用 ModelCacheService
from src.models.database import GlobalModel global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
global_model = (
db.query(GlobalModel)
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
.first()
)
if not global_model: if not global_model:
logger.warning(f"GlobalModel not found: {model_name}") logger.warning(f"GlobalModel not found: {model_name}")
@@ -580,6 +588,8 @@ class CacheAwareScheduler:
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存 # 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存
global_model_id: str = str(global_model.id) global_model_id: str = str(global_model.id)
requested_model_name = model_name
resolved_model_name = str(global_model.name)
# 获取合并后的访问限制ApiKey + User # 获取合并后的访问限制ApiKey + User
restrictions = self._get_effective_restrictions(user_api_key) restrictions = self._get_effective_restrictions(user_api_key)
@@ -591,14 +601,27 @@ class CacheAwareScheduler:
# 0.1 检查 API 格式是否被允许 # 0.1 检查 API 格式是否被允许
if allowed_api_formats: if allowed_api_formats:
if target_format.value not in allowed_api_formats: if target_format.value not in allowed_api_formats:
logger.debug(f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, " logger.debug(
f"允许的格式: {allowed_api_formats}") f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, "
f"允许的格式: {allowed_api_formats}"
)
return [], global_model_id return [], global_model_id
# 0.2 检查模型是否被允许 # 0.2 检查模型是否被允许
if allowed_models: if allowed_models:
if model_name not in allowed_models: if (
logger.debug(f"用户/API Key 不允许使用模型 {model_name}, " f"允许的模型: {allowed_models}") requested_model_name not in allowed_models
and resolved_model_name not in allowed_models
):
resolved_note = (
f" (解析为 {resolved_model_name})"
if resolved_model_name != requested_model_name
else ""
)
logger.debug(
f"用户/API Key 不允许使用模型 {requested_model_name}{resolved_note}, "
f"允许的模型: {allowed_models}"
)
return [], global_model_id return [], global_model_id
# 1. 查询 Providers # 1. 查询 Providers
@@ -629,7 +652,8 @@ class CacheAwareScheduler:
db=db, db=db,
providers=providers, providers=providers,
target_format=target_format, target_format=target_format,
model_name=model_name, model_name=requested_model_name,
resolved_model_name=resolved_model_name,
affinity_key=affinity_key, affinity_key=affinity_key,
max_candidates=max_candidates, max_candidates=max_candidates,
allowed_endpoints=allowed_endpoints, allowed_endpoints=allowed_endpoints,
@@ -644,8 +668,10 @@ class CacheAwareScheduler:
self._metrics["total_candidates"] += len(candidates) self._metrics["total_candidates"] += len(candidates)
self._metrics["last_candidate_count"] = len(candidates) self._metrics["last_candidate_count"] = len(candidates)
logger.debug(f"预先获取到 {len(candidates)} 个可用组合 " logger.debug(
f"(api_format={target_format.value}, model={model_name})") f"预先获取到 {len(candidates)} 个可用组合 "
f"(api_format={target_format.value}, model={model_name})"
)
# 4. 应用缓存亲和性排序(使用 global_model_id 作为模型标识) # 4. 应用缓存亲和性排序(使用 global_model_id 作为模型标识)
if affinity_key and candidates: if affinity_key and candidates:
@@ -708,33 +734,31 @@ class CacheAwareScheduler:
- 模型支持的能力是全局的,与具体的 Key 无关 - 模型支持的能力是全局的,与具体的 Key 无关
- 如果模型不支持某能力,整个 Provider 的所有 Key 都应该被跳过 - 如果模型不支持某能力,整个 Provider 的所有 Key 都应该被跳过
支持两种匹配方式:
1. 直接匹配 GlobalModel.name
2. 通过 ModelCacheService 匹配别名(全局查找)
Args: Args:
db: 数据库会话 db: 数据库会话
provider: Provider 对象 provider: Provider 对象
model_name: 模型名称(必须是 GlobalModel.name model_name: 模型名称(可以是 GlobalModel.name 或别名
is_stream: 是否是流式请求,如果为 True 则同时检查流式支持 is_stream: 是否是流式请求,如果为 True 则同时检查流式支持
capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力 capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力
Returns: Returns:
(is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表 (is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表
""" """
from src.models.database import GlobalModel # 使用 ModelCacheService 解析模型名称(支持别名)
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
# 直接通过 GlobalModel.name 查找
global_model = (
db.query(GlobalModel)
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
.first()
)
if not global_model: if not global_model:
return False, "模型不存在", None # 完全未找到匹配
return False, "模型不存在或 Provider 未配置此模型", None
# 检查模型支持 # 找到 GlobalModel 后,检查当前 Provider 是否支持
is_supported, skip_reason, caps = await self._check_model_support_for_global_model( is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
db, provider, global_model, model_name, is_stream, capability_requirements db, provider, global_model, model_name, is_stream, capability_requirements
) )
return is_supported, skip_reason, caps return is_supported, skip_reason, caps
async def _check_model_support_for_global_model( async def _check_model_support_for_global_model(
@@ -761,7 +785,16 @@ class CacheAwareScheduler:
(is_supported, skip_reason, supported_capabilities) (is_supported, skip_reason, supported_capabilities)
""" """
# 确保 global_model 附加到当前 Session # 确保 global_model 附加到当前 Session
global_model = db.merge(global_model, load=False) # 注意:从缓存重建的对象是 transient 状态,不能使用 load=False
# 使用 load=True默认允许 SQLAlchemy 正确处理 transient 对象
from sqlalchemy import inspect
insp = inspect(global_model)
if insp.transient or insp.detached:
# transient/detached 对象:使用默认 merge会查询 DB 检查是否存在)
global_model = db.merge(global_model)
else:
# persistent 对象:已经附加到 session无需 merge
pass
# 获取模型支持的能力列表 # 获取模型支持的能力列表
model_supported_capabilities: List[str] = list(global_model.supported_capabilities or []) model_supported_capabilities: List[str] = list(global_model.supported_capabilities or [])
@@ -796,6 +829,7 @@ class CacheAwareScheduler:
key: ProviderAPIKey, key: ProviderAPIKey,
model_name: str, model_name: str,
capability_requirements: Optional[Dict[str, bool]] = None, capability_requirements: Optional[Dict[str, bool]] = None,
resolved_model_name: Optional[str] = None,
) -> Tuple[bool, Optional[str]]: ) -> Tuple[bool, Optional[str]]:
""" """
检查 API Key 的可用性 检查 API Key 的可用性
@@ -807,6 +841,7 @@ class CacheAwareScheduler:
key: API Key 对象 key: API Key 对象
model_name: 模型名称 model_name: 模型名称
capability_requirements: 能力需求(可选) capability_requirements: 能力需求(可选)
resolved_model_name: 解析后的 GlobalModel.name可选
Returns: Returns:
(is_available, skip_reason) (is_available, skip_reason)
@@ -818,7 +853,10 @@ class CacheAwareScheduler:
# 模型权限检查:使用 allowed_models 白名单 # 模型权限检查:使用 allowed_models 白名单
# None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型 # None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型
if key.allowed_models is not None and model_name not in key.allowed_models: if key.allowed_models is not None and (
model_name not in key.allowed_models
and (not resolved_model_name or resolved_model_name not in key.allowed_models)
):
allowed_preview = ", ".join(key.allowed_models[:3]) if key.allowed_models else "(无)" allowed_preview = ", ".join(key.allowed_models[:3]) if key.allowed_models else "(无)"
suffix = "..." if len(key.allowed_models) > 3 else "" suffix = "..." if len(key.allowed_models) > 3 else ""
return False, f"模型权限不匹配(允许: {allowed_preview}{suffix})" return False, f"模型权限不匹配(允许: {allowed_preview}{suffix})"
@@ -843,6 +881,7 @@ class CacheAwareScheduler:
target_format: APIFormat, target_format: APIFormat,
model_name: str, model_name: str,
affinity_key: Optional[str], affinity_key: Optional[str],
resolved_model_name: Optional[str] = None,
max_candidates: Optional[int] = None, max_candidates: Optional[int] = None,
allowed_endpoints: Optional[set] = None, allowed_endpoints: Optional[set] = None,
is_stream: bool = False, is_stream: bool = False,
@@ -855,8 +894,9 @@ class CacheAwareScheduler:
db: 数据库会话 db: 数据库会话
providers: Provider 列表 providers: Provider 列表
target_format: 目标 API 格式 target_format: 目标 API 格式
model_name: 模型名称 model_name: 模型名称(用户请求的名称,可能是别名)
affinity_key: 亲和性标识符通常为API Key ID affinity_key: 亲和性标识符通常为API Key ID
resolved_model_name: 解析后的 GlobalModel.name用于 Key.allowed_models 校验)
max_candidates: 最大候选数 max_candidates: 最大候选数
allowed_endpoints: 允许的 Endpoint ID 集合None 表示不限制) allowed_endpoints: 允许的 Endpoint ID 集合None 表示不限制)
is_stream: 是否是流式请求,如果为 True 则过滤不支持流式的 Provider is_stream: 是否是流式请求,如果为 True 则过滤不支持流式的 Provider
@@ -901,7 +941,10 @@ class CacheAwareScheduler:
for key in keys: for key in keys:
# Key 级别的能力检查(模型级别的能力检查已在上面完成) # Key 级别的能力检查(模型级别的能力检查已在上面完成)
is_available, skip_reason = self._check_key_availability( is_available, skip_reason = self._check_key_availability(
key, model_name, capability_requirements key,
model_name,
capability_requirements,
resolved_model_name=resolved_model_name,
) )
candidate = ProviderCandidate( candidate = ProviderCandidate(
@@ -970,11 +1013,13 @@ class CacheAwareScheduler:
candidate.is_cached = True candidate.is_cached = True
cached_candidates.append(candidate) cached_candidates.append(candidate)
matched = True matched = True
logger.debug(f"检测到缓存亲和性: affinity_key={affinity_key[:8]}..., " logger.debug(
f"检测到缓存亲和性: affinity_key={affinity_key[:8]}..., "
f"api_format={api_format_str}, global_model_id={global_model_id[:8]}..., " f"api_format={api_format_str}, global_model_id={global_model_id[:8]}..., "
f"provider={provider.name}, endpoint={endpoint.id[:8]}..., " f"provider={provider.name}, endpoint={endpoint.id[:8]}..., "
f"provider_key=***{key.api_key[-4:]}, " f"provider_key=***{key.api_key[-4:]}, "
f"使用次数={affinity.request_count}") f"使用次数={affinity.request_count}"
)
else: else:
candidate.is_cached = False candidate.is_cached = False
other_candidates.append(candidate) other_candidates.append(candidate)
@@ -1080,6 +1125,7 @@ class CacheAwareScheduler:
c.key.internal_priority if c.key else 999999, c.key.internal_priority if c.key else 999999,
c.key.id if c.key else "", c.key.id if c.key else "",
) )
result.extend(sorted(group, key=secondary_sort)) result.extend(sorted(group, key=secondary_sort))
return result return result