From e8d6ab64c1c3103df33ff62cfca93b23ab8d1433 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 1 May 2023 19:09:52 -0500 Subject: [PATCH] add view iteration to panoramic inpaint --- api/onnx_web/chain/upscale_outpaint.py | 3 + api/onnx_web/diffusers/pipelines/panorama.py | 58 +++++++++++++------- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 2ca66cff..bc59bba6 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -123,6 +123,9 @@ def upscale_outpaint( draw_mask.rectangle((left, top, left + tile, top + tile), fill="black") return result.images[0] + if params.pipeline == "panorama": + logger.debug("outpainting with one shot panorama, no tiling") + return outpaint(source, (source.width, source.height, max(source.width, source.height))) if overlap == 0: logger.debug("outpainting with 0 margin, using grid tiling") output = process_tile_grid(source, SizeChart.auto, 1, [outpaint]) diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index 9f6893a3..8c908681 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -1088,30 +1088,48 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): ) timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + # panorama additions + views = self.get_views(height, width) + count = np.zeros_like(latents) + value = np.zeros_like(latents) + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - # concat latents, mask, masked_image_latnets in the channel dimension - latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) - latent_model_input = latent_model_input.cpu().numpy() - latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) + count.fill(0) + value.fill(0) - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ - 0 - ] + for h_start, h_end, w_start, w_end in views: + # get the latents corresponding to the current view coordinates + latents_for_view = latents[:, :, h_start:h_end, w_start:w_end] - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view + # concat latents, mask, masked_image_latnets in the channel dimension + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = self.scheduler.step( - torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs - ) - latents = scheduler_output.prev_sample.numpy() + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ + 0 + ] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents_view_denoised = scheduler_output.prev_sample.numpy() + + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised + count[:, :, h_start:h_end, w_start:w_end] += 1 + + # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 + latents = np.where(count > 0, value / count, value) # call the callback, if provided if callback is not None and i % callback_steps == 0: