diff --git a/api/onnx_web/convert/diffusion_stable.py b/api/onnx_web/convert/diffusion_stable.py index fc8ef876..c52882ac 100644 --- a/api/onnx_web/convert/diffusion_stable.py +++ b/api/onnx_web/convert/diffusion_stable.py @@ -110,11 +110,10 @@ def convert_diffusion_stable( # UNET if single_vae: unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"] - # unet_inputs = ["latent_model_input", "timestep", "encoder_hidden_states", "class_labels"] - unet_scale = 4 + unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int) else: unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"] - unet_scale = False + unet_scale = torch.tensor(False).to(device=ctx.training_device, dtype=torch.bool) unet_in_channels = pipeline.unet.config.in_channels unet_sample_size = pipeline.unet.config.sample_size