fix inpainting conversion via torch
This commit is contained in:
parent
c399430274
commit
22ff0be758
|
@ -1377,6 +1377,7 @@ def extract_checkpoint(
|
|||
conversion: ConversionContext,
|
||||
new_model_name: str,
|
||||
checkpoint_file: str,
|
||||
is_inpainting: Optional[bool] = False,
|
||||
scheduler_type="ddim",
|
||||
extract_ema=False,
|
||||
train_unfrozen=False,
|
||||
|
@ -1537,6 +1538,9 @@ def extract_checkpoint(
|
|||
original_config, image_size=image_size
|
||||
)
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
if is_inpainting:
|
||||
unet_config["in_channels "] = 9
|
||||
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
||||
converted_unet_checkpoint, has_ema = convert_ldm_unet_checkpoint(
|
||||
|
@ -1685,6 +1689,7 @@ def convert_extract_checkpoint(
|
|||
conversion: ConversionContext,
|
||||
source: str,
|
||||
dest: str,
|
||||
is_inpainting: Optional[bool] = False,
|
||||
config_file: Optional[str] = None,
|
||||
vae_file: Optional[str] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
|
|
|
@ -312,9 +312,11 @@ def convert_diffusion_diffusers(
|
|||
v2, pipe_args = get_model_version(
|
||||
source, conversion.map_location, size=image_size, version=version
|
||||
)
|
||||
|
||||
|
||||
is_inpainting = False
|
||||
if pipe_type == "inpaint":
|
||||
pipe_args["num_in_channels"] = 9
|
||||
is_inpainting = True
|
||||
|
||||
if format == "safetensors":
|
||||
pipe_args["from_safetensors"] = True
|
||||
|
@ -334,6 +336,7 @@ def convert_diffusion_diffusers(
|
|||
conversion,
|
||||
source,
|
||||
f"{name}-torch",
|
||||
is_inpainting=is_inpainting
|
||||
config_file=config,
|
||||
vae_file=replace_vae,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue