fix devices, make subdir
This commit is contained in:
parent
3f4b3fa322
commit
3dfaef041e
|
@ -1,4 +1,4 @@
|
||||||
from os import mkdir, path
|
from os import mkdirs, path
|
||||||
from huggingface_hub.file_download import hf_hub_download
|
from huggingface_hub.file_download import hf_hub_download
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
@ -12,13 +12,13 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def convert_diffusion_textual_inversion(context: ConversionContext, name: str, base_model: str, inversion: str):
|
def convert_diffusion_textual_inversion(context: ConversionContext, name: str, base_model: str, inversion: str):
|
||||||
cache_path = path.join(context.cache_path, f"inversion-{name}")
|
dest_path = path.join(context.model_path, f"inversion-{name}")
|
||||||
logger.info("converting Textual Inversion: %s + %s -> %s", base_model, inversion, cache_path)
|
logger.info("converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path)
|
||||||
|
|
||||||
if path.exists(cache_path):
|
if path.exists(dest_path):
|
||||||
logger.info("ONNX model already exists, skipping.")
|
logger.info("ONNX model already exists, skipping.")
|
||||||
|
|
||||||
mkdir(cache_path)
|
mkdirs(path.join(dest_path, "text_encoder"))
|
||||||
|
|
||||||
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
|
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
|
||||||
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
|
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
|
||||||
|
@ -29,11 +29,11 @@ def convert_diffusion_textual_inversion(context: ConversionContext, name: str, b
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
subfolder="tokenizer",
|
subfolder="tokenizer",
|
||||||
).to(context.training_device)
|
)
|
||||||
text_encoder = CLIPTextModel.from_pretrained(
|
text_encoder = CLIPTextModel.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
subfolder="text_encoder",
|
subfolder="text_encoder",
|
||||||
).to(context.training_device)
|
)
|
||||||
|
|
||||||
loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
|
loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
|
||||||
|
|
||||||
|
@ -72,9 +72,9 @@ def convert_diffusion_textual_inversion(context: ConversionContext, name: str, b
|
||||||
text_encoder,
|
text_encoder,
|
||||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||||
(
|
(
|
||||||
text_input.input_ids.to(device=context.training_device, dtype=torch.int32)
|
text_input.input_ids.to(dtype=torch.int32)
|
||||||
),
|
),
|
||||||
f=path.join(cache_path, "text_encoder", "model.onnx"),
|
f=path.join(dest_path, "text_encoder", "model.onnx"),
|
||||||
input_names=["input_ids"],
|
input_names=["input_ids"],
|
||||||
output_names=["last_hidden_state", "pooler_output"],
|
output_names=["last_hidden_state", "pooler_output"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
|
|
Loading…
Reference in New Issue