feat(api): skip regions on last timestep
This commit is contained in:
parent
4513fa3428
commit
09f600ab54
|
@ -557,6 +557,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
value = np.zeros_like(latents)
|
value = np.zeros_like(latents)
|
||||||
|
|
||||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||||
|
last = i == (len(self.scheduler.timesteps) - 1)
|
||||||
count.fill(0)
|
count.fill(0)
|
||||||
value.fill(0)
|
value.fill(0)
|
||||||
|
|
||||||
|
@ -603,97 +604,98 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||||
count[:, :, h_start:h_end, w_start:w_end] += 1
|
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||||
|
|
||||||
for r in range(len(regions)):
|
if not last:
|
||||||
top, left, bottom, right, weight, feather, prompt = regions[r]
|
for r, region in enumerate(regions):
|
||||||
logger.debug(
|
top, left, bottom, right, weight, feather, prompt = region
|
||||||
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
|
logger.debug(
|
||||||
top,
|
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
|
||||||
left,
|
top,
|
||||||
bottom,
|
left,
|
||||||
right,
|
bottom,
|
||||||
weight,
|
right,
|
||||||
feather,
|
weight,
|
||||||
prompt,
|
feather,
|
||||||
)
|
prompt,
|
||||||
|
|
||||||
# convert coordinates to latent space
|
|
||||||
h_start = top // LATENT_FACTOR
|
|
||||||
h_end = bottom // LATENT_FACTOR
|
|
||||||
w_start = left // LATENT_FACTOR
|
|
||||||
w_end = right // LATENT_FACTOR
|
|
||||||
|
|
||||||
# get the latents corresponding to the current view coordinates
|
|
||||||
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
|
||||||
logger.trace(
|
|
||||||
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
|
|
||||||
h_start,
|
|
||||||
h_end,
|
|
||||||
w_start,
|
|
||||||
w_end,
|
|
||||||
latents_for_region.shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
# expand the latents if we are doing classifier free guidance
|
|
||||||
latent_region_input = (
|
|
||||||
np.concatenate([latents_for_region] * 2)
|
|
||||||
if do_classifier_free_guidance
|
|
||||||
else latents_for_region
|
|
||||||
)
|
|
||||||
latent_region_input = self.scheduler.scale_model_input(
|
|
||||||
torch.from_numpy(latent_region_input), t
|
|
||||||
)
|
|
||||||
latent_region_input = latent_region_input.cpu().numpy()
|
|
||||||
|
|
||||||
# predict the noise residual
|
|
||||||
timestep = np.array([t], dtype=timestep_dtype)
|
|
||||||
region_noise_pred = self.unet(
|
|
||||||
sample=latent_region_input,
|
|
||||||
timestep=timestep,
|
|
||||||
encoder_hidden_states=region_embeds[r],
|
|
||||||
)
|
|
||||||
region_noise_pred = region_noise_pred[0]
|
|
||||||
|
|
||||||
# perform guidance
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
region_noise_pred_uncond, region_noise_pred_text = np.split(
|
|
||||||
region_noise_pred, 2
|
|
||||||
)
|
|
||||||
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
|
|
||||||
region_noise_pred_text - region_noise_pred_uncond
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# convert coordinates to latent space
|
||||||
scheduler_output = self.scheduler.step(
|
h_start = top // LATENT_FACTOR
|
||||||
torch.from_numpy(region_noise_pred),
|
h_end = bottom // LATENT_FACTOR
|
||||||
t,
|
w_start = left // LATENT_FACTOR
|
||||||
torch.from_numpy(latents_for_region),
|
w_end = right // LATENT_FACTOR
|
||||||
**extra_step_kwargs,
|
|
||||||
)
|
|
||||||
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
|
||||||
|
|
||||||
if feather[0] > 0.0:
|
# get the latents corresponding to the current view coordinates
|
||||||
mask = make_tile_mask(
|
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
||||||
(h_end - h_start, w_end - w_start),
|
logger.trace(
|
||||||
(h_end - h_start, w_end - w_start),
|
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
|
||||||
feather[0],
|
h_start,
|
||||||
feather[1],
|
h_end,
|
||||||
|
w_start,
|
||||||
|
w_end,
|
||||||
|
latents_for_region.shape,
|
||||||
)
|
)
|
||||||
mask = np.expand_dims(mask, axis=0)
|
|
||||||
mask = np.repeat(mask, 4, axis=0)
|
|
||||||
mask = np.expand_dims(mask, axis=0)
|
|
||||||
else:
|
|
||||||
mask = 1
|
|
||||||
|
|
||||||
if weight >= 10.0:
|
# expand the latents if we are doing classifier free guidance
|
||||||
value[:, :, h_start:h_end, w_start:w_end] = (
|
latent_region_input = (
|
||||||
latents_region_denoised * mask
|
np.concatenate([latents_for_region] * 2)
|
||||||
|
if do_classifier_free_guidance
|
||||||
|
else latents_for_region
|
||||||
)
|
)
|
||||||
count[:, :, h_start:h_end, w_start:w_end] = mask
|
latent_region_input = self.scheduler.scale_model_input(
|
||||||
else:
|
torch.from_numpy(latent_region_input), t
|
||||||
value[:, :, h_start:h_end, w_start:w_end] += (
|
|
||||||
latents_region_denoised * weight * mask
|
|
||||||
)
|
)
|
||||||
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
|
latent_region_input = latent_region_input.cpu().numpy()
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
timestep = np.array([t], dtype=timestep_dtype)
|
||||||
|
region_noise_pred = self.unet(
|
||||||
|
sample=latent_region_input,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=region_embeds[r],
|
||||||
|
)
|
||||||
|
region_noise_pred = region_noise_pred[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
region_noise_pred_uncond, region_noise_pred_text = np.split(
|
||||||
|
region_noise_pred, 2
|
||||||
|
)
|
||||||
|
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
|
||||||
|
region_noise_pred_text - region_noise_pred_uncond
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
scheduler_output = self.scheduler.step(
|
||||||
|
torch.from_numpy(region_noise_pred),
|
||||||
|
t,
|
||||||
|
torch.from_numpy(latents_for_region),
|
||||||
|
**extra_step_kwargs,
|
||||||
|
)
|
||||||
|
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
||||||
|
|
||||||
|
if feather[0] > 0.0:
|
||||||
|
mask = make_tile_mask(
|
||||||
|
(h_end - h_start, w_end - w_start),
|
||||||
|
(h_end - h_start, w_end - w_start),
|
||||||
|
feather[0],
|
||||||
|
feather[1],
|
||||||
|
)
|
||||||
|
mask = np.expand_dims(mask, axis=0)
|
||||||
|
mask = np.repeat(mask, 4, axis=0)
|
||||||
|
mask = np.expand_dims(mask, axis=0)
|
||||||
|
else:
|
||||||
|
mask = 1
|
||||||
|
|
||||||
|
if weight >= 10.0:
|
||||||
|
value[:, :, h_start:h_end, w_start:w_end] = (
|
||||||
|
latents_region_denoised * mask
|
||||||
|
)
|
||||||
|
count[:, :, h_start:h_end, w_start:w_end] = mask
|
||||||
|
else:
|
||||||
|
value[:, :, h_start:h_end, w_start:w_end] += (
|
||||||
|
latents_region_denoised * weight * mask
|
||||||
|
)
|
||||||
|
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
|
||||||
|
|
||||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||||
latents = np.where(count > 0, value / count, value)
|
latents = np.where(count > 0, value / count, value)
|
||||||
|
|
|
@ -388,6 +388,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
# 8. Denoising loop
|
# 8. Denoising loop
|
||||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
|
last = i == (len(timesteps) - 1)
|
||||||
count.fill(0)
|
count.fill(0)
|
||||||
value.fill(0)
|
value.fill(0)
|
||||||
|
|
||||||
|
@ -443,106 +444,107 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||||
count[:, :, h_start:h_end, w_start:w_end] += 1
|
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||||
|
|
||||||
for r in range(len(regions)):
|
if not last:
|
||||||
top, left, bottom, right, weight, feather, prompt = regions[r]
|
for r, region in enumerate(regions):
|
||||||
logger.debug(
|
top, left, bottom, right, weight, feather, prompt = region
|
||||||
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
|
logger.debug(
|
||||||
top,
|
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
|
||||||
left,
|
top,
|
||||||
bottom,
|
left,
|
||||||
right,
|
bottom,
|
||||||
weight,
|
right,
|
||||||
feather,
|
weight,
|
||||||
prompt,
|
feather,
|
||||||
)
|
prompt,
|
||||||
|
|
||||||
# convert coordinates to latent space
|
|
||||||
h_start = top // LATENT_FACTOR
|
|
||||||
h_end = bottom // LATENT_FACTOR
|
|
||||||
w_start = left // LATENT_FACTOR
|
|
||||||
w_end = right // LATENT_FACTOR
|
|
||||||
|
|
||||||
# get the latents corresponding to the current view coordinates
|
|
||||||
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
|
||||||
logger.trace(
|
|
||||||
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
|
|
||||||
h_start,
|
|
||||||
h_end,
|
|
||||||
w_start,
|
|
||||||
w_end,
|
|
||||||
latents_for_region.shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
# expand the latents if we are doing classifier free guidance
|
|
||||||
latent_region_input = (
|
|
||||||
np.concatenate([latents_for_region] * 2)
|
|
||||||
if do_classifier_free_guidance
|
|
||||||
else latents_for_region
|
|
||||||
)
|
|
||||||
latent_region_input = self.scheduler.scale_model_input(
|
|
||||||
torch.from_numpy(latent_region_input), t
|
|
||||||
)
|
|
||||||
latent_region_input = latent_region_input.cpu().numpy()
|
|
||||||
|
|
||||||
# predict the noise residual
|
|
||||||
timestep = np.array([t], dtype=timestep_dtype)
|
|
||||||
region_noise_pred = self.unet(
|
|
||||||
sample=latent_region_input,
|
|
||||||
timestep=timestep,
|
|
||||||
encoder_hidden_states=region_embeds[r],
|
|
||||||
text_embeds=add_region_embeds[r],
|
|
||||||
time_ids=add_time_ids,
|
|
||||||
)
|
|
||||||
region_noise_pred = region_noise_pred[0]
|
|
||||||
|
|
||||||
# perform guidance
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
region_noise_pred_uncond, region_noise_pred_text = np.split(
|
|
||||||
region_noise_pred, 2
|
|
||||||
)
|
)
|
||||||
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
|
|
||||||
region_noise_pred_text - region_noise_pred_uncond
|
# convert coordinates to latent space
|
||||||
|
h_start = top // LATENT_FACTOR
|
||||||
|
h_end = bottom // LATENT_FACTOR
|
||||||
|
w_start = left // LATENT_FACTOR
|
||||||
|
w_end = right // LATENT_FACTOR
|
||||||
|
|
||||||
|
# get the latents corresponding to the current view coordinates
|
||||||
|
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
||||||
|
logger.trace(
|
||||||
|
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
|
||||||
|
h_start,
|
||||||
|
h_end,
|
||||||
|
w_start,
|
||||||
|
w_end,
|
||||||
|
latents_for_region.shape,
|
||||||
)
|
)
|
||||||
if guidance_rescale > 0.0:
|
|
||||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
# expand the latents if we are doing classifier free guidance
|
||||||
region_noise_pred = rescale_noise_cfg(
|
latent_region_input = (
|
||||||
region_noise_pred,
|
np.concatenate([latents_for_region] * 2)
|
||||||
region_noise_pred_text,
|
if do_classifier_free_guidance
|
||||||
guidance_rescale=guidance_rescale,
|
else latents_for_region
|
||||||
|
)
|
||||||
|
latent_region_input = self.scheduler.scale_model_input(
|
||||||
|
torch.from_numpy(latent_region_input), t
|
||||||
|
)
|
||||||
|
latent_region_input = latent_region_input.cpu().numpy()
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
timestep = np.array([t], dtype=timestep_dtype)
|
||||||
|
region_noise_pred = self.unet(
|
||||||
|
sample=latent_region_input,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=region_embeds[r],
|
||||||
|
text_embeds=add_region_embeds[r],
|
||||||
|
time_ids=add_time_ids,
|
||||||
|
)
|
||||||
|
region_noise_pred = region_noise_pred[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
region_noise_pred_uncond, region_noise_pred_text = np.split(
|
||||||
|
region_noise_pred, 2
|
||||||
)
|
)
|
||||||
|
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
|
||||||
|
region_noise_pred_text - region_noise_pred_uncond
|
||||||
|
)
|
||||||
|
if guidance_rescale > 0.0:
|
||||||
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
|
region_noise_pred = rescale_noise_cfg(
|
||||||
|
region_noise_pred,
|
||||||
|
region_noise_pred_text,
|
||||||
|
guidance_rescale=guidance_rescale,
|
||||||
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
scheduler_output = self.scheduler.step(
|
scheduler_output = self.scheduler.step(
|
||||||
torch.from_numpy(region_noise_pred),
|
torch.from_numpy(region_noise_pred),
|
||||||
t,
|
t,
|
||||||
torch.from_numpy(latents_for_region),
|
torch.from_numpy(latents_for_region),
|
||||||
**extra_step_kwargs,
|
**extra_step_kwargs,
|
||||||
)
|
)
|
||||||
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
||||||
|
|
||||||
if feather[0] > 0.0:
|
if feather[0] > 0.0:
|
||||||
mask = make_tile_mask(
|
mask = make_tile_mask(
|
||||||
(h_end - h_start, w_end - w_start),
|
(h_end - h_start, w_end - w_start),
|
||||||
(h_end - h_start, w_end - w_start),
|
(h_end - h_start, w_end - w_start),
|
||||||
feather[0],
|
feather[0],
|
||||||
feather[1],
|
feather[1],
|
||||||
)
|
)
|
||||||
mask = np.expand_dims(mask, axis=0)
|
mask = np.expand_dims(mask, axis=0)
|
||||||
mask = np.repeat(mask, 4, axis=0)
|
mask = np.repeat(mask, 4, axis=0)
|
||||||
mask = np.expand_dims(mask, axis=0)
|
mask = np.expand_dims(mask, axis=0)
|
||||||
else:
|
else:
|
||||||
mask = 1
|
mask = 1
|
||||||
|
|
||||||
if weight >= 10.0:
|
if weight >= 10.0:
|
||||||
value[:, :, h_start:h_end, w_start:w_end] = (
|
value[:, :, h_start:h_end, w_start:w_end] = (
|
||||||
latents_region_denoised * mask
|
latents_region_denoised * mask
|
||||||
)
|
)
|
||||||
count[:, :, h_start:h_end, w_start:w_end] = mask
|
count[:, :, h_start:h_end, w_start:w_end] = mask
|
||||||
else:
|
else:
|
||||||
value[:, :, h_start:h_end, w_start:w_end] += (
|
value[:, :, h_start:h_end, w_start:w_end] += (
|
||||||
latents_region_denoised * weight * mask
|
latents_region_denoised * weight * mask
|
||||||
)
|
)
|
||||||
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
|
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
|
||||||
|
|
||||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||||
latents = np.where(count > 0, value / count, value)
|
latents = np.where(count > 0, value / count, value)
|
||||||
|
|
Loading…
Reference in New Issue