mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
refactor(backend): optimize cache system and model/provider services
This commit is contained in:
79
src/services/cache/aware_scheduler.py
vendored
79
src/services/cache/aware_scheduler.py
vendored
@@ -36,7 +36,7 @@ from src.core.logger import logger
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from src.core.enums import APIFormat
|
||||
from src.core.exceptions import ProviderNotAvailableException
|
||||
from src.core.exceptions import ModelNotSupportedException, ProviderNotAvailableException
|
||||
from src.models.database import (
|
||||
ApiKey,
|
||||
Model,
|
||||
@@ -227,19 +227,6 @@ class CacheAwareScheduler:
|
||||
if provider_offset == 0:
|
||||
# 没有找到任何候选,提供友好的错误提示
|
||||
error_msg = f"模型 '{model_name}' 不可用"
|
||||
|
||||
# 查找相似模型
|
||||
from src.services.model.mapping_resolver import get_model_mapping_resolver
|
||||
|
||||
resolver = get_model_mapping_resolver()
|
||||
similar_models = resolver.find_similar_models(db, model_name, limit=3)
|
||||
|
||||
if similar_models:
|
||||
suggestions = [
|
||||
f"{name} (相似度: {score:.0%})" for name, score in similar_models
|
||||
]
|
||||
error_msg += f"\n\n您可能想使用以下模型:\n - " + "\n - ".join(suggestions)
|
||||
|
||||
raise ProviderNotAvailableException(error_msg)
|
||||
break
|
||||
|
||||
@@ -579,13 +566,20 @@ class CacheAwareScheduler:
|
||||
|
||||
target_format = normalize_api_format(api_format)
|
||||
|
||||
# 0. 解析 model_name 到 GlobalModel(用于缓存亲和性的规范化标识)
|
||||
from src.services.model.mapping_resolver import get_model_mapping_resolver
|
||||
mapping_resolver = get_model_mapping_resolver()
|
||||
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None)
|
||||
# 0. 解析 model_name 到 GlobalModel(直接查找,用户必须使用标准名称)
|
||||
from src.models.database import GlobalModel
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not global_model:
|
||||
logger.warning(f"GlobalModel not found: {model_name}")
|
||||
raise ModelNotSupportedException(model=model_name)
|
||||
|
||||
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存
|
||||
global_model_id: str = str(global_model.id) if global_model else model_name
|
||||
global_model_id: str = str(global_model.id)
|
||||
|
||||
# 获取合并后的访问限制(ApiKey + User)
|
||||
restrictions = self._get_effective_restrictions(user_api_key)
|
||||
@@ -685,7 +679,6 @@ class CacheAwareScheduler:
|
||||
db.query(Provider)
|
||||
.options(
|
||||
selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys),
|
||||
selectinload(Provider.model_mappings),
|
||||
# 同时加载 models 和 global_model 关系,以便 get_effective_* 方法能正确继承默认值
|
||||
selectinload(Provider.models).selectinload(Model.global_model),
|
||||
)
|
||||
@@ -715,63 +708,33 @@ class CacheAwareScheduler:
|
||||
- 模型支持的能力是全局的,与具体的 Key 无关
|
||||
- 如果模型不支持某能力,整个 Provider 的所有 Key 都应该被跳过
|
||||
|
||||
映射回退逻辑:
|
||||
- 如果存在模型映射(mapping),先尝试映射后的模型
|
||||
- 如果映射后的模型因能力不满足而失败,回退尝试原模型
|
||||
- 其他失败原因(如模型不存在、Provider 未实现等)不触发回退
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider: Provider 对象
|
||||
model_name: 模型名称
|
||||
model_name: 模型名称(必须是 GlobalModel.name)
|
||||
is_stream: 是否是流式请求,如果为 True 则同时检查流式支持
|
||||
capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力
|
||||
|
||||
Returns:
|
||||
(is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表
|
||||
"""
|
||||
from src.services.model.mapping_resolver import get_model_mapping_resolver
|
||||
from src.models.database import GlobalModel
|
||||
|
||||
mapping_resolver = get_model_mapping_resolver()
|
||||
|
||||
# 获取映射后的模型,同时检查是否发生了映射
|
||||
global_model, is_mapped = await mapping_resolver.get_global_model_with_mapping_info(
|
||||
db, model_name, str(provider.id)
|
||||
# 直接通过 GlobalModel.name 查找
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not global_model:
|
||||
return False, "模型不存在", None
|
||||
|
||||
# 尝试检查映射后的模型
|
||||
# 检查模型支持
|
||||
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
|
||||
db, provider, global_model, model_name, is_stream, capability_requirements
|
||||
)
|
||||
|
||||
# 如果映射后的模型因能力不满足而失败,且存在映射,则回退尝试原模型
|
||||
if not is_supported and is_mapped and skip_reason and "不支持能力" in skip_reason:
|
||||
logger.debug(
|
||||
f"Provider {provider.name} 映射模型 {global_model.name} 能力不满足,"
|
||||
f"回退尝试原模型 {model_name}"
|
||||
)
|
||||
|
||||
# 获取原模型(不应用映射)
|
||||
original_global_model = await mapping_resolver.get_global_model_direct(db, model_name)
|
||||
|
||||
if original_global_model and original_global_model.id != global_model.id:
|
||||
# 尝试原模型
|
||||
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
|
||||
db,
|
||||
provider,
|
||||
original_global_model,
|
||||
model_name,
|
||||
is_stream,
|
||||
capability_requirements,
|
||||
)
|
||||
if is_supported:
|
||||
logger.debug(
|
||||
f"Provider {provider.name} 原模型 {original_global_model.name} 支持所需能力"
|
||||
)
|
||||
|
||||
return is_supported, skip_reason, caps
|
||||
|
||||
async def _check_model_support_for_global_model(
|
||||
|
||||
3
src/services/cache/backend.py
vendored
3
src/services/cache/backend.py
vendored
@@ -6,8 +6,7 @@
|
||||
2. RedisCache: Redis 缓存(分布式)
|
||||
|
||||
使用场景:
|
||||
- ModelMappingResolver: 模型映射与别名解析缓存
|
||||
- ModelMapper: 模型映射缓存
|
||||
- ModelCacheService: 模型解析缓存
|
||||
- 其他需要缓存的服务
|
||||
"""
|
||||
|
||||
|
||||
42
src/services/cache/invalidation.py
vendored
42
src/services/cache/invalidation.py
vendored
@@ -3,18 +3,14 @@
|
||||
|
||||
统一管理各种缓存的失效逻辑,支持:
|
||||
1. GlobalModel 变更时失效相关缓存
|
||||
2. ModelMapping 变更时失效别名/降级缓存
|
||||
3. Model 变更时失效模型映射缓存
|
||||
4. 支持同步和异步缓存后端
|
||||
2. Model 变更时失效模型映射缓存
|
||||
3. 支持同步和异步缓存后端
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
class CacheInvalidationService:
|
||||
"""
|
||||
@@ -25,14 +21,8 @@ class CacheInvalidationService:
|
||||
|
||||
def __init__(self):
|
||||
"""初始化缓存失效服务"""
|
||||
self._mapping_resolver = None
|
||||
self._model_mappers = [] # 可能有多个 ModelMapperMiddleware 实例
|
||||
|
||||
def set_mapping_resolver(self, mapping_resolver):
|
||||
"""设置模型映射解析器实例"""
|
||||
self._mapping_resolver = mapping_resolver
|
||||
logger.debug(f"[CacheInvalidation] 模型映射解析器已注册 (实例: {id(mapping_resolver)})")
|
||||
|
||||
def register_model_mapper(self, model_mapper):
|
||||
"""注册 ModelMapper 实例"""
|
||||
if model_mapper not in self._model_mappers:
|
||||
@@ -48,37 +38,12 @@ class CacheInvalidationService:
|
||||
"""
|
||||
logger.info(f"[CacheInvalidation] GlobalModel 变更: {model_name}")
|
||||
|
||||
# 异步失效模型解析器中的缓存
|
||||
if self._mapping_resolver:
|
||||
asyncio.create_task(self._mapping_resolver.invalidate_global_model_cache())
|
||||
|
||||
# 失效所有 ModelMapper 中与此模型相关的缓存
|
||||
for mapper in self._model_mappers:
|
||||
# 清空所有缓存(因为不知道哪些 provider 使用了这个模型)
|
||||
mapper.clear_cache()
|
||||
logger.debug(f"[CacheInvalidation] 已清空 ModelMapper 缓存")
|
||||
|
||||
def on_model_mapping_changed(self, source_model: str, provider_id: Optional[str] = None):
|
||||
"""
|
||||
ModelMapping 变更时的缓存失效
|
||||
|
||||
Args:
|
||||
source_model: 变更的源模型名
|
||||
provider_id: 相关 Provider(None 表示全局)
|
||||
"""
|
||||
logger.info(f"[CacheInvalidation] ModelMapping 变更: {source_model} (provider={provider_id})")
|
||||
|
||||
if self._mapping_resolver:
|
||||
asyncio.create_task(
|
||||
self._mapping_resolver.invalidate_mapping_cache(source_model, provider_id)
|
||||
)
|
||||
|
||||
for mapper in self._model_mappers:
|
||||
if provider_id:
|
||||
mapper.refresh_cache(provider_id)
|
||||
else:
|
||||
mapper.clear_cache()
|
||||
|
||||
def on_model_changed(self, provider_id: str, global_model_id: str):
|
||||
"""
|
||||
Model 变更时的缓存失效
|
||||
@@ -98,9 +63,6 @@ class CacheInvalidationService:
|
||||
"""清空所有缓存"""
|
||||
logger.info("[CacheInvalidation] 清空所有缓存")
|
||||
|
||||
if self._mapping_resolver:
|
||||
asyncio.create_task(self._mapping_resolver.clear_cache())
|
||||
|
||||
for mapper in self._model_mappers:
|
||||
mapper.clear_cache()
|
||||
|
||||
|
||||
92
src/services/cache/model_cache.py
vendored
92
src/services/cache/model_cache.py
vendored
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Model 映射缓存服务 - 减少模型映射和别名查询
|
||||
Model 映射缓存服务 - 减少模型查询
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
@@ -9,8 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from src.config.constants import CacheTTL
|
||||
from src.core.cache_service import CacheService
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, Model, ModelMapping
|
||||
|
||||
from src.models.database import GlobalModel, Model
|
||||
|
||||
|
||||
class ModelCacheService:
|
||||
@@ -158,56 +157,6 @@ class ModelCacheService:
|
||||
|
||||
return global_model
|
||||
|
||||
@staticmethod
|
||||
async def resolve_alias(
|
||||
db: Session, source_model: str, provider_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
解析模型别名(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
source_model: 源模型名称或别名
|
||||
provider_id: Provider ID(可选,用于 Provider 特定别名)
|
||||
|
||||
Returns:
|
||||
目标 GlobalModel ID 或 None
|
||||
"""
|
||||
# 构造缓存键
|
||||
if provider_id:
|
||||
cache_key = f"alias:provider:{provider_id}:{source_model}"
|
||||
else:
|
||||
cache_key = f"alias:global:{source_model}"
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_result = await CacheService.get(cache_key)
|
||||
if cached_result:
|
||||
logger.debug(f"别名缓存命中: {source_model} (provider: {provider_id or 'global'})")
|
||||
return cached_result
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
query = db.query(ModelMapping).filter(ModelMapping.source_model == source_model)
|
||||
|
||||
if provider_id:
|
||||
# Provider 特定别名优先
|
||||
query = query.filter(ModelMapping.provider_id == provider_id)
|
||||
else:
|
||||
# 全局别名
|
||||
query = query.filter(ModelMapping.provider_id.is_(None))
|
||||
|
||||
mapping = query.first()
|
||||
|
||||
# 3. 写入缓存
|
||||
target_global_model_id = mapping.target_global_model_id if mapping else None
|
||||
await CacheService.set(
|
||||
cache_key, target_global_model_id, ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
)
|
||||
|
||||
if mapping:
|
||||
logger.debug(f"别名已缓存: {source_model} → {target_global_model_id}")
|
||||
|
||||
return target_global_model_id
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_model_cache(
|
||||
model_id: str, provider_id: Optional[str] = None, global_model_id: Optional[str] = None
|
||||
@@ -237,17 +186,6 @@ class ModelCacheService:
|
||||
await CacheService.delete(f"global_model:name:{name}")
|
||||
logger.debug(f"GlobalModel 缓存已清除: {global_model_id}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_alias_cache(source_model: str, provider_id: Optional[str] = None):
|
||||
"""清除别名缓存"""
|
||||
if provider_id:
|
||||
cache_key = f"alias:provider:{provider_id}:{source_model}"
|
||||
else:
|
||||
cache_key = f"alias:global:{source_model}"
|
||||
|
||||
await CacheService.delete(cache_key)
|
||||
logger.debug(f"别名缓存已清除: {source_model}")
|
||||
|
||||
@staticmethod
|
||||
def _model_to_dict(model: Model) -> dict:
|
||||
"""将 Model 对象转换为字典"""
|
||||
@@ -256,6 +194,7 @@ class ModelCacheService:
|
||||
"provider_id": model.provider_id,
|
||||
"global_model_id": model.global_model_id,
|
||||
"provider_model_name": model.provider_model_name,
|
||||
"provider_model_aliases": getattr(model, "provider_model_aliases", None),
|
||||
"is_active": model.is_active,
|
||||
"is_available": model.is_available if hasattr(model, "is_available") else True,
|
||||
"price_per_request": (
|
||||
@@ -266,6 +205,7 @@ class ModelCacheService:
|
||||
"supports_function_calling": model.supports_function_calling,
|
||||
"supports_streaming": model.supports_streaming,
|
||||
"supports_extended_thinking": model.supports_extended_thinking,
|
||||
"supports_image_generation": getattr(model, "supports_image_generation", None),
|
||||
"config": model.config,
|
||||
}
|
||||
|
||||
@@ -277,6 +217,7 @@ class ModelCacheService:
|
||||
provider_id=model_dict["provider_id"],
|
||||
global_model_id=model_dict["global_model_id"],
|
||||
provider_model_name=model_dict["provider_model_name"],
|
||||
provider_model_aliases=model_dict.get("provider_model_aliases"),
|
||||
is_active=model_dict["is_active"],
|
||||
is_available=model_dict.get("is_available", True),
|
||||
price_per_request=model_dict.get("price_per_request"),
|
||||
@@ -285,6 +226,7 @@ class ModelCacheService:
|
||||
supports_function_calling=model_dict.get("supports_function_calling"),
|
||||
supports_streaming=model_dict.get("supports_streaming"),
|
||||
supports_extended_thinking=model_dict.get("supports_extended_thinking"),
|
||||
supports_image_generation=model_dict.get("supports_image_generation"),
|
||||
config=model_dict.get("config"),
|
||||
)
|
||||
return model
|
||||
@@ -296,12 +238,11 @@ class ModelCacheService:
|
||||
"id": global_model.id,
|
||||
"name": global_model.name,
|
||||
"display_name": global_model.display_name,
|
||||
"family": global_model.family,
|
||||
"group_id": global_model.group_id,
|
||||
"supports_vision": global_model.supports_vision,
|
||||
"supports_thinking": global_model.supports_thinking,
|
||||
"context_window": global_model.context_window,
|
||||
"max_output_tokens": global_model.max_output_tokens,
|
||||
"default_supports_vision": global_model.default_supports_vision,
|
||||
"default_supports_function_calling": global_model.default_supports_function_calling,
|
||||
"default_supports_streaming": global_model.default_supports_streaming,
|
||||
"default_supports_extended_thinking": global_model.default_supports_extended_thinking,
|
||||
"default_supports_image_generation": global_model.default_supports_image_generation,
|
||||
"is_active": global_model.is_active,
|
||||
"description": global_model.description,
|
||||
}
|
||||
@@ -313,12 +254,11 @@ class ModelCacheService:
|
||||
id=global_model_dict["id"],
|
||||
name=global_model_dict["name"],
|
||||
display_name=global_model_dict.get("display_name"),
|
||||
family=global_model_dict.get("family"),
|
||||
group_id=global_model_dict.get("group_id"),
|
||||
supports_vision=global_model_dict.get("supports_vision", False),
|
||||
supports_thinking=global_model_dict.get("supports_thinking", False),
|
||||
context_window=global_model_dict.get("context_window"),
|
||||
max_output_tokens=global_model_dict.get("max_output_tokens"),
|
||||
default_supports_vision=global_model_dict.get("default_supports_vision", False),
|
||||
default_supports_function_calling=global_model_dict.get("default_supports_function_calling", False),
|
||||
default_supports_streaming=global_model_dict.get("default_supports_streaming", True),
|
||||
default_supports_extended_thinking=global_model_dict.get("default_supports_extended_thinking", False),
|
||||
default_supports_image_generation=global_model_dict.get("default_supports_image_generation", False),
|
||||
is_active=global_model_dict.get("is_active", True),
|
||||
description=global_model_dict.get("description"),
|
||||
)
|
||||
|
||||
14
src/services/cache/sync.py
vendored
14
src/services/cache/sync.py
vendored
@@ -6,7 +6,7 @@
|
||||
|
||||
使用场景:
|
||||
1. 多实例部署时,确保所有实例的缓存一致性
|
||||
2. GlobalModel/ModelMapping 变更时,同步失效所有实例的缓存
|
||||
2. GlobalModel/Model 变更时,同步失效所有实例的缓存
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -29,7 +29,6 @@ class CacheSyncService:
|
||||
|
||||
# Redis 频道名称
|
||||
CHANNEL_GLOBAL_MODEL = "cache:invalidate:global_model"
|
||||
CHANNEL_MODEL_MAPPING = "cache:invalidate:model_mapping"
|
||||
CHANNEL_MODEL = "cache:invalidate:model"
|
||||
CHANNEL_CLEAR_ALL = "cache:invalidate:clear_all"
|
||||
|
||||
@@ -58,7 +57,6 @@ class CacheSyncService:
|
||||
# 订阅所有缓存失效频道
|
||||
await self._pubsub.subscribe(
|
||||
self.CHANNEL_GLOBAL_MODEL,
|
||||
self.CHANNEL_MODEL_MAPPING,
|
||||
self.CHANNEL_MODEL,
|
||||
self.CHANNEL_CLEAR_ALL,
|
||||
)
|
||||
@@ -68,7 +66,7 @@ class CacheSyncService:
|
||||
self._running = True
|
||||
|
||||
logger.info("[CacheSync] 缓存同步服务已启动,订阅频道: "
|
||||
f"{self.CHANNEL_GLOBAL_MODEL}, {self.CHANNEL_MODEL_MAPPING}, "
|
||||
f"{self.CHANNEL_GLOBAL_MODEL}, "
|
||||
f"{self.CHANNEL_MODEL}, {self.CHANNEL_CLEAR_ALL}")
|
||||
except Exception as e:
|
||||
logger.error(f"[CacheSync] 启动失败: {e}")
|
||||
@@ -141,14 +139,6 @@ class CacheSyncService:
|
||||
"""发布 GlobalModel 变更通知"""
|
||||
await self._publish(self.CHANNEL_GLOBAL_MODEL, {"model_name": model_name})
|
||||
|
||||
async def publish_model_mapping_changed(
|
||||
self, source_model: str, provider_id: Optional[str] = None
|
||||
):
|
||||
"""发布 ModelMapping 变更通知"""
|
||||
await self._publish(
|
||||
self.CHANNEL_MODEL_MAPPING, {"source_model": source_model, "provider_id": provider_id}
|
||||
)
|
||||
|
||||
async def publish_model_changed(self, provider_id: str, global_model_id: str):
|
||||
"""发布 Model 变更通知"""
|
||||
await self._publish(
|
||||
|
||||
@@ -1,19 +1,15 @@
|
||||
"""
|
||||
模型服务模块
|
||||
|
||||
包含模型管理、模型映射、成本计算等功能。
|
||||
包含模型管理、成本计算等功能。
|
||||
"""
|
||||
|
||||
from src.services.model.cost import ModelCostService
|
||||
from src.services.model.global_model import GlobalModelService
|
||||
from src.services.model.mapper import ModelMapperMiddleware
|
||||
from src.services.model.mapping_resolver import ModelMappingResolver
|
||||
from src.services.model.service import ModelService
|
||||
|
||||
__all__ = [
|
||||
"ModelService",
|
||||
"GlobalModelService",
|
||||
"ModelMapperMiddleware",
|
||||
"ModelMappingResolver",
|
||||
"ModelCostService",
|
||||
]
|
||||
|
||||
@@ -14,7 +14,7 @@ from typing import Dict, Optional, Tuple, Union
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, Model, ModelMapping, Provider
|
||||
from src.models.database import GlobalModel, Model, Provider
|
||||
|
||||
|
||||
ProviderRef = Union[str, Provider, None]
|
||||
@@ -161,16 +161,11 @@ class ModelCostService:
|
||||
result = None
|
||||
|
||||
if provider_obj:
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(
|
||||
self.db, model, provider_obj.id
|
||||
)
|
||||
|
||||
# 直接通过 GlobalModel.name 查找
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == global_model_name,
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
@@ -226,17 +221,14 @@ class ModelCostService:
|
||||
注意:如果模型配置了阶梯计费,此方法返回第一个阶梯的价格作为默认值。
|
||||
实际计费时应使用 compute_cost_with_tiered_pricing 方法。
|
||||
|
||||
计费逻辑(基于 mapping_type):
|
||||
1. 查找 ModelMapping(如果存在)
|
||||
2. 如果 mapping_type='alias':使用目标 GlobalModel 的价格
|
||||
3. 如果 mapping_type='mapping':尝试使用 source_model 对应的 GlobalModel 价格
|
||||
- 如果 source_model 对应的 GlobalModel 存在且有 Model 实现,使用那个价格
|
||||
- 否则回退到目标 GlobalModel 的价格
|
||||
4. 如果没有找到任何 ModelMapping,尝试直接匹配 GlobalModel.name
|
||||
计费逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现
|
||||
3. 获取价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
(input_price, output_price) 元组
|
||||
@@ -253,136 +245,37 @@ class ModelCostService:
|
||||
output_price = None
|
||||
|
||||
if provider_obj:
|
||||
# 步骤 1: 查找 ModelMapping 以确定 mapping_type
|
||||
from src.models.database import ModelMapping
|
||||
|
||||
mapping = None
|
||||
# 先查 Provider 特定映射
|
||||
mapping = (
|
||||
self.db.query(ModelMapping)
|
||||
# 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称)
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
ModelMapping.source_model == model,
|
||||
ModelMapping.provider_id == provider_obj.id,
|
||||
ModelMapping.is_active == True,
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
# 再查全局映射
|
||||
if not mapping:
|
||||
mapping = (
|
||||
self.db.query(ModelMapping)
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
ModelMapping.source_model == model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.is_active == True,
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if mapping:
|
||||
# 有映射,根据 mapping_type 决定计费模型
|
||||
if mapping.mapping_type == "mapping":
|
||||
# mapping 模式:尝试使用 source_model 对应的 GlobalModel 价格
|
||||
source_global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if source_global_model:
|
||||
source_model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == source_global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if source_model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = source_model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = source_model_obj.get_effective_input_price()
|
||||
output_price = source_model_obj.get_effective_output_price()
|
||||
logger.debug(f"[mapping模式] 使用源模型价格: {model} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
|
||||
# alias 模式或 mapping 模式未找到源模型价格:使用目标 GlobalModel 价格
|
||||
if input_price is None:
|
||||
target_global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == mapping.target_global_model_id,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if target_global_model:
|
||||
target_model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == target_global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if target_model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = target_model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = target_model_obj.get_effective_input_price()
|
||||
output_price = target_model_obj.get_effective_output_price()
|
||||
mode_label = (
|
||||
"alias模式"
|
||||
if mapping.mapping_type == "alias"
|
||||
else "mapping模式(回退)"
|
||||
)
|
||||
logger.debug(f"[{mode_label}] 使用目标模型价格: {model} -> {target_global_model.name} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
else:
|
||||
# 没有映射,尝试直接匹配 GlobalModel.name
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = model_obj.get_effective_input_price()
|
||||
output_price = model_obj.get_effective_output_price()
|
||||
logger.debug(f"找到模型价格配置: {provider_name}/{model} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
if model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = model_obj.get_effective_input_price()
|
||||
output_price = model_obj.get_effective_output_price()
|
||||
logger.debug(f"找到模型价格配置: {provider_name}/{model} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
|
||||
# 如果没有找到价格配置,使用 0.0 并记录警告
|
||||
if input_price is None:
|
||||
@@ -404,15 +297,14 @@ class ModelCostService:
|
||||
"""
|
||||
返回给定 provider/model 的 (input_price, output_price)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取价格配置
|
||||
逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现
|
||||
3. 获取价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
(input_price, output_price) 元组
|
||||
@@ -434,15 +326,9 @@ class ModelCostService:
|
||||
"""
|
||||
异步版本: 返回缓存创建/读取价格(每 1M tokens)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取缓存价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
input_price: 基础输入价格(用于 Claude 模型的默认估算)
|
||||
|
||||
Returns:
|
||||
@@ -460,22 +346,17 @@ class ModelCostService:
|
||||
cache_read_price = None
|
||||
|
||||
if provider_obj:
|
||||
# 步骤 1: 检查是否是别名
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
|
||||
|
||||
# 步骤 2: 查找 GlobalModel
|
||||
# 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称)
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == global_model_name,
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# 步骤 3: 查找该 Provider 的 Model 实现
|
||||
# 查找该 Provider 的 Model 实现
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
@@ -517,15 +398,9 @@ class ModelCostService:
|
||||
"""
|
||||
异步版本: 返回按次计费价格(每次请求的固定费用)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取按次计费价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
按次计费价格,如果没有配置则返回 None
|
||||
@@ -534,22 +409,17 @@ class ModelCostService:
|
||||
price_per_request = None
|
||||
|
||||
if provider_obj:
|
||||
# 步骤 1: 检查是否是别名
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
|
||||
|
||||
# 步骤 2: 查找 GlobalModel
|
||||
# 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称)
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == global_model_name,
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# 步骤 3: 查找该 Provider 的 Model 实现
|
||||
# 查找该 Provider 的 Model 实现
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
@@ -595,15 +465,14 @@ class ModelCostService:
|
||||
"""
|
||||
返回缓存创建/读取价格(每 1M tokens)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取缓存价格配置
|
||||
逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现
|
||||
3. 获取缓存价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
input_price: 基础输入价格(用于 Claude 模型的默认估算)
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, Model, ModelMapping
|
||||
from src.models.database import GlobalModel, Model
|
||||
from src.models.pydantic_models import GlobalModelUpdate
|
||||
|
||||
|
||||
|
||||
@@ -5,18 +5,13 @@
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.core.cache_utils import SyncLRUCache
|
||||
from src.core.logger import logger
|
||||
from src.models.claude import ClaudeMessagesRequest
|
||||
from src.models.database import GlobalModel, Model, ModelMapping, Provider, ProviderEndpoint
|
||||
from src.models.database import GlobalModel, Model, Provider, ProviderEndpoint
|
||||
from src.services.cache.model_cache import ModelCacheService
|
||||
from src.services.model.mapping_resolver import (
|
||||
get_model_mapping_resolver,
|
||||
resolve_model_to_global_name,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class ModelMapperMiddleware:
|
||||
@@ -71,10 +66,10 @@ class ModelMapperMiddleware:
|
||||
if mapping:
|
||||
# 应用映射
|
||||
original_model = request.model
|
||||
request.model = mapping.model.provider_model_name
|
||||
request.model = mapping.model.select_provider_model_name()
|
||||
|
||||
logger.debug(f"Applied model mapping for provider {provider.name}: "
|
||||
f"{original_model} -> {mapping.model.provider_model_name}")
|
||||
f"{original_model} -> {request.model}")
|
||||
else:
|
||||
# 没有找到映射,使用原始模型名
|
||||
logger.debug(f"No model mapping found for {source_model} with provider {provider.name}, "
|
||||
@@ -84,17 +79,16 @@ class ModelMapperMiddleware:
|
||||
|
||||
async def get_mapping(
|
||||
self, source_model: str, provider_id: str
|
||||
) -> Optional[ModelMapping]: # UUID
|
||||
) -> Optional[object]:
|
||||
"""
|
||||
获取模型映射
|
||||
|
||||
优化后逻辑:
|
||||
1. 使用统一的 ModelMappingResolver 解析别名(带缓存)
|
||||
2. 通过 GlobalModel 找到该 Provider 的 Model 实现
|
||||
3. 使用独立的映射缓存
|
||||
简化后的逻辑:
|
||||
1. 通过 GlobalModel.name 直接查找
|
||||
2. 找到 GlobalModel 后,查找该 Provider 的 Model 实现
|
||||
|
||||
Args:
|
||||
source_model: 用户请求的模型名或别名
|
||||
source_model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
provider_id: 提供商ID (UUID)
|
||||
|
||||
Returns:
|
||||
@@ -107,62 +101,59 @@ class ModelMapperMiddleware:
|
||||
|
||||
mapping = None
|
||||
|
||||
# 步骤 1 & 2: 通过统一的模型映射解析服务
|
||||
mapping_resolver = get_model_mapping_resolver()
|
||||
global_model = await mapping_resolver.get_global_model_by_request(
|
||||
self.db, source_model, provider_id
|
||||
# 步骤 1: 直接通过名称查找 GlobalModel
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == source_model, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not global_model:
|
||||
logger.debug(f"GlobalModel not found: {source_model} (provider={provider_id[:8]}...)")
|
||||
logger.debug(f"GlobalModel not found: {source_model}")
|
||||
self._cache[cache_key] = None
|
||||
return None
|
||||
|
||||
# 步骤 3: 查找该 Provider 是否有实现这个 GlobalModel 的 Model(使用缓存)
|
||||
# 步骤 2: 查找该 Provider 是否有实现这个 GlobalModel 的 Model(使用缓存)
|
||||
model = await ModelCacheService.get_model_by_provider_and_global_model(
|
||||
self.db, provider_id, global_model.id
|
||||
)
|
||||
|
||||
if model:
|
||||
# 只有当模型名发生变化时才返回映射
|
||||
if model.provider_model_name != source_model:
|
||||
mapping = type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"source_model": source_model,
|
||||
"model": model,
|
||||
"is_active": True,
|
||||
"provider_id": provider_id,
|
||||
},
|
||||
)()
|
||||
# 创建映射对象
|
||||
mapping = type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"source_model": source_model,
|
||||
"model": model,
|
||||
"is_active": True,
|
||||
"provider_id": provider_id,
|
||||
},
|
||||
)()
|
||||
|
||||
logger.debug(f"Found model mapping: {source_model} -> {model.provider_model_name} "
|
||||
f"(provider={provider_id[:8]}...)")
|
||||
else:
|
||||
logger.debug(f"Model found but no name change: {source_model} (provider={provider_id[:8]}...)")
|
||||
logger.debug(f"Found model mapping: {source_model} -> {model.provider_model_name} "
|
||||
f"(provider={provider_id[:8]}...)")
|
||||
|
||||
# 缓存结果
|
||||
self._cache[cache_key] = mapping
|
||||
|
||||
return mapping
|
||||
|
||||
def get_all_mappings(self, provider_id: str) -> List[ModelMapping]: # UUID
|
||||
def get_all_mappings(self, provider_id: str) -> List[object]:
|
||||
"""
|
||||
获取提供商的所有可用模型(通过 GlobalModel)
|
||||
|
||||
方案 A: 返回该 Provider 所有可用的 GlobalModel
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID (UUID)
|
||||
|
||||
Returns:
|
||||
模型映射列表(模拟的 ModelMapping 对象列表)
|
||||
模型映射列表
|
||||
"""
|
||||
# 查询该 Provider 的所有活跃 Model
|
||||
# 查询该 Provider 的所有活跃 Model(使用 joinedload 避免 N+1)
|
||||
models = (
|
||||
self.db.query(Model)
|
||||
.join(GlobalModel)
|
||||
.options(joinedload(Model.global_model))
|
||||
.filter(
|
||||
Model.provider_id == provider_id,
|
||||
Model.is_active == True,
|
||||
@@ -171,7 +162,7 @@ class ModelMapperMiddleware:
|
||||
.all()
|
||||
)
|
||||
|
||||
# 构造兼容的 ModelMapping 对象列表
|
||||
# 构造兼容的映射对象列表
|
||||
mappings = []
|
||||
for model in models:
|
||||
mapping = type(
|
||||
@@ -188,7 +179,7 @@ class ModelMapperMiddleware:
|
||||
|
||||
return mappings
|
||||
|
||||
def get_supported_models(self, provider_id: str) -> List[str]: # UUID
|
||||
def get_supported_models(self, provider_id: str) -> List[str]:
|
||||
"""
|
||||
获取提供商支持的所有源模型名
|
||||
|
||||
@@ -223,15 +214,6 @@ class ModelMapperMiddleware:
|
||||
if not mapping.is_active:
|
||||
return False, f"Model mapping for {request.model} is disabled"
|
||||
|
||||
# 不限制max_tokens,作为中转服务不应该限制用户的请求
|
||||
# if request.max_tokens and request.max_tokens > mapping.max_output_tokens:
|
||||
# return False, (
|
||||
# f"Requested max_tokens {request.max_tokens} exceeds limit "
|
||||
# f"{mapping.max_output_tokens} for model {request.model}"
|
||||
# )
|
||||
|
||||
# 可以添加更多验证逻辑,比如检查输入长度等
|
||||
|
||||
return True, None
|
||||
|
||||
def clear_cache(self):
|
||||
@@ -239,7 +221,7 @@ class ModelMapperMiddleware:
|
||||
self._cache.clear()
|
||||
logger.debug("Model mapping cache cleared")
|
||||
|
||||
def refresh_cache(self, provider_id: Optional[str] = None): # UUID
|
||||
def refresh_cache(self, provider_id: Optional[str] = None):
|
||||
"""
|
||||
刷新缓存
|
||||
|
||||
@@ -285,16 +267,10 @@ class ModelRoutingMiddleware:
|
||||
"""
|
||||
根据模型名选择提供商
|
||||
|
||||
逻辑:
|
||||
1. 如果指定了提供商,使用指定的提供商
|
||||
2. 如果没指定,使用默认提供商
|
||||
3. 选定提供商后,会检查该提供商的模型映射(在apply_mapping中处理)
|
||||
4. 如果指定了allowed_api_formats,只选择符合格式的提供商
|
||||
|
||||
Args:
|
||||
model_name: 请求的模型名
|
||||
preferred_provider: 首选提供商名称
|
||||
allowed_api_formats: 允许的API格式列表(如 ['CLAUDE', 'CLAUDE_CLI'])
|
||||
allowed_api_formats: 允许的API格式列表
|
||||
request_id: 请求ID(用于日志关联)
|
||||
|
||||
Returns:
|
||||
@@ -313,14 +289,12 @@ class ModelRoutingMiddleware:
|
||||
if provider:
|
||||
# 检查API格式 - 从 endpoints 中检查
|
||||
if allowed_api_formats:
|
||||
# 检查是否有符合要求的活跃端点
|
||||
has_matching_endpoint = any(
|
||||
ep.is_active and ep.api_format and ep.api_format in allowed_api_formats
|
||||
for ep in provider.endpoints
|
||||
)
|
||||
if not has_matching_endpoint:
|
||||
logger.warning(f"Specified provider {provider.name} has no active endpoints with allowed API formats ({allowed_api_formats})")
|
||||
# 不返回该提供商,继续查找
|
||||
else:
|
||||
logger.debug(f" └─ {request_prefix}使用指定提供商: {provider.name} | 模型:{model_name}")
|
||||
return provider
|
||||
@@ -330,10 +304,9 @@ class ModelRoutingMiddleware:
|
||||
else:
|
||||
logger.warning(f"Specified provider {preferred_provider} not found or inactive")
|
||||
|
||||
# 2. 查找优先级最高的活动提供商(provider_priority 最小)
|
||||
# 2. 查找优先级最高的活动提供商
|
||||
query = self.db.query(Provider).filter(Provider.is_active == True)
|
||||
|
||||
# 如果指定了API格式过滤,添加过滤条件 - 检查是否有符合要求的 endpoint
|
||||
if allowed_api_formats:
|
||||
query = (
|
||||
query.join(ProviderEndpoint)
|
||||
@@ -344,32 +317,27 @@ class ModelRoutingMiddleware:
|
||||
.distinct()
|
||||
)
|
||||
|
||||
# 按 provider_priority 排序,优先级最高(数字最小)的在前
|
||||
best_provider = query.order_by(Provider.provider_priority.asc(), Provider.id.asc()).first()
|
||||
|
||||
if best_provider:
|
||||
logger.debug(f" └─ {request_prefix}使用优先级最高提供商: {best_provider.name} (priority:{best_provider.provider_priority}) | 模型:{model_name}")
|
||||
return best_provider
|
||||
|
||||
# 3. 没有任何活动提供商
|
||||
if allowed_api_formats:
|
||||
logger.error(f"No active providers found with allowed API formats {allowed_api_formats}. Please configure at least one provider.")
|
||||
logger.error(f"No active providers found with allowed API formats {allowed_api_formats}.")
|
||||
else:
|
||||
logger.error("No active providers found. Please configure at least one provider.")
|
||||
logger.error("No active providers found.")
|
||||
return None
|
||||
|
||||
def get_available_models(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
获取所有可用的模型及其提供商
|
||||
|
||||
方案 A: 基于 GlobalModel 查询
|
||||
|
||||
Returns:
|
||||
字典,键为 GlobalModel.name,值为支持该模型的提供商名列表
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# 查询所有活跃的 GlobalModel 及其 Provider
|
||||
models = (
|
||||
self.db.query(GlobalModel.name, Provider.name)
|
||||
.join(Model, GlobalModel.id == Model.global_model_id)
|
||||
@@ -392,28 +360,23 @@ class ModelRoutingMiddleware:
|
||||
"""
|
||||
获取某个模型最便宜的提供商
|
||||
|
||||
方案 A: 通过 GlobalModel 查找
|
||||
|
||||
Args:
|
||||
model_name: GlobalModel 名称或别名
|
||||
model_name: GlobalModel 名称
|
||||
|
||||
Returns:
|
||||
最便宜的提供商
|
||||
"""
|
||||
# 步骤 1: 解析模型名
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model_name)
|
||||
|
||||
# 步骤 2: 查找 GlobalModel
|
||||
# 直接查找 GlobalModel
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == global_model_name, GlobalModel.is_active == True)
|
||||
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not global_model:
|
||||
return None
|
||||
|
||||
# 步骤 3: 查询所有支持该模型的 Provider 及其价格
|
||||
# 查询所有支持该模型的 Provider 及其价格
|
||||
models_with_providers = (
|
||||
self.db.query(Provider, Model)
|
||||
.join(Model, Provider.id == Model.provider_id)
|
||||
@@ -428,15 +391,16 @@ class ModelRoutingMiddleware:
|
||||
if not models_with_providers:
|
||||
return None
|
||||
|
||||
# 按总价格排序(输入+输出价格)
|
||||
# 按总价格排序
|
||||
cheapest = min(
|
||||
models_with_providers, key=lambda x: x[1].input_price_per_1m + x[1].output_price_per_1m
|
||||
models_with_providers,
|
||||
key=lambda x: x[1].get_effective_input_price() + x[1].get_effective_output_price()
|
||||
)
|
||||
|
||||
provider = cheapest[0]
|
||||
model = cheapest[1]
|
||||
|
||||
logger.debug(f"Selected cheapest provider {provider.name} for model {model_name} "
|
||||
f"(input: ${model.input_price_per_1m}/M, output: ${model.output_price_per_1m}/M)")
|
||||
f"(input: ${model.get_effective_input_price()}/M, output: ${model.get_effective_output_price()}/M)")
|
||||
|
||||
return provider
|
||||
|
||||
@@ -50,6 +50,7 @@ class ModelService:
|
||||
provider_id=provider_id,
|
||||
global_model_id=model_data.global_model_id,
|
||||
provider_model_name=model_data.provider_model_name,
|
||||
provider_model_aliases=model_data.provider_model_aliases,
|
||||
price_per_request=model_data.price_per_request,
|
||||
tiered_pricing=model_data.tiered_pricing,
|
||||
supports_vision=model_data.supports_vision,
|
||||
@@ -191,7 +192,6 @@ class ModelService:
|
||||
新架构删除逻辑:
|
||||
- Model 只是 Provider 对 GlobalModel 的实现,删除不影响 GlobalModel
|
||||
- 检查是否是该 GlobalModel 的最后一个实现(如果是,警告但允许删除)
|
||||
- 不检查 ModelMapping(映射是 GlobalModel 之间的关系,别名也统一存储在此表中)
|
||||
"""
|
||||
model = db.query(Model).filter(Model.id == model_id).first()
|
||||
if not model:
|
||||
@@ -326,6 +326,7 @@ class ModelService:
|
||||
provider_id=model.provider_id,
|
||||
global_model_id=model.global_model_id,
|
||||
provider_model_name=model.provider_model_name,
|
||||
provider_model_aliases=model.provider_model_aliases,
|
||||
# 原始配置值(可能为空)
|
||||
price_per_request=model.price_per_request,
|
||||
tiered_pricing=model.tiered_pricing,
|
||||
|
||||
@@ -7,13 +7,11 @@ from typing import Dict
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, Model, Provider
|
||||
from src.services.model.cost import ModelCostService
|
||||
from src.services.model.mapper import ModelMapperMiddleware, ModelRoutingMiddleware
|
||||
|
||||
|
||||
|
||||
class ProviderService:
|
||||
"""提供商服务类"""
|
||||
|
||||
@@ -34,30 +32,15 @@ class ProviderService:
|
||||
检查模型是否可用(严格白名单模式)
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
model_name: 模型名称(必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
Model对象如果存在且激活,否则None
|
||||
"""
|
||||
# 首先检查是否有直接的模型记录
|
||||
model = (
|
||||
self.db.query(Model)
|
||||
.filter(Model.provider_model_name == model_name, Model.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model:
|
||||
return model
|
||||
|
||||
# 方案 A:检查是否是别名(全局别名系统)
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model_name)
|
||||
|
||||
# 查找 GlobalModel
|
||||
# 直接查找 GlobalModel
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == global_model_name, GlobalModel.is_active == True)
|
||||
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -79,34 +62,15 @@ class ProviderService:
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
model_name: 模型名称
|
||||
model_name: 模型名称(必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
Model对象如果该提供商支持该模型且激活,否则None
|
||||
"""
|
||||
# 首先检查该提供商下是否有直接的模型记录
|
||||
model = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_id,
|
||||
Model.provider_model_name == model_name,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model:
|
||||
return model
|
||||
|
||||
# 方案 A:检查是否是别名
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model_name, provider_id)
|
||||
|
||||
# 查找 GlobalModel
|
||||
# 直接查找 GlobalModel
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == global_model_name, GlobalModel.is_active == True)
|
||||
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -148,12 +112,19 @@ class ProviderService:
|
||||
获取所有可用的模型
|
||||
|
||||
Returns:
|
||||
模型和支持的提供商映射
|
||||
字典,键为模型名,值为提供商列表
|
||||
"""
|
||||
return self.router.get_available_models()
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
self.mapper.clear_cache()
|
||||
self.cost_service.clear_cache()
|
||||
logger.info("Provider service cache cleared")
|
||||
def select_provider(self, model_name: str, preferred_provider=None):
|
||||
"""
|
||||
选择提供商
|
||||
|
||||
Args:
|
||||
model_name: 模型名
|
||||
preferred_provider: 首选提供商
|
||||
|
||||
Returns:
|
||||
Provider对象
|
||||
"""
|
||||
return self.router.select_provider(model_name, preferred_provider)
|
||||
|
||||
@@ -26,11 +26,10 @@ class UsageService:
|
||||
) -> tuple[float, float]:
|
||||
"""异步获取模型价格(输入价格,输出价格)每1M tokens
|
||||
|
||||
新架构查找逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现并获取价格
|
||||
4. 如果找不到则使用系统默认价格
|
||||
查找逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现并获取价格
|
||||
3. 如果找不到则使用系统默认价格
|
||||
"""
|
||||
|
||||
service = ModelCostService(db)
|
||||
@@ -40,11 +39,10 @@ class UsageService:
|
||||
def get_model_price(cls, db: Session, provider: str, model: str) -> tuple[float, float]:
|
||||
"""获取模型价格(输入价格,输出价格)每1M tokens
|
||||
|
||||
新架构查找逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现并获取价格
|
||||
4. 如果找不到则使用系统默认价格
|
||||
查找逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现并获取价格
|
||||
3. 如果找不到则使用系统默认价格
|
||||
"""
|
||||
|
||||
service = ModelCostService(db)
|
||||
|
||||
Reference in New Issue
Block a user