From 56f19256b5cdd2b1d5c01cb6805f163d1a8e4f92 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 24 Sep 2023 09:49:16 -0500 Subject: [PATCH] fix(api): correctly load text encoder 2 and VAE without LoRAs --- api/onnx_web/diffusers/load.py | 12 +++++++++++- api/onnx_web/diffusers/patches/unet.py | 11 ++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 3e0c2920..2773922c 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -200,7 +200,7 @@ def load_pipeline( ) components.update(unet_components) - vae_components = load_vae(server, device, model) + vae_components = load_vae(server, device, model, params) components.update(vae_components) # additional options for panorama pipeline @@ -338,6 +338,16 @@ def load_text_encoders( sess_options=device.sess_options(), ) ) + + if params.is_xl(): + text_encoder_2_session = InferenceSession( + text_encoder_2.SerializeToString(), + providers=[device.ort_provider("text-encoder")], + sess_options=text_encoder_2_opts, + ) + text_encoder_2_session._model_path = path.join(model, "text_encoder_2") + components["text_encoder_2_session"] = text_encoder_2_session + else: # blend and load text encoder lora_names, lora_weights = zip(*loras) diff --git a/api/onnx_web/diffusers/patches/unet.py b/api/onnx_web/diffusers/patches/unet.py index fbe0f4f8..6e15597f 100644 --- a/api/onnx_web/diffusers/patches/unet.py +++ b/api/onnx_web/diffusers/patches/unet.py @@ -47,11 +47,12 @@ class UNetWrapper(object): self.prompt_index += 1 if self.xl: - logger.trace( - "converting UNet sample to hidden state dtype for XL: %s", - encoder_hidden_states.dtype, - ) - sample = sample.astype(encoder_hidden_states.dtype) + if sample.dtype != encoder_hidden_states.dtype: + logger.trace( + "converting UNet sample to hidden state dtype for XL: %s", + encoder_hidden_states.dtype, + ) + sample = sample.astype(encoder_hidden_states.dtype) else: if sample.dtype != timestep.dtype: logger.trace("converting UNet sample to timestep dtype")