2023-02-22 05:50:27 +00:00
|
|
|
from logging import getLogger
|
2023-02-22 04:50:59 +00:00
|
|
|
from os import makedirs, path
|
2023-03-15 22:14:52 +00:00
|
|
|
from typing import List, Optional, Tuple
|
2023-02-22 05:50:27 +00:00
|
|
|
|
2023-03-15 22:14:52 +00:00
|
|
|
import numpy as np
|
2023-02-22 05:50:27 +00:00
|
|
|
import torch
|
2023-03-15 22:14:52 +00:00
|
|
|
from onnx import ModelProto, load_model, numpy_helper, save_model
|
|
|
|
from transformers import CLIPTokenizer
|
2023-02-21 05:07:16 +00:00
|
|
|
|
2023-03-24 13:14:19 +00:00
|
|
|
from ...constants import ONNX_MODEL
|
2023-03-15 22:14:52 +00:00
|
|
|
from ...server.context import ServerContext
|
2023-03-19 20:13:54 +00:00
|
|
|
from ..utils import ConversionContext, load_tensor
|
2023-02-21 05:07:16 +00:00
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
2023-03-15 22:14:52 +00:00
|
|
|
@torch.no_grad()
|
|
|
|
def blend_textual_inversions(
|
2023-04-10 01:33:03 +00:00
|
|
|
server: ServerContext,
|
2023-03-18 15:50:48 +00:00
|
|
|
text_encoder: ModelProto,
|
|
|
|
tokenizer: CLIPTokenizer,
|
|
|
|
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
|
2023-03-15 22:14:52 +00:00
|
|
|
) -> Tuple[ModelProto, CLIPTokenizer]:
|
2023-03-19 20:13:54 +00:00
|
|
|
# always load to CPU for blending
|
|
|
|
device = torch.device("cpu")
|
2023-03-22 03:05:14 +00:00
|
|
|
dtype = np.float32
|
2023-03-15 22:14:52 +00:00
|
|
|
embeds = {}
|
|
|
|
|
2023-03-18 16:55:06 +00:00
|
|
|
for name, weight, base_token, inversion_format in inversions:
|
2023-03-18 15:50:48 +00:00
|
|
|
if base_token is None:
|
2023-03-20 01:16:52 +00:00
|
|
|
logger.debug("no base token provided, using name: %s", name)
|
2023-03-18 15:50:48 +00:00
|
|
|
base_token = name
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
"blending Textual Inversion %s with weight of %s for token %s",
|
|
|
|
name,
|
|
|
|
weight,
|
|
|
|
base_token,
|
|
|
|
)
|
|
|
|
|
2023-03-20 01:16:52 +00:00
|
|
|
loaded_embeds = load_tensor(name, map_location=device)
|
|
|
|
if loaded_embeds is None:
|
|
|
|
logger.warning("unable to load tensor")
|
|
|
|
continue
|
2023-03-15 22:14:52 +00:00
|
|
|
|
2023-03-20 01:16:52 +00:00
|
|
|
if inversion_format is None:
|
|
|
|
keys: List[str] = list(loaded_embeds.keys())
|
|
|
|
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
|
|
|
|
logger.debug("detected Textual Inversion concept: %s", keys)
|
|
|
|
inversion_format = "concept"
|
2023-04-22 17:07:52 +00:00
|
|
|
elif "emb_params" in keys:
|
2023-04-22 17:28:46 +00:00
|
|
|
logger.debug(
|
|
|
|
"detected Textual Inversion parameter embeddings: %s", keys
|
|
|
|
)
|
2023-04-22 17:07:52 +00:00
|
|
|
inversion_format = "parameters"
|
2023-03-20 01:16:52 +00:00
|
|
|
elif "string_to_token" in keys and "string_to_param" in keys:
|
2023-04-22 17:07:52 +00:00
|
|
|
logger.debug("detected Textual Inversion token embeddings: %s", keys)
|
2023-03-20 01:16:52 +00:00
|
|
|
inversion_format = "embeddings"
|
|
|
|
else:
|
|
|
|
logger.error(
|
|
|
|
"unknown Textual Inversion format, no recognized keys: %s", keys
|
|
|
|
)
|
2023-03-19 20:38:43 +00:00
|
|
|
continue
|
2023-03-15 22:14:52 +00:00
|
|
|
|
2023-03-20 01:16:52 +00:00
|
|
|
if inversion_format == "concept":
|
2023-03-15 22:14:52 +00:00
|
|
|
# separate token and the embeds
|
2023-03-20 01:16:52 +00:00
|
|
|
token = list(loaded_embeds.keys())[0]
|
2023-03-15 22:14:52 +00:00
|
|
|
|
2023-03-20 01:16:52 +00:00
|
|
|
layer = loaded_embeds[token].numpy().astype(dtype)
|
2023-03-15 22:14:52 +00:00
|
|
|
layer *= weight
|
2023-03-20 00:01:22 +00:00
|
|
|
|
|
|
|
if base_token in embeds:
|
|
|
|
embeds[base_token] += layer
|
|
|
|
else:
|
|
|
|
embeds[base_token] = layer
|
|
|
|
|
|
|
|
if token in embeds:
|
2023-03-15 22:14:52 +00:00
|
|
|
embeds[token] += layer
|
|
|
|
else:
|
|
|
|
embeds[token] = layer
|
2023-04-22 17:07:52 +00:00
|
|
|
elif inversion_format == "parameters":
|
|
|
|
emb_params = loaded_embeds["emb_params"]
|
|
|
|
|
|
|
|
num_tokens = emb_params.shape[0]
|
|
|
|
logger.debug("generating %s layer tokens for %s", num_tokens, name)
|
|
|
|
|
|
|
|
sum_layer = np.zeros(emb_params[0, :].shape)
|
|
|
|
|
|
|
|
for i in range(num_tokens):
|
|
|
|
token = f"{base_token}-{i}"
|
|
|
|
layer = emb_params[i, :].numpy().astype(dtype)
|
|
|
|
layer *= weight
|
|
|
|
|
|
|
|
sum_layer += layer
|
|
|
|
if token in embeds:
|
|
|
|
embeds[token] += layer
|
|
|
|
else:
|
|
|
|
embeds[token] = layer
|
|
|
|
|
|
|
|
# add base and sum tokens to embeds
|
|
|
|
if base_token in embeds:
|
|
|
|
embeds[base_token] += sum_layer
|
|
|
|
else:
|
|
|
|
embeds[base_token] = sum_layer
|
|
|
|
|
|
|
|
sum_token = f"{base_token}-all"
|
|
|
|
if sum_token in embeds:
|
|
|
|
embeds[sum_token] += sum_layer
|
|
|
|
else:
|
|
|
|
embeds[sum_token] = sum_layer
|
2023-03-18 16:55:06 +00:00
|
|
|
elif inversion_format == "embeddings":
|
2023-03-15 22:14:52 +00:00
|
|
|
string_to_token = loaded_embeds["string_to_token"]
|
|
|
|
string_to_param = loaded_embeds["string_to_param"]
|
|
|
|
|
|
|
|
# separate token and embeds
|
2023-03-20 01:16:52 +00:00
|
|
|
token = list(string_to_token.keys())[0]
|
|
|
|
trained_embeds = string_to_param[token]
|
2023-03-15 22:14:52 +00:00
|
|
|
|
|
|
|
num_tokens = trained_embeds.shape[0]
|
2023-03-17 00:37:25 +00:00
|
|
|
logger.debug("generating %s layer tokens for %s", num_tokens, name)
|
2023-03-15 22:14:52 +00:00
|
|
|
|
2023-03-18 12:47:36 +00:00
|
|
|
sum_layer = np.zeros(trained_embeds[0, :].shape)
|
|
|
|
|
2023-03-15 22:14:52 +00:00
|
|
|
for i in range(num_tokens):
|
2023-03-18 15:50:48 +00:00
|
|
|
token = f"{base_token}-{i}"
|
2023-03-19 20:13:54 +00:00
|
|
|
layer = trained_embeds[i, :].numpy().astype(dtype)
|
2023-03-15 22:14:52 +00:00
|
|
|
layer *= weight
|
2023-03-18 15:50:48 +00:00
|
|
|
|
2023-03-18 12:47:36 +00:00
|
|
|
sum_layer += layer
|
2023-03-15 22:14:52 +00:00
|
|
|
if token in embeds:
|
|
|
|
embeds[token] += layer
|
|
|
|
else:
|
|
|
|
embeds[token] = layer
|
2023-03-18 12:47:36 +00:00
|
|
|
|
2023-03-19 23:04:21 +00:00
|
|
|
# add base and sum tokens to embeds
|
|
|
|
if base_token in embeds:
|
|
|
|
embeds[base_token] += sum_layer
|
|
|
|
else:
|
|
|
|
embeds[base_token] = sum_layer
|
|
|
|
|
2023-03-18 15:50:48 +00:00
|
|
|
sum_token = f"{base_token}-all"
|
2023-03-18 12:47:36 +00:00
|
|
|
if sum_token in embeds:
|
|
|
|
embeds[sum_token] += sum_layer
|
|
|
|
else:
|
|
|
|
embeds[sum_token] = sum_layer
|
2023-03-15 22:14:52 +00:00
|
|
|
else:
|
2023-03-18 16:55:06 +00:00
|
|
|
raise ValueError(f"unknown Textual Inversion format: {inversion_format}")
|
2023-03-15 22:14:52 +00:00
|
|
|
|
|
|
|
# add the tokens to the tokenizer
|
2023-03-17 00:37:25 +00:00
|
|
|
logger.debug(
|
2023-03-18 18:16:59 +00:00
|
|
|
"found embeddings for %s tokens: %s",
|
|
|
|
len(embeds.keys()),
|
|
|
|
list(embeds.keys()),
|
2023-03-16 00:27:29 +00:00
|
|
|
)
|
2023-03-15 22:14:52 +00:00
|
|
|
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
|
|
|
|
if num_added_tokens == 0:
|
|
|
|
raise ValueError(
|
|
|
|
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
|
|
|
|
)
|
|
|
|
|
2023-03-17 00:37:25 +00:00
|
|
|
logger.trace("added %s tokens", num_added_tokens)
|
2023-03-15 22:14:52 +00:00
|
|
|
|
|
|
|
# resize the token embeddings
|
|
|
|
# text_encoder.resize_token_embeddings(len(tokenizer))
|
2023-03-16 00:27:29 +00:00
|
|
|
embedding_node = [
|
|
|
|
n
|
|
|
|
for n in text_encoder.graph.initializer
|
|
|
|
if n.name == "text_model.embeddings.token_embedding.weight"
|
|
|
|
][0]
|
2023-03-22 03:05:14 +00:00
|
|
|
base_weights = numpy_helper.to_array(embedding_node)
|
2023-03-15 22:14:52 +00:00
|
|
|
|
2023-03-22 03:05:14 +00:00
|
|
|
weights_dim = base_weights.shape[1]
|
2023-03-15 22:14:52 +00:00
|
|
|
zero_weights = np.zeros((num_added_tokens, weights_dim))
|
2023-03-22 03:05:14 +00:00
|
|
|
embedding_weights = np.concatenate((base_weights, zero_weights), axis=0)
|
2023-03-15 22:14:52 +00:00
|
|
|
|
|
|
|
for token, weights in embeds.items():
|
|
|
|
token_id = tokenizer.convert_tokens_to_ids(token)
|
2023-03-17 00:37:25 +00:00
|
|
|
logger.trace("embedding %s weights for token %s", weights.shape, token)
|
2023-03-15 22:14:52 +00:00
|
|
|
embedding_weights[token_id] = weights
|
|
|
|
|
|
|
|
# replace embedding_node
|
|
|
|
for i in range(len(text_encoder.graph.initializer)):
|
2023-03-16 00:27:29 +00:00
|
|
|
if (
|
|
|
|
text_encoder.graph.initializer[i].name
|
|
|
|
== "text_model.embeddings.token_embedding.weight"
|
|
|
|
):
|
|
|
|
new_initializer = numpy_helper.from_array(
|
2023-03-22 03:05:14 +00:00
|
|
|
embedding_weights.astype(base_weights.dtype), embedding_node.name
|
2023-03-16 00:27:29 +00:00
|
|
|
)
|
2023-03-17 00:37:25 +00:00
|
|
|
logger.trace("new initializer data type: %s", new_initializer.data_type)
|
2023-03-15 22:14:52 +00:00
|
|
|
del text_encoder.graph.initializer[i]
|
|
|
|
text_encoder.graph.initializer.insert(i, new_initializer)
|
|
|
|
|
|
|
|
return (text_encoder, tokenizer)
|
|
|
|
|
|
|
|
|
2023-03-01 14:26:40 +00:00
|
|
|
@torch.no_grad()
|
2023-02-22 05:50:27 +00:00
|
|
|
def convert_diffusion_textual_inversion(
|
2023-04-10 01:33:03 +00:00
|
|
|
conversion: ConversionContext,
|
2023-03-05 13:19:48 +00:00
|
|
|
name: str,
|
|
|
|
base_model: str,
|
|
|
|
inversion: str,
|
2023-03-18 16:50:09 +00:00
|
|
|
inversion_format: str,
|
2023-03-05 13:19:48 +00:00
|
|
|
base_token: Optional[str] = None,
|
2023-03-18 16:50:09 +00:00
|
|
|
inversion_weight: Optional[float] = 1.0,
|
2023-02-22 05:50:27 +00:00
|
|
|
):
|
2023-04-10 01:33:03 +00:00
|
|
|
dest_path = path.join(conversion.model_path, f"inversion-{name}")
|
2023-02-22 05:50:27 +00:00
|
|
|
logger.info(
|
|
|
|
"converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path
|
|
|
|
)
|
2023-02-21 05:07:16 +00:00
|
|
|
|
2023-03-08 05:40:04 +00:00
|
|
|
encoder_path = path.join(dest_path, "text_encoder")
|
2023-03-24 13:14:19 +00:00
|
|
|
encoder_model = path.join(encoder_path, ONNX_MODEL)
|
2023-03-08 05:40:04 +00:00
|
|
|
tokenizer_path = path.join(dest_path, "tokenizer")
|
|
|
|
|
2023-03-08 05:57:39 +00:00
|
|
|
if (
|
|
|
|
path.exists(dest_path)
|
|
|
|
and path.exists(encoder_model)
|
|
|
|
and path.exists(tokenizer_path)
|
|
|
|
):
|
2023-02-22 03:40:57 +00:00
|
|
|
logger.info("ONNX model already exists, skipping.")
|
2023-02-22 05:50:27 +00:00
|
|
|
return
|
2023-02-22 03:40:57 +00:00
|
|
|
|
2023-03-08 05:40:04 +00:00
|
|
|
makedirs(encoder_path, exist_ok=True)
|
2023-02-21 05:07:16 +00:00
|
|
|
|
2023-03-24 13:14:19 +00:00
|
|
|
text_encoder = load_model(path.join(base_model, "text_encoder", ONNX_MODEL))
|
2023-02-21 05:07:16 +00:00
|
|
|
tokenizer = CLIPTokenizer.from_pretrained(
|
|
|
|
base_model,
|
|
|
|
subfolder="tokenizer",
|
2023-02-22 04:49:34 +00:00
|
|
|
)
|
2023-03-15 22:14:52 +00:00
|
|
|
text_encoder, tokenizer = blend_textual_inversions(
|
2023-04-10 01:33:03 +00:00
|
|
|
conversion,
|
2023-03-15 22:14:52 +00:00
|
|
|
text_encoder,
|
|
|
|
tokenizer,
|
2023-03-18 16:50:09 +00:00
|
|
|
[(inversion, inversion_weight, base_token, inversion_format)],
|
2023-02-21 05:07:16 +00:00
|
|
|
)
|
|
|
|
|
2023-03-18 16:50:09 +00:00
|
|
|
logger.info("saving tokenizer for textual inversion")
|
2023-03-08 05:40:04 +00:00
|
|
|
tokenizer.save_pretrained(tokenizer_path)
|
2023-03-02 01:08:31 +00:00
|
|
|
|
2023-03-18 16:50:09 +00:00
|
|
|
logger.info("saving text encoder for textual inversion")
|
2023-03-15 22:14:52 +00:00
|
|
|
save_model(
|
2023-02-21 05:07:16 +00:00
|
|
|
text_encoder,
|
2023-03-08 05:40:04 +00:00
|
|
|
f=encoder_model,
|
2023-02-21 05:07:16 +00:00
|
|
|
)
|
2023-03-02 13:57:59 +00:00
|
|
|
|
2023-03-18 16:50:09 +00:00
|
|
|
logger.info("textual inversion saved to %s", dest_path)
|