1
0
Fork 0

make blend functions take tuples rather than split lists

This commit is contained in:
Sean Sube 2023-03-18 10:50:48 -05:00
parent 6cd0b4f7eb
commit 1f6105a8fe
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 64 additions and 56 deletions

View File

@ -289,7 +289,7 @@ def convert_models(ctx: ConversionContext, args, models: Models):
inversion_name = inversion["name"] inversion_name = inversion["name"]
inversion_source = inversion["source"] inversion_source = inversion["source"]
inversion_format = inversion.get("format", "embeddings") inversion_format = inversion.get("format", None)
inversion_source = fetch_model( inversion_source = fetch_model(
ctx, ctx,
f"{name}-inversion-{inversion_name}", f"{name}-inversion-{inversion_name}",
@ -303,10 +303,14 @@ def convert_models(ctx: ConversionContext, args, models: Models):
ctx, ctx,
blend_models["text_encoder"], blend_models["text_encoder"],
blend_models["tokenizer"], blend_models["tokenizer"],
[inversion_source], [
[inversion_format], (
base_token=inversion_token, inversion_source,
inversion_weights=[inversion_weight], inversion_weight,
inversion_token,
inversion_format,
)
],
) )
for lora in model.get("loras", []): for lora in model.get("loras", []):

View File

@ -1,7 +1,7 @@
from argparse import ArgumentParser from argparse import ArgumentParser
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Dict, List, Literal, Tuple from typing import Dict, List, Literal, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -57,25 +57,23 @@ def fix_node_name(key: str):
def blend_loras( def blend_loras(
context: ServerContext, context: ServerContext,
base_name: str, base_name: Union[str, ModelProto],
lora_names: List[str], loras: List[Tuple[str, float]],
dest_type: Literal["text_encoder", "unet"], model_type: Literal["text_encoder", "unet"],
lora_weights: "np.NDArray[np.float64]" = None,
): ):
base_model = base_name if isinstance(base_name, ModelProto) else load(base_name) base_model = base_name if isinstance(base_name, ModelProto) else load(base_name)
lora_models = [load_file(name) for name in lora_names]
lora_count = len(lora_models) lora_count = len(loras)
lora_models = [load_file(name) for name, _weight in loras]
lora_weights = lora_weights or (np.ones((lora_count)) / lora_count) lora_weights = lora_weights or (np.ones((lora_count)) / lora_count)
if dest_type == "text_encoder": if model_type == "text_encoder":
lora_prefix = "lora_te_" lora_prefix = "lora_te_"
else: else:
lora_prefix = f"lora_{dest_type}_" lora_prefix = f"lora_{model_type}_"
blended: Dict[str, np.ndarray] = {} blended: Dict[str, np.ndarray] = {}
for lora_name, lora_model, lora_weight in zip( for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
lora_names, lora_models, lora_weights
):
logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight) logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight)
for key in lora_model.keys(): for key in lora_model.keys():
if ".lora_down" in key and lora_prefix in key: if ".lora_down" in key and lora_prefix in key:
@ -254,8 +252,8 @@ if __name__ == "__main__":
parser.add_argument("--base", type=str) parser.add_argument("--base", type=str)
parser.add_argument("--dest", type=str) parser.add_argument("--dest", type=str)
parser.add_argument("--type", type=str, choices=["text_encoder", "unet"]) parser.add_argument("--type", type=str, choices=["text_encoder", "unet"])
parser.add_argument("--lora_models", nargs="+", type=str) parser.add_argument("--lora_models", nargs="+", type=str, default=[])
parser.add_argument("--lora_weights", nargs="+", type=float) parser.add_argument("--lora_weights", nargs="+", type=float, default=[])
args = parser.parse_args() args = parser.parse_args()
logger.info( logger.info(
@ -265,10 +263,17 @@ if __name__ == "__main__":
args.lora_weights, args.lora_weights,
) )
default_weight = 1.0 / len(args.lora_models)
while len(args.lora_weights) < len(args.lora_models):
args.lora_weights.append(default_weight)
blend_model = blend_loras( blend_model = blend_loras(
context, args.base, args.lora_models, args.type, args.lora_weights context,
args.base,
list(zip(args.lora_models, args.lora_weights)),
args.type,
) )
if args.dest is None or args.dest == "" or args.dest == "ort": if args.dest is None or args.dest == "" or args.dest == ":load":
# convert to external data and save to memory # convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model) (bare_model, external_data) = buffer_external_data_tensors(blend_model)
logger.info("saved external data for %s nodes", len(external_data)) logger.info("saved external data for %s nodes", len(external_data))

