1
0
Fork 0

add none option to inversion menu

This commit is contained in:
Sean Sube 2023-02-21 23:50:27 -06:00
parent 7ad8385c5b
commit e8b5ff250d
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
9 changed files with 565 additions and 244 deletions

View File

@ -22,6 +22,28 @@
"name": "diffusion-unstable-ink-dream-v6", "name": "diffusion-unstable-ink-dream-v6",
"source": "civitai://5796", "source": "civitai://5796",
"format": "safetensors" "format": "safetensors"
},
{
"name": "stable-diffusion-onnx-v1-5",
"source": "runwayml/stable-diffusion-v1-5",
"inversions": [
{
"name": "line-art",
"source": "sd-concepts-library/line-art"
},
{
"name": "cubex",
"source": "sd-concepts-library/cubex"
},
{
"name": "birb",
"source": "sd-concepts-library/birb-style"
},
{
"name": "minecraft",
"source": "sd-concepts-library/minecraft-concept-art"
}
]
} }
], ],
"correction": [], "correction": [],

View File

@ -10,8 +10,8 @@ from jsonschema import ValidationError, validate
from yaml import safe_load from yaml import safe_load
from .correction_gfpgan import convert_correction_gfpgan from .correction_gfpgan import convert_correction_gfpgan
from .diffusion.original import convert_diffusion_original
from .diffusion.diffusers import convert_diffusion_diffusers from .diffusion.diffusers import convert_diffusion_diffusers
from .diffusion.original import convert_diffusion_original
from .diffusion.textual_inversion import convert_diffusion_textual_inversion from .diffusion.textual_inversion import convert_diffusion_textual_inversion
from .upscale_resrgan import convert_upscale_resrgan from .upscale_resrgan import convert_upscale_resrgan
from .utils import ( from .utils import (
@ -233,8 +233,12 @@ def convert_models(ctx: ConversionContext, args, models: Models):
for inversion in model.get("inversions", []): for inversion in model.get("inversions", []):
inversion_name = inversion["name"] inversion_name = inversion["name"]
inversion_source = inversion["source"] inversion_source = inversion["source"]
inversion_source = fetch_model(ctx, f"{name}-inversion-{inversion_name}", inversion_source) inversion_source = fetch_model(
convert_diffusion_textual_inversion(ctx, inversion_name, model["source"], inversion_source) ctx, f"{name}-inversion-{inversion_name}", inversion_source
)
convert_diffusion_textual_inversion(
ctx, inversion_name, model["source"], inversion_source
)
except Exception as e: except Exception as e:
logger.error("error converting diffusion model %s: %s", name, e) logger.error("error converting diffusion model %s: %s", name, e)

View File

@ -1,94 +1,105 @@
from numpy import ndarray
from onnx import TensorProto, helper, load, numpy_helper, ModelProto, save_model
from typing import Dict, List, Tuple
from logging import getLogger from logging import getLogger
from typing import List, Tuple
from numpy import ndarray
from onnx import ModelProto, TensorProto, helper, load, numpy_helper, save_model
logger = getLogger(__name__) logger = getLogger(__name__)
def load_lora(filename: str): def load_lora(filename: str):
model = load(filename) model = load(filename)
for weight in model.graph.initializer: for weight in model.graph.initializer:
# print(weight.name, numpy_helper.to_array(weight).shape) # print(weight.name, numpy_helper.to_array(weight).shape)
pass pass
return model return model
def blend_loras(base: ModelProto, weights: List[ModelProto], alphas: List[float]) -> List[Tuple[TensorProto, ndarray]]: def blend_loras(
total = 1 + sum(alphas) base: ModelProto, weights: List[ModelProto], alphas: List[float]
) -> List[Tuple[TensorProto, ndarray]]:
total = 1 + sum(alphas)
results = [] results = []
for base_node in base.graph.initializer: for base_node in base.graph.initializer:
logger.info("blending initializer node %s", base_node.name) logger.info("blending initializer node %s", base_node.name)
base_weights = numpy_helper.to_array(base_node).copy() base_weights = numpy_helper.to_array(base_node).copy()
for weight, alpha in zip(weights, alphas): for weight, alpha in zip(weights, alphas):
weight_node = next(iter([f for f in weight.graph.initializer if f.name == base_node.name]), None) weight_node = next(
iter([f for f in weight.graph.initializer if f.name == base_node.name]),
None,
)
if weight_node is not None: if weight_node is not None:
base_weights += numpy_helper.to_array(weight_node) * alpha base_weights += numpy_helper.to_array(weight_node) * alpha
else: else:
logger.warning("missing weights: %s in %s", base_node.name, weight.doc_string) logger.warning(
"missing weights: %s in %s", base_node.name, weight.doc_string
)
results.append((base_node, base_weights / total)) results.append((base_node, base_weights / total))
return results return results
def convert_diffusion_lora(part: str): def convert_diffusion_lora(part: str):
lora_weights = [ lora_weights = [
f"diffusion-lora-jack/{part}/model.onnx", f"diffusion-lora-jack/{part}/model.onnx",
f"diffusion-lora-taters/{part}/model.onnx", f"diffusion-lora-taters/{part}/model.onnx",
] ]
base = load_lora(f"stable-diffusion-onnx-v1-5/{part}/model.onnx") base = load_lora(f"stable-diffusion-onnx-v1-5/{part}/model.onnx")
weights = [load_lora(f) for f in lora_weights] weights = [load_lora(f) for f in lora_weights]
alphas = [1 / len(weights)] * len(weights) alphas = [1 / len(weights)] * len(weights)
logger.info("blending LoRAs with alphas: %s, %s", weights, alphas) logger.info("blending LoRAs with alphas: %s, %s", weights, alphas)
result = blend_loras(base, weights, alphas) result = blend_loras(base, weights, alphas)
logger.info("blended result keys: %s", len(result)) logger.info("blended result keys: %s", len(result))
del weights del weights
del alphas del alphas
tensors = [] tensors = []
for node, tensor in result: for node, tensor in result:
logger.info("remaking tensor for %s", node.name) logger.info("remaking tensor for %s", node.name)
tensors.append(helper.make_tensor(node.name, node.data_type, node.dims, tensor)) tensors.append(helper.make_tensor(node.name, node.data_type, node.dims, tensor))
del result del result
graph = helper.make_graph( graph = helper.make_graph(
base.graph.node, base.graph.node,
base.graph.name, base.graph.name,
base.graph.input, base.graph.input,
base.graph.output, base.graph.output,
tensors, tensors,
base.graph.doc_string, base.graph.doc_string,
base.graph.value_info, base.graph.value_info,
base.graph.sparse_initializer, base.graph.sparse_initializer,
) )
model = helper.make_model(graph) model = helper.make_model(graph)
del model.opset_import[:] del model.opset_import[:]
opset = model.opset_import.add() opset = model.opset_import.add()
opset.version = 14 opset.version = 14
save_model( save_model(
model, model,
f"/tmp/lora-{part}.onnx", f"/tmp/lora-{part}.onnx",
save_as_external_data=True, save_as_external_data=True,
all_tensors_to_one_file=True, all_tensors_to_one_file=True,
location=f"/tmp/lora-{part}.tensors", location=f"/tmp/lora-{part}.tensors",
) )
logger.info("saved model to %s and tensors to %s", f"/tmp/lora-{part}.onnx", f"/tmp/lora-{part}.tensors") logger.info(
"saved model to %s and tensors to %s",
f"/tmp/lora-{part}.onnx",
f"/tmp/lora-{part}.tensors",
)
if __name__ == "__main__": if __name__ == "__main__":
convert_diffusion_lora("unet") convert_diffusion_lora("unet")
convert_diffusion_lora("text_encoder") convert_diffusion_lora("text_encoder")

View File

@ -53,13 +53,13 @@ from transformers import (
CLIPVisionConfig, CLIPVisionConfig,
) )
from .diffusers import convert_diffusion_diffusers
from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name
from .diffusers import convert_diffusion_diffusers
logger = getLogger(__name__) logger = getLogger(__name__)
class TrainingConfig(): class TrainingConfig:
""" """
From https://github.com/d8ahazard/sd_dreambooth_extension/blob/main/dreambooth/db_config.py From https://github.com/d8ahazard/sd_dreambooth_extension/blob/main/dreambooth/db_config.py
""" """
@ -184,7 +184,9 @@ class TrainingConfig():
backup_dir = os.path.join(models_path, "backups") backup_dir = os.path.join(models_path, "backups")
if not os.path.exists(backup_dir): if not os.path.exists(backup_dir):
os.makedirs(backup_dir) os.makedirs(backup_dir)
config_file = os.path.join(models_path, "backups", f"db_config_{self.revision}.json") config_file = os.path.join(
models_path, "backups", f"db_config_{self.revision}.json"
)
with open(config_file, "w") as outfile: with open(config_file, "w") as outfile:
json.dump(self.__dict__, outfile, indent=4) json.dump(self.__dict__, outfile, indent=4)
@ -238,7 +240,9 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("emb_layers.1", "time_emb_proj") new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut") new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) new_item = shave_segments(
new_item, n_shave_prefix_segments=n_shave_prefix_segments
)
mapping.append({"old": old_item, "new": new_item}) mapping.append({"old": old_item, "new": new_item})
@ -253,7 +257,9 @@ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
for old_item in old_list: for old_item in old_list:
new_item = old_item new_item = old_item
new_item = new_item.replace("nin_shortcut", "conv_shortcut") new_item = new_item.replace("nin_shortcut", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) new_item = shave_segments(
new_item, n_shave_prefix_segments=n_shave_prefix_segments
)
mapping.append({"old": old_item, "new": new_item}) mapping.append({"old": old_item, "new": new_item})
@ -295,7 +301,9 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("proj_out.weight", "proj_attn.weight") new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias") new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) new_item = shave_segments(
new_item, n_shave_prefix_segments=n_shave_prefix_segments
)
mapping.append({"old": old_item, "new": new_item}) mapping.append({"old": old_item, "new": new_item})
@ -303,7 +311,12 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
def assign_to_checkpoint( def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None paths,
checkpoint,
old_checkpoint,
attention_paths_to_split=None,
additional_replacements=None,
config=None,
): ):
""" """
This does the final conversion step: take locally converted weights and apply a global renaming This does the final conversion step: take locally converted weights and apply a global renaming
@ -312,7 +325,9 @@ def assign_to_checkpoint(
Assigns the weights to the new checkpoint. Assigns the weights to the new checkpoint.
""" """
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." assert isinstance(
paths, list
), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables. # Splits the attention layers into three variables.
if attention_paths_to_split is not None: if attention_paths_to_split is not None:
@ -324,7 +339,9 @@ def assign_to_checkpoint(
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) old_tensor = old_tensor.reshape(
(num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
)
query, key, value = old_tensor.split(channels // num_heads, dim=1) query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape) checkpoint[path_map["query"]] = query.reshape(target_shape)
@ -335,7 +352,10 @@ def assign_to_checkpoint(
new_path = path["new"] new_path = path["new"]
# These have already been assigned # These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split: if (
attention_paths_to_split is not None
and new_path in attention_paths_to_split
):
continue continue
# Global renaming happens here # Global renaming happens here
@ -373,19 +393,29 @@ def create_unet_diffusers_config(original_config, image_size: int):
unet_params = original_config.model.params.unet_config.params unet_params = original_config.model.params.unet_config.params
vae_params = original_config.model.params.first_stage_config.params.ddconfig vae_params = original_config.model.params.first_stage_config.params.ddconfig
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] block_out_channels = [
unet_params.model_channels * mult for mult in unet_params.channel_mult
]
down_block_types = [] down_block_types = []
resolution = 1 resolution = 1
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" block_type = (
"CrossAttnDownBlock2D"
if resolution in unet_params.attention_resolutions
else "DownBlock2D"
)
down_block_types.append(block_type) down_block_types.append(block_type)
if i != len(block_out_channels) - 1: if i != len(block_out_channels) - 1:
resolution *= 2 resolution *= 2
up_block_types = [] up_block_types = []
for i in range(len(block_out_channels)): for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" block_type = (
"CrossAttnUpBlock2D"
if resolution in unet_params.attention_resolutions
else "UpBlock2D"
)
up_block_types.append(block_type) up_block_types.append(block_type)
resolution //= 2 resolution //= 2
@ -393,7 +423,9 @@ def create_unet_diffusers_config(original_config, image_size: int):
head_dim = unet_params.num_heads if "num_heads" in unet_params else None head_dim = unet_params.num_heads if "num_heads" in unet_params else None
use_linear_projection = ( use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False unet_params.use_linear_in_transformer
if "use_linear_in_transformer" in unet_params
else False
) )
if use_linear_projection: if use_linear_projection:
# stable diffusion 2-base-512 and 2-768 # stable diffusion 2-base-512 and 2-768
@ -482,7 +514,9 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
for key in keys: for key in keys:
if key.startswith("model.diffusion_model"): if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key
)
else: else:
print( print(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
@ -493,33 +527,53 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
if key.startswith(unet_key): if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {"time_embedding.linear_1.weight": unet_state_dict["time_embed.0.weight"], new_checkpoint = {
"time_embedding.linear_1.bias": unet_state_dict["time_embed.0.bias"], "time_embedding.linear_1.weight": unet_state_dict["time_embed.0.weight"],
"time_embedding.linear_2.weight": unet_state_dict["time_embed.2.weight"], "time_embedding.linear_1.bias": unet_state_dict["time_embed.0.bias"],
"time_embedding.linear_2.bias": unet_state_dict["time_embed.2.bias"], "time_embedding.linear_2.weight": unet_state_dict["time_embed.2.weight"],
"conv_in.weight": unet_state_dict["input_blocks.0.0.weight"], "time_embedding.linear_2.bias": unet_state_dict["time_embed.2.bias"],
"conv_in.bias": unet_state_dict["input_blocks.0.0.bias"], "conv_in.weight": unet_state_dict["input_blocks.0.0.weight"],
"conv_norm_out.weight": unet_state_dict["out.0.weight"], "conv_in.bias": unet_state_dict["input_blocks.0.0.bias"],
"conv_norm_out.bias": unet_state_dict["out.0.bias"], "conv_norm_out.weight": unet_state_dict["out.0.weight"],
"conv_out.weight": unet_state_dict["out.2.weight"], "conv_norm_out.bias": unet_state_dict["out.0.bias"],
"conv_out.bias": unet_state_dict["out.2.bias"]} "conv_out.weight": unet_state_dict["out.2.weight"],
"conv_out.bias": unet_state_dict["out.2.bias"],
}
# Retrieves the keys for the input blocks only # Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) num_input_blocks = len(
{
".".join(layer.split(".")[:2])
for layer in unet_state_dict
if "input_blocks" in layer
}
)
input_blocks = { input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks) for layer_id in range(num_input_blocks)
} }
# Retrieves the keys for the middle blocks only # Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) num_middle_blocks = len(
{
".".join(layer.split(".")[:2])
for layer in unet_state_dict
if "middle_block" in layer
}
)
middle_blocks = { middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks) for layer_id in range(num_middle_blocks)
} }
# Retrieves the keys for the output blocks only # Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) num_output_blocks = len(
{
".".join(layer.split(".")[:2])
for layer in unet_state_dict
if "output_blocks" in layer
}
)
output_blocks = { output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks) for layer_id in range(num_output_blocks)
@ -530,29 +584,45 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [ resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key key
for key in input_blocks[i]
if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
] ]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict: if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( new_checkpoint[
f"input_blocks.{i}.0.op.weight" f"down_blocks.{block_id}.downsamplers.0.conv.weight"
) ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight")
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( new_checkpoint[
f"input_blocks.{i}.0.op.bias" f"down_blocks.{block_id}.downsamplers.0.conv.bias"
) ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} meta_path = {
"old": f"input_blocks.{i}.0",
"new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}",
}
assign_to_checkpoint( assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config paths,
new_checkpoint,
unet_state_dict,
additional_replacements=[meta_path],
config=config,
) )
if len(attentions): if len(attentions):
paths = renew_attention_paths(attentions) paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} meta_path = {
"old": f"input_blocks.{i}.1",
"new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint( assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config paths,
new_checkpoint,
unet_state_dict,
additional_replacements=[meta_path],
config=config,
) )
resnet_0 = middle_blocks[0] resnet_0 = middle_blocks[0]
@ -568,7 +638,11 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
attentions_paths = renew_attention_paths(attentions) attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint( assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config attentions_paths,
new_checkpoint,
unet_state_dict,
additional_replacements=[meta_path],
config=config,
) )
for i in range(num_output_blocks): for i in range(num_output_blocks):
@ -586,25 +660,36 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
if len(output_block_list) > 1: if len(output_block_list) > 1:
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] attentions = [
key for key in output_blocks[i] if f"output_blocks.{i}.1" in key
]
resnet_0_paths = renew_resnet_paths(resnets) resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} meta_path = {
"old": f"output_blocks.{i}.0",
"new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}",
}
assign_to_checkpoint( assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config paths,
new_checkpoint,
unet_state_dict,
additional_replacements=[meta_path],
config=config,
) )
output_block_list = {k: sorted(v) for k, v in output_block_list.items()} output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values(): if ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) index = list(output_block_list.values()).index(
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ ["conv.bias", "conv.weight"]
f"output_blocks.{i}.{index}.conv.weight" )
] new_checkpoint[
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ f"up_blocks.{block_id}.upsamplers.0.conv.weight"
f"output_blocks.{i}.{index}.conv.bias" ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"]
] new_checkpoint[
f"up_blocks.{block_id}.upsamplers.0.conv.bias"
] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"]
# Clear attentions as they have been attributed above. # Clear attentions as they have been attributed above.
if len(attentions) == 2: if len(attentions) == 2:
@ -617,13 +702,27 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
} }
assign_to_checkpoint( assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config paths,
new_checkpoint,
unet_state_dict,
additional_replacements=[meta_path],
config=config,
) )
else: else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) resnet_0_paths = renew_resnet_paths(
output_block_layers, n_shave_prefix_segments=1
)
for path in resnet_0_paths: for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]]) old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) new_path = ".".join(
[
"up_blocks",
str(block_id),
"resnets",
str(layer_in_block_id),
path["new"],
]
)
new_checkpoint[new_path] = unet_state_dict[old_path] new_checkpoint[new_path] = unet_state_dict[old_path]
@ -645,49 +744,75 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True):
else: else:
vae_state_dict[key] = checkpoint.get(key) vae_state_dict[key] = checkpoint.get(key)
new_checkpoint = {"encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"], new_checkpoint = {
"encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"], "encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"],
"encoder.conv_out.weight": vae_state_dict["encoder.conv_out.weight"], "encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"],
"encoder.conv_out.bias": vae_state_dict["encoder.conv_out.bias"], "encoder.conv_out.weight": vae_state_dict["encoder.conv_out.weight"],
"encoder.conv_norm_out.weight": vae_state_dict["encoder.norm_out.weight"], "encoder.conv_out.bias": vae_state_dict["encoder.conv_out.bias"],
"encoder.conv_norm_out.bias": vae_state_dict["encoder.norm_out.bias"], "encoder.conv_norm_out.weight": vae_state_dict["encoder.norm_out.weight"],
"decoder.conv_in.weight": vae_state_dict["decoder.conv_in.weight"], "encoder.conv_norm_out.bias": vae_state_dict["encoder.norm_out.bias"],
"decoder.conv_in.bias": vae_state_dict["decoder.conv_in.bias"], "decoder.conv_in.weight": vae_state_dict["decoder.conv_in.weight"],
"decoder.conv_out.weight": vae_state_dict["decoder.conv_out.weight"], "decoder.conv_in.bias": vae_state_dict["decoder.conv_in.bias"],
"decoder.conv_out.bias": vae_state_dict["decoder.conv_out.bias"], "decoder.conv_out.weight": vae_state_dict["decoder.conv_out.weight"],
"decoder.conv_norm_out.weight": vae_state_dict["decoder.norm_out.weight"], "decoder.conv_out.bias": vae_state_dict["decoder.conv_out.bias"],
"decoder.conv_norm_out.bias": vae_state_dict["decoder.norm_out.bias"], "decoder.conv_norm_out.weight": vae_state_dict["decoder.norm_out.weight"],
"quant_conv.weight": vae_state_dict["quant_conv.weight"], "decoder.conv_norm_out.bias": vae_state_dict["decoder.norm_out.bias"],
"quant_conv.bias": vae_state_dict["quant_conv.bias"], "quant_conv.weight": vae_state_dict["quant_conv.weight"],
"post_quant_conv.weight": vae_state_dict["post_quant_conv.weight"], "quant_conv.bias": vae_state_dict["quant_conv.bias"],
"post_quant_conv.bias": vae_state_dict["post_quant_conv.bias"]} "post_quant_conv.weight": vae_state_dict["post_quant_conv.weight"],
"post_quant_conv.bias": vae_state_dict["post_quant_conv.bias"],
}
# Retrieves the keys for the encoder down blocks only # Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) num_down_blocks = len(
{
".".join(layer.split(".")[:3])
for layer in vae_state_dict
if "encoder.down" in layer
}
)
down_blocks = { down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key]
for layer_id in range(num_down_blocks)
} }
# Retrieves the keys for the decoder up blocks only # Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) num_up_blocks = len(
{
".".join(layer.split(".")[:3])
for layer in vae_state_dict
if "decoder.up" in layer
}
)
up_blocks = { up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key]
for layer_id in range(num_up_blocks)
} }
for i in range(num_down_blocks): for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] resnets = [
key
for key in down_blocks[i]
if f"down.{i}" in key and f"down.{i}.downsample" not in key
]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( new_checkpoint[
f"encoder.down.{i}.downsample.conv.weight" f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"
) ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( new_checkpoint[
f"encoder.down.{i}.downsample.conv.bias" f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"
) ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
paths = renew_vae_resnet_paths(resnets) paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2 num_mid_res_blocks = 2
@ -696,31 +821,51 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True):
paths = renew_vae_resnet_paths(resnets) paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions) paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
conv_attn_to_linear(new_checkpoint) conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks): for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i block_id = num_up_blocks - 1 - i
resnets = [ resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key key
for key in up_blocks[block_id]
if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
] ]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ new_checkpoint[
f"decoder.up.{block_id}.upsample.conv.weight" f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"
] ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ new_checkpoint[
f"decoder.up.{block_id}.upsample.conv.bias" f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"
] ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
paths = renew_vae_resnet_paths(resnets) paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2 num_mid_res_blocks = 2
@ -729,12 +874,24 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True):
paths = renew_vae_resnet_paths(resnets) paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions) paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
conv_attn_to_linear(new_checkpoint) conv_attn_to_linear(new_checkpoint)
return new_checkpoint return new_checkpoint
@ -769,14 +926,16 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
for i, hf_layer in enumerate(hf_layers): for i, hf_layer in enumerate(hf_layers):
if i != 0: if i != 0:
i += i i += i
pt_layer = pt_layers[i: i + 2] pt_layer = pt_layers[i : i + 2]
_copy_layer(hf_layer, pt_layer) _copy_layer(hf_layer, pt_layer)
hf_model = LDMBertModel(config).eval() hf_model = LDMBertModel(config).eval()
# copy embeds # copy embeds
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight hf_model.model.embed_positions.weight.data = (
checkpoint.transformer.pos_emb.emb.weight
)
# copy layer norm # copy layer norm
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
@ -799,9 +958,13 @@ def convert_ldm_clip_checkpoint(checkpoint):
for key in keys: for key in keys:
if key.startswith("cond_stage_model.transformer"): if key.startswith("cond_stage_model.transformer"):
if key.find("text_model") == -1: if key.find("text_model") == -1:
text_model_dict["text_model." + key[len("cond_stage_model.transformer."):]] = checkpoint[key] text_model_dict[
"text_model." + key[len("cond_stage_model.transformer.") :]
] = checkpoint[key]
else: else:
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] text_model_dict[
key[len("cond_stage_model.transformer.") :]
] = checkpoint[key]
text_model.load_state_dict(text_model_dict) text_model.load_state_dict(text_model_dict)
@ -809,12 +972,16 @@ def convert_ldm_clip_checkpoint(checkpoint):
textenc_conversion_lst = [ textenc_conversion_lst = [
('cond_stage_model.model.positional_embedding', (
"text_model.embeddings.position_embedding.weight"), "cond_stage_model.model.positional_embedding",
('cond_stage_model.model.token_embedding.weight', "text_model.embeddings.position_embedding.weight",
"text_model.embeddings.token_embedding.weight"), ),
('cond_stage_model.model.ln_final.weight', 'text_model.final_layer_norm.weight'), (
('cond_stage_model.model.ln_final.bias', 'text_model.final_layer_norm.bias') "cond_stage_model.model.token_embedding.weight",
"text_model.embeddings.token_embedding.weight",
),
("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
] ]
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
@ -827,8 +994,14 @@ textenc_transformer_conversion_lst = [
(".c_proj.", ".fc2."), (".c_proj.", ".fc2."),
(".attn", ".self_attn"), (".attn", ".self_attn"),
("ln_final.", "transformer.text_model.final_layer_norm."), ("ln_final.", "transformer.text_model.final_layer_norm."),
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), (
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), "token_embedding.weight",
"transformer.text_model.embeddings.token_embedding.weight",
),
(
"positional_embedding",
"transformer.text_model.embeddings.position_embedding.weight",
),
] ]
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
textenc_pattern = re.compile("|".join(protected.keys())) textenc_pattern = re.compile("|".join(protected.keys()))
@ -844,7 +1017,9 @@ def convert_paint_by_example_checkpoint(checkpoint):
for key in keys: for key in keys:
if key.startswith("cond_stage_model.transformer"): if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[
key
]
# load clip vision # load clip vision
model.model.load_state_dict(text_model_dict) model.model.load_state_dict(text_model_dict)
@ -902,19 +1077,25 @@ def convert_paint_by_example_checkpoint(checkpoint):
def convert_open_clip_checkpoint(checkpoint): def convert_open_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") text_model = CLIPTextModel.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="text_encoder"
)
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
text_model_dict = {} text_model_dict = {}
if 'cond_stage_model.model.text_projection' in checkpoint: if "cond_stage_model.model.text_projection" in checkpoint:
d_model = int(checkpoint['cond_stage_model.model.text_projection'].shape[0]) d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
else: else:
logger.debug("no projection shape found, setting to 1024") logger.debug("no projection shape found, setting to 1024")
d_model = 1024 d_model = 1024
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") text_model_dict[
"text_model.embeddings.position_ids"
] = text_model.text_model.embeddings.get_buffer("position_ids")
for key in keys: for key in keys:
if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer if (
"resblocks.23" in key
): # Diffusers drops the final layer and only uses the penultimate layer
continue continue
if key in textenc_conversion_map: if key in textenc_conversion_map:
text_model_dict[textenc_conversion_map[key]] = checkpoint[key] text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
@ -922,18 +1103,34 @@ def convert_open_clip_checkpoint(checkpoint):
new_key = key[len("cond_stage_model.model.transformer.") :] new_key = key[len("cond_stage_model.model.transformer.") :]
if new_key.endswith(".in_proj_weight"): if new_key.endswith(".in_proj_weight"):
new_key = new_key[: -len(".in_proj_weight")] new_key = new_key[: -len(".in_proj_weight")]
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) new_key = textenc_pattern.sub(
text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] lambda m: protected[re.escape(m.group(0))], new_key
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] )
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][
:d_model, :
]
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][
d_model : d_model * 2, :
]
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][
d_model * 2 :, :
]
elif new_key.endswith(".in_proj_bias"): elif new_key.endswith(".in_proj_bias"):
new_key = new_key[: -len(".in_proj_bias")] new_key = new_key[: -len(".in_proj_bias")]
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) new_key = textenc_pattern.sub(
lambda m: protected[re.escape(m.group(0))], new_key
)
text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] d_model : d_model * 2
]
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][
d_model * 2 :
]
else: else:
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) new_key = textenc_pattern.sub(
lambda m: protected[re.escape(m.group(0))], new_key
)
text_model_dict[new_key] = checkpoint[key] text_model_dict[new_key] = checkpoint[key]
@ -992,7 +1189,15 @@ def download_model(db_config: TrainingConfig, token):
siblings = repo_info.siblings siblings = repo_info.siblings
diffusion_dirs = ["text_encoder", "unet", "vae", "tokenizer", "scheduler", "feature_extractor", "safety_checker"] diffusion_dirs = [
"text_encoder",
"unet",
"vae",
"tokenizer",
"scheduler",
"feature_extractor",
"safety_checker",
]
config_file = None config_file = None
model_index = None model_index = None
model_files = [] model_files = []
@ -1031,9 +1236,9 @@ def download_model(db_config: TrainingConfig, token):
(x for x in model_files if "nonema" in x), (x for x in model_files if "nonema" in x),
next( next(
(x for x in model_files if ".safetensors" in x), (x for x in model_files if ".safetensors" in x),
model_files[0] if model_files else None model_files[0] if model_files else None,
) ),
) ),
) )
files_to_fetch = None files_to_fetch = None
@ -1061,7 +1266,7 @@ def download_model(db_config: TrainingConfig, token):
filename=repo_file, filename=repo_file,
repo_type="model", repo_type="model",
revision=repo_info.sha, revision=repo_info.sha,
token=token token=token,
) )
replace_symlinks(out, db_config.model_dir) replace_symlinks(out, db_config.model_dir)
dest = None dest = None
@ -1074,7 +1279,9 @@ def download_model(db_config: TrainingConfig, token):
for diffusion_dir in diffusion_dirs: for diffusion_dir in diffusion_dirs:
if diffusion_dir in out: if diffusion_dir in out:
out_model = db_config.pretrained_model_name_or_path out_model = db_config.pretrained_model_name_or_path
dest = os.path.join(db_config.pretrained_model_name_or_path, diffusion_dir) dest = os.path.join(
db_config.pretrained_model_name_or_path, diffusion_dir
)
if not dest: if not dest:
if ".ckpt" in out or ".safetensors" in out: if ".ckpt" in out or ".safetensors" in out:
dest = os.path.join(db_config.model_dir, "src") dest = os.path.join(db_config.model_dir, "src")
@ -1095,9 +1302,11 @@ def get_config_path(
model_version: str = "v1", model_version: str = "v1",
train_type: str = "default", train_type: str = "default",
config_base_name: str = "training", config_base_name: str = "training",
prediction_type: str = "epsilon" prediction_type: str = "epsilon",
): ):
train_type = f"{train_type}" if not prediction_type == "v_prediction" else f"{train_type}-v" train_type = (
f"{train_type}" if not prediction_type == "v_prediction" else f"{train_type}-v"
)
parts = os.path.join( parts = os.path.join(
os.path.dirname(os.path.realpath(__file__)), os.path.dirname(os.path.realpath(__file__)),
@ -1106,21 +1315,20 @@ def get_config_path(
"..", "..",
"models", "models",
"configs", "configs",
f"{model_version}-{config_base_name}-{train_type}.yaml" f"{model_version}-{config_base_name}-{train_type}.yaml",
) )
return os.path.abspath(parts) return os.path.abspath(parts)
def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None): def get_config_file(
train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None
):
if config_file is not None: if config_file is not None:
return config_file return config_file
config_base_name = "training" config_base_name = "training"
model_versions = { model_versions = {"v1": "v1", "v2": "v2"}
"v1": "v1",
"v2": "v2"
}
train_types = { train_types = {
"default": "default", "default": "default",
"unfrozen": "unfrozen", "unfrozen": "unfrozen",
@ -1134,7 +1342,9 @@ def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon", c
else: else:
model_train_type = train_types["default"] model_train_type = train_types["default"]
return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type) return get_config_path(
model_version_name, model_train_type, config_base_name, prediction_type
)
def extract_checkpoint( def extract_checkpoint(
@ -1182,8 +1392,9 @@ def extract_checkpoint(
msg = None msg = None
# Create empty config # Create empty config
db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type, db_config = TrainingConfig(
src=checkpoint_file) ctx, model_name=new_model_name, scheduler=scheduler_type, src=checkpoint_file
)
original_config_file = None original_config_file = None
@ -1221,9 +1432,13 @@ def extract_checkpoint(
else: else:
prediction_type = "epsilon" prediction_type = "epsilon"
original_config_file = get_config_file(train_unfrozen, v2, prediction_type, config_file=config_file) original_config_file = get_config_file(
train_unfrozen, v2, prediction_type, config_file=config_file
)
logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}") logger.info(
f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}"
)
db_config.resolution = image_size db_config.resolution = image_size
db_config.lifetime_revision = revision db_config.lifetime_revision = revision
db_config.epoch = epoch db_config.epoch = epoch
@ -1233,12 +1448,18 @@ def extract_checkpoint(
# Use existing YAML if present # Use existing YAML if present
if checkpoint_file is not None: if checkpoint_file is not None:
config_check = checkpoint_file.replace(".ckpt", ".yaml") if ".ckpt" in checkpoint_file else checkpoint_file.replace(".safetensors", ".yaml") config_check = (
checkpoint_file.replace(".ckpt", ".yaml")
if ".ckpt" in checkpoint_file
else checkpoint_file.replace(".safetensors", ".yaml")
)
if os.path.exists(config_check): if os.path.exists(config_check):
original_config_file = config_check original_config_file = config_check
if original_config_file is None or not os.path.exists(original_config_file): if original_config_file is None or not os.path.exists(original_config_file):
logger.warning("unable to select a config file: %s" % (original_config_file)) logger.warning(
"unable to select a config file: %s" % (original_config_file)
)
return return
logger.debug("trying to load: %s", original_config_file) logger.debug("trying to load: %s", original_config_file)
@ -1281,7 +1502,9 @@ def extract_checkpoint(
# Convert the UNet2DConditionModel model. # Convert the UNet2DConditionModel model.
logger.info("converting UNet") logger.info("converting UNet")
unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config = create_unet_diffusers_config(
original_config, image_size=image_size
)
unet_config["upcast_attention"] = upcast_attention unet_config["upcast_attention"] = upcast_attention
unet = UNet2DConditionModel(**unet_config) unet = UNet2DConditionModel(**unet_config)
@ -1297,22 +1520,30 @@ def extract_checkpoint(
vae_config = create_vae_diffusers_config(original_config, image_size=image_size) vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
if vae_file is None: if vae_file is None:
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) converted_vae_checkpoint = convert_ldm_vae_checkpoint(
checkpoint, vae_config
)
else: else:
vae_file = os.path.join(ctx.model_path, vae_file) vae_file = os.path.join(ctx.model_path, vae_file)
logger.debug("loading custom VAE: %s", vae_file) logger.debug("loading custom VAE: %s", vae_file)
vae_checkpoint = load_tensor(vae_file, map_location=map_location) vae_checkpoint = load_tensor(vae_file, map_location=map_location)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False) converted_vae_checkpoint = convert_ldm_vae_checkpoint(
vae_checkpoint, vae_config, first_stage=False
)
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
# Convert the text model. # Convert the text model.
logger.info("converting text encoder") logger.info("converting text encoder")
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] text_model_type = original_config.model.params.cond_stage_config.target.split(
"."
)[-1]
if text_model_type == "FrozenOpenCLIPEmbedder": if text_model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint) text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") tokenizer = CLIPTokenizer.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="tokenizer"
)
pipe = StableDiffusionPipeline( pipe = StableDiffusionPipeline(
vae=vae, vae=vae,
text_encoder=text_model, text_encoder=text_model,
@ -1326,7 +1557,9 @@ def extract_checkpoint(
elif text_model_type == "PaintByExample": elif text_model_type == "PaintByExample":
vision_model = convert_paint_by_example_checkpoint(checkpoint) vision_model = convert_paint_by_example_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
)
pipe = PaintByExamplePipeline( pipe = PaintByExamplePipeline(
vae=vae, vae=vae,
image_encoder=vision_model, image_encoder=vision_model,
@ -1338,8 +1571,12 @@ def extract_checkpoint(
elif text_model_type == "FrozenCLIPEmbedder": elif text_model_type == "FrozenCLIPEmbedder":
text_model = convert_ldm_clip_checkpoint(checkpoint) text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") safety_checker = StableDiffusionSafetyChecker.from_pretrained(
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") "CompVis/stable-diffusion-safety-checker"
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
)
pipe = StableDiffusionPipeline( pipe = StableDiffusionPipeline(
vae=vae, vae=vae,
text_encoder=text_model, text_encoder=text_model,
@ -1347,16 +1584,24 @@ def extract_checkpoint(
unet=unet, unet=unet,
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor feature_extractor=feature_extractor,
) )
else: else:
text_config = create_ldm_bert_config(original_config) text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, pipe = LDMTextToImagePipeline(
scheduler=scheduler) vqvae=vae,
bert=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
)
except Exception: except Exception:
logger.error("exception setting up output: %s", traceback.format_exception(*sys.exc_info())) logger.error(
"exception setting up output: %s",
traceback.format_exception(*sys.exc_info()),
)
pipe = None pipe = None
if pipe is None or db_config is None: if pipe is None or db_config is None:
@ -1371,12 +1616,18 @@ def extract_checkpoint(
scheduler = db_config.scheduler scheduler = db_config.scheduler
required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"] required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"]
if original_config_file is not None and os.path.exists(original_config_file): if original_config_file is not None and os.path.exists(original_config_file):
logger.debug("copying original config: %s -> %s", original_config_file, db_config.model_dir) logger.debug(
"copying original config: %s -> %s",
original_config_file,
db_config.model_dir,
)
shutil.copy(original_config_file, db_config.model_dir) shutil.copy(original_config_file, db_config.model_dir)
basename = os.path.basename(original_config_file) basename = os.path.basename(original_config_file)
new_ex_path = os.path.join(db_config.model_dir, basename) new_ex_path = os.path.join(db_config.model_dir, basename)
new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml") new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml")
logger.debug("copying model config to new name: %s -> %s", new_ex_path, new_name) logger.debug(
"copying model config to new name: %s -> %s", new_ex_path, new_name
)
if os.path.exists(new_name): if os.path.exists(new_name):
os.remove(new_name) os.remove(new_name)
os.rename(new_ex_path, new_name) os.rename(new_ex_path, new_name)
@ -1407,7 +1658,9 @@ def convert_diffusion_original(
source = source or model["source"] source = source or model["source"]
dest = os.path.join(ctx.model_path, name) dest = os.path.join(ctx.model_path, name)
logger.info("converting original Diffusers checkpoint %s: %s -> %s", name, source, dest) logger.info(
"converting original Diffusers checkpoint %s: %s -> %s", name, source, dest
)
if os.path.exists(dest): if os.path.exists(dest):
logger.info("ONNX pipeline already exists, skipping") logger.info("ONNX pipeline already exists, skipping")
@ -1420,8 +1673,18 @@ def convert_diffusion_original(
if os.path.exists(torch_path): if os.path.exists(torch_path):
logger.info("torch pipeline already exists, reusing: %s", torch_path) logger.info("torch pipeline already exists, reusing: %s", torch_path)
else: else:
logger.info("converting original Diffusers check to Torch model: %s -> %s", source, torch_path) logger.info(
extract_checkpoint(ctx, torch_name, source, config_file=model.get("config"), vae_file=model.get("vae")) "converting original Diffusers check to Torch model: %s -> %s",
source,
torch_path,
)
extract_checkpoint(
ctx,
torch_name,
source,
config_file=model.get("config"),
vae_file=model.get("vae"),
)
logger.info("converted original Diffusers checkpoint to Torch model") logger.info("converted original Diffusers checkpoint to Torch model")
# VAE has already been converted and will confuse HF repo lookup # VAE has already been converted and will confuse HF repo lookup

View File

@ -1,24 +1,29 @@
from os import makedirs, path
from huggingface_hub.file_download import hf_hub_download
from transformers import CLIPTokenizer, CLIPTextModel
from torch.onnx import export
from logging import getLogger from logging import getLogger
from os import makedirs, path
from ..utils import ConversionContext
import torch import torch
from huggingface_hub.file_download import hf_hub_download
from torch.onnx import export
from transformers import CLIPTextModel, CLIPTokenizer
from ..utils import ConversionContext
logger = getLogger(__name__) logger = getLogger(__name__)
def convert_diffusion_textual_inversion(context: ConversionContext, name: str, base_model: str, inversion: str): def convert_diffusion_textual_inversion(
context: ConversionContext, name: str, base_model: str, inversion: str
):
dest_path = path.join(context.model_path, f"inversion-{name}") dest_path = path.join(context.model_path, f"inversion-{name}")
logger.info("converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path) logger.info(
"converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path
)
if path.exists(dest_path): if path.exists(dest_path):
logger.info("ONNX model already exists, skipping.") logger.info("ONNX model already exists, skipping.")
return
makedirs(path.join(dest_path, "text_encoder")) makedirs(path.join(dest_path, "text_encoder"), exist_ok=True)
embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin") embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt") token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt")
@ -71,9 +76,7 @@ def convert_diffusion_textual_inversion(context: ConversionContext, name: str, b
export( export(
text_encoder, text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
( (text_input.input_ids.to(dtype=torch.int32)),
text_input.input_ids.to(dtype=torch.int32)
),
f=path.join(dest_path, "text_encoder", "model.onnx"), f=path.join(dest_path, "text_encoder", "model.onnx"),
input_names=["input_ids"], input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"], output_names=["last_hidden_state", "pooler_output"],

View File

@ -17,9 +17,9 @@ from diffusers import (
KDPM2AncestralDiscreteScheduler, KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler, KDPM2DiscreteScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
OnnxRuntimeModel,
PNDMScheduler, PNDMScheduler,
StableDiffusionPipeline, StableDiffusionPipeline,
OnnxRuntimeModel,
) )
try: try:

View File

@ -7,7 +7,7 @@ import { useStore } from 'zustand';
import { STALE_TIME } from '../../config.js'; import { STALE_TIME } from '../../config.js';
import { ClientContext, StateContext } from '../../state.js'; import { ClientContext, StateContext } from '../../state.js';
import { MODEL_LABELS, PLATFORM_LABELS } from '../../strings.js'; import { INVERSION_LABELS, MODEL_LABELS, PLATFORM_LABELS } from '../../strings.js';
import { QueryList } from '../input/QueryList.js'; import { QueryList } from '../input/QueryList.js';
export function ModelControl() { export function ModelControl() {
@ -56,12 +56,13 @@ export function ModelControl() {
/> />
<QueryList <QueryList
id='inversion' id='inversion'
labels={MODEL_LABELS} labels={INVERSION_LABELS}
name='Textual Inversion' name='Textual Inversion'
query={{ query={{
result: models, result: models,
selector: (result) => result.inversion, selector: (result) => result.inversion,
}} }}
showEmpty={true}
value={params.inversion} value={params.inversion}
onChange={(inversion) => { onChange={(inversion) => {
setModel({ setModel({

View File

@ -20,6 +20,7 @@ export interface QueryListProps<T> {
value: string; value: string;
query: QueryListComplete | QueryListFilter<T>; query: QueryListComplete | QueryListFilter<T>;
showEmpty?: boolean;
onChange?: (value: string) => void; onChange?: (value: string) => void;
} }
@ -28,17 +29,25 @@ export function hasFilter<T>(query: QueryListComplete | QueryListFilter<T>): que
return Reflect.has(query, 'selector'); return Reflect.has(query, 'selector');
} }
export function filterQuery<T>(query: QueryListComplete | QueryListFilter<T>): Array<string> { export function filterQuery<T>(query: QueryListComplete | QueryListFilter<T>, showEmpty: boolean): Array<string> {
if (hasFilter(query)) { if (hasFilter(query)) {
const data = mustExist(query.result.data); const data = mustExist(query.result.data);
return (query as QueryListFilter<unknown>).selector(data); const selected = (query as QueryListFilter<unknown>).selector(data);
if (showEmpty) {
return ['', ...selected];
}
return selected;
} else { } else {
return mustExist(query.result.data); const data = Array.from(mustExist(query.result.data));
if (showEmpty) {
return ['', ...data];
}
return data;
} }
} }
export function QueryList<T>(props: QueryListProps<T>) { export function QueryList<T>(props: QueryListProps<T>) {
const { labels, query, value } = props; const { labels, query, showEmpty = false, value } = props;
const { result } = query; const { result } = query;
function firstValidValue(): string { function firstValidValue(): string {
@ -52,7 +61,7 @@ export function QueryList<T>(props: QueryListProps<T>) {
// update state when previous selection was invalid: https://github.com/ssube/onnx-web/issues/120 // update state when previous selection was invalid: https://github.com/ssube/onnx-web/issues/120
useEffect(() => { useEffect(() => {
if (result.status === 'success' && doesExist(result.data) && doesExist(props.onChange)) { if (result.status === 'success' && doesExist(result.data) && doesExist(props.onChange)) {
const data = filterQuery(query); const data = filterQuery(query, showEmpty);
if (data.includes(value) === false) { if (data.includes(value) === false) {
props.onChange(data[0]); props.onChange(data[0]);
} }
@ -77,7 +86,7 @@ export function QueryList<T>(props: QueryListProps<T>) {
// else: success // else: success
const labelID = `query-list-${props.id}-labels`; const labelID = `query-list-${props.id}-labels`;
const data = filterQuery(query); const data = filterQuery(query, showEmpty);
return <FormControl> return <FormControl>
<InputLabel id={labelID}>{props.name}</InputLabel> <InputLabel id={labelID}>{props.name}</InputLabel>

View File

@ -32,6 +32,14 @@ export const MODEL_LABELS: Record<string, string> = {
'diffusion-unstable-ink-dream-v6': 'Unstable Ink Dream v6', 'diffusion-unstable-ink-dream-v6': 'Unstable Ink Dream v6',
}; };
export const INVERSION_LABELS: Record<string, string> = {
'': 'None',
'inversion-cubex': 'Cubex',
'inversion-birb': 'Birb Style',
'inversion-line-art': 'Line Art',
'inversion-minecraft': 'Minecraft Concept',
};
export const PLATFORM_LABELS: Record<string, string> = { export const PLATFORM_LABELS: Record<string, string> = {
amd: 'AMD GPU', amd: 'AMD GPU',
// eslint-disable-next-line id-blacklist // eslint-disable-next-line id-blacklist