From c8aad8554ed39a564b94c72a163e374b5a496e3c Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 21 Mar 2023 22:05:14 -0500 Subject: [PATCH] fix(api): convert back to model format after blending, convert samples as needed (#274) --- api/onnx_web/convert/diffusion/diffusers.py | 18 +++++++++--------- api/onnx_web/convert/diffusion/lora.py | 2 +- .../convert/diffusion/textual_inversion.py | 10 +++++----- api/onnx_web/diffusers/load.py | 14 +++++++------- api/onnx_web/server/context.py | 6 ------ 5 files changed, 22 insertions(+), 28 deletions(-) diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index 92ae68dc..6c434a2d 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -97,8 +97,8 @@ def convert_diffusion_diffusers( single_vae = model.get("single_vae") replace_vae = model.get("vae") - torch_dtype = ctx.torch_dtype() - logger.debug("using Torch dtype %s for pipeline", torch_dtype) + dtype = ctx.torch_dtype() + logger.debug("using Torch dtype %s for pipeline", dtype) dest_path = path.join(ctx.model_path, name) model_index = path.join(dest_path, "model_index.json") @@ -117,7 +117,7 @@ def convert_diffusion_diffusers( pipeline = StableDiffusionPipeline.from_pretrained( source, - torch_dtype=torch_dtype, + torch_dtype=dtype, use_auth_token=ctx.token, ).to(ctx.training_device) output_path = Path(dest_path) @@ -174,11 +174,11 @@ def convert_diffusion_diffusers( pipeline.unet, model_args=( torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to( - device=ctx.training_device, dtype=torch_dtype + device=ctx.training_device, dtype=dtype ), - torch.randn(2).to(device=ctx.training_device, dtype=torch_dtype), + torch.randn(2).to(device=ctx.training_device, dtype=dtype), torch.randn(2, num_tokens, text_hidden_size).to( - device=ctx.training_device, dtype=torch_dtype + device=ctx.training_device, dtype=dtype ), unet_scale, ), @@ -230,7 +230,7 @@ def convert_diffusion_diffusers( model_args=( torch.randn( 1, vae_latent_channels, unet_sample_size, unet_sample_size - ).to(device=ctx.training_device, dtype=torch_dtype), + ).to(device=ctx.training_device, dtype=dtype), False, ), output_path=output_path / "vae" / "model.onnx", @@ -255,7 +255,7 @@ def convert_diffusion_diffusers( vae_encoder, model_args=( torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to( - device=ctx.training_device, dtype=torch_dtype + device=ctx.training_device, dtype=dtype ), False, ), @@ -279,7 +279,7 @@ def convert_diffusion_diffusers( model_args=( torch.randn( 1, vae_latent_channels, unet_sample_size, unet_sample_size - ).to(device=ctx.training_device, dtype=torch_dtype), + ).to(device=ctx.training_device, dtype=dtype), False, ), output_path=output_path / "vae_decoder" / "model.onnx", diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index c1e8cf11..da6f06b2 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -62,7 +62,7 @@ def blend_loras( ): # always load to CPU for blending device = torch.device("cpu") - dtype = context.torch_dtype() + dtype = torch.float32 base_model = base_name if isinstance(base_name, ModelProto) else load(base_name) lora_models = [load_tensor(name, map_location=device) for name, _weight in loras] diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index 519ea3f3..f96c7a32 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -22,7 +22,7 @@ def blend_textual_inversions( ) -> Tuple[ModelProto, CLIPTokenizer]: # always load to CPU for blending device = torch.device("cpu") - dtype = context.numpy_dtype() + dtype = np.float32 embeds = {} for name, weight, base_token, inversion_format in inversions: @@ -131,11 +131,11 @@ def blend_textual_inversions( for n in text_encoder.graph.initializer if n.name == "text_model.embeddings.token_embedding.weight" ][0] - embedding_weights = numpy_helper.to_array(embedding_node) + base_weights = numpy_helper.to_array(embedding_node) - weights_dim = embedding_weights.shape[1] + weights_dim = base_weights.shape[1] zero_weights = np.zeros((num_added_tokens, weights_dim)) - embedding_weights = np.concatenate((embedding_weights, zero_weights), axis=0) + embedding_weights = np.concatenate((base_weights, zero_weights), axis=0) for token, weights in embeds.items(): token_id = tokenizer.convert_tokens_to_ids(token) @@ -149,7 +149,7 @@ def blend_textual_inversions( == "text_model.embeddings.token_embedding.weight" ): new_initializer = numpy_helper.from_array( - embedding_weights.astype(dtype), embedding_node.name + embedding_weights.astype(base_weights.dtype), embedding_node.name ) logger.trace("new initializer data type: %s", new_initializer.data_type) del text_encoder.graph.initializer[i] diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 08f3cbce..5a795794 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -360,15 +360,15 @@ class UNetWrapper(object): global timestep_dtype timestep_dtype = timestep.dtype - logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype) - if "onnx-fp16" in self.server.optimizations: - logger.info("converting UNet sample to ONNX fp16") - sample = sample.astype(np.float16) - encoder_hidden_states = encoder_hidden_states.astype(np.float16) - elif sample.dtype != timestep.dtype: - logger.info("converting UNet sample to timestep dtype") + logger.trace("UNet parameter types: %s, %s, %s", sample.dtype, timestep.dtype, encoder_hidden_states.dtype) + if sample.dtype != timestep.dtype: + logger.trace("converting UNet sample to timestep dtype") sample = sample.astype(timestep.dtype) + if sample.dtype != timestep.dtype: + logger.trace("converting UNet hidden states to timestep dtype") + encoder_hidden_states = encoder_hidden_states.astype(np.float16) + return self.wrapped( sample=sample, timestep=timestep, diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index fa264ca5..773abef9 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -86,9 +86,3 @@ class ServerContext: return torch.float16 else: return torch.float32 - - def numpy_dtype(self): - if "torch-fp16" in self.optimizations or "onnx-fp16" in self.optimizations: - return np.float16 - else: - return np.float32 \ No newline at end of file