1
0
Fork 0

add view iteration to panoramic inpaint

This commit is contained in:
Sean Sube 2023-05-01 19:09:52 -05:00
parent ca611f03df
commit e8d6ab64c1
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 41 additions and 20 deletions

View File

@ -123,6 +123,9 @@ def upscale_outpaint(
draw_mask.rectangle((left, top, left + tile, top + tile), fill="black") draw_mask.rectangle((left, top, left + tile, top + tile), fill="black")
return result.images[0] return result.images[0]
if params.pipeline == "panorama":
logger.debug("outpainting with one shot panorama, no tiling")
return outpaint(source, (source.width, source.height, max(source.width, source.height)))
if overlap == 0: if overlap == 0:
logger.debug("outpainting with 0 margin, using grid tiling") logger.debug("outpainting with 0 margin, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint]) output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])

View File

@ -1088,30 +1088,48 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
) )
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# panorama additions
views = self.get_views(height, width)
count = np.zeros_like(latents)
value = np.zeros_like(latents)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance count.fill(0)
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents value.fill(0)
# concat latents, mask, masked_image_latnets in the channel dimension
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.cpu().numpy()
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
# predict the noise residual for h_start, h_end, w_start, w_end in views:
timestep = np.array([t], dtype=timestep_dtype) # get the latents corresponding to the current view coordinates
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
0
]
# perform guidance # expand the latents if we are doing classifier free guidance
if do_classifier_free_guidance: latent_model_input = np.concatenate([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) # concat latents, mask, masked_image_latnets in the channel dimension
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.cpu().numpy()
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
# compute the previous noisy sample x_t -> x_t-1 # predict the noise residual
scheduler_output = self.scheduler.step( timestep = np.array([t], dtype=timestep_dtype)
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[
) 0
latents = scheduler_output.prev_sample.numpy() ]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
)
latents_view_denoised = scheduler_output.prev_sample.numpy()
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value)
# call the callback, if provided # call the callback, if provided
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0: