2023-02-22 05:50:27 +00:00
|
|
|
from logging import getLogger
|
2023-02-22 04:50:59 +00:00
|
|
|
from os import makedirs, path
|
2023-02-22 05:50:27 +00:00
|
|
|
|
|
|
|
import torch
|
2023-02-21 05:07:16 +00:00
|
|
|
from huggingface_hub.file_download import hf_hub_download
|
|
|
|
from torch.onnx import export
|
2023-02-22 05:50:27 +00:00
|
|
|
from transformers import CLIPTextModel, CLIPTokenizer
|
2023-02-21 05:07:16 +00:00
|
|
|
|
2023-02-22 03:40:57 +00:00
|
|
|
from ..utils import ConversionContext
|
2023-02-21 05:07:16 +00:00
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
2023-02-22 05:50:27 +00:00
|
|
|
def convert_diffusion_textual_inversion(
|
|
|
|
context: ConversionContext, name: str, base_model: str, inversion: str
|
|
|
|
):
|
2023-02-22 04:49:34 +00:00
|
|
|
dest_path = path.join(context.model_path, f"inversion-{name}")
|
2023-02-22 05:50:27 +00:00
|
|
|
logger.info(
|
|
|
|
"converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path
|
|
|
|
)
|
2023-02-21 05:07:16 +00:00
|
|
|
|
2023-02-22 04:49:34 +00:00
|
|
|
if path.exists(dest_path):
|
2023-02-22 03:40:57 +00:00
|
|
|
logger.info("ONNX model already exists, skipping.")
|
2023-02-22 05:50:27 +00:00
|
|
|
return
|
2023-02-22 03:40:57 +00:00
|
|
|
|
2023-02-22 05:50:27 +00:00
|
|
|
makedirs(path.join(dest_path, "text_encoder"), exist_ok=True)
|
2023-02-21 05:07:16 +00:00
|
|
|
|
|
|
|
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
|
|
|
|
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
|
|
|
|
|
|
|
|
with open(token_file, "r") as f:
|
|
|
|
token = f.read()
|
|
|
|
|
|
|
|
tokenizer = CLIPTokenizer.from_pretrained(
|
|
|
|
base_model,
|
|
|
|
subfolder="tokenizer",
|
2023-02-22 04:49:34 +00:00
|
|
|
)
|
2023-02-21 05:07:16 +00:00
|
|
|
text_encoder = CLIPTextModel.from_pretrained(
|
|
|
|
base_model,
|
|
|
|
subfolder="text_encoder",
|
2023-02-22 04:49:34 +00:00
|
|
|
)
|
2023-02-21 05:07:16 +00:00
|
|
|
|
|
|
|
loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
|
|
|
|
|
|
|
|
# separate token and the embeds
|
|
|
|
trained_token = list(loaded_embeds.keys())[0]
|
|
|
|
embeds = loaded_embeds[trained_token]
|
|
|
|
|
|
|
|
# cast to dtype of text_encoder
|
|
|
|
dtype = text_encoder.get_input_embeddings().weight.dtype
|
|
|
|
embeds.to(dtype)
|
|
|
|
|
|
|
|
# add the token in tokenizer
|
|
|
|
num_added_tokens = tokenizer.add_tokens(token)
|
|
|
|
if num_added_tokens == 0:
|
|
|
|
raise ValueError(
|
|
|
|
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
|
|
|
|
)
|
|
|
|
|
|
|
|
# resize the token embeddings
|
|
|
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
|
|
|
|
|
|
|
# get the id for the token and assign the embeds
|
|
|
|
token_id = tokenizer.convert_tokens_to_ids(token)
|
|
|
|
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
|
|
|
|
|
|
|
# conversion stuff
|
|
|
|
text_input = tokenizer(
|
|
|
|
"A sample prompt",
|
|
|
|
padding="max_length",
|
|
|
|
max_length=tokenizer.model_max_length,
|
|
|
|
truncation=True,
|
|
|
|
return_tensors="pt",
|
|
|
|
)
|
|
|
|
|
|
|
|
export(
|
|
|
|
text_encoder,
|
|
|
|
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
2023-02-22 05:50:27 +00:00
|
|
|
(text_input.input_ids.to(dtype=torch.int32)),
|
2023-02-22 04:49:34 +00:00
|
|
|
f=path.join(dest_path, "text_encoder", "model.onnx"),
|
2023-02-21 05:07:16 +00:00
|
|
|
input_names=["input_ids"],
|
|
|
|
output_names=["last_hidden_state", "pooler_output"],
|
|
|
|
dynamic_axes={
|
|
|
|
"input_ids": {0: "batch", 1: "sequence"},
|
|
|
|
},
|
|
|
|
do_constant_folding=True,
|
|
|
|
opset_version=context.opset,
|
|
|
|
)
|