fix(api): better logging when converting textual inversions
This commit is contained in:
parent
46aac263d5
commit
9a0d2051fb
|
@ -49,6 +49,8 @@ def convert_diffusion_textual_inversion(
|
|||
# separate token and embeds
|
||||
trained_token = list(string_to_token.keys())[0]
|
||||
embeds = string_to_param[trained_token]
|
||||
else:
|
||||
raise ValueError(f"unknown textual inversion format: {format}")
|
||||
|
||||
logger.info("found embedding for token %s: %s", trained_token, embeds.shape)
|
||||
|
||||
|
@ -88,8 +90,10 @@ def convert_diffusion_textual_inversion(
|
|||
return_tensors="pt",
|
||||
)
|
||||
|
||||
logger.info("saving tokenizer for textual inversion")
|
||||
tokenizer.save_pretrained(path.join(dest_path, "tokenizer"))
|
||||
|
||||
logger.info("saving text encoder for textual inversion")
|
||||
export(
|
||||
text_encoder,
|
||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||
|
@ -103,3 +107,5 @@ def convert_diffusion_textual_inversion(
|
|||
do_constant_folding=True,
|
||||
opset_version=context.opset,
|
||||
)
|
||||
|
||||
logger.info("textual inversion saved to %s", dest_path)
|
||||
|
|
Loading…
Reference in New Issue