From e9b13754406929049bd30b69b63e88ce4bc000c7 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 6 Oct 2023 19:04:48 -0500 Subject: [PATCH] fix(api): write tests for embedding/inversion blending --- .../convert/diffusion/textual_inversion.py | 297 +++++++++--------- api/onnx_web/image/utils.py | 4 +- api/onnx_web/server/hacks.py | 8 + api/onnx_web/worker/worker.py | 8 +- api/tests/convert/test_utils.py | 56 ++++ api/tests/helpers.py | 3 +- onnx-web.code-workspace | 1 + 7 files changed, 230 insertions(+), 147 deletions(-) diff --git a/api/onnx_web/convert/diffusion/textual_inversion.py b/api/onnx_web/convert/diffusion/textual_inversion.py index 3eece453..0a5ed755 100644 --- a/api/onnx_web/convert/diffusion/textual_inversion.py +++ b/api/onnx_web/convert/diffusion/textual_inversion.py @@ -14,19 +14,155 @@ from ..utils import ConversionContext, load_tensor logger = getLogger(__name__) +def detect_embedding_format(loaded_embeds) -> str: + keys: List[str] = list(loaded_embeds.keys()) + if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"): + logger.debug("detected Textual Inversion concept: %s", keys) + return "concept" + elif "emb_params" in keys: + logger.debug("detected Textual Inversion parameter embeddings: %s", keys) + return "parameters" + elif "string_to_token" in keys and "string_to_param" in keys: + logger.debug("detected Textual Inversion token embeddings: %s", keys) + return "embeddings" + else: + logger.error("unknown Textual Inversion format, no recognized keys: %s", keys) + return None + + +def blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight): + # separate token and the embeds + token = list(loaded_embeds.keys())[0] + + layer = loaded_embeds[token].numpy().astype(dtype) + layer *= weight + + if base_token in embeds: + embeds[base_token] += layer + else: + embeds[base_token] = layer + + if token in embeds: + embeds[token] += layer + else: + embeds[token] = layer + + +def blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight): + emb_params = loaded_embeds["emb_params"] + + num_tokens = emb_params.shape[0] + logger.debug("generating %s layer tokens for %s", num_tokens, base_token) + + sum_layer = np.zeros(emb_params[0, :].shape) + + for i in range(num_tokens): + token = f"{base_token}-{i}" + layer = emb_params[i, :].numpy().astype(dtype) + layer *= weight + + sum_layer += layer + if token in embeds: + embeds[token] += layer + else: + embeds[token] = layer + + # add base and sum tokens to embeds + if base_token in embeds: + embeds[base_token] += sum_layer + else: + embeds[base_token] = sum_layer + + sum_token = f"{base_token}-all" + if sum_token in embeds: + embeds[sum_token] += sum_layer + else: + embeds[sum_token] = sum_layer + + +def blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight): + string_to_token = loaded_embeds["string_to_token"] + string_to_param = loaded_embeds["string_to_param"] + + # separate token and embeds + token = list(string_to_token.keys())[0] + trained_embeds = string_to_param[token] + + num_tokens = trained_embeds.shape[0] + logger.debug("generating %s layer tokens for %s", num_tokens, base_token) + + sum_layer = np.zeros(trained_embeds[0, :].shape) + + for i in range(num_tokens): + token = f"{base_token}-{i}" + layer = trained_embeds[i, :].numpy().astype(dtype) + layer *= weight + + sum_layer += layer + if token in embeds: + embeds[token] += layer + else: + embeds[token] = layer + + # add base and sum tokens to embeds + if base_token in embeds: + embeds[base_token] += sum_layer + else: + embeds[base_token] = sum_layer + + sum_token = f"{base_token}-all" + if sum_token in embeds: + embeds[sum_token] += sum_layer + else: + embeds[sum_token] = sum_layer + + +def blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens): + # resize the token embeddings + # text_encoder.resize_token_embeddings(len(tokenizer)) + embedding_node = [ + n + for n in text_encoder.graph.initializer + if n.name == "text_model.embeddings.token_embedding.weight" + ][0] + base_weights = numpy_helper.to_array(embedding_node) + + weights_dim = base_weights.shape[1] + zero_weights = np.zeros((num_added_tokens, weights_dim)) + embedding_weights = np.concatenate((base_weights, zero_weights), axis=0) + + for token, weights in embeds.items(): + token_id = tokenizer.convert_tokens_to_ids(token) + logger.trace("embedding %s weights for token %s", weights.shape, token) + embedding_weights[token_id] = weights + + # replace embedding_node + for i in range(len(text_encoder.graph.initializer)): + if ( + text_encoder.graph.initializer[i].name + == "text_model.embeddings.token_embedding.weight" + ): + new_initializer = numpy_helper.from_array( + embedding_weights.astype(base_weights.dtype), embedding_node.name + ) + logger.trace("new initializer data type: %s", new_initializer.data_type) + del text_encoder.graph.initializer[i] + text_encoder.graph.initializer.insert(i, new_initializer) + + @torch.no_grad() def blend_textual_inversions( server: ServerContext, text_encoder: ModelProto, tokenizer: CLIPTokenizer, - inversions: List[Tuple[str, float, Optional[str], Optional[str]]], + embeddings: List[Tuple[str, float, Optional[str], Optional[str]]], ) -> Tuple[ModelProto, CLIPTokenizer]: # always load to CPU for blending device = torch.device("cpu") dtype = np.float32 embeds = {} - for name, weight, base_token, inversion_format in inversions: + for name, weight, base_token, format in embeddings: if base_token is None: logger.debug("no base token provided, using name: %s", name) base_token = name @@ -43,153 +179,28 @@ def blend_textual_inversions( logger.warning("unable to load tensor") continue - if inversion_format is None: - keys: List[str] = list(loaded_embeds.keys()) - if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"): - logger.debug("detected Textual Inversion concept: %s", keys) - inversion_format = "concept" - elif "emb_params" in keys: - logger.debug( - "detected Textual Inversion parameter embeddings: %s", keys - ) - inversion_format = "parameters" - elif "string_to_token" in keys and "string_to_param" in keys: - logger.debug("detected Textual Inversion token embeddings: %s", keys) - inversion_format = "embeddings" - else: - logger.error( - "unknown Textual Inversion format, no recognized keys: %s", keys - ) - continue + if format is None: + format = detect_embedding_format() - if inversion_format == "concept": - # separate token and the embeds - token = list(loaded_embeds.keys())[0] - - layer = loaded_embeds[token].numpy().astype(dtype) - layer *= weight - - if base_token in embeds: - embeds[base_token] += layer - else: - embeds[base_token] = layer - - if token in embeds: - embeds[token] += layer - else: - embeds[token] = layer - elif inversion_format == "parameters": - emb_params = loaded_embeds["emb_params"] - - num_tokens = emb_params.shape[0] - logger.debug("generating %s layer tokens for %s", num_tokens, name) - - sum_layer = np.zeros(emb_params[0, :].shape) - - for i in range(num_tokens): - token = f"{base_token}-{i}" - layer = emb_params[i, :].numpy().astype(dtype) - layer *= weight - - sum_layer += layer - if token in embeds: - embeds[token] += layer - else: - embeds[token] = layer - - # add base and sum tokens to embeds - if base_token in embeds: - embeds[base_token] += sum_layer - else: - embeds[base_token] = sum_layer - - sum_token = f"{base_token}-all" - if sum_token in embeds: - embeds[sum_token] += sum_layer - else: - embeds[sum_token] = sum_layer - elif inversion_format == "embeddings": - string_to_token = loaded_embeds["string_to_token"] - string_to_param = loaded_embeds["string_to_param"] - - # separate token and embeds - token = list(string_to_token.keys())[0] - trained_embeds = string_to_param[token] - - num_tokens = trained_embeds.shape[0] - logger.debug("generating %s layer tokens for %s", num_tokens, name) - - sum_layer = np.zeros(trained_embeds[0, :].shape) - - for i in range(num_tokens): - token = f"{base_token}-{i}" - layer = trained_embeds[i, :].numpy().astype(dtype) - layer *= weight - - sum_layer += layer - if token in embeds: - embeds[token] += layer - else: - embeds[token] = layer - - # add base and sum tokens to embeds - if base_token in embeds: - embeds[base_token] += sum_layer - else: - embeds[base_token] = sum_layer - - sum_token = f"{base_token}-all" - if sum_token in embeds: - embeds[sum_token] += sum_layer - else: - embeds[sum_token] = sum_layer + if format == "concept": + blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight) + elif format == "parameters": + blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight) + elif format == "embeddings": + blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight) else: - raise ValueError(f"unknown Textual Inversion format: {inversion_format}") + raise ValueError(f"unknown Textual Inversion format: {format}") - # add the tokens to the tokenizer - logger.debug( - "found embeddings for %s tokens: %s", - len(embeds.keys()), - list(embeds.keys()), + # add the tokens to the tokenizer + num_added_tokens = tokenizer.add_tokens(list(embeds.keys())) + if num_added_tokens == 0: + raise ValueError( + "The tokenizer already contains the tokens. Please pass a different `token` that is not already in the tokenizer." ) - num_added_tokens = tokenizer.add_tokens(list(embeds.keys())) - if num_added_tokens == 0: - raise ValueError( - f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer." - ) - logger.trace("added %s tokens", num_added_tokens) + logger.trace("added %s tokens", num_added_tokens) - # resize the token embeddings - # text_encoder.resize_token_embeddings(len(tokenizer)) - embedding_node = [ - n - for n in text_encoder.graph.initializer - if n.name == "text_model.embeddings.token_embedding.weight" - ][0] - base_weights = numpy_helper.to_array(embedding_node) - - weights_dim = base_weights.shape[1] - zero_weights = np.zeros((num_added_tokens, weights_dim)) - embedding_weights = np.concatenate((base_weights, zero_weights), axis=0) - - for token, weights in embeds.items(): - token_id = tokenizer.convert_tokens_to_ids(token) - logger.trace("embedding %s weights for token %s", weights.shape, token) - embedding_weights[token_id] = weights - - # replace embedding_node - for i in range(len(text_encoder.graph.initializer)): - if ( - text_encoder.graph.initializer[i].name - == "text_model.embeddings.token_embedding.weight" - ): - new_initializer = numpy_helper.from_array( - embedding_weights.astype(base_weights.dtype), embedding_node.name - ) - logger.trace("new initializer data type: %s", new_initializer.data_type) - del text_encoder.graph.initializer[i] - text_encoder.graph.initializer.insert(i, new_initializer) + blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens) return (text_encoder, tokenizer) diff --git a/api/onnx_web/image/utils.py b/api/onnx_web/image/utils.py index 80972080..4e2f3a7a 100644 --- a/api/onnx_web/image/utils.py +++ b/api/onnx_web/image/utils.py @@ -1,3 +1,5 @@ +from typing import Tuple + from PIL import Image, ImageChops from ..params import Border, Size @@ -13,7 +15,7 @@ def expand_image( fill="white", noise_source=noise_source_histogram, mask_filter=mask_filter_none, -): +) -> Tuple[Image.Image, Image.Image, Image.Image, Tuple[int]]: size = Size(*source.size).add_border(expand) size = tuple(size) origin = (expand.left, expand.top) diff --git a/api/onnx_web/server/hacks.py b/api/onnx_web/server/hacks.py index b59bb73a..f51b51f4 100644 --- a/api/onnx_web/server/hacks.py +++ b/api/onnx_web/server/hacks.py @@ -2,8 +2,16 @@ import sys from functools import partial from logging import getLogger from os import path +from pathlib import Path +from typing import Dict, Optional, Union from urllib.parse import urlparse +from optimum.onnxruntime.modeling_diffusion import ( + ORTModel, + ORTStableDiffusionPipelineBase, +) + +from ..torch_before_ort import SessionOptions from ..utils import run_gc from .context import ServerContext diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 5377c42a..a55ba4a2 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -27,10 +27,14 @@ MEMORY_ERRORS = [ ] -def worker_main(worker: WorkerContext, server: ServerContext, *args): - apply_patches(server) +def worker_main( + worker: WorkerContext, server: ServerContext, *args, exit=exit, patch=True +): setproctitle("onnx-web worker: %s" % (worker.device.device)) + if patch: + apply_patches(server) + logger.trace( "checking in from worker with providers: %s", get_available_providers() ) diff --git a/api/tests/convert/test_utils.py b/api/tests/convert/test_utils.py index 45c8fccc..f08f0d0c 100644 --- a/api/tests/convert/test_utils.py +++ b/api/tests/convert/test_utils.py @@ -4,11 +4,19 @@ from onnx_web.convert.utils import ( DEFAULT_OPSET, ConversionContext, download_progress, + remove_prefix, + resolve_tensor, + source_format, tuple_to_correction, tuple_to_diffusion, tuple_to_source, tuple_to_upscaling, ) +from tests.helpers import ( + TEST_MODEL_DIFFUSION_SD15, + TEST_MODEL_UPSCALING_SWINIR, + test_needs_models, +) class ConversionContextTests(unittest.TestCase): @@ -182,3 +190,51 @@ class TupleToUpscalingTests(unittest.TestCase): self.assertEqual(source["scale"], 2) self.assertEqual(source["half"], True) self.assertEqual(source["opset"], 14) + + +class SourceFormatTests(unittest.TestCase): + def test_with_format(self): + result = source_format({ + "format": "foo", + }) + self.assertEqual(result, "foo") + + def test_source_known_extension(self): + result = source_format({ + "source": "foo.safetensors", + }) + self.assertEqual(result, "safetensors") + + def test_source_unknown_extension(self): + result = source_format({ + "source": "foo.none" + }) + self.assertEqual(result, None) + + def test_incomplete_model(self): + self.assertIsNone(source_format({})) + + +class RemovePrefixTests(unittest.TestCase): + def test_with_prefix(self): + self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar") + + def test_without_prefix(self): + self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar") + + +class LoadTorchTests(unittest.TestCase): + pass + + +class LoadTensorTests(unittest.TestCase): + pass + + +class ResolveTensorTests(unittest.TestCase): + @test_needs_models([TEST_MODEL_UPSCALING_SWINIR]) + def test_resolve_existing(self): + self.assertEqual(resolve_tensor("../models/.cache/upscaling-swinir"), TEST_MODEL_UPSCALING_SWINIR) + + def test_resolve_missing(self): + self.assertIsNone(resolve_tensor("missing")) diff --git a/api/tests/helpers.py b/api/tests/helpers.py index 586ecbd8..3b6716b2 100644 --- a/api/tests/helpers.py +++ b/api/tests/helpers.py @@ -13,4 +13,5 @@ def test_device() -> DeviceParams: return DeviceParams("cpu", "CPUExecutionProvider") -TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5" \ No newline at end of file +TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5" +TEST_MODEL_UPSCALING_SWINIR = "../models/.cache/upscaling-swinir.pth" diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace index d3965a5f..9674ca2f 100644 --- a/onnx-web.code-workspace +++ b/onnx-web.code-workspace @@ -90,6 +90,7 @@ "spinalcase", "stabilityai", "stringcase", + "swinir", "timestep", "timesteps", "tojson",