fix(api): pass correct text model type when converting v2 checkpoints (#360)
This commit is contained in:
parent
4eba9a6400
commit
2690eafe09
|
@ -13,7 +13,7 @@ from logging import getLogger
|
|||
from os import mkdir, path
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from diffusers import (
|
||||
|
@ -25,6 +25,9 @@ from diffusers import (
|
|||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
from onnx import load_model, save_model
|
||||
|
||||
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||
|
@ -32,7 +35,7 @@ from ...diffusers.load import optimize_pipeline
|
|||
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||
from ...diffusers.version_safe_diffusers import AttnProcessor
|
||||
from ...models.cnet import UNet2DConditionModel_CNet
|
||||
from ..utils import ConversionContext, is_torch_2_0, onnx_export
|
||||
from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -48,6 +51,47 @@ available_pipelines = {
|
|||
}
|
||||
|
||||
|
||||
def get_model_version(
|
||||
checkpoint,
|
||||
size=None,
|
||||
) -> Tuple[bool, Dict[str, Union[bool, int, str]]]:
|
||||
if "global_step" in checkpoint:
|
||||
global_step = checkpoint["global_step"]
|
||||
else:
|
||||
print("global_step key not found in model")
|
||||
global_step = None
|
||||
|
||||
if size is None:
|
||||
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
||||
# as it relies on a brittle global step parameter here
|
||||
size = 512 if global_step == 875000 else 768
|
||||
|
||||
v2 = False
|
||||
opts = {
|
||||
"extract_ema": True,
|
||||
"image_size": size,
|
||||
}
|
||||
|
||||
key_name = (
|
||||
"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
)
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
v2 = True
|
||||
if size != 512:
|
||||
# v2.1 needs to upcast attention
|
||||
logger.debug("setting upcast_attention")
|
||||
opts["upcast_attention"] = True
|
||||
|
||||
if v2 and size != 512:
|
||||
opts["model_type"] = "FrozenOpenCLIPEmbedder"
|
||||
opts["prediction_type"] = "v_prediction"
|
||||
else:
|
||||
opts["model_type"] = "FrozenCLIPEmbedder"
|
||||
opts["prediction_type"] = "epsilon"
|
||||
|
||||
return (v2, opts)
|
||||
|
||||
|
||||
def convert_diffusion_diffusers_cnet(
|
||||
conversion: ConversionContext,
|
||||
source: str,
|
||||
|
@ -199,16 +243,18 @@ def convert_diffusion_diffusers(
|
|||
"""
|
||||
name = model.get("name")
|
||||
source = source or model.get("source")
|
||||
config = model.get("config", None)
|
||||
single_vae = model.get("single_vae")
|
||||
replace_vae = model.get("vae")
|
||||
pipe_type = model.get("pipeline", "txt2img")
|
||||
pipe_config = model.get("config", None)
|
||||
|
||||
device = conversion.training_device
|
||||
dtype = conversion.torch_dtype()
|
||||
logger.debug("using Torch dtype %s for pipeline", dtype)
|
||||
|
||||
config_path = None if pipe_config is None else path.join(conversion.model_path, "config", pipe_config)
|
||||
config_path = (
|
||||
None if config is None else path.join(conversion.model_path, "config", config)
|
||||
)
|
||||
dest_path = path.join(conversion.model_path, name)
|
||||
model_index = path.join(dest_path, "model_index.json")
|
||||
model_cnet = path.join(dest_path, "cnet", ONNX_MODEL)
|
||||
|
@ -233,7 +279,7 @@ def convert_diffusion_diffusers(
|
|||
return (False, dest_path)
|
||||
|
||||
pipe_class = available_pipelines.get(pipe_type)
|
||||
pipe_args = {}
|
||||
v2, pipe_args = get_model_version(load_tensor(source, conversion.map_location))
|
||||
|
||||
if pipe_type == "inpaint":
|
||||
pipe_args["num_in_channels"] = 9
|
||||
|
@ -247,9 +293,10 @@ def convert_diffusion_diffusers(
|
|||
).to(device)
|
||||
elif path.exists(source) and path.isfile(source):
|
||||
logger.debug("loading pipeline from SD checkpoint: %s", source)
|
||||
pipeline = pipe_class.from_ckpt(
|
||||
pipeline = download_from_original_stable_diffusion_ckpt(
|
||||
source,
|
||||
original_config_file=config_path,
|
||||
pipeline_class=pipe_class,
|
||||
torch_dtype=dtype,
|
||||
**pipe_args,
|
||||
).to(device)
|
||||
|
|
|
@ -49,7 +49,9 @@ def run_loopback(
|
|||
# load img2img pipeline once
|
||||
pipe_type = params.get_valid_pipeline("img2img")
|
||||
if pipe_type == "controlnet":
|
||||
logger.debug("controlnet pipeline cannot be used for loopback, switching to img2img")
|
||||
logger.debug(
|
||||
"controlnet pipeline cannot be used for loopback, switching to img2img"
|
||||
)
|
||||
pipe_type = "img2img"
|
||||
|
||||
logger.debug("using %s pipeline for loopback", pipe_type)
|
||||
|
|
Loading…
Reference in New Issue