From 59515193a1cd6ffc0027f08e652adb675f60eeeb Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 8 Nov 2023 22:00:32 -0600 Subject: [PATCH] feat(api): add edge feathering to region prompts --- api/onnx_web/chain/tile.py | 4 +-- api/onnx_web/diffusers/pipelines/panorama.py | 28 +++++++++++++------ .../diffusers/pipelines/panorama_xl.py | 21 ++++++-------- api/onnx_web/diffusers/utils.py | 6 ++-- 4 files changed, 33 insertions(+), 26 deletions(-) diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index 4e2a4e3b..1e9fa38d 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -2,7 +2,7 @@ import itertools from enum import Enum from logging import getLogger from math import ceil -from typing import List, Optional, Protocol, Tuple +from typing import Any, List, Optional, Protocol, Tuple import numpy as np from PIL import Image @@ -92,7 +92,7 @@ def get_tile_grads( def make_tile_mask( - shape: np.ndarray, + shape: Any, tile: int, overlap: float, ) -> np.ndarray: diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index 3ed6c67a..afb05b76 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -26,6 +26,8 @@ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMSchedu from diffusers.utils import PIL_INTERPOLATION, deprecate, logging from transformers import CLIPImageProcessor, CLIPTokenizer +from onnx_web.chain.tile import make_tile_mask + from ..utils import parse_regions logger = logging.get_logger(__name__) @@ -495,7 +497,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): # 3.b. Encode region prompts region_embeds: List[np.ndarray] = [] - for _top, _left, _bottom, _right, _mult, region_prompt in regions: + for _top, _left, _bottom, _right, _weight, _feather, region_prompt in regions: if region_prompt.endswith("+"): region_prompt = region_prompt[:-1] + " " + prompt @@ -597,14 +599,15 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): count[:, :, h_start:h_end, w_start:w_end] += 1 for r in range(len(regions)): - top, left, bottom, right, mult, prompt = regions[r] + top, left, bottom, right, weight, feather, prompt = regions[r] logger.debug( - "running region prompt: %s, %s, %s, %s, %s, %s", + "running region prompt: %s, %s, %s, %s, %s, %s, %s", top, left, bottom, right, - mult, + weight, + feather, prompt, ) @@ -656,14 +659,21 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): ) latents_region_denoised = scheduler_output.prev_sample.numpy() - if mult >= 10.0: - value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised - count[:, :, h_start:h_end, w_start:w_end] = 1 + if feather > 0.0: + mask = make_tile_mask((h_end - h_start, w_end - w_start), self.window, feather) + mask = np.repeat(mask, 4, axis=0) + mask = np.expand_dims(mask, axis=0) + else: + mask = 1 + + if weight >= 10.0: + value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask + count[:, :, h_start:h_end, w_start:w_end] = mask else: value[:, :, h_start:h_end, w_start:w_end] += ( - latents_region_denoised * mult + latents_region_denoised * weight * mask ) - count[:, :, h_start:h_end, w_start:w_end] += mult + count[:, :, h_start:h_end, w_start:w_end] += weight * mask # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 latents = np.where(count > 0, value / count, value) diff --git a/api/onnx_web/diffusers/pipelines/panorama_xl.py b/api/onnx_web/diffusers/pipelines/panorama_xl.py index 444cfa3c..ee6d0413 100644 --- a/api/onnx_web/diffusers/pipelines/panorama_xl.py +++ b/api/onnx_web/diffusers/pipelines/panorama_xl.py @@ -309,7 +309,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix region_embeds: List[np.ndarray] = [] add_region_embeds: List[np.ndarray] = [] - for _top, _left, _bottom, _right, _mult, region_prompt in regions: + for _top, _left, _bottom, _right, _weight, region_prompt in regions: if region_prompt.endswith("+"): region_prompt = region_prompt[:-1] + " " + prompt @@ -444,14 +444,15 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix count[:, :, h_start:h_end, w_start:w_end] += 1 for r in range(len(regions)): - top, left, bottom, right, mult, prompt = regions[r] + top, left, bottom, right, weight, feather, prompt = regions[r] logger.debug( - "running region prompt: %s, %s, %s, %s, %s, %s", + "running region prompt: %s, %s, %s, %s, %s, %s, %s", top, left, bottom, right, - mult, + weight, + feather, prompt, ) @@ -512,25 +513,21 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix ) latents_region_denoised = scheduler_output.prev_sample.numpy() - # TODO: get feather settings from prompt - feather = 0.25 - tile = 128 - if feather > 0.0: - mask = make_tile_mask(latents_region_denoised, (tile, tile), feather) + mask = make_tile_mask((h_end - h_start, w_end - w_start), self.window, feather) mask = np.repeat(mask, 4, axis=0) mask = np.expand_dims(mask, axis=0) else: mask = 1 - if mult >= 10.0: + if weight >= 10.0: value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask count[:, :, h_start:h_end, w_start:w_end] = mask else: value[:, :, h_start:h_end, w_start:w_end] += ( - latents_region_denoised * mult * mask + latents_region_denoised * weight * mask ) - count[:, :, h_start:h_end, w_start:w_end] += mult * mask + count[:, :, h_start:h_end, w_start:w_end] += weight * mask # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 latents = np.where(count > 0, value / count, value) diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 45928a48..fa4b1dcd 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -21,7 +21,7 @@ CLIP_TOKEN = compile(r"\") INVERSION_TOKEN = compile(r"\]+):(-?[\.|\d]+)\>") LORA_TOKEN = compile(r"\]+):(-?[\.|\d]+)\>") WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__") -REGION_TOKEN = compile(r"\]+)\>") +REGION_TOKEN = compile(r"\]+)\>") INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)") @@ -459,8 +459,8 @@ Region = Tuple[int, int, int, int, float, str] def parse_region_group(group) -> Region: - top, left, bottom, right, mult, prompt = group - return (int(top), int(left), int(bottom), int(right), float(mult), prompt) + top, left, bottom, right, weight, feather, prompt = group + return (int(top), int(left), int(bottom), int(right), float(weight), float(feather), prompt) def parse_regions(prompt: str) -> Tuple[str, List[Region]]: