diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 7c113d95..0c9edd30 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -15,6 +15,7 @@ from onnx.external_data_helper import ( from onnxruntime import InferenceSession, OrtValue, SessionOptions from safetensors.torch import load_file +from ...server.context import ServerContext from ..utils import ConversionContext logger = getLogger(__name__) @@ -55,6 +56,7 @@ def fix_node_name(key: str): def blend_loras( + context: ServerContext, base_name: str, lora_names: List[str], dest_type: Literal["text_encoder", "unet"], @@ -236,6 +238,7 @@ def blend_loras( if __name__ == "__main__": + context = ConversionContext.from_environ() parser = ArgumentParser() parser.add_argument("--base", type=str) parser.add_argument("--dest", type=str) @@ -251,7 +254,9 @@ if __name__ == "__main__": args.lora_weights, ) - blend_model = blend_loras(args.base, args.lora_models, args.type, args.lora_weights) + blend_model = blend_loras( + context, args.base, args.lora_models, args.type, args.lora_weights + ) if args.dest is None or args.dest == "" or args.dest == "ort": # convert to external data and save to memory (bare_model, external_data) = buffer_external_data_tensors(blend_model) diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index 3ffe8415..2644282c 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -6,7 +6,6 @@ import numpy as np import torch from huggingface_hub.file_download import hf_hub_download from onnx import ModelProto, load_model, numpy_helper, save_model -from torch.onnx import export from transformers import CLIPTokenizer from ...server.context import ServerContext @@ -25,11 +24,15 @@ def blend_textual_inversions( inversion_weights: Optional[List[float]] = None, base_tokens: Optional[List[str]] = None, ) -> Tuple[ModelProto, CLIPTokenizer]: - dtype = np.float # TODO: fixed type, which one? - # prev: text_encoder.get_input_embeddings().weight.dtype + dtype = np.float embeds = {} - for name, format, weight, base_token in zip(inversion_names, inversion_formats, inversion_weights, base_tokens or inversion_names): + for name, format, weight, base_token in zip( + inversion_names, + inversion_formats, + inversion_weights, + base_tokens or inversion_names, + ): logger.info("blending Textual Inversion %s with weight of %s", name, weight) if format == "concept": embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin") @@ -64,7 +67,7 @@ def blend_textual_inversions( for i in range(num_tokens): token = f"{base_token or name}-{i}" - layer = trained_embeds[i,:].cpu().numpy().astype(dtype) + layer = trained_embeds[i, :].cpu().numpy().astype(dtype) layer *= weight if token in embeds: embeds[token] += layer @@ -74,7 +77,9 @@ def blend_textual_inversions( raise ValueError(f"unknown Textual Inversion format: {format}") # add the tokens to the tokenizer - logger.info("found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys()) + logger.info( + "found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys() + ) num_added_tokens = tokenizer.add_tokens(list(embeds.keys())) if num_added_tokens == 0: raise ValueError( @@ -85,7 +90,11 @@ def blend_textual_inversions( # resize the token embeddings # text_encoder.resize_token_embeddings(len(tokenizer)) - embedding_node = [n for n in text_encoder.graph.initializer if n.name == "text_model.embeddings.token_embedding.weight"][0] + embedding_node = [ + n + for n in text_encoder.graph.initializer + if n.name == "text_model.embeddings.token_embedding.weight" + ][0] embedding_weights = numpy_helper.to_array(embedding_node) weights_dim = embedding_weights.shape[1] @@ -94,15 +103,18 @@ def blend_textual_inversions( for token, weights in embeds.items(): token_id = tokenizer.convert_tokens_to_ids(token) - logger.debug( - "embedding %s weights for token %s", weights.shape, token - ) + logger.debug("embedding %s weights for token %s", weights.shape, token) embedding_weights[token_id] = weights # replace embedding_node for i in range(len(text_encoder.graph.initializer)): - if text_encoder.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight": - new_initializer = numpy_helper.from_array(embedding_weights.astype(np.float32), embedding_node.name) + if ( + text_encoder.graph.initializer[i].name + == "text_model.embeddings.token_embedding.weight" + ): + new_initializer = numpy_helper.from_array( + embedding_weights.astype(np.float32), embedding_node.name + ) logger.debug("new initializer data type: %s", new_initializer.data_type) del text_encoder.graph.initializer[i] text_encoder.graph.initializer.insert(i, new_initializer) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 61a14c67..7236e120 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -221,7 +221,10 @@ def load_pipeline( inversion_names, inversion_weights = zip(*inversions) logger.debug("blending Textual Inversions from %s", inversion_names) - inversion_models = [path.join(server.model_path, "inversion", f"{name}.ckpt") for name in inversion_names] + inversion_models = [ + path.join(server.model_path, "inversion", f"{name}.ckpt") + for name in inversion_names + ] text_encoder = load_model(path.join(model, "text_encoder", "model.onnx")) tokenizer = CLIPTokenizer.from_pretrained( model, @@ -249,16 +252,33 @@ def load_pipeline( # test LoRA blending if loras is not None and len(loras) > 0: lora_names, lora_weights = zip(*loras) - lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names] - logger.info("blending base model %s with LoRA models: %s", model, lora_models) + lora_models = [ + path.join(server.model_path, "lora", f"{name}.safetensors") + for name in lora_names + ] + logger.info( + "blending base model %s with LoRA models: %s", model, lora_models + ) # blend and load text encoder - text_encoder = text_encoder or path.join(model, "text_encoder", "model.onnx") - blended_text_encoder = blend_loras(text_encoder, lora_models, "text_encoder", lora_weights=lora_weights) - (text_encoder, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder) + text_encoder = text_encoder or path.join( + model, "text_encoder", "model.onnx" + ) + text_encoder = blend_loras( + server, + text_encoder, + lora_models, + "text_encoder", + lora_weights=lora_weights, + ) + (text_encoder, text_encoder_data) = buffer_external_data_tensors( + text_encoder + ) text_encoder_names, text_encoder_values = zip(*text_encoder_data) text_encoder_opts = SessionOptions() - text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values)) + text_encoder_opts.add_external_initializers( + list(text_encoder_names), list(text_encoder_values) + ) components["text_encoder"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( text_encoder.SerializeToString(), @@ -268,7 +288,13 @@ def load_pipeline( ) # blend and load unet - blended_unet = blend_loras(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights) + blended_unet = blend_loras( + server, + path.join(model, "unet", "model.onnx"), + lora_models, + "unet", + lora_weights=lora_weights, + ) (unet_model, unet_data) = buffer_external_data_tensors(blended_unet) unet_names, unet_values = zip(*unet_data) unet_opts = SessionOptions() diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index e1d34303..6e8c261b 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -1,6 +1,6 @@ from logging import getLogger from math import ceil -from re import compile, Pattern +from re import Pattern, compile from typing import List, Optional, Tuple import numpy as np @@ -132,7 +132,9 @@ def expand_prompt( return prompt_embeds -def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, List[Tuple[str, float]]]: +def get_tokens_from_prompt( + prompt: str, pattern: Pattern[str] +) -> Tuple[str, List[Tuple[str, float]]]: """ TODO: replace with Arpeggio """ @@ -145,7 +147,10 @@ def get_tokens_from_prompt(prompt: str, pattern: Pattern[str]) -> Tuple[str, Lis name, weight = next_match.groups() tokens.append((name, float(weight))) # remove this match and look for another - remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():] + remaining_prompt = ( + remaining_prompt[: next_match.start()] + + remaining_prompt[next_match.end() :] + ) next_match = pattern.search(remaining_prompt) return (remaining_prompt, tokens)