1
0
Fork 0

feat(api): skip regions on last timestep

This commit is contained in:
Sean Sube 2023-11-11 22:43:41 -06:00
parent 4513fa3428
commit 09f600ab54
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 182 additions and 178 deletions

View File

@ -557,6 +557,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
value = np.zeros_like(latents) value = np.zeros_like(latents)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
last = i == (len(self.scheduler.timesteps) - 1)
count.fill(0) count.fill(0)
value.fill(0) value.fill(0)
@ -603,97 +604,98 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1 count[:, :, h_start:h_end, w_start:w_end] += 1
for r in range(len(regions)): if not last:
top, left, bottom, right, weight, feather, prompt = regions[r] for r, region in enumerate(regions):
logger.debug( top, left, bottom, right, weight, feather, prompt = region
"running region prompt: %s, %s, %s, %s, %s, %s, %s", logger.debug(
top, "running region prompt: %s, %s, %s, %s, %s, %s, %s",
left, top,
bottom, left,
right, bottom,
weight, right,
feather, weight,
prompt, feather,
) prompt,
# convert coordinates to latent space
h_start = top // LATENT_FACTOR
h_end = bottom // LATENT_FACTOR
w_start = left // LATENT_FACTOR
w_end = right // LATENT_FACTOR
# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
logger.trace(
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
h_start,
h_end,
w_start,
w_end,
latents_for_region.shape,
)
# expand the latents if we are doing classifier free guidance
latent_region_input = (
np.concatenate([latents_for_region] * 2)
if do_classifier_free_guidance
else latents_for_region
)
latent_region_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_region_input), t
)
latent_region_input = latent_region_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
region_noise_pred = self.unet(
sample=latent_region_input,
timestep=timestep,
encoder_hidden_states=region_embeds[r],
)
region_noise_pred = region_noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2
)
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
region_noise_pred_text - region_noise_pred_uncond
) )
# compute the previous noisy sample x_t -> x_t-1 # convert coordinates to latent space
scheduler_output = self.scheduler.step( h_start = top // LATENT_FACTOR
torch.from_numpy(region_noise_pred), h_end = bottom // LATENT_FACTOR
t, w_start = left // LATENT_FACTOR
torch.from_numpy(latents_for_region), w_end = right // LATENT_FACTOR
**extra_step_kwargs,
)
latents_region_denoised = scheduler_output.prev_sample.numpy()
if feather[0] > 0.0: # get the latents corresponding to the current view coordinates
mask = make_tile_mask( latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
(h_end - h_start, w_end - w_start), logger.trace(
(h_end - h_start, w_end - w_start), "region latent shape: [:,:,%s:%s,%s:%s] -> %s",
feather[0], h_start,
feather[1], h_end,
w_start,
w_end,
latents_for_region.shape,
) )
mask = np.expand_dims(mask, axis=0)
mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0)
else:
mask = 1
if weight >= 10.0: # expand the latents if we are doing classifier free guidance
value[:, :, h_start:h_end, w_start:w_end] = ( latent_region_input = (
latents_region_denoised * mask np.concatenate([latents_for_region] * 2)
if do_classifier_free_guidance
else latents_for_region
) )
count[:, :, h_start:h_end, w_start:w_end] = mask latent_region_input = self.scheduler.scale_model_input(
else: torch.from_numpy(latent_region_input), t
value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * weight * mask
) )
count[:, :, h_start:h_end, w_start:w_end] += weight * mask latent_region_input = latent_region_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
region_noise_pred = self.unet(
sample=latent_region_input,
timestep=timestep,
encoder_hidden_states=region_embeds[r],
)
region_noise_pred = region_noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2
)
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
region_noise_pred_text - region_noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(region_noise_pred),
t,
torch.from_numpy(latents_for_region),
**extra_step_kwargs,
)
latents_region_denoised = scheduler_output.prev_sample.numpy()
if feather[0] > 0.0:
mask = make_tile_mask(
(h_end - h_start, w_end - w_start),
(h_end - h_start, w_end - w_start),
feather[0],
feather[1],
)
mask = np.expand_dims(mask, axis=0)
mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0)
else:
mask = 1
if weight >= 10.0:
value[:, :, h_start:h_end, w_start:w_end] = (
latents_region_denoised * mask
)
count[:, :, h_start:h_end, w_start:w_end] = mask
else:
value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * weight * mask
)
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value) latents = np.where(count > 0, value / count, value)

View File

