diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index 549868f6..1cb0502c 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -13,7 +13,7 @@ from logging import getLogger from os import mkdir, path from pathlib import Path from shutil import rmtree -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import torch from diffusers import ( @@ -25,6 +25,9 @@ from diffusers import ( StableDiffusionPipeline, StableDiffusionUpscalePipeline, ) +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + download_from_original_stable_diffusion_ckpt, +) from onnx import load_model, save_model from ...constants import ONNX_MODEL, ONNX_WEIGHTS @@ -32,7 +35,7 @@ from ...diffusers.load import optimize_pipeline from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ...diffusers.version_safe_diffusers import AttnProcessor from ...models.cnet import UNet2DConditionModel_CNet -from ..utils import ConversionContext, is_torch_2_0, onnx_export +from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export logger = getLogger(__name__) @@ -48,6 +51,47 @@ available_pipelines = { } +def get_model_version( + checkpoint, + size=None, +) -> Tuple[bool, Dict[str, Union[bool, int, str]]]: + if "global_step" in checkpoint: + global_step = checkpoint["global_step"] + else: + print("global_step key not found in model") + global_step = None + + if size is None: + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + size = 512 if global_step == 875000 else 768 + + v2 = False + opts = { + "extract_ema": True, + "image_size": size, + } + + key_name = ( + "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + ) + if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: + v2 = True + if size != 512: + # v2.1 needs to upcast attention + logger.debug("setting upcast_attention") + opts["upcast_attention"] = True + + if v2 and size != 512: + opts["model_type"] = "FrozenOpenCLIPEmbedder" + opts["prediction_type"] = "v_prediction" + else: + opts["model_type"] = "FrozenCLIPEmbedder" + opts["prediction_type"] = "epsilon" + + return (v2, opts) + + def convert_diffusion_diffusers_cnet( conversion: ConversionContext, source: str, @@ -199,16 +243,18 @@ def convert_diffusion_diffusers( """ name = model.get("name") source = source or model.get("source") + config = model.get("config", None) single_vae = model.get("single_vae") replace_vae = model.get("vae") pipe_type = model.get("pipeline", "txt2img") - pipe_config = model.get("config", None) device = conversion.training_device dtype = conversion.torch_dtype() logger.debug("using Torch dtype %s for pipeline", dtype) - config_path = None if pipe_config is None else path.join(conversion.model_path, "config", pipe_config) + config_path = ( + None if config is None else path.join(conversion.model_path, "config", config) + ) dest_path = path.join(conversion.model_path, name) model_index = path.join(dest_path, "model_index.json") model_cnet = path.join(dest_path, "cnet", ONNX_MODEL) @@ -233,7 +279,7 @@ def convert_diffusion_diffusers( return (False, dest_path) pipe_class = available_pipelines.get(pipe_type) - pipe_args = {} + v2, pipe_args = get_model_version(load_tensor(source, conversion.map_location)) if pipe_type == "inpaint": pipe_args["num_in_channels"] = 9 @@ -247,9 +293,10 @@ def convert_diffusion_diffusers( ).to(device) elif path.exists(source) and path.isfile(source): logger.debug("loading pipeline from SD checkpoint: %s", source) - pipeline = pipe_class.from_ckpt( + pipeline = download_from_original_stable_diffusion_ckpt( source, original_config_file=config_path, + pipeline_class=pipe_class, torch_dtype=dtype, **pipe_args, ).to(device) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 92f0c8fc..2a50a930 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -49,7 +49,9 @@ def run_loopback( # load img2img pipeline once pipe_type = params.get_valid_pipeline("img2img") if pipe_type == "controlnet": - logger.debug("controlnet pipeline cannot be used for loopback, switching to img2img") + logger.debug( + "controlnet pipeline cannot be used for loopback, switching to img2img" + ) pipe_type = "img2img" logger.debug("using %s pipeline for loopback", pipe_type)