diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 3b8565d6..7e78ab62 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -230,17 +230,32 @@ def load_pipeline( ) else: logger.debug( - "loading pretrained SD pipeline for %s", pipeline_class.__name__ - ) - pipe = pipeline_class.from_pretrained( - model, - provider=device.ort_provider(), - sess_options=device.sess_options(), - safety_checker=None, - torch_dtype=torch_dtype, - **components, + "assembling SD pipeline for %s", pipeline_class.__name__ ) + if pipeline_class == OnnxStableDiffusionUpscalePipeline: + # upscale uses a single VAE + pipe = pipeline_class( + components["vae"], + components["text_encoder"], + components["tokenizer"], + components["unet"], + scheduler, + scheduler, + ) + else: + pipe = pipeline_class( + components["vae_encoder"], + components["vae_decoder"], + components["text_encoder"], + components["tokenizer"], + components["unet"], + scheduler, + None, + None, + requires_safety_checker=False, + ) + if not server.show_progress: pipe.set_progress_bar_config(disable=True)