1
0
Fork 0

fix(api): move all unet tensors to the training device (#119)

This commit is contained in:
Sean Sube 2023-02-09 22:04:01 -06:00
parent d8d5bcd927
commit 45b09168dd
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 2 additions and 3 deletions

View File

@ -110,11 +110,10 @@ def convert_diffusion_stable(
# UNET
if single_vae:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
# unet_inputs = ["latent_model_input", "timestep", "encoder_hidden_states", "class_labels"]
unet_scale = 4
unet_scale = torch.tensor(4).to(device=ctx.training_device, dtype=torch.int)
else:
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
unet_scale = False
unet_scale = torch.tensor(False).to(device=ctx.training_device, dtype=torch.bool)
unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size