diff --git a/api/onnx_web/convert/diffusion_stable.py b/api/onnx_web/convert/diffusion_stable.py index 18156c3c..2ee2c861 100644 --- a/api/onnx_web/convert/diffusion_stable.py +++ b/api/onnx_web/convert/diffusion_stable.py @@ -126,9 +126,7 @@ def convert_diffusion_stable( # UNET if single_vae: unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"] - unet_scale = torch.tensor(4).to( - device=ctx.training_device, dtype=torch.long - ) + unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.long) else: unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"] unet_scale = torch.tensor(False).to(