fix(api): move all unet tensors to the training device (#119)
This commit is contained in:
parent
d8d5bcd927
commit
45b09168dd
|
@ -110,11 +110,10 @@ def convert_diffusion_stable(
|
|||
# UNET
|
||||
if single_vae:
|
||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
|
||||
# unet_inputs = ["latent_model_input", "timestep", "encoder_hidden_states", "class_labels"]
|
||||
unet_scale = 4
|
||||
unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int)
|
||||
else:
|
||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
||||
unet_scale = False
|
||||
unet_scale = torch.tensor(False).to(device=ctx.training_device, dtype=torch.bool)
|
||||
|
||||
unet_in_channels = pipeline.unet.config.in_channels
|
||||
unet_sample_size = pipeline.unet.config.sample_size
|
||||
|
|
Loading…
Reference in New Issue