diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index 1cb0502c..79cd7358 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -52,42 +52,51 @@ available_pipelines = { def get_model_version( - checkpoint, - size=None, + source, + map_location, + size = None, + version = None, ) -> Tuple[bool, Dict[str, Union[bool, int, str]]]: - if "global_step" in checkpoint: - global_step = checkpoint["global_step"] - else: - print("global_step key not found in model") - global_step = None - - if size is None: - # NOTE: For stable diffusion 2 base one has to pass `image_size==512` - # as it relies on a brittle global step parameter here - size = 512 if global_step == 875000 else 768 - - v2 = False + v2 = version is not None and "v2" in version opts = { "extract_ema": True, - "image_size": size, } - key_name = ( - "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - ) - if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: - v2 = True - if size != 512: - # v2.1 needs to upcast attention - logger.debug("setting upcast_attention") - opts["upcast_attention"] = True + try: + checkpoint = load_tensor(source, map_location=map_location) - if v2 and size != 512: - opts["model_type"] = "FrozenOpenCLIPEmbedder" - opts["prediction_type"] = "v_prediction" - else: - opts["model_type"] = "FrozenCLIPEmbedder" - opts["prediction_type"] = "epsilon" + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + else: + print("global_step key not found in model") + global_step = None + + if size is None: + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + size = 512 if global_step == 875000 else 768 + + opts["image_size"] = size + + key_name = ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + ) + if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: + v2 = True + if size != 512: + # v2.1 needs to upcast attention + logger.debug("setting upcast_attention") + opts["upcast_attention"] = True + + if v2 and size != 512: + opts["model_type"] = "FrozenOpenCLIPEmbedder" + opts["prediction_type"] = "v_prediction" + else: + opts["model_type"] = "FrozenCLIPEmbedder" + opts["prediction_type"] = "epsilon" + except: + logger.debug("unable to load tensor for version check") + pass return (v2, opts) @@ -241,12 +250,14 @@ def convert_diffusion_diffusers( """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py """ - name = model.get("name") - source = source or model.get("source") config = model.get("config", None) - single_vae = model.get("single_vae") - replace_vae = model.get("vae") + image_size = model.get("image_size", None) + name = model.get("name") pipe_type = model.get("pipeline", "txt2img") + single_vae = model.get("single_vae") + source = source or model.get("source") + replace_vae = model.get("vae") + version = model.get("version", None) device = conversion.training_device dtype = conversion.torch_dtype() @@ -279,7 +290,7 @@ def convert_diffusion_diffusers( return (False, dest_path) pipe_class = available_pipelines.get(pipe_type) - v2, pipe_args = get_model_version(load_tensor(source, conversion.map_location)) + v2, pipe_args = get_model_version(source, conversion.map_location, size=image_size, version=version) if pipe_type == "inpaint": pipe_args["num_in_channels"] = 9 @@ -299,7 +310,7 @@ def convert_diffusion_diffusers( pipeline_class=pipe_class, torch_dtype=dtype, **pipe_args, - ).to(device) + ).to(device, torch_dtype=dtype) else: logger.warning("pipeline source not found or not recognized: %s", source) raise ValueError(f"pipeline source not found or not recognized: {source}") diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 63872732..23efae3c 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -268,6 +268,9 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]: except Exception as e: logger.warning("error loading tensor: %s", e) + if checkpoint is None: + raise ValueError("error loading tensor") + if checkpoint is not None and "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] diff --git a/api/schemas/extras.yaml b/api/schemas/extras.yaml index c1568244..3935d3a9 100644 --- a/api/schemas/extras.yaml +++ b/api/schemas/extras.yaml @@ -78,6 +78,8 @@ $defs: properties: config: type: string + image_size: + type: number inversions: type: array items: @@ -94,6 +96,13 @@ $defs: ] vae: type: string + version: + type: string + enum: [ + v1, + v2, + v2.1, + ] upscaling_model: allOf: