1
0
Fork 0

apply lint

This commit is contained in:
Sean Sube 2023-03-14 23:32:47 -05:00
parent 91210ee236
commit 45166f281e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 90 additions and 38 deletions

View File

@ -1,7 +1,7 @@
from argparse import ArgumentParser
from logging import getLogger
from typing import Dict, List, Literal, Tuple
from os import path
from typing import Dict, List, Literal, Tuple
import numpy as np
import torch
@ -12,7 +12,7 @@ from onnx.external_data_helper import (
set_external_data,
write_external_data_tensors,
)
from onnxruntime import OrtValue, InferenceSession, SessionOptions
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from safetensors.torch import load_file
from onnx_web.convert.utils import ConversionContext
@ -25,7 +25,9 @@ logger = getLogger(__name__)
###
def buffer_external_data_tensors(model: ModelProto) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]:
def buffer_external_data_tensors(
model: ModelProto,
) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]:
external_data = []
for tensor in model.graph.initializer:
name = tensor.name
@ -74,17 +76,19 @@ def merge_lora(
lora_prefix = f"lora_{dest_type}_"
blended: Dict[str, np.ndarray] = {}
for lora_name, lora_model, lora_weight in zip(lora_names, lora_models, lora_weights):
for lora_name, lora_model, lora_weight in zip(
lora_names, lora_models, lora_weights
):
logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight)
for key in lora_model.keys():
if ".lora_down" in key and lora_prefix in key:
base_key = key[: key.index(".lora_down")].replace(
lora_prefix, ""
)
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
logger.info("blending weights for keys: %s, %s, %s", key, up_key, alpha_key)
logger.info(
"blending weights for keys: %s, %s, %s", key, up_key, alpha_key
)
down_weight = lora_model[key].to(dtype=torch.float32)
up_weight = lora_model[up_key].to(dtype=torch.float32)
@ -95,12 +99,22 @@ def merge_lora(
try:
if len(up_weight.size()) == 2:
# blend for nn.Linear
logger.info("blending weights for Linear node: %s, %s, %s", down_weight.shape, up_weight.shape, alpha)
logger.info(
"blending weights for Linear node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = up_weight @ down_weight
np_weights = (weights.numpy() * (alpha / dim))
np_weights = weights.numpy() * (alpha / dim)
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
# blend for nn.Conv2d 1x1
logger.info("blending weights for Conv node: %s, %s, %s", down_weight.shape, up_weight.shape, alpha)
logger.info(
"blending weights for Conv node: %s, %s, %s",
down_weight.shape,
up_weight.shape,
alpha,
)
weights = (
(
up_weight.squeeze(3).squeeze(2)
@ -109,10 +123,14 @@ def merge_lora(
.unsqueeze(2)
.unsqueeze(3)
)
np_weights = (weights.numpy() * (alpha / dim))
np_weights = weights.numpy() * (alpha / dim)
else:
# TODO: add support for Conv2d 3x3
logger.warning("unknown LoRA node type at %s: %s", base_key, up_weight.shape[-2:])
logger.warning(
"unknown LoRA node type at %s: %s",
base_key,
up_weight.shape[-2:],
)
continue
np_weights *= lora_weight
@ -122,15 +140,13 @@ def merge_lora(
blended[base_key] = np_weights
except Exception:
logger.exception(
"error blending weights for key %s", base_key
)
logger.exception("error blending weights for key %s", base_key)
logger.info(
"updating %s of %s initializers: %s",
len(blended.keys()),
len(base_model.graph.initializer),
list(blended.keys())
list(blended.keys()),
)
fixed_initializer_names = [
@ -138,17 +154,19 @@ def merge_lora(
]
# logger.info("fixed initializer names: %s", fixed_initializer_names)
fixed_node_names = [
fix_node_name(node.name) for node in base_model.graph.node
]
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
# logger.info("fixed node names: %s", fixed_node_names)
for base_key, weights in blended.items():
conv_key = base_key + "_Conv"
matmul_key = base_key + "_MatMul"
logger.info("key %s has conv: %s, matmul: %s", base_key, conv_key in fixed_node_names, matmul_key in fixed_node_names)
logger.info(
"key %s has conv: %s, matmul: %s",
base_key,
conv_key in fixed_node_names,
matmul_key in fixed_node_names,
)
if conv_key in fixed_node_names:
conv_idx = fixed_node_names.index(conv_key)
@ -166,7 +184,11 @@ def merge_lora(
# blending
base_weights = numpy_helper.to_array(weight_node)
logger.info("found blended weights for conv: %s, %s", weights.shape, base_weights.shape)
logger.info(
"found blended weights for conv: %s, %s",
weights.shape,
base_weights.shape,
)
blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
blended = np.expand_dims(blended, (2, 3))
@ -191,7 +213,11 @@ def merge_lora(
# blending
base_weights = numpy_helper.to_array(matmul_node)
logger.info("found blended weights for matmul: %s, %s", weights.shape, base_weights.shape)
logger.info(
"found blended weights for matmul: %s, %s",
weights.shape,
base_weights.shape,
)
blended = base_weights + weights.transpose()
logger.info("blended weight shape: %s", blended.shape)
@ -208,7 +234,7 @@ def merge_lora(
len(fixed_initializer_names),
len(base_model.graph.initializer),
len(fixed_node_names),
len(base_model.graph.node)
len(base_model.graph.node),
)
return base_model
@ -219,11 +245,16 @@ if __name__ == "__main__":
parser.add_argument("--base", type=str)
parser.add_argument("--dest", type=str)
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
parser.add_argument("--lora_models", nargs='+', type=str)
parser.add_argument("--lora_weights", nargs='+', type=float)
parser.add_argument("--lora_models", nargs="+", type=str)
parser.add_argument("--lora_weights", nargs="+", type=float)
args = parser.parse_args()
logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights)
logger.info(
"merging %s with %s with weights: %s",
args.lora_models,
args.base,
args.lora_weights,
)
blend_model = merge_lora(args.base, args.lora_models, args.type, args.lora_weights)
if args.dest is None or args.dest == "" or args.dest == "ort":
@ -234,10 +265,18 @@ if __name__ == "__main__":
external_names, external_values = zip(*external_data)
opts = SessionOptions()
opts.add_external_initializers(list(external_names), list(external_values))
sess = InferenceSession(bare_model.SerializeToString(), sess_options=opts, providers=["CPUExecutionProvider"])
logger.info("successfully loaded blended model: %s", [i.name for i in sess.get_inputs()])
sess = InferenceSession(
bare_model.SerializeToString(),
sess_options=opts,
providers=["CPUExecutionProvider"],
)
logger.info(
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
)
else:
convert_model_to_external_data(blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb")
convert_model_to_external_data(
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
)
bare_model = write_external_data_tensors(blend_model, args.dest)
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")

View File

@ -37,7 +37,7 @@ try:
except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
from ..convert.diffusion.lora import merge_lora, buffer_external_data_tensors
from ..convert.diffusion.lora import buffer_external_data_tensors, merge_lora
from ..params import DeviceParams, Size
from ..server import ServerContext
from ..utils import run_gc
@ -118,7 +118,10 @@ def get_loras_from_prompt(prompt: str) -> Tuple[str, List[str]]:
name, weight = next_match.groups()
loras.append(name)
# remove this match and look for another
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
remaining_prompt = (
remaining_prompt[: next_match.start()]
+ remaining_prompt[next_match.end() :]
)
next_match = lora_expr.search(remaining_prompt)
return (remaining_prompt, loras)
@ -244,15 +247,23 @@ def load_pipeline(
)
# test LoRA blending
lora_models = [path.join(server.model_path, "lora", f"{i}.safetensors") for i in loras]
lora_models = [
path.join(server.model_path, "lora", f"{i}.safetensors") for i in loras
]
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
# blend and load text encoder
blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder")
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
blended_text_encoder = merge_lora(
path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder"
)
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(
blended_text_encoder
)
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = SessionOptions()
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
text_encoder_opts.add_external_initializers(
list(text_encoder_names), list(text_encoder_values)
)
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder_model.SerializeToString(),
@ -262,7 +273,9 @@ def load_pipeline(
)
# blend and load unet
blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, "unet")
blended_unet = merge_lora(
path.join(model, "unet", "model.onnx"), lora_models, "unet"
)
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data)
unet_opts = SessionOptions()