1
0
Fork 0

feat(api): add img2img mode to panorama pipeline

This commit is contained in:
Sean Sube 2023-04-27 07:22:00 -05:00
parent 29c616d99e
commit 47d80b07b3
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 297 additions and 17 deletions

View File

@ -15,13 +15,14 @@
import inspect
from typing import Callable, List, Optional, Union
import PIL
import numpy as np
import torch
from transformers import CLIPImageProcessor, CLIPTokenizer
from diffusers.configuration_utils import FrozenDict
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, logging
from diffusers.utils import deprecate, logging, PIL_INTERPOLATION
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@ -30,6 +31,27 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
logger = logging.get_logger(__name__)
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel
@ -491,3 +513,236 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
def img2img(
self,
prompt: Union[str, List[str]] = None,
image: Union[np.ndarray, PIL.Image.Image] = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[np.random.RandomState] = None,
prompt_embeds: Optional[np.ndarray] = None,
negative_prompt_embeds: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: int = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
`Image`, or tensor representing an image batch which will be upscaled. *
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
One or a list of [numpy generator(s)](TODO) to make generation deterministic.
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`np.ndarray`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
height = image.height
width = image.width
# check inputs. Raise error if not correct
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
# define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if generator is None:
generator = np.random
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
prompt_embeds = self._encode_prompt(
prompt,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# get the initial random noise unless the user supplied it
latents_dtype = prompt_embeds.dtype
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# prep image
image = preprocess(image).cpu().numpy()
image = image.astype(latents_dtype)
# encode the init image into latents and scale the latents
latents = self.vae_encoder(sample=image)[0]
latents = 0.18215 * latents
latents = latents * np.float64(self.scheduler.init_noise_sigma)
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps.numpy()[-init_timestep]
timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)
noise = generator.randn(*init_latents.shape).astype(latents_dtype)
init_latents = self.scheduler.add_noise(
torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
)
init_latents = init_latents.numpy()
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
timestep_dtype = next(
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# panorama additions
views = self.get_views(height, width)
count = np.zeros_like(latents)
value = np.zeros_like(latents)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
count.fill(0)
value.fill(0)
for h_start, h_end, w_start, w_end in views:
# get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
noise_pred = noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents_for_view), **extra_step_kwargs
)
latents_view_denoised = scheduler_output.prev_sample.numpy()
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = np.where(count > 0, value / count, value)
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else:
has_nsfw_concept = None
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@ -47,7 +47,9 @@ def run_loopback(
loopback_progress = ChainProgress.from_progress(progress)
# load img2img pipeline once
pipe_type = "lpw" if params.lpw() else "img2img"
pipe_type = params.get_valid_pipeline("img2img")
logger.debug("using %s pipeline for loopback", pipe_type)
pipe = pipeline or load_pipeline(
server,
pipe_type,
@ -59,8 +61,7 @@ def run_loopback(
)
def loopback_iteration(source: Image.Image):
if params.lpw():
logger.debug("using LPW pipeline for loopback")
if pipe_type in ["lpw", "panorama"]:
rng = torch.manual_seed(params.seed)
result = pipe.img2img(
source,
@ -76,7 +77,6 @@ def run_loopback(
)
return result.images[0]
else:
logger.debug("using img2img pipeline for loopback")
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
@ -134,7 +134,9 @@ def run_highres(
)
# load img2img pipeline once
pipe_type = "lpw" if params.lpw() else "img2img"
pipe_type = params.get_valid_pipeline("img2img")
logger.debug("using %s pipeline for highres", pipe_type)
highres_pipe = pipeline or load_pipeline(
server,
pipe_type,
@ -172,8 +174,7 @@ def run_highres(
callback=highres_progress,
)
if params.lpw():
logger.debug("using LPW pipeline for highres")
if pipe_type in ["lpw", "panorama"]:
rng = torch.manual_seed(params.seed)
result = highres_pipe.img2img(
tile,
@ -189,7 +190,6 @@ def run_highres(
)
return result.images[0]
else:
logger.debug("using img2img pipeline for highres")
rng = np.random.RandomState(params.seed)
result = highres_pipe(
params.prompt,
@ -236,7 +236,9 @@ def run_txt2img_pipeline(
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
prompt_pairs, loras, inversions = parse_prompt(params)
pipe_type = params.pipeline # TODO: allow txt2img, panorama, etc, but filter out the others
pipe_type = params.get_valid_pipeline("txt2img")
logger.debug("using %s pipeline for txt2img", pipe_type)
pipe = load_pipeline(
server,
pipe_type,
@ -248,8 +250,7 @@ def run_txt2img_pipeline(
)
progress = job.get_progress_callback()
if params.lpw():
logger.debug("using LPW pipeline for txt2img")
if pipe_type in ["lpw", "panorama"]:
rng = torch.manual_seed(params.seed)
result = pipe.text2img(
params.prompt,
@ -345,9 +346,10 @@ def run_img2img_pipeline(
logger.debug("running source filter: %s", f.__name__)
source = f(server, source)
pipe_type = params.get_valid_pipeline("img2img")
pipe = load_pipeline(
server,
params.pipeline, # this is one of the only places this can actually vary between different pipelines
pipe_type,
params.model,
params.scheduler,
job.get_device(),
@ -357,15 +359,17 @@ def run_img2img_pipeline(
)
pipe_params = {}
if params.pipeline == "controlnet":
if pipe_type == "controlnet":
pipe_params["controlnet_conditioning_scale"] = strength
elif params.pipeline == "img2img":
elif pipe_type == "img2img":
pipe_params["strength"] = strength
elif params.pipeline == "pix2pix":
elif pipe_type == "panorama":
pipe_params["strength"] = strength
elif pipe_type == "pix2pix":
pipe_params["image_guidance_scale"] = strength
progress = job.get_progress_callback()
if params.lpw():
if pipe_type in ["lpw", "panorama"]:
logger.debug("using LPW pipeline for img2img")
rng = torch.manual_seed(params.seed)
result = pipe.img2img(

View File

@ -220,6 +220,27 @@ class ImageParams:
def do_cfg(self):
return self.cfg > 1.0
def get_valid_pipeline(self, group: str, pipeline: str = None) -> str:
pipeline = pipeline or self.pipeline
# if the correct pipeline was already requested, simply use that
if group == pipeline:
return pipeline
# otherwise, check for additional allowed pipelines
if group == "img2img":
if pipeline in ["controlnet", "lpw", "panorama", "pix2pix"]:
return pipeline
elif group == "inpaint":
if pipeline in ["controlnet"]:
return pipeline
elif group == "txt2img":
if pipeline in ["panorama"]:
return pipeline
logger.debug("pipeline %s is not valid for %s", pipeline, group)
return group
def lpw(self):
return self.pipeline == "lpw"