From 45b09168dd26ecddb0210c18033323ea41eae334 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 9 Feb 2023 22:04:01 -0600 Subject: [PATCH] fix(api): move all unet tensors to the training device (#119) --- api/onnx_web/convert/diffusion_stable.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/api/onnx_web/convert/diffusion_stable.py b/api/onnx_web/convert/diffusion_stable.py index fc8ef876..c52882ac 100644 --- a/api/onnx_web/convert/diffusion_stable.py +++ b/api/onnx_web/convert/diffusion_stable.py @@ -110,11 +110,10 @@ def convert_diffusion_stable( # UNET if single_vae: unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"] - # unet_inputs = ["latent_model_input", "timestep", "encoder_hidden_states", "class_labels"] - unet_scale = 4 + unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int) else: unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"] - unet_scale = False + unet_scale = torch.tensor(False).to(device=ctx.training_device, dtype=torch.bool) unet_in_channels = pipeline.unet.config.in_channels unet_sample_size = pipeline.unet.config.sample_size