1
0
Fork 0

fix(api): write tests for embedding/inversion blending

This commit is contained in:
Sean Sube 2023-10-06 19:04:48 -05:00
parent ebdfa78737
commit e9b1375440
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 230 additions and 147 deletions

View File

@ -14,55 +14,23 @@ from ..utils import ConversionContext, load_tensor
logger = getLogger(__name__) logger = getLogger(__name__)
@torch.no_grad() def detect_embedding_format(loaded_embeds) -> str:
def blend_textual_inversions(
server: ServerContext,
text_encoder: ModelProto,
tokenizer: CLIPTokenizer,
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
) -> Tuple[ModelProto, CLIPTokenizer]:
# always load to CPU for blending
device = torch.device("cpu")
dtype = np.float32
embeds = {}
for name, weight, base_token, inversion_format in inversions:
if base_token is None:
logger.debug("no base token provided, using name: %s", name)
base_token = name
logger.info(
"blending Textual Inversion %s with weight of %s for token %s",
name,
weight,
base_token,
)
loaded_embeds = load_tensor(name, map_location=device)
if loaded_embeds is None:
logger.warning("unable to load tensor")
continue
if inversion_format is None:
keys: List[str] = list(loaded_embeds.keys()) keys: List[str] = list(loaded_embeds.keys())
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"): if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
logger.debug("detected Textual Inversion concept: %s", keys) logger.debug("detected Textual Inversion concept: %s", keys)
inversion_format = "concept" return "concept"
elif "emb_params" in keys: elif "emb_params" in keys:
logger.debug( logger.debug("detected Textual Inversion parameter embeddings: %s", keys)
"detected Textual Inversion parameter embeddings: %s", keys return "parameters"
)
inversion_format = "parameters"
elif "string_to_token" in keys and "string_to_param" in keys: elif "string_to_token" in keys and "string_to_param" in keys:
logger.debug("detected Textual Inversion token embeddings: %s", keys) logger.debug("detected Textual Inversion token embeddings: %s", keys)
inversion_format = "embeddings" return "embeddings"
else: else:
logger.error( logger.error("unknown Textual Inversion format, no recognized keys: %s", keys)
"unknown Textual Inversion format, no recognized keys: %s", keys return None
)
continue
if inversion_format == "concept":
def blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight):
# separate token and the embeds # separate token and the embeds
token = list(loaded_embeds.keys())[0] token = list(loaded_embeds.keys())[0]
@ -78,11 +46,13 @@ def blend_textual_inversions(
embeds[token] += layer embeds[token] += layer
else: else:
embeds[token] = layer embeds[token] = layer
elif inversion_format == "parameters":
def blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight):
emb_params = loaded_embeds["emb_params"] emb_params = loaded_embeds["emb_params"]
num_tokens = emb_params.shape[0] num_tokens = emb_params.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, name) logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
sum_layer = np.zeros(emb_params[0, :].shape) sum_layer = np.zeros(emb_params[0, :].shape)
@ -108,7 +78,9 @@ def blend_textual_inversions(
embeds[sum_token] += sum_layer embeds[sum_token] += sum_layer
else: else:
embeds[sum_token] = sum_layer embeds[sum_token] = sum_layer
elif inversion_format == "embeddings":
def blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight):
string_to_token = loaded_embeds["string_to_token"] string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"] string_to_param = loaded_embeds["string_to_param"]
@ -117,7 +89,7 @@ def blend_textual_inversions(
trained_embeds = string_to_param[token] trained_embeds = string_to_param[token]
num_tokens = trained_embeds.shape[0] num_tokens = trained_embeds.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, name) logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
sum_layer = np.zeros(trained_embeds[0, :].shape) sum_layer = np.zeros(trained_embeds[0, :].shape)
@ -143,23 +115,9 @@ def blend_textual_inversions(
embeds[sum_token] += sum_layer embeds[sum_token] += sum_layer
else: else:
embeds[sum_token] = sum_layer embeds[sum_token] = sum_layer
else:
raise ValueError(f"unknown Textual Inversion format: {inversion_format}")
# add the tokens to the tokenizer
logger.debug(
"found embeddings for %s tokens: %s",
len(embeds.keys()),
list(embeds.keys()),
)
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
)
logger.trace("added %s tokens", num_added_tokens)
def blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens):
# resize the token embeddings # resize the token embeddings
# text_encoder.resize_token_embeddings(len(tokenizer)) # text_encoder.resize_token_embeddings(len(tokenizer))
embedding_node = [ embedding_node = [
@ -191,6 +149,59 @@ def blend_textual_inversions(
del text_encoder.graph.initializer[i] del text_encoder.graph.initializer[i]
text_encoder.graph.initializer.insert(i, new_initializer) text_encoder.graph.initializer.insert(i, new_initializer)
@torch.no_grad()
def blend_textual_inversions(
server: ServerContext,
text_encoder: ModelProto,
tokenizer: CLIPTokenizer,
embeddings: List[Tuple[str, float, Optional[str], Optional[str]]],
) -> Tuple[ModelProto, CLIPTokenizer]:
# always load to CPU for blending
device = torch.device("cpu")
dtype = np.float32
embeds = {}
for name, weight, base_token, format in embeddings:
if base_token is None:
logger.debug("no base token provided, using name: %s", name)
base_token = name
logger.info(
"blending Textual Inversion %s with weight of %s for token %s",
name,
weight,
base_token,
)
loaded_embeds = load_tensor(name, map_location=device)
if loaded_embeds is None:
logger.warning("unable to load tensor")
continue
if format is None:
format = detect_embedding_format()
if format == "concept":
blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight)
elif format == "parameters":
blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight)
elif format == "embeddings":
blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight)
else:
raise ValueError(f"unknown Textual Inversion format: {format}")
# add the tokens to the tokenizer
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0:
raise ValueError(
"The tokenizer already contains the tokens. Please pass a different `token` that is not already in the tokenizer."
)
logger.trace("added %s tokens", num_added_tokens)
blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens)
return (text_encoder, tokenizer) return (text_encoder, tokenizer)

