mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-11 12:08:30 +08:00
Initial commit
This commit is contained in:
19
src/services/model/__init__.py
Normal file
19
src/services/model/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
模型服务模块
|
||||
|
||||
包含模型管理、模型映射、成本计算等功能。
|
||||
"""
|
||||
|
||||
from src.services.model.cost import ModelCostService
|
||||
from src.services.model.global_model import GlobalModelService
|
||||
from src.services.model.mapper import ModelMapperMiddleware
|
||||
from src.services.model.mapping_resolver import ModelMappingResolver
|
||||
from src.services.model.service import ModelService
|
||||
|
||||
__all__ = [
|
||||
"ModelService",
|
||||
"GlobalModelService",
|
||||
"ModelMapperMiddleware",
|
||||
"ModelMappingResolver",
|
||||
"ModelCostService",
|
||||
]
|
||||
946
src/services/model/cost.py
Normal file
946
src/services/model/cost.py
Normal file
@@ -0,0 +1,946 @@
|
||||
"""
|
||||
模型成本服务
|
||||
负责统一的价格解析、缓存以及成本计算逻辑。
|
||||
支持固定价格、按次计费和阶梯计费三种模式。
|
||||
|
||||
计费策略:
|
||||
- 不同 API format 可以有不同的计费逻辑
|
||||
- 通过 PricingStrategy 抽象,支持自定义总输入上下文计算、缓存 TTL 差异化等
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, Model, ModelMapping, Provider
|
||||
|
||||
|
||||
ProviderRef = Union[str, Provider, None]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TieredPriceResult:
|
||||
"""阶梯计费价格查询结果"""
|
||||
input_price_per_1m: float
|
||||
output_price_per_1m: float
|
||||
cache_creation_price_per_1m: Optional[float] = None
|
||||
cache_read_price_per_1m: Optional[float] = None
|
||||
tier_index: int = 0 # 命中的阶梯索引
|
||||
|
||||
|
||||
@dataclass
|
||||
class CostBreakdown:
|
||||
"""成本明细"""
|
||||
input_cost: float
|
||||
output_cost: float
|
||||
cache_creation_cost: float
|
||||
cache_read_cost: float
|
||||
cache_cost: float
|
||||
request_cost: float
|
||||
total_cost: float
|
||||
|
||||
|
||||
class ModelCostService:
|
||||
"""集中负责模型价格与成本计算,避免在 mapper/usage 中重复实现。"""
|
||||
|
||||
_price_cache: Dict[str, Dict[str, float]] = {}
|
||||
_cache_price_cache: Dict[str, Dict[str, float]] = {}
|
||||
_tiered_pricing_cache: Dict[str, Optional[dict]] = {}
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 阶梯计费相关方法
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def get_tier_for_tokens(
|
||||
tiered_pricing: dict,
|
||||
total_input_tokens: int
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
根据总输入 token 数确定价格阶梯。
|
||||
|
||||
Args:
|
||||
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
|
||||
total_input_tokens: 总输入 token 数(input_tokens + cache_read_tokens)
|
||||
|
||||
Returns:
|
||||
匹配的阶梯配置,如果未找到返回 None
|
||||
"""
|
||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
for tier in tiers:
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_tokens <= up_to:
|
||||
return tier
|
||||
|
||||
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
|
||||
return tiers[-1] if tiers else None
|
||||
|
||||
@staticmethod
|
||||
def get_cache_read_price_for_ttl(
|
||||
tier: dict,
|
||||
cache_ttl_minutes: Optional[int] = None
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
根据缓存 TTL 获取缓存读取价格。
|
||||
|
||||
Args:
|
||||
tier: 当前阶梯配置
|
||||
cache_ttl_minutes: 缓存时长(分钟),如果为 None 使用默认价格
|
||||
|
||||
Returns:
|
||||
缓存读取价格
|
||||
"""
|
||||
# 首先检查是否有 TTL 差异化定价
|
||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||
if ttl_pricing and cache_ttl_minutes is not None:
|
||||
# 找到匹配或最接近的 TTL 价格
|
||||
matched_price = None
|
||||
for ttl_config in ttl_pricing:
|
||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||
if cache_ttl_minutes <= ttl_limit:
|
||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
||||
break
|
||||
if matched_price is not None:
|
||||
return matched_price
|
||||
# 如果超过所有配置的 TTL,使用最后一个
|
||||
if ttl_pricing:
|
||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||
|
||||
# 使用默认的缓存读取价格
|
||||
return tier.get("cache_read_price_per_1m")
|
||||
|
||||
async def get_tiered_pricing_async(
|
||||
self, provider: ProviderRef, model: str
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
异步获取模型的阶梯计费配置。
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 模型名称
|
||||
|
||||
Returns:
|
||||
阶梯计费配置,如果未配置返回 None
|
||||
"""
|
||||
result = await self.get_tiered_pricing_with_source_async(provider, model)
|
||||
return result.get("pricing") if result else None
|
||||
|
||||
async def get_tiered_pricing_with_source_async(
|
||||
self, provider: ProviderRef, model: str
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
异步获取模型的阶梯计费配置及来源信息。
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 模型名称
|
||||
|
||||
Returns:
|
||||
包含 pricing 和 source 的字典:
|
||||
- pricing: 阶梯计费配置
|
||||
- source: 'provider' 或 'global'
|
||||
"""
|
||||
provider_name = self._provider_name(provider)
|
||||
cache_key = f"{provider_name}:{model}:tiered_with_source"
|
||||
|
||||
if cache_key in self._tiered_pricing_cache:
|
||||
return self._tiered_pricing_cache[cache_key]
|
||||
|
||||
provider_obj = self._resolve_provider(provider)
|
||||
result = None
|
||||
|
||||
if provider_obj:
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(
|
||||
self.db, model, provider_obj.id
|
||||
)
|
||||
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == global_model_name,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model_obj:
|
||||
# 判断定价来源
|
||||
if model_obj.tiered_pricing is not None:
|
||||
result = {
|
||||
"pricing": model_obj.tiered_pricing,
|
||||
"source": "provider"
|
||||
}
|
||||
elif global_model.default_tiered_pricing is not None:
|
||||
result = {
|
||||
"pricing": global_model.default_tiered_pricing,
|
||||
"source": "global"
|
||||
}
|
||||
|
||||
self._tiered_pricing_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
def get_tiered_pricing(self, provider: ProviderRef, model: str) -> Optional[dict]:
|
||||
"""同步获取模型的阶梯计费配置。"""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop.run_until_complete(self.get_tiered_pricing_async(provider, model))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 公共方法
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_model_price_async(self, provider: ProviderRef, model: str) -> Tuple[float, float]:
|
||||
"""
|
||||
异步版本: 返回给定 provider/model 的 (input_price, output_price)。
|
||||
|
||||
注意:如果模型配置了阶梯计费,此方法返回第一个阶梯的价格作为默认值。
|
||||
实际计费时应使用 compute_cost_with_tiered_pricing 方法。
|
||||
|
||||
计费逻辑(基于 mapping_type):
|
||||
1. 查找 ModelMapping(如果存在)
|
||||
2. 如果 mapping_type='alias':使用目标 GlobalModel 的价格
|
||||
3. 如果 mapping_type='mapping':尝试使用 source_model 对应的 GlobalModel 价格
|
||||
- 如果 source_model 对应的 GlobalModel 存在且有 Model 实现,使用那个价格
|
||||
- 否则回退到目标 GlobalModel 的价格
|
||||
4. 如果没有找到任何 ModelMapping,尝试直接匹配 GlobalModel.name
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
|
||||
Returns:
|
||||
(input_price, output_price) 元组
|
||||
"""
|
||||
provider_name = self._provider_name(provider)
|
||||
cache_key = f"{provider_name}:{model}"
|
||||
|
||||
if cache_key in self._price_cache:
|
||||
prices = self._price_cache[cache_key]
|
||||
return prices["input"], prices["output"]
|
||||
|
||||
provider_obj = self._resolve_provider(provider)
|
||||
input_price = None
|
||||
output_price = None
|
||||
|
||||
if provider_obj:
|
||||
# 步骤 1: 查找 ModelMapping 以确定 mapping_type
|
||||
from src.models.database import ModelMapping
|
||||
|
||||
mapping = None
|
||||
# 先查 Provider 特定映射
|
||||
mapping = (
|
||||
self.db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == model,
|
||||
ModelMapping.provider_id == provider_obj.id,
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
# 再查全局映射
|
||||
if not mapping:
|
||||
mapping = (
|
||||
self.db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if mapping:
|
||||
# 有映射,根据 mapping_type 决定计费模型
|
||||
if mapping.mapping_type == "mapping":
|
||||
# mapping 模式:尝试使用 source_model 对应的 GlobalModel 价格
|
||||
source_global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if source_global_model:
|
||||
source_model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == source_global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if source_model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = source_model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = source_model_obj.get_effective_input_price()
|
||||
output_price = source_model_obj.get_effective_output_price()
|
||||
logger.debug(f"[mapping模式] 使用源模型价格: {model} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
|
||||
# alias 模式或 mapping 模式未找到源模型价格:使用目标 GlobalModel 价格
|
||||
if input_price is None:
|
||||
target_global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == mapping.target_global_model_id,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if target_global_model:
|
||||
target_model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == target_global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if target_model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = target_model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = target_model_obj.get_effective_input_price()
|
||||
output_price = target_model_obj.get_effective_output_price()
|
||||
mode_label = (
|
||||
"alias模式"
|
||||
if mapping.mapping_type == "alias"
|
||||
else "mapping模式(回退)"
|
||||
)
|
||||
logger.debug(f"[{mode_label}] 使用目标模型价格: {model} -> {target_global_model.name} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
else:
|
||||
# 没有映射,尝试直接匹配 GlobalModel.name
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = model_obj.get_effective_input_price()
|
||||
output_price = model_obj.get_effective_output_price()
|
||||
logger.debug(f"找到模型价格配置: {provider_name}/{model} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
|
||||
# 如果没有找到价格配置,使用 0.0 并记录警告
|
||||
if input_price is None:
|
||||
input_price = 0.0
|
||||
if output_price is None:
|
||||
output_price = 0.0
|
||||
|
||||
# 检查是否有按次计费配置(按次计费模型的 token 价格可以为 0)
|
||||
if input_price == 0.0 and output_price == 0.0:
|
||||
# 异步检查按次计费价格
|
||||
price_per_request = await self.get_request_price_async(provider, model)
|
||||
if price_per_request is None or price_per_request == 0.0:
|
||||
logger.warning(f"未找到模型价格配置: {provider_name}/{model},请在 GlobalModel 中配置价格")
|
||||
|
||||
self._price_cache[cache_key] = {"input": input_price, "output": output_price}
|
||||
return input_price, output_price
|
||||
|
||||
def get_model_price(self, provider: ProviderRef, model: str) -> Tuple[float, float]:
|
||||
"""
|
||||
返回给定 provider/model 的 (input_price, output_price)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
|
||||
Returns:
|
||||
(input_price, output_price) 元组
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# 在同步上下文中调用异步方法
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop.run_until_complete(self.get_model_price_async(provider, model))
|
||||
|
||||
async def get_cache_prices_async(
|
||||
self, provider: ProviderRef, model: str, input_price: float
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""
|
||||
异步版本: 返回缓存创建/读取价格(每 1M tokens)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取缓存价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
input_price: 基础输入价格(用于 Claude 模型的默认估算)
|
||||
|
||||
Returns:
|
||||
(cache_creation_price, cache_read_price) 元组
|
||||
"""
|
||||
provider_name = self._provider_name(provider)
|
||||
cache_key = f"{provider_name}:{model}"
|
||||
|
||||
if cache_key in self._cache_price_cache:
|
||||
prices = self._cache_price_cache[cache_key]
|
||||
return prices["creation"], prices["read"]
|
||||
|
||||
provider_obj = self._resolve_provider(provider)
|
||||
cache_creation_price = None
|
||||
cache_read_price = None
|
||||
|
||||
if provider_obj:
|
||||
# 步骤 1: 检查是否是别名
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
|
||||
|
||||
# 步骤 2: 查找 GlobalModel
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == global_model_name,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# 步骤 3: 查找该 Provider 的 Model 实现
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model_obj:
|
||||
# 检查是否有阶梯计费配置
|
||||
tiered = model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
# 使用第一个阶梯的缓存价格作为默认值
|
||||
first_tier = tiered["tiers"][0]
|
||||
cache_creation_price = first_tier.get("cache_creation_price_per_1m")
|
||||
cache_read_price = first_tier.get("cache_read_price_per_1m")
|
||||
else:
|
||||
# 使用 get_effective_* 方法,会自动回退到 GlobalModel 的默认值
|
||||
cache_creation_price = model_obj.get_effective_cache_creation_price()
|
||||
cache_read_price = model_obj.get_effective_cache_read_price()
|
||||
|
||||
# 默认缓存价格估算(如果没有配置)- 基于输入价格计算
|
||||
if cache_creation_price is None or cache_read_price is None:
|
||||
if cache_creation_price is None:
|
||||
cache_creation_price = input_price * 1.25
|
||||
if cache_read_price is None:
|
||||
cache_read_price = input_price * 0.1
|
||||
|
||||
self._cache_price_cache[cache_key] = {
|
||||
"creation": cache_creation_price,
|
||||
"read": cache_read_price,
|
||||
}
|
||||
return cache_creation_price, cache_read_price
|
||||
|
||||
async def get_request_price_async(self, provider: ProviderRef, model: str) -> Optional[float]:
|
||||
"""
|
||||
异步版本: 返回按次计费价格(每次请求的固定费用)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取按次计费价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
|
||||
Returns:
|
||||
按次计费价格,如果没有配置则返回 None
|
||||
"""
|
||||
provider_obj = self._resolve_provider(provider)
|
||||
price_per_request = None
|
||||
|
||||
if provider_obj:
|
||||
# 步骤 1: 检查是否是别名
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
|
||||
|
||||
# 步骤 2: 查找 GlobalModel
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == global_model_name,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# 步骤 3: 查找该 Provider 的 Model 实现
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model_obj:
|
||||
# 使用 get_effective_* 方法,会自动回退到 GlobalModel 的默认值
|
||||
price_per_request = model_obj.get_effective_price_per_request()
|
||||
|
||||
return price_per_request
|
||||
|
||||
def get_request_price(self, provider: ProviderRef, model: str) -> Optional[float]:
|
||||
"""
|
||||
返回按次计费价格(每次请求的固定费用)。
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
|
||||
Returns:
|
||||
按次计费价格,如果没有配置则返回 None
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# 在同步上下文中调用异步方法
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop.run_until_complete(self.get_request_price_async(provider, model))
|
||||
|
||||
def get_cache_prices(
|
||||
self, provider: ProviderRef, model: str, input_price: float
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""
|
||||
返回缓存创建/读取价格(每 1M tokens)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取缓存价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
input_price: 基础输入价格(用于 Claude 模型的默认估算)
|
||||
|
||||
Returns:
|
||||
(cache_creation_price, cache_read_price) 元组
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# 在同步上下文中调用异步方法
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop.run_until_complete(self.get_cache_prices_async(provider, model, input_price))
|
||||
|
||||
def calculate_cost(
|
||||
self,
|
||||
provider: Provider,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
) -> Dict[str, float]:
|
||||
"""返回与旧 ModelMapper.calculate_cost 相同结构的费用信息。"""
|
||||
input_price, output_price = self.get_model_price(provider, model)
|
||||
input_cost, output_cost, _, _, _, _, total_cost = self.compute_cost(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
input_price_per_1m=input_price,
|
||||
output_price_per_1m=output_price,
|
||||
)
|
||||
return {
|
||||
"input_cost": round(input_cost, 6),
|
||||
"output_cost": round(output_cost, 6),
|
||||
"total_cost": round(total_cost, 6),
|
||||
"input_price_per_1m": input_price,
|
||||
"output_price_per_1m": output_price,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def compute_cost(
|
||||
*,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
input_price_per_1m: float,
|
||||
output_price_per_1m: float,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
cache_read_input_tokens: int = 0,
|
||||
cache_creation_price_per_1m: Optional[float] = None,
|
||||
cache_read_price_per_1m: Optional[float] = None,
|
||||
price_per_request: Optional[float] = None,
|
||||
) -> Tuple[float, float, float, float, float, float, float]:
|
||||
"""成本计算核心逻辑(固定价格模式),供 UsageService 等复用。
|
||||
|
||||
Returns:
|
||||
Tuple of (input_cost, output_cost, cache_creation_cost,
|
||||
cache_read_cost, cache_cost, request_cost, total_cost)
|
||||
"""
|
||||
input_cost = (input_tokens / 1_000_000) * input_price_per_1m
|
||||
output_cost = (output_tokens / 1_000_000) * output_price_per_1m
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and cache_creation_price_per_1m is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * cache_creation_price_per_1m
|
||||
if cache_read_input_tokens > 0 and cache_read_price_per_1m is not None:
|
||||
cache_read_cost = (cache_read_input_tokens / 1_000_000) * cache_read_price_per_1m
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
|
||||
# 按次计费成本
|
||||
request_cost = price_per_request if price_per_request is not None else 0.0
|
||||
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
return (
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
cache_cost,
|
||||
request_cost,
|
||||
total_cost,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def compute_cost_with_tiered_pricing(
|
||||
*,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
cache_read_input_tokens: int = 0,
|
||||
tiered_pricing: Optional[dict] = None,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
price_per_request: Optional[float] = None,
|
||||
# 回退价格(当没有阶梯配置时使用)
|
||||
fallback_input_price_per_1m: float = 0.0,
|
||||
fallback_output_price_per_1m: float = 0.0,
|
||||
fallback_cache_creation_price_per_1m: Optional[float] = None,
|
||||
fallback_cache_read_price_per_1m: Optional[float] = None,
|
||||
) -> Tuple[float, float, float, float, float, float, float, Optional[int]]:
|
||||
"""
|
||||
支持阶梯计费的成本计算核心逻辑。
|
||||
|
||||
阶梯判定:使用 input_tokens + cache_read_input_tokens(总输入上下文)
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
output_tokens: 输出 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
tiered_pricing: 阶梯计费配置
|
||||
cache_ttl_minutes: 缓存时长(分钟),用于 TTL 差异化定价
|
||||
price_per_request: 按次计费价格
|
||||
fallback_*: 回退价格配置
|
||||
|
||||
Returns:
|
||||
Tuple of (input_cost, output_cost, cache_creation_cost,
|
||||
cache_read_cost, cache_cost, request_cost, total_cost, tier_index)
|
||||
tier_index: 命中的阶梯索引(0-based),如果未使用阶梯计费则为 None
|
||||
"""
|
||||
# 计算总输入上下文(用于阶梯判定)
|
||||
total_input_context = input_tokens + cache_read_input_tokens
|
||||
|
||||
tier_index = None
|
||||
input_price_per_1m = fallback_input_price_per_1m
|
||||
output_price_per_1m = fallback_output_price_per_1m
|
||||
cache_creation_price_per_1m = fallback_cache_creation_price_per_1m
|
||||
cache_read_price_per_1m = fallback_cache_read_price_per_1m
|
||||
|
||||
# 如果有阶梯配置,查找匹配的阶梯
|
||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||
tier = ModelCostService.get_tier_for_tokens(tiered_pricing, total_input_context)
|
||||
if tier:
|
||||
# 找到阶梯索引
|
||||
tier_index = tiered_pricing["tiers"].index(tier)
|
||||
|
||||
input_price_per_1m = tier.get("input_price_per_1m", fallback_input_price_per_1m)
|
||||
output_price_per_1m = tier.get("output_price_per_1m", fallback_output_price_per_1m)
|
||||
cache_creation_price_per_1m = tier.get(
|
||||
"cache_creation_price_per_1m", fallback_cache_creation_price_per_1m
|
||||
)
|
||||
|
||||
# 获取缓存读取价格(考虑 TTL 差异化)
|
||||
cache_read_price_per_1m = ModelCostService.get_cache_read_price_for_ttl(
|
||||
tier, cache_ttl_minutes
|
||||
)
|
||||
if cache_read_price_per_1m is None:
|
||||
cache_read_price_per_1m = fallback_cache_read_price_per_1m
|
||||
|
||||
logger.debug(
|
||||
f"[阶梯计费] 总输入上下文: {total_input_context}, "
|
||||
f"命中阶梯: {tier_index + 1}, "
|
||||
f"输入价格: ${input_price_per_1m}/M, "
|
||||
f"输出价格: ${output_price_per_1m}/M"
|
||||
)
|
||||
|
||||
# 计算成本
|
||||
input_cost = (input_tokens / 1_000_000) * input_price_per_1m
|
||||
output_cost = (output_tokens / 1_000_000) * output_price_per_1m
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and cache_creation_price_per_1m is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * cache_creation_price_per_1m
|
||||
if cache_read_input_tokens > 0 and cache_read_price_per_1m is not None:
|
||||
cache_read_cost = (cache_read_input_tokens / 1_000_000) * cache_read_price_per_1m
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
|
||||
# 按次计费成本
|
||||
request_cost = price_per_request if price_per_request is not None else 0.0
|
||||
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
|
||||
return (
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
cache_cost,
|
||||
request_cost,
|
||||
total_cost,
|
||||
tier_index,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls):
|
||||
"""清理价格相关缓存。"""
|
||||
cls._price_cache.clear()
|
||||
cls._cache_price_cache.clear()
|
||||
cls._tiered_pricing_cache.clear()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 内部辅助
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _provider_name(self, provider: ProviderRef) -> str:
|
||||
if isinstance(provider, Provider):
|
||||
return provider.name
|
||||
return provider or "unknown"
|
||||
|
||||
def _resolve_provider(self, provider: ProviderRef) -> Optional[Provider]:
|
||||
if isinstance(provider, Provider):
|
||||
return provider
|
||||
if not provider or provider == "unknown":
|
||||
return None
|
||||
return self.db.query(Provider).filter(Provider.name == provider).first()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 基于策略模式的计费方法
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def compute_cost_with_strategy_async(
|
||||
self,
|
||||
provider: ProviderRef,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
cache_read_input_tokens: int = 0,
|
||||
api_format: Optional[str] = None,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Tuple[float, float, float, float, float, float, float, Optional[int]]:
|
||||
"""
|
||||
使用计费策略计算成本(异步版本)
|
||||
|
||||
根据 api_format 选择对应的 Adapter 计费逻辑,支持阶梯计费和 TTL 差异化。
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 模型名称
|
||||
input_tokens: 输入 token 数
|
||||
output_tokens: 输出 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
api_format: API 格式(用于选择计费策略)
|
||||
cache_ttl_minutes: 缓存时长(分钟),用于 TTL 差异化定价
|
||||
|
||||
Returns:
|
||||
Tuple of (input_cost, output_cost, cache_creation_cost,
|
||||
cache_read_cost, cache_cost, request_cost, total_cost, tier_index)
|
||||
"""
|
||||
# 获取价格配置
|
||||
input_price, output_price = await self.get_model_price_async(provider, model)
|
||||
cache_creation_price, cache_read_price = await self.get_cache_prices_async(
|
||||
provider, model, input_price
|
||||
)
|
||||
request_price = await self.get_request_price_async(provider, model)
|
||||
tiered_pricing = await self.get_tiered_pricing_async(provider, model)
|
||||
|
||||
# 获取对应 API 格式的 Adapter 实例来计算成本
|
||||
# 优先检查 Chat Adapter,然后检查 CLI Adapter
|
||||
from src.api.handlers.base.chat_adapter_base import get_adapter_instance
|
||||
from src.api.handlers.base.cli_adapter_base import get_cli_adapter_instance
|
||||
|
||||
adapter = None
|
||||
if api_format:
|
||||
adapter = get_adapter_instance(api_format)
|
||||
if adapter is None:
|
||||
adapter = get_cli_adapter_instance(api_format)
|
||||
|
||||
if adapter:
|
||||
# 使用 Adapter 的计费方法
|
||||
result = adapter.compute_cost(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
input_price_per_1m=input_price,
|
||||
output_price_per_1m=output_price,
|
||||
cache_creation_price_per_1m=cache_creation_price,
|
||||
cache_read_price_per_1m=cache_read_price,
|
||||
price_per_request=request_price,
|
||||
tiered_pricing=tiered_pricing,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
)
|
||||
return (
|
||||
result["input_cost"],
|
||||
result["output_cost"],
|
||||
result["cache_creation_cost"],
|
||||
result["cache_read_cost"],
|
||||
result["cache_cost"],
|
||||
result["request_cost"],
|
||||
result["total_cost"],
|
||||
result["tier_index"],
|
||||
)
|
||||
else:
|
||||
# 回退到默认计算逻辑(无 Adapter 时使用静态方法)
|
||||
return self.compute_cost_with_tiered_pricing(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
tiered_pricing=tiered_pricing,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
price_per_request=request_price,
|
||||
fallback_input_price_per_1m=input_price,
|
||||
fallback_output_price_per_1m=output_price,
|
||||
fallback_cache_creation_price_per_1m=cache_creation_price,
|
||||
fallback_cache_read_price_per_1m=cache_read_price,
|
||||
)
|
||||
|
||||
def compute_cost_with_strategy(
|
||||
self,
|
||||
provider: ProviderRef,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
cache_read_input_tokens: int = 0,
|
||||
api_format: Optional[str] = None,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Tuple[float, float, float, float, float, float, float, Optional[int]]:
|
||||
"""
|
||||
使用计费策略计算成本(同步版本)
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop.run_until_complete(
|
||||
self.compute_cost_with_strategy_async(
|
||||
provider=provider,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
api_format=api_format,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
)
|
||||
)
|
||||
299
src/services/model/global_model.py
Normal file
299
src/services/model/global_model.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
GlobalModel 服务层
|
||||
|
||||
提供 GlobalModel 的 CRUD 操作、查询和统计功能
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, Model, ModelMapping
|
||||
from src.models.pydantic_models import GlobalModelUpdate
|
||||
|
||||
|
||||
|
||||
class GlobalModelService:
|
||||
"""GlobalModel 服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_global_model(db: Session, global_model_id: str) -> GlobalModel:
|
||||
"""
|
||||
获取单个 GlobalModel
|
||||
|
||||
Args:
|
||||
global_model_id: GlobalModel 的 UUID 或 name
|
||||
"""
|
||||
# 先尝试通过 ID 查找
|
||||
global_model = db.query(GlobalModel).filter(GlobalModel.id == global_model_id).first()
|
||||
|
||||
# 如果没找到,尝试通过 name 查找
|
||||
if not global_model:
|
||||
global_model = db.query(GlobalModel).filter(GlobalModel.name == global_model_id).first()
|
||||
|
||||
if not global_model:
|
||||
raise NotFoundException(f"GlobalModel {global_model_id} not found")
|
||||
return global_model
|
||||
|
||||
@staticmethod
|
||||
def get_global_model_by_name(db: Session, name: str) -> Optional[GlobalModel]:
|
||||
"""通过名称获取 GlobalModel"""
|
||||
return db.query(GlobalModel).filter(GlobalModel.name == name).first()
|
||||
|
||||
@staticmethod
|
||||
def list_global_models(
|
||||
db: Session,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None,
|
||||
) -> List[GlobalModel]:
|
||||
"""列出 GlobalModel"""
|
||||
query = db.query(GlobalModel)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(GlobalModel.is_active == is_active)
|
||||
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
query = query.filter(
|
||||
(GlobalModel.name.ilike(search_pattern))
|
||||
| (GlobalModel.display_name.ilike(search_pattern))
|
||||
| (GlobalModel.description.ilike(search_pattern))
|
||||
)
|
||||
|
||||
# 按名称排序
|
||||
query = query.order_by(GlobalModel.name)
|
||||
|
||||
return query.offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def create_global_model(
|
||||
db: Session,
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: Optional[str] = None,
|
||||
official_url: Optional[str] = None,
|
||||
icon_url: Optional[str] = None,
|
||||
is_active: Optional[bool] = True,
|
||||
# 按次计费配置
|
||||
default_price_per_request: Optional[float] = None,
|
||||
# 阶梯计费配置(必填)
|
||||
default_tiered_pricing: dict = None,
|
||||
# 默认能力配置
|
||||
default_supports_vision: Optional[bool] = None,
|
||||
default_supports_function_calling: Optional[bool] = None,
|
||||
default_supports_streaming: Optional[bool] = None,
|
||||
default_supports_extended_thinking: Optional[bool] = None,
|
||||
# Key 能力配置
|
||||
supported_capabilities: Optional[List[str]] = None,
|
||||
) -> GlobalModel:
|
||||
"""创建 GlobalModel"""
|
||||
# 检查名称是否已存在
|
||||
existing = GlobalModelService.get_global_model_by_name(db, name)
|
||||
if existing:
|
||||
raise InvalidRequestException(f"GlobalModel with name '{name}' already exists")
|
||||
|
||||
global_model = GlobalModel(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
official_url=official_url,
|
||||
icon_url=icon_url,
|
||||
is_active=is_active,
|
||||
# 按次计费配置
|
||||
default_price_per_request=default_price_per_request,
|
||||
# 阶梯计费配置
|
||||
default_tiered_pricing=default_tiered_pricing,
|
||||
# 默认能力配置
|
||||
default_supports_vision=default_supports_vision,
|
||||
default_supports_function_calling=default_supports_function_calling,
|
||||
default_supports_streaming=default_supports_streaming,
|
||||
default_supports_extended_thinking=default_supports_extended_thinking,
|
||||
# Key 能力配置
|
||||
supported_capabilities=supported_capabilities,
|
||||
)
|
||||
|
||||
db.add(global_model)
|
||||
db.commit()
|
||||
db.refresh(global_model)
|
||||
|
||||
return global_model
|
||||
|
||||
@staticmethod
|
||||
def update_global_model(
|
||||
db: Session,
|
||||
global_model_id: str,
|
||||
update_data: GlobalModelUpdate,
|
||||
) -> GlobalModel:
|
||||
"""
|
||||
更新 GlobalModel
|
||||
|
||||
使用 exclude_unset=True 来区分"未提供字段"和"显式设置为 None":
|
||||
- 未提供的字段不会被更新
|
||||
- 显式设置为 None 的字段会被更新为 None(置空)
|
||||
"""
|
||||
global_model = GlobalModelService.get_global_model(db, global_model_id)
|
||||
|
||||
# 只更新显式设置的字段(包括显式设置为 None 的情况)
|
||||
data_dict = update_data.model_dump(exclude_unset=True)
|
||||
|
||||
# 处理阶梯计费配置:如果是 TieredPricingConfig 对象,转换为 dict
|
||||
if "default_tiered_pricing" in data_dict:
|
||||
tiered_pricing = data_dict["default_tiered_pricing"]
|
||||
if tiered_pricing is not None and hasattr(tiered_pricing, "model_dump"):
|
||||
data_dict["default_tiered_pricing"] = tiered_pricing.model_dump()
|
||||
|
||||
for field, value in data_dict.items():
|
||||
setattr(global_model, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(global_model)
|
||||
|
||||
return global_model
|
||||
|
||||
@staticmethod
|
||||
def delete_global_model(db: Session, global_model_id: str) -> None:
|
||||
"""
|
||||
删除 GlobalModel
|
||||
|
||||
默认行为: 级联删除所有关联的 Provider 模型实现
|
||||
"""
|
||||
global_model = GlobalModelService.get_global_model(db, global_model_id)
|
||||
|
||||
# 查找所有关联的 Model(使用 global_model.id,预加载 provider 关联)
|
||||
associated_models = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.provider))
|
||||
.filter(Model.global_model_id == global_model.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 级联删除所有关联的 Provider 模型实现
|
||||
if associated_models:
|
||||
logger.info(f"删除 GlobalModel {global_model.name} 的 {len(associated_models)} 个关联 Provider 模型")
|
||||
for model in associated_models:
|
||||
db.delete(model)
|
||||
|
||||
# 删除 GlobalModel
|
||||
db.delete(global_model)
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_global_model_stats(db: Session, global_model_id: str) -> Dict:
|
||||
"""获取 GlobalModel 统计信息"""
|
||||
global_model = GlobalModelService.get_global_model(db, global_model_id)
|
||||
|
||||
# 统计关联的 Model 数量(使用 global_model.id,预加载 provider 关联)
|
||||
models = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.provider))
|
||||
.filter(Model.global_model_id == global_model.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 统计支持的 Provider 数量
|
||||
provider_ids = set(model.provider_id for model in models)
|
||||
|
||||
# 从阶梯计费中提取价格范围
|
||||
input_prices = []
|
||||
output_prices = []
|
||||
for m in models:
|
||||
tiered = m.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
if first_tier.get("input_price_per_1m") is not None:
|
||||
input_prices.append(first_tier["input_price_per_1m"])
|
||||
if first_tier.get("output_price_per_1m") is not None:
|
||||
output_prices.append(first_tier["output_price_per_1m"])
|
||||
|
||||
return {
|
||||
"global_model_id": global_model.id,
|
||||
"name": global_model.name,
|
||||
"total_models": len(models),
|
||||
"total_providers": len(provider_ids),
|
||||
"price_range": {
|
||||
"min_input": min(input_prices) if input_prices else None,
|
||||
"max_input": max(input_prices) if input_prices else None,
|
||||
"min_output": min(output_prices) if output_prices else None,
|
||||
"max_output": max(output_prices) if output_prices else None,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def batch_assign_to_providers(
|
||||
db: Session,
|
||||
global_model_id: str,
|
||||
provider_ids: List[str],
|
||||
create_models: bool = False,
|
||||
) -> Dict:
|
||||
"""批量为多个 Provider 添加 GlobalModel 实现"""
|
||||
from .service import ModelService
|
||||
|
||||
global_model = GlobalModelService.get_global_model(db, global_model_id)
|
||||
|
||||
results = {
|
||||
"success": [],
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
for provider_id in provider_ids:
|
||||
try:
|
||||
# 检查该 Provider 是否已有该 GlobalModel 的实现(使用 global_model.id)
|
||||
existing_model = (
|
||||
db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_id,
|
||||
Model.global_model_id == global_model.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_model:
|
||||
results["errors"].append(
|
||||
{
|
||||
"provider_id": provider_id,
|
||||
"error": "Model already exists for this provider",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if create_models:
|
||||
# 创建新的 Model(价格和能力设为 None,继承 GlobalModel 默认值)
|
||||
model = Model(
|
||||
provider_id=provider_id,
|
||||
global_model_id=global_model.id,
|
||||
provider_model_name=global_model.name, # 默认使用 GlobalModel name
|
||||
# 计费设为 None,使用 GlobalModel 默认值
|
||||
price_per_request=None,
|
||||
tiered_pricing=None,
|
||||
# 能力设为 None,使用 GlobalModel 默认值
|
||||
supports_vision=None,
|
||||
supports_function_calling=None,
|
||||
supports_streaming=None,
|
||||
supports_extended_thinking=None,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(model)
|
||||
db.commit()
|
||||
|
||||
results["success"].append(
|
||||
{"provider_id": provider_id, "model_id": model.id, "created": True}
|
||||
)
|
||||
else:
|
||||
results["errors"].append(
|
||||
{
|
||||
"provider_id": provider_id,
|
||||
"error": "create_models=False, no existing model found",
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
results["errors"].append({"provider_id": provider_id, "error": str(e)})
|
||||
|
||||
db.commit()
|
||||
return results
|
||||
442
src/services/model/mapper.py
Normal file
442
src/services/model/mapper.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""
|
||||
模型映射中间件
|
||||
根据数据库中的配置,将用户请求的模型映射到提供商的实际模型
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.cache_utils import SyncLRUCache
|
||||
from src.core.logger import logger
|
||||
from src.models.claude import ClaudeMessagesRequest
|
||||
from src.models.database import GlobalModel, Model, ModelMapping, Provider, ProviderEndpoint
|
||||
from src.services.cache.model_cache import ModelCacheService
|
||||
from src.services.model.mapping_resolver import (
|
||||
get_model_mapping_resolver,
|
||||
resolve_model_to_global_name,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class ModelMapperMiddleware:
|
||||
"""
|
||||
模型映射中间件
|
||||
负责将用户请求的模型名映射到提供商的实际模型名
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, cache_max_size: int = 1000, cache_ttl: int = 300):
|
||||
"""
|
||||
初始化模型映射中间件
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
cache_max_size: 缓存最大容量(默认 1000)
|
||||
cache_ttl: 缓存过期时间(秒,默认 300)
|
||||
"""
|
||||
self.db = db
|
||||
self._cache = SyncLRUCache(max_size=cache_max_size, ttl=cache_ttl)
|
||||
|
||||
logger.debug(f"[ModelMapper] 初始化(max_size={cache_max_size}, ttl={cache_ttl}s)")
|
||||
|
||||
# 注册到缓存失效服务
|
||||
try:
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.register_model_mapper(self)
|
||||
logger.debug("[ModelMapper] 已注册到缓存失效服务")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ModelMapper] 注册缓存失效服务失败: {e}")
|
||||
|
||||
async def apply_mapping(
|
||||
self, request: ClaudeMessagesRequest, provider: Provider
|
||||
) -> ClaudeMessagesRequest:
|
||||
"""
|
||||
应用模型映射到请求
|
||||
|
||||
Args:
|
||||
request: 原始请求
|
||||
provider: 目标提供商
|
||||
|
||||
Returns:
|
||||
应用映射后的请求
|
||||
"""
|
||||
# 获取请求的模型名
|
||||
source_model = request.model
|
||||
|
||||
# 查找映射
|
||||
mapping = await self.get_mapping(source_model, provider.id)
|
||||
|
||||
if mapping:
|
||||
# 应用映射
|
||||
original_model = request.model
|
||||
request.model = mapping.model.provider_model_name
|
||||
|
||||
logger.debug(f"Applied model mapping for provider {provider.name}: "
|
||||
f"{original_model} -> {mapping.model.provider_model_name}")
|
||||
else:
|
||||
# 没有找到映射,使用原始模型名
|
||||
logger.debug(f"No model mapping found for {source_model} with provider {provider.name}, "
|
||||
f"forwarding with original model name")
|
||||
|
||||
return request
|
||||
|
||||
async def get_mapping(
|
||||
self, source_model: str, provider_id: str
|
||||
) -> Optional[ModelMapping]: # UUID
|
||||
"""
|
||||
获取模型映射
|
||||
|
||||
优化后逻辑:
|
||||
1. 使用统一的 ModelMappingResolver 解析别名(带缓存)
|
||||
2. 通过 GlobalModel 找到该 Provider 的 Model 实现
|
||||
3. 使用独立的映射缓存
|
||||
|
||||
Args:
|
||||
source_model: 用户请求的模型名或别名
|
||||
provider_id: 提供商ID (UUID)
|
||||
|
||||
Returns:
|
||||
模型映射对象(包含 model 字段),如果没有找到返回None
|
||||
"""
|
||||
# 检查缓存
|
||||
cache_key = f"{provider_id}:{source_model}"
|
||||
if cache_key in self._cache:
|
||||
return self._cache[cache_key]
|
||||
|
||||
mapping = None
|
||||
|
||||
# 步骤 1 & 2: 通过统一的模型映射解析服务
|
||||
mapping_resolver = get_model_mapping_resolver()
|
||||
global_model = await mapping_resolver.get_global_model_by_request(
|
||||
self.db, source_model, provider_id
|
||||
)
|
||||
|
||||
if not global_model:
|
||||
logger.debug(f"GlobalModel not found: {source_model} (provider={provider_id[:8]}...)")
|
||||
self._cache[cache_key] = None
|
||||
return None
|
||||
|
||||
# 步骤 3: 查找该 Provider 是否有实现这个 GlobalModel 的 Model(使用缓存)
|
||||
model = await ModelCacheService.get_model_by_provider_and_global_model(
|
||||
self.db, provider_id, global_model.id
|
||||
)
|
||||
|
||||
if model:
|
||||
# 只有当模型名发生变化时才返回映射
|
||||
if model.provider_model_name != source_model:
|
||||
mapping = type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"source_model": source_model,
|
||||
"model": model,
|
||||
"is_active": True,
|
||||
"provider_id": provider_id,
|
||||
},
|
||||
)()
|
||||
|
||||
logger.debug(f"Found model mapping: {source_model} -> {model.provider_model_name} "
|
||||
f"(provider={provider_id[:8]}...)")
|
||||
else:
|
||||
logger.debug(f"Model found but no name change: {source_model} (provider={provider_id[:8]}...)")
|
||||
|
||||
# 缓存结果
|
||||
self._cache[cache_key] = mapping
|
||||
|
||||
return mapping
|
||||
|
||||
def get_all_mappings(self, provider_id: str) -> List[ModelMapping]: # UUID
|
||||
"""
|
||||
获取提供商的所有可用模型(通过 GlobalModel)
|
||||
|
||||
方案 A: 返回该 Provider 所有可用的 GlobalModel
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID (UUID)
|
||||
|
||||
Returns:
|
||||
模型映射列表(模拟的 ModelMapping 对象列表)
|
||||
"""
|
||||
# 查询该 Provider 的所有活跃 Model
|
||||
models = (
|
||||
self.db.query(Model)
|
||||
.join(GlobalModel)
|
||||
.filter(
|
||||
Model.provider_id == provider_id,
|
||||
Model.is_active == True,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 构造兼容的 ModelMapping 对象列表
|
||||
mappings = []
|
||||
for model in models:
|
||||
mapping = type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"source_model": model.global_model.name,
|
||||
"model": model,
|
||||
"is_active": True,
|
||||
"provider_id": provider_id,
|
||||
},
|
||||
)()
|
||||
mappings.append(mapping)
|
||||
|
||||
return mappings
|
||||
|
||||
def get_supported_models(self, provider_id: str) -> List[str]: # UUID
|
||||
"""
|
||||
获取提供商支持的所有源模型名
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID (UUID)
|
||||
|
||||
Returns:
|
||||
支持的模型名列表
|
||||
"""
|
||||
mappings = self.get_all_mappings(provider_id)
|
||||
return [mapping.source_model for mapping in mappings]
|
||||
|
||||
async def validate_request(
|
||||
self, request: ClaudeMessagesRequest, provider: Provider
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
验证请求是否符合映射的限制
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
provider: 提供商对象
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
mapping = await self.get_mapping(request.model, provider.id)
|
||||
|
||||
if not mapping:
|
||||
# 没有映射,可能是默认支持的模型
|
||||
return True, None
|
||||
|
||||
if not mapping.is_active:
|
||||
return False, f"Model mapping for {request.model} is disabled"
|
||||
|
||||
# 不限制max_tokens,作为中转服务不应该限制用户的请求
|
||||
# if request.max_tokens and request.max_tokens > mapping.max_output_tokens:
|
||||
# return False, (
|
||||
# f"Requested max_tokens {request.max_tokens} exceeds limit "
|
||||
# f"{mapping.max_output_tokens} for model {request.model}"
|
||||
# )
|
||||
|
||||
# 可以添加更多验证逻辑,比如检查输入长度等
|
||||
|
||||
return True, None
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
self._cache.clear()
|
||||
logger.debug("Model mapping cache cleared")
|
||||
|
||||
def refresh_cache(self, provider_id: Optional[str] = None): # UUID
|
||||
"""
|
||||
刷新缓存
|
||||
|
||||
Args:
|
||||
provider_id: 如果指定,只刷新该提供商的缓存 (UUID)
|
||||
"""
|
||||
if provider_id:
|
||||
# 清除特定提供商的缓存
|
||||
keys_to_remove = [
|
||||
key for key in self._cache.keys() if key.startswith(f"{provider_id}:")
|
||||
]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
logger.debug(f"Refreshed cache for provider {provider_id}")
|
||||
else:
|
||||
# 清空所有缓存
|
||||
self.clear_cache()
|
||||
|
||||
|
||||
class ModelRoutingMiddleware:
|
||||
"""
|
||||
模型路由中间件
|
||||
根据模型名选择合适的提供商
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""
|
||||
初始化模型路由中间件
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
self.mapper = ModelMapperMiddleware(db)
|
||||
|
||||
def select_provider(
|
||||
self,
|
||||
model_name: str,
|
||||
preferred_provider: Optional[str] = None,
|
||||
allowed_api_formats: Optional[List[str]] = None,
|
||||
request_id: Optional[str] = None,
|
||||
) -> Optional[Provider]:
|
||||
"""
|
||||
根据模型名选择提供商
|
||||
|
||||
逻辑:
|
||||
1. 如果指定了提供商,使用指定的提供商
|
||||
2. 如果没指定,使用默认提供商
|
||||
3. 选定提供商后,会检查该提供商的模型映射(在apply_mapping中处理)
|
||||
4. 如果指定了allowed_api_formats,只选择符合格式的提供商
|
||||
|
||||
Args:
|
||||
model_name: 请求的模型名
|
||||
preferred_provider: 首选提供商名称
|
||||
allowed_api_formats: 允许的API格式列表(如 ['CLAUDE', 'CLAUDE_CLI'])
|
||||
request_id: 请求ID(用于日志关联)
|
||||
|
||||
Returns:
|
||||
选中的提供商,如果没有找到返回None
|
||||
"""
|
||||
request_prefix = f"ID:{request_id} | " if request_id else ""
|
||||
|
||||
# 1. 如果指定了提供商,直接使用
|
||||
if preferred_provider:
|
||||
provider = (
|
||||
self.db.query(Provider)
|
||||
.filter(Provider.name == preferred_provider, Provider.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider:
|
||||
# 检查API格式 - 从 endpoints 中检查
|
||||
if allowed_api_formats:
|
||||
# 检查是否有符合要求的活跃端点
|
||||
has_matching_endpoint = any(
|
||||
ep.is_active and ep.api_format and ep.api_format in allowed_api_formats
|
||||
for ep in provider.endpoints
|
||||
)
|
||||
if not has_matching_endpoint:
|
||||
logger.warning(f"Specified provider {provider.name} has no active endpoints with allowed API formats ({allowed_api_formats})")
|
||||
# 不返回该提供商,继续查找
|
||||
else:
|
||||
logger.debug(f" └─ {request_prefix}使用指定提供商: {provider.name} | 模型:{model_name}")
|
||||
return provider
|
||||
else:
|
||||
logger.debug(f" └─ {request_prefix}使用指定提供商: {provider.name} | 模型:{model_name}")
|
||||
return provider
|
||||
else:
|
||||
logger.warning(f"Specified provider {preferred_provider} not found or inactive")
|
||||
|
||||
# 2. 查找优先级最高的活动提供商(provider_priority 最小)
|
||||
query = self.db.query(Provider).filter(Provider.is_active == True)
|
||||
|
||||
# 如果指定了API格式过滤,添加过滤条件 - 检查是否有符合要求的 endpoint
|
||||
if allowed_api_formats:
|
||||
query = (
|
||||
query.join(ProviderEndpoint)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active == True,
|
||||
ProviderEndpoint.api_format.in_(allowed_api_formats),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
|
||||
# 按 provider_priority 排序,优先级最高(数字最小)的在前
|
||||
best_provider = query.order_by(Provider.provider_priority.asc(), Provider.id.asc()).first()
|
||||
|
||||
if best_provider:
|
||||
logger.debug(f" └─ {request_prefix}使用优先级最高提供商: {best_provider.name} (priority:{best_provider.provider_priority}) | 模型:{model_name}")
|
||||
return best_provider
|
||||
|
||||
# 3. 没有任何活动提供商
|
||||
if allowed_api_formats:
|
||||
logger.error(f"No active providers found with allowed API formats {allowed_api_formats}. Please configure at least one provider.")
|
||||
else:
|
||||
logger.error("No active providers found. Please configure at least one provider.")
|
||||
return None
|
||||
|
||||
def get_available_models(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
获取所有可用的模型及其提供商
|
||||
|
||||
方案 A: 基于 GlobalModel 查询
|
||||
|
||||
Returns:
|
||||
字典,键为 GlobalModel.name,值为支持该模型的提供商名列表
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# 查询所有活跃的 GlobalModel 及其 Provider
|
||||
models = (
|
||||
self.db.query(GlobalModel.name, Provider.name)
|
||||
.join(Model, GlobalModel.id == Model.global_model_id)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.filter(
|
||||
GlobalModel.is_active == True, Model.is_active == True, Provider.is_active == True
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for global_model_name, provider_name in models:
|
||||
if global_model_name not in result:
|
||||
result[global_model_name] = []
|
||||
if provider_name not in result[global_model_name]:
|
||||
result[global_model_name].append(provider_name)
|
||||
|
||||
return result
|
||||
|
||||
async def get_cheapest_provider(self, model_name: str) -> Optional[Provider]:
|
||||
"""
|
||||
获取某个模型最便宜的提供商
|
||||
|
||||
方案 A: 通过 GlobalModel 查找
|
||||
|
||||
Args:
|
||||
model_name: GlobalModel 名称或别名
|
||||
|
||||
Returns:
|
||||
最便宜的提供商
|
||||
"""
|
||||
# 步骤 1: 解析模型名
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model_name)
|
||||
|
||||
# 步骤 2: 查找 GlobalModel
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == global_model_name, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not global_model:
|
||||
return None
|
||||
|
||||
# 步骤 3: 查询所有支持该模型的 Provider 及其价格
|
||||
models_with_providers = (
|
||||
self.db.query(Provider, Model)
|
||||
.join(Model, Provider.id == Model.provider_id)
|
||||
.filter(
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.is_active == True,
|
||||
Provider.is_active == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not models_with_providers:
|
||||
return None
|
||||
|
||||
# 按总价格排序(输入+输出价格)
|
||||
cheapest = min(
|
||||
models_with_providers, key=lambda x: x[1].input_price_per_1m + x[1].output_price_per_1m
|
||||
)
|
||||
|
||||
provider = cheapest[0]
|
||||
model = cheapest[1]
|
||||
|
||||
logger.debug(f"Selected cheapest provider {provider.name} for model {model_name} "
|
||||
f"(input: ${model.input_price_per_1m}/M, output: ${model.output_price_per_1m}/M)")
|
||||
|
||||
return provider
|
||||
432
src/services/model/mapping_resolver.py
Normal file
432
src/services/model/mapping_resolver.py
Normal file
@@ -0,0 +1,432 @@
|
||||
"""
|
||||
模型映射解析服务
|
||||
|
||||
负责统一的模型别名/降级解析,按优先级顺序:
|
||||
1. 映射(mapping):Provider 特定 → 全局
|
||||
2. 别名(alias):Provider 特定 → 全局
|
||||
3. 直接匹配 GlobalModel.name
|
||||
|
||||
支持特性:
|
||||
- 带缓存(本地或 Redis),减少数据库访问
|
||||
- 提供模糊匹配能力,用于提示相似模型
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from src.core.logger import logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.constants import CacheSize, CacheTTL
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, ModelMapping
|
||||
from src.services.cache.backend import BaseCacheBackend, get_cache_backend
|
||||
|
||||
|
||||
class ModelMappingResolver:
|
||||
"""统一的 ModelMapping 解析服务(可跨进程共享缓存)。"""
|
||||
|
||||
def __init__(self, cache_ttl: int = CacheTTL.MODEL_MAPPING, cache_backend_type: str = "auto"):
|
||||
self._cache_ttl = cache_ttl
|
||||
self._cache_backend_type = cache_backend_type
|
||||
self._mapping_cache: Optional[BaseCacheBackend] = None
|
||||
self._global_model_cache: Optional[BaseCacheBackend] = None
|
||||
self._initialized = False
|
||||
self._stats = {
|
||||
"mapping_hits": 0,
|
||||
"mapping_misses": 0,
|
||||
"global_hits": 0,
|
||||
"global_misses": 0,
|
||||
}
|
||||
|
||||
async def _ensure_initialized(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._mapping_cache = await get_cache_backend(
|
||||
name="model_mapping_resolver:mapping",
|
||||
backend_type=self._cache_backend_type,
|
||||
max_size=CacheSize.MODEL_MAPPING,
|
||||
ttl=self._cache_ttl,
|
||||
)
|
||||
self._global_model_cache = await get_cache_backend(
|
||||
name="model_mapping_resolver:global",
|
||||
backend_type=self._cache_backend_type,
|
||||
max_size=CacheSize.MODEL_MAPPING,
|
||||
ttl=self._cache_ttl,
|
||||
)
|
||||
self._initialized = True
|
||||
logger.debug(f"[ModelMappingResolver] 缓存后端已初始化: {self._mapping_cache.get_stats()['backend']}")
|
||||
|
||||
def _cache_key(self, source_model: str, provider_id: Optional[str]) -> str:
|
||||
return f"{provider_id or 'global'}:{source_model}"
|
||||
|
||||
async def _lookup_target_global_model_id(
|
||||
self,
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
按优先级查找目标 GlobalModel ID:
|
||||
1. 映射(mapping_type='mapping'):Provider 特定 → 全局
|
||||
2. 别名(mapping_type='alias'):Provider 特定 → 全局
|
||||
3. 直接匹配 GlobalModel.name
|
||||
"""
|
||||
await self._ensure_initialized()
|
||||
cache_key = self._cache_key(source_model, provider_id)
|
||||
cached = await self._mapping_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
self._stats["mapping_hits"] += 1
|
||||
return cached or None
|
||||
|
||||
self._stats["mapping_misses"] += 1
|
||||
|
||||
target_id: Optional[str] = None
|
||||
|
||||
# 优先级 1:查找映射(mapping_type='mapping')
|
||||
# 1.1 Provider 特定映射
|
||||
if provider_id:
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id == provider_id,
|
||||
ModelMapping.mapping_type == "mapping",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if mapping:
|
||||
target_id = mapping.target_global_model_id
|
||||
logger.debug(f"[MappingResolver] 命中 Provider 映射: {source_model} -> {target_id[:8]}...")
|
||||
|
||||
# 1.2 全局映射
|
||||
if not target_id:
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.mapping_type == "mapping",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if mapping:
|
||||
target_id = mapping.target_global_model_id
|
||||
logger.debug(f"[MappingResolver] 命中全局映射: {source_model} -> {target_id[:8]}...")
|
||||
|
||||
# 优先级 2:查找别名(mapping_type='alias')
|
||||
# 2.1 Provider 特定别名
|
||||
if not target_id and provider_id:
|
||||
alias = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id == provider_id,
|
||||
ModelMapping.mapping_type == "alias",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if alias:
|
||||
target_id = alias.target_global_model_id
|
||||
logger.debug(f"[MappingResolver] 命中 Provider 别名: {source_model} -> {target_id[:8]}...")
|
||||
|
||||
# 2.2 全局别名
|
||||
if not target_id:
|
||||
alias = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.mapping_type == "alias",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if alias:
|
||||
target_id = alias.target_global_model_id
|
||||
logger.debug(f"[MappingResolver] 命中全局别名: {source_model} -> {target_id[:8]}...")
|
||||
|
||||
# 优先级 3:直接匹配 GlobalModel.name
|
||||
if not target_id:
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == source_model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if global_model:
|
||||
target_id = global_model.id
|
||||
logger.debug(f"[MappingResolver] 直接匹配 GlobalModel: {source_model}")
|
||||
|
||||
cached_value = target_id if target_id is not None else ""
|
||||
await self._mapping_cache.set(cache_key, cached_value, self._cache_ttl)
|
||||
return target_id
|
||||
|
||||
async def resolve_to_global_model_name(
|
||||
self,
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""解析模型名/别名为 GlobalModel.name。未找到时返回原始输入。"""
|
||||
target_id = await self._lookup_target_global_model_id(db, source_model, provider_id)
|
||||
if not target_id:
|
||||
return source_model
|
||||
|
||||
await self._ensure_initialized()
|
||||
cached_name = await self._global_model_cache.get(target_id)
|
||||
if cached_name:
|
||||
self._stats["global_hits"] += 1
|
||||
return cached_name
|
||||
|
||||
self._stats["global_misses"] += 1
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.id == target_id, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
if global_model:
|
||||
await self._global_model_cache.set(target_id, global_model.name, self._cache_ttl)
|
||||
return global_model.name
|
||||
|
||||
return source_model
|
||||
|
||||
async def get_global_model_by_request(
|
||||
self,
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> Optional[GlobalModel]:
|
||||
"""解析并返回 GlobalModel 对象(绑定当前 Session)。"""
|
||||
target_id = await self._lookup_target_global_model_id(db, source_model, provider_id)
|
||||
if not target_id:
|
||||
return None
|
||||
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.id == target_id, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
return global_model
|
||||
|
||||
async def get_global_model_with_mapping_info(
|
||||
self,
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> Tuple[Optional[GlobalModel], bool]:
|
||||
"""
|
||||
解析并返回 GlobalModel 对象,同时返回是否发生了映射。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
source_model: 用户请求的模型名
|
||||
provider_id: Provider ID(可选)
|
||||
|
||||
Returns:
|
||||
(global_model, is_mapped) - GlobalModel 对象和是否发生了映射
|
||||
is_mapped=True 表示 source_model 通过 mapping 规则映射到了不同的模型
|
||||
is_mapped=False 表示 source_model 直接匹配或通过 alias 匹配
|
||||
"""
|
||||
await self._ensure_initialized()
|
||||
|
||||
# 先检查是否存在 mapping 类型的映射规则
|
||||
has_mapping = False
|
||||
|
||||
# 检查 Provider 特定映射
|
||||
if provider_id:
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id == provider_id,
|
||||
ModelMapping.mapping_type == "mapping",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if mapping:
|
||||
has_mapping = True
|
||||
|
||||
# 检查全局映射
|
||||
if not has_mapping:
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.mapping_type == "mapping",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if mapping:
|
||||
has_mapping = True
|
||||
|
||||
# 获取 GlobalModel
|
||||
global_model = await self.get_global_model_by_request(db, source_model, provider_id)
|
||||
|
||||
return global_model, has_mapping
|
||||
|
||||
async def get_global_model_direct(
|
||||
self,
|
||||
db: Session,
|
||||
source_model: str,
|
||||
) -> Optional[GlobalModel]:
|
||||
"""
|
||||
直接通过模型名获取 GlobalModel,不应用任何映射规则。
|
||||
仅查找 alias 和直接匹配。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
source_model: 模型名
|
||||
|
||||
Returns:
|
||||
GlobalModel 对象或 None
|
||||
"""
|
||||
# 优先级 1:查找别名(alias)
|
||||
# 全局别名
|
||||
alias = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.mapping_type == "alias",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if alias:
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.id == alias.target_global_model_id, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
if global_model:
|
||||
return global_model
|
||||
|
||||
# 优先级 2:直接匹配 GlobalModel.name
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == source_model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return global_model
|
||||
|
||||
def find_similar_models(
|
||||
self,
|
||||
db: Session,
|
||||
invalid_model: str,
|
||||
limit: int = 3,
|
||||
threshold: float = 0.4,
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""用于提示相似的 GlobalModel.name。"""
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
all_models = db.query(GlobalModel.name).filter(GlobalModel.is_active == True).all()
|
||||
similarities: List[Tuple[str, float]] = []
|
||||
invalid_lower = invalid_model.lower()
|
||||
|
||||
for model in all_models:
|
||||
model_name = model.name
|
||||
ratio = SequenceMatcher(None, invalid_lower, model_name.lower()).ratio()
|
||||
if invalid_lower in model_name.lower() or model_name.lower() in invalid_lower:
|
||||
ratio += 0.2
|
||||
if ratio >= threshold:
|
||||
similarities.append((model_name, ratio))
|
||||
|
||||
similarities.sort(key=lambda item: item[1], reverse=True)
|
||||
return similarities[:limit]
|
||||
|
||||
async def invalidate_mapping_cache(self, source_model: str, provider_id: Optional[str] = None):
|
||||
await self._ensure_initialized()
|
||||
keys = [self._cache_key(source_model, provider_id)]
|
||||
if provider_id:
|
||||
keys.append(self._cache_key(source_model, None))
|
||||
for key in keys:
|
||||
await self._mapping_cache.delete(key)
|
||||
|
||||
async def invalidate_global_model_cache(self, global_model_id: Optional[str] = None):
|
||||
await self._ensure_initialized()
|
||||
if global_model_id:
|
||||
await self._global_model_cache.delete(global_model_id)
|
||||
else:
|
||||
await self._global_model_cache.clear()
|
||||
|
||||
async def clear_cache(self):
|
||||
await self._ensure_initialized()
|
||||
await self._mapping_cache.clear()
|
||||
await self._global_model_cache.clear()
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
total_mapping = self._stats["mapping_hits"] + self._stats["mapping_misses"]
|
||||
total_global = self._stats["global_hits"] + self._stats["global_misses"]
|
||||
stats = {
|
||||
"mapping_hit_rate": (
|
||||
self._stats["mapping_hits"] / total_mapping if total_mapping else 0.0
|
||||
),
|
||||
"global_hit_rate": self._stats["global_hits"] / total_global if total_global else 0.0,
|
||||
"stats": self._stats,
|
||||
}
|
||||
if self._initialized:
|
||||
stats["mapping_cache_backend"] = self._mapping_cache.get_stats()
|
||||
stats["global_cache_backend"] = self._global_model_cache.get_stats()
|
||||
return stats
|
||||
|
||||
|
||||
_model_mapping_resolver: Optional[ModelMappingResolver] = None
|
||||
|
||||
|
||||
def get_model_mapping_resolver(
|
||||
cache_ttl: int = 300, cache_backend_type: Optional[str] = None
|
||||
) -> ModelMappingResolver:
|
||||
global _model_mapping_resolver
|
||||
|
||||
if _model_mapping_resolver is None:
|
||||
if cache_backend_type is None:
|
||||
cache_backend_type = os.getenv("ALIAS_CACHE_BACKEND", "auto")
|
||||
_model_mapping_resolver = ModelMappingResolver(
|
||||
cache_ttl=cache_ttl,
|
||||
cache_backend_type=cache_backend_type,
|
||||
)
|
||||
logger.debug(f"[ModelMappingResolver] 初始化(cache_ttl={cache_ttl}s, backend={cache_backend_type})")
|
||||
|
||||
# 注册到缓存失效服务
|
||||
try:
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.set_mapping_resolver(_model_mapping_resolver)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[ModelMappingResolver] 注册缓存失效服务失败: {exc}")
|
||||
|
||||
return _model_mapping_resolver
|
||||
|
||||
|
||||
async def resolve_model_to_global_name(
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> str:
|
||||
resolver = get_model_mapping_resolver()
|
||||
return await resolver.resolve_to_global_model_name(db, source_model, provider_id)
|
||||
|
||||
|
||||
async def get_global_model_by_request(
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> Optional[GlobalModel]:
|
||||
resolver = get_model_mapping_resolver()
|
||||
return await resolver.get_global_model_by_request(db, source_model, provider_id)
|
||||
48
src/services/model/pricing_strategy.py
Normal file
48
src/services/model/pricing_strategy.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
计费相关数据类
|
||||
|
||||
定义计费计算所需的数据结构。
|
||||
实际的计费逻辑已移至 ChatAdapterBase,每种 API 格式可以覆盖计费方法。
|
||||
|
||||
数据类:
|
||||
- UsageTokens: 请求的 token 使用量
|
||||
- PricingConfig: 价格配置
|
||||
- CostResult: 计费结果
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageTokens:
|
||||
"""请求的 token 使用量"""
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_creation_input_tokens: int = 0
|
||||
cache_read_input_tokens: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class PricingConfig:
|
||||
"""价格配置"""
|
||||
input_price_per_1m: float = 0.0
|
||||
output_price_per_1m: float = 0.0
|
||||
cache_creation_price_per_1m: Optional[float] = None
|
||||
cache_read_price_per_1m: Optional[float] = None
|
||||
price_per_request: Optional[float] = None
|
||||
tiered_pricing: Optional[dict] = None
|
||||
cache_ttl_minutes: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CostResult:
|
||||
"""计费结果"""
|
||||
input_cost: float = 0.0
|
||||
output_cost: float = 0.0
|
||||
cache_creation_cost: float = 0.0
|
||||
cache_read_cost: float = 0.0
|
||||
cache_cost: float = 0.0
|
||||
request_cost: float = 0.0
|
||||
total_cost: float = 0.0
|
||||
tier_index: Optional[int] = None # 命中的阶梯索引
|
||||
356
src/services/model/service.py
Normal file
356
src/services/model/service.py
Normal file
@@ -0,0 +1,356 @@
|
||||
"""
|
||||
模型管理服务
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.models.api import ModelCreate, ModelResponse, ModelUpdate
|
||||
from src.models.database import Model, Provider
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
from src.services.cache.model_cache import ModelCacheService
|
||||
|
||||
|
||||
|
||||
class ModelService:
|
||||
"""模型管理服务"""
|
||||
|
||||
@staticmethod
|
||||
def create_model(db: Session, provider_id: str, model_data: ModelCreate) -> Model:
|
||||
"""创建模型"""
|
||||
# 检查提供商是否存在
|
||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException(f"提供商 {provider_id} 不存在")
|
||||
|
||||
# 检查同一提供商下是否已存在同名模型
|
||||
existing = (
|
||||
db.query(Model)
|
||||
.filter(
|
||||
and_(
|
||||
Model.provider_id == provider_id,
|
||||
Model.provider_model_name == model_data.provider_model_name,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise InvalidRequestException(
|
||||
f"提供商 {provider.name} 下已存在模型 {model_data.provider_model_name}"
|
||||
)
|
||||
|
||||
try:
|
||||
model = Model(
|
||||
provider_id=provider_id,
|
||||
global_model_id=model_data.global_model_id,
|
||||
provider_model_name=model_data.provider_model_name,
|
||||
price_per_request=model_data.price_per_request,
|
||||
tiered_pricing=model_data.tiered_pricing,
|
||||
supports_vision=model_data.supports_vision,
|
||||
supports_function_calling=model_data.supports_function_calling,
|
||||
supports_streaming=model_data.supports_streaming,
|
||||
supports_extended_thinking=model_data.supports_extended_thinking,
|
||||
is_active=model_data.is_active if model_data.is_active is not None else True,
|
||||
config=model_data.config,
|
||||
)
|
||||
db.add(model)
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
# 显式加载 global_model 关系
|
||||
if model.global_model_id:
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
model = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.global_model))
|
||||
.filter(Model.id == model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
logger.info(f"创建模型成功: provider={provider.name}, model={model.provider_model_name}, global_model_id={model.global_model_id}")
|
||||
return model
|
||||
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
logger.error(f"创建模型失败: {str(e)}")
|
||||
raise InvalidRequestException("创建模型失败,请检查输入数据")
|
||||
|
||||
@staticmethod
|
||||
def get_model(db: Session, model_id: str) -> Model: # UUID
|
||||
"""获取模型详情"""
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
model = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.global_model))
|
||||
.filter(Model.id == model_id)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
raise NotFoundException(f"模型 {model_id} 不存在")
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def get_models_by_provider(
|
||||
db: Session,
|
||||
provider_id: str, # UUID
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
) -> List[Model]:
|
||||
"""获取提供商的模型列表"""
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
query = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.global_model))
|
||||
.filter(Model.provider_id == provider_id)
|
||||
)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(Model.is_active == is_active)
|
||||
|
||||
# 按创建时间排序
|
||||
query = query.order_by(Model.created_at.desc())
|
||||
|
||||
return query.offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def get_all_models(
|
||||
db: Session,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
category: Optional[str] = None,
|
||||
) -> List[Model]:
|
||||
"""获取所有模型列表"""
|
||||
query = db.query(Model)
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(Model.is_active == is_active)
|
||||
|
||||
# 按提供商和创建时间排序
|
||||
query = query.order_by(Model.provider_id, Model.created_at.desc())
|
||||
|
||||
return query.offset(skip).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def update_model(db: Session, model_id: str, model_data: ModelUpdate) -> Model: # UUID
|
||||
"""更新模型"""
|
||||
model = db.query(Model).filter(Model.id == model_id).first()
|
||||
if not model:
|
||||
raise NotFoundException(f"模型 {model_id} 不存在")
|
||||
|
||||
# 更新字段
|
||||
update_data = model_data.model_dump(exclude_unset=True)
|
||||
|
||||
# 添加调试日志
|
||||
logger.debug(f"更新模型 {model_id} 收到的数据: {update_data}")
|
||||
logger.debug(f"更新前的 supports_vision: {model.supports_vision}, supports_function_calling: {model.supports_function_calling}, supports_extended_thinking: {model.supports_extended_thinking}")
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(model, field, value)
|
||||
|
||||
logger.debug(f"更新后的 supports_vision: {model.supports_vision}, supports_function_calling: {model.supports_function_calling}, supports_extended_thinking: {model.supports_extended_thinking}")
|
||||
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
|
||||
# 清除 Redis 缓存(异步执行,不阻塞返回)
|
||||
asyncio.create_task(
|
||||
ModelCacheService.invalidate_model_cache(
|
||||
model_id=model.id,
|
||||
provider_id=model.provider_id,
|
||||
global_model_id=model.global_model_id,
|
||||
)
|
||||
)
|
||||
|
||||
# 清除内存缓存(ModelMapperMiddleware 实例)
|
||||
if model.provider_id and model.global_model_id:
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_changed(model.provider_id, model.global_model_id)
|
||||
|
||||
logger.info(f"更新模型成功: id={model_id}, 最终 supports_vision: {model.supports_vision}, supports_function_calling: {model.supports_function_calling}, supports_extended_thinking: {model.supports_extended_thinking}")
|
||||
return model
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
logger.error(f"更新模型失败: {str(e)}")
|
||||
raise InvalidRequestException("更新模型失败,请检查输入数据")
|
||||
|
||||
@staticmethod
|
||||
def delete_model(db: Session, model_id: str): # UUID
|
||||
"""删除模型
|
||||
|
||||
新架构删除逻辑:
|
||||
- Model 只是 Provider 对 GlobalModel 的实现,删除不影响 GlobalModel
|
||||
- 检查是否是该 GlobalModel 的最后一个实现(如果是,警告但允许删除)
|
||||
- 不检查 ModelMapping(映射是 GlobalModel 之间的关系,别名也统一存储在此表中)
|
||||
"""
|
||||
model = db.query(Model).filter(Model.id == model_id).first()
|
||||
if not model:
|
||||
raise NotFoundException(f"模型 {model_id} 不存在")
|
||||
|
||||
# 检查这是否是该 GlobalModel 的最后一个关联提供商
|
||||
if model.global_model_id:
|
||||
other_implementations = (
|
||||
db.query(Model)
|
||||
.filter(
|
||||
Model.global_model_id == model.global_model_id,
|
||||
Model.id != model_id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if other_implementations == 0:
|
||||
logger.warning(f"警告:删除模型 {model_id}(Provider: {model.provider_id[:8]}...)后,"
|
||||
f"GlobalModel '{model.global_model_id}' 将没有任何活跃的关联提供商")
|
||||
|
||||
try:
|
||||
db.delete(model)
|
||||
db.commit()
|
||||
logger.info(f"删除模型成功: id={model_id}, provider_model_name={model.provider_model_name}, "
|
||||
f"global_model_id={model.global_model_id[:8] if model.global_model_id else 'None'}...")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"删除模型失败: {str(e)}")
|
||||
raise InvalidRequestException("删除模型失败")
|
||||
|
||||
@staticmethod
|
||||
def toggle_model_availability(db: Session, model_id: str, is_available: bool) -> Model: # UUID
|
||||
"""切换模型可用状态"""
|
||||
model = db.query(Model).filter(Model.id == model_id).first()
|
||||
if not model:
|
||||
raise NotFoundException(f"模型 {model_id} 不存在")
|
||||
|
||||
model.is_available = is_available
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
|
||||
# 清除 Redis 缓存
|
||||
asyncio.create_task(
|
||||
ModelCacheService.invalidate_model_cache(
|
||||
model_id=model.id,
|
||||
provider_id=model.provider_id,
|
||||
global_model_id=model.global_model_id,
|
||||
)
|
||||
)
|
||||
|
||||
# 清除内存缓存(ModelMapperMiddleware 实例)
|
||||
if model.provider_id and model.global_model_id:
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_changed(model.provider_id, model.global_model_id)
|
||||
|
||||
status = "可用" if is_available else "不可用"
|
||||
logger.info(f"更新模型可用状态: id={model_id}, status={status}")
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def get_model_by_name(db: Session, provider_id: str, model_name: str) -> Optional[Model]:
|
||||
"""根据 provider_model_name 获取模型"""
|
||||
return (
|
||||
db.query(Model)
|
||||
.filter(and_(Model.provider_id == provider_id, Model.provider_model_name == model_name))
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def batch_create_models(
|
||||
db: Session, provider_id: str, models_data: List[ModelCreate]
|
||||
) -> List[Model]: # UUID
|
||||
"""批量创建模型"""
|
||||
# 检查提供商是否存在
|
||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException(f"提供商 {provider_id} 不存在")
|
||||
|
||||
created_models = []
|
||||
for model_data in models_data:
|
||||
# 检查是否已存在
|
||||
existing = (
|
||||
db.query(Model)
|
||||
.filter(
|
||||
and_(
|
||||
Model.provider_id == provider_id,
|
||||
Model.provider_model_name == model_data.provider_model_name,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
logger.warning(f"模型 {model_data.provider_model_name} 已存在,跳过创建")
|
||||
continue
|
||||
|
||||
model = Model(
|
||||
provider_id=provider_id,
|
||||
global_model_id=model_data.global_model_id,
|
||||
provider_model_name=model_data.provider_model_name,
|
||||
price_per_request=model_data.price_per_request,
|
||||
tiered_pricing=model_data.tiered_pricing,
|
||||
supports_vision=model_data.supports_vision,
|
||||
supports_function_calling=model_data.supports_function_calling,
|
||||
supports_streaming=model_data.supports_streaming,
|
||||
supports_extended_thinking=model_data.supports_extended_thinking,
|
||||
is_active=model_data.is_active,
|
||||
config=model_data.config,
|
||||
)
|
||||
db.add(model)
|
||||
created_models.append(model)
|
||||
|
||||
if created_models:
|
||||
try:
|
||||
db.commit()
|
||||
for model in created_models:
|
||||
db.refresh(model)
|
||||
logger.info(f"批量创建 {len(created_models)} 个模型成功")
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
logger.error(f"批量创建模型失败: {str(e)}")
|
||||
raise InvalidRequestException("批量创建模型失败")
|
||||
|
||||
return created_models
|
||||
|
||||
@staticmethod
|
||||
def convert_to_response(model: Model) -> ModelResponse:
|
||||
"""转换为响应模型(新架构:从 GlobalModel 获取显示信息和默认值)"""
|
||||
return ModelResponse(
|
||||
id=model.id,
|
||||
provider_id=model.provider_id,
|
||||
global_model_id=model.global_model_id,
|
||||
provider_model_name=model.provider_model_name,
|
||||
# 原始配置值(可能为空)
|
||||
price_per_request=model.price_per_request,
|
||||
tiered_pricing=model.tiered_pricing,
|
||||
supports_vision=model.supports_vision,
|
||||
supports_function_calling=model.supports_function_calling,
|
||||
supports_streaming=model.supports_streaming,
|
||||
supports_extended_thinking=model.supports_extended_thinking,
|
||||
supports_image_generation=model.supports_image_generation,
|
||||
# 有效值(合并 Model 和 GlobalModel 默认值)
|
||||
effective_tiered_pricing=model.get_effective_tiered_pricing(),
|
||||
effective_input_price=model.get_effective_input_price(),
|
||||
effective_output_price=model.get_effective_output_price(),
|
||||
effective_price_per_request=model.get_effective_price_per_request(),
|
||||
effective_supports_vision=model.get_effective_supports_vision(),
|
||||
effective_supports_function_calling=model.get_effective_supports_function_calling(),
|
||||
effective_supports_streaming=model.get_effective_supports_streaming(),
|
||||
effective_supports_extended_thinking=model.get_effective_supports_extended_thinking(),
|
||||
effective_supports_image_generation=model.get_effective_supports_image_generation(),
|
||||
is_active=model.is_active,
|
||||
is_available=model.is_available if model.is_available is not None else True,
|
||||
created_at=model.created_at,
|
||||
updated_at=model.updated_at,
|
||||
# GlobalModel 信息(如果存在)
|
||||
global_model_name=model.global_model.name if model.global_model else None,
|
||||
global_model_display_name=(
|
||||
model.global_model.display_name if model.global_model else None
|
||||
),
|
||||
)
|
||||
Reference in New Issue
Block a user