1
0
Fork 0

feat(api): skip regions on last timestep

This commit is contained in:
Sean Sube 2023-11-11 22:43:41 -06:00
parent 4513fa3428
commit 09f600ab54
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 182 additions and 178 deletions

View File

@ -557,6 +557,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
value = np.zeros_like(latents)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
last = i == (len(self.scheduler.timesteps) - 1)
count.fill(0)
value.fill(0)
@ -603,97 +604,98 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
for r in range(len(regions)):
top, left, bottom, right, weight, feather, prompt = regions[r]
logger.debug(
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
top,
left,
bottom,
right,
weight,
feather,
prompt,
)
# convert coordinates to latent space
h_start = top // LATENT_FACTOR
h_end = bottom // LATENT_FACTOR
w_start = left // LATENT_FACTOR
w_end = right // LATENT_FACTOR
# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
logger.trace(
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
h_start,
h_end,
w_start,
w_end,
latents_for_region.shape,
)
# expand the latents if we are doing classifier free guidance
latent_region_input = (
np.concatenate([latents_for_region] * 2)
if do_classifier_free_guidance
else latents_for_region
)
latent_region_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_region_input), t
)
latent_region_input = latent_region_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
region_noise_pred = self.unet(
sample=latent_region_input,
timestep=timestep,
encoder_hidden_states=region_embeds[r],
)
region_noise_pred = region_noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2
)
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
region_noise_pred_text - region_noise_pred_uncond
if not last:
for r, region in enumerate(regions):
top, left, bottom, right, weight, feather, prompt = region
logger.debug(
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
top,
left,
bottom,
right,
weight,
feather,
prompt,
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(region_noise_pred),
t,
torch.from_numpy(latents_for_region),
**extra_step_kwargs,
)
latents_region_denoised = scheduler_output.prev_sample.numpy()
# convert coordinates to latent space
h_start = top // LATENT_FACTOR
h_end = bottom // LATENT_FACTOR
w_start = left // LATENT_FACTOR
w_end = right // LATENT_FACTOR
if feather[0] > 0.0:
mask = make_tile_mask(
(h_end - h_start, w_end - w_start),
(h_end - h_start, w_end - w_start),
feather[0],
feather[1],
# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
logger.trace(
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
h_start,
h_end,
w_start,
w_end,
latents_for_region.shape,
)
mask = np.expand_dims(mask, axis=0)
mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0)
else:
mask = 1
if weight >= 10.0:
value[:, :, h_start:h_end, w_start:w_end] = (
latents_region_denoised * mask
# expand the latents if we are doing classifier free guidance
latent_region_input = (
np.concatenate([latents_for_region] * 2)
if do_classifier_free_guidance
else latents_for_region
)
count[:, :, h_start:h_end, w_start:w_end] = mask
else:
value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * weight * mask
latent_region_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_region_input), t
)
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
latent_region_input = latent_region_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
region_noise_pred = self.unet(
sample=latent_region_input,
timestep=timestep,
encoder_hidden_states=region_embeds[r],
)
region_noise_pred = region_noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2
)
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
region_noise_pred_text - region_noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(region_noise_pred),
t,
torch.from_numpy(latents_for_region),
**extra_step_kwargs,
)
latents_region_denoised = scheduler_output.prev_sample.numpy()
if feather[0] > 0.0:
mask = make_tile_mask(
(h_end - h_start, w_end - w_start),
(h_end - h_start, w_end - w_start),
feather[0],
feather[1],
)
mask = np.expand_dims(mask, axis=0)
mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0)
else:
mask = 1
if weight >= 10.0:
value[:, :, h_start:h_end, w_start:w_end] = (
latents_region_denoised * mask
)
count[:, :, h_start:h_end, w_start:w_end] = mask
else:
value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * weight * mask
)
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value)

View File

