diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index c60cf63d..7ddfe5f2 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -221,7 +221,9 @@ def load_tensor(name: str, map_location=None): "failed to load with Torch JIT, falling back to PyTorch", e ) checkpoint = torch.load(name, map_location=map_location) - + checkpoint = ( + checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint + ) else: logger.debug("loading ckpt") checkpoint = torch.load(name, map_location=map_location)