1
0
Fork 0

use regional noise sample, avoid very errors from large multipliers

This commit is contained in:
Sean Sube 2023-11-05 22:48:07 -06:00
parent 997891b255
commit 408e3d725b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 11 additions and 11 deletions

View File

@ -465,32 +465,32 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(
region_noise_pred = self.unet(
sample=latent_region_input,
timestep=timestep,
encoder_hidden_states=region_embeds[r],
text_embeds=add_region_embeds[r],
time_ids=add_time_ids,
)
noise_pred = noise_pred[0]
region_noise_pred = region_noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
region_noise_pred_uncond, region_noise_pred_text = np.split(region_noise_pred, 2)
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
region_noise_pred_text - region_noise_pred_uncond
)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(
noise_pred,
noise_pred_text,
region_noise_pred = rescale_noise_cfg(
region_noise_pred,
region_noise_pred_text,
guidance_rescale=guidance_rescale,
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred),
torch.from_numpy(region_noise_pred),
t,
torch.from_numpy(latents_for_region),
**extra_step_kwargs,
@ -498,8 +498,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
latents_region_denoised = scheduler_output.prev_sample.numpy()
if mult >= 1000.0:
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mult
count[:, :, h_start:h_end, w_start:w_end] = mult
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised
count[:, :, h_start:h_end, w_start:w_end] = 1
else:
value[:, :, h_start:h_end, w_start:w_end] += latents_region_denoised * mult
count[:, :, h_start:h_end, w_start:w_end] += mult