fix(api): better error handling around tensor loading
This commit is contained in:
parent
07622690dc
commit
8acc15f52e
|
@ -75,6 +75,10 @@ def blend_loras(
|
||||||
blended: Dict[str, np.ndarray] = {}
|
blended: Dict[str, np.ndarray] = {}
|
||||||
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
||||||
logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
||||||
|
if lora_model is None:
|
||||||
|
logger.warning("unable to load tensor for LoRA")
|
||||||
|
continue
|
||||||
|
|
||||||
for key in lora_model.keys():
|
for key in lora_model.keys():
|
||||||
if ".lora_down" in key and lora_prefix in key:
|
if ".lora_down" in key and lora_prefix in key:
|
||||||
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
|
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
|
||||||
|
|
|
@ -1410,6 +1410,9 @@ def extract_checkpoint(
|
||||||
# Try to determine if v1 or v2 model if we have a ckpt
|
# Try to determine if v1 or v2 model if we have a ckpt
|
||||||
logger.info("loading model from checkpoint")
|
logger.info("loading model from checkpoint")
|
||||||
checkpoint = load_tensor(checkpoint_file, map_location=map_location)
|
checkpoint = load_tensor(checkpoint_file, map_location=map_location)
|
||||||
|
if checkpoint is None:
|
||||||
|
logger.warning("unable to load tensor")
|
||||||
|
return
|
||||||
|
|
||||||
rev_keys = ["db_global_step", "global_step"]
|
rev_keys = ["db_global_step", "global_step"]
|
||||||
epoch_keys = ["db_epoch", "epoch"]
|
epoch_keys = ["db_epoch", "epoch"]
|
||||||
|
|
|
@ -50,6 +50,9 @@ def blend_textual_inversions(
|
||||||
token = base_token or f.read()
|
token = base_token or f.read()
|
||||||
|
|
||||||
loaded_embeds = load_tensor(embeds_file, map_location=device)
|
loaded_embeds = load_tensor(embeds_file, map_location=device)
|
||||||
|
if loaded_embeds is None:
|
||||||
|
logger.warning("unable to load tensor")
|
||||||
|
continue
|
||||||
|
|
||||||
# separate token and the embeds
|
# separate token and the embeds
|
||||||
trained_token = list(loaded_embeds.keys())[0]
|
trained_token = list(loaded_embeds.keys())[0]
|
||||||
|
@ -62,6 +65,9 @@ def blend_textual_inversions(
|
||||||
embeds[token] = layer
|
embeds[token] = layer
|
||||||
elif inversion_format == "embeddings":
|
elif inversion_format == "embeddings":
|
||||||
loaded_embeds = load_tensor(name, map_location=device)
|
loaded_embeds = load_tensor(name, map_location=device)
|
||||||
|
if loaded_embeds is None:
|
||||||
|
logger.warning("unable to load tensor")
|
||||||
|
continue
|
||||||
|
|
||||||
string_to_token = loaded_embeds["string_to_token"]
|
string_to_token = loaded_embeds["string_to_token"]
|
||||||
string_to_param = loaded_embeds["string_to_param"]
|
string_to_param = loaded_embeds["string_to_param"]
|
||||||
|
|
|
@ -218,6 +218,7 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
||||||
_, extension = path.splitext(name)
|
_, extension = path.splitext(name)
|
||||||
extension = extension[1:].lower()
|
extension = extension[1:].lower()
|
||||||
|
|
||||||
|
checkpoint = None
|
||||||
if extension == "":
|
if extension == "":
|
||||||
# if no extension was intentional, do not search for others
|
# if no extension was intentional, do not search for others
|
||||||
if path.exists(name):
|
if path.exists(name):
|
||||||
|
@ -225,25 +226,39 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
||||||
checkpoint = torch.load(name, map_location=map_location)
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
else:
|
else:
|
||||||
logger.debug("searching for tensors with known extensions")
|
logger.debug("searching for tensors with known extensions")
|
||||||
for try_extension in ["safetensors", "ckpt", "pt", "bin"]:
|
for next_extension in ["safetensors", "ckpt", "pt", "bin"]:
|
||||||
checkpoint = load_tensor(f"{name}.{try_extension}", map_location=map_location)
|
next_name = f"{name}.{next_extension}"
|
||||||
if checkpoint is not None:
|
if path.exists(next_name):
|
||||||
break
|
checkpoint = load_tensor(next_name, map_location=map_location)
|
||||||
|
if checkpoint is not None:
|
||||||
|
break
|
||||||
elif extension == "safetensors":
|
elif extension == "safetensors":
|
||||||
environ["SAFETENSORS_FAST_GPU"] = "1"
|
|
||||||
logger.debug("loading safetensors")
|
logger.debug("loading safetensors")
|
||||||
checkpoint = safetensors.torch.load_file(name, device="cpu")
|
try:
|
||||||
|
environ["SAFETENSORS_FAST_GPU"] = "1"
|
||||||
|
checkpoint = safetensors.torch.load_file(name, device="cpu")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("error loading safetensor: %s", e)
|
||||||
elif extension in ["bin", "ckpt", "pt"]:
|
elif extension in ["bin", "ckpt", "pt"]:
|
||||||
logger.debug("loading pickle tensor")
|
logger.debug("loading pickle tensor")
|
||||||
checkpoint = load_torch(name, map_location=map_location)
|
try:
|
||||||
|
checkpoint = load_torch(name, map_location=map_location)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("error loading pickle tensor: %s", e)
|
||||||
elif extension in ["onnx", "pt"]:
|
elif extension in ["onnx", "pt"]:
|
||||||
logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension)
|
logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension)
|
||||||
checkpoint = load_torch(name, map_location=map_location)
|
try:
|
||||||
|
checkpoint = load_torch(name, map_location=map_location)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("error loading tensor: %s", e)
|
||||||
else:
|
else:
|
||||||
logger.warning("unknown tensor type, falling back to PyTorch: %s", extension)
|
logger.warning("unknown tensor type, falling back to PyTorch: %s", extension)
|
||||||
checkpoint = load_torch(name, map_location=map_location)
|
try:
|
||||||
|
checkpoint = load_torch(name, map_location=map_location)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("error loading tensor: %s", e)
|
||||||
|
|
||||||
if "state_dict" in checkpoint:
|
if checkpoint is not None and "state_dict" in checkpoint:
|
||||||
checkpoint = checkpoint["state_dict"]
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
Loading…
Reference in New Issue