fix(api): write tests for embedding/inversion blending
This commit is contained in:
parent
ebdfa78737
commit
e9b1375440
|
@ -14,19 +14,155 @@ from ..utils import ConversionContext, load_tensor
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_embedding_format(loaded_embeds) -> str:
|
||||||
|
keys: List[str] = list(loaded_embeds.keys())
|
||||||
|
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
|
||||||
|
logger.debug("detected Textual Inversion concept: %s", keys)
|
||||||
|
return "concept"
|
||||||
|
elif "emb_params" in keys:
|
||||||
|
logger.debug("detected Textual Inversion parameter embeddings: %s", keys)
|
||||||
|
return "parameters"
|
||||||
|
elif "string_to_token" in keys and "string_to_param" in keys:
|
||||||
|
logger.debug("detected Textual Inversion token embeddings: %s", keys)
|
||||||
|
return "embeddings"
|
||||||
|
else:
|
||||||
|
logger.error("unknown Textual Inversion format, no recognized keys: %s", keys)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight):
|
||||||
|
# separate token and the embeds
|
||||||
|
token = list(loaded_embeds.keys())[0]
|
||||||
|
|
||||||
|
layer = loaded_embeds[token].numpy().astype(dtype)
|
||||||
|
layer *= weight
|
||||||
|
|
||||||
|
if base_token in embeds:
|
||||||
|
embeds[base_token] += layer
|
||||||
|
else:
|
||||||
|
embeds[base_token] = layer
|
||||||
|
|
||||||
|
if token in embeds:
|
||||||
|
embeds[token] += layer
|
||||||
|
else:
|
||||||
|
embeds[token] = layer
|
||||||
|
|
||||||
|
|
||||||
|
def blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight):
|
||||||
|
emb_params = loaded_embeds["emb_params"]
|
||||||
|
|
||||||
|
num_tokens = emb_params.shape[0]
|
||||||
|
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
|
||||||
|
|
||||||
|
sum_layer = np.zeros(emb_params[0, :].shape)
|
||||||
|
|
||||||
|
for i in range(num_tokens):
|
||||||
|
token = f"{base_token}-{i}"
|
||||||
|
layer = emb_params[i, :].numpy().astype(dtype)
|
||||||
|
layer *= weight
|
||||||
|
|
||||||
|
sum_layer += layer
|
||||||
|
if token in embeds:
|
||||||
|
embeds[token] += layer
|
||||||
|
else:
|
||||||
|
embeds[token] = layer
|
||||||
|
|
||||||
|
# add base and sum tokens to embeds
|
||||||
|
if base_token in embeds:
|
||||||
|
embeds[base_token] += sum_layer
|
||||||
|
else:
|
||||||
|
embeds[base_token] = sum_layer
|
||||||
|
|
||||||
|
sum_token = f"{base_token}-all"
|
||||||
|
if sum_token in embeds:
|
||||||
|
embeds[sum_token] += sum_layer
|
||||||
|
else:
|
||||||
|
embeds[sum_token] = sum_layer
|
||||||
|
|
||||||
|
|
||||||
|
def blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight):
|
||||||
|
string_to_token = loaded_embeds["string_to_token"]
|
||||||
|
string_to_param = loaded_embeds["string_to_param"]
|
||||||
|
|
||||||
|
# separate token and embeds
|
||||||
|
token = list(string_to_token.keys())[0]
|
||||||
|
trained_embeds = string_to_param[token]
|
||||||
|
|
||||||
|
num_tokens = trained_embeds.shape[0]
|
||||||
|
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
|
||||||
|
|
||||||
|
sum_layer = np.zeros(trained_embeds[0, :].shape)
|
||||||
|
|
||||||
|
for i in range(num_tokens):
|
||||||
|
token = f"{base_token}-{i}"
|
||||||
|
layer = trained_embeds[i, :].numpy().astype(dtype)
|
||||||
|
layer *= weight
|
||||||
|
|
||||||
|
sum_layer += layer
|
||||||
|
if token in embeds:
|
||||||
|
embeds[token] += layer
|
||||||
|
else:
|
||||||
|
embeds[token] = layer
|
||||||
|
|
||||||
|
# add base and sum tokens to embeds
|
||||||
|
if base_token in embeds:
|
||||||
|
embeds[base_token] += sum_layer
|
||||||
|
else:
|
||||||
|
embeds[base_token] = sum_layer
|
||||||
|
|
||||||
|
sum_token = f"{base_token}-all"
|
||||||
|
if sum_token in embeds:
|
||||||
|
embeds[sum_token] += sum_layer
|
||||||
|
else:
|
||||||
|
embeds[sum_token] = sum_layer
|
||||||
|
|
||||||
|
|
||||||
|
def blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens):
|
||||||
|
# resize the token embeddings
|
||||||
|
# text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
embedding_node = [
|
||||||
|
n
|
||||||
|
for n in text_encoder.graph.initializer
|
||||||
|
if n.name == "text_model.embeddings.token_embedding.weight"
|
||||||
|
][0]
|
||||||
|
base_weights = numpy_helper.to_array(embedding_node)
|
||||||
|
|
||||||
|
weights_dim = base_weights.shape[1]
|
||||||
|
zero_weights = np.zeros((num_added_tokens, weights_dim))
|
||||||
|
embedding_weights = np.concatenate((base_weights, zero_weights), axis=0)
|
||||||
|
|
||||||
|
for token, weights in embeds.items():
|
||||||
|
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||||
|
logger.trace("embedding %s weights for token %s", weights.shape, token)
|
||||||
|
embedding_weights[token_id] = weights
|
||||||
|
|
||||||
|
# replace embedding_node
|
||||||
|
for i in range(len(text_encoder.graph.initializer)):
|
||||||
|
if (
|
||||||
|
text_encoder.graph.initializer[i].name
|
||||||
|
== "text_model.embeddings.token_embedding.weight"
|
||||||
|
):
|
||||||
|
new_initializer = numpy_helper.from_array(
|
||||||
|
embedding_weights.astype(base_weights.dtype), embedding_node.name
|
||||||
|
)
|
||||||
|
logger.trace("new initializer data type: %s", new_initializer.data_type)
|
||||||
|
del text_encoder.graph.initializer[i]
|
||||||
|
text_encoder.graph.initializer.insert(i, new_initializer)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def blend_textual_inversions(
|
def blend_textual_inversions(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
text_encoder: ModelProto,
|
text_encoder: ModelProto,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
|
embeddings: List[Tuple[str, float, Optional[str], Optional[str]]],
|
||||||
) -> Tuple[ModelProto, CLIPTokenizer]:
|
) -> Tuple[ModelProto, CLIPTokenizer]:
|
||||||
# always load to CPU for blending
|
# always load to CPU for blending
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = np.float32
|
dtype = np.float32
|
||||||
embeds = {}
|
embeds = {}
|
||||||
|
|
||||||
for name, weight, base_token, inversion_format in inversions:
|
for name, weight, base_token, format in embeddings:
|
||||||
if base_token is None:
|
if base_token is None:
|
||||||
logger.debug("no base token provided, using name: %s", name)
|
logger.debug("no base token provided, using name: %s", name)
|
||||||
base_token = name
|
base_token = name
|
||||||
|
@ -43,153 +179,28 @@ def blend_textual_inversions(
|
||||||
logger.warning("unable to load tensor")
|
logger.warning("unable to load tensor")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if inversion_format is None:
|
if format is None:
|
||||||
keys: List[str] = list(loaded_embeds.keys())
|
format = detect_embedding_format()
|
||||||
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
|
|
||||||
logger.debug("detected Textual Inversion concept: %s", keys)
|
|
||||||
inversion_format = "concept"
|
|
||||||
elif "emb_params" in keys:
|
|
||||||
logger.debug(
|
|
||||||
"detected Textual Inversion parameter embeddings: %s", keys
|
|
||||||
)
|
|
||||||
inversion_format = "parameters"
|
|
||||||
elif "string_to_token" in keys and "string_to_param" in keys:
|
|
||||||
logger.debug("detected Textual Inversion token embeddings: %s", keys)
|
|
||||||
inversion_format = "embeddings"
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
"unknown Textual Inversion format, no recognized keys: %s", keys
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if inversion_format == "concept":
|
if format == "concept":
|
||||||
# separate token and the embeds
|
blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight)
|
||||||
token = list(loaded_embeds.keys())[0]
|
elif format == "parameters":
|
||||||
|
blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight)
|
||||||
layer = loaded_embeds[token].numpy().astype(dtype)
|
elif format == "embeddings":
|
||||||
layer *= weight
|
blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight)
|
||||||
|
|
||||||
if base_token in embeds:
|
|
||||||
embeds[base_token] += layer
|
|
||||||
else:
|
|
||||||
embeds[base_token] = layer
|
|
||||||
|
|
||||||
if token in embeds:
|
|
||||||
embeds[token] += layer
|
|
||||||
else:
|
|
||||||
embeds[token] = layer
|
|
||||||
elif inversion_format == "parameters":
|
|
||||||
emb_params = loaded_embeds["emb_params"]
|
|
||||||
|
|
||||||
num_tokens = emb_params.shape[0]
|
|
||||||
logger.debug("generating %s layer tokens for %s", num_tokens, name)
|
|
||||||
|
|
||||||
sum_layer = np.zeros(emb_params[0, :].shape)
|
|
||||||
|
|
||||||
for i in range(num_tokens):
|
|
||||||
token = f"{base_token}-{i}"
|
|
||||||
layer = emb_params[i, :].numpy().astype(dtype)
|
|
||||||
layer *= weight
|
|
||||||
|
|
||||||
sum_layer += layer
|
|
||||||
if token in embeds:
|
|
||||||
embeds[token] += layer
|
|
||||||
else:
|
|
||||||
embeds[token] = layer
|
|
||||||
|
|
||||||
# add base and sum tokens to embeds
|
|
||||||
if base_token in embeds:
|
|
||||||
embeds[base_token] += sum_layer
|
|
||||||
else:
|
|
||||||
embeds[base_token] = sum_layer
|
|
||||||
|
|
||||||
sum_token = f"{base_token}-all"
|
|
||||||
if sum_token in embeds:
|
|
||||||
embeds[sum_token] += sum_layer
|
|
||||||
else:
|
|
||||||
embeds[sum_token] = sum_layer
|
|
||||||
elif inversion_format == "embeddings":
|
|
||||||
string_to_token = loaded_embeds["string_to_token"]
|
|
||||||
string_to_param = loaded_embeds["string_to_param"]
|
|
||||||
|
|
||||||
# separate token and embeds
|
|
||||||
token = list(string_to_token.keys())[0]
|
|
||||||
trained_embeds = string_to_param[token]
|
|
||||||
|
|
||||||
num_tokens = trained_embeds.shape[0]
|
|
||||||
logger.debug("generating %s layer tokens for %s", num_tokens, name)
|
|
||||||
|
|
||||||
sum_layer = np.zeros(trained_embeds[0, :].shape)
|
|
||||||
|
|
||||||
for i in range(num_tokens):
|
|
||||||
token = f"{base_token}-{i}"
|
|
||||||
layer = trained_embeds[i, :].numpy().astype(dtype)
|
|
||||||
layer *= weight
|
|
||||||
|
|
||||||
sum_layer += layer
|
|
||||||
if token in embeds:
|
|
||||||
embeds[token] += layer
|
|
||||||
else:
|
|
||||||
embeds[token] = layer
|
|
||||||
|
|
||||||
# add base and sum tokens to embeds
|
|
||||||
if base_token in embeds:
|
|
||||||
embeds[base_token] += sum_layer
|
|
||||||
else:
|
|
||||||
embeds[base_token] = sum_layer
|
|
||||||
|
|
||||||
sum_token = f"{base_token}-all"
|
|
||||||
if sum_token in embeds:
|
|
||||||
embeds[sum_token] += sum_layer
|
|
||||||
else:
|
|
||||||
embeds[sum_token] = sum_layer
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown Textual Inversion format: {inversion_format}")
|
raise ValueError(f"unknown Textual Inversion format: {format}")
|
||||||
|
|
||||||
# add the tokens to the tokenizer
|
# add the tokens to the tokenizer
|
||||||
logger.debug(
|
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
|
||||||
"found embeddings for %s tokens: %s",
|
if num_added_tokens == 0:
|
||||||
len(embeds.keys()),
|
raise ValueError(
|
||||||
list(embeds.keys()),
|
"The tokenizer already contains the tokens. Please pass a different `token` that is not already in the tokenizer."
|
||||||
)
|
)
|
||||||
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
|
|
||||||
if num_added_tokens == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.trace("added %s tokens", num_added_tokens)
|
logger.trace("added %s tokens", num_added_tokens)
|
||||||
|
|
||||||
# resize the token embeddings
|
blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens)
|
||||||
# text_encoder.resize_token_embeddings(len(tokenizer))
|
|
||||||
embedding_node = [
|
|
||||||
n
|
|
||||||
for n in text_encoder.graph.initializer
|
|
||||||
if n.name == "text_model.embeddings.token_embedding.weight"
|
|
||||||
][0]
|
|
||||||
base_weights = numpy_helper.to_array(embedding_node)
|
|
||||||
|
|
||||||
weights_dim = base_weights.shape[1]
|
|
||||||
zero_weights = np.zeros((num_added_tokens, weights_dim))
|
|
||||||
embedding_weights = np.concatenate((base_weights, zero_weights), axis=0)
|
|
||||||
|
|
||||||
for token, weights in embeds.items():
|
|
||||||
token_id = tokenizer.convert_tokens_to_ids(token)
|
|
||||||
logger.trace("embedding %s weights for token %s", weights.shape, token)
|
|
||||||
embedding_weights[token_id] = weights
|
|
||||||
|
|
||||||
# replace embedding_node
|
|
||||||
for i in range(len(text_encoder.graph.initializer)):
|
|
||||||
if (
|
|
||||||
text_encoder.graph.initializer[i].name
|
|
||||||
== "text_model.embeddings.token_embedding.weight"
|
|
||||||
):
|
|
||||||
new_initializer = numpy_helper.from_array(
|
|
||||||
embedding_weights.astype(base_weights.dtype), embedding_node.name
|
|
||||||
)
|
|
||||||
logger.trace("new initializer data type: %s", new_initializer.data_type)
|
|
||||||
del text_encoder.graph.initializer[i]
|
|
||||||
text_encoder.graph.initializer.insert(i, new_initializer)
|
|
||||||
|
|
||||||
return (text_encoder, tokenizer)
|
return (text_encoder, tokenizer)
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
from PIL import Image, ImageChops
|
from PIL import Image, ImageChops
|
||||||
|
|
||||||
from ..params import Border, Size
|
from ..params import Border, Size
|
||||||
|
@ -13,7 +15,7 @@ def expand_image(
|
||||||
fill="white",
|
fill="white",
|
||||||
noise_source=noise_source_histogram,
|
noise_source=noise_source_histogram,
|
||||||
mask_filter=mask_filter_none,
|
mask_filter=mask_filter_none,
|
||||||
):
|
) -> Tuple[Image.Image, Image.Image, Image.Image, Tuple[int]]:
|
||||||
size = Size(*source.size).add_border(expand)
|
size = Size(*source.size).add_border(expand)
|
||||||
size = tuple(size)
|
size = tuple(size)
|
||||||
origin = (expand.left, expand.top)
|
origin = (expand.left, expand.top)
|
||||||
|
|
|
@ -2,8 +2,16 @@ import sys
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from optimum.onnxruntime.modeling_diffusion import (
|
||||||
|
ORTModel,
|
||||||
|
ORTStableDiffusionPipelineBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..torch_before_ort import SessionOptions
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
from .context import ServerContext
|
from .context import ServerContext
|
||||||
|
|
||||||
|
|
|
@ -27,10 +27,14 @@ MEMORY_ERRORS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def worker_main(worker: WorkerContext, server: ServerContext, *args):
|
def worker_main(
|
||||||
apply_patches(server)
|
worker: WorkerContext, server: ServerContext, *args, exit=exit, patch=True
|
||||||
|
):
|
||||||
setproctitle("onnx-web worker: %s" % (worker.device.device))
|
setproctitle("onnx-web worker: %s" % (worker.device.device))
|
||||||
|
|
||||||
|
if patch:
|
||||||
|
apply_patches(server)
|
||||||
|
|
||||||
logger.trace(
|
logger.trace(
|
||||||
"checking in from worker with providers: %s", get_available_providers()
|
"checking in from worker with providers: %s", get_available_providers()
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,11 +4,19 @@ from onnx_web.convert.utils import (
|
||||||
DEFAULT_OPSET,
|
DEFAULT_OPSET,
|
||||||
ConversionContext,
|
ConversionContext,
|
||||||
download_progress,
|
download_progress,
|
||||||
|
remove_prefix,
|
||||||
|
resolve_tensor,
|
||||||
|
source_format,
|
||||||
tuple_to_correction,
|
tuple_to_correction,
|
||||||
tuple_to_diffusion,
|
tuple_to_diffusion,
|
||||||
tuple_to_source,
|
tuple_to_source,
|
||||||
tuple_to_upscaling,
|
tuple_to_upscaling,
|
||||||
)
|
)
|
||||||
|
from tests.helpers import (
|
||||||
|
TEST_MODEL_DIFFUSION_SD15,
|
||||||
|
TEST_MODEL_UPSCALING_SWINIR,
|
||||||
|
test_needs_models,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConversionContextTests(unittest.TestCase):
|
class ConversionContextTests(unittest.TestCase):
|
||||||
|
@ -182,3 +190,51 @@ class TupleToUpscalingTests(unittest.TestCase):
|
||||||
self.assertEqual(source["scale"], 2)
|
self.assertEqual(source["scale"], 2)
|
||||||
self.assertEqual(source["half"], True)
|
self.assertEqual(source["half"], True)
|
||||||
self.assertEqual(source["opset"], 14)
|
self.assertEqual(source["opset"], 14)
|
||||||
|
|
||||||
|
|
||||||
|
class SourceFormatTests(unittest.TestCase):
|
||||||
|
def test_with_format(self):
|
||||||
|
result = source_format({
|
||||||
|
"format": "foo",
|
||||||
|
})
|
||||||
|
self.assertEqual(result, "foo")
|
||||||
|
|
||||||
|
def test_source_known_extension(self):
|
||||||
|
result = source_format({
|
||||||
|
"source": "foo.safetensors",
|
||||||
|
})
|
||||||
|
self.assertEqual(result, "safetensors")
|
||||||
|
|
||||||
|
def test_source_unknown_extension(self):
|
||||||
|
result = source_format({
|
||||||
|
"source": "foo.none"
|
||||||
|
})
|
||||||
|
self.assertEqual(result, None)
|
||||||
|
|
||||||
|
def test_incomplete_model(self):
|
||||||
|
self.assertIsNone(source_format({}))
|
||||||
|
|
||||||
|
|
||||||
|
class RemovePrefixTests(unittest.TestCase):
|
||||||
|
def test_with_prefix(self):
|
||||||
|
self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar")
|
||||||
|
|
||||||
|
def test_without_prefix(self):
|
||||||
|
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
|
||||||
|
|
||||||
|
|
||||||
|
class LoadTorchTests(unittest.TestCase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LoadTensorTests(unittest.TestCase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ResolveTensorTests(unittest.TestCase):
|
||||||
|
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
|
||||||
|
def test_resolve_existing(self):
|
||||||
|
self.assertEqual(resolve_tensor("../models/.cache/upscaling-swinir"), TEST_MODEL_UPSCALING_SWINIR)
|
||||||
|
|
||||||
|
def test_resolve_missing(self):
|
||||||
|
self.assertIsNone(resolve_tensor("missing"))
|
||||||
|
|
|
@ -14,3 +14,4 @@ def test_device() -> DeviceParams:
|
||||||
|
|
||||||
|
|
||||||
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
|
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
|
||||||
|
TEST_MODEL_UPSCALING_SWINIR = "../models/.cache/upscaling-swinir.pth"
|
||||||
|
|
|
@ -90,6 +90,7 @@
|
||||||
"spinalcase",
|
"spinalcase",
|
||||||
"stabilityai",
|
"stabilityai",
|
||||||
"stringcase",
|
"stringcase",
|
||||||
|
"swinir",
|
||||||
"timestep",
|
"timestep",
|
||||||
"timesteps",
|
"timesteps",
|
||||||
"tojson",
|
"tojson",
|
||||||
|
|
Loading…
Reference in New Issue