@ -388,6 +388,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
for i, t in enumerate(self.progress_bar(timesteps)):
last = i == (len(timesteps) - 1)
count.fill(0)
value.fill(0)
@ -443,106 +444,107 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
for r in range(len(regions)):
top, left, bottom, right, weight, feather, prompt = regions[r]
logger.debug(
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
top,
left,
bottom,
right,
weight,
feather,
prompt,
)
# convert coordinates to latent space
h_start = top // LATENT_FACTOR
h_end = bottom // LATENT_FACTOR
w_start = left // LATENT_FACTOR
w_end = right // LATENT_FACTOR
# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
logger.trace(
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
h_start,
h_end,
w_start,
w_end,
latents_for_region.shape,
)
# expand the latents if we are doing classifier free guidance
latent_region_input = (
np.concatenate([latents_for_region] * 2)
if do_classifier_free_guidance
else latents_for_region
)
latent_region_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_region_input), t
)
latent_region_input = latent_region_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
region_noise_pred = self.unet(
sample=latent_region_input,
timestep=timestep,
encoder_hidden_states=region_embeds[r],
text_embeds=add_region_embeds[r],
time_ids=add_time_ids,
)
region_noise_pred = region_noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2
if not last:
for r, region in enumerate(regions):
top, left, bottom, right, weight, feather, prompt = region
logger.debug(
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
top,
left,
bottom,
right,
weight,
feather,
prompt,
)
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
region_noise_pred_text - region_noise_pred_uncond
# convert coordinates to latent space
h_start = top // LATENT_FACTOR
h_end = bottom // LATENT_FACTOR
w_start = left // LATENT_FACTOR
w_end = right // LATENT_FACTOR
# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
logger.trace(
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
h_start,
h_end,
w_start,
w_end,
latents_for_region.shape,
)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
region_noise_pred = rescale_noise_cfg(
region_noise_pred,
region_noise_pred_text,
guidance_rescale=guidance_rescale,
# expand the latents if we are doing classifier free guidance
latent_region_input = (
np.concatenate([latents_for_region] * 2)
if do_classifier_free_guidance
else latents_for_region
)
latent_region_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_region_input), t
)
latent_region_input = latent_region_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
region_noise_pred = self.unet(
sample=latent_region_input,
timestep=timestep,
encoder_hidden_states=region_embeds[r],
text_embeds=add_region_embeds[r],
time_ids=add_time_ids,
)
region_noise_pred = region_noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2
)
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
region_noise_pred_text - region_noise_pred_uncond
)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
region_noise_pred = rescale_noise_cfg(
region_noise_pred,
region_noise_pred_text,
guidance_rescale=guidance_rescale,
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(region_noise_pred),
t,
torch.from_numpy(latents_for_region),
**extra_step_kwargs,
)
latents_region_denoised = scheduler_output.prev_sample.numpy()
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(region_noise_pred),
t,
torch.from_numpy(latents_for_region),
**extra_step_kwargs,
)
latents_region_denoised = scheduler_output.prev_sample.numpy()
if feather[0] > 0.0:
mask = make_tile_mask(
(h_end - h_start, w_end - w_start),
(h_end - h_start, w_end - w_start),
feather[0],
feather[1],
)
mask = np.expand_dims(mask, axis=0)
mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0)
else:
mask = 1
if feather[0] > 0.0:
mask = make_tile_mask(
(h_end - h_start, w_end - w_start),
(h_end - h_start, w_end - w_start),
feather[0],
feather[1],
)
mask = np.expand_dims(mask, axis=0)
mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0)
else:
mask = 1
if weight >= 10.0:
value[:, :, h_start:h_end, w_start:w_end] = (
latents_region_denoised * mask
)
count[:, :, h_start:h_end, w_start:w_end] = mask
else:
value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * weight * mask
)
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
if weight >= 10.0:
value[:, :, h_start:h_end, w_start:w_end] = (
latents_region_denoised * mask
)
count[:, :, h_start:h_end, w_start:w_end] = mask
else:
value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * weight * mask
)
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value)