mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
feat: implement upstream model import and batch model assignment with UI components
This commit is contained in:
@@ -151,29 +151,46 @@ async def query_available_models(
|
||||
adapter_class = _get_adapter_for_format(api_format)
|
||||
if not adapter_class:
|
||||
return [], f"Unknown API format: {api_format}"
|
||||
return await adapter_class.fetch_models(
|
||||
models, error = await adapter_class.fetch_models(
|
||||
client, base_url, api_key_value, extra_headers
|
||||
)
|
||||
# 确保所有模型都有 api_format 字段
|
||||
for m in models:
|
||||
if "api_format" not in m:
|
||||
m["api_format"] = api_format
|
||||
return models, error
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching models from {api_format} endpoint: {e}")
|
||||
return [], f"{api_format}: {str(e)}"
|
||||
|
||||
# 限制并发请求数量,避免触发上游速率限制
|
||||
MAX_CONCURRENT_REQUESTS = 5
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
||||
|
||||
async def fetch_with_semaphore(
|
||||
client: httpx.AsyncClient, config: dict
|
||||
) -> tuple[list, Optional[str]]:
|
||||
async with semaphore:
|
||||
return await fetch_endpoint_models(client, config)
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
results = await asyncio.gather(
|
||||
*[fetch_endpoint_models(client, c) for c in endpoint_configs]
|
||||
*[fetch_with_semaphore(client, c) for c in endpoint_configs]
|
||||
)
|
||||
for models, error in results:
|
||||
all_models.extend(models)
|
||||
if error:
|
||||
errors.append(error)
|
||||
|
||||
# 按 model id 去重(保留第一个)
|
||||
seen_ids: set[str] = set()
|
||||
# 按 model id + api_format 去重(保留第一个)
|
||||
seen_keys: set[str] = set()
|
||||
unique_models: list = []
|
||||
for model in all_models:
|
||||
model_id = model.get("id")
|
||||
if model_id and model_id not in seen_ids:
|
||||
seen_ids.add(model_id)
|
||||
api_format = model.get("api_format", "")
|
||||
unique_key = f"{model_id}:{api_format}"
|
||||
if model_id and unique_key not in seen_keys:
|
||||
seen_keys.add(unique_key)
|
||||
unique_models.append(model)
|
||||
|
||||
error = "; ".join(errors) if errors else None
|
||||
|
||||
@@ -22,16 +22,18 @@ from src.models.api import (
|
||||
from src.models.pydantic_models import (
|
||||
BatchAssignModelsToProviderRequest,
|
||||
BatchAssignModelsToProviderResponse,
|
||||
ImportFromUpstreamRequest,
|
||||
ImportFromUpstreamResponse,
|
||||
ImportFromUpstreamSuccessItem,
|
||||
ImportFromUpstreamErrorItem,
|
||||
ProviderAvailableSourceModel,
|
||||
ProviderAvailableSourceModelsResponse,
|
||||
)
|
||||
from src.models.database import (
|
||||
GlobalModel,
|
||||
Model,
|
||||
Provider,
|
||||
)
|
||||
from src.models.pydantic_models import (
|
||||
ProviderAvailableSourceModel,
|
||||
ProviderAvailableSourceModelsResponse,
|
||||
)
|
||||
from src.services.model.service import ModelService
|
||||
|
||||
router = APIRouter(tags=["Model Management"])
|
||||
@@ -158,6 +160,28 @@ async def batch_assign_global_models_to_provider(
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{provider_id}/import-from-upstream",
|
||||
response_model=ImportFromUpstreamResponse,
|
||||
)
|
||||
async def import_models_from_upstream(
|
||||
provider_id: str,
|
||||
payload: ImportFromUpstreamRequest,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ImportFromUpstreamResponse:
|
||||
"""
|
||||
从上游提供商导入模型
|
||||
|
||||
流程:
|
||||
1. 根据 model_ids 检查全局模型是否存在(按 name 匹配)
|
||||
2. 如不存在,自动创建新的 GlobalModel(使用默认配置)
|
||||
3. 创建 Model 关联到当前 Provider
|
||||
"""
|
||||
adapter = AdminImportFromUpstreamAdapter(provider_id=provider_id, payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@@ -425,3 +449,130 @@ class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter):
|
||||
await invalidate_models_list_cache()
|
||||
|
||||
return BatchAssignModelsToProviderResponse(success=success, errors=errors)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminImportFromUpstreamAdapter(AdminApiAdapter):
|
||||
"""从上游提供商导入模型"""
|
||||
|
||||
provider_id: str
|
||||
payload: ImportFromUpstreamRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("Provider not found", "provider")
|
||||
|
||||
success: list[ImportFromUpstreamSuccessItem] = []
|
||||
errors: list[ImportFromUpstreamErrorItem] = []
|
||||
|
||||
# 默认阶梯计费配置(免费)
|
||||
default_tiered_pricing = {
|
||||
"tiers": [
|
||||
{
|
||||
"up_to": None,
|
||||
"input_price_per_1m": 0.0,
|
||||
"output_price_per_1m": 0.0,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
for model_id in self.payload.model_ids:
|
||||
# 输入验证:检查 model_id 长度
|
||||
if not model_id or len(model_id) > 100:
|
||||
errors.append(
|
||||
ImportFromUpstreamErrorItem(
|
||||
model_id=model_id[:50] + "..." if model_id and len(model_id) > 50 else model_id or "<empty>",
|
||||
error="Invalid model_id: must be 1-100 characters",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
# 使用 savepoint 确保单个模型导入的原子性
|
||||
savepoint = db.begin_nested()
|
||||
try:
|
||||
# 1. 检查是否已存在同名的 GlobalModel
|
||||
global_model = (
|
||||
db.query(GlobalModel).filter(GlobalModel.name == model_id).first()
|
||||
)
|
||||
created_global_model = False
|
||||
|
||||
if not global_model:
|
||||
# 2. 创建新的 GlobalModel
|
||||
global_model = GlobalModel(
|
||||
name=model_id,
|
||||
display_name=model_id,
|
||||
default_tiered_pricing=default_tiered_pricing,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(global_model)
|
||||
db.flush()
|
||||
created_global_model = True
|
||||
logger.info(
|
||||
f"Created new GlobalModel: {model_id} during upstream import"
|
||||
)
|
||||
|
||||
# 3. 检查是否已存在关联
|
||||
existing = (
|
||||
db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == self.provider_id,
|
||||
Model.global_model_id == global_model.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
# 已存在关联,提交 savepoint 并记录成功
|
||||
savepoint.commit()
|
||||
success.append(
|
||||
ImportFromUpstreamSuccessItem(
|
||||
model_id=model_id,
|
||||
global_model_id=global_model.id,
|
||||
global_model_name=global_model.name,
|
||||
provider_model_id=existing.id,
|
||||
created_global_model=created_global_model,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# 4. 创建新的 Model 记录
|
||||
new_model = Model(
|
||||
provider_id=self.provider_id,
|
||||
global_model_id=global_model.id,
|
||||
provider_model_name=global_model.name,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(new_model)
|
||||
db.flush()
|
||||
|
||||
# 提交 savepoint
|
||||
savepoint.commit()
|
||||
success.append(
|
||||
ImportFromUpstreamSuccessItem(
|
||||
model_id=model_id,
|
||||
global_model_id=global_model.id,
|
||||
global_model_name=global_model.name,
|
||||
provider_model_id=new_model.id,
|
||||
created_global_model=created_global_model,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# 回滚到 savepoint
|
||||
savepoint.rollback()
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Error importing model {model_id}: {e}")
|
||||
errors.append(ImportFromUpstreamErrorItem(model_id=model_id, error=str(e)))
|
||||
|
||||
db.commit()
|
||||
logger.info(
|
||||
f"Imported {len(success)} models from upstream to provider {provider.name} by {context.user.username}"
|
||||
)
|
||||
|
||||
# 清除 /v1/models 列表缓存
|
||||
if success:
|
||||
await invalidate_models_list_cache()
|
||||
|
||||
return ImportFromUpstreamResponse(success=success, errors=errors)
|
||||
|
||||
@@ -301,6 +301,36 @@ class BatchAssignModelsToProviderResponse(BaseModel):
|
||||
errors: List[dict]
|
||||
|
||||
|
||||
class ImportFromUpstreamRequest(BaseModel):
|
||||
"""从上游提供商导入模型请求"""
|
||||
|
||||
model_ids: List[str] = Field(..., min_length=1, description="上游模型 ID 列表")
|
||||
|
||||
|
||||
class ImportFromUpstreamSuccessItem(BaseModel):
|
||||
"""导入成功的模型信息"""
|
||||
|
||||
model_id: str = Field(..., description="上游模型 ID")
|
||||
global_model_id: str = Field(..., description="GlobalModel ID")
|
||||
global_model_name: str = Field(..., description="GlobalModel 名称")
|
||||
provider_model_id: str = Field(..., description="Provider Model ID")
|
||||
created_global_model: bool = Field(..., description="是否新创建了 GlobalModel")
|
||||
|
||||
|
||||
class ImportFromUpstreamErrorItem(BaseModel):
|
||||
"""导入失败的模型信息"""
|
||||
|
||||
model_id: str = Field(..., description="上游模型 ID")
|
||||
error: str = Field(..., description="错误信息")
|
||||
|
||||
|
||||
class ImportFromUpstreamResponse(BaseModel):
|
||||
"""从上游提供商导入模型响应"""
|
||||
|
||||
success: List[ImportFromUpstreamSuccessItem]
|
||||
errors: List[ImportFromUpstreamErrorItem]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BatchAssignModelsToProviderRequest",
|
||||
"BatchAssignModelsToProviderResponse",
|
||||
@@ -311,6 +341,10 @@ __all__ = [
|
||||
"GlobalModelResponse",
|
||||
"GlobalModelUpdate",
|
||||
"GlobalModelWithStats",
|
||||
"ImportFromUpstreamErrorItem",
|
||||
"ImportFromUpstreamRequest",
|
||||
"ImportFromUpstreamResponse",
|
||||
"ImportFromUpstreamSuccessItem",
|
||||
"ModelCapabilities",
|
||||
"ModelCatalogItem",
|
||||
"ModelCatalogProviderDetail",
|
||||
|
||||
Reference in New Issue
Block a user