refactor(backend): optimize cache system and model/provider services

This commit is contained in:
fawney19
2025-12-15 14:30:21 +08:00
parent 56fb6bf36c
commit 7068aa9130
12 changed files with 170 additions and 517 deletions

View File

@@ -36,7 +36,7 @@ from src.core.logger import logger
from sqlalchemy.orm import Session, selectinload from sqlalchemy.orm import Session, selectinload
from src.core.enums import APIFormat from src.core.enums import APIFormat
from src.core.exceptions import ProviderNotAvailableException from src.core.exceptions import ModelNotSupportedException, ProviderNotAvailableException
from src.models.database import ( from src.models.database import (
ApiKey, ApiKey,
Model, Model,
@@ -227,19 +227,6 @@ class CacheAwareScheduler:
if provider_offset == 0: if provider_offset == 0:
# 没有找到任何候选,提供友好的错误提示 # 没有找到任何候选,提供友好的错误提示
error_msg = f"模型 '{model_name}' 不可用" error_msg = f"模型 '{model_name}' 不可用"
# 查找相似模型
from src.services.model.mapping_resolver import get_model_mapping_resolver
resolver = get_model_mapping_resolver()
similar_models = resolver.find_similar_models(db, model_name, limit=3)
if similar_models:
suggestions = [
f"{name} (相似度: {score:.0%})" for name, score in similar_models
]
error_msg += f"\n\n您可能想使用以下模型:\n - " + "\n - ".join(suggestions)
raise ProviderNotAvailableException(error_msg) raise ProviderNotAvailableException(error_msg)
break break
@@ -579,13 +566,20 @@ class CacheAwareScheduler:
target_format = normalize_api_format(api_format) target_format = normalize_api_format(api_format)
# 0. 解析 model_name 到 GlobalModel用于缓存亲和性的规范化标识 # 0. 解析 model_name 到 GlobalModel直接查找,用户必须使用标准名称
from src.services.model.mapping_resolver import get_model_mapping_resolver from src.models.database import GlobalModel
mapping_resolver = get_model_mapping_resolver() global_model = (
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None) db.query(GlobalModel)
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
.first()
)
if not global_model:
logger.warning(f"GlobalModel not found: {model_name}")
raise ModelNotSupportedException(model=model_name)
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存 # 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存
global_model_id: str = str(global_model.id) if global_model else model_name global_model_id: str = str(global_model.id)
# 获取合并后的访问限制ApiKey + User # 获取合并后的访问限制ApiKey + User
restrictions = self._get_effective_restrictions(user_api_key) restrictions = self._get_effective_restrictions(user_api_key)
@@ -685,7 +679,6 @@ class CacheAwareScheduler:
db.query(Provider) db.query(Provider)
.options( .options(
selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys), selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys),
selectinload(Provider.model_mappings),
# 同时加载 models 和 global_model 关系,以便 get_effective_* 方法能正确继承默认值 # 同时加载 models 和 global_model 关系,以便 get_effective_* 方法能正确继承默认值
selectinload(Provider.models).selectinload(Model.global_model), selectinload(Provider.models).selectinload(Model.global_model),
) )
@@ -715,63 +708,33 @@ class CacheAwareScheduler:
- 模型支持的能力是全局的,与具体的 Key 无关 - 模型支持的能力是全局的,与具体的 Key 无关
- 如果模型不支持某能力,整个 Provider 的所有 Key 都应该被跳过 - 如果模型不支持某能力,整个 Provider 的所有 Key 都应该被跳过
映射回退逻辑:
- 如果存在模型映射mapping先尝试映射后的模型
- 如果映射后的模型因能力不满足而失败,回退尝试原模型
- 其他失败原因如模型不存在、Provider 未实现等)不触发回退
Args: Args:
db: 数据库会话 db: 数据库会话
provider: Provider 对象 provider: Provider 对象
model_name: 模型名称 model_name: 模型名称(必须是 GlobalModel.name
is_stream: 是否是流式请求,如果为 True 则同时检查流式支持 is_stream: 是否是流式请求,如果为 True 则同时检查流式支持
capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力 capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力
Returns: Returns:
(is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表 (is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表
""" """
from src.services.model.mapping_resolver import get_model_mapping_resolver from src.models.database import GlobalModel
mapping_resolver = get_model_mapping_resolver() # 直接通过 GlobalModel.name 查找
global_model = (
# 获取映射后的模型,同时检查是否发生了映射 db.query(GlobalModel)
global_model, is_mapped = await mapping_resolver.get_global_model_with_mapping_info( .filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
db, model_name, str(provider.id) .first()
) )
if not global_model: if not global_model:
return False, "模型不存在", None return False, "模型不存在", None
# 尝试检查映射后的模型 # 检查模型支持
is_supported, skip_reason, caps = await self._check_model_support_for_global_model( is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
db, provider, global_model, model_name, is_stream, capability_requirements db, provider, global_model, model_name, is_stream, capability_requirements
) )
# 如果映射后的模型因能力不满足而失败,且存在映射,则回退尝试原模型
if not is_supported and is_mapped and skip_reason and "不支持能力" in skip_reason:
logger.debug(
f"Provider {provider.name} 映射模型 {global_model.name} 能力不满足,"
f"回退尝试原模型 {model_name}"
)
# 获取原模型(不应用映射)
original_global_model = await mapping_resolver.get_global_model_direct(db, model_name)
if original_global_model and original_global_model.id != global_model.id:
# 尝试原模型
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
db,
provider,
original_global_model,
model_name,
is_stream,
capability_requirements,
)
if is_supported:
logger.debug(
f"Provider {provider.name} 原模型 {original_global_model.name} 支持所需能力"
)
return is_supported, skip_reason, caps return is_supported, skip_reason, caps
async def _check_model_support_for_global_model( async def _check_model_support_for_global_model(

View File

@@ -6,8 +6,7 @@
2. RedisCache: Redis 缓存(分布式) 2. RedisCache: Redis 缓存(分布式)
使用场景: 使用场景:
- ModelMappingResolver: 模型映射与别名解析缓存 - ModelCacheService: 模型解析缓存
- ModelMapper: 模型映射缓存
- 其他需要缓存的服务 - 其他需要缓存的服务
""" """

View File

@@ -3,18 +3,14 @@
统一管理各种缓存的失效逻辑,支持: 统一管理各种缓存的失效逻辑,支持:
1. GlobalModel 变更时失效相关缓存 1. GlobalModel 变更时失效相关缓存
2. ModelMapping 变更时失效别名/降级缓存 2. Model 变更时失效模型映射缓存
3. Model 变更时失效模型映射缓存 3. 支持同步和异步缓存后端
4. 支持同步和异步缓存后端
""" """
import asyncio
from typing import Optional from typing import Optional
from src.core.logger import logger from src.core.logger import logger
from src.core.logger import logger
class CacheInvalidationService: class CacheInvalidationService:
""" """
@@ -25,14 +21,8 @@ class CacheInvalidationService:
def __init__(self): def __init__(self):
"""初始化缓存失效服务""" """初始化缓存失效服务"""
self._mapping_resolver = None
self._model_mappers = [] # 可能有多个 ModelMapperMiddleware 实例 self._model_mappers = [] # 可能有多个 ModelMapperMiddleware 实例
def set_mapping_resolver(self, mapping_resolver):
"""设置模型映射解析器实例"""
self._mapping_resolver = mapping_resolver
logger.debug(f"[CacheInvalidation] 模型映射解析器已注册 (实例: {id(mapping_resolver)})")
def register_model_mapper(self, model_mapper): def register_model_mapper(self, model_mapper):
"""注册 ModelMapper 实例""" """注册 ModelMapper 实例"""
if model_mapper not in self._model_mappers: if model_mapper not in self._model_mappers:
@@ -48,37 +38,12 @@ class CacheInvalidationService:
""" """
logger.info(f"[CacheInvalidation] GlobalModel 变更: {model_name}") logger.info(f"[CacheInvalidation] GlobalModel 变更: {model_name}")
# 异步失效模型解析器中的缓存
if self._mapping_resolver:
asyncio.create_task(self._mapping_resolver.invalidate_global_model_cache())
# 失效所有 ModelMapper 中与此模型相关的缓存 # 失效所有 ModelMapper 中与此模型相关的缓存
for mapper in self._model_mappers: for mapper in self._model_mappers:
# 清空所有缓存(因为不知道哪些 provider 使用了这个模型) # 清空所有缓存(因为不知道哪些 provider 使用了这个模型)
mapper.clear_cache() mapper.clear_cache()
logger.debug(f"[CacheInvalidation] 已清空 ModelMapper 缓存") logger.debug(f"[CacheInvalidation] 已清空 ModelMapper 缓存")
def on_model_mapping_changed(self, source_model: str, provider_id: Optional[str] = None):
"""
ModelMapping 变更时的缓存失效
Args:
source_model: 变更的源模型名
provider_id: 相关 ProviderNone 表示全局)
"""
logger.info(f"[CacheInvalidation] ModelMapping 变更: {source_model} (provider={provider_id})")
if self._mapping_resolver:
asyncio.create_task(
self._mapping_resolver.invalidate_mapping_cache(source_model, provider_id)
)
for mapper in self._model_mappers:
if provider_id:
mapper.refresh_cache(provider_id)
else:
mapper.clear_cache()
def on_model_changed(self, provider_id: str, global_model_id: str): def on_model_changed(self, provider_id: str, global_model_id: str):
""" """
Model 变更时的缓存失效 Model 变更时的缓存失效
@@ -98,9 +63,6 @@ class CacheInvalidationService:
"""清空所有缓存""" """清空所有缓存"""
logger.info("[CacheInvalidation] 清空所有缓存") logger.info("[CacheInvalidation] 清空所有缓存")
if self._mapping_resolver:
asyncio.create_task(self._mapping_resolver.clear_cache())
for mapper in self._model_mappers: for mapper in self._model_mappers:
mapper.clear_cache() mapper.clear_cache()

View File

@@ -1,5 +1,5 @@
""" """
Model 映射缓存服务 - 减少模型映射和别名查询 Model 映射缓存服务 - 减少模型查询
""" """
from typing import Optional from typing import Optional
@@ -9,8 +9,7 @@ from sqlalchemy.orm import Session
from src.config.constants import CacheTTL from src.config.constants import CacheTTL
from src.core.cache_service import CacheService from src.core.cache_service import CacheService
from src.core.logger import logger from src.core.logger import logger
from src.models.database import GlobalModel, Model, ModelMapping from src.models.database import GlobalModel, Model
class ModelCacheService: class ModelCacheService:
@@ -158,56 +157,6 @@ class ModelCacheService:
return global_model return global_model
@staticmethod
async def resolve_alias(
db: Session, source_model: str, provider_id: Optional[str] = None
) -> Optional[str]:
"""
解析模型别名(带缓存)
Args:
db: 数据库会话
source_model: 源模型名称或别名
provider_id: Provider ID可选用于 Provider 特定别名)
Returns:
目标 GlobalModel ID 或 None
"""
# 构造缓存键
if provider_id:
cache_key = f"alias:provider:{provider_id}:{source_model}"
else:
cache_key = f"alias:global:{source_model}"
# 1. 尝试从缓存获取
cached_result = await CacheService.get(cache_key)
if cached_result:
logger.debug(f"别名缓存命中: {source_model} (provider: {provider_id or 'global'})")
return cached_result
# 2. 缓存未命中,查询数据库
query = db.query(ModelMapping).filter(ModelMapping.source_model == source_model)
if provider_id:
# Provider 特定别名优先
query = query.filter(ModelMapping.provider_id == provider_id)
else:
# 全局别名
query = query.filter(ModelMapping.provider_id.is_(None))
mapping = query.first()
# 3. 写入缓存
target_global_model_id = mapping.target_global_model_id if mapping else None
await CacheService.set(
cache_key, target_global_model_id, ttl_seconds=ModelCacheService.CACHE_TTL
)
if mapping:
logger.debug(f"别名已缓存: {source_model}{target_global_model_id}")
return target_global_model_id
@staticmethod @staticmethod
async def invalidate_model_cache( async def invalidate_model_cache(
model_id: str, provider_id: Optional[str] = None, global_model_id: Optional[str] = None model_id: str, provider_id: Optional[str] = None, global_model_id: Optional[str] = None
@@ -237,17 +186,6 @@ class ModelCacheService:
await CacheService.delete(f"global_model:name:{name}") await CacheService.delete(f"global_model:name:{name}")
logger.debug(f"GlobalModel 缓存已清除: {global_model_id}") logger.debug(f"GlobalModel 缓存已清除: {global_model_id}")
@staticmethod
async def invalidate_alias_cache(source_model: str, provider_id: Optional[str] = None):
"""清除别名缓存"""
if provider_id:
cache_key = f"alias:provider:{provider_id}:{source_model}"
else:
cache_key = f"alias:global:{source_model}"
await CacheService.delete(cache_key)
logger.debug(f"别名缓存已清除: {source_model}")
@staticmethod @staticmethod
def _model_to_dict(model: Model) -> dict: def _model_to_dict(model: Model) -> dict:
"""将 Model 对象转换为字典""" """将 Model 对象转换为字典"""
@@ -256,6 +194,7 @@ class ModelCacheService:
"provider_id": model.provider_id, "provider_id": model.provider_id,
"global_model_id": model.global_model_id, "global_model_id": model.global_model_id,
"provider_model_name": model.provider_model_name, "provider_model_name": model.provider_model_name,
"provider_model_aliases": getattr(model, "provider_model_aliases", None),
"is_active": model.is_active, "is_active": model.is_active,
"is_available": model.is_available if hasattr(model, "is_available") else True, "is_available": model.is_available if hasattr(model, "is_available") else True,
"price_per_request": ( "price_per_request": (
@@ -266,6 +205,7 @@ class ModelCacheService:
"supports_function_calling": model.supports_function_calling, "supports_function_calling": model.supports_function_calling,
"supports_streaming": model.supports_streaming, "supports_streaming": model.supports_streaming,
"supports_extended_thinking": model.supports_extended_thinking, "supports_extended_thinking": model.supports_extended_thinking,
"supports_image_generation": getattr(model, "supports_image_generation", None),
"config": model.config, "config": model.config,
} }
@@ -277,6 +217,7 @@ class ModelCacheService:
provider_id=model_dict["provider_id"], provider_id=model_dict["provider_id"],
global_model_id=model_dict["global_model_id"], global_model_id=model_dict["global_model_id"],
provider_model_name=model_dict["provider_model_name"], provider_model_name=model_dict["provider_model_name"],
provider_model_aliases=model_dict.get("provider_model_aliases"),
is_active=model_dict["is_active"], is_active=model_dict["is_active"],
is_available=model_dict.get("is_available", True), is_available=model_dict.get("is_available", True),
price_per_request=model_dict.get("price_per_request"), price_per_request=model_dict.get("price_per_request"),
@@ -285,6 +226,7 @@ class ModelCacheService:
supports_function_calling=model_dict.get("supports_function_calling"), supports_function_calling=model_dict.get("supports_function_calling"),
supports_streaming=model_dict.get("supports_streaming"), supports_streaming=model_dict.get("supports_streaming"),
supports_extended_thinking=model_dict.get("supports_extended_thinking"), supports_extended_thinking=model_dict.get("supports_extended_thinking"),
supports_image_generation=model_dict.get("supports_image_generation"),
config=model_dict.get("config"), config=model_dict.get("config"),
) )
return model return model
@@ -296,12 +238,11 @@ class ModelCacheService:
"id": global_model.id, "id": global_model.id,
"name": global_model.name, "name": global_model.name,
"display_name": global_model.display_name, "display_name": global_model.display_name,
"family": global_model.family, "default_supports_vision": global_model.default_supports_vision,
"group_id": global_model.group_id, "default_supports_function_calling": global_model.default_supports_function_calling,
"supports_vision": global_model.supports_vision, "default_supports_streaming": global_model.default_supports_streaming,
"supports_thinking": global_model.supports_thinking, "default_supports_extended_thinking": global_model.default_supports_extended_thinking,
"context_window": global_model.context_window, "default_supports_image_generation": global_model.default_supports_image_generation,
"max_output_tokens": global_model.max_output_tokens,
"is_active": global_model.is_active, "is_active": global_model.is_active,
"description": global_model.description, "description": global_model.description,
} }
@@ -313,12 +254,11 @@ class ModelCacheService:
id=global_model_dict["id"], id=global_model_dict["id"],
name=global_model_dict["name"], name=global_model_dict["name"],
display_name=global_model_dict.get("display_name"), display_name=global_model_dict.get("display_name"),
family=global_model_dict.get("family"), default_supports_vision=global_model_dict.get("default_supports_vision", False),
group_id=global_model_dict.get("group_id"), default_supports_function_calling=global_model_dict.get("default_supports_function_calling", False),
supports_vision=global_model_dict.get("supports_vision", False), default_supports_streaming=global_model_dict.get("default_supports_streaming", True),
supports_thinking=global_model_dict.get("supports_thinking", False), default_supports_extended_thinking=global_model_dict.get("default_supports_extended_thinking", False),
context_window=global_model_dict.get("context_window"), default_supports_image_generation=global_model_dict.get("default_supports_image_generation", False),
max_output_tokens=global_model_dict.get("max_output_tokens"),
is_active=global_model_dict.get("is_active", True), is_active=global_model_dict.get("is_active", True),
description=global_model_dict.get("description"), description=global_model_dict.get("description"),
) )

View File

@@ -6,7 +6,7 @@
使用场景: 使用场景:
1. 多实例部署时,确保所有实例的缓存一致性 1. 多实例部署时,确保所有实例的缓存一致性
2. GlobalModel/ModelMapping 变更时,同步失效所有实例的缓存 2. GlobalModel/Model 变更时,同步失效所有实例的缓存
""" """
import asyncio import asyncio
@@ -29,7 +29,6 @@ class CacheSyncService:
# Redis 频道名称 # Redis 频道名称
CHANNEL_GLOBAL_MODEL = "cache:invalidate:global_model" CHANNEL_GLOBAL_MODEL = "cache:invalidate:global_model"
CHANNEL_MODEL_MAPPING = "cache:invalidate:model_mapping"
CHANNEL_MODEL = "cache:invalidate:model" CHANNEL_MODEL = "cache:invalidate:model"
CHANNEL_CLEAR_ALL = "cache:invalidate:clear_all" CHANNEL_CLEAR_ALL = "cache:invalidate:clear_all"
@@ -58,7 +57,6 @@ class CacheSyncService:
# 订阅所有缓存失效频道 # 订阅所有缓存失效频道
await self._pubsub.subscribe( await self._pubsub.subscribe(
self.CHANNEL_GLOBAL_MODEL, self.CHANNEL_GLOBAL_MODEL,
self.CHANNEL_MODEL_MAPPING,
self.CHANNEL_MODEL, self.CHANNEL_MODEL,
self.CHANNEL_CLEAR_ALL, self.CHANNEL_CLEAR_ALL,
) )
@@ -68,7 +66,7 @@ class CacheSyncService:
self._running = True self._running = True
logger.info("[CacheSync] 缓存同步服务已启动,订阅频道: " logger.info("[CacheSync] 缓存同步服务已启动,订阅频道: "
f"{self.CHANNEL_GLOBAL_MODEL}, {self.CHANNEL_MODEL_MAPPING}, " f"{self.CHANNEL_GLOBAL_MODEL}, "
f"{self.CHANNEL_MODEL}, {self.CHANNEL_CLEAR_ALL}") f"{self.CHANNEL_MODEL}, {self.CHANNEL_CLEAR_ALL}")
except Exception as e: except Exception as e:
logger.error(f"[CacheSync] 启动失败: {e}") logger.error(f"[CacheSync] 启动失败: {e}")
@@ -141,14 +139,6 @@ class CacheSyncService:
"""发布 GlobalModel 变更通知""" """发布 GlobalModel 变更通知"""
await self._publish(self.CHANNEL_GLOBAL_MODEL, {"model_name": model_name}) await self._publish(self.CHANNEL_GLOBAL_MODEL, {"model_name": model_name})
async def publish_model_mapping_changed(
self, source_model: str, provider_id: Optional[str] = None
):
"""发布 ModelMapping 变更通知"""
await self._publish(
self.CHANNEL_MODEL_MAPPING, {"source_model": source_model, "provider_id": provider_id}
)
async def publish_model_changed(self, provider_id: str, global_model_id: str): async def publish_model_changed(self, provider_id: str, global_model_id: str):
"""发布 Model 变更通知""" """发布 Model 变更通知"""
await self._publish( await self._publish(

View File

@@ -1,19 +1,15 @@
""" """
模型服务模块 模型服务模块
包含模型管理、模型映射、成本计算等功能。 包含模型管理、成本计算等功能。
""" """
from src.services.model.cost import ModelCostService from src.services.model.cost import ModelCostService
from src.services.model.global_model import GlobalModelService from src.services.model.global_model import GlobalModelService
from src.services.model.mapper import ModelMapperMiddleware
from src.services.model.mapping_resolver import ModelMappingResolver
from src.services.model.service import ModelService from src.services.model.service import ModelService
__all__ = [ __all__ = [
"ModelService", "ModelService",
"GlobalModelService", "GlobalModelService",
"ModelMapperMiddleware",
"ModelMappingResolver",
"ModelCostService", "ModelCostService",
] ]

View File

@@ -14,7 +14,7 @@ from typing import Dict, Optional, Tuple, Union
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.core.logger import logger from src.core.logger import logger
from src.models.database import GlobalModel, Model, ModelMapping, Provider from src.models.database import GlobalModel, Model, Provider
ProviderRef = Union[str, Provider, None] ProviderRef = Union[str, Provider, None]
@@ -161,16 +161,11 @@ class ModelCostService:
result = None result = None
if provider_obj: if provider_obj:
from src.services.model.mapping_resolver import resolve_model_to_global_name # 直接通过 GlobalModel.name 查找
global_model_name = await resolve_model_to_global_name(
self.db, model, provider_obj.id
)
global_model = ( global_model = (
self.db.query(GlobalModel) self.db.query(GlobalModel)
.filter( .filter(
GlobalModel.name == global_model_name, GlobalModel.name == model,
GlobalModel.is_active == True, GlobalModel.is_active == True,
) )
.first() .first()
@@ -226,17 +221,14 @@ class ModelCostService:
注意:如果模型配置了阶梯计费,此方法返回第一个阶梯的价格作为默认值。 注意:如果模型配置了阶梯计费,此方法返回第一个阶梯的价格作为默认值。
实际计费时应使用 compute_cost_with_tiered_pricing 方法。 实际计费时应使用 compute_cost_with_tiered_pricing 方法。
计费逻辑(基于 mapping_type: 计费逻辑:
1. 查找 ModelMapping如果存在 1. 直接通过 GlobalModel.name 匹配
2. 如果 mapping_type='alias':使用目标 GlobalModel 的价格 2. 查找该 Provider 的 Model 实现
3. 如果 mapping_type='mapping':尝试使用 source_model 对应的 GlobalModel 价格 3. 获取价格配置
- 如果 source_model 对应的 GlobalModel 存在且有 Model 实现,使用那个价格
- 否则回退到目标 GlobalModel 的价格
4. 如果没有找到任何 ModelMapping尝试直接匹配 GlobalModel.name
Args: Args:
provider: Provider 对象或提供商名称 provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名 model: 用户请求的模型名(必须是 GlobalModel.name
Returns: Returns:
(input_price, output_price) 元组 (input_price, output_price) 元组
@@ -253,106 +245,7 @@ class ModelCostService:
output_price = None output_price = None
if provider_obj: if provider_obj:
# 步骤 1: 查找 ModelMapping 以确定 mapping_type # 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称)
from src.models.database import ModelMapping
mapping = None
# 先查 Provider 特定映射
mapping = (
self.db.query(ModelMapping)
.filter(
ModelMapping.source_model == model,
ModelMapping.provider_id == provider_obj.id,
ModelMapping.is_active == True,
)
.first()
)
# 再查全局映射
if not mapping:
mapping = (
self.db.query(ModelMapping)
.filter(
ModelMapping.source_model == model,
ModelMapping.provider_id.is_(None),
ModelMapping.is_active == True,
)
.first()
)
if mapping:
# 有映射,根据 mapping_type 决定计费模型
if mapping.mapping_type == "mapping":
# mapping 模式:尝试使用 source_model 对应的 GlobalModel 价格
source_global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == model,
GlobalModel.is_active == True,
)
.first()
)
if source_global_model:
source_model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == source_global_model.id,
Model.is_active == True,
)
.first()
)
if source_model_obj:
# 检查是否有阶梯计费
tiered = source_model_obj.get_effective_tiered_pricing()
if tiered and tiered.get("tiers"):
first_tier = tiered["tiers"][0]
input_price = first_tier.get("input_price_per_1m", 0)
output_price = first_tier.get("output_price_per_1m", 0)
else:
input_price = source_model_obj.get_effective_input_price()
output_price = source_model_obj.get_effective_output_price()
logger.debug(f"[mapping模式] 使用源模型价格: {model} "
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
# alias 模式或 mapping 模式未找到源模型价格:使用目标 GlobalModel 价格
if input_price is None:
target_global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.id == mapping.target_global_model_id,
GlobalModel.is_active == True,
)
.first()
)
if target_global_model:
target_model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == target_global_model.id,
Model.is_active == True,
)
.first()
)
if target_model_obj:
# 检查是否有阶梯计费
tiered = target_model_obj.get_effective_tiered_pricing()
if tiered and tiered.get("tiers"):
first_tier = tiered["tiers"][0]
input_price = first_tier.get("input_price_per_1m", 0)
output_price = first_tier.get("output_price_per_1m", 0)
else:
input_price = target_model_obj.get_effective_input_price()
output_price = target_model_obj.get_effective_output_price()
mode_label = (
"alias模式"
if mapping.mapping_type == "alias"
else "mapping模式(回退)"
)
logger.debug(f"[{mode_label}] 使用目标模型价格: {model} -> {target_global_model.name} "
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
else:
# 没有映射,尝试直接匹配 GlobalModel.name
global_model = ( global_model = (
self.db.query(GlobalModel) self.db.query(GlobalModel)
.filter( .filter(
@@ -404,15 +297,14 @@ class ModelCostService:
""" """
返回给定 provider/model 的 (input_price, output_price)。 返回给定 provider/model 的 (input_price, output_price)。
新架构逻辑: 逻辑:
1. 使用 ModelMappingResolver 解析别名(如果是) 1. 直接通过 GlobalModel.name 匹配
2. 解析为 GlobalModel.name 2. 查找该 Provider 的 Model 实现
3. 查找该 Provider 的 Model 实现 3. 获取价格配置
4. 获取价格配置
Args: Args:
provider: Provider 对象或提供商名称 provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名 model: 用户请求的模型名(必须是 GlobalModel.name
Returns: Returns:
(input_price, output_price) 元组 (input_price, output_price) 元组
@@ -434,15 +326,9 @@ class ModelCostService:
""" """
异步版本: 返回缓存创建/读取价格(每 1M tokens 异步版本: 返回缓存创建/读取价格(每 1M tokens
新架构逻辑:
1. 使用 ModelMappingResolver 解析别名(如果是)
2. 解析为 GlobalModel.name
3. 查找该 Provider 的 Model 实现
4. 获取缓存价格配置
Args: Args:
provider: Provider 对象或提供商名称 provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名 model: 用户请求的模型名(必须是 GlobalModel.name
input_price: 基础输入价格(用于 Claude 模型的默认估算) input_price: 基础输入价格(用于 Claude 模型的默认估算)
Returns: Returns:
@@ -460,22 +346,17 @@ class ModelCostService:
cache_read_price = None cache_read_price = None
if provider_obj: if provider_obj:
# 步骤 1: 检查是否是别名 # 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称)
from src.services.model.mapping_resolver import resolve_model_to_global_name
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
# 步骤 2: 查找 GlobalModel
global_model = ( global_model = (
self.db.query(GlobalModel) self.db.query(GlobalModel)
.filter( .filter(
GlobalModel.name == global_model_name, GlobalModel.name == model,
GlobalModel.is_active == True, GlobalModel.is_active == True,
) )
.first() .first()
) )
# 步骤 3: 查找该 Provider 的 Model 实现 # 查找该 Provider 的 Model 实现
if global_model: if global_model:
model_obj = ( model_obj = (
self.db.query(Model) self.db.query(Model)
@@ -517,15 +398,9 @@ class ModelCostService:
""" """
异步版本: 返回按次计费价格(每次请求的固定费用)。 异步版本: 返回按次计费价格(每次请求的固定费用)。
新架构逻辑:
1. 使用 ModelMappingResolver 解析别名(如果是)
2. 解析为 GlobalModel.name
3. 查找该 Provider 的 Model 实现
4. 获取按次计费价格配置
Args: Args:
provider: Provider 对象或提供商名称 provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名 model: 用户请求的模型名(必须是 GlobalModel.name
Returns: Returns:
按次计费价格,如果没有配置则返回 None 按次计费价格,如果没有配置则返回 None
@@ -534,22 +409,17 @@ class ModelCostService:
price_per_request = None price_per_request = None
if provider_obj: if provider_obj:
# 步骤 1: 检查是否是别名 # 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称)
from src.services.model.mapping_resolver import resolve_model_to_global_name
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
# 步骤 2: 查找 GlobalModel
global_model = ( global_model = (
self.db.query(GlobalModel) self.db.query(GlobalModel)
.filter( .filter(
GlobalModel.name == global_model_name, GlobalModel.name == model,
GlobalModel.is_active == True, GlobalModel.is_active == True,
) )
.first() .first()
) )
# 步骤 3: 查找该 Provider 的 Model 实现 # 查找该 Provider 的 Model 实现
if global_model: if global_model:
model_obj = ( model_obj = (
self.db.query(Model) self.db.query(Model)
@@ -595,15 +465,14 @@ class ModelCostService:
""" """
返回缓存创建/读取价格(每 1M tokens 返回缓存创建/读取价格(每 1M tokens
新架构逻辑: 逻辑:
1. 使用 ModelMappingResolver 解析别名(如果是) 1. 直接通过 GlobalModel.name 匹配
2. 解析为 GlobalModel.name 2. 查找该 Provider 的 Model 实现
3. 查找该 Provider 的 Model 实现 3. 获取缓存价格配置
4. 获取缓存价格配置
Args: Args:
provider: Provider 对象或提供商名称 provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名 model: 用户请求的模型名(必须是 GlobalModel.name
input_price: 基础输入价格(用于 Claude 模型的默认估算) input_price: 基础输入价格(用于 Claude 模型的默认估算)
Returns: Returns:

View File

@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, joinedload
from src.core.exceptions import InvalidRequestException, NotFoundException from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger from src.core.logger import logger
from src.models.database import GlobalModel, Model, ModelMapping from src.models.database import GlobalModel, Model
from src.models.pydantic_models import GlobalModelUpdate from src.models.pydantic_models import GlobalModelUpdate

View File

@@ -5,18 +5,13 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, joinedload
from src.core.cache_utils import SyncLRUCache from src.core.cache_utils import SyncLRUCache
from src.core.logger import logger from src.core.logger import logger
from src.models.claude import ClaudeMessagesRequest from src.models.claude import ClaudeMessagesRequest
from src.models.database import GlobalModel, Model, ModelMapping, Provider, ProviderEndpoint from src.models.database import GlobalModel, Model, Provider, ProviderEndpoint
from src.services.cache.model_cache import ModelCacheService from src.services.cache.model_cache import ModelCacheService
from src.services.model.mapping_resolver import (
get_model_mapping_resolver,
resolve_model_to_global_name,
)
class ModelMapperMiddleware: class ModelMapperMiddleware:
@@ -71,10 +66,10 @@ class ModelMapperMiddleware:
if mapping: if mapping:
# 应用映射 # 应用映射
original_model = request.model original_model = request.model
request.model = mapping.model.provider_model_name request.model = mapping.model.select_provider_model_name()
logger.debug(f"Applied model mapping for provider {provider.name}: " logger.debug(f"Applied model mapping for provider {provider.name}: "
f"{original_model} -> {mapping.model.provider_model_name}") f"{original_model} -> {request.model}")
else: else:
# 没有找到映射,使用原始模型名 # 没有找到映射,使用原始模型名
logger.debug(f"No model mapping found for {source_model} with provider {provider.name}, " logger.debug(f"No model mapping found for {source_model} with provider {provider.name}, "
@@ -84,17 +79,16 @@ class ModelMapperMiddleware:
async def get_mapping( async def get_mapping(
self, source_model: str, provider_id: str self, source_model: str, provider_id: str
) -> Optional[ModelMapping]: # UUID ) -> Optional[object]:
""" """
获取模型映射 获取模型映射
化后逻辑: 化后逻辑:
1. 使用统一的 ModelMappingResolver 解析别名(带缓存) 1. 通过 GlobalModel.name 直接查找
2. 通过 GlobalModel 找该 Provider 的 Model 实现 2. 找到 GlobalModel 后,查找该 Provider 的 Model 实现
3. 使用独立的映射缓存
Args: Args:
source_model: 用户请求的模型名或别名 source_model: 用户请求的模型名(必须是 GlobalModel.name
provider_id: 提供商ID (UUID) provider_id: 提供商ID (UUID)
Returns: Returns:
@@ -107,25 +101,25 @@ class ModelMapperMiddleware:
mapping = None mapping = None
# 步骤 1 & 2: 通过统一的模型映射解析服务 # 步骤 1: 直接通过名称查找 GlobalModel
mapping_resolver = get_model_mapping_resolver() global_model = (
global_model = await mapping_resolver.get_global_model_by_request( self.db.query(GlobalModel)
self.db, source_model, provider_id .filter(GlobalModel.name == source_model, GlobalModel.is_active == True)
.first()
) )
if not global_model: if not global_model:
logger.debug(f"GlobalModel not found: {source_model} (provider={provider_id[:8]}...)") logger.debug(f"GlobalModel not found: {source_model}")
self._cache[cache_key] = None self._cache[cache_key] = None
return None return None
# 步骤 3: 查找该 Provider 是否有实现这个 GlobalModel 的 Model使用缓存 # 步骤 2: 查找该 Provider 是否有实现这个 GlobalModel 的 Model使用缓存
model = await ModelCacheService.get_model_by_provider_and_global_model( model = await ModelCacheService.get_model_by_provider_and_global_model(
self.db, provider_id, global_model.id self.db, provider_id, global_model.id
) )
if model: if model:
# 只有当模型名发生变化时才返回映射 # 创建映射对象
if model.provider_model_name != source_model:
mapping = type( mapping = type(
"obj", "obj",
(object,), (object,),
@@ -139,30 +133,27 @@ class ModelMapperMiddleware:
logger.debug(f"Found model mapping: {source_model} -> {model.provider_model_name} " logger.debug(f"Found model mapping: {source_model} -> {model.provider_model_name} "
f"(provider={provider_id[:8]}...)") f"(provider={provider_id[:8]}...)")
else:
logger.debug(f"Model found but no name change: {source_model} (provider={provider_id[:8]}...)")
# 缓存结果 # 缓存结果
self._cache[cache_key] = mapping self._cache[cache_key] = mapping
return mapping return mapping
def get_all_mappings(self, provider_id: str) -> List[ModelMapping]: # UUID def get_all_mappings(self, provider_id: str) -> List[object]:
""" """
获取提供商的所有可用模型(通过 GlobalModel) 获取提供商的所有可用模型(通过 GlobalModel)
方案 A: 返回该 Provider 所有可用的 GlobalModel
Args: Args:
provider_id: 提供商ID (UUID) provider_id: 提供商ID (UUID)
Returns: Returns:
模型映射列表(模拟的 ModelMapping 对象列表) 模型映射列表
""" """
# 查询该 Provider 的所有活跃 Model # 查询该 Provider 的所有活跃 Model(使用 joinedload 避免 N+1
models = ( models = (
self.db.query(Model) self.db.query(Model)
.join(GlobalModel) .join(GlobalModel)
.options(joinedload(Model.global_model))
.filter( .filter(
Model.provider_id == provider_id, Model.provider_id == provider_id,
Model.is_active == True, Model.is_active == True,
@@ -171,7 +162,7 @@ class ModelMapperMiddleware:
.all() .all()
) )
# 构造兼容的 ModelMapping 对象列表 # 构造兼容的映射对象列表
mappings = [] mappings = []
for model in models: for model in models:
mapping = type( mapping = type(
@@ -188,7 +179,7 @@ class ModelMapperMiddleware:
return mappings return mappings
def get_supported_models(self, provider_id: str) -> List[str]: # UUID def get_supported_models(self, provider_id: str) -> List[str]:
""" """
获取提供商支持的所有源模型名 获取提供商支持的所有源模型名
@@ -223,15 +214,6 @@ class ModelMapperMiddleware:
if not mapping.is_active: if not mapping.is_active:
return False, f"Model mapping for {request.model} is disabled" return False, f"Model mapping for {request.model} is disabled"
# 不限制max_tokens作为中转服务不应该限制用户的请求
# if request.max_tokens and request.max_tokens > mapping.max_output_tokens:
# return False, (
# f"Requested max_tokens {request.max_tokens} exceeds limit "
# f"{mapping.max_output_tokens} for model {request.model}"
# )
# 可以添加更多验证逻辑,比如检查输入长度等
return True, None return True, None
def clear_cache(self): def clear_cache(self):
@@ -239,7 +221,7 @@ class ModelMapperMiddleware:
self._cache.clear() self._cache.clear()
logger.debug("Model mapping cache cleared") logger.debug("Model mapping cache cleared")
def refresh_cache(self, provider_id: Optional[str] = None): # UUID def refresh_cache(self, provider_id: Optional[str] = None):
""" """
刷新缓存 刷新缓存
@@ -285,16 +267,10 @@ class ModelRoutingMiddleware:
""" """
根据模型名选择提供商 根据模型名选择提供商
逻辑:
1. 如果指定了提供商,使用指定的提供商
2. 如果没指定,使用默认提供商
3. 选定提供商后会检查该提供商的模型映射在apply_mapping中处理
4. 如果指定了allowed_api_formats只选择符合格式的提供商
Args: Args:
model_name: 请求的模型名 model_name: 请求的模型名
preferred_provider: 首选提供商名称 preferred_provider: 首选提供商名称
allowed_api_formats: 允许的API格式列表(如 ['CLAUDE', 'CLAUDE_CLI'] allowed_api_formats: 允许的API格式列表
request_id: 请求ID用于日志关联 request_id: 请求ID用于日志关联
Returns: Returns:
@@ -313,14 +289,12 @@ class ModelRoutingMiddleware:
if provider: if provider:
# 检查API格式 - 从 endpoints 中检查 # 检查API格式 - 从 endpoints 中检查
if allowed_api_formats: if allowed_api_formats:
# 检查是否有符合要求的活跃端点
has_matching_endpoint = any( has_matching_endpoint = any(
ep.is_active and ep.api_format and ep.api_format in allowed_api_formats ep.is_active and ep.api_format and ep.api_format in allowed_api_formats
for ep in provider.endpoints for ep in provider.endpoints
) )
if not has_matching_endpoint: if not has_matching_endpoint:
logger.warning(f"Specified provider {provider.name} has no active endpoints with allowed API formats ({allowed_api_formats})") logger.warning(f"Specified provider {provider.name} has no active endpoints with allowed API formats ({allowed_api_formats})")
# 不返回该提供商,继续查找
else: else:
logger.debug(f" └─ {request_prefix}使用指定提供商: {provider.name} | 模型:{model_name}") logger.debug(f" └─ {request_prefix}使用指定提供商: {provider.name} | 模型:{model_name}")
return provider return provider
@@ -330,10 +304,9 @@ class ModelRoutingMiddleware:
else: else:
logger.warning(f"Specified provider {preferred_provider} not found or inactive") logger.warning(f"Specified provider {preferred_provider} not found or inactive")
# 2. 查找优先级最高的活动提供商provider_priority 最小) # 2. 查找优先级最高的活动提供商
query = self.db.query(Provider).filter(Provider.is_active == True) query = self.db.query(Provider).filter(Provider.is_active == True)
# 如果指定了API格式过滤添加过滤条件 - 检查是否有符合要求的 endpoint
if allowed_api_formats: if allowed_api_formats:
query = ( query = (
query.join(ProviderEndpoint) query.join(ProviderEndpoint)
@@ -344,32 +317,27 @@ class ModelRoutingMiddleware:
.distinct() .distinct()
) )
# 按 provider_priority 排序,优先级最高(数字最小)的在前
best_provider = query.order_by(Provider.provider_priority.asc(), Provider.id.asc()).first() best_provider = query.order_by(Provider.provider_priority.asc(), Provider.id.asc()).first()
if best_provider: if best_provider:
logger.debug(f" └─ {request_prefix}使用优先级最高提供商: {best_provider.name} (priority:{best_provider.provider_priority}) | 模型:{model_name}") logger.debug(f" └─ {request_prefix}使用优先级最高提供商: {best_provider.name} (priority:{best_provider.provider_priority}) | 模型:{model_name}")
return best_provider return best_provider
# 3. 没有任何活动提供商
if allowed_api_formats: if allowed_api_formats:
logger.error(f"No active providers found with allowed API formats {allowed_api_formats}. Please configure at least one provider.") logger.error(f"No active providers found with allowed API formats {allowed_api_formats}.")
else: else:
logger.error("No active providers found. Please configure at least one provider.") logger.error("No active providers found.")
return None return None
def get_available_models(self) -> Dict[str, List[str]]: def get_available_models(self) -> Dict[str, List[str]]:
""" """
获取所有可用的模型及其提供商 获取所有可用的模型及其提供商
方案 A: 基于 GlobalModel 查询
Returns: Returns:
字典,键为 GlobalModel.name值为支持该模型的提供商名列表 字典,键为 GlobalModel.name值为支持该模型的提供商名列表
""" """
result = {} result = {}
# 查询所有活跃的 GlobalModel 及其 Provider
models = ( models = (
self.db.query(GlobalModel.name, Provider.name) self.db.query(GlobalModel.name, Provider.name)
.join(Model, GlobalModel.id == Model.global_model_id) .join(Model, GlobalModel.id == Model.global_model_id)
@@ -392,28 +360,23 @@ class ModelRoutingMiddleware:
""" """
获取某个模型最便宜的提供商 获取某个模型最便宜的提供商
方案 A: 通过 GlobalModel 查找
Args: Args:
model_name: GlobalModel 名称或别名 model_name: GlobalModel 名称
Returns: Returns:
最便宜的提供商 最便宜的提供商
""" """
# 步骤 1: 解析模型名 # 直接查找 GlobalModel
global_model_name = await resolve_model_to_global_name(self.db, model_name)
# 步骤 2: 查找 GlobalModel
global_model = ( global_model = (
self.db.query(GlobalModel) self.db.query(GlobalModel)
.filter(GlobalModel.name == global_model_name, GlobalModel.is_active == True) .filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
.first() .first()
) )
if not global_model: if not global_model:
return None return None
# 步骤 3: 查询所有支持该模型的 Provider 及其价格 # 查询所有支持该模型的 Provider 及其价格
models_with_providers = ( models_with_providers = (
self.db.query(Provider, Model) self.db.query(Provider, Model)
.join(Model, Provider.id == Model.provider_id) .join(Model, Provider.id == Model.provider_id)
@@ -428,15 +391,16 @@ class ModelRoutingMiddleware:
if not models_with_providers: if not models_with_providers:
return None return None
# 按总价格排序(输入+输出价格) # 按总价格排序
cheapest = min( cheapest = min(
models_with_providers, key=lambda x: x[1].input_price_per_1m + x[1].output_price_per_1m models_with_providers,
key=lambda x: x[1].get_effective_input_price() + x[1].get_effective_output_price()
) )
provider = cheapest[0] provider = cheapest[0]
model = cheapest[1] model = cheapest[1]
logger.debug(f"Selected cheapest provider {provider.name} for model {model_name} " logger.debug(f"Selected cheapest provider {provider.name} for model {model_name} "
f"(input: ${model.input_price_per_1m}/M, output: ${model.output_price_per_1m}/M)") f"(input: ${model.get_effective_input_price()}/M, output: ${model.get_effective_output_price()}/M)")
return provider return provider

View File

@@ -50,6 +50,7 @@ class ModelService:
provider_id=provider_id, provider_id=provider_id,
global_model_id=model_data.global_model_id, global_model_id=model_data.global_model_id,
provider_model_name=model_data.provider_model_name, provider_model_name=model_data.provider_model_name,
provider_model_aliases=model_data.provider_model_aliases,
price_per_request=model_data.price_per_request, price_per_request=model_data.price_per_request,
tiered_pricing=model_data.tiered_pricing, tiered_pricing=model_data.tiered_pricing,
supports_vision=model_data.supports_vision, supports_vision=model_data.supports_vision,
@@ -191,7 +192,6 @@ class ModelService:
新架构删除逻辑: 新架构删除逻辑:
- Model 只是 Provider 对 GlobalModel 的实现,删除不影响 GlobalModel - Model 只是 Provider 对 GlobalModel 的实现,删除不影响 GlobalModel
- 检查是否是该 GlobalModel 的最后一个实现(如果是,警告但允许删除) - 检查是否是该 GlobalModel 的最后一个实现(如果是,警告但允许删除)
- 不检查 ModelMapping映射是 GlobalModel 之间的关系,别名也统一存储在此表中)
""" """
model = db.query(Model).filter(Model.id == model_id).first() model = db.query(Model).filter(Model.id == model_id).first()
if not model: if not model:
@@ -326,6 +326,7 @@ class ModelService:
provider_id=model.provider_id, provider_id=model.provider_id,
global_model_id=model.global_model_id, global_model_id=model.global_model_id,
provider_model_name=model.provider_model_name, provider_model_name=model.provider_model_name,
provider_model_aliases=model.provider_model_aliases,
# 原始配置值(可能为空) # 原始配置值(可能为空)
price_per_request=model.price_per_request, price_per_request=model.price_per_request,
tiered_pricing=model.tiered_pricing, tiered_pricing=model.tiered_pricing,

View File

@@ -7,13 +7,11 @@ from typing import Dict
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.core.logger import logger
from src.models.database import GlobalModel, Model, Provider from src.models.database import GlobalModel, Model, Provider
from src.services.model.cost import ModelCostService from src.services.model.cost import ModelCostService
from src.services.model.mapper import ModelMapperMiddleware, ModelRoutingMiddleware from src.services.model.mapper import ModelMapperMiddleware, ModelRoutingMiddleware
class ProviderService: class ProviderService:
"""提供商服务类""" """提供商服务类"""
@@ -34,30 +32,15 @@ class ProviderService:
检查模型是否可用(严格白名单模式) 检查模型是否可用(严格白名单模式)
Args: Args:
model_name: 模型名称 model_name: 模型名称(必须是 GlobalModel.name
Returns: Returns:
Model对象如果存在且激活否则None Model对象如果存在且激活否则None
""" """
# 首先检查是否有直接的模型记录 # 直接查找 GlobalModel
model = (
self.db.query(Model)
.filter(Model.provider_model_name == model_name, Model.is_active == True)
.first()
)
if model:
return model
# 方案 A检查是否是别名全局别名系统
from src.services.model.mapping_resolver import resolve_model_to_global_name
global_model_name = await resolve_model_to_global_name(self.db, model_name)
# 查找 GlobalModel
global_model = ( global_model = (
self.db.query(GlobalModel) self.db.query(GlobalModel)
.filter(GlobalModel.name == global_model_name, GlobalModel.is_active == True) .filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
.first() .first()
) )
@@ -79,34 +62,15 @@ class ProviderService:
Args: Args:
provider_id: 提供商ID provider_id: 提供商ID
model_name: 模型名称 model_name: 模型名称(必须是 GlobalModel.name
Returns: Returns:
Model对象如果该提供商支持该模型且激活否则None Model对象如果该提供商支持该模型且激活否则None
""" """
# 首先检查该提供商下是否有直接的模型记录 # 直接查找 GlobalModel
model = (
self.db.query(Model)
.filter(
Model.provider_id == provider_id,
Model.provider_model_name == model_name,
Model.is_active == True,
)
.first()
)
if model:
return model
# 方案 A检查是否是别名
from src.services.model.mapping_resolver import resolve_model_to_global_name
global_model_name = await resolve_model_to_global_name(self.db, model_name, provider_id)
# 查找 GlobalModel
global_model = ( global_model = (
self.db.query(GlobalModel) self.db.query(GlobalModel)
.filter(GlobalModel.name == global_model_name, GlobalModel.is_active == True) .filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
.first() .first()
) )
@@ -148,12 +112,19 @@ class ProviderService:
获取所有可用的模型 获取所有可用的模型
Returns: Returns:
模型和支持的提供商映射 字典,键为模型名,值为提供商列表
""" """
return self.router.get_available_models() return self.router.get_available_models()
def clear_cache(self): def select_provider(self, model_name: str, preferred_provider=None):
"""清空缓存""" """
self.mapper.clear_cache() 选择提供商
self.cost_service.clear_cache()
logger.info("Provider service cache cleared") Args:
model_name: 模型名
preferred_provider: 首选提供商
Returns:
Provider对象
"""
return self.router.select_provider(model_name, preferred_provider)

View File

@@ -26,11 +26,10 @@ class UsageService:
) -> tuple[float, float]: ) -> tuple[float, float]:
"""异步获取模型价格输入价格输出价格每1M tokens """异步获取模型价格输入价格输出价格每1M tokens
新架构查找逻辑: 查找逻辑:
1. 使用 ModelMappingResolver 解析别名(如果是) 1. 直接通过 GlobalModel.name 匹配
2. 解析为 GlobalModel.name 2. 查找该 Provider 的 Model 实现并获取价格
3. 查找该 Provider 的 Model 实现并获取价格 3. 如果找不到则使用系统默认价格
4. 如果找不到则使用系统默认价格
""" """
service = ModelCostService(db) service = ModelCostService(db)
@@ -40,11 +39,10 @@ class UsageService:
def get_model_price(cls, db: Session, provider: str, model: str) -> tuple[float, float]: def get_model_price(cls, db: Session, provider: str, model: str) -> tuple[float, float]:
"""获取模型价格输入价格输出价格每1M tokens """获取模型价格输入价格输出价格每1M tokens
新架构查找逻辑: 查找逻辑:
1. 使用 ModelMappingResolver 解析别名(如果是) 1. 直接通过 GlobalModel.name 匹配
2. 解析为 GlobalModel.name 2. 查找该 Provider 的 Model 实现并获取价格
3. 查找该 Provider 的 Model 实现并获取价格 3. 如果找不到则使用系统默认价格
4. 如果找不到则使用系统默认价格
""" """
service = ModelCostService(db) service = ModelCostService(db)