1
0
Fork 0

fix(api): override pipeline models if pipeline ignored components

This commit is contained in:
Sean Sube 2023-08-29 19:05:01 -05:00
parent ea9023c2eb
commit 6b31075616
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 23 additions and 16 deletions

View File

@ -402,22 +402,27 @@ def load_pipeline(
# make sure XL models are actually being used
# TODO: why is this needed?
logger.info(
"text encoder matches: %s, %s",
pipe.text_encoder == components["text_encoder"],
type(pipe.text_encoder),
)
pipe.text_encoder = components["text_encoder"]
logger.info(
"text encoder 2 matches: %s, %s",
pipe.text_encoder_2 == components["text_encoder_2"],
type(pipe.text_encoder_2),
)
pipe.text_encoder_2 = components["text_encoder_2"]
logger.info(
"unet matches: %s, %s", pipe.unet == components["unet"], type(pipe.unet)
)
pipe.unet = components["unet"]
if "text_encoder" in components:
logger.info(
"text encoder matches: %s, %s",
pipe.text_encoder == components["text_encoder"],
type(pipe.text_encoder),
)
pipe.text_encoder = components["text_encoder"]
if "text_encoder_2" in components:
logger.info(
"text encoder 2 matches: %s, %s",
pipe.text_encoder_2 == components["text_encoder_2"],
type(pipe.text_encoder_2),
)
pipe.text_encoder_2 = components["text_encoder_2"]
if "unet" in components:
logger.info(
"unet matches: %s, %s", pipe.unet == components["unet"], type(pipe.unet)
)
pipe.unet = components["unet"]
if not server.show_progress:
pipe.set_progress_bar_config(disable=True)
@ -447,6 +452,8 @@ def load_pipeline(
if hasattr(pipe, "vae_encoder"):
pipe.vae_encoder.set_window_size(latent_window, params.overlap)
run_gc([device])
return pipe