1
0
Fork 0

update tests

This commit is contained in:
Sean Sube 2024-01-05 20:13:57 -06:00
parent 88f99ef6c2
commit 6b0b2e41a6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
23 changed files with 467 additions and 49 deletions

View File

@ -107,7 +107,7 @@ def download_progress(source: str, dest: str):
stream=True,
allow_redirects=True,
headers={
"User-Agent": "onnx-web-api",
"User-Agent": "onnx-web-api", # TODO: add version
},
)
if req.status_code != 200:
@ -226,9 +226,7 @@ def load_torch(name: str, map_location=None) -> Optional[Dict]:
logger.debug("loading tensor with Torch: %s", name)
checkpoint = torch.load(name, map_location=map_location)
except Exception:
logger.exception(
"error loading with Torch JIT, trying with Torch JIT: %s", name
)
logger.exception("error loading with Torch, trying with Torch JIT: %s", name)
checkpoint = torch.jit.load(name)
return checkpoint

View File

@ -81,6 +81,12 @@ class Size:
def __str__(self) -> str:
return "%sx%s" % (self.width, self.height)
def __eq__(self, other: Any) -> bool:
if isinstance(other, Size):
return self.width == other.width and self.height == other.height
return False
def add_border(self, border: Border):
return Size(
border.left + self.width + border.right,

View File

@ -22,6 +22,13 @@ class JobType(str, Enum):
class Progress:
"""
Generic counter with current and expected/final/total value. Can be used to count up or down.
Counter is considered "complete" when the current value is greater than or equal to the total value, and "empty"
when the current value is zero.
"""
current: int
total: int

View File

@ -0,0 +1,50 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_denoise_localstd import BlendDenoiseLocalStdStage
from onnx_web.chain.result import ImageMetadata, StageResult
from onnx_web.params import ImageParams, Size
class TestBlendDenoiseLocalStdStage(unittest.TestCase):
def test_run(self):
# Create a dummy image
image = Image.new("RGB", (64, 64), color="white")
# Create a dummy StageResult object
sources = StageResult.from_images(
[image],
metadata=[
ImageMetadata(
ImageParams("test", "txt2img", "ddim", "test", 5.0, 25, 0),
Size(64, 64),
)
],
)
# Create an instance of BlendDenoiseLocalStdStage
stage = BlendDenoiseLocalStdStage()
# Call the run method with dummy parameters
result = stage.run(
_worker=None,
_server=None,
_stage=None,
_params=None,
sources=sources,
strength=5,
range=4,
stage_source=None,
callback=None,
)
# Assert that the result is an instance of StageResult
self.assertIsInstance(result, StageResult)
# Assert that the result contains the denoised image
self.assertEqual(len(result), 1)
self.assertEqual(result.size(), Size(64, 64))
# Assert that the metadata is preserved
self.assertEqual(result.metadata, sources.metadata)

View File

@ -3,7 +3,8 @@ import unittest
from PIL import Image
from onnx_web.chain.blend_grid import BlendGridStage
from onnx_web.chain.result import StageResult
from onnx_web.chain.result import ImageMetadata, StageResult
from onnx_web.params import ImageParams, Size
class BlendGridStageTests(unittest.TestCase):
@ -15,9 +16,17 @@ class BlendGridStageTests(unittest.TestCase):
Image.new("RGB", (64, 64), "white"),
Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"),
],
metadata=[
ImageMetadata(
ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1),
Size(64, 64),
),
]
* 4,
)
result = stage.run(None, None, None, None, sources, height=2, width=2)
result.validate()
self.assertEqual(len(result), 5)
self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0))
self.assertEqual(result.as_images()[-1].getpixel((0, 0)), (0, 0, 0))

View File

