From bacce0ace9e67270dfe65b31a8ae18196c142009 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 29 Jan 2023 18:42:05 -0600 Subject: [PATCH] testing ONNX upscaling --- .../chain/upscale_stable_diffusion.py | 6 +-- api/onnx_web/convert.py | 38 +++++++++------- .../pipeline_onnx_stable_diffusion_upscale.py | 43 +++++++++++++------ api/onnx_web/logging.py | 1 - api/onnx_web/serve.py | 6 +-- onnx-web.code-workspace | 4 ++ 6 files changed, 58 insertions(+), 40 deletions(-) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 603f66e7..e4749918 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -63,11 +63,7 @@ def upscale_stable_diffusion( logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt) pipeline = load_stable_diffusion(ctx, upscale) - - if upscale.format == 'onnx': - generator = np.random.default_rng(params.seed) - else: - generator = torch.manual_seed(params.seed) + generator = torch.manual_seed(params.seed) return pipeline( params.prompt, diff --git a/api/onnx_web/convert.py b/api/onnx_web/convert.py index 2e390425..b90a4c00 100644 --- a/api/onnx_web/convert.py +++ b/api/onnx_web/convert.py @@ -253,7 +253,17 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si ) del pipeline.text_encoder + logger.info('UNET config: %s', pipeline.unet.config) + # UNET + if single_vae: + unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"] + # unet_inputs = ["latent_model_input", "timestep", "encoder_hidden_states", "class_labels"] + unet_scale = 4 + else: + unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"] + unet_scale = False + unet_in_channels = pipeline.unet.config.in_channels unet_sample_size = pipeline.unet.config.sample_size unet_path = output_path / "unet" / "model.onnx" @@ -265,13 +275,10 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si torch.randn(2).to(device=training_device, dtype=dtype), torch.randn(2, num_tokens, text_hidden_size).to( device=training_device, dtype=dtype), - # TODO: needs to be Int or Long for upscaling, Bool for regular - 4, - # False, + unet_scale, ), output_path=unet_path, - ordered_input_names=["sample", "timestep", - "encoder_hidden_states", "return_dict"], + ordered_input_names=unet_inputs, # has to be different from "sample" for correct tracing output_names=["out_sample"], dynamic_axes={ @@ -300,25 +307,26 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si del pipeline.unet if single_vae: + logger.info('VAE config: %s', pipeline.vae.config) + # SINGLE VAE vae_only = pipeline.vae - vae_in_channels = vae_only.config.in_channels - vae_sample_size = vae_only.config.sample_size - # need to get the raw tensor output (sample) from the encoder - vae_only.forward = lambda sample, return_dict: vae_only.encode( - sample, return_dict)[0].sample() + vae_latent_channels = vae_only.config.latent_channels + vae_out_channels = vae_only.config.out_channels + # forward only through the decoder part + vae_only.forward = vae_only.decode onnx_export( vae_only, model_args=( - torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to( + torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to( device=training_device, dtype=dtype), False, ), output_path=output_path / "vae" / "model.onnx", - ordered_input_names=["sample", "return_dict"], - output_names=["latent_sample"], + ordered_input_names=["latent_sample", "return_dict"], + output_names=["sample"], dynamic_axes={ - "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, }, opset=opset, ) @@ -435,7 +443,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si logger.info('exporting ONNX model') onnx_pipeline.save_pretrained(output_path) - logger.info("ONNX pipeline saved to", output_path) + logger.info("ONNX pipeline saved to %s", output_path) del pipeline del onnx_pipeline diff --git a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py index 8bb5fd46..17137e5a 100644 --- a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -21,6 +21,8 @@ import torch logger = getLogger(__name__) +num_channels_latents = 4 # self.vae.config.latent_channels +unet_in_channels = 7 # self.unet.config.in_channels def preprocess(image): if isinstance(image, torch.Tensor): @@ -68,7 +70,8 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + # generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -100,11 +103,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): timesteps = self.scheduler.timesteps # 5. Add noise to image - print('text embedding dtype', text_embeddings.dtype) + # print('text embedding dtype', text_embeddings.dtype) + text_embeddings_dtype = torch.float32 noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) - noise = generator.random(size=image.shape, dtype=text_embeddings.dtype) - noise = torch.from_numpy(noise).to(device) + noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype) image = self.low_res_scheduler.add_noise(image, noise, noise_level) batch_multiplier = 2 if do_classifier_free_guidance else 1 @@ -113,13 +116,12 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): # 6. Prepare latent variables height, width = image.shape[2:] - num_channels_latents = self.vae.config.latent_channels # TODO: config latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, - text_embeddings.dtype, + text_embeddings_dtype, device, generator, latents, @@ -127,10 +129,10 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): # 7. Check that sizes of image and latents match num_channels_image = image.shape[1] - if num_channels_latents + num_channels_image != self.unet.config.in_channels: # TODO: config + if num_channels_latents + num_channels_image != unet_in_channels: raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" - f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f"Incorrect configuration settings! The config of `pipeline.unet` expects" + f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" `num_channels_image`: {num_channels_image} " f" = {num_channels_latents+num_channels_image}. Please verify the config of" " `pipeline.unet` or your `image` input." @@ -148,16 +150,23 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - latent_model_input = np.concatenate([latent_model_input, image], dim=1) + latent_model_input = np.concatenate([latent_model_input, image], axis=1) + + # timestep to tensor + timestep = np.array([t], dtype=np.float32) # predict the noise residual + # print('noise pred unet', latent_model_input.dtype, text_embeddings.dtype, t, noise_level) noise_pred = self.unet( - latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level - ).sample + sample=latent_model_input, + timestep=timestep, + encoder_hidden_states=text_embeddings, + class_labels=noise_level + )[0] # perform guidance if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) # noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 @@ -171,7 +180,7 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): # 10. Post-processing # make sure the VAE is in float32 mode, as it overflows in float16 - self.vae.to(dtype=np.float32) + # self.vae.to(dtype=np.float32) image = self.decode_latents(latents.float()) # 11. Convert to PIL @@ -183,6 +192,12 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): return ImagePipelineOutput(images=image) + def decode_latents(self, latents): + latents = 1 / 0.08333 * latents + image = self.vae(latent_sample=latents)[0] + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + return image def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): batch_size = len(prompt) if isinstance(prompt, list) else 1 diff --git a/api/onnx_web/logging.py b/api/onnx_web/logging.py index b32f93bb..87117fa0 100644 --- a/api/onnx_web/logging.py +++ b/api/onnx_web/logging.py @@ -10,7 +10,6 @@ try: if path.exists(logging_path): with open(logging_path, 'r') as f: config_logging = safe_load(f) - # print('configuring logger', config_logging) dictConfig(config_logging) except Exception as err: print('error loading logging config: %s' % (err)) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index bc1198f9..7a36aa8f 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -575,11 +575,9 @@ def chain(): pipeline = ChainPipeline() for stage_data in data.get('stages', []): - logger.info('request stage: %s', stage_data) - callback = chain_stages[stage_data.get('type')] kwargs = stage_data.get('params', {}) - print('stage', callback.__name__, kwargs) + logger.info('request stage: %s, %s', callback.__name__, kwargs) stage = StageParams( stage_data.get('name', callback.__name__), @@ -587,12 +585,10 @@ def chain(): outscale=get_and_clamp_int(kwargs,'outscale', 1, 4), ) - # TODO: create Border from border if 'border' in kwargs: border = Border.even(int(kwargs.get('border'))) kwargs['border'] = border - # TODO: create Upscale from upscale if 'upscale' in kwargs: upscale = UpscaleParams(kwargs.get('upscale'), params.provider) kwargs['upscale'] = upscale diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace index e73f770c..504a5b13 100644 --- a/onnx-web.code-workspace +++ b/onnx-web.code-workspace @@ -20,6 +20,7 @@ "denoise", "denoising", "directml", + "dtype", "ESRGAN", "ftfy", "gfpgan", @@ -44,6 +45,7 @@ "pndm", "pretrained", "protobuf", + "randn", "realesr", "resrgan", "RRDB", @@ -56,6 +58,8 @@ "spinalcase", "stabilityai", "stringcase", + "timestep", + "timesteps", "uncond", "unet", "untruncated",