diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index 7a1d6310..03441a12 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -344,7 +344,7 @@ def process_tile_stack( tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile)) if isinstance(tile_stack, list): - tile_stack = StageResult.from_images(tile_stack) + tile_stack = StageResult.from_images(tile_stack, metadata=stack.metadata) tiles.append((left, top, tile_stack.as_images())) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 9b9cfce2..8e32714f 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -396,7 +396,7 @@ def run_inpaint_pipeline( latents=latents, ) - for i, image, metadata in enumerate(zip(images.as_images(), images.metadata)): + for i, (image, metadata) in enumerate(zip(images.as_images(), images.metadata)): if full_res_inpaint: if is_debug(): save_image(server, "adjusted-output.png", image) diff --git a/api/tests/chain/test_blend_denoise_fastnlmeans.py b/api/tests/chain/test_blend_denoise_fastnlmeans.py new file mode 100644 index 00000000..14a42be8 --- /dev/null +++ b/api/tests/chain/test_blend_denoise_fastnlmeans.py @@ -0,0 +1,51 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.blend_denoise_fastnlmeans import BlendDenoiseFastNLMeansStage +from onnx_web.chain.result import ImageMetadata, StageResult +from tests.helpers import test_params, test_size + + +class TestBlendDenoiseFastNLMeansStage(unittest.TestCase): + def test_run(self): + # Create a dummy image + size = test_size() + image = Image.new("RGB", (size.width, size.height), color="white") + + # Create a dummy StageResult object + sources = StageResult.from_images( + [image], + metadata=[ + ImageMetadata( + test_params(), + size, + ) + ], + ) + + # Create an instance of BlendDenoiseLocalStdStage + stage = BlendDenoiseFastNLMeansStage() + + # Call the run method with dummy parameters + result = stage.run( + _worker=None, + _server=None, + _stage=None, + _params=None, + sources=sources, + strength=5, + range=4, + stage_source=None, + callback=None, + ) + + # Assert that the result is an instance of StageResult + self.assertIsInstance(result, StageResult) + + # Assert that the result contains the denoised image + self.assertEqual(len(result), 1) + self.assertEqual(result.size(), size) + + # Assert that the metadata is preserved + self.assertEqual(result.metadata, sources.metadata) diff --git a/api/tests/image/test_ade_palette.py b/api/tests/image/test_ade_palette.py new file mode 100644 index 00000000..7749b689 --- /dev/null +++ b/api/tests/image/test_ade_palette.py @@ -0,0 +1,9 @@ +import unittest + +from onnx_web.image.ade_palette import ade_palette + + +class TestADEPalette(unittest.TestCase): + def test_palette_length(self): + palette = ade_palette() + self.assertEqual(len(palette), 150, "Palette length should be 150") diff --git a/api/tests/image/test_laion_face.py b/api/tests/image/test_laion_face.py new file mode 100644 index 00000000..0262f285 --- /dev/null +++ b/api/tests/image/test_laion_face.py @@ -0,0 +1,69 @@ +import unittest + +import numpy as np + +from onnx_web.image.laion_face import draw_pupils, generate_annotation, reverse_channels + + +class TestLaionFace(unittest.TestCase): + @unittest.skip + def test_draw_pupils(self): + # Create a dummy image + image = np.zeros((100, 100, 3), dtype=np.uint8) + + # Create a dummy landmark list + class LandmarkList: + def __init__(self, landmarks): + self.landmark = landmarks + + # Create a dummy drawing spec + class DrawingSpec: + def __init__(self, color): + self.color = color + + # Create some dummy landmarks + landmarks = [ + # Add your landmarks here + ] + + # Create a dummy drawing spec + drawing_spec = DrawingSpec(color=(255, 0, 0)) # Red color + + # Call the draw_pupils function + draw_pupils(image, LandmarkList(landmarks), drawing_spec) + + self.assertNotEqual(np.sum(image), 0, "Image should be modified") + + @unittest.skip + def test_generate_annotation(self): + # Create a dummy image + image = np.zeros((100, 100, 3), dtype=np.uint8) + + # Call the generate_annotation function + result = generate_annotation(image, max_faces=1, min_confidence=0.5) + + self.assertEqual( + result.shape, + image.shape, + "Result shape should be the same as the input image", + ) + self.assertNotEqual(np.sum(result), 0, "Result should not be all zeros") + + +class TestReverseChannels(unittest.TestCase): + def test_reverse_channels(self): + # Create a dummy image + image = np.zeros((100, 100, 3), dtype=np.uint8) + layer = np.ones((100, 100), dtype=np.uint8) + image[:, :, 0] = layer + + # Call the reverse_channels function + reversed_image = reverse_channels(image) + + self.assertEqual( + image.shape, reversed_image.shape, "Image shape should be the same" + ) + self.assertTrue( + np.array_equal(reversed_image[:, :, 2], layer), + "Channels should be reversed", + ) diff --git a/api/tests/test_utils.py b/api/tests/test_utils.py new file mode 100644 index 00000000..6db6838a --- /dev/null +++ b/api/tests/test_utils.py @@ -0,0 +1,99 @@ +import unittest + +from onnx_web.utils import ( + get_and_clamp_float, + get_and_clamp_int, + get_boolean, + get_from_list, + get_from_map, + get_list, + get_not_empty, + split_list, +) + + +class TestUtils(unittest.TestCase): + def test_split_list_empty(self): + self.assertEqual(split_list(""), []) + self.assertEqual(split_list(" "), []) + self.assertEqual(split_list(" , "), []) + + def test_split_list_single(self): + self.assertEqual(split_list("a"), ["a"]) + self.assertEqual(split_list(" a "), ["a"]) + self.assertEqual(split_list(" a, "), ["a"]) + self.assertEqual(split_list(" a , "), ["a"]) + + def test_split_list_multiple(self): + self.assertEqual(split_list("a,b"), ["a", "b"]) + self.assertEqual(split_list(" a , b "), ["a", "b"]) + self.assertEqual(split_list(" a, b "), ["a", "b"]) + self.assertEqual(split_list(" a ,b "), ["a", "b"]) + + def test_get_boolean_empty(self): + self.assertEqual(get_boolean({}, "key", False), False) + self.assertEqual(get_boolean({}, "key", True), True) + + def test_get_boolean_true(self): + self.assertEqual(get_boolean({"key": True}, "key", False), True) + self.assertEqual(get_boolean({"key": True}, "key", True), True) + + def test_get_boolean_false(self): + self.assertEqual(get_boolean({"key": False}, "key", False), False) + self.assertEqual(get_boolean({"key": False}, "key", True), False) + + def test_get_list_empty(self): + self.assertEqual(get_list({}, "key", ""), []) + self.assertEqual(get_list({}, "key", "a"), ["a"]) + + def test_get_list_exists(self): + self.assertEqual(get_list({"key": "a,b"}, "key", ""), ["a", "b"]) + self.assertEqual(get_list({"key": "a,b"}, "key", "c"), ["a", "b"]) + + def test_get_and_clamp_float_empty(self): + self.assertEqual(get_and_clamp_float({}, "key", 0.0, 1.0), 0.0) + self.assertEqual(get_and_clamp_float({}, "key", 1.0, 1.0), 1.0) + + def test_get_and_clamp_float_clamped(self): + self.assertEqual(get_and_clamp_float({"key": -1.0}, "key", 0.0, 1.0), 0.0) + self.assertEqual(get_and_clamp_float({"key": 2.0}, "key", 0.0, 1.0), 1.0) + + def test_get_and_clamp_float_normal(self): + self.assertEqual(get_and_clamp_float({"key": 0.5}, "key", 0.0, 1.0), 0.5) + + def test_get_and_clamp_int_empty(self): + self.assertEqual(get_and_clamp_int({}, "key", 0, 1), 1) + self.assertEqual(get_and_clamp_int({}, "key", 1, 1), 1) + + def test_get_and_clamp_int_clamped(self): + self.assertEqual(get_and_clamp_int({"key": 0}, "key", 1, 1), 1) + self.assertEqual(get_and_clamp_int({"key": 2}, "key", 1, 1), 1) + + def test_get_and_clamp_int_normal(self): + self.assertEqual(get_and_clamp_int({"key": 1}, "key", 0, 1), 1) + + def test_get_from_list_empty(self): + self.assertEqual(get_from_list({}, "key", ["a", "b"]), "a") + self.assertEqual(get_from_list({}, "key", ["a", "b"], "a"), "a") + + def test_get_from_list_exists(self): + self.assertEqual(get_from_list({"key": "a"}, "key", ["a", "b"]), "a") + self.assertEqual(get_from_list({"key": "b"}, "key", ["a", "b"]), "b") + + def test_get_from_list_invalid(self): + self.assertEqual(get_from_list({"key": "c"}, "key", ["a", "b"]), "a") + + def test_get_from_map_empty(self): + self.assertEqual(get_from_map({}, "key", {"a": 1, "b": 2}, "a"), 1) + self.assertEqual(get_from_map({}, "key", {"a": 1, "b": 2}, "b"), 2) + + def test_get_from_map_exists(self): + self.assertEqual(get_from_map({"key": "a"}, "key", {"a": 1, "b": 2}, "a"), 1) + self.assertEqual(get_from_map({"key": "b"}, "key", {"a": 1, "b": 2}, "a"), 2) + + def test_get_not_empty_empty(self): + self.assertEqual(get_not_empty({}, "key", "a"), "a") + self.assertEqual(get_not_empty({"key": ""}, "key", "a"), "a") + + def test_get_not_empty_exists(self): + self.assertEqual(get_not_empty({"key": "b"}, "key", "a"), "b")