@ -3,8 +3,8 @@ import unittest
from PIL import Image
from onnx_web.chain.blend_img2img import BlendImg2ImgStage
from onnx_web.chain.result import StageResult
from onnx_web.params import ImageParams
from onnx_web.chain.result import ImageMetadata, StageResult
from onnx_web.params import ImageParams, Size
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
@ -39,9 +39,16 @@ class BlendImg2ImgStageTests(unittest.TestCase):
sources = StageResult(
images=[
Image.new("RGB", (64, 64), "black"),
]
],
metadata=[
ImageMetadata(
ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1),
Size(64, 64),
),
],
)
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
result.validate()
self.assertEqual(len(result), 1)
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (0, 0, 0))
self.assertEqual(result.as_images()[0].getpixel((0, 0)), (0, 0, 0))

View File

@ -3,7 +3,8 @@ import unittest
from PIL import Image
from onnx_web.chain.blend_linear import BlendLinearStage
from onnx_web.chain.result import StageResult
from onnx_web.chain.result import ImageMetadata, StageResult
from onnx_web.params import ImageParams, Size
class BlendLinearStageTests(unittest.TestCase):
@ -12,12 +13,19 @@ class BlendLinearStageTests(unittest.TestCase):
sources = StageResult(
images=[
Image.new("RGB", (64, 64), "black"),
]
],
metadata=[
ImageMetadata(
ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1),
Size(64, 64),
),
],
)
stage_source = Image.new("RGB", (64, 64), "white")
result = stage.run(
None, None, None, None, sources, alpha=0.5, stage_source=stage_source
)
result.validate()
self.assertEqual(len(result), 1)
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127))
self.assertEqual(result.as_images()[0].getpixel((0, 0)), (127, 127, 127))

View File

@ -23,5 +23,6 @@ class BlendMaskStageTests(unittest.TestCase):
stage_source=Image.new("RGBA", (64, 64)),
dims=(0, 0, SizeChart.auto),
)
result.validate()
self.assertEqual(len(result), 0)

View File

@ -42,5 +42,6 @@ class CorrectCodeformerStageTests(unittest.TestCase):
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
)
result.validate()
self.assertEqual(len(result), 0)

View File

@ -6,13 +6,14 @@ from onnx_web.params import HighresParams, UpscaleParams
from onnx_web.server.context import ServerContext
from onnx_web.server.hacks import apply_patches
from onnx_web.worker.context import WorkerContext
from tests.helpers import test_device, test_needs_onnx_models
from tests.helpers import test_device, test_needs_models
TEST_MODEL = "../models/correction-gfpgan-v1-3"
TEST_MODEL_NAME = "correction-gfpgan-v1-3"
TEST_MODEL = f"../models/.cache/{TEST_MODEL_NAME}.pth"
class CorrectGFPGANStageTests(unittest.TestCase):
@test_needs_onnx_models([TEST_MODEL])
@test_needs_models([TEST_MODEL])
def test_empty(self):
server = ServerContext(model_path="../models", output_path="../outputs")
apply_patches(server)
@ -33,12 +34,13 @@ class CorrectGFPGANStageTests(unittest.TestCase):
sources = StageResult.empty()
result = stage.run(
worker,
None,
server,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(TEST_MODEL),
upscale=UpscaleParams(TEST_MODEL_NAME, TEST_MODEL_NAME),
)
result.validate()
self.assertEqual(len(result), 0)

View File

@ -20,5 +20,6 @@ class ReduceCropStageTests(unittest.TestCase):
origin=Size(0, 0),
size=Size(128, 128),
)
result.validate()
self.assertEqual(len(result), 0)

View File

@ -24,5 +24,6 @@ class ReduceThumbnailStageTests(unittest.TestCase):
size=Size(128, 128),
stage_source=stage_source,
)
result.validate()
self.assertEqual(len(result), 0)

View File

@ -22,5 +22,6 @@ class SourceNoiseStageTests(unittest.TestCase):
size=Size(128, 128),
noise_source=noise_source_fill_edge,
)
result.validate()
self.assertEqual(len(result), 0)

View File

@ -22,5 +22,6 @@ class SourceS3StageTests(unittest.TestCase):
bucket="test",
source_keys=[],
)
result.validate()
self.assertEqual(len(result), 0)

View File

@ -21,5 +21,6 @@ class SourceURLStageTests(unittest.TestCase):
size=Size(128, 128),
source_urls=[],
)
result.validate()
self.assertEqual(len(result), 0)

View File

@ -37,5 +37,6 @@ class UpscaleBSRGANStageTests(unittest.TestCase):
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(TEST_MODEL),
)
result.validate()
self.assertEqual(len(result), 0)

View File

@ -18,5 +18,6 @@ class UpscaleHighresStageTests(unittest.TestCase):
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""),
)
result.validate()
self.assertEqual(len(result), 0)

View File

@ -46,5 +46,6 @@ class UpscaleOutpaintStageTests(unittest.TestCase):
dims=(),
tile_mask=Image.new("RGB", (64, 64)),
)
result.validate()
self.assertEqual(len(result), 0)

View File

@ -35,5 +35,6 @@ class UpscaleRealESRGANStageTests(unittest.TestCase):
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams("upscaling-real-esrgan-x4-v3"),
)
result.validate()
self.assertEqual(len(result), 0)

View File

@ -35,5 +35,6 @@ class UpscaleSwinIRStageTests(unittest.TestCase):
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(TEST_MODEL),
)
result.validate()
self.assertEqual(len(result), 0)

View File

@ -1,9 +1,17 @@
import unittest
from os import path
from unittest import mock
from unittest.mock import MagicMock, patch
from onnx_web.convert.utils import (
DEFAULT_OPSET,
ConversionContext,
build_cache_paths,
download_progress,
fix_diffusion_name,
get_first_exists,
load_tensor,
load_torch,
remove_prefix,
resolve_tensor,
source_format,
@ -30,6 +38,32 @@ class DownloadProgressTests(unittest.TestCase):
path = download_progress("https://example.com", "/tmp/example-dot-com")
self.assertEqual(path, "/tmp/example-dot-com")
@patch("onnx_web.convert.utils.Path")
@patch("onnx_web.convert.utils.requests")
@patch("onnx_web.convert.utils.shutil")
@patch("onnx_web.convert.utils.tqdm")
def test_download_progress(self, mock_tqdm, mock_shutil, mock_requests, mock_path):
source = "http://example.com/image.jpg"
dest = "/path/to/destination/image.jpg"
dest_path_mock = MagicMock()
mock_path.return_value.expanduser.return_value.resolve.return_value = (
dest_path_mock
)
dest_path_mock.exists.return_value = False
dest_path_mock.absolute.return_value = "test"
mock_requests.get.return_value.status_code = 200
mock_requests.get.return_value.headers.get.return_value = "1000"
mock_tqdm.wrapattr.return_value.__enter__.return_value = MagicMock()
result = download_progress(source, dest)
mock_path.assert_called_once_with(dest)
dest_path_mock.parent.mkdir.assert_called_once_with(parents=True, exist_ok=True)
dest_path_mock.open.assert_called_once_with("wb")
mock_shutil.copyfileobj.assert_called_once()
self.assertEqual(result, str(dest_path_mock.absolute.return_value))
class TupleToSourceTests(unittest.TestCase):
def test_basic_tuple(self):
@ -221,14 +255,6 @@ class RemovePrefixTests(unittest.TestCase):
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
class LoadTorchTests(unittest.TestCase):
pass
class LoadTensorTests(unittest.TestCase):
pass
class ResolveTensorTests(unittest.TestCase):
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
def test_resolve_existing(self):
@ -239,3 +265,251 @@ class ResolveTensorTests(unittest.TestCase):
def test_resolve_missing(self):
self.assertIsNone(resolve_tensor("missing"))
class LoadTorchTests(unittest.TestCase):
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.torch")
def test_load_torch_with_torch_load(self, mock_torch, mock_logger):
name = "model.pth"
map_location = "cpu"
checkpoint = MagicMock()
mock_torch.load.return_value = checkpoint
result = load_torch(name, map_location)
mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name)
mock_torch.load.assert_called_once_with(name, map_location=map_location)
self.assertEqual(result, checkpoint)
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.torch")
def test_load_torch_with_torch_jit_load(self, mock_torch, mock_logger):
name = "model.pth"
checkpoint = MagicMock()
mock_torch.load.side_effect = Exception()
mock_torch.jit.load.return_value = checkpoint
result = load_torch(name)
mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name)
mock_logger.exception.assert_called_once_with(
"error loading with Torch, trying with Torch JIT: %s", name
)
mock_torch.jit.load.assert_called_once_with(name)
self.assertEqual(result, checkpoint)
class LoadTensorTests(unittest.TestCase):
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.path")
@patch("onnx_web.convert.utils.torch")
def test_load_tensor_with_no_extension(self, mock_torch, mock_path, mock_logger):
name = "model"
map_location = "cpu"
checkpoint = MagicMock()
mock_path.exists.return_value = True
mock_path.splitext.side_effect = [("model", ""), ("model", ".safetensors")]
mock_torch.load.return_value = checkpoint
result = load_tensor(name, map_location)
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
mock_path.splitext.assert_called_once_with(name)
mock_path.exists.assert_called_once_with(name)
mock_torch.load.assert_called_once_with(name, map_location=map_location)
self.assertEqual(result, checkpoint)
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.environ")
@patch("onnx_web.convert.utils.safetensors")
def test_load_tensor_with_safetensors_extension(
self, mock_safetensors, mock_environ, mock_logger
):
name = "model.safetensors"
checkpoint = MagicMock()
mock_environ.__getitem__.return_value = "1"
mock_safetensors.torch.load_file.return_value = checkpoint
result = load_tensor(name)
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
mock_safetensors.torch.load_file.assert_called_once_with(name, device="cpu")
self.assertEqual(result, checkpoint)
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.torch")
def test_load_tensor_with_pickle_extension(self, mock_torch, mock_logger):
name = "model.pt"
map_location = "cpu"
checkpoint = MagicMock()
mock_torch.load.side_effect = [checkpoint]
result = load_tensor(name, map_location)
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
mock_torch.load.assert_has_calls(
[
mock.call(name, map_location=map_location),
]
)
self.assertEqual(result, checkpoint)
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.torch")
def test_load_tensor_with_onnx_extension(self, mock_torch, mock_logger):
name = "model.onnx"
map_location = "cpu"
checkpoint = MagicMock()
mock_torch.load.side_effect = [checkpoint]
result = load_tensor(name, map_location)
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
mock_logger.warning.assert_called_once_with(
"tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx"
)
mock_torch.load.assert_has_calls(
[
mock.call(name, map_location=map_location),
]
)
self.assertEqual(result, checkpoint)
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.torch")
def test_load_tensor_with_unknown_extension(self, mock_torch, mock_logger):
name = "model.xyz"
map_location = "cpu"
checkpoint = MagicMock()
mock_torch.load.side_effect = [checkpoint]
result = load_tensor(name, map_location)
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
mock_logger.warning.assert_called_once_with(
"unknown tensor type, falling back to PyTorch: %s", "xyz"
)
mock_torch.load.assert_has_calls(
[
mock.call(name, map_location=map_location),
]
)
self.assertEqual(result, checkpoint)
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.torch")
def test_load_tensor_with_error_loading_tensor(self, mock_torch, mock_logger):
name = "model"
map_location = "cpu"
mock_torch.load.side_effect = Exception()
with self.assertRaises(ValueError):
load_tensor(name, map_location)
class FixDiffusionNameTests(unittest.TestCase):
def test_fix_diffusion_name_with_valid_name(self):
name = "diffusion-model"
result = fix_diffusion_name(name)
self.assertEqual(result, name)
@patch("onnx_web.convert.utils.logger")
def test_fix_diffusion_name_with_invalid_name(self, logger):
name = "model"
expected_result = "diffusion-model"
result = fix_diffusion_name(name)
self.assertEqual(result, expected_result)
logger.warning.assert_called_once_with(
"diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match",
name,
)
class BuildCachePathsTests(unittest.TestCase):
def test_build_cache_paths_without_format(self):
name = "model.onnx"
client = "client1"
cache = "/path/to/cache"
conversion = ConversionContext(cache_path=cache)
result = build_cache_paths(conversion, name, client, cache)
expected_paths = [
path.join("/path/to/cache", "model.onnx"),
path.join("/path/to/cache/client1", "model.onnx"),
]
self.assertEqual(result, expected_paths)
def test_build_cache_paths_with_format(self):
name = "model"
client = "client2"
cache = "/path/to/cache"
format = "onnx"
conversion = ConversionContext(cache_path=cache)
result = build_cache_paths(conversion, name, client, cache, format)
expected_paths = [
path.join("/path/to/cache", "model.onnx"),
path.join("/path/to/cache/client2", "model.onnx"),
]
self.assertEqual(result, expected_paths)
def test_build_cache_paths_with_existing_extension(self):
name = "model.pth"
client = "client3"
cache = "/path/to/cache"
format = "onnx"
conversion = ConversionContext(cache_path=cache)
result = build_cache_paths(conversion, name, client, cache, format)
expected_paths = [
path.join("/path/to/cache", "model.pth"),
path.join("/path/to/cache/client3", "model.pth"),
]
self.assertEqual(result, expected_paths)
def test_build_cache_paths_with_empty_extension(self):
name = "model"
client = "client4"
cache = "/path/to/cache"
format = "onnx"
conversion = ConversionContext(cache_path=cache)
result = build_cache_paths(conversion, name, client, cache, format)
expected_paths = [
path.join("/path/to/cache", "model.onnx"),
path.join("/path/to/cache/client4", "model.onnx"),
]
self.assertEqual(result, expected_paths)
class GetFirstExistsTests(unittest.TestCase):
@patch("onnx_web.convert.utils.path")
@patch("onnx_web.convert.utils.logger")
def test_get_first_exists_with_existing_path(self, mock_logger, mock_path):
paths = ["path1", "path2", "path3"]
mock_path.exists.side_effect = [False, True, False]
mock_path.return_value = MagicMock()
result = get_first_exists(paths)
mock_logger.debug.assert_called_once_with(
"model already exists in cache, skipping fetch: %s", "path2"
)
self.assertEqual(result, "path2")
@patch("onnx_web.convert.utils.path")
@patch("onnx_web.convert.utils.logger")
def test_get_first_exists_with_no_existing_path(self, mock_logger, mock_path):
paths = ["path1", "path2", "path3"]
mock_path.exists.return_value = False
result = get_first_exists(paths)
mock_logger.debug.assert_not_called()
self.assertIsNone(result)

