refactor(backend): update model catalog and provider APIs after mappings removal

This commit is contained in:
fawney19
2025-12-15 14:30:10 +08:00
parent 728f9bb126
commit 56fb6bf36c
6 changed files with 85 additions and 566 deletions

View File

@@ -20,14 +20,12 @@ from src.models.api import (
ProviderStatsResponse,
PublicGlobalModelListResponse,
PublicGlobalModelResponse,
PublicModelMappingResponse,
PublicModelResponse,
PublicProviderResponse,
)
from src.models.database import (
GlobalModel,
Model,
ModelMapping,
Provider,
ProviderEndpoint,
RequestCandidate,
@@ -72,24 +70,6 @@ async def get_public_models(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/model-mappings", response_model=List[PublicModelMappingResponse])
async def get_public_model_mappings(
request: Request,
provider_id: Optional[str] = Query(None, description="提供商ID过滤"),
alias: Optional[str] = Query(None, description="别名过滤原source_model"),
skip: int = Query(0, description="跳过记录数"),
limit: int = Query(100, description="返回记录数限制"),
db: Session = Depends(get_db),
):
adapter = PublicModelMappingsAdapter(
provider_id=provider_id,
alias=alias,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/stats", response_model=ProviderStatsResponse)
async def get_public_stats(request: Request, db: Session = Depends(get_db)):
adapter = PublicStatsAdapter()
@@ -176,13 +156,6 @@ class PublicProvidersAdapter(PublicApiAdapter):
.filter(and_(Model.provider_id == provider.id, Model.is_active.is_(True)))
.count()
)
mappings_count = (
db.query(ModelMapping)
.filter(
and_(ModelMapping.provider_id == provider.id, ModelMapping.is_active.is_(True))
)
.count()
)
endpoints_count = len(provider.endpoints) if provider.endpoints else 0
active_endpoints_count = (
sum(1 for ep in provider.endpoints if ep.is_active) if provider.endpoints else 0
@@ -196,7 +169,6 @@ class PublicProvidersAdapter(PublicApiAdapter):
provider_priority=provider.provider_priority,
models_count=models_count,
active_models_count=active_models_count,
mappings_count=mappings_count,
endpoints_count=endpoints_count,
active_endpoints_count=active_endpoints_count,
)
@@ -256,78 +228,6 @@ class PublicModelsAdapter(PublicApiAdapter):
return response
@dataclass
class PublicModelMappingsAdapter(PublicApiAdapter):
provider_id: Optional[str]
alias: Optional[str] # 原 source_model改为 alias
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
logger.debug("公共API请求模型映射列表")
query = (
db.query(ModelMapping, GlobalModel, Provider)
.join(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
.outerjoin(Provider, ModelMapping.provider_id == Provider.id)
.filter(
and_(
ModelMapping.is_active.is_(True),
GlobalModel.is_active.is_(True),
)
)
)
if self.provider_id is not None:
provider_global_model_ids = (
db.query(Model.global_model_id)
.join(Provider, Model.provider_id == Provider.id)
.filter(
Provider.id == self.provider_id,
Model.is_active.is_(True),
Provider.is_active.is_(True),
Model.global_model_id.isnot(None),
)
.distinct()
)
query = query.filter(
or_(
ModelMapping.provider_id == self.provider_id,
and_(
ModelMapping.provider_id.is_(None),
ModelMapping.target_global_model_id.in_(provider_global_model_ids),
),
)
)
else:
query = query.filter(ModelMapping.provider_id.is_(None))
if self.alias is not None:
query = query.filter(ModelMapping.source_model.ilike(f"%{self.alias}%"))
results = query.offset(self.skip).limit(self.limit).all()
response = []
for mapping, global_model, provider in results:
scope = "provider" if mapping.provider_id else "global"
mapping_data = PublicModelMappingResponse(
id=mapping.id,
source_model=mapping.source_model,
target_global_model_id=mapping.target_global_model_id,
target_global_model_name=global_model.name if global_model else None,
target_global_model_display_name=(
global_model.display_name if global_model else None
),
provider_id=mapping.provider_id,
scope=scope,
is_active=mapping.is_active,
)
response.append(mapping_data.model_dump())
logger.debug(f"返回 {len(response)} 个模型映射")
return response
class PublicStatsAdapter(PublicApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
@@ -339,9 +239,6 @@ class PublicStatsAdapter(PublicApiAdapter):
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
.count()
)
from ...models.database import ModelMapping
active_mappings = db.query(ModelMapping).filter(ModelMapping.is_active.is_(True)).count()
formats = (
db.query(Provider.api_format).filter(Provider.is_active.is_(True)).distinct().all()
)
@@ -351,7 +248,6 @@ class PublicStatsAdapter(PublicApiAdapter):
active_providers=active_providers,
total_models=active_models,
active_models=active_models,
total_mappings=active_mappings,
supported_formats=supported_formats,
)
logger.debug("返回系统统计信息")