1
0
Fork 0

fix inpainting conversion via torch

This commit is contained in:
HoopyFreud 2023-06-10 12:01:02 -04:00
parent c399430274
commit 22ff0be758
2 changed files with 9 additions and 1 deletions

View File

@ -1377,6 +1377,7 @@ def extract_checkpoint(
conversion: ConversionContext, conversion: ConversionContext,
new_model_name: str, new_model_name: str,
checkpoint_file: str, checkpoint_file: str,
is_inpainting: Optional[bool] = False,
scheduler_type="ddim", scheduler_type="ddim",
extract_ema=False, extract_ema=False,
train_unfrozen=False, train_unfrozen=False,
@ -1537,6 +1538,9 @@ def extract_checkpoint(
original_config, image_size=image_size original_config, image_size=image_size
) )
unet_config["upcast_attention"] = upcast_attention unet_config["upcast_attention"] = upcast_attention
if is_inpainting:
unet_config["in_channels "] = 9
unet = UNet2DConditionModel(**unet_config) unet = UNet2DConditionModel(**unet_config)
converted_unet_checkpoint, has_ema = convert_ldm_unet_checkpoint( converted_unet_checkpoint, has_ema = convert_ldm_unet_checkpoint(
@ -1685,6 +1689,7 @@ def convert_extract_checkpoint(
conversion: ConversionContext, conversion: ConversionContext,
source: str, source: str,
dest: str, dest: str,
is_inpainting: Optional[bool] = False,
config_file: Optional[str] = None, config_file: Optional[str] = None,
vae_file: Optional[str] = None, vae_file: Optional[str] = None,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:

View File

@ -313,8 +313,10 @@ def convert_diffusion_diffusers(
source, conversion.map_location, size=image_size, version=version source, conversion.map_location, size=image_size, version=version
) )
is_inpainting = False
if pipe_type == "inpaint": if pipe_type == "inpaint":
pipe_args["num_in_channels"] = 9 pipe_args["num_in_channels"] = 9
is_inpainting = True
if format == "safetensors": if format == "safetensors":
pipe_args["from_safetensors"] = True pipe_args["from_safetensors"] = True
@ -334,6 +336,7 @@ def convert_diffusion_diffusers(
conversion, conversion,
source, source,
f"{name}-torch", f"{name}-torch",
is_inpainting=is_inpainting
config_file=config, config_file=config,
vae_file=replace_vae, vae_file=replace_vae,
) )