From 65912c5a4adc49a1c09819abdf3d0d0597c527d2 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 19 Nov 2023 23:18:57 -0600 Subject: [PATCH] apply lint to tests, test highres --- api/Makefile | 4 + api/onnx_web/chain/result.py | 8 + api/onnx_web/chain/tile.py | 7 +- api/onnx_web/main.py | 1 + api/tests/chain/test_blend_grid.py | 16 +- api/tests/chain/test_blend_img2img.py | 31 +- api/tests/chain/test_blend_linear.py | 14 +- api/tests/chain/test_correct_codeformer.py | 2 +- api/tests/chain/test_tile.py | 151 +++--- api/tests/chain/test_upscale_highres.py | 10 +- api/tests/convert/diffusion/test_lora.py | 70 ++- .../diffusion/test_textual_inversion.py | 356 ++++++++------ api/tests/convert/test_utils.py | 310 ++++++------ api/tests/helpers.py | 6 +- api/tests/image/test_mask_filter.py | 30 +- api/tests/image/test_source_filter.py | 36 +- api/tests/image/test_utils.py | 28 +- api/tests/mocks.py | 64 +-- api/tests/prompt/test_parser.py | 8 +- api/tests/server/test_load.py | 14 + api/tests/server/test_model_cache.py | 78 +-- api/tests/test_diffusers/test_load.py | 454 ++++++++++-------- api/tests/test_diffusers/test_run.py | 347 +++++++------ api/tests/test_diffusers/test_utils.py | 171 ++++--- api/tests/worker/test_pool.py | 176 +++---- api/tests/worker/test_worker.py | 241 ++++++---- 26 files changed, 1506 insertions(+), 1127 deletions(-) diff --git a/api/Makefile b/api/Makefile index db142073..579c8546 100644 --- a/api/Makefile +++ b/api/Makefile @@ -33,13 +33,17 @@ package-upload: lint-check: black --check onnx_web/ + black --check tests/ flake8 onnx_web + flake8 tests isort --check-only --skip __init__.py --filter-files onnx_web isort --check-only --skip __init__.py --filter-files tests lint-fix: black onnx_web/ + black tests/ flake8 onnx_web + flake8 tests isort --skip __init__.py --filter-files onnx_web isort --skip __init__.py --filter-files tests diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index 3bc54e43..9bd7395d 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -19,6 +19,14 @@ class StageResult: def empty(): return StageResult(images=[]) + @staticmethod + def from_arrays(arrays: List[np.ndarray]): + return StageResult(arrays=arrays) + + @staticmethod + def from_images(images: List[Image.Image]): + return StageResult(images=images) + def __init__(self, arrays=None, images=None) -> None: if arrays is not None and images is not None: raise ValueError("stages must only return one type of result") diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index fb170736..f8fe5e8d 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -25,9 +25,7 @@ class TileCallback(Protocol): Definition for a tile job function. """ - def __call__( - self, image: Image.Image, dims: Tuple[int, int, int] - ) -> StageResult: + def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult: """ Run this stage against a single tile. """ @@ -319,6 +317,9 @@ def process_tile_stack( if mask: tile_mask = mask.crop((left, top, right, bottom)) + if isinstance(tile_stack, list): + tile_stack = StageResult.from_images(tile_stack) + for image_filter in filters: tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile)) diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index 4255242e..6de5dc39 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -48,6 +48,7 @@ def main(): # debug options if server.debug: import debugpy + debugpy.listen(5678) logger.warning("waiting for debugger") debugpy.wait_for_client() diff --git a/api/tests/chain/test_blend_grid.py b/api/tests/chain/test_blend_grid.py index b1623019..0e6188b1 100644 --- a/api/tests/chain/test_blend_grid.py +++ b/api/tests/chain/test_blend_grid.py @@ -9,13 +9,15 @@ from onnx_web.chain.result import StageResult class BlendGridStageTests(unittest.TestCase): def test_stage(self): stage = BlendGridStage() - sources = StageResult(images=[ - Image.new("RGB", (64, 64), "black"), - Image.new("RGB", (64, 64), "white"), - Image.new("RGB", (64, 64), "black"), - Image.new("RGB", (64, 64), "white"), - ]) + sources = StageResult( + images=[ + Image.new("RGB", (64, 64), "black"), + Image.new("RGB", (64, 64), "white"), + Image.new("RGB", (64, 64), "black"), + Image.new("RGB", (64, 64), "white"), + ] + ) result = stage.run(None, None, None, None, sources, height=2, width=2) self.assertEqual(len(result), 5) - self.assertEqual(result.as_image()[-1].getpixel((0,0)), (0, 0, 0)) \ No newline at end of file + self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0)) diff --git a/api/tests/chain/test_blend_img2img.py b/api/tests/chain/test_blend_img2img.py index 21b583f0..31aa27a3 100644 --- a/api/tests/chain/test_blend_img2img.py +++ b/api/tests/chain/test_blend_img2img.py @@ -6,21 +6,38 @@ from onnx_web.chain.blend_img2img import BlendImg2ImgStage from onnx_web.params import DeviceParams, ImageParams from onnx_web.server.context import ServerContext from onnx_web.worker.context import WorkerContext +from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_needs_models class BlendImg2ImgStageTests(unittest.TestCase): + @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) def test_stage(self): - """ stage = BlendImg2ImgStage() - params = ImageParams("runwayml/stable-diffusion-v1-5", "txt2img", "euler-a", "an astronaut eating a hamburger", 3.0, 1, 1) - server = ServerContext() - worker = WorkerContext("test", DeviceParams("cpu", "CPUProvider"), None, None, None, None, None, None, 0) + params = ImageParams( + TEST_MODEL_DIFFUSION_SD15, + "txt2img", + "euler-a", + "an astronaut eating a hamburger", + 3.0, + 1, + 1, + ) + server = ServerContext(model_path="../models", output_path="../outputs") + worker = WorkerContext( + "test", + DeviceParams("cpu", "CPUProvider"), + None, + None, + None, + None, + None, + None, + 0, + ) sources = [ Image.new("RGB", (64, 64), "black"), ] result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1) self.assertEqual(len(result), 1) - self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127)) - """ - pass \ No newline at end of file + self.assertEqual(result[0].getpixel((0, 0)), (127, 127, 127)) diff --git a/api/tests/chain/test_blend_linear.py b/api/tests/chain/test_blend_linear.py index a983a2e1..76a2715a 100644 --- a/api/tests/chain/test_blend_linear.py +++ b/api/tests/chain/test_blend_linear.py @@ -9,11 +9,15 @@ from onnx_web.chain.result import StageResult class BlendLinearStageTests(unittest.TestCase): def test_stage(self): stage = BlendLinearStage() - sources = StageResult(images=[ - Image.new("RGB", (64, 64), "black"), - ]) + sources = StageResult( + images=[ + Image.new("RGB", (64, 64), "black"), + ] + ) stage_source = Image.new("RGB", (64, 64), "white") - result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source) + result = stage.run( + None, None, None, None, sources, alpha=0.5, stage_source=stage_source + ) self.assertEqual(len(result), 1) - self.assertEqual(result.as_image()[0].getpixel((0,0)), (127, 127, 127)) \ No newline at end of file + self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127)) diff --git a/api/tests/chain/test_correct_codeformer.py b/api/tests/chain/test_correct_codeformer.py index 9cc24de0..8a90d0c9 100644 --- a/api/tests/chain/test_correct_codeformer.py +++ b/api/tests/chain/test_correct_codeformer.py @@ -30,4 +30,4 @@ class CorrectCodeformerStageTests(unittest.TestCase): self.assertEqual(len(result), 0) """ - pass \ No newline at end of file + pass diff --git a/api/tests/chain/test_tile.py b/api/tests/chain/test_tile.py index 6323c0bb..c27cb077 100644 --- a/api/tests/chain/test_tile.py +++ b/api/tests/chain/test_tile.py @@ -2,6 +2,7 @@ import unittest from PIL import Image +from onnx_web.chain.result import StageResult from onnx_web.chain.tile import ( complete_tile, generate_tile_grid, @@ -14,122 +15,126 @@ from onnx_web.params import Size class TestCompleteTile(unittest.TestCase): - def test_with_complete_tile(self): - partial = Image.new("RGB", (64, 64)) - output = complete_tile(partial, 64) + def test_with_complete_tile(self): + partial = Image.new("RGB", (64, 64)) + output = complete_tile(partial, 64) - self.assertEqual(output.size, (64, 64)) + self.assertEqual(output.size, (64, 64)) - def test_with_partial_tile(self): - partial = Image.new("RGB", (64, 32)) - output = complete_tile(partial, 64) + def test_with_partial_tile(self): + partial = Image.new("RGB", (64, 32)) + output = complete_tile(partial, 64) - self.assertEqual(output.size, (64, 64)) + self.assertEqual(output.size, (64, 64)) - def test_with_nothing(self): - output = complete_tile(None, 64) + def test_with_nothing(self): + output = complete_tile(None, 64) - self.assertIsNone(output) + self.assertIsNone(output) class TestNeedsTile(unittest.TestCase): - def test_with_undersized_source(self): - small = Image.new("RGB", (32, 32)) + def test_with_undersized_source(self): + small = Image.new("RGB", (32, 32)) - self.assertFalse(needs_tile(64, 64, source=small)) + self.assertFalse(needs_tile(64, 64, source=small)) - def test_with_oversized_source(self): - large = Image.new("RGB", (64, 64)) + def test_with_oversized_source(self): + large = Image.new("RGB", (64, 64)) - self.assertTrue(needs_tile(32, 32, source=large)) + self.assertTrue(needs_tile(32, 32, source=large)) - def test_with_undersized_size(self): - small = Size(32, 32) + def test_with_undersized_size(self): + small = Size(32, 32) - self.assertFalse(needs_tile(64, 64, size=small)) + self.assertFalse(needs_tile(64, 64, size=small)) - def test_with_oversized_source(self): - large = Size(64, 64) + def test_with_oversized_size(self): + large = Size(64, 64) - self.assertTrue(needs_tile(32, 32, size=large)) + self.assertTrue(needs_tile(32, 32, size=large)) - def test_with_nothing(self): - self.assertFalse(needs_tile(32, 32)) + def test_with_nothing(self): + self.assertFalse(needs_tile(32, 32)) class TestTileGrads(unittest.TestCase): - def test_center_tile(self): - grad_x, grad_y = make_tile_grads(32, 32, 8, 64, 64) + def test_center_tile(self): + grad_x, grad_y = make_tile_grads(32, 32, 8, 64, 64) - self.assertEqual(grad_x, [0, 1, 1, 0]) - self.assertEqual(grad_y, [0, 1, 1, 0]) + self.assertEqual(grad_x, [0, 1, 1, 0]) + self.assertEqual(grad_y, [0, 1, 1, 0]) - def test_vertical_edge_tile(self): - grad_x, grad_y = make_tile_grads(32, 0, 8, 64, 8) + def test_vertical_edge_tile(self): + grad_x, grad_y = make_tile_grads(32, 0, 8, 64, 8) - self.assertEqual(grad_x, [0, 1, 1, 0]) - self.assertEqual(grad_y, [1, 1, 1, 1]) + self.assertEqual(grad_x, [0, 1, 1, 0]) + self.assertEqual(grad_y, [1, 1, 1, 1]) - def test_horizontal_edge_tile(self): - grad_x, grad_y = make_tile_grads(0, 32, 8, 8, 64) + def test_horizontal_edge_tile(self): + grad_x, grad_y = make_tile_grads(0, 32, 8, 8, 64) - self.assertEqual(grad_x, [1, 1, 1, 1]) - self.assertEqual(grad_y, [0, 1, 1, 0]) + self.assertEqual(grad_x, [1, 1, 1, 1]) + self.assertEqual(grad_y, [0, 1, 1, 0]) class TestGenerateTileGrid(unittest.TestCase): - def test_grid_complete(self): - tiles = generate_tile_grid(16, 16, 8, 0.0) + def test_grid_complete(self): + tiles = generate_tile_grid(16, 16, 8, 0.0) - self.assertEqual(len(tiles), 4) - self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)]) + self.assertEqual(len(tiles), 4) + self.assertEqual(tiles, [(0, 0), (8, 0), (0, 8), (8, 8)]) - def test_grid_no_overlap(self): - tiles = generate_tile_grid(64, 64, 8, 0.0) + def test_grid_no_overlap(self): + tiles = generate_tile_grid(64, 64, 8, 0.0) - self.assertEqual(len(tiles), 64) - self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)]) - self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)]) + self.assertEqual(len(tiles), 64) + self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)]) + self.assertEqual(tiles[-5:-1], [(24, 56), (32, 56), (40, 56), (48, 56)]) - def test_grid_50_overlap(self): - tiles = generate_tile_grid(64, 64, 8, 0.5) + def test_grid_50_overlap(self): + tiles = generate_tile_grid(64, 64, 8, 0.5) - self.assertEqual(len(tiles), 225) - self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)]) - self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)]) + self.assertEqual(len(tiles), 256) + self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)]) + self.assertEqual(tiles[-5:-1], [(44, 60), (48, 60), (52, 60), (56, 60)]) class TestGenerateTileSpiral(unittest.TestCase): - def test_spiral_complete(self): - tiles = generate_tile_spiral(16, 16, 8, 0.0) + def test_spiral_complete(self): + tiles = generate_tile_spiral(16, 16, 8, 0.0) - self.assertEqual(len(tiles), 4) - self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)]) + self.assertEqual(len(tiles), 4) + self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)]) - def test_spiral_no_overlap(self): - tiles = generate_tile_spiral(64, 64, 8, 0.0) + def test_spiral_no_overlap(self): + tiles = generate_tile_spiral(64, 64, 8, 0.0) - self.assertEqual(len(tiles), 64) - self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)]) - self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)]) + self.assertEqual(len(tiles), 64) + self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)]) + self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)]) - def test_spiral_50_overlap(self): - tiles = generate_tile_spiral(64, 64, 8, 0.5) + def test_spiral_50_overlap(self): + tiles = generate_tile_spiral(64, 64, 8, 0.5) - self.assertEqual(len(tiles), 225) - self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)]) - self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)]) + self.assertEqual(len(tiles), 225) + self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)]) + self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)]) class TestProcessTileStack(unittest.TestCase): - def test_grid_full(self): - source = Image.new("RGB", (64, 64)) - blend = process_tile_stack(source, 32, 1, []) + def test_grid_full(self): + source = Image.new("RGB", (64, 64)) + blend = process_tile_stack( + StageResult(images=[source]), 32, 1, [], generate_tile_grid + ) - self.assertEqual(blend.size, (64, 64)) + self.assertEqual(blend[0].size, (64, 64)) - def test_grid_partial(self): - source = Image.new("RGB", (72, 72)) - blend = process_tile_stack(source, 32, 1, []) + def test_grid_partial(self): + source = Image.new("RGB", (72, 72)) + blend = process_tile_stack( + StageResult(images=[source]), 32, 1, [], generate_tile_grid + ) - self.assertEqual(blend.size, (72, 72)) + self.assertEqual(blend[0].size, (72, 72)) diff --git a/api/tests/chain/test_upscale_highres.py b/api/tests/chain/test_upscale_highres.py index 72437fc8..096eea54 100644 --- a/api/tests/chain/test_upscale_highres.py +++ b/api/tests/chain/test_upscale_highres.py @@ -9,6 +9,14 @@ class UpscaleHighresStageTests(unittest.TestCase): def test_empty(self): stage = UpscaleHighresStage() sources = StageResult.empty() - result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams("")) + result = stage.run( + None, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(""), + ) self.assertEqual(len(result), 0) diff --git a/api/tests/convert/diffusion/test_lora.py b/api/tests/convert/diffusion/test_lora.py index 87a7fff0..bcf19680 100644 --- a/api/tests/convert/diffusion/test_lora.py +++ b/api/tests/convert/diffusion/test_lora.py @@ -6,7 +6,6 @@ from onnx import GraphProto, ModelProto, NodeProto from onnx.numpy_helper import from_array from onnx_web.convert.diffusion.lora import ( - blend_loras, blend_node_conv_gemm, blend_node_matmul, blend_weights_loha, @@ -33,7 +32,6 @@ class SumWeightsTests(unittest.TestCase): weights = sum_weights(np.zeros((4, 4)), np.ones((4, 4, 1, 1))) self.assertEqual(weights.shape, (4, 4, 1, 1)) - def test_3x3_kernel(self): """ weights = sum_weights(np.zeros((4, 4, 3, 3)), np.ones((4, 4))) @@ -53,14 +51,20 @@ class BufferExternalDataTensorTests(unittest.TestCase): ) (slim_model, external_weights) = buffer_external_data_tensors(model) - self.assertEqual(len(slim_model.graph.initializer), len(model.graph.initializer)) + self.assertEqual( + len(slim_model.graph.initializer), len(model.graph.initializer) + ) self.assertEqual(len(external_weights), 1) class FixInitializerKeyTests(unittest.TestCase): def test_fix_name(self): - inputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight"] - outputs = ["lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight"] + inputs = [ + "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight" + ] + outputs = [ + "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0_lora_down_weight" + ] for input, output in zip(inputs, outputs): self.assertEqual(fix_initializer_name(input), output) @@ -92,25 +96,37 @@ class FixXLNameTests(unittest.TestCase): nodes = { "input_block_proj.lora_down.weight": {}, } - fixed = fix_xl_names(nodes, [ - NodeProto(name="/down_blocks_proj/MatMul"), - ]) + fixed = fix_xl_names( + nodes, + [ + NodeProto(name="/down_blocks_proj/MatMul"), + ], + ) - self.assertEqual(fixed, { - "down_blocks_proj": nodes["input_block_proj.lora_down.weight"], - }) + self.assertEqual( + fixed, + { + "down_blocks_proj": nodes["input_block_proj.lora_down.weight"], + }, + ) def test_middle_block(self): nodes = { "middle_block_proj.lora_down.weight": {}, } - fixed = fix_xl_names(nodes, [ - NodeProto(name="/mid_blocks_proj/MatMul"), - ]) + fixed = fix_xl_names( + nodes, + [ + NodeProto(name="/mid_blocks_proj/MatMul"), + ], + ) - self.assertEqual(fixed, { - "mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"], - }) + self.assertEqual( + fixed, + { + "mid_blocks_proj": nodes["middle_block_proj.lora_down.weight"], + }, + ) def test_output_block(self): pass @@ -133,13 +149,19 @@ class FixXLNameTests(unittest.TestCase): nodes = { "output_block_proj_out.lora_down.weight": {}, } - fixed = fix_xl_names(nodes, [ - NodeProto(name="/up_blocks_proj_out/MatMul"), - ]) + fixed = fix_xl_names( + nodes, + [ + NodeProto(name="/up_blocks_proj_out/MatMul"), + ], + ) - self.assertEqual(fixed, { - "up_blocks_proj_out": nodes["output_block_proj_out.lora_down.weight"], - }) + self.assertEqual( + fixed, + { + "up_blocks_proj_out": nodes["output_block_proj_out.lora_down.weight"], + }, + ) class KernelSliceTests(unittest.TestCase): @@ -250,6 +272,7 @@ class BlendWeightsLoHATests(unittest.TestCase): self.assertEqual(result.shape, (4, 4)) """ + class BlendWeightsLoRATests(unittest.TestCase): def test_blend_kernel_none(self): model = { @@ -260,7 +283,6 @@ class BlendWeightsLoRATests(unittest.TestCase): key, result = blend_weights_lora("foo.lora_down", "", model, torch.float32) self.assertEqual(result.shape, (4, 4)) - def test_blend_kernel_1x1(self): model = { "foo.lora_down": torch.from_numpy(np.ones((1, 4, 1, 1))), diff --git a/api/tests/convert/diffusion/test_textual_inversion.py b/api/tests/convert/diffusion/test_textual_inversion.py index 246d53b4..287907b6 100644 --- a/api/tests/convert/diffusion/test_textual_inversion.py +++ b/api/tests/convert/diffusion/test_textual_inversion.py @@ -10,7 +10,6 @@ from onnx_web.convert.diffusion.textual_inversion import ( blend_embedding_embeddings, blend_embedding_node, blend_embedding_parameters, - blend_textual_inversions, detect_embedding_format, ) @@ -18,210 +17,267 @@ TEST_DIMS = (8, 8) TEST_DIMS_EMBEDS = (1, *TEST_DIMS) TEST_MODEL_EMBEDS = { - "string_to_token": { + "string_to_token": { "test": 1, - }, - "string_to_param": { + }, + "string_to_param": { "test": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), - }, + }, } class DetectEmbeddingFormatTests(unittest.TestCase): - def test_concept(self): - embedding = { - "": "test", - } - self.assertEqual(detect_embedding_format(embedding), "concept") + def test_concept(self): + embedding = { + "": "test", + } + self.assertEqual(detect_embedding_format(embedding), "concept") - def test_parameters(self): - embedding = { - "emb_params": "test", - } - self.assertEqual(detect_embedding_format(embedding), "parameters") + def test_parameters(self): + embedding = { + "emb_params": "test", + } + self.assertEqual(detect_embedding_format(embedding), "parameters") - def test_embeddings(self): - embedding = { - "string_to_token": "test", - "string_to_param": "test", - } - self.assertEqual(detect_embedding_format(embedding), "embeddings") + def test_embeddings(self): + embedding = { + "string_to_token": "test", + "string_to_param": "test", + } + self.assertEqual(detect_embedding_format(embedding), "embeddings") - def test_unknown(self): - embedding = { - "what_is_this": "test", - } - self.assertEqual(detect_embedding_format(embedding), None) + def test_unknown(self): + embedding = { + "what_is_this": "test", + } + self.assertEqual(detect_embedding_format(embedding), None) class BlendEmbeddingConceptTests(unittest.TestCase): - def test_existing_base_token(self): - embeds = { - "test": np.ones(TEST_DIMS), - } - blend_embedding_concept(embeds, { - "": torch.from_numpy(np.ones(TEST_DIMS)), - }, np.float32, "test", 1.0) + def test_existing_base_token(self): + embeds = { + "test": np.ones(TEST_DIMS), + } + blend_embedding_concept( + embeds, + { + "": torch.from_numpy(np.ones(TEST_DIMS)), + }, + np.float32, + "test", + 1.0, + ) - self.assertIn("test", embeds) - self.assertEqual(embeds["test"].shape, TEST_DIMS) - self.assertEqual(embeds["test"].mean(), 2) + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) + self.assertEqual(embeds["test"].mean(), 2) - def test_missing_base_token(self): - embeds = {} - blend_embedding_concept(embeds, { - "": torch.from_numpy(np.ones(TEST_DIMS)), - }, np.float32, "test", 1.0) + def test_missing_base_token(self): + embeds = {} + blend_embedding_concept( + embeds, + { + "": torch.from_numpy(np.ones(TEST_DIMS)), + }, + np.float32, + "test", + 1.0, + ) - self.assertIn("test", embeds) - self.assertEqual(embeds["test"].shape, TEST_DIMS) + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) - def test_existing_token(self): - embeds = { - "": np.ones(TEST_DIMS), - } - blend_embedding_concept(embeds, { - "": torch.from_numpy(np.ones(TEST_DIMS)), - }, np.float32, "test", 1.0) + def test_existing_token(self): + embeds = { + "": np.ones(TEST_DIMS), + } + blend_embedding_concept( + embeds, + { + "": torch.from_numpy(np.ones(TEST_DIMS)), + }, + np.float32, + "test", + 1.0, + ) - keys = list(embeds.keys()) - keys.sort() + keys = list(embeds.keys()) + keys.sort() - self.assertIn("test", embeds) - self.assertEqual(keys, ["", "test"]) + self.assertIn("test", embeds) + self.assertEqual(keys, ["", "test"]) - def test_missing_token(self): - embeds = {} - blend_embedding_concept(embeds, { - "": torch.from_numpy(np.ones(TEST_DIMS)), - }, np.float32, "test", 1.0) + def test_missing_token(self): + embeds = {} + blend_embedding_concept( + embeds, + { + "": torch.from_numpy(np.ones(TEST_DIMS)), + }, + np.float32, + "test", + 1.0, + ) - keys = list(embeds.keys()) - keys.sort() + keys = list(embeds.keys()) + keys.sort() - self.assertIn("test", embeds) - self.assertEqual(keys, ["", "test"]) + self.assertIn("test", embeds) + self.assertEqual(keys, ["", "test"]) class BlendEmbeddingParametersTests(unittest.TestCase): - def test_existing_base_token(self): - embeds = { - "test": np.ones(TEST_DIMS), - } - blend_embedding_parameters(embeds, { - "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), - }, np.float32, "test", 1.0) + def test_existing_base_token(self): + embeds = { + "test": np.ones(TEST_DIMS), + } + blend_embedding_parameters( + embeds, + { + "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), + }, + np.float32, + "test", + 1.0, + ) - self.assertIn("test", embeds) - self.assertEqual(embeds["test"].shape, TEST_DIMS) - self.assertEqual(embeds["test"].mean(), 2) + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) + self.assertEqual(embeds["test"].mean(), 2) - def test_missing_base_token(self): - embeds = {} - blend_embedding_parameters(embeds, { - "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), - }, np.float32, "test", 1.0) + def test_missing_base_token(self): + embeds = {} + blend_embedding_parameters( + embeds, + { + "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), + }, + np.float32, + "test", + 1.0, + ) - self.assertIn("test", embeds) - self.assertEqual(embeds["test"].shape, TEST_DIMS) + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) - def test_existing_token(self): - embeds = { - "test": np.ones(TEST_DIMS_EMBEDS), - } - blend_embedding_parameters(embeds, { - "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), - }, np.float32, "test", 1.0) + def test_existing_token(self): + embeds = { + "test": np.ones(TEST_DIMS_EMBEDS), + } + blend_embedding_parameters( + embeds, + { + "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), + }, + np.float32, + "test", + 1.0, + ) - keys = list(embeds.keys()) - keys.sort() + keys = list(embeds.keys()) + keys.sort() - self.assertIn("test", embeds) - self.assertEqual(keys, ["test", "test-0", "test-all"]) + self.assertIn("test", embeds) + self.assertEqual(keys, ["test", "test-0", "test-all"]) - def test_missing_token(self): - embeds = {} - blend_embedding_parameters(embeds, { - "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), - }, np.float32, "test", 1.0) + def test_missing_token(self): + embeds = {} + blend_embedding_parameters( + embeds, + { + "emb_params": torch.from_numpy(np.ones(TEST_DIMS_EMBEDS)), + }, + np.float32, + "test", + 1.0, + ) - keys = list(embeds.keys()) - keys.sort() + keys = list(embeds.keys()) + keys.sort() - self.assertIn("test", embeds) - self.assertEqual(keys, ["test", "test-0", "test-all"]) + self.assertIn("test", embeds) + self.assertEqual(keys, ["test", "test-0", "test-all"]) class BlendEmbeddingEmbeddingsTests(unittest.TestCase): - def test_existing_base_token(self): - embeds = { - "test": np.ones(TEST_DIMS), - } - blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) + def test_existing_base_token(self): + embeds = { + "test": np.ones(TEST_DIMS), + } + blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) - self.assertIn("test", embeds) - self.assertEqual(embeds["test"].shape, TEST_DIMS) - self.assertEqual(embeds["test"].mean(), 2) + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) + self.assertEqual(embeds["test"].mean(), 2) - def test_missing_base_token(self): - embeds = {} - blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) + def test_missing_base_token(self): + embeds = {} + blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) - self.assertIn("test", embeds) - self.assertEqual(embeds["test"].shape, TEST_DIMS) + self.assertIn("test", embeds) + self.assertEqual(embeds["test"].shape, TEST_DIMS) - def test_existing_token(self): - embeds = { - "test": np.ones(TEST_DIMS), - } - blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) + def test_existing_token(self): + embeds = { + "test": np.ones(TEST_DIMS), + } + blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) - keys = list(embeds.keys()) - keys.sort() + keys = list(embeds.keys()) + keys.sort() - self.assertIn("test", embeds) - self.assertEqual(keys, ["test", "test-0", "test-all"]) + self.assertIn("test", embeds) + self.assertEqual(keys, ["test", "test-0", "test-all"]) - def test_missing_token(self): - embeds = {} - blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) + def test_missing_token(self): + embeds = {} + blend_embedding_embeddings(embeds, TEST_MODEL_EMBEDS, np.float32, "test", 1.0) - keys = list(embeds.keys()) - keys.sort() + keys = list(embeds.keys()) + keys.sort() - self.assertIn("test", embeds) - self.assertEqual(keys, ["test", "test-0", "test-all"]) + self.assertIn("test", embeds) + self.assertEqual(keys, ["test", "test-0", "test-all"]) class BlendEmbeddingNodeTests(unittest.TestCase): - def test_expand_weights(self): - weights = from_array(np.ones(TEST_DIMS)) - weights.name = "text_model.embeddings.token_embedding.weight" + def test_expand_weights(self): + weights = from_array(np.ones(TEST_DIMS)) + weights.name = "text_model.embeddings.token_embedding.weight" - model = ModelProto(graph=GraphProto(initializer=[ - weights, - ])) + model = ModelProto( + graph=GraphProto( + initializer=[ + weights, + ] + ) + ) - embeds = {} - blend_embedding_node(model, { - 'convert_tokens_to_ids': lambda t: t, - }, embeds, 2) + embeds = {} + blend_embedding_node( + model, + { + "convert_tokens_to_ids": lambda t: t, + }, + embeds, + 2, + ) - result = to_array(model.graph.initializer[0]) + result = to_array(model.graph.initializer[0]) - self.assertEqual(len(model.graph.initializer), 1) - self.assertEqual(result.shape, (10, 8)) # (8 + 2, 8) + self.assertEqual(len(model.graph.initializer), 1) + self.assertEqual(result.shape, (10, 8)) # (8 + 2, 8) class BlendTextualInversionsTests(unittest.TestCase): - def test_blend_multi_concept(self): - pass + def test_blend_multi_concept(self): + pass - def test_blend_multi_parameters(self): - pass + def test_blend_multi_parameters(self): + pass - def test_blend_multi_embeddings(self): - pass + def test_blend_multi_embeddings(self): + pass - def test_blend_multi_mixed(self): - pass + def test_blend_multi_mixed(self): + pass diff --git a/api/tests/convert/test_utils.py b/api/tests/convert/test_utils.py index f08f0d0c..ae0c2842 100644 --- a/api/tests/convert/test_utils.py +++ b/api/tests/convert/test_utils.py @@ -13,7 +13,6 @@ from onnx_web.convert.utils import ( tuple_to_upscaling, ) from tests.helpers import ( - TEST_MODEL_DIFFUSION_SD15, TEST_MODEL_UPSCALING_SWINIR, test_needs_models, ) @@ -21,220 +20,225 @@ from tests.helpers import ( class ConversionContextTests(unittest.TestCase): def test_from_environ(self): - context = ConversionContext.from_environ() - self.assertEqual(context.opset, DEFAULT_OPSET) + context = ConversionContext.from_environ() + self.assertEqual(context.opset, DEFAULT_OPSET) def test_map_location(self): - context = ConversionContext.from_environ() - self.assertEqual(context.map_location.type, "cpu") + context = ConversionContext.from_environ() + self.assertEqual(context.map_location.type, "cpu") class DownloadProgressTests(unittest.TestCase): - def test_download_example(self): - path = download_progress([("https://example.com", "/tmp/example-dot-com")]) - self.assertEqual(path, "/tmp/example-dot-com") + def test_download_example(self): + path = download_progress([("https://example.com", "/tmp/example-dot-com")]) + self.assertEqual(path, "/tmp/example-dot-com") class TupleToSourceTests(unittest.TestCase): - def test_basic_tuple(self): - source = tuple_to_source(("foo", "bar")) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_basic_tuple(self): + source = tuple_to_source(("foo", "bar")) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_basic_list(self): - source = tuple_to_source(["foo", "bar"]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_basic_list(self): + source = tuple_to_source(["foo", "bar"]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_basic_dict(self): - source = tuple_to_source(["foo", "bar"]) - source["bin"] = "bin" + def test_basic_dict(self): + source = tuple_to_source(["foo", "bar"]) + source["bin"] = "bin" - # make sure this is returned as-is with extra fields - second = tuple_to_source(source) + # make sure this is returned as-is with extra fields + second = tuple_to_source(source) - self.assertEqual(source, second) - self.assertIn("bin", second) + self.assertEqual(source, second) + self.assertIn("bin", second) class TupleToCorrectionTests(unittest.TestCase): - def test_basic_tuple(self): - source = tuple_to_correction(("foo", "bar")) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_basic_tuple(self): + source = tuple_to_correction(("foo", "bar")) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_basic_list(self): - source = tuple_to_correction(["foo", "bar"]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_basic_list(self): + source = tuple_to_correction(["foo", "bar"]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_basic_dict(self): - source = tuple_to_correction(["foo", "bar"]) - source["bin"] = "bin" + def test_basic_dict(self): + source = tuple_to_correction(["foo", "bar"]) + source["bin"] = "bin" - # make sure this is returned with extra fields - second = tuple_to_source(source) + # make sure this is returned with extra fields + second = tuple_to_source(source) - self.assertEqual(source, second) - self.assertIn("bin", second) + self.assertEqual(source, second) + self.assertIn("bin", second) - def test_scale_tuple(self): - source = tuple_to_correction(["foo", "bar", 2]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_scale_tuple(self): + source = tuple_to_correction(["foo", "bar", 2]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_half_tuple(self): - source = tuple_to_correction(["foo", "bar", True]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_half_tuple(self): + source = tuple_to_correction(["foo", "bar", True]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_opset_tuple(self): - source = tuple_to_correction(["foo", "bar", 14]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_opset_tuple(self): + source = tuple_to_correction(["foo", "bar", 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_all_tuple(self): - source = tuple_to_correction(["foo", "bar", 2, True, 14]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") - self.assertEqual(source["scale"], 2) - self.assertEqual(source["half"], True) - self.assertEqual(source["opset"], 14) + def test_all_tuple(self): + source = tuple_to_correction(["foo", "bar", 2, True, 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + self.assertEqual(source["scale"], 2) + self.assertEqual(source["half"], True) + self.assertEqual(source["opset"], 14) class TupleToDiffusionTests(unittest.TestCase): - def test_basic_tuple(self): - source = tuple_to_diffusion(("foo", "bar")) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_basic_tuple(self): + source = tuple_to_diffusion(("foo", "bar")) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_basic_list(self): - source = tuple_to_diffusion(["foo", "bar"]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_basic_list(self): + source = tuple_to_diffusion(["foo", "bar"]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_basic_dict(self): - source = tuple_to_diffusion(["foo", "bar"]) - source["bin"] = "bin" + def test_basic_dict(self): + source = tuple_to_diffusion(["foo", "bar"]) + source["bin"] = "bin" - # make sure this is returned with extra fields - second = tuple_to_diffusion(source) + # make sure this is returned with extra fields + second = tuple_to_diffusion(source) - self.assertEqual(source, second) - self.assertIn("bin", second) + self.assertEqual(source, second) + self.assertIn("bin", second) - def test_single_vae_tuple(self): - source = tuple_to_diffusion(["foo", "bar", True]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_single_vae_tuple(self): + source = tuple_to_diffusion(["foo", "bar", True]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_half_tuple(self): - source = tuple_to_diffusion(["foo", "bar", True]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_half_tuple(self): + source = tuple_to_diffusion(["foo", "bar", True]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_opset_tuple(self): - source = tuple_to_diffusion(["foo", "bar", 14]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_opset_tuple(self): + source = tuple_to_diffusion(["foo", "bar", 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_all_tuple(self): - source = tuple_to_diffusion(["foo", "bar", True, True, 14]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") - self.assertEqual(source["single_vae"], True) - self.assertEqual(source["half"], True) - self.assertEqual(source["opset"], 14) + def test_all_tuple(self): + source = tuple_to_diffusion(["foo", "bar", True, True, 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + self.assertEqual(source["single_vae"], True) + self.assertEqual(source["half"], True) + self.assertEqual(source["opset"], 14) class TupleToUpscalingTests(unittest.TestCase): - def test_basic_tuple(self): - source = tuple_to_upscaling(("foo", "bar")) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_basic_tuple(self): + source = tuple_to_upscaling(("foo", "bar")) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_basic_list(self): - source = tuple_to_upscaling(["foo", "bar"]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_basic_list(self): + source = tuple_to_upscaling(["foo", "bar"]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_basic_dict(self): - source = tuple_to_upscaling(["foo", "bar"]) - source["bin"] = "bin" + def test_basic_dict(self): + source = tuple_to_upscaling(["foo", "bar"]) + source["bin"] = "bin" - # make sure this is returned with extra fields - second = tuple_to_source(source) + # make sure this is returned with extra fields + second = tuple_to_source(source) - self.assertEqual(source, second) - self.assertIn("bin", second) + self.assertEqual(source, second) + self.assertIn("bin", second) - def test_scale_tuple(self): - source = tuple_to_upscaling(["foo", "bar", 2]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_scale_tuple(self): + source = tuple_to_upscaling(["foo", "bar", 2]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_half_tuple(self): - source = tuple_to_upscaling(["foo", "bar", True]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_half_tuple(self): + source = tuple_to_upscaling(["foo", "bar", True]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_opset_tuple(self): - source = tuple_to_upscaling(["foo", "bar", 14]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") + def test_opset_tuple(self): + source = tuple_to_upscaling(["foo", "bar", 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") - def test_all_tuple(self): - source = tuple_to_upscaling(["foo", "bar", 2, True, 14]) - self.assertEqual(source["name"], "foo") - self.assertEqual(source["source"], "bar") - self.assertEqual(source["scale"], 2) - self.assertEqual(source["half"], True) - self.assertEqual(source["opset"], 14) + def test_all_tuple(self): + source = tuple_to_upscaling(["foo", "bar", 2, True, 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + self.assertEqual(source["scale"], 2) + self.assertEqual(source["half"], True) + self.assertEqual(source["opset"], 14) class SourceFormatTests(unittest.TestCase): - def test_with_format(self): - result = source_format({ - "format": "foo", - }) - self.assertEqual(result, "foo") + def test_with_format(self): + result = source_format( + { + "format": "foo", + } + ) + self.assertEqual(result, "foo") - def test_source_known_extension(self): - result = source_format({ - "source": "foo.safetensors", - }) - self.assertEqual(result, "safetensors") + def test_source_known_extension(self): + result = source_format( + { + "source": "foo.safetensors", + } + ) + self.assertEqual(result, "safetensors") - def test_source_unknown_extension(self): - result = source_format({ - "source": "foo.none" - }) - self.assertEqual(result, None) + def test_source_unknown_extension(self): + result = source_format({"source": "foo.none"}) + self.assertEqual(result, None) - def test_incomplete_model(self): - self.assertIsNone(source_format({})) + def test_incomplete_model(self): + self.assertIsNone(source_format({})) class RemovePrefixTests(unittest.TestCase): - def test_with_prefix(self): - self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar") + def test_with_prefix(self): + self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar") - def test_without_prefix(self): - self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar") + def test_without_prefix(self): + self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar") class LoadTorchTests(unittest.TestCase): - pass + pass class LoadTensorTests(unittest.TestCase): - pass + pass class ResolveTensorTests(unittest.TestCase): - @test_needs_models([TEST_MODEL_UPSCALING_SWINIR]) - def test_resolve_existing(self): - self.assertEqual(resolve_tensor("../models/.cache/upscaling-swinir"), TEST_MODEL_UPSCALING_SWINIR) + @test_needs_models([TEST_MODEL_UPSCALING_SWINIR]) + def test_resolve_existing(self): + self.assertEqual( + resolve_tensor("../models/.cache/upscaling-swinir"), + TEST_MODEL_UPSCALING_SWINIR, + ) - def test_resolve_missing(self): - self.assertIsNone(resolve_tensor("missing")) + def test_resolve_missing(self): + self.assertIsNone(resolve_tensor("missing")) diff --git a/api/tests/helpers.py b/api/tests/helpers.py index 3b6716b2..64714819 100644 --- a/api/tests/helpers.py +++ b/api/tests/helpers.py @@ -6,11 +6,13 @@ from onnx_web.params import DeviceParams def test_needs_models(models: List[str]): - return skipUnless(all([path.exists(model) for model in models]), "model does not exist") + return skipUnless( + all([path.exists(model) for model in models]), "model does not exist" + ) def test_device() -> DeviceParams: - return DeviceParams("cpu", "CPUExecutionProvider") + return DeviceParams("cpu", "CPUExecutionProvider") TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5" diff --git a/api/tests/image/test_mask_filter.py b/api/tests/image/test_mask_filter.py index 58b46c7c..e36470e4 100644 --- a/api/tests/image/test_mask_filter.py +++ b/api/tests/image/test_mask_filter.py @@ -10,24 +10,24 @@ from onnx_web.image.mask_filter import ( class MaskFilterNoneTests(unittest.TestCase): - def test_basic(self): - dims = (64, 64) - mask = Image.new("RGB", dims) - result = mask_filter_none(mask, dims, (0, 0)) - self.assertEqual(result.size, dims) + def test_basic(self): + dims = (64, 64) + mask = Image.new("RGB", dims) + result = mask_filter_none(mask, dims, (0, 0)) + self.assertEqual(result.size, dims) class MaskFilterGaussianMultiplyTests(unittest.TestCase): - def test_basic(self): - dims = (64, 64) - mask = Image.new("RGB", dims) - result = mask_filter_gaussian_multiply(mask, dims, (0, 0)) - self.assertEqual(result.size, dims) + def test_basic(self): + dims = (64, 64) + mask = Image.new("RGB", dims) + result = mask_filter_gaussian_multiply(mask, dims, (0, 0)) + self.assertEqual(result.size, dims) class MaskFilterGaussianScreenTests(unittest.TestCase): - def test_basic(self): - dims = (64, 64) - mask = Image.new("RGB", dims) - result = mask_filter_gaussian_screen(mask, dims, (0, 0)) - self.assertEqual(result.size, dims) + def test_basic(self): + dims = (64, 64) + mask = Image.new("RGB", dims) + result = mask_filter_gaussian_screen(mask, dims, (0, 0)) + self.assertEqual(result.size, dims) diff --git a/api/tests/image/test_source_filter.py b/api/tests/image/test_source_filter.py index 89e73924..fb44073e 100644 --- a/api/tests/image/test_source_filter.py +++ b/api/tests/image/test_source_filter.py @@ -11,27 +11,27 @@ from onnx_web.server.context import ServerContext class SourceFilterNoneTests(unittest.TestCase): - def test_basic(self): - dims = (64, 64) - server = ServerContext() - source = Image.new("RGB", dims) - result = source_filter_none(server, source) - self.assertEqual(result.size, dims) + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_none(server, source) + self.assertEqual(result.size, dims) class SourceFilterGaussianTests(unittest.TestCase): - def test_basic(self): - dims = (64, 64) - server = ServerContext() - source = Image.new("RGB", dims) - result = source_filter_gaussian(server, source) - self.assertEqual(result.size, dims) + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_gaussian(server, source) + self.assertEqual(result.size, dims) class SourceFilterNoiseTests(unittest.TestCase): - def test_basic(self): - dims = (64, 64) - server = ServerContext() - source = Image.new("RGB", dims) - result = source_filter_noise(server, source) - self.assertEqual(result.size, dims) + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_noise(server, source) + self.assertEqual(result.size, dims) diff --git a/api/tests/image/test_utils.py b/api/tests/image/test_utils.py index f3b10fd5..215bb10b 100644 --- a/api/tests/image/test_utils.py +++ b/api/tests/image/test_utils.py @@ -7,18 +7,18 @@ from onnx_web.params import Border class ExpandImageTests(unittest.TestCase): - def test_expand(self): - result = expand_image( - Image.new("RGB", (8, 8)), - Image.new("RGB", (8, 8), "white"), - Border.even(4), - ) - self.assertEqual(result[0].size, (16, 16)) + def test_expand(self): + result = expand_image( + Image.new("RGB", (8, 8)), + Image.new("RGB", (8, 8), "white"), + Border.even(4), + ) + self.assertEqual(result[0].size, (16, 16)) - def test_masked(self): - result = expand_image( - Image.new("RGB", (8, 8), "red"), - Image.new("RGB", (8, 8), "white"), - Border.even(4), - ) - self.assertEqual(result[0].getpixel((8, 8)), (255, 0, 0)) + def test_masked(self): + result = expand_image( + Image.new("RGB", (8, 8), "red"), + Image.new("RGB", (8, 8), "white"), + Border.even(4), + ) + self.assertEqual(result[0].getpixel((8, 8)), (255, 0, 0)) diff --git a/api/tests/mocks.py b/api/tests/mocks.py index f16ae22f..ef95d754 100644 --- a/api/tests/mocks.py +++ b/api/tests/mocks.py @@ -1,43 +1,43 @@ from typing import Any, Optional -class MockPipeline(): - # flags - slice_size: Optional[str] - vae_slicing: Optional[bool] - sequential_offload: Optional[bool] - model_offload: Optional[bool] - xformers: Optional[bool] +class MockPipeline: + # flags + slice_size: Optional[str] + vae_slicing: Optional[bool] + sequential_offload: Optional[bool] + model_offload: Optional[bool] + xformers: Optional[bool] - # stubs - _encode_prompt: Optional[Any] - unet: Optional[Any] - vae_decoder: Optional[Any] - vae_encoder: Optional[Any] + # stubs + _encode_prompt: Optional[Any] + unet: Optional[Any] + vae_decoder: Optional[Any] + vae_encoder: Optional[Any] - def __init__(self) -> None: - self.slice_size = None - self.vae_slicing = None - self.sequential_offload = None - self.model_offload = None - self.xformers = None + def __init__(self) -> None: + self.slice_size = None + self.vae_slicing = None + self.sequential_offload = None + self.model_offload = None + self.xformers = None - self._encode_prompt = None - self.unet = None - self.vae_decoder = None - self.vae_encoder = None + self._encode_prompt = None + self.unet = None + self.vae_decoder = None + self.vae_encoder = None - def enable_attention_slicing(self, slice_size: str = None): - self.slice_size = slice_size + def enable_attention_slicing(self, slice_size: str = None): + self.slice_size = slice_size - def enable_vae_slicing(self): - self.vae_slicing = True + def enable_vae_slicing(self): + self.vae_slicing = True - def enable_sequential_cpu_offload(self): - self.sequential_offload = True + def enable_sequential_cpu_offload(self): + self.sequential_offload = True - def enable_model_cpu_offload(self): - self.model_offload = True + def enable_model_cpu_offload(self): + self.model_offload = True - def enable_xformers_memory_efficient_attention(self): - self.xformers = True \ No newline at end of file + def enable_xformers_memory_efficient_attention(self): + self.xformers = True diff --git a/api/tests/prompt/test_parser.py b/api/tests/prompt/test_parser.py index b6b13a23..15c91d6c 100644 --- a/api/tests/prompt/test_parser.py +++ b/api/tests/prompt/test_parser.py @@ -13,7 +13,7 @@ class ParserTests(unittest.TestCase): str(["foo"]), str(PromptPhrase(["bar"], weight=1.5)), str(["bin"]), - ] + ], ) def test_multi_word_phrase(self): @@ -24,7 +24,7 @@ class ParserTests(unittest.TestCase): str(["foo", "bar"]), str(PromptPhrase(["middle", "words"], weight=1.5)), str(["bin", "bun"]), - ] + ], ) def test_nested_phrase(self): @@ -33,7 +33,7 @@ class ParserTests(unittest.TestCase): [str(i) for i in res], [ str(["foo"]), - str(PromptPhrase(["bar"], weight=(1.5 ** 3))), + str(PromptPhrase(["bar"], weight=(1.5**3))), str(["bin"]), - ] + ], ) diff --git a/api/tests/server/test_load.py b/api/tests/server/test_load.py index c32b9663..b04df9ef 100644 --- a/api/tests/server/test_load.py +++ b/api/tests/server/test_load.py @@ -25,71 +25,85 @@ class ConfigParamTests(unittest.TestCase): params = get_config_params() self.assertIsNotNone(params) + class AvailablePlatformTests(unittest.TestCase): def test_before_setup(self): platforms = get_available_platforms() self.assertIsNotNone(platforms) + class CorrectModelTests(unittest.TestCase): def test_before_setup(self): models = get_correction_models() self.assertIsNotNone(models) + class DiffusionModelTests(unittest.TestCase): def test_before_setup(self): models = get_diffusion_models() self.assertIsNotNone(models) + class NetworkModelTests(unittest.TestCase): def test_before_setup(self): models = get_network_models() self.assertIsNotNone(models) + class UpscalingModelTests(unittest.TestCase): def test_before_setup(self): models = get_upscaling_models() self.assertIsNotNone(models) + class WildcardDataTests(unittest.TestCase): def test_before_setup(self): wildcards = get_wildcard_data() self.assertIsNotNone(wildcards) + class ExtraStringsTests(unittest.TestCase): def test_before_setup(self): strings = get_extra_strings() self.assertIsNotNone(strings) + class ExtraHashesTests(unittest.TestCase): def test_before_setup(self): hashes = get_extra_hashes() self.assertIsNotNone(hashes) + class HighresMethodTests(unittest.TestCase): def test_before_setup(self): methods = get_highres_methods() self.assertIsNotNone(methods) + class MaskFilterTests(unittest.TestCase): def test_before_setup(self): filters = get_mask_filters() self.assertIsNotNone(filters) + class NoiseSourceTests(unittest.TestCase): def test_before_setup(self): sources = get_noise_sources() self.assertIsNotNone(sources) + class SourceFilterTests(unittest.TestCase): def test_before_setup(self): filters = get_source_filters() self.assertIsNotNone(filters) + class LoadExtrasTests(unittest.TestCase): def test_default_extras(self): server = ServerContext(extra_models=["../models/extras.json"]) load_extras(server) + class LoadModelsTests(unittest.TestCase): def test_default_models(self): server = ServerContext(model_path="../models") diff --git a/api/tests/server/test_model_cache.py b/api/tests/server/test_model_cache.py index 0e4839c9..c024b611 100644 --- a/api/tests/server/test_model_cache.py +++ b/api/tests/server/test_model_cache.py @@ -4,37 +4,37 @@ from onnx_web.server.model_cache import ModelCache class TestModelCache(unittest.TestCase): - def test_drop_existing(self): - cache = ModelCache(10) - cache.clear() - cache.set("foo", ("bar",), {}) - self.assertGreater(cache.size, 0) - self.assertEqual(cache.drop("foo", ("bar",)), 1) + def test_drop_existing(self): + cache = ModelCache(10) + cache.clear() + cache.set("foo", ("bar",), {}) + self.assertGreater(cache.size, 0) + self.assertEqual(cache.drop("foo", ("bar",)), 1) - def test_drop_missing(self): - cache = ModelCache(10) - cache.clear() - cache.set("foo", ("bar",), {}) - self.assertGreater(cache.size, 0) - self.assertEqual(cache.drop("foo", ("bin",)), 0) + def test_drop_missing(self): + cache = ModelCache(10) + cache.clear() + cache.set("foo", ("bar",), {}) + self.assertGreater(cache.size, 0) + self.assertEqual(cache.drop("foo", ("bin",)), 0) - def test_get_existing(self): - cache = ModelCache(10) - cache.clear() - value = {} - cache.set("foo", ("bar",), value) - self.assertGreater(cache.size, 0) - self.assertIs(cache.get("foo", ("bar",)), value) + def test_get_existing(self): + cache = ModelCache(10) + cache.clear() + value = {} + cache.set("foo", ("bar",), value) + self.assertGreater(cache.size, 0) + self.assertIs(cache.get("foo", ("bar",)), value) - def test_get_missing(self): - cache = ModelCache(10) - cache.clear() - value = {} - cache.set("foo", ("bar",), value) - self.assertGreater(cache.size, 0) - self.assertIs(cache.get("foo", ("bin",)), None) + def test_get_missing(self): + cache = ModelCache(10) + cache.clear() + value = {} + cache.set("foo", ("bar",), value) + self.assertGreater(cache.size, 0) + self.assertIs(cache.get("foo", ("bin",)), None) - """ + """ def test_set_existing(self): cache = ModelCache(10) cache.clear() @@ -48,16 +48,16 @@ class TestModelCache(unittest.TestCase): self.assertIs(cache.get("foo", ("bar",)), value) """ - def test_set_missing(self): - cache = ModelCache(10) - cache.clear() - value = {} - cache.set("foo", ("bar",), value) - self.assertIs(cache.get("foo", ("bar",)), value) + def test_set_missing(self): + cache = ModelCache(10) + cache.clear() + value = {} + cache.set("foo", ("bar",), value) + self.assertIs(cache.get("foo", ("bar",)), value) - def test_set_zero(self): - cache = ModelCache(0) - cache.clear() - value = {} - cache.set("foo", ("bar",), value) - self.assertEqual(cache.size, 0) + def test_set_zero(self): + cache = ModelCache(0) + cache.clear() + value = {} + cache.set("foo", ("bar",), value) + self.assertEqual(cache.size, 0) diff --git a/api/tests/test_diffusers/test_load.py b/api/tests/test_diffusers/test_load.py index 8f7a3963..014f7aa0 100644 --- a/api/tests/test_diffusers/test_load.py +++ b/api/tests/test_diffusers/test_load.py @@ -24,253 +24,307 @@ from tests.mocks import MockPipeline class TestAvailablePipelines(unittest.TestCase): - def test_available_pipelines(self): - pipelines = get_available_pipelines() + def test_available_pipelines(self): + pipelines = get_available_pipelines() - self.assertIn("txt2img", pipelines) + self.assertIn("txt2img", pipelines) class TestPipelineSchedulers(unittest.TestCase): - def test_pipeline_schedulers(self): - schedulers = get_pipeline_schedulers() + def test_pipeline_schedulers(self): + schedulers = get_pipeline_schedulers() - self.assertIn("euler-a", schedulers) + self.assertIn("euler-a", schedulers) class TestSchedulerNames(unittest.TestCase): - def test_valid_name(self): - scheduler = get_scheduler_name(DDIMScheduler) + def test_valid_name(self): + scheduler = get_scheduler_name(DDIMScheduler) - self.assertEqual("ddim", scheduler) + self.assertEqual("ddim", scheduler) - def test_missing_names(self): - self.assertIsNone(get_scheduler_name("test")) + def test_missing_names(self): + self.assertIsNone(get_scheduler_name("test")) class TestOptimizePipeline(unittest.TestCase): - def test_auto_attention_slicing(self): - server = ServerContext( - optimizations=[ - "diffusers-attention-slicing-auto", - ], - ) - pipeline = MockPipeline() - optimize_pipeline(server, pipeline) - self.assertEqual(pipeline.slice_size, "auto") + def test_auto_attention_slicing(self): + server = ServerContext( + optimizations=[ + "diffusers-attention-slicing-auto", + ], + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.slice_size, "auto") - def test_max_attention_slicing(self): - server = ServerContext( - optimizations=[ - "diffusers-attention-slicing-max", - ] - ) - pipeline = MockPipeline() - optimize_pipeline(server, pipeline) - self.assertEqual(pipeline.slice_size, "max") + def test_max_attention_slicing(self): + server = ServerContext( + optimizations=[ + "diffusers-attention-slicing-max", + ] + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.slice_size, "max") - def test_vae_slicing(self): - server = ServerContext( - optimizations=[ - "diffusers-vae-slicing", - ] - ) - pipeline = MockPipeline() - optimize_pipeline(server, pipeline) - self.assertEqual(pipeline.vae_slicing, True) + def test_vae_slicing(self): + server = ServerContext( + optimizations=[ + "diffusers-vae-slicing", + ] + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.vae_slicing, True) - def test_cpu_offload_sequential(self): - server = ServerContext( - optimizations=[ - "diffusers-cpu-offload-sequential", - ] - ) - pipeline = MockPipeline() - optimize_pipeline(server, pipeline) - self.assertEqual(pipeline.sequential_offload, True) + def test_cpu_offload_sequential(self): + server = ServerContext( + optimizations=[ + "diffusers-cpu-offload-sequential", + ] + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.sequential_offload, True) - def test_cpu_offload_model(self): - server = ServerContext( - optimizations=[ - "diffusers-cpu-offload-model", - ] - ) - pipeline = MockPipeline() - optimize_pipeline(server, pipeline) - self.assertEqual(pipeline.model_offload, True) + def test_cpu_offload_model(self): + server = ServerContext( + optimizations=[ + "diffusers-cpu-offload-model", + ] + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.model_offload, True) - def test_memory_efficient_attention(self): - server = ServerContext( - optimizations=[ - "diffusers-memory-efficient-attention", - ] - ) - pipeline = MockPipeline() - optimize_pipeline(server, pipeline) - self.assertEqual(pipeline.xformers, True) + def test_memory_efficient_attention(self): + server = ServerContext( + optimizations=[ + "diffusers-memory-efficient-attention", + ] + ) + pipeline = MockPipeline() + optimize_pipeline(server, pipeline) + self.assertEqual(pipeline.xformers, True) class TestPatchPipeline(unittest.TestCase): - def test_expand_not_lpw(self): - """ - server = ServerContext() - pipeline = MockPipeline() - patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1)) - self.assertEqual(pipeline._encode_prompt, expand_prompt) - """ - pass + def test_expand_not_lpw(self): + """ + server = ServerContext() + pipeline = MockPipeline() + patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1)) + self.assertEqual(pipeline._encode_prompt, expand_prompt) + """ + pass - def test_unet_wrapper_not_xl(self): - server = ServerContext() - pipeline = MockPipeline() - patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1)) - self.assertTrue(isinstance(pipeline.unet, UNetWrapper)) + def test_unet_wrapper_not_xl(self): + server = ServerContext() + pipeline = MockPipeline() + patch_pipeline( + server, + pipeline, + None, + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertTrue(isinstance(pipeline.unet, UNetWrapper)) - def test_unet_wrapper_xl(self): - server = ServerContext() - pipeline = MockPipeline() - patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1)) - self.assertTrue(isinstance(pipeline.unet, UNetWrapper)) + def test_unet_wrapper_xl(self): + server = ServerContext() + pipeline = MockPipeline() + patch_pipeline( + server, + pipeline, + None, + ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1), + ) + self.assertTrue(isinstance(pipeline.unet, UNetWrapper)) - def test_vae_wrapper(self): - server = ServerContext() - pipeline = MockPipeline() - patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1)) - self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper)) - self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper)) + def test_vae_wrapper(self): + server = ServerContext() + pipeline = MockPipeline() + patch_pipeline( + server, + pipeline, + None, + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper)) + self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper)) class TestLoadControlNet(unittest.TestCase): - @unittest.skipUnless(path.exists("../models/control/canny.onnx"), "model does not exist") - def test_load_existing(self): - """ - Should load a model - """ - components = load_controlnet( - ServerContext(model_path="../models"), - DeviceParams("cpu", "CPUExecutionProvider"), - ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("canny", "control")), + @unittest.skipUnless( + path.exists("../models/control/canny.onnx"), "model does not exist" ) - self.assertIn("controlnet", components) + def test_load_existing(self): + """ + Should load a model + """ + components = load_controlnet( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + ImageParams( + "test", + "txt2img", + "ddim", + "test", + 1.0, + 10, + 1, + control=NetworkModel("canny", "control"), + ), + ) + self.assertIn("controlnet", components) - def test_load_missing(self): - """ - Should throw - """ - components = {} - try: - components = load_controlnet( - ServerContext(), - DeviceParams("cpu", "CPUExecutionProvider"), - ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("missing", "control")), - ) - except: - self.assertNotIn("controlnet", components) - return + def test_load_missing(self): + """ + Should throw + """ + components = {} + try: + components = load_controlnet( + ServerContext(), + DeviceParams("cpu", "CPUExecutionProvider"), + ImageParams( + "test", + "txt2img", + "ddim", + "test", + 1.0, + 10, + 1, + control=NetworkModel("missing", "control"), + ), + ) + except Exception: + self.assertNotIn("controlnet", components) + return - self.fail() + self.fail() class TestLoadTextEncoders(unittest.TestCase): - @unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist") - def test_load_embeddings(self): - """ - Should add the token to tokenizer - Should increase the encoder dims - """ - components = load_text_encoders( - ServerContext(model_path="../models"), - DeviceParams("cpu", "CPUExecutionProvider"), - "../models/stable-diffusion-onnx-v1-5", - [ - # TODO: add some embeddings - ], - [], - torch.float32, - ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + @unittest.skipUnless( + path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), + "model does not exist", ) - self.assertIn("text_encoder", components) + def test_load_embeddings(self): + """ + Should add the token to tokenizer + Should increase the encoder dims + """ + components = load_text_encoders( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/stable-diffusion-onnx-v1-5", + [ + # TODO: add some embeddings + ], + [], + torch.float32, + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertIn("text_encoder", components) - def test_load_embeddings_xl(self): - pass + def test_load_embeddings_xl(self): + pass - @unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist") - def test_load_loras(self): - components = load_text_encoders( - ServerContext(model_path="../models"), - DeviceParams("cpu", "CPUExecutionProvider"), - "../models/stable-diffusion-onnx-v1-5", - [], - [ - # TODO: add some loras - ], - torch.float32, - ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + @unittest.skipUnless( + path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), + "model does not exist", ) - self.assertIn("text_encoder", components) + def test_load_loras(self): + components = load_text_encoders( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/stable-diffusion-onnx-v1-5", + [], + [ + # TODO: add some loras + ], + torch.float32, + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertIn("text_encoder", components) + + def test_load_loras_xl(self): + pass - def test_load_loras_xl(self): - pass class TestLoadUnet(unittest.TestCase): - @unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"), "model does not exist") - def test_load_unet_loras(self): - components = load_unet( - ServerContext(model_path="../models"), - DeviceParams("cpu", "CPUExecutionProvider"), - "../models/stable-diffusion-onnx-v1-5", - [ - # TODO: add some loras - ], - "unet", - ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + @unittest.skipUnless( + path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"), + "model does not exist", ) - self.assertIn("unet", components) + def test_load_unet_loras(self): + components = load_unet( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/stable-diffusion-onnx-v1-5", + [ + # TODO: add some loras + ], + "unet", + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertIn("unet", components) - def test_load_unet_loras_xl(self): - pass + def test_load_unet_loras_xl(self): + pass - @unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"), "model does not exist") - def test_load_cnet_loras(self): - components = load_unet( - ServerContext(model_path="../models"), - DeviceParams("cpu", "CPUExecutionProvider"), - "../models/stable-diffusion-onnx-v1-5", - [ - # TODO: add some loras - ], - "cnet", - ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + @unittest.skipUnless( + path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"), + "model does not exist", ) - self.assertIn("unet", components) + def test_load_cnet_loras(self): + components = load_unet( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/stable-diffusion-onnx-v1-5", + [ + # TODO: add some loras + ], + "cnet", + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertIn("unet", components) class TestLoadVae(unittest.TestCase): - @unittest.skipUnless(path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"), "model does not exist") - def test_load_single(self): - """ - Should return single component - """ - components = load_vae( - ServerContext(model_path="../models"), - DeviceParams("cpu", "CPUExecutionProvider"), - "../models/upscaling-stable-diffusion-x4", - ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + @unittest.skipUnless( + path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"), + "model does not exist", ) - self.assertIn("vae", components) - self.assertNotIn("vae_decoder", components) - self.assertNotIn("vae_encoder", components) + def test_load_single(self): + """ + Should return single component + """ + components = load_vae( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/upscaling-stable-diffusion-x4", + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertIn("vae", components) + self.assertNotIn("vae_decoder", components) + self.assertNotIn("vae_encoder", components) - @unittest.skipUnless(path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"), "model does not exist") - def test_load_split(self): - """ - Should return split encoder/decoder - """ - components = load_vae( - ServerContext(model_path="../models"), - DeviceParams("cpu", "CPUExecutionProvider"), - "../models/stable-diffusion-onnx-v1-5", - ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + @unittest.skipUnless( + path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"), + "model does not exist", ) - self.assertNotIn("vae", components) - self.assertIn("vae_decoder", components) - self.assertIn("vae_encoder", components) + def test_load_split(self): + """ + Should return split encoder/decoder + """ + components = load_vae( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/stable-diffusion-onnx-v1-5", + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertNotIn("vae", components) + self.assertIn("vae_decoder", components) + self.assertIn("vae_encoder", components) diff --git a/api/tests/test_diffusers/test_run.py b/api/tests/test_diffusers/test_run.py index 5152a834..bb374838 100644 --- a/api/tests/test_diffusers/test_run.py +++ b/api/tests/test_diffusers/test_run.py @@ -17,155 +17,234 @@ from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_mod class TestTxt2ImgPipeline(unittest.TestCase): - @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) - def test_basic(self): - cancel = Value("L", 0) - logs = Queue() - pending = Queue() - progress = Queue() - active = Value("L", 0) - idle = Value("L", 0) + @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) + def test_basic(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) - worker = WorkerContext( - "test", - test_device(), - cancel, - logs, - pending, - progress, - active, - idle, - 3, - 0.1, - ) - worker.start("test") + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") - run_txt2img_pipeline( - worker, - ServerContext(model_path="../models", output_path="../outputs"), - ImageParams( - TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1), - Size(256, 256), - ["test-txt2img.png"], - UpscaleParams("test"), - HighresParams(False, 1, 0, 0), - ) + run_txt2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15, + "txt2img", + "ddim", + "an astronaut eating a hamburger", + 3.0, + 1, + 1, + ), + Size(256, 256), + ["test-txt2img-basic.png"], + UpscaleParams("test"), + HighresParams(False, 1, 0, 0), + ) + + self.assertTrue(path.exists("../outputs/test-txt2img-basic.png")) + output = Image.open("../outputs/test-txt2img-basic.png") + self.assertEqual(output.size, (256, 256)) + # TODO: test contents of image + + @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) + def test_highres(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) + + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") + + run_txt2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15, + "txt2img", + "ddim", + "an astronaut eating a hamburger", + 3.0, + 1, + 1, + ), + Size(256, 256), + ["test-txt2img-highres.png"], + UpscaleParams("test"), + HighresParams(True, 2, 0, 0), + ) + + self.assertTrue(path.exists("../outputs/test-txt2img-highres.png")) + output = Image.open("../outputs/test-txt2img-highres.png") + self.assertEqual(output.size, (512, 512)) - self.assertTrue(path.exists("../outputs/test-txt2img.png")) class TestImg2ImgPipeline(unittest.TestCase): - @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) - def test_basic(self): - cancel = Value("L", 0) - logs = Queue() - pending = Queue() - progress = Queue() - active = Value("L", 0) - idle = Value("L", 0) + @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) + def test_basic(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) - worker = WorkerContext( - "test", - test_device(), - cancel, - logs, - pending, - progress, - active, - idle, - 3, - 0.1, - ) - worker.start("test") + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") - source = Image.new("RGB", (64, 64), "black") - run_img2img_pipeline( - worker, - ServerContext(model_path="../models", output_path="../outputs"), - ImageParams( - TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1), - ["test-img2img.png"], - UpscaleParams("test"), - HighresParams(False, 1, 0, 0), - source, - 1.0, - ) + source = Image.new("RGB", (64, 64), "black") + run_img2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15, + "txt2img", + "ddim", + "an astronaut eating a hamburger", + 3.0, + 1, + 1, + ), + ["test-img2img.png"], + UpscaleParams("test"), + HighresParams(False, 1, 0, 0), + source, + 1.0, + ) + + self.assertTrue(path.exists("../outputs/test-img2img.png")) - self.assertTrue(path.exists("../outputs/test-img2img.png")) class TestUpscalePipeline(unittest.TestCase): - @test_needs_models(["../models/upscaling-stable-diffusion-x4"]) - def test_basic(self): - cancel = Value("L", 0) - logs = Queue() - pending = Queue() - progress = Queue() - active = Value("L", 0) - idle = Value("L", 0) + @test_needs_models(["../models/upscaling-stable-diffusion-x4"]) + def test_basic(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) - worker = WorkerContext( - "test", - test_device(), - cancel, - logs, - pending, - progress, - active, - idle, - 3, - 0.1, - ) - worker.start("test") + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") - source = Image.new("RGB", (64, 64), "black") - run_upscale_pipeline( - worker, - ServerContext(model_path="../models", output_path="../outputs"), - ImageParams( - "../models/upscaling-stable-diffusion-x4", "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1), - Size(256, 256), - ["test-upscale.png"], - UpscaleParams("test"), - HighresParams(False, 1, 0, 0), - source, - ) + source = Image.new("RGB", (64, 64), "black") + run_upscale_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + "../models/upscaling-stable-diffusion-x4", + "txt2img", + "ddim", + "an astronaut eating a hamburger", + 3.0, + 1, + 1, + ), + Size(256, 256), + ["test-upscale.png"], + UpscaleParams("test"), + HighresParams(False, 1, 0, 0), + source, + ) + + self.assertTrue(path.exists("../outputs/test-upscale.png")) - self.assertTrue(path.exists("../outputs/test-upscale.png")) class TestBlendPipeline(unittest.TestCase): - def test_basic(self): - cancel = Value("L", 0) - logs = Queue() - pending = Queue() - progress = Queue() - active = Value("L", 0) - idle = Value("L", 0) + def test_basic(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) - worker = WorkerContext( - "test", - test_device(), - cancel, - logs, - pending, - progress, - active, - idle, - 3, - 0.1, - ) - worker.start("test") + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") - source = Image.new("RGBA", (64, 64), "black") - mask = Image.new("RGBA", (64, 64), "white") - run_blend_pipeline( - worker, - ServerContext(model_path="../models", output_path="../outputs"), - ImageParams( - TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1), - Size(64, 64), - ["test-blend.png"], - UpscaleParams("test"), - [source, source], - mask, - ) + source = Image.new("RGBA", (64, 64), "black") + mask = Image.new("RGBA", (64, 64), "white") + run_blend_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15, + "txt2img", + "ddim", + "an astronaut eating a hamburger", + 3.0, + 1, + 1, + ), + Size(64, 64), + ["test-blend.png"], + UpscaleParams("test"), + [source, source], + mask, + ) - self.assertTrue(path.exists("../outputs/test-blend.png")) + self.assertTrue(path.exists("../outputs/test-blend.png")) diff --git a/api/tests/test_diffusers/test_utils.py b/api/tests/test_diffusers/test_utils.py index a98647cb..0e576d8b 100644 --- a/api/tests/test_diffusers/test_utils.py +++ b/api/tests/test_diffusers/test_utils.py @@ -10,7 +10,6 @@ from onnx_web.diffusers.utils import ( get_loras_from_prompt, get_scaled_latents, get_tile_latents, - get_tokens_from_prompt, pop_random, slice_prompt, ) @@ -18,110 +17,128 @@ from onnx_web.params import Size class TestExpandIntervalRanges(unittest.TestCase): - def test_prompt_with_no_ranges(self): - prompt = "an astronaut eating a hamburger" - result = expand_interval_ranges(prompt) - self.assertEqual(prompt, result) + def test_prompt_with_no_ranges(self): + prompt = "an astronaut eating a hamburger" + result = expand_interval_ranges(prompt) + self.assertEqual(prompt, result) + + def test_prompt_with_range(self): + prompt = "an astronaut-{1,4} eating a hamburger" + result = expand_interval_ranges(prompt) + self.assertEqual( + result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger" + ) - def test_prompt_with_range(self): - prompt = "an astronaut-{1,4} eating a hamburger" - result = expand_interval_ranges(prompt) - self.assertEqual(result, "an astronaut-1 astronaut-2 astronaut-3 eating a hamburger") class TestExpandAlternativeRanges(unittest.TestCase): - def test_prompt_with_no_ranges(self): - prompt = "an astronaut eating a hamburger" - result = expand_alternative_ranges(prompt) - self.assertEqual([prompt], result) + def test_prompt_with_no_ranges(self): + prompt = "an astronaut eating a hamburger" + result = expand_alternative_ranges(prompt) + self.assertEqual([prompt], result) + + def test_ranges_match(self): + prompt = "(an astronaut|a squirrel) eating (a hamburger|an acorn)" + result = expand_alternative_ranges(prompt) + self.assertEqual( + result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"] + ) - def test_ranges_match(self): - prompt = "(an astronaut|a squirrel) eating (a hamburger|an acorn)" - result = expand_alternative_ranges(prompt) - self.assertEqual(result, ["an astronaut eating a hamburger", "a squirrel eating an acorn"]) class TestInversionsFromPrompt(unittest.TestCase): - def test_get_inversions(self): - prompt = " an astronaut eating an embedding" - result, tokens = get_inversions_from_prompt(prompt) + def test_get_inversions(self): + prompt = " an astronaut eating an embedding" + result, tokens = get_inversions_from_prompt(prompt) + + self.assertEqual(result, " an astronaut eating an embedding") + self.assertEqual(tokens, [("test", 1.0)]) - self.assertEqual(result, " an astronaut eating an embedding") - self.assertEqual(tokens, [("test", 1.0)]) class TestLoRAsFromPrompt(unittest.TestCase): - def test_get_loras(self): - prompt = " an astronaut eating a LoRA" - result, tokens = get_loras_from_prompt(prompt) + def test_get_loras(self): + prompt = " an astronaut eating a LoRA" + result, tokens = get_loras_from_prompt(prompt) + + self.assertEqual(result, " an astronaut eating a LoRA") + self.assertEqual(tokens, [("test", 1.0)]) - self.assertEqual(result, " an astronaut eating a LoRA") - self.assertEqual(tokens, [("test", 1.0)]) class TestLatentsFromSeed(unittest.TestCase): - def test_batch_size(self): - latents = get_latents_from_seed(1, Size(64, 64), batch=4) - self.assertEqual(latents.shape, (4, 4, 8, 8)) + def test_batch_size(self): + latents = get_latents_from_seed(1, Size(64, 64), batch=4) + self.assertEqual(latents.shape, (4, 4, 8, 8)) + + def test_consistency(self): + latents1 = get_latents_from_seed(1, Size(64, 64)) + latents2 = get_latents_from_seed(1, Size(64, 64)) + self.assertTrue(np.array_equal(latents1, latents2)) - def test_consistency(self): - latents1 = get_latents_from_seed(1, Size(64, 64)) - latents2 = get_latents_from_seed(1, Size(64, 64)) - self.assertTrue(np.array_equal(latents1, latents2)) class TestTileLatents(unittest.TestCase): - def test_full_tile(self): - partial = np.zeros((1, 1, 64, 64)) - full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64)) - self.assertEqual(full.shape, (1, 1, 8, 8)) + def test_full_tile(self): + partial = np.zeros((1, 1, 64, 64)) + full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64)) + self.assertEqual(full.shape, (1, 1, 8, 8)) - def test_contract_tile(self): - partial = np.zeros((1, 1, 64, 64)) - full = get_tile_latents(partial, 1, (32, 32), (0, 0, 32)) - self.assertEqual(full.shape, (1, 1, 4, 4)) + def test_contract_tile(self): + partial = np.zeros((1, 1, 64, 64)) + full = get_tile_latents(partial, 1, (32, 32), (0, 0, 32)) + self.assertEqual(full.shape, (1, 1, 4, 4)) + + def test_expand_tile(self): + partial = np.zeros((1, 1, 32, 32)) + full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64)) + self.assertEqual(full.shape, (1, 1, 8, 8)) - def test_expand_tile(self): - partial = np.zeros((1, 1, 32, 32)) - full = get_tile_latents(partial, 1, (64, 64), (0, 0, 64)) - self.assertEqual(full.shape, (1, 1, 8, 8)) class TestScaledLatents(unittest.TestCase): - def test_scale_up(self): - latents = get_latents_from_seed(1, Size(16, 16)) - scaled = get_scaled_latents(1, Size(16, 16), scale=2) - self.assertEqual(latents[0, 0, 0, 0], scaled[0, 0, 0, 0]) + def test_scale_up(self): + latents = get_latents_from_seed(1, Size(16, 16)) + scaled = get_scaled_latents(1, Size(16, 16), scale=2) + self.assertEqual(latents[0, 0, 0, 0], scaled[0, 0, 0, 0]) + + def test_scale_down(self): + latents = get_latents_from_seed(1, Size(16, 16)) + scaled = get_scaled_latents(1, Size(16, 16), scale=0.5) + self.assertEqual( + ( + latents[0, 0, 0, 0] + + latents[0, 0, 0, 1] + + latents[0, 0, 1, 0] + + latents[0, 0, 1, 1] + ) + / 4, + scaled[0, 0, 0, 0], + ) - def test_scale_down(self): - latents = get_latents_from_seed(1, Size(16, 16)) - scaled = get_scaled_latents(1, Size(16, 16), scale=0.5) - self.assertEqual(( - latents[0, 0, 0, 0] + - latents[0, 0, 0, 1] + - latents[0, 0, 1, 0] + - latents[0, 0, 1, 1] - ) / 4, scaled[0, 0, 0, 0]) class TestReplaceWildcards(unittest.TestCase): - pass + pass + class TestPopRandom(unittest.TestCase): - def test_pop(self): - items = ["1", "2", "3"] - pop_random(items) - self.assertEqual(len(items), 2) + def test_pop(self): + items = ["1", "2", "3"] + pop_random(items) + self.assertEqual(len(items), 2) + class TestRepairNaN(unittest.TestCase): - def test_unchanged(self): - pass + def test_unchanged(self): + pass + + def test_missing(self): + pass - def test_missing(self): - pass class TestSlicePrompt(unittest.TestCase): - def test_slice_no_delimiter(self): - slice = slice_prompt("foo", 1) - self.assertEqual(slice, "foo") + def test_slice_no_delimiter(self): + slice = slice_prompt("foo", 1) + self.assertEqual(slice, "foo") - def test_slice_within_range(self): - slice = slice_prompt("foo || bar", 1) - self.assertEqual(slice, " bar") + def test_slice_within_range(self): + slice = slice_prompt("foo || bar", 1) + self.assertEqual(slice, " bar") - def test_slice_outside_range(self): - slice = slice_prompt("foo || bar", 9) - self.assertEqual(slice, " bar") + def test_slice_outside_range(self): + slice = slice_prompt("foo || bar", 9) + self.assertEqual(slice, " bar") diff --git a/api/tests/worker/test_pool.py b/api/tests/worker/test_pool.py index d0a36982..3f6f13cd 100644 --- a/api/tests/worker/test_pool.py +++ b/api/tests/worker/test_pool.py @@ -13,122 +13,128 @@ lock = Event() def test_job(*args, **kwargs): - lock.wait() + lock.wait() def wait_job(*args, **kwargs): - sleep(0.5) + sleep(0.5) class TestWorkerPool(unittest.TestCase): - # lock: Optional[Event] - pool: Optional[DevicePoolExecutor] + # lock: Optional[Event] + pool: Optional[DevicePoolExecutor] - def setUp(self) -> None: - self.pool = None + def setUp(self) -> None: + self.pool = None - def tearDown(self) -> None: - if self.pool is not None: - self.pool.join() + def tearDown(self) -> None: + if self.pool is not None: + self.pool.join() - def test_no_devices(self): - server = ServerContext() - self.pool = DevicePoolExecutor(server, [], join_timeout=TEST_JOIN_TIMEOUT) - self.pool.start() + def test_no_devices(self): + server = ServerContext() + self.pool = DevicePoolExecutor(server, [], join_timeout=TEST_JOIN_TIMEOUT) + self.pool.start() - def test_fake_worker(self): - device = DeviceParams("cpu", "CPUProvider") - server = ServerContext() - self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) - self.pool.start() - self.assertEqual(len(self.pool.workers), 1) + def test_fake_worker(self): + device = DeviceParams("cpu", "CPUProvider") + server = ServerContext() + self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) + self.pool.start() + self.assertEqual(len(self.pool.workers), 1) - def test_cancel_pending(self): - device = DeviceParams("cpu", "CPUProvider") - server = ServerContext() + def test_cancel_pending(self): + device = DeviceParams("cpu", "CPUProvider") + server = ServerContext() - self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) - self.pool.start() + self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) + self.pool.start() - self.pool.submit("test", wait_job, lock=lock) - self.assertEqual(self.pool.done("test"), (True, None)) + self.pool.submit("test", wait_job, lock=lock) + self.assertEqual(self.pool.done("test"), (True, None)) - self.assertTrue(self.pool.cancel("test")) - self.assertEqual(self.pool.done("test"), (False, None)) + self.assertTrue(self.pool.cancel("test")) + self.assertEqual(self.pool.done("test"), (False, None)) - def test_cancel_running(self): - pass + def test_cancel_running(self): + pass - def test_next_device(self): - device = DeviceParams("cpu", "CPUProvider") - server = ServerContext() - self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) - self.pool.start() + def test_next_device(self): + device = DeviceParams("cpu", "CPUProvider") + server = ServerContext() + self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) + self.pool.start() - self.assertEqual(self.pool.get_next_device(), 0) + self.assertEqual(self.pool.get_next_device(), 0) - def test_needs_device(self): - device1 = DeviceParams("cpu1", "CPUProvider") - device2 = DeviceParams("cpu2", "CPUProvider") - server = ServerContext() - self.pool = DevicePoolExecutor(server, [device1, device2], join_timeout=TEST_JOIN_TIMEOUT) - self.pool.start() + def test_needs_device(self): + device1 = DeviceParams("cpu1", "CPUProvider") + device2 = DeviceParams("cpu2", "CPUProvider") + server = ServerContext() + self.pool = DevicePoolExecutor( + server, [device1, device2], join_timeout=TEST_JOIN_TIMEOUT + ) + self.pool.start() - self.assertEqual(self.pool.get_next_device(needs_device=device2), 1) + self.assertEqual(self.pool.get_next_device(needs_device=device2), 1) - def test_done_running(self): - """ - TODO: flaky - """ - device = DeviceParams("cpu", "CPUProvider") - server = ServerContext() + def test_done_running(self): + """ + TODO: flaky + """ + device = DeviceParams("cpu", "CPUProvider") + server = ServerContext() - self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1) - self.pool.start(lock) - sleep(2.0) + self.pool = DevicePoolExecutor( + server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 + ) + self.pool.start(lock) + sleep(2.0) - self.pool.submit("test", test_job) - sleep(2.0) + self.pool.submit("test", test_job) + sleep(2.0) - pending, _progress = self.pool.done("test") - self.assertFalse(pending) + pending, _progress = self.pool.done("test") + self.assertFalse(pending) - def test_done_pending(self): - device = DeviceParams("cpu", "CPUProvider") - server = ServerContext() + def test_done_pending(self): + device = DeviceParams("cpu", "CPUProvider") + server = ServerContext() - self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) - self.pool.start(lock) + self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) + self.pool.start(lock) - self.pool.submit("test1", test_job) - self.pool.submit("test2", test_job) - self.assertTrue(self.pool.done("test2"), (True, None)) + self.pool.submit("test1", test_job) + self.pool.submit("test2", test_job) + self.assertTrue(self.pool.done("test2"), (True, None)) - lock.set() + lock.set() - def test_done_finished(self): - """ - TODO: flaky - """ - device = DeviceParams("cpu", "CPUProvider") - server = ServerContext() + def test_done_finished(self): + """ + TODO: flaky + """ + device = DeviceParams("cpu", "CPUProvider") + server = ServerContext() - self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1) - self.pool.start() - sleep(2.0) + self.pool = DevicePoolExecutor( + server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 + ) + self.pool.start() + sleep(2.0) - self.pool.submit("test", wait_job) - self.assertEqual(self.pool.done("test"), (True, None)) + self.pool.submit("test", wait_job) + self.assertEqual(self.pool.done("test"), (True, None)) - sleep(2.0) - pending, _progress = self.pool.done("test") - self.assertFalse(pending) + sleep(2.0) + pending, _progress = self.pool.done("test") + self.assertFalse(pending) - def test_recycle_live(self): - pass + def test_recycle_live(self): + pass - def test_recycle_dead(self): - pass + def test_recycle_dead(self): + pass - def test_running_status(self): - pass \ No newline at end of file + def test_running_status(self): + pass diff --git a/api/tests/worker/test_worker.py b/api/tests/worker/test_worker.py index 9f02d4e9..993c6d67 100644 --- a/api/tests/worker/test_worker.py +++ b/api/tests/worker/test_worker.py @@ -18,119 +18,194 @@ from tests.helpers import test_device def main_memory(_worker): - raise Exception(MEMORY_ERRORS[0]) + raise Exception(MEMORY_ERRORS[0]) + def main_retry(_worker): - raise RetryException() + raise RetryException() + def main_interrupt(_worker): - raise KeyboardInterrupt() + raise KeyboardInterrupt() class WorkerMainTests(unittest.TestCase): - def test_pending_exception_empty(self): - pass + def test_pending_exception_empty(self): + pass - def test_pending_exception_interrupt(self): - status = None + def test_pending_exception_interrupt(self): + status = None - def exit(exit_status): - nonlocal status - status = exit_status + def exit(exit_status): + nonlocal status + status = exit_status - job = JobCommand("test", "test", main_interrupt, [], {}) - cancel = Value("L", False) - logs = Queue() - pending = Queue() - progress = Queue() - pid = Value("L", getpid()) - idle = Value("L", False) + job = JobCommand("test", "test", main_interrupt, [], {}) + cancel = Value("L", False) + logs = Queue() + pending = Queue() + progress = Queue() + pid = Value("L", getpid()) + idle = Value("L", False) - pending.put(job) - worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) + pending.put(job) + worker_main( + WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + pid, + idle, + 0, + 0.0, + ), + ServerContext(), + exit=exit, + ) - self.assertEqual(status, EXIT_INTERRUPT) - pass + self.assertEqual(status, EXIT_INTERRUPT) + pass - def test_pending_exception_retry(self): - status = None + def test_pending_exception_retry(self): + status = None - def exit(exit_status): - nonlocal status - status = exit_status + def exit(exit_status): + nonlocal status + status = exit_status - job = JobCommand("test", "test", main_retry, [], {}) - cancel = Value("L", False) - logs = Queue() - pending = Queue() - progress = Queue() - pid = Value("L", getpid()) - idle = Value("L", False) + job = JobCommand("test", "test", main_retry, [], {}) + cancel = Value("L", False) + logs = Queue() + pending = Queue() + progress = Queue() + pid = Value("L", getpid()) + idle = Value("L", False) - pending.put(job) - worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) + pending.put(job) + worker_main( + WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + pid, + idle, + 0, + 0.0, + ), + ServerContext(), + exit=exit, + ) - self.assertEqual(status, EXIT_ERROR) - pass + self.assertEqual(status, EXIT_ERROR) + pass - def test_pending_exception_value(self): - status = None + def test_pending_exception_value(self): + status = None - def exit(exit_status): - nonlocal status - status = exit_status + def exit(exit_status): + nonlocal status + status = exit_status - cancel = Value("L", False) - logs = Queue() - pending = Queue() - progress = Queue() - pid = Value("L", getpid()) - idle = Value("L", False) + cancel = Value("L", False) + logs = Queue() + pending = Queue() + progress = Queue() + pid = Value("L", getpid()) + idle = Value("L", False) - pending.close() - worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) + pending.close() + worker_main( + WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + pid, + idle, + 0, + 0.0, + ), + ServerContext(), + exit=exit, + ) - self.assertEqual(status, EXIT_ERROR) + self.assertEqual(status, EXIT_ERROR) - def test_pending_exception_other_memory(self): - status = None + def test_pending_exception_other_memory(self): + status = None - def exit(exit_status): - nonlocal status - status = exit_status + def exit(exit_status): + nonlocal status + status = exit_status - job = JobCommand("test", "test", main_memory, [], {}) - cancel = Value("L", False) - logs = Queue() - pending = Queue() - progress = Queue() - pid = Value("L", getpid()) - idle = Value("L", False) + job = JobCommand("test", "test", main_memory, [], {}) + cancel = Value("L", False) + logs = Queue() + pending = Queue() + progress = Queue() + pid = Value("L", getpid()) + idle = Value("L", False) - pending.put(job) - worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) + pending.put(job) + worker_main( + WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + pid, + idle, + 0, + 0.0, + ), + ServerContext(), + exit=exit, + ) - self.assertEqual(status, EXIT_MEMORY) + self.assertEqual(status, EXIT_MEMORY) + def test_pending_exception_other_unknown(self): + pass - def test_pending_exception_other_unknown(self): - pass + def test_pending_replaced(self): + status = None - def test_pending_replaced(self): - status = None + def exit(exit_status): + nonlocal status + status = exit_status - def exit(exit_status): - nonlocal status - status = exit_status + cancel = Value("L", False) + logs = Queue() + pending = Queue() + progress = Queue() + pid = Value("L", 0) + idle = Value("L", False) - cancel = Value("L", False) - logs = Queue() - pending = Queue() - progress = Queue() - pid = Value("L", 0) - idle = Value("L", False) - - worker_main(WorkerContext("test", test_device(), cancel, logs, pending, progress, pid, idle, 0, 0.0), ServerContext(), exit=exit) - - self.assertEqual(status, EXIT_REPLACED) + worker_main( + WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + pid, + idle, + 0, + 0.0, + ), + ServerContext(), + exit=exit, + ) + self.assertEqual(status, EXIT_REPLACED)