apply lint
This commit is contained in:
parent
506cf9f65f
commit
8e8e230ffd
|
@ -15,6 +15,7 @@ from onnx.external_data_helper import (
|
|||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from ...server.context import ServerContext
|
||||
from ..utils import ConversionContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -55,6 +56,7 @@ def fix_node_name(key: str):
|
|||
|
||||
|
||||
def blend_loras(
|
||||
context: ServerContext,
|
||||
base_name: str,
|
||||
lora_names: List[str],
|
||||
dest_type: Literal["text_encoder", "unet"],
|
||||
|
@ -236,6 +238,7 @@ def blend_loras(
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
context = ConversionContext.from_environ()
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--base", type=str)
|
||||
parser.add_argument("--dest", type=str)
|
||||
|
@ -251,7 +254,9 @@ if __name__ == "__main__":
|
|||
args.lora_weights,
|
||||
)
|
||||
|
||||
blend_model = blend_loras(args.base, args.lora_models, args.type, args.lora_weights)
|
||||
blend_model = blend_loras(
|
||||
context, args.base, args.lora_models, args.type, args.lora_weights
|
||||
)
|
||||
if args.dest is None or args.dest == "" or args.dest == "ort":
|
||||
# convert to external data and save to memory
|
||||
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
|
||||
|
|
|
@ -6,7 +6,6 @@ import numpy as np
|
|||
import torch
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
from onnx import ModelProto, load_model, numpy_helper, save_model
|
||||
from torch.onnx import export
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from ...server.context import ServerContext
|
||||
|
@ -25,11 +24,15 @@ def blend_textual_inversions(
|
|||
inversion_weights: Optional[List[float]] = None,
|
||||
base_tokens: Optional[List[str]] = None,
|
||||
) -> Tuple[ModelProto, CLIPTokenizer]:
|
||||
dtype = np.float # TODO: fixed type, which one?
|
||||
# prev: text_encoder.get_input_embeddings().weight.dtype
|
||||
dtype = np.float
|
||||
embeds = {}
|
||||
|
||||
for name, format, weight, base_token in zip(inversion_names, inversion_formats, inversion_weights, base_tokens or inversion_names):
|
||||
for name, format, weight, base_token in zip(
|
||||
inversion_names,
|
||||
inversion_formats,
|
||||
inversion_weights,
|
||||
base_tokens or inversion_names,
|
||||
):
|
||||
logger.info("blending Textual Inversion %s with weight of %s", name, weight)
|
||||
if format == "concept":
|
||||
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
|
||||
|
@ -64,7 +67,7 @@ def blend_textual_inversions(
|
|||
|
||||
for i in range(num_tokens):
|
||||
token = f"{base_token or name}-{i}"
|
||||
layer = trained_embeds[i,:].cpu().numpy().astype(dtype)
|
||||
layer = trained_embeds[i, :].cpu().numpy().astype(dtype)
|
||||
layer *= weight
|
||||
if token in embeds:
|
||||
embeds[token] += layer
|
||||
|
@ -74,7 +77,9 @@ def blend_textual_inversions(
|
|||
raise ValueError(f"unknown Textual Inversion format: {format}")
|
||||
|
||||
# add the tokens to the tokenizer
|
||||
logger.info("found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys())
|
||||
logger.info(
|
||||
"found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys()
|
||||
)
|
||||
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
|
@ -85,7 +90,11 @@ def blend_textual_inversions(
|
|||
|
||||
# resize the token embeddings
|
||||
# text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
embedding_node = [n for n in text_encoder.graph.initializer if n.name == "text_model.embeddings.token_embedding.weight"][0]
|
||||
embedding_node = [
|
||||
n
|
||||
for n in text_encoder.graph.initializer
|
||||
if n.name == "text_model.embeddings.token_embedding.weight"
|
||||
][0]
|
||||
embedding_weights = numpy_helper.to_array(embedding_node)
|
||||
|
||||
weights_dim = embedding_weights.shape[1]
|
||||
|
@ -94,15 +103,18 @@ def blend_textual_inversions(
|
|||
|
||||
for token, weights in embeds.items():
|
||||
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||
logger.debug(
|
||||
"embedding %s weights for token %s", weights.shape, token
|
||||
)
|
||||
logger.debug("embedding %s weights for token %s", weights.shape, token)
|
||||
embedding_weights[token_id] = weights
|
||||
|
||||
# replace embedding_node
|
||||
for i in range(len(text_encoder.graph.initializer)):
|
||||
if text_encoder.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight":
|
||||
new_initializer = numpy_helper.from_array(embedding_weights.astype(np.float32), embedding_node.name)
|
||||
if (
|
||||
text_encoder.graph.initializer[i].name
|
||||
== "text_model.embeddings.token_embedding.weight"
|
||||
):
|
||||
new_initializer = numpy_helper.from_array(
|
||||
embedding_weights.astype(np.float32), embedding_node.name
|
||||
)
|
||||
logger.debug("new initializer data type: %s", new_initializer.data_type)
|
||||
del text_encoder.graph.initializer[i]
|
||||
text_encoder.graph.initializer.insert(i, new_initializer)
|
||||
|
|
|
@ -221,7 +221,10 @@ def load_pipeline(
|
|||
inversion_names, inversion_weights = zip(*inversions)
|
||||
logger.debug("blending Textual Inversions from %s", inversion_names)
|
||||
|
||||
inversion_models = [path.join(server.model_path, "inversion", f"{name}.ckpt") for name in inversion_names]
|
||||
inversion_models = [
|
||||
path.join(server.model_path, "inversion", f"{name}.ckpt")
|
||||
for name in inversion_names
|
||||
]
|
||||
text_encoder = load_model(path.join(model, "text_encoder", "model.onnx"))
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
model,
|
||||
|
@ -249,16 +252,33 @@ def load_pipeline(
|
|||
# test LoRA blending
|
||||
if loras is not None and len(loras) > 0:
|
||||
lora_names, lora_weights = zip(*loras)
|
||||
lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names]
|
||||
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
||||
lora_models = [
|
||||
path.join(server.model_path, "lora", f"{name}.safetensors")
|
||||
for name in lora_names
|
||||
]
|
||||
logger.info(
|
||||
"blending base model %s with LoRA models: %s", model, lora_models
|
||||
)
|
||||
|
||||
# blend and load text encoder
|
||||
text_encoder = text_encoder or path.join(model, "text_encoder", "model.onnx")
|
||||
blended_text_encoder = blend_loras(text_encoder, lora_models, "text_encoder", lora_weights=lora_weights)
|
||||
(text_encoder, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
|
||||
text_encoder = text_encoder or path.join(
|
||||
model, "text_encoder", "model.onnx"
|
||||
)
|
||||
text_encoder = blend_loras(
|
||||
server,
|
||||
text_encoder,
|
||||
lora_models,
|
||||
"text_encoder",
|
||||
lora_weights=lora_weights,
|
||||
)
|
||||
(text_encoder, text_encoder_data) = buffer_external_data_tensors(
|
||||
text_encoder
|
||||
)
|
||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||
text_encoder_opts = SessionOptions()
|
||||
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
|
||||
text_encoder_opts.add_external_initializers(
|
||||
list(text_encoder_names), list(text_encoder_values)
|
||||
)
|
||||
components["text_encoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
text_encoder.SerializeToString(),
|
||||
|
@ -268,7 +288,13 @@ def load_pipeline(
|
|||
)
|
||||
|
||||
# blend and load unet
|
||||
blended_unet = blend_loras(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights)
|
||||
blended_unet = blend_loras(
|
||||
server,
|
||||
path.join(model, "unet", "model.onnx"),
|
||||
lora_models,
|
||||
"unet",
|
||||
lora_weights=lora_weights,
|
||||
)
|
||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||
unet_names, unet_values = zip(*unet_data)
|
||||
unet_opts = SessionOptions()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from logging import getLogger
|
||||
from math import ceil
|
||||
from re import compile, Pattern
|
||||
from re import Pattern, compile
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
@ -132,7 +132,9 @@ def expand_prompt(
|
|||
return prompt_embeds
|
||||
|
||||
|
||||
def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, List[Tuple[str, float]]]:
|
||||
def get_tokens_from_prompt(
|
||||
prompt: str, pattern: Pattern[str]
|
||||
) -> Tuple[str, List[Tuple[str, float]]]:
|
||||
"""
|
||||
TODO: replace with Arpeggio
|
||||
"""
|
||||
|
@ -145,7 +147,10 @@ def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, Lis
|
|||
name, weight = next_match.groups()
|
||||
tokens.append((name, float(weight)))
|
||||
# remove this match and look for another
|
||||
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
|
||||
remaining_prompt = (
|
||||
remaining_prompt[: next_match.start()]
|
||||
+ remaining_prompt[next_match.end() :]
|
||||
)
|
||||
next_match = pattern.search(remaining_prompt)
|
||||
|
||||
return (remaining_prompt, tokens)
|
||||
|
|
Loading…
Reference in New Issue