diff --git a/api/onnx_web/diffusers/pipelines/panorama_xl.py b/api/onnx_web/diffusers/pipelines/panorama_xl.py index fed65722..a0c92d07 100644 --- a/api/onnx_web/diffusers/pipelines/panorama_xl.py +++ b/api/onnx_web/diffusers/pipelines/panorama_xl.py @@ -12,6 +12,8 @@ from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import ( ) from optimum.pipelines.diffusers.pipeline_utils import preprocess, rescale_noise_cfg +from ..utils import parse_regions + logger = logging.getLogger(__name__) @@ -299,6 +301,33 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, ) + # 3.b. Encode region prompts + regions = parse_regions(prompt) + region_embeds: List[Tuple[List[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]] = [] + add_region_embeds: List[np.ndarray] = [] + + for _top, _left, _bottom, _right, _mode, region_prompt in regions: + current_region_embeds = self._encode_prompt( + region_prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + if do_classifier_free_guidance: + current_region_embeds[0] = np.concatenate( + (current_region_embeds[1], current_region_embeds[0]), axis=0 + ) + add_region_embeds.append(np.concatenate( + (current_region_embeds[3], current_region_embeds[2]), axis=0 + )) + + region_embeds.append(current_region_embeds) + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps) timesteps = self.scheduler.timesteps @@ -330,6 +359,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix (negative_pooled_prompt_embeds, add_text_embeds), axis=0 ) add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0) + add_time_ids = np.repeat( add_time_ids, batch_size * num_images_per_prompt, axis=0 ) @@ -400,6 +430,75 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1 + for i in range(len(regions)): + top, left, bottom, right, mode, prompt = regions[i] + print("running region prompt", top, left, bottom, right, mode, prompt) + + # convert coordinates to latent space + h_start = top // 8 + h_end = bottom // 8 + w_start = left // 8 + w_end = right // 8 + + # get the latents corresponding to the current view coordinates + latents_for_view = latents[:, :, h_start:h_end, w_start:w_end] + + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + np.concatenate([latents_for_view] * 2) + if do_classifier_free_guidance + else latents_for_view + ) + latent_model_input = self.scheduler.scale_model_input( + torch.from_numpy(latent_model_input), t + ) + latent_model_input = latent_model_input.cpu().numpy() + + # fetch region embeds + region_1 = region_embeds[i][0] + region_2 = add_region_embeds[i] + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet( + sample=latent_model_input, + timestep=timestep, + encoder_hidden_states=region_1, + text_embeds=region_2, + time_ids=add_time_ids, + ) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=guidance_rescale, + ) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), + t, + torch.from_numpy(latents_for_view), + **extra_step_kwargs, + ) + latents_view_denoised = scheduler_output.prev_sample.numpy() + + if mode: + value[:, :, h_start:h_end, w_start:w_end] = latents_view_denoised + count[:, :, h_start:h_end, w_start:w_end] = 1 + else: + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised + count[:, :, h_start:h_end, w_start:w_end] += 1 + # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 latents = np.where(count > 0, value / count, value) diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index ab3c63c5..152b5bca 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -444,3 +444,9 @@ def slice_prompt(prompt: str, slice: int) -> str: return parts[min(slice, len(parts) - 1)] else: return prompt + + +Region = Tuple[int, int, int, int, bool, str] + +def parse_regions(prompt: str) -> List[Region]: + return []