From 37b173d0d1643393c542ec8663182d257dbfe04e Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 17 Feb 2023 08:23:12 -0600 Subject: [PATCH] fix(api): unwrap state dict from VAE --- api/onnx_web/convert/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index c60cf63d..7ddfe5f2 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -221,7 +221,9 @@ def load_tensor(name: str, map_location=None): "failed to load with Torch JIT, falling back to PyTorch", e ) checkpoint = torch.load(name, map_location=map_location) - + checkpoint = ( + checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint + ) else: logger.debug("loading ckpt") checkpoint = torch.load(name, map_location=map_location)