1
0
Fork 0

fix(api): write tests for embedding/inversion blending

This commit is contained in:
Sean Sube 2023-10-06 19:04:48 -05:00
parent ebdfa78737
commit e9b1375440
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 230 additions and 147 deletions

View File

@ -14,55 +14,23 @@ from ..utils import ConversionContext, load_tensor
logger = getLogger(__name__)
@torch.no_grad()
def blend_textual_inversions(
server: ServerContext,
text_encoder: ModelProto,
tokenizer: CLIPTokenizer,
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
) -> Tuple[ModelProto, CLIPTokenizer]:
# always load to CPU for blending
device = torch.device("cpu")
dtype = np.float32
embeds = {}
for name, weight, base_token, inversion_format in inversions:
if base_token is None:
logger.debug("no base token provided, using name: %s", name)
base_token = name
logger.info(
"blending Textual Inversion %s with weight of %s for token %s",
name,
weight,
base_token,
)
loaded_embeds = load_tensor(name, map_location=device)
if loaded_embeds is None:
logger.warning("unable to load tensor")
continue
if inversion_format is None:
def detect_embedding_format(loaded_embeds) -> str:
keys: List[str] = list(loaded_embeds.keys())
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
logger.debug("detected Textual Inversion concept: %s", keys)
inversion_format = "concept"
return "concept"
elif "emb_params" in keys:
logger.debug(
"detected Textual Inversion parameter embeddings: %s", keys
)
inversion_format = "parameters"
logger.debug("detected Textual Inversion parameter embeddings: %s", keys)
return "parameters"
elif "string_to_token" in keys and "string_to_param" in keys:
logger.debug("detected Textual Inversion token embeddings: %s", keys)
inversion_format = "embeddings"
return "embeddings"
else:
logger.error(
"unknown Textual Inversion format, no recognized keys: %s", keys
)
continue
logger.error("unknown Textual Inversion format, no recognized keys: %s", keys)
return None
if inversion_format == "concept":
def blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight):
# separate token and the embeds
token = list(loaded_embeds.keys())[0]
@ -78,11 +46,13 @@ def blend_textual_inversions(
embeds[token] += layer
else:
embeds[token] = layer
elif inversion_format == "parameters":
def blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight):
emb_params = loaded_embeds["emb_params"]
num_tokens = emb_params.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, name)
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
sum_layer = np.zeros(emb_params[0, :].shape)
@ -108,7 +78,9 @@ def blend_textual_inversions(
embeds[sum_token] += sum_layer
else:
embeds[sum_token] = sum_layer
elif inversion_format == "embeddings":
def blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight):
string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"]
@ -117,7 +89,7 @@ def blend_textual_inversions(
trained_embeds = string_to_param[token]
num_tokens = trained_embeds.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, name)
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
sum_layer = np.zeros(trained_embeds[0, :].shape)
@ -143,23 +115,9 @@ def blend_textual_inversions(
embeds[sum_token] += sum_layer
else:
embeds[sum_token] = sum_layer
else:
raise ValueError(f"unknown Textual Inversion format: {inversion_format}")
# add the tokens to the tokenizer
logger.debug(
"found embeddings for %s tokens: %s",
len(embeds.keys()),
list(embeds.keys()),
)
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
)
logger.trace("added %s tokens", num_added_tokens)
def blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens):
# resize the token embeddings
# text_encoder.resize_token_embeddings(len(tokenizer))
embedding_node = [
@ -191,6 +149,59 @@ def blend_textual_inversions(
del text_encoder.graph.initializer[i]
text_encoder.graph.initializer.insert(i, new_initializer)
@torch.no_grad()
def blend_textual_inversions(
server: ServerContext,
text_encoder: ModelProto,
tokenizer: CLIPTokenizer,
embeddings: List[Tuple[str, float, Optional[str], Optional[str]]],
) -> Tuple[ModelProto, CLIPTokenizer]:
# always load to CPU for blending
device = torch.device("cpu")
dtype = np.float32
embeds = {}
for name, weight, base_token, format in embeddings:
if base_token is None:
logger.debug("no base token provided, using name: %s", name)
base_token = name
logger.info(
"blending Textual Inversion %s with weight of %s for token %s",
name,
weight,
base_token,
)
loaded_embeds = load_tensor(name, map_location=device)
if loaded_embeds is None:
logger.warning("unable to load tensor")
continue
if format is None:
format = detect_embedding_format()
if format == "concept":
blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight)
elif format == "parameters":
blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight)
elif format == "embeddings":
blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight)
else:
raise ValueError(f"unknown Textual Inversion format: {format}")
# add the tokens to the tokenizer
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
if num_added_tokens == 0:
raise ValueError(
"The tokenizer already contains the tokens. Please pass a different `token` that is not already in the tokenizer."
)
logger.trace("added %s tokens", num_added_tokens)
blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens)
return (text_encoder, tokenizer)

