diff --git a/api/tests/convert/diffusion/test_textual_inversion.py b/api/tests/convert/diffusion/test_textual_inversion.py new file mode 100644 index 00000000..246d53b4 --- /dev/null +++ b/api/tests/convert/diffusion/test_textual_inversion.py @@ -0,0 +1,227 @@ +import unittest + +import numpy as np +import torch +from onnx import GraphProto, ModelProto +from onnx.numpy_helper import from_array, to_array + +from onnx_web.convert.diffusion.textual_inversion import ( + blend_embedding_concept, + blend_embedding_embeddings, + blend_embedding_node, + blend_embedding_parameters, + blend_textual_inversions, + detect_embedding_format, +) + +TEST_DIMS = (8, 8) +TEST_DIMS_EMBEDS = (1, *TEST_DIMS) + +TEST_MODEL_EMBEDS = { + "string_to_token": { + "test": 1, + }, + "string_to_param": { + "test": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), + }, +} + + +class DetectEmbeddingFormatTests(unittest.TestCase): + def test_concept(self): + embedding = { + "": "test", + } + self.assertEqual(detect_embedding_format(embedding), "concept") + + def test_parameters(self): + embedding = { + "emb_params": "test", + } + self.assertEqual(detect_embedding_format(embedding), "parameters") + + def test_embeddings(self): + embedding = { + "string_to_token": "test", + "string_to_param": "test", + } + self.assertEqual(detect_embedding_format(embedding), "embeddings") + + def test_unknown(self): + embedding = { + "what_is_this": "test", + } + self.assertEqual(detect_embedding_format(embedding), None) + + +class BlendEmbeddingConceptTests(unittest.TestCase): + def test_existing_base_token(self): + embeds = { + "test": np.ones(TEST_DIMS), + } + blend_embedding_concept(embeds, { + "": torch.from_numpy(np.ones(TEST_DIMS)), + }, np.float32, "test", 1.0) + + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) + self.assertEqual(embeds["test"].mean(), 2) + + def test_missing_base_token(self): + embeds = {} + blend_embedding_concept(embeds, { + "": torch.from_numpy(np.ones(TEST_DIMS)), + }, np.float32, "test", 1.0) + + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) + + def test_existing_token(self): + embeds = { + "": np.ones(TEST_DIMS), + } + blend_embedding_concept(embeds, { + "": torch.from_numpy(np.ones(TEST_DIMS)), + }, np.float32, "test", 1.0) + + keys = list(embeds.keys()) + keys.sort() + + self.assertIn("test", embeds) + self.assertEqual(keys, ["", "test"]) + + def test_missing_token(self): + embeds = {} + blend_embedding_concept(embeds, { + "": torch.from_numpy(np.ones(TEST_DIMS)), + }, np.float32, "test", 1.0) + + keys = list(embeds.keys()) + keys.sort() + + self.assertIn("test", embeds) + self.assertEqual(keys, ["", "test"]) + + +class BlendEmbeddingParametersTests(unittest.TestCase): + def test_existing_base_token(self): + embeds = { + "test": np.ones(TEST_DIMS), + } + blend_embedding_parameters(embeds, { + "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), + }, np.float32, "test", 1.0) + + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) + self.assertEqual(embeds["test"].mean(), 2) + + def test_missing_base_token(self): + embeds = {} + blend_embedding_parameters(embeds, { + "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), + }, np.float32, "test", 1.0) + + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) + + def test_existing_token(self): + embeds = { + "test": np.ones(TEST_DIMS_EMBEDS), + } + blend_embedding_parameters(embeds, { + "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), + }, np.float32, "test", 1.0) + + keys = list(embeds.keys()) + keys.sort() + + self.assertIn("test", embeds) + self.assertEqual(keys, ["test", "test-0", "test-all"]) + + def test_missing_token(self): + embeds = {} + blend_embedding_parameters(embeds, { + "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), + }, np.float32, "test", 1.0) + + keys = list(embeds.keys()) + keys.sort() + + self.assertIn("test", embeds) + self.assertEqual(keys, ["test", "test-0", "test-all"]) + + +class BlendEmbeddingEmbeddingsTests(unittest.TestCase): + def test_existing_base_token(self): + embeds = { + "test": np.ones(TEST_DIMS), + } + blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) + + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) + self.assertEqual(embeds["test"].mean(), 2) + + def test_missing_base_token(self): + embeds = {} + blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) + + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) + + def test_existing_token(self): + embeds = { + "test": np.ones(TEST_DIMS), + } + blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) + + keys = list(embeds.keys()) + keys.sort() + + self.assertIn("test", embeds) + self.assertEqual(keys, ["test", "test-0", "test-all"]) + + def test_missing_token(self): + embeds = {} + blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) + + keys = list(embeds.keys()) + keys.sort() + + self.assertIn("test", embeds) + self.assertEqual(keys, ["test", "test-0", "test-all"]) + + +class BlendEmbeddingNodeTests(unittest.TestCase): + def test_expand_weights(self): + weights = from_array(np.ones(TEST_DIMS)) + weights.name = "text_model.embeddings.token_embedding.weight" + + model = ModelProto(graph=GraphProto(initializer=[ + weights, + ])) + + embeds = {} + blend_embedding_node(model, { + 'convert_tokens_to_ids': lambda t: t, + }, embeds, 2) + + result = to_array(model.graph.initializer[0]) + + self.assertEqual(len(model.graph.initializer), 1) + self.assertEqual(result.shape, (10, 8)) # (8 + 2, 8) + + +class BlendTextualInversionsTests(unittest.TestCase): + def test_blend_multi_concept(self): + pass + + def test_blend_multi_parameters(self): + pass + + def test_blend_multi_embeddings(self): + pass + + def test_blend_multi_mixed(self): + pass diff --git a/api/tests/image/test_utils.py b/api/tests/image/test_utils.py new file mode 100644 index 00000000..f3b10fd5 --- /dev/null +++ b/api/tests/image/test_utils.py @@ -0,0 +1,24 @@ +import unittest + +from PIL import Image + +from onnx_web.image.utils import expand_image +from onnx_web.params import Border + + +class ExpandImageTests(unittest.TestCase): + def test_expand(self): + result = expand_image( + Image.new("RGB", (8, 8)), + Image.new("RGB", (8, 8), "white"), + Border.even(4), + ) + self.assertEqual(result[0].size, (16, 16)) + + def test_masked(self): + result = expand_image( + Image.new("RGB", (8, 8), "red"), + Image.new("RGB", (8, 8), "white"), + Border.even(4), + ) + self.assertEqual(result[0].getpixel((8, 8)), (255, 0, 0)) diff --git a/api/tests/worker/test_worker.py b/api/tests/worker/test_worker.py new file mode 100644 index 00000000..06c0822d --- /dev/null +++ b/api/tests/worker/test_worker.py @@ -0,0 +1,42 @@ +import unittest +from multiprocessing import Queue, Value + +from onnx_web.server.context import ServerContext +from onnx_web.worker.context import WorkerContext +from onnx_web.worker.worker import EXIT_INTERRUPT, worker_main +from tests.helpers import test_device + + +class WorkerMainTests(unittest.TestCase): + def test_pending_exception_empty(self): + pass + + def test_pending_exception_interrupt(self): + status = None + + def exit(exit_status): + status = exit_status + + cancel = Value("L", False) + logs = Queue() + pending = Queue() + progress = Queue() + pid = Value("L", False) + idle = Value("L", False) + + pending.close() + # worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) + + self.assertEqual(status, EXIT_INTERRUPT) + + def test_pending_exception_retry(self): + pass + + def test_pending_exception_value(self): + pass + + def test_pending_exception_other_memory(self): + pass + + def test_pending_exception_other_unknown(self): + pass