diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 6877c298..d75b1862 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -75,6 +75,10 @@ def blend_loras( blended: Dict[str, np.ndarray] = {} for (lora_name, lora_weight), lora_model in zip(loras, lora_models): logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight) + if lora_model is None: + logger.warning("unable to load tensor for LoRA") + continue + for key in lora_model.keys(): if ".lora_down" in key and lora_prefix in key: base_key = key[: key.index(".lora_down")].replace(lora_prefix, "") diff --git a/api/onnx_web/convert/diffusion/original.py b/api/onnx_web/convert/diffusion/original.py index c75a328e..005e1898 100644 --- a/api/onnx_web/convert/diffusion/original.py +++ b/api/onnx_web/convert/diffusion/original.py @@ -1410,6 +1410,9 @@ def extract_checkpoint( # Try to determine if v1 or v2 model if we have a ckpt logger.info("loading model from checkpoint") checkpoint = load_tensor(checkpoint_file, map_location=map_location) + if checkpoint is None: + logger.warning("unable to load tensor") + return rev_keys = ["db_global_step", "global_step"] epoch_keys = ["db_epoch", "epoch"] diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index 340a19fc..ec02f603 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -50,6 +50,9 @@ def blend_textual_inversions( token = base_token or f.read() loaded_embeds = load_tensor(embeds_file, map_location=device) + if loaded_embeds is None: + logger.warning("unable to load tensor") + continue # separate token and the embeds trained_token = list(loaded_embeds.keys())[0] @@ -62,6 +65,9 @@ def blend_textual_inversions( embeds[token] = layer elif inversion_format == "embeddings": loaded_embeds = load_tensor(name, map_location=device) + if loaded_embeds is None: + logger.warning("unable to load tensor") + continue string_to_token = loaded_embeds["string_to_token"] string_to_param = loaded_embeds["string_to_param"] diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index c09607c0..6e482d63 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -218,6 +218,7 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]: _, extension = path.splitext(name) extension = extension[1:].lower() + checkpoint = None if extension == "": # if no extension was intentional, do not search for others if path.exists(name): @@ -225,25 +226,39 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]: checkpoint = torch.load(name, map_location=map_location) else: logger.debug("searching for tensors with known extensions") - for try_extension in ["safetensors", "ckpt", "pt", "bin"]: - checkpoint = load_tensor(f"{name}.{try_extension}", map_location=map_location) - if checkpoint is not None: - break + for next_extension in ["safetensors", "ckpt", "pt", "bin"]: + next_name = f"{name}.{next_extension}" + if path.exists(next_name): + checkpoint = load_tensor(next_name, map_location=map_location) + if checkpoint is not None: + break elif extension == "safetensors": - environ["SAFETENSORS_FAST_GPU"] = "1" logger.debug("loading safetensors") - checkpoint = safetensors.torch.load_file(name, device="cpu") + try: + environ["SAFETENSORS_FAST_GPU"] = "1" + checkpoint = safetensors.torch.load_file(name, device="cpu") + except Exception as e: + logger.warning("error loading safetensor: %s", e) elif extension in ["bin", "ckpt", "pt"]: logger.debug("loading pickle tensor") - checkpoint = load_torch(name, map_location=map_location) + try: + checkpoint = load_torch(name, map_location=map_location) + except Exception as e: + logger.warning("error loading pickle tensor: %s", e) elif extension in ["onnx", "pt"]: logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension) - checkpoint = load_torch(name, map_location=map_location) + try: + checkpoint = load_torch(name, map_location=map_location) + except Exception as e: + logger.warning("error loading tensor: %s", e) else: logger.warning("unknown tensor type, falling back to PyTorch: %s", extension) - checkpoint = load_torch(name, map_location=map_location) + try: + checkpoint = load_torch(name, map_location=map_location) + except Exception as e: + logger.warning("error loading tensor: %s", e) - if "state_dict" in checkpoint: + if checkpoint is not None and "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] return checkpoint