feat(api): blend Textual Inversions from prompt
This commit is contained in:
parent
973ad0f682
commit
506cf9f65f
|
@ -15,16 +15,11 @@ from onnx.external_data_helper import (
|
||||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from onnx_web.convert.utils import ConversionContext
|
from ..utils import ConversionContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
###
|
|
||||||
# everything in this file is still super experimental and may not produce valid ONNX models
|
|
||||||
###
|
|
||||||
|
|
||||||
|
|
||||||
def buffer_external_data_tensors(
|
def buffer_external_data_tensors(
|
||||||
model: ModelProto,
|
model: ModelProto,
|
||||||
) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]:
|
) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]:
|
||||||
|
@ -32,7 +27,7 @@ def buffer_external_data_tensors(
|
||||||
for tensor in model.graph.initializer:
|
for tensor in model.graph.initializer:
|
||||||
name = tensor.name
|
name = tensor.name
|
||||||
|
|
||||||
logger.info("externalizing tensor: %s", name)
|
logger.debug("externalizing tensor: %s", name)
|
||||||
if tensor.HasField("raw_data"):
|
if tensor.HasField("raw_data"):
|
||||||
npt = numpy_helper.to_array(tensor)
|
npt = numpy_helper.to_array(tensor)
|
||||||
orv = OrtValue.ortvalue_from_numpy(npt)
|
orv = OrtValue.ortvalue_from_numpy(npt)
|
||||||
|
@ -59,13 +54,13 @@ def fix_node_name(key: str):
|
||||||
return fixed_name
|
return fixed_name
|
||||||
|
|
||||||
|
|
||||||
def merge_lora(
|
def blend_loras(
|
||||||
base_name: str,
|
base_name: str,
|
||||||
lora_names: List[str],
|
lora_names: List[str],
|
||||||
dest_type: Literal["text_encoder", "unet"],
|
dest_type: Literal["text_encoder", "unet"],
|
||||||
lora_weights: "np.NDArray[np.float64]" = None,
|
lora_weights: "np.NDArray[np.float64]" = None,
|
||||||
):
|
):
|
||||||
base_model = load(base_name)
|
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
|
||||||
lora_models = [load_file(name) for name in lora_names]
|
lora_models = [load_file(name) for name in lora_names]
|
||||||
lora_count = len(lora_models)
|
lora_count = len(lora_models)
|
||||||
lora_weights = lora_weights or (np.ones((lora_count)) / lora_count)
|
lora_weights = lora_weights or (np.ones((lora_count)) / lora_count)
|
||||||
|
@ -86,7 +81,7 @@ def merge_lora(
|
||||||
|
|
||||||
up_key = key.replace("lora_down", "lora_up")
|
up_key = key.replace("lora_down", "lora_up")
|
||||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||||
logger.info(
|
logger.debug(
|
||||||
"blending weights for keys: %s, %s, %s", key, up_key, alpha_key
|
"blending weights for keys: %s, %s, %s", key, up_key, alpha_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -99,7 +94,7 @@ def merge_lora(
|
||||||
try:
|
try:
|
||||||
if len(up_weight.size()) == 2:
|
if len(up_weight.size()) == 2:
|
||||||
# blend for nn.Linear
|
# blend for nn.Linear
|
||||||
logger.info(
|
logger.debug(
|
||||||
"blending weights for Linear node: %s, %s, %s",
|
"blending weights for Linear node: %s, %s, %s",
|
||||||
down_weight.shape,
|
down_weight.shape,
|
||||||
up_weight.shape,
|
up_weight.shape,
|
||||||
|
@ -109,7 +104,7 @@ def merge_lora(
|
||||||
np_weights = weights.numpy() * (alpha / dim)
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
|
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
|
||||||
# blend for nn.Conv2d 1x1
|
# blend for nn.Conv2d 1x1
|
||||||
logger.info(
|
logger.debug(
|
||||||
"blending weights for Conv node: %s, %s, %s",
|
"blending weights for Conv node: %s, %s, %s",
|
||||||
down_weight.shape,
|
down_weight.shape,
|
||||||
up_weight.shape,
|
up_weight.shape,
|
||||||
|
@ -161,7 +156,7 @@ def merge_lora(
|
||||||
conv_key = base_key + "_Conv"
|
conv_key = base_key + "_Conv"
|
||||||
matmul_key = base_key + "_MatMul"
|
matmul_key = base_key + "_MatMul"
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
"key %s has conv: %s, matmul: %s",
|
"key %s has conv: %s, matmul: %s",
|
||||||
base_key,
|
base_key,
|
||||||
conv_key in fixed_node_names,
|
conv_key in fixed_node_names,
|
||||||
|
@ -171,20 +166,20 @@ def merge_lora(
|
||||||
if conv_key in fixed_node_names:
|
if conv_key in fixed_node_names:
|
||||||
conv_idx = fixed_node_names.index(conv_key)
|
conv_idx = fixed_node_names.index(conv_key)
|
||||||
conv_node = base_model.graph.node[conv_idx]
|
conv_node = base_model.graph.node[conv_idx]
|
||||||
logger.info("found conv node: %s", conv_node.name)
|
logger.debug("found conv node: %s", conv_node.name)
|
||||||
|
|
||||||
# find weight initializer
|
# find weight initializer
|
||||||
logger.info("conv inputs: %s", conv_node.input)
|
logger.debug("conv inputs: %s", conv_node.input)
|
||||||
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
||||||
weight_name = fix_initializer_name(weight_name)
|
weight_name = fix_initializer_name(weight_name)
|
||||||
|
|
||||||
weight_idx = fixed_initializer_names.index(weight_name)
|
weight_idx = fixed_initializer_names.index(weight_name)
|
||||||
weight_node = base_model.graph.initializer[weight_idx]
|
weight_node = base_model.graph.initializer[weight_idx]
|
||||||
logger.info("found weight initializer: %s", weight_node.name)
|
logger.debug("found weight initializer: %s", weight_node.name)
|
||||||
|
|
||||||
# blending
|
# blending
|
||||||
base_weights = numpy_helper.to_array(weight_node)
|
base_weights = numpy_helper.to_array(weight_node)
|
||||||
logger.info(
|
logger.debug(
|
||||||
"found blended weights for conv: %s, %s",
|
"found blended weights for conv: %s, %s",
|
||||||
weights.shape,
|
weights.shape,
|
||||||
base_weights.shape,
|
base_weights.shape,
|
||||||
|
@ -192,7 +187,7 @@ def merge_lora(
|
||||||
|
|
||||||
blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
||||||
blended = np.expand_dims(blended, (2, 3))
|
blended = np.expand_dims(blended, (2, 3))
|
||||||
logger.info("blended weight shape: %s", blended.shape)
|
logger.debug("blended weight shape: %s", blended.shape)
|
||||||
|
|
||||||
# replace the original initializer
|
# replace the original initializer
|
||||||
updated_node = numpy_helper.from_array(blended, weight_node.name)
|
updated_node = numpy_helper.from_array(blended, weight_node.name)
|
||||||
|
@ -201,33 +196,33 @@ def merge_lora(
|
||||||
elif matmul_key in fixed_node_names:
|
elif matmul_key in fixed_node_names:
|
||||||
weight_idx = fixed_node_names.index(matmul_key)
|
weight_idx = fixed_node_names.index(matmul_key)
|
||||||
weight_node = base_model.graph.node[weight_idx]
|
weight_node = base_model.graph.node[weight_idx]
|
||||||
logger.info("found matmul node: %s", weight_node.name)
|
logger.debug("found matmul node: %s", weight_node.name)
|
||||||
|
|
||||||
# find the MatMul initializer
|
# find the MatMul initializer
|
||||||
logger.info("matmul inputs: %s", weight_node.input)
|
logger.debug("matmul inputs: %s", weight_node.input)
|
||||||
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
||||||
|
|
||||||
matmul_idx = fixed_initializer_names.index(matmul_name)
|
matmul_idx = fixed_initializer_names.index(matmul_name)
|
||||||
matmul_node = base_model.graph.initializer[matmul_idx]
|
matmul_node = base_model.graph.initializer[matmul_idx]
|
||||||
logger.info("found matmul initializer: %s", matmul_node.name)
|
logger.debug("found matmul initializer: %s", matmul_node.name)
|
||||||
|
|
||||||
# blending
|
# blending
|
||||||
base_weights = numpy_helper.to_array(matmul_node)
|
base_weights = numpy_helper.to_array(matmul_node)
|
||||||
logger.info(
|
logger.debug(
|
||||||
"found blended weights for matmul: %s, %s",
|
"found blended weights for matmul: %s, %s",
|
||||||
weights.shape,
|
weights.shape,
|
||||||
base_weights.shape,
|
base_weights.shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
blended = base_weights + weights.transpose()
|
blended = base_weights + weights.transpose()
|
||||||
logger.info("blended weight shape: %s", blended.shape)
|
logger.debug("blended weight shape: %s", blended.shape)
|
||||||
|
|
||||||
# replace the original initializer
|
# replace the original initializer
|
||||||
updated_node = numpy_helper.from_array(blended, matmul_node.name)
|
updated_node = numpy_helper.from_array(blended, matmul_node.name)
|
||||||
del base_model.graph.initializer[matmul_idx]
|
del base_model.graph.initializer[matmul_idx]
|
||||||
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
||||||
else:
|
else:
|
||||||
logger.info("could not find any nodes for %s", base_key)
|
logger.warning("could not find any nodes for %s", base_key)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"node counts: %s -> %s, %s -> %s",
|
"node counts: %s -> %s, %s -> %s",
|
||||||
|
@ -256,7 +251,7 @@ if __name__ == "__main__":
|
||||||
args.lora_weights,
|
args.lora_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
blend_model = merge_lora(args.base, args.lora_models, args.type, args.lora_weights)
|
blend_model = blend_loras(args.base, args.lora_models, args.type, args.lora_weights)
|
||||||
if args.dest is None or args.dest == "" or args.dest == "ort":
|
if args.dest is None or args.dest == "" or args.dest == "ort":
|
||||||
# convert to external data and save to memory
|
# convert to external data and save to memory
|
||||||
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
|
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
|
||||||
|
|
|
@ -1,17 +1,115 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import makedirs, path
|
from os import makedirs, path
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub.file_download import hf_hub_download
|
from huggingface_hub.file_download import hf_hub_download
|
||||||
|
from onnx import ModelProto, load_model, numpy_helper, save_model
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
|
from ...server.context import ServerContext
|
||||||
from ..utils import ConversionContext
|
from ..utils import ConversionContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def blend_textual_inversions(
|
||||||
|
context: ServerContext,
|
||||||
|
text_encoder: Optional[ModelProto],
|
||||||
|
tokenizer: Optional[CLIPTokenizer],
|
||||||
|
inversion_names: List[str],
|
||||||
|
inversion_formats: List[str],
|
||||||
|
inversion_weights: Optional[List[float]] = None,
|
||||||
|
base_tokens: Optional[List[str]] = None,
|
||||||
|
) -> Tuple[ModelProto, CLIPTokenizer]:
|
||||||
|
dtype = np.float # TODO: fixed type, which one?
|
||||||
|
# prev: text_encoder.get_input_embeddings().weight.dtype
|
||||||
|
embeds = {}
|
||||||
|
|
||||||
|
for name, format, weight, base_token in zip(inversion_names, inversion_formats, inversion_weights, base_tokens or inversion_names):
|
||||||
|
logger.info("blending Textual Inversion %s with weight of %s", name, weight)
|
||||||
|
if format == "concept":
|
||||||
|
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
|
||||||
|
token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt")
|
||||||
|
|
||||||
|
with open(token_file, "r") as f:
|
||||||
|
token = base_token or f.read()
|
||||||
|
|
||||||
|
loaded_embeds = torch.load(embeds_file)
|
||||||
|
|
||||||
|
# separate token and the embeds
|
||||||
|
trained_token = list(loaded_embeds.keys())[0]
|
||||||
|
|
||||||
|
layer = loaded_embeds[trained_token].cpu().numpy().astype(dtype)
|
||||||
|
layer *= weight
|
||||||
|
if trained_token in embeds:
|
||||||
|
embeds[token] += layer
|
||||||
|
else:
|
||||||
|
embeds[token] = layer
|
||||||
|
elif format == "embeddings":
|
||||||
|
loaded_embeds = torch.load(name)
|
||||||
|
|
||||||
|
string_to_token = loaded_embeds["string_to_token"]
|
||||||
|
string_to_param = loaded_embeds["string_to_param"]
|
||||||
|
|
||||||
|
# separate token and embeds
|
||||||
|
trained_token = list(string_to_token.keys())[0]
|
||||||
|
trained_embeds = string_to_param[trained_token]
|
||||||
|
|
||||||
|
num_tokens = trained_embeds.shape[0]
|
||||||
|
logger.debug("generating %s layer tokens", num_tokens)
|
||||||
|
|
||||||
|
for i in range(num_tokens):
|
||||||
|
token = f"{base_token or name}-{i}"
|
||||||
|
layer = trained_embeds[i,:].cpu().numpy().astype(dtype)
|
||||||
|
layer *= weight
|
||||||
|
if token in embeds:
|
||||||
|
embeds[token] += layer
|
||||||
|
else:
|
||||||
|
embeds[token] = layer
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown Textual Inversion format: {format}")
|
||||||
|
|
||||||
|
# add the tokens to the tokenizer
|
||||||
|
logger.info("found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys())
|
||||||
|
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
|
||||||
|
if num_added_tokens == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("added %s tokens", num_added_tokens)
|
||||||
|
|
||||||
|
# resize the token embeddings
|
||||||
|
# text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
embedding_node = [n for n in text_encoder.graph.initializer if n.name == "text_model.embeddings.token_embedding.weight"][0]
|
||||||
|
embedding_weights = numpy_helper.to_array(embedding_node)
|
||||||
|
|
||||||
|
weights_dim = embedding_weights.shape[1]
|
||||||
|
zero_weights = np.zeros((num_added_tokens, weights_dim))
|
||||||
|
embedding_weights = np.concatenate((embedding_weights, zero_weights), axis=0)
|
||||||
|
|
||||||
|
for token, weights in embeds.items():
|
||||||
|
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||||
|
logger.debug(
|
||||||
|
"embedding %s weights for token %s", weights.shape, token
|
||||||
|
)
|
||||||
|
embedding_weights[token_id] = weights
|
||||||
|
|
||||||
|
# replace embedding_node
|
||||||
|
for i in range(len(text_encoder.graph.initializer)):
|
||||||
|
if text_encoder.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight":
|
||||||
|
new_initializer = numpy_helper.from_array(embedding_weights.astype(np.float32), embedding_node.name)
|
||||||
|
logger.debug("new initializer data type: %s", new_initializer.data_type)
|
||||||
|
del text_encoder.graph.initializer[i]
|
||||||
|
text_encoder.graph.initializer.insert(i, new_initializer)
|
||||||
|
|
||||||
|
return (text_encoder, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_diffusion_textual_inversion(
|
def convert_diffusion_textual_inversion(
|
||||||
context: ConversionContext,
|
context: ConversionContext,
|
||||||
|
@ -40,101 +138,28 @@ def convert_diffusion_textual_inversion(
|
||||||
|
|
||||||
makedirs(encoder_path, exist_ok=True)
|
makedirs(encoder_path, exist_ok=True)
|
||||||
|
|
||||||
if format == "concept":
|
text_encoder = load_model(path.join(base_model, "text_encoder", "model.onnx"))
|
||||||
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
|
|
||||||
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
|
|
||||||
|
|
||||||
with open(token_file, "r") as f:
|
|
||||||
token = base_token or f.read()
|
|
||||||
|
|
||||||
loaded_embeds = torch.load(embeds_file, map_location=context.map_location)
|
|
||||||
|
|
||||||
# separate token and the embeds
|
|
||||||
trained_token = list(loaded_embeds.keys())[0]
|
|
||||||
embeds = loaded_embeds[trained_token]
|
|
||||||
elif format == "embeddings":
|
|
||||||
loaded_embeds = torch.load(inversion, map_location=context.map_location)
|
|
||||||
|
|
||||||
string_to_token = loaded_embeds["string_to_token"]
|
|
||||||
string_to_param = loaded_embeds["string_to_param"]
|
|
||||||
|
|
||||||
# separate token and embeds
|
|
||||||
trained_token = list(string_to_token.keys())[0]
|
|
||||||
embeds = string_to_param[trained_token]
|
|
||||||
|
|
||||||
num_tokens = embeds.shape[0]
|
|
||||||
logger.info("generating %s layer tokens", num_tokens)
|
|
||||||
token = [f"{base_token or name}-{i}" for i in range(num_tokens)]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unknown textual inversion format: {format}")
|
|
||||||
|
|
||||||
logger.info("found embeddings for token %s: %s", token, embeds.shape)
|
|
||||||
|
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
subfolder="tokenizer",
|
subfolder="tokenizer",
|
||||||
)
|
)
|
||||||
text_encoder = CLIPTextModel.from_pretrained(
|
text_encoder, tokenizer = blend_textual_inversions(
|
||||||
base_model,
|
context,
|
||||||
subfolder="text_encoder",
|
text_encoder,
|
||||||
)
|
tokenizer,
|
||||||
|
[inversion],
|
||||||
# cast to dtype of text_encoder
|
[format],
|
||||||
dtype = text_encoder.get_input_embeddings().weight.dtype
|
[1.0],
|
||||||
embeds.to(dtype)
|
base_token=(base_token or name),
|
||||||
|
|
||||||
# add the token in tokenizer
|
|
||||||
num_added_tokens = tokenizer.add_tokens(token)
|
|
||||||
if num_added_tokens == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("added %s tokens", num_added_tokens)
|
|
||||||
|
|
||||||
# resize the token embeddings
|
|
||||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
|
||||||
|
|
||||||
if len(embeds.shape) == 2:
|
|
||||||
# multiple vectors in embeds
|
|
||||||
for i in range(embeds.shape[0]):
|
|
||||||
layer_embeds = embeds[i]
|
|
||||||
layer_token = token[i]
|
|
||||||
logger.debug(
|
|
||||||
"embedding %s vector for layer %s", layer_embeds.shape, layer_token
|
|
||||||
)
|
|
||||||
token_id = tokenizer.convert_tokens_to_ids(layer_token)
|
|
||||||
text_encoder.get_input_embeddings().weight.data[token_id] = layer_embeds
|
|
||||||
else:
|
|
||||||
# get the id for the token and assign the embeds
|
|
||||||
token_id = tokenizer.convert_tokens_to_ids(token)
|
|
||||||
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
|
||||||
|
|
||||||
# conversion stuff
|
|
||||||
text_input = tokenizer(
|
|
||||||
"A sample prompt",
|
|
||||||
padding="max_length",
|
|
||||||
max_length=tokenizer.model_max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("saving tokenizer for textual inversion")
|
logger.info("saving tokenizer for textual inversion")
|
||||||
tokenizer.save_pretrained(tokenizer_path)
|
tokenizer.save_pretrained(tokenizer_path)
|
||||||
|
|
||||||
logger.info("saving text encoder for textual inversion")
|
logger.info("saving text encoder for textual inversion")
|
||||||
export(
|
save_model(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
|
||||||
(text_input.input_ids.to(dtype=torch.int32)),
|
|
||||||
f=encoder_model,
|
f=encoder_model,
|
||||||
input_names=["input_ids"],
|
|
||||||
output_names=["last_hidden_state", "pooler_output"],
|
|
||||||
dynamic_axes={
|
|
||||||
"input_ids": {0: "batch", 1: "sequence"},
|
|
||||||
},
|
|
||||||
do_constant_folding=True,
|
|
||||||
opset_version=context.opset,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("textual inversion saved to %s", dest_path)
|
logger.info("textual inversion saved to %s", dest_path)
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from re import compile
|
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -22,6 +21,7 @@ from diffusers import (
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
|
from onnx import load_model
|
||||||
from onnxruntime import SessionOptions
|
from onnxruntime import SessionOptions
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
|
@ -37,7 +37,8 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
|
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
|
||||||
|
|
||||||
from ..convert.diffusion.lora import merge_lora, buffer_external_data_tensors
|
from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
|
||||||
|
from ..convert.diffusion.textual_inversion import blend_textual_inversions
|
||||||
from ..params import DeviceParams, Size
|
from ..params import DeviceParams, Size
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
|
@ -215,19 +216,36 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
text_encoder = None
|
||||||
if inversions is not None and len(inversions) > 0:
|
if inversions is not None and len(inversions) > 0:
|
||||||
inversion = "inversion-" + inversions[0][0]
|
inversion_names, inversion_weights = zip(*inversions)
|
||||||
logger.debug("loading Textual Inversion from %s", inversion)
|
logger.debug("blending Textual Inversions from %s", inversion_names)
|
||||||
# TODO: blend the inversion models
|
|
||||||
components["text_encoder"] = OnnxRuntimeModel.from_pretrained(
|
inversion_models = [path.join(server.model_path, "inversion", f"{name}.ckpt") for name in inversion_names]
|
||||||
path.join(server.model_path, inversion, "text_encoder"),
|
text_encoder = load_model(path.join(model, "text_encoder", "model.onnx"))
|
||||||
provider=device.ort_provider(),
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
sess_options=device.sess_options(),
|
model,
|
||||||
|
subfolder="tokenizer",
|
||||||
)
|
)
|
||||||
components["tokenizer"] = CLIPTokenizer.from_pretrained(
|
text_encoder, tokenizer = blend_textual_inversions(
|
||||||
path.join(server.model_path, inversion, "tokenizer"),
|
server,
|
||||||
|
text_encoder,
|
||||||
|
tokenizer,
|
||||||
|
inversion_models,
|
||||||
|
["embeddings"] * len(inversion_names),
|
||||||
|
inversion_weights,
|
||||||
|
base_tokens=inversion_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# should be pretty small and should not need external data
|
||||||
|
components["text_encoder"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
text_encoder.SerializeToString(),
|
||||||
|
provider=device.ort_provider(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
components["tokenizer"] = tokenizer
|
||||||
|
|
||||||
# test LoRA blending
|
# test LoRA blending
|
||||||
if loras is not None and len(loras) > 0:
|
if loras is not None and len(loras) > 0:
|
||||||
lora_names, lora_weights = zip(*loras)
|
lora_names, lora_weights = zip(*loras)
|
||||||
|
@ -235,21 +253,22 @@ def load_pipeline(
|
||||||
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
||||||
|
|
||||||
# blend and load text encoder
|
# blend and load text encoder
|
||||||
blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder", lora_weights=lora_weights)
|
text_encoder = text_encoder or path.join(model, "text_encoder", "model.onnx")
|
||||||
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
|
blended_text_encoder = blend_loras(text_encoder, lora_models, "text_encoder", lora_weights=lora_weights)
|
||||||
|
(text_encoder, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
|
||||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||||
text_encoder_opts = SessionOptions()
|
text_encoder_opts = SessionOptions()
|
||||||
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
|
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
|
||||||
components["text_encoder"] = OnnxRuntimeModel(
|
components["text_encoder"] = OnnxRuntimeModel(
|
||||||
OnnxRuntimeModel.load_model(
|
OnnxRuntimeModel.load_model(
|
||||||
text_encoder_model.SerializeToString(),
|
text_encoder.SerializeToString(),
|
||||||
provider=device.ort_provider(),
|
provider=device.ort_provider(),
|
||||||
sess_options=text_encoder_opts,
|
sess_options=text_encoder_opts,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# blend and load unet
|
# blend and load unet
|
||||||
blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights)
|
blended_unet = blend_loras(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights)
|
||||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||||
unet_names, unet_values = zip(*unet_data)
|
unet_names, unet_values = zip(*unet_data)
|
||||||
unet_opts = SessionOptions()
|
unet_opts = SessionOptions()
|
||||||
|
|
Loading…
Reference in New Issue