diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 08df9462..603f66e7 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -17,6 +17,7 @@ from ..utils import ( ServerContext, ) +import numpy as np import torch logger = getLogger(__name__) @@ -62,12 +63,15 @@ def upscale_stable_diffusion( logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt) pipeline = load_stable_diffusion(ctx, upscale) - generator = torch.manual_seed(params.seed) - seed = generator.initial_seed() + + if upscale.format == 'onnx': + generator = np.random.default_rng(params.seed) + else: + generator = torch.manual_seed(params.seed) return pipeline( params.prompt, source, - generator=torch.manual_seed(seed), + generator=generator, num_inference_steps=params.steps, ).images[0] diff --git a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py index 93c983af..8bb5fd46 100644 --- a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -3,10 +3,17 @@ from diffusers import ( OnnxRuntimeModel, StableDiffusionUpscalePipeline, ) +from diffusers.pipeline_utils import ( + ImagePipelineOutput, +) from logging import getLogger +from PIL import Image from typing import ( Any, + Callable, List, + Optional, + Union, ) import numpy as np @@ -15,6 +22,27 @@ import torch logger = getLogger(__name__) +def preprocess(image): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, Image.Image): + image = [image] + + if isinstance(image[0], Image.Image): + w, h = image[0].size + w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32 + + image = [np.array(i.resize((w, h)))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + return image + class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): def __init__( @@ -32,10 +60,129 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): def __call__( self, - *args, - **kwargs, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, Image.Image, List[Image.Image]], + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + noise_level: int = 20, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, ): - super().__call__(*args, **kwargs) + # 1. Check inputs + self.check_inputs(prompt, image, noise_level, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Preprocess image + image = preprocess(image) + image = image.cpu() #.numpy() + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Add noise to image + print('text embedding dtype', text_embeddings.dtype) + + noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) + noise = generator.random(size=image.shape, dtype=text_embeddings.dtype) + noise = torch.from_numpy(noise).to(device) + image = self.low_res_scheduler.add_noise(image, noise, noise_level) + + batch_multiplier = 2 if do_classifier_free_guidance else 1 + image = np.concatenate([image] * batch_multiplier * num_images_per_prompt) + noise_level = np.concatenate([noise_level] * image.shape[0]) + + # 6. Prepare latent variables + height, width = image.shape[2:] + num_channels_latents = self.vae.config.latent_channels # TODO: config + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Check that sizes of image and latents match + num_channels_image = image.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: # TODO: config + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = np.concatenate([latent_model_input, image], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 10. Post-processing + # make sure the VAE is in float32 mode, as it overflows in float16 + self.vae.to(dtype=np.float32) + image = self.decode_latents(latents.float()) + + # 11. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): batch_size = len(prompt) if isinstance(prompt, list) else 1