1
0
Fork 0

add none option to inversion menu

This commit is contained in:
Sean Sube 2023-02-21 23:50:27 -06:00
parent 7ad8385c5b
commit e8b5ff250d
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
9 changed files with 565 additions and 244 deletions

View File

@ -22,6 +22,28 @@
"name": "diffusion-unstable-ink-dream-v6",
"source": "civitai://5796",
"format": "safetensors"
},
{
"name": "stable-diffusion-onnx-v1-5",
"source": "runwayml/stable-diffusion-v1-5",
"inversions": [
{
"name": "line-art",
"source": "sd-concepts-library/line-art"
},
{
"name": "cubex",
"source": "sd-concepts-library/cubex"
},
{
"name": "birb",
"source": "sd-concepts-library/birb-style"
},
{
"name": "minecraft",
"source": "sd-concepts-library/minecraft-concept-art"
}
]
}
],
"correction": [],

View File

@ -10,8 +10,8 @@ from jsonschema import ValidationError, validate
from yaml import safe_load
from .correction_gfpgan import convert_correction_gfpgan
from .diffusion.original import convert_diffusion_original
from .diffusion.diffusers import convert_diffusion_diffusers
from .diffusion.original import convert_diffusion_original
from .diffusion.textual_inversion import convert_diffusion_textual_inversion
from .upscale_resrgan import convert_upscale_resrgan
from .utils import (
@ -233,8 +233,12 @@ def convert_models(ctx: ConversionContext, args, models: Models):
for inversion in model.get("inversions", []):
inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_source = fetch_model(ctx, f"{name}-inversion-{inversion_name}", inversion_source)
convert_diffusion_textual_inversion(ctx, inversion_name, model["source"], inversion_source)
inversion_source = fetch_model(
ctx, f"{name}-inversion-{inversion_name}", inversion_source
)
convert_diffusion_textual_inversion(
ctx, inversion_name, model["source"], inversion_source
)
except Exception as e:
logger.error("error converting diffusion model %s: %s", name, e)

View File

@ -1,94 +1,105 @@
from numpy import ndarray
from onnx import TensorProto, helper, load, numpy_helper, ModelProto, save_model
from typing import Dict, List, Tuple
from logging import getLogger
from typing import List, Tuple
from numpy import ndarray
from onnx import ModelProto, TensorProto, helper, load, numpy_helper, save_model
logger = getLogger(__name__)
def load_lora(filename: str):
model = load(filename)
model = load(filename)
for weight in model.graph.initializer:
# print(weight.name, numpy_helper.to_array(weight).shape)
pass
for weight in model.graph.initializer:
# print(weight.name, numpy_helper.to_array(weight).shape)
pass
return model
return model
def blend_loras(base: ModelProto, weights: List[ModelProto], alphas: List[float]) -> List[Tuple[TensorProto, ndarray]]:
total = 1 + sum(alphas)
def blend_loras(
base: ModelProto, weights: List[ModelProto], alphas: List[float]
) -> List[Tuple[TensorProto, ndarray]]:
total = 1 + sum(alphas)
results = []
results = []
for base_node in base.graph.initializer:
logger.info("blending initializer node %s", base_node.name)
base_weights = numpy_helper.to_array(base_node).copy()
for base_node in base.graph.initializer:
logger.info("blending initializer node %s", base_node.name)
base_weights = numpy_helper.to_array(base_node).copy()
for weight, alpha in zip(weights, alphas):
weight_node = next(iter([f for f in weight.graph.initializer if f.name == base_node.name]), None)
for weight, alpha in zip(weights, alphas):
weight_node = next(
iter([f for f in weight.graph.initializer if f.name == base_node.name]),
None,
)
if weight_node is not None:
base_weights += numpy_helper.to_array(weight_node) * alpha
else:
logger.warning("missing weights: %s in %s", base_node.name, weight.doc_string)
if weight_node is not None:
base_weights += numpy_helper.to_array(weight_node) * alpha
else:
logger.warning(
"missing weights: %s in %s", base_node.name, weight.doc_string
)
results.append((base_node, base_weights / total))
results.append((base_node, base_weights / total))
return results
return results
def convert_diffusion_lora(part: str):
lora_weights = [
f"diffusion-lora-jack/{part}/model.onnx",
f"diffusion-lora-taters/{part}/model.onnx",
]
lora_weights = [
f"diffusion-lora-jack/{part}/model.onnx",
f"diffusion-lora-taters/{part}/model.onnx",
]
base = load_lora(f"stable-diffusion-onnx-v1-5/{part}/model.onnx")
weights = [load_lora(f) for f in lora_weights]
alphas = [1 / len(weights)] * len(weights)
logger.info("blending LoRAs with alphas: %s, %s", weights, alphas)
base = load_lora(f"stable-diffusion-onnx-v1-5/{part}/model.onnx")
weights = [load_lora(f) for f in lora_weights]
alphas = [1 / len(weights)] * len(weights)
logger.info("blending LoRAs with alphas: %s, %s", weights, alphas)
result = blend_loras(base, weights, alphas)
logger.info("blended result keys: %s", len(result))
result = blend_loras(base, weights, alphas)
logger.info("blended result keys: %s", len(result))
del weights
del alphas
del weights
del alphas
tensors = []
for node, tensor in result:
logger.info("remaking tensor for %s", node.name)
tensors.append(helper.make_tensor(node.name, node.data_type, node.dims, tensor))
tensors = []
for node, tensor in result:
logger.info("remaking tensor for %s", node.name)
tensors.append(helper.make_tensor(node.name, node.data_type, node.dims, tensor))
del result
del result
graph = helper.make_graph(
base.graph.node,
base.graph.name,
base.graph.input,
base.graph.output,
tensors,
base.graph.doc_string,
base.graph.value_info,
base.graph.sparse_initializer,
)
model = helper.make_model(graph)
graph = helper.make_graph(
base.graph.node,
base.graph.name,
base.graph.input,
base.graph.output,
tensors,
base.graph.doc_string,
base.graph.value_info,
base.graph.sparse_initializer,
)
model = helper.make_model(graph)
del model.opset_import[:]
opset = model.opset_import.add()
opset.version = 14
del model.opset_import[:]
opset = model.opset_import.add()
opset.version = 14
save_model(
model,
f"/tmp/lora-{part}.onnx",
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"/tmp/lora-{part}.tensors",
)
logger.info("saved model to %s and tensors to %s", f"/tmp/lora-{part}.onnx", f"/tmp/lora-{part}.tensors")
save_model(
model,
f"/tmp/lora-{part}.onnx",
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"/tmp/lora-{part}.tensors",
)
logger.info(
"saved model to %s and tensors to %s",
f"/tmp/lora-{part}.onnx",
f"/tmp/lora-{part}.tensors",
)
if __name__ == "__main__":
convert_diffusion_lora("unet")
convert_diffusion_lora("text_encoder")
convert_diffusion_lora("unet")
convert_diffusion_lora("text_encoder")

View File

@ -53,13 +53,13 @@ from transformers import (
CLIPVisionConfig,
)
from .diffusers import convert_diffusion_diffusers
from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name
from .diffusers import convert_diffusion_diffusers
logger = getLogger(__name__)
class TrainingConfig():
class TrainingConfig:
"""
From https://github.com/d8ahazard/sd_dreambooth_extension/blob/main/dreambooth/db_config.py
"""
@ -184,7 +184,9 @@ class TrainingConfig():
backup_dir = os.path.join(models_path, "backups")
if not os.path.exists(backup_dir):
os.makedirs(backup_dir)
config_file = os.path.join(models_path, "backups", f"db_config_{self.revision}.json")
config_file = os.path.join(
models_path, "backups", f"db_config_{self.revision}.json"
)
with open(config_file, "w") as outfile:
json.dump(self.__dict__, outfile, indent=4)
@ -238,7 +240,9 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
new_item = shave_segments(
new_item, n_shave_prefix_segments=n_shave_prefix_segments
)
mapping.append({"old": old_item, "new": new_item})
@ -253,7 +257,9 @@ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
new_item = shave_segments(
new_item, n_shave_prefix_segments=n_shave_prefix_segments
)
mapping.append({"old": old_item, "new": new_item})
@ -295,7 +301,9 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
new_item = shave_segments(
new_item, n_shave_prefix_segments=n_shave_prefix_segments
)
mapping.append({"old": old_item, "new": new_item})
@ -303,7 +311,12 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
paths,
checkpoint,
old_checkpoint,
attention_paths_to_split=None,
additional_replacements=None,
config=None,
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming
@ -312,7 +325,9 @@ def assign_to_checkpoint(
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
assert isinstance(
paths, list
), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
@ -324,7 +339,9 @@ def assign_to_checkpoint(
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
old_tensor = old_tensor.reshape(
(num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
)
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape)
@ -335,7 +352,10 @@ def assign_to_checkpoint(
new_path = path["new"]
# These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
if (
attention_paths_to_split is not None
and new_path in attention_paths_to_split
):
continue
# Global renaming happens here
@ -373,19 +393,29 @@ def create_unet_diffusers_config(original_config, image_size: int):
unet_params = original_config.model.params.unet_config.params
vae_params = original_config.model.params.first_stage_config.params.ddconfig
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
block_out_channels = [
unet_params.model_channels * mult for mult in unet_params.channel_mult
]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
block_type = (
"CrossAttnDownBlock2D"
if resolution in unet_params.attention_resolutions
else "DownBlock2D"
)
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
block_type = (
"CrossAttnUpBlock2D"
if resolution in unet_params.attention_resolutions
else "UpBlock2D"
)
up_block_types.append(block_type)
resolution //= 2
@ -393,7 +423,9 @@ def create_unet_diffusers_config(original_config, image_size: int):
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
unet_params.use_linear_in_transformer
if "use_linear_in_transformer" in unet_params
else False
)
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
@ -482,7 +514,9 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key
)
else:
print(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
@ -493,33 +527,53 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {"time_embedding.linear_1.weight": unet_state_dict["time_embed.0.weight"],
"time_embedding.linear_1.bias": unet_state_dict["time_embed.0.bias"],
"time_embedding.linear_2.weight": unet_state_dict["time_embed.2.weight"],
"time_embedding.linear_2.bias": unet_state_dict["time_embed.2.bias"],
"conv_in.weight": unet_state_dict["input_blocks.0.0.weight"],
"conv_in.bias": unet_state_dict["input_blocks.0.0.bias"],
"conv_norm_out.weight": unet_state_dict["out.0.weight"],
"conv_norm_out.bias": unet_state_dict["out.0.bias"],
"conv_out.weight": unet_state_dict["out.2.weight"],
"conv_out.bias": unet_state_dict["out.2.bias"]}
new_checkpoint = {
"time_embedding.linear_1.weight": unet_state_dict["time_embed.0.weight"],
"time_embedding.linear_1.bias": unet_state_dict["time_embed.0.bias"],
"time_embedding.linear_2.weight": unet_state_dict["time_embed.2.weight"],
"time_embedding.linear_2.bias": unet_state_dict["time_embed.2.bias"],
"conv_in.weight": unet_state_dict["input_blocks.0.0.weight"],
"conv_in.bias": unet_state_dict["input_blocks.0.0.bias"],
"conv_norm_out.weight": unet_state_dict["out.0.weight"],
"conv_norm_out.bias": unet_state_dict["out.0.bias"],
"conv_out.weight": unet_state_dict["out.2.weight"],
"conv_out.bias": unet_state_dict["out.2.bias"],
}
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
num_input_blocks = len(
{
".".join(layer.split(".")[:2])
for layer in unet_state_dict
if "input_blocks" in layer
}
)
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
num_middle_blocks = len(
{
".".join(layer.split(".")[:2])
for layer in unet_state_dict
if "middle_block" in layer
}
)
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
num_output_blocks = len(
{
".".join(layer.split(".")[:2])
for layer in unet_state_dict
if "output_blocks" in layer
}
)
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
@ -530,29 +584,45 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
key
for key in input_blocks[i]
if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
new_checkpoint[
f"down_blocks.{block_id}.downsamplers.0.conv.weight"
] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight")
new_checkpoint[
f"down_blocks.{block_id}.downsamplers.0.conv.bias"
] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
meta_path = {
"old": f"input_blocks.{i}.0",
"new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
paths,
new_checkpoint,
unet_state_dict,
additional_replacements=[meta_path],
config=config,
)
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
meta_path = {
"old": f"input_blocks.{i}.1",
"new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
paths,
new_checkpoint,
unet_state_dict,
additional_replacements=[meta_path],
config=config,
)
resnet_0 = middle_blocks[0]
@ -568,7 +638,11 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
attentions_paths,
new_checkpoint,
unet_state_dict,
additional_replacements=[meta_path],
config=config,
)
for i in range(num_output_blocks):
@ -586,25 +660,36 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
if len(output_block_list) > 1:
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
attentions = [
key for key in output_blocks[i] if f"output_blocks.{i}.1" in key
]
resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
meta_path = {
"old": f"output_blocks.{i}.0",
"new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
paths,
new_checkpoint,
unet_state_dict,
additional_replacements=[meta_path],
config=config,
)
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
index = list(output_block_list.values()).index(
["conv.bias", "conv.weight"]
)
new_checkpoint[
f"up_blocks.{block_id}.upsamplers.0.conv.weight"
] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"]
new_checkpoint[
f"up_blocks.{block_id}.upsamplers.0.conv.bias"
] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
@ -617,13 +702,27 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
paths,
new_checkpoint,
unet_state_dict,
additional_replacements=[meta_path],
config=config,
)
else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
resnet_0_paths = renew_resnet_paths(
output_block_layers, n_shave_prefix_segments=1
)
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
new_path = ".".join(
[
"up_blocks",
str(block_id),
"resnets",
str(layer_in_block_id),
path["new"],
]
)
new_checkpoint[new_path] = unet_state_dict[old_path]
@ -645,49 +744,75 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True):
else:
vae_state_dict[key] = checkpoint.get(key)
new_checkpoint = {"encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"],
"encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"],
"encoder.conv_out.weight": vae_state_dict["encoder.conv_out.weight"],
"encoder.conv_out.bias": vae_state_dict["encoder.conv_out.bias"],
"encoder.conv_norm_out.weight": vae_state_dict["encoder.norm_out.weight"],
"encoder.conv_norm_out.bias": vae_state_dict["encoder.norm_out.bias"],
"decoder.conv_in.weight": vae_state_dict["decoder.conv_in.weight"],
"decoder.conv_in.bias": vae_state_dict["decoder.conv_in.bias"],
"decoder.conv_out.weight": vae_state_dict["decoder.conv_out.weight"],
"decoder.conv_out.bias": vae_state_dict["decoder.conv_out.bias"],
"decoder.conv_norm_out.weight": vae_state_dict["decoder.norm_out.weight"],
"decoder.conv_norm_out.bias": vae_state_dict["decoder.norm_out.bias"],
"quant_conv.weight": vae_state_dict["quant_conv.weight"],
"quant_conv.bias": vae_state_dict["quant_conv.bias"],
"post_quant_conv.weight": vae_state_dict["post_quant_conv.weight"],
"post_quant_conv.bias": vae_state_dict["post_quant_conv.bias"]}
new_checkpoint = {
"encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"],
"encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"],
"encoder.conv_out.weight": vae_state_dict["encoder.conv_out.weight"],
"encoder.conv_out.bias": vae_state_dict["encoder.conv_out.bias"],
"encoder.conv_norm_out.weight": vae_state_dict["encoder.norm_out.weight"],
"encoder.conv_norm_out.bias": vae_state_dict["encoder.norm_out.bias"],
"decoder.conv_in.weight": vae_state_dict["decoder.conv_in.weight"],
"decoder.conv_in.bias": vae_state_dict["decoder.conv_in.bias"],
"decoder.conv_out.weight": vae_state_dict["decoder.conv_out.weight"],
"decoder.conv_out.bias": vae_state_dict["decoder.conv_out.bias"],
"decoder.conv_norm_out.weight": vae_state_dict["decoder.norm_out.weight"],
"decoder.conv_norm_out.bias": vae_state_dict["decoder.norm_out.bias"],
"quant_conv.weight": vae_state_dict["quant_conv.weight"],
"quant_conv.bias": vae_state_dict["quant_conv.bias"],
"post_quant_conv.weight": vae_state_dict["post_quant_conv.weight"],
"post_quant_conv.bias": vae_state_dict["post_quant_conv.bias"],
}
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
num_down_blocks = len(
{
".".join(layer.split(".")[:3])
for layer in vae_state_dict
if "encoder.down" in layer
}
)
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key]
for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
num_up_blocks = len(
{
".".join(layer.split(".")[:3])
for layer in vae_state_dict
if "decoder.up" in layer
}
)
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key]
for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
resnets = [
key
for key in down_blocks[i]
if f"down.{i}" in key and f"down.{i}.downsample" not in key
]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
new_checkpoint[
f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"
] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
new_checkpoint[
f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"
] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
@ -696,31 +821,51 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True):
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
key
for key in up_blocks[block_id]
if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
new_checkpoint[
f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"
] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
new_checkpoint[
f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"
] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
@ -729,12 +874,24 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True):
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint
@ -769,14 +926,16 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
for i, hf_layer in enumerate(hf_layers):
if i != 0:
i += i
pt_layer = pt_layers[i: i + 2]
pt_layer = pt_layers[i : i + 2]
_copy_layer(hf_layer, pt_layer)
hf_model = LDMBertModel(config).eval()
# copy embeds
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
hf_model.model.embed_positions.weight.data = (
checkpoint.transformer.pos_emb.emb.weight
)
# copy layer norm
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
@ -799,9 +958,13 @@ def convert_ldm_clip_checkpoint(checkpoint):
for key in keys:
if key.startswith("cond_stage_model.transformer"):
if key.find("text_model") == -1:
text_model_dict["text_model." + key[len("cond_stage_model.transformer."):]] = checkpoint[key]
text_model_dict[
"text_model." + key[len("cond_stage_model.transformer.") :]
] = checkpoint[key]
else:
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
text_model_dict[
key[len("cond_stage_model.transformer.") :]
] = checkpoint[key]
text_model.load_state_dict(text_model_dict)
@ -809,12 +972,16 @@ def convert_ldm_clip_checkpoint(checkpoint):
textenc_conversion_lst = [
('cond_stage_model.model.positional_embedding',
"text_model.embeddings.position_embedding.weight"),
('cond_stage_model.model.token_embedding.weight',
"text_model.embeddings.token_embedding.weight"),
('cond_stage_model.model.ln_final.weight', 'text_model.final_layer_norm.weight'),
('cond_stage_model.model.ln_final.bias', 'text_model.final_layer_norm.bias')
(
"cond_stage_model.model.positional_embedding",
"text_model.embeddings.position_embedding.weight",
),
(
"cond_stage_model.model.token_embedding.weight",
"text_model.embeddings.token_embedding.weight",
),
("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
]
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
@ -827,8 +994,14 @@ textenc_transformer_conversion_lst = [
(".c_proj.", ".fc2."),
(".attn", ".self_attn"),
("ln_final.", "transformer.text_model.final_layer_norm."),
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
(
"token_embedding.weight",
"transformer.text_model.embeddings.token_embedding.weight",
),
(
"positional_embedding",
"transformer.text_model.embeddings.position_embedding.weight",
),
]
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
textenc_pattern = re.compile("|".join(protected.keys()))
@ -844,7 +1017,9 @@ def convert_paint_by_example_checkpoint(checkpoint):
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[
key
]
# load clip vision
model.model.load_state_dict(text_model_dict)
@ -902,19 +1077,25 @@ def convert_paint_by_example_checkpoint(checkpoint):
def convert_open_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
text_model = CLIPTextModel.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="text_encoder"
)
keys = list(checkpoint.keys())
text_model_dict = {}
if 'cond_stage_model.model.text_projection' in checkpoint:
d_model = int(checkpoint['cond_stage_model.model.text_projection'].shape[0])
if "cond_stage_model.model.text_projection" in checkpoint:
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
else:
logger.debug("no projection shape found, setting to 1024")
d_model = 1024
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
text_model_dict[
"text_model.embeddings.position_ids"
] = text_model.text_model.embeddings.get_buffer("position_ids")
for key in keys:
if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
if (
"resblocks.23" in key
): # Diffusers drops the final layer and only uses the penultimate layer
continue
if key in textenc_conversion_map:
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
@ -922,18 +1103,34 @@ def convert_open_clip_checkpoint(checkpoint):
new_key = key[len("cond_stage_model.model.transformer.") :]
if new_key.endswith(".in_proj_weight"):
new_key = new_key[: -len(".in_proj_weight")]
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
new_key = textenc_pattern.sub(
lambda m: protected[re.escape(m.group(0))], new_key
)
text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][
:d_model, :
]
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][
d_model : d_model * 2, :
]
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][
d_model * 2 :, :
]
elif new_key.endswith(".in_proj_bias"):
new_key = new_key[: -len(".in_proj_bias")]
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
new_key = textenc_pattern.sub(
lambda m: protected[re.escape(m.group(0))], new_key
)
text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][
d_model : d_model * 2
]
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][
d_model * 2 :
]
else:
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
new_key = textenc_pattern.sub(
lambda m: protected[re.escape(m.group(0))], new_key
)
text_model_dict[new_key] = checkpoint[key]
@ -992,7 +1189,15 @@ def download_model(db_config: TrainingConfig, token):
siblings = repo_info.siblings
diffusion_dirs = ["text_encoder", "unet", "vae", "tokenizer", "scheduler", "feature_extractor", "safety_checker"]
diffusion_dirs = [
"text_encoder",
"unet",
"vae",
"tokenizer",
"scheduler",
"feature_extractor",
"safety_checker",
]
config_file = None
model_index = None
model_files = []
@ -1031,9 +1236,9 @@ def download_model(db_config: TrainingConfig, token):
(x for x in model_files if "nonema" in x),
next(
(x for x in model_files if ".safetensors" in x),
model_files[0] if model_files else None
)
)
model_files[0] if model_files else None,
),
),
)
files_to_fetch = None
@ -1061,7 +1266,7 @@ def download_model(db_config: TrainingConfig, token):
filename=repo_file,
repo_type="model",
revision=repo_info.sha,
token=token
token=token,
)
replace_symlinks(out, db_config.model_dir)
dest = None
@ -1074,7 +1279,9 @@ def download_model(db_config: TrainingConfig, token):
for diffusion_dir in diffusion_dirs:
if diffusion_dir in out:
out_model = db_config.pretrained_model_name_or_path
dest = os.path.join(db_config.pretrained_model_name_or_path, diffusion_dir)
dest = os.path.join(
db_config.pretrained_model_name_or_path, diffusion_dir
)
if not dest:
if ".ckpt" in out or ".safetensors" in out:
dest = os.path.join(db_config.model_dir, "src")
@ -1095,9 +1302,11 @@ def get_config_path(
model_version: str = "v1",
train_type: str = "default",
config_base_name: str = "training",
prediction_type: str = "epsilon"
prediction_type: str = "epsilon",
):
train_type = f"{train_type}" if not prediction_type == "v_prediction" else f"{train_type}-v"
train_type = (
f"{train_type}" if not prediction_type == "v_prediction" else f"{train_type}-v"
)
parts = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
@ -1106,21 +1315,20 @@ def get_config_path(
"..",
"models",
"configs",
f"{model_version}-{config_base_name}-{train_type}.yaml"
f"{model_version}-{config_base_name}-{train_type}.yaml",
)
return os.path.abspath(parts)
def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None):
def get_config_file(
train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None
):
if config_file is not None:
return config_file
config_base_name = "training"
model_versions = {
"v1": "v1",
"v2": "v2"
}
model_versions = {"v1": "v1", "v2": "v2"}
train_types = {
"default": "default",
"unfrozen": "unfrozen",
@ -1134,7 +1342,9 @@ def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon", c
else:
model_train_type = train_types["default"]
return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type)
return get_config_path(
model_version_name, model_train_type, config_base_name, prediction_type
)
def extract_checkpoint(
@ -1182,8 +1392,9 @@ def extract_checkpoint(
msg = None
# Create empty config
db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type,
src=checkpoint_file)
db_config = TrainingConfig(
ctx, model_name=new_model_name, scheduler=scheduler_type, src=checkpoint_file
)
original_config_file = None
@ -1221,9 +1432,13 @@ def extract_checkpoint(
else:
prediction_type = "epsilon"
original_config_file = get_config_file(train_unfrozen, v2, prediction_type, config_file=config_file)
original_config_file = get_config_file(
train_unfrozen, v2, prediction_type, config_file=config_file
)
logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}")
logger.info(
f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}"
)
db_config.resolution = image_size
db_config.lifetime_revision = revision
db_config.epoch = epoch
@ -1233,12 +1448,18 @@ def extract_checkpoint(
# Use existing YAML if present
if checkpoint_file is not None:
config_check = checkpoint_file.replace(".ckpt", ".yaml") if ".ckpt" in checkpoint_file else checkpoint_file.replace(".safetensors", ".yaml")
config_check = (
checkpoint_file.replace(".ckpt", ".yaml")
if ".ckpt" in checkpoint_file
else checkpoint_file.replace(".safetensors", ".yaml")
)
if os.path.exists(config_check):
original_config_file = config_check
if original_config_file is None or not os.path.exists(original_config_file):
logger.warning("unable to select a config file: %s" % (original_config_file))
logger.warning(
"unable to select a config file: %s" % (original_config_file)
)
return
logger.debug("trying to load: %s", original_config_file)
@ -1281,7 +1502,9 @@ def extract_checkpoint(
# Convert the UNet2DConditionModel model.
logger.info("converting UNet")
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config = create_unet_diffusers_config(
original_config, image_size=image_size
)
unet_config["upcast_attention"] = upcast_attention
unet = UNet2DConditionModel(**unet_config)
@ -1297,22 +1520,30 @@ def extract_checkpoint(
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
if vae_file is None:
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
checkpoint, vae_config
)
else:
vae_file = os.path.join(ctx.model_path, vae_file)
logger.debug("loading custom VAE: %s", vae_file)
vae_checkpoint = load_tensor(vae_file, map_location=map_location)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
vae_checkpoint, vae_config, first_stage=False
)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
# Convert the text model.
logger.info("converting text encoder")
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
text_model_type = original_config.model.params.cond_stage_config.target.split(
"."
)[-1]
if text_model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="tokenizer"
)
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
@ -1326,7 +1557,9 @@ def extract_checkpoint(
elif text_model_type == "PaintByExample":
vision_model = convert_paint_by_example_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
)
pipe = PaintByExamplePipeline(
vae=vae,
image_encoder=vision_model,
@ -1338,8 +1571,12 @@ def extract_checkpoint(
elif text_model_type == "FrozenCLIPEmbedder":
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
)
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
@ -1347,16 +1584,24 @@ def extract_checkpoint(
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor
feature_extractor=feature_extractor,
)
else:
text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet,
scheduler=scheduler)
pipe = LDMTextToImagePipeline(
vqvae=vae,
bert=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
)
except Exception:
logger.error("exception setting up output: %s", traceback.format_exception(*sys.exc_info()))
logger.error(
"exception setting up output: %s",
traceback.format_exception(*sys.exc_info()),
)
pipe = None
if pipe is None or db_config is None:
@ -1371,12 +1616,18 @@ def extract_checkpoint(
scheduler = db_config.scheduler
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
if original_config_file is not None and os.path.exists(original_config_file):
logger.debug("copying original config: %s -> %s", original_config_file, db_config.model_dir)
logger.debug(
"copying original config: %s -> %s",
original_config_file,
db_config.model_dir,
)
shutil.copy(original_config_file, db_config.model_dir)
basename = os.path.basename(original_config_file)
new_ex_path = os.path.join(db_config.model_dir, basename)
new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml")
logger.debug("copying model config to new name: %s -> %s", new_ex_path, new_name)
logger.debug(
"copying model config to new name: %s -> %s", new_ex_path, new_name
)
if os.path.exists(new_name):
os.remove(new_name)
os.rename(new_ex_path, new_name)
@ -1407,7 +1658,9 @@ def convert_diffusion_original(
source = source or model["source"]
dest = os.path.join(ctx.model_path, name)
logger.info("converting original Diffusers checkpoint %s: %s -> %s", name, source, dest)
logger.info(
"converting original Diffusers checkpoint %s: %s -> %s", name, source, dest
)
if os.path.exists(dest):
logger.info("ONNX pipeline already exists, skipping")
@ -1420,8 +1673,18 @@ def convert_diffusion_original(
if os.path.exists(torch_path):
logger.info("torch pipeline already exists, reusing: %s", torch_path)
else:
logger.info("converting original Diffusers check to Torch model: %s -> %s", source, torch_path)
extract_checkpoint(ctx, torch_name, source, config_file=model.get("config"), vae_file=model.get("vae"))
logger.info(
"converting original Diffusers check to Torch model: %s -> %s",
source,
torch_path,
)
extract_checkpoint(
ctx,
torch_name,
source,
config_file=model.get("config"),
vae_file=model.get("vae"),
)
logger.info("converted original Diffusers checkpoint to Torch model")
# VAE has already been converted and will confuse HF repo lookup

View File

@ -1,24 +1,29 @@
from os import makedirs, path
from huggingface_hub.file_download import hf_hub_download
from transformers import CLIPTokenizer, CLIPTextModel
from torch.onnx import export
from logging import getLogger
from ..utils import ConversionContext
from os import makedirs, path
import torch
from huggingface_hub.file_download import hf_hub_download
from torch.onnx import export
from transformers import CLIPTextModel, CLIPTokenizer
from ..utils import ConversionContext
logger = getLogger(__name__)
def convert_diffusion_textual_inversion(context: ConversionContext, name: str, base_model: str, inversion: str):
def convert_diffusion_textual_inversion(
context: ConversionContext, name: str, base_model: str, inversion: str
):
dest_path = path.join(context.model_path, f"inversion-{name}")
logger.info("converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path)
logger.info(
"converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path
)
if path.exists(dest_path):
logger.info("ONNX model already exists, skipping.")
return
makedirs(path.join(dest_path, "text_encoder"))
makedirs(path.join(dest_path, "text_encoder"), exist_ok=True)
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
@ -71,9 +76,7 @@ def convert_diffusion_textual_inversion(context: ConversionContext, name: str, b
export(
text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
(
text_input.input_ids.to(dtype=torch.int32)
),
(text_input.input_ids.to(dtype=torch.int32)),
f=path.join(dest_path, "text_encoder", "model.onnx"),
input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"],

View File

@ -17,9 +17,9 @@ from diffusers import (
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
LMSDiscreteScheduler,
OnnxRuntimeModel,
PNDMScheduler,
StableDiffusionPipeline,
OnnxRuntimeModel,
)
try:

View File

@ -7,7 +7,7 @@ import { useStore } from 'zustand';
import { STALE_TIME } from '../../config.js';
import { ClientContext, StateContext } from '../../state.js';
import { MODEL_LABELS, PLATFORM_LABELS } from '../../strings.js';
import { INVERSION_LABELS, MODEL_LABELS, PLATFORM_LABELS } from '../../strings.js';
import { QueryList } from '../input/QueryList.js';
export function ModelControl() {
@ -56,12 +56,13 @@ export function ModelControl() {
/>
<QueryList
id='inversion'
labels={MODEL_LABELS}
labels={INVERSION_LABELS}
name='Textual Inversion'
query={{
result: models,
selector: (result) => result.inversion,
}}
showEmpty={true}
value={params.inversion}
onChange={(inversion) => {
setModel({

View File

@ -20,6 +20,7 @@ export interface QueryListProps<T> {
value: string;
query: QueryListComplete | QueryListFilter<T>;
showEmpty?: boolean;
onChange?: (value: string) => void;
}
@ -28,17 +29,25 @@ export function hasFilter<T>(query: QueryListComplete | QueryListFilter<T>): que
return Reflect.has(query, 'selector');
}
export function filterQuery<T>(query: QueryListComplete | QueryListFilter<T>): Array<string> {
export function filterQuery<T>(query: QueryListComplete | QueryListFilter<T>, showEmpty: boolean): Array<string> {
if (hasFilter(query)) {
const data = mustExist(query.result.data);
return (query as QueryListFilter<unknown>).selector(data);
const selected = (query as QueryListFilter<unknown>).selector(data);
if (showEmpty) {
return ['', ...selected];
}
return selected;
} else {
return mustExist(query.result.data);
const data = Array.from(mustExist(query.result.data));
if (showEmpty) {
return ['', ...data];
}
return data;
}
}
export function QueryList<T>(props: QueryListProps<T>) {
const { labels, query, value } = props;
const { labels, query, showEmpty = false, value } = props;
const { result } = query;
function firstValidValue(): string {
@ -52,7 +61,7 @@ export function QueryList<T>(props: QueryListProps<T>) {
// update state when previous selection was invalid: https://github.com/ssube/onnx-web/issues/120
useEffect(() => {
if (result.status === 'success' && doesExist(result.data) && doesExist(props.onChange)) {
const data = filterQuery(query);
const data = filterQuery(query, showEmpty);
if (data.includes(value) === false) {
props.onChange(data[0]);
}
@ -77,7 +86,7 @@ export function QueryList<T>(props: QueryListProps<T>) {
// else: success
const labelID = `query-list-${props.id}-labels`;
const data = filterQuery(query);
const data = filterQuery(query, showEmpty);
return <FormControl>
<InputLabel id={labelID}>{props.name}</InputLabel>

View File

@ -32,6 +32,14 @@ export const MODEL_LABELS: Record<string, string> = {
'diffusion-unstable-ink-dream-v6': 'Unstable Ink Dream v6',
};
export const INVERSION_LABELS: Record<string, string> = {
'': 'None',
'inversion-cubex': 'Cubex',
'inversion-birb': 'Birb Style',
'inversion-line-art': 'Line Art',
'inversion-minecraft': 'Minecraft Concept',
};
export const PLATFORM_LABELS: Record<string, string> = {
amd: 'AMD GPU',
// eslint-disable-next-line id-blacklist