lint LoRA code and extras, replace public paths with context ones
This commit is contained in:
parent
ce74183e97
commit
d88f13cbd7
|
@ -1,10 +1,5 @@
|
||||||
{
|
{
|
||||||
"diffusion": [
|
"diffusion": [
|
||||||
{
|
|
||||||
"name": "diffusion-ugly-sonic",
|
|
||||||
"source": "runwayml/stable-diffusion-v1-5",
|
|
||||||
"inversion": "sd-concepts-library/ugly-sonic"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "diffusion-knollingcase",
|
"name": "diffusion-knollingcase",
|
||||||
"source": "Aybeeceedee/knollingcase"
|
"source": "Aybeeceedee/knollingcase"
|
||||||
|
@ -42,6 +37,10 @@
|
||||||
{
|
{
|
||||||
"name": "minecraft",
|
"name": "minecraft",
|
||||||
"source": "sd-concepts-library/minecraft-concept-art"
|
"source": "sd-concepts-library/minecraft-concept-art"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "ugly-sonic",
|
||||||
|
"source": "sd-concepts-library/ugly-sonic"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,17 +1,24 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
from os import path
|
||||||
|
from sys import argv
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import onnx.checker
|
||||||
|
import torch
|
||||||
from numpy import ndarray
|
from numpy import ndarray
|
||||||
from onnx import ModelProto, TensorProto, helper, load, numpy_helper, save_model
|
from onnx import ModelProto, TensorProto, helper, load, numpy_helper, save_model
|
||||||
from sys import argv
|
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
|
||||||
import torch
|
from ..utils import ConversionContext
|
||||||
import onnx.checker
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
###
|
||||||
|
# everything in this file is still super experimental and may not produce valid ONNX models
|
||||||
|
###
|
||||||
|
|
||||||
|
|
||||||
def load_lora(filename: str):
|
def load_lora(filename: str):
|
||||||
model = load(filename)
|
model = load(filename)
|
||||||
|
|
||||||
|
@ -51,13 +58,13 @@ def blend_loras(
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def convert_diffusion_lora(part: str):
|
def convert_diffusion_lora(context: ConversionContext, component: str):
|
||||||
lora_weights = [
|
lora_weights = [
|
||||||
f"diffusion-lora-jack/{part}/model.onnx",
|
f"diffusion-lora-jack/{component}/model.onnx",
|
||||||
f"diffusion-lora-taters/{part}/model.onnx",
|
f"diffusion-lora-taters/{component}/model.onnx",
|
||||||
]
|
]
|
||||||
|
|
||||||
base = load_lora(f"stable-diffusion-onnx-v1-5/{part}/model.onnx")
|
base = load_lora(f"stable-diffusion-onnx-v1-5/{component}/model.onnx")
|
||||||
weights = [load_lora(f) for f in lora_weights]
|
weights = [load_lora(f) for f in lora_weights]
|
||||||
alphas = [1 / len(weights)] * len(weights)
|
alphas = [1 / len(weights)] * len(weights)
|
||||||
logger.info("blending LoRAs with alphas: %s, %s", weights, alphas)
|
logger.info("blending LoRAs with alphas: %s, %s", weights, alphas)
|
||||||
|
@ -91,17 +98,19 @@ def convert_diffusion_lora(part: str):
|
||||||
opset = model.opset_import.add()
|
opset = model.opset_import.add()
|
||||||
opset.version = 14
|
opset.version = 14
|
||||||
|
|
||||||
|
onnx_path = path.join(context.cache_path, f"lora-{component}.onnx")
|
||||||
|
tensor_path = path.join(context.cache_path, f"lora-{component}.tensors")
|
||||||
save_model(
|
save_model(
|
||||||
model,
|
model,
|
||||||
f"/tmp/lora-{part}.onnx",
|
onnx_path,
|
||||||
save_as_external_data=True,
|
save_as_external_data=True,
|
||||||
all_tensors_to_one_file=True,
|
all_tensors_to_one_file=True,
|
||||||
location=f"/tmp/lora-{part}.tensors",
|
location=tensor_path,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"saved model to %s and tensors to %s",
|
"saved model to %s and tensors to %s",
|
||||||
f"/tmp/lora-{part}.onnx",
|
onnx_path,
|
||||||
f"/tmp/lora-{part}.tensors",
|
tensor_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -124,40 +133,59 @@ def merge_lora():
|
||||||
|
|
||||||
for key in lora_model.keys():
|
for key in lora_model.keys():
|
||||||
if "lora_down" in key:
|
if "lora_down" in key:
|
||||||
lora_key = key[:key.index("lora_down")].replace("lora_unet_", "")
|
lora_key = key[: key.index("lora_down")].replace("lora_unet_", "")
|
||||||
if lora_key.startswith(base_key):
|
if lora_key.startswith(base_key):
|
||||||
print("down for key:", base_key, lora_key)
|
print("down for key:", base_key, lora_key)
|
||||||
|
|
||||||
up_key = key.replace("lora_down", "lora_up")
|
up_key = key.replace("lora_down", "lora_up")
|
||||||
alpha_key = key[:key.index("lora_down")] + 'alpha'
|
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||||
|
|
||||||
down_weight = lora_model.get_tensor(key).to(dtype=torch.float32)
|
down_weight = lora_model.get_tensor(key).to(dtype=torch.float32)
|
||||||
up_weight = lora_model.get_tensor(up_key).to(dtype=torch.float32)
|
up_weight = lora_model.get_tensor(up_key).to(dtype=torch.float32)
|
||||||
|
|
||||||
dim = down_weight.size()[0]
|
dim = down_weight.size()[0]
|
||||||
alpha = lora_model.get(alpha_key).numpy() or dim
|
alpha = lora_model.get(alpha_key).numpy() or dim
|
||||||
scale = alpha / dim
|
|
||||||
|
|
||||||
np_vals = numpy_helper.to_array(base_node)
|
np_vals = numpy_helper.to_array(base_node)
|
||||||
print(np_vals.shape, up_weight.shape, down_weight.shape)
|
print(np_vals.shape, up_weight.shape, down_weight.shape)
|
||||||
|
|
||||||
squoze = (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
squoze = (
|
||||||
|
(
|
||||||
|
up_weight.squeeze(3).squeeze(2)
|
||||||
|
@ down_weight.squeeze(3).squeeze(2)
|
||||||
|
)
|
||||||
|
.unsqueeze(2)
|
||||||
|
.unsqueeze(3)
|
||||||
|
)
|
||||||
print(squoze.shape)
|
print(squoze.shape)
|
||||||
|
|
||||||
np_vals = np_vals + (alpha * squoze.numpy())
|
np_vals = np_vals + (alpha * squoze.numpy())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if len(up_weight.size()) == 2:
|
if len(up_weight.size()) == 2:
|
||||||
squoze = (up_weight @ down_weight)
|
squoze = up_weight @ down_weight
|
||||||
print(squoze.shape)
|
print(squoze.shape)
|
||||||
np_vals = np_vals + (squoze.numpy() * (alpha / dim))
|
np_vals = np_vals + (squoze.numpy() * (alpha / dim))
|
||||||
else:
|
else:
|
||||||
squoze = (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
squoze = (
|
||||||
|
(
|
||||||
|
up_weight.squeeze(3).squeeze(2)
|
||||||
|
@ down_weight.squeeze(3).squeeze(2)
|
||||||
|
)
|
||||||
|
.unsqueeze(2)
|
||||||
|
.unsqueeze(3)
|
||||||
|
)
|
||||||
print(squoze.shape)
|
print(squoze.shape)
|
||||||
np_vals = np_vals + (alpha * squoze.numpy())
|
np_vals = np_vals + (alpha * squoze.numpy())
|
||||||
|
|
||||||
# retensor = numpy_helper.from_array(np_vals, base_node.name)
|
# retensor = numpy_helper.from_array(np_vals, base_node.name)
|
||||||
retensor = helper.make_tensor(base_node.name, base_node.data_type, base_node.dim, np_vals, raw=True)
|
retensor = helper.make_tensor(
|
||||||
|
base_node.name,
|
||||||
|
base_node.data_type,
|
||||||
|
base_node.dim,
|
||||||
|
np_vals,
|
||||||
|
raw=True,
|
||||||
|
)
|
||||||
print(retensor)
|
print(retensor)
|
||||||
|
|
||||||
# TypeError: does not support assignment
|
# TypeError: does not support assignment
|
||||||
|
@ -167,7 +195,6 @@ def merge_lora():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
if retensor is None:
|
if retensor is None:
|
||||||
print("no lora found for key", base_key)
|
print("no lora found for key", base_key)
|
||||||
lora_nodes.append(base_node)
|
lora_nodes.append(base_node)
|
||||||
|
@ -179,7 +206,6 @@ def merge_lora():
|
||||||
onnx.checker.check_model(base_model)
|
onnx.checker.check_model(base_model)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
convert_diffusion_lora("unet")
|
convert_diffusion_lora("unet")
|
||||||
convert_diffusion_lora("text_encoder")
|
convert_diffusion_lora("text_encoder")
|
||||||
|
|
Loading…
Reference in New Issue