feat(api): initial support for textual inversion embeddings from civitai and others (#179)
This commit is contained in:
parent
1f3a5f6f3c
commit
46aac263d5
|
@ -234,11 +234,16 @@ def convert_models(ctx: ConversionContext, args, models: Models):
|
||||||
for inversion in model.get("inversions", []):
|
for inversion in model.get("inversions", []):
|
||||||
inversion_name = inversion["name"]
|
inversion_name = inversion["name"]
|
||||||
inversion_source = inversion["source"]
|
inversion_source = inversion["source"]
|
||||||
|
inversion_format = inversion.get("format", "huggingface")
|
||||||
inversion_source = fetch_model(
|
inversion_source = fetch_model(
|
||||||
ctx, f"{name}-inversion-{inversion_name}", inversion_source
|
ctx, f"{name}-inversion-{inversion_name}", inversion_source
|
||||||
)
|
)
|
||||||
convert_diffusion_textual_inversion(
|
convert_diffusion_textual_inversion(
|
||||||
ctx, inversion_name, model["source"], inversion_source
|
ctx,
|
||||||
|
inversion_name,
|
||||||
|
model["source"],
|
||||||
|
inversion_source,
|
||||||
|
inversion_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -13,7 +13,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_diffusion_textual_inversion(
|
def convert_diffusion_textual_inversion(
|
||||||
context: ConversionContext, name: str, base_model: str, inversion: str
|
context: ConversionContext, name: str, base_model: str, inversion: str, format: str
|
||||||
):
|
):
|
||||||
dest_path = path.join(context.model_path, f"inversion-{name}")
|
dest_path = path.join(context.model_path, f"inversion-{name}")
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -26,12 +26,32 @@ def convert_diffusion_textual_inversion(
|
||||||
|
|
||||||
makedirs(path.join(dest_path, "text_encoder"), exist_ok=True)
|
makedirs(path.join(dest_path, "text_encoder"), exist_ok=True)
|
||||||
|
|
||||||
|
if format == "huggingface":
|
||||||
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
|
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
|
||||||
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
|
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
|
||||||
|
|
||||||
with open(token_file, "r") as f:
|
with open(token_file, "r") as f:
|
||||||
token = f.read()
|
token = f.read()
|
||||||
|
|
||||||
|
loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
|
||||||
|
|
||||||
|
# separate token and the embeds
|
||||||
|
trained_token = list(loaded_embeds.keys())[0]
|
||||||
|
embeds = loaded_embeds[trained_token]
|
||||||
|
elif format == "embeddings":
|
||||||
|
loaded_embeds = torch.load(inversion, map_location=context.map_location)
|
||||||
|
|
||||||
|
string_to_token = loaded_embeds["string_to_token"]
|
||||||
|
string_to_param = loaded_embeds["string_to_param"]
|
||||||
|
|
||||||
|
token = name
|
||||||
|
|
||||||
|
# separate token and embeds
|
||||||
|
trained_token = list(string_to_token.keys())[0]
|
||||||
|
embeds = string_to_param[trained_token]
|
||||||
|
|
||||||
|
logger.info("found embedding for token %s: %s", trained_token, embeds.shape)
|
||||||
|
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
subfolder="tokenizer",
|
subfolder="tokenizer",
|
||||||
|
@ -41,12 +61,6 @@ def convert_diffusion_textual_inversion(
|
||||||
subfolder="text_encoder",
|
subfolder="text_encoder",
|
||||||
)
|
)
|
||||||
|
|
||||||
loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
|
|
||||||
|
|
||||||
# separate token and the embeds
|
|
||||||
trained_token = list(loaded_embeds.keys())[0]
|
|
||||||
embeds = loaded_embeds[trained_token]
|
|
||||||
|
|
||||||
# cast to dtype of text_encoder
|
# cast to dtype of text_encoder
|
||||||
dtype = text_encoder.get_input_embeddings().weight.dtype
|
dtype = text_encoder.get_input_embeddings().weight.dtype
|
||||||
embeds.to(dtype)
|
embeds.to(dtype)
|
||||||
|
|
Loading…
Reference in New Issue