1
0
Fork 0

dump unet to torch before making cnet

This commit is contained in:
Sean Sube 2023-12-19 22:50:18 -06:00
parent 4a87fb2a31
commit c8a9734acf
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 16 additions and 10 deletions

View File

@ -987,7 +987,7 @@ def convert_ldm_clip_checkpoint(checkpoint):
key[len("cond_stage_model.transformer.") :] key[len("cond_stage_model.transformer.") :]
] = checkpoint[key] ] = checkpoint[key]
text_model.load_state_dict(text_model_dict) text_model.load_state_dict(text_model_dict, strict=False)
return text_model return text_model
@ -1155,7 +1155,7 @@ def convert_open_clip_checkpoint(checkpoint):
text_model_dict[new_key] = checkpoint[key] text_model_dict[new_key] = checkpoint[key]
text_model.load_state_dict(text_model_dict) text_model.load_state_dict(text_model_dict, strict=False)
return text_model return text_model
@ -1548,7 +1548,7 @@ def extract_checkpoint(
) )
db_config.has_ema = has_ema db_config.has_ema = has_ema
db_config.save() db_config.save()
unet.load_state_dict(converted_unet_checkpoint) unet.load_state_dict(converted_unet_checkpoint, strict=False)
# Convert the VAE model. # Convert the VAE model.
logger.info("converting VAE") logger.info("converting VAE")
@ -1567,7 +1567,7 @@ def extract_checkpoint(
) )
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint, strict=False)
# Convert the text model. # Convert the text model.
logger.info("converting text encoder") logger.info("converting text encoder")

View File

@ -10,7 +10,7 @@
### ###
from logging import getLogger from logging import getLogger
from os import mkdir, path from os import path, makedirs
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
@ -113,6 +113,7 @@ def get_model_version(
@torch.no_grad() @torch.no_grad()
def convert_diffusion_diffusers_cnet( def convert_diffusion_diffusers_cnet(
conversion: ConversionContext, conversion: ConversionContext,
name: str,
source: str, source: str,
device: str, device: str,
output_path: Path, output_path: Path,
@ -124,13 +125,16 @@ def convert_diffusion_diffusers_cnet(
unet: Optional[Any] = None, unet: Optional[Any] = None,
v2: Optional[bool] = False, v2: Optional[bool] = False,
): ):
# CNet
if unet is not None: if unet is not None:
logger.debug("creating CNet from existing UNet config") logger.debug("creating CNet from existing UNet config")
pipe_cnet = UNet2DConditionModel_CNet.from_config(unet.config) cnet_tmp = path.join(conversion.cache_path, f"{name}-cnet")
makedirs(cnet_tmp, exist_ok=True)
unet.save_pretrained(cnet_tmp)
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(cnet_tmp, low_cpu_mem_usage=False)
else: else:
logger.debug("loading CNet from pretrained UNet config") logger.debug("loading CNet from pretrained UNet config")
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet") pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet", low_cpu_mem_usage=False)
pipe_cnet = pipe_cnet.to(device=device, dtype=dtype) pipe_cnet = pipe_cnet.to(device=device, dtype=dtype)
run_gc() run_gc()
@ -308,7 +312,7 @@ def collate_cnet(cnet_path):
# clean up existing tensor files # clean up existing tensor files
rmtree(cnet_dir) rmtree(cnet_dir)
mkdir(cnet_dir) makedirs(cnet_dir)
# collate external tensor files into one # collate external tensor files into one
save_model( save_model(
@ -548,6 +552,7 @@ def convert_diffusion_diffusers(
logger.debug("converting CNet from loaded UNet") logger.debug("converting CNet from loaded UNet")
cnet_path = convert_diffusion_diffusers_cnet( cnet_path = convert_diffusion_diffusers_cnet(
conversion, conversion,
name,
source, source,
device, device,
output_path, output_path,
@ -568,6 +573,7 @@ def convert_diffusion_diffusers(
logger.info("loading and converting CNet from %s", cnet_source) logger.info("loading and converting CNet from %s", cnet_source)
cnet_path = convert_diffusion_diffusers_cnet( cnet_path = convert_diffusion_diffusers_cnet(
conversion, conversion,
name,
cnet_source, cnet_source,
device, device,
output_path, output_path,
@ -594,7 +600,7 @@ def convert_diffusion_diffusers(
# clean up existing tensor files # clean up existing tensor files
rmtree(unet_dir) rmtree(unet_dir)
mkdir(unet_dir) makedirs(unet_dir)
# collate external tensor files into one # collate external tensor files into one
save_model( save_model(