1
0
Fork 0

feat(api): add edge feathering to region prompts

This commit is contained in:
Sean Sube 2023-11-08 22:00:32 -06:00
parent 633e078036
commit 59515193a1
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 33 additions and 26 deletions

View File

@ -2,7 +2,7 @@ import itertools
from enum import Enum from enum import Enum
from logging import getLogger from logging import getLogger
from math import ceil from math import ceil
from typing import List, Optional, Protocol, Tuple from typing import Any, List, Optional, Protocol, Tuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -92,7 +92,7 @@ def get_tile_grads(
def make_tile_mask( def make_tile_mask(
shape: np.ndarray, shape: Any,
tile: int, tile: int,
overlap: float, overlap: float,
) -> np.ndarray: ) -> np.ndarray:

View File

@ -26,6 +26,8 @@ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMSchedu
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
from transformers import CLIPImageProcessor, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTokenizer
from onnx_web.chain.tile import make_tile_mask
from ..utils import parse_regions from ..utils import parse_regions
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -495,7 +497,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
# 3.b. Encode region prompts # 3.b. Encode region prompts
region_embeds: List[np.ndarray] = [] region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _mult, region_prompt in regions: for _top, _left, _bottom, _right, _weight, _feather, region_prompt in regions:
if region_prompt.endswith("+"): if region_prompt.endswith("+"):
region_prompt = region_prompt[:-1] + " " + prompt region_prompt = region_prompt[:-1] + " " + prompt
@ -597,14 +599,15 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
count[:, :, h_start:h_end, w_start:w_end] += 1 count[:, :, h_start:h_end, w_start:w_end] += 1
for r in range(len(regions)): for r in range(len(regions)):
top, left, bottom, right, mult, prompt = regions[r] top, left, bottom, right, weight, feather, prompt = regions[r]
logger.debug( logger.debug(
"running region prompt: %s, %s, %s, %s, %s, %s", "running region prompt: %s, %s, %s, %s, %s, %s, %s",
top, top,
left, left,
bottom, bottom,
right, right,
mult, weight,
feather,
prompt, prompt,
) )
@ -656,14 +659,21 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
) )
latents_region_denoised = scheduler_output.prev_sample.numpy() latents_region_denoised = scheduler_output.prev_sample.numpy()
if mult >= 10.0: if feather > 0.0:
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised mask = make_tile_mask((h_end - h_start, w_end - w_start), self.window, feather)
count[:, :, h_start:h_end, w_start:w_end] = 1 mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0)
else:
mask = 1
if weight >= 10.0:
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask
count[:, :, h_start:h_end, w_start:w_end] = mask
else: else:
value[:, :, h_start:h_end, w_start:w_end] += ( value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * mult latents_region_denoised * weight * mask
) )
count[:, :, h_start:h_end, w_start:w_end] += mult count[:, :, h_start:h_end, w_start:w_end] += weight * mask
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value) latents = np.where(count > 0, value / count, value)

View File

@ -309,7 +309,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
region_embeds: List[np.ndarray] = [] region_embeds: List[np.ndarray] = []
add_region_embeds: List[np.ndarray] = [] add_region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _mult, region_prompt in regions: for _top, _left, _bottom, _right, _weight, region_prompt in regions:
if region_prompt.endswith("+"): if region_prompt.endswith("+"):
region_prompt = region_prompt[:-1] + " " + prompt region_prompt = region_prompt[:-1] + " " + prompt
@ -444,14 +444,15 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
count[:, :, h_start:h_end, w_start:w_end] += 1 count[:, :, h_start:h_end, w_start:w_end] += 1
for r in range(len(regions)): for r in range(len(regions)):
top, left, bottom, right, mult, prompt = regions[r] top, left, bottom, right, weight, feather, prompt = regions[r]
logger.debug( logger.debug(
"running region prompt: %s, %s, %s, %s, %s, %s", "running region prompt: %s, %s, %s, %s, %s, %s, %s",
top, top,
left, left,
bottom, bottom,
right, right,
mult, weight,
feather,
prompt, prompt,
) )
@ -512,25 +513,21 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
) )
latents_region_denoised = scheduler_output.prev_sample.numpy() latents_region_denoised = scheduler_output.prev_sample.numpy()
# TODO: get feather settings from prompt
feather = 0.25
tile = 128
if feather > 0.0: if feather > 0.0:
mask = make_tile_mask(latents_region_denoised, (tile, tile), feather) mask = make_tile_mask((h_end - h_start, w_end - w_start), self.window, feather)
mask = np.repeat(mask, 4, axis=0) mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0) mask = np.expand_dims(mask, axis=0)
else: else:
mask = 1 mask = 1
if mult >= 10.0: if weight >= 10.0:
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask
count[:, :, h_start:h_end, w_start:w_end] = mask count[:, :, h_start:h_end, w_start:w_end] = mask
else: else:
value[:, :, h_start:h_end, w_start:w_end] += ( value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * mult * mask latents_region_denoised * weight * mask
) )
count[:, :, h_start:h_end, w_start:w_end] += mult * mask count[:, :, h_start:h_end, w_start:w_end] += weight * mask
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value) latents = np.where(count > 0, value / count, value)

View File

@ -21,7 +21,7 @@ CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>") INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>") LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__") WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
REGION_TOKEN = compile(r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):([^\>]+)\>") REGION_TOKEN = compile(r"\<region:(\d+):(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):([^\>]+)\>")
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)") ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
@ -459,8 +459,8 @@ Region = Tuple[int, int, int, int, float, str]
def parse_region_group(group) -> Region: def parse_region_group(group) -> Region:
top, left, bottom, right, mult, prompt = group top, left, bottom, right, weight, feather, prompt = group
return (int(top), int(left), int(bottom), int(right), float(mult), prompt) return (int(top), int(left), int(bottom), int(right), float(weight), float(feather), prompt)
def parse_regions(prompt: str) -> Tuple[str, List[Region]]: def parse_regions(prompt: str) -> Tuple[str, List[Region]]: