From 506cf9f65f8cbb5224ebd72377ec0004a9448585 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 15 Mar 2023 17:14:52 -0500 Subject: [PATCH] feat(api): blend Textual Inversions from prompt --- api/onnx_web/convert/diffusion/lora.py | 45 ++-- .../convert/diffusion/textual_inversion.py | 195 ++++++++++-------- api/onnx_web/diffusers/load.py | 49 +++-- 3 files changed, 164 insertions(+), 125 deletions(-) diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 47cda3fb..7c113d95 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -15,16 +15,11 @@ from onnx.external_data_helper import ( from onnxruntime import InferenceSession, OrtValue, SessionOptions from safetensors.torch import load_file -from onnx_web.convert.utils import ConversionContext +from ..utils import ConversionContext logger = getLogger(__name__) -### -# everything in this file is still super experimental and may not produce valid ONNX models -### - - def buffer_external_data_tensors( model: ModelProto, ) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]: @@ -32,7 +27,7 @@ def buffer_external_data_tensors( for tensor in model.graph.initializer: name = tensor.name - logger.info("externalizing tensor: %s", name) + logger.debug("externalizing tensor: %s", name) if tensor.HasField("raw_data"): npt = numpy_helper.to_array(tensor) orv = OrtValue.ortvalue_from_numpy(npt) @@ -59,13 +54,13 @@ def fix_node_name(key: str): return fixed_name -def merge_lora( +def blend_loras( base_name: str, lora_names: List[str], dest_type: Literal["text_encoder", "unet"], lora_weights: "np.NDArray[np.float64]" = None, ): - base_model = load(base_name) + base_model = base_name if isinstance(base_name, ModelProto) else load(base_name) lora_models = [load_file(name) for name in lora_names] lora_count = len(lora_models) lora_weights = lora_weights or (np.ones((lora_count)) / lora_count) @@ -86,7 +81,7 @@ def merge_lora( up_key = key.replace("lora_down", "lora_up") alpha_key = key[: key.index("lora_down")] + "alpha" - logger.info( + logger.debug( "blending weights for keys: %s, %s, %s", key, up_key, alpha_key ) @@ -99,7 +94,7 @@ def merge_lora( try: if len(up_weight.size()) == 2: # blend for nn.Linear - logger.info( + logger.debug( "blending weights for Linear node: %s, %s, %s", down_weight.shape, up_weight.shape, @@ -109,7 +104,7 @@ def merge_lora( np_weights = weights.numpy() * (alpha / dim) elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1): # blend for nn.Conv2d 1x1 - logger.info( + logger.debug( "blending weights for Conv node: %s, %s, %s", down_weight.shape, up_weight.shape, @@ -161,7 +156,7 @@ def merge_lora( conv_key = base_key + "_Conv" matmul_key = base_key + "_MatMul" - logger.info( + logger.debug( "key %s has conv: %s, matmul: %s", base_key, conv_key in fixed_node_names, @@ -171,20 +166,20 @@ def merge_lora( if conv_key in fixed_node_names: conv_idx = fixed_node_names.index(conv_key) conv_node = base_model.graph.node[conv_idx] - logger.info("found conv node: %s", conv_node.name) + logger.debug("found conv node: %s", conv_node.name) # find weight initializer - logger.info("conv inputs: %s", conv_node.input) + logger.debug("conv inputs: %s", conv_node.input) weight_name = [n for n in conv_node.input if ".weight" in n][0] weight_name = fix_initializer_name(weight_name) weight_idx = fixed_initializer_names.index(weight_name) weight_node = base_model.graph.initializer[weight_idx] - logger.info("found weight initializer: %s", weight_node.name) + logger.debug("found weight initializer: %s", weight_node.name) # blending base_weights = numpy_helper.to_array(weight_node) - logger.info( + logger.debug( "found blended weights for conv: %s, %s", weights.shape, base_weights.shape, @@ -192,7 +187,7 @@ def merge_lora( blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2)) blended = np.expand_dims(blended, (2, 3)) - logger.info("blended weight shape: %s", blended.shape) + logger.debug("blended weight shape: %s", blended.shape) # replace the original initializer updated_node = numpy_helper.from_array(blended, weight_node.name) @@ -201,33 +196,33 @@ def merge_lora( elif matmul_key in fixed_node_names: weight_idx = fixed_node_names.index(matmul_key) weight_node = base_model.graph.node[weight_idx] - logger.info("found matmul node: %s", weight_node.name) + logger.debug("found matmul node: %s", weight_node.name) # find the MatMul initializer - logger.info("matmul inputs: %s", weight_node.input) + logger.debug("matmul inputs: %s", weight_node.input) matmul_name = [n for n in weight_node.input if "MatMul" in n][0] matmul_idx = fixed_initializer_names.index(matmul_name) matmul_node = base_model.graph.initializer[matmul_idx] - logger.info("found matmul initializer: %s", matmul_node.name) + logger.debug("found matmul initializer: %s", matmul_node.name) # blending base_weights = numpy_helper.to_array(matmul_node) - logger.info( + logger.debug( "found blended weights for matmul: %s, %s", weights.shape, base_weights.shape, ) blended = base_weights + weights.transpose() - logger.info("blended weight shape: %s", blended.shape) + logger.debug("blended weight shape: %s", blended.shape) # replace the original initializer updated_node = numpy_helper.from_array(blended, matmul_node.name) del base_model.graph.initializer[matmul_idx] base_model.graph.initializer.insert(matmul_idx, updated_node) else: - logger.info("could not find any nodes for %s", base_key) + logger.warning("could not find any nodes for %s", base_key) logger.info( "node counts: %s -> %s, %s -> %s", @@ -256,7 +251,7 @@ if __name__ == "__main__": args.lora_weights, ) - blend_model = merge_lora(args.base, args.lora_models, args.type, args.lora_weights) + blend_model = blend_loras(args.base, args.lora_models, args.type, args.lora_weights) if args.dest is None or args.dest == "" or args.dest == "ort": # convert to external data and save to memory (bare_model, external_data) = buffer_external_data_tensors(blend_model) diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index e05a9e14..3ffe8415 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -1,17 +1,115 @@ from logging import getLogger from os import makedirs, path -from typing import Optional +from typing import List, Optional, Tuple +import numpy as np import torch from huggingface_hub.file_download import hf_hub_download +from onnx import ModelProto, load_model, numpy_helper, save_model from torch.onnx import export -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTokenizer +from ...server.context import ServerContext from ..utils import ConversionContext logger = getLogger(__name__) +@torch.no_grad() +def blend_textual_inversions( + context: ServerContext, + text_encoder: Optional[ModelProto], + tokenizer: Optional[CLIPTokenizer], + inversion_names: List[str], + inversion_formats: List[str], + inversion_weights: Optional[List[float]] = None, + base_tokens: Optional[List[str]] = None, +) -> Tuple[ModelProto, CLIPTokenizer]: + dtype = np.float # TODO: fixed type, which one? + # prev: text_encoder.get_input_embeddings().weight.dtype + embeds = {} + + for name, format, weight, base_token in zip(inversion_names, inversion_formats, inversion_weights, base_tokens or inversion_names): + logger.info("blending Textual Inversion %s with weight of %s", name, weight) + if format == "concept": + embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin") + token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt") + + with open(token_file, "r") as f: + token = base_token or f.read() + + loaded_embeds = torch.load(embeds_file) + + # separate token and the embeds + trained_token = list(loaded_embeds.keys())[0] + + layer = loaded_embeds[trained_token].cpu().numpy().astype(dtype) + layer *= weight + if trained_token in embeds: + embeds[token] += layer + else: + embeds[token] = layer + elif format == "embeddings": + loaded_embeds = torch.load(name) + + string_to_token = loaded_embeds["string_to_token"] + string_to_param = loaded_embeds["string_to_param"] + + # separate token and embeds + trained_token = list(string_to_token.keys())[0] + trained_embeds = string_to_param[trained_token] + + num_tokens = trained_embeds.shape[0] + logger.debug("generating %s layer tokens", num_tokens) + + for i in range(num_tokens): + token = f"{base_token or name}-{i}" + layer = trained_embeds[i,:].cpu().numpy().astype(dtype) + layer *= weight + if token in embeds: + embeds[token] += layer + else: + embeds[token] = layer + else: + raise ValueError(f"unknown Textual Inversion format: {format}") + + # add the tokens to the tokenizer + logger.info("found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys()) + num_added_tokens = tokenizer.add_tokens(list(embeds.keys())) + if num_added_tokens == 0: + raise ValueError( + f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer." + ) + + logger.debug("added %s tokens", num_added_tokens) + + # resize the token embeddings + # text_encoder.resize_token_embeddings(len(tokenizer)) + embedding_node = [n for n in text_encoder.graph.initializer if n.name == "text_model.embeddings.token_embedding.weight"][0] + embedding_weights = numpy_helper.to_array(embedding_node) + + weights_dim = embedding_weights.shape[1] + zero_weights = np.zeros((num_added_tokens, weights_dim)) + embedding_weights = np.concatenate((embedding_weights, zero_weights), axis=0) + + for token, weights in embeds.items(): + token_id = tokenizer.convert_tokens_to_ids(token) + logger.debug( + "embedding %s weights for token %s", weights.shape, token + ) + embedding_weights[token_id] = weights + + # replace embedding_node + for i in range(len(text_encoder.graph.initializer)): + if text_encoder.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight": + new_initializer = numpy_helper.from_array(embedding_weights.astype(np.float32), embedding_node.name) + logger.debug("new initializer data type: %s", new_initializer.data_type) + del text_encoder.graph.initializer[i] + text_encoder.graph.initializer.insert(i, new_initializer) + + return (text_encoder, tokenizer) + + @torch.no_grad() def convert_diffusion_textual_inversion( context: ConversionContext, @@ -40,101 +138,28 @@ def convert_diffusion_textual_inversion( makedirs(encoder_path, exist_ok=True) - if format == "concept": - embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin") - token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt") - - with open(token_file, "r") as f: - token = base_token or f.read() - - loaded_embeds = torch.load(embeds_file, map_location=context.map_location) - - # separate token and the embeds - trained_token = list(loaded_embeds.keys())[0] - embeds = loaded_embeds[trained_token] - elif format == "embeddings": - loaded_embeds = torch.load(inversion, map_location=context.map_location) - - string_to_token = loaded_embeds["string_to_token"] - string_to_param = loaded_embeds["string_to_param"] - - # separate token and embeds - trained_token = list(string_to_token.keys())[0] - embeds = string_to_param[trained_token] - - num_tokens = embeds.shape[0] - logger.info("generating %s layer tokens", num_tokens) - token = [f"{base_token or name}-{i}" for i in range(num_tokens)] - else: - raise ValueError(f"unknown textual inversion format: {format}") - - logger.info("found embeddings for token %s: %s", token, embeds.shape) - + text_encoder = load_model(path.join(base_model, "text_encoder", "model.onnx")) tokenizer = CLIPTokenizer.from_pretrained( base_model, subfolder="tokenizer", ) - text_encoder = CLIPTextModel.from_pretrained( - base_model, - subfolder="text_encoder", - ) - - # cast to dtype of text_encoder - dtype = text_encoder.get_input_embeddings().weight.dtype - embeds.to(dtype) - - # add the token in tokenizer - num_added_tokens = tokenizer.add_tokens(token) - if num_added_tokens == 0: - raise ValueError( - f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer." - ) - - logger.info("added %s tokens", num_added_tokens) - - # resize the token embeddings - text_encoder.resize_token_embeddings(len(tokenizer)) - - if len(embeds.shape) == 2: - # multiple vectors in embeds - for i in range(embeds.shape[0]): - layer_embeds = embeds[i] - layer_token = token[i] - logger.debug( - "embedding %s vector for layer %s", layer_embeds.shape, layer_token - ) - token_id = tokenizer.convert_tokens_to_ids(layer_token) - text_encoder.get_input_embeddings().weight.data[token_id] = layer_embeds - else: - # get the id for the token and assign the embeds - token_id = tokenizer.convert_tokens_to_ids(token) - text_encoder.get_input_embeddings().weight.data[token_id] = embeds - - # conversion stuff - text_input = tokenizer( - "A sample prompt", - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", + text_encoder, tokenizer = blend_textual_inversions( + context, + text_encoder, + tokenizer, + [inversion], + [format], + [1.0], + base_token=(base_token or name), ) logger.info("saving tokenizer for textual inversion") tokenizer.save_pretrained(tokenizer_path) logger.info("saving text encoder for textual inversion") - export( + save_model( text_encoder, - # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files - (text_input.input_ids.to(dtype=torch.int32)), f=encoder_model, - input_names=["input_ids"], - output_names=["last_hidden_state", "pooler_output"], - dynamic_axes={ - "input_ids": {0: "batch", 1: "sequence"}, - }, - do_constant_folding=True, - opset_version=context.opset, ) logger.info("textual inversion saved to %s", dest_path) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index f8fc5625..61a14c67 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -1,6 +1,5 @@ from logging import getLogger from os import path -from re import compile from typing import Any, List, Optional, Tuple import numpy as np @@ -22,6 +21,7 @@ from diffusers import ( PNDMScheduler, StableDiffusionPipeline, ) +from onnx import load_model from onnxruntime import SessionOptions from transformers import CLIPTokenizer @@ -37,7 +37,8 @@ try: except ImportError: from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler -from ..convert.diffusion.lora import merge_lora, buffer_external_data_tensors +from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors +from ..convert.diffusion.textual_inversion import blend_textual_inversions from ..params import DeviceParams, Size from ..server import ServerContext from ..utils import run_gc @@ -215,19 +216,36 @@ def load_pipeline( ) } + text_encoder = None if inversions is not None and len(inversions) > 0: - inversion = "inversion-" + inversions[0][0] - logger.debug("loading Textual Inversion from %s", inversion) - # TODO: blend the inversion models - components["text_encoder"] = OnnxRuntimeModel.from_pretrained( - path.join(server.model_path, inversion, "text_encoder"), - provider=device.ort_provider(), - sess_options=device.sess_options(), + inversion_names, inversion_weights = zip(*inversions) + logger.debug("blending Textual Inversions from %s", inversion_names) + + inversion_models = [path.join(server.model_path, "inversion", f"{name}.ckpt") for name in inversion_names] + text_encoder = load_model(path.join(model, "text_encoder", "model.onnx")) + tokenizer = CLIPTokenizer.from_pretrained( + model, + subfolder="tokenizer", ) - components["tokenizer"] = CLIPTokenizer.from_pretrained( - path.join(server.model_path, inversion, "tokenizer"), + text_encoder, tokenizer = blend_textual_inversions( + server, + text_encoder, + tokenizer, + inversion_models, + ["embeddings"] * len(inversion_names), + inversion_weights, + base_tokens=inversion_names, ) + # should be pretty small and should not need external data + components["text_encoder"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + text_encoder.SerializeToString(), + provider=device.ort_provider(), + ) + ) + components["tokenizer"] = tokenizer + # test LoRA blending if loras is not None and len(loras) > 0: lora_names, lora_weights = zip(*loras) @@ -235,21 +253,22 @@ def load_pipeline( logger.info("blending base model %s with LoRA models: %s", model, lora_models) # blend and load text encoder - blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder", lora_weights=lora_weights) - (text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder) + text_encoder = text_encoder or path.join(model, "text_encoder", "model.onnx") + blended_text_encoder = blend_loras(text_encoder, lora_models, "text_encoder", lora_weights=lora_weights) + (text_encoder, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder) text_encoder_names, text_encoder_values = zip(*text_encoder_data) text_encoder_opts = SessionOptions() text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values)) components["text_encoder"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( - text_encoder_model.SerializeToString(), + text_encoder.SerializeToString(), provider=device.ort_provider(), sess_options=text_encoder_opts, ) ) # blend and load unet - blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights) + blended_unet = blend_loras(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights) (unet_model, unet_data) = buffer_external_data_tensors(blended_unet) unet_names, unet_values = zip(*unet_data) unet_opts = SessionOptions()