1
0
Fork 0

dump unet to torch before making cnet

This commit is contained in:
Sean Sube 2023-12-19 22:50:18 -06:00
parent 4a87fb2a31
commit c8a9734acf
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 16 additions and 10 deletions

View File

@ -987,7 +987,7 @@ def convert_ldm_clip_checkpoint(checkpoint):
key[len("cond_stage_model.transformer.") :]
] = checkpoint[key]
text_model.load_state_dict(text_model_dict)
text_model.load_state_dict(text_model_dict, strict=False)
return text_model
@ -1155,7 +1155,7 @@ def convert_open_clip_checkpoint(checkpoint):
text_model_dict[new_key] = checkpoint[key]
text_model.load_state_dict(text_model_dict)
text_model.load_state_dict(text_model_dict, strict=False)
return text_model
@ -1548,7 +1548,7 @@ def extract_checkpoint(
)
db_config.has_ema = has_ema
db_config.save()
unet.load_state_dict(converted_unet_checkpoint)
unet.load_state_dict(converted_unet_checkpoint, strict=False)
# Convert the VAE model.
logger.info("converting VAE")
@ -1567,7 +1567,7 @@ def extract_checkpoint(
)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
vae.load_state_dict(converted_vae_checkpoint, strict=False)
# Convert the text model.
logger.info("converting text encoder")

View File

@ -10,7 +10,7 @@
###
from logging import getLogger
from os import mkdir, path
from os import path, makedirs
from pathlib import Path
from shutil import rmtree
from typing import Any, Dict, Optional, Tuple, Union
@ -113,6 +113,7 @@ def get_model_version(
@torch.no_grad()
def convert_diffusion_diffusers_cnet(
conversion: ConversionContext,
name: str,
source: str,
device: str,
output_path: Path,
@ -124,13 +125,16 @@ def convert_diffusion_diffusers_cnet(
unet: Optional[Any] = None,
v2: Optional[bool] = False,
):
# CNet
if unet is not None:
logger.debug("creating CNet from existing UNet config")
pipe_cnet = UNet2DConditionModel_CNet.from_config(unet.config)
cnet_tmp = path.join(conversion.cache_path, f"{name}-cnet")
makedirs(cnet_tmp, exist_ok=True)
unet.save_pretrained(cnet_tmp)
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(cnet_tmp, low_cpu_mem_usage=False)
else:
logger.debug("loading CNet from pretrained UNet config")
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet")
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet", low_cpu_mem_usage=False)
pipe_cnet = pipe_cnet.to(device=device, dtype=dtype)
run_gc()
@ -308,7 +312,7 @@ def collate_cnet(cnet_path):
# clean up existing tensor files
rmtree(cnet_dir)
mkdir(cnet_dir)
makedirs(cnet_dir)
# collate external tensor files into one
save_model(
@ -548,6 +552,7 @@ def convert_diffusion_diffusers(
logger.debug("converting CNet from loaded UNet")
cnet_path = convert_diffusion_diffusers_cnet(
conversion,
name,
source,
device,
output_path,
@ -568,6 +573,7 @@ def convert_diffusion_diffusers(
logger.info("loading and converting CNet from %s", cnet_source)
cnet_path = convert_diffusion_diffusers_cnet(
conversion,
name,
cnet_source,
device,
output_path,
@ -594,7 +600,7 @@ def convert_diffusion_diffusers(
# clean up existing tensor files
rmtree(unet_dir)
mkdir(unet_dir)
makedirs(unet_dir)
# collate external tensor files into one
save_model(