diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 9ed7424f..52963916 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -133,9 +133,9 @@ def tuple_to_source(model: Union[ModelDict, LegacyModel]): def tuple_to_correction(model: Union[ModelDict, LegacyModel]): if isinstance(model, list) or isinstance(model, tuple): name, source, *rest = model - scale = rest[0] if len(rest) > 0 else 1 - half = rest[0] if len(rest) > 0 else False - opset = rest[0] if len(rest) > 0 else None + scale = rest.pop(0) if len(rest) > 0 else 1 + half = rest.pop(0) if len(rest) > 0 else False + opset = rest.pop(0) if len(rest) > 0 else None return { "name": name, @@ -151,9 +151,9 @@ def tuple_to_correction(model: Union[ModelDict, LegacyModel]): def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]): if isinstance(model, list) or isinstance(model, tuple): name, source, *rest = model - single_vae = rest[0] if len(rest) > 0 else False - half = rest[0] if len(rest) > 0 else False - opset = rest[0] if len(rest) > 0 else None + single_vae = rest.pop(0) if len(rest) > 0 else False + half = rest.pop(0) if len(rest) > 0 else False + opset = rest.pop(0) if len(rest) > 0 else None return { "name": name, @@ -169,9 +169,9 @@ def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]): def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]): if isinstance(model, list) or isinstance(model, tuple): name, source, *rest = model - scale = rest[0] if len(rest) > 0 else 1 - half = rest[0] if len(rest) > 0 else False - opset = rest[0] if len(rest) > 0 else None + scale = rest.pop(0) if len(rest) > 0 else 1 + half = rest.pop(0) if len(rest) > 0 else False + opset = rest.pop(0) if len(rest) > 0 else None return { "name": name, @@ -298,6 +298,7 @@ def onnx_export( half=False, external_data=False, v2=False, + op_block_list=None, ): """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py @@ -316,8 +317,7 @@ def onnx_export( opset_version=opset, ) - op_block_list = None - if v2: + if v2 and op_block_list is None: op_block_list = ["Attention", "MultiHeadAttention"] if half: diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index b1f14b1d..a9c72d2f 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -97,6 +97,7 @@ def run_txt2img_pipeline( _pairs, loras, inversions, _rest = parse_prompt(params) for image, output in zip(images, outputs): + logger.trace("saving output image %s: %s", output, image.size) dest = save_image( server, output, diff --git a/api/onnx_web/image/source_filter.py b/api/onnx_web/image/source_filter.py index ea6e0d12..99c6b5eb 100644 --- a/api/onnx_web/image/source_filter.py +++ b/api/onnx_web/image/source_filter.py @@ -47,7 +47,7 @@ def source_filter_noise( source: Image.Image, strength: float = 0.5, ): - noise = noise_source_histogram(source, source.size) + noise = noise_source_histogram(source, source.size, (0, 0)) return ImageChops.blend(source, noise, strength) diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 2b35ff4c..a24613ed 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -25,6 +25,7 @@ class WorkerContext: idle: "Value[bool]" timeout: float retries: int + initial_retries: int def __init__( self, @@ -37,6 +38,7 @@ class WorkerContext: active_pid: "Value[int]", idle: "Value[bool]", retries: int, + timeout: float, ): self.job = None self.name = name @@ -48,12 +50,13 @@ class WorkerContext: self.active_pid = active_pid self.last_progress = None self.idle = idle + self.initial_retries = retries self.retries = retries - self.timeout = 1.0 + self.timeout = timeout def start(self, job: str) -> None: self.job = job - self.retries = 3 + self.retries = self.initial_retries self.set_cancel(cancel=False) self.set_idle(idle=False) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 833a04f8..3b0d32a8 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -86,15 +86,15 @@ class DevicePoolExecutor: self.logs = Queue(self.max_pending_per_worker) self.rlock = Lock() - def start(self) -> None: + def start(self, *args) -> None: self.create_health_worker() self.create_logger_worker() self.create_progress_worker() for device in self.devices: - self.create_device_worker(device) + self.create_device_worker(device, *args) - def create_device_worker(self, device: DeviceParams) -> None: + def create_device_worker(self, device: DeviceParams, *args) -> None: name = device.device # always recreate queues @@ -125,15 +125,16 @@ class DevicePoolExecutor: active_pid=current, idle=self.worker_idle[name], retries=self.server.worker_retries, + timeout=self.progress_interval, ) self.context[name] = context worker = Process( name=f"onnx-web worker: {name}", target=worker_main, - args=(context, self.server), + args=(context, self.server, *args), + daemon=True, ) - worker.daemon = True self.workers[name] = worker logger.debug("starting worker for device %s", device) diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 361c1150..5377c42a 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -27,7 +27,7 @@ MEMORY_ERRORS = [ ] -def worker_main(worker: WorkerContext, server: ServerContext): +def worker_main(worker: WorkerContext, server: ServerContext, *args): apply_patches(server) setproctitle("onnx-web worker: %s" % (worker.device.device)) diff --git a/api/tests/chain/test_tile.py b/api/tests/chain/test_tile.py index a613a719..7f599db2 100644 --- a/api/tests/chain/test_tile.py +++ b/api/tests/chain/test_tile.py @@ -7,6 +7,8 @@ from onnx_web.chain.tile import ( generate_tile_spiral, get_tile_grads, needs_tile, + process_tile_grid, + process_tile_spiral, ) from onnx_web.params import Size @@ -95,3 +97,31 @@ class TestGenerateTileSpiral(unittest.TestCase): self.assertEqual(len(tiles), 225) self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)]) self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)]) + + +class TestProcessTileGrid(unittest.TestCase): + def test_grid_full(self): + source = Image.new("RGB", (64, 64)) + blend = process_tile_grid(source, 32, 1, []) + + self.assertEqual(blend.size, (64, 64)) + + def test_grid_partial(self): + source = Image.new("RGB", (72, 72)) + blend = process_tile_grid(source, 32, 1, []) + + self.assertEqual(blend.size, (72, 72)) + + +class TestProcessTileSpiral(unittest.TestCase): + def test_grid_full(self): + source = Image.new("RGB", (64, 64)) + blend = process_tile_spiral(source, 32, 1, []) + + self.assertEqual(blend.size, (64, 64)) + + def test_grid_partial(self): + source = Image.new("RGB", (72, 72)) + blend = process_tile_spiral(source, 32, 1, []) + + self.assertEqual(blend.size, (72, 72)) diff --git a/api/tests/convert/test_utils.py b/api/tests/convert/test_utils.py index 755d6032..45c8fccc 100644 --- a/api/tests/convert/test_utils.py +++ b/api/tests/convert/test_utils.py @@ -1,6 +1,14 @@ import unittest -from onnx_web.convert.utils import DEFAULT_OPSET, ConversionContext, download_progress +from onnx_web.convert.utils import ( + DEFAULT_OPSET, + ConversionContext, + download_progress, + tuple_to_correction, + tuple_to_diffusion, + tuple_to_source, + tuple_to_upscaling, +) class ConversionContextTests(unittest.TestCase): @@ -17,3 +25,160 @@ class DownloadProgressTests(unittest.TestCase): def test_download_example(self): path = download_progress([("https://example.com", "/tmp/example-dot-com")]) self.assertEqual(path, "/tmp/example-dot-com") + + +class TupleToSourceTests(unittest.TestCase): + def test_basic_tuple(self): + source = tuple_to_source(("foo", "bar")) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_list(self): + source = tuple_to_source(["foo", "bar"]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_dict(self): + source = tuple_to_source(["foo", "bar"]) + source["bin"] = "bin" + + # make sure this is returned as-is with extra fields + second = tuple_to_source(source) + + self.assertEqual(source, second) + self.assertIn("bin", second) + + +class TupleToCorrectionTests(unittest.TestCase): + def test_basic_tuple(self): + source = tuple_to_correction(("foo", "bar")) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_list(self): + source = tuple_to_correction(["foo", "bar"]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_dict(self): + source = tuple_to_correction(["foo", "bar"]) + source["bin"] = "bin" + + # make sure this is returned with extra fields + second = tuple_to_source(source) + + self.assertEqual(source, second) + self.assertIn("bin", second) + + def test_scale_tuple(self): + source = tuple_to_correction(["foo", "bar", 2]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_half_tuple(self): + source = tuple_to_correction(["foo", "bar", True]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_opset_tuple(self): + source = tuple_to_correction(["foo", "bar", 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_all_tuple(self): + source = tuple_to_correction(["foo", "bar", 2, True, 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + self.assertEqual(source["scale"], 2) + self.assertEqual(source["half"], True) + self.assertEqual(source["opset"], 14) + + +class TupleToDiffusionTests(unittest.TestCase): + def test_basic_tuple(self): + source = tuple_to_diffusion(("foo", "bar")) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_list(self): + source = tuple_to_diffusion(["foo", "bar"]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_dict(self): + source = tuple_to_diffusion(["foo", "bar"]) + source["bin"] = "bin" + + # make sure this is returned with extra fields + second = tuple_to_diffusion(source) + + self.assertEqual(source, second) + self.assertIn("bin", second) + + def test_single_vae_tuple(self): + source = tuple_to_diffusion(["foo", "bar", True]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_half_tuple(self): + source = tuple_to_diffusion(["foo", "bar", True]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_opset_tuple(self): + source = tuple_to_diffusion(["foo", "bar", 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_all_tuple(self): + source = tuple_to_diffusion(["foo", "bar", True, True, 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + self.assertEqual(source["single_vae"], True) + self.assertEqual(source["half"], True) + self.assertEqual(source["opset"], 14) + + +class TupleToUpscalingTests(unittest.TestCase): + def test_basic_tuple(self): + source = tuple_to_upscaling(("foo", "bar")) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_list(self): + source = tuple_to_upscaling(["foo", "bar"]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_basic_dict(self): + source = tuple_to_upscaling(["foo", "bar"]) + source["bin"] = "bin" + + # make sure this is returned with extra fields + second = tuple_to_source(source) + + self.assertEqual(source, second) + self.assertIn("bin", second) + + def test_scale_tuple(self): + source = tuple_to_upscaling(["foo", "bar", 2]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_half_tuple(self): + source = tuple_to_upscaling(["foo", "bar", True]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_opset_tuple(self): + source = tuple_to_upscaling(["foo", "bar", 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + + def test_all_tuple(self): + source = tuple_to_upscaling(["foo", "bar", 2, True, 14]) + self.assertEqual(source["name"], "foo") + self.assertEqual(source["source"], "bar") + self.assertEqual(source["scale"], 2) + self.assertEqual(source["half"], True) + self.assertEqual(source["opset"], 14) diff --git a/api/tests/helpers.py b/api/tests/helpers.py index 3fc5cc7b..586ecbd8 100644 --- a/api/tests/helpers.py +++ b/api/tests/helpers.py @@ -1,9 +1,16 @@ +from os import path from typing import List +from unittest import skipUnless + +from onnx_web.params import DeviceParams -def test_with_models(models: List[str]): - def wrapper(func): - # TODO: check if models exist - return func +def test_needs_models(models: List[str]): + return skipUnless(all([path.exists(model) for model in models]), "model does not exist") - return wrapper + +def test_device() -> DeviceParams: + return DeviceParams("cpu", "CPUExecutionProvider") + + +TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5" \ No newline at end of file diff --git a/api/tests/image/__init__.py b/api/tests/image/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/tests/image/test_mask_filter.py b/api/tests/image/test_mask_filter.py new file mode 100644 index 00000000..58b46c7c --- /dev/null +++ b/api/tests/image/test_mask_filter.py @@ -0,0 +1,33 @@ +import unittest + +from PIL import Image + +from onnx_web.image.mask_filter import ( + mask_filter_gaussian_multiply, + mask_filter_gaussian_screen, + mask_filter_none, +) + + +class MaskFilterNoneTests(unittest.TestCase): + def test_basic(self): + dims = (64, 64) + mask = Image.new("RGB", dims) + result = mask_filter_none(mask, dims, (0, 0)) + self.assertEqual(result.size, dims) + + +class MaskFilterGaussianMultiplyTests(unittest.TestCase): + def test_basic(self): + dims = (64, 64) + mask = Image.new("RGB", dims) + result = mask_filter_gaussian_multiply(mask, dims, (0, 0)) + self.assertEqual(result.size, dims) + + +class MaskFilterGaussianScreenTests(unittest.TestCase): + def test_basic(self): + dims = (64, 64) + mask = Image.new("RGB", dims) + result = mask_filter_gaussian_screen(mask, dims, (0, 0)) + self.assertEqual(result.size, dims) diff --git a/api/tests/image/test_source_filter.py b/api/tests/image/test_source_filter.py new file mode 100644 index 00000000..89e73924 --- /dev/null +++ b/api/tests/image/test_source_filter.py @@ -0,0 +1,37 @@ +import unittest + +from PIL import Image + +from onnx_web.image.source_filter import ( + source_filter_gaussian, + source_filter_noise, + source_filter_none, +) +from onnx_web.server.context import ServerContext + + +class SourceFilterNoneTests(unittest.TestCase): + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_none(server, source) + self.assertEqual(result.size, dims) + + +class SourceFilterGaussianTests(unittest.TestCase): + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_gaussian(server, source) + self.assertEqual(result.size, dims) + + +class SourceFilterNoiseTests(unittest.TestCase): + def test_basic(self): + dims = (64, 64) + server = ServerContext() + source = Image.new("RGB", dims) + result = source_filter_noise(server, source) + self.assertEqual(result.size, dims) diff --git a/api/tests/server/test_load.py b/api/tests/server/test_load.py index 67a5f4e4..c32b9663 100644 --- a/api/tests/server/test_load.py +++ b/api/tests/server/test_load.py @@ -1,5 +1,6 @@ import unittest +from onnx_web.server.context import ServerContext from onnx_web.server.load import ( get_available_platforms, get_config_params, @@ -14,6 +15,8 @@ from onnx_web.server.load import ( get_source_filters, get_upscaling_models, get_wildcard_data, + load_extras, + load_models, ) @@ -81,3 +84,13 @@ class SourceFilterTests(unittest.TestCase): def test_before_setup(self): filters = get_source_filters() self.assertIsNotNone(filters) + +class LoadExtrasTests(unittest.TestCase): + def test_default_extras(self): + server = ServerContext(extra_models=["../models/extras.json"]) + load_extras(server) + +class LoadModelsTests(unittest.TestCase): + def test_default_models(self): + server = ServerContext(model_path="../models") + load_models(server) diff --git a/api/tests/test_diffusers/test_load.py b/api/tests/test_diffusers/test_load.py index 8a7720c7..8f7a3963 100644 --- a/api/tests/test_diffusers/test_load.py +++ b/api/tests/test_diffusers/test_load.py @@ -17,10 +17,9 @@ from onnx_web.diffusers.load import ( ) from onnx_web.diffusers.patches.unet import UNetWrapper from onnx_web.diffusers.patches.vae import VAEWrapper -from onnx_web.models.meta import NetworkModel, NetworkType +from onnx_web.models.meta import NetworkModel from onnx_web.params import DeviceParams, ImageParams from onnx_web.server.context import ServerContext -from tests.helpers import test_with_models from tests.mocks import MockPipeline diff --git a/api/tests/test_diffusers/test_run.py b/api/tests/test_diffusers/test_run.py new file mode 100644 index 00000000..5152a834 --- /dev/null +++ b/api/tests/test_diffusers/test_run.py @@ -0,0 +1,171 @@ +import unittest +from multiprocessing import Queue, Value +from os import path + +from PIL import Image + +from onnx_web.diffusers.run import ( + run_blend_pipeline, + run_img2img_pipeline, + run_txt2img_pipeline, + run_upscale_pipeline, +) +from onnx_web.params import HighresParams, ImageParams, Size, UpscaleParams +from onnx_web.server.context import ServerContext +from onnx_web.worker.context import WorkerContext +from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models + + +class TestTxt2ImgPipeline(unittest.TestCase): + @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) + def test_basic(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) + + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") + + run_txt2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1), + Size(256, 256), + ["test-txt2img.png"], + UpscaleParams("test"), + HighresParams(False, 1, 0, 0), + ) + + self.assertTrue(path.exists("../outputs/test-txt2img.png")) + +class TestImg2ImgPipeline(unittest.TestCase): + @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) + def test_basic(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) + + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") + + source = Image.new("RGB", (64, 64), "black") + run_img2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1), + ["test-img2img.png"], + UpscaleParams("test"), + HighresParams(False, 1, 0, 0), + source, + 1.0, + ) + + self.assertTrue(path.exists("../outputs/test-img2img.png")) + +class TestUpscalePipeline(unittest.TestCase): + @test_needs_models(["../models/upscaling-stable-diffusion-x4"]) + def test_basic(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) + + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") + + source = Image.new("RGB", (64, 64), "black") + run_upscale_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + "../models/upscaling-stable-diffusion-x4", "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1), + Size(256, 256), + ["test-upscale.png"], + UpscaleParams("test"), + HighresParams(False, 1, 0, 0), + source, + ) + + self.assertTrue(path.exists("../outputs/test-upscale.png")) + +class TestBlendPipeline(unittest.TestCase): + def test_basic(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) + + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") + + source = Image.new("RGBA", (64, 64), "black") + mask = Image.new("RGBA", (64, 64), "white") + run_blend_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1), + Size(64, 64), + ["test-blend.png"], + UpscaleParams("test"), + [source, source], + mask, + ) + + self.assertTrue(path.exists("../outputs/test-blend.png")) diff --git a/api/tests/worker/test_pool.py b/api/tests/worker/test_pool.py index 2547512f..7ea73451 100644 --- a/api/tests/worker/test_pool.py +++ b/api/tests/worker/test_pool.py @@ -9,16 +9,22 @@ from onnx_web.worker.pool import DevicePoolExecutor TEST_JOIN_TIMEOUT = 0.2 -def test_job(*args, lock: Event, **kwargs): +lock = Event() + + +def test_job(*args, **kwargs): lock.wait() +def wait_job(*args, **kwargs): + sleep(0.5) + + class TestWorkerPool(unittest.TestCase): - lock: Optional[Event] + # lock: Optional[Event] pool: Optional[DevicePoolExecutor] def setUp(self) -> None: - self.lock = Event() self.pool = None def tearDown(self) -> None: @@ -38,7 +44,17 @@ class TestWorkerPool(unittest.TestCase): self.assertEqual(len(self.pool.workers), 1) def test_cancel_pending(self): - pass + device = DeviceParams("cpu", "CPUProvider") + server = ServerContext() + + self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) + self.pool.start() + + self.pool.submit("test", wait_job, lock=lock) + self.assertEqual(self.pool.done("test"), (True, None)) + + self.assertTrue(self.pool.cancel("test")) + self.assertEqual(self.pool.done("test"), (False, None)) def test_cancel_running(self): pass @@ -61,48 +77,46 @@ class TestWorkerPool(unittest.TestCase): self.assertEqual(self.pool.get_next_device(needs_device=device2), 1) def test_done_running(self): - """ device = DeviceParams("cpu", "CPUProvider") server = ServerContext() - self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) - self.pool.start() + self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1) + self.pool.start(lock) + sleep(2.0) - self.pool.submit("test", test_job, lock=self.lock) - sleep(5.0) - self.assertEqual(self.pool.done("test"), (False, None)) - """ - pass + self.pool.submit("test", test_job) + sleep(2.0) + + pending, _progress = self.pool.done("test") + self.assertFalse(pending) def test_done_pending(self): device = DeviceParams("cpu", "CPUProvider") server = ServerContext() self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) - self.pool.start() + self.pool.start(lock) - self.pool.submit("test1", test_job, lock=self.lock) - self.pool.submit("test2", test_job, lock=self.lock) + self.pool.submit("test1", test_job) + self.pool.submit("test2", test_job) self.assertTrue(self.pool.done("test2"), (True, None)) - self.lock.set() + lock.set() def test_done_finished(self): - """ device = DeviceParams("cpu", "CPUProvider") server = ServerContext() - self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) + self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1) self.pool.start() + sleep(2.0) - self.pool.submit("test", test_job, lock=self.lock) + self.pool.submit("test", wait_job) self.assertEqual(self.pool.done("test"), (True, None)) - self.lock.set() - sleep(5.0) - self.assertEqual(self.pool.done("test"), (False, None)) - """ - pass + sleep(2.0) + pending, _progress = self.pool.done("test") + self.assertFalse(pending) def test_recycle_live(self): pass