Files
Aether/src/api/admin/models/catalog.py

157 lines
6.3 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
统一模型目录 Admin API
基于 GlobalModel 的聚合视图
2025-12-10 20:52:44 +08:00
"""
from dataclasses import dataclass
from typing import Dict, List
2025-12-10 20:52:44 +08:00
from fastapi import APIRouter, Depends, Request
2025-12-10 20:52:44 +08:00
from sqlalchemy.orm import Session, joinedload
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.database import get_db
from src.models.database import GlobalModel, Model
2025-12-10 20:52:44 +08:00
from src.models.pydantic_models import (
ModelCapabilities,
ModelCatalogItem,
ModelCatalogProviderDetail,
ModelCatalogResponse,
ModelPriceRange,
)
router = APIRouter(prefix="/catalog", tags=["Admin - Model Catalog"])
pipeline = ApiRequestPipeline()
@router.get("", response_model=ModelCatalogResponse)
async def get_model_catalog(
request: Request,
db: Session = Depends(get_db),
) -> ModelCatalogResponse:
adapter = AdminGetModelCatalogAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@dataclass
class AdminGetModelCatalogAdapter(AdminApiAdapter):
"""管理员查询统一模型目录
架构说明
2025-12-10 20:52:44 +08:00
1. GlobalModel 为中心聚合数据
2. Model 表提供关联提供商和价格
2025-12-10 20:52:44 +08:00
"""
async def handle(self, context): # type: ignore[override]
db: Session = context.db
# 1. 获取所有活跃的 GlobalModel
global_models: List[GlobalModel] = (
db.query(GlobalModel).filter(GlobalModel.is_active == True).all()
)
# 2. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格)
2025-12-10 20:52:44 +08:00
models: List[Model] = (
db.query(Model)
.options(joinedload(Model.provider), joinedload(Model.global_model))
.filter(Model.is_active == True)
.all()
)
# 按 GlobalModel ID 组织关联提供商
models_by_global_model: Dict[str, List[Model]] = {}
for model in models:
if model.global_model_id:
models_by_global_model.setdefault(model.global_model_id, []).append(model)
# 3. 为每个 GlobalModel 构建 catalog item
2025-12-10 20:52:44 +08:00
catalog_items: List[ModelCatalogItem] = []
for gm in global_models:
gm_id = gm.id
provider_entries: List[ModelCatalogProviderDetail] = []
# 从 config JSON 读取能力标志
gm_config = gm.config or {}
2025-12-10 20:52:44 +08:00
capability_flags = {
"supports_vision": gm_config.get("vision", False),
"supports_function_calling": gm_config.get("function_calling", False),
"supports_streaming": gm_config.get("streaming", True),
2025-12-10 20:52:44 +08:00
}
# 遍历该 GlobalModel 的所有关联提供商
for model in models_by_global_model.get(gm_id, []):
provider = model.provider
if not provider:
continue
# 使用有效价格(考虑 GlobalModel 默认值)
effective_input = model.get_effective_input_price()
effective_output = model.get_effective_output_price()
effective_tiered = model.get_effective_tiered_pricing()
tier_count = len(effective_tiered.get("tiers", [])) if effective_tiered else 1
# 使用有效能力值
capability_flags["supports_vision"] = (
capability_flags["supports_vision"] or model.get_effective_supports_vision()
)
capability_flags["supports_function_calling"] = (
capability_flags["supports_function_calling"]
or model.get_effective_supports_function_calling()
)
capability_flags["supports_streaming"] = (
capability_flags["supports_streaming"]
or model.get_effective_supports_streaming()
)
provider_entries.append(
ModelCatalogProviderDetail(
provider_id=provider.id,
provider_name=provider.name,
provider_display_name=provider.display_name,
model_id=model.id,
target_model=model.provider_model_name,
# 显示有效价格
input_price_per_1m=effective_input,
output_price_per_1m=effective_output,
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
cache_read_price_per_1m=model.get_effective_cache_read_price(),
cache_1h_creation_price_per_1m=model.get_effective_1h_cache_creation_price(),
price_per_request=model.get_effective_price_per_request(),
effective_tiered_pricing=effective_tiered,
tier_count=tier_count,
supports_vision=model.get_effective_supports_vision(),
supports_function_calling=model.get_effective_supports_function_calling(),
supports_streaming=model.get_effective_supports_streaming(),
is_active=bool(model.is_active),
)
)
# 模型目录显示 GlobalModel 的第一个阶梯价格(不是 Provider 聚合价格)
tiered = gm.default_tiered_pricing or {}
first_tier = tiered.get("tiers", [{}])[0] if tiered.get("tiers") else {}
price_range = ModelPriceRange(
min_input=first_tier.get("input_price_per_1m", 0),
max_input=first_tier.get("input_price_per_1m", 0),
min_output=first_tier.get("output_price_per_1m", 0),
max_output=first_tier.get("output_price_per_1m", 0),
)
catalog_items.append(
ModelCatalogItem(
global_model_name=gm.name,
display_name=gm.display_name,
description=gm_config.get("description"),
2025-12-10 20:52:44 +08:00
providers=provider_entries,
price_range=price_range,
total_providers=len(provider_entries),
capabilities=ModelCapabilities(**capability_flags),
)
)
return ModelCatalogResponse(
models=catalog_items,
total=len(catalog_items),
)