From baecb38343202153f9ddafa03e2ceec984d36003 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 5 Nov 2023 16:25:48 -0600 Subject: [PATCH] fix region parsing --- .../diffusers/pipelines/panorama_xl.py | 23 ++++++++++++------- api/onnx_web/diffusers/utils.py | 2 +- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/api/onnx_web/diffusers/pipelines/panorama_xl.py b/api/onnx_web/diffusers/pipelines/panorama_xl.py index 7c46326c..27fc4964 100644 --- a/api/onnx_web/diffusers/pipelines/panorama_xl.py +++ b/api/onnx_web/diffusers/pipelines/panorama_xl.py @@ -284,6 +284,8 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + prompt, regions = parse_regions(prompt) + # 3. Encode input prompt ( prompt_embeds, @@ -302,7 +304,6 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix ) # 3.b. Encode region prompts - regions = parse_regions(prompt) region_embeds: List[ Tuple[ List[np.ndarray], @@ -314,7 +315,12 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix add_region_embeds: List[np.ndarray] = [] for _top, _left, _bottom, _right, _mode, region_prompt in regions: - current_region_embeds = self._encode_prompt( + ( + region_prompt_embeds, + region_negative_prompt_embeds, + region_pooled_prompt_embeds, + region_negative_pooled_prompt_embeds, + ) = self._encode_prompt( region_prompt, num_images_per_prompt, do_classifier_free_guidance, @@ -326,16 +332,16 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix ) if do_classifier_free_guidance: - current_region_embeds[0] = np.concatenate( - (current_region_embeds[1], current_region_embeds[0]), axis=0 + region_prompt_embeds = np.concatenate( + (region_negative_prompt_embeds, region_prompt_embeds), axis=0 ) add_region_embeds.append( np.concatenate( - (current_region_embeds[3], current_region_embeds[2]), axis=0 + (region_negative_pooled_prompt_embeds, region_pooled_prompt_embeds), axis=0 ) ) - region_embeds.append(current_region_embeds) + region_embeds.append(region_prompt_embeds) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps) @@ -441,7 +447,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix for i in range(len(regions)): top, left, bottom, right, mode, prompt = regions[i] - print("running region prompt", top, left, bottom, right, mode, prompt) + logger.debug("running region prompt: %s, %s, %s, %s, %s, %s", top, left, bottom, right, mode, prompt) # convert coordinates to latent space h_start = top // 8 @@ -464,8 +470,9 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix latent_model_input = latent_model_input.cpu().numpy() # fetch region embeds - region_1 = region_embeds[i][0] + region_1 = region_embeds[i] region_2 = add_region_embeds[i] + logger.debug("region embeds shape: %s, %s", region_1.shape, region_2.shape) # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index d96e3d5e..5a5cfcd4 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -458,5 +458,5 @@ def slice_prompt(prompt: str, slice: int) -> str: Region = Tuple[int, int, int, int, Literal["add", "replace"], str] -def parse_regions(prompt: str) -> List[Region]: +def parse_regions(prompt: str) -> Tuple[str, List[Region]]: return get_tokens_from_prompt(prompt, REGION_TOKEN, lambda it: it)