diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 23ecaea9..9b9c9d78 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -402,22 +402,27 @@ def load_pipeline( # make sure XL models are actually being used # TODO: why is this needed? - logger.info( - "text encoder matches: %s, %s", - pipe.text_encoder == components["text_encoder"], - type(pipe.text_encoder), - ) - pipe.text_encoder = components["text_encoder"] - logger.info( - "text encoder 2 matches: %s, %s", - pipe.text_encoder_2 == components["text_encoder_2"], - type(pipe.text_encoder_2), - ) - pipe.text_encoder_2 = components["text_encoder_2"] - logger.info( - "unet matches: %s, %s", pipe.unet == components["unet"], type(pipe.unet) - ) - pipe.unet = components["unet"] + if "text_encoder" in components: + logger.info( + "text encoder matches: %s, %s", + pipe.text_encoder == components["text_encoder"], + type(pipe.text_encoder), + ) + pipe.text_encoder = components["text_encoder"] + + if "text_encoder_2" in components: + logger.info( + "text encoder 2 matches: %s, %s", + pipe.text_encoder_2 == components["text_encoder_2"], + type(pipe.text_encoder_2), + ) + pipe.text_encoder_2 = components["text_encoder_2"] + + if "unet" in components: + logger.info( + "unet matches: %s, %s", pipe.unet == components["unet"], type(pipe.unet) + ) + pipe.unet = components["unet"] if not server.show_progress: pipe.set_progress_bar_config(disable=True) @@ -447,6 +452,8 @@ def load_pipeline( if hasattr(pipe, "vae_encoder"): pipe.vae_encoder.set_window_size(latent_window, params.overlap) + run_gc([device]) + return pipe