View File

@ -17,24 +17,30 @@ logger = getLogger(__name__)
@torch.no_grad() @torch.no_grad()
def blend_textual_inversions( def blend_textual_inversions(
context: ServerContext, context: ServerContext,
text_encoder: Optional[ModelProto], text_encoder: ModelProto,
tokenizer: Optional[CLIPTokenizer], tokenizer: CLIPTokenizer,
inversion_names: List[str], inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
inversion_formats: List[str],
inversion_weights: Optional[List[float]] = None,
base_tokens: Optional[List[str]] = None,
) -> Tuple[ModelProto, CLIPTokenizer]: ) -> Tuple[ModelProto, CLIPTokenizer]:
dtype = np.float dtype = np.float
embeds = {} embeds = {}
for name, format, weight, base_token in zip( for name, weight, base_token, format in inversions:
inversion_names, if base_token is None:
inversion_formats, base_token = name
inversion_weights,
base_tokens or inversion_names, if format is None:
): # TODO: detect concept format
logger.info("blending Textual Inversion %s with weight of %s", name, weight) format = "embeddings"
logger.info(
"blending Textual Inversion %s with weight of %s for token %s",
name,
weight,
base_token,
)
if format == "concept": if format == "concept":
# TODO: this should be done in fetch, maybe
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin") embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt") token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt")
@ -68,9 +74,10 @@ def blend_textual_inversions(
sum_layer = np.zeros(trained_embeds[0, :].shape) sum_layer = np.zeros(trained_embeds[0, :].shape)
for i in range(num_tokens): for i in range(num_tokens):
token = f"{base_token or name}-{i}" token = f"{base_token}-{i}"
layer = trained_embeds[i, :].cpu().numpy().astype(dtype) layer = trained_embeds[i, :].cpu().numpy().astype(dtype)
layer *= weight layer *= weight
sum_layer += layer sum_layer += layer
if token in embeds: if token in embeds:
embeds[token] += layer embeds[token] += layer
@ -78,7 +85,7 @@ def blend_textual_inversions(
embeds[token] = layer embeds[token] = layer
# add sum layer to embeds # add sum layer to embeds
sum_token = f"{base_token or name}-all" sum_token = f"{base_token}-all"
if sum_token in embeds: if sum_token in embeds:
embeds[sum_token] += sum_layer embeds[sum_token] += sum_layer
else: else:
@ -170,19 +177,16 @@ def convert_diffusion_textual_inversion(
context, context,
text_encoder, text_encoder,
tokenizer, tokenizer,
[inversion], [(inversion, weight, base_token, format)],
[format],
[weight],
base_token=(base_token or name),
) )
logger.info("saving tokenizer for textual inversion") logger.info("saving tokenizer for Textual Inversion")
tokenizer.save_pretrained(tokenizer_path) tokenizer.save_pretrained(tokenizer_path)
logger.info("saving text encoder for textual inversion") logger.info("saving text encoder for Textual Inversion")
save_model( save_model(
text_encoder, text_encoder,
f=encoder_model, f=encoder_model,
) )
logger.info("textual inversion saved to %s", dest_path) logger.info("Textual Inversion saved to %s", dest_path)

View File

@ -191,7 +191,7 @@ def load_yaml(file: str) -> Config:
return Config(data) return Config(data)
def remove_prefix(name, prefix): def remove_prefix(name: str, prefix: str) -> str:
if name.startswith(prefix): if name.startswith(prefix):
return name[len(prefix) :] return name[len(prefix) :]

View File

@ -234,10 +234,7 @@ def load_pipeline(
server, server,
text_encoder, text_encoder,
tokenizer, tokenizer,
inversion_models, list(zip(inversion_models, inversion_weights, inversion_names)),
["embeddings"] * len(inversion_names),
inversion_weights,
base_tokens=inversion_names,
) )
# should be pretty small and should not need external data # should be pretty small and should not need external data
@ -267,9 +264,8 @@ def load_pipeline(
text_encoder = blend_loras( text_encoder = blend_loras(
server, server,
text_encoder, text_encoder,
lora_models, list(zip(lora_models, lora_weights)),
"text_encoder", "text_encoder",
lora_weights=lora_weights,
) )
(text_encoder, text_encoder_data) = buffer_external_data_tensors( (text_encoder, text_encoder_data) = buffer_external_data_tensors(
text_encoder text_encoder
@ -291,9 +287,8 @@ def load_pipeline(
blended_unet = blend_loras( blended_unet = blend_loras(
server, server,
path.join(model, "unet", "model.onnx"), path.join(model, "unet", "model.onnx"),
lora_models, list(zip(lora_models, lora_weights)),
"unet", "unet",
lora_weights=lora_weights,
) )
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet) (unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data) unet_names, unet_values = zip(*unet_data)

View File

@ -255,8 +255,8 @@ The base token, without any layer numbers, should be printed to the logs with th
```none ```none
[2023-03-08 04:54:00,234] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: found embedding for token <concept>: torch.Size([768]) [2023-03-08 04:54:00,234] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: found embedding for token <concept>: torch.Size([768])
[2023-03-08 04:54:01,624] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: added 1 tokens [2023-03-08 04:54:01,624] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: added 1 tokens
[2023-03-08 04:54:01,814] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving tokenizer for textual inversion [2023-03-08 04:54:01,814] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving tokenizer for Textual Inversion
[2023-03-08 04:54:01,859] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving text encoder for textual inversion [2023-03-08 04:54:01,859] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving text encoder for Textual Inversion
``` ```
If you have set a custom token, that will be shown instead. If more than one token has been added, they will be If you have set a custom token, that will be shown instead. If more than one token has been added, they will be
@ -272,8 +272,8 @@ The number of layers is shown in the server logs when the model is converted:
```none ```none
[2023-03-08 04:54:00,234] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: found embedding for token <concept>: torch.Size([768]) [2023-03-08 04:54:00,234] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: found embedding for token <concept>: torch.Size([768])
[2023-03-08 04:54:01,624] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: added 1 tokens [2023-03-08 04:54:01,624] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: added 1 tokens
[2023-03-08 04:54:01,814] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving tokenizer for textual inversion [2023-03-08 04:54:01,814] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving tokenizer for Textual Inversion
[2023-03-08 04:54:01,859] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving text encoder for textual inversion [2023-03-08 04:54:01,859] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: saving text encoder for Textual Inversion
... ...
[2023-03-08 04:58:06,378] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: generating 74 layer tokens [2023-03-08 04:58:06,378] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: generating 74 layer tokens
[2023-03-08 04:58:06,379] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: found embedding for token ['goblin-0', 'goblin-1', 'goblin-2', 'goblin-3', 'goblin-4', 'goblin-5', 'goblin-6', 'gob [2023-03-08 04:58:06,379] INFO: MainProcess MainThread onnx_web.convert.diffusion.textual_inversion: found embedding for token ['goblin-0', 'goblin-1', 'goblin-2', 'goblin-3', 'goblin-4', 'goblin-5', 'goblin-6', 'gob