diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index b98640db..aaff53d2 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -198,35 +198,53 @@ def remove_prefix(name: str, prefix: str) -> str: return name -def load_tensor(name: str, map_location=None): +def load_torch(name: str, map_location=None) -> Optional[Dict]: + try: + logger.debug( + "loading tensor with Torch JIT: %s", name + ) + checkpoint = torch.jit.load(name) + except Exception: + logger.exception( + "error loading with Torch JIT, falling back to Torch: %s", name + ) + checkpoint = torch.load(name, map_location=map_location) + + return checkpoint + + +def load_tensor(name: str, map_location=None) -> Optional[Dict]: + if not path.exists(name): + logger.warning("tensor does not exist: %s", name) + return None + logger.debug("loading tensor: %s", name) _, extension = path.splitext(name) extension = extension[1:].lower() - if extension == "safetensors": + if extension == "": + # if no extension was intentional, do not search for others + if path.exists(name): + logger.debug("loading anonymous tensor") + checkpoint = torch.load(name, map_location=map_location) + else: + logger.debug("searching for tensors with known extensions") + for try_extension in ["safetensors", "ckpt", "pt", "bin"]: + checkpoint = load_tensor(f"{name}.{try_extension}", map_location=map_location) + if checkpoint is not None: + break + elif extension == "safetensors": environ["SAFETENSORS_FAST_GPU"] = "1" - try: - logger.debug("loading safetensors") - checkpoint = safetensors.torch.load_file(name, device="cpu") - except Exception as e: - try: - logger.warning( - "failed to load as safetensors file, falling back to Torch JIT: %s", e - ) - checkpoint = torch.jit.load(name) - except Exception as e: - logger.warning( - "failed to load with Torch JIT, falling back to PyTorch: %s", e - ) - checkpoint = torch.load(name, map_location=map_location) - elif extension in ["", "bin", "ckpt", "pt"]: - logger.debug("loading ckpt") + logger.debug("loading safetensors") + checkpoint = safetensors.torch.load_file(name, device="cpu") + elif extension in ["bin", "ckpt", "pt"]: + logger.debug("loading pickle tensor") checkpoint = torch.load(name, map_location=map_location) elif extension in ["onnx", "pt"]: - logger.warning("unknown tensor extension, may be ONNX model: %s", extension) + logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension) checkpoint = torch.load(name, map_location=map_location) else: - logger.warning("unknown tensor extension: %s", extension) + logger.warning("unknown tensor type, falling back to PyTorch: %s", extension) checkpoint = torch.load(name, map_location=map_location) if "state_dict" in checkpoint: diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index c96bb31c..474d300a 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -188,7 +188,7 @@ def load_pipeline( inversion_names, inversion_weights = zip(*inversions) inversion_models = [ - path.join(server.model_path, "inversion", f"{name}.ckpt") + path.join(server.model_path, "inversion", name) for name in inversion_names ] text_encoder = load_model(path.join(model, "text_encoder", "model.onnx")) @@ -226,7 +226,7 @@ def load_pipeline( if loras is not None and len(loras) > 0: lora_names, lora_weights = zip(*loras) lora_models = [ - path.join(server.model_path, "lora", f"{name}.safetensors") + path.join(server.model_path, "lora", name) for name in lora_names ] logger.info(