fix(api): better error handling around tensor loading
This commit is contained in:
parent
07622690dc
commit
8acc15f52e
|
@ -75,6 +75,10 @@ def blend_loras(
|
|||
blended: Dict[str, np.ndarray] = {}
|
||||
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
||||
logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
||||
if lora_model is None:
|
||||
logger.warning("unable to load tensor for LoRA")
|
||||
continue
|
||||
|
||||
for key in lora_model.keys():
|
||||
if ".lora_down" in key and lora_prefix in key:
|
||||
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
|
||||
|
|
|
@ -1410,6 +1410,9 @@ def extract_checkpoint(
|
|||
# Try to determine if v1 or v2 model if we have a ckpt
|
||||
logger.info("loading model from checkpoint")
|
||||
checkpoint = load_tensor(checkpoint_file, map_location=map_location)
|
||||
if checkpoint is None:
|
||||
logger.warning("unable to load tensor")
|
||||
return
|
||||
|
||||
rev_keys = ["db_global_step", "global_step"]
|
||||
epoch_keys = ["db_epoch", "epoch"]
|
||||
|
|
|
@ -50,6 +50,9 @@ def blend_textual_inversions(
|
|||
token = base_token or f.read()
|
||||
|
||||
loaded_embeds = load_tensor(embeds_file, map_location=device)
|
||||
if loaded_embeds is None:
|
||||
logger.warning("unable to load tensor")
|
||||
continue
|
||||
|
||||
# separate token and the embeds
|
||||
trained_token = list(loaded_embeds.keys())[0]
|
||||
|
@ -62,6 +65,9 @@ def blend_textual_inversions(
|
|||
embeds[token] = layer
|
||||
elif inversion_format == "embeddings":
|
||||
loaded_embeds = load_tensor(name, map_location=device)
|
||||
if loaded_embeds is None:
|
||||
logger.warning("unable to load tensor")
|
||||
continue
|
||||
|
||||
string_to_token = loaded_embeds["string_to_token"]
|
||||
string_to_param = loaded_embeds["string_to_param"]
|
||||
|
|
|
@ -218,6 +218,7 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
|||
_, extension = path.splitext(name)
|
||||
extension = extension[1:].lower()
|
||||
|
||||
checkpoint = None
|
||||
if extension == "":
|
||||
# if no extension was intentional, do not search for others
|
||||
if path.exists(name):
|
||||
|
@ -225,25 +226,39 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
|||
checkpoint = torch.load(name, map_location=map_location)
|
||||
else:
|
||||
logger.debug("searching for tensors with known extensions")
|
||||
for try_extension in ["safetensors", "ckpt", "pt", "bin"]:
|
||||
checkpoint = load_tensor(f"{name}.{try_extension}", map_location=map_location)
|
||||
for next_extension in ["safetensors", "ckpt", "pt", "bin"]:
|
||||
next_name = f"{name}.{next_extension}"
|
||||
if path.exists(next_name):
|
||||
checkpoint = load_tensor(next_name, map_location=map_location)
|
||||
if checkpoint is not None:
|
||||
break
|
||||
elif extension == "safetensors":
|
||||
environ["SAFETENSORS_FAST_GPU"] = "1"
|
||||
logger.debug("loading safetensors")
|
||||
try:
|
||||
environ["SAFETENSORS_FAST_GPU"] = "1"
|
||||
checkpoint = safetensors.torch.load_file(name, device="cpu")
|
||||
except Exception as e:
|
||||
logger.warning("error loading safetensor: %s", e)
|
||||
elif extension in ["bin", "ckpt", "pt"]:
|
||||
logger.debug("loading pickle tensor")
|
||||
try:
|
||||
checkpoint = load_torch(name, map_location=map_location)
|
||||
except Exception as e:
|
||||
logger.warning("error loading pickle tensor: %s", e)
|
||||
elif extension in ["onnx", "pt"]:
|
||||
logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension)
|
||||
try:
|
||||
checkpoint = load_torch(name, map_location=map_location)
|
||||
except Exception as e:
|
||||
logger.warning("error loading tensor: %s", e)
|
||||
else:
|
||||
logger.warning("unknown tensor type, falling back to PyTorch: %s", extension)
|
||||
try:
|
||||
checkpoint = load_torch(name, map_location=map_location)
|
||||
except Exception as e:
|
||||
logger.warning("error loading tensor: %s", e)
|
||||
|
||||
if "state_dict" in checkpoint:
|
||||
if checkpoint is not None and "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
return checkpoint
|
||||
|
|
Loading…
Reference in New Issue