1
0
Fork 0

avoid passing None as encoder

This commit is contained in:
Sean Sube 2023-02-25 13:12:58 -06:00
parent 3626d69f40
commit 18f59f034d
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 10 additions and 10 deletions

View File

@ -179,17 +179,18 @@ def load_pipeline(
custom_pipeline = None custom_pipeline = None
logger.debug("loading new diffusion pipeline from %s", model) logger.debug("loading new diffusion pipeline from %s", model)
scheduler = scheduler_type.from_pretrained( components = {
model, "scheduler": scheduler_type.from_pretrained(
provider=device.ort_provider(), model,
sess_options=device.sess_options(), provider=device.ort_provider(),
subfolder="scheduler", sess_options=device.sess_options(),
) subfolder="scheduler",
)
}
text_encoder = None
if inversion is not None: if inversion is not None:
logger.debug("loading text encoder from %s", inversion) logger.debug("loading text encoder from %s", inversion)
text_encoder = OnnxRuntimeModel.from_pretrained( components["text_encoder"] = OnnxRuntimeModel.from_pretrained(
path.join(inversion, "text_encoder"), path.join(inversion, "text_encoder"),
provider=device.ort_provider(), provider=device.ort_provider(),
sess_options=device.sess_options(), sess_options=device.sess_options(),
@ -202,8 +203,7 @@ def load_pipeline(
sess_options=device.sess_options(), sess_options=device.sess_options(),
revision="onnx", revision="onnx",
safety_checker=None, safety_checker=None,
scheduler=scheduler, **components,
text_encoder=text_encoder,
) )
if not server.show_progress: if not server.show_progress: