From 374efca7765a75139154bbdb45fcdce4a8da5380 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 14 Jan 2024 21:29:18 -0600 Subject: [PATCH] attempt to freeze scheduler timestep for panorama --- api/onnx_web/diffusers/pipelines/panorama.py | 27 +++++++++++++++++++ .../diffusers/pipelines/panorama_xl.py | 27 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index 317a8515..c3c518fe 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -612,6 +612,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): noise_pred_text - noise_pred_uncond ) + # freeze the scheduler's internal timestep + prev_step_index = None + if hasattr(self.scheduler, "_step_index"): + prev_step_index = self.scheduler._step_index + # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( torch.from_numpy(noise_pred), @@ -621,6 +626,10 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): ) latents_view_denoised = scheduler_output.prev_sample.numpy() + # reset the scheduler's internal timestep + if prev_step_index is not None: + self.scheduler._step_index = prev_step_index + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1 @@ -686,6 +695,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): * (region_noise_pred_text - region_noise_pred_uncond) ) + # freeze the scheduler's internal timestep + prev_step_index = None + if hasattr(self.scheduler, "_step_index"): + prev_step_index = self.scheduler._step_index + # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( torch.from_numpy(region_noise_pred), @@ -695,6 +709,10 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): ) latents_region_denoised = scheduler_output.prev_sample.numpy() + # reset the scheduler's internal timestep + if prev_step_index is not None: + self.scheduler._step_index = prev_step_index + if feather[0] > 0.0: mask = make_tile_mask( (h_end - h_start, w_end - w_start), @@ -1027,6 +1045,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): noise_pred_text - noise_pred_uncond ) + # freeze the scheduler's internal timestep + prev_step_index = None + if hasattr(self.scheduler, "_step_index"): + prev_step_index = self.scheduler._step_index + # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( torch.from_numpy(noise_pred), @@ -1036,6 +1059,10 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): ) latents_view_denoised = scheduler_output.prev_sample.numpy() + # reset the scheduler's internal timestep + if prev_step_index is not None: + self.scheduler._step_index = prev_step_index + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1 diff --git a/api/onnx_web/diffusers/pipelines/panorama_xl.py b/api/onnx_web/diffusers/pipelines/panorama_xl.py index d54877ee..70f3be28 100644 --- a/api/onnx_web/diffusers/pipelines/panorama_xl.py +++ b/api/onnx_web/diffusers/pipelines/panorama_xl.py @@ -454,6 +454,11 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix guidance_rescale=guidance_rescale, ) + # freeze the scheduler's internal timestep + prev_step_index = None + if hasattr(self.scheduler, "_step_index"): + prev_step_index = self.scheduler._step_index + # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( torch.from_numpy(noise_pred), @@ -463,6 +468,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix ) latents_view_denoised = scheduler_output.prev_sample.numpy() + # reset the scheduler's internal timestep + if prev_step_index is not None: + self.scheduler._step_index = prev_step_index + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1 @@ -537,6 +546,11 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix guidance_rescale=guidance_rescale, ) + # freeze the scheduler's internal timestep + prev_step_index = None + if hasattr(self.scheduler, "_step_index"): + prev_step_index = self.scheduler._step_index + # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( torch.from_numpy(region_noise_pred), @@ -546,6 +560,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix ) latents_region_denoised = scheduler_output.prev_sample.numpy() + # reset the scheduler's internal timestep + if prev_step_index is not None: + self.scheduler._step_index = prev_step_index + if feather[0] > 0.0: mask = make_tile_mask( (h_end - h_start, w_end - w_start), @@ -878,6 +896,11 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix guidance_rescale=guidance_rescale, ) + # freeze the scheduler's internal timestep + prev_step_index = None + if hasattr(self.scheduler, "_step_index"): + prev_step_index = self.scheduler._step_index + # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( torch.from_numpy(noise_pred), @@ -887,6 +910,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix ) latents_view_denoised = scheduler_output.prev_sample.numpy() + # reset the scheduler's internal timestep + if prev_step_index is not None: + self.scheduler._step_index = prev_step_index + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1