1
0
Fork 0

fix(api): support loading VAE from CKPT files

This commit is contained in:
Sean Sube 2023-02-16 20:18:42 -06:00
parent 4b6be765a6
commit ca1b22d44d
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 28 additions and 19 deletions

View File

@ -55,7 +55,7 @@ from transformers import (
) )
from .diffusion_stable import convert_diffusion_stable from .diffusion_stable import convert_diffusion_stable
from .utils import ConversionContext, ModelDict, load_yaml, sanitize_name from .utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name
logger = getLogger(__name__) logger = getLogger(__name__)
@ -634,13 +634,13 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
return new_checkpoint, has_ema return new_checkpoint, has_ema
def convert_ldm_vae_checkpoint(checkpoint, config, use_key=True): def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True):
# extract state dict for VAE # extract state dict for VAE
vae_state_dict = {} vae_state_dict = {}
vae_key = "first_stage_model." vae_key = "first_stage_model."
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
for key in keys: for key in keys:
if use_key: if first_stage:
if key.startswith(vae_key): if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
else: else:
@ -1190,24 +1190,11 @@ def extract_checkpoint(
original_config_file = None original_config_file = None
try: try:
checkpoint = None
map_location = torch.device("cpu") map_location = torch.device("cpu")
# Try to determine if v1 or v2 model if we have a ckpt # Try to determine if v1 or v2 model if we have a ckpt
logger.info("loading model from checkpoint") logger.info("loading model from checkpoint")
_, extension = os.path.splitext(checkpoint_file) checkpoint = load_tensor(checkpoint_file, map_location=map_location)
if extension.lower() == ".safetensors":
os.environ["SAFETENSORS_FAST_GPU"] = "1"
try:
logger.debug("loading safetensors")
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
except Exception as e:
logger.warn("Failed to load as safetensors file, falling back to torch...", e)
checkpoint = torch.jit.load(checkpoint_file)
else:
logger.debug("loading ckpt")
checkpoint = torch.load(checkpoint_file, map_location=map_location)
checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
rev_keys = ["db_global_step", "global_step"] rev_keys = ["db_global_step", "global_step"]
epoch_keys = ["db_epoch", "epoch"] epoch_keys = ["db_epoch", "epoch"]
@ -1315,8 +1302,8 @@ def extract_checkpoint(
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
else: else:
vae_file = os.path.join(ctx.model_path, vae_file) vae_file = os.path.join(ctx.model_path, vae_file)
vae_checkpoint = safetensors.torch.load_file(vae_file, device="cpu") vae_checkpoint = load_tensor(vae_file, map_location=map_location)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, use_key=False) converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False)
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)

View File

@ -6,6 +6,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import requests import requests
import safetensors
import torch import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
from yaml import safe_load from yaml import safe_load
@ -199,3 +200,24 @@ def remove_prefix(name, prefix):
return name[len(prefix) :] return name[len(prefix) :]
return name return name
def load_tensor(name: str, map_location=None):
logger.info("loading model from checkpoint")
_, extension = path.splitext(name)
if extension.lower() == ".safetensors":
environ["SAFETENSORS_FAST_GPU"] = "1"
try:
logger.debug("loading safetensors")
checkpoint = safetensors.torch.load_file(name, device="cpu")
except Exception as e:
logger.warning(
"failed to load as safetensors file, falling back to torch", e
)
checkpoint = torch.jit.load(name)
else:
logger.debug("loading ckpt")
checkpoint = torch.load(name, map_location=map_location)
checkpoint = (
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
)