diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 36a35b7a..28bdac75 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -7,7 +7,12 @@ from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline, ORTStableDiffusionXLImg2ImgPipeline, ORTStableDiffusionXLPipeline, ) -from optimum.onnxruntime.modeling_diffusion import ORTModelTextEncoder, ORTModelUnet +from optimum.onnxruntime.modeling_diffusion import ( + ORTModelTextEncoder, + ORTModelUnet, + ORTModelVaeDecoder, + ORTModelVaeEncoder, +) from transformers import CLIPTokenizer from ..constants import ONNX_MODEL @@ -363,26 +368,40 @@ def load_pipeline( sess_options=device.sess_options(), ) ) - elif ( - not params.is_xl() and path.exists(vae_decoder) and path.exists(vae_encoder) - ): - logger.debug("loading VAE decoder from %s", vae_decoder) - components["vae_decoder"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( + elif path.exists(vae_decoder) and path.exists(vae_encoder): + if params.is_xl(): + logger.debug("loading VAE decoder from %s", vae_decoder) + components["vae_decoder_session"] = OnnxRuntimeModel.load_model( vae_decoder, provider=device.ort_provider("vae"), sess_options=device.sess_options(), ) - ) - logger.debug("loading VAE encoder from %s", vae_encoder) - components["vae_encoder"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( + logger.debug("loading VAE encoder from %s", vae_encoder) + components["vae_encoder_session"] = OnnxRuntimeModel.load_model( vae_encoder, provider=device.ort_provider("vae"), sess_options=device.sess_options(), ) - ) + + else: + logger.debug("loading VAE decoder from %s", vae_decoder) + components["vae_decoder"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + vae_decoder, + provider=device.ort_provider("vae"), + sess_options=device.sess_options(), + ) + ) + + logger.debug("loading VAE encoder from %s", vae_encoder) + components["vae_encoder"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + vae_encoder, + provider=device.ort_provider("vae"), + sess_options=device.sess_options(), + ) + ) # additional options for panorama pipeline if params.is_panorama(): @@ -402,33 +421,30 @@ def load_pipeline( # make sure XL models are actually being used if "text_encoder_session" in components: - logger.info( - "text encoder matches: %s, %s", - pipe.text_encoder.session == components["text_encoder_session"], - type(pipe.text_encoder), - ) pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder) if "text_encoder_2_session" in components: - logger.info( - "text encoder 2 matches: %s, %s", - pipe.text_encoder_2.session == components["text_encoder_2_session"], - type(pipe.text_encoder_2), - ) pipe.text_encoder_2 = ORTModelTextEncoder( text_encoder_2_session, text_encoder_2 ) if "unet_session" in components: - logger.info( - "unet matches: %s, %s", - pipe.unet.session == components["unet_session"], - type(pipe.unet), - ) + # unload old UNet first pipe.unet = None run_gc([device]) + # load correct one pipe.unet = ORTModelUnet(unet_session, unet_model) + if "vae_decoder_session" in components: + pipe.vae_decoder = ORTModelVaeDecoder( + components["vae_decoder_session"], vae_decoder + ) + + if "vae_encoder_session" in components: + pipe.vae_encoder = ORTModelVaeEncoder( + components["vae_encoder_session"], vae_encoder + ) + if not server.show_progress: pipe.set_progress_bar_config(disable=True)