diff --git a/api/onnx_web/convert/diffusion_original.py b/api/onnx_web/convert/diffusion_original.py index 7c7dcf3b..9dd80a51 100644 --- a/api/onnx_web/convert/diffusion_original.py +++ b/api/onnx_web/convert/diffusion_original.py @@ -55,7 +55,7 @@ from transformers import ( ) from .diffusion_stable import convert_diffusion_stable -from .utils import ConversionContext, ModelDict, load_yaml, sanitize_name +from .utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name logger = getLogger(__name__) @@ -634,13 +634,13 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False return new_checkpoint, has_ema -def convert_ldm_vae_checkpoint(checkpoint, config, use_key=True): +def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True): # extract state dict for VAE vae_state_dict = {} vae_key = "first_stage_model." keys = list(checkpoint.keys()) for key in keys: - if use_key: + if first_stage: if key.startswith(vae_key): vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) else: @@ -1190,24 +1190,11 @@ def extract_checkpoint( original_config_file = None try: - checkpoint = None map_location = torch.device("cpu") # Try to determine if v1 or v2 model if we have a ckpt logger.info("loading model from checkpoint") - _, extension = os.path.splitext(checkpoint_file) - if extension.lower() == ".safetensors": - os.environ["SAFETENSORS_FAST_GPU"] = "1" - try: - logger.debug("loading safetensors") - checkpoint = safetensors.torch.load_file(checkpoint_file, device="cpu") - except Exception as e: - logger.warn("Failed to load as safetensors file, falling back to torch...", e) - checkpoint = torch.jit.load(checkpoint_file) - else: - logger.debug("loading ckpt") - checkpoint = torch.load(checkpoint_file, map_location=map_location) - checkpoint = checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint + checkpoint = load_tensor(checkpoint_file, map_location=map_location) rev_keys = ["db_global_step", "global_step"] epoch_keys = ["db_epoch", "epoch"] @@ -1315,8 +1302,8 @@ def extract_checkpoint( converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) else: vae_file = os.path.join(ctx.model_path, vae_file) - vae_checkpoint = safetensors.torch.load_file(vae_file, device="cpu") - converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, use_key=False) + vae_checkpoint = load_tensor(vae_file, map_location=map_location) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False) vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 1062737a..a150b1b6 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import requests +import safetensors import torch from tqdm.auto import tqdm from yaml import safe_load @@ -199,3 +200,24 @@ def remove_prefix(name, prefix): return name[len(prefix) :] return name + + +def load_tensor(name: str, map_location=None): + logger.info("loading model from checkpoint") + _, extension = path.splitext(name) + if extension.lower() == ".safetensors": + environ["SAFETENSORS_FAST_GPU"] = "1" + try: + logger.debug("loading safetensors") + checkpoint = safetensors.torch.load_file(name, device="cpu") + except Exception as e: + logger.warning( + "failed to load as safetensors file, falling back to torch", e + ) + checkpoint = torch.jit.load(name) + else: + logger.debug("loading ckpt") + checkpoint = torch.load(name, map_location=map_location) + checkpoint = ( + checkpoint["state_dict"] if "state_dict" in checkpoint else checkpoint + )