1
0
Fork 0
onnx-web/api/onnx_web/convert/diffusion/textual_inversion.py

259 lines
8.1 KiB
Python

from logging import getLogger
from os import makedirs, path
from typing import List, Optional, Tuple
import numpy as np
import torch
from onnx import ModelProto, load_model, numpy_helper, save_model
from transformers import CLIPTokenizer
from ...constants import ONNX_MODEL
from ...server.context import ServerContext
from ..utils import ConversionContext, load_tensor
logger = getLogger(__name__)
def detect_embedding_format(loaded_embeds) -> str:
keys: List[str] = list(loaded_embeds.keys())
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
logger.debug("detected Textual Inversion concept: %s", keys)
return "concept"
elif "emb_params" in keys:
logger.debug("detected Textual Inversion parameter embeddings: %s", keys)
return "parameters"
elif "string_to_token" in keys and "string_to_param" in keys:
logger.debug("detected Textual Inversion token embeddings: %s", keys)
return "embeddings"
else:
logger.error("unknown Textual Inversion format, no recognized keys: %s", keys)
return None
def blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight):
# separate token and the embeds
token = list(loaded_embeds.keys())[0]
layer = loaded_embeds[token].numpy().astype(dtype)
layer *= weight
if base_token in embeds:
embeds[base_token] += layer
else:
embeds[base_token] = layer
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
def blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight):
emb_params = loaded_embeds["emb_params"]
num_tokens = emb_params.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
sum_layer = np.zeros(emb_params[0, :].shape)
for i in range(num_tokens):
token = f"{base_token}-{i}"
layer = emb_params[i, :].numpy().astype(dtype)
layer *= weight
sum_layer += layer
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
# add base and sum tokens to embeds
if base_token in embeds:
embeds[base_token] += sum_layer
else:
embeds[base_token] = sum_layer
sum_token = f"{base_token}-all"
if sum_token in embeds:
embeds[sum_token] += sum_layer
else:
embeds[sum_token] = sum_layer
def blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight):
string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]
# separate token and embeds
token = list(string_to_token.keys())[0]
trained_embeds = string_to_param[token]
num_tokens = trained_embeds.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
sum_layer = np.zeros(trained_embeds[0, :].shape)
for i in range(num_tokens):
token = f"{base_token}-{i}"
layer = trained_embeds[i, :].numpy().astype(dtype)
layer *= weight
sum_layer += layer
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
# add base and sum tokens to embeds
if base_token in embeds:
embeds[base_token] += sum_layer
else:
embeds[base_token] = sum_layer
sum_token = f"{base_token}-all"
if sum_token in embeds:
embeds[sum_token] += sum_layer
else:
embeds[sum_token] = sum_layer
def blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens):
# resize the token embeddings
# text_encoder.resize_token_embeddings(len(tokenizer))
embedding_node = [
n
for n in text_encoder.graph.initializer
if n.name == "text_model.embeddings.token_embedding.weight"
][0]
base_weights = numpy_helper.to_array(embedding_node)
weights_dim = base_weights.shape[1]
zero_weights = np.zeros((num_added_tokens, weights_dim))
embedding_weights = np.concatenate((base_weights, zero_weights), axis=0)
for token, weights in embeds.items():
token_id = tokenizer.convert_tokens_to_ids(token)
logger.trace("embedding %s weights for token %s", weights.shape, token)
embedding_weights[token_id] = weights
# replace embedding_node
for i in range(len(text_encoder.graph.initializer)):
if (
text_encoder.graph.initializer[i].name
== "text_model.embeddings.token_embedding.weight"
):
new_initializer = numpy_helper.from_array(
embedding_weights.astype(base_weights.dtype), embedding_node.name
)
logger.trace("new initializer data type: %s", new_initializer.data_type)
del text_encoder.graph.initializer[i]
text_encoder.graph.initializer.insert(i, new_initializer)
@torch.no_grad()
def blend_textual_inversions(
server: ServerContext,
text_encoder: ModelProto,
tokenizer: CLIPTokenizer,
embeddings: List[Tuple[str, float, Optional[str], Optional[str]]],
) -> Tuple[ModelProto, CLIPTokenizer]:
# always load to CPU for blending
device = torch.device("cpu")
dtype = np.float32
embeds = {}
for name, weight, base_token, format in embeddings:
if base_token is None:
logger.debug("no base token provided, using name: %s", name)
base_token = name
logger.info(
"blending Textual Inversion %s with weight of %s for token %s",
name,
weight,
base_token,
)
loaded_embeds = load_tensor(name, map_location=device)
if loaded_embeds is None:
logger.warning("unable to load tensor")
continue
if format is None:
format = detect_embedding_format(loaded_embeds)
if format == "concept":
blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight)
elif format == "parameters":
blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight)
elif format == "embeddings":
blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight)
else:
raise ValueError(f"unknown Textual Inversion format: {format}")
# add the tokens to the tokenizer
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0:
raise ValueError(
"The tokenizer already contains the tokens. Please pass a different `token` that is not already in the tokenizer."
)
logger.trace("added %s tokens", num_added_tokens)
blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens)
return (text_encoder, tokenizer)
@torch.no_grad()
def convert_diffusion_textual_inversion(
conversion: ConversionContext,
name: str,
base_model: str,
inversion: str,
inversion_format: str,
base_token: Optional[str] = None,
inversion_weight: Optional[float] = 1.0,
):
dest_path = path.join(conversion.model_path, f"inversion-{name}")
logger.info(
"converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path
)
encoder_path = path.join(dest_path, "text_encoder")
encoder_model = path.join(encoder_path, ONNX_MODEL)
tokenizer_path = path.join(dest_path, "tokenizer")
if (
path.exists(dest_path)
and path.exists(encoder_model)
and path.exists(tokenizer_path)
):
logger.info("ONNX model already exists, skipping.")
return
makedirs(encoder_path, exist_ok=True)
text_encoder = load_model(path.join(base_model, "text_encoder", ONNX_MODEL))
tokenizer = CLIPTokenizer.from_pretrained(
base_model,
subfolder="tokenizer",
)
text_encoder, tokenizer = blend_textual_inversions(
conversion,
text_encoder,
tokenizer,
[(inversion, inversion_weight, base_token, inversion_format)],
)
logger.info("saving tokenizer for textual inversion")
tokenizer.save_pretrained(tokenizer_path)
logger.info("saving text encoder for textual inversion")
save_model(
text_encoder,
f=encoder_model,
)
logger.info("textual inversion saved to %s", dest_path)