View File

@ -1,3 +1,5 @@
from typing import Tuple
from PIL import Image, ImageChops
from ..params import Border, Size
@ -13,7 +15,7 @@ def expand_image(
fill="white",
noise_source=noise_source_histogram,
mask_filter=mask_filter_none,
):
) -> Tuple[Image.Image, Image.Image, Image.Image, Tuple[int]]:
size = Size(*source.size).add_border(expand)
size = tuple(size)
origin = (expand.left, expand.top)

View File

@ -2,8 +2,16 @@ import sys
from functools import partial
from logging import getLogger
from os import path
from pathlib import Path
from typing import Dict, Optional, Union
from urllib.parse import urlparse
from optimum.onnxruntime.modeling_diffusion import (
ORTModel,
ORTStableDiffusionPipelineBase,
)
from ..torch_before_ort import SessionOptions
from ..utils import run_gc
from .context import ServerContext

View File

@ -27,10 +27,14 @@ MEMORY_ERRORS = [
]
def worker_main(worker: WorkerContext, server: ServerContext, *args):
apply_patches(server)
def worker_main(
worker: WorkerContext, server: ServerContext, *args, exit=exit, patch=True
):
setproctitle("onnx-web worker: %s" % (worker.device.device))
if patch:
apply_patches(server)
logger.trace(
"checking in from worker with providers: %s", get_available_providers()
)

View File

@ -4,11 +4,19 @@ from onnx_web.convert.utils import (
DEFAULT_OPSET,
ConversionContext,
download_progress,
remove_prefix,
resolve_tensor,
source_format,
tuple_to_correction,
tuple_to_diffusion,
tuple_to_source,
tuple_to_upscaling,
)
from tests.helpers import (
TEST_MODEL_DIFFUSION_SD15,
TEST_MODEL_UPSCALING_SWINIR,
test_needs_models,
)
class ConversionContextTests(unittest.TestCase):
@ -182,3 +190,51 @@ class TupleToUpscalingTests(unittest.TestCase):
self.assertEqual(source["scale"], 2)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)
class SourceFormatTests(unittest.TestCase):
def test_with_format(self):
result = source_format({
"format": "foo",
})
self.assertEqual(result, "foo")
def test_source_known_extension(self):
result = source_format({
"source": "foo.safetensors",
})
self.assertEqual(result, "safetensors")
def test_source_unknown_extension(self):
result = source_format({
"source": "foo.none"
})
self.assertEqual(result, None)
def test_incomplete_model(self):
self.assertIsNone(source_format({}))
class RemovePrefixTests(unittest.TestCase):
def test_with_prefix(self):
self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar")
def test_without_prefix(self):
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
class LoadTorchTests(unittest.TestCase):
pass
class LoadTensorTests(unittest.TestCase):
pass
class ResolveTensorTests(unittest.TestCase):
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
def test_resolve_existing(self):
self.assertEqual(resolve_tensor("../models/.cache/upscaling-swinir"), TEST_MODEL_UPSCALING_SWINIR)
def test_resolve_missing(self):
self.assertIsNone(resolve_tensor("missing"))

View File

@ -14,3 +14,4 @@ def test_device() -> DeviceParams:
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
TEST_MODEL_UPSCALING_SWINIR = "../models/.cache/upscaling-swinir.pth"

View File

@ -90,6 +90,7 @@
"spinalcase",
"stabilityai",
"stringcase",
"swinir",
"timestep",
"timesteps",
"tojson",