@ -388,6 +388,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
for i, t in enumerate(self.progress_bar(timesteps)): for i, t in enumerate(self.progress_bar(timesteps)):
last = i == (len(timesteps) - 1)
count.fill(0) count.fill(0)
value.fill(0) value.fill(0)
@ -443,106 +444,107 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1 count[:, :, h_start:h_end, w_start:w_end] += 1
for r in range(len(regions)): if not last:
top, left, bottom, right, weight, feather, prompt = regions[r] for r, region in enumerate(regions):
logger.debug( top, left, bottom, right, weight, feather, prompt = region
"running region prompt: %s, %s, %s, %s, %s, %s, %s", logger.debug(
top, "running region prompt: %s, %s, %s, %s, %s, %s, %s",
left, top,
bottom, left,
right, bottom,
weight, right,
feather, weight,
prompt, feather,
) prompt,
# convert coordinates to latent space
h_start = top // LATENT_FACTOR
h_end = bottom // LATENT_FACTOR
w_start = left // LATENT_FACTOR
w_end = right // LATENT_FACTOR
# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
logger.trace(
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
h_start,
h_end,
w_start,
w_end,
latents_for_region.shape,
)
# expand the latents if we are doing classifier free guidance
latent_region_input = (
np.concatenate([latents_for_region] * 2)
if do_classifier_free_guidance
else latents_for_region
)
latent_region_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_region_input), t
)
latent_region_input = latent_region_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
region_noise_pred = self.unet(
sample=latent_region_input,
timestep=timestep,
encoder_hidden_states=region_embeds[r],
text_embeds=add_region_embeds[r],
time_ids=add_time_ids,
)
region_noise_pred = region_noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2
) )
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
region_noise_pred_text - region_noise_pred_uncond # convert coordinates to latent space
h_start = top // LATENT_FACTOR
h_end = bottom // LATENT_FACTOR
w_start = left // LATENT_FACTOR
w_end = right // LATENT_FACTOR
# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
logger.trace(
"region latent shape: [:,:,%s:%s,%s:%s] -> %s",
h_start,
h_end,
w_start,
w_end,
latents_for_region.shape,
) )
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # expand the latents if we are doing classifier free guidance
region_noise_pred = rescale_noise_cfg( latent_region_input = (
region_noise_pred, np.concatenate([latents_for_region] * 2)
region_noise_pred_text, if do_classifier_free_guidance
guidance_rescale=guidance_rescale, else latents_for_region
)
latent_region_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_region_input), t
)
latent_region_input = latent_region_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
region_noise_pred = self.unet(
sample=latent_region_input,
timestep=timestep,
encoder_hidden_states=region_embeds[r],
text_embeds=add_region_embeds[r],
time_ids=add_time_ids,
)
region_noise_pred = region_noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2
) )
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
region_noise_pred_text - region_noise_pred_uncond
)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
region_noise_pred = rescale_noise_cfg(
region_noise_pred,
region_noise_pred_text,
guidance_rescale=guidance_rescale,
)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step( scheduler_output = self.scheduler.step(
torch.from_numpy(region_noise_pred), torch.from_numpy(region_noise_pred),
t, t,
torch.from_numpy(latents_for_region), torch.from_numpy(latents_for_region),
**extra_step_kwargs, **extra_step_kwargs,
) )
latents_region_denoised = scheduler_output.prev_sample.numpy() latents_region_denoised = scheduler_output.prev_sample.numpy()
if feather[0] > 0.0: if feather[0] > 0.0:
mask = make_tile_mask( mask = make_tile_mask(
(h_end - h_start, w_end - w_start), (h_end - h_start, w_end - w_start),
(h_end - h_start, w_end - w_start), (h_end - h_start, w_end - w_start),
feather[0], feather[0],
feather[1], feather[1],
) )
mask = np.expand_dims(mask, axis=0) mask = np.expand_dims(mask, axis=0)
mask = np.repeat(mask, 4, axis=0) mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0) mask = np.expand_dims(mask, axis=0)
else: else:
mask = 1 mask = 1
if weight >= 10.0: if weight >= 10.0:
value[:, :, h_start:h_end, w_start:w_end] = ( value[:, :, h_start:h_end, w_start:w_end] = (
latents_region_denoised * mask latents_region_denoised * mask
) )
count[:, :, h_start:h_end, w_start:w_end] = mask count[:, :, h_start:h_end, w_start:w_end] = mask
else: else:
value[:, :, h_start:h_end, w_start:w_end] += ( value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * weight * mask latents_region_denoised * weight * mask
) )
count[:, :, h_start:h_end, w_start:w_end] += weight * mask count[:, :, h_start:h_end, w_start:w_end] += weight * mask
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value) latents = np.where(count > 0, value / count, value)