add regions to non-XL panorama, add feathering to SDXL regions
This commit is contained in:
parent
7c67d595fb
commit
f564bb3f65
|
@ -91,6 +91,31 @@ def get_tile_grads(
|
|||
return (grad_x, grad_y)
|
||||
|
||||
|
||||
def make_tile_mask(
|
||||
shape: np.ndarray,
|
||||
tile: int,
|
||||
overlap: float,
|
||||
) -> np.ndarray:
|
||||
mask = np.ones_like(shape[:, :, 0])
|
||||
adj_tile = int(float(tile) * (1.0 - overlap))
|
||||
|
||||
# sort gradient points
|
||||
p1 = adj_tile
|
||||
p2 = (tile - adj_tile)
|
||||
points = [0, min(p1, p2), max(p1, p2), tile]
|
||||
|
||||
# build gradients
|
||||
grad_x, grad_y = [0, 1, 1, 0], [0, 1, 1, 0]
|
||||
logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y)
|
||||
|
||||
mult_x = [np.interp(i, points, grad_x) for i in range(tile)]
|
||||
mult_y = [np.interp(i, points, grad_y) for i in range(tile)]
|
||||
|
||||
mask = ((mask * mult_x).T * mult_y).T
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def blend_tiles(
|
||||
tiles: List[Tuple[int, int, Image.Image]],
|
||||
scale: int,
|
||||
|
@ -109,7 +134,7 @@ def blend_tiles(
|
|||
value = np.zeros(scaled_size)
|
||||
|
||||
for left, top, tile_image in tiles:
|
||||
# histogram equalization
|
||||
# TODO: histogram equalization
|
||||
equalized = np.array(tile_image).astype(np.float32)
|
||||
mask = np.ones_like(equalized[:, :, 0])
|
||||
|
||||
|
|
|
@ -26,6 +26,8 @@ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMSchedu
|
|||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||
from transformers import CLIPImageProcessor, CLIPTokenizer
|
||||
|
||||
from ..utils import parse_regions
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
@ -479,6 +481,8 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
prompt, regions = parse_regions(prompt)
|
||||
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
num_images_per_prompt,
|
||||
|
@ -488,6 +492,22 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 3.b. Encode region prompts
|
||||
region_embeds: List[np.ndarray] = []
|
||||
|
||||
for _top, _left, _bottom, _right, _mult, region_prompt in regions:
|
||||
if region_prompt.endswith("+"):
|
||||
region_prompt = region_prompt[:-1] + " " + prompt
|
||||
|
||||
region_prompt_embeds = self._encode_prompt(
|
||||
region_prompt,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
region_embeds.append(region_prompt_embeds)
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
latents_dtype = prompt_embeds.dtype
|
||||
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||
|
@ -576,6 +596,75 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||
|
||||
for r in range(len(regions)):
|
||||
top, left, bottom, right, mult, prompt = regions[r]
|
||||
logger.debug(
|
||||
"running region prompt: %s, %s, %s, %s, %s, %s",
|
||||
top,
|
||||
left,
|
||||
bottom,
|
||||
right,
|
||||
mult,
|
||||
prompt,
|
||||
)
|
||||
|
||||
# convert coordinates to latent space
|
||||
h_start = top // 8
|
||||
h_end = bottom // 8
|
||||
w_start = left // 8
|
||||
w_end = right // 8
|
||||
|
||||
# get the latents corresponding to the current view coordinates
|
||||
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
||||
logger.trace("region latent shape: [:,:,%s:%s,%s:%s] -> %s", h_start, h_end, w_start, w_end, latents_for_region.shape)
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_region_input = (
|
||||
np.concatenate([latents_for_region] * 2)
|
||||
if do_classifier_free_guidance
|
||||
else latents_for_region
|
||||
)
|
||||
latent_region_input = self.scheduler.scale_model_input(
|
||||
torch.from_numpy(latent_region_input), t
|
||||
)
|
||||
latent_region_input = latent_region_input.cpu().numpy()
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
region_noise_pred = self.unet(
|
||||
sample=latent_region_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=region_embeds[r],
|
||||
)
|
||||
region_noise_pred = region_noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
region_noise_pred_uncond, region_noise_pred_text = np.split(
|
||||
region_noise_pred, 2
|
||||
)
|
||||
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
|
||||
region_noise_pred_text - region_noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = self.scheduler.step(
|
||||
torch.from_numpy(region_noise_pred),
|
||||
t,
|
||||
torch.from_numpy(latents_for_region),
|
||||
**extra_step_kwargs,
|
||||
)
|
||||
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
||||
|
||||
if mult >= 10.0:
|
||||
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised
|
||||
count[:, :, h_start:h_end, w_start:w_end] = 1
|
||||
else:
|
||||
value[:, :, h_start:h_end, w_start:w_end] += (
|
||||
latents_region_denoised * mult
|
||||
)
|
||||
count[:, :, h_start:h_end, w_start:w_end] += mult
|
||||
|
||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||
latents = np.where(count > 0, value / count, value)
|
||||
|
||||
|
|
|
@ -12,6 +12,8 @@ from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
|
|||
)
|
||||
from optimum.pipelines.diffusers.pipeline_utils import preprocess, rescale_noise_cfg
|
||||
|
||||
from onnx_web.chain.tile import make_tile_mask
|
||||
|
||||
from ..utils import parse_regions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -307,7 +309,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
region_embeds: List[np.ndarray] = []
|
||||
add_region_embeds: List[np.ndarray] = []
|
||||
|
||||
for _top, _left, _bottom, _right, _mode, region_prompt in regions:
|
||||
for _top, _left, _bottom, _right, _mult, region_prompt in regions:
|
||||
if region_prompt.endswith("+"):
|
||||
region_prompt = region_prompt[:-1] + " " + prompt
|
||||
|
||||
(
|
||||
region_prompt_embeds,
|
||||
region_negative_prompt_embeds,
|
||||
|
@ -318,10 +323,6 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
|
@ -462,7 +463,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
|
||||
# get the latents corresponding to the current view coordinates
|
||||
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
|
||||
logger.trace("region latent shape: %s", latents_for_region.shape)
|
||||
logger.trace("region latent shape: [:,:,%s:%s,%s:%s] -> %s", h_start, h_end, w_start, w_end, latents_for_region.shape)
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_region_input = (
|
||||
|
@ -511,14 +512,23 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
|
|||
)
|
||||
latents_region_denoised = scheduler_output.prev_sample.numpy()
|
||||
|
||||
if mult >= 1000.0:
|
||||
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised
|
||||
count[:, :, h_start:h_end, w_start:w_end] = 1
|
||||
# TODO: get feather settings from prompt
|
||||
feather = 0.25
|
||||
tile = 1024
|
||||
|
||||
if feather > 0.0:
|
||||
mask = make_tile_mask(latents_region_denoised, tile, feather)
|
||||
else:
|
||||
mask = 1
|
||||
|
||||
if mult >= 10.0:
|
||||
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask
|
||||
count[:, :, h_start:h_end, w_start:w_end] = mask
|
||||
else:
|
||||
value[:, :, h_start:h_end, w_start:w_end] += (
|
||||
latents_region_denoised * mult
|
||||
latents_region_denoised * mult * mask
|
||||
)
|
||||
count[:, :, h_start:h_end, w_start:w_end] += mult
|
||||
count[:, :, h_start:h_end, w_start:w_end] += mult * mask
|
||||
|
||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||
latents = np.where(count > 0, value / count, value)
|
||||
|
|
Loading…
Reference in New Issue