fix(api): do not assume file extension for additional networks
This commit is contained in:
parent
33008531e9
commit
b797b3d616
|
@ -198,35 +198,53 @@ def remove_prefix(name: str, prefix: str) -> str:
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def load_tensor(name: str, map_location=None):
|
def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
"loading tensor with Torch JIT: %s", name
|
||||||
|
)
|
||||||
|
checkpoint = torch.jit.load(name)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"error loading with Torch JIT, falling back to Torch: %s", name
|
||||||
|
)
|
||||||
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
||||||
|
if not path.exists(name):
|
||||||
|
logger.warning("tensor does not exist: %s", name)
|
||||||
|
return None
|
||||||
|
|
||||||
logger.debug("loading tensor: %s", name)
|
logger.debug("loading tensor: %s", name)
|
||||||
_, extension = path.splitext(name)
|
_, extension = path.splitext(name)
|
||||||
extension = extension[1:].lower()
|
extension = extension[1:].lower()
|
||||||
|
|
||||||
if extension == "safetensors":
|
if extension == "":
|
||||||
|
# if no extension was intentional, do not search for others
|
||||||
|
if path.exists(name):
|
||||||
|
logger.debug("loading anonymous tensor")
|
||||||
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
|
else:
|
||||||
|
logger.debug("searching for tensors with known extensions")
|
||||||
|
for try_extension in ["safetensors", "ckpt", "pt", "bin"]:
|
||||||
|
checkpoint = load_tensor(f"{name}.{try_extension}", map_location=map_location)
|
||||||
|
if checkpoint is not None:
|
||||||
|
break
|
||||||
|
elif extension == "safetensors":
|
||||||
environ["SAFETENSORS_FAST_GPU"] = "1"
|
environ["SAFETENSORS_FAST_GPU"] = "1"
|
||||||
try:
|
logger.debug("loading safetensors")
|
||||||
logger.debug("loading safetensors")
|
checkpoint = safetensors.torch.load_file(name, device="cpu")
|
||||||
checkpoint = safetensors.torch.load_file(name, device="cpu")
|
elif extension in ["bin", "ckpt", "pt"]:
|
||||||
except Exception as e:
|
logger.debug("loading pickle tensor")
|
||||||
try:
|
|
||||||
logger.warning(
|
|
||||||
"failed to load as safetensors file, falling back to Torch JIT: %s", e
|
|
||||||
)
|
|
||||||
checkpoint = torch.jit.load(name)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
"failed to load with Torch JIT, falling back to PyTorch: %s", e
|
|
||||||
)
|
|
||||||
checkpoint = torch.load(name, map_location=map_location)
|
|
||||||
elif extension in ["", "bin", "ckpt", "pt"]:
|
|
||||||
logger.debug("loading ckpt")
|
|
||||||
checkpoint = torch.load(name, map_location=map_location)
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
elif extension in ["onnx", "pt"]:
|
elif extension in ["onnx", "pt"]:
|
||||||
logger.warning("unknown tensor extension, may be ONNX model: %s", extension)
|
logger.warning("tensor has ONNX extension, falling back to PyTorch: %s", extension)
|
||||||
checkpoint = torch.load(name, map_location=map_location)
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
else:
|
else:
|
||||||
logger.warning("unknown tensor extension: %s", extension)
|
logger.warning("unknown tensor type, falling back to PyTorch: %s", extension)
|
||||||
checkpoint = torch.load(name, map_location=map_location)
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
|
|
||||||
if "state_dict" in checkpoint:
|
if "state_dict" in checkpoint:
|
||||||
|
|
|
@ -188,7 +188,7 @@ def load_pipeline(
|
||||||
inversion_names, inversion_weights = zip(*inversions)
|
inversion_names, inversion_weights = zip(*inversions)
|
||||||
|
|
||||||
inversion_models = [
|
inversion_models = [
|
||||||
path.join(server.model_path, "inversion", f"{name}.ckpt")
|
path.join(server.model_path, "inversion", name)
|
||||||
for name in inversion_names
|
for name in inversion_names
|
||||||
]
|
]
|
||||||
text_encoder = load_model(path.join(model, "text_encoder", "model.onnx"))
|
text_encoder = load_model(path.join(model, "text_encoder", "model.onnx"))
|
||||||
|
@ -226,7 +226,7 @@ def load_pipeline(
|
||||||
if loras is not None and len(loras) > 0:
|
if loras is not None and len(loras) > 0:
|
||||||
lora_names, lora_weights = zip(*loras)
|
lora_names, lora_weights = zip(*loras)
|
||||||
lora_models = [
|
lora_models = [
|
||||||
path.join(server.model_path, "lora", f"{name}.safetensors")
|
path.join(server.model_path, "lora", name)
|
||||||
for name in lora_names
|
for name in lora_names
|
||||||
]
|
]
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
Loading…
Reference in New Issue