diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index 2df3457b..92ae68dc 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -97,8 +97,7 @@ def convert_diffusion_diffusers( single_vae = model.get("single_vae") replace_vae = model.get("vae") - torch_half = "torch-fp16" in ctx.optimizations - torch_dtype = torch.float16 if torch_half else torch.float32 + torch_dtype = ctx.torch_dtype() logger.debug("using Torch dtype %s for pipeline", torch_dtype) dest_path = path.join(ctx.model_path, name) diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 4a8cd3db..c1e8cf11 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -62,6 +62,7 @@ def blend_loras( ): # always load to CPU for blending device = torch.device("cpu") + dtype = context.torch_dtype() base_model = base_name if isinstance(base_name, ModelProto) else load(base_name) lora_models = [load_tensor(name, map_location=device) for name, _weight in loras] @@ -88,11 +89,11 @@ def blend_loras( "blending weights for keys: %s, %s, %s", key, up_key, alpha_key ) - down_weight = lora_model[key].to(dtype=torch.float32) - up_weight = lora_model[up_key].to(dtype=torch.float32) + down_weight = lora_model[key].to(dtype=dtype) + up_weight = lora_model[up_key].to(dtype=dtype) dim = down_weight.size()[0] - alpha = lora_model.get(alpha_key, dim).to(torch.float32).numpy() + alpha = lora_model.get(alpha_key, dim).to(dtype).numpy() try: if len(up_weight.size()) == 2: @@ -203,7 +204,7 @@ def blend_loras( logger.trace("blended weight shape: %s", blended.shape) # replace the original initializer - updated_node = numpy_helper.from_array(blended, weight_node.name) + updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), weight_node.name) del base_model.graph.initializer[weight_idx] base_model.graph.initializer.insert(weight_idx, updated_node) elif matmul_key in fixed_node_names: @@ -232,7 +233,7 @@ def blend_loras( logger.trace("blended weight shape: %s", blended.shape) # replace the original initializer - updated_node = numpy_helper.from_array(blended, matmul_node.name) + updated_node = numpy_helper.from_array(blended.astype(base_weights.dtype), matmul_node.name) del base_model.graph.initializer[matmul_idx] base_model.graph.initializer.insert(matmul_idx, updated_node) else: diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index 3650ee55..519ea3f3 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -22,7 +22,7 @@ def blend_textual_inversions( ) -> Tuple[ModelProto, CLIPTokenizer]: # always load to CPU for blending device = torch.device("cpu") - dtype = np.float + dtype = context.numpy_dtype() embeds = {} for name, weight, base_token, inversion_format in inversions: @@ -149,7 +149,7 @@ def blend_textual_inversions( == "text_model.embeddings.token_embedding.weight" ): new_initializer = numpy_helper.from_array( - embedding_weights.astype(np.float32), embedding_node.name + embedding_weights.astype(dtype), embedding_node.name ) logger.trace("new initializer data type: %s", new_initializer.data_type) del text_encoder.graph.initializer[i] diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index 1fa73ba2..fa264ca5 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -2,6 +2,9 @@ from logging import getLogger from os import environ, path from typing import List, Optional +import torch +import numpy as np + from ..utils import get_boolean from .model_cache import ModelCache @@ -77,3 +80,15 @@ class ServerContext: job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), memory_limit=memory_limit, ) + + def torch_dtype(self): + if "torch-fp16" in self.optimizations: + return torch.float16 + else: + return torch.float32 + + def numpy_dtype(self): + if "torch-fp16" in self.optimizations or "onnx-fp16" in self.optimizations: + return np.float16 + else: + return np.float32 \ No newline at end of file