diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index 14ffa703..acd9331d 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -1,4 +1,4 @@ -from os import mkdir, path +from os import mkdirs, path from huggingface_hub.file_download import hf_hub_download from transformers import CLIPTokenizer, CLIPTextModel from torch.onnx import export @@ -12,13 +12,13 @@ logger = getLogger(__name__) def convert_diffusion_textual_inversion(context: ConversionContext, name: str, base_model: str, inversion: str): - cache_path = path.join(context.cache_path, f"inversion-{name}") - logger.info("converting Textual Inversion: %s + %s -> %s", base_model, inversion, cache_path) + dest_path = path.join(context.model_path, f"inversion-{name}") + logger.info("converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path) - if path.exists(cache_path): + if path.exists(dest_path): logger.info("ONNX model already exists, skipping.") - mkdir(cache_path) + mkdirs(path.join(dest_path, "text_encoder")) embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin") token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt") @@ -29,11 +29,11 @@ def convert_diffusion_textual_inversion(context: ConversionContext, name: str, b tokenizer = CLIPTokenizer.from_pretrained( base_model, subfolder="tokenizer", - ).to(context.training_device) + ) text_encoder = CLIPTextModel.from_pretrained( base_model, subfolder="text_encoder", - ).to(context.training_device) + ) loaded_embeds = torch.load(embeds_file, map_location=context.map_location) @@ -72,9 +72,9 @@ def convert_diffusion_textual_inversion(context: ConversionContext, name: str, b text_encoder, # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files ( - text_input.input_ids.to(device=context.training_device, dtype=torch.int32) + text_input.input_ids.to(dtype=torch.int32) ), - f=path.join(cache_path, "text_encoder", "model.onnx"), + f=path.join(dest_path, "text_encoder", "model.onnx"), input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output"], dynamic_axes={