feat(api): add img2img mode to panorama pipeline
This commit is contained in:
parent
29c616d99e
commit
47d80b07b3
|
@ -15,13 +15,14 @@
|
|||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTokenizer
|
||||
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import deprecate, logging
|
||||
from diffusers.utils import deprecate, logging, PIL_INTERPOLATION
|
||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
@ -30,6 +31,27 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||
vae_encoder: OnnxRuntimeModel
|
||||
vae_decoder: OnnxRuntimeModel
|
||||
|
@ -491,3 +513,236 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
def img2img(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[np.ndarray, PIL.Image.Image] = None,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[np.random.RandomState] = None,
|
||||
prompt_embeds: Optional[np.ndarray] = None,
|
||||
negative_prompt_embeds: Optional[np.ndarray] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
|
||||
`Image`, or tensor representing an image batch which will be upscaled. *
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
|
||||
is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`np.random.RandomState`, *optional*):
|
||||
One or a list of [numpy generator(s)](TODO) to make generation deterministic.
|
||||
latents (`np.ndarray`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`np.ndarray`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`np.ndarray`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
|
||||
height = image.height
|
||||
width = image.width
|
||||
|
||||
# check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
||||
)
|
||||
|
||||
# define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if generator is None:
|
||||
generator = np.random
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
latents_dtype = prompt_embeds.dtype
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# prep image
|
||||
image = preprocess(image).cpu().numpy()
|
||||
image = image.astype(latents_dtype)
|
||||
# encode the init image into latents and scale the latents
|
||||
latents = self.vae_encoder(sample=image)[0]
|
||||
latents = 0.18215 * latents
|
||||
|
||||
latents = latents * np.float64(self.scheduler.init_noise_sigma)
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
|
||||
timesteps = self.scheduler.timesteps.numpy()[-init_timestep]
|
||||
timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)
|
||||
|
||||
noise = generator.randn(*init_latents.shape).astype(latents_dtype)
|
||||
init_latents = self.scheduler.add_noise(
|
||||
torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
|
||||
)
|
||||
init_latents = init_latents.numpy()
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
timestep_dtype = next(
|
||||
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
# panorama additions
|
||||
views = self.get_views(height, width)
|
||||
count = np.zeros_like(latents)
|
||||
value = np.zeros_like(latents)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
count.fill(0)
|
||||
value.fill(0)
|
||||
|
||||
for h_start, h_end, w_start, w_end in views:
|
||||
# get the latents corresponding to the current view coordinates
|
||||
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view
|
||||
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
|
||||
latent_model_input = latent_model_input.cpu().numpy()
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = self.scheduler.step(
|
||||
torch.from_numpy(noise_pred), t, torch.from_numpy(latents_for_view), **extra_step_kwargs
|
||||
)
|
||||
latents_view_denoised = scheduler_output.prev_sample.numpy()
|
||||
|
||||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||
|
||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||
latents = np.where(count > 0, value / count, value)
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
image = np.concatenate(
|
||||
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
|
||||
)
|
||||
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(
|
||||
self.numpy_to_pil(image), return_tensors="np"
|
||||
).pixel_values.astype(image.dtype)
|
||||
|
||||
images, has_nsfw_concept = [], []
|
||||
for i in range(image.shape[0]):
|
||||
image_i, has_nsfw_concept_i = self.safety_checker(
|
||||
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
|
||||
)
|
||||
images.append(image_i)
|
||||
has_nsfw_concept.append(has_nsfw_concept_i[0])
|
||||
image = np.concatenate(images)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
|
|
@ -47,7 +47,9 @@ def run_loopback(
|
|||
loopback_progress = ChainProgress.from_progress(progress)
|
||||
|
||||
# load img2img pipeline once
|
||||
pipe_type = "lpw" if params.lpw() else "img2img"
|
||||
pipe_type = params.get_valid_pipeline("img2img")
|
||||
logger.debug("using %s pipeline for loopback", pipe_type)
|
||||
|
||||
pipe = pipeline or load_pipeline(
|
||||
server,
|
||||
pipe_type,
|
||||
|
@ -59,8 +61,7 @@ def run_loopback(
|
|||
)
|
||||
|
||||
def loopback_iteration(source: Image.Image):
|
||||
if params.lpw():
|
||||
logger.debug("using LPW pipeline for loopback")
|
||||
if pipe_type in ["lpw", "panorama"]:
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.img2img(
|
||||
source,
|
||||
|
@ -76,7 +77,6 @@ def run_loopback(
|
|||
)
|
||||
return result.images[0]
|
||||
else:
|
||||
logger.debug("using img2img pipeline for loopback")
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
|
@ -134,7 +134,9 @@ def run_highres(
|
|||
)
|
||||
|
||||
# load img2img pipeline once
|
||||
pipe_type = "lpw" if params.lpw() else "img2img"
|
||||
pipe_type = params.get_valid_pipeline("img2img")
|
||||
logger.debug("using %s pipeline for highres", pipe_type)
|
||||
|
||||
highres_pipe = pipeline or load_pipeline(
|
||||
server,
|
||||
pipe_type,
|
||||
|
@ -172,8 +174,7 @@ def run_highres(
|
|||
callback=highres_progress,
|
||||
)
|
||||
|
||||
if params.lpw():
|
||||
logger.debug("using LPW pipeline for highres")
|
||||
if pipe_type in ["lpw", "panorama"]:
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = highres_pipe.img2img(
|
||||
tile,
|
||||
|
@ -189,7 +190,6 @@ def run_highres(
|
|||
)
|
||||
return result.images[0]
|
||||
else:
|
||||
logger.debug("using img2img pipeline for highres")
|
||||
rng = np.random.RandomState(params.seed)
|
||||
result = highres_pipe(
|
||||
params.prompt,
|
||||
|
@ -236,7 +236,9 @@ def run_txt2img_pipeline(
|
|||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
prompt_pairs, loras, inversions = parse_prompt(params)
|
||||
|
||||
pipe_type = params.pipeline # TODO: allow txt2img, panorama, etc, but filter out the others
|
||||
pipe_type = params.get_valid_pipeline("txt2img")
|
||||
logger.debug("using %s pipeline for txt2img", pipe_type)
|
||||
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
pipe_type,
|
||||
|
@ -248,8 +250,7 @@ def run_txt2img_pipeline(
|
|||
)
|
||||
progress = job.get_progress_callback()
|
||||
|
||||
if params.lpw():
|
||||
logger.debug("using LPW pipeline for txt2img")
|
||||
if pipe_type in ["lpw", "panorama"]:
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.text2img(
|
||||
params.prompt,
|
||||
|
@ -345,9 +346,10 @@ def run_img2img_pipeline(
|
|||
logger.debug("running source filter: %s", f.__name__)
|
||||
source = f(server, source)
|
||||
|
||||
pipe_type = params.get_valid_pipeline("img2img")
|
||||
pipe = load_pipeline(
|
||||
server,
|
||||
params.pipeline, # this is one of the only places this can actually vary between different pipelines
|
||||
pipe_type,
|
||||
params.model,
|
||||
params.scheduler,
|
||||
job.get_device(),
|
||||
|
@ -357,15 +359,17 @@ def run_img2img_pipeline(
|
|||
)
|
||||
|
||||
pipe_params = {}
|
||||
if params.pipeline == "controlnet":
|
||||
if pipe_type == "controlnet":
|
||||
pipe_params["controlnet_conditioning_scale"] = strength
|
||||
elif params.pipeline == "img2img":
|
||||
elif pipe_type == "img2img":
|
||||
pipe_params["strength"] = strength
|
||||
elif params.pipeline == "pix2pix":
|
||||
elif pipe_type == "panorama":
|
||||
pipe_params["strength"] = strength
|
||||
elif pipe_type == "pix2pix":
|
||||
pipe_params["image_guidance_scale"] = strength
|
||||
|
||||
progress = job.get_progress_callback()
|
||||
if params.lpw():
|
||||
if pipe_type in ["lpw", "panorama"]:
|
||||
logger.debug("using LPW pipeline for img2img")
|
||||
rng = torch.manual_seed(params.seed)
|
||||
result = pipe.img2img(
|
||||
|
|
|
@ -220,6 +220,27 @@ class ImageParams:
|
|||
def do_cfg(self):
|
||||
return self.cfg > 1.0
|
||||
|
||||
def get_valid_pipeline(self, group: str, pipeline: str = None) -> str:
|
||||
pipeline = pipeline or self.pipeline
|
||||
|
||||
# if the correct pipeline was already requested, simply use that
|
||||
if group == pipeline:
|
||||
return pipeline
|
||||
|
||||
# otherwise, check for additional allowed pipelines
|
||||
if group == "img2img":
|
||||
if pipeline in ["controlnet", "lpw", "panorama", "pix2pix"]:
|
||||
return pipeline
|
||||
elif group == "inpaint":
|
||||
if pipeline in ["controlnet"]:
|
||||
return pipeline
|
||||
elif group == "txt2img":
|
||||
if pipeline in ["panorama"]:
|
||||
return pipeline
|
||||
|
||||
logger.debug("pipeline %s is not valid for %s", pipeline, group)
|
||||
return group
|
||||
|
||||
def lpw(self):
|
||||
return self.pipeline == "lpw"
|
||||
|
||||
|
|
Loading…
Reference in New Issue