1
0
Fork 0

fix(api): better error handling around tensor loading

This commit is contained in:
Sean Sube 2023-03-19 15:38:43 -05:00
parent 07622690dc
commit 8acc15f52e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 38 additions and 10 deletions

View File

@ -75,6 +75,10 @@ def blend_loras(
blended: Dict[str, np.ndarray] = {}
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight)
if lora_model is None:
logger.warning("unable to load tensor for LoRA")
continue
for key in lora_model.keys():
if ".lora_down" in key and lora_prefix in key:
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")

View File

@ -1410,6 +1410,9 @@ def extract_checkpoint(
# Try to determine if v1 or v2 model if we have a ckpt
logger.info("loading model from checkpoint")
checkpoint = load_tensor(checkpoint_file, map_location=map_location)
if checkpoint is None:
logger.warning("unable to load tensor")
return
rev_keys = ["db_global_step", "global_step"]
epoch_keys = ["db_epoch", "epoch"]

View File

@ -50,6 +50,9 @@ def blend_textual_inversions(
token = base_token or f.read()
loaded_embeds = load_tensor(embeds_file, map_location=device)
if loaded_embeds is None:
logger.warning("unable to load tensor")
continue
# separate token and the embeds
trained_token = list(loaded_embeds.keys())[0]
@ -62,6 +65,9 @@ def blend_textual_inversions(
embeds[token] = layer
elif inversion_format == "embeddings":
loaded_embeds = load_tensor(name, map_location=device)
if loaded_embeds is None:
logger.warning("unable to load tensor")
continue
string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]

View File

@ -218,6 +218,7 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
_, extension = path.splitext(name)
extension = extension[1:].lower()
checkpoint = None
if extension == "":
# if no extension was intentional, do not search for others
if path.exists(name):
@ -225,25 +226,39 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
checkpoint = torch.load(name, map_location=map_location)
else:
logger.debug("searching for tensors with known extensions")
for try_extension in ["safetensors", "ckpt", "pt", "bin"]:
checkpoint = load_tensor(f"{name}.{try_extension}", map_location=map_location)
for next_extension in ["safetensors", "ckpt", "pt", "bin"]:
next_name = f"{name}.{next_extension}"
if path.exists(next_name):
checkpoint = load_tensor(next_name, map_location=map_location)
if checkpoint is not None:
break
elif extension == "safetensors":
environ["SAFETENSORS_FAST_GPU"] = "1"
logger.debug("loading safetensors")
try:
environ["SAFETENSORS_FAST_GPU"] = "1"
checkpoint = safetensors.torch.load_file(name, device="cpu")
except Exception as e:
logger.warning("error loading safetensor: %s", e)
elif extension in ["bin", "ckpt", "pt"]:
logger.debug("loading pickle tensor")
try:
checkpoint = load_torch(name, map_location=map_location)
except Exception as e:
logger.warning("error loading pickle tensor: %s", e)
elif extension in ["onnx", "pt"]:
logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension)
try:
checkpoint = load_torch(name, map_location=map_location)
except Exception as e:
logger.warning("error loading tensor: %s", e)
else:
logger.warning("unknown tensor type, falling back to PyTorch: %s", extension)
try:
checkpoint = load_torch(name, map_location=map_location)
except Exception as e:
logger.warning("error loading tensor: %s", e)
if "state_dict" in checkpoint:
if checkpoint is not None and "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
return checkpoint