1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-03-15 19:27:29 -05:00
parent 506cf9f65f
commit 8e8e230ffd
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 72 additions and 24 deletions

View File

@ -15,6 +15,7 @@ from onnx.external_data_helper import (
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from safetensors.torch import load_file
from ...server.context import ServerContext
from ..utils import ConversionContext
logger = getLogger(__name__)
@ -55,6 +56,7 @@ def fix_node_name(key: str):
def blend_loras(
context: ServerContext,
base_name: str,
lora_names: List[str],
dest_type: Literal["text_encoder", "unet"],
@ -236,6 +238,7 @@ def blend_loras(
if __name__ == "__main__":
context = ConversionContext.from_environ()
parser = ArgumentParser()
parser.add_argument("--base", type=str)
parser.add_argument("--dest", type=str)
@ -251,7 +254,9 @@ if __name__ == "__main__":
args.lora_weights,
)
blend_model = blend_loras(args.base, args.lora_models, args.type, args.lora_weights)
blend_model = blend_loras(
context, args.base, args.lora_models, args.type, args.lora_weights
)
if args.dest is None or args.dest == "" or args.dest == "ort":
# convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model)

View File

@ -6,7 +6,6 @@ import numpy as np
import torch
from huggingface_hub.file_download import hf_hub_download
from onnx import ModelProto, load_model, numpy_helper, save_model
from torch.onnx import export
from transformers import CLIPTokenizer
from ...server.context import ServerContext
@ -25,11 +24,15 @@ def blend_textual_inversions(
inversion_weights: Optional[List[float]] = None,
base_tokens: Optional[List[str]] = None,
) -> Tuple[ModelProto, CLIPTokenizer]:
dtype = np.float # TODO: fixed type, which one?
# prev: text_encoder.get_input_embeddings().weight.dtype
dtype = np.float
embeds = {}
for name, format, weight, base_token in zip(inversion_names, inversion_formats, inversion_weights, base_tokens or inversion_names):
for name, format, weight, base_token in zip(
inversion_names,
inversion_formats,
inversion_weights,
base_tokens or inversion_names,
):
logger.info("blending Textual Inversion %s with weight of %s", name, weight)
if format == "concept":
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
@ -74,7 +77,9 @@ def blend_textual_inversions(
raise ValueError(f"unknown Textual Inversion format: {format}")
# add the tokens to the tokenizer
logger.info("found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys())
logger.info(
"found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys()
)
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0:
raise ValueError(
@ -85,7 +90,11 @@ def blend_textual_inversions(
# resize the token embeddings
# text_encoder.resize_token_embeddings(len(tokenizer))
embedding_node = [n for n in text_encoder.graph.initializer if n.name == "text_model.embeddings.token_embedding.weight"][0]
embedding_node = [
n
for n in text_encoder.graph.initializer
if n.name == "text_model.embeddings.token_embedding.weight"
][0]
embedding_weights = numpy_helper.to_array(embedding_node)
weights_dim = embedding_weights.shape[1]
@ -94,15 +103,18 @@ def blend_textual_inversions(
for token, weights in embeds.items():
token_id = tokenizer.convert_tokens_to_ids(token)
logger.debug(
"embedding %s weights for token %s", weights.shape, token
)
logger.debug("embedding %s weights for token %s", weights.shape, token)
embedding_weights[token_id] = weights
# replace embedding_node
for i in range(len(text_encoder.graph.initializer)):
if text_encoder.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight":
new_initializer = numpy_helper.from_array(embedding_weights.astype(np.float32), embedding_node.name)
if (
text_encoder.graph.initializer[i].name
== "text_model.embeddings.token_embedding.weight"
):
new_initializer = numpy_helper.from_array(
embedding_weights.astype(np.float32), embedding_node.name
)
logger.debug("new initializer data type: %s", new_initializer.data_type)
del text_encoder.graph.initializer[i]
text_encoder.graph.initializer.insert(i, new_initializer)

View File

@ -221,7 +221,10 @@ def load_pipeline(
inversion_names, inversion_weights = zip(*inversions)
logger.debug("blending Textual Inversions from %s", inversion_names)
inversion_models = [path.join(server.model_path, "inversion", f"{name}.ckpt") for name in inversion_names]
inversion_models = [
path.join(server.model_path, "inversion", f"{name}.ckpt")
for name in inversion_names
]
text_encoder = load_model(path.join(model, "text_encoder", "model.onnx"))
tokenizer = CLIPTokenizer.from_pretrained(
model,
@ -249,16 +252,33 @@ def load_pipeline(
# test LoRA blending
if loras is not None and len(loras) > 0:
lora_names, lora_weights = zip(*loras)
lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names]
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
lora_models = [
path.join(server.model_path, "lora", f"{name}.safetensors")
for name in lora_names
]
logger.info(
"blending base model %s with LoRA models: %s", model, lora_models
)
# blend and load text encoder
text_encoder = text_encoder or path.join(model, "text_encoder", "model.onnx")
blended_text_encoder = blend_loras(text_encoder, lora_models, "text_encoder", lora_weights=lora_weights)
(text_encoder, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
text_encoder = text_encoder or path.join(
model, "text_encoder", "model.onnx"
)
text_encoder = blend_loras(
server,
text_encoder,
lora_models,
"text_encoder",
lora_weights=lora_weights,
)
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
text_encoder
)
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = SessionOptions()
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
text_encoder_opts.add_external_initializers(
list(text_encoder_names), list(text_encoder_values)
)
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
@ -268,7 +288,13 @@ def load_pipeline(
)
# blend and load unet
blended_unet = blend_loras(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights)
blended_unet = blend_loras(
server,
path.join(model, "unet", "model.onnx"),
lora_models,
"unet",
lora_weights=lora_weights,
)
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data)
unet_opts = SessionOptions()

View File

@ -1,6 +1,6 @@
from logging import getLogger
from math import ceil
from re import compile, Pattern
from re import Pattern, compile
from typing import List, Optional, Tuple
import numpy as np
@ -132,7 +132,9 @@ def expand_prompt(
return prompt_embeds
def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, List[Tuple[str, float]]]:
def get_tokens_from_prompt(
prompt: str, pattern: Pattern[str]
) -> Tuple[str, List[Tuple[str, float]]]:
"""
TODO: replace with Arpeggio
"""
@ -145,7 +147,10 @@ def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, Lis
name, weight = next_match.groups()
tokens.append((name, float(weight)))
# remove this match and look for another
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
remaining_prompt = (
remaining_prompt[: next_match.start()]
+ remaining_prompt[next_match.end() :]
)
next_match = pattern.search(remaining_prompt)
return (remaining_prompt, tokens)