From 09f600ab54fd8346ded95d3bc44945754b29f23a Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 11 Nov 2023 22:43:41 -0600 Subject: [PATCH] feat(api): skip regions on last timestep --- api/onnx_web/diffusers/pipelines/panorama.py | 170 ++++++++-------- .../diffusers/pipelines/panorama_xl.py | 190 +++++++++--------- 2 files changed, 182 insertions(+), 178 deletions(-) diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index d90a5def..fb5ea0e8 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -557,6 +557,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): value = np.zeros_like(latents) for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + last = i == (len(self.scheduler.timesteps) - 1) count.fill(0) value.fill(0) @@ -603,97 +604,98 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1 - for r in range(len(regions)): - top, left, bottom, right, weight, feather, prompt = regions[r] - logger.debug( - "running region prompt: %s, %s, %s, %s, %s, %s, %s", - top, - left, - bottom, - right, - weight, - feather, - prompt, - ) - - # convert coordinates to latent space - h_start = top // LATENT_FACTOR - h_end = bottom // LATENT_FACTOR - w_start = left // LATENT_FACTOR - w_end = right // LATENT_FACTOR - - # get the latents corresponding to the current view coordinates - latents_for_region = latents[:, :, h_start:h_end, w_start:w_end] - logger.trace( - "region latent shape: [:,:,%s:%s,%s:%s] -> %s", - h_start, - h_end, - w_start, - w_end, - latents_for_region.shape, - ) - - # expand the latents if we are doing classifier free guidance - latent_region_input = ( - np.concatenate([latents_for_region] * 2) - if do_classifier_free_guidance - else latents_for_region - ) - latent_region_input = self.scheduler.scale_model_input( - torch.from_numpy(latent_region_input), t - ) - latent_region_input = latent_region_input.cpu().numpy() - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - region_noise_pred = self.unet( - sample=latent_region_input, - timestep=timestep, - encoder_hidden_states=region_embeds[r], - ) - region_noise_pred = region_noise_pred[0] - - # perform guidance - if do_classifier_free_guidance: - region_noise_pred_uncond, region_noise_pred_text = np.split( - region_noise_pred, 2 - ) - region_noise_pred = region_noise_pred_uncond + guidance_scale * ( - region_noise_pred_text - region_noise_pred_uncond + if not last: + for r, region in enumerate(regions): + top, left, bottom, right, weight, feather, prompt = region + logger.debug( + "running region prompt: %s, %s, %s, %s, %s, %s, %s", + top, + left, + bottom, + right, + weight, + feather, + prompt, ) - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = self.scheduler.step( - torch.from_numpy(region_noise_pred), - t, - torch.from_numpy(latents_for_region), - **extra_step_kwargs, - ) - latents_region_denoised = scheduler_output.prev_sample.numpy() + # convert coordinates to latent space + h_start = top // LATENT_FACTOR + h_end = bottom // LATENT_FACTOR + w_start = left // LATENT_FACTOR + w_end = right // LATENT_FACTOR - if feather[0] > 0.0: - mask = make_tile_mask( - (h_end - h_start, w_end - w_start), - (h_end - h_start, w_end - w_start), - feather[0], - feather[1], + # get the latents corresponding to the current view coordinates + latents_for_region = latents[:, :, h_start:h_end, w_start:w_end] + logger.trace( + "region latent shape: [:,:,%s:%s,%s:%s] -> %s", + h_start, + h_end, + w_start, + w_end, + latents_for_region.shape, ) - mask = np.expand_dims(mask, axis=0) - mask = np.repeat(mask, 4, axis=0) - mask = np.expand_dims(mask, axis=0) - else: - mask = 1 - if weight >= 10.0: - value[:, :, h_start:h_end, w_start:w_end] = ( - latents_region_denoised * mask + # expand the latents if we are doing classifier free guidance + latent_region_input = ( + np.concatenate([latents_for_region] * 2) + if do_classifier_free_guidance + else latents_for_region ) - count[:, :, h_start:h_end, w_start:w_end] = mask - else: - value[:, :, h_start:h_end, w_start:w_end] += ( - latents_region_denoised * weight * mask + latent_region_input = self.scheduler.scale_model_input( + torch.from_numpy(latent_region_input), t ) - count[:, :, h_start:h_end, w_start:w_end] += weight * mask + latent_region_input = latent_region_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + region_noise_pred = self.unet( + sample=latent_region_input, + timestep=timestep, + encoder_hidden_states=region_embeds[r], + ) + region_noise_pred = region_noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + region_noise_pred_uncond, region_noise_pred_text = np.split( + region_noise_pred, 2 + ) + region_noise_pred = region_noise_pred_uncond + guidance_scale * ( + region_noise_pred_text - region_noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(region_noise_pred), + t, + torch.from_numpy(latents_for_region), + **extra_step_kwargs, + ) + latents_region_denoised = scheduler_output.prev_sample.numpy() + + if feather[0] > 0.0: + mask = make_tile_mask( + (h_end - h_start, w_end - w_start), + (h_end - h_start, w_end - w_start), + feather[0], + feather[1], + ) + mask = np.expand_dims(mask, axis=0) + mask = np.repeat(mask, 4, axis=0) + mask = np.expand_dims(mask, axis=0) + else: + mask = 1 + + if weight >= 10.0: + value[:, :, h_start:h_end, w_start:w_end] = ( + latents_region_denoised * mask + ) + count[:, :, h_start:h_end, w_start:w_end] = mask + else: + value[:, :, h_start:h_end, w_start:w_end] += ( + latents_region_denoised * weight * mask + ) + count[:, :, h_start:h_end, w_start:w_end] += weight * mask # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 latents = np.where(count > 0, value / count, value) diff --git a/api/onnx_web/diffusers/pipelines/panorama_xl.py b/api/onnx_web/diffusers/pipelines/panorama_xl.py index 24147160..650ed17a 100644 --- a/api/onnx_web/diffusers/pipelines/panorama_xl.py +++ b/api/onnx_web/diffusers/pipelines/panorama_xl.py @@ -388,6 +388,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order for i, t in enumerate(self.progress_bar(timesteps)): + last = i == (len(timesteps) - 1) count.fill(0) value.fill(0) @@ -443,106 +444,107 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1 - for r in range(len(regions)): - top, left, bottom, right, weight, feather, prompt = regions[r] - logger.debug( - "running region prompt: %s, %s, %s, %s, %s, %s, %s", - top, - left, - bottom, - right, - weight, - feather, - prompt, - ) - - # convert coordinates to latent space - h_start = top // LATENT_FACTOR - h_end = bottom // LATENT_FACTOR - w_start = left // LATENT_FACTOR - w_end = right // LATENT_FACTOR - - # get the latents corresponding to the current view coordinates - latents_for_region = latents[:, :, h_start:h_end, w_start:w_end] - logger.trace( - "region latent shape: [:,:,%s:%s,%s:%s] -> %s", - h_start, - h_end, - w_start, - w_end, - latents_for_region.shape, - ) - - # expand the latents if we are doing classifier free guidance - latent_region_input = ( - np.concatenate([latents_for_region] * 2) - if do_classifier_free_guidance - else latents_for_region - ) - latent_region_input = self.scheduler.scale_model_input( - torch.from_numpy(latent_region_input), t - ) - latent_region_input = latent_region_input.cpu().numpy() - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - region_noise_pred = self.unet( - sample=latent_region_input, - timestep=timestep, - encoder_hidden_states=region_embeds[r], - text_embeds=add_region_embeds[r], - time_ids=add_time_ids, - ) - region_noise_pred = region_noise_pred[0] - - # perform guidance - if do_classifier_free_guidance: - region_noise_pred_uncond, region_noise_pred_text = np.split( - region_noise_pred, 2 + if not last: + for r, region in enumerate(regions): + top, left, bottom, right, weight, feather, prompt = region + logger.debug( + "running region prompt: %s, %s, %s, %s, %s, %s, %s", + top, + left, + bottom, + right, + weight, + feather, + prompt, ) - region_noise_pred = region_noise_pred_uncond + guidance_scale * ( - region_noise_pred_text - region_noise_pred_uncond + + # convert coordinates to latent space + h_start = top // LATENT_FACTOR + h_end = bottom // LATENT_FACTOR + w_start = left // LATENT_FACTOR + w_end = right // LATENT_FACTOR + + # get the latents corresponding to the current view coordinates + latents_for_region = latents[:, :, h_start:h_end, w_start:w_end] + logger.trace( + "region latent shape: [:,:,%s:%s,%s:%s] -> %s", + h_start, + h_end, + w_start, + w_end, + latents_for_region.shape, ) - if guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - region_noise_pred = rescale_noise_cfg( - region_noise_pred, - region_noise_pred_text, - guidance_rescale=guidance_rescale, + + # expand the latents if we are doing classifier free guidance + latent_region_input = ( + np.concatenate([latents_for_region] * 2) + if do_classifier_free_guidance + else latents_for_region + ) + latent_region_input = self.scheduler.scale_model_input( + torch.from_numpy(latent_region_input), t + ) + latent_region_input = latent_region_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + region_noise_pred = self.unet( + sample=latent_region_input, + timestep=timestep, + encoder_hidden_states=region_embeds[r], + text_embeds=add_region_embeds[r], + time_ids=add_time_ids, + ) + region_noise_pred = region_noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + region_noise_pred_uncond, region_noise_pred_text = np.split( + region_noise_pred, 2 ) + region_noise_pred = region_noise_pred_uncond + guidance_scale * ( + region_noise_pred_text - region_noise_pred_uncond + ) + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + region_noise_pred = rescale_noise_cfg( + region_noise_pred, + region_noise_pred_text, + guidance_rescale=guidance_rescale, + ) - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = self.scheduler.step( - torch.from_numpy(region_noise_pred), - t, - torch.from_numpy(latents_for_region), - **extra_step_kwargs, - ) - latents_region_denoised = scheduler_output.prev_sample.numpy() + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(region_noise_pred), + t, + torch.from_numpy(latents_for_region), + **extra_step_kwargs, + ) + latents_region_denoised = scheduler_output.prev_sample.numpy() - if feather[0] > 0.0: - mask = make_tile_mask( - (h_end - h_start, w_end - w_start), - (h_end - h_start, w_end - w_start), - feather[0], - feather[1], - ) - mask = np.expand_dims(mask, axis=0) - mask = np.repeat(mask, 4, axis=0) - mask = np.expand_dims(mask, axis=0) - else: - mask = 1 + if feather[0] > 0.0: + mask = make_tile_mask( + (h_end - h_start, w_end - w_start), + (h_end - h_start, w_end - w_start), + feather[0], + feather[1], + ) + mask = np.expand_dims(mask, axis=0) + mask = np.repeat(mask, 4, axis=0) + mask = np.expand_dims(mask, axis=0) + else: + mask = 1 - if weight >= 10.0: - value[:, :, h_start:h_end, w_start:w_end] = ( - latents_region_denoised * mask - ) - count[:, :, h_start:h_end, w_start:w_end] = mask - else: - value[:, :, h_start:h_end, w_start:w_end] += ( - latents_region_denoised * weight * mask - ) - count[:, :, h_start:h_end, w_start:w_end] += weight * mask + if weight >= 10.0: + value[:, :, h_start:h_end, w_start:w_end] = ( + latents_region_denoised * mask + ) + count[:, :, h_start:h_end, w_start:w_end] = mask + else: + value[:, :, h_start:h_end, w_start:w_end] += ( + latents_region_denoised * weight * mask + ) + count[:, :, h_start:h_end, w_start:w_end] += weight * mask # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 latents = np.where(count > 0, value / count, value)