1
0
Fork 0

fix inpainting conversion via torch

This commit is contained in:
HoopyFreud 2023-06-10 12:01:02 -04:00
parent c399430274
commit 22ff0be758
2 changed files with 9 additions and 1 deletions

View File

@ -1377,6 +1377,7 @@ def extract_checkpoint(
conversion: ConversionContext,
new_model_name: str,
checkpoint_file: str,
is_inpainting: Optional[bool] = False,
scheduler_type="ddim",
extract_ema=False,
train_unfrozen=False,
@ -1537,6 +1538,9 @@ def extract_checkpoint(
original_config, image_size=image_size
)
unet_config["upcast_attention"] = upcast_attention
if is_inpainting:
unet_config["in_channels "] = 9
unet = UNet2DConditionModel(**unet_config)
converted_unet_checkpoint, has_ema = convert_ldm_unet_checkpoint(
@ -1685,6 +1689,7 @@ def convert_extract_checkpoint(
conversion: ConversionContext,
source: str,
dest: str,
is_inpainting: Optional[bool] = False,
config_file: Optional[str] = None,
vae_file: Optional[str] = None,
) -> Tuple[bool, str]:

View File

@ -312,9 +312,11 @@ def convert_diffusion_diffusers(
v2, pipe_args = get_model_version(
source, conversion.map_location, size=image_size, version=version
)
is_inpainting = False
if pipe_type == "inpaint":
pipe_args["num_in_channels"] = 9
is_inpainting = True
if format == "safetensors":
pipe_args["from_safetensors"] = True
@ -334,6 +336,7 @@ def convert_diffusion_diffusers(
conversion,
source,
f"{name}-torch",
is_inpainting=is_inpainting
config_file=config,
vae_file=replace_vae,
)