From 873276f1d0e1cc1e2df2750e298028023557f0bd Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 16 Feb 2023 20:23:10 -0600 Subject: [PATCH] load checkpoint properly --- api/onnx_web/convert/diffusion_original.py | 1 + api/onnx_web/convert/utils.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/api/onnx_web/convert/diffusion_original.py b/api/onnx_web/convert/diffusion_original.py index 9dd80a51..64ab7be7 100644 --- a/api/onnx_web/convert/diffusion_original.py +++ b/api/onnx_web/convert/diffusion_original.py @@ -1302,6 +1302,7 @@ def extract_checkpoint( converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) else: vae_file = os.path.join(ctx.model_path, vae_file) + logger.debug("loading custom VAE file: %s", vae_file) vae_checkpoint = load_tensor(vae_file, map_location=map_location) converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index a150b1b6..8493ca1a 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -221,3 +221,5 @@ def load_tensor(name: str, map_location=None): checkpoint = ( checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint ) + + return checkpoint \ No newline at end of file