fix inpainting conversion via torch
This commit is contained in:
parent
c399430274
commit
22ff0be758
|
@ -1377,6 +1377,7 @@ def extract_checkpoint(
|
||||||
conversion: ConversionContext,
|
conversion: ConversionContext,
|
||||||
new_model_name: str,
|
new_model_name: str,
|
||||||
checkpoint_file: str,
|
checkpoint_file: str,
|
||||||
|
is_inpainting: Optional[bool] = False,
|
||||||
scheduler_type="ddim",
|
scheduler_type="ddim",
|
||||||
extract_ema=False,
|
extract_ema=False,
|
||||||
train_unfrozen=False,
|
train_unfrozen=False,
|
||||||
|
@ -1537,6 +1538,9 @@ def extract_checkpoint(
|
||||||
original_config, image_size=image_size
|
original_config, image_size=image_size
|
||||||
)
|
)
|
||||||
unet_config["upcast_attention"] = upcast_attention
|
unet_config["upcast_attention"] = upcast_attention
|
||||||
|
if is_inpainting:
|
||||||
|
unet_config["in_channels "] = 9
|
||||||
|
|
||||||
unet = UNet2DConditionModel(**unet_config)
|
unet = UNet2DConditionModel(**unet_config)
|
||||||
|
|
||||||
converted_unet_checkpoint, has_ema = convert_ldm_unet_checkpoint(
|
converted_unet_checkpoint, has_ema = convert_ldm_unet_checkpoint(
|
||||||
|
@ -1685,6 +1689,7 @@ def convert_extract_checkpoint(
|
||||||
conversion: ConversionContext,
|
conversion: ConversionContext,
|
||||||
source: str,
|
source: str,
|
||||||
dest: str,
|
dest: str,
|
||||||
|
is_inpainting: Optional[bool] = False,
|
||||||
config_file: Optional[str] = None,
|
config_file: Optional[str] = None,
|
||||||
vae_file: Optional[str] = None,
|
vae_file: Optional[str] = None,
|
||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
|
|
|
@ -313,8 +313,10 @@ def convert_diffusion_diffusers(
|
||||||
source, conversion.map_location, size=image_size, version=version
|
source, conversion.map_location, size=image_size, version=version
|
||||||
)
|
)
|
||||||
|
|
||||||
|
is_inpainting = False
|
||||||
if pipe_type == "inpaint":
|
if pipe_type == "inpaint":
|
||||||
pipe_args["num_in_channels"] = 9
|
pipe_args["num_in_channels"] = 9
|
||||||
|
is_inpainting = True
|
||||||
|
|
||||||
if format == "safetensors":
|
if format == "safetensors":
|
||||||
pipe_args["from_safetensors"] = True
|
pipe_args["from_safetensors"] = True
|
||||||
|
@ -334,6 +336,7 @@ def convert_diffusion_diffusers(
|
||||||
conversion,
|
conversion,
|
||||||
source,
|
source,
|
||||||
f"{name}-torch",
|
f"{name}-torch",
|
||||||
|
is_inpainting=is_inpainting
|
||||||
config_file=config,
|
config_file=config,
|
||||||
vae_file=replace_vae,
|
vae_file=replace_vae,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue