diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index aaff53d2..c09607c0 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -214,10 +214,6 @@ def load_torch(name: str, map_location=None) -> Optional[Dict]: def load_tensor(name: str, map_location=None) -> Optional[Dict]: - if not path.exists(name): - logger.warning("tensor does not exist: %s", name) - return None - logger.debug("loading tensor: %s", name) _, extension = path.splitext(name) extension = extension[1:].lower() @@ -239,13 +235,13 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]: checkpoint = safetensors.torch.load_file(name, device="cpu") elif extension in ["bin", "ckpt", "pt"]: logger.debug("loading pickle tensor") - checkpoint = torch.load(name, map_location=map_location) + checkpoint = load_torch(name, map_location=map_location) elif extension in ["onnx", "pt"]: logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension) - checkpoint = torch.load(name, map_location=map_location) + checkpoint = load_torch(name, map_location=map_location) else: logger.warning("unknown tensor type, falling back to PyTorch: %s", extension) - checkpoint = torch.load(name, map_location=map_location) + checkpoint = load_torch(name, map_location=map_location) if "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"]