From 22ff0be7589b2f35e285ec1cdaeb1567f7d7cd0a Mon Sep 17 00:00:00 2001 From: HoopyFreud Date: Sat, 10 Jun 2023 12:01:02 -0400 Subject: [PATCH] fix inpainting conversion via torch --- api/onnx_web/convert/diffusion/checkpoint.py | 5 +++++ api/onnx_web/convert/diffusion/diffusers.py | 5 ++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/convert/diffusion/checkpoint.py b/api/onnx_web/convert/diffusion/checkpoint.py index b89cdda1..494a61ad 100644 --- a/api/onnx_web/convert/diffusion/checkpoint.py +++ b/api/onnx_web/convert/diffusion/checkpoint.py @@ -1377,6 +1377,7 @@ def extract_checkpoint( conversion: ConversionContext, new_model_name: str, checkpoint_file: str, + is_inpainting: Optional[bool] = False, scheduler_type="ddim", extract_ema=False, train_unfrozen=False, @@ -1537,6 +1538,9 @@ def extract_checkpoint( original_config, image_size=image_size ) unet_config["upcast_attention"] = upcast_attention + if is_inpainting: + unet_config["in_channels "] = 9 + unet = UNet2DConditionModel(**unet_config) converted_unet_checkpoint, has_ema = convert_ldm_unet_checkpoint( @@ -1685,6 +1689,7 @@ def convert_extract_checkpoint( conversion: ConversionContext, source: str, dest: str, + is_inpainting: Optional[bool] = False, config_file: Optional[str] = None, vae_file: Optional[str] = None, ) -> Tuple[bool, str]: diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index 882e1b67..2792e582 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -312,9 +312,11 @@ def convert_diffusion_diffusers( v2, pipe_args = get_model_version( source, conversion.map_location, size=image_size, version=version ) - + + is_inpainting = False if pipe_type == "inpaint": pipe_args["num_in_channels"] = 9 + is_inpainting = True if format == "safetensors": pipe_args["from_safetensors"] = True @@ -334,6 +336,7 @@ def convert_diffusion_diffusers( conversion, source, f"{name}-torch", + is_inpainting=is_inpainting config_file=config, vae_file=replace_vae, )