diff --git a/src/services/cache/aware_scheduler.py b/src/services/cache/aware_scheduler.py index 553f363..7c0d11a 100644 --- a/src/services/cache/aware_scheduler.py +++ b/src/services/cache/aware_scheduler.py @@ -36,7 +36,7 @@ from src.core.logger import logger from sqlalchemy.orm import Session, selectinload from src.core.enums import APIFormat -from src.core.exceptions import ProviderNotAvailableException +from src.core.exceptions import ModelNotSupportedException, ProviderNotAvailableException from src.models.database import ( ApiKey, Model, @@ -227,19 +227,6 @@ class CacheAwareScheduler: if provider_offset == 0: # 没有找到任何候选,提供友好的错误提示 error_msg = f"模型 '{model_name}' 不可用" - - # 查找相似模型 - from src.services.model.mapping_resolver import get_model_mapping_resolver - - resolver = get_model_mapping_resolver() - similar_models = resolver.find_similar_models(db, model_name, limit=3) - - if similar_models: - suggestions = [ - f"{name} (相似度: {score:.0%})" for name, score in similar_models - ] - error_msg += f"\n\n您可能想使用以下模型:\n - " + "\n - ".join(suggestions) - raise ProviderNotAvailableException(error_msg) break @@ -579,13 +566,20 @@ class CacheAwareScheduler: target_format = normalize_api_format(api_format) - # 0. 解析 model_name 到 GlobalModel(用于缓存亲和性的规范化标识) - from src.services.model.mapping_resolver import get_model_mapping_resolver - mapping_resolver = get_model_mapping_resolver() - global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None) + # 0. 解析 model_name 到 GlobalModel(直接查找,用户必须使用标准名称) + from src.models.database import GlobalModel + global_model = ( + db.query(GlobalModel) + .filter(GlobalModel.name == model_name, GlobalModel.is_active == True) + .first() + ) + + if not global_model: + logger.warning(f"GlobalModel not found: {model_name}") + raise ModelNotSupportedException(model=model_name) # 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存 - global_model_id: str = str(global_model.id) if global_model else model_name + global_model_id: str = str(global_model.id) # 获取合并后的访问限制(ApiKey + User) restrictions = self._get_effective_restrictions(user_api_key) @@ -685,7 +679,6 @@ class CacheAwareScheduler: db.query(Provider) .options( selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys), - selectinload(Provider.model_mappings), # 同时加载 models 和 global_model 关系,以便 get_effective_* 方法能正确继承默认值 selectinload(Provider.models).selectinload(Model.global_model), ) @@ -715,63 +708,33 @@ class CacheAwareScheduler: - 模型支持的能力是全局的,与具体的 Key 无关 - 如果模型不支持某能力,整个 Provider 的所有 Key 都应该被跳过 - 映射回退逻辑: - - 如果存在模型映射(mapping),先尝试映射后的模型 - - 如果映射后的模型因能力不满足而失败,回退尝试原模型 - - 其他失败原因(如模型不存在、Provider 未实现等)不触发回退 - Args: db: 数据库会话 provider: Provider 对象 - model_name: 模型名称 + model_name: 模型名称(必须是 GlobalModel.name) is_stream: 是否是流式请求,如果为 True 则同时检查流式支持 capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力 Returns: (is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表 """ - from src.services.model.mapping_resolver import get_model_mapping_resolver + from src.models.database import GlobalModel - mapping_resolver = get_model_mapping_resolver() - - # 获取映射后的模型,同时检查是否发生了映射 - global_model, is_mapped = await mapping_resolver.get_global_model_with_mapping_info( - db, model_name, str(provider.id) + # 直接通过 GlobalModel.name 查找 + global_model = ( + db.query(GlobalModel) + .filter(GlobalModel.name == model_name, GlobalModel.is_active == True) + .first() ) if not global_model: return False, "模型不存在", None - # 尝试检查映射后的模型 + # 检查模型支持 is_supported, skip_reason, caps = await self._check_model_support_for_global_model( db, provider, global_model, model_name, is_stream, capability_requirements ) - # 如果映射后的模型因能力不满足而失败,且存在映射,则回退尝试原模型 - if not is_supported and is_mapped and skip_reason and "不支持能力" in skip_reason: - logger.debug( - f"Provider {provider.name} 映射模型 {global_model.name} 能力不满足," - f"回退尝试原模型 {model_name}" - ) - - # 获取原模型(不应用映射) - original_global_model = await mapping_resolver.get_global_model_direct(db, model_name) - - if original_global_model and original_global_model.id != global_model.id: - # 尝试原模型 - is_supported, skip_reason, caps = await self._check_model_support_for_global_model( - db, - provider, - original_global_model, - model_name, - is_stream, - capability_requirements, - ) - if is_supported: - logger.debug( - f"Provider {provider.name} 原模型 {original_global_model.name} 支持所需能力" - ) - return is_supported, skip_reason, caps async def _check_model_support_for_global_model( diff --git a/src/services/cache/backend.py b/src/services/cache/backend.py index 0ab384c..b6d859a 100644 --- a/src/services/cache/backend.py +++ b/src/services/cache/backend.py @@ -6,8 +6,7 @@ 2. RedisCache: Redis 缓存(分布式) 使用场景: -- ModelMappingResolver: 模型映射与别名解析缓存 -- ModelMapper: 模型映射缓存 +- ModelCacheService: 模型解析缓存 - 其他需要缓存的服务 """ diff --git a/src/services/cache/invalidation.py b/src/services/cache/invalidation.py index 7e06b30..c789238 100644 --- a/src/services/cache/invalidation.py +++ b/src/services/cache/invalidation.py @@ -3,18 +3,14 @@ 统一管理各种缓存的失效逻辑,支持: 1. GlobalModel 变更时失效相关缓存 -2. ModelMapping 变更时失效别名/降级缓存 -3. Model 变更时失效模型映射缓存 -4. 支持同步和异步缓存后端 +2. Model 变更时失效模型映射缓存 +3. 支持同步和异步缓存后端 """ -import asyncio from typing import Optional from src.core.logger import logger -from src.core.logger import logger - class CacheInvalidationService: """ @@ -25,14 +21,8 @@ class CacheInvalidationService: def __init__(self): """初始化缓存失效服务""" - self._mapping_resolver = None self._model_mappers = [] # 可能有多个 ModelMapperMiddleware 实例 - def set_mapping_resolver(self, mapping_resolver): - """设置模型映射解析器实例""" - self._mapping_resolver = mapping_resolver - logger.debug(f"[CacheInvalidation] 模型映射解析器已注册 (实例: {id(mapping_resolver)})") - def register_model_mapper(self, model_mapper): """注册 ModelMapper 实例""" if model_mapper not in self._model_mappers: @@ -48,37 +38,12 @@ class CacheInvalidationService: """ logger.info(f"[CacheInvalidation] GlobalModel 变更: {model_name}") - # 异步失效模型解析器中的缓存 - if self._mapping_resolver: - asyncio.create_task(self._mapping_resolver.invalidate_global_model_cache()) - # 失效所有 ModelMapper 中与此模型相关的缓存 for mapper in self._model_mappers: # 清空所有缓存(因为不知道哪些 provider 使用了这个模型) mapper.clear_cache() logger.debug(f"[CacheInvalidation] 已清空 ModelMapper 缓存") - def on_model_mapping_changed(self, source_model: str, provider_id: Optional[str] = None): - """ - ModelMapping 变更时的缓存失效 - - Args: - source_model: 变更的源模型名 - provider_id: 相关 Provider(None 表示全局) - """ - logger.info(f"[CacheInvalidation] ModelMapping 变更: {source_model} (provider={provider_id})") - - if self._mapping_resolver: - asyncio.create_task( - self._mapping_resolver.invalidate_mapping_cache(source_model, provider_id) - ) - - for mapper in self._model_mappers: - if provider_id: - mapper.refresh_cache(provider_id) - else: - mapper.clear_cache() - def on_model_changed(self, provider_id: str, global_model_id: str): """ Model 变更时的缓存失效 @@ -98,9 +63,6 @@ class CacheInvalidationService: """清空所有缓存""" logger.info("[CacheInvalidation] 清空所有缓存") - if self._mapping_resolver: - asyncio.create_task(self._mapping_resolver.clear_cache()) - for mapper in self._model_mappers: mapper.clear_cache() diff --git a/src/services/cache/model_cache.py b/src/services/cache/model_cache.py index 391100c..1f9486b 100644 --- a/src/services/cache/model_cache.py +++ b/src/services/cache/model_cache.py @@ -1,5 +1,5 @@ """ -Model 映射缓存服务 - 减少模型映射和别名查询 +Model 映射缓存服务 - 减少模型查询 """ from typing import Optional @@ -9,8 +9,7 @@ from sqlalchemy.orm import Session from src.config.constants import CacheTTL from src.core.cache_service import CacheService from src.core.logger import logger -from src.models.database import GlobalModel, Model, ModelMapping - +from src.models.database import GlobalModel, Model class ModelCacheService: @@ -158,56 +157,6 @@ class ModelCacheService: return global_model - @staticmethod - async def resolve_alias( - db: Session, source_model: str, provider_id: Optional[str] = None - ) -> Optional[str]: - """ - 解析模型别名(带缓存) - - Args: - db: 数据库会话 - source_model: 源模型名称或别名 - provider_id: Provider ID(可选,用于 Provider 特定别名) - - Returns: - 目标 GlobalModel ID 或 None - """ - # 构造缓存键 - if provider_id: - cache_key = f"alias:provider:{provider_id}:{source_model}" - else: - cache_key = f"alias:global:{source_model}" - - # 1. 尝试从缓存获取 - cached_result = await CacheService.get(cache_key) - if cached_result: - logger.debug(f"别名缓存命中: {source_model} (provider: {provider_id or 'global'})") - return cached_result - - # 2. 缓存未命中,查询数据库 - query = db.query(ModelMapping).filter(ModelMapping.source_model == source_model) - - if provider_id: - # Provider 特定别名优先 - query = query.filter(ModelMapping.provider_id == provider_id) - else: - # 全局别名 - query = query.filter(ModelMapping.provider_id.is_(None)) - - mapping = query.first() - - # 3. 写入缓存 - target_global_model_id = mapping.target_global_model_id if mapping else None - await CacheService.set( - cache_key, target_global_model_id, ttl_seconds=ModelCacheService.CACHE_TTL - ) - - if mapping: - logger.debug(f"别名已缓存: {source_model} → {target_global_model_id}") - - return target_global_model_id - @staticmethod async def invalidate_model_cache( model_id: str, provider_id: Optional[str] = None, global_model_id: Optional[str] = None @@ -237,17 +186,6 @@ class ModelCacheService: await CacheService.delete(f"global_model:name:{name}") logger.debug(f"GlobalModel 缓存已清除: {global_model_id}") - @staticmethod - async def invalidate_alias_cache(source_model: str, provider_id: Optional[str] = None): - """清除别名缓存""" - if provider_id: - cache_key = f"alias:provider:{provider_id}:{source_model}" - else: - cache_key = f"alias:global:{source_model}" - - await CacheService.delete(cache_key) - logger.debug(f"别名缓存已清除: {source_model}") - @staticmethod def _model_to_dict(model: Model) -> dict: """将 Model 对象转换为字典""" @@ -256,6 +194,7 @@ class ModelCacheService: "provider_id": model.provider_id, "global_model_id": model.global_model_id, "provider_model_name": model.provider_model_name, + "provider_model_aliases": getattr(model, "provider_model_aliases", None), "is_active": model.is_active, "is_available": model.is_available if hasattr(model, "is_available") else True, "price_per_request": ( @@ -266,6 +205,7 @@ class ModelCacheService: "supports_function_calling": model.supports_function_calling, "supports_streaming": model.supports_streaming, "supports_extended_thinking": model.supports_extended_thinking, + "supports_image_generation": getattr(model, "supports_image_generation", None), "config": model.config, } @@ -277,6 +217,7 @@ class ModelCacheService: provider_id=model_dict["provider_id"], global_model_id=model_dict["global_model_id"], provider_model_name=model_dict["provider_model_name"], + provider_model_aliases=model_dict.get("provider_model_aliases"), is_active=model_dict["is_active"], is_available=model_dict.get("is_available", True), price_per_request=model_dict.get("price_per_request"), @@ -285,6 +226,7 @@ class ModelCacheService: supports_function_calling=model_dict.get("supports_function_calling"), supports_streaming=model_dict.get("supports_streaming"), supports_extended_thinking=model_dict.get("supports_extended_thinking"), + supports_image_generation=model_dict.get("supports_image_generation"), config=model_dict.get("config"), ) return model @@ -296,12 +238,11 @@ class ModelCacheService: "id": global_model.id, "name": global_model.name, "display_name": global_model.display_name, - "family": global_model.family, - "group_id": global_model.group_id, - "supports_vision": global_model.supports_vision, - "supports_thinking": global_model.supports_thinking, - "context_window": global_model.context_window, - "max_output_tokens": global_model.max_output_tokens, + "default_supports_vision": global_model.default_supports_vision, + "default_supports_function_calling": global_model.default_supports_function_calling, + "default_supports_streaming": global_model.default_supports_streaming, + "default_supports_extended_thinking": global_model.default_supports_extended_thinking, + "default_supports_image_generation": global_model.default_supports_image_generation, "is_active": global_model.is_active, "description": global_model.description, } @@ -313,12 +254,11 @@ class ModelCacheService: id=global_model_dict["id"], name=global_model_dict["name"], display_name=global_model_dict.get("display_name"), - family=global_model_dict.get("family"), - group_id=global_model_dict.get("group_id"), - supports_vision=global_model_dict.get("supports_vision", False), - supports_thinking=global_model_dict.get("supports_thinking", False), - context_window=global_model_dict.get("context_window"), - max_output_tokens=global_model_dict.get("max_output_tokens"), + default_supports_vision=global_model_dict.get("default_supports_vision", False), + default_supports_function_calling=global_model_dict.get("default_supports_function_calling", False), + default_supports_streaming=global_model_dict.get("default_supports_streaming", True), + default_supports_extended_thinking=global_model_dict.get("default_supports_extended_thinking", False), + default_supports_image_generation=global_model_dict.get("default_supports_image_generation", False), is_active=global_model_dict.get("is_active", True), description=global_model_dict.get("description"), ) diff --git a/src/services/cache/sync.py b/src/services/cache/sync.py index e3b1c93..47a6e70 100644 --- a/src/services/cache/sync.py +++ b/src/services/cache/sync.py @@ -6,7 +6,7 @@ 使用场景: 1. 多实例部署时,确保所有实例的缓存一致性 -2. GlobalModel/ModelMapping 变更时,同步失效所有实例的缓存 +2. GlobalModel/Model 变更时,同步失效所有实例的缓存 """ import asyncio @@ -29,7 +29,6 @@ class CacheSyncService: # Redis 频道名称 CHANNEL_GLOBAL_MODEL = "cache:invalidate:global_model" - CHANNEL_MODEL_MAPPING = "cache:invalidate:model_mapping" CHANNEL_MODEL = "cache:invalidate:model" CHANNEL_CLEAR_ALL = "cache:invalidate:clear_all" @@ -58,7 +57,6 @@ class CacheSyncService: # 订阅所有缓存失效频道 await self._pubsub.subscribe( self.CHANNEL_GLOBAL_MODEL, - self.CHANNEL_MODEL_MAPPING, self.CHANNEL_MODEL, self.CHANNEL_CLEAR_ALL, ) @@ -68,7 +66,7 @@ class CacheSyncService: self._running = True logger.info("[CacheSync] 缓存同步服务已启动,订阅频道: " - f"{self.CHANNEL_GLOBAL_MODEL}, {self.CHANNEL_MODEL_MAPPING}, " + f"{self.CHANNEL_GLOBAL_MODEL}, " f"{self.CHANNEL_MODEL}, {self.CHANNEL_CLEAR_ALL}") except Exception as e: logger.error(f"[CacheSync] 启动失败: {e}") @@ -141,14 +139,6 @@ class CacheSyncService: """发布 GlobalModel 变更通知""" await self._publish(self.CHANNEL_GLOBAL_MODEL, {"model_name": model_name}) - async def publish_model_mapping_changed( - self, source_model: str, provider_id: Optional[str] = None - ): - """发布 ModelMapping 变更通知""" - await self._publish( - self.CHANNEL_MODEL_MAPPING, {"source_model": source_model, "provider_id": provider_id} - ) - async def publish_model_changed(self, provider_id: str, global_model_id: str): """发布 Model 变更通知""" await self._publish( diff --git a/src/services/model/__init__.py b/src/services/model/__init__.py index f1c628f..25f3316 100644 --- a/src/services/model/__init__.py +++ b/src/services/model/__init__.py @@ -1,19 +1,15 @@ """ 模型服务模块 -包含模型管理、模型映射、成本计算等功能。 +包含模型管理、成本计算等功能。 """ from src.services.model.cost import ModelCostService from src.services.model.global_model import GlobalModelService -from src.services.model.mapper import ModelMapperMiddleware -from src.services.model.mapping_resolver import ModelMappingResolver from src.services.model.service import ModelService __all__ = [ "ModelService", "GlobalModelService", - "ModelMapperMiddleware", - "ModelMappingResolver", "ModelCostService", ] diff --git a/src/services/model/cost.py b/src/services/model/cost.py index 30897cc..4413ae7 100644 --- a/src/services/model/cost.py +++ b/src/services/model/cost.py @@ -14,7 +14,7 @@ from typing import Dict, Optional, Tuple, Union from sqlalchemy.orm import Session from src.core.logger import logger -from src.models.database import GlobalModel, Model, ModelMapping, Provider +from src.models.database import GlobalModel, Model, Provider ProviderRef = Union[str, Provider, None] @@ -161,16 +161,11 @@ class ModelCostService: result = None if provider_obj: - from src.services.model.mapping_resolver import resolve_model_to_global_name - - global_model_name = await resolve_model_to_global_name( - self.db, model, provider_obj.id - ) - + # 直接通过 GlobalModel.name 查找 global_model = ( self.db.query(GlobalModel) .filter( - GlobalModel.name == global_model_name, + GlobalModel.name == model, GlobalModel.is_active == True, ) .first() @@ -226,17 +221,14 @@ class ModelCostService: 注意:如果模型配置了阶梯计费,此方法返回第一个阶梯的价格作为默认值。 实际计费时应使用 compute_cost_with_tiered_pricing 方法。 - 计费逻辑(基于 mapping_type): - 1. 查找 ModelMapping(如果存在) - 2. 如果 mapping_type='alias':使用目标 GlobalModel 的价格 - 3. 如果 mapping_type='mapping':尝试使用 source_model 对应的 GlobalModel 价格 - - 如果 source_model 对应的 GlobalModel 存在且有 Model 实现,使用那个价格 - - 否则回退到目标 GlobalModel 的价格 - 4. 如果没有找到任何 ModelMapping,尝试直接匹配 GlobalModel.name + 计费逻辑: + 1. 直接通过 GlobalModel.name 匹配 + 2. 查找该 Provider 的 Model 实现 + 3. 获取价格配置 Args: provider: Provider 对象或提供商名称 - model: 用户请求的模型名(可能是 GlobalModel.name 或别名) + model: 用户请求的模型名(必须是 GlobalModel.name) Returns: (input_price, output_price) 元组 @@ -253,136 +245,37 @@ class ModelCostService: output_price = None if provider_obj: - # 步骤 1: 查找 ModelMapping 以确定 mapping_type - from src.models.database import ModelMapping - - mapping = None - # 先查 Provider 特定映射 - mapping = ( - self.db.query(ModelMapping) + # 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称) + global_model = ( + self.db.query(GlobalModel) .filter( - ModelMapping.source_model == model, - ModelMapping.provider_id == provider_obj.id, - ModelMapping.is_active == True, + GlobalModel.name == model, + GlobalModel.is_active == True, ) .first() ) - # 再查全局映射 - if not mapping: - mapping = ( - self.db.query(ModelMapping) + if global_model: + model_obj = ( + self.db.query(Model) .filter( - ModelMapping.source_model == model, - ModelMapping.provider_id.is_(None), - ModelMapping.is_active == True, + Model.provider_id == provider_obj.id, + Model.global_model_id == global_model.id, + Model.is_active == True, ) .first() ) - - if mapping: - # 有映射,根据 mapping_type 决定计费模型 - if mapping.mapping_type == "mapping": - # mapping 模式:尝试使用 source_model 对应的 GlobalModel 价格 - source_global_model = ( - self.db.query(GlobalModel) - .filter( - GlobalModel.name == model, - GlobalModel.is_active == True, - ) - .first() - ) - if source_global_model: - source_model_obj = ( - self.db.query(Model) - .filter( - Model.provider_id == provider_obj.id, - Model.global_model_id == source_global_model.id, - Model.is_active == True, - ) - .first() - ) - if source_model_obj: - # 检查是否有阶梯计费 - tiered = source_model_obj.get_effective_tiered_pricing() - if tiered and tiered.get("tiers"): - first_tier = tiered["tiers"][0] - input_price = first_tier.get("input_price_per_1m", 0) - output_price = first_tier.get("output_price_per_1m", 0) - else: - input_price = source_model_obj.get_effective_input_price() - output_price = source_model_obj.get_effective_output_price() - logger.debug(f"[mapping模式] 使用源模型价格: {model} " - f"(输入: ${input_price}/M, 输出: ${output_price}/M)") - - # alias 模式或 mapping 模式未找到源模型价格:使用目标 GlobalModel 价格 - if input_price is None: - target_global_model = ( - self.db.query(GlobalModel) - .filter( - GlobalModel.id == mapping.target_global_model_id, - GlobalModel.is_active == True, - ) - .first() - ) - if target_global_model: - target_model_obj = ( - self.db.query(Model) - .filter( - Model.provider_id == provider_obj.id, - Model.global_model_id == target_global_model.id, - Model.is_active == True, - ) - .first() - ) - if target_model_obj: - # 检查是否有阶梯计费 - tiered = target_model_obj.get_effective_tiered_pricing() - if tiered and tiered.get("tiers"): - first_tier = tiered["tiers"][0] - input_price = first_tier.get("input_price_per_1m", 0) - output_price = first_tier.get("output_price_per_1m", 0) - else: - input_price = target_model_obj.get_effective_input_price() - output_price = target_model_obj.get_effective_output_price() - mode_label = ( - "alias模式" - if mapping.mapping_type == "alias" - else "mapping模式(回退)" - ) - logger.debug(f"[{mode_label}] 使用目标模型价格: {model} -> {target_global_model.name} " - f"(输入: ${input_price}/M, 输出: ${output_price}/M)") - else: - # 没有映射,尝试直接匹配 GlobalModel.name - global_model = ( - self.db.query(GlobalModel) - .filter( - GlobalModel.name == model, - GlobalModel.is_active == True, - ) - .first() - ) - if global_model: - model_obj = ( - self.db.query(Model) - .filter( - Model.provider_id == provider_obj.id, - Model.global_model_id == global_model.id, - Model.is_active == True, - ) - .first() - ) - if model_obj: - # 检查是否有阶梯计费 - tiered = model_obj.get_effective_tiered_pricing() - if tiered and tiered.get("tiers"): - first_tier = tiered["tiers"][0] - input_price = first_tier.get("input_price_per_1m", 0) - output_price = first_tier.get("output_price_per_1m", 0) - else: - input_price = model_obj.get_effective_input_price() - output_price = model_obj.get_effective_output_price() - logger.debug(f"找到模型价格配置: {provider_name}/{model} " - f"(输入: ${input_price}/M, 输出: ${output_price}/M)") + if model_obj: + # 检查是否有阶梯计费 + tiered = model_obj.get_effective_tiered_pricing() + if tiered and tiered.get("tiers"): + first_tier = tiered["tiers"][0] + input_price = first_tier.get("input_price_per_1m", 0) + output_price = first_tier.get("output_price_per_1m", 0) + else: + input_price = model_obj.get_effective_input_price() + output_price = model_obj.get_effective_output_price() + logger.debug(f"找到模型价格配置: {provider_name}/{model} " + f"(输入: ${input_price}/M, 输出: ${output_price}/M)") # 如果没有找到价格配置,使用 0.0 并记录警告 if input_price is None: @@ -404,15 +297,14 @@ class ModelCostService: """ 返回给定 provider/model 的 (input_price, output_price)。 - 新架构逻辑: - 1. 使用 ModelMappingResolver 解析别名(如果是) - 2. 解析为 GlobalModel.name - 3. 查找该 Provider 的 Model 实现 - 4. 获取价格配置 + 逻辑: + 1. 直接通过 GlobalModel.name 匹配 + 2. 查找该 Provider 的 Model 实现 + 3. 获取价格配置 Args: provider: Provider 对象或提供商名称 - model: 用户请求的模型名(可能是 GlobalModel.name 或别名) + model: 用户请求的模型名(必须是 GlobalModel.name) Returns: (input_price, output_price) 元组 @@ -434,15 +326,9 @@ class ModelCostService: """ 异步版本: 返回缓存创建/读取价格(每 1M tokens)。 - 新架构逻辑: - 1. 使用 ModelMappingResolver 解析别名(如果是) - 2. 解析为 GlobalModel.name - 3. 查找该 Provider 的 Model 实现 - 4. 获取缓存价格配置 - Args: provider: Provider 对象或提供商名称 - model: 用户请求的模型名(可能是 GlobalModel.name 或别名) + model: 用户请求的模型名(必须是 GlobalModel.name) input_price: 基础输入价格(用于 Claude 模型的默认估算) Returns: @@ -460,22 +346,17 @@ class ModelCostService: cache_read_price = None if provider_obj: - # 步骤 1: 检查是否是别名 - from src.services.model.mapping_resolver import resolve_model_to_global_name - - global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id) - - # 步骤 2: 查找 GlobalModel + # 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称) global_model = ( self.db.query(GlobalModel) .filter( - GlobalModel.name == global_model_name, + GlobalModel.name == model, GlobalModel.is_active == True, ) .first() ) - # 步骤 3: 查找该 Provider 的 Model 实现 + # 查找该 Provider 的 Model 实现 if global_model: model_obj = ( self.db.query(Model) @@ -517,15 +398,9 @@ class ModelCostService: """ 异步版本: 返回按次计费价格(每次请求的固定费用)。 - 新架构逻辑: - 1. 使用 ModelMappingResolver 解析别名(如果是) - 2. 解析为 GlobalModel.name - 3. 查找该 Provider 的 Model 实现 - 4. 获取按次计费价格配置 - Args: provider: Provider 对象或提供商名称 - model: 用户请求的模型名(可能是 GlobalModel.name 或别名) + model: 用户请求的模型名(必须是 GlobalModel.name) Returns: 按次计费价格,如果没有配置则返回 None @@ -534,22 +409,17 @@ class ModelCostService: price_per_request = None if provider_obj: - # 步骤 1: 检查是否是别名 - from src.services.model.mapping_resolver import resolve_model_to_global_name - - global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id) - - # 步骤 2: 查找 GlobalModel + # 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称) global_model = ( self.db.query(GlobalModel) .filter( - GlobalModel.name == global_model_name, + GlobalModel.name == model, GlobalModel.is_active == True, ) .first() ) - # 步骤 3: 查找该 Provider 的 Model 实现 + # 查找该 Provider 的 Model 实现 if global_model: model_obj = ( self.db.query(Model) @@ -595,15 +465,14 @@ class ModelCostService: """ 返回缓存创建/读取价格(每 1M tokens)。 - 新架构逻辑: - 1. 使用 ModelMappingResolver 解析别名(如果是) - 2. 解析为 GlobalModel.name - 3. 查找该 Provider 的 Model 实现 - 4. 获取缓存价格配置 + 逻辑: + 1. 直接通过 GlobalModel.name 匹配 + 2. 查找该 Provider 的 Model 实现 + 3. 获取缓存价格配置 Args: provider: Provider 对象或提供商名称 - model: 用户请求的模型名(可能是 GlobalModel.name 或别名) + model: 用户请求的模型名(必须是 GlobalModel.name) input_price: 基础输入价格(用于 Claude 模型的默认估算) Returns: diff --git a/src/services/model/global_model.py b/src/services/model/global_model.py index 5d5f554..ef98614 100644 --- a/src/services/model/global_model.py +++ b/src/services/model/global_model.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, joinedload from src.core.exceptions import InvalidRequestException, NotFoundException from src.core.logger import logger -from src.models.database import GlobalModel, Model, ModelMapping +from src.models.database import GlobalModel, Model from src.models.pydantic_models import GlobalModelUpdate diff --git a/src/services/model/mapper.py b/src/services/model/mapper.py index e1be4f9..c14f59f 100644 --- a/src/services/model/mapper.py +++ b/src/services/model/mapper.py @@ -5,18 +5,13 @@ from typing import Dict, List, Optional -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, joinedload from src.core.cache_utils import SyncLRUCache from src.core.logger import logger from src.models.claude import ClaudeMessagesRequest -from src.models.database import GlobalModel, Model, ModelMapping, Provider, ProviderEndpoint +from src.models.database import GlobalModel, Model, Provider, ProviderEndpoint from src.services.cache.model_cache import ModelCacheService -from src.services.model.mapping_resolver import ( - get_model_mapping_resolver, - resolve_model_to_global_name, -) - class ModelMapperMiddleware: @@ -71,10 +66,10 @@ class ModelMapperMiddleware: if mapping: # 应用映射 original_model = request.model - request.model = mapping.model.provider_model_name + request.model = mapping.model.select_provider_model_name() logger.debug(f"Applied model mapping for provider {provider.name}: " - f"{original_model} -> {mapping.model.provider_model_name}") + f"{original_model} -> {request.model}") else: # 没有找到映射,使用原始模型名 logger.debug(f"No model mapping found for {source_model} with provider {provider.name}, " @@ -84,17 +79,16 @@ class ModelMapperMiddleware: async def get_mapping( self, source_model: str, provider_id: str - ) -> Optional[ModelMapping]: # UUID + ) -> Optional[object]: """ 获取模型映射 - 优化后逻辑: - 1. 使用统一的 ModelMappingResolver 解析别名(带缓存) - 2. 通过 GlobalModel 找到该 Provider 的 Model 实现 - 3. 使用独立的映射缓存 + 简化后的逻辑: + 1. 通过 GlobalModel.name 直接查找 + 2. 找到 GlobalModel 后,查找该 Provider 的 Model 实现 Args: - source_model: 用户请求的模型名或别名 + source_model: 用户请求的模型名(必须是 GlobalModel.name) provider_id: 提供商ID (UUID) Returns: @@ -107,62 +101,59 @@ class ModelMapperMiddleware: mapping = None - # 步骤 1 & 2: 通过统一的模型映射解析服务 - mapping_resolver = get_model_mapping_resolver() - global_model = await mapping_resolver.get_global_model_by_request( - self.db, source_model, provider_id + # 步骤 1: 直接通过名称查找 GlobalModel + global_model = ( + self.db.query(GlobalModel) + .filter(GlobalModel.name == source_model, GlobalModel.is_active == True) + .first() ) if not global_model: - logger.debug(f"GlobalModel not found: {source_model} (provider={provider_id[:8]}...)") + logger.debug(f"GlobalModel not found: {source_model}") self._cache[cache_key] = None return None - # 步骤 3: 查找该 Provider 是否有实现这个 GlobalModel 的 Model(使用缓存) + # 步骤 2: 查找该 Provider 是否有实现这个 GlobalModel 的 Model(使用缓存) model = await ModelCacheService.get_model_by_provider_and_global_model( self.db, provider_id, global_model.id ) if model: - # 只有当模型名发生变化时才返回映射 - if model.provider_model_name != source_model: - mapping = type( - "obj", - (object,), - { - "source_model": source_model, - "model": model, - "is_active": True, - "provider_id": provider_id, - }, - )() + # 创建映射对象 + mapping = type( + "obj", + (object,), + { + "source_model": source_model, + "model": model, + "is_active": True, + "provider_id": provider_id, + }, + )() - logger.debug(f"Found model mapping: {source_model} -> {model.provider_model_name} " - f"(provider={provider_id[:8]}...)") - else: - logger.debug(f"Model found but no name change: {source_model} (provider={provider_id[:8]}...)") + logger.debug(f"Found model mapping: {source_model} -> {model.provider_model_name} " + f"(provider={provider_id[:8]}...)") # 缓存结果 self._cache[cache_key] = mapping return mapping - def get_all_mappings(self, provider_id: str) -> List[ModelMapping]: # UUID + def get_all_mappings(self, provider_id: str) -> List[object]: """ 获取提供商的所有可用模型(通过 GlobalModel) - 方案 A: 返回该 Provider 所有可用的 GlobalModel - Args: provider_id: 提供商ID (UUID) Returns: - 模型映射列表(模拟的 ModelMapping 对象列表) + 模型映射列表 """ - # 查询该 Provider 的所有活跃 Model + # 查询该 Provider 的所有活跃 Model(使用 joinedload 避免 N+1) models = ( self.db.query(Model) .join(GlobalModel) + .options(joinedload(Model.global_model)) .filter( Model.provider_id == provider_id, Model.is_active == True, @@ -171,7 +162,7 @@ class ModelMapperMiddleware: .all() ) - # 构造兼容的 ModelMapping 对象列表 + # 构造兼容的映射对象列表 mappings = [] for model in models: mapping = type( @@ -188,7 +179,7 @@ class ModelMapperMiddleware: return mappings - def get_supported_models(self, provider_id: str) -> List[str]: # UUID + def get_supported_models(self, provider_id: str) -> List[str]: """ 获取提供商支持的所有源模型名 @@ -223,15 +214,6 @@ class ModelMapperMiddleware: if not mapping.is_active: return False, f"Model mapping for {request.model} is disabled" - # 不限制max_tokens,作为中转服务不应该限制用户的请求 - # if request.max_tokens and request.max_tokens > mapping.max_output_tokens: - # return False, ( - # f"Requested max_tokens {request.max_tokens} exceeds limit " - # f"{mapping.max_output_tokens} for model {request.model}" - # ) - - # 可以添加更多验证逻辑,比如检查输入长度等 - return True, None def clear_cache(self): @@ -239,7 +221,7 @@ class ModelMapperMiddleware: self._cache.clear() logger.debug("Model mapping cache cleared") - def refresh_cache(self, provider_id: Optional[str] = None): # UUID + def refresh_cache(self, provider_id: Optional[str] = None): """ 刷新缓存 @@ -285,16 +267,10 @@ class ModelRoutingMiddleware: """ 根据模型名选择提供商 - 逻辑: - 1. 如果指定了提供商,使用指定的提供商 - 2. 如果没指定,使用默认提供商 - 3. 选定提供商后,会检查该提供商的模型映射(在apply_mapping中处理) - 4. 如果指定了allowed_api_formats,只选择符合格式的提供商 - Args: model_name: 请求的模型名 preferred_provider: 首选提供商名称 - allowed_api_formats: 允许的API格式列表(如 ['CLAUDE', 'CLAUDE_CLI']) + allowed_api_formats: 允许的API格式列表 request_id: 请求ID(用于日志关联) Returns: @@ -313,14 +289,12 @@ class ModelRoutingMiddleware: if provider: # 检查API格式 - 从 endpoints 中检查 if allowed_api_formats: - # 检查是否有符合要求的活跃端点 has_matching_endpoint = any( ep.is_active and ep.api_format and ep.api_format in allowed_api_formats for ep in provider.endpoints ) if not has_matching_endpoint: logger.warning(f"Specified provider {provider.name} has no active endpoints with allowed API formats ({allowed_api_formats})") - # 不返回该提供商,继续查找 else: logger.debug(f" └─ {request_prefix}使用指定提供商: {provider.name} | 模型:{model_name}") return provider @@ -330,10 +304,9 @@ class ModelRoutingMiddleware: else: logger.warning(f"Specified provider {preferred_provider} not found or inactive") - # 2. 查找优先级最高的活动提供商(provider_priority 最小) + # 2. 查找优先级最高的活动提供商 query = self.db.query(Provider).filter(Provider.is_active == True) - # 如果指定了API格式过滤,添加过滤条件 - 检查是否有符合要求的 endpoint if allowed_api_formats: query = ( query.join(ProviderEndpoint) @@ -344,32 +317,27 @@ class ModelRoutingMiddleware: .distinct() ) - # 按 provider_priority 排序,优先级最高(数字最小)的在前 best_provider = query.order_by(Provider.provider_priority.asc(), Provider.id.asc()).first() if best_provider: logger.debug(f" └─ {request_prefix}使用优先级最高提供商: {best_provider.name} (priority:{best_provider.provider_priority}) | 模型:{model_name}") return best_provider - # 3. 没有任何活动提供商 if allowed_api_formats: - logger.error(f"No active providers found with allowed API formats {allowed_api_formats}. Please configure at least one provider.") + logger.error(f"No active providers found with allowed API formats {allowed_api_formats}.") else: - logger.error("No active providers found. Please configure at least one provider.") + logger.error("No active providers found.") return None def get_available_models(self) -> Dict[str, List[str]]: """ 获取所有可用的模型及其提供商 - 方案 A: 基于 GlobalModel 查询 - Returns: 字典,键为 GlobalModel.name,值为支持该模型的提供商名列表 """ result = {} - # 查询所有活跃的 GlobalModel 及其 Provider models = ( self.db.query(GlobalModel.name, Provider.name) .join(Model, GlobalModel.id == Model.global_model_id) @@ -392,28 +360,23 @@ class ModelRoutingMiddleware: """ 获取某个模型最便宜的提供商 - 方案 A: 通过 GlobalModel 查找 - Args: - model_name: GlobalModel 名称或别名 + model_name: GlobalModel 名称 Returns: 最便宜的提供商 """ - # 步骤 1: 解析模型名 - global_model_name = await resolve_model_to_global_name(self.db, model_name) - - # 步骤 2: 查找 GlobalModel + # 直接查找 GlobalModel global_model = ( self.db.query(GlobalModel) - .filter(GlobalModel.name == global_model_name, GlobalModel.is_active == True) + .filter(GlobalModel.name == model_name, GlobalModel.is_active == True) .first() ) if not global_model: return None - # 步骤 3: 查询所有支持该模型的 Provider 及其价格 + # 查询所有支持该模型的 Provider 及其价格 models_with_providers = ( self.db.query(Provider, Model) .join(Model, Provider.id == Model.provider_id) @@ -428,15 +391,16 @@ class ModelRoutingMiddleware: if not models_with_providers: return None - # 按总价格排序(输入+输出价格) + # 按总价格排序 cheapest = min( - models_with_providers, key=lambda x: x[1].input_price_per_1m + x[1].output_price_per_1m + models_with_providers, + key=lambda x: x[1].get_effective_input_price() + x[1].get_effective_output_price() ) provider = cheapest[0] model = cheapest[1] logger.debug(f"Selected cheapest provider {provider.name} for model {model_name} " - f"(input: ${model.input_price_per_1m}/M, output: ${model.output_price_per_1m}/M)") + f"(input: ${model.get_effective_input_price()}/M, output: ${model.get_effective_output_price()}/M)") return provider diff --git a/src/services/model/service.py b/src/services/model/service.py index 7e73ca7..3702750 100644 --- a/src/services/model/service.py +++ b/src/services/model/service.py @@ -50,6 +50,7 @@ class ModelService: provider_id=provider_id, global_model_id=model_data.global_model_id, provider_model_name=model_data.provider_model_name, + provider_model_aliases=model_data.provider_model_aliases, price_per_request=model_data.price_per_request, tiered_pricing=model_data.tiered_pricing, supports_vision=model_data.supports_vision, @@ -191,7 +192,6 @@ class ModelService: 新架构删除逻辑: - Model 只是 Provider 对 GlobalModel 的实现,删除不影响 GlobalModel - 检查是否是该 GlobalModel 的最后一个实现(如果是,警告但允许删除) - - 不检查 ModelMapping(映射是 GlobalModel 之间的关系,别名也统一存储在此表中) """ model = db.query(Model).filter(Model.id == model_id).first() if not model: @@ -326,6 +326,7 @@ class ModelService: provider_id=model.provider_id, global_model_id=model.global_model_id, provider_model_name=model.provider_model_name, + provider_model_aliases=model.provider_model_aliases, # 原始配置值(可能为空) price_per_request=model.price_per_request, tiered_pricing=model.tiered_pricing, diff --git a/src/services/provider/service.py b/src/services/provider/service.py index 2246529..11fcd1b 100644 --- a/src/services/provider/service.py +++ b/src/services/provider/service.py @@ -7,13 +7,11 @@ from typing import Dict from sqlalchemy.orm import Session -from src.core.logger import logger from src.models.database import GlobalModel, Model, Provider from src.services.model.cost import ModelCostService from src.services.model.mapper import ModelMapperMiddleware, ModelRoutingMiddleware - class ProviderService: """提供商服务类""" @@ -34,30 +32,15 @@ class ProviderService: 检查模型是否可用(严格白名单模式) Args: - model_name: 模型名称 + model_name: 模型名称(必须是 GlobalModel.name) Returns: Model对象如果存在且激活,否则None """ - # 首先检查是否有直接的模型记录 - model = ( - self.db.query(Model) - .filter(Model.provider_model_name == model_name, Model.is_active == True) - .first() - ) - - if model: - return model - - # 方案 A:检查是否是别名(全局别名系统) - from src.services.model.mapping_resolver import resolve_model_to_global_name - - global_model_name = await resolve_model_to_global_name(self.db, model_name) - - # 查找 GlobalModel + # 直接查找 GlobalModel global_model = ( self.db.query(GlobalModel) - .filter(GlobalModel.name == global_model_name, GlobalModel.is_active == True) + .filter(GlobalModel.name == model_name, GlobalModel.is_active == True) .first() ) @@ -79,34 +62,15 @@ class ProviderService: Args: provider_id: 提供商ID - model_name: 模型名称 + model_name: 模型名称(必须是 GlobalModel.name) Returns: Model对象如果该提供商支持该模型且激活,否则None """ - # 首先检查该提供商下是否有直接的模型记录 - model = ( - self.db.query(Model) - .filter( - Model.provider_id == provider_id, - Model.provider_model_name == model_name, - Model.is_active == True, - ) - .first() - ) - - if model: - return model - - # 方案 A:检查是否是别名 - from src.services.model.mapping_resolver import resolve_model_to_global_name - - global_model_name = await resolve_model_to_global_name(self.db, model_name, provider_id) - - # 查找 GlobalModel + # 直接查找 GlobalModel global_model = ( self.db.query(GlobalModel) - .filter(GlobalModel.name == global_model_name, GlobalModel.is_active == True) + .filter(GlobalModel.name == model_name, GlobalModel.is_active == True) .first() ) @@ -148,12 +112,19 @@ class ProviderService: 获取所有可用的模型 Returns: - 模型和支持的提供商映射 + 字典,键为模型名,值为提供商列表 """ return self.router.get_available_models() - def clear_cache(self): - """清空缓存""" - self.mapper.clear_cache() - self.cost_service.clear_cache() - logger.info("Provider service cache cleared") + def select_provider(self, model_name: str, preferred_provider=None): + """ + 选择提供商 + + Args: + model_name: 模型名 + preferred_provider: 首选提供商 + + Returns: + Provider对象 + """ + return self.router.select_provider(model_name, preferred_provider) diff --git a/src/services/usage/service.py b/src/services/usage/service.py index b4a14cf..f7fc7aa 100644 --- a/src/services/usage/service.py +++ b/src/services/usage/service.py @@ -26,11 +26,10 @@ class UsageService: ) -> tuple[float, float]: """异步获取模型价格(输入价格,输出价格)每1M tokens - 新架构查找逻辑: - 1. 使用 ModelMappingResolver 解析别名(如果是) - 2. 解析为 GlobalModel.name - 3. 查找该 Provider 的 Model 实现并获取价格 - 4. 如果找不到则使用系统默认价格 + 查找逻辑: + 1. 直接通过 GlobalModel.name 匹配 + 2. 查找该 Provider 的 Model 实现并获取价格 + 3. 如果找不到则使用系统默认价格 """ service = ModelCostService(db) @@ -40,11 +39,10 @@ class UsageService: def get_model_price(cls, db: Session, provider: str, model: str) -> tuple[float, float]: """获取模型价格(输入价格,输出价格)每1M tokens - 新架构查找逻辑: - 1. 使用 ModelMappingResolver 解析别名(如果是) - 2. 解析为 GlobalModel.name - 3. 查找该 Provider 的 Model 实现并获取价格 - 4. 如果找不到则使用系统默认价格 + 查找逻辑: + 1. 直接通过 GlobalModel.name 匹配 + 2. 查找该 Provider 的 Model 实现并获取价格 + 3. 如果找不到则使用系统默认价格 """ service = ModelCostService(db)