flip and fix h/v coords for regions
This commit is contained in:
parent
911f87f7ec
commit
05f63a32b7
|
@ -450,29 +450,30 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
logger.debug("running region prompt: %s, %s, %s, %s, %s, %s", top, left, bottom, right, mult, prompt)
|
logger.debug("running region prompt: %s, %s, %s, %s, %s, %s", top, left, bottom, right, mult, prompt)
|
||||||
|
|
||||||
# convert coordinates to latent space
|
# convert coordinates to latent space
|
||||||
h_start = top // 8
|
h_start = left // 8
|
||||||
h_end = bottom // 8
|
h_end = right // 8
|
||||||
w_start = left // 8
|
w_start = top // 8
|
||||||
w_end = right // 8
|
w_end = bottom // 8
|
||||||
|
|
||||||
# get the latents corresponding to the current view coordinates
|
# get the latents corresponding to the current view coordinates
|
||||||
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
|
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
||||||
|
logger.trace("region latent shape: %s", latents_for_region.shape)
|
||||||
|
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
latent_model_input = (
|
latent_region_input = (
|
||||||
np.concatenate([latents_for_view] * 2)
|
np.concatenate([latents_for_region] * 2)
|
||||||
if do_classifier_free_guidance
|
if do_classifier_free_guidance
|
||||||
else latents_for_view
|
else latents_for_region
|
||||||
)
|
)
|
||||||
latent_model_input = self.scheduler.scale_model_input(
|
latent_region_input = self.scheduler.scale_model_input(
|
||||||
torch.from_numpy(latent_model_input), t
|
torch.from_numpy(latent_region_input), t
|
||||||
)
|
)
|
||||||
latent_model_input = latent_model_input.cpu().numpy()
|
latent_region_input = latent_region_input.cpu().numpy()
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
timestep = np.array([t], dtype=timestep_dtype)
|
timestep = np.array([t], dtype=timestep_dtype)
|
||||||
noise_pred = self.unet(
|
noise_pred = self.unet(
|
||||||
sample=latent_model_input,
|
sample=latent_region_input,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
encoder_hidden_states=region_embeds[i],
|
encoder_hidden_states=region_embeds[i],
|
||||||
text_embeds=add_region_embeds[i],
|
text_embeds=add_region_embeds[i],
|
||||||
|
@ -498,12 +499,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
scheduler_output = self.scheduler.step(
|
scheduler_output = self.scheduler.step(
|
||||||
torch.from_numpy(noise_pred),
|
torch.from_numpy(noise_pred),
|
||||||
t,
|
t,
|
||||||
torch.from_numpy(latents_for_view),
|
torch.from_numpy(latents_for_region),
|
||||||
**extra_step_kwargs,
|
**extra_step_kwargs,
|
||||||
)
|
)
|
||||||
latents_view_denoised = scheduler_output.prev_sample.numpy()
|
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
||||||
|
|
||||||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised * mult
|
value[:, :, h_start:h_end, w_start:w_end] += latents_region_denoised * mult
|
||||||
count[:, :, h_start:h_end, w_start:w_end] += mult
|
count[:, :, h_start:h_end, w_start:w_end] += mult
|
||||||
|
|
||||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||||
|
|
Loading…
Reference in New Issue