1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-03-15 19:27:29 -05:00
parent 506cf9f65f
commit 8e8e230ffd
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 72 additions and 24 deletions

View File

@ -15,6 +15,7 @@ from onnx.external_data_helper import (
from onnxruntime import InferenceSession, OrtValue, SessionOptions from onnxruntime import InferenceSession, OrtValue, SessionOptions
from safetensors.torch import load_file from safetensors.torch import load_file
from ...server.context import ServerContext
from ..utils import ConversionContext from ..utils import ConversionContext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -55,6 +56,7 @@ def fix_node_name(key: str):
def blend_loras( def blend_loras(
context: ServerContext,
base_name: str, base_name: str,
lora_names: List[str], lora_names: List[str],
dest_type: Literal["text_encoder", "unet"], dest_type: Literal["text_encoder", "unet"],
@ -236,6 +238,7 @@ def blend_loras(
if __name__ == "__main__": if __name__ == "__main__":
context = ConversionContext.from_environ()
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--base", type=str) parser.add_argument("--base", type=str)
parser.add_argument("--dest", type=str) parser.add_argument("--dest", type=str)
@ -251,7 +254,9 @@ if __name__ == "__main__":
args.lora_weights, args.lora_weights,
) )
blend_model = blend_loras(args.base, args.lora_models, args.type, args.lora_weights) blend_model = blend_loras(
context, args.base, args.lora_models, args.type, args.lora_weights
)
if args.dest is None or args.dest == "" or args.dest == "ort": if args.dest is None or args.dest == "" or args.dest == "ort":
# convert to external data and save to memory # convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model) (bare_model, external_data) = buffer_external_data_tensors(blend_model)

View File

@ -6,7 +6,6 @@ import numpy as np
import torch import torch
from huggingface_hub.file_download import hf_hub_download from huggingface_hub.file_download import hf_hub_download
from onnx import ModelProto, load_model, numpy_helper, save_model from onnx import ModelProto, load_model, numpy_helper, save_model
from torch.onnx import export
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from ...server.context import ServerContext from ...server.context import ServerContext
@ -25,11 +24,15 @@ def blend_textual_inversions(
inversion_weights: Optional[List[float]] = None, inversion_weights: Optional[List[float]] = None,
base_tokens: Optional[List[str]] = None, base_tokens: Optional[List[str]] = None,
) -> Tuple[ModelProto, CLIPTokenizer]: ) -> Tuple[ModelProto, CLIPTokenizer]:
dtype = np.float # TODO: fixed type, which one? dtype = np.float
# prev: text_encoder.get_input_embeddings().weight.dtype
embeds = {} embeds = {}
for name, format, weight, base_token in zip(inversion_names, inversion_formats, inversion_weights, base_tokens or inversion_names): for name, format, weight, base_token in zip(
inversion_names,
inversion_formats,
inversion_weights,
base_tokens or inversion_names,
):
logger.info("blending Textual Inversion %s with weight of %s", name, weight) logger.info("blending Textual Inversion %s with weight of %s", name, weight)
if format == "concept": if format == "concept":
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin") embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
@ -64,7 +67,7 @@ def blend_textual_inversions(
for i in range(num_tokens): for i in range(num_tokens):
token = f"{base_token or name}-{i}" token = f"{base_token or name}-{i}"
layer = trained_embeds[i,:].cpu().numpy().astype(dtype) layer = trained_embeds[i, :].cpu().numpy().astype(dtype)
layer *= weight layer *= weight
if token in embeds: if token in embeds:
embeds[token] += layer embeds[token] += layer
@ -74,7 +77,9 @@ def blend_textual_inversions(
raise ValueError(f"unknown Textual Inversion format: {format}") raise ValueError(f"unknown Textual Inversion format: {format}")
# add the tokens to the tokenizer # add the tokens to the tokenizer
logger.info("found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys()) logger.info(
"found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys()
)
num_added_tokens = tokenizer.add_tokens(list(embeds.keys())) num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0: if num_added_tokens == 0:
raise ValueError( raise ValueError(
@ -85,7 +90,11 @@ def blend_textual_inversions(
# resize the token embeddings # resize the token embeddings
# text_encoder.resize_token_embeddings(len(tokenizer)) # text_encoder.resize_token_embeddings(len(tokenizer))
embedding_node = [n for n in text_encoder.graph.initializer if n.name == "text_model.embeddings.token_embedding.weight"][0] embedding_node = [
n
for n in text_encoder.graph.initializer
if n.name == "text_model.embeddings.token_embedding.weight"
][0]
embedding_weights = numpy_helper.to_array(embedding_node) embedding_weights = numpy_helper.to_array(embedding_node)
weights_dim = embedding_weights.shape[1] weights_dim = embedding_weights.shape[1]
@ -94,15 +103,18 @@ def blend_textual_inversions(
for token, weights in embeds.items(): for token, weights in embeds.items():
token_id = tokenizer.convert_tokens_to_ids(token) token_id = tokenizer.convert_tokens_to_ids(token)
logger.debug( logger.debug("embedding %s weights for token %s", weights.shape, token)
"embedding %s weights for token %s", weights.shape, token
)
embedding_weights[token_id] = weights embedding_weights[token_id] = weights
# replace embedding_node # replace embedding_node
for i in range(len(text_encoder.graph.initializer)): for i in range(len(text_encoder.graph.initializer)):
if text_encoder.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight": if (
new_initializer = numpy_helper.from_array(embedding_weights.astype(np.float32), embedding_node.name) text_encoder.graph.initializer[i].name
== "text_model.embeddings.token_embedding.weight"
):
new_initializer = numpy_helper.from_array(
embedding_weights.astype(np.float32), embedding_node.name
)
logger.debug("new initializer data type: %s", new_initializer.data_type) logger.debug("new initializer data type: %s", new_initializer.data_type)
del text_encoder.graph.initializer[i] del text_encoder.graph.initializer[i]
text_encoder.graph.initializer.insert(i, new_initializer) text_encoder.graph.initializer.insert(i, new_initializer)

View File

@ -221,7 +221,10 @@ def load_pipeline(
inversion_names, inversion_weights = zip(*inversions) inversion_names, inversion_weights = zip(*inversions)
logger.debug("blending Textual Inversions from %s", inversion_names) logger.debug("blending Textual Inversions from %s", inversion_names)
inversion_models = [path.join(server.model_path, "inversion", f"{name}.ckpt") for name in inversion_names] inversion_models = [
path.join(server.model_path, "inversion", f"{name}.ckpt")
for name in inversion_names
]
text_encoder = load_model(path.join(model, "text_encoder", "model.onnx")) text_encoder = load_model(path.join(model, "text_encoder", "model.onnx"))
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
model, model,
@ -249,16 +252,33 @@ def load_pipeline(
# test LoRA blending # test LoRA blending
if loras is not None and len(loras) > 0: if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras) lora_names, lora_weights = zip(*loras)
lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names] lora_models = [
logger.info("blending base model %s with LoRA models: %s", model, lora_models) path.join(server.model_path, "lora", f"{name}.safetensors")
for name in lora_names
]
logger.info(
"blending base model %s with LoRA models: %s", model, lora_models
)
# blend and load text encoder # blend and load text encoder
text_encoder = text_encoder or path.join(model, "text_encoder", "model.onnx") text_encoder = text_encoder or path.join(
blended_text_encoder = blend_loras(text_encoder, lora_models, "text_encoder", lora_weights=lora_weights) model, "text_encoder", "model.onnx"
(text_encoder, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder) )
text_encoder = blend_loras(
server,
text_encoder,
lora_models,
"text_encoder",
lora_weights=lora_weights,
)
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
text_encoder
)
text_encoder_names, text_encoder_values = zip(*text_encoder_data) text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = SessionOptions() text_encoder_opts = SessionOptions()
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values)) text_encoder_opts.add_external_initializers(
list(text_encoder_names), list(text_encoder_values)
)
components["text_encoder"] = OnnxRuntimeModel( components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model( OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(), text_encoder.SerializeToString(),
@ -268,7 +288,13 @@ def load_pipeline(
) )
# blend and load unet # blend and load unet
blended_unet = blend_loras(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights) blended_unet = blend_loras(
server,
path.join(model, "unet", "model.onnx"),
lora_models,
"unet",
lora_weights=lora_weights,
)
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet) (unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data) unet_names, unet_values = zip(*unet_data)
unet_opts = SessionOptions() unet_opts = SessionOptions()

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from math import ceil from math import ceil
from re import compile, Pattern from re import Pattern, compile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import numpy as np import numpy as np
@ -132,7 +132,9 @@ def expand_prompt(
return prompt_embeds return prompt_embeds
def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, List[Tuple[str, float]]]: def get_tokens_from_prompt(
prompt: str, pattern: Pattern[str]
) -> Tuple[str, List[Tuple[str, float]]]:
""" """
TODO: replace with Arpeggio TODO: replace with Arpeggio
""" """
@ -145,7 +147,10 @@ def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, Lis
name, weight = next_match.groups() name, weight = next_match.groups()
tokens.append((name, float(weight))) tokens.append((name, float(weight)))
# remove this match and look for another # remove this match and look for another
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():] remaining_prompt = (
remaining_prompt[: next_match.start()]
+ remaining_prompt[next_match.end() :]
)
next_match = pattern.search(remaining_prompt) next_match = pattern.search(remaining_prompt)
return (remaining_prompt, tokens) return (remaining_prompt, tokens)