fix(api): fallback to PyTorch if tensors fail to load with JIT
This commit is contained in:
parent
005650a9a2
commit
b3c8fce16b
|
@ -211,10 +211,17 @@ def load_tensor(name: str, map_location=None):
|
|||
logger.debug("loading safetensors")
|
||||
checkpoint = safetensors.torch.load_file(name, device="cpu")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"failed to load as safetensors file, falling back to torch", e
|
||||
)
|
||||
checkpoint = torch.jit.load(name)
|
||||
try:
|
||||
logger.warning(
|
||||
"failed to load as safetensors file, falling back to torch", e
|
||||
)
|
||||
checkpoint = torch.jit.load(name)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"failed to load with Torch JIT, falling back to PyTorch", e
|
||||
)
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
|
||||
else:
|
||||
logger.debug("loading ckpt")
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
|
|
Loading…
Reference in New Issue