1
0
Fork 0

feat(api): add img2img mode to panorama pipeline

This commit is contained in:
Sean Sube 2023-04-27 07:22:00 -05:00
parent 29c616d99e
commit 47d80b07b3
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 297 additions and 17 deletions

View File

@ -15,13 +15,14 @@
import inspect import inspect
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import PIL
import numpy as np import numpy as np
import torch import torch
from transformers import CLIPImageProcessor, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTokenizer
from diffusers.configuration_utils import FrozenDict from diffusers.configuration_utils import FrozenDict
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, logging from diffusers.utils import deprecate, logging, PIL_INTERPOLATION
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@ -30,6 +31,27 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
vae_encoder: OnnxRuntimeModel vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel
@ -491,3 +513,236 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
def img2img(
self,
prompt: Union[str, List[str]] = None,
image: Union[np.ndarray, PIL.Image.Image] = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[np.random.RandomState] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
`Image`, or tensor representing an image batch which will be upscaled. *
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
One or a list of [numpy generator(s)](TODO) to make generation deterministic.
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`np.ndarray`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
height = image.height
width = image.width
# check inputs. Raise error if not correct
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
# define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if generator is None:
generator = np.random
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
prompt_embeds = self._encode_prompt(
prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# get the initial random noise unless the user supplied it
latents_dtype = prompt_embeds.dtype
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# prep image
image = preprocess(image).cpu().numpy()
image = image.astype(latents_dtype)
# encode the init image into latents and scale the latents
latents = self.vae_encoder(sample=image)[0]
latents = 0.18215 * latents
latents = latents * np.float64(self.scheduler.init_noise_sigma)
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps.numpy()[-init_timestep]
timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)
noise = generator.randn(*init_latents.shape).astype(latents_dtype)
init_latents = self.scheduler.add_noise(
torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
)
init_latents = init_latents.numpy()
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
timestep_dtype = next(
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# panorama additions
views = self.get_views(height, width)
count = np.zeros_like(latents)
value = np.zeros_like(latents)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
count.fill(0)
value.fill(0)
for h_start, h_end, w_start, w_end in views:
# get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
noise_pred = noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents_for_view), **extra_step_kwargs
)
latents_view_denoised = scheduler_output.prev_sample.numpy()
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value)
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else:
has_nsfw_concept = None
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@ -47,7 +47,9 @@ def run_loopback(
loopback_progress = ChainProgress.from_progress(progress) loopback_progress = ChainProgress.from_progress(progress)
# load img2img pipeline once # load img2img pipeline once
pipe_type = "lpw" if params.lpw() else "img2img" pipe_type = params.get_valid_pipeline("img2img")
logger.debug("using %s pipeline for loopback", pipe_type)
pipe = pipeline or load_pipeline( pipe = pipeline or load_pipeline(
server, server,
pipe_type, pipe_type,
@ -59,8 +61,7 @@ def run_loopback(
) )
def loopback_iteration(source: Image.Image): def loopback_iteration(source: Image.Image):
if params.lpw(): if pipe_type in ["lpw", "panorama"]:
logger.debug("using LPW pipeline for loopback")
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)
result = pipe.img2img( result = pipe.img2img(
source, source,
@ -76,7 +77,6 @@ def run_loopback(
) )
return result.images[0] return result.images[0]
else: else:
logger.debug("using img2img pipeline for loopback")
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = pipe( result = pipe(
params.prompt, params.prompt,
@ -134,7 +134,9 @@ def run_highres(
) )
# load img2img pipeline once # load img2img pipeline once
pipe_type = "lpw" if params.lpw() else "img2img" pipe_type = params.get_valid_pipeline("img2img")
logger.debug("using %s pipeline for highres", pipe_type)
highres_pipe = pipeline or load_pipeline( highres_pipe = pipeline or load_pipeline(
server, server,
pipe_type, pipe_type,
@ -172,8 +174,7 @@ def run_highres(
callback=highres_progress, callback=highres_progress,
) )
if params.lpw(): if pipe_type in ["lpw", "panorama"]:
logger.debug("using LPW pipeline for highres")
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)
result = highres_pipe.img2img( result = highres_pipe.img2img(
tile, tile,
@ -189,7 +190,6 @@ def run_highres(
) )
return result.images[0] return result.images[0]
else: else:
logger.debug("using img2img pipeline for highres")
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
result = highres_pipe( result = highres_pipe(
params.prompt, params.prompt,
@ -236,7 +236,9 @@ def run_txt2img_pipeline(
latents = get_latents_from_seed(params.seed, size, batch=params.batch) latents = get_latents_from_seed(params.seed, size, batch=params.batch)
prompt_pairs, loras, inversions = parse_prompt(params) prompt_pairs, loras, inversions = parse_prompt(params)
pipe_type = params.pipeline # TODO: allow txt2img, panorama, etc, but filter out the others pipe_type = params.get_valid_pipeline("txt2img")
logger.debug("using %s pipeline for txt2img", pipe_type)
pipe = load_pipeline( pipe = load_pipeline(
server, server,
pipe_type, pipe_type,
@ -248,8 +250,7 @@ def run_txt2img_pipeline(
) )
progress = job.get_progress_callback() progress = job.get_progress_callback()
if params.lpw(): if pipe_type in ["lpw", "panorama"]:
logger.debug("using LPW pipeline for txt2img")
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)
result = pipe.text2img( result = pipe.text2img(
params.prompt, params.prompt,
@ -345,9 +346,10 @@ def run_img2img_pipeline(
logger.debug("running source filter: %s", f.__name__) logger.debug("running source filter: %s", f.__name__)
source = f(server, source) source = f(server, source)
pipe_type = params.get_valid_pipeline("img2img")
pipe = load_pipeline( pipe = load_pipeline(
server, server,
params.pipeline, # this is one of the only places this can actually vary between different pipelines pipe_type,
params.model, params.model,
params.scheduler, params.scheduler,
job.get_device(), job.get_device(),
@ -357,15 +359,17 @@ def run_img2img_pipeline(
) )
pipe_params = {} pipe_params = {}
if params.pipeline == "controlnet": if pipe_type == "controlnet":
pipe_params["controlnet_conditioning_scale"] = strength pipe_params["controlnet_conditioning_scale"] = strength
elif params.pipeline == "img2img": elif pipe_type == "img2img":
pipe_params["strength"] = strength pipe_params["strength"] = strength
elif params.pipeline == "pix2pix": elif pipe_type == "panorama":
pipe_params["strength"] = strength
elif pipe_type == "pix2pix":
pipe_params["image_guidance_scale"] = strength pipe_params["image_guidance_scale"] = strength
progress = job.get_progress_callback() progress = job.get_progress_callback()
if params.lpw(): if pipe_type in ["lpw", "panorama"]:
logger.debug("using LPW pipeline for img2img") logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed) rng = torch.manual_seed(params.seed)
result = pipe.img2img( result = pipe.img2img(

View File

@ -220,6 +220,27 @@ class ImageParams:
def do_cfg(self): def do_cfg(self):
return self.cfg > 1.0 return self.cfg > 1.0
def get_valid_pipeline(self, group: str, pipeline: str = None) -> str:
pipeline = pipeline or self.pipeline
# if the correct pipeline was already requested, simply use that
if group == pipeline:
return pipeline
# otherwise, check for additional allowed pipelines
if group == "img2img":
if pipeline in ["controlnet", "lpw", "panorama", "pix2pix"]:
return pipeline
elif group == "inpaint":
if pipeline in ["controlnet"]:
return pipeline
elif group == "txt2img":
if pipeline in ["panorama"]:
return pipeline
logger.debug("pipeline %s is not valid for %s", pipeline, group)
return group
def lpw(self): def lpw(self):
return self.pipeline == "lpw" return self.pipeline == "lpw"