feat(api): add experimental region prompts to SDXL panorama
This commit is contained in:
parent
59e1a1a4c2
commit
5cf7a39be0
|
@ -12,6 +12,8 @@ from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
|
|||
)
|
||||
from optimum.pipelines.diffusers.pipeline_utils import preprocess, rescale_noise_cfg
|
||||
|
||||
from ..utils import parse_regions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -299,6 +301,33 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
# 3.b. Encode region prompts
|
||||
regions = parse_regions(prompt)
|
||||
region_embeds: List[Tuple[List[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]] = []
|
||||
add_region_embeds: List[np.ndarray] = []
|
||||
|
||||
for _top, _left, _bottom, _right, _mode, region_prompt in regions:
|
||||
current_region_embeds = self._encode_prompt(
|
||||
region_prompt,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
current_region_embeds[0] = np.concatenate(
|
||||
(current_region_embeds[1], current_region_embeds[0]), axis=0
|
||||
)
|
||||
add_region_embeds.append(np.concatenate(
|
||||
(current_region_embeds[3], current_region_embeds[2]), axis=0
|
||||
))
|
||||
|
||||
region_embeds.append(current_region_embeds)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
@ -330,6 +359,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
(negative_pooled_prompt_embeds, add_text_embeds), axis=0
|
||||
)
|
||||
add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0)
|
||||
|
||||
add_time_ids = np.repeat(
|
||||
add_time_ids, batch_size * num_images_per_prompt, axis=0
|
||||
)
|
||||
|
@ -400,6 +430,75 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||
|
||||
for i in range(len(regions)):
|
||||
top, left, bottom, right, mode, prompt = regions[i]
|
||||
print("running region prompt", top, left, bottom, right, mode, prompt)
|
||||
|
||||
# convert coordinates to latent space
|
||||
h_start = top // 8
|
||||
h_end = bottom // 8
|
||||
w_start = left // 8
|
||||
w_end = right // 8
|
||||
|
||||
# get the latents corresponding to the current view coordinates
|
||||
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
np.concatenate([latents_for_view] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents_for_view
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
torch.from_numpy(latent_model_input), t
|
||||
)
|
||||
latent_model_input = latent_model_input.cpu().numpy()
|
||||
|
||||
# fetch region embeds
|
||||
region_1 = region_embeds[i][0]
|
||||
region_2 = add_region_embeds[i]
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=region_1,
|
||||
text_embeds=region_2,
|
||||
time_ids=add_time_ids,
|
||||
)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
if guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(
|
||||
noise_pred,
|
||||
noise_pred_text,
|
||||
guidance_rescale=guidance_rescale,
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = self.scheduler.step(
|
||||
torch.from_numpy(noise_pred),
|
||||
t,
|
||||
torch.from_numpy(latents_for_view),
|
||||
**extra_step_kwargs,
|
||||
)
|
||||
latents_view_denoised = scheduler_output.prev_sample.numpy()
|
||||
|
||||
if mode:
|
||||
value[:, :, h_start:h_end, w_start:w_end] = latents_view_denoised
|
||||
count[:, :, h_start:h_end, w_start:w_end] = 1
|
||||
else:
|
||||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||
|
||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||
latents = np.where(count > 0, value / count, value)
|
||||
|
||||
|
|
|
@ -444,3 +444,9 @@ def slice_prompt(prompt: str, slice: int) -> str:
|
|||
return parts[min(slice, len(parts) - 1)]
|
||||
else:
|
||||
return prompt
|
||||
|
||||
|
||||
Region = Tuple[int, int, int, int, bool, str]
|
||||
|
||||
def parse_regions(prompt: str) -> List[Region]:
|
||||
return []
|
||||
|
|
Loading…
Reference in New Issue