apply lint
This commit is contained in:
parent
91210ee236
commit
45166f281e
|
@ -1,7 +1,7 @@
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Dict, List, Literal, Tuple
|
|
||||||
from os import path
|
from os import path
|
||||||
|
from typing import Dict, List, Literal, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -12,7 +12,7 @@ from onnx.external_data_helper import (
|
||||||
set_external_data,
|
set_external_data,
|
||||||
write_external_data_tensors,
|
write_external_data_tensors,
|
||||||
)
|
)
|
||||||
from onnxruntime import OrtValue, InferenceSession, SessionOptions
|
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from onnx_web.convert.utils import ConversionContext
|
from onnx_web.convert.utils import ConversionContext
|
||||||
|
@ -25,7 +25,9 @@ logger = getLogger(__name__)
|
||||||
###
|
###
|
||||||
|
|
||||||
|
|
||||||
def buffer_external_data_tensors(model: ModelProto) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]:
|
def buffer_external_data_tensors(
|
||||||
|
model: ModelProto,
|
||||||
|
) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]:
|
||||||
external_data = []
|
external_data = []
|
||||||
for tensor in model.graph.initializer:
|
for tensor in model.graph.initializer:
|
||||||
name = tensor.name
|
name = tensor.name
|
||||||
|
@ -74,17 +76,19 @@ def merge_lora(
|
||||||
lora_prefix = f"lora_{dest_type}_"
|
lora_prefix = f"lora_{dest_type}_"
|
||||||
|
|
||||||
blended: Dict[str, np.ndarray] = {}
|
blended: Dict[str, np.ndarray] = {}
|
||||||
for lora_name, lora_model, lora_weight in zip(lora_names, lora_models, lora_weights):
|
for lora_name, lora_model, lora_weight in zip(
|
||||||
|
lora_names, lora_models, lora_weights
|
||||||
|
):
|
||||||
logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
||||||
for key in lora_model.keys():
|
for key in lora_model.keys():
|
||||||
if ".lora_down" in key and lora_prefix in key:
|
if ".lora_down" in key and lora_prefix in key:
|
||||||
base_key = key[: key.index(".lora_down")].replace(
|
base_key = key[: key.index(".lora_down")].replace(lora_prefix, "")
|
||||||
lora_prefix, ""
|
|
||||||
)
|
|
||||||
|
|
||||||
up_key = key.replace("lora_down", "lora_up")
|
up_key = key.replace("lora_down", "lora_up")
|
||||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||||
logger.info("blending weights for keys: %s, %s, %s", key, up_key, alpha_key)
|
logger.info(
|
||||||
|
"blending weights for keys: %s, %s, %s", key, up_key, alpha_key
|
||||||
|
)
|
||||||
|
|
||||||
down_weight = lora_model[key].to(dtype=torch.float32)
|
down_weight = lora_model[key].to(dtype=torch.float32)
|
||||||
up_weight = lora_model[up_key].to(dtype=torch.float32)
|
up_weight = lora_model[up_key].to(dtype=torch.float32)
|
||||||
|
@ -95,12 +99,22 @@ def merge_lora(
|
||||||
try:
|
try:
|
||||||
if len(up_weight.size()) == 2:
|
if len(up_weight.size()) == 2:
|
||||||
# blend for nn.Linear
|
# blend for nn.Linear
|
||||||
logger.info("blending weights for Linear node: %s, %s, %s", down_weight.shape, up_weight.shape, alpha)
|
logger.info(
|
||||||
|
"blending weights for Linear node: %s, %s, %s",
|
||||||
|
down_weight.shape,
|
||||||
|
up_weight.shape,
|
||||||
|
alpha,
|
||||||
|
)
|
||||||
weights = up_weight @ down_weight
|
weights = up_weight @ down_weight
|
||||||
np_weights = (weights.numpy() * (alpha / dim))
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
|
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
|
||||||
# blend for nn.Conv2d 1x1
|
# blend for nn.Conv2d 1x1
|
||||||
logger.info("blending weights for Conv node: %s, %s, %s", down_weight.shape, up_weight.shape, alpha)
|
logger.info(
|
||||||
|
"blending weights for Conv node: %s, %s, %s",
|
||||||
|
down_weight.shape,
|
||||||
|
up_weight.shape,
|
||||||
|
alpha,
|
||||||
|
)
|
||||||
weights = (
|
weights = (
|
||||||
(
|
(
|
||||||
up_weight.squeeze(3).squeeze(2)
|
up_weight.squeeze(3).squeeze(2)
|
||||||
|
@ -109,10 +123,14 @@ def merge_lora(
|
||||||
.unsqueeze(2)
|
.unsqueeze(2)
|
||||||
.unsqueeze(3)
|
.unsqueeze(3)
|
||||||
)
|
)
|
||||||
np_weights = (weights.numpy() * (alpha / dim))
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
else:
|
else:
|
||||||
# TODO: add support for Conv2d 3x3
|
# TODO: add support for Conv2d 3x3
|
||||||
logger.warning("unknown LoRA node type at %s: %s", base_key, up_weight.shape[-2:])
|
logger.warning(
|
||||||
|
"unknown LoRA node type at %s: %s",
|
||||||
|
base_key,
|
||||||
|
up_weight.shape[-2:],
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
np_weights *= lora_weight
|
np_weights *= lora_weight
|
||||||
|
@ -122,15 +140,13 @@ def merge_lora(
|
||||||
blended[base_key] = np_weights
|
blended[base_key] = np_weights
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception("error blending weights for key %s", base_key)
|
||||||
"error blending weights for key %s", base_key
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"updating %s of %s initializers: %s",
|
"updating %s of %s initializers: %s",
|
||||||
len(blended.keys()),
|
len(blended.keys()),
|
||||||
len(base_model.graph.initializer),
|
len(base_model.graph.initializer),
|
||||||
list(blended.keys())
|
list(blended.keys()),
|
||||||
)
|
)
|
||||||
|
|
||||||
fixed_initializer_names = [
|
fixed_initializer_names = [
|
||||||
|
@ -138,17 +154,19 @@ def merge_lora(
|
||||||
]
|
]
|
||||||
# logger.info("fixed initializer names: %s", fixed_initializer_names)
|
# logger.info("fixed initializer names: %s", fixed_initializer_names)
|
||||||
|
|
||||||
fixed_node_names = [
|
fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node]
|
||||||
fix_node_name(node.name) for node in base_model.graph.node
|
|
||||||
]
|
|
||||||
# logger.info("fixed node names: %s", fixed_node_names)
|
# logger.info("fixed node names: %s", fixed_node_names)
|
||||||
|
|
||||||
|
|
||||||
for base_key, weights in blended.items():
|
for base_key, weights in blended.items():
|
||||||
conv_key = base_key + "_Conv"
|
conv_key = base_key + "_Conv"
|
||||||
matmul_key = base_key + "_MatMul"
|
matmul_key = base_key + "_MatMul"
|
||||||
|
|
||||||
logger.info("key %s has conv: %s, matmul: %s", base_key, conv_key in fixed_node_names, matmul_key in fixed_node_names)
|
logger.info(
|
||||||
|
"key %s has conv: %s, matmul: %s",
|
||||||
|
base_key,
|
||||||
|
conv_key in fixed_node_names,
|
||||||
|
matmul_key in fixed_node_names,
|
||||||
|
)
|
||||||
|
|
||||||
if conv_key in fixed_node_names:
|
if conv_key in fixed_node_names:
|
||||||
conv_idx = fixed_node_names.index(conv_key)
|
conv_idx = fixed_node_names.index(conv_key)
|
||||||
|
@ -166,7 +184,11 @@ def merge_lora(
|
||||||
|
|
||||||
# blending
|
# blending
|
||||||
base_weights = numpy_helper.to_array(weight_node)
|
base_weights = numpy_helper.to_array(weight_node)
|
||||||
logger.info("found blended weights for conv: %s, %s", weights.shape, base_weights.shape)
|
logger.info(
|
||||||
|
"found blended weights for conv: %s, %s",
|
||||||
|
weights.shape,
|
||||||
|
base_weights.shape,
|
||||||
|
)
|
||||||
|
|
||||||
blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
||||||
blended = np.expand_dims(blended, (2, 3))
|
blended = np.expand_dims(blended, (2, 3))
|
||||||
|
@ -191,7 +213,11 @@ def merge_lora(
|
||||||
|
|
||||||
# blending
|
# blending
|
||||||
base_weights = numpy_helper.to_array(matmul_node)
|
base_weights = numpy_helper.to_array(matmul_node)
|
||||||
logger.info("found blended weights for matmul: %s, %s", weights.shape, base_weights.shape)
|
logger.info(
|
||||||
|
"found blended weights for matmul: %s, %s",
|
||||||
|
weights.shape,
|
||||||
|
base_weights.shape,
|
||||||
|
)
|
||||||
|
|
||||||
blended = base_weights + weights.transpose()
|
blended = base_weights + weights.transpose()
|
||||||
logger.info("blended weight shape: %s", blended.shape)
|
logger.info("blended weight shape: %s", blended.shape)
|
||||||
|
@ -208,7 +234,7 @@ def merge_lora(
|
||||||
len(fixed_initializer_names),
|
len(fixed_initializer_names),
|
||||||
len(base_model.graph.initializer),
|
len(base_model.graph.initializer),
|
||||||
len(fixed_node_names),
|
len(fixed_node_names),
|
||||||
len(base_model.graph.node)
|
len(base_model.graph.node),
|
||||||
)
|
)
|
||||||
|
|
||||||
return base_model
|
return base_model
|
||||||
|
@ -219,11 +245,16 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--base", type=str)
|
parser.add_argument("--base", type=str)
|
||||||
parser.add_argument("--dest", type=str)
|
parser.add_argument("--dest", type=str)
|
||||||
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
|
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
|
||||||
parser.add_argument("--lora_models", nargs='+', type=str)
|
parser.add_argument("--lora_models", nargs="+", type=str)
|
||||||
parser.add_argument("--lora_weights", nargs='+', type=float)
|
parser.add_argument("--lora_weights", nargs="+", type=float)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights)
|
logger.info(
|
||||||
|
"merging %s with %s with weights: %s",
|
||||||
|
args.lora_models,
|
||||||
|
args.base,
|
||||||
|
args.lora_weights,
|
||||||
|
)
|
||||||
|
|
||||||
blend_model = merge_lora(args.base, args.lora_models, args.type, args.lora_weights)
|
blend_model = merge_lora(args.base, args.lora_models, args.type, args.lora_weights)
|
||||||
if args.dest is None or args.dest == "" or args.dest == "ort":
|
if args.dest is None or args.dest == "" or args.dest == "ort":
|
||||||
|
@ -234,10 +265,18 @@ if __name__ == "__main__":
|
||||||
external_names, external_values = zip(*external_data)
|
external_names, external_values = zip(*external_data)
|
||||||
opts = SessionOptions()
|
opts = SessionOptions()
|
||||||
opts.add_external_initializers(list(external_names), list(external_values))
|
opts.add_external_initializers(list(external_names), list(external_values))
|
||||||
sess = InferenceSession(bare_model.SerializeToString(), sess_options=opts, providers=["CPUExecutionProvider"])
|
sess = InferenceSession(
|
||||||
logger.info("successfully loaded blended model: %s", [i.name for i in sess.get_inputs()])
|
bare_model.SerializeToString(),
|
||||||
|
sess_options=opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
convert_model_to_external_data(blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb")
|
convert_model_to_external_data(
|
||||||
|
blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb"
|
||||||
|
)
|
||||||
bare_model = write_external_data_tensors(blend_model, args.dest)
|
bare_model = write_external_data_tensors(blend_model, args.dest)
|
||||||
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
|
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
|
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
|
||||||
|
|
||||||
from ..convert.diffusion.lora import merge_lora, buffer_external_data_tensors
|
from ..convert.diffusion.lora import buffer_external_data_tensors, merge_lora
|
||||||
from ..params import DeviceParams, Size
|
from ..params import DeviceParams, Size
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
|
@ -118,7 +118,10 @@ def get_loras_from_prompt(prompt: str) -> Tuple[str, List[str]]:
|
||||||
name, weight = next_match.groups()
|
name, weight = next_match.groups()
|
||||||
loras.append(name)
|
loras.append(name)
|
||||||
# remove this match and look for another
|
# remove this match and look for another
|
||||||
remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():]
|
remaining_prompt = (
|
||||||
|
remaining_prompt[: next_match.start()]
|
||||||
|
+ remaining_prompt[next_match.end() :]
|
||||||
|
)
|
||||||
next_match = lora_expr.search(remaining_prompt)
|
next_match = lora_expr.search(remaining_prompt)
|
||||||
|
|
||||||
return (remaining_prompt, loras)
|
return (remaining_prompt, loras)
|
||||||
|
@ -244,15 +247,23 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# test LoRA blending
|
# test LoRA blending
|
||||||
lora_models = [path.join(server.model_path, "lora", f"{i}.safetensors") for i in loras]
|
lora_models = [
|
||||||
|
path.join(server.model_path, "lora", f"{i}.safetensors") for i in loras
|
||||||
|
]
|
||||||
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
||||||
|
|
||||||
# blend and load text encoder
|
# blend and load text encoder
|
||||||
blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder")
|
blended_text_encoder = merge_lora(
|
||||||
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
|
path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder"
|
||||||
|
)
|
||||||
|
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(
|
||||||
|
blended_text_encoder
|
||||||
|
)
|
||||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||||
text_encoder_opts = SessionOptions()
|
text_encoder_opts = SessionOptions()
|
||||||
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
|
text_encoder_opts.add_external_initializers(
|
||||||
|
list(text_encoder_names), list(text_encoder_values)
|
||||||
|
)
|
||||||
components["text_encoder"] = OnnxRuntimeModel(
|
components["text_encoder"] = OnnxRuntimeModel(
|
||||||
OnnxRuntimeModel.load_model(
|
OnnxRuntimeModel.load_model(
|
||||||
text_encoder_model.SerializeToString(),
|
text_encoder_model.SerializeToString(),
|
||||||
|
@ -262,7 +273,9 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# blend and load unet
|
# blend and load unet
|
||||||
blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, "unet")
|
blended_unet = merge_lora(
|
||||||
|
path.join(model, "unet", "model.onnx"), lora_models, "unet"
|
||||||
|
)
|
||||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||||
unet_names, unet_values = zip(*unet_data)
|
unet_names, unet_values = zip(*unet_data)
|
||||||
unet_opts = SessionOptions()
|
unet_opts = SessionOptions()
|
||||||
|
|
Loading…
Reference in New Issue