fix(api): fallback to PyTorch if tensors fail to load with JIT
This commit is contained in:
parent
005650a9a2
commit
b3c8fce16b
|
@ -211,10 +211,17 @@ def load_tensor(name: str, map_location=None):
|
||||||
logger.debug("loading safetensors")
|
logger.debug("loading safetensors")
|
||||||
checkpoint = safetensors.torch.load_file(name, device="cpu")
|
checkpoint = safetensors.torch.load_file(name, device="cpu")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
try:
|
||||||
"failed to load as safetensors file, falling back to torch", e
|
logger.warning(
|
||||||
)
|
"failed to load as safetensors file, falling back to torch", e
|
||||||
checkpoint = torch.jit.load(name)
|
)
|
||||||
|
checkpoint = torch.jit.load(name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"failed to load with Torch JIT, falling back to PyTorch", e
|
||||||
|
)
|
||||||
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.debug("loading ckpt")
|
logger.debug("loading ckpt")
|
||||||
checkpoint = torch.load(name, map_location=map_location)
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
|
|
Loading…
Reference in New Issue