1
0
Fork 0

fix(api): set VAE attn processor during conversion

This commit is contained in:
Sean Sube 2023-09-24 15:02:21 -05:00
parent 6e2896f7f7
commit 0ecae65f88
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 4 additions and 3 deletions

View File

@ -380,6 +380,10 @@ def convert_diffusion_diffusers(
else: else:
pipeline.vae = AutoencoderKL.from_pretrained(vae_path) pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
if is_torch_2_0:
pipeline.unet.set_attn_processor(AttnProcessor())
pipeline.vae.set_attn_processor(AttnProcessor())
optimize_pipeline(conversion, pipeline) optimize_pipeline(conversion, pipeline)
output_path = Path(dest_path) output_path = Path(dest_path)
@ -430,9 +434,6 @@ def convert_diffusion_diffusers(
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"] unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool) unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool)
if is_torch_2_0:
pipeline.unet.set_attn_processor(AttnProcessor())
unet_in_channels = pipeline.unet.config.in_channels unet_in_channels = pipeline.unet.config.in_channels
unet_sample_size = pipeline.unet.config.sample_size unet_sample_size = pipeline.unet.config.sample_size
unet_path = output_path / "unet" / ONNX_MODEL unet_path = output_path / "unet" / ONNX_MODEL