1
0
Fork 0

fix(api): do not assume file extension for additional networks

This commit is contained in:
Sean Sube 2023-03-19 15:27:51 -05:00
parent 33008531e9
commit b797b3d616
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 40 additions and 22 deletions

View File

@ -198,35 +198,53 @@ def remove_prefix(name: str, prefix: str) -> str:
return name
def load_tensor(name: str, map_location=None):
def load_torch(name: str, map_location=None) -> Optional[Dict]:
try:
logger.debug(
"loading tensor with Torch JIT: %s", name
)
checkpoint = torch.jit.load(name)
except Exception:
logger.exception(
"error loading with Torch JIT, falling back to Torch: %s", name
)
checkpoint = torch.load(name, map_location=map_location)
return checkpoint
def load_tensor(name: str, map_location=None) -> Optional[Dict]:
if not path.exists(name):
logger.warning("tensor does not exist: %s", name)
return None
logger.debug("loading tensor: %s", name)
_, extension = path.splitext(name)
extension = extension[1:].lower()
if extension == "safetensors":
environ["SAFETENSORS_FAST_GPU"] = "1"
try:
logger.debug("loading safetensors")
checkpoint = safetensors.torch.load_file(name, device="cpu")
except Exception as e:
try:
logger.warning(
"failed to load as safetensors file, falling back to Torch JIT: %s", e
)
checkpoint = torch.jit.load(name)
except Exception as e:
logger.warning(
"failed to load with Torch JIT, falling back to PyTorch: %s", e
)
checkpoint = torch.load(name, map_location=map_location)
elif extension in ["", "bin", "ckpt", "pt"]:
logger.debug("loading ckpt")
checkpoint = torch.load(name, map_location=map_location)
elif extension in ["onnx", "pt"]:
logger.warning("unknown tensor extension, may be ONNX model: %s", extension)
if extension == "":
# if no extension was intentional, do not search for others
if path.exists(name):
logger.debug("loading anonymous tensor")
checkpoint = torch.load(name, map_location=map_location)
else:
logger.warning("unknown tensor extension: %s", extension)
logger.debug("searching for tensors with known extensions")
for try_extension in ["safetensors", "ckpt", "pt", "bin"]:
checkpoint = load_tensor(f"{name}.{try_extension}", map_location=map_location)
if checkpoint is not None:
break
elif extension == "safetensors":
environ["SAFETENSORS_FAST_GPU"] = "1"
logger.debug("loading safetensors")
checkpoint = safetensors.torch.load_file(name, device="cpu")
elif extension in ["bin", "ckpt", "pt"]:
logger.debug("loading pickle tensor")
checkpoint = torch.load(name, map_location=map_location)
elif extension in ["onnx", "pt"]:
logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension)
checkpoint = torch.load(name, map_location=map_location)
else:
logger.warning("unknown tensor type, falling back to PyTorch: %s", extension)
checkpoint = torch.load(name, map_location=map_location)
if "state_dict" in checkpoint:

View File

@ -188,7 +188,7 @@ def load_pipeline(
inversion_names, inversion_weights = zip(*inversions)
inversion_models = [
path.join(server.model_path, "inversion", f"{name}.ckpt")
path.join(server.model_path, "inversion", name)
for name in inversion_names
]
text_encoder = load_model(path.join(model, "text_encoder", "model.onnx"))
@ -226,7 +226,7 @@ def load_pipeline(
if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras)
lora_models = [
path.join(server.model_path, "lora", f"{name}.safetensors")
path.join(server.model_path, "lora", name)
for name in lora_names
]
logger.info(