fix(api): write tests for embedding/inversion blending
This commit is contained in:
parent
ebdfa78737
commit
e9b1375440
|
@ -14,55 +14,23 @@ from ..utils import ConversionContext, load_tensor
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def blend_textual_inversions(
|
||||
server: ServerContext,
|
||||
text_encoder: ModelProto,
|
||||
tokenizer: CLIPTokenizer,
|
||||
inversions: List[Tuple[str, float, Optional[str], Optional[str]]],
|
||||
) -> Tuple[ModelProto, CLIPTokenizer]:
|
||||
# always load to CPU for blending
|
||||
device = torch.device("cpu")
|
||||
dtype = np.float32
|
||||
embeds = {}
|
||||
|
||||
for name, weight, base_token, inversion_format in inversions:
|
||||
if base_token is None:
|
||||
logger.debug("no base token provided, using name: %s", name)
|
||||
base_token = name
|
||||
|
||||
logger.info(
|
||||
"blending Textual Inversion %s with weight of %s for token %s",
|
||||
name,
|
||||
weight,
|
||||
base_token,
|
||||
)
|
||||
|
||||
loaded_embeds = load_tensor(name, map_location=device)
|
||||
if loaded_embeds is None:
|
||||
logger.warning("unable to load tensor")
|
||||
continue
|
||||
|
||||
if inversion_format is None:
|
||||
def detect_embedding_format(loaded_embeds) -> str:
|
||||
keys: List[str] = list(loaded_embeds.keys())
|
||||
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
|
||||
logger.debug("detected Textual Inversion concept: %s", keys)
|
||||
inversion_format = "concept"
|
||||
return "concept"
|
||||
elif "emb_params" in keys:
|
||||
logger.debug(
|
||||
"detected Textual Inversion parameter embeddings: %s", keys
|
||||
)
|
||||
inversion_format = "parameters"
|
||||
logger.debug("detected Textual Inversion parameter embeddings: %s", keys)
|
||||
return "parameters"
|
||||
elif "string_to_token" in keys and "string_to_param" in keys:
|
||||
logger.debug("detected Textual Inversion token embeddings: %s", keys)
|
||||
inversion_format = "embeddings"
|
||||
return "embeddings"
|
||||
else:
|
||||
logger.error(
|
||||
"unknown Textual Inversion format, no recognized keys: %s", keys
|
||||
)
|
||||
continue
|
||||
logger.error("unknown Textual Inversion format, no recognized keys: %s", keys)
|
||||
return None
|
||||
|
||||
if inversion_format == "concept":
|
||||
|
||||
def blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight):
|
||||
# separate token and the embeds
|
||||
token = list(loaded_embeds.keys())[0]
|
||||
|
||||
|
@ -78,11 +46,13 @@ def blend_textual_inversions(
|
|||
embeds[token] += layer
|
||||
else:
|
||||
embeds[token] = layer
|
||||
elif inversion_format == "parameters":
|
||||
|
||||
|
||||
def blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight):
|
||||
emb_params = loaded_embeds["emb_params"]
|
||||
|
||||
num_tokens = emb_params.shape[0]
|
||||
logger.debug("generating %s layer tokens for %s", num_tokens, name)
|
||||
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
|
||||
|
||||
sum_layer = np.zeros(emb_params[0, :].shape)
|
||||
|
||||
|
@ -108,7 +78,9 @@ def blend_textual_inversions(
|
|||
embeds[sum_token] += sum_layer
|
||||
else:
|
||||
embeds[sum_token] = sum_layer
|
||||
elif inversion_format == "embeddings":
|
||||
|
||||
|
||||
def blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight):
|
||||
string_to_token = loaded_embeds["string_to_token"]
|
||||
string_to_param = loaded_embeds["string_to_param"]
|
||||
|
||||
|
@ -117,7 +89,7 @@ def blend_textual_inversions(
|
|||
trained_embeds = string_to_param[token]
|
||||
|
||||
num_tokens = trained_embeds.shape[0]
|
||||
logger.debug("generating %s layer tokens for %s", num_tokens, name)
|
||||
logger.debug("generating %s layer tokens for %s", num_tokens, base_token)
|
||||
|
||||
sum_layer = np.zeros(trained_embeds[0, :].shape)
|
||||
|
||||
|
@ -143,23 +115,9 @@ def blend_textual_inversions(
|
|||
embeds[sum_token] += sum_layer
|
||||
else:
|
||||
embeds[sum_token] = sum_layer
|
||||
else:
|
||||
raise ValueError(f"unknown Textual Inversion format: {inversion_format}")
|
||||
|
||||
# add the tokens to the tokenizer
|
||||
logger.debug(
|
||||
"found embeddings for %s tokens: %s",
|
||||
len(embeds.keys()),
|
||||
list(embeds.keys()),
|
||||
)
|
||||
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
logger.trace("added %s tokens", num_added_tokens)
|
||||
|
||||
def blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens):
|
||||
# resize the token embeddings
|
||||
# text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
embedding_node = [
|
||||
|
@ -191,6 +149,59 @@ def blend_textual_inversions(
|
|||
del text_encoder.graph.initializer[i]
|
||||
text_encoder.graph.initializer.insert(i, new_initializer)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def blend_textual_inversions(
|
||||
server: ServerContext,
|
||||
text_encoder: ModelProto,
|
||||
tokenizer: CLIPTokenizer,
|
||||
embeddings: List[Tuple[str, float, Optional[str], Optional[str]]],
|
||||
) -> Tuple[ModelProto, CLIPTokenizer]:
|
||||
# always load to CPU for blending
|
||||
device = torch.device("cpu")
|
||||
dtype = np.float32
|
||||
embeds = {}
|
||||
|
||||
for name, weight, base_token, format in embeddings:
|
||||
if base_token is None:
|
||||
logger.debug("no base token provided, using name: %s", name)
|
||||
base_token = name
|
||||
|
||||
logger.info(
|
||||
"blending Textual Inversion %s with weight of %s for token %s",
|
||||
name,
|
||||
weight,
|
||||
base_token,
|
||||
)
|
||||
|
||||
loaded_embeds = load_tensor(name, map_location=device)
|
||||
if loaded_embeds is None:
|
||||
logger.warning("unable to load tensor")
|
||||
continue
|
||||
|
||||
if format is None:
|
||||
format = detect_embedding_format()
|
||||
|
||||
if format == "concept":
|
||||
blend_embedding_concept(embeds, loaded_embeds, dtype, base_token, weight)
|
||||
elif format == "parameters":
|
||||
blend_embedding_parameters(embeds, loaded_embeds, dtype, base_token, weight)
|
||||
elif format == "embeddings":
|
||||
blend_embedding_embeddings(embeds, loaded_embeds, dtype, base_token, weight)
|
||||
else:
|
||||
raise ValueError(f"unknown Textual Inversion format: {format}")
|
||||
|
||||
# add the tokens to the tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
"The tokenizer already contains the tokens. Please pass a different `token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
logger.trace("added %s tokens", num_added_tokens)
|
||||
|
||||
blend_embedding_node(text_encoder, tokenizer, embeds, num_added_tokens)
|
||||
|
||||
return (text_encoder, tokenizer)
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Tuple
|
||||
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from ..params import Border, Size
|
||||
|
@ -13,7 +15,7 @@ def expand_image(
|
|||
fill="white",
|
||||
noise_source=noise_source_histogram,
|
||||
mask_filter=mask_filter_none,
|
||||
):
|
||||
) -> Tuple[Image.Image, Image.Image, Image.Image, Tuple[int]]:
|
||||
size = Size(*source.size).add_border(expand)
|
||||
size = tuple(size)
|
||||
origin = (expand.left, expand.top)
|
||||
|
|
|
@ -2,8 +2,16 @@ import sys
|
|||
from functools import partial
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from optimum.onnxruntime.modeling_diffusion import (
|
||||
ORTModel,
|
||||
ORTStableDiffusionPipelineBase,
|
||||
)
|
||||
|
||||
from ..torch_before_ort import SessionOptions
|
||||
from ..utils import run_gc
|
||||
from .context import ServerContext
|
||||
|
||||
|
|
|
@ -27,10 +27,14 @@ MEMORY_ERRORS = [
|
|||
]
|
||||
|
||||
|
||||
def worker_main(worker: WorkerContext, server: ServerContext, *args):
|
||||
apply_patches(server)
|
||||
def worker_main(
|
||||
worker: WorkerContext, server: ServerContext, *args, exit=exit, patch=True
|
||||
):
|
||||
setproctitle("onnx-web worker: %s" % (worker.device.device))
|
||||
|
||||
if patch:
|
||||
apply_patches(server)
|
||||
|
||||
logger.trace(
|
||||
"checking in from worker with providers: %s", get_available_providers()
|
||||
)
|
||||
|
|
|
@ -4,11 +4,19 @@ from onnx_web.convert.utils import (
|
|||
DEFAULT_OPSET,
|
||||
ConversionContext,
|
||||
download_progress,
|
||||
remove_prefix,
|
||||
resolve_tensor,
|
||||
source_format,
|
||||
tuple_to_correction,
|
||||
tuple_to_diffusion,
|
||||
tuple_to_source,
|
||||
tuple_to_upscaling,
|
||||
)
|
||||
from tests.helpers import (
|
||||
TEST_MODEL_DIFFUSION_SD15,
|
||||
TEST_MODEL_UPSCALING_SWINIR,
|
||||
test_needs_models,
|
||||
)
|
||||
|
||||
|
||||
class ConversionContextTests(unittest.TestCase):
|
||||
|
@ -182,3 +190,51 @@ class TupleToUpscalingTests(unittest.TestCase):
|
|||
self.assertEqual(source["scale"], 2)
|
||||
self.assertEqual(source["half"], True)
|
||||
self.assertEqual(source["opset"], 14)
|
||||
|
||||
|
||||
class SourceFormatTests(unittest.TestCase):
|
||||
def test_with_format(self):
|
||||
result = source_format({
|
||||
"format": "foo",
|
||||
})
|
||||
self.assertEqual(result, "foo")
|
||||
|
||||
def test_source_known_extension(self):
|
||||
result = source_format({
|
||||
"source": "foo.safetensors",
|
||||
})
|
||||
self.assertEqual(result, "safetensors")
|
||||
|
||||
def test_source_unknown_extension(self):
|
||||
result = source_format({
|
||||
"source": "foo.none"
|
||||
})
|
||||
self.assertEqual(result, None)
|
||||
|
||||
def test_incomplete_model(self):
|
||||
self.assertIsNone(source_format({}))
|
||||
|
||||
|
||||
class RemovePrefixTests(unittest.TestCase):
|
||||
def test_with_prefix(self):
|
||||
self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar")
|
||||
|
||||
def test_without_prefix(self):
|
||||
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
|
||||
|
||||
|
||||
class LoadTorchTests(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class LoadTensorTests(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class ResolveTensorTests(unittest.TestCase):
|
||||
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
|
||||
def test_resolve_existing(self):
|
||||
self.assertEqual(resolve_tensor("../models/.cache/upscaling-swinir"), TEST_MODEL_UPSCALING_SWINIR)
|
||||
|
||||
def test_resolve_missing(self):
|
||||
self.assertIsNone(resolve_tensor("missing"))
|
||||
|
|
|
@ -14,3 +14,4 @@ def test_device() -> DeviceParams:
|
|||
|
||||
|
||||
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
|
||||
TEST_MODEL_UPSCALING_SWINIR = "../models/.cache/upscaling-swinir.pth"
|
||||
|
|
|
@ -90,6 +90,7 @@
|
|||
"spinalcase",
|
||||
"stabilityai",
|
||||
"stringcase",
|
||||
"swinir",
|
||||
"timestep",
|
||||
"timesteps",
|
||||
"tojson",
|
||||
|
|
Loading…
Reference in New Issue