dump unet to torch before making cnet
This commit is contained in:
parent
4a87fb2a31
commit
c8a9734acf
|
@ -987,7 +987,7 @@ def convert_ldm_clip_checkpoint(checkpoint):
|
|||
key[len("cond_stage_model.transformer.") :]
|
||||
] = checkpoint[key]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
text_model.load_state_dict(text_model_dict, strict=False)
|
||||
|
||||
return text_model
|
||||
|
||||
|
@ -1155,7 +1155,7 @@ def convert_open_clip_checkpoint(checkpoint):
|
|||
|
||||
text_model_dict[new_key] = checkpoint[key]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
text_model.load_state_dict(text_model_dict, strict=False)
|
||||
|
||||
return text_model
|
||||
|
||||
|
@ -1548,7 +1548,7 @@ def extract_checkpoint(
|
|||
)
|
||||
db_config.has_ema = has_ema
|
||||
db_config.save()
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
||||
|
||||
# Convert the VAE model.
|
||||
logger.info("converting VAE")
|
||||
|
@ -1567,7 +1567,7 @@ def extract_checkpoint(
|
|||
)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
vae.load_state_dict(converted_vae_checkpoint, strict=False)
|
||||
|
||||
# Convert the text model.
|
||||
logger.info("converting text encoder")
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
###
|
||||
|
||||
from logging import getLogger
|
||||
from os import mkdir, path
|
||||
from os import path, makedirs
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
@ -113,6 +113,7 @@ def get_model_version(
|
|||
@torch.no_grad()
|
||||
def convert_diffusion_diffusers_cnet(
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
source: str,
|
||||
device: str,
|
||||
output_path: Path,
|
||||
|
@ -124,13 +125,16 @@ def convert_diffusion_diffusers_cnet(
|
|||
unet: Optional[Any] = None,
|
||||
v2: Optional[bool] = False,
|
||||
):
|
||||
# CNet
|
||||
if unet is not None:
|
||||
logger.debug("creating CNet from existing UNet config")
|
||||
pipe_cnet = UNet2DConditionModel_CNet.from_config(unet.config)
|
||||
cnet_tmp = path.join(conversion.cache_path, f"{name}-cnet")
|
||||
makedirs(cnet_tmp, exist_ok=True)
|
||||
|
||||
unet.save_pretrained(cnet_tmp)
|
||||
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(cnet_tmp, low_cpu_mem_usage=False)
|
||||
else:
|
||||
logger.debug("loading CNet from pretrained UNet config")
|
||||
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet")
|
||||
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet", low_cpu_mem_usage=False)
|
||||
|
||||
pipe_cnet = pipe_cnet.to(device=device, dtype=dtype)
|
||||
run_gc()
|
||||
|
@ -308,7 +312,7 @@ def collate_cnet(cnet_path):
|
|||
|
||||
# clean up existing tensor files
|
||||
rmtree(cnet_dir)
|
||||
mkdir(cnet_dir)
|
||||
makedirs(cnet_dir)
|
||||
|
||||
# collate external tensor files into one
|
||||
save_model(
|
||||
|
@ -548,6 +552,7 @@ def convert_diffusion_diffusers(
|
|||
logger.debug("converting CNet from loaded UNet")
|
||||
cnet_path = convert_diffusion_diffusers_cnet(
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
device,
|
||||
output_path,
|
||||
|
@ -568,6 +573,7 @@ def convert_diffusion_diffusers(
|
|||
logger.info("loading and converting CNet from %s", cnet_source)
|
||||
cnet_path = convert_diffusion_diffusers_cnet(
|
||||
conversion,
|
||||
name,
|
||||
cnet_source,
|
||||
device,
|
||||
output_path,
|
||||
|
@ -594,7 +600,7 @@ def convert_diffusion_diffusers(
|
|||
|
||||
# clean up existing tensor files
|
||||
rmtree(unet_dir)
|
||||
mkdir(unet_dir)
|
||||
makedirs(unet_dir)
|
||||
|
||||
# collate external tensor files into one
|
||||
save_model(
|
||||
|
|
Loading…
Reference in New Issue