1
0
Fork 0

feat(api): add experimental region prompts to SDXL panorama

This commit is contained in:
Sean Sube 2023-11-05 15:36:31 -06:00
parent 59e1a1a4c2
commit 5cf7a39be0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 105 additions and 0 deletions

View File

@ -12,6 +12,8 @@ from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
) )
from optimum.pipelines.diffusers.pipeline_utils import preprocess, rescale_noise_cfg from optimum.pipelines.diffusers.pipeline_utils import preprocess, rescale_noise_cfg
from ..utils import parse_regions
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -299,6 +301,33 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
) )
# 3.b. Encode region prompts
regions = parse_regions(prompt)
region_embeds: List[Tuple[List[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]] = []
add_region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _mode, region_prompt in regions:
current_region_embeds = self._encode_prompt(
region_prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
)
if do_classifier_free_guidance:
current_region_embeds[0] = np.concatenate(
(current_region_embeds[1], current_region_embeds[0]), axis=0
)
add_region_embeds.append(np.concatenate(
(current_region_embeds[3], current_region_embeds[2]), axis=0
))
region_embeds.append(current_region_embeds)
# 4. Prepare timesteps # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
@ -330,6 +359,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
(negative_pooled_prompt_embeds, add_text_embeds), axis=0 (negative_pooled_prompt_embeds, add_text_embeds), axis=0
) )
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0) add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
add_time_ids = np.repeat( add_time_ids = np.repeat(
add_time_ids, batch_size * num_images_per_prompt, axis=0 add_time_ids, batch_size * num_images_per_prompt, axis=0
) )
@ -400,6 +430,75 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1 count[:, :, h_start:h_end, w_start:w_end] += 1
for i in range(len(regions)):
top, left, bottom, right, mode, prompt = regions[i]
print("running region prompt", top, left, bottom, right, mode, prompt)
# convert coordinates to latent space
h_start = top // 8
h_end = bottom // 8
w_start = left // 8
w_end = right // 8
# get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
# expand the latents if we are doing classifier free guidance
latent_model_input = (
np.concatenate([latents_for_view] * 2)
if do_classifier_free_guidance
else latents_for_view
)
latent_model_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_model_input), t
)
latent_model_input = latent_model_input.cpu().numpy()
# fetch region embeds
region_1 = region_embeds[i][0]
region_2 = add_region_embeds[i]
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(
sample=latent_model_input,
timestep=timestep,
encoder_hidden_states=region_1,
text_embeds=region_2,
time_ids=add_time_ids,
)
noise_pred = noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(
noise_pred,
noise_pred_text,
guidance_rescale=guidance_rescale,
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred),
t,
torch.from_numpy(latents_for_view),
**extra_step_kwargs,
)
latents_view_denoised = scheduler_output.prev_sample.numpy()
if mode:
value[:, :, h_start:h_end, w_start:w_end] = latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] = 1
else:
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value) latents = np.where(count > 0, value / count, value)

View File

@ -444,3 +444,9 @@ def slice_prompt(prompt: str, slice: int) -> str:
return parts[min(slice, len(parts) - 1)] return parts[min(slice, len(parts) - 1)]
else: else:
return prompt return prompt
Region = Tuple[int, int, int, int, bool, str]
def parse_regions(prompt: str) -> List[Region]:
return []