feat(api): add edge feathering to region prompts
This commit is contained in:
parent
633e078036
commit
59515193a1
|
@ -2,7 +2,7 @@ import itertools
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import List, Optional, Protocol, Tuple
|
from typing import Any, List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -92,7 +92,7 @@ def get_tile_grads(
|
||||||
|
|
||||||
|
|
||||||
def make_tile_mask(
|
def make_tile_mask(
|
||||||
shape: np.ndarray,
|
shape: Any,
|
||||||
tile: int,
|
tile: int,
|
||||||
overlap: float,
|
overlap: float,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
|
|
@ -26,6 +26,8 @@ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMSchedu
|
||||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||||
from transformers import CLIPImageProcessor, CLIPTokenizer
|
from transformers import CLIPImageProcessor, CLIPTokenizer
|
||||||
|
|
||||||
|
from onnx_web.chain.tile import make_tile_mask
|
||||||
|
|
||||||
from ..utils import parse_regions
|
from ..utils import parse_regions
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
@ -495,7 +497,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
# 3.b. Encode region prompts
|
# 3.b. Encode region prompts
|
||||||
region_embeds: List[np.ndarray] = []
|
region_embeds: List[np.ndarray] = []
|
||||||
|
|
||||||
for _top, _left, _bottom, _right, _mult, region_prompt in regions:
|
for _top, _left, _bottom, _right, _weight, _feather, region_prompt in regions:
|
||||||
if region_prompt.endswith("+"):
|
if region_prompt.endswith("+"):
|
||||||
region_prompt = region_prompt[:-1] + " " + prompt
|
region_prompt = region_prompt[:-1] + " " + prompt
|
||||||
|
|
||||||
|
@ -597,14 +599,15 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
count[:, :, h_start:h_end, w_start:w_end] += 1
|
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||||
|
|
||||||
for r in range(len(regions)):
|
for r in range(len(regions)):
|
||||||
top, left, bottom, right, mult, prompt = regions[r]
|
top, left, bottom, right, weight, feather, prompt = regions[r]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"running region prompt: %s, %s, %s, %s, %s, %s",
|
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
|
||||||
top,
|
top,
|
||||||
left,
|
left,
|
||||||
bottom,
|
bottom,
|
||||||
right,
|
right,
|
||||||
mult,
|
weight,
|
||||||
|
feather,
|
||||||
prompt,
|
prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -656,14 +659,21 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
)
|
)
|
||||||
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
||||||
|
|
||||||
if mult >= 10.0:
|
if feather > 0.0:
|
||||||
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised
|
mask = make_tile_mask((h_end - h_start, w_end - w_start), self.window, feather)
|
||||||
count[:, :, h_start:h_end, w_start:w_end] = 1
|
mask = np.repeat(mask, 4, axis=0)
|
||||||
|
mask = np.expand_dims(mask, axis=0)
|
||||||
|
else:
|
||||||
|
mask = 1
|
||||||
|
|
||||||
|
if weight >= 10.0:
|
||||||
|
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask
|
||||||
|
count[:, :, h_start:h_end, w_start:w_end] = mask
|
||||||
else:
|
else:
|
||||||
value[:, :, h_start:h_end, w_start:w_end] += (
|
value[:, :, h_start:h_end, w_start:w_end] += (
|
||||||
latents_region_denoised * mult
|
latents_region_denoised * weight * mask
|
||||||
)
|
)
|
||||||
count[:, :, h_start:h_end, w_start:w_end] += mult
|
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
|
||||||
|
|
||||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||||
latents = np.where(count > 0, value / count, value)
|
latents = np.where(count > 0, value / count, value)
|
||||||
|
|
|
@ -309,7 +309,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
region_embeds: List[np.ndarray] = []
|
region_embeds: List[np.ndarray] = []
|
||||||
add_region_embeds: List[np.ndarray] = []
|
add_region_embeds: List[np.ndarray] = []
|
||||||
|
|
||||||
for _top, _left, _bottom, _right, _mult, region_prompt in regions:
|
for _top, _left, _bottom, _right, _weight, region_prompt in regions:
|
||||||
if region_prompt.endswith("+"):
|
if region_prompt.endswith("+"):
|
||||||
region_prompt = region_prompt[:-1] + " " + prompt
|
region_prompt = region_prompt[:-1] + " " + prompt
|
||||||
|
|
||||||
|
@ -444,14 +444,15 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
count[:, :, h_start:h_end, w_start:w_end] += 1
|
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||||
|
|
||||||
for r in range(len(regions)):
|
for r in range(len(regions)):
|
||||||
top, left, bottom, right, mult, prompt = regions[r]
|
top, left, bottom, right, weight, feather, prompt = regions[r]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"running region prompt: %s, %s, %s, %s, %s, %s",
|
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
|
||||||
top,
|
top,
|
||||||
left,
|
left,
|
||||||
bottom,
|
bottom,
|
||||||
right,
|
right,
|
||||||
mult,
|
weight,
|
||||||
|
feather,
|
||||||
prompt,
|
prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -512,25 +513,21 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
||||||
)
|
)
|
||||||
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
||||||
|
|
||||||
# TODO: get feather settings from prompt
|
|
||||||
feather = 0.25
|
|
||||||
tile = 128
|
|
||||||
|
|
||||||
if feather > 0.0:
|
if feather > 0.0:
|
||||||
mask = make_tile_mask(latents_region_denoised, (tile, tile), feather)
|
mask = make_tile_mask((h_end - h_start, w_end - w_start), self.window, feather)
|
||||||
mask = np.repeat(mask, 4, axis=0)
|
mask = np.repeat(mask, 4, axis=0)
|
||||||
mask = np.expand_dims(mask, axis=0)
|
mask = np.expand_dims(mask, axis=0)
|
||||||
else:
|
else:
|
||||||
mask = 1
|
mask = 1
|
||||||
|
|
||||||
if mult >= 10.0:
|
if weight >= 10.0:
|
||||||
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask
|
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask
|
||||||
count[:, :, h_start:h_end, w_start:w_end] = mask
|
count[:, :, h_start:h_end, w_start:w_end] = mask
|
||||||
else:
|
else:
|
||||||
value[:, :, h_start:h_end, w_start:w_end] += (
|
value[:, :, h_start:h_end, w_start:w_end] += (
|
||||||
latents_region_denoised * mult * mask
|
latents_region_denoised * weight * mask
|
||||||
)
|
)
|
||||||
count[:, :, h_start:h_end, w_start:w_end] += mult * mask
|
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
|
||||||
|
|
||||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||||
latents = np.where(count > 0, value / count, value)
|
latents = np.where(count > 0, value / count, value)
|
||||||
|
|
|
@ -21,7 +21,7 @@ CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
||||||
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
|
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
|
||||||
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
|
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
|
||||||
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
|
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
|
||||||
REGION_TOKEN = compile(r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):([^\>]+)\>")
|
REGION_TOKEN = compile(r"\<region:(\d+):(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):([^\>]+)\>")
|
||||||
|
|
||||||
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||||
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
||||||
|
@ -459,8 +459,8 @@ Region = Tuple[int, int, int, int, float, str]
|
||||||
|
|
||||||
|
|
||||||
def parse_region_group(group) -> Region:
|
def parse_region_group(group) -> Region:
|
||||||
top, left, bottom, right, mult, prompt = group
|
top, left, bottom, right, weight, feather, prompt = group
|
||||||
return (int(top), int(left), int(bottom), int(right), float(mult), prompt)
|
return (int(top), int(left), int(bottom), int(right), float(weight), float(feather), prompt)
|
||||||
|
|
||||||
|
|
||||||
def parse_regions(prompt: str) -> Tuple[str, List[Region]]:
|
def parse_regions(prompt: str) -> Tuple[str, List[Region]]:
|
||||||
|
|
Loading…
Reference in New Issue