fix(api): support loading VAE from CKPT files
This commit is contained in:
parent
4b6be765a6
commit
ca1b22d44d
|
@ -55,7 +55,7 @@ from transformers import (
|
|||
)
|
||||
|
||||
from .diffusion_stable import convert_diffusion_stable
|
||||
from .utils import ConversionContext, ModelDict, load_yaml, sanitize_name
|
||||
from .utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -634,13 +634,13 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||
return new_checkpoint, has_ema
|
||||
|
||||
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config, use_key=True):
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
keys = list(checkpoint.keys())
|
||||
for key in keys:
|
||||
if use_key:
|
||||
if first_stage:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
else:
|
||||
|
@ -1190,24 +1190,11 @@ def extract_checkpoint(
|
|||
original_config_file = None
|
||||
|
||||
try:
|
||||
checkpoint = None
|
||||
map_location = torch.device("cpu")
|
||||
|
||||
# Try to determine if v1 or v2 model if we have a ckpt
|
||||
logger.info("loading model from checkpoint")
|
||||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if extension.lower() == ".safetensors":
|
||||
os.environ["SAFETENSORS_FAST_GPU"] = "1"
|
||||
try:
|
||||
logger.debug("loading safetensors")
|
||||
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
except Exception as e:
|
||||
logger.warn("Failed to load as safetensors file, falling back to torch...", e)
|
||||
checkpoint = torch.jit.load(checkpoint_file)
|
||||
else:
|
||||
logger.debug("loading ckpt")
|
||||
checkpoint = torch.load(checkpoint_file, map_location=map_location)
|
||||
checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
||||
checkpoint = load_tensor(checkpoint_file, map_location=map_location)
|
||||
|
||||
rev_keys = ["db_global_step", "global_step"]
|
||||
epoch_keys = ["db_epoch", "epoch"]
|
||||
|
@ -1315,8 +1302,8 @@ def extract_checkpoint(
|
|||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
else:
|
||||
vae_file = os.path.join(ctx.model_path, vae_file)
|
||||
vae_checkpoint = safetensors.torch.load_file(vae_file, device="cpu")
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, use_key=False)
|
||||
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
|||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import requests
|
||||
import safetensors
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
from yaml import safe_load
|
||||
|
@ -199,3 +200,24 @@ def remove_prefix(name, prefix):
|
|||
return name[len(prefix) :]
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def load_tensor(name: str, map_location=None):
|
||||
logger.info("loading model from checkpoint")
|
||||
_, extension = path.splitext(name)
|
||||
if extension.lower() == ".safetensors":
|
||||
environ["SAFETENSORS_FAST_GPU"] = "1"
|
||||
try:
|
||||
logger.debug("loading safetensors")
|
||||
checkpoint = safetensors.torch.load_file(name, device="cpu")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"failed to load as safetensors file, falling back to torch", e
|
||||
)
|
||||
checkpoint = torch.jit.load(name)
|
||||
else:
|
||||
logger.debug("loading ckpt")
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
checkpoint = (
|
||||
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue