1
0
Fork 0

feat(api): look for an index file when checking for converted models (#222)

This commit is contained in:
Sean Sube 2023-03-07 23:40:04 -06:00
parent 78005812f3
commit 843e2f1ff3
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 17 additions and 10 deletions

View File

@ -65,7 +65,7 @@ def onnx_export(
)
if half:
logger.info("converting model to FP16 internally: %s", output_file)
logger.info("converting model to fp16 internally: %s", output_file)
infer_shapes_path(output_file)
base_model = load_model(output_file)
opt_model = convert_float_to_float16(
@ -99,6 +99,7 @@ def convert_diffusion_diffusers(
dtype = torch.float32 # torch.float16 if ctx.half else torch.float32
dest_path = path.join(ctx.model_path, name)
model_index = path.join(dest_path, "model_index.json")
# diffusers go into a directory rather than .onnx file
logger.info(
@ -108,7 +109,7 @@ def convert_diffusion_diffusers(
if single_vae:
logger.info("converting model with single VAE")
if path.exists(dest_path):
if path.exists(dest_path) and path.exists(model_index):
logger.info("ONNX model already exists, skipping")
return

View File

@ -1667,20 +1667,22 @@ def convert_diffusion_original(
name = model["name"]
source = source or model["source"]
dest = os.path.join(ctx.model_path, name)
dest_path = os.path.join(ctx.model_path, name)
dest_index = os.path.join(dest_path, "model_index.json")
logger.info(
"converting original Diffusers checkpoint %s: %s -> %s", name, source, dest
"converting original Diffusers checkpoint %s: %s -> %s", name, source, dest_path
)
if os.path.exists(dest):
if os.path.exists(dest_path) and os.path.exists(dest_index):
logger.info("ONNX pipeline already exists, skipping")
return
torch_name = name + "-torch"
torch_path = os.path.join(ctx.cache_path, torch_name)
working_name = os.path.join(ctx.cache_path, torch_name, "working")
model_index = os.path.join(working_name, "model_index.json")
if os.path.exists(torch_path):
if os.path.exists(torch_path) and os.path.exists(model_index):
logger.info("torch pipeline already exists, reusing: %s", torch_path)
else:
logger.info(

View File

@ -26,11 +26,15 @@ def convert_diffusion_textual_inversion(
"converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path
)
if path.exists(dest_path):
encoder_path = path.join(dest_path, "text_encoder")
encoder_model = path.join(encoder_path, "model.onnx")
tokenizer_path = path.join(dest_path, "tokenizer")
if path.exists(dest_path) and path.exists(encoder_model) and path.exists(tokenizer_path):
logger.info("ONNX model already exists, skipping.")
return
makedirs(path.join(dest_path, "text_encoder"), exist_ok=True)
makedirs(encoder_path, exist_ok=True)
if format == "concept":
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
@ -112,14 +116,14 @@ def convert_diffusion_textual_inversion(
)
logger.info("saving tokenizer for textual inversion")
tokenizer.save_pretrained(path.join(dest_path, "tokenizer"))
tokenizer.save_pretrained(tokenizer_path)
logger.info("saving text encoder for textual inversion")
export(
text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
(text_input.input_ids.to(dtype=torch.int32)),
f=path.join(dest_path, "text_encoder", "model.onnx"),
f=encoder_model,
input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"],
dynamic_axes={