apply lint
This commit is contained in:
parent
506cf9f65f
commit
8e8e230ffd
|
@ -15,6 +15,7 @@ from onnx.external_data_helper import (
|
||||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
from ...server.context import ServerContext
|
||||||
from ..utils import ConversionContext
|
from ..utils import ConversionContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -55,6 +56,7 @@ def fix_node_name(key: str):
|
||||||
|
|
||||||
|
|
||||||
def blend_loras(
|
def blend_loras(
|
||||||
|
context: ServerContext,
|
||||||
base_name: str,
|
base_name: str,
|
||||||
lora_names: List[str],
|
lora_names: List[str],
|
||||||
dest_type: Literal["text_encoder", "unet"],
|
dest_type: Literal["text_encoder", "unet"],
|
||||||
|
@ -236,6 +238,7 @@ def blend_loras(
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
context = ConversionContext.from_environ()
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("--base", type=str)
|
parser.add_argument("--base", type=str)
|
||||||
parser.add_argument("--dest", type=str)
|
parser.add_argument("--dest", type=str)
|
||||||
|
@ -251,7 +254,9 @@ if __name__ == "__main__":
|
||||||
args.lora_weights,
|
args.lora_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
blend_model = blend_loras(args.base, args.lora_models, args.type, args.lora_weights)
|
blend_model = blend_loras(
|
||||||
|
context, args.base, args.lora_models, args.type, args.lora_weights
|
||||||
|
)
|
||||||
if args.dest is None or args.dest == "" or args.dest == "ort":
|
if args.dest is None or args.dest == "" or args.dest == "ort":
|
||||||
# convert to external data and save to memory
|
# convert to external data and save to memory
|
||||||
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
|
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
|
||||||
|
|
|
@ -6,7 +6,6 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub.file_download import hf_hub_download
|
from huggingface_hub.file_download import hf_hub_download
|
||||||
from onnx import ModelProto, load_model, numpy_helper, save_model
|
from onnx import ModelProto, load_model, numpy_helper, save_model
|
||||||
from torch.onnx import export
|
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
from ...server.context import ServerContext
|
from ...server.context import ServerContext
|
||||||
|
@ -25,11 +24,15 @@ def blend_textual_inversions(
|
||||||
inversion_weights: Optional[List[float]] = None,
|
inversion_weights: Optional[List[float]] = None,
|
||||||
base_tokens: Optional[List[str]] = None,
|
base_tokens: Optional[List[str]] = None,
|
||||||
) -> Tuple[ModelProto, CLIPTokenizer]:
|
) -> Tuple[ModelProto, CLIPTokenizer]:
|
||||||
dtype = np.float # TODO: fixed type, which one?
|
dtype = np.float
|
||||||
# prev: text_encoder.get_input_embeddings().weight.dtype
|
|
||||||
embeds = {}
|
embeds = {}
|
||||||
|
|
||||||
for name, format, weight, base_token in zip(inversion_names, inversion_formats, inversion_weights, base_tokens or inversion_names):
|
for name, format, weight, base_token in zip(
|
||||||
|
inversion_names,
|
||||||
|
inversion_formats,
|
||||||
|
inversion_weights,
|
||||||
|
base_tokens or inversion_names,
|
||||||
|
):
|
||||||
logger.info("blending Textual Inversion %s with weight of %s", name, weight)
|
logger.info("blending Textual Inversion %s with weight of %s", name, weight)
|
||||||
if format == "concept":
|
if format == "concept":
|
||||||
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
|
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
|
||||||
|
@ -64,7 +67,7 @@ def blend_textual_inversions(
|
||||||
|
|
||||||
for i in range(num_tokens):
|
for i in range(num_tokens):
|
||||||
token = f"{base_token or name}-{i}"
|
token = f"{base_token or name}-{i}"
|
||||||
layer = trained_embeds[i,:].cpu().numpy().astype(dtype)
|
layer = trained_embeds[i, :].cpu().numpy().astype(dtype)
|
||||||
layer *= weight
|
layer *= weight
|
||||||
if token in embeds:
|
if token in embeds:
|
||||||
embeds[token] += layer
|
embeds[token] += layer
|
||||||
|
@ -74,7 +77,9 @@ def blend_textual_inversions(
|
||||||
raise ValueError(f"unknown Textual Inversion format: {format}")
|
raise ValueError(f"unknown Textual Inversion format: {format}")
|
||||||
|
|
||||||
# add the tokens to the tokenizer
|
# add the tokens to the tokenizer
|
||||||
logger.info("found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys())
|
logger.info(
|
||||||
|
"found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys()
|
||||||
|
)
|
||||||
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
|
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
|
||||||
if num_added_tokens == 0:
|
if num_added_tokens == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -85,7 +90,11 @@ def blend_textual_inversions(
|
||||||
|
|
||||||
# resize the token embeddings
|
# resize the token embeddings
|
||||||
# text_encoder.resize_token_embeddings(len(tokenizer))
|
# text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
embedding_node = [n for n in text_encoder.graph.initializer if n.name == "text_model.embeddings.token_embedding.weight"][0]
|
embedding_node = [
|
||||||
|
n
|
||||||
|
for n in text_encoder.graph.initializer
|
||||||
|
if n.name == "text_model.embeddings.token_embedding.weight"
|
||||||
|
][0]
|
||||||
embedding_weights = numpy_helper.to_array(embedding_node)
|
embedding_weights = numpy_helper.to_array(embedding_node)
|
||||||
|
|
||||||
weights_dim = embedding_weights.shape[1]
|
weights_dim = embedding_weights.shape[1]
|
||||||
|
@ -94,15 +103,18 @@ def blend_textual_inversions(
|
||||||
|
|
||||||
for token, weights in embeds.items():
|
for token, weights in embeds.items():
|
||||||
token_id = tokenizer.convert_tokens_to_ids(token)
|
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||||
logger.debug(
|
logger.debug("embedding %s weights for token %s", weights.shape, token)
|
||||||
"embedding %s weights for token %s", weights.shape, token
|
|
||||||
)
|
|
||||||
embedding_weights[token_id] = weights
|
embedding_weights[token_id] = weights
|
||||||
|
|
||||||
# replace embedding_node
|
# replace embedding_node
|
||||||
for i in range(len(text_encoder.graph.initializer)):
|
for i in range(len(text_encoder.graph.initializer)):
|
||||||
if text_encoder.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight":
|
if (
|
||||||
new_initializer = numpy_helper.from_array(embedding_weights.astype(np.float32), embedding_node.name)
|
text_encoder.graph.initializer[i].name
|
||||||
|
== "text_model.embeddings.token_embedding.weight"
|
||||||
|
):
|
||||||
|
new_initializer = numpy_helper.from_array(
|
||||||
|
embedding_weights.astype(np.float32), embedding_node.name
|
||||||
|
)
|
||||||
logger.debug("new initializer data type: %s", new_initializer.data_type)
|
logger.debug("new initializer data type: %s", new_initializer.data_type)
|
||||||
del text_encoder.graph.initializer[i]
|
del text_encoder.graph.initializer[i]
|
||||||
text_encoder.graph.initializer.insert(i, new_initializer)
|
text_encoder.graph.initializer.insert(i, new_initializer)
|
||||||
|
|
|
@ -221,7 +221,10 @@ def load_pipeline(
|
||||||
inversion_names, inversion_weights = zip(*inversions)
|
inversion_names, inversion_weights = zip(*inversions)
|
||||||
logger.debug("blending Textual Inversions from %s", inversion_names)
|
logger.debug("blending Textual Inversions from %s", inversion_names)
|
||||||
|
|
||||||
inversion_models = [path.join(server.model_path, "inversion", f"{name}.ckpt") for name in inversion_names]
|
inversion_models = [
|
||||||
|
path.join(server.model_path, "inversion", f"{name}.ckpt")
|
||||||
|
for name in inversion_names
|
||||||
|
]
|
||||||
text_encoder = load_model(path.join(model, "text_encoder", "model.onnx"))
|
text_encoder = load_model(path.join(model, "text_encoder", "model.onnx"))
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
model,
|
model,
|
||||||
|
@ -249,16 +252,33 @@ def load_pipeline(
|
||||||
# test LoRA blending
|
# test LoRA blending
|
||||||
if loras is not None and len(loras) > 0:
|
if loras is not None and len(loras) > 0:
|
||||||
lora_names, lora_weights = zip(*loras)
|
lora_names, lora_weights = zip(*loras)
|
||||||
lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names]
|
lora_models = [
|
||||||
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
path.join(server.model_path, "lora", f"{name}.safetensors")
|
||||||
|
for name in lora_names
|
||||||
|
]
|
||||||
|
logger.info(
|
||||||
|
"blending base model %s with LoRA models: %s", model, lora_models
|
||||||
|
)
|
||||||
|
|
||||||
# blend and load text encoder
|
# blend and load text encoder
|
||||||
text_encoder = text_encoder or path.join(model, "text_encoder", "model.onnx")
|
text_encoder = text_encoder or path.join(
|
||||||
blended_text_encoder = blend_loras(text_encoder, lora_models, "text_encoder", lora_weights=lora_weights)
|
model, "text_encoder", "model.onnx"
|
||||||
(text_encoder, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
|
)
|
||||||
|
text_encoder = blend_loras(
|
||||||
|
server,
|
||||||
|
text_encoder,
|
||||||
|
lora_models,
|
||||||
|
"text_encoder",
|
||||||
|
lora_weights=lora_weights,
|
||||||
|
)
|
||||||
|
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
|
||||||
|
text_encoder
|
||||||
|
)
|
||||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||||
text_encoder_opts = SessionOptions()
|
text_encoder_opts = SessionOptions()
|
||||||
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
|
text_encoder_opts.add_external_initializers(
|
||||||
|
list(text_encoder_names), list(text_encoder_values)
|
||||||
|
)
|
||||||
components["text_encoder"] = OnnxRuntimeModel(
|
components["text_encoder"] = OnnxRuntimeModel(
|
||||||
OnnxRuntimeModel.load_model(
|
OnnxRuntimeModel.load_model(
|
||||||
text_encoder.SerializeToString(),
|
text_encoder.SerializeToString(),
|
||||||
|
@ -268,7 +288,13 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# blend and load unet
|
# blend and load unet
|
||||||
blended_unet = blend_loras(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights)
|
blended_unet = blend_loras(
|
||||||
|
server,
|
||||||
|
path.join(model, "unet", "model.onnx"),
|
||||||
|
lora_models,
|
||||||
|
"unet",
|
||||||
|
lora_weights=lora_weights,
|
||||||
|
)
|
||||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||||
unet_names, unet_values = zip(*unet_data)
|
unet_names, unet_values = zip(*unet_data)
|
||||||
unet_opts = SessionOptions()
|
unet_opts = SessionOptions()
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from re import compile, Pattern
|
from re import Pattern, compile
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -132,7 +132,9 @@ def expand_prompt(
|
||||||
return prompt_embeds
|
return prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, List[Tuple[str, float]]]:
|
def get_tokens_from_prompt(
|
||||||
|
prompt: str, pattern: Pattern[str]
|
||||||
|
) -> Tuple[str, List[Tuple[str, float]]]:
|
||||||
"""
|
"""
|
||||||
TODO: replace with Arpeggio
|
TODO: replace with Arpeggio
|
||||||
"""
|
"""
|
||||||
|
@ -145,7 +147,10 @@ def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, Lis
|
||||||
name, weight = next_match.groups()
|
name, weight = next_match.groups()
|
||||||
tokens.append((name, float(weight)))
|
tokens.append((name, float(weight)))
|
||||||
# remove this match and look for another
|
# remove this match and look for another
|
||||||
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
|
remaining_prompt = (
|
||||||
|
remaining_prompt[: next_match.start()]
|
||||||
|
+ remaining_prompt[next_match.end() :]
|
||||||
|
)
|
||||||
next_match = pattern.search(remaining_prompt)
|
next_match = pattern.search(remaining_prompt)
|
||||||
|
|
||||||
return (remaining_prompt, tokens)
|
return (remaining_prompt, tokens)
|
||||||
|
|
Loading…
Reference in New Issue