From b851c234fe1de84e8504ba8675f712e283aca404 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 14 Sep 2023 19:35:48 -0500 Subject: [PATCH] more tests --- api/Makefile | 1 + api/onnx_web/chain/base.py | 2 +- api/onnx_web/chain/source_noise.py | 4 +- api/onnx_web/chain/source_url.py | 4 +- api/onnx_web/convert/diffusion/lora.py | 6 +- api/tests/chain/test_blend_grid.py | 21 +++ api/tests/chain/test_blend_img2img.py | 26 ++++ api/tests/chain/test_blend_mask.py | 1 + api/tests/chain/test_correct_codeformer.py | 33 ++++ api/tests/chain/test_reduce_thumbnail.py | 28 ++++ api/tests/chain/test_source_noise.py | 25 +++ api/tests/chain/test_source_s3.py | 25 +++ api/tests/chain/test_source_url.py | 24 +++ api/tests/chain/test_tile.py | 7 +- api/tests/convert/__init__.py | 0 api/tests/convert/diffusion/__init__.py | 0 api/tests/convert/diffusion/test_lora.py | 170 +++++++++++++++++++++ api/tests/convert/test_utils.py | 19 +++ api/tests/server/test_load.py | 83 ++++++++++ 19 files changed, 470 insertions(+), 9 deletions(-) create mode 100644 api/tests/chain/test_blend_grid.py create mode 100644 api/tests/chain/test_blend_img2img.py create mode 100644 api/tests/chain/test_correct_codeformer.py create mode 100644 api/tests/chain/test_reduce_thumbnail.py create mode 100644 api/tests/chain/test_source_noise.py create mode 100644 api/tests/chain/test_source_s3.py create mode 100644 api/tests/chain/test_source_url.py create mode 100644 api/tests/convert/__init__.py create mode 100644 api/tests/convert/diffusion/__init__.py create mode 100644 api/tests/convert/diffusion/test_lora.py create mode 100644 api/tests/convert/test_utils.py create mode 100644 api/tests/server/test_load.py diff --git a/api/Makefile b/api/Makefile index 780fa7d1..b5c90421 100644 --- a/api/Makefile +++ b/api/Makefile @@ -21,6 +21,7 @@ test: python -m coverage run -m unittest discover -v -s tests/ python -m coverage html -i python -m coverage xml -i + python -m coverage report -i package: package-dist package-upload diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 9b554de3..b67bf40b 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -95,7 +95,7 @@ class ChainPipeline: def outputs(self, params: ImageParams, sources: int): outputs = sources for callback, _params, kwargs in self.stages: - outputs += callback.outputs(kwargs.get("params", params), outputs) + outputs = callback.outputs(kwargs.get("params", params), outputs) return outputs diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index 738e9878..89e65abc 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Callable, List +from typing import Callable, List, Optional from PIL import Image @@ -22,7 +22,7 @@ class SourceNoiseStage(BaseStage): *, size: Size, noise_source: Callable, - stage_source: Image.Image, + stage_source: Optional[Image.Image] = None, **kwargs, ) -> List[Image.Image]: logger.info("generating image from noise source") diff --git a/api/onnx_web/chain/source_url.py b/api/onnx_web/chain/source_url.py index 8b1683f8..33e5ac78 100644 --- a/api/onnx_web/chain/source_url.py +++ b/api/onnx_web/chain/source_url.py @@ -1,6 +1,6 @@ from io import BytesIO from logging import getLogger -from typing import List +from typing import List, Optional import requests from PIL import Image @@ -23,7 +23,7 @@ class SourceURLStage(BaseStage): sources: List[Image.Image], *, source_urls: List[str], - stage_source: Image.Image, + stage_source: Optional[Image.Image] = None, **kwargs, ) -> List[Image.Image]: logger.info("loading image from URL source") diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index e2dc0058..3d3fed7c 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple, Union import numpy as np import torch -from onnx import ModelProto, load, numpy_helper +from onnx import ModelProto, NodeProto, load, numpy_helper from onnx.checker import check_model from onnx.external_data_helper import ( convert_model_to_external_data, @@ -39,7 +39,7 @@ def sum_weights(a: np.ndarray, b: np.ndarray) -> np.ndarray: lr = a if kernel == (1, 1): - lr = np.expand_dims(lr, axis=(2, 3)) + lr = np.expand_dims(lr, axis=(2, 3)) # TODO: generate axis return hr + lr @@ -78,7 +78,7 @@ def fix_node_name(key: str): return fixed_name -def fix_xl_names(keys: Dict[str, Any], nodes: List[Any]): +def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]): fixed = {} for key, value in keys.items(): diff --git a/api/tests/chain/test_blend_grid.py b/api/tests/chain/test_blend_grid.py new file mode 100644 index 00000000..8244df5e --- /dev/null +++ b/api/tests/chain/test_blend_grid.py @@ -0,0 +1,21 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.blend_grid import BlendGridStage +from onnx_web.chain.blend_linear import BlendLinearStage + + +class BlendGridStageTests(unittest.TestCase): + def test_stage(self): + stage = BlendGridStage() + sources = [ + Image.new("RGB", (64, 64), "black"), + Image.new("RGB", (64, 64), "white"), + Image.new("RGB", (64, 64), "black"), + Image.new("RGB", (64, 64), "white"), + ] + result = stage.run(None, None, None, None, sources, height=2, width=2) + + self.assertEqual(len(result), 5) + self.assertEqual(result[-1].getpixel((0,0)), (0, 0, 0)) \ No newline at end of file diff --git a/api/tests/chain/test_blend_img2img.py b/api/tests/chain/test_blend_img2img.py new file mode 100644 index 00000000..21b583f0 --- /dev/null +++ b/api/tests/chain/test_blend_img2img.py @@ -0,0 +1,26 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.blend_img2img import BlendImg2ImgStage +from onnx_web.params import DeviceParams, ImageParams +from onnx_web.server.context import ServerContext +from onnx_web.worker.context import WorkerContext + + +class BlendImg2ImgStageTests(unittest.TestCase): + def test_stage(self): + """ + stage = BlendImg2ImgStage() + params = ImageParams("runwayml/stable-diffusion-v1-5", "txt2img", "euler-a", "an astronaut eating a hamburger", 3.0, 1, 1) + server = ServerContext() + worker = WorkerContext("test", DeviceParams("cpu", "CPUProvider"), None, None, None, None, None, None, 0) + sources = [ + Image.new("RGB", (64, 64), "black"), + ] + result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127)) + """ + pass \ No newline at end of file diff --git a/api/tests/chain/test_blend_mask.py b/api/tests/chain/test_blend_mask.py index 410249fe..cf70535f 100644 --- a/api/tests/chain/test_blend_mask.py +++ b/api/tests/chain/test_blend_mask.py @@ -1,4 +1,5 @@ import unittest + from PIL import Image from onnx_web.chain.blend_mask import BlendMaskStage diff --git a/api/tests/chain/test_correct_codeformer.py b/api/tests/chain/test_correct_codeformer.py new file mode 100644 index 00000000..8203e876 --- /dev/null +++ b/api/tests/chain/test_correct_codeformer.py @@ -0,0 +1,33 @@ +import unittest + +from onnx_web.chain.correct_codeformer import CorrectCodeformerStage +from onnx_web.params import DeviceParams, HighresParams, UpscaleParams +from onnx_web.server.context import ServerContext +from onnx_web.server.hacks import apply_patches +from onnx_web.worker.context import WorkerContext + + +class CorrectCodeformerStageTests(unittest.TestCase): + def test_empty(self): + """ + server = ServerContext() + apply_patches(server) + + worker = WorkerContext( + "test", + DeviceParams("cpu", "CPUProvider"), + None, + None, + None, + None, + None, + None, + 0, + ) + stage = CorrectCodeformerStage() + sources = [] + result = stage.run(worker, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams("")) + + self.assertEqual(len(result), 0) + """ + pass \ No newline at end of file diff --git a/api/tests/chain/test_reduce_thumbnail.py b/api/tests/chain/test_reduce_thumbnail.py new file mode 100644 index 00000000..14cb12a7 --- /dev/null +++ b/api/tests/chain/test_reduce_thumbnail.py @@ -0,0 +1,28 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.reduce_crop import ReduceCropStage +from onnx_web.chain.reduce_thumbnail import ReduceThumbnailStage +from onnx_web.params import HighresParams, Size, UpscaleParams + + +class ReduceThumbnailStageTests(unittest.TestCase): + def test_empty(self): + stage_source = Image.new("RGB", (64, 64)) + stage = ReduceThumbnailStage() + sources = [] + result = stage.run( + None, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(""), + origin=Size(0, 0), + size=Size(128, 128), + stage_source=stage_source, + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_source_noise.py b/api/tests/chain/test_source_noise.py new file mode 100644 index 00000000..8187a751 --- /dev/null +++ b/api/tests/chain/test_source_noise.py @@ -0,0 +1,25 @@ +import unittest + +from onnx_web.chain.source_noise import SourceNoiseStage +from onnx_web.image.noise_source import noise_source_fill_edge +from onnx_web.params import HighresParams, Size, UpscaleParams + + +class SourceNoiseStageTests(unittest.TestCase): + def test_empty(self): + stage = SourceNoiseStage() + sources = [] + result = stage.run( + None, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(""), + origin=Size(0, 0), + size=Size(128, 128), + noise_source=noise_source_fill_edge, + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_source_s3.py b/api/tests/chain/test_source_s3.py new file mode 100644 index 00000000..aad37c5b --- /dev/null +++ b/api/tests/chain/test_source_s3.py @@ -0,0 +1,25 @@ +import unittest + +from onnx_web.chain.source_s3 import SourceS3Stage +from onnx_web.params import HighresParams, Size, UpscaleParams + + +class SourceS3StageTests(unittest.TestCase): + def test_empty(self): + stage = SourceS3Stage() + sources = [] + result = stage.run( + None, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(""), + origin=Size(0, 0), + size=Size(128, 128), + bucket="test", + source_keys=[], + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_source_url.py b/api/tests/chain/test_source_url.py new file mode 100644 index 00000000..1f185b7b --- /dev/null +++ b/api/tests/chain/test_source_url.py @@ -0,0 +1,24 @@ +import unittest + +from onnx_web.chain.source_url import SourceURLStage +from onnx_web.params import HighresParams, Size, UpscaleParams + + +class SourceURLStageTests(unittest.TestCase): + def test_empty(self): + stage = SourceURLStage() + sources = [] + result = stage.run( + None, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(""), + origin=Size(0, 0), + size=Size(128, 128), + source_urls=[], + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_tile.py b/api/tests/chain/test_tile.py index a006a34e..a613a719 100644 --- a/api/tests/chain/test_tile.py +++ b/api/tests/chain/test_tile.py @@ -2,7 +2,12 @@ import unittest from PIL import Image -from onnx_web.chain.tile import complete_tile, generate_tile_spiral, get_tile_grads, needs_tile +from onnx_web.chain.tile import ( + complete_tile, + generate_tile_spiral, + get_tile_grads, + needs_tile, +) from onnx_web.params import Size diff --git a/api/tests/convert/__init__.py b/api/tests/convert/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/tests/convert/diffusion/__init__.py b/api/tests/convert/diffusion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/tests/convert/diffusion/test_lora.py b/api/tests/convert/diffusion/test_lora.py new file mode 100644 index 00000000..cbc64979 --- /dev/null +++ b/api/tests/convert/diffusion/test_lora.py @@ -0,0 +1,170 @@ +import unittest + +import numpy as np +from onnx import GraphProto, ModelProto, NodeProto +from onnx.numpy_helper import from_array + +from onnx_web.convert.diffusion.lora import ( + blend_loras, + buffer_external_data_tensors, + fix_initializer_name, + fix_node_name, + fix_xl_names, + interp_to_match, + kernel_slice, + sum_weights, +) + + +class SumWeightsTests(unittest.TestCase): + def test_same_shape(self): + weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4))) + self.assertEqual(weights.shape, (4, 4)) + + def test_1x1_kernel(self): + weights = sum_weights(np.zeros((4, 4, 1, 1)), np.ones((4, 4))) + self.assertEqual(weights.shape, (4, 4, 1, 1)) + + weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1))) + self.assertEqual(weights.shape, (4, 4, 1, 1)) + + + def test_3x3_kernel(self): + """ + weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4))) + self.assertEqual(weights.shape, (4, 4, 1, 1)) + """ + pass + + +class BufferExternalDataTensorTests(unittest.TestCase): + def test_basic_external(self): + model = ModelProto( + graph=GraphProto( + initializer=[ + from_array(np.zeros((4, 4))), + ], + ) + ) + (slim_model, external_weights) = buffer_external_data_tensors(model) + + self.assertEqual(len(slim_model.graph.initializer), len(model.graph.initializer)) + self.assertEqual(len(external_weights), 1) + + +class FixInitializerKeyTests(unittest.TestCase): + def test_fix_name(self): + inputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"] + outputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"] + + for input, output in zip(inputs, outputs): + self.assertEqual(fix_initializer_name(input), output) + + +class FixNodeNameTests(unittest.TestCase): + def test_fix_name(self): + inputs = [ + "lora_unet/up_blocks/3/attentions/2/transformer_blocks/0/attn2_to_out/0.lora_down.weight", + "_prefix", + ] + outputs = [ + "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight", + "prefix", + ] + + for input, output in zip(inputs, outputs): + self.assertEqual(fix_node_name(input), output) + + +class FixXLNameTests(unittest.TestCase): + def test_empty(self): + nodes = {} + fixed = fix_xl_names(nodes, []) + + self.assertEqual(fixed, {}) + + def test_input_block(self): + nodes = { + "input_block_proj.lora_down.weight": {}, + } + fixed = fix_xl_names(nodes, [ + NodeProto(name="/down_blocks_proj/MatMul"), + ]) + + self.assertEqual(fixed, { + "down_blocks_proj": nodes["input_block_proj.lora_down.weight"], + }) + + def test_middle_block(self): + nodes = { + "middle_block_proj.lora_down.weight": {}, + } + fixed = fix_xl_names(nodes, [ + NodeProto(name="/mid_blocks_proj/MatMul"), + ]) + + self.assertEqual(fixed, { + "mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"], + }) + + def test_output_block(self): + pass + + def test_text_model(self): + pass + + def test_unknown_block(self): + pass + + def test_unmatched_block(self): + nodes = { + "lora_unet.input_block.lora_down.weight": {}, + } + fixed = fix_xl_names(nodes, [""]) + + self.assertEqual(fixed, nodes) + + def test_output_projection(self): + nodes = { + "output_block_proj_o.lora_down.weight": {}, + } + fixed = fix_xl_names(nodes, [ + NodeProto(name="/up_blocks_proj_o/MatMul"), + ]) + + self.assertEqual(fixed, { + "up_blocks_proj_out": nodes["output_block_proj_o.lora_down.weight"], + }) + + +class KernelSliceTests(unittest.TestCase): + def test_within_kernel(self): + self.assertEqual( + kernel_slice(1, 1, (3, 3, 3, 3)), + (1, 1), + ) + + def test_outside_kernel(self): + self.assertEqual( + kernel_slice(9, 9, (3, 3, 3, 3)), + (2, 2), + ) + +class BlendLoRATests(unittest.TestCase): + pass + +class InterpToMatchTests(unittest.TestCase): + def test_same_shape(self): + ref = np.zeros((4, 4)) + resize = np.zeros((4, 4)) + self.assertEqual(interp_to_match(ref, resize).shape, (4, 4)) + + def test_different_one_dim(self): + ref = np.zeros((4, 2)) + resize = np.zeros((4, 4)) + self.assertEqual(interp_to_match(ref, resize).shape, (4, 4)) + + def test_different_both_dims(self): + ref = np.zeros((2, 2)) + resize = np.zeros((4, 4)) + self.assertEqual(interp_to_match(ref, resize).shape, (4, 4)) diff --git a/api/tests/convert/test_utils.py b/api/tests/convert/test_utils.py new file mode 100644 index 00000000..755d6032 --- /dev/null +++ b/api/tests/convert/test_utils.py @@ -0,0 +1,19 @@ +import unittest + +from onnx_web.convert.utils import DEFAULT_OPSET, ConversionContext, download_progress + + +class ConversionContextTests(unittest.TestCase): + def test_from_environ(self): + context = ConversionContext.from_environ() + self.assertEqual(context.opset, DEFAULT_OPSET) + + def test_map_location(self): + context = ConversionContext.from_environ() + self.assertEqual(context.map_location.type, "cpu") + + +class DownloadProgressTests(unittest.TestCase): + def test_download_example(self): + path = download_progress([("https://example.com", "/tmp/example-dot-com")]) + self.assertEqual(path, "/tmp/example-dot-com") diff --git a/api/tests/server/test_load.py b/api/tests/server/test_load.py new file mode 100644 index 00000000..67a5f4e4 --- /dev/null +++ b/api/tests/server/test_load.py @@ -0,0 +1,83 @@ +import unittest + +from onnx_web.server.load import ( + get_available_platforms, + get_config_params, + get_correction_models, + get_diffusion_models, + get_extra_hashes, + get_extra_strings, + get_highres_methods, + get_mask_filters, + get_network_models, + get_noise_sources, + get_source_filters, + get_upscaling_models, + get_wildcard_data, +) + + +class ConfigParamTests(unittest.TestCase): + def test_before_setup(self): + params = get_config_params() + self.assertIsNotNone(params) + +class AvailablePlatformTests(unittest.TestCase): + def test_before_setup(self): + platforms = get_available_platforms() + self.assertIsNotNone(platforms) + +class CorrectModelTests(unittest.TestCase): + def test_before_setup(self): + models = get_correction_models() + self.assertIsNotNone(models) + +class DiffusionModelTests(unittest.TestCase): + def test_before_setup(self): + models = get_diffusion_models() + self.assertIsNotNone(models) + +class NetworkModelTests(unittest.TestCase): + def test_before_setup(self): + models = get_network_models() + self.assertIsNotNone(models) + +class UpscalingModelTests(unittest.TestCase): + def test_before_setup(self): + models = get_upscaling_models() + self.assertIsNotNone(models) + +class WildcardDataTests(unittest.TestCase): + def test_before_setup(self): + wildcards = get_wildcard_data() + self.assertIsNotNone(wildcards) + +class ExtraStringsTests(unittest.TestCase): + def test_before_setup(self): + strings = get_extra_strings() + self.assertIsNotNone(strings) + +class ExtraHashesTests(unittest.TestCase): + def test_before_setup(self): + hashes = get_extra_hashes() + self.assertIsNotNone(hashes) + +class HighresMethodTests(unittest.TestCase): + def test_before_setup(self): + methods = get_highres_methods() + self.assertIsNotNone(methods) + +class MaskFilterTests(unittest.TestCase): + def test_before_setup(self): + filters = get_mask_filters() + self.assertIsNotNone(filters) + +class NoiseSourceTests(unittest.TestCase): + def test_before_setup(self): + sources = get_noise_sources() + self.assertIsNotNone(sources) + +class SourceFilterTests(unittest.TestCase): + def test_before_setup(self): + filters = get_source_filters() + self.assertIsNotNone(filters)