1
0
Fork 0

fix(api): support loading VAE from CKPT files

This commit is contained in:
Sean Sube 2023-02-16 20:18:42 -06:00
parent 4b6be765a6
commit ca1b22d44d
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 28 additions and 19 deletions

View File

@ -55,7 +55,7 @@ from transformers import (
)
from .diffusion_stable import convert_diffusion_stable
from .utils import ConversionContext, ModelDict, load_yaml, sanitize_name
from .utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name
logger = getLogger(__name__)
@ -634,13 +634,13 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
return new_checkpoint, has_ema
def convert_ldm_vae_checkpoint(checkpoint, config, use_key=True):
def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True):
# extract state dict for VAE
vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys())
for key in keys:
if use_key:
if first_stage:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
else:
@ -1190,24 +1190,11 @@ def extract_checkpoint(
original_config_file = None
try:
checkpoint = None
map_location = torch.device("cpu")
# Try to determine if v1 or v2 model if we have a ckpt
logger.info("loading model from checkpoint")
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
os.environ["SAFETENSORS_FAST_GPU"] = "1"
try:
logger.debug("loading safetensors")
checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu")
except Exception as e:
logger.warn("Failed to load as safetensors file, falling back to torch...", e)
checkpoint = torch.jit.load(checkpoint_file)
else:
logger.debug("loading ckpt")
checkpoint = torch.load(checkpoint_file, map_location=map_location)
checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
checkpoint = load_tensor(checkpoint_file, map_location=map_location)
rev_keys = ["db_global_step", "global_step"]
epoch_keys = ["db_epoch", "epoch"]
@ -1315,8 +1302,8 @@ def extract_checkpoint(
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
else:
vae_file = os.path.join(ctx.model_path, vae_file)
vae_checkpoint = safetensors.torch.load_file(vae_file, device="cpu")
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, use_key=False)
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)

View File

@ -6,6 +6,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import requests
import safetensors
import torch
from tqdm.auto import tqdm
from yaml import safe_load
@ -199,3 +200,24 @@ def remove_prefix(name, prefix):
return name[len(prefix) :]
return name
def load_tensor(name: str, map_location=None):
logger.info("loading model from checkpoint")
_, extension = path.splitext(name)
if extension.lower() == ".safetensors":
environ["SAFETENSORS_FAST_GPU"] = "1"
try:
logger.debug("loading safetensors")
checkpoint = safetensors.torch.load_file(name, device="cpu")
except Exception as e:
logger.warning(
"failed to load as safetensors file, falling back to torch", e
)
checkpoint = torch.jit.load(name)
else:
logger.debug("loading ckpt")
checkpoint = torch.load(name, map_location=map_location)
checkpoint = (
checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint
)