1
0
Fork 0

fix(api): fallback to PyTorch if tensors fail to load with JIT

This commit is contained in:
Sean Sube 2023-02-17 07:39:48 -06:00
parent 005650a9a2
commit b3c8fce16b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 11 additions and 4 deletions

View File

@ -211,10 +211,17 @@ def load_tensor(name: str, map_location=None):
logger.debug("loading safetensors")
checkpoint = safetensors.torch.load_file(name, device="cpu")
except Exception as e:
logger.warning(
"failed to load as safetensors file, falling back to torch", e
)
checkpoint = torch.jit.load(name)
try:
logger.warning(
"failed to load as safetensors file, falling back to torch", e
)
checkpoint = torch.jit.load(name)
except Exception as e:
logger.warning(
"failed to load with Torch JIT, falling back to PyTorch", e
)
checkpoint = torch.load(name, map_location=map_location)
else:
logger.debug("loading ckpt")
checkpoint = torch.load(name, map_location=map_location)