refactor(scheduler): integrate model alias resolution

- Use ModelCacheService.resolve_global_model_by_name_or_alias() for model lookups
- Support both requested model name and resolved GlobalModel name in validation
- Track resolved_model_name for proper allow_models checking
- Improve model availability checks to handle alias resolution
- Fix transient/detached object handling in global_model merge
- Add more descriptive debug logs for alias resolution mismatches
- Clean up code formatting (line length, imports organization)
This commit is contained in:
fawney19
2025-12-15 18:13:35 +08:00
parent 51b85915d2
commit 8f0a0cbdb1

View File

@@ -28,15 +28,17 @@
- 失效缓存亲和性,避免重复选择故障资源
"""
from __future__ import annotations
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from src.core.logger import logger
from sqlalchemy.orm import Session, selectinload
from src.core.enums import APIFormat
from src.core.exceptions import ModelNotSupportedException, ProviderNotAvailableException
from src.core.logger import logger
from src.models.database import (
ApiKey,
Model,
@@ -44,10 +46,15 @@ from src.models.database import (
ProviderAPIKey,
ProviderEndpoint,
)
if TYPE_CHECKING:
from src.models.database import GlobalModel
from src.services.cache.affinity_manager import (
CacheAffinityManager,
get_affinity_manager,
)
from src.services.cache.model_cache import ModelCacheService
from src.services.health.monitor import health_monitor
from src.services.provider.format import normalize_api_format
from src.services.rate_limit.adaptive_reservation import (
@@ -259,9 +266,11 @@ class CacheAwareScheduler:
self._metrics["concurrency_denied"] += 1
continue
logger.debug(f" └─ 选择 Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
logger.debug(
f" └─ 选择 Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"Key=***{key.api_key[-4:]}, 缓存命中={is_cached_user}, "
f"并发状态[{snapshot.describe()}]")
f"并发状态[{snapshot.describe()}]"
)
if key.cache_ttl_minutes > 0 and global_model_id:
ttl = key.cache_ttl_minutes * 60 if key.cache_ttl_minutes > 0 else None
@@ -349,7 +358,9 @@ class CacheAwareScheduler:
logger.debug(f" -> 无并发管理器,直接通过")
snapshot = ConcurrencySnapshot(
endpoint_current=0,
endpoint_limit=int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None,
endpoint_limit=(
int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None
),
key_current=0,
key_limit=effective_key_limit,
is_cached_user=is_cached_user,
@@ -484,10 +495,12 @@ class CacheAwareScheduler:
user = None
# 调试日志
logger.debug(f"[_get_effective_restrictions] ApiKey={user_api_key.id[:8]}..., "
logger.debug(
f"[_get_effective_restrictions] ApiKey={user_api_key.id[:8]}..., "
f"User={user.id[:8] if user else 'None'}..., "
f"ApiKey.allowed_models={user_api_key.allowed_models}, "
f"User.allowed_models={user.allowed_models if user else 'N/A'}")
f"User.allowed_models={user.allowed_models if user else 'N/A'}"
)
def merge_restrictions(key_restriction, user_restriction):
"""合并两个限制列表,返回有效的限制集合"""
@@ -566,13 +579,8 @@ class CacheAwareScheduler:
target_format = normalize_api_format(api_format)
# 0. 解析 model_name 到 GlobalModel直接查找,用户必须使用标准名称
from src.models.database import GlobalModel
global_model = (
db.query(GlobalModel)
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
.first()
)
# 0. 解析 model_name 到 GlobalModel支持直接匹配和别名匹配,使用 ModelCacheService
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
if not global_model:
logger.warning(f"GlobalModel not found: {model_name}")
@@ -580,6 +588,8 @@ class CacheAwareScheduler:
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存
global_model_id: str = str(global_model.id)
requested_model_name = model_name
resolved_model_name = str(global_model.name)
# 获取合并后的访问限制ApiKey + User
restrictions = self._get_effective_restrictions(user_api_key)
@@ -591,14 +601,27 @@ class CacheAwareScheduler:
# 0.1 检查 API 格式是否被允许
if allowed_api_formats:
if target_format.value not in allowed_api_formats:
logger.debug(f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, "
f"允许的格式: {allowed_api_formats}")
logger.debug(
f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, "
f"允许的格式: {allowed_api_formats}"
)
return [], global_model_id
# 0.2 检查模型是否被允许
if allowed_models:
if model_name not in allowed_models:
logger.debug(f"用户/API Key 不允许使用模型 {model_name}, " f"允许的模型: {allowed_models}")
if (
requested_model_name not in allowed_models
and resolved_model_name not in allowed_models
):
resolved_note = (
f" (解析为 {resolved_model_name})"
if resolved_model_name != requested_model_name
else ""
)
logger.debug(
f"用户/API Key 不允许使用模型 {requested_model_name}{resolved_note}, "
f"允许的模型: {allowed_models}"
)
return [], global_model_id
# 1. 查询 Providers
@@ -629,7 +652,8 @@ class CacheAwareScheduler:
db=db,
providers=providers,
target_format=target_format,
model_name=model_name,
model_name=requested_model_name,
resolved_model_name=resolved_model_name,
affinity_key=affinity_key,
max_candidates=max_candidates,
allowed_endpoints=allowed_endpoints,
@@ -644,8 +668,10 @@ class CacheAwareScheduler:
self._metrics["total_candidates"] += len(candidates)
self._metrics["last_candidate_count"] = len(candidates)
logger.debug(f"预先获取到 {len(candidates)} 个可用组合 "
f"(api_format={target_format.value}, model={model_name})")
logger.debug(
f"预先获取到 {len(candidates)} 个可用组合 "
f"(api_format={target_format.value}, model={model_name})"
)
# 4. 应用缓存亲和性排序(使用 global_model_id 作为模型标识)
if affinity_key and candidates:
@@ -708,33 +734,31 @@ class CacheAwareScheduler:
- 模型支持的能力是全局的,与具体的 Key 无关
- 如果模型不支持某能力,整个 Provider 的所有 Key 都应该被跳过
支持两种匹配方式:
1. 直接匹配 GlobalModel.name
2. 通过 ModelCacheService 匹配别名(全局查找)
Args:
db: 数据库会话
provider: Provider 对象
model_name: 模型名称(必须是 GlobalModel.name
model_name: 模型名称(可以是 GlobalModel.name 或别名
is_stream: 是否是流式请求,如果为 True 则同时检查流式支持
capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力
Returns:
(is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表
"""
from src.models.database import GlobalModel
# 直接通过 GlobalModel.name 查找
global_model = (
db.query(GlobalModel)
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
.first()
)
# 使用 ModelCacheService 解析模型名称(支持别名)
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
if not global_model:
return False, "模型不存在", None
# 完全未找到匹配
return False, "模型不存在或 Provider 未配置此模型", None
# 检查模型支持
# 找到 GlobalModel 后,检查当前 Provider 是否支持
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
db, provider, global_model, model_name, is_stream, capability_requirements
)
return is_supported, skip_reason, caps
async def _check_model_support_for_global_model(
@@ -761,7 +785,16 @@ class CacheAwareScheduler:
(is_supported, skip_reason, supported_capabilities)
"""
# 确保 global_model 附加到当前 Session
global_model = db.merge(global_model, load=False)
# 注意:从缓存重建的对象是 transient 状态,不能使用 load=False
# 使用 load=True默认允许 SQLAlchemy 正确处理 transient 对象
from sqlalchemy import inspect
insp = inspect(global_model)
if insp.transient or insp.detached:
# transient/detached 对象:使用默认 merge会查询 DB 检查是否存在)
global_model = db.merge(global_model)
else:
# persistent 对象:已经附加到 session无需 merge
pass
# 获取模型支持的能力列表
model_supported_capabilities: List[str] = list(global_model.supported_capabilities or [])
@@ -796,6 +829,7 @@ class CacheAwareScheduler:
key: ProviderAPIKey,
model_name: str,
capability_requirements: Optional[Dict[str, bool]] = None,
resolved_model_name: Optional[str] = None,
) -> Tuple[bool, Optional[str]]:
"""
检查 API Key 的可用性
@@ -807,6 +841,7 @@ class CacheAwareScheduler:
key: API Key 对象
model_name: 模型名称
capability_requirements: 能力需求(可选)
resolved_model_name: 解析后的 GlobalModel.name可选
Returns:
(is_available, skip_reason)
@@ -818,7 +853,10 @@ class CacheAwareScheduler:
# 模型权限检查:使用 allowed_models 白名单
# None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型
if key.allowed_models is not None and model_name not in key.allowed_models:
if key.allowed_models is not None and (
model_name not in key.allowed_models
and (not resolved_model_name or resolved_model_name not in key.allowed_models)
):
allowed_preview = ", ".join(key.allowed_models[:3]) if key.allowed_models else "(无)"
suffix = "..." if len(key.allowed_models) > 3 else ""
return False, f"模型权限不匹配(允许: {allowed_preview}{suffix})"
@@ -843,6 +881,7 @@ class CacheAwareScheduler:
target_format: APIFormat,
model_name: str,
affinity_key: Optional[str],
resolved_model_name: Optional[str] = None,
max_candidates: Optional[int] = None,
allowed_endpoints: Optional[set] = None,
is_stream: bool = False,
@@ -855,8 +894,9 @@ class CacheAwareScheduler:
db: 数据库会话
providers: Provider 列表
target_format: 目标 API 格式
model_name: 模型名称
model_name: 模型名称(用户请求的名称,可能是别名)
affinity_key: 亲和性标识符通常为API Key ID
resolved_model_name: 解析后的 GlobalModel.name用于 Key.allowed_models 校验)
max_candidates: 最大候选数
allowed_endpoints: 允许的 Endpoint ID 集合None 表示不限制)
is_stream: 是否是流式请求,如果为 True 则过滤不支持流式的 Provider
@@ -901,7 +941,10 @@ class CacheAwareScheduler:
for key in keys:
# Key 级别的能力检查(模型级别的能力检查已在上面完成)
is_available, skip_reason = self._check_key_availability(
key, model_name, capability_requirements
key,
model_name,
capability_requirements,
resolved_model_name=resolved_model_name,
)
candidate = ProviderCandidate(
@@ -970,11 +1013,13 @@ class CacheAwareScheduler:
candidate.is_cached = True
cached_candidates.append(candidate)
matched = True
logger.debug(f"检测到缓存亲和性: affinity_key={affinity_key[:8]}..., "
logger.debug(
f"检测到缓存亲和性: affinity_key={affinity_key[:8]}..., "
f"api_format={api_format_str}, global_model_id={global_model_id[:8]}..., "
f"provider={provider.name}, endpoint={endpoint.id[:8]}..., "
f"provider_key=***{key.api_key[-4:]}, "
f"使用次数={affinity.request_count}")
f"使用次数={affinity.request_count}"
)
else:
candidate.is_cached = False
other_candidates.append(candidate)
@@ -1080,6 +1125,7 @@ class CacheAwareScheduler:
c.key.internal_priority if c.key else 999999,
c.key.id if c.key else "",
)
result.extend(sorted(group, key=secondary_sort))
return result