From 1f6105a8fe861f82c44b3e4f6132ae959aacb678 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Mar 2023 10:50:48 -0500 Subject: [PATCH] make blend functions take tuples rather than split lists --- api/onnx_web/convert/__main__.py | 14 ++++-- api/onnx_web/convert/diffusion/lora.py | 37 +++++++------- .../convert/diffusion/textual_inversion.py | 48 ++++++++++--------- api/onnx_web/convert/utils.py | 2 +- api/onnx_web/diffusers/load.py | 11 ++--- docs/converting-models.md | 8 ++-- 6 files changed, 64 insertions(+), 56 deletions(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index ac6ba257..a1b4c4fb 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -289,7 +289,7 @@ def convert_models(ctx: ConversionContext, args, models: Models): inversion_name = inversion["name"] inversion_source = inversion["source"] - inversion_format = inversion.get("format", "embeddings") + inversion_format = inversion.get("format", None) inversion_source = fetch_model( ctx, f"{name}-inversion-{inversion_name}", @@ -303,10 +303,14 @@ def convert_models(ctx: ConversionContext, args, models: Models): ctx, blend_models["text_encoder"], blend_models["tokenizer"], - [inversion_source], - [inversion_format], - base_token=inversion_token, - inversion_weights=[inversion_weight], + [ + ( + inversion_source, + inversion_weight, + inversion_token, + inversion_format, + ) + ], ) for lora in model.get("loras", []): diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 6c869a18..3754c538 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from logging import getLogger from os import path -from typing import Dict, List, Literal, Tuple +from typing import Dict, List, Literal, Tuple, Union import numpy as np import torch @@ -57,25 +57,23 @@ def fix_node_name(key: str): def blend_loras( context: ServerContext, - base_name: str, - lora_names: List[str], - dest_type: Literal["text_encoder", "unet"], - lora_weights: "np.NDArray[np.float64]" = None, + base_name: Union[str, ModelProto], + loras: List[Tuple[str, float]], + model_type: Literal["text_encoder", "unet"], ): base_model = base_name if isinstance(base_name, ModelProto) else load(base_name) - lora_models = [load_file(name) for name in lora_names] - lora_count = len(lora_models) + + lora_count = len(loras) + lora_models = [load_file(name) for name, _weight in loras] lora_weights = lora_weights or (np.ones((lora_count)) / lora_count) - if dest_type == "text_encoder": + if model_type == "text_encoder": lora_prefix = "lora_te_" else: - lora_prefix = f"lora_{dest_type}_" + lora_prefix = f"lora_{model_type}_" blended: Dict[str, np.ndarray] = {} - for lora_name, lora_model, lora_weight in zip( - lora_names, lora_models, lora_weights - ): + for (lora_name, lora_weight), lora_model in zip(loras, lora_models): logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight) for key in lora_model.keys(): if ".lora_down" in key and lora_prefix in key: @@ -254,8 +252,8 @@ if __name__ == "__main__": parser.add_argument("--base", type=str) parser.add_argument("--dest", type=str) parser.add_argument("--type", type=str, choices=["text_encoder", "unet"]) - parser.add_argument("--lora_models", nargs="+", type=str) - parser.add_argument("--lora_weights", nargs="+", type=float) + parser.add_argument("--lora_models", nargs="+", type=str, default=[]) + parser.add_argument("--lora_weights", nargs="+", type=float, default=[]) args = parser.parse_args() logger.info( @@ -265,10 +263,17 @@ if __name__ == "__main__": args.lora_weights, ) + default_weight = 1.0 / len(args.lora_models) + while len(args.lora_weights) < len(args.lora_models): + args.lora_weights.append(default_weight) + blend_model = blend_loras( - context, args.base, args.lora_models, args.type, args.lora_weights + context, + args.base, + list(zip(args.lora_models, args.lora_weights)), + args.type, ) - if args.dest is None or args.dest == "" or args.dest == "ort": + if args.dest is None or args.dest == "" or args.dest == ":load": # convert to external data and save to memory (bare_model, external_data) = buffer_external_data_tensors(blend_model) logger.info("saved external data for %s nodes", len(external_data)) diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index a13f5203..3fe67ab7 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -17,24 +17,30 @@ logger = getLogger(__name__) @torch.no_grad() def blend_textual_inversions( context: ServerContext, - text_encoder: Optional[ModelProto], - tokenizer: Optional[CLIPTokenizer], - inversion_names: List[str], - inversion_formats: List[str], - inversion_weights: Optional[List[float]] = None, - base_tokens: Optional[List[str]] = None, + text_encoder: ModelProto, + tokenizer: CLIPTokenizer, + inversions: List[Tuple[str, float, Optional[str], Optional[str]]], ) -> Tuple[ModelProto, CLIPTokenizer]: dtype = np.float embeds = {} - for name, format, weight, base_token in zip( - inversion_names, - inversion_formats, - inversion_weights, - base_tokens or inversion_names, - ): - logger.info("blending Textual Inversion %s with weight of %s", name, weight) + for name, weight, base_token, format in inversions: + if base_token is None: + base_token = name + + if format is None: + # TODO: detect concept format + format = "embeddings" + + logger.info( + "blending Textual Inversion %s with weight of %s for token %s", + name, + weight, + base_token, + ) + if format == "concept": + # TODO: this should be done in fetch, maybe embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin") token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt") @@ -68,9 +74,10 @@ def blend_textual_inversions( sum_layer = np.zeros(trained_embeds[0, :].shape) for i in range(num_tokens): - token = f"{base_token or name}-{i}" + token = f"{base_token}-{i}" layer = trained_embeds[i, :].cpu().numpy().astype(dtype) layer *= weight + sum_layer += layer if token in embeds: embeds[token] += layer @@ -78,7 +85,7 @@ def blend_textual_inversions( embeds[token] = layer # add sum layer to embeds - sum_token = f"{base_token or name}-all" + sum_token = f"{base_token}-all" if sum_token in embeds: embeds[sum_token] += sum_layer else: @@ -170,19 +177,16 @@ def convert_diffusion_textual_inversion( context, text_encoder, tokenizer, - [inversion], - [format], - [weight], - base_token=(base_token or name), + [(inversion, weight, base_token, format)], ) - logger.info("saving tokenizer for textual inversion") + logger.info("saving tokenizer for Textual Inversion") tokenizer.save_pretrained(tokenizer_path) - logger.info("saving text encoder for textual inversion") + logger.info("saving text encoder for Textual Inversion") save_model( text_encoder, f=encoder_model, ) - logger.info("textual inversion saved to %s", dest_path) + logger.info("Textual Inversion saved to %s", dest_path) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index a06d6126..1cac9b9e 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -191,7 +191,7 @@ def load_yaml(file: str) -> Config: return Config(data) -def remove_prefix(name, prefix): +def remove_prefix(name: str, prefix: str) -> str: if name.startswith(prefix): return name[len(prefix) :] diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 7236e120..857874e6 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -234,10 +234,7 @@ def load_pipeline( server, text_encoder, tokenizer, - inversion_models, - ["embeddings"] * len(inversion_names), - inversion_weights, - base_tokens=inversion_names, + list(zip(inversion_models, inversion_weights, inversion_names)), ) # should be pretty small and should not need external data @@ -267,9 +264,8 @@ def load_pipeline( text_encoder = blend_loras( server, text_encoder, - lora_models, + list(zip(lora_models, lora_weights)), "text_encoder", - lora_weights=lora_weights, ) (text_encoder, text_encoder_data) = buffer_external_data_tensors( text_encoder @@ -291,9 +287,8 @@ def load_pipeline( blended_unet = blend_loras( server, path.join(model, "unet", "model.onnx"), - lora_models, + list(zip(lora_models, lora_weights)), "unet", - lora_weights=lora_weights, ) (unet_model, unet_data) = buffer_external_data_tensors(blended_unet) unet_names, unet_values = zip(*unet_data) diff --git a/docs/converting-models.md b/docs/converting-models.md index 646204ba..65250f64 100644 --- a/docs/converting-models.md +++ b/docs/converting-models.md @@ -255,8 +255,8 @@ The base token, without any layer numbers, should be printed to the logs with th ```none [2023-03-08 04:54:00,234] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: found embedding for token : torch.Size([768]) [2023-03-08 04:54:01,624] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: added 1 tokens -[2023-03-08 04:54:01,814] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving tokenizer for textual inversion -[2023-03-08 04:54:01,859] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving text encoder for textual inversion +[2023-03-08 04:54:01,814] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving tokenizer for Textual Inversion +[2023-03-08 04:54:01,859] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving text encoder for Textual Inversion ``` If you have set a custom token, that will be shown instead. If more than one token has been added, they will be @@ -272,8 +272,8 @@ The number of layers is shown in the server logs when the model is converted: ```none [2023-03-08 04:54:00,234] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: found embedding for token : torch.Size([768]) [2023-03-08 04:54:01,624] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: added 1 tokens -[2023-03-08 04:54:01,814] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving tokenizer for textual inversion -[2023-03-08 04:54:01,859] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving text encoder for textual inversion +[2023-03-08 04:54:01,814] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving tokenizer for Textual Inversion +[2023-03-08 04:54:01,859] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving text encoder for Textual Inversion ... [2023-03-08 04:58:06,378] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: generating 74 layer tokens [2023-03-08 04:58:06,379] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: found embedding for token ['goblin-0', 'goblin-1', 'goblin-2', 'goblin-3', 'goblin-4', 'goblin-5', 'goblin-6', 'gob