1
0
Fork 0

fix devices, make subdir

This commit is contained in:
Sean Sube 2023-02-21 22:49:34 -06:00
parent 3f4b3fa322
commit 3dfaef041e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 9 additions and 9 deletions

View File

@ -1,4 +1,4 @@
from os import mkdir, path
from os import mkdirs, path
from huggingface_hub.file_download import hf_hub_download
from transformers import CLIPTokenizer, CLIPTextModel
from torch.onnx import export
@ -12,13 +12,13 @@ logger = getLogger(__name__)
def convert_diffusion_textual_inversion(context: ConversionContext, name: str, base_model: str, inversion: str):
cache_path = path.join(context.cache_path, f"inversion-{name}")
logger.info("converting Textual Inversion: %s + %s -> %s", base_model, inversion, cache_path)
dest_path = path.join(context.model_path, f"inversion-{name}")
logger.info("converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path)
if path.exists(cache_path):
if path.exists(dest_path):
logger.info("ONNX model already exists, skipping.")
mkdir(cache_path)
mkdirs(path.join(dest_path, "text_encoder"))
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
@ -29,11 +29,11 @@ def convert_diffusion_textual_inversion(context: ConversionContext, name: str, b
tokenizer = CLIPTokenizer.from_pretrained(
base_model,
subfolder="tokenizer",
).to(context.training_device)
)
text_encoder = CLIPTextModel.from_pretrained(
base_model,
subfolder="text_encoder",
).to(context.training_device)
)
loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
@ -72,9 +72,9 @@ def convert_diffusion_textual_inversion(context: ConversionContext, name: str, b
text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
(
text_input.input_ids.to(device=context.training_device, dtype=torch.int32)
text_input.input_ids.to(dtype=torch.int32)
),
f=path.join(cache_path, "text_encoder", "model.onnx"),
f=path.join(dest_path, "text_encoder", "model.onnx"),
input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"],
dynamic_axes={