View File

@ -78,7 +78,8 @@ class TestTxt2ImgPipeline(unittest.TestCase):
)
self.assertTrue(path.exists("../outputs/test-txt2img-basic.png"))
output = Image.open("../outputs/test-txt2img-basic.png")
with Image.open("../outputs/test-txt2img-basic.png") as output:
self.assertEqual(output.size, (256, 256))
# TODO: test contents of image
@ -126,7 +127,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
self.assertTrue(path.exists("../outputs/test-txt2img-batch-0.png"))
self.assertTrue(path.exists("../outputs/test-txt2img-batch-1.png"))
output = Image.open("../outputs/test-txt2img-batch-0.png")
with Image.open("../outputs/test-txt2img-batch-0.png") as output:
self.assertEqual(output.size, (256, 256))
# TODO: test contents of image
@ -172,7 +173,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
)
self.assertTrue(path.exists("../outputs/test-txt2img-highres.png"))
output = Image.open("../outputs/test-txt2img-highres.png")
with Image.open("../outputs/test-txt2img-highres.png") as output:
self.assertEqual(output.size, (512, 512))
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
@ -219,7 +220,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-0.png"))
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-1.png"))
output = Image.open("../outputs/test-txt2img-highres-batch-0.png")
with Image.open("../outputs/test-txt2img-highres-batch-0.png") as output:
self.assertEqual(output.size, (512, 512))

