fix(api): correctly load text encoder 2 and VAE without LoRAs
This commit is contained in:
parent
d11b37f0b2
commit
56f19256b5
|
@ -200,7 +200,7 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
components.update(unet_components)
|
components.update(unet_components)
|
||||||
|
|
||||||
vae_components = load_vae(server, device, model)
|
vae_components = load_vae(server, device, model, params)
|
||||||
components.update(vae_components)
|
components.update(vae_components)
|
||||||
|
|
||||||
# additional options for panorama pipeline
|
# additional options for panorama pipeline
|
||||||
|
@ -338,6 +338,16 @@ def load_text_encoders(
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if params.is_xl():
|
||||||
|
text_encoder_2_session = InferenceSession(
|
||||||
|
text_encoder_2.SerializeToString(),
|
||||||
|
providers=[device.ort_provider("text-encoder")],
|
||||||
|
sess_options=text_encoder_2_opts,
|
||||||
|
)
|
||||||
|
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
||||||
|
components["text_encoder_2_session"] = text_encoder_2_session
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# blend and load text encoder
|
# blend and load text encoder
|
||||||
lora_names, lora_weights = zip(*loras)
|
lora_names, lora_weights = zip(*loras)
|
||||||
|
|
|
@ -47,11 +47,12 @@ class UNetWrapper(object):
|
||||||
self.prompt_index += 1
|
self.prompt_index += 1
|
||||||
|
|
||||||
if self.xl:
|
if self.xl:
|
||||||
logger.trace(
|
if sample.dtype != encoder_hidden_states.dtype:
|
||||||
"converting UNet sample to hidden state dtype for XL: %s",
|
logger.trace(
|
||||||
encoder_hidden_states.dtype,
|
"converting UNet sample to hidden state dtype for XL: %s",
|
||||||
)
|
encoder_hidden_states.dtype,
|
||||||
sample = sample.astype(encoder_hidden_states.dtype)
|
)
|
||||||
|
sample = sample.astype(encoder_hidden_states.dtype)
|
||||||
else:
|
else:
|
||||||
if sample.dtype != timestep.dtype:
|
if sample.dtype != timestep.dtype:
|
||||||
logger.trace("converting UNet sample to timestep dtype")
|
logger.trace("converting UNet sample to timestep dtype")
|
||||||
|
|
Loading…
Reference in New Issue