From f564bb3f6557f20e98b864220e0797395b725398 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 8 Nov 2023 18:51:31 -0600 Subject: [PATCH] add regions to non-XL panorama, add feathering to SDXL regions --- api/onnx_web/chain/tile.py | 27 +++++- api/onnx_web/diffusers/pipelines/panorama.py | 89 +++++++++++++++++++ .../diffusers/pipelines/panorama_xl.py | 32 ++++--- 3 files changed, 136 insertions(+), 12 deletions(-) diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index 44f62671..8258ae7a 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -91,6 +91,31 @@ def get_tile_grads( return (grad_x, grad_y) +def make_tile_mask( + shape: np.ndarray, + tile: int, + overlap: float, +) -> np.ndarray: + mask = np.ones_like(shape[:, :, 0]) + adj_tile = int(float(tile) * (1.0 - overlap)) + + # sort gradient points + p1 = adj_tile + p2 = (tile - adj_tile) + points = [0, min(p1, p2), max(p1, p2), tile] + + # build gradients + grad_x, grad_y = [0, 1, 1, 0], [0, 1, 1, 0] + logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y) + + mult_x = [np.interp(i, points, grad_x) for i in range(tile)] + mult_y = [np.interp(i, points, grad_y) for i in range(tile)] + + mask = ((mask * mult_x).T * mult_y).T + + return mask + + def blend_tiles( tiles: List[Tuple[int, int, Image.Image]], scale: int, @@ -109,7 +134,7 @@ def blend_tiles( value = np.zeros(scaled_size) for left, top, tile_image in tiles: - # histogram equalization + # TODO: histogram equalization equalized = np.array(tile_image).astype(np.float32) mask = np.ones_like(equalized[:, :, 0]) diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index 99e283a1..3ed6c67a 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -26,6 +26,8 @@ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMSchedu from diffusers.utils import PIL_INTERPOLATION, deprecate, logging from transformers import CLIPImageProcessor, CLIPTokenizer +from ..utils import parse_regions + logger = logging.get_logger(__name__) @@ -479,6 +481,8 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + prompt, regions = parse_regions(prompt) + prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, @@ -488,6 +492,22 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): negative_prompt_embeds=negative_prompt_embeds, ) + # 3.b. Encode region prompts + region_embeds: List[np.ndarray] = [] + + for _top, _left, _bottom, _right, _mult, region_prompt in regions: + if region_prompt.endswith("+"): + region_prompt = region_prompt[:-1] + " " + prompt + + region_prompt_embeds = self._encode_prompt( + region_prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + ) + + region_embeds.append(region_prompt_embeds) + # get the initial random noise unless the user supplied it latents_dtype = prompt_embeds.dtype latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) @@ -576,6 +596,75 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised count[:, :, h_start:h_end, w_start:w_end] += 1 + for r in range(len(regions)): + top, left, bottom, right, mult, prompt = regions[r] + logger.debug( + "running region prompt: %s, %s, %s, %s, %s, %s", + top, + left, + bottom, + right, + mult, + prompt, + ) + + # convert coordinates to latent space + h_start = top // 8 + h_end = bottom // 8 + w_start = left // 8 + w_end = right // 8 + + # get the latents corresponding to the current view coordinates + latents_for_region = latents[:, :, h_start:h_end, w_start:w_end] + logger.trace("region latent shape: [:,:,%s:%s,%s:%s] -> %s", h_start, h_end, w_start, w_end, latents_for_region.shape) + + # expand the latents if we are doing classifier free guidance + latent_region_input = ( + np.concatenate([latents_for_region] * 2) + if do_classifier_free_guidance + else latents_for_region + ) + latent_region_input = self.scheduler.scale_model_input( + torch.from_numpy(latent_region_input), t + ) + latent_region_input = latent_region_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + region_noise_pred = self.unet( + sample=latent_region_input, + timestep=timestep, + encoder_hidden_states=region_embeds[r], + ) + region_noise_pred = region_noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + region_noise_pred_uncond, region_noise_pred_text = np.split( + region_noise_pred, 2 + ) + region_noise_pred = region_noise_pred_uncond + guidance_scale * ( + region_noise_pred_text - region_noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(region_noise_pred), + t, + torch.from_numpy(latents_for_region), + **extra_step_kwargs, + ) + latents_region_denoised = scheduler_output.prev_sample.numpy() + + if mult >= 10.0: + value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised + count[:, :, h_start:h_end, w_start:w_end] = 1 + else: + value[:, :, h_start:h_end, w_start:w_end] += ( + latents_region_denoised * mult + ) + count[:, :, h_start:h_end, w_start:w_end] += mult + # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 latents = np.where(count > 0, value / count, value) diff --git a/api/onnx_web/diffusers/pipelines/panorama_xl.py b/api/onnx_web/diffusers/pipelines/panorama_xl.py index 91e8b040..61817b15 100644 --- a/api/onnx_web/diffusers/pipelines/panorama_xl.py +++ b/api/onnx_web/diffusers/pipelines/panorama_xl.py @@ -12,6 +12,8 @@ from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import ( ) from optimum.pipelines.diffusers.pipeline_utils import preprocess, rescale_noise_cfg +from onnx_web.chain.tile import make_tile_mask + from ..utils import parse_regions logger = logging.getLogger(__name__) @@ -307,7 +309,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix region_embeds: List[np.ndarray] = [] add_region_embeds: List[np.ndarray] = [] - for _top, _left, _bottom, _right, _mode, region_prompt in regions: + for _top, _left, _bottom, _right, _mult, region_prompt in regions: + if region_prompt.endswith("+"): + region_prompt = region_prompt[:-1] + " " + prompt + ( region_prompt_embeds, region_negative_prompt_embeds, @@ -318,10 +323,6 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix num_images_per_prompt, do_classifier_free_guidance, negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, ) if do_classifier_free_guidance: @@ -462,7 +463,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix # get the latents corresponding to the current view coordinates latents_for_region = latents[:, :, h_start:h_end, w_start:w_end] - logger.trace("region latent shape: %s", latents_for_region.shape) + logger.trace("region latent shape: [:,:,%s:%s,%s:%s] -> %s", h_start, h_end, w_start, w_end, latents_for_region.shape) # expand the latents if we are doing classifier free guidance latent_region_input = ( @@ -511,14 +512,23 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix ) latents_region_denoised = scheduler_output.prev_sample.numpy() - if mult >= 1000.0: - value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised - count[:, :, h_start:h_end, w_start:w_end] = 1 + # TODO: get feather settings from prompt + feather = 0.25 + tile = 1024 + + if feather > 0.0: + mask = make_tile_mask(latents_region_denoised, tile, feather) + else: + mask = 1 + + if mult >= 10.0: + value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask + count[:, :, h_start:h_end, w_start:w_end] = mask else: value[:, :, h_start:h_end, w_start:w_end] += ( - latents_region_denoised * mult + latents_region_denoised * mult * mask ) - count[:, :, h_start:h_end, w_start:w_end] += mult + count[:, :, h_start:h_end, w_start:w_end] += mult * mask # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 latents = np.where(count > 0, value / count, value)