fix(api): support loading VAE from CKPT files
This commit is contained in:
parent
4b6be765a6
commit
ca1b22d44d
|
@ -55,7 +55,7 @@ from transformers import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .diffusion_stable import convert_diffusion_stable
|
from .diffusion_stable import convert_diffusion_stable
|
||||||
from .utils import ConversionContext, ModelDict, load_yaml, sanitize_name
|
from .utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -634,13 +634,13 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||||
return new_checkpoint, has_ema
|
return new_checkpoint, has_ema
|
||||||
|
|
||||||
|
|
||||||
def convert_ldm_vae_checkpoint(checkpoint, config, use_key=True):
|
def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True):
|
||||||
# extract state dict for VAE
|
# extract state dict for VAE
|
||||||
vae_state_dict = {}
|
vae_state_dict = {}
|
||||||
vae_key = "first_stage_model."
|
vae_key = "first_stage_model."
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if use_key:
|
if first_stage:
|
||||||
if key.startswith(vae_key):
|
if key.startswith(vae_key):
|
||||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||||
else:
|
else:
|
||||||
|
@ -1190,24 +1190,11 @@ def extract_checkpoint(
|
||||||
original_config_file = None
|
original_config_file = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
checkpoint = None
|
|
||||||
map_location = torch.device("cpu")
|
map_location = torch.device("cpu")
|
||||||
|
|
||||||
# Try to determine if v1 or v2 model if we have a ckpt
|
# Try to determine if v1 or v2 model if we have a ckpt
|
||||||
logger.info("loading model from checkpoint")
|
logger.info("loading model from checkpoint")
|
||||||
_, extension = os.path.splitext(checkpoint_file)
|
checkpoint = load_tensor(checkpoint_file, map_location=map_location)
|
||||||
if extension.lower() == ".safetensors":
|
|
||||||
os.environ["SAFETENSORS_FAST_GPU"] = "1"
|
|
||||||
try:
|
|
||||||
logger.debug("loading safetensors")
|
|
||||||
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warn("Failed to load as safetensors file, falling back to torch...", e)
|
|
||||||
checkpoint = torch.jit.load(checkpoint_file)
|
|
||||||
else:
|
|
||||||
logger.debug("loading ckpt")
|
|
||||||
checkpoint = torch.load(checkpoint_file, map_location=map_location)
|
|
||||||
checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
|
||||||
|
|
||||||
rev_keys = ["db_global_step", "global_step"]
|
rev_keys = ["db_global_step", "global_step"]
|
||||||
epoch_keys = ["db_epoch", "epoch"]
|
epoch_keys = ["db_epoch", "epoch"]
|
||||||
|
@ -1315,8 +1302,8 @@ def extract_checkpoint(
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||||
else:
|
else:
|
||||||
vae_file = os.path.join(ctx.model_path, vae_file)
|
vae_file = os.path.join(ctx.model_path, vae_file)
|
||||||
vae_checkpoint = safetensors.torch.load_file(vae_file, device="cpu")
|
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, use_key=False)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False)
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
|
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
|
@ -199,3 +200,24 @@ def remove_prefix(name, prefix):
|
||||||
return name[len(prefix) :]
|
return name[len(prefix) :]
|
||||||
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def load_tensor(name: str, map_location=None):
|
||||||
|
logger.info("loading model from checkpoint")
|
||||||
|
_, extension = path.splitext(name)
|
||||||
|
if extension.lower() == ".safetensors":
|
||||||
|
environ["SAFETENSORS_FAST_GPU"] = "1"
|
||||||
|
try:
|
||||||
|
logger.debug("loading safetensors")
|
||||||
|
checkpoint = safetensors.torch.load_file(name, device="cpu")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"failed to load as safetensors file, falling back to torch", e
|
||||||
|
)
|
||||||
|
checkpoint = torch.jit.load(name)
|
||||||
|
else:
|
||||||
|
logger.debug("loading ckpt")
|
||||||
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
|
checkpoint = (
|
||||||
|
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue