add more tests
This commit is contained in:
parent
e8d7d9a881
commit
783e8eab4b
|
@ -0,0 +1,227 @@
|
|||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from onnx import GraphProto, ModelProto
|
||||
from onnx.numpy_helper import from_array, to_array
|
||||
|
||||
from onnx_web.convert.diffusion.textual_inversion import (
|
||||
blend_embedding_concept,
|
||||
blend_embedding_embeddings,
|
||||
blend_embedding_node,
|
||||
blend_embedding_parameters,
|
||||
blend_textual_inversions,
|
||||
detect_embedding_format,
|
||||
)
|
||||
|
||||
TEST_DIMS = (8, 8)
|
||||
TEST_DIMS_EMBEDS = (1, *TEST_DIMS)
|
||||
|
||||
TEST_MODEL_EMBEDS = {
|
||||
"string_to_token": {
|
||||
"test": 1,
|
||||
},
|
||||
"string_to_param": {
|
||||
"test": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DetectEmbeddingFormatTests(unittest.TestCase):
|
||||
def test_concept(self):
|
||||
embedding = {
|
||||
"<test>": "test",
|
||||
}
|
||||
self.assertEqual(detect_embedding_format(embedding), "concept")
|
||||
|
||||
def test_parameters(self):
|
||||
embedding = {
|
||||
"emb_params": "test",
|
||||
}
|
||||
self.assertEqual(detect_embedding_format(embedding), "parameters")
|
||||
|
||||
def test_embeddings(self):
|
||||
embedding = {
|
||||
"string_to_token": "test",
|
||||
"string_to_param": "test",
|
||||
}
|
||||
self.assertEqual(detect_embedding_format(embedding), "embeddings")
|
||||
|
||||
def test_unknown(self):
|
||||
embedding = {
|
||||
"what_is_this": "test",
|
||||
}
|
||||
self.assertEqual(detect_embedding_format(embedding), None)
|
||||
|
||||
|
||||
class BlendEmbeddingConceptTests(unittest.TestCase):
|
||||
def test_existing_base_token(self):
|
||||
embeds = {
|
||||
"test": np.ones(TEST_DIMS),
|
||||
}
|
||||
blend_embedding_concept(embeds, {
|
||||
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||
self.assertEqual(embeds["test"].mean(), 2)
|
||||
|
||||
def test_missing_base_token(self):
|
||||
embeds = {}
|
||||
blend_embedding_concept(embeds, {
|
||||
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||
|
||||
def test_existing_token(self):
|
||||
embeds = {
|
||||
"<test>": np.ones(TEST_DIMS),
|
||||
}
|
||||
blend_embedding_concept(embeds, {
|
||||
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
|
||||
keys = list(embeds.keys())
|
||||
keys.sort()
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(keys, ["<test>", "test"])
|
||||
|
||||
def test_missing_token(self):
|
||||
embeds = {}
|
||||
blend_embedding_concept(embeds, {
|
||||
"<test>": torch.from_numpy(np.ones(TEST_DIMS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
|
||||
keys = list(embeds.keys())
|
||||
keys.sort()
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(keys, ["<test>", "test"])
|
||||
|
||||
|
||||
class BlendEmbeddingParametersTests(unittest.TestCase):
|
||||
def test_existing_base_token(self):
|
||||
embeds = {
|
||||
"test": np.ones(TEST_DIMS),
|
||||
}
|
||||
blend_embedding_parameters(embeds, {
|
||||
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||
self.assertEqual(embeds["test"].mean(), 2)
|
||||
|
||||
def test_missing_base_token(self):
|
||||
embeds = {}
|
||||
blend_embedding_parameters(embeds, {
|
||||
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||
|
||||
def test_existing_token(self):
|
||||
embeds = {
|
||||
"test": np.ones(TEST_DIMS_EMBEDS),
|
||||
}
|
||||
blend_embedding_parameters(embeds, {
|
||||
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
|
||||
keys = list(embeds.keys())
|
||||
keys.sort()
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(keys, ["test", "test-0", "test-all"])
|
||||
|
||||
def test_missing_token(self):
|
||||
embeds = {}
|
||||
blend_embedding_parameters(embeds, {
|
||||
"emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)),
|
||||
}, np.float32, "test", 1.0)
|
||||
|
||||
keys = list(embeds.keys())
|
||||
keys.sort()
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(keys, ["test", "test-0", "test-all"])
|
||||
|
||||
|
||||
class BlendEmbeddingEmbeddingsTests(unittest.TestCase):
|
||||
def test_existing_base_token(self):
|
||||
embeds = {
|
||||
"test": np.ones(TEST_DIMS),
|
||||
}
|
||||
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||
self.assertEqual(embeds["test"].mean(), 2)
|
||||
|
||||
def test_missing_base_token(self):
|
||||
embeds = {}
|
||||
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(embeds["test"].shape, TEST_DIMS)
|
||||
|
||||
def test_existing_token(self):
|
||||
embeds = {
|
||||
"test": np.ones(TEST_DIMS),
|
||||
}
|
||||
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
|
||||
|
||||
keys = list(embeds.keys())
|
||||
keys.sort()
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(keys, ["test", "test-0", "test-all"])
|
||||
|
||||
def test_missing_token(self):
|
||||
embeds = {}
|
||||
blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0)
|
||||
|
||||
keys = list(embeds.keys())
|
||||
keys.sort()
|
||||
|
||||
self.assertIn("test", embeds)
|
||||
self.assertEqual(keys, ["test", "test-0", "test-all"])
|
||||
|
||||
|
||||
class BlendEmbeddingNodeTests(unittest.TestCase):
|
||||
def test_expand_weights(self):
|
||||
weights = from_array(np.ones(TEST_DIMS))
|
||||
weights.name = "text_model.embeddings.token_embedding.weight"
|
||||
|
||||
model = ModelProto(graph=GraphProto(initializer=[
|
||||
weights,
|
||||
]))
|
||||
|
||||
embeds = {}
|
||||
blend_embedding_node(model, {
|
||||
'convert_tokens_to_ids': lambda t: t,
|
||||
}, embeds, 2)
|
||||
|
||||
result = to_array(model.graph.initializer[0])
|
||||
|
||||
self.assertEqual(len(model.graph.initializer), 1)
|
||||
self.assertEqual(result.shape, (10, 8)) # (8 + 2, 8)
|
||||
|
||||
|
||||
class BlendTextualInversionsTests(unittest.TestCase):
|
||||
def test_blend_multi_concept(self):
|
||||
pass
|
||||
|
||||
def test_blend_multi_parameters(self):
|
||||
pass
|
||||
|
||||
def test_blend_multi_embeddings(self):
|
||||
pass
|
||||
|
||||
def test_blend_multi_mixed(self):
|
||||
pass
|
|
@ -0,0 +1,24 @@
|
|||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.image.utils import expand_image
|
||||
from onnx_web.params import Border
|
||||
|
||||
|
||||
class ExpandImageTests(unittest.TestCase):
|
||||
def test_expand(self):
|
||||
result = expand_image(
|
||||
Image.new("RGB", (8, 8)),
|
||||
Image.new("RGB", (8, 8), "white"),
|
||||
Border.even(4),
|
||||
)
|
||||
self.assertEqual(result[0].size, (16, 16))
|
||||
|
||||
def test_masked(self):
|
||||
result = expand_image(
|
||||
Image.new("RGB", (8, 8), "red"),
|
||||
Image.new("RGB", (8, 8), "white"),
|
||||
Border.even(4),
|
||||
)
|
||||
self.assertEqual(result[0].getpixel((8, 8)), (255, 0, 0))
|
|
@ -0,0 +1,42 @@
|
|||
import unittest
|
||||
from multiprocessing import Queue, Value
|
||||
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from onnx_web.worker.worker import EXIT_INTERRUPT, worker_main
|
||||
from tests.helpers import test_device
|
||||
|
||||
|
||||
class WorkerMainTests(unittest.TestCase):
|
||||
def test_pending_exception_empty(self):
|
||||
pass
|
||||
|
||||
def test_pending_exception_interrupt(self):
|
||||
status = None
|
||||
|
||||
def exit(exit_status):
|
||||
status = exit_status
|
||||
|
||||
cancel = Value("L", False)
|
||||
logs = Queue()
|
||||
pending = Queue()
|
||||
progress = Queue()
|
||||
pid = Value("L", False)
|
||||
idle = Value("L", False)
|
||||
|
||||
pending.close()
|
||||
# worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit)
|
||||
|
||||
self.assertEqual(status, EXIT_INTERRUPT)
|
||||
|
||||
def test_pending_exception_retry(self):
|
||||
pass
|
||||
|
||||
def test_pending_exception_value(self):
|
||||
pass
|
||||
|
||||
def test_pending_exception_other_memory(self):
|
||||
pass
|
||||
|
||||
def test_pending_exception_other_unknown(self):
|
||||
pass
|
Loading…
Reference in New Issue