1
0
Fork 0

fix(api): override pipeline models if pipeline ignored components

This commit is contained in:
Sean Sube 2023-08-29 19:05:01 -05:00
parent ea9023c2eb
commit 6b31075616
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 23 additions and 16 deletions

View File

@ -402,18 +402,23 @@ def load_pipeline(
# make sure XL models are actually being used
# TODO: why is this needed?
if "text_encoder" in components:
logger.info(
"text encoder matches: %s, %s",
pipe.text_encoder == components["text_encoder"],
type(pipe.text_encoder),
)
pipe.text_encoder = components["text_encoder"]
if "text_encoder_2" in components:
logger.info(
"text encoder 2 matches: %s, %s",
pipe.text_encoder_2 == components["text_encoder_2"],
type(pipe.text_encoder_2),
)
pipe.text_encoder_2 = components["text_encoder_2"]
if "unet" in components:
logger.info(
"unet matches: %s, %s", pipe.unet == components["unet"], type(pipe.unet)
)
@ -447,6 +452,8 @@ def load_pipeline(
if hasattr(pipe, "vae_encoder"):
pipe.vae_encoder.set_window_size(latent_window, params.overlap)
run_gc([device])
return pipe