diff --git a/api/onnx_web/convert/diffusion/checkpoint.py b/api/onnx_web/convert/diffusion/checkpoint.py index c9b4f103..ee2731a1 100644 --- a/api/onnx_web/convert/diffusion/checkpoint.py +++ b/api/onnx_web/convert/diffusion/checkpoint.py @@ -987,7 +987,7 @@ def convert_ldm_clip_checkpoint(checkpoint): key[len("cond_stage_model.transformer.") :] ] = checkpoint[key] - text_model.load_state_dict(text_model_dict) + text_model.load_state_dict(text_model_dict, strict=False) return text_model @@ -1155,7 +1155,7 @@ def convert_open_clip_checkpoint(checkpoint): text_model_dict[new_key] = checkpoint[key] - text_model.load_state_dict(text_model_dict) + text_model.load_state_dict(text_model_dict, strict=False) return text_model @@ -1548,7 +1548,7 @@ def extract_checkpoint( ) db_config.has_ema = has_ema db_config.save() - unet.load_state_dict(converted_unet_checkpoint) + unet.load_state_dict(converted_unet_checkpoint, strict=False) # Convert the VAE model. logger.info("converting VAE") @@ -1567,7 +1567,7 @@ def extract_checkpoint( ) vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) + vae.load_state_dict(converted_vae_checkpoint, strict=False) # Convert the text model. logger.info("converting text encoder") diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index 07b18d10..b6f43e2c 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -10,7 +10,7 @@ ### from logging import getLogger -from os import mkdir, path +from os import path, makedirs from pathlib import Path from shutil import rmtree from typing import Any, Dict, Optional, Tuple, Union @@ -113,6 +113,7 @@ def get_model_version( @torch.no_grad() def convert_diffusion_diffusers_cnet( conversion: ConversionContext, + name: str, source: str, device: str, output_path: Path, @@ -124,13 +125,16 @@ def convert_diffusion_diffusers_cnet( unet: Optional[Any] = None, v2: Optional[bool] = False, ): - # CNet if unet is not None: logger.debug("creating CNet from existing UNet config") - pipe_cnet = UNet2DConditionModel_CNet.from_config(unet.config) + cnet_tmp = path.join(conversion.cache_path, f"{name}-cnet") + makedirs(cnet_tmp, exist_ok=True) + + unet.save_pretrained(cnet_tmp) + pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(cnet_tmp, low_cpu_mem_usage=False) else: logger.debug("loading CNet from pretrained UNet config") - pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet") + pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet", low_cpu_mem_usage=False) pipe_cnet = pipe_cnet.to(device=device, dtype=dtype) run_gc() @@ -308,7 +312,7 @@ def collate_cnet(cnet_path): # clean up existing tensor files rmtree(cnet_dir) - mkdir(cnet_dir) + makedirs(cnet_dir) # collate external tensor files into one save_model( @@ -548,6 +552,7 @@ def convert_diffusion_diffusers( logger.debug("converting CNet from loaded UNet") cnet_path = convert_diffusion_diffusers_cnet( conversion, + name, source, device, output_path, @@ -568,6 +573,7 @@ def convert_diffusion_diffusers( logger.info("loading and converting CNet from %s", cnet_source) cnet_path = convert_diffusion_diffusers_cnet( conversion, + name, cnet_source, device, output_path, @@ -594,7 +600,7 @@ def convert_diffusion_diffusers( # clean up existing tensor files rmtree(unet_dir) - mkdir(unet_dir) + makedirs(unet_dir) # collate external tensor files into one save_model(