From 33008531e976e6aa61750609d0e46efefc555ff4 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 19 Mar 2023 15:13:54 -0500 Subject: [PATCH] fix(api): load blending tensors onto CPU --- api/onnx_web/__init__.py | 2 ++ api/onnx_web/convert/diffusion/lora.py | 7 +++-- .../convert/diffusion/textual_inversion.py | 12 +++++---- api/onnx_web/convert/utils.py | 27 ++++++++++--------- 4 files changed, 29 insertions(+), 19 deletions(-) diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index 21173d70..6797bc77 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -5,6 +5,8 @@ from .chain import ( upscale_resrgan, upscale_stable_diffusion, ) +from .convert.diffusion.lora import blend_loras +from .convert.diffusion.textual_inversion import blend_textual_inversions from .diffusers.load import get_latents_from_seed, load_pipeline, optimize_pipeline from .diffusers.run import ( run_blend_pipeline, diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index fd99afbc..6877c298 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -16,7 +16,7 @@ from onnxruntime import InferenceSession, OrtValue, SessionOptions from safetensors.torch import load_file from ...server.context import ServerContext -from ..utils import ConversionContext +from ..utils import ConversionContext, load_tensor logger = getLogger(__name__) @@ -61,8 +61,11 @@ def blend_loras( loras: List[Tuple[str, float]], model_type: Literal["text_encoder", "unet"], ): + # always load to CPU for blending + device = torch.device("cpu") + base_model = base_name if isinstance(base_name, ModelProto) else load(base_name) - lora_models = [load_file(name) for name, _weight in loras] + lora_models = [load_tensor(name, map_location=device) for name, _weight in loras] if model_type == "text_encoder": lora_prefix = "lora_te_" diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index d0b5d2fb..340a19fc 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -9,7 +9,7 @@ from onnx import ModelProto, load_model, numpy_helper, save_model from transformers import CLIPTokenizer from ...server.context import ServerContext -from ..utils import ConversionContext +from ..utils import ConversionContext, load_tensor logger = getLogger(__name__) @@ -21,6 +21,8 @@ def blend_textual_inversions( tokenizer: CLIPTokenizer, inversions: List[Tuple[str, float, Optional[str], Optional[str]]], ) -> Tuple[ModelProto, CLIPTokenizer]: + # always load to CPU for blending + device = torch.device("cpu") dtype = np.float embeds = {} @@ -47,19 +49,19 @@ def blend_textual_inversions( with open(token_file, "r") as f: token = base_token or f.read() - loaded_embeds = torch.load(embeds_file) + loaded_embeds = load_tensor(embeds_file, map_location=device) # separate token and the embeds trained_token = list(loaded_embeds.keys())[0] - layer = loaded_embeds[trained_token].cpu().numpy().astype(dtype) + layer = loaded_embeds[trained_token].numpy().astype(dtype) layer *= weight if trained_token in embeds: embeds[token] += layer else: embeds[token] = layer elif inversion_format == "embeddings": - loaded_embeds = torch.load(name) + loaded_embeds = load_tensor(name, map_location=device) string_to_token = loaded_embeds["string_to_token"] string_to_param = loaded_embeds["string_to_param"] @@ -75,7 +77,7 @@ def blend_textual_inversions( for i in range(num_tokens): token = f"{base_token}-{i}" - layer = trained_embeds[i, :].cpu().numpy().astype(dtype) + layer = trained_embeds[i, :].numpy().astype(dtype) layer *= weight sum_layer += layer diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 133765b9..b98640db 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -199,9 +199,11 @@ def remove_prefix(name: str, prefix: str) -> str: def load_tensor(name: str, map_location=None): - logger.info("loading model from checkpoint") + logger.debug("loading tensor: %s", name) _, extension = path.splitext(name) - if extension.lower() == ".safetensors": + extension = extension[1:].lower() + + if extension == "safetensors": environ["SAFETENSORS_FAST_GPU"] = "1" try: logger.debug("loading safetensors") @@ -209,7 +211,7 @@ def load_tensor(name: str, map_location=None): except Exception as e: try: logger.warning( - "failed to load as safetensors file, falling back to torch: %s", e + "failed to load as safetensors file, falling back to Torch JIT: %s", e ) checkpoint = torch.jit.load(name) except Exception as e: @@ -217,16 +219,17 @@ def load_tensor(name: str, map_location=None): "failed to load with Torch JIT, falling back to PyTorch: %s", e ) checkpoint = torch.load(name, map_location=map_location) - checkpoint = ( - checkpoint["state_dict"] - if "state_dict" in checkpoint - else checkpoint - ) - else: + elif extension in ["", "bin", "ckpt", "pt"]: logger.debug("loading ckpt") checkpoint = torch.load(name, map_location=map_location) - checkpoint = ( - checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint - ) + elif extension in ["onnx", "pt"]: + logger.warning("unknown tensor extension, may be ONNX model: %s", extension) + checkpoint = torch.load(name, map_location=map_location) + else: + logger.warning("unknown tensor extension: %s", extension) + checkpoint = torch.load(name, map_location=map_location) + + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] return checkpoint