1
0
Fork 0

fix(api): better logging when converting textual inversions

This commit is contained in:
Sean Sube 2023-03-02 07:57:59 -06:00
parent 46aac263d5
commit 9a0d2051fb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 6 additions and 0 deletions

View File

@ -49,6 +49,8 @@ def convert_diffusion_textual_inversion(
# separate token and embeds
trained_token = list(string_to_token.keys())[0]
embeds = string_to_param[trained_token]
else:
raise ValueError(f"unknown textual inversion format: {format}")
logger.info("found embedding for token %s: %s", trained_token, embeds.shape)
@ -88,8 +90,10 @@ def convert_diffusion_textual_inversion(
return_tensors="pt",
)
logger.info("saving tokenizer for textual inversion")
tokenizer.save_pretrained(path.join(dest_path, "tokenizer"))
logger.info("saving text encoder for textual inversion")
export(
text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
@ -103,3 +107,5 @@ def convert_diffusion_textual_inversion(
do_constant_folding=True,
opset_version=context.opset,
)
logger.info("textual inversion saved to %s", dest_path)