1
0
Fork 0

fix(api): pass correct text model type when converting v2 checkpoints (#360)

This commit is contained in:
Sean Sube 2023-04-29 22:45:48 -05:00
parent 4eba9a6400
commit 2690eafe09
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 56 additions and 7 deletions

View File

@ -13,7 +13,7 @@ from logging import getLogger
from os import mkdir, path
from pathlib import Path
from shutil import rmtree
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union
import torch
from diffusers import (
@ -25,6 +25,9 @@ from diffusers import (
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
from onnx import load_model, save_model
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
@ -32,7 +35,7 @@ from ...diffusers.load import optimize_pipeline
from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ...diffusers.version_safe_diffusers import AttnProcessor
from ...models.cnet import UNet2DConditionModel_CNet
from ..utils import ConversionContext, is_torch_2_0, onnx_export
from ..utils import ConversionContext, is_torch_2_0, load_tensor, onnx_export
logger = getLogger(__name__)
@ -48,6 +51,47 @@ available_pipelines = {
}
def get_model_version(
checkpoint,
size=None,
) -> Tuple[bool, Dict[str, Union[bool, int, str]]]:
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
print("global_step key not found in model")
global_step = None
if size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
size = 512 if global_step == 875000 else 768
v2 = False
opts = {
"extract_ema": True,
"image_size": size,
}
key_name = (
"model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
)
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
v2 = True
if size != 512:
# v2.1 needs to upcast attention
logger.debug("setting upcast_attention")
opts["upcast_attention"] = True
if v2 and size != 512:
opts["model_type"] = "FrozenOpenCLIPEmbedder"
opts["prediction_type"] = "v_prediction"
else:
opts["model_type"] = "FrozenCLIPEmbedder"
opts["prediction_type"] = "epsilon"
return (v2, opts)
def convert_diffusion_diffusers_cnet(
conversion: ConversionContext,
source: str,
@ -199,16 +243,18 @@ def convert_diffusion_diffusers(
"""
name = model.get("name")
source = source or model.get("source")
config = model.get("config", None)
single_vae = model.get("single_vae")
replace_vae = model.get("vae")
pipe_type = model.get("pipeline", "txt2img")
pipe_config = model.get("config", None)
device = conversion.training_device
dtype = conversion.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", dtype)
config_path = None if pipe_config is None else path.join(conversion.model_path, "config", pipe_config)
config_path = (
None if config is None else path.join(conversion.model_path, "config", config)
)
dest_path = path.join(conversion.model_path, name)
model_index = path.join(dest_path, "model_index.json")
model_cnet = path.join(dest_path, "cnet", ONNX_MODEL)
@ -233,7 +279,7 @@ def convert_diffusion_diffusers(
return (False, dest_path)
pipe_class = available_pipelines.get(pipe_type)
pipe_args = {}
v2, pipe_args = get_model_version(load_tensor(source, conversion.map_location))
if pipe_type == "inpaint":
pipe_args["num_in_channels"] = 9
@ -247,9 +293,10 @@ def convert_diffusion_diffusers(
).to(device)
elif path.exists(source) and path.isfile(source):
logger.debug("loading pipeline from SD checkpoint: %s", source)
pipeline = pipe_class.from_ckpt(
pipeline = download_from_original_stable_diffusion_ckpt(
source,
original_config_file=config_path,
pipeline_class=pipe_class,
torch_dtype=dtype,
**pipe_args,
).to(device)

View File

@ -49,7 +49,9 @@ def run_loopback(
# load img2img pipeline once
pipe_type = params.get_valid_pipeline("img2img")
if pipe_type == "controlnet":
logger.debug("controlnet pipeline cannot be used for loopback, switching to img2img")
logger.debug(
"controlnet pipeline cannot be used for loopback, switching to img2img"
)
pipe_type = "img2img"
logger.debug("using %s pipeline for loopback", pipe_type)