View File

@ -1,3 +1,5 @@
from typing import Tuple
from PIL import Image, ImageChops from PIL import Image, ImageChops
from ..params import Border, Size from ..params import Border, Size
@ -13,7 +15,7 @@ def expand_image(
fill="white", fill="white",
noise_source=noise_source_histogram, noise_source=noise_source_histogram,
mask_filter=mask_filter_none, mask_filter=mask_filter_none,
): ) -> Tuple[Image.Image, Image.Image, Image.Image, Tuple[int]]:
size = Size(*source.size).add_border(expand) size = Size(*source.size).add_border(expand)
size = tuple(size) size = tuple(size)
origin = (expand.left, expand.top) origin = (expand.left, expand.top)

View File

@ -2,8 +2,16 @@ import sys
from functools import partial from functools import partial
from logging import getLogger from logging import getLogger
from os import path from os import path
from pathlib import Path
from typing import Dict, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
from optimum.onnxruntime.modeling_diffusion import (
ORTModel,
ORTStableDiffusionPipelineBase,
)
from ..torch_before_ort import SessionOptions
from ..utils import run_gc from ..utils import run_gc
from .context import ServerContext from .context import ServerContext

View File

@ -27,10 +27,14 @@ MEMORY_ERRORS = [
] ]
def worker_main(worker: WorkerContext, server: ServerContext, *args): def worker_main(
apply_patches(server) worker: WorkerContext, server: ServerContext, *args, exit=exit, patch=True
):
setproctitle("onnx-web worker: %s" % (worker.device.device)) setproctitle("onnx-web worker: %s" % (worker.device.device))
if patch:
apply_patches(server)
logger.trace( logger.trace(
"checking in from worker with providers: %s", get_available_providers() "checking in from worker with providers: %s", get_available_providers()
) )

View File

@ -4,11 +4,19 @@ from onnx_web.convert.utils import (
DEFAULT_OPSET, DEFAULT_OPSET,
ConversionContext, ConversionContext,
download_progress, download_progress,
remove_prefix,
resolve_tensor,
source_format,
tuple_to_correction, tuple_to_correction,
tuple_to_diffusion, tuple_to_diffusion,
tuple_to_source, tuple_to_source,
tuple_to_upscaling, tuple_to_upscaling,
) )
from tests.helpers import (
TEST_MODEL_DIFFUSION_SD15,
TEST_MODEL_UPSCALING_SWINIR,
test_needs_models,
)
class ConversionContextTests(unittest.TestCase): class ConversionContextTests(unittest.TestCase):
@ -182,3 +190,51 @@ class TupleToUpscalingTests(unittest.TestCase):
self.assertEqual(source["scale"], 2) self.assertEqual(source["scale"], 2)
self.assertEqual(source["half"], True) self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14) self.assertEqual(source["opset"], 14)
class SourceFormatTests(unittest.TestCase):
def test_with_format(self):
result = source_format({
"format": "foo",
})
self.assertEqual(result, "foo")
def test_source_known_extension(self):
result = source_format({
"source": "foo.safetensors",
})
self.assertEqual(result, "safetensors")
def test_source_unknown_extension(self):
result = source_format({
"source": "foo.none"
})
self.assertEqual(result, None)
def test_incomplete_model(self):
self.assertIsNone(source_format({}))
class RemovePrefixTests(unittest.TestCase):
def test_with_prefix(self):
self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar")
def test_without_prefix(self):
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
class LoadTorchTests(unittest.TestCase):
pass
class LoadTensorTests(unittest.TestCase):
pass
class ResolveTensorTests(unittest.TestCase):
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
def test_resolve_existing(self):
self.assertEqual(resolve_tensor("../models/.cache/upscaling-swinir"), TEST_MODEL_UPSCALING_SWINIR)
def test_resolve_missing(self):
self.assertIsNone(resolve_tensor("missing"))

View File

@ -14,3 +14,4 @@ def test_device() -> DeviceParams:
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5" TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
TEST_MODEL_UPSCALING_SWINIR = "../models/.cache/upscaling-swinir.pth"

View File

@ -90,6 +90,7 @@
"spinalcase", "spinalcase",
"stabilityai", "stabilityai",
"stringcase", "stringcase",
"swinir",
"timestep", "timestep",
"timesteps", "timesteps",
"tojson", "tojson",