fix(api): pass correct text model type when converting v2 checkpoints (#360)
This commit is contained in:
parent
4eba9a6400
commit
2690eafe09
|
@ -13,7 +13,7 @@ from logging import getLogger
|
||||||
from os import mkdir, path
|
from os import mkdir, path
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
|
@ -25,6 +25,9 @@ from diffusers import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
StableDiffusionUpscalePipeline,
|
StableDiffusionUpscalePipeline,
|
||||||
)
|
)
|
||||||
|
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||||
|
download_from_original_stable_diffusion_ckpt,
|
||||||
|
)
|
||||||
from onnx import load_model, save_model
|
from onnx import load_model, save_model
|
||||||
|
|
||||||
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
|
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||||
|
@ -32,7 +35,7 @@ from ...diffusers.load import optimize_pipeline
|
||||||
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||||
from ...diffusers.version_safe_diffusers import AttnProcessor
|
from ...diffusers.version_safe_diffusers import AttnProcessor
|
||||||
from ...models.cnet import UNet2DConditionModel_CNet
|
from ...models.cnet import UNet2DConditionModel_CNet
|
||||||
from ..utils import ConversionContext, is_torch_2_0, onnx_export
|
from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -48,6 +51,47 @@ available_pipelines = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_version(
|
||||||
|
checkpoint,
|
||||||
|
size=None,
|
||||||
|
) -> Tuple[bool, Dict[str, Union[bool, int, str]]]:
|
||||||
|
if "global_step" in checkpoint:
|
||||||
|
global_step = checkpoint["global_step"]
|
||||||
|
else:
|
||||||
|
print("global_step key not found in model")
|
||||||
|
global_step = None
|
||||||
|
|
||||||
|
if size is None:
|
||||||
|
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
||||||
|
# as it relies on a brittle global step parameter here
|
||||||
|
size = 512 if global_step == 875000 else 768
|
||||||
|
|
||||||
|
v2 = False
|
||||||
|
opts = {
|
||||||
|
"extract_ema": True,
|
||||||
|
"image_size": size,
|
||||||
|
}
|
||||||
|
|
||||||
|
key_name = (
|
||||||
|
"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
|
)
|
||||||
|
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||||
|
v2 = True
|
||||||
|
if size != 512:
|
||||||
|
# v2.1 needs to upcast attention
|
||||||
|
logger.debug("setting upcast_attention")
|
||||||
|
opts["upcast_attention"] = True
|
||||||
|
|
||||||
|
if v2 and size != 512:
|
||||||
|
opts["model_type"] = "FrozenOpenCLIPEmbedder"
|
||||||
|
opts["prediction_type"] = "v_prediction"
|
||||||
|
else:
|
||||||
|
opts["model_type"] = "FrozenCLIPEmbedder"
|
||||||
|
opts["prediction_type"] = "epsilon"
|
||||||
|
|
||||||
|
return (v2, opts)
|
||||||
|
|
||||||
|
|
||||||
def convert_diffusion_diffusers_cnet(
|
def convert_diffusion_diffusers_cnet(
|
||||||
conversion: ConversionContext,
|
conversion: ConversionContext,
|
||||||
source: str,
|
source: str,
|
||||||
|
@ -199,16 +243,18 @@ def convert_diffusion_diffusers(
|
||||||
"""
|
"""
|
||||||
name = model.get("name")
|
name = model.get("name")
|
||||||
source = source or model.get("source")
|
source = source or model.get("source")
|
||||||
|
config = model.get("config", None)
|
||||||
single_vae = model.get("single_vae")
|
single_vae = model.get("single_vae")
|
||||||
replace_vae = model.get("vae")
|
replace_vae = model.get("vae")
|
||||||
pipe_type = model.get("pipeline", "txt2img")
|
pipe_type = model.get("pipeline", "txt2img")
|
||||||
pipe_config = model.get("config", None)
|
|
||||||
|
|
||||||
device = conversion.training_device
|
device = conversion.training_device
|
||||||
dtype = conversion.torch_dtype()
|
dtype = conversion.torch_dtype()
|
||||||
logger.debug("using Torch dtype %s for pipeline", dtype)
|
logger.debug("using Torch dtype %s for pipeline", dtype)
|
||||||
|
|
||||||
config_path = None if pipe_config is None else path.join(conversion.model_path, "config", pipe_config)
|
config_path = (
|
||||||
|
None if config is None else path.join(conversion.model_path, "config", config)
|
||||||
|
)
|
||||||
dest_path = path.join(conversion.model_path, name)
|
dest_path = path.join(conversion.model_path, name)
|
||||||
model_index = path.join(dest_path, "model_index.json")
|
model_index = path.join(dest_path, "model_index.json")
|
||||||
model_cnet = path.join(dest_path, "cnet", ONNX_MODEL)
|
model_cnet = path.join(dest_path, "cnet", ONNX_MODEL)
|
||||||
|
@ -233,7 +279,7 @@ def convert_diffusion_diffusers(
|
||||||
return (False, dest_path)
|
return (False, dest_path)
|
||||||
|
|
||||||
pipe_class = available_pipelines.get(pipe_type)
|
pipe_class = available_pipelines.get(pipe_type)
|
||||||
pipe_args = {}
|
v2, pipe_args = get_model_version(load_tensor(source, conversion.map_location))
|
||||||
|
|
||||||
if pipe_type == "inpaint":
|
if pipe_type == "inpaint":
|
||||||
pipe_args["num_in_channels"] = 9
|
pipe_args["num_in_channels"] = 9
|
||||||
|
@ -247,9 +293,10 @@ def convert_diffusion_diffusers(
|
||||||
).to(device)
|
).to(device)
|
||||||
elif path.exists(source) and path.isfile(source):
|
elif path.exists(source) and path.isfile(source):
|
||||||
logger.debug("loading pipeline from SD checkpoint: %s", source)
|
logger.debug("loading pipeline from SD checkpoint: %s", source)
|
||||||
pipeline = pipe_class.from_ckpt(
|
pipeline = download_from_original_stable_diffusion_ckpt(
|
||||||
source,
|
source,
|
||||||
original_config_file=config_path,
|
original_config_file=config_path,
|
||||||
|
pipeline_class=pipe_class,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
**pipe_args,
|
**pipe_args,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
|
@ -49,7 +49,9 @@ def run_loopback(
|
||||||
# load img2img pipeline once
|
# load img2img pipeline once
|
||||||
pipe_type = params.get_valid_pipeline("img2img")
|
pipe_type = params.get_valid_pipeline("img2img")
|
||||||
if pipe_type == "controlnet":
|
if pipe_type == "controlnet":
|
||||||
logger.debug("controlnet pipeline cannot be used for loopback, switching to img2img")
|
logger.debug(
|
||||||
|
"controlnet pipeline cannot be used for loopback, switching to img2img"
|
||||||
|
)
|
||||||
pipe_type = "img2img"
|
pipe_type = "img2img"
|
||||||
|
|
||||||
logger.debug("using %s pipeline for loopback", pipe_type)
|
logger.debug("using %s pipeline for loopback", pipe_type)
|
||||||
|
|
Loading…
Reference in New Issue