1
0
Fork 0

feat(api): blend Textual Inversions from prompt

This commit is contained in:
Sean Sube 2023-03-15 17:14:52 -05:00
parent 973ad0f682
commit 506cf9f65f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 164 additions and 125 deletions

View File

@ -15,16 +15,11 @@ from onnx.external_data_helper import (
from onnxruntime import InferenceSession, OrtValue, SessionOptions from onnxruntime import InferenceSession, OrtValue, SessionOptions
from safetensors.torch import load_file from safetensors.torch import load_file
from onnx_web.convert.utils import ConversionContext from ..utils import ConversionContext
logger = getLogger(__name__) logger = getLogger(__name__)
###
# everything in this file is still super experimental and may not produce valid ONNX models
###
def buffer_external_data_tensors( def buffer_external_data_tensors(
model: ModelProto, model: ModelProto,
) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]: ) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]:
@ -32,7 +27,7 @@ def buffer_external_data_tensors(
for tensor in model.graph.initializer: for tensor in model.graph.initializer:
name = tensor.name name = tensor.name
logger.info("externalizing tensor: %s", name) logger.debug("externalizing tensor: %s", name)
if tensor.HasField("raw_data"): if tensor.HasField("raw_data"):
npt = numpy_helper.to_array(tensor) npt = numpy_helper.to_array(tensor)
orv = OrtValue.ortvalue_from_numpy(npt) orv = OrtValue.ortvalue_from_numpy(npt)
@ -59,13 +54,13 @@ def fix_node_name(key: str):
return fixed_name return fixed_name
def merge_lora( def blend_loras(
base_name: str, base_name: str,
lora_names: List[str], lora_names: List[str],
dest_type: Literal["text_encoder", "unet"], dest_type: Literal["text_encoder", "unet"],
lora_weights: "np.NDArray[np.float64]" = None, lora_weights: "np.NDArray[np.float64]" = None,
): ):
base_model = load(base_name) base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
lora_models = [load_file(name) for name in lora_names] lora_models = [load_file(name) for name in lora_names]
lora_count = len(lora_models) lora_count = len(lora_models)
lora_weights = lora_weights or (np.ones((lora_count)) / lora_count) lora_weights = lora_weights or (np.ones((lora_count)) / lora_count)
@ -86,7 +81,7 @@ def merge_lora(
up_key = key.replace("lora_down", "lora_up") up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha" alpha_key = key[: key.index("lora_down")] + "alpha"
logger.info( logger.debug(
"blending weights for keys: %s, %s, %s", key, up_key, alpha_key "blending weights for keys: %s, %s, %s", key, up_key, alpha_key
) )
@ -99,7 +94,7 @@ def merge_lora(
try: try:
if len(up_weight.size()) == 2: if len(up_weight.size()) == 2:
# blend for nn.Linear # blend for nn.Linear
logger.info( logger.debug(
"blending weights for Linear node: %s, %s, %s", "blending weights for Linear node: %s, %s, %s",
down_weight.shape, down_weight.shape,
up_weight.shape, up_weight.shape,
@ -109,7 +104,7 @@ def merge_lora(
np_weights = weights.numpy() * (alpha / dim) np_weights = weights.numpy() * (alpha / dim)
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1): elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
# blend for nn.Conv2d 1x1 # blend for nn.Conv2d 1x1
logger.info( logger.debug(
"blending weights for Conv node: %s, %s, %s", "blending weights for Conv node: %s, %s, %s",
down_weight.shape, down_weight.shape,
up_weight.shape, up_weight.shape,
@ -161,7 +156,7 @@ def merge_lora(
conv_key = base_key + "_Conv" conv_key = base_key + "_Conv"
matmul_key = base_key + "_MatMul" matmul_key = base_key + "_MatMul"
logger.info( logger.debug(
"key %s has conv: %s, matmul: %s", "key %s has conv: %s, matmul: %s",
base_key, base_key,
conv_key in fixed_node_names, conv_key in fixed_node_names,
@ -171,20 +166,20 @@ def merge_lora(
if conv_key in fixed_node_names: if conv_key in fixed_node_names:
conv_idx = fixed_node_names.index(conv_key) conv_idx = fixed_node_names.index(conv_key)
conv_node = base_model.graph.node[conv_idx] conv_node = base_model.graph.node[conv_idx]
logger.info("found conv node: %s", conv_node.name) logger.debug("found conv node: %s", conv_node.name)
# find weight initializer # find weight initializer
logger.info("conv inputs: %s", conv_node.input) logger.debug("conv inputs: %s", conv_node.input)
weight_name = [n for n in conv_node.input if ".weight" in n][0] weight_name = [n for n in conv_node.input if ".weight" in n][0]
weight_name = fix_initializer_name(weight_name) weight_name = fix_initializer_name(weight_name)
weight_idx = fixed_initializer_names.index(weight_name) weight_idx = fixed_initializer_names.index(weight_name)
weight_node = base_model.graph.initializer[weight_idx] weight_node = base_model.graph.initializer[weight_idx]
logger.info("found weight initializer: %s", weight_node.name) logger.debug("found weight initializer: %s", weight_node.name)
# blending # blending
base_weights = numpy_helper.to_array(weight_node) base_weights = numpy_helper.to_array(weight_node)
logger.info( logger.debug(
"found blended weights for conv: %s, %s", "found blended weights for conv: %s, %s",
weights.shape, weights.shape,
base_weights.shape, base_weights.shape,
@ -192,7 +187,7 @@ def merge_lora(
blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2)) blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
blended = np.expand_dims(blended, (2, 3)) blended = np.expand_dims(blended, (2, 3))
logger.info("blended weight shape: %s", blended.shape) logger.debug("blended weight shape: %s", blended.shape)
# replace the original initializer # replace the original initializer
updated_node = numpy_helper.from_array(blended, weight_node.name) updated_node = numpy_helper.from_array(blended, weight_node.name)
@ -201,33 +196,33 @@ def merge_lora(
elif matmul_key in fixed_node_names: elif matmul_key in fixed_node_names:
weight_idx = fixed_node_names.index(matmul_key) weight_idx = fixed_node_names.index(matmul_key)
weight_node = base_model.graph.node[weight_idx] weight_node = base_model.graph.node[weight_idx]
logger.info("found matmul node: %s", weight_node.name) logger.debug("found matmul node: %s", weight_node.name)
# find the MatMul initializer # find the MatMul initializer
logger.info("matmul inputs: %s", weight_node.input) logger.debug("matmul inputs: %s", weight_node.input)
matmul_name = [n for n in weight_node.input if "MatMul" in n][0] matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
matmul_idx = fixed_initializer_names.index(matmul_name) matmul_idx = fixed_initializer_names.index(matmul_name)
matmul_node = base_model.graph.initializer[matmul_idx] matmul_node = base_model.graph.initializer[matmul_idx]
logger.info("found matmul initializer: %s", matmul_node.name) logger.debug("found matmul initializer: %s", matmul_node.name)
# blending # blending
base_weights = numpy_helper.to_array(matmul_node) base_weights = numpy_helper.to_array(matmul_node)
logger.info( logger.debug(
"found blended weights for matmul: %s, %s", "found blended weights for matmul: %s, %s",
weights.shape, weights.shape,
base_weights.shape, base_weights.shape,
) )
blended = base_weights + weights.transpose() blended = base_weights + weights.transpose()
logger.info("blended weight shape: %s", blended.shape) logger.debug("blended weight shape: %s", blended.shape)
# replace the original initializer # replace the original initializer
updated_node = numpy_helper.from_array(blended, matmul_node.name) updated_node = numpy_helper.from_array(blended, matmul_node.name)
del base_model.graph.initializer[matmul_idx] del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, updated_node) base_model.graph.initializer.insert(matmul_idx, updated_node)
else: else:
logger.info("could not find any nodes for %s", base_key) logger.warning("could not find any nodes for %s", base_key)
logger.info( logger.info(
"node counts: %s -> %s, %s -> %s", "node counts: %s -> %s, %s -> %s",
@ -256,7 +251,7 @@ if __name__ == "__main__":
args.lora_weights, args.lora_weights,
) )
blend_model = merge_lora(args.base, args.lora_models, args.type, args.lora_weights) blend_model = blend_loras(args.base, args.lora_models, args.type, args.lora_weights)
if args.dest is None or args.dest == "" or args.dest == "ort": if args.dest is None or args.dest == "" or args.dest == "ort":
# convert to external data and save to memory # convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model) (bare_model, external_data) = buffer_external_data_tensors(blend_model)

View File

@ -1,17 +1,115 @@
from logging import getLogger from logging import getLogger
from os import makedirs, path from os import makedirs, path
from typing import Optional from typing import List, Optional, Tuple
import numpy as np
import torch import torch
from huggingface_hub.file_download import hf_hub_download from huggingface_hub.file_download import hf_hub_download
from onnx import ModelProto, load_model, numpy_helper, save_model
from torch.onnx import export from torch.onnx import export
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTokenizer
from ...server.context import ServerContext
from ..utils import ConversionContext from ..utils import ConversionContext
logger = getLogger(__name__) logger = getLogger(__name__)
@torch.no_grad()
def blend_textual_inversions(
context: ServerContext,
text_encoder: Optional[ModelProto],
tokenizer: Optional[CLIPTokenizer],
inversion_names: List[str],
inversion_formats: List[str],
inversion_weights: Optional[List[float]] = None,
base_tokens: Optional[List[str]] = None,
) -> Tuple[ModelProto, CLIPTokenizer]:
dtype = np.float # TODO: fixed type, which one?
# prev: text_encoder.get_input_embeddings().weight.dtype
embeds = {}
for name, format, weight, base_token in zip(inversion_names, inversion_formats, inversion_weights, base_tokens or inversion_names):
logger.info("blending Textual Inversion %s with weight of %s", name, weight)
if format == "concept":
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt")
with open(token_file, "r") as f:
token = base_token or f.read()
loaded_embeds = torch.load(embeds_file)
# separate token and the embeds
trained_token = list(loaded_embeds.keys())[0]
layer = loaded_embeds[trained_token].cpu().numpy().astype(dtype)
layer *= weight
if trained_token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
elif format == "embeddings":
loaded_embeds = torch.load(name)
string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]
# separate token and embeds
trained_token = list(string_to_token.keys())[0]
trained_embeds = string_to_param[trained_token]
num_tokens = trained_embeds.shape[0]
logger.debug("generating %s layer tokens", num_tokens)
for i in range(num_tokens):
token = f"{base_token or name}-{i}"
layer = trained_embeds[i,:].cpu().numpy().astype(dtype)
layer *= weight
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
else:
raise ValueError(f"unknown Textual Inversion format: {format}")
# add the tokens to the tokenizer
logger.info("found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys())
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
)
logger.debug("added %s tokens", num_added_tokens)
# resize the token embeddings
# text_encoder.resize_token_embeddings(len(tokenizer))
embedding_node = [n for n in text_encoder.graph.initializer if n.name == "text_model.embeddings.token_embedding.weight"][0]
embedding_weights = numpy_helper.to_array(embedding_node)
weights_dim = embedding_weights.shape[1]
zero_weights = np.zeros((num_added_tokens, weights_dim))
embedding_weights = np.concatenate((embedding_weights, zero_weights), axis=0)
for token, weights in embeds.items():
token_id = tokenizer.convert_tokens_to_ids(token)
logger.debug(
"embedding %s weights for token %s", weights.shape, token
)
embedding_weights[token_id] = weights
# replace embedding_node
for i in range(len(text_encoder.graph.initializer)):
if text_encoder.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight":
new_initializer = numpy_helper.from_array(embedding_weights.astype(np.float32), embedding_node.name)
logger.debug("new initializer data type: %s", new_initializer.data_type)
del text_encoder.graph.initializer[i]
text_encoder.graph.initializer.insert(i, new_initializer)
return (text_encoder, tokenizer)
@torch.no_grad() @torch.no_grad()
def convert_diffusion_textual_inversion( def convert_diffusion_textual_inversion(
context: ConversionContext, context: ConversionContext,
@ -40,101 +138,28 @@ def convert_diffusion_textual_inversion(
makedirs(encoder_path, exist_ok=True) makedirs(encoder_path, exist_ok=True)
if format == "concept": text_encoder = load_model(path.join(base_model, "text_encoder", "model.onnx"))
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
with open(token_file, "r") as f:
token = base_token or f.read()
loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
# separate token and the embeds
trained_token = list(loaded_embeds.keys())[0]
embeds = loaded_embeds[trained_token]
elif format == "embeddings":
loaded_embeds = torch.load(inversion, map_location=context.map_location)
string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]
# separate token and embeds
trained_token = list(string_to_token.keys())[0]
embeds = string_to_param[trained_token]
num_tokens = embeds.shape[0]
logger.info("generating %s layer tokens", num_tokens)
token = [f"{base_token or name}-{i}" for i in range(num_tokens)]
else:
raise ValueError(f"unknown textual inversion format: {format}")
logger.info("found embeddings for token %s: %s", token, embeds.shape)
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
base_model, base_model,
subfolder="tokenizer", subfolder="tokenizer",
) )
text_encoder = CLIPTextModel.from_pretrained( text_encoder, tokenizer = blend_textual_inversions(
base_model, context,
subfolder="text_encoder", text_encoder,
) tokenizer,
[inversion],
# cast to dtype of text_encoder [format],
dtype = text_encoder.get_input_embeddings().weight.dtype [1.0],
embeds.to(dtype) base_token=(base_token or name),
# add the token in tokenizer
num_added_tokens = tokenizer.add_tokens(token)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
)
logger.info("added %s tokens", num_added_tokens)
# resize the token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
if len(embeds.shape) == 2:
# multiple vectors in embeds
for i in range(embeds.shape[0]):
layer_embeds = embeds[i]
layer_token = token[i]
logger.debug(
"embedding %s vector for layer %s", layer_embeds.shape, layer_token
)
token_id = tokenizer.convert_tokens_to_ids(layer_token)
text_encoder.get_input_embeddings().weight.data[token_id] = layer_embeds
else:
# get the id for the token and assign the embeds
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
# conversion stuff
text_input = tokenizer(
"A sample prompt",
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
) )
logger.info("saving tokenizer for textual inversion") logger.info("saving tokenizer for textual inversion")
tokenizer.save_pretrained(tokenizer_path) tokenizer.save_pretrained(tokenizer_path)
logger.info("saving text encoder for textual inversion") logger.info("saving text encoder for textual inversion")
export( save_model(
text_encoder, text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
(text_input.input_ids.to(dtype=torch.int32)),
f=encoder_model, f=encoder_model,
input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
},
do_constant_folding=True,
opset_version=context.opset,
) )
logger.info("textual inversion saved to %s", dest_path) logger.info("textual inversion saved to %s", dest_path)

View File

@ -1,6 +1,5 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from re import compile
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple
import numpy as np import numpy as np
@ -22,6 +21,7 @@ from diffusers import (
PNDMScheduler, PNDMScheduler,
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from onnx import load_model
from onnxruntime import SessionOptions from onnxruntime import SessionOptions
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
@ -37,7 +37,8 @@ try:
except ImportError: except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
from ..convert.diffusion.lora import merge_lora, buffer_external_data_tensors from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..params import DeviceParams, Size from ..params import DeviceParams, Size
from ..server import ServerContext from ..server import ServerContext
from ..utils import run_gc from ..utils import run_gc
@ -215,18 +216,35 @@ def load_pipeline(
) )
} }
text_encoder = None
if inversions is not None and len(inversions) > 0: if inversions is not None and len(inversions) > 0:
inversion = "inversion-" + inversions[0][0] inversion_names, inversion_weights = zip(*inversions)
logger.debug("loading Textual Inversion from %s", inversion) logger.debug("blending Textual Inversions from %s", inversion_names)
# TODO: blend the inversion models
components["text_encoder"] = OnnxRuntimeModel.from_pretrained( inversion_models = [path.join(server.model_path, "inversion", f"{name}.ckpt") for name in inversion_names]
path.join(server.model_path, inversion, "text_encoder"), text_encoder = load_model(path.join(model, "text_encoder", "model.onnx"))
tokenizer = CLIPTokenizer.from_pretrained(
model,
subfolder="tokenizer",
)
text_encoder, tokenizer = blend_textual_inversions(
server,
text_encoder,
tokenizer,
inversion_models,
["embeddings"] * len(inversion_names),
inversion_weights,
base_tokens=inversion_names,
)
# should be pretty small and should not need external data
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
provider=device.ort_provider(), provider=device.ort_provider(),
sess_options=device.sess_options(),
) )
components["tokenizer"] = CLIPTokenizer.from_pretrained(
path.join(server.model_path, inversion, "tokenizer"),
) )
components["tokenizer"] = tokenizer
# test LoRA blending # test LoRA blending
if loras is not None and len(loras) > 0: if loras is not None and len(loras) > 0:
@ -235,21 +253,22 @@ def load_pipeline(
logger.info("blending base model %s with LoRA models: %s", model, lora_models) logger.info("blending base model %s with LoRA models: %s", model, lora_models)
# blend and load text encoder # blend and load text encoder
blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder", lora_weights=lora_weights) text_encoder = text_encoder or path.join(model, "text_encoder", "model.onnx")
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder) blended_text_encoder = blend_loras(text_encoder, lora_models, "text_encoder", lora_weights=lora_weights)
(text_encoder, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
text_encoder_names, text_encoder_values = zip(*text_encoder_data) text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = SessionOptions() text_encoder_opts = SessionOptions()
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values)) text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
components["text_encoder"] = OnnxRuntimeModel( components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model( OnnxRuntimeModel.load_model(
text_encoder_model.SerializeToString(), text_encoder.SerializeToString(),
provider=device.ort_provider(), provider=device.ort_provider(),
sess_options=text_encoder_opts, sess_options=text_encoder_opts,
) )
) )
# blend and load unet # blend and load unet
blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights) blended_unet = blend_loras(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights)
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet) (unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data) unet_names, unet_values = zip(*unet_data)
unet_opts = SessionOptions() unet_opts = SessionOptions()