use regional noise sample, avoid very errors from large multipliers
This commit is contained in:
parent
997891b255
commit
408e3d725b
|
@ -465,32 +465,32 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
noise_pred = self.unet(
|
||||
region_noise_pred = self.unet(
|
||||
sample=latent_region_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=region_embeds[r],
|
||||
text_embeds=add_region_embeds[r],
|
||||
time_ids=add_time_ids,
|
||||
)
|
||||
noise_pred = noise_pred[0]
|
||||
region_noise_pred = region_noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
region_noise_pred_uncond, region_noise_pred_text = np.split(region_noise_pred, 2)
|
||||
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
|
||||
region_noise_pred_text - region_noise_pred_uncond
|
||||
)
|
||||
if guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(
|
||||
noise_pred,
|
||||
noise_pred_text,
|
||||
region_noise_pred = rescale_noise_cfg(
|
||||
region_noise_pred,
|
||||
region_noise_pred_text,
|
||||
guidance_rescale=guidance_rescale,
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = self.scheduler.step(
|
||||
torch.from_numpy(noise_pred),
|
||||
torch.from_numpy(region_noise_pred),
|
||||
t,
|
||||
torch.from_numpy(latents_for_region),
|
||||
**extra_step_kwargs,
|
||||
|
@ -498,8 +498,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
||||
|
||||
if mult >= 1000.0:
|
||||
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mult
|
||||
count[:, :, h_start:h_end, w_start:w_end] = mult
|
||||
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised
|
||||
count[:, :, h_start:h_end, w_start:w_end] = 1
|
||||
else:
|
||||
value[:, :, h_start:h_end, w_start:w_end] += latents_region_denoised * mult
|
||||
count[:, :, h_start:h_end, w_start:w_end] += mult
|
||||
|
|
Loading…
Reference in New Issue