add controlnet support to optimum converter
This commit is contained in:
parent
abeeddeeb2
commit
f6e6e31789
|
@ -833,6 +833,16 @@ def convert_diffusion_diffusers_optimum(
|
||||||
logger.debug("exporting torch model for %s: %s", source, temp_path)
|
logger.debug("exporting torch model for %s: %s", source, temp_path)
|
||||||
pipeline.save_pretrained(temp_path)
|
pipeline.save_pretrained(temp_path)
|
||||||
|
|
||||||
|
# config needed for ControlNet later
|
||||||
|
num_tokens = pipeline.text_encoder.config.max_position_embeddings
|
||||||
|
text_hidden_size = pipeline.text_encoder.config.hidden_size
|
||||||
|
unet_in_channels = pipeline.unet.config.in_channels
|
||||||
|
unet_sample_size = pipeline.unet.config.sample_size
|
||||||
|
|
||||||
|
# GC temporary pipeline
|
||||||
|
del pipeline
|
||||||
|
run_gc()
|
||||||
|
|
||||||
main_export(
|
main_export(
|
||||||
temp_path,
|
temp_path,
|
||||||
output=dest_path,
|
output=dest_path,
|
||||||
|
@ -868,4 +878,38 @@ def convert_diffusion_diffusers_optimum(
|
||||||
location=ONNX_WEIGHTS,
|
location=ONNX_WEIGHTS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if conversion.control:
|
||||||
|
logger.debug("converting CNet from pretrained UNet")
|
||||||
|
cnet_path = convert_diffusion_diffusers_cnet(
|
||||||
|
conversion,
|
||||||
|
name,
|
||||||
|
temp_path,
|
||||||
|
device,
|
||||||
|
Path(dest_path),
|
||||||
|
dtype,
|
||||||
|
unet_in_channels,
|
||||||
|
unet_sample_size,
|
||||||
|
num_tokens,
|
||||||
|
text_hidden_size,
|
||||||
|
v2=v2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if conversion.half:
|
||||||
|
infer_shapes_path(cnet_path)
|
||||||
|
cnet = load_model(cnet_path)
|
||||||
|
opt_model = convert_float_to_float16(
|
||||||
|
cnet,
|
||||||
|
disable_shape_infer=True,
|
||||||
|
force_fp16_initializers=True,
|
||||||
|
keep_io_types=True,
|
||||||
|
op_block_list=["Attention", "MultiHeadAttention"],
|
||||||
|
)
|
||||||
|
save_model(
|
||||||
|
opt_model,
|
||||||
|
cnet_path,
|
||||||
|
save_as_external_data=True,
|
||||||
|
all_tensors_to_one_file=True,
|
||||||
|
location=ONNX_WEIGHTS,
|
||||||
|
)
|
||||||
|
|
||||||
return (True, dest_path)
|
return (True, dest_path)
|
||||||
|
|
Loading…
Reference in New Issue