From 6b0b2e41a68fb3a48800c5bdb8ad1ae1a6b1ae7c Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 5 Jan 2024 20:13:57 -0600 Subject: [PATCH] update tests --- api/onnx_web/convert/utils.py | 6 +- api/onnx_web/params.py | 6 + api/onnx_web/worker/command.py | 7 + .../chain/test_blend_denoise_localstd.py | 50 +++ api/tests/chain/test_blend_grid.py | 13 +- api/tests/chain/test_blend_img2img.py | 15 +- api/tests/chain/test_blend_linear.py | 14 +- api/tests/chain/test_blend_mask.py | 1 + api/tests/chain/test_correct_codeformer.py | 1 + api/tests/chain/test_correct_gfpgan.py | 12 +- api/tests/chain/test_reduce_crop.py | 1 + api/tests/chain/test_reduce_thumbnail.py | 1 + api/tests/chain/test_source_noise.py | 1 + api/tests/chain/test_source_s3.py | 1 + api/tests/chain/test_source_url.py | 1 + api/tests/chain/test_upscale_bsrgan.py | 1 + api/tests/chain/test_upscale_highres.py | 1 + api/tests/chain/test_upscale_outpaint.py | 1 + api/tests/chain/test_upscale_resrgan.py | 1 + api/tests/chain/test_upscale_swinir.py | 1 + api/tests/convert/test_utils.py | 290 +++++++++++++++++- api/tests/test_diffusers/test_run.py | 21 +- api/tests/worker/test_pool.py | 70 ++++- 23 files changed, 467 insertions(+), 49 deletions(-) create mode 100644 api/tests/chain/test_blend_denoise_localstd.py diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 6b4dca35..69c7b37b 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -107,7 +107,7 @@ def download_progress(source: str, dest: str): stream=True, allow_redirects=True, headers={ - "User-Agent": "onnx-web-api", + "User-Agent": "onnx-web-api", # TODO: add version }, ) if req.status_code != 200: @@ -226,9 +226,7 @@ def load_torch(name: str, map_location=None) -> Optional[Dict]: logger.debug("loading tensor with Torch: %s", name) checkpoint = torch.load(name, map_location=map_location) except Exception: - logger.exception( - "error loading with Torch JIT, trying with Torch JIT: %s", name - ) + logger.exception("error loading with Torch, trying with Torch JIT: %s", name) checkpoint = torch.jit.load(name) return checkpoint diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 8e8c593a..b78f2421 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -81,6 +81,12 @@ class Size: def __str__(self) -> str: return "%sx%s" % (self.width, self.height) + def __eq__(self, other: Any) -> bool: + if isinstance(other, Size): + return self.width == other.width and self.height == other.height + + return False + def add_border(self, border: Border): return Size( border.left + self.width + border.right, diff --git a/api/onnx_web/worker/command.py b/api/onnx_web/worker/command.py index 1ba72a2b..8e303652 100644 --- a/api/onnx_web/worker/command.py +++ b/api/onnx_web/worker/command.py @@ -22,6 +22,13 @@ class JobType(str, Enum): class Progress: + """ + Generic counter with current and expected/final/total value. Can be used to count up or down. + + Counter is considered "complete" when the current value is greater than or equal to the total value, and "empty" + when the current value is zero. + """ + current: int total: int diff --git a/api/tests/chain/test_blend_denoise_localstd.py b/api/tests/chain/test_blend_denoise_localstd.py new file mode 100644 index 00000000..cb908bed --- /dev/null +++ b/api/tests/chain/test_blend_denoise_localstd.py @@ -0,0 +1,50 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.blend_denoise_localstd import BlendDenoiseLocalStdStage +from onnx_web.chain.result import ImageMetadata, StageResult +from onnx_web.params import ImageParams, Size + + +class TestBlendDenoiseLocalStdStage(unittest.TestCase): + def test_run(self): + # Create a dummy image + image = Image.new("RGB", (64, 64), color="white") + + # Create a dummy StageResult object + sources = StageResult.from_images( + [image], + metadata=[ + ImageMetadata( + ImageParams("test", "txt2img", "ddim", "test", 5.0, 25, 0), + Size(64, 64), + ) + ], + ) + + # Create an instance of BlendDenoiseLocalStdStage + stage = BlendDenoiseLocalStdStage() + + # Call the run method with dummy parameters + result = stage.run( + _worker=None, + _server=None, + _stage=None, + _params=None, + sources=sources, + strength=5, + range=4, + stage_source=None, + callback=None, + ) + + # Assert that the result is an instance of StageResult + self.assertIsInstance(result, StageResult) + + # Assert that the result contains the denoised image + self.assertEqual(len(result), 1) + self.assertEqual(result.size(), Size(64, 64)) + + # Assert that the metadata is preserved + self.assertEqual(result.metadata, sources.metadata) diff --git a/api/tests/chain/test_blend_grid.py b/api/tests/chain/test_blend_grid.py index 0e6188b1..14620829 100644 --- a/api/tests/chain/test_blend_grid.py +++ b/api/tests/chain/test_blend_grid.py @@ -3,7 +3,8 @@ import unittest from PIL import Image from onnx_web.chain.blend_grid import BlendGridStage -from onnx_web.chain.result import StageResult +from onnx_web.chain.result import ImageMetadata, StageResult +from onnx_web.params import ImageParams, Size class BlendGridStageTests(unittest.TestCase): @@ -15,9 +16,17 @@ class BlendGridStageTests(unittest.TestCase): Image.new("RGB", (64, 64), "white"), Image.new("RGB", (64, 64), "black"), Image.new("RGB", (64, 64), "white"), + ], + metadata=[ + ImageMetadata( + ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1), + Size(64, 64), + ), ] + * 4, ) result = stage.run(None, None, None, None, sources, height=2, width=2) + result.validate() self.assertEqual(len(result), 5) - self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0)) + self.assertEqual(result.as_images()[-1].getpixel((0, 0)), (0, 0, 0)) diff --git a/api/tests/chain/test_blend_img2img.py b/api/tests/chain/test_blend_img2img.py index 9d6f71d9..d7ab2211 100644 --- a/api/tests/chain/test_blend_img2img.py +++ b/api/tests/chain/test_blend_img2img.py @@ -3,8 +3,8 @@ import unittest from PIL import Image from onnx_web.chain.blend_img2img import BlendImg2ImgStage -from onnx_web.chain.result import StageResult -from onnx_web.params import ImageParams +from onnx_web.chain.result import ImageMetadata, StageResult +from onnx_web.params import ImageParams, Size from onnx_web.server.context import ServerContext from onnx_web.worker.context import WorkerContext from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models @@ -39,9 +39,16 @@ class BlendImg2ImgStageTests(unittest.TestCase): sources = StageResult( images=[ Image.new("RGB", (64, 64), "black"), - ] + ], + metadata=[ + ImageMetadata( + ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1), + Size(64, 64), + ), + ], ) result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1) + result.validate() self.assertEqual(len(result), 1) - self.assertEqual(result.as_image()[0].getpixel((0, 0)), (0, 0, 0)) + self.assertEqual(result.as_images()[0].getpixel((0, 0)), (0, 0, 0)) diff --git a/api/tests/chain/test_blend_linear.py b/api/tests/chain/test_blend_linear.py index 76a2715a..aab545fa 100644 --- a/api/tests/chain/test_blend_linear.py +++ b/api/tests/chain/test_blend_linear.py @@ -3,7 +3,8 @@ import unittest from PIL import Image from onnx_web.chain.blend_linear import BlendLinearStage -from onnx_web.chain.result import StageResult +from onnx_web.chain.result import ImageMetadata, StageResult +from onnx_web.params import ImageParams, Size class BlendLinearStageTests(unittest.TestCase): @@ -12,12 +13,19 @@ class BlendLinearStageTests(unittest.TestCase): sources = StageResult( images=[ Image.new("RGB", (64, 64), "black"), - ] + ], + metadata=[ + ImageMetadata( + ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1), + Size(64, 64), + ), + ], ) stage_source = Image.new("RGB", (64, 64), "white") result = stage.run( None, None, None, None, sources, alpha=0.5, stage_source=stage_source ) + result.validate() self.assertEqual(len(result), 1) - self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127)) + self.assertEqual(result.as_images()[0].getpixel((0, 0)), (127, 127, 127)) diff --git a/api/tests/chain/test_blend_mask.py b/api/tests/chain/test_blend_mask.py index f168fab9..6611e3a6 100644 --- a/api/tests/chain/test_blend_mask.py +++ b/api/tests/chain/test_blend_mask.py @@ -23,5 +23,6 @@ class BlendMaskStageTests(unittest.TestCase): stage_source=Image.new("RGBA", (64, 64)), dims=(0, 0, SizeChart.auto), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_correct_codeformer.py b/api/tests/chain/test_correct_codeformer.py index 1498beeb..e1ac41de 100644 --- a/api/tests/chain/test_correct_codeformer.py +++ b/api/tests/chain/test_correct_codeformer.py @@ -42,5 +42,6 @@ class CorrectCodeformerStageTests(unittest.TestCase): highres=HighresParams(False, 1, 0, 0), upscale=UpscaleParams(""), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_correct_gfpgan.py b/api/tests/chain/test_correct_gfpgan.py index 9f8b6cb3..a90449fb 100644 --- a/api/tests/chain/test_correct_gfpgan.py +++ b/api/tests/chain/test_correct_gfpgan.py @@ -6,13 +6,14 @@ from onnx_web.params import HighresParams, UpscaleParams from onnx_web.server.context import ServerContext from onnx_web.server.hacks import apply_patches from onnx_web.worker.context import WorkerContext -from tests.helpers import test_device, test_needs_onnx_models +from tests.helpers import test_device, test_needs_models -TEST_MODEL = "../models/correction-gfpgan-v1-3" +TEST_MODEL_NAME = "correction-gfpgan-v1-3" +TEST_MODEL = f"../models/.cache/{TEST_MODEL_NAME}.pth" class CorrectGFPGANStageTests(unittest.TestCase): - @test_needs_onnx_models([TEST_MODEL]) + @test_needs_models([TEST_MODEL]) def test_empty(self): server = ServerContext(model_path="../models", output_path="../outputs") apply_patches(server) @@ -33,12 +34,13 @@ class CorrectGFPGANStageTests(unittest.TestCase): sources = StageResult.empty() result = stage.run( worker, - None, + server, None, None, sources, highres=HighresParams(False, 1, 0, 0), - upscale=UpscaleParams(TEST_MODEL), + upscale=UpscaleParams(TEST_MODEL_NAME, TEST_MODEL_NAME), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_reduce_crop.py b/api/tests/chain/test_reduce_crop.py index bfc7adc4..81629df2 100644 --- a/api/tests/chain/test_reduce_crop.py +++ b/api/tests/chain/test_reduce_crop.py @@ -20,5 +20,6 @@ class ReduceCropStageTests(unittest.TestCase): origin=Size(0, 0), size=Size(128, 128), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_reduce_thumbnail.py b/api/tests/chain/test_reduce_thumbnail.py index 8b129672..a162bf5e 100644 --- a/api/tests/chain/test_reduce_thumbnail.py +++ b/api/tests/chain/test_reduce_thumbnail.py @@ -24,5 +24,6 @@ class ReduceThumbnailStageTests(unittest.TestCase): size=Size(128, 128), stage_source=stage_source, ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_source_noise.py b/api/tests/chain/test_source_noise.py index 37c99bfa..40b0c437 100644 --- a/api/tests/chain/test_source_noise.py +++ b/api/tests/chain/test_source_noise.py @@ -22,5 +22,6 @@ class SourceNoiseStageTests(unittest.TestCase): size=Size(128, 128), noise_source=noise_source_fill_edge, ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_source_s3.py b/api/tests/chain/test_source_s3.py index 59bbb72f..8587859f 100644 --- a/api/tests/chain/test_source_s3.py +++ b/api/tests/chain/test_source_s3.py @@ -22,5 +22,6 @@ class SourceS3StageTests(unittest.TestCase): bucket="test", source_keys=[], ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_source_url.py b/api/tests/chain/test_source_url.py index 4d03dedb..59d3f990 100644 --- a/api/tests/chain/test_source_url.py +++ b/api/tests/chain/test_source_url.py @@ -21,5 +21,6 @@ class SourceURLStageTests(unittest.TestCase): size=Size(128, 128), source_urls=[], ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_bsrgan.py b/api/tests/chain/test_upscale_bsrgan.py index f93b800c..cca11f11 100644 --- a/api/tests/chain/test_upscale_bsrgan.py +++ b/api/tests/chain/test_upscale_bsrgan.py @@ -37,5 +37,6 @@ class UpscaleBSRGANStageTests(unittest.TestCase): highres=HighresParams(False, 1, 0, 0), upscale=UpscaleParams(TEST_MODEL), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_highres.py b/api/tests/chain/test_upscale_highres.py index 096eea54..788d3612 100644 --- a/api/tests/chain/test_upscale_highres.py +++ b/api/tests/chain/test_upscale_highres.py @@ -18,5 +18,6 @@ class UpscaleHighresStageTests(unittest.TestCase): highres=HighresParams(False, 1, 0, 0), upscale=UpscaleParams(""), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_outpaint.py b/api/tests/chain/test_upscale_outpaint.py index 261a8d45..19a67dd7 100644 --- a/api/tests/chain/test_upscale_outpaint.py +++ b/api/tests/chain/test_upscale_outpaint.py @@ -46,5 +46,6 @@ class UpscaleOutpaintStageTests(unittest.TestCase): dims=(), tile_mask=Image.new("RGB", (64, 64)), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_resrgan.py b/api/tests/chain/test_upscale_resrgan.py index f832767f..f464947e 100644 --- a/api/tests/chain/test_upscale_resrgan.py +++ b/api/tests/chain/test_upscale_resrgan.py @@ -35,5 +35,6 @@ class UpscaleRealESRGANStageTests(unittest.TestCase): highres=HighresParams(False, 1, 0, 0), upscale=UpscaleParams("upscaling-real-esrgan-x4-v3"), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_swinir.py b/api/tests/chain/test_upscale_swinir.py index dfa9676e..ce23695d 100644 --- a/api/tests/chain/test_upscale_swinir.py +++ b/api/tests/chain/test_upscale_swinir.py @@ -35,5 +35,6 @@ class UpscaleSwinIRStageTests(unittest.TestCase): highres=HighresParams(False, 1, 0, 0), upscale=UpscaleParams(TEST_MODEL), ) + result.validate() self.assertEqual(len(result), 0) diff --git a/api/tests/convert/test_utils.py b/api/tests/convert/test_utils.py index 34b0bf9b..da8bc550 100644 --- a/api/tests/convert/test_utils.py +++ b/api/tests/convert/test_utils.py @@ -1,9 +1,17 @@ import unittest +from os import path +from unittest import mock +from unittest.mock import MagicMock, patch from onnx_web.convert.utils import ( DEFAULT_OPSET, ConversionContext, + build_cache_paths, download_progress, + fix_diffusion_name, + get_first_exists, + load_tensor, + load_torch, remove_prefix, resolve_tensor, source_format, @@ -30,6 +38,32 @@ class DownloadProgressTests(unittest.TestCase): path = download_progress("https://example.com", "/tmp/example-dot-com") self.assertEqual(path, "/tmp/example-dot-com") + @patch("onnx_web.convert.utils.Path") + @patch("onnx_web.convert.utils.requests") + @patch("onnx_web.convert.utils.shutil") + @patch("onnx_web.convert.utils.tqdm") + def test_download_progress(self, mock_tqdm, mock_shutil, mock_requests, mock_path): + source = "http://example.com/image.jpg" + dest = "/path/to/destination/image.jpg" + + dest_path_mock = MagicMock() + mock_path.return_value.expanduser.return_value.resolve.return_value = ( + dest_path_mock + ) + dest_path_mock.exists.return_value = False + dest_path_mock.absolute.return_value = "test" + mock_requests.get.return_value.status_code = 200 + mock_requests.get.return_value.headers.get.return_value = "1000" + mock_tqdm.wrapattr.return_value.__enter__.return_value = MagicMock() + + result = download_progress(source, dest) + + mock_path.assert_called_once_with(dest) + dest_path_mock.parent.mkdir.assert_called_once_with(parents=True, exist_ok=True) + dest_path_mock.open.assert_called_once_with("wb") + mock_shutil.copyfileobj.assert_called_once() + self.assertEqual(result, str(dest_path_mock.absolute.return_value)) + class TupleToSourceTests(unittest.TestCase): def test_basic_tuple(self): @@ -221,14 +255,6 @@ class RemovePrefixTests(unittest.TestCase): self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar") -class LoadTorchTests(unittest.TestCase): - pass - - -class LoadTensorTests(unittest.TestCase): - pass - - class ResolveTensorTests(unittest.TestCase): @test_needs_models([TEST_MODEL_UPSCALING_SWINIR]) def test_resolve_existing(self): @@ -239,3 +265,251 @@ class ResolveTensorTests(unittest.TestCase): def test_resolve_missing(self): self.assertIsNone(resolve_tensor("missing")) + + +class LoadTorchTests(unittest.TestCase): + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.torch") + def test_load_torch_with_torch_load(self, mock_torch, mock_logger): + name = "model.pth" + map_location = "cpu" + checkpoint = MagicMock() + mock_torch.load.return_value = checkpoint + + result = load_torch(name, map_location) + + mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name) + mock_torch.load.assert_called_once_with(name, map_location=map_location) + self.assertEqual(result, checkpoint) + + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.torch") + def test_load_torch_with_torch_jit_load(self, mock_torch, mock_logger): + name = "model.pth" + checkpoint = MagicMock() + mock_torch.load.side_effect = Exception() + mock_torch.jit.load.return_value = checkpoint + + result = load_torch(name) + + mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name) + mock_logger.exception.assert_called_once_with( + "error loading with Torch, trying with Torch JIT: %s", name + ) + mock_torch.jit.load.assert_called_once_with(name) + self.assertEqual(result, checkpoint) + + +class LoadTensorTests(unittest.TestCase): + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.path") + @patch("onnx_web.convert.utils.torch") + def test_load_tensor_with_no_extension(self, mock_torch, mock_path, mock_logger): + name = "model" + map_location = "cpu" + checkpoint = MagicMock() + mock_path.exists.return_value = True + mock_path.splitext.side_effect = [("model", ""), ("model", ".safetensors")] + mock_torch.load.return_value = checkpoint + + result = load_tensor(name, map_location) + + mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)]) + mock_path.splitext.assert_called_once_with(name) + mock_path.exists.assert_called_once_with(name) + mock_torch.load.assert_called_once_with(name, map_location=map_location) + self.assertEqual(result, checkpoint) + + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.environ") + @patch("onnx_web.convert.utils.safetensors") + def test_load_tensor_with_safetensors_extension( + self, mock_safetensors, mock_environ, mock_logger + ): + name = "model.safetensors" + checkpoint = MagicMock() + mock_environ.__getitem__.return_value = "1" + mock_safetensors.torch.load_file.return_value = checkpoint + + result = load_tensor(name) + + mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)]) + mock_safetensors.torch.load_file.assert_called_once_with(name, device="cpu") + self.assertEqual(result, checkpoint) + + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.torch") + def test_load_tensor_with_pickle_extension(self, mock_torch, mock_logger): + name = "model.pt" + map_location = "cpu" + checkpoint = MagicMock() + mock_torch.load.side_effect = [checkpoint] + + result = load_tensor(name, map_location) + + mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)]) + mock_torch.load.assert_has_calls( + [ + mock.call(name, map_location=map_location), + ] + ) + self.assertEqual(result, checkpoint) + + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.torch") + def test_load_tensor_with_onnx_extension(self, mock_torch, mock_logger): + name = "model.onnx" + map_location = "cpu" + checkpoint = MagicMock() + mock_torch.load.side_effect = [checkpoint] + + result = load_tensor(name, map_location) + + mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)]) + mock_logger.warning.assert_called_once_with( + "tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx" + ) + mock_torch.load.assert_has_calls( + [ + mock.call(name, map_location=map_location), + ] + ) + self.assertEqual(result, checkpoint) + + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.torch") + def test_load_tensor_with_unknown_extension(self, mock_torch, mock_logger): + name = "model.xyz" + map_location = "cpu" + checkpoint = MagicMock() + mock_torch.load.side_effect = [checkpoint] + + result = load_tensor(name, map_location) + + mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)]) + mock_logger.warning.assert_called_once_with( + "unknown tensor type, falling back to PyTorch: %s", "xyz" + ) + mock_torch.load.assert_has_calls( + [ + mock.call(name, map_location=map_location), + ] + ) + self.assertEqual(result, checkpoint) + + @patch("onnx_web.convert.utils.logger") + @patch("onnx_web.convert.utils.torch") + def test_load_tensor_with_error_loading_tensor(self, mock_torch, mock_logger): + name = "model" + map_location = "cpu" + mock_torch.load.side_effect = Exception() + + with self.assertRaises(ValueError): + load_tensor(name, map_location) + + +class FixDiffusionNameTests(unittest.TestCase): + def test_fix_diffusion_name_with_valid_name(self): + name = "diffusion-model" + result = fix_diffusion_name(name) + self.assertEqual(result, name) + + @patch("onnx_web.convert.utils.logger") + def test_fix_diffusion_name_with_invalid_name(self, logger): + name = "model" + expected_result = "diffusion-model" + result = fix_diffusion_name(name) + + self.assertEqual(result, expected_result) + logger.warning.assert_called_once_with( + "diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match", + name, + ) + + +class BuildCachePathsTests(unittest.TestCase): + def test_build_cache_paths_without_format(self): + name = "model.onnx" + client = "client1" + cache = "/path/to/cache" + + conversion = ConversionContext(cache_path=cache) + result = build_cache_paths(conversion, name, client, cache) + + expected_paths = [ + path.join("/path/to/cache", "model.onnx"), + path.join("/path/to/cache/client1", "model.onnx"), + ] + self.assertEqual(result, expected_paths) + + def test_build_cache_paths_with_format(self): + name = "model" + client = "client2" + cache = "/path/to/cache" + format = "onnx" + + conversion = ConversionContext(cache_path=cache) + result = build_cache_paths(conversion, name, client, cache, format) + + expected_paths = [ + path.join("/path/to/cache", "model.onnx"), + path.join("/path/to/cache/client2", "model.onnx"), + ] + self.assertEqual(result, expected_paths) + + def test_build_cache_paths_with_existing_extension(self): + name = "model.pth" + client = "client3" + cache = "/path/to/cache" + format = "onnx" + + conversion = ConversionContext(cache_path=cache) + result = build_cache_paths(conversion, name, client, cache, format) + + expected_paths = [ + path.join("/path/to/cache", "model.pth"), + path.join("/path/to/cache/client3", "model.pth"), + ] + self.assertEqual(result, expected_paths) + + def test_build_cache_paths_with_empty_extension(self): + name = "model" + client = "client4" + cache = "/path/to/cache" + format = "onnx" + + conversion = ConversionContext(cache_path=cache) + result = build_cache_paths(conversion, name, client, cache, format) + + expected_paths = [ + path.join("/path/to/cache", "model.onnx"), + path.join("/path/to/cache/client4", "model.onnx"), + ] + self.assertEqual(result, expected_paths) + + +class GetFirstExistsTests(unittest.TestCase): + @patch("onnx_web.convert.utils.path") + @patch("onnx_web.convert.utils.logger") + def test_get_first_exists_with_existing_path(self, mock_logger, mock_path): + paths = ["path1", "path2", "path3"] + mock_path.exists.side_effect = [False, True, False] + mock_path.return_value = MagicMock() + + result = get_first_exists(paths) + + mock_logger.debug.assert_called_once_with( + "model already exists in cache, skipping fetch: %s", "path2" + ) + self.assertEqual(result, "path2") + + @patch("onnx_web.convert.utils.path") + @patch("onnx_web.convert.utils.logger") + def test_get_first_exists_with_no_existing_path(self, mock_logger, mock_path): + paths = ["path1", "path2", "path3"] + mock_path.exists.return_value = False + + result = get_first_exists(paths) + + mock_logger.debug.assert_not_called() + self.assertIsNone(result) diff --git a/api/tests/test_diffusers/test_run.py b/api/tests/test_diffusers/test_run.py index 1796d978..a2904a2c 100644 --- a/api/tests/test_diffusers/test_run.py +++ b/api/tests/test_diffusers/test_run.py @@ -78,9 +78,10 @@ class TestTxt2ImgPipeline(unittest.TestCase): ) self.assertTrue(path.exists("../outputs/test-txt2img-basic.png")) - output = Image.open("../outputs/test-txt2img-basic.png") - self.assertEqual(output.size, (256, 256)) - # TODO: test contents of image + + with Image.open("../outputs/test-txt2img-basic.png") as output: + self.assertEqual(output.size, (256, 256)) + # TODO: test contents of image @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) def test_batch(self): @@ -126,9 +127,9 @@ class TestTxt2ImgPipeline(unittest.TestCase): self.assertTrue(path.exists("../outputs/test-txt2img-batch-0.png")) self.assertTrue(path.exists("../outputs/test-txt2img-batch-1.png")) - output = Image.open("../outputs/test-txt2img-batch-0.png") - self.assertEqual(output.size, (256, 256)) - # TODO: test contents of image + with Image.open("../outputs/test-txt2img-batch-0.png") as output: + self.assertEqual(output.size, (256, 256)) + # TODO: test contents of image @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) def test_highres(self): @@ -172,8 +173,8 @@ class TestTxt2ImgPipeline(unittest.TestCase): ) self.assertTrue(path.exists("../outputs/test-txt2img-highres.png")) - output = Image.open("../outputs/test-txt2img-highres.png") - self.assertEqual(output.size, (512, 512)) + with Image.open("../outputs/test-txt2img-highres.png") as output: + self.assertEqual(output.size, (512, 512)) @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) def test_highres_batch(self): @@ -219,8 +220,8 @@ class TestTxt2ImgPipeline(unittest.TestCase): self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-0.png")) self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-1.png")) - output = Image.open("../outputs/test-txt2img-highres-batch-0.png") - self.assertEqual(output.size, (512, 512)) + with Image.open("../outputs/test-txt2img-highres-batch-0.png") as output: + self.assertEqual(output.size, (512, 512)) class TestImg2ImgPipeline(unittest.TestCase): diff --git a/api/tests/worker/test_pool.py b/api/tests/worker/test_pool.py index 721fe87d..1c8a2b83 100644 --- a/api/tests/worker/test_pool.py +++ b/api/tests/worker/test_pool.py @@ -7,20 +7,29 @@ from onnx_web.params import DeviceParams from onnx_web.server.context import ServerContext from onnx_web.worker.command import JobStatus from onnx_web.worker.pool import DevicePoolExecutor +from tests.helpers import test_device TEST_JOIN_TIMEOUT = 0.2 lock = Event() -def test_job(*args, **kwargs): +def lock_job(*args, **kwargs): lock.wait() -def wait_job(*args, **kwargs): +def sleep_job(*args, **kwargs): sleep(0.5) +def progress_job(worker, *args, **kwargs): + worker.set_progress(1) + + +def fail_job(*args, **kwargs): + raise RuntimeError("job failed") + + class TestWorkerPool(unittest.TestCase): # lock: Optional[Event] pool: Optional[DevicePoolExecutor] @@ -38,20 +47,20 @@ class TestWorkerPool(unittest.TestCase): self.pool.start() def test_fake_worker(self): - device = DeviceParams("cpu", "CPUProvider") + device = test_device() server = ServerContext() self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool.start() self.assertEqual(len(self.pool.workers), 1) def test_cancel_pending(self): - device = DeviceParams("cpu", "CPUProvider") + device = test_device() server = ServerContext() self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool.start() - self.pool.submit("test", "test", wait_job, lock=lock) + self.pool.submit("test", "test", sleep_job, lock=lock) self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None)) self.assertTrue(self.pool.cancel("test")) @@ -61,7 +70,7 @@ class TestWorkerPool(unittest.TestCase): pass def test_next_device(self): - device = DeviceParams("cpu", "CPUProvider") + device = test_device() server = ServerContext() self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool.start() @@ -83,7 +92,7 @@ class TestWorkerPool(unittest.TestCase): """ TODO: flaky """ - device = DeviceParams("cpu", "CPUProvider") + device = test_device() server = ServerContext() self.pool = DevicePoolExecutor( @@ -92,21 +101,21 @@ class TestWorkerPool(unittest.TestCase): lock.clear() self.pool.start(lock) - self.pool.submit("test", "test", test_job) + self.pool.submit("test", "test", lock_job) sleep(5.0) status, _progress = self.pool.status("test") self.assertEqual(status, JobStatus.RUNNING) def test_done_pending(self): - device = DeviceParams("cpu", "CPUProvider") + device = test_device() server = ServerContext() self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool.start(lock) - self.pool.submit("test1", "test", test_job) - self.pool.submit("test2", "test", test_job) + self.pool.submit("test1", "test", lock_job) + self.pool.submit("test2", "test", lock_job) self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None)) lock.set() @@ -115,14 +124,14 @@ class TestWorkerPool(unittest.TestCase): """ TODO: flaky """ - device = DeviceParams("cpu", "CPUProvider") + device = test_device() server = ServerContext() self.pool = DevicePoolExecutor( server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 ) self.pool.start() - self.pool.submit("test", "test", wait_job) + self.pool.submit("test", "test", sleep_job) self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None)) sleep(5.0) @@ -137,3 +146,38 @@ class TestWorkerPool(unittest.TestCase): def test_running_status(self): pass + + def test_progress_update(self): + pass + + def test_progress_finished(self): + device = test_device() + server = ServerContext() + + self.pool = DevicePoolExecutor( + server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 + ) + self.pool.start() + + self.pool.submit("test", "test", progress_job) + sleep(5.0) + + status, progress = self.pool.status("test") + self.assertEqual(status, JobStatus.SUCCESS) + self.assertEqual(progress.steps.current, 1) + + def test_progress_failed(self): + device = test_device() + server = ServerContext() + + self.pool = DevicePoolExecutor( + server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 + ) + self.pool.start() + + self.pool.submit("test", "test", fail_job) + sleep(5.0) + + status, progress = self.pool.status("test") + self.assertEqual(status, JobStatus.FAILED) + self.assertEqual(progress.steps.current, 0)