Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
"""
模型服务模块
包含模型管理、模型映射、成本计算等功能。
"""
from src.services.model.cost import ModelCostService
from src.services.model.global_model import GlobalModelService
from src.services.model.mapper import ModelMapperMiddleware
from src.services.model.mapping_resolver import ModelMappingResolver
from src.services.model.service import ModelService
__all__ = [
"ModelService",
"GlobalModelService",
"ModelMapperMiddleware",
"ModelMappingResolver",
"ModelCostService",
]

946
src/services/model/cost.py Normal file
View File

@@ -0,0 +1,946 @@
"""
模型成本服务
负责统一的价格解析、缓存以及成本计算逻辑。
支持固定价格、按次计费和阶梯计费三种模式。
计费策略:
- 不同 API format 可以有不同的计费逻辑
- 通过 PricingStrategy 抽象,支持自定义总输入上下文计算、缓存 TTL 差异化等
"""
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.models.database import GlobalModel, Model, ModelMapping, Provider
ProviderRef = Union[str, Provider, None]
@dataclass
class TieredPriceResult:
"""阶梯计费价格查询结果"""
input_price_per_1m: float
output_price_per_1m: float
cache_creation_price_per_1m: Optional[float] = None
cache_read_price_per_1m: Optional[float] = None
tier_index: int = 0 # 命中的阶梯索引
@dataclass
class CostBreakdown:
"""成本明细"""
input_cost: float
output_cost: float
cache_creation_cost: float
cache_read_cost: float
cache_cost: float
request_cost: float
total_cost: float
class ModelCostService:
"""集中负责模型价格与成本计算,避免在 mapper/usage 中重复实现。"""
_price_cache: Dict[str, Dict[str, float]] = {}
_cache_price_cache: Dict[str, Dict[str, float]] = {}
_tiered_pricing_cache: Dict[str, Optional[dict]] = {}
def __init__(self, db: Session):
self.db = db
# ------------------------------------------------------------------
# 阶梯计费相关方法
# ------------------------------------------------------------------
@staticmethod
def get_tier_for_tokens(
tiered_pricing: dict,
total_input_tokens: int
) -> Optional[dict]:
"""
根据总输入 token 数确定价格阶梯。
Args:
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
total_input_tokens: 总输入 token 数input_tokens + cache_read_tokens
Returns:
匹配的阶梯配置,如果未找到返回 None
"""
if not tiered_pricing or "tiers" not in tiered_pricing:
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
for tier in tiers:
up_to = tier.get("up_to")
if up_to is None or total_input_tokens <= up_to:
return tier
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
return tiers[-1] if tiers else None
@staticmethod
def get_cache_read_price_for_ttl(
tier: dict,
cache_ttl_minutes: Optional[int] = None
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格。
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟),如果为 None 使用默认价格
Returns:
缓存读取价格
"""
# 首先检查是否有 TTL 差异化定价
ttl_pricing = tier.get("cache_ttl_pricing")
if ttl_pricing and cache_ttl_minutes is not None:
# 找到匹配或最接近的 TTL 价格
matched_price = None
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
matched_price = ttl_config.get("cache_read_price_per_1m")
break
if matched_price is not None:
return matched_price
# 如果超过所有配置的 TTL使用最后一个
if ttl_pricing:
return ttl_pricing[-1].get("cache_read_price_per_1m")
# 使用默认的缓存读取价格
return tier.get("cache_read_price_per_1m")
async def get_tiered_pricing_async(
self, provider: ProviderRef, model: str
) -> Optional[dict]:
"""
异步获取模型的阶梯计费配置。
Args:
provider: Provider 对象或提供商名称
model: 模型名称
Returns:
阶梯计费配置,如果未配置返回 None
"""
result = await self.get_tiered_pricing_with_source_async(provider, model)
return result.get("pricing") if result else None
async def get_tiered_pricing_with_source_async(
self, provider: ProviderRef, model: str
) -> Optional[dict]:
"""
异步获取模型的阶梯计费配置及来源信息。
Args:
provider: Provider 对象或提供商名称
model: 模型名称
Returns:
包含 pricing 和 source 的字典:
- pricing: 阶梯计费配置
- source: 'provider''global'
"""
provider_name = self._provider_name(provider)
cache_key = f"{provider_name}:{model}:tiered_with_source"
if cache_key in self._tiered_pricing_cache:
return self._tiered_pricing_cache[cache_key]
provider_obj = self._resolve_provider(provider)
result = None
if provider_obj:
from src.services.model.mapping_resolver import resolve_model_to_global_name
global_model_name = await resolve_model_to_global_name(
self.db, model, provider_obj.id
)
global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == global_model_name,
GlobalModel.is_active == True,
)
.first()
)
if global_model:
model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == global_model.id,
Model.is_active == True,
)
.first()
)
if model_obj:
# 判断定价来源
if model_obj.tiered_pricing is not None:
result = {
"pricing": model_obj.tiered_pricing,
"source": "provider"
}
elif global_model.default_tiered_pricing is not None:
result = {
"pricing": global_model.default_tiered_pricing,
"source": "global"
}
self._tiered_pricing_cache[cache_key] = result
return result
def get_tiered_pricing(self, provider: ProviderRef, model: str) -> Optional[dict]:
"""同步获取模型的阶梯计费配置。"""
import asyncio
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(self.get_tiered_pricing_async(provider, model))
# ------------------------------------------------------------------
# 公共方法
# ------------------------------------------------------------------
async def get_model_price_async(self, provider: ProviderRef, model: str) -> Tuple[float, float]:
"""
异步版本: 返回给定 provider/model 的 (input_price, output_price)。
注意:如果模型配置了阶梯计费,此方法返回第一个阶梯的价格作为默认值。
实际计费时应使用 compute_cost_with_tiered_pricing 方法。
计费逻辑(基于 mapping_type:
1. 查找 ModelMapping如果存在
2. 如果 mapping_type='alias':使用目标 GlobalModel 的价格
3. 如果 mapping_type='mapping':尝试使用 source_model 对应的 GlobalModel 价格
- 如果 source_model 对应的 GlobalModel 存在且有 Model 实现,使用那个价格
- 否则回退到目标 GlobalModel 的价格
4. 如果没有找到任何 ModelMapping尝试直接匹配 GlobalModel.name
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
Returns:
(input_price, output_price) 元组
"""
provider_name = self._provider_name(provider)
cache_key = f"{provider_name}:{model}"
if cache_key in self._price_cache:
prices = self._price_cache[cache_key]
return prices["input"], prices["output"]
provider_obj = self._resolve_provider(provider)
input_price = None
output_price = None
if provider_obj:
# 步骤 1: 查找 ModelMapping 以确定 mapping_type
from src.models.database import ModelMapping
mapping = None
# 先查 Provider 特定映射
mapping = (
self.db.query(ModelMapping)
.filter(
ModelMapping.source_model == model,
ModelMapping.provider_id == provider_obj.id,
ModelMapping.is_active == True,
)
.first()
)
# 再查全局映射
if not mapping:
mapping = (
self.db.query(ModelMapping)
.filter(
ModelMapping.source_model == model,
ModelMapping.provider_id.is_(None),
ModelMapping.is_active == True,
)
.first()
)
if mapping:
# 有映射,根据 mapping_type 决定计费模型
if mapping.mapping_type == "mapping":
# mapping 模式:尝试使用 source_model 对应的 GlobalModel 价格
source_global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == model,
GlobalModel.is_active == True,
)
.first()
)
if source_global_model:
source_model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == source_global_model.id,
Model.is_active == True,
)
.first()
)
if source_model_obj:
# 检查是否有阶梯计费
tiered = source_model_obj.get_effective_tiered_pricing()
if tiered and tiered.get("tiers"):
first_tier = tiered["tiers"][0]
input_price = first_tier.get("input_price_per_1m", 0)
output_price = first_tier.get("output_price_per_1m", 0)
else:
input_price = source_model_obj.get_effective_input_price()
output_price = source_model_obj.get_effective_output_price()
logger.debug(f"[mapping模式] 使用源模型价格: {model} "
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
# alias 模式或 mapping 模式未找到源模型价格:使用目标 GlobalModel 价格
if input_price is None:
target_global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.id == mapping.target_global_model_id,
GlobalModel.is_active == True,
)
.first()
)
if target_global_model:
target_model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == target_global_model.id,
Model.is_active == True,
)
.first()
)
if target_model_obj:
# 检查是否有阶梯计费
tiered = target_model_obj.get_effective_tiered_pricing()
if tiered and tiered.get("tiers"):
first_tier = tiered["tiers"][0]
input_price = first_tier.get("input_price_per_1m", 0)
output_price = first_tier.get("output_price_per_1m", 0)
else:
input_price = target_model_obj.get_effective_input_price()
output_price = target_model_obj.get_effective_output_price()
mode_label = (
"alias模式"
if mapping.mapping_type == "alias"
else "mapping模式(回退)"
)
logger.debug(f"[{mode_label}] 使用目标模型价格: {model} -> {target_global_model.name} "
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
else:
# 没有映射,尝试直接匹配 GlobalModel.name
global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == model,
GlobalModel.is_active == True,
)
.first()
)
if global_model:
model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == global_model.id,
Model.is_active == True,
)
.first()
)
if model_obj:
# 检查是否有阶梯计费
tiered = model_obj.get_effective_tiered_pricing()
if tiered and tiered.get("tiers"):
first_tier = tiered["tiers"][0]
input_price = first_tier.get("input_price_per_1m", 0)
output_price = first_tier.get("output_price_per_1m", 0)
else:
input_price = model_obj.get_effective_input_price()
output_price = model_obj.get_effective_output_price()
logger.debug(f"找到模型价格配置: {provider_name}/{model} "
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
# 如果没有找到价格配置,使用 0.0 并记录警告
if input_price is None:
input_price = 0.0
if output_price is None:
output_price = 0.0
# 检查是否有按次计费配置(按次计费模型的 token 价格可以为 0
if input_price == 0.0 and output_price == 0.0:
# 异步检查按次计费价格
price_per_request = await self.get_request_price_async(provider, model)
if price_per_request is None or price_per_request == 0.0:
logger.warning(f"未找到模型价格配置: {provider_name}/{model},请在 GlobalModel 中配置价格")
self._price_cache[cache_key] = {"input": input_price, "output": output_price}
return input_price, output_price
def get_model_price(self, provider: ProviderRef, model: str) -> Tuple[float, float]:
"""
返回给定 provider/model 的 (input_price, output_price)。
新架构逻辑:
1. 使用 ModelMappingResolver 解析别名(如果是)
2. 解析为 GlobalModel.name
3. 查找该 Provider 的 Model 实现
4. 获取价格配置
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
Returns:
(input_price, output_price) 元组
"""
import asyncio
# 在同步上下文中调用异步方法
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(self.get_model_price_async(provider, model))
async def get_cache_prices_async(
self, provider: ProviderRef, model: str, input_price: float
) -> Tuple[Optional[float], Optional[float]]:
"""
异步版本: 返回缓存创建/读取价格(每 1M tokens
新架构逻辑:
1. 使用 ModelMappingResolver 解析别名(如果是)
2. 解析为 GlobalModel.name
3. 查找该 Provider 的 Model 实现
4. 获取缓存价格配置
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
input_price: 基础输入价格(用于 Claude 模型的默认估算)
Returns:
(cache_creation_price, cache_read_price) 元组
"""
provider_name = self._provider_name(provider)
cache_key = f"{provider_name}:{model}"
if cache_key in self._cache_price_cache:
prices = self._cache_price_cache[cache_key]
return prices["creation"], prices["read"]
provider_obj = self._resolve_provider(provider)
cache_creation_price = None
cache_read_price = None
if provider_obj:
# 步骤 1: 检查是否是别名
from src.services.model.mapping_resolver import resolve_model_to_global_name
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
# 步骤 2: 查找 GlobalModel
global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == global_model_name,
GlobalModel.is_active == True,
)
.first()
)
# 步骤 3: 查找该 Provider 的 Model 实现
if global_model:
model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == global_model.id,
Model.is_active == True,
)
.first()
)
if model_obj:
# 检查是否有阶梯计费配置
tiered = model_obj.get_effective_tiered_pricing()
if tiered and tiered.get("tiers"):
# 使用第一个阶梯的缓存价格作为默认值
first_tier = tiered["tiers"][0]
cache_creation_price = first_tier.get("cache_creation_price_per_1m")
cache_read_price = first_tier.get("cache_read_price_per_1m")
else:
# 使用 get_effective_* 方法,会自动回退到 GlobalModel 的默认值
cache_creation_price = model_obj.get_effective_cache_creation_price()
cache_read_price = model_obj.get_effective_cache_read_price()
# 默认缓存价格估算(如果没有配置)- 基于输入价格计算
if cache_creation_price is None or cache_read_price is None:
if cache_creation_price is None:
cache_creation_price = input_price * 1.25
if cache_read_price is None:
cache_read_price = input_price * 0.1
self._cache_price_cache[cache_key] = {
"creation": cache_creation_price,
"read": cache_read_price,
}
return cache_creation_price, cache_read_price
async def get_request_price_async(self, provider: ProviderRef, model: str) -> Optional[float]:
"""
异步版本: 返回按次计费价格(每次请求的固定费用)。
新架构逻辑:
1. 使用 ModelMappingResolver 解析别名(如果是)
2. 解析为 GlobalModel.name
3. 查找该 Provider 的 Model 实现
4. 获取按次计费价格配置
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
Returns:
按次计费价格,如果没有配置则返回 None
"""
provider_obj = self._resolve_provider(provider)
price_per_request = None
if provider_obj:
# 步骤 1: 检查是否是别名
from src.services.model.mapping_resolver import resolve_model_to_global_name
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
# 步骤 2: 查找 GlobalModel
global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == global_model_name,
GlobalModel.is_active == True,
)
.first()
)
# 步骤 3: 查找该 Provider 的 Model 实现
if global_model:
model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == global_model.id,
Model.is_active == True,
)
.first()
)
if model_obj:
# 使用 get_effective_* 方法,会自动回退到 GlobalModel 的默认值
price_per_request = model_obj.get_effective_price_per_request()
return price_per_request
def get_request_price(self, provider: ProviderRef, model: str) -> Optional[float]:
"""
返回按次计费价格(每次请求的固定费用)。
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
Returns:
按次计费价格,如果没有配置则返回 None
"""
import asyncio
# 在同步上下文中调用异步方法
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(self.get_request_price_async(provider, model))
def get_cache_prices(
self, provider: ProviderRef, model: str, input_price: float
) -> Tuple[Optional[float], Optional[float]]:
"""
返回缓存创建/读取价格(每 1M tokens
新架构逻辑:
1. 使用 ModelMappingResolver 解析别名(如果是)
2. 解析为 GlobalModel.name
3. 查找该 Provider 的 Model 实现
4. 获取缓存价格配置
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
input_price: 基础输入价格(用于 Claude 模型的默认估算)
Returns:
(cache_creation_price, cache_read_price) 元组
"""
import asyncio
# 在同步上下文中调用异步方法
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(self.get_cache_prices_async(provider, model, input_price))
def calculate_cost(
self,
provider: Provider,
model: str,
input_tokens: int,
output_tokens: int,
) -> Dict[str, float]:
"""返回与旧 ModelMapper.calculate_cost 相同结构的费用信息。"""
input_price, output_price = self.get_model_price(provider, model)
input_cost, output_cost, _, _, _, _, total_cost = self.compute_cost(
input_tokens=input_tokens,
output_tokens=output_tokens,
input_price_per_1m=input_price,
output_price_per_1m=output_price,
)
return {
"input_cost": round(input_cost, 6),
"output_cost": round(output_cost, 6),
"total_cost": round(total_cost, 6),
"input_price_per_1m": input_price,
"output_price_per_1m": output_price,
}
@staticmethod
def compute_cost(
*,
input_tokens: int,
output_tokens: int,
input_price_per_1m: float,
output_price_per_1m: float,
cache_creation_input_tokens: int = 0,
cache_read_input_tokens: int = 0,
cache_creation_price_per_1m: Optional[float] = None,
cache_read_price_per_1m: Optional[float] = None,
price_per_request: Optional[float] = None,
) -> Tuple[float, float, float, float, float, float, float]:
"""成本计算核心逻辑(固定价格模式),供 UsageService 等复用。
Returns:
Tuple of (input_cost, output_cost, cache_creation_cost,
cache_read_cost, cache_cost, request_cost, total_cost)
"""
input_cost = (input_tokens / 1_000_000) * input_price_per_1m
output_cost = (output_tokens / 1_000_000) * output_price_per_1m
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and cache_creation_price_per_1m is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * cache_creation_price_per_1m
if cache_read_input_tokens > 0 and cache_read_price_per_1m is not None:
cache_read_cost = (cache_read_input_tokens / 1_000_000) * cache_read_price_per_1m
cache_cost = cache_creation_cost + cache_read_cost
# 按次计费成本
request_cost = price_per_request if price_per_request is not None else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return (
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
cache_cost,
request_cost,
total_cost,
)
@staticmethod
def compute_cost_with_tiered_pricing(
*,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int = 0,
cache_read_input_tokens: int = 0,
tiered_pricing: Optional[dict] = None,
cache_ttl_minutes: Optional[int] = None,
price_per_request: Optional[float] = None,
# 回退价格(当没有阶梯配置时使用)
fallback_input_price_per_1m: float = 0.0,
fallback_output_price_per_1m: float = 0.0,
fallback_cache_creation_price_per_1m: Optional[float] = None,
fallback_cache_read_price_per_1m: Optional[float] = None,
) -> Tuple[float, float, float, float, float, float, float, Optional[int]]:
"""
支持阶梯计费的成本计算核心逻辑。
阶梯判定:使用 input_tokens + cache_read_input_tokens总输入上下文
Args:
input_tokens: 输入 token 数
output_tokens: 输出 token 数
cache_creation_input_tokens: 缓存创建 token 数
cache_read_input_tokens: 缓存读取 token 数
tiered_pricing: 阶梯计费配置
cache_ttl_minutes: 缓存时长(分钟),用于 TTL 差异化定价
price_per_request: 按次计费价格
fallback_*: 回退价格配置
Returns:
Tuple of (input_cost, output_cost, cache_creation_cost,
cache_read_cost, cache_cost, request_cost, total_cost, tier_index)
tier_index: 命中的阶梯索引0-based如果未使用阶梯计费则为 None
"""
# 计算总输入上下文(用于阶梯判定)
total_input_context = input_tokens + cache_read_input_tokens
tier_index = None
input_price_per_1m = fallback_input_price_per_1m
output_price_per_1m = fallback_output_price_per_1m
cache_creation_price_per_1m = fallback_cache_creation_price_per_1m
cache_read_price_per_1m = fallback_cache_read_price_per_1m
# 如果有阶梯配置,查找匹配的阶梯
if tiered_pricing and tiered_pricing.get("tiers"):
tier = ModelCostService.get_tier_for_tokens(tiered_pricing, total_input_context)
if tier:
# 找到阶梯索引
tier_index = tiered_pricing["tiers"].index(tier)
input_price_per_1m = tier.get("input_price_per_1m", fallback_input_price_per_1m)
output_price_per_1m = tier.get("output_price_per_1m", fallback_output_price_per_1m)
cache_creation_price_per_1m = tier.get(
"cache_creation_price_per_1m", fallback_cache_creation_price_per_1m
)
# 获取缓存读取价格(考虑 TTL 差异化)
cache_read_price_per_1m = ModelCostService.get_cache_read_price_for_ttl(
tier, cache_ttl_minutes
)
if cache_read_price_per_1m is None:
cache_read_price_per_1m = fallback_cache_read_price_per_1m
logger.debug(
f"[阶梯计费] 总输入上下文: {total_input_context}, "
f"命中阶梯: {tier_index + 1}, "
f"输入价格: ${input_price_per_1m}/M, "
f"输出价格: ${output_price_per_1m}/M"
)
# 计算成本
input_cost = (input_tokens / 1_000_000) * input_price_per_1m
output_cost = (output_tokens / 1_000_000) * output_price_per_1m
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and cache_creation_price_per_1m is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * cache_creation_price_per_1m
if cache_read_input_tokens > 0 and cache_read_price_per_1m is not None:
cache_read_cost = (cache_read_input_tokens / 1_000_000) * cache_read_price_per_1m
cache_cost = cache_creation_cost + cache_read_cost
# 按次计费成本
request_cost = price_per_request if price_per_request is not None else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return (
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
cache_cost,
request_cost,
total_cost,
tier_index,
)
@classmethod
def clear_cache(cls):
"""清理价格相关缓存。"""
cls._price_cache.clear()
cls._cache_price_cache.clear()
cls._tiered_pricing_cache.clear()
# ------------------------------------------------------------------
# 内部辅助
# ------------------------------------------------------------------
def _provider_name(self, provider: ProviderRef) -> str:
if isinstance(provider, Provider):
return provider.name
return provider or "unknown"
def _resolve_provider(self, provider: ProviderRef) -> Optional[Provider]:
if isinstance(provider, Provider):
return provider
if not provider or provider == "unknown":
return None
return self.db.query(Provider).filter(Provider.name == provider).first()
# ------------------------------------------------------------------
# 基于策略模式的计费方法
# ------------------------------------------------------------------
async def compute_cost_with_strategy_async(
self,
provider: ProviderRef,
model: str,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int = 0,
cache_read_input_tokens: int = 0,
api_format: Optional[str] = None,
cache_ttl_minutes: Optional[int] = None,
) -> Tuple[float, float, float, float, float, float, float, Optional[int]]:
"""
使用计费策略计算成本(异步版本)
根据 api_format 选择对应的 Adapter 计费逻辑,支持阶梯计费和 TTL 差异化。
Args:
provider: Provider 对象或提供商名称
model: 模型名称
input_tokens: 输入 token 数
output_tokens: 输出 token 数
cache_creation_input_tokens: 缓存创建 token 数
cache_read_input_tokens: 缓存读取 token 数
api_format: API 格式(用于选择计费策略)
cache_ttl_minutes: 缓存时长(分钟),用于 TTL 差异化定价
Returns:
Tuple of (input_cost, output_cost, cache_creation_cost,
cache_read_cost, cache_cost, request_cost, total_cost, tier_index)
"""
# 获取价格配置
input_price, output_price = await self.get_model_price_async(provider, model)
cache_creation_price, cache_read_price = await self.get_cache_prices_async(
provider, model, input_price
)
request_price = await self.get_request_price_async(provider, model)
tiered_pricing = await self.get_tiered_pricing_async(provider, model)
# 获取对应 API 格式的 Adapter 实例来计算成本
# 优先检查 Chat Adapter然后检查 CLI Adapter
from src.api.handlers.base.chat_adapter_base import get_adapter_instance
from src.api.handlers.base.cli_adapter_base import get_cli_adapter_instance
adapter = None
if api_format:
adapter = get_adapter_instance(api_format)
if adapter is None:
adapter = get_cli_adapter_instance(api_format)
if adapter:
# 使用 Adapter 的计费方法
result = adapter.compute_cost(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
input_price_per_1m=input_price,
output_price_per_1m=output_price,
cache_creation_price_per_1m=cache_creation_price,
cache_read_price_per_1m=cache_read_price,
price_per_request=request_price,
tiered_pricing=tiered_pricing,
cache_ttl_minutes=cache_ttl_minutes,
)
return (
result["input_cost"],
result["output_cost"],
result["cache_creation_cost"],
result["cache_read_cost"],
result["cache_cost"],
result["request_cost"],
result["total_cost"],
result["tier_index"],
)
else:
# 回退到默认计算逻辑(无 Adapter 时使用静态方法)
return self.compute_cost_with_tiered_pricing(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
tiered_pricing=tiered_pricing,
cache_ttl_minutes=cache_ttl_minutes,
price_per_request=request_price,
fallback_input_price_per_1m=input_price,
fallback_output_price_per_1m=output_price,
fallback_cache_creation_price_per_1m=cache_creation_price,
fallback_cache_read_price_per_1m=cache_read_price,
)
def compute_cost_with_strategy(
self,
provider: ProviderRef,
model: str,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int = 0,
cache_read_input_tokens: int = 0,
api_format: Optional[str] = None,
cache_ttl_minutes: Optional[int] = None,
) -> Tuple[float, float, float, float, float, float, float, Optional[int]]:
"""
使用计费策略计算成本(同步版本)
"""
import asyncio
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(
self.compute_cost_with_strategy_async(
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
api_format=api_format,
cache_ttl_minutes=cache_ttl_minutes,
)
)

View File

@@ -0,0 +1,299 @@
"""
GlobalModel 服务层
提供 GlobalModel 的 CRUD 操作、查询和统计功能
"""
from typing import Dict, List, Optional
from sqlalchemy import and_, func
from sqlalchemy.orm import Session, joinedload
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.models.database import GlobalModel, Model, ModelMapping
from src.models.pydantic_models import GlobalModelUpdate
class GlobalModelService:
"""GlobalModel 服务"""
@staticmethod
def get_global_model(db: Session, global_model_id: str) -> GlobalModel:
"""
获取单个 GlobalModel
Args:
global_model_id: GlobalModel 的 UUID 或 name
"""
# 先尝试通过 ID 查找
global_model = db.query(GlobalModel).filter(GlobalModel.id == global_model_id).first()
# 如果没找到,尝试通过 name 查找
if not global_model:
global_model = db.query(GlobalModel).filter(GlobalModel.name == global_model_id).first()
if not global_model:
raise NotFoundException(f"GlobalModel {global_model_id} not found")
return global_model
@staticmethod
def get_global_model_by_name(db: Session, name: str) -> Optional[GlobalModel]:
"""通过名称获取 GlobalModel"""
return db.query(GlobalModel).filter(GlobalModel.name == name).first()
@staticmethod
def list_global_models(
db: Session,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None,
) -> List[GlobalModel]:
"""列出 GlobalModel"""
query = db.query(GlobalModel)
if is_active is not None:
query = query.filter(GlobalModel.is_active == is_active)
if search:
search_pattern = f"%{search}%"
query = query.filter(
(GlobalModel.name.ilike(search_pattern))
| (GlobalModel.display_name.ilike(search_pattern))
| (GlobalModel.description.ilike(search_pattern))
)
# 按名称排序
query = query.order_by(GlobalModel.name)
return query.offset(skip).limit(limit).all()
@staticmethod
def create_global_model(
db: Session,
name: str,
display_name: str,
description: Optional[str] = None,
official_url: Optional[str] = None,
icon_url: Optional[str] = None,
is_active: Optional[bool] = True,
# 按次计费配置
default_price_per_request: Optional[float] = None,
# 阶梯计费配置(必填)
default_tiered_pricing: dict = None,
# 默认能力配置
default_supports_vision: Optional[bool] = None,
default_supports_function_calling: Optional[bool] = None,
default_supports_streaming: Optional[bool] = None,
default_supports_extended_thinking: Optional[bool] = None,
# Key 能力配置
supported_capabilities: Optional[List[str]] = None,
) -> GlobalModel:
"""创建 GlobalModel"""
# 检查名称是否已存在
existing = GlobalModelService.get_global_model_by_name(db, name)
if existing:
raise InvalidRequestException(f"GlobalModel with name '{name}' already exists")
global_model = GlobalModel(
name=name,
display_name=display_name,
description=description,
official_url=official_url,
icon_url=icon_url,
is_active=is_active,
# 按次计费配置
default_price_per_request=default_price_per_request,
# 阶梯计费配置
default_tiered_pricing=default_tiered_pricing,
# 默认能力配置
default_supports_vision=default_supports_vision,
default_supports_function_calling=default_supports_function_calling,
default_supports_streaming=default_supports_streaming,
default_supports_extended_thinking=default_supports_extended_thinking,
# Key 能力配置
supported_capabilities=supported_capabilities,
)
db.add(global_model)
db.commit()
db.refresh(global_model)
return global_model
@staticmethod
def update_global_model(
db: Session,
global_model_id: str,
update_data: GlobalModelUpdate,
) -> GlobalModel:
"""
更新 GlobalModel
使用 exclude_unset=True 来区分"未提供字段""显式设置为 None"
- 未提供的字段不会被更新
- 显式设置为 None 的字段会被更新为 None置空
"""
global_model = GlobalModelService.get_global_model(db, global_model_id)
# 只更新显式设置的字段(包括显式设置为 None 的情况)
data_dict = update_data.model_dump(exclude_unset=True)
# 处理阶梯计费配置:如果是 TieredPricingConfig 对象,转换为 dict
if "default_tiered_pricing" in data_dict:
tiered_pricing = data_dict["default_tiered_pricing"]
if tiered_pricing is not None and hasattr(tiered_pricing, "model_dump"):
data_dict["default_tiered_pricing"] = tiered_pricing.model_dump()
for field, value in data_dict.items():
setattr(global_model, field, value)
db.commit()
db.refresh(global_model)
return global_model
@staticmethod
def delete_global_model(db: Session, global_model_id: str) -> None:
"""
删除 GlobalModel
默认行为: 级联删除所有关联的 Provider 模型实现
"""
global_model = GlobalModelService.get_global_model(db, global_model_id)
# 查找所有关联的 Model使用 global_model.id预加载 provider 关联)
associated_models = (
db.query(Model)
.options(joinedload(Model.provider))
.filter(Model.global_model_id == global_model.id)
.all()
)
# 级联删除所有关联的 Provider 模型实现
if associated_models:
logger.info(f"删除 GlobalModel {global_model.name}{len(associated_models)} 个关联 Provider 模型")
for model in associated_models:
db.delete(model)
# 删除 GlobalModel
db.delete(global_model)
db.commit()
@staticmethod
def get_global_model_stats(db: Session, global_model_id: str) -> Dict:
"""获取 GlobalModel 统计信息"""
global_model = GlobalModelService.get_global_model(db, global_model_id)
# 统计关联的 Model 数量(使用 global_model.id预加载 provider 关联)
models = (
db.query(Model)
.options(joinedload(Model.provider))
.filter(Model.global_model_id == global_model.id)
.all()
)
# 统计支持的 Provider 数量
provider_ids = set(model.provider_id for model in models)
# 从阶梯计费中提取价格范围
input_prices = []
output_prices = []
for m in models:
tiered = m.get_effective_tiered_pricing()
if tiered and tiered.get("tiers"):
first_tier = tiered["tiers"][0]
if first_tier.get("input_price_per_1m") is not None:
input_prices.append(first_tier["input_price_per_1m"])
if first_tier.get("output_price_per_1m") is not None:
output_prices.append(first_tier["output_price_per_1m"])
return {
"global_model_id": global_model.id,
"name": global_model.name,
"total_models": len(models),
"total_providers": len(provider_ids),
"price_range": {
"min_input": min(input_prices) if input_prices else None,
"max_input": max(input_prices) if input_prices else None,
"min_output": min(output_prices) if output_prices else None,
"max_output": max(output_prices) if output_prices else None,
},
}
@staticmethod
def batch_assign_to_providers(
db: Session,
global_model_id: str,
provider_ids: List[str],
create_models: bool = False,
) -> Dict:
"""批量为多个 Provider 添加 GlobalModel 实现"""
from .service import ModelService
global_model = GlobalModelService.get_global_model(db, global_model_id)
results = {
"success": [],
"errors": [],
}
for provider_id in provider_ids:
try:
# 检查该 Provider 是否已有该 GlobalModel 的实现(使用 global_model.id
existing_model = (
db.query(Model)
.filter(
Model.provider_id == provider_id,
Model.global_model_id == global_model.id,
)
.first()
)
if existing_model:
results["errors"].append(
{
"provider_id": provider_id,
"error": "Model already exists for this provider",
}
)
continue
if create_models:
# 创建新的 Model价格和能力设为 None继承 GlobalModel 默认值)
model = Model(
provider_id=provider_id,
global_model_id=global_model.id,
provider_model_name=global_model.name, # 默认使用 GlobalModel name
# 计费设为 None使用 GlobalModel 默认值
price_per_request=None,
tiered_pricing=None,
# 能力设为 None使用 GlobalModel 默认值
supports_vision=None,
supports_function_calling=None,
supports_streaming=None,
supports_extended_thinking=None,
is_active=True,
)
db.add(model)
db.commit()
results["success"].append(
{"provider_id": provider_id, "model_id": model.id, "created": True}
)
else:
results["errors"].append(
{
"provider_id": provider_id,
"error": "create_models=False, no existing model found",
}
)
except Exception as e:
db.rollback()
results["errors"].append({"provider_id": provider_id, "error": str(e)})
db.commit()
return results

View File

@@ -0,0 +1,442 @@
"""
模型映射中间件
根据数据库中的配置,将用户请求的模型映射到提供商的实际模型
"""
from typing import Dict, List, Optional
from sqlalchemy.orm import Session
from src.core.cache_utils import SyncLRUCache
from src.core.logger import logger
from src.models.claude import ClaudeMessagesRequest
from src.models.database import GlobalModel, Model, ModelMapping, Provider, ProviderEndpoint
from src.services.cache.model_cache import ModelCacheService
from src.services.model.mapping_resolver import (
get_model_mapping_resolver,
resolve_model_to_global_name,
)
class ModelMapperMiddleware:
"""
模型映射中间件
负责将用户请求的模型名映射到提供商的实际模型名
"""
def __init__(self, db: Session, cache_max_size: int = 1000, cache_ttl: int = 300):
"""
初始化模型映射中间件
Args:
db: 数据库会话
cache_max_size: 缓存最大容量(默认 1000
cache_ttl: 缓存过期时间(秒,默认 300
"""
self.db = db
self._cache = SyncLRUCache(max_size=cache_max_size, ttl=cache_ttl)
logger.debug(f"[ModelMapper] 初始化max_size={cache_max_size}, ttl={cache_ttl}s")
# 注册到缓存失效服务
try:
from src.services.cache.invalidation import get_cache_invalidation_service
cache_service = get_cache_invalidation_service()
cache_service.register_model_mapper(self)
logger.debug("[ModelMapper] 已注册到缓存失效服务")
except Exception as e:
logger.warning(f"[ModelMapper] 注册缓存失效服务失败: {e}")
async def apply_mapping(
self, request: ClaudeMessagesRequest, provider: Provider
) -> ClaudeMessagesRequest:
"""
应用模型映射到请求
Args:
request: 原始请求
provider: 目标提供商
Returns:
应用映射后的请求
"""
# 获取请求的模型名
source_model = request.model
# 查找映射
mapping = await self.get_mapping(source_model, provider.id)
if mapping:
# 应用映射
original_model = request.model
request.model = mapping.model.provider_model_name
logger.debug(f"Applied model mapping for provider {provider.name}: "
f"{original_model} -> {mapping.model.provider_model_name}")
else:
# 没有找到映射,使用原始模型名
logger.debug(f"No model mapping found for {source_model} with provider {provider.name}, "
f"forwarding with original model name")
return request
async def get_mapping(
self, source_model: str, provider_id: str
) -> Optional[ModelMapping]: # UUID
"""
获取模型映射
优化后逻辑:
1. 使用统一的 ModelMappingResolver 解析别名(带缓存)
2. 通过 GlobalModel 找到该 Provider 的 Model 实现
3. 使用独立的映射缓存
Args:
source_model: 用户请求的模型名或别名
provider_id: 提供商ID (UUID)
Returns:
模型映射对象(包含 model 字段如果没有找到返回None
"""
# 检查缓存
cache_key = f"{provider_id}:{source_model}"
if cache_key in self._cache:
return self._cache[cache_key]
mapping = None
# 步骤 1 & 2: 通过统一的模型映射解析服务
mapping_resolver = get_model_mapping_resolver()
global_model = await mapping_resolver.get_global_model_by_request(
self.db, source_model, provider_id
)
if not global_model:
logger.debug(f"GlobalModel not found: {source_model} (provider={provider_id[:8]}...)")
self._cache[cache_key] = None
return None
# 步骤 3: 查找该 Provider 是否有实现这个 GlobalModel 的 Model使用缓存
model = await ModelCacheService.get_model_by_provider_and_global_model(
self.db, provider_id, global_model.id
)
if model:
# 只有当模型名发生变化时才返回映射
if model.provider_model_name != source_model:
mapping = type(
"obj",
(object,),
{
"source_model": source_model,
"model": model,
"is_active": True,
"provider_id": provider_id,
},
)()
logger.debug(f"Found model mapping: {source_model} -> {model.provider_model_name} "
f"(provider={provider_id[:8]}...)")
else:
logger.debug(f"Model found but no name change: {source_model} (provider={provider_id[:8]}...)")
# 缓存结果
self._cache[cache_key] = mapping
return mapping
def get_all_mappings(self, provider_id: str) -> List[ModelMapping]: # UUID
"""
获取提供商的所有可用模型(通过 GlobalModel)
方案 A: 返回该 Provider 所有可用的 GlobalModel
Args:
provider_id: 提供商ID (UUID)
Returns:
模型映射列表(模拟的 ModelMapping 对象列表)
"""
# 查询该 Provider 的所有活跃 Model
models = (
self.db.query(Model)
.join(GlobalModel)
.filter(
Model.provider_id == provider_id,
Model.is_active == True,
GlobalModel.is_active == True,
)
.all()
)
# 构造兼容的 ModelMapping 对象列表
mappings = []
for model in models:
mapping = type(
"obj",
(object,),
{
"source_model": model.global_model.name,
"model": model,
"is_active": True,
"provider_id": provider_id,
},
)()
mappings.append(mapping)
return mappings
def get_supported_models(self, provider_id: str) -> List[str]: # UUID
"""
获取提供商支持的所有源模型名
Args:
provider_id: 提供商ID (UUID)
Returns:
支持的模型名列表
"""
mappings = self.get_all_mappings(provider_id)
return [mapping.source_model for mapping in mappings]
async def validate_request(
self, request: ClaudeMessagesRequest, provider: Provider
) -> tuple[bool, Optional[str]]:
"""
验证请求是否符合映射的限制
Args:
request: 请求对象
provider: 提供商对象
Returns:
(是否有效, 错误信息)
"""
mapping = await self.get_mapping(request.model, provider.id)
if not mapping:
# 没有映射,可能是默认支持的模型
return True, None
if not mapping.is_active:
return False, f"Model mapping for {request.model} is disabled"
# 不限制max_tokens作为中转服务不应该限制用户的请求
# if request.max_tokens and request.max_tokens > mapping.max_output_tokens:
# return False, (
# f"Requested max_tokens {request.max_tokens} exceeds limit "
# f"{mapping.max_output_tokens} for model {request.model}"
# )
# 可以添加更多验证逻辑,比如检查输入长度等
return True, None
def clear_cache(self):
"""清空缓存"""
self._cache.clear()
logger.debug("Model mapping cache cleared")
def refresh_cache(self, provider_id: Optional[str] = None): # UUID
"""
刷新缓存
Args:
provider_id: 如果指定,只刷新该提供商的缓存 (UUID)
"""
if provider_id:
# 清除特定提供商的缓存
keys_to_remove = [
key for key in self._cache.keys() if key.startswith(f"{provider_id}:")
]
for key in keys_to_remove:
del self._cache[key]
logger.debug(f"Refreshed cache for provider {provider_id}")
else:
# 清空所有缓存
self.clear_cache()
class ModelRoutingMiddleware:
"""
模型路由中间件
根据模型名选择合适的提供商
"""
def __init__(self, db: Session):
"""
初始化模型路由中间件
Args:
db: 数据库会话
"""
self.db = db
self.mapper = ModelMapperMiddleware(db)
def select_provider(
self,
model_name: str,
preferred_provider: Optional[str] = None,
allowed_api_formats: Optional[List[str]] = None,
request_id: Optional[str] = None,
) -> Optional[Provider]:
"""
根据模型名选择提供商
逻辑:
1. 如果指定了提供商,使用指定的提供商
2. 如果没指定,使用默认提供商
3. 选定提供商后会检查该提供商的模型映射在apply_mapping中处理
4. 如果指定了allowed_api_formats只选择符合格式的提供商
Args:
model_name: 请求的模型名
preferred_provider: 首选提供商名称
allowed_api_formats: 允许的API格式列表如 ['CLAUDE', 'CLAUDE_CLI']
request_id: 请求ID用于日志关联
Returns:
选中的提供商如果没有找到返回None
"""
request_prefix = f"ID:{request_id} | " if request_id else ""
# 1. 如果指定了提供商,直接使用
if preferred_provider:
provider = (
self.db.query(Provider)
.filter(Provider.name == preferred_provider, Provider.is_active == True)
.first()
)
if provider:
# 检查API格式 - 从 endpoints 中检查
if allowed_api_formats:
# 检查是否有符合要求的活跃端点
has_matching_endpoint = any(
ep.is_active and ep.api_format and ep.api_format in allowed_api_formats
for ep in provider.endpoints
)
if not has_matching_endpoint:
logger.warning(f"Specified provider {provider.name} has no active endpoints with allowed API formats ({allowed_api_formats})")
# 不返回该提供商,继续查找
else:
logger.debug(f" └─ {request_prefix}使用指定提供商: {provider.name} | 模型:{model_name}")
return provider
else:
logger.debug(f" └─ {request_prefix}使用指定提供商: {provider.name} | 模型:{model_name}")
return provider
else:
logger.warning(f"Specified provider {preferred_provider} not found or inactive")
# 2. 查找优先级最高的活动提供商provider_priority 最小)
query = self.db.query(Provider).filter(Provider.is_active == True)
# 如果指定了API格式过滤添加过滤条件 - 检查是否有符合要求的 endpoint
if allowed_api_formats:
query = (
query.join(ProviderEndpoint)
.filter(
ProviderEndpoint.is_active == True,
ProviderEndpoint.api_format.in_(allowed_api_formats),
)
.distinct()
)
# 按 provider_priority 排序,优先级最高(数字最小)的在前
best_provider = query.order_by(Provider.provider_priority.asc(), Provider.id.asc()).first()
if best_provider:
logger.debug(f" └─ {request_prefix}使用优先级最高提供商: {best_provider.name} (priority:{best_provider.provider_priority}) | 模型:{model_name}")
return best_provider
# 3. 没有任何活动提供商
if allowed_api_formats:
logger.error(f"No active providers found with allowed API formats {allowed_api_formats}. Please configure at least one provider.")
else:
logger.error("No active providers found. Please configure at least one provider.")
return None
def get_available_models(self) -> Dict[str, List[str]]:
"""
获取所有可用的模型及其提供商
方案 A: 基于 GlobalModel 查询
Returns:
字典,键为 GlobalModel.name值为支持该模型的提供商名列表
"""
result = {}
# 查询所有活跃的 GlobalModel 及其 Provider
models = (
self.db.query(GlobalModel.name, Provider.name)
.join(Model, GlobalModel.id == Model.global_model_id)
.join(Provider, Model.provider_id == Provider.id)
.filter(
GlobalModel.is_active == True, Model.is_active == True, Provider.is_active == True
)
.all()
)
for global_model_name, provider_name in models:
if global_model_name not in result:
result[global_model_name] = []
if provider_name not in result[global_model_name]:
result[global_model_name].append(provider_name)
return result
async def get_cheapest_provider(self, model_name: str) -> Optional[Provider]:
"""
获取某个模型最便宜的提供商
方案 A: 通过 GlobalModel 查找
Args:
model_name: GlobalModel 名称或别名
Returns:
最便宜的提供商
"""
# 步骤 1: 解析模型名
global_model_name = await resolve_model_to_global_name(self.db, model_name)
# 步骤 2: 查找 GlobalModel
global_model = (
self.db.query(GlobalModel)
.filter(GlobalModel.name == global_model_name, GlobalModel.is_active == True)
.first()
)
if not global_model:
return None
# 步骤 3: 查询所有支持该模型的 Provider 及其价格
models_with_providers = (
self.db.query(Provider, Model)
.join(Model, Provider.id == Model.provider_id)
.filter(
Model.global_model_id == global_model.id,
Model.is_active == True,
Provider.is_active == True,
)
.all()
)
if not models_with_providers:
return None
# 按总价格排序(输入+输出价格)
cheapest = min(
models_with_providers, key=lambda x: x[1].input_price_per_1m + x[1].output_price_per_1m
)
provider = cheapest[0]
model = cheapest[1]
logger.debug(f"Selected cheapest provider {provider.name} for model {model_name} "
f"(input: ${model.input_price_per_1m}/M, output: ${model.output_price_per_1m}/M)")
return provider

View File

@@ -0,0 +1,432 @@
"""
模型映射解析服务
负责统一的模型别名/降级解析,按优先级顺序:
1. 映射mappingProvider 特定 → 全局
2. 别名aliasProvider 特定 → 全局
3. 直接匹配 GlobalModel.name
支持特性:
- 带缓存(本地或 Redis减少数据库访问
- 提供模糊匹配能力,用于提示相似模型
"""
from __future__ import annotations
import os
from typing import List, Optional, Tuple
from src.core.logger import logger
from sqlalchemy.orm import Session
from src.config.constants import CacheSize, CacheTTL
from src.core.logger import logger
from src.models.database import GlobalModel, ModelMapping
from src.services.cache.backend import BaseCacheBackend, get_cache_backend
class ModelMappingResolver:
"""统一的 ModelMapping 解析服务(可跨进程共享缓存)。"""
def __init__(self, cache_ttl: int = CacheTTL.MODEL_MAPPING, cache_backend_type: str = "auto"):
self._cache_ttl = cache_ttl
self._cache_backend_type = cache_backend_type
self._mapping_cache: Optional[BaseCacheBackend] = None
self._global_model_cache: Optional[BaseCacheBackend] = None
self._initialized = False
self._stats = {
"mapping_hits": 0,
"mapping_misses": 0,
"global_hits": 0,
"global_misses": 0,
}
async def _ensure_initialized(self):
if self._initialized:
return
self._mapping_cache = await get_cache_backend(
name="model_mapping_resolver:mapping",
backend_type=self._cache_backend_type,
max_size=CacheSize.MODEL_MAPPING,
ttl=self._cache_ttl,
)
self._global_model_cache = await get_cache_backend(
name="model_mapping_resolver:global",
backend_type=self._cache_backend_type,
max_size=CacheSize.MODEL_MAPPING,
ttl=self._cache_ttl,
)
self._initialized = True
logger.debug(f"[ModelMappingResolver] 缓存后端已初始化: {self._mapping_cache.get_stats()['backend']}")
def _cache_key(self, source_model: str, provider_id: Optional[str]) -> str:
return f"{provider_id or 'global'}:{source_model}"
async def _lookup_target_global_model_id(
self,
db: Session,
source_model: str,
provider_id: Optional[str] = None,
) -> Optional[str]:
"""
按优先级查找目标 GlobalModel ID
1. 映射mapping_type='mapping'Provider 特定 → 全局
2. 别名mapping_type='alias'Provider 特定 → 全局
3. 直接匹配 GlobalModel.name
"""
await self._ensure_initialized()
cache_key = self._cache_key(source_model, provider_id)
cached = await self._mapping_cache.get(cache_key)
if cached is not None:
self._stats["mapping_hits"] += 1
return cached or None
self._stats["mapping_misses"] += 1
target_id: Optional[str] = None
# 优先级 1查找映射mapping_type='mapping'
# 1.1 Provider 特定映射
if provider_id:
mapping = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id == provider_id,
ModelMapping.mapping_type == "mapping",
ModelMapping.is_active == True,
)
.first()
)
if mapping:
target_id = mapping.target_global_model_id
logger.debug(f"[MappingResolver] 命中 Provider 映射: {source_model} -> {target_id[:8]}...")
# 1.2 全局映射
if not target_id:
mapping = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id.is_(None),
ModelMapping.mapping_type == "mapping",
ModelMapping.is_active == True,
)
.first()
)
if mapping:
target_id = mapping.target_global_model_id
logger.debug(f"[MappingResolver] 命中全局映射: {source_model} -> {target_id[:8]}...")
# 优先级 2查找别名mapping_type='alias'
# 2.1 Provider 特定别名
if not target_id and provider_id:
alias = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id == provider_id,
ModelMapping.mapping_type == "alias",
ModelMapping.is_active == True,
)
.first()
)
if alias:
target_id = alias.target_global_model_id
logger.debug(f"[MappingResolver] 命中 Provider 别名: {source_model} -> {target_id[:8]}...")
# 2.2 全局别名
if not target_id:
alias = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id.is_(None),
ModelMapping.mapping_type == "alias",
ModelMapping.is_active == True,
)
.first()
)
if alias:
target_id = alias.target_global_model_id
logger.debug(f"[MappingResolver] 命中全局别名: {source_model} -> {target_id[:8]}...")
# 优先级 3直接匹配 GlobalModel.name
if not target_id:
global_model = (
db.query(GlobalModel)
.filter(
GlobalModel.name == source_model,
GlobalModel.is_active == True,
)
.first()
)
if global_model:
target_id = global_model.id
logger.debug(f"[MappingResolver] 直接匹配 GlobalModel: {source_model}")
cached_value = target_id if target_id is not None else ""
await self._mapping_cache.set(cache_key, cached_value, self._cache_ttl)
return target_id
async def resolve_to_global_model_name(
self,
db: Session,
source_model: str,
provider_id: Optional[str] = None,
) -> str:
"""解析模型名/别名为 GlobalModel.name。未找到时返回原始输入。"""
target_id = await self._lookup_target_global_model_id(db, source_model, provider_id)
if not target_id:
return source_model
await self._ensure_initialized()
cached_name = await self._global_model_cache.get(target_id)
if cached_name:
self._stats["global_hits"] += 1
return cached_name
self._stats["global_misses"] += 1
global_model = (
db.query(GlobalModel)
.filter(GlobalModel.id == target_id, GlobalModel.is_active == True)
.first()
)
if global_model:
await self._global_model_cache.set(target_id, global_model.name, self._cache_ttl)
return global_model.name
return source_model
async def get_global_model_by_request(
self,
db: Session,
source_model: str,
provider_id: Optional[str] = None,
) -> Optional[GlobalModel]:
"""解析并返回 GlobalModel 对象(绑定当前 Session"""
target_id = await self._lookup_target_global_model_id(db, source_model, provider_id)
if not target_id:
return None
global_model = (
db.query(GlobalModel)
.filter(GlobalModel.id == target_id, GlobalModel.is_active == True)
.first()
)
return global_model
async def get_global_model_with_mapping_info(
self,
db: Session,
source_model: str,
provider_id: Optional[str] = None,
) -> Tuple[Optional[GlobalModel], bool]:
"""
解析并返回 GlobalModel 对象,同时返回是否发生了映射。
Args:
db: 数据库会话
source_model: 用户请求的模型名
provider_id: Provider ID可选
Returns:
(global_model, is_mapped) - GlobalModel 对象和是否发生了映射
is_mapped=True 表示 source_model 通过 mapping 规则映射到了不同的模型
is_mapped=False 表示 source_model 直接匹配或通过 alias 匹配
"""
await self._ensure_initialized()
# 先检查是否存在 mapping 类型的映射规则
has_mapping = False
# 检查 Provider 特定映射
if provider_id:
mapping = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id == provider_id,
ModelMapping.mapping_type == "mapping",
ModelMapping.is_active == True,
)
.first()
)
if mapping:
has_mapping = True
# 检查全局映射
if not has_mapping:
mapping = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id.is_(None),
ModelMapping.mapping_type == "mapping",
ModelMapping.is_active == True,
)
.first()
)
if mapping:
has_mapping = True
# 获取 GlobalModel
global_model = await self.get_global_model_by_request(db, source_model, provider_id)
return global_model, has_mapping
async def get_global_model_direct(
self,
db: Session,
source_model: str,
) -> Optional[GlobalModel]:
"""
直接通过模型名获取 GlobalModel不应用任何映射规则。
仅查找 alias 和直接匹配。
Args:
db: 数据库会话
source_model: 模型名
Returns:
GlobalModel 对象或 None
"""
# 优先级 1查找别名alias
# 全局别名
alias = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id.is_(None),
ModelMapping.mapping_type == "alias",
ModelMapping.is_active == True,
)
.first()
)
if alias:
global_model = (
db.query(GlobalModel)
.filter(GlobalModel.id == alias.target_global_model_id, GlobalModel.is_active == True)
.first()
)
if global_model:
return global_model
# 优先级 2直接匹配 GlobalModel.name
global_model = (
db.query(GlobalModel)
.filter(
GlobalModel.name == source_model,
GlobalModel.is_active == True,
)
.first()
)
return global_model
def find_similar_models(
self,
db: Session,
invalid_model: str,
limit: int = 3,
threshold: float = 0.4,
) -> List[Tuple[str, float]]:
"""用于提示相似的 GlobalModel.name。"""
from difflib import SequenceMatcher
all_models = db.query(GlobalModel.name).filter(GlobalModel.is_active == True).all()
similarities: List[Tuple[str, float]] = []
invalid_lower = invalid_model.lower()
for model in all_models:
model_name = model.name
ratio = SequenceMatcher(None, invalid_lower, model_name.lower()).ratio()
if invalid_lower in model_name.lower() or model_name.lower() in invalid_lower:
ratio += 0.2
if ratio >= threshold:
similarities.append((model_name, ratio))
similarities.sort(key=lambda item: item[1], reverse=True)
return similarities[:limit]
async def invalidate_mapping_cache(self, source_model: str, provider_id: Optional[str] = None):
await self._ensure_initialized()
keys = [self._cache_key(source_model, provider_id)]
if provider_id:
keys.append(self._cache_key(source_model, None))
for key in keys:
await self._mapping_cache.delete(key)
async def invalidate_global_model_cache(self, global_model_id: Optional[str] = None):
await self._ensure_initialized()
if global_model_id:
await self._global_model_cache.delete(global_model_id)
else:
await self._global_model_cache.clear()
async def clear_cache(self):
await self._ensure_initialized()
await self._mapping_cache.clear()
await self._global_model_cache.clear()
def get_stats(self) -> dict:
total_mapping = self._stats["mapping_hits"] + self._stats["mapping_misses"]
total_global = self._stats["global_hits"] + self._stats["global_misses"]
stats = {
"mapping_hit_rate": (
self._stats["mapping_hits"] / total_mapping if total_mapping else 0.0
),
"global_hit_rate": self._stats["global_hits"] / total_global if total_global else 0.0,
"stats": self._stats,
}
if self._initialized:
stats["mapping_cache_backend"] = self._mapping_cache.get_stats()
stats["global_cache_backend"] = self._global_model_cache.get_stats()
return stats
_model_mapping_resolver: Optional[ModelMappingResolver] = None
def get_model_mapping_resolver(
cache_ttl: int = 300, cache_backend_type: Optional[str] = None
) -> ModelMappingResolver:
global _model_mapping_resolver
if _model_mapping_resolver is None:
if cache_backend_type is None:
cache_backend_type = os.getenv("ALIAS_CACHE_BACKEND", "auto")
_model_mapping_resolver = ModelMappingResolver(
cache_ttl=cache_ttl,
cache_backend_type=cache_backend_type,
)
logger.debug(f"[ModelMappingResolver] 初始化cache_ttl={cache_ttl}s, backend={cache_backend_type})")
# 注册到缓存失效服务
try:
from src.services.cache.invalidation import get_cache_invalidation_service
cache_service = get_cache_invalidation_service()
cache_service.set_mapping_resolver(_model_mapping_resolver)
except Exception as exc:
logger.warning(f"[ModelMappingResolver] 注册缓存失效服务失败: {exc}")
return _model_mapping_resolver
async def resolve_model_to_global_name(
db: Session,
source_model: str,
provider_id: Optional[str] = None,
) -> str:
resolver = get_model_mapping_resolver()
return await resolver.resolve_to_global_model_name(db, source_model, provider_id)
async def get_global_model_by_request(
db: Session,
source_model: str,
provider_id: Optional[str] = None,
) -> Optional[GlobalModel]:
resolver = get_model_mapping_resolver()
return await resolver.get_global_model_by_request(db, source_model, provider_id)

View File

@@ -0,0 +1,48 @@
"""
计费相关数据类
定义计费计算所需的数据结构。
实际的计费逻辑已移至 ChatAdapterBase每种 API 格式可以覆盖计费方法。
数据类:
- UsageTokens: 请求的 token 使用量
- PricingConfig: 价格配置
- CostResult: 计费结果
"""
from dataclasses import dataclass
from typing import Optional
@dataclass
class UsageTokens:
"""请求的 token 使用量"""
input_tokens: int = 0
output_tokens: int = 0
cache_creation_input_tokens: int = 0
cache_read_input_tokens: int = 0
@dataclass
class PricingConfig:
"""价格配置"""
input_price_per_1m: float = 0.0
output_price_per_1m: float = 0.0
cache_creation_price_per_1m: Optional[float] = None
cache_read_price_per_1m: Optional[float] = None
price_per_request: Optional[float] = None
tiered_pricing: Optional[dict] = None
cache_ttl_minutes: Optional[int] = None
@dataclass
class CostResult:
"""计费结果"""
input_cost: float = 0.0
output_cost: float = 0.0
cache_creation_cost: float = 0.0
cache_read_cost: float = 0.0
cache_cost: float = 0.0
request_cost: float = 0.0
total_cost: float = 0.0
tier_index: Optional[int] = None # 命中的阶梯索引

View File

@@ -0,0 +1,356 @@
"""
模型管理服务
"""
import asyncio
from typing import List, Optional
from sqlalchemy import and_
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.models.api import ModelCreate, ModelResponse, ModelUpdate
from src.models.database import Model, Provider
from src.services.cache.invalidation import get_cache_invalidation_service
from src.services.cache.model_cache import ModelCacheService
class ModelService:
"""模型管理服务"""
@staticmethod
def create_model(db: Session, provider_id: str, model_data: ModelCreate) -> Model:
"""创建模型"""
# 检查提供商是否存在
provider = db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
raise NotFoundException(f"提供商 {provider_id} 不存在")
# 检查同一提供商下是否已存在同名模型
existing = (
db.query(Model)
.filter(
and_(
Model.provider_id == provider_id,
Model.provider_model_name == model_data.provider_model_name,
)
)
.first()
)
if existing:
raise InvalidRequestException(
f"提供商 {provider.name} 下已存在模型 {model_data.provider_model_name}"
)
try:
model = Model(
provider_id=provider_id,
global_model_id=model_data.global_model_id,
provider_model_name=model_data.provider_model_name,
price_per_request=model_data.price_per_request,
tiered_pricing=model_data.tiered_pricing,
supports_vision=model_data.supports_vision,
supports_function_calling=model_data.supports_function_calling,
supports_streaming=model_data.supports_streaming,
supports_extended_thinking=model_data.supports_extended_thinking,
is_active=model_data.is_active if model_data.is_active is not None else True,
config=model_data.config,
)
db.add(model)
db.commit()
db.refresh(model)
# 显式加载 global_model 关系
if model.global_model_id:
from sqlalchemy.orm import joinedload
model = (
db.query(Model)
.options(joinedload(Model.global_model))
.filter(Model.id == model.id)
.first()
)
logger.info(f"创建模型成功: provider={provider.name}, model={model.provider_model_name}, global_model_id={model.global_model_id}")
return model
except IntegrityError as e:
db.rollback()
logger.error(f"创建模型失败: {str(e)}")
raise InvalidRequestException("创建模型失败,请检查输入数据")
@staticmethod
def get_model(db: Session, model_id: str) -> Model: # UUID
"""获取模型详情"""
from sqlalchemy.orm import joinedload
model = (
db.query(Model)
.options(joinedload(Model.global_model))
.filter(Model.id == model_id)
.first()
)
if not model:
raise NotFoundException(f"模型 {model_id} 不存在")
return model
@staticmethod
def get_models_by_provider(
db: Session,
provider_id: str, # UUID
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
) -> List[Model]:
"""获取提供商的模型列表"""
from sqlalchemy.orm import joinedload
query = (
db.query(Model)
.options(joinedload(Model.global_model))
.filter(Model.provider_id == provider_id)
)
if is_active is not None:
query = query.filter(Model.is_active == is_active)
# 按创建时间排序
query = query.order_by(Model.created_at.desc())
return query.offset(skip).limit(limit).all()
@staticmethod
def get_all_models(
db: Session,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
category: Optional[str] = None,
) -> List[Model]:
"""获取所有模型列表"""
query = db.query(Model)
if is_active is not None:
query = query.filter(Model.is_active == is_active)
# 按提供商和创建时间排序
query = query.order_by(Model.provider_id, Model.created_at.desc())
return query.offset(skip).limit(limit).all()
@staticmethod
def update_model(db: Session, model_id: str, model_data: ModelUpdate) -> Model: # UUID
"""更新模型"""
model = db.query(Model).filter(Model.id == model_id).first()
if not model:
raise NotFoundException(f"模型 {model_id} 不存在")
# 更新字段
update_data = model_data.model_dump(exclude_unset=True)
# 添加调试日志
logger.debug(f"更新模型 {model_id} 收到的数据: {update_data}")
logger.debug(f"更新前的 supports_vision: {model.supports_vision}, supports_function_calling: {model.supports_function_calling}, supports_extended_thinking: {model.supports_extended_thinking}")
for field, value in update_data.items():
setattr(model, field, value)
logger.debug(f"更新后的 supports_vision: {model.supports_vision}, supports_function_calling: {model.supports_function_calling}, supports_extended_thinking: {model.supports_extended_thinking}")
try:
db.commit()
db.refresh(model)
# 清除 Redis 缓存(异步执行,不阻塞返回)
asyncio.create_task(
ModelCacheService.invalidate_model_cache(
model_id=model.id,
provider_id=model.provider_id,
global_model_id=model.global_model_id,
)
)
# 清除内存缓存ModelMapperMiddleware 实例)
if model.provider_id and model.global_model_id:
cache_service = get_cache_invalidation_service()
cache_service.on_model_changed(model.provider_id, model.global_model_id)
logger.info(f"更新模型成功: id={model_id}, 最终 supports_vision: {model.supports_vision}, supports_function_calling: {model.supports_function_calling}, supports_extended_thinking: {model.supports_extended_thinking}")
return model
except IntegrityError as e:
db.rollback()
logger.error(f"更新模型失败: {str(e)}")
raise InvalidRequestException("更新模型失败,请检查输入数据")
@staticmethod
def delete_model(db: Session, model_id: str): # UUID
"""删除模型
新架构删除逻辑:
- Model 只是 Provider 对 GlobalModel 的实现,删除不影响 GlobalModel
- 检查是否是该 GlobalModel 的最后一个实现(如果是,警告但允许删除)
- 不检查 ModelMapping映射是 GlobalModel 之间的关系,别名也统一存储在此表中)
"""
model = db.query(Model).filter(Model.id == model_id).first()
if not model:
raise NotFoundException(f"模型 {model_id} 不存在")
# 检查这是否是该 GlobalModel 的最后一个关联提供商
if model.global_model_id:
other_implementations = (
db.query(Model)
.filter(
Model.global_model_id == model.global_model_id,
Model.id != model_id,
Model.is_active == True,
)
.count()
)
if other_implementations == 0:
logger.warning(f"警告:删除模型 {model_id}Provider: {model.provider_id[:8]}...)后,"
f"GlobalModel '{model.global_model_id}' 将没有任何活跃的关联提供商")
try:
db.delete(model)
db.commit()
logger.info(f"删除模型成功: id={model_id}, provider_model_name={model.provider_model_name}, "
f"global_model_id={model.global_model_id[:8] if model.global_model_id else 'None'}...")
except Exception as e:
db.rollback()
logger.error(f"删除模型失败: {str(e)}")
raise InvalidRequestException("删除模型失败")
@staticmethod
def toggle_model_availability(db: Session, model_id: str, is_available: bool) -> Model: # UUID
"""切换模型可用状态"""
model = db.query(Model).filter(Model.id == model_id).first()
if not model:
raise NotFoundException(f"模型 {model_id} 不存在")
model.is_available = is_available
db.commit()
db.refresh(model)
# 清除 Redis 缓存
asyncio.create_task(
ModelCacheService.invalidate_model_cache(
model_id=model.id,
provider_id=model.provider_id,
global_model_id=model.global_model_id,
)
)
# 清除内存缓存ModelMapperMiddleware 实例)
if model.provider_id and model.global_model_id:
cache_service = get_cache_invalidation_service()
cache_service.on_model_changed(model.provider_id, model.global_model_id)
status = "可用" if is_available else "不可用"
logger.info(f"更新模型可用状态: id={model_id}, status={status}")
return model
@staticmethod
def get_model_by_name(db: Session, provider_id: str, model_name: str) -> Optional[Model]:
"""根据 provider_model_name 获取模型"""
return (
db.query(Model)
.filter(and_(Model.provider_id == provider_id, Model.provider_model_name == model_name))
.first()
)
@staticmethod
def batch_create_models(
db: Session, provider_id: str, models_data: List[ModelCreate]
) -> List[Model]: # UUID
"""批量创建模型"""
# 检查提供商是否存在
provider = db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
raise NotFoundException(f"提供商 {provider_id} 不存在")
created_models = []
for model_data in models_data:
# 检查是否已存在
existing = (
db.query(Model)
.filter(
and_(
Model.provider_id == provider_id,
Model.provider_model_name == model_data.provider_model_name,
)
)
.first()
)
if existing:
logger.warning(f"模型 {model_data.provider_model_name} 已存在,跳过创建")
continue
model = Model(
provider_id=provider_id,
global_model_id=model_data.global_model_id,
provider_model_name=model_data.provider_model_name,
price_per_request=model_data.price_per_request,
tiered_pricing=model_data.tiered_pricing,
supports_vision=model_data.supports_vision,
supports_function_calling=model_data.supports_function_calling,
supports_streaming=model_data.supports_streaming,
supports_extended_thinking=model_data.supports_extended_thinking,
is_active=model_data.is_active,
config=model_data.config,
)
db.add(model)
created_models.append(model)
if created_models:
try:
db.commit()
for model in created_models:
db.refresh(model)
logger.info(f"批量创建 {len(created_models)} 个模型成功")
except IntegrityError as e:
db.rollback()
logger.error(f"批量创建模型失败: {str(e)}")
raise InvalidRequestException("批量创建模型失败")
return created_models
@staticmethod
def convert_to_response(model: Model) -> ModelResponse:
"""转换为响应模型(新架构:从 GlobalModel 获取显示信息和默认值)"""
return ModelResponse(
id=model.id,
provider_id=model.provider_id,
global_model_id=model.global_model_id,
provider_model_name=model.provider_model_name,
# 原始配置值(可能为空)
price_per_request=model.price_per_request,
tiered_pricing=model.tiered_pricing,
supports_vision=model.supports_vision,
supports_function_calling=model.supports_function_calling,
supports_streaming=model.supports_streaming,
supports_extended_thinking=model.supports_extended_thinking,
supports_image_generation=model.supports_image_generation,
# 有效值(合并 Model 和 GlobalModel 默认值)
effective_tiered_pricing=model.get_effective_tiered_pricing(),
effective_input_price=model.get_effective_input_price(),
effective_output_price=model.get_effective_output_price(),
effective_price_per_request=model.get_effective_price_per_request(),
effective_supports_vision=model.get_effective_supports_vision(),
effective_supports_function_calling=model.get_effective_supports_function_calling(),
effective_supports_streaming=model.get_effective_supports_streaming(),
effective_supports_extended_thinking=model.get_effective_supports_extended_thinking(),
effective_supports_image_generation=model.get_effective_supports_image_generation(),
is_active=model.is_active,
is_available=model.is_available if model.is_available is not None else True,
created_at=model.created_at,
updated_at=model.updated_at,
# GlobalModel 信息(如果存在)
global_model_name=model.global_model.name if model.global_model else None,
global_model_display_name=(
model.global_model.display_name if model.global_model else None
),
)