View File

@ -7,20 +7,29 @@ from onnx_web.params import DeviceParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.command import JobStatus
from onnx_web.worker.pool import DevicePoolExecutor
from tests.helpers import test_device
TEST_JOIN_TIMEOUT = 0.2
lock = Event()
def test_job(*args, **kwargs):
def lock_job(*args, **kwargs):
lock.wait()
def wait_job(*args, **kwargs):
def sleep_job(*args, **kwargs):
sleep(0.5)
def progress_job(worker, *args, **kwargs):
worker.set_progress(1)
def fail_job(*args, **kwargs):
raise RuntimeError("job failed")
class TestWorkerPool(unittest.TestCase):
# lock: Optional[Event]
pool: Optional[DevicePoolExecutor]
@ -38,20 +47,20 @@ class TestWorkerPool(unittest.TestCase):
self.pool.start()
def test_fake_worker(self):
device = DeviceParams("cpu", "CPUProvider")
device = test_device()
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start()
self.assertEqual(len(self.pool.workers), 1)
def test_cancel_pending(self):
device = DeviceParams("cpu", "CPUProvider")
device = test_device()
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start()
self.pool.submit("test", "test", wait_job, lock=lock)
self.pool.submit("test", "test", sleep_job, lock=lock)
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
self.assertTrue(self.pool.cancel("test"))
@ -61,7 +70,7 @@ class TestWorkerPool(unittest.TestCase):
pass
def test_next_device(self):
device = DeviceParams("cpu", "CPUProvider")
device = test_device()
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start()
@ -83,7 +92,7 @@ class TestWorkerPool(unittest.TestCase):
"""
TODO: flaky
"""
device = DeviceParams("cpu", "CPUProvider")
device = test_device()
server = ServerContext()
self.pool = DevicePoolExecutor(
@ -92,21 +101,21 @@ class TestWorkerPool(unittest.TestCase):
lock.clear()
self.pool.start(lock)
self.pool.submit("test", "test", test_job)
self.pool.submit("test", "test", lock_job)
sleep(5.0)
status, _progress = self.pool.status("test")
self.assertEqual(status, JobStatus.RUNNING)
def test_done_pending(self):
device = DeviceParams("cpu", "CPUProvider")
device = test_device()
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start(lock)
self.pool.submit("test1", "test", test_job)
self.pool.submit("test2", "test", test_job)
self.pool.submit("test1", "test", lock_job)
self.pool.submit("test2", "test", lock_job)
self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None))
lock.set()
@ -115,14 +124,14 @@ class TestWorkerPool(unittest.TestCase):
"""
TODO: flaky
"""
device = DeviceParams("cpu", "CPUProvider")
device = test_device()
server = ServerContext()
self.pool = DevicePoolExecutor(
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
)
self.pool.start()
self.pool.submit("test", "test", wait_job)
self.pool.submit("test", "test", sleep_job)
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
sleep(5.0)
@ -137,3 +146,38 @@ class TestWorkerPool(unittest.TestCase):
def test_running_status(self):
pass
def test_progress_update(self):
pass
def test_progress_finished(self):
device = test_device()
server = ServerContext()
self.pool = DevicePoolExecutor(
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
)
self.pool.start()
self.pool.submit("test", "test", progress_job)
sleep(5.0)
status, progress = self.pool.status("test")
self.assertEqual(status, JobStatus.SUCCESS)
self.assertEqual(progress.steps.current, 1)
def test_progress_failed(self):
device = test_device()
server = ServerContext()
self.pool = DevicePoolExecutor(
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
)
self.pool.start()
self.pool.submit("test", "test", fail_job)
sleep(5.0)
status, progress = self.pool.status("test")
self.assertEqual(status, JobStatus.FAILED)
self.assertEqual(progress.steps.current, 0)