1
0
Fork 0

add regions to non-XL panorama, add feathering to SDXL regions

This commit is contained in:
Sean Sube 2023-11-08 18:51:31 -06:00
parent 7c67d595fb
commit f564bb3f65
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 136 additions and 12 deletions

View File

@ -91,6 +91,31 @@ def get_tile_grads(
return (grad_x, grad_y) return (grad_x, grad_y)
def make_tile_mask(
shape: np.ndarray,
tile: int,
overlap: float,
) -> np.ndarray:
mask = np.ones_like(shape[:, :, 0])
adj_tile = int(float(tile) * (1.0 - overlap))
# sort gradient points
p1 = adj_tile
p2 = (tile - adj_tile)
points = [0, min(p1, p2), max(p1, p2), tile]
# build gradients
grad_x, grad_y = [0, 1, 1, 0], [0, 1, 1, 0]
logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y)
mult_x = [np.interp(i, points, grad_x) for i in range(tile)]
mult_y = [np.interp(i, points, grad_y) for i in range(tile)]
mask = ((mask * mult_x).T * mult_y).T
return mask
def blend_tiles( def blend_tiles(
tiles: List[Tuple[int, int, Image.Image]], tiles: List[Tuple[int, int, Image.Image]],
scale: int, scale: int,
@ -109,7 +134,7 @@ def blend_tiles(
value = np.zeros(scaled_size) value = np.zeros(scaled_size)
for left, top, tile_image in tiles: for left, top, tile_image in tiles:
# histogram equalization # TODO: histogram equalization
equalized = np.array(tile_image).astype(np.float32) equalized = np.array(tile_image).astype(np.float32)
mask = np.ones_like(equalized[:, :, 0]) mask = np.ones_like(equalized[:, :, 0])

View File

@ -26,6 +26,8 @@ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMSchedu
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
from transformers import CLIPImageProcessor, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTokenizer
from ..utils import parse_regions
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -479,6 +481,8 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
# corresponds to doing no classifier free guidance. # corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
prompt, regions = parse_regions(prompt)
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
prompt, prompt,
num_images_per_prompt, num_images_per_prompt,
@ -488,6 +492,22 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
) )
# 3.b. Encode region prompts
region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _mult, region_prompt in regions:
if region_prompt.endswith("+"):
region_prompt = region_prompt[:-1] + " " + prompt
region_prompt_embeds = self._encode_prompt(
region_prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
)
region_embeds.append(region_prompt_embeds)
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
latents_dtype = prompt_embeds.dtype latents_dtype = prompt_embeds.dtype
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
@ -576,6 +596,75 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1 count[:, :, h_start:h_end, w_start:w_end] += 1
for r in range(len(regions)):
top, left, bottom, right, mult, prompt = regions[r]
logger.debug(
"running region prompt: %s, %s, %s, %s, %s, %s",
top,
left,
bottom,
right,
mult,
prompt,
)
# convert coordinates to latent space
h_start = top // 8
h_end = bottom // 8
w_start = left // 8
w_end = right // 8
# get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
logger.trace("region latent shape: [:,:,%s:%s,%s:%s] -> %s", h_start, h_end, w_start, w_end, latents_for_region.shape)
# expand the latents if we are doing classifier free guidance
latent_region_input = (
np.concatenate([latents_for_region] * 2)
if do_classifier_free_guidance
else latents_for_region
)
latent_region_input = self.scheduler.scale_model_input(
torch.from_numpy(latent_region_input), t
)
latent_region_input = latent_region_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
region_noise_pred = self.unet(
sample=latent_region_input,
timestep=timestep,
encoder_hidden_states=region_embeds[r],
)
region_noise_pred = region_noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
region_noise_pred_uncond, region_noise_pred_text = np.split(
region_noise_pred, 2
)
region_noise_pred = region_noise_pred_uncond + guidance_scale * (
region_noise_pred_text - region_noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(region_noise_pred),
t,
torch.from_numpy(latents_for_region),
**extra_step_kwargs,
)
latents_region_denoised = scheduler_output.prev_sample.numpy()
if mult >= 10.0:
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised
count[:, :, h_start:h_end, w_start:w_end] = 1
else:
value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * mult
)
count[:, :, h_start:h_end, w_start:w_end] += mult
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value) latents = np.where(count > 0, value / count, value)

View File

@ -12,6 +12,8 @@ from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import (
) )
from optimum.pipelines.diffusers.pipeline_utils import preprocess, rescale_noise_cfg from optimum.pipelines.diffusers.pipeline_utils import preprocess, rescale_noise_cfg
from onnx_web.chain.tile import make_tile_mask
from ..utils import parse_regions from ..utils import parse_regions
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -307,7 +309,10 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
region_embeds: List[np.ndarray] = [] region_embeds: List[np.ndarray] = []
add_region_embeds: List[np.ndarray] = [] add_region_embeds: List[np.ndarray] = []
for _top, _left, _bottom, _right, _mode, region_prompt in regions: for _top, _left, _bottom, _right, _mult, region_prompt in regions:
if region_prompt.endswith("+"):
region_prompt = region_prompt[:-1] + " " + prompt
( (
region_prompt_embeds, region_prompt_embeds,
region_negative_prompt_embeds, region_negative_prompt_embeds,
@ -318,10 +323,6 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
num_images_per_prompt, num_images_per_prompt,
do_classifier_free_guidance, do_classifier_free_guidance,
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
) )
if do_classifier_free_guidance: if do_classifier_free_guidance:
@ -462,7 +463,7 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
# get the latents corresponding to the current view coordinates # get the latents corresponding to the current view coordinates
latents_for_region = latents[:, :, h_start:h_end, w_start:w_end] latents_for_region = latents[:, :, h_start:h_end, w_start:w_end]
logger.trace("region latent shape: %s", latents_for_region.shape) logger.trace("region latent shape: [:,:,%s:%s,%s:%s] -> %s", h_start, h_end, w_start, w_end, latents_for_region.shape)
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_region_input = ( latent_region_input = (
@ -511,14 +512,23 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
) )
latents_region_denoised = scheduler_output.prev_sample.numpy() latents_region_denoised = scheduler_output.prev_sample.numpy()
if mult >= 1000.0: # TODO: get feather settings from prompt
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised feather = 0.25
count[:, :, h_start:h_end, w_start:w_end] = 1 tile = 1024
if feather > 0.0:
mask = make_tile_mask(latents_region_denoised, tile, feather)
else:
mask = 1
if mult >= 10.0:
value[:, :, h_start:h_end, w_start:w_end] = latents_region_denoised * mask
count[:, :, h_start:h_end, w_start:w_end] = mask
else: else:
value[:, :, h_start:h_end, w_start:w_end] += ( value[:, :, h_start:h_end, w_start:w_end] += (
latents_region_denoised * mult latents_region_denoised * mult * mask
) )
count[:, :, h_start:h_end, w_start:w_end] += mult count[:, :, h_start:h_end, w_start:w_end] += mult * mask
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value) latents = np.where(count > 0, value / count, value)