1
0
Fork 0

feat(api): add edge feathering to region prompts

This commit is contained in:
Sean Sube 2023-11-08 22:00:32 -06:00
parent 633e078036
commit 59515193a1
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 33 additions and 26 deletions

View File

@ -2,7 +2,7 @@ import itertools
from enum import Enum
from logging import getLogger
from math import ceil
from typing import List, Optional, Protocol, Tuple
from typing import Any, List, Optional, Protocol, Tuple
import numpy as np
from PIL import Image
@ -92,7 +92,7 @@ def get_tile_grads(
def make_tile_mask(
shape: np.ndarray,
shape: Any,
tile: int,
overlap: float,
) -> np.ndarray:

View File

@ -26,6 +26,8 @@ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMSchedu
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
from transformers import CLIPImageProcessor, CLIPTokenizer
from onnx_web.chain.tile import make_tile_mask
from ..utils import parse_regions
logger = logging.get_logger(__name__)
@ -495,7 +497,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
# 3.b. Encode region prompts
region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _mult, region_prompt in regions:
for _top, _left, _bottom, _right, _weight, _feather, region_prompt in regions:
if region_prompt.endswith("+"):
region_prompt = region_prompt[:-1] + " " + prompt
@ -597,14 +599,15 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
count[:, :, h_start:h_end, w_start:w_end] += 1
for r in range(len(regions)):
top, left, bottom, right, mult, prompt = regions[r]
top, left, bottom, right, weight, feather, prompt = regions[r]
logger.debug(
"running region prompt: %s, %s, %s, %s, %s, %s",
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
top,
left,
bottom,
right,
mult,
weight,
feather,
prompt,
)
@ -656,14 +659,21 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
)
latents_region_denoised = scheduler_output.prev_sample.numpy()
if mult >= 10.0:
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised
count[:, :, h_start:h_end, w_start:w_end] = 1
if feather > 0.0:
mask = make_tile_mask((h_end - h_start, w_end - w_start), self.window, feather)
mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0)
else:
mask = 1
if weight >= 10.0:
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask
count[:, :, h_start:h_end, w_start:w_end] = mask
else:
value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * mult
latents_region_denoised * weight * mask
)
count[:, :, h_start:h_end, w_start:w_end] += mult
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value)

View File

@ -309,7 +309,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
region_embeds: List[np.ndarray] = []
add_region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _mult, region_prompt in regions:
for _top, _left, _bottom, _right, _weight, region_prompt in regions:
if region_prompt.endswith("+"):
region_prompt = region_prompt[:-1] + " " + prompt
@ -444,14 +444,15 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
count[:, :, h_start:h_end, w_start:w_end] += 1
for r in range(len(regions)):
top, left, bottom, right, mult, prompt = regions[r]
top, left, bottom, right, weight, feather, prompt = regions[r]
logger.debug(
"running region prompt: %s, %s, %s, %s, %s, %s",
"running region prompt: %s, %s, %s, %s, %s, %s, %s",
top,
left,
bottom,
right,
mult,
weight,
feather,
prompt,
)
@ -512,25 +513,21 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
)
latents_region_denoised = scheduler_output.prev_sample.numpy()
# TODO: get feather settings from prompt
feather = 0.25
tile = 128
if feather > 0.0:
mask = make_tile_mask(latents_region_denoised, (tile, tile), feather)
mask = make_tile_mask((h_end - h_start, w_end - w_start), self.window, feather)
mask = np.repeat(mask, 4, axis=0)
mask = np.expand_dims(mask, axis=0)
else:
mask = 1
if mult >= 10.0:
if weight >= 10.0:
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask
count[:, :, h_start:h_end, w_start:w_end] = mask
else:
value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * mult * mask
latents_region_denoised * weight * mask
)
count[:, :, h_start:h_end, w_start:w_end] += mult * mask
count[:, :, h_start:h_end, w_start:w_end] += weight * mask
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value)

View File

@ -21,7 +21,7 @@ CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__")
REGION_TOKEN = compile(r"\<region:(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):([^\>]+)\>")
REGION_TOKEN = compile(r"\<region:(\d+):(\d+):(\d+):(\d+):(\d+):(-?[\.|\d]+):([^\>]+)\>")
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
@ -459,8 +459,8 @@ Region = Tuple[int, int, int, int, float, str]
def parse_region_group(group) -> Region:
top, left, bottom, right, mult, prompt = group
return (int(top), int(left), int(bottom), int(right), float(mult), prompt)
top, left, bottom, right, weight, feather, prompt = group
return (int(top), int(left), int(bottom), int(right), float(weight), float(feather), prompt)
def parse_regions(prompt: str) -> Tuple[str, List[Region]]: