diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 946c85de..5e334941 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -179,17 +179,18 @@ def load_pipeline( custom_pipeline = None logger.debug("loading new diffusion pipeline from %s", model) - scheduler = scheduler_type.from_pretrained( - model, - provider=device.ort_provider(), - sess_options=device.sess_options(), - subfolder="scheduler", - ) + components = { + "scheduler": scheduler_type.from_pretrained( + model, + provider=device.ort_provider(), + sess_options=device.sess_options(), + subfolder="scheduler", + ) + } - text_encoder = None if inversion is not None: logger.debug("loading text encoder from %s", inversion) - text_encoder = OnnxRuntimeModel.from_pretrained( + components["text_encoder"] = OnnxRuntimeModel.from_pretrained( path.join(inversion, "text_encoder"), provider=device.ort_provider(), sess_options=device.sess_options(), @@ -202,8 +203,7 @@ def load_pipeline( sess_options=device.sess_options(), revision="onnx", safety_checker=None, - scheduler=scheduler, - text_encoder=text_encoder, + **components, ) if not server.show_progress: