From e8b5ff250dfb07815043ab13a49b7989ab817861 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 21 Feb 2023 23:50:27 -0600 Subject: [PATCH] add none option to inversion menu --- api/extras.json | 22 + api/onnx_web/convert/__main__.py | 10 +- api/onnx_web/convert/diffusion/lora.py | 137 +++-- api/onnx_web/convert/diffusion/original.py | 577 +++++++++++++----- .../convert/diffusion/textual_inversion.py | 27 +- api/onnx_web/diffusion/load.py | 2 +- gui/src/components/control/ModelControl.tsx | 5 +- gui/src/components/input/QueryList.tsx | 21 +- gui/src/strings.ts | 8 + 9 files changed, 565 insertions(+), 244 deletions(-) diff --git a/api/extras.json b/api/extras.json index 4371eeb6..05478b71 100644 --- a/api/extras.json +++ b/api/extras.json @@ -22,6 +22,28 @@ "name": "diffusion-unstable-ink-dream-v6", "source": "civitai://5796", "format": "safetensors" + }, + { + "name": "stable-diffusion-onnx-v1-5", + "source": "runwayml/stable-diffusion-v1-5", + "inversions": [ + { + "name": "line-art", + "source": "sd-concepts-library/line-art" + }, + { + "name": "cubex", + "source": "sd-concepts-library/cubex" + }, + { + "name": "birb", + "source": "sd-concepts-library/birb-style" + }, + { + "name": "minecraft", + "source": "sd-concepts-library/minecraft-concept-art" + } + ] } ], "correction": [], diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 09ce6e3d..2a670511 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -10,8 +10,8 @@ from jsonschema import ValidationError, validate from yaml import safe_load from .correction_gfpgan import convert_correction_gfpgan -from .diffusion.original import convert_diffusion_original from .diffusion.diffusers import convert_diffusion_diffusers +from .diffusion.original import convert_diffusion_original from .diffusion.textual_inversion import convert_diffusion_textual_inversion from .upscale_resrgan import convert_upscale_resrgan from .utils import ( @@ -233,8 +233,12 @@ def convert_models(ctx: ConversionContext, args, models: Models): for inversion in model.get("inversions", []): inversion_name = inversion["name"] inversion_source = inversion["source"] - inversion_source = fetch_model(ctx, f"{name}-inversion-{inversion_name}", inversion_source) - convert_diffusion_textual_inversion(ctx, inversion_name, model["source"], inversion_source) + inversion_source = fetch_model( + ctx, f"{name}-inversion-{inversion_name}", inversion_source + ) + convert_diffusion_textual_inversion( + ctx, inversion_name, model["source"], inversion_source + ) except Exception as e: logger.error("error converting diffusion model %s: %s", name, e) diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 1b97b430..47c09f59 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -1,94 +1,105 @@ -from numpy import ndarray -from onnx import TensorProto, helper, load, numpy_helper, ModelProto, save_model -from typing import Dict, List, Tuple from logging import getLogger +from typing import List, Tuple +from numpy import ndarray +from onnx import ModelProto, TensorProto, helper, load, numpy_helper, save_model logger = getLogger(__name__) def load_lora(filename: str): - model = load(filename) + model = load(filename) - for weight in model.graph.initializer: - # print(weight.name, numpy_helper.to_array(weight).shape) - pass + for weight in model.graph.initializer: + # print(weight.name, numpy_helper.to_array(weight).shape) + pass - return model + return model -def blend_loras(base: ModelProto, weights: List[ModelProto], alphas: List[float]) -> List[Tuple[TensorProto, ndarray]]: - total = 1 + sum(alphas) +def blend_loras( + base: ModelProto, weights: List[ModelProto], alphas: List[float] +) -> List[Tuple[TensorProto, ndarray]]: + total = 1 + sum(alphas) - results = [] + results = [] - for base_node in base.graph.initializer: - logger.info("blending initializer node %s", base_node.name) - base_weights = numpy_helper.to_array(base_node).copy() + for base_node in base.graph.initializer: + logger.info("blending initializer node %s", base_node.name) + base_weights = numpy_helper.to_array(base_node).copy() - for weight, alpha in zip(weights, alphas): - weight_node = next(iter([f for f in weight.graph.initializer if f.name == base_node.name]), None) + for weight, alpha in zip(weights, alphas): + weight_node = next( + iter([f for f in weight.graph.initializer if f.name == base_node.name]), + None, + ) - if weight_node is not None: - base_weights += numpy_helper.to_array(weight_node) * alpha - else: - logger.warning("missing weights: %s in %s", base_node.name, weight.doc_string) + if weight_node is not None: + base_weights += numpy_helper.to_array(weight_node) * alpha + else: + logger.warning( + "missing weights: %s in %s", base_node.name, weight.doc_string + ) - results.append((base_node, base_weights / total)) + results.append((base_node, base_weights / total)) - return results + return results def convert_diffusion_lora(part: str): - lora_weights = [ - f"diffusion-lora-jack/{part}/model.onnx", - f"diffusion-lora-taters/{part}/model.onnx", - ] + lora_weights = [ + f"diffusion-lora-jack/{part}/model.onnx", + f"diffusion-lora-taters/{part}/model.onnx", + ] - base = load_lora(f"stable-diffusion-onnx-v1-5/{part}/model.onnx") - weights = [load_lora(f) for f in lora_weights] - alphas = [1 / len(weights)] * len(weights) - logger.info("blending LoRAs with alphas: %s, %s", weights, alphas) + base = load_lora(f"stable-diffusion-onnx-v1-5/{part}/model.onnx") + weights = [load_lora(f) for f in lora_weights] + alphas = [1 / len(weights)] * len(weights) + logger.info("blending LoRAs with alphas: %s, %s", weights, alphas) - result = blend_loras(base, weights, alphas) - logger.info("blended result keys: %s", len(result)) + result = blend_loras(base, weights, alphas) + logger.info("blended result keys: %s", len(result)) - del weights - del alphas + del weights + del alphas - tensors = [] - for node, tensor in result: - logger.info("remaking tensor for %s", node.name) - tensors.append(helper.make_tensor(node.name, node.data_type, node.dims, tensor)) + tensors = [] + for node, tensor in result: + logger.info("remaking tensor for %s", node.name) + tensors.append(helper.make_tensor(node.name, node.data_type, node.dims, tensor)) - del result + del result - graph = helper.make_graph( - base.graph.node, - base.graph.name, - base.graph.input, - base.graph.output, - tensors, - base.graph.doc_string, - base.graph.value_info, - base.graph.sparse_initializer, - ) - model = helper.make_model(graph) + graph = helper.make_graph( + base.graph.node, + base.graph.name, + base.graph.input, + base.graph.output, + tensors, + base.graph.doc_string, + base.graph.value_info, + base.graph.sparse_initializer, + ) + model = helper.make_model(graph) - del model.opset_import[:] - opset = model.opset_import.add() - opset.version = 14 + del model.opset_import[:] + opset = model.opset_import.add() + opset.version = 14 - save_model( - model, - f"/tmp/lora-{part}.onnx", - save_as_external_data=True, - all_tensors_to_one_file=True, - location=f"/tmp/lora-{part}.tensors", - ) - logger.info("saved model to %s and tensors to %s", f"/tmp/lora-{part}.onnx", f"/tmp/lora-{part}.tensors") + save_model( + model, + f"/tmp/lora-{part}.onnx", + save_as_external_data=True, + all_tensors_to_one_file=True, + location=f"/tmp/lora-{part}.tensors", + ) + logger.info( + "saved model to %s and tensors to %s", + f"/tmp/lora-{part}.onnx", + f"/tmp/lora-{part}.tensors", + ) if __name__ == "__main__": - convert_diffusion_lora("unet") - convert_diffusion_lora("text_encoder") \ No newline at end of file + convert_diffusion_lora("unet") + convert_diffusion_lora("text_encoder") diff --git a/api/onnx_web/convert/diffusion/original.py b/api/onnx_web/convert/diffusion/original.py index b75be023..c331ee74 100644 --- a/api/onnx_web/convert/diffusion/original.py +++ b/api/onnx_web/convert/diffusion/original.py @@ -53,13 +53,13 @@ from transformers import ( CLIPVisionConfig, ) -from .diffusers import convert_diffusion_diffusers from ..utils import ConversionContext, ModelDict, load_tensor, load_yaml, sanitize_name +from .diffusers import convert_diffusion_diffusers logger = getLogger(__name__) -class TrainingConfig(): +class TrainingConfig: """ From https://github.com/d8ahazard/sd_dreambooth_extension/blob/main/dreambooth/db_config.py """ @@ -184,7 +184,9 @@ class TrainingConfig(): backup_dir = os.path.join(models_path, "backups") if not os.path.exists(backup_dir): os.makedirs(backup_dir) - config_file = os.path.join(models_path, "backups", f"db_config_{self.revision}.json") + config_file = os.path.join( + models_path, "backups", f"db_config_{self.revision}.json" + ) with open(config_file, "w") as outfile: json.dump(self.__dict__, outfile, indent=4) @@ -238,7 +240,9 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("emb_layers.1", "time_emb_proj") new_item = new_item.replace("skip_connection", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) mapping.append({"old": old_item, "new": new_item}) @@ -253,7 +257,9 @@ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): for old_item in old_list: new_item = old_item new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) mapping.append({"old": old_item, "new": new_item}) @@ -295,7 +301,9 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("proj_out.weight", "proj_attn.weight") new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = shave_segments( + new_item, n_shave_prefix_segments=n_shave_prefix_segments + ) mapping.append({"old": old_item, "new": new_item}) @@ -303,7 +311,12 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None + paths, + checkpoint, + old_checkpoint, + attention_paths_to_split=None, + additional_replacements=None, + config=None, ): """ This does the final conversion step: take locally converted weights and apply a global renaming @@ -312,7 +325,9 @@ def assign_to_checkpoint( Assigns the weights to the new checkpoint. """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + assert isinstance( + paths, list + ), "Paths should be a list of dicts containing 'old' and 'new' keys." # Splits the attention layers into three variables. if attention_paths_to_split is not None: @@ -324,7 +339,9 @@ def assign_to_checkpoint( num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + old_tensor = old_tensor.reshape( + (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:] + ) query, key, value = old_tensor.split(channels // num_heads, dim=1) checkpoint[path_map["query"]] = query.reshape(target_shape) @@ -335,7 +352,10 @@ def assign_to_checkpoint( new_path = path["new"] # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: + if ( + attention_paths_to_split is not None + and new_path in attention_paths_to_split + ): continue # Global renaming happens here @@ -373,19 +393,29 @@ def create_unet_diffusers_config(original_config, image_size: int): unet_params = original_config.model.params.unet_config.params vae_params = original_config.model.params.first_stage_config.params.ddconfig - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + block_out_channels = [ + unet_params.model_channels * mult for mult in unet_params.channel_mult + ] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + block_type = ( + "CrossAttnDownBlock2D" + if resolution in unet_params.attention_resolutions + else "DownBlock2D" + ) down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + block_type = ( + "CrossAttnUpBlock2D" + if resolution in unet_params.attention_resolutions + else "UpBlock2D" + ) up_block_types.append(block_type) resolution //= 2 @@ -393,7 +423,9 @@ def create_unet_diffusers_config(original_config, image_size: int): head_dim = unet_params.num_heads if "num_heads" in unet_params else None use_linear_projection = ( - unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + unet_params.use_linear_in_transformer + if "use_linear_in_transformer" in unet_params + else False ) if use_linear_projection: # stable diffusion 2-base-512 and 2-768 @@ -482,7 +514,9 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False for key in keys: if key.startswith("model.diffusion_model"): flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop( + flat_ema_key + ) else: print( "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" @@ -493,33 +527,53 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False if key.startswith(unet_key): unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - new_checkpoint = {"time_embedding.linear_1.weight": unet_state_dict["time_embed.0.weight"], - "time_embedding.linear_1.bias": unet_state_dict["time_embed.0.bias"], - "time_embedding.linear_2.weight": unet_state_dict["time_embed.2.weight"], - "time_embedding.linear_2.bias": unet_state_dict["time_embed.2.bias"], - "conv_in.weight": unet_state_dict["input_blocks.0.0.weight"], - "conv_in.bias": unet_state_dict["input_blocks.0.0.bias"], - "conv_norm_out.weight": unet_state_dict["out.0.weight"], - "conv_norm_out.bias": unet_state_dict["out.0.bias"], - "conv_out.weight": unet_state_dict["out.2.weight"], - "conv_out.bias": unet_state_dict["out.2.bias"]} + new_checkpoint = { + "time_embedding.linear_1.weight": unet_state_dict["time_embed.0.weight"], + "time_embedding.linear_1.bias": unet_state_dict["time_embed.0.bias"], + "time_embedding.linear_2.weight": unet_state_dict["time_embed.2.weight"], + "time_embedding.linear_2.bias": unet_state_dict["time_embed.2.bias"], + "conv_in.weight": unet_state_dict["input_blocks.0.0.weight"], + "conv_in.bias": unet_state_dict["input_blocks.0.0.bias"], + "conv_norm_out.weight": unet_state_dict["out.0.weight"], + "conv_norm_out.bias": unet_state_dict["out.0.bias"], + "conv_out.weight": unet_state_dict["out.2.weight"], + "conv_out.bias": unet_state_dict["out.2.bias"], + } # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + num_input_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "input_blocks" in layer + } + ) input_blocks = { layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] for layer_id in range(num_input_blocks) } # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + num_middle_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "middle_block" in layer + } + ) middle_blocks = { layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] for layer_id in range(num_middle_blocks) } # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + num_output_blocks = len( + { + ".".join(layer.split(".")[:2]) + for layer in unet_state_dict + if "output_blocks" in layer + } + ) output_blocks = { layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] for layer_id in range(num_output_blocks) @@ -530,29 +584,45 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + key + for key in input_blocks[i] + if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key ] attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) + new_checkpoint[ + f"down_blocks.{block_id}.downsamplers.0.conv.weight" + ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.weight") + new_checkpoint[ + f"down_blocks.{block_id}.downsamplers.0.conv.bias" + ] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + meta_path = { + "old": f"input_blocks.{i}.0", + "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}", + } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) if len(attentions): paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + meta_path = { + "old": f"input_blocks.{i}.1", + "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}", + } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) resnet_0 = middle_blocks[0] @@ -568,7 +638,11 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False attentions_paths = renew_attention_paths(attentions) meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + attentions_paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) for i in range(num_output_blocks): @@ -586,25 +660,36 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False if len(output_block_list) > 1: resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + attentions = [ + key for key in output_blocks[i] if f"output_blocks.{i}.1" in key + ] resnet_0_paths = renew_resnet_paths(resnets) paths = renew_resnet_paths(resnets) - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + meta_path = { + "old": f"output_blocks.{i}.0", + "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}", + } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) output_block_list = {k: sorted(v) for k, v in output_block_list.items()} if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] + index = list(output_block_list.values()).index( + ["conv.bias", "conv.weight"] + ) + new_checkpoint[ + f"up_blocks.{block_id}.upsamplers.0.conv.weight" + ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"] + new_checkpoint[ + f"up_blocks.{block_id}.upsamplers.0.conv.bias" + ] = unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"] # Clear attentions as they have been attributed above. if len(attentions) == 2: @@ -617,13 +702,27 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", } assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + paths, + new_checkpoint, + unet_state_dict, + additional_replacements=[meta_path], + config=config, ) else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + resnet_0_paths = renew_resnet_paths( + output_block_layers, n_shave_prefix_segments=1 + ) for path in resnet_0_paths: old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + new_path = ".".join( + [ + "up_blocks", + str(block_id), + "resnets", + str(layer_in_block_id), + path["new"], + ] + ) new_checkpoint[new_path] = unet_state_dict[old_path] @@ -645,49 +744,75 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True): else: vae_state_dict[key] = checkpoint.get(key) - new_checkpoint = {"encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"], - "encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"], - "encoder.conv_out.weight": vae_state_dict["encoder.conv_out.weight"], - "encoder.conv_out.bias": vae_state_dict["encoder.conv_out.bias"], - "encoder.conv_norm_out.weight": vae_state_dict["encoder.norm_out.weight"], - "encoder.conv_norm_out.bias": vae_state_dict["encoder.norm_out.bias"], - "decoder.conv_in.weight": vae_state_dict["decoder.conv_in.weight"], - "decoder.conv_in.bias": vae_state_dict["decoder.conv_in.bias"], - "decoder.conv_out.weight": vae_state_dict["decoder.conv_out.weight"], - "decoder.conv_out.bias": vae_state_dict["decoder.conv_out.bias"], - "decoder.conv_norm_out.weight": vae_state_dict["decoder.norm_out.weight"], - "decoder.conv_norm_out.bias": vae_state_dict["decoder.norm_out.bias"], - "quant_conv.weight": vae_state_dict["quant_conv.weight"], - "quant_conv.bias": vae_state_dict["quant_conv.bias"], - "post_quant_conv.weight": vae_state_dict["post_quant_conv.weight"], - "post_quant_conv.bias": vae_state_dict["post_quant_conv.bias"]} + new_checkpoint = { + "encoder.conv_in.weight": vae_state_dict["encoder.conv_in.weight"], + "encoder.conv_in.bias": vae_state_dict["encoder.conv_in.bias"], + "encoder.conv_out.weight": vae_state_dict["encoder.conv_out.weight"], + "encoder.conv_out.bias": vae_state_dict["encoder.conv_out.bias"], + "encoder.conv_norm_out.weight": vae_state_dict["encoder.norm_out.weight"], + "encoder.conv_norm_out.bias": vae_state_dict["encoder.norm_out.bias"], + "decoder.conv_in.weight": vae_state_dict["decoder.conv_in.weight"], + "decoder.conv_in.bias": vae_state_dict["decoder.conv_in.bias"], + "decoder.conv_out.weight": vae_state_dict["decoder.conv_out.weight"], + "decoder.conv_out.bias": vae_state_dict["decoder.conv_out.bias"], + "decoder.conv_norm_out.weight": vae_state_dict["decoder.norm_out.weight"], + "decoder.conv_norm_out.bias": vae_state_dict["decoder.norm_out.bias"], + "quant_conv.weight": vae_state_dict["quant_conv.weight"], + "quant_conv.bias": vae_state_dict["quant_conv.bias"], + "post_quant_conv.weight": vae_state_dict["post_quant_conv.weight"], + "post_quant_conv.bias": vae_state_dict["post_quant_conv.bias"], + } # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + num_down_blocks = len( + { + ".".join(layer.split(".")[:3]) + for layer in vae_state_dict + if "encoder.down" in layer + } + ) down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] + for layer_id in range(num_down_blocks) } # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + num_up_blocks = len( + { + ".".join(layer.split(".")[:3]) + for layer in vae_state_dict + if "decoder.up" in layer + } + ) up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] + for layer_id in range(num_up_blocks) } for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + resnets = [ + key + for key in down_blocks[i] + if f"down.{i}" in key and f"down.{i}.downsample" not in key + ] if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) + new_checkpoint[ + f"encoder.down_blocks.{i}.downsamplers.0.conv.weight" + ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight") + new_checkpoint[ + f"encoder.down_blocks.{i}.downsamplers.0.conv.bias" + ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias") paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] num_mid_res_blocks = 2 @@ -696,31 +821,51 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True): paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) conv_attn_to_linear(new_checkpoint) for i in range(num_up_blocks): block_id = num_up_blocks - 1 - i resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + key + for key in up_blocks[block_id] + if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key ] if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] + new_checkpoint[ + f"decoder.up_blocks.{i}.upsamplers.0.conv.weight" + ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"] + new_checkpoint[ + f"decoder.up_blocks.{i}.upsamplers.0.conv.bias" + ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"] paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] num_mid_res_blocks = 2 @@ -729,12 +874,24 @@ def convert_ldm_vae_checkpoint(checkpoint, config, first_stage=True): paths = renew_vae_resnet_paths(resnets) meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] paths = renew_vae_attention_paths(mid_attentions) meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + config=config, + ) conv_attn_to_linear(new_checkpoint) return new_checkpoint @@ -769,14 +926,16 @@ def convert_ldm_bert_checkpoint(checkpoint, config): for i, hf_layer in enumerate(hf_layers): if i != 0: i += i - pt_layer = pt_layers[i: i + 2] + pt_layer = pt_layers[i : i + 2] _copy_layer(hf_layer, pt_layer) hf_model = LDMBertModel(config).eval() # copy embeds hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight - hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + hf_model.model.embed_positions.weight.data = ( + checkpoint.transformer.pos_emb.emb.weight + ) # copy layer norm _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) @@ -799,9 +958,13 @@ def convert_ldm_clip_checkpoint(checkpoint): for key in keys: if key.startswith("cond_stage_model.transformer"): if key.find("text_model") == -1: - text_model_dict["text_model." + key[len("cond_stage_model.transformer."):]] = checkpoint[key] + text_model_dict[ + "text_model." + key[len("cond_stage_model.transformer.") :] + ] = checkpoint[key] else: - text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] + text_model_dict[ + key[len("cond_stage_model.transformer.") :] + ] = checkpoint[key] text_model.load_state_dict(text_model_dict) @@ -809,12 +972,16 @@ def convert_ldm_clip_checkpoint(checkpoint): textenc_conversion_lst = [ - ('cond_stage_model.model.positional_embedding', - "text_model.embeddings.position_embedding.weight"), - ('cond_stage_model.model.token_embedding.weight', - "text_model.embeddings.token_embedding.weight"), - ('cond_stage_model.model.ln_final.weight', 'text_model.final_layer_norm.weight'), - ('cond_stage_model.model.ln_final.bias', 'text_model.final_layer_norm.bias') + ( + "cond_stage_model.model.positional_embedding", + "text_model.embeddings.position_embedding.weight", + ), + ( + "cond_stage_model.model.token_embedding.weight", + "text_model.embeddings.token_embedding.weight", + ), + ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"), + ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"), ] textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} @@ -827,8 +994,14 @@ textenc_transformer_conversion_lst = [ (".c_proj.", ".fc2."), (".attn", ".self_attn"), ("ln_final.", "transformer.text_model.final_layer_norm."), - ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), - ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), + ( + "token_embedding.weight", + "transformer.text_model.embeddings.token_embedding.weight", + ), + ( + "positional_embedding", + "transformer.text_model.embeddings.position_embedding.weight", + ), ] protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} textenc_pattern = re.compile("|".join(protected.keys())) @@ -844,7 +1017,9 @@ def convert_paint_by_example_checkpoint(checkpoint): for key in keys: if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[ + key + ] # load clip vision model.model.load_state_dict(text_model_dict) @@ -902,19 +1077,25 @@ def convert_paint_by_example_checkpoint(checkpoint): def convert_open_clip_checkpoint(checkpoint): - text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + text_model = CLIPTextModel.from_pretrained( + "stabilityai/stable-diffusion-2", subfolder="text_encoder" + ) keys = list(checkpoint.keys()) text_model_dict = {} - if 'cond_stage_model.model.text_projection' in checkpoint: - d_model = int(checkpoint['cond_stage_model.model.text_projection'].shape[0]) + if "cond_stage_model.model.text_projection" in checkpoint: + d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0]) else: logger.debug("no projection shape found, setting to 1024") d_model = 1024 - text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + text_model_dict[ + "text_model.embeddings.position_ids" + ] = text_model.text_model.embeddings.get_buffer("position_ids") for key in keys: - if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer + if ( + "resblocks.23" in key + ): # Diffusers drops the final layer and only uses the penultimate layer continue if key in textenc_conversion_map: text_model_dict[textenc_conversion_map[key]] = checkpoint[key] @@ -922,18 +1103,34 @@ def convert_open_clip_checkpoint(checkpoint): new_key = key[len("cond_stage_model.model.transformer.") :] if new_key.endswith(".in_proj_weight"): new_key = new_key[: -len(".in_proj_weight")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] - text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] - text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + new_key = textenc_pattern.sub( + lambda m: protected[re.escape(m.group(0))], new_key + ) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][ + :d_model, : + ] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][ + d_model : d_model * 2, : + ] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][ + d_model * 2 :, : + ] elif new_key.endswith(".in_proj_bias"): new_key = new_key[: -len(".in_proj_bias")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + new_key = textenc_pattern.sub( + lambda m: protected[re.escape(m.group(0))], new_key + ) text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] - text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] - text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][ + d_model : d_model * 2 + ] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][ + d_model * 2 : + ] else: - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + new_key = textenc_pattern.sub( + lambda m: protected[re.escape(m.group(0))], new_key + ) text_model_dict[new_key] = checkpoint[key] @@ -992,7 +1189,15 @@ def download_model(db_config: TrainingConfig, token): siblings = repo_info.siblings - diffusion_dirs = ["text_encoder", "unet", "vae", "tokenizer", "scheduler", "feature_extractor", "safety_checker"] + diffusion_dirs = [ + "text_encoder", + "unet", + "vae", + "tokenizer", + "scheduler", + "feature_extractor", + "safety_checker", + ] config_file = None model_index = None model_files = [] @@ -1031,9 +1236,9 @@ def download_model(db_config: TrainingConfig, token): (x for x in model_files if "nonema" in x), next( (x for x in model_files if ".safetensors" in x), - model_files[0] if model_files else None - ) - ) + model_files[0] if model_files else None, + ), + ), ) files_to_fetch = None @@ -1061,7 +1266,7 @@ def download_model(db_config: TrainingConfig, token): filename=repo_file, repo_type="model", revision=repo_info.sha, - token=token + token=token, ) replace_symlinks(out, db_config.model_dir) dest = None @@ -1074,7 +1279,9 @@ def download_model(db_config: TrainingConfig, token): for diffusion_dir in diffusion_dirs: if diffusion_dir in out: out_model = db_config.pretrained_model_name_or_path - dest = os.path.join(db_config.pretrained_model_name_or_path, diffusion_dir) + dest = os.path.join( + db_config.pretrained_model_name_or_path, diffusion_dir + ) if not dest: if ".ckpt" in out or ".safetensors" in out: dest = os.path.join(db_config.model_dir, "src") @@ -1095,9 +1302,11 @@ def get_config_path( model_version: str = "v1", train_type: str = "default", config_base_name: str = "training", - prediction_type: str = "epsilon" + prediction_type: str = "epsilon", ): - train_type = f"{train_type}" if not prediction_type == "v_prediction" else f"{train_type}-v" + train_type = ( + f"{train_type}" if not prediction_type == "v_prediction" else f"{train_type}-v" + ) parts = os.path.join( os.path.dirname(os.path.realpath(__file__)), @@ -1106,21 +1315,20 @@ def get_config_path( "..", "models", "configs", - f"{model_version}-{config_base_name}-{train_type}.yaml" + f"{model_version}-{config_base_name}-{train_type}.yaml", ) return os.path.abspath(parts) -def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None): +def get_config_file( + train_unfrozen=False, v2=False, prediction_type="epsilon", config_file=None +): if config_file is not None: return config_file config_base_name = "training" - model_versions = { - "v1": "v1", - "v2": "v2" - } + model_versions = {"v1": "v1", "v2": "v2"} train_types = { "default": "default", "unfrozen": "unfrozen", @@ -1134,7 +1342,9 @@ def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon", c else: model_train_type = train_types["default"] - return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type) + return get_config_path( + model_version_name, model_train_type, config_base_name, prediction_type + ) def extract_checkpoint( @@ -1182,8 +1392,9 @@ def extract_checkpoint( msg = None # Create empty config - db_config = TrainingConfig(ctx, model_name=new_model_name, scheduler=scheduler_type, - src=checkpoint_file) + db_config = TrainingConfig( + ctx, model_name=new_model_name, scheduler=scheduler_type, src=checkpoint_file + ) original_config_file = None @@ -1221,9 +1432,13 @@ def extract_checkpoint( else: prediction_type = "epsilon" - original_config_file = get_config_file(train_unfrozen, v2, prediction_type, config_file=config_file) + original_config_file = get_config_file( + train_unfrozen, v2, prediction_type, config_file=config_file + ) - logger.info(f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}") + logger.info( + f"Pred and size are {prediction_type} and {image_size}, using config: {original_config_file}" + ) db_config.resolution = image_size db_config.lifetime_revision = revision db_config.epoch = epoch @@ -1233,12 +1448,18 @@ def extract_checkpoint( # Use existing YAML if present if checkpoint_file is not None: - config_check = checkpoint_file.replace(".ckpt", ".yaml") if ".ckpt" in checkpoint_file else checkpoint_file.replace(".safetensors", ".yaml") + config_check = ( + checkpoint_file.replace(".ckpt", ".yaml") + if ".ckpt" in checkpoint_file + else checkpoint_file.replace(".safetensors", ".yaml") + ) if os.path.exists(config_check): original_config_file = config_check if original_config_file is None or not os.path.exists(original_config_file): - logger.warning("unable to select a config file: %s" % (original_config_file)) + logger.warning( + "unable to select a config file: %s" % (original_config_file) + ) return logger.debug("trying to load: %s", original_config_file) @@ -1281,7 +1502,9 @@ def extract_checkpoint( # Convert the UNet2DConditionModel model. logger.info("converting UNet") - unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config = create_unet_diffusers_config( + original_config, image_size=image_size + ) unet_config["upcast_attention"] = upcast_attention unet = UNet2DConditionModel(**unet_config) @@ -1297,22 +1520,30 @@ def extract_checkpoint( vae_config = create_vae_diffusers_config(original_config, image_size=image_size) if vae_file is None: - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + converted_vae_checkpoint = convert_ldm_vae_checkpoint( + checkpoint, vae_config + ) else: vae_file = os.path.join(ctx.model_path, vae_file) logger.debug("loading custom VAE: %s", vae_file) vae_checkpoint = load_tensor(vae_file, map_location=map_location) - converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, first_stage=False) + converted_vae_checkpoint = convert_ldm_vae_checkpoint( + vae_checkpoint, vae_config, first_stage=False + ) vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) # Convert the text model. logger.info("converting text encoder") - text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] + text_model_type = original_config.model.params.cond_stage_config.target.split( + "." + )[-1] if text_model_type == "FrozenOpenCLIPEmbedder": text_model = convert_open_clip_checkpoint(checkpoint) - tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") + tokenizer = CLIPTokenizer.from_pretrained( + "stabilityai/stable-diffusion-2", subfolder="tokenizer" + ) pipe = StableDiffusionPipeline( vae=vae, text_encoder=text_model, @@ -1326,7 +1557,9 @@ def extract_checkpoint( elif text_model_type == "PaintByExample": vision_model = convert_paint_by_example_checkpoint(checkpoint) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker" + ) pipe = PaintByExamplePipeline( vae=vae, image_encoder=vision_model, @@ -1338,8 +1571,12 @@ def extract_checkpoint( elif text_model_type == "FrozenCLIPEmbedder": text_model = convert_ldm_clip_checkpoint(checkpoint) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") - feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker" + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker" + ) pipe = StableDiffusionPipeline( vae=vae, text_encoder=text_model, @@ -1347,16 +1584,24 @@ def extract_checkpoint( unet=unet, scheduler=scheduler, safety_checker=safety_checker, - feature_extractor=feature_extractor + feature_extractor=feature_extractor, ) else: text_config = create_ldm_bert_config(original_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") - pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, - scheduler=scheduler) + pipe = LDMTextToImagePipeline( + vqvae=vae, + bert=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) except Exception: - logger.error("exception setting up output: %s", traceback.format_exception(*sys.exc_info())) + logger.error( + "exception setting up output: %s", + traceback.format_exception(*sys.exc_info()), + ) pipe = None if pipe is None or db_config is None: @@ -1371,12 +1616,18 @@ def extract_checkpoint( scheduler = db_config.scheduler required_dirs = ["unet", "vae", "text_encoder", "scheduler", "tokenizer"] if original_config_file is not None and os.path.exists(original_config_file): - logger.debug("copying original config: %s -> %s", original_config_file, db_config.model_dir) + logger.debug( + "copying original config: %s -> %s", + original_config_file, + db_config.model_dir, + ) shutil.copy(original_config_file, db_config.model_dir) basename = os.path.basename(original_config_file) new_ex_path = os.path.join(db_config.model_dir, basename) new_name = os.path.join(db_config.model_dir, f"{db_config.model_name}.yaml") - logger.debug("copying model config to new name: %s -> %s", new_ex_path, new_name) + logger.debug( + "copying model config to new name: %s -> %s", new_ex_path, new_name + ) if os.path.exists(new_name): os.remove(new_name) os.rename(new_ex_path, new_name) @@ -1407,7 +1658,9 @@ def convert_diffusion_original( source = source or model["source"] dest = os.path.join(ctx.model_path, name) - logger.info("converting original Diffusers checkpoint %s: %s -> %s", name, source, dest) + logger.info( + "converting original Diffusers checkpoint %s: %s -> %s", name, source, dest + ) if os.path.exists(dest): logger.info("ONNX pipeline already exists, skipping") @@ -1420,8 +1673,18 @@ def convert_diffusion_original( if os.path.exists(torch_path): logger.info("torch pipeline already exists, reusing: %s", torch_path) else: - logger.info("converting original Diffusers check to Torch model: %s -> %s", source, torch_path) - extract_checkpoint(ctx, torch_name, source, config_file=model.get("config"), vae_file=model.get("vae")) + logger.info( + "converting original Diffusers check to Torch model: %s -> %s", + source, + torch_path, + ) + extract_checkpoint( + ctx, + torch_name, + source, + config_file=model.get("config"), + vae_file=model.get("vae"), + ) logger.info("converted original Diffusers checkpoint to Torch model") # VAE has already been converted and will confuse HF repo lookup diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index e4cfac27..cf8deb68 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -1,24 +1,29 @@ -from os import makedirs, path -from huggingface_hub.file_download import hf_hub_download -from transformers import CLIPTokenizer, CLIPTextModel -from torch.onnx import export from logging import getLogger - -from ..utils import ConversionContext +from os import makedirs, path import torch +from huggingface_hub.file_download import hf_hub_download +from torch.onnx import export +from transformers import CLIPTextModel, CLIPTokenizer + +from ..utils import ConversionContext logger = getLogger(__name__) -def convert_diffusion_textual_inversion(context: ConversionContext, name: str, base_model: str, inversion: str): +def convert_diffusion_textual_inversion( + context: ConversionContext, name: str, base_model: str, inversion: str +): dest_path = path.join(context.model_path, f"inversion-{name}") - logger.info("converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path) + logger.info( + "converting Textual Inversion: %s + %s -> %s", base_model, inversion, dest_path + ) if path.exists(dest_path): logger.info("ONNX model already exists, skipping.") + return - makedirs(path.join(dest_path, "text_encoder")) + makedirs(path.join(dest_path, "text_encoder"), exist_ok=True) embeds_file = hf_hub_download(repo_id=inversion, filename="learned_embeds.bin") token_file = hf_hub_download(repo_id=inversion, filename="token_identifier.txt") @@ -71,9 +76,7 @@ def convert_diffusion_textual_inversion(context: ConversionContext, name: str, b export( text_encoder, # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files - ( - text_input.input_ids.to(dtype=torch.int32) - ), + (text_input.input_ids.to(dtype=torch.int32)), f=path.join(dest_path, "text_encoder", "model.onnx"), input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output"], diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index 4f1e5dd2..946c85de 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -17,9 +17,9 @@ from diffusers import ( KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, LMSDiscreteScheduler, + OnnxRuntimeModel, PNDMScheduler, StableDiffusionPipeline, - OnnxRuntimeModel, ) try: diff --git a/gui/src/components/control/ModelControl.tsx b/gui/src/components/control/ModelControl.tsx index a51e3966..41d157cf 100644 --- a/gui/src/components/control/ModelControl.tsx +++ b/gui/src/components/control/ModelControl.tsx @@ -7,7 +7,7 @@ import { useStore } from 'zustand'; import { STALE_TIME } from '../../config.js'; import { ClientContext, StateContext } from '../../state.js'; -import { MODEL_LABELS, PLATFORM_LABELS } from '../../strings.js'; +import { INVERSION_LABELS, MODEL_LABELS, PLATFORM_LABELS } from '../../strings.js'; import { QueryList } from '../input/QueryList.js'; export function ModelControl() { @@ -56,12 +56,13 @@ export function ModelControl() { /> result.inversion, }} + showEmpty={true} value={params.inversion} onChange={(inversion) => { setModel({ diff --git a/gui/src/components/input/QueryList.tsx b/gui/src/components/input/QueryList.tsx index c66fab43..6230cce3 100644 --- a/gui/src/components/input/QueryList.tsx +++ b/gui/src/components/input/QueryList.tsx @@ -20,6 +20,7 @@ export interface QueryListProps { value: string; query: QueryListComplete | QueryListFilter; + showEmpty?: boolean; onChange?: (value: string) => void; } @@ -28,17 +29,25 @@ export function hasFilter(query: QueryListComplete | QueryListFilter): que return Reflect.has(query, 'selector'); } -export function filterQuery(query: QueryListComplete | QueryListFilter): Array { +export function filterQuery(query: QueryListComplete | QueryListFilter, showEmpty: boolean): Array { if (hasFilter(query)) { const data = mustExist(query.result.data); - return (query as QueryListFilter).selector(data); + const selected = (query as QueryListFilter).selector(data); + if (showEmpty) { + return ['', ...selected]; + } + return selected; } else { - return mustExist(query.result.data); + const data = Array.from(mustExist(query.result.data)); + if (showEmpty) { + return ['', ...data]; + } + return data; } } export function QueryList(props: QueryListProps) { - const { labels, query, value } = props; + const { labels, query, showEmpty = false, value } = props; const { result } = query; function firstValidValue(): string { @@ -52,7 +61,7 @@ export function QueryList(props: QueryListProps) { // update state when previous selection was invalid: https://github.com/ssube/onnx-web/issues/120 useEffect(() => { if (result.status === 'success' && doesExist(result.data) && doesExist(props.onChange)) { - const data = filterQuery(query); + const data = filterQuery(query, showEmpty); if (data.includes(value) === false) { props.onChange(data[0]); } @@ -77,7 +86,7 @@ export function QueryList(props: QueryListProps) { // else: success const labelID = `query-list-${props.id}-labels`; - const data = filterQuery(query); + const data = filterQuery(query, showEmpty); return {props.name} diff --git a/gui/src/strings.ts b/gui/src/strings.ts index 86c4171b..6e6dec3e 100644 --- a/gui/src/strings.ts +++ b/gui/src/strings.ts @@ -32,6 +32,14 @@ export const MODEL_LABELS: Record = { 'diffusion-unstable-ink-dream-v6': 'Unstable Ink Dream v6', }; +export const INVERSION_LABELS: Record = { + '': 'None', + 'inversion-cubex': 'Cubex', + 'inversion-birb': 'Birb Style', + 'inversion-line-art': 'Line Art', + 'inversion-minecraft': 'Minecraft Concept', +}; + export const PLATFORM_LABELS: Record = { amd: 'AMD GPU', // eslint-disable-next-line id-blacklist