feat(api): convert CNet for existing diffusion models
This commit is contained in:
parent
2c75311fba
commit
0dd8272285
|
@ -34,138 +34,17 @@ from ..utils import ConversionContext, is_torch_2_0, onnx_export
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
def convert_diffusion_diffusers_cnet(
|
||||||
def convert_diffusion_diffusers(
|
|
||||||
conversion: ConversionContext,
|
conversion: ConversionContext,
|
||||||
model: Dict,
|
|
||||||
source: str,
|
source: str,
|
||||||
) -> Tuple[bool, str]:
|
device: str,
|
||||||
"""
|
output_path: Path,
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
dtype,
|
||||||
"""
|
unet_in_channels,
|
||||||
name = model.get("name")
|
unet_sample_size,
|
||||||
source = source or model.get("source")
|
num_tokens,
|
||||||
single_vae = model.get("single_vae")
|
text_hidden_size,
|
||||||
replace_vae = model.get("vae")
|
):
|
||||||
|
|
||||||
device = conversion.training_device
|
|
||||||
dtype = conversion.torch_dtype()
|
|
||||||
logger.debug("using Torch dtype %s for pipeline", dtype)
|
|
||||||
|
|
||||||
dest_path = path.join(conversion.model_path, name)
|
|
||||||
model_index = path.join(dest_path, "model_index.json")
|
|
||||||
|
|
||||||
# diffusers go into a directory rather than .onnx file
|
|
||||||
logger.info(
|
|
||||||
"converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path
|
|
||||||
)
|
|
||||||
|
|
||||||
if single_vae:
|
|
||||||
logger.info("converting model with single VAE")
|
|
||||||
|
|
||||||
if path.exists(dest_path) and path.exists(model_index):
|
|
||||||
# TODO: check if CNet has been converted
|
|
||||||
logger.info("ONNX model already exists, skipping")
|
|
||||||
return (False, dest_path)
|
|
||||||
|
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
|
||||||
source,
|
|
||||||
torch_dtype=dtype,
|
|
||||||
use_auth_token=conversion.token,
|
|
||||||
).to(device)
|
|
||||||
output_path = Path(dest_path)
|
|
||||||
|
|
||||||
optimize_pipeline(conversion, pipeline)
|
|
||||||
|
|
||||||
# TEXT ENCODER
|
|
||||||
num_tokens = pipeline.text_encoder.config.max_position_embeddings
|
|
||||||
text_hidden_size = pipeline.text_encoder.config.hidden_size
|
|
||||||
text_input = pipeline.tokenizer(
|
|
||||||
"A sample prompt",
|
|
||||||
padding="max_length",
|
|
||||||
max_length=pipeline.tokenizer.model_max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
onnx_export(
|
|
||||||
pipeline.text_encoder,
|
|
||||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
|
||||||
model_args=(
|
|
||||||
text_input.input_ids.to(device=device, dtype=torch.int32),
|
|
||||||
None, # attention mask
|
|
||||||
None, # position ids
|
|
||||||
None, # output attentions
|
|
||||||
torch.tensor(True).to(device=device, dtype=torch.bool),
|
|
||||||
),
|
|
||||||
output_path=output_path / "text_encoder" / ONNX_MODEL,
|
|
||||||
ordered_input_names=["input_ids"],
|
|
||||||
output_names=["last_hidden_state", "pooler_output", "hidden_states"],
|
|
||||||
dynamic_axes={
|
|
||||||
"input_ids": {0: "batch", 1: "sequence"},
|
|
||||||
},
|
|
||||||
opset=conversion.opset,
|
|
||||||
half=conversion.half,
|
|
||||||
)
|
|
||||||
del pipeline.text_encoder
|
|
||||||
|
|
||||||
logger.debug("UNET config: %s", pipeline.unet.config)
|
|
||||||
|
|
||||||
# UNET
|
|
||||||
if single_vae:
|
|
||||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
|
|
||||||
unet_scale = torch.tensor(4).to(device=device, dtype=torch.long)
|
|
||||||
else:
|
|
||||||
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
|
||||||
unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool)
|
|
||||||
|
|
||||||
if is_torch_2_0:
|
|
||||||
pipeline.unet.set_attn_processor(CrossAttnProcessor())
|
|
||||||
|
|
||||||
unet_in_channels = pipeline.unet.config.in_channels
|
|
||||||
unet_sample_size = pipeline.unet.config.sample_size
|
|
||||||
unet_path = output_path / "unet" / ONNX_MODEL
|
|
||||||
onnx_export(
|
|
||||||
pipeline.unet,
|
|
||||||
model_args=(
|
|
||||||
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
|
||||||
device=device, dtype=dtype
|
|
||||||
),
|
|
||||||
torch.randn(2).to(device=device, dtype=dtype),
|
|
||||||
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
|
|
||||||
unet_scale,
|
|
||||||
),
|
|
||||||
output_path=unet_path,
|
|
||||||
ordered_input_names=unet_inputs,
|
|
||||||
# has to be different from "sample" for correct tracing
|
|
||||||
output_names=["out_sample"],
|
|
||||||
dynamic_axes={
|
|
||||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
|
||||||
"timestep": {0: "batch"},
|
|
||||||
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
|
||||||
},
|
|
||||||
opset=conversion.opset,
|
|
||||||
half=conversion.half,
|
|
||||||
external_data=True,
|
|
||||||
)
|
|
||||||
unet_model_path = str(unet_path.absolute().as_posix())
|
|
||||||
unet_dir = path.dirname(unet_model_path)
|
|
||||||
unet = load_model(unet_model_path)
|
|
||||||
|
|
||||||
# clean up existing tensor files
|
|
||||||
rmtree(unet_dir)
|
|
||||||
mkdir(unet_dir)
|
|
||||||
|
|
||||||
# collate external tensor files into one
|
|
||||||
save_model(
|
|
||||||
unet,
|
|
||||||
unet_model_path,
|
|
||||||
save_as_external_data=True,
|
|
||||||
all_tensors_to_one_file=True,
|
|
||||||
location=ONNX_WEIGHTS,
|
|
||||||
convert_attribute=False,
|
|
||||||
)
|
|
||||||
del pipeline.unet
|
|
||||||
|
|
||||||
# CNet
|
# CNet
|
||||||
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet").to(
|
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet").to(
|
||||||
device=device, dtype=dtype
|
device=device, dtype=dtype
|
||||||
|
@ -288,6 +167,147 @@ def convert_diffusion_diffusers(
|
||||||
)
|
)
|
||||||
del pipe_cnet
|
del pipe_cnet
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert_diffusion_diffusers(
|
||||||
|
conversion: ConversionContext,
|
||||||
|
model: Dict,
|
||||||
|
source: str,
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
|
"""
|
||||||
|
name = model.get("name")
|
||||||
|
source = source or model.get("source")
|
||||||
|
single_vae = model.get("single_vae")
|
||||||
|
replace_vae = model.get("vae")
|
||||||
|
|
||||||
|
device = conversion.training_device
|
||||||
|
dtype = conversion.torch_dtype()
|
||||||
|
logger.debug("using Torch dtype %s for pipeline", dtype)
|
||||||
|
|
||||||
|
dest_path = path.join(conversion.model_path, name)
|
||||||
|
model_index = path.join(dest_path, "model_index.json")
|
||||||
|
model_cnet = path.join(dest_path, "cnet", ONNX_MODEL)
|
||||||
|
|
||||||
|
# diffusers go into a directory rather than .onnx file
|
||||||
|
logger.info(
|
||||||
|
"converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path
|
||||||
|
)
|
||||||
|
|
||||||
|
if single_vae:
|
||||||
|
logger.info("converting model with single VAE")
|
||||||
|
|
||||||
|
if path.exists(dest_path) and path.exists(model_index):
|
||||||
|
if not path.exists(model_cnet):
|
||||||
|
logger.info("ONNX model was converted without a ControlNet UNet, converting one")
|
||||||
|
convert_diffusion_diffusers_cnet(conversion, source, device, output_path, dtype, unet_in_channels, unet_sample_size, num_tokens, text_hidden_size)
|
||||||
|
return (True, dest_path)
|
||||||
|
else:
|
||||||
|
logger.info("ONNX model already exists, skipping")
|
||||||
|
return (False, dest_path)
|
||||||
|
|
||||||
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
|
source,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
use_auth_token=conversion.token,
|
||||||
|
).to(device)
|
||||||
|
output_path = Path(dest_path)
|
||||||
|
|
||||||
|
optimize_pipeline(conversion, pipeline)
|
||||||
|
|
||||||
|
# TEXT ENCODER
|
||||||
|
num_tokens = pipeline.text_encoder.config.max_position_embeddings
|
||||||
|
text_hidden_size = pipeline.text_encoder.config.hidden_size
|
||||||
|
text_input = pipeline.tokenizer(
|
||||||
|
"A sample prompt",
|
||||||
|
padding="max_length",
|
||||||
|
max_length=pipeline.tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
onnx_export(
|
||||||
|
pipeline.text_encoder,
|
||||||
|
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||||
|
model_args=(
|
||||||
|
text_input.input_ids.to(device=device, dtype=torch.int32),
|
||||||
|
None, # attention mask
|
||||||
|
None, # position ids
|
||||||
|
None, # output attentions
|
||||||
|
torch.tensor(True).to(device=device, dtype=torch.bool),
|
||||||
|
),
|
||||||
|
output_path=output_path / "text_encoder" / ONNX_MODEL,
|
||||||
|
ordered_input_names=["input_ids"],
|
||||||
|
output_names=["last_hidden_state", "pooler_output", "hidden_states"],
|
||||||
|
dynamic_axes={
|
||||||
|
"input_ids": {0: "batch", 1: "sequence"},
|
||||||
|
},
|
||||||
|
opset=conversion.opset,
|
||||||
|
half=conversion.half,
|
||||||
|
)
|
||||||
|
del pipeline.text_encoder
|
||||||
|
|
||||||
|
logger.debug("UNET config: %s", pipeline.unet.config)
|
||||||
|
|
||||||
|
# UNET
|
||||||
|
if single_vae:
|
||||||
|
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
|
||||||
|
unet_scale = torch.tensor(4).to(device=device, dtype=torch.long)
|
||||||
|
else:
|
||||||
|
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
||||||
|
unet_scale = torch.tensor(False).to(device=device, dtype=torch.bool)
|
||||||
|
|
||||||
|
if is_torch_2_0:
|
||||||
|
pipeline.unet.set_attn_processor(CrossAttnProcessor())
|
||||||
|
|
||||||
|
unet_in_channels = pipeline.unet.config.in_channels
|
||||||
|
unet_sample_size = pipeline.unet.config.sample_size
|
||||||
|
unet_path = output_path / "unet" / ONNX_MODEL
|
||||||
|
onnx_export(
|
||||||
|
pipeline.unet,
|
||||||
|
model_args=(
|
||||||
|
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
||||||
|
device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
torch.randn(2).to(device=device, dtype=dtype),
|
||||||
|
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
|
||||||
|
unet_scale,
|
||||||
|
),
|
||||||
|
output_path=unet_path,
|
||||||
|
ordered_input_names=unet_inputs,
|
||||||
|
# has to be different from "sample" for correct tracing
|
||||||
|
output_names=["out_sample"],
|
||||||
|
dynamic_axes={
|
||||||
|
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||||
|
"timestep": {0: "batch"},
|
||||||
|
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
||||||
|
},
|
||||||
|
opset=conversion.opset,
|
||||||
|
half=conversion.half,
|
||||||
|
external_data=True,
|
||||||
|
)
|
||||||
|
unet_model_path = str(unet_path.absolute().as_posix())
|
||||||
|
unet_dir = path.dirname(unet_model_path)
|
||||||
|
unet = load_model(unet_model_path)
|
||||||
|
|
||||||
|
# clean up existing tensor files
|
||||||
|
rmtree(unet_dir)
|
||||||
|
mkdir(unet_dir)
|
||||||
|
|
||||||
|
# collate external tensor files into one
|
||||||
|
save_model(
|
||||||
|
unet,
|
||||||
|
unet_model_path,
|
||||||
|
save_as_external_data=True,
|
||||||
|
all_tensors_to_one_file=True,
|
||||||
|
location=ONNX_WEIGHTS,
|
||||||
|
convert_attribute=False,
|
||||||
|
)
|
||||||
|
del pipeline.unet
|
||||||
|
|
||||||
|
convert_diffusion_diffusers_cnet(conversion, source, device, output_path, dtype, unet_in_channels, unet_sample_size, num_tokens, text_hidden_size)
|
||||||
|
|
||||||
# VAE
|
# VAE
|
||||||
if replace_vae is not None:
|
if replace_vae is not None:
|
||||||
logger.debug("loading custom VAE: %s", replace_vae)
|
logger.debug("loading custom VAE: %s", replace_vae)
|
||||||
|
|
|
@ -64,7 +64,7 @@ export function QueryList<T>(props: QueryListProps<T>) {
|
||||||
|
|
||||||
function noneLabel(): Maybe<string> {
|
function noneLabel(): Maybe<string> {
|
||||||
if (showNone) {
|
if (showNone) {
|
||||||
return t(`${labelKey}.none`);
|
return 'none';
|
||||||
}
|
}
|
||||||
|
|
||||||
return undefined;
|
return undefined;
|
||||||
|
|
Loading…
Reference in New Issue