1
0
Fork 0

attempt to freeze scheduler timestep for panorama

This commit is contained in:
Sean Sube 2024-01-14 21:29:18 -06:00
parent 482d9b246d
commit 374efca776
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 54 additions and 0 deletions

View File

@ -612,6 +612,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
noise_pred_text - noise_pred_uncond noise_pred_text - noise_pred_uncond
) )
# freeze the scheduler's internal timestep
prev_step_index = None
if hasattr(self.scheduler, "_step_index"):
prev_step_index = self.scheduler._step_index
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step( scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), torch.from_numpy(noise_pred),
@ -621,6 +626,10 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
) )
latents_view_denoised = scheduler_output.prev_sample.numpy() latents_view_denoised = scheduler_output.prev_sample.numpy()
# reset the scheduler's internal timestep
if prev_step_index is not None:
self.scheduler._step_index = prev_step_index
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1 count[:, :, h_start:h_end, w_start:w_end] += 1
@ -686,6 +695,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
* (region_noise_pred_text - region_noise_pred_uncond) * (region_noise_pred_text - region_noise_pred_uncond)
) )
# freeze the scheduler's internal timestep
prev_step_index = None
if hasattr(self.scheduler, "_step_index"):
prev_step_index = self.scheduler._step_index
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step( scheduler_output = self.scheduler.step(
torch.from_numpy(region_noise_pred), torch.from_numpy(region_noise_pred),
@ -695,6 +709,10 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
) )
latents_region_denoised = scheduler_output.prev_sample.numpy() latents_region_denoised = scheduler_output.prev_sample.numpy()
# reset the scheduler's internal timestep
if prev_step_index is not None:
self.scheduler._step_index = prev_step_index
if feather[0] > 0.0: if feather[0] > 0.0:
mask = make_tile_mask( mask = make_tile_mask(
(h_end - h_start, w_end - w_start), (h_end - h_start, w_end - w_start),
@ -1027,6 +1045,11 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
noise_pred_text - noise_pred_uncond noise_pred_text - noise_pred_uncond
) )
# freeze the scheduler's internal timestep
prev_step_index = None
if hasattr(self.scheduler, "_step_index"):
prev_step_index = self.scheduler._step_index
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step( scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), torch.from_numpy(noise_pred),
@ -1036,6 +1059,10 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
) )
latents_view_denoised = scheduler_output.prev_sample.numpy() latents_view_denoised = scheduler_output.prev_sample.numpy()
# reset the scheduler's internal timestep
if prev_step_index is not None:
self.scheduler._step_index = prev_step_index
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1 count[:, :, h_start:h_end, w_start:w_end] += 1

View File

@ -454,6 +454,11 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
guidance_rescale=guidance_rescale, guidance_rescale=guidance_rescale,
) )
# freeze the scheduler's internal timestep
prev_step_index = None
if hasattr(self.scheduler, "_step_index"):
prev_step_index = self.scheduler._step_index
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step( scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), torch.from_numpy(noise_pred),
@ -463,6 +468,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
) )
latents_view_denoised = scheduler_output.prev_sample.numpy() latents_view_denoised = scheduler_output.prev_sample.numpy()
# reset the scheduler's internal timestep
if prev_step_index is not None:
self.scheduler._step_index = prev_step_index
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1 count[:, :, h_start:h_end, w_start:w_end] += 1
@ -537,6 +546,11 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
guidance_rescale=guidance_rescale, guidance_rescale=guidance_rescale,
) )
# freeze the scheduler's internal timestep
prev_step_index = None
if hasattr(self.scheduler, "_step_index"):
prev_step_index = self.scheduler._step_index
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step( scheduler_output = self.scheduler.step(
torch.from_numpy(region_noise_pred), torch.from_numpy(region_noise_pred),
@ -546,6 +560,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
) )
latents_region_denoised = scheduler_output.prev_sample.numpy() latents_region_denoised = scheduler_output.prev_sample.numpy()
# reset the scheduler's internal timestep
if prev_step_index is not None:
self.scheduler._step_index = prev_step_index
if feather[0] > 0.0: if feather[0] > 0.0:
mask = make_tile_mask( mask = make_tile_mask(
(h_end - h_start, w_end - w_start), (h_end - h_start, w_end - w_start),
@ -878,6 +896,11 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
guidance_rescale=guidance_rescale, guidance_rescale=guidance_rescale,
) )
# freeze the scheduler's internal timestep
prev_step_index = None
if hasattr(self.scheduler, "_step_index"):
prev_step_index = self.scheduler._step_index
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step( scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), torch.from_numpy(noise_pred),
@ -887,6 +910,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
) )
latents_view_denoised = scheduler_output.prev_sample.numpy() latents_view_denoised = scheduler_output.prev_sample.numpy()
# reset the scheduler's internal timestep
if prev_step_index is not None:
self.scheduler._step_index = prev_step_index
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1 count[:, :, h_start:h_end, w_start:w_end] += 1