1
0
Fork 0

fix(api): better error handling around tensor loading

This commit is contained in:
Sean Sube 2023-03-19 15:38:43 -05:00
parent 07622690dc
commit 8acc15f52e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 38 additions and 10 deletions

View File

@ -75,6 +75,10 @@ def blend_loras(
blended: Dict[str, np.ndarray] = {} blended: Dict[str, np.ndarray] = {}
for (lora_name, lora_weight), lora_model in zip(loras, lora_models): for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight) logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight)
if lora_model is None:
logger.warning("unable to load tensor for LoRA")
continue
for key in lora_model.keys(): for key in lora_model.keys():
if ".lora_down" in key and lora_prefix in key: if ".lora_down" in key and lora_prefix in key:
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "") base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")

View File

@ -1410,6 +1410,9 @@ def extract_checkpoint(
# Try to determine if v1 or v2 model if we have a ckpt # Try to determine if v1 or v2 model if we have a ckpt
logger.info("loading model from checkpoint") logger.info("loading model from checkpoint")
checkpoint = load_tensor(checkpoint_file, map_location=map_location) checkpoint = load_tensor(checkpoint_file, map_location=map_location)
if checkpoint is None:
logger.warning("unable to load tensor")
return
rev_keys = ["db_global_step", "global_step"] rev_keys = ["db_global_step", "global_step"]
epoch_keys = ["db_epoch", "epoch"] epoch_keys = ["db_epoch", "epoch"]

View File

@ -50,6 +50,9 @@ def blend_textual_inversions(
token = base_token or f.read() token = base_token or f.read()
loaded_embeds = load_tensor(embeds_file, map_location=device) loaded_embeds = load_tensor(embeds_file, map_location=device)
if loaded_embeds is None:
logger.warning("unable to load tensor")
continue
# separate token and the embeds # separate token and the embeds
trained_token = list(loaded_embeds.keys())[0] trained_token = list(loaded_embeds.keys())[0]
@ -62,6 +65,9 @@ def blend_textual_inversions(
embeds[token] = layer embeds[token] = layer
elif inversion_format == "embeddings": elif inversion_format == "embeddings":
loaded_embeds = load_tensor(name, map_location=device) loaded_embeds = load_tensor(name, map_location=device)
if loaded_embeds is None:
logger.warning("unable to load tensor")
continue
string_to_token = loaded_embeds["string_to_token"] string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"] string_to_param = loaded_embeds["string_to_param"]

View File

@ -218,6 +218,7 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
_, extension = path.splitext(name) _, extension = path.splitext(name)
extension = extension[1:].lower() extension = extension[1:].lower()
checkpoint = None
if extension == "": if extension == "":
# if no extension was intentional, do not search for others # if no extension was intentional, do not search for others
if path.exists(name): if path.exists(name):
@ -225,25 +226,39 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
checkpoint = torch.load(name, map_location=map_location) checkpoint = torch.load(name, map_location=map_location)
else: else:
logger.debug("searching for tensors with known extensions") logger.debug("searching for tensors with known extensions")
for try_extension in ["safetensors", "ckpt", "pt", "bin"]: for next_extension in ["safetensors", "ckpt", "pt", "bin"]:
checkpoint = load_tensor(f"{name}.{try_extension}", map_location=map_location) next_name = f"{name}.{next_extension}"
if checkpoint is not None: if path.exists(next_name):
break checkpoint = load_tensor(next_name, map_location=map_location)
if checkpoint is not None:
break
elif extension == "safetensors": elif extension == "safetensors":
environ["SAFETENSORS_FAST_GPU"] = "1"
logger.debug("loading safetensors") logger.debug("loading safetensors")
checkpoint = safetensors.torch.load_file(name, device="cpu") try:
environ["SAFETENSORS_FAST_GPU"] = "1"
checkpoint = safetensors.torch.load_file(name, device="cpu")
except Exception as e:
logger.warning("error loading safetensor: %s", e)
elif extension in ["bin", "ckpt", "pt"]: elif extension in ["bin", "ckpt", "pt"]:
logger.debug("loading pickle tensor") logger.debug("loading pickle tensor")
checkpoint = load_torch(name, map_location=map_location) try:
checkpoint = load_torch(name, map_location=map_location)
except Exception as e:
logger.warning("error loading pickle tensor: %s", e)
elif extension in ["onnx", "pt"]: elif extension in ["onnx", "pt"]:
logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension) logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension)
checkpoint = load_torch(name, map_location=map_location) try:
checkpoint = load_torch(name, map_location=map_location)
except Exception as e:
logger.warning("error loading tensor: %s", e)
else: else:
logger.warning("unknown tensor type, falling back to PyTorch: %s", extension) logger.warning("unknown tensor type, falling back to PyTorch: %s", extension)
checkpoint = load_torch(name, map_location=map_location) try:
checkpoint = load_torch(name, map_location=map_location)
except Exception as e:
logger.warning("error loading tensor: %s", e)
if "state_dict" in checkpoint: if checkpoint is not None and "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
return checkpoint return checkpoint