fix(api): do not assume file extension for additional networks
This commit is contained in:
parent
33008531e9
commit
b797b3d616
|
@ -198,35 +198,53 @@ def remove_prefix(name: str, prefix: str) -> str:
|
|||
return name
|
||||
|
||||
|
||||
def load_tensor(name: str, map_location=None):
|
||||
def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
||||
try:
|
||||
logger.debug(
|
||||
"loading tensor with Torch JIT: %s", name
|
||||
)
|
||||
checkpoint = torch.jit.load(name)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error loading with Torch JIT, falling back to Torch: %s", name
|
||||
)
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
||||
if not path.exists(name):
|
||||
logger.warning("tensor does not exist: %s", name)
|
||||
return None
|
||||
|
||||
logger.debug("loading tensor: %s", name)
|
||||
_, extension = path.splitext(name)
|
||||
extension = extension[1:].lower()
|
||||
|
||||
if extension == "safetensors":
|
||||
if extension == "":
|
||||
# if no extension was intentional, do not search for others
|
||||
if path.exists(name):
|
||||
logger.debug("loading anonymous tensor")
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
else:
|
||||
logger.debug("searching for tensors with known extensions")
|
||||
for try_extension in ["safetensors", "ckpt", "pt", "bin"]:
|
||||
checkpoint = load_tensor(f"{name}.{try_extension}", map_location=map_location)
|
||||
if checkpoint is not None:
|
||||
break
|
||||
elif extension == "safetensors":
|
||||
environ["SAFETENSORS_FAST_GPU"] = "1"
|
||||
try:
|
||||
logger.debug("loading safetensors")
|
||||
checkpoint = safetensors.torch.load_file(name, device="cpu")
|
||||
except Exception as e:
|
||||
try:
|
||||
logger.warning(
|
||||
"failed to load as safetensors file, falling back to Torch JIT: %s", e
|
||||
)
|
||||
checkpoint = torch.jit.load(name)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"failed to load with Torch JIT, falling back to PyTorch: %s", e
|
||||
)
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
elif extension in ["", "bin", "ckpt", "pt"]:
|
||||
logger.debug("loading ckpt")
|
||||
logger.debug("loading safetensors")
|
||||
checkpoint = safetensors.torch.load_file(name, device="cpu")
|
||||
elif extension in ["bin", "ckpt", "pt"]:
|
||||
logger.debug("loading pickle tensor")
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
elif extension in ["onnx", "pt"]:
|
||||
logger.warning("unknown tensor extension, may be ONNX model: %s", extension)
|
||||
logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension)
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
else:
|
||||
logger.warning("unknown tensor extension: %s", extension)
|
||||
logger.warning("unknown tensor type, falling back to PyTorch: %s", extension)
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
|
||||
if "state_dict" in checkpoint:
|
||||
|
|
|
@ -188,7 +188,7 @@ def load_pipeline(
|
|||
inversion_names, inversion_weights = zip(*inversions)
|
||||
|
||||
inversion_models = [
|
||||
path.join(server.model_path, "inversion", f"{name}.ckpt")
|
||||
path.join(server.model_path, "inversion", name)
|
||||
for name in inversion_names
|
||||
]
|
||||
text_encoder = load_model(path.join(model, "text_encoder", "model.onnx"))
|
||||
|
@ -226,7 +226,7 @@ def load_pipeline(
|
|||
if loras is not None and len(loras) > 0:
|
||||
lora_names, lora_weights = zip(*loras)
|
||||
lora_models = [
|
||||
path.join(server.model_path, "lora", f"{name}.safetensors")
|
||||
path.join(server.model_path, "lora", name)
|
||||
for name in lora_names
|
||||
]
|
||||
logger.info(
|
||||
|
|
Loading…
Reference in New Issue