mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
refactor(scheduler): integrate model alias resolution
- Use ModelCacheService.resolve_global_model_by_name_or_alias() for model lookups - Support both requested model name and resolved GlobalModel name in validation - Track resolved_model_name for proper allow_models checking - Improve model availability checks to handle alias resolution - Fix transient/detached object handling in global_model merge - Add more descriptive debug logs for alias resolution mismatches - Clean up code formatting (line length, imports organization)
This commit is contained in:
124
src/services/cache/aware_scheduler.py
vendored
124
src/services/cache/aware_scheduler.py
vendored
@@ -28,15 +28,17 @@
|
|||||||
- 失效缓存亲和性,避免重复选择故障资源
|
- 失效缓存亲和性,避免重复选择故障资源
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from src.core.logger import logger
|
|
||||||
from sqlalchemy.orm import Session, selectinload
|
from sqlalchemy.orm import Session, selectinload
|
||||||
|
|
||||||
from src.core.enums import APIFormat
|
from src.core.enums import APIFormat
|
||||||
from src.core.exceptions import ModelNotSupportedException, ProviderNotAvailableException
|
from src.core.exceptions import ModelNotSupportedException, ProviderNotAvailableException
|
||||||
|
from src.core.logger import logger
|
||||||
from src.models.database import (
|
from src.models.database import (
|
||||||
ApiKey,
|
ApiKey,
|
||||||
Model,
|
Model,
|
||||||
@@ -44,10 +46,15 @@ from src.models.database import (
|
|||||||
ProviderAPIKey,
|
ProviderAPIKey,
|
||||||
ProviderEndpoint,
|
ProviderEndpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.models.database import GlobalModel
|
||||||
|
|
||||||
from src.services.cache.affinity_manager import (
|
from src.services.cache.affinity_manager import (
|
||||||
CacheAffinityManager,
|
CacheAffinityManager,
|
||||||
get_affinity_manager,
|
get_affinity_manager,
|
||||||
)
|
)
|
||||||
|
from src.services.cache.model_cache import ModelCacheService
|
||||||
from src.services.health.monitor import health_monitor
|
from src.services.health.monitor import health_monitor
|
||||||
from src.services.provider.format import normalize_api_format
|
from src.services.provider.format import normalize_api_format
|
||||||
from src.services.rate_limit.adaptive_reservation import (
|
from src.services.rate_limit.adaptive_reservation import (
|
||||||
@@ -259,9 +266,11 @@ class CacheAwareScheduler:
|
|||||||
self._metrics["concurrency_denied"] += 1
|
self._metrics["concurrency_denied"] += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.debug(f" └─ 选择 Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
logger.debug(
|
||||||
|
f" └─ 选择 Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||||
f"Key=***{key.api_key[-4:]}, 缓存命中={is_cached_user}, "
|
f"Key=***{key.api_key[-4:]}, 缓存命中={is_cached_user}, "
|
||||||
f"并发状态[{snapshot.describe()}]")
|
f"并发状态[{snapshot.describe()}]"
|
||||||
|
)
|
||||||
|
|
||||||
if key.cache_ttl_minutes > 0 and global_model_id:
|
if key.cache_ttl_minutes > 0 and global_model_id:
|
||||||
ttl = key.cache_ttl_minutes * 60 if key.cache_ttl_minutes > 0 else None
|
ttl = key.cache_ttl_minutes * 60 if key.cache_ttl_minutes > 0 else None
|
||||||
@@ -349,7 +358,9 @@ class CacheAwareScheduler:
|
|||||||
logger.debug(f" -> 无并发管理器,直接通过")
|
logger.debug(f" -> 无并发管理器,直接通过")
|
||||||
snapshot = ConcurrencySnapshot(
|
snapshot = ConcurrencySnapshot(
|
||||||
endpoint_current=0,
|
endpoint_current=0,
|
||||||
endpoint_limit=int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None,
|
endpoint_limit=(
|
||||||
|
int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None
|
||||||
|
),
|
||||||
key_current=0,
|
key_current=0,
|
||||||
key_limit=effective_key_limit,
|
key_limit=effective_key_limit,
|
||||||
is_cached_user=is_cached_user,
|
is_cached_user=is_cached_user,
|
||||||
@@ -484,10 +495,12 @@ class CacheAwareScheduler:
|
|||||||
user = None
|
user = None
|
||||||
|
|
||||||
# 调试日志
|
# 调试日志
|
||||||
logger.debug(f"[_get_effective_restrictions] ApiKey={user_api_key.id[:8]}..., "
|
logger.debug(
|
||||||
|
f"[_get_effective_restrictions] ApiKey={user_api_key.id[:8]}..., "
|
||||||
f"User={user.id[:8] if user else 'None'}..., "
|
f"User={user.id[:8] if user else 'None'}..., "
|
||||||
f"ApiKey.allowed_models={user_api_key.allowed_models}, "
|
f"ApiKey.allowed_models={user_api_key.allowed_models}, "
|
||||||
f"User.allowed_models={user.allowed_models if user else 'N/A'}")
|
f"User.allowed_models={user.allowed_models if user else 'N/A'}"
|
||||||
|
)
|
||||||
|
|
||||||
def merge_restrictions(key_restriction, user_restriction):
|
def merge_restrictions(key_restriction, user_restriction):
|
||||||
"""合并两个限制列表,返回有效的限制集合"""
|
"""合并两个限制列表,返回有效的限制集合"""
|
||||||
@@ -566,13 +579,8 @@ class CacheAwareScheduler:
|
|||||||
|
|
||||||
target_format = normalize_api_format(api_format)
|
target_format = normalize_api_format(api_format)
|
||||||
|
|
||||||
# 0. 解析 model_name 到 GlobalModel(直接查找,用户必须使用标准名称)
|
# 0. 解析 model_name 到 GlobalModel(支持直接匹配和别名匹配,使用 ModelCacheService)
|
||||||
from src.models.database import GlobalModel
|
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
|
||||||
global_model = (
|
|
||||||
db.query(GlobalModel)
|
|
||||||
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not global_model:
|
if not global_model:
|
||||||
logger.warning(f"GlobalModel not found: {model_name}")
|
logger.warning(f"GlobalModel not found: {model_name}")
|
||||||
@@ -580,6 +588,8 @@ class CacheAwareScheduler:
|
|||||||
|
|
||||||
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存
|
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存
|
||||||
global_model_id: str = str(global_model.id)
|
global_model_id: str = str(global_model.id)
|
||||||
|
requested_model_name = model_name
|
||||||
|
resolved_model_name = str(global_model.name)
|
||||||
|
|
||||||
# 获取合并后的访问限制(ApiKey + User)
|
# 获取合并后的访问限制(ApiKey + User)
|
||||||
restrictions = self._get_effective_restrictions(user_api_key)
|
restrictions = self._get_effective_restrictions(user_api_key)
|
||||||
@@ -591,14 +601,27 @@ class CacheAwareScheduler:
|
|||||||
# 0.1 检查 API 格式是否被允许
|
# 0.1 检查 API 格式是否被允许
|
||||||
if allowed_api_formats:
|
if allowed_api_formats:
|
||||||
if target_format.value not in allowed_api_formats:
|
if target_format.value not in allowed_api_formats:
|
||||||
logger.debug(f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, "
|
logger.debug(
|
||||||
f"允许的格式: {allowed_api_formats}")
|
f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, "
|
||||||
|
f"允许的格式: {allowed_api_formats}"
|
||||||
|
)
|
||||||
return [], global_model_id
|
return [], global_model_id
|
||||||
|
|
||||||
# 0.2 检查模型是否被允许
|
# 0.2 检查模型是否被允许
|
||||||
if allowed_models:
|
if allowed_models:
|
||||||
if model_name not in allowed_models:
|
if (
|
||||||
logger.debug(f"用户/API Key 不允许使用模型 {model_name}, " f"允许的模型: {allowed_models}")
|
requested_model_name not in allowed_models
|
||||||
|
and resolved_model_name not in allowed_models
|
||||||
|
):
|
||||||
|
resolved_note = (
|
||||||
|
f" (解析为 {resolved_model_name})"
|
||||||
|
if resolved_model_name != requested_model_name
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"用户/API Key 不允许使用模型 {requested_model_name}{resolved_note}, "
|
||||||
|
f"允许的模型: {allowed_models}"
|
||||||
|
)
|
||||||
return [], global_model_id
|
return [], global_model_id
|
||||||
|
|
||||||
# 1. 查询 Providers
|
# 1. 查询 Providers
|
||||||
@@ -629,7 +652,8 @@ class CacheAwareScheduler:
|
|||||||
db=db,
|
db=db,
|
||||||
providers=providers,
|
providers=providers,
|
||||||
target_format=target_format,
|
target_format=target_format,
|
||||||
model_name=model_name,
|
model_name=requested_model_name,
|
||||||
|
resolved_model_name=resolved_model_name,
|
||||||
affinity_key=affinity_key,
|
affinity_key=affinity_key,
|
||||||
max_candidates=max_candidates,
|
max_candidates=max_candidates,
|
||||||
allowed_endpoints=allowed_endpoints,
|
allowed_endpoints=allowed_endpoints,
|
||||||
@@ -644,8 +668,10 @@ class CacheAwareScheduler:
|
|||||||
self._metrics["total_candidates"] += len(candidates)
|
self._metrics["total_candidates"] += len(candidates)
|
||||||
self._metrics["last_candidate_count"] = len(candidates)
|
self._metrics["last_candidate_count"] = len(candidates)
|
||||||
|
|
||||||
logger.debug(f"预先获取到 {len(candidates)} 个可用组合 "
|
logger.debug(
|
||||||
f"(api_format={target_format.value}, model={model_name})")
|
f"预先获取到 {len(candidates)} 个可用组合 "
|
||||||
|
f"(api_format={target_format.value}, model={model_name})"
|
||||||
|
)
|
||||||
|
|
||||||
# 4. 应用缓存亲和性排序(使用 global_model_id 作为模型标识)
|
# 4. 应用缓存亲和性排序(使用 global_model_id 作为模型标识)
|
||||||
if affinity_key and candidates:
|
if affinity_key and candidates:
|
||||||
@@ -708,33 +734,31 @@ class CacheAwareScheduler:
|
|||||||
- 模型支持的能力是全局的,与具体的 Key 无关
|
- 模型支持的能力是全局的,与具体的 Key 无关
|
||||||
- 如果模型不支持某能力,整个 Provider 的所有 Key 都应该被跳过
|
- 如果模型不支持某能力,整个 Provider 的所有 Key 都应该被跳过
|
||||||
|
|
||||||
|
支持两种匹配方式:
|
||||||
|
1. 直接匹配 GlobalModel.name
|
||||||
|
2. 通过 ModelCacheService 匹配别名(全局查找)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
provider: Provider 对象
|
provider: Provider 对象
|
||||||
model_name: 模型名称(必须是 GlobalModel.name)
|
model_name: 模型名称(可以是 GlobalModel.name 或别名)
|
||||||
is_stream: 是否是流式请求,如果为 True 则同时检查流式支持
|
is_stream: 是否是流式请求,如果为 True 则同时检查流式支持
|
||||||
capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力
|
capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表
|
(is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表
|
||||||
"""
|
"""
|
||||||
from src.models.database import GlobalModel
|
# 使用 ModelCacheService 解析模型名称(支持别名)
|
||||||
|
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
|
||||||
# 直接通过 GlobalModel.name 查找
|
|
||||||
global_model = (
|
|
||||||
db.query(GlobalModel)
|
|
||||||
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not global_model:
|
if not global_model:
|
||||||
return False, "模型不存在", None
|
# 完全未找到匹配
|
||||||
|
return False, "模型不存在或 Provider 未配置此模型", None
|
||||||
|
|
||||||
# 检查模型支持
|
# 找到 GlobalModel 后,检查当前 Provider 是否支持
|
||||||
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
|
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
|
||||||
db, provider, global_model, model_name, is_stream, capability_requirements
|
db, provider, global_model, model_name, is_stream, capability_requirements
|
||||||
)
|
)
|
||||||
|
|
||||||
return is_supported, skip_reason, caps
|
return is_supported, skip_reason, caps
|
||||||
|
|
||||||
async def _check_model_support_for_global_model(
|
async def _check_model_support_for_global_model(
|
||||||
@@ -761,7 +785,16 @@ class CacheAwareScheduler:
|
|||||||
(is_supported, skip_reason, supported_capabilities)
|
(is_supported, skip_reason, supported_capabilities)
|
||||||
"""
|
"""
|
||||||
# 确保 global_model 附加到当前 Session
|
# 确保 global_model 附加到当前 Session
|
||||||
global_model = db.merge(global_model, load=False)
|
# 注意:从缓存重建的对象是 transient 状态,不能使用 load=False
|
||||||
|
# 使用 load=True(默认)允许 SQLAlchemy 正确处理 transient 对象
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
insp = inspect(global_model)
|
||||||
|
if insp.transient or insp.detached:
|
||||||
|
# transient/detached 对象:使用默认 merge(会查询 DB 检查是否存在)
|
||||||
|
global_model = db.merge(global_model)
|
||||||
|
else:
|
||||||
|
# persistent 对象:已经附加到 session,无需 merge
|
||||||
|
pass
|
||||||
|
|
||||||
# 获取模型支持的能力列表
|
# 获取模型支持的能力列表
|
||||||
model_supported_capabilities: List[str] = list(global_model.supported_capabilities or [])
|
model_supported_capabilities: List[str] = list(global_model.supported_capabilities or [])
|
||||||
@@ -796,6 +829,7 @@ class CacheAwareScheduler:
|
|||||||
key: ProviderAPIKey,
|
key: ProviderAPIKey,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
capability_requirements: Optional[Dict[str, bool]] = None,
|
capability_requirements: Optional[Dict[str, bool]] = None,
|
||||||
|
resolved_model_name: Optional[str] = None,
|
||||||
) -> Tuple[bool, Optional[str]]:
|
) -> Tuple[bool, Optional[str]]:
|
||||||
"""
|
"""
|
||||||
检查 API Key 的可用性
|
检查 API Key 的可用性
|
||||||
@@ -807,6 +841,7 @@ class CacheAwareScheduler:
|
|||||||
key: API Key 对象
|
key: API Key 对象
|
||||||
model_name: 模型名称
|
model_name: 模型名称
|
||||||
capability_requirements: 能力需求(可选)
|
capability_requirements: 能力需求(可选)
|
||||||
|
resolved_model_name: 解析后的 GlobalModel.name(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(is_available, skip_reason)
|
(is_available, skip_reason)
|
||||||
@@ -818,7 +853,10 @@ class CacheAwareScheduler:
|
|||||||
|
|
||||||
# 模型权限检查:使用 allowed_models 白名单
|
# 模型权限检查:使用 allowed_models 白名单
|
||||||
# None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型
|
# None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型
|
||||||
if key.allowed_models is not None and model_name not in key.allowed_models:
|
if key.allowed_models is not None and (
|
||||||
|
model_name not in key.allowed_models
|
||||||
|
and (not resolved_model_name or resolved_model_name not in key.allowed_models)
|
||||||
|
):
|
||||||
allowed_preview = ", ".join(key.allowed_models[:3]) if key.allowed_models else "(无)"
|
allowed_preview = ", ".join(key.allowed_models[:3]) if key.allowed_models else "(无)"
|
||||||
suffix = "..." if len(key.allowed_models) > 3 else ""
|
suffix = "..." if len(key.allowed_models) > 3 else ""
|
||||||
return False, f"模型权限不匹配(允许: {allowed_preview}{suffix})"
|
return False, f"模型权限不匹配(允许: {allowed_preview}{suffix})"
|
||||||
@@ -843,6 +881,7 @@ class CacheAwareScheduler:
|
|||||||
target_format: APIFormat,
|
target_format: APIFormat,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
affinity_key: Optional[str],
|
affinity_key: Optional[str],
|
||||||
|
resolved_model_name: Optional[str] = None,
|
||||||
max_candidates: Optional[int] = None,
|
max_candidates: Optional[int] = None,
|
||||||
allowed_endpoints: Optional[set] = None,
|
allowed_endpoints: Optional[set] = None,
|
||||||
is_stream: bool = False,
|
is_stream: bool = False,
|
||||||
@@ -855,8 +894,9 @@ class CacheAwareScheduler:
|
|||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
providers: Provider 列表
|
providers: Provider 列表
|
||||||
target_format: 目标 API 格式
|
target_format: 目标 API 格式
|
||||||
model_name: 模型名称
|
model_name: 模型名称(用户请求的名称,可能是别名)
|
||||||
affinity_key: 亲和性标识符(通常为API Key ID)
|
affinity_key: 亲和性标识符(通常为API Key ID)
|
||||||
|
resolved_model_name: 解析后的 GlobalModel.name(用于 Key.allowed_models 校验)
|
||||||
max_candidates: 最大候选数
|
max_candidates: 最大候选数
|
||||||
allowed_endpoints: 允许的 Endpoint ID 集合(None 表示不限制)
|
allowed_endpoints: 允许的 Endpoint ID 集合(None 表示不限制)
|
||||||
is_stream: 是否是流式请求,如果为 True 则过滤不支持流式的 Provider
|
is_stream: 是否是流式请求,如果为 True 则过滤不支持流式的 Provider
|
||||||
@@ -901,7 +941,10 @@ class CacheAwareScheduler:
|
|||||||
for key in keys:
|
for key in keys:
|
||||||
# Key 级别的能力检查(模型级别的能力检查已在上面完成)
|
# Key 级别的能力检查(模型级别的能力检查已在上面完成)
|
||||||
is_available, skip_reason = self._check_key_availability(
|
is_available, skip_reason = self._check_key_availability(
|
||||||
key, model_name, capability_requirements
|
key,
|
||||||
|
model_name,
|
||||||
|
capability_requirements,
|
||||||
|
resolved_model_name=resolved_model_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
candidate = ProviderCandidate(
|
candidate = ProviderCandidate(
|
||||||
@@ -970,11 +1013,13 @@ class CacheAwareScheduler:
|
|||||||
candidate.is_cached = True
|
candidate.is_cached = True
|
||||||
cached_candidates.append(candidate)
|
cached_candidates.append(candidate)
|
||||||
matched = True
|
matched = True
|
||||||
logger.debug(f"检测到缓存亲和性: affinity_key={affinity_key[:8]}..., "
|
logger.debug(
|
||||||
|
f"检测到缓存亲和性: affinity_key={affinity_key[:8]}..., "
|
||||||
f"api_format={api_format_str}, global_model_id={global_model_id[:8]}..., "
|
f"api_format={api_format_str}, global_model_id={global_model_id[:8]}..., "
|
||||||
f"provider={provider.name}, endpoint={endpoint.id[:8]}..., "
|
f"provider={provider.name}, endpoint={endpoint.id[:8]}..., "
|
||||||
f"provider_key=***{key.api_key[-4:]}, "
|
f"provider_key=***{key.api_key[-4:]}, "
|
||||||
f"使用次数={affinity.request_count}")
|
f"使用次数={affinity.request_count}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
candidate.is_cached = False
|
candidate.is_cached = False
|
||||||
other_candidates.append(candidate)
|
other_candidates.append(candidate)
|
||||||
@@ -1080,6 +1125,7 @@ class CacheAwareScheduler:
|
|||||||
c.key.internal_priority if c.key else 999999,
|
c.key.internal_priority if c.key else 999999,
|
||||||
c.key.id if c.key else "",
|
c.key.id if c.key else "",
|
||||||
)
|
)
|
||||||
|
|
||||||
result.extend(sorted(group, key=secondary_sort))
|
result.extend(sorted(group, key=secondary_sort))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
Reference in New Issue
Block a user