1
0
Fork 0

fix(api): correctly load text encoder 2 and VAE without LoRAs

This commit is contained in:
Sean Sube 2023-09-24 09:49:16 -05:00
parent d11b37f0b2
commit 56f19256b5
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 17 additions and 6 deletions

View File

@ -200,7 +200,7 @@ def load_pipeline(
) )
components.update(unet_components) components.update(unet_components)
vae_components = load_vae(server, device, model) vae_components = load_vae(server, device, model, params)
components.update(vae_components) components.update(vae_components)
# additional options for panorama pipeline # additional options for panorama pipeline
@ -338,6 +338,16 @@ def load_text_encoders(
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )
) )
if params.is_xl():
text_encoder_2_session = InferenceSession(
text_encoder_2.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_2_opts,
)
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
components["text_encoder_2_session"] = text_encoder_2_session
else: else:
# blend and load text encoder # blend and load text encoder
lora_names, lora_weights = zip(*loras) lora_names, lora_weights = zip(*loras)

View File

@ -47,11 +47,12 @@ class UNetWrapper(object):
self.prompt_index += 1 self.prompt_index += 1
if self.xl: if self.xl:
logger.trace( if sample.dtype != encoder_hidden_states.dtype:
"converting UNet sample to hidden state dtype for XL: %s", logger.trace(
encoder_hidden_states.dtype, "converting UNet sample to hidden state dtype for XL: %s",
) encoder_hidden_states.dtype,
sample = sample.astype(encoder_hidden_states.dtype) )
sample = sample.astype(encoder_hidden_states.dtype)
else: else:
if sample.dtype != timestep.dtype: if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype") logger.trace("converting UNet sample to timestep dtype")