1
0
Fork 0

feat(api): initial support for textual inversion embeddings from civitai and others (#179)

This commit is contained in:
Sean Sube 2023-03-01 19:09:51 -06:00
parent 1f3a5f6f3c
commit 46aac263d5
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 31 additions and 12 deletions

View File

@ -234,11 +234,16 @@ def convert_models(ctx: ConversionContext, args, models: Models):
for inversion in model.get("inversions", []): for inversion in model.get("inversions", []):
inversion_name = inversion["name"] inversion_name = inversion["name"]
inversion_source = inversion["source"] inversion_source = inversion["source"]
inversion_format = inversion.get("format", "huggingface")
inversion_source = fetch_model( inversion_source = fetch_model(
ctx, f"{name}-inversion-{inversion_name}", inversion_source ctx, f"{name}-inversion-{inversion_name}", inversion_source
) )
convert_diffusion_textual_inversion( convert_diffusion_textual_inversion(
ctx, inversion_name, model["source"], inversion_source ctx,
inversion_name,
model["source"],
inversion_source,
inversion_format,
) )
except Exception as e: except Exception as e:

View File

@ -13,7 +13,7 @@ logger = getLogger(__name__)
@torch.no_grad() @torch.no_grad()
def convert_diffusion_textual_inversion( def convert_diffusion_textual_inversion(
context: ConversionContext, name: str, base_model: str, inversion: str context: ConversionContext, name: str, base_model: str, inversion: str, format: str
): ):
dest_path = path.join(context.model_path, f"inversion-{name}") dest_path = path.join(context.model_path, f"inversion-{name}")
logger.info( logger.info(
@ -26,12 +26,32 @@ def convert_diffusion_textual_inversion(
makedirs(path.join(dest_path, "text_encoder"), exist_ok=True) makedirs(path.join(dest_path, "text_encoder"), exist_ok=True)
if format == "huggingface":
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin") embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt") token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
with open(token_file, "r") as f: with open(token_file, "r") as f:
token = f.read() token = f.read()
loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
# separate token and the embeds
trained_token = list(loaded_embeds.keys())[0]
embeds = loaded_embeds[trained_token]
elif format == "embeddings":
loaded_embeds = torch.load(inversion, map_location=context.map_location)
string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]
token = name
# separate token and embeds
trained_token = list(string_to_token.keys())[0]
embeds = string_to_param[trained_token]
logger.info("found embedding for token %s: %s", trained_token, embeds.shape)
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
base_model, base_model,
subfolder="tokenizer", subfolder="tokenizer",
@ -41,12 +61,6 @@ def convert_diffusion_textual_inversion(
subfolder="text_encoder", subfolder="text_encoder",
) )
loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
# separate token and the embeds
trained_token = list(loaded_embeds.keys())[0]
embeds = loaded_embeds[trained_token]
# cast to dtype of text_encoder # cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype dtype = text_encoder.get_input_embeddings().weight.dtype
embeds.to(dtype) embeds.to(dtype)