dump unet to torch before making cnet
This commit is contained in:
parent
4a87fb2a31
commit
c8a9734acf
|
@ -987,7 +987,7 @@ def convert_ldm_clip_checkpoint(checkpoint):
|
||||||
key[len("cond_stage_model.transformer.") :]
|
key[len("cond_stage_model.transformer.") :]
|
||||||
] = checkpoint[key]
|
] = checkpoint[key]
|
||||||
|
|
||||||
text_model.load_state_dict(text_model_dict)
|
text_model.load_state_dict(text_model_dict, strict=False)
|
||||||
|
|
||||||
return text_model
|
return text_model
|
||||||
|
|
||||||
|
@ -1155,7 +1155,7 @@ def convert_open_clip_checkpoint(checkpoint):
|
||||||
|
|
||||||
text_model_dict[new_key] = checkpoint[key]
|
text_model_dict[new_key] = checkpoint[key]
|
||||||
|
|
||||||
text_model.load_state_dict(text_model_dict)
|
text_model.load_state_dict(text_model_dict, strict=False)
|
||||||
|
|
||||||
return text_model
|
return text_model
|
||||||
|
|
||||||
|
@ -1548,7 +1548,7 @@ def extract_checkpoint(
|
||||||
)
|
)
|
||||||
db_config.has_ema = has_ema
|
db_config.has_ema = has_ema
|
||||||
db_config.save()
|
db_config.save()
|
||||||
unet.load_state_dict(converted_unet_checkpoint)
|
unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
||||||
|
|
||||||
# Convert the VAE model.
|
# Convert the VAE model.
|
||||||
logger.info("converting VAE")
|
logger.info("converting VAE")
|
||||||
|
@ -1567,7 +1567,7 @@ def extract_checkpoint(
|
||||||
)
|
)
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint, strict=False)
|
||||||
|
|
||||||
# Convert the text model.
|
# Convert the text model.
|
||||||
logger.info("converting text encoder")
|
logger.info("converting text encoder")
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
###
|
###
|
||||||
|
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import mkdir, path
|
from os import path, makedirs
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
@ -113,6 +113,7 @@ def get_model_version(
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_diffusion_diffusers_cnet(
|
def convert_diffusion_diffusers_cnet(
|
||||||
conversion: ConversionContext,
|
conversion: ConversionContext,
|
||||||
|
name: str,
|
||||||
source: str,
|
source: str,
|
||||||
device: str,
|
device: str,
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
|
@ -124,13 +125,16 @@ def convert_diffusion_diffusers_cnet(
|
||||||
unet: Optional[Any] = None,
|
unet: Optional[Any] = None,
|
||||||
v2: Optional[bool] = False,
|
v2: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
# CNet
|
|
||||||
if unet is not None:
|
if unet is not None:
|
||||||
logger.debug("creating CNet from existing UNet config")
|
logger.debug("creating CNet from existing UNet config")
|
||||||
pipe_cnet = UNet2DConditionModel_CNet.from_config(unet.config)
|
cnet_tmp = path.join(conversion.cache_path, f"{name}-cnet")
|
||||||
|
makedirs(cnet_tmp, exist_ok=True)
|
||||||
|
|
||||||
|
unet.save_pretrained(cnet_tmp)
|
||||||
|
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(cnet_tmp, low_cpu_mem_usage=False)
|
||||||
else:
|
else:
|
||||||
logger.debug("loading CNet from pretrained UNet config")
|
logger.debug("loading CNet from pretrained UNet config")
|
||||||
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet")
|
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet", low_cpu_mem_usage=False)
|
||||||
|
|
||||||
pipe_cnet = pipe_cnet.to(device=device, dtype=dtype)
|
pipe_cnet = pipe_cnet.to(device=device, dtype=dtype)
|
||||||
run_gc()
|
run_gc()
|
||||||
|
@ -308,7 +312,7 @@ def collate_cnet(cnet_path):
|
||||||
|
|
||||||
# clean up existing tensor files
|
# clean up existing tensor files
|
||||||
rmtree(cnet_dir)
|
rmtree(cnet_dir)
|
||||||
mkdir(cnet_dir)
|
makedirs(cnet_dir)
|
||||||
|
|
||||||
# collate external tensor files into one
|
# collate external tensor files into one
|
||||||
save_model(
|
save_model(
|
||||||
|
@ -548,6 +552,7 @@ def convert_diffusion_diffusers(
|
||||||
logger.debug("converting CNet from loaded UNet")
|
logger.debug("converting CNet from loaded UNet")
|
||||||
cnet_path = convert_diffusion_diffusers_cnet(
|
cnet_path = convert_diffusion_diffusers_cnet(
|
||||||
conversion,
|
conversion,
|
||||||
|
name,
|
||||||
source,
|
source,
|
||||||
device,
|
device,
|
||||||
output_path,
|
output_path,
|
||||||
|
@ -568,6 +573,7 @@ def convert_diffusion_diffusers(
|
||||||
logger.info("loading and converting CNet from %s", cnet_source)
|
logger.info("loading and converting CNet from %s", cnet_source)
|
||||||
cnet_path = convert_diffusion_diffusers_cnet(
|
cnet_path = convert_diffusion_diffusers_cnet(
|
||||||
conversion,
|
conversion,
|
||||||
|
name,
|
||||||
cnet_source,
|
cnet_source,
|
||||||
device,
|
device,
|
||||||
output_path,
|
output_path,
|
||||||
|
@ -594,7 +600,7 @@ def convert_diffusion_diffusers(
|
||||||
|
|
||||||
# clean up existing tensor files
|
# clean up existing tensor files
|
||||||
rmtree(unet_dir)
|
rmtree(unet_dir)
|
||||||
mkdir(unet_dir)
|
makedirs(unet_dir)
|
||||||
|
|
||||||
# collate external tensor files into one
|
# collate external tensor files into one
|
||||||
save_model(
|
save_model(
|
||||||
|
|
Loading…
Reference in New Issue