feat(api): look for an index file when checking for converted models (#222)
This commit is contained in:
parent
78005812f3
commit
843e2f1ff3
|
@ -65,7 +65,7 @@ def onnx_export(
|
||||||
)
|
)
|
||||||
|
|
||||||
if half:
|
if half:
|
||||||
logger.info("converting model to FP16 internally: %s", output_file)
|
logger.info("converting model to fp16 internally: %s", output_file)
|
||||||
infer_shapes_path(output_file)
|
infer_shapes_path(output_file)
|
||||||
base_model = load_model(output_file)
|
base_model = load_model(output_file)
|
||||||
opt_model = convert_float_to_float16(
|
opt_model = convert_float_to_float16(
|
||||||
|
@ -99,6 +99,7 @@ def convert_diffusion_diffusers(
|
||||||
|
|
||||||
dtype = torch.float32 # torch.float16 if ctx.half else torch.float32
|
dtype = torch.float32 # torch.float16 if ctx.half else torch.float32
|
||||||
dest_path = path.join(ctx.model_path, name)
|
dest_path = path.join(ctx.model_path, name)
|
||||||
|
model_index = path.join(dest_path, "model_index.json")
|
||||||
|
|
||||||
# diffusers go into a directory rather than .onnx file
|
# diffusers go into a directory rather than .onnx file
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -108,7 +109,7 @@ def convert_diffusion_diffusers(
|
||||||
if single_vae:
|
if single_vae:
|
||||||
logger.info("converting model with single VAE")
|
logger.info("converting model with single VAE")
|
||||||
|
|
||||||
if path.exists(dest_path):
|
if path.exists(dest_path) and path.exists(model_index):
|
||||||
logger.info("ONNX model already exists, skipping")
|
logger.info("ONNX model already exists, skipping")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -1667,20 +1667,22 @@ def convert_diffusion_original(
|
||||||
name = model["name"]
|
name = model["name"]
|
||||||
source = source or model["source"]
|
source = source or model["source"]
|
||||||
|
|
||||||
dest = os.path.join(ctx.model_path, name)
|
dest_path = os.path.join(ctx.model_path, name)
|
||||||
|
dest_index = os.path.join(dest_path, "model_index.json")
|
||||||
logger.info(
|
logger.info(
|
||||||
"converting original Diffusers checkpoint %s: %s -> %s", name, source, dest
|
"converting original Diffusers checkpoint %s: %s -> %s", name, source, dest_path
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.path.exists(dest):
|
if os.path.exists(dest_path) and os.path.exists(dest_index):
|
||||||
logger.info("ONNX pipeline already exists, skipping")
|
logger.info("ONNX pipeline already exists, skipping")
|
||||||
return
|
return
|
||||||
|
|
||||||
torch_name = name + "-torch"
|
torch_name = name + "-torch"
|
||||||
torch_path = os.path.join(ctx.cache_path, torch_name)
|
torch_path = os.path.join(ctx.cache_path, torch_name)
|
||||||
working_name = os.path.join(ctx.cache_path, torch_name, "working")
|
working_name = os.path.join(ctx.cache_path, torch_name, "working")
|
||||||
|
model_index = os.path.join(working_name, "model_index.json")
|
||||||
|
|
||||||
if os.path.exists(torch_path):
|
if os.path.exists(torch_path) and os.path.exists(model_index):
|
||||||
logger.info("torch pipeline already exists, reusing: %s", torch_path)
|
logger.info("torch pipeline already exists, reusing: %s", torch_path)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
@ -26,11 +26,15 @@ def convert_diffusion_textual_inversion(
|
||||||
"converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path
|
"converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path
|
||||||
)
|
)
|
||||||
|
|
||||||
if path.exists(dest_path):
|
encoder_path = path.join(dest_path, "text_encoder")
|
||||||
|
encoder_model = path.join(encoder_path, "model.onnx")
|
||||||
|
tokenizer_path = path.join(dest_path, "tokenizer")
|
||||||
|
|
||||||
|
if path.exists(dest_path) and path.exists(encoder_model) and path.exists(tokenizer_path):
|
||||||
logger.info("ONNX model already exists, skipping.")
|
logger.info("ONNX model already exists, skipping.")
|
||||||
return
|
return
|
||||||
|
|
||||||
makedirs(path.join(dest_path, "text_encoder"), exist_ok=True)
|
makedirs(encoder_path, exist_ok=True)
|
||||||
|
|
||||||
if format == "concept":
|
if format == "concept":
|
||||||
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
|
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
|
||||||
|
@ -112,14 +116,14 @@ def convert_diffusion_textual_inversion(
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("saving tokenizer for textual inversion")
|
logger.info("saving tokenizer for textual inversion")
|
||||||
tokenizer.save_pretrained(path.join(dest_path, "tokenizer"))
|
tokenizer.save_pretrained(tokenizer_path)
|
||||||
|
|
||||||
logger.info("saving text encoder for textual inversion")
|
logger.info("saving text encoder for textual inversion")
|
||||||
export(
|
export(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||||
(text_input.input_ids.to(dtype=torch.int32)),
|
(text_input.input_ids.to(dtype=torch.int32)),
|
||||||
f=path.join(dest_path, "text_encoder", "model.onnx"),
|
f=encoder_model,
|
||||||
input_names=["input_ids"],
|
input_names=["input_ids"],
|
||||||
output_names=["last_hidden_state", "pooler_output"],
|
output_names=["last_hidden_state", "pooler_output"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
|
|
Loading…
Reference in New Issue