From b3c8fce16b73ee2d3da75c5a190fd33ff600e676 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 17 Feb 2023 07:39:48 -0600 Subject: [PATCH] fix(api): fallback to PyTorch if tensors fail to load with JIT --- api/onnx_web/convert/utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 4661b846..c60cf63d 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -211,10 +211,17 @@ def load_tensor(name: str, map_location=None): logger.debug("loading safetensors") checkpoint = safetensors.torch.load_file(name, device="cpu") except Exception as e: - logger.warning( - "failed to load as safetensors file, falling back to torch", e - ) - checkpoint = torch.jit.load(name) + try: + logger.warning( + "failed to load as safetensors file, falling back to torch", e + ) + checkpoint = torch.jit.load(name) + except Exception as e: + logger.warning( + "failed to load with Torch JIT, falling back to PyTorch", e + ) + checkpoint = torch.load(name, map_location=map_location) + else: logger.debug("loading ckpt") checkpoint = torch.load(name, map_location=map_location)