1
0
Fork 0

use regional noise sample, avoid very errors from large multipliers

This commit is contained in:
Sean Sube 2023-11-05 22:48:07 -06:00
parent 997891b255
commit 408e3d725b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 11 additions and 11 deletions

View File

@ -465,32 +465,32 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# predict the noise residual # predict the noise residual
timestep = np.array([t], dtype=timestep_dtype) timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet( region_noise_pred = self.unet(
sample=latent_region_input, sample=latent_region_input,
timestep=timestep, timestep=timestep,
encoder_hidden_states=region_embeds[r], encoder_hidden_states=region_embeds[r],
text_embeds=add_region_embeds[r], text_embeds=add_region_embeds[r],
time_ids=add_time_ids, time_ids=add_time_ids,
) )
noise_pred = noise_pred[0] region_noise_pred = region_noise_pred[0]
# perform guidance # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) region_noise_pred_uncond, region_noise_pred_text = np.split(region_noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * ( region_noise_pred = region_noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond region_noise_pred_text - region_noise_pred_uncond
) )
if guidance_rescale > 0.0: if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg( region_noise_pred = rescale_noise_cfg(
noise_pred, region_noise_pred,
noise_pred_text, region_noise_pred_text,
guidance_rescale=guidance_rescale, guidance_rescale=guidance_rescale,
) )
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step( scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), torch.from_numpy(region_noise_pred),
t, t,
torch.from_numpy(latents_for_region), torch.from_numpy(latents_for_region),
**extra_step_kwargs, **extra_step_kwargs,
@ -498,8 +498,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
latents_region_denoised = scheduler_output.prev_sample.numpy() latents_region_denoised = scheduler_output.prev_sample.numpy()
if mult >= 1000.0: if mult >= 1000.0:
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mult value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised
count[:, :, h_start:h_end, w_start:w_end] = mult count[:, :, h_start:h_end, w_start:w_end] = 1
else: else:
value[:, :, h_start:h_end, w_start:w_end] += latents_region_denoised * mult value[:, :, h_start:h_end, w_start:w_end] += latents_region_denoised * mult
count[:, :, h_start:h_end, w_start:w_end] += mult count[:, :, h_start:h_end, w_start:w_end] += mult