1
0
Fork 0

fix(api): continue adding tests, fix bugs encountered

This commit is contained in:
Sean Sube 2023-09-28 18:45:04 -05:00
parent 898d76e4a5
commit 047e58c916
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
16 changed files with 526 additions and 52 deletions

View File

@ -133,9 +133,9 @@ def tuple_to_source(model: Union[ModelDict, LegacyModel]):
def tuple_to_correction(model: Union[ModelDict, LegacyModel]): def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple): if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model name, source, *rest = model
scale = rest[0] if len(rest) > 0 else 1 scale = rest.pop(0) if len(rest) > 0 else 1
half = rest[0] if len(rest) > 0 else False half = rest.pop(0) if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None opset = rest.pop(0) if len(rest) > 0 else None
return { return {
"name": name, "name": name,
@ -151,9 +151,9 @@ def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]): def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple): if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model name, source, *rest = model
single_vae = rest[0] if len(rest) > 0 else False single_vae = rest.pop(0) if len(rest) > 0 else False
half = rest[0] if len(rest) > 0 else False half = rest.pop(0) if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None opset = rest.pop(0) if len(rest) > 0 else None
return { return {
"name": name, "name": name,
@ -169,9 +169,9 @@ def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]): def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple): if isinstance(model, list) or isinstance(model, tuple):
name, source, *rest = model name, source, *rest = model
scale = rest[0] if len(rest) > 0 else 1 scale = rest.pop(0) if len(rest) > 0 else 1
half = rest[0] if len(rest) > 0 else False half = rest.pop(0) if len(rest) > 0 else False
opset = rest[0] if len(rest) > 0 else None opset = rest.pop(0) if len(rest) > 0 else None
return { return {
"name": name, "name": name,
@ -298,6 +298,7 @@ def onnx_export(
half=False, half=False,
external_data=False, external_data=False,
v2=False, v2=False,
op_block_list=None,
): ):
""" """
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
@ -316,8 +317,7 @@ def onnx_export(
opset_version=opset, opset_version=opset,
) )
op_block_list = None if v2 and op_block_list is None:
if v2:
op_block_list = ["Attention", "MultiHeadAttention"] op_block_list = ["Attention", "MultiHeadAttention"]
if half: if half:

View File

@ -97,6 +97,7 @@ def run_txt2img_pipeline(
_pairs, loras, inversions, _rest = parse_prompt(params) _pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs): for image, output in zip(images, outputs):
logger.trace("saving output image %s: %s", output, image.size)
dest = save_image( dest = save_image(
server, server,
output, output,

View File

@ -47,7 +47,7 @@ def source_filter_noise(
source: Image.Image, source: Image.Image,
strength: float = 0.5, strength: float = 0.5,
): ):
noise = noise_source_histogram(source, source.size) noise = noise_source_histogram(source, source.size, (0, 0))
return ImageChops.blend(source, noise, strength) return ImageChops.blend(source, noise, strength)

View File

@ -25,6 +25,7 @@ class WorkerContext:
idle: "Value[bool]" idle: "Value[bool]"
timeout: float timeout: float
retries: int retries: int
initial_retries: int
def __init__( def __init__(
self, self,
@ -37,6 +38,7 @@ class WorkerContext:
active_pid: "Value[int]", active_pid: "Value[int]",
idle: "Value[bool]", idle: "Value[bool]",
retries: int, retries: int,
timeout: float,
): ):
self.job = None self.job = None
self.name = name self.name = name
@ -48,12 +50,13 @@ class WorkerContext:
self.active_pid = active_pid self.active_pid = active_pid
self.last_progress = None self.last_progress = None
self.idle = idle self.idle = idle
self.initial_retries = retries
self.retries = retries self.retries = retries
self.timeout = 1.0 self.timeout = timeout
def start(self, job: str) -> None: def start(self, job: str) -> None:
self.job = job self.job = job
self.retries = 3 self.retries = self.initial_retries
self.set_cancel(cancel=False) self.set_cancel(cancel=False)
self.set_idle(idle=False) self.set_idle(idle=False)

View File

@ -86,15 +86,15 @@ class DevicePoolExecutor:
self.logs = Queue(self.max_pending_per_worker) self.logs = Queue(self.max_pending_per_worker)
self.rlock = Lock() self.rlock = Lock()
def start(self) -> None: def start(self, *args) -> None:
self.create_health_worker() self.create_health_worker()
self.create_logger_worker() self.create_logger_worker()
self.create_progress_worker() self.create_progress_worker()
for device in self.devices: for device in self.devices:
self.create_device_worker(device) self.create_device_worker(device, *args)
def create_device_worker(self, device: DeviceParams) -> None: def create_device_worker(self, device: DeviceParams, *args) -> None:
name = device.device name = device.device
# always recreate queues # always recreate queues
@ -125,15 +125,16 @@ class DevicePoolExecutor:
active_pid=current, active_pid=current,
idle=self.worker_idle[name], idle=self.worker_idle[name],
retries=self.server.worker_retries, retries=self.server.worker_retries,
timeout=self.progress_interval,
) )
self.context[name] = context self.context[name] = context
worker = Process( worker = Process(
name=f"onnx-web worker: {name}", name=f"onnx-web worker: {name}",
target=worker_main, target=worker_main,
args=(context, self.server), args=(context, self.server, *args),
daemon=True,
) )
worker.daemon = True
self.workers[name] = worker self.workers[name] = worker
logger.debug("starting worker for device %s", device) logger.debug("starting worker for device %s", device)

View File

@ -27,7 +27,7 @@ MEMORY_ERRORS = [
] ]
def worker_main(worker: WorkerContext, server: ServerContext): def worker_main(worker: WorkerContext, server: ServerContext, *args):
apply_patches(server) apply_patches(server)
setproctitle("onnx-web worker: %s" % (worker.device.device)) setproctitle("onnx-web worker: %s" % (worker.device.device))

View File

@ -7,6 +7,8 @@ from onnx_web.chain.tile import (
generate_tile_spiral, generate_tile_spiral,
get_tile_grads, get_tile_grads,
needs_tile, needs_tile,
process_tile_grid,
process_tile_spiral,
) )
from onnx_web.params import Size from onnx_web.params import Size
@ -95,3 +97,31 @@ class TestGenerateTileSpiral(unittest.TestCase):
self.assertEqual(len(tiles), 225) self.assertEqual(len(tiles), 225)
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)]) self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)]) self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
class TestProcessTileGrid(unittest.TestCase):
def test_grid_full(self):
source = Image.new("RGB", (64, 64))
blend = process_tile_grid(source, 32, 1, [])
self.assertEqual(blend.size, (64, 64))
def test_grid_partial(self):
source = Image.new("RGB", (72, 72))
blend = process_tile_grid(source, 32, 1, [])
self.assertEqual(blend.size, (72, 72))
class TestProcessTileSpiral(unittest.TestCase):
def test_grid_full(self):
source = Image.new("RGB", (64, 64))
blend = process_tile_spiral(source, 32, 1, [])
self.assertEqual(blend.size, (64, 64))
def test_grid_partial(self):
source = Image.new("RGB", (72, 72))
blend = process_tile_spiral(source, 32, 1, [])
self.assertEqual(blend.size, (72, 72))

View File

@ -1,6 +1,14 @@
import unittest import unittest
from onnx_web.convert.utils import DEFAULT_OPSET, ConversionContext, download_progress from onnx_web.convert.utils import (
DEFAULT_OPSET,
ConversionContext,
download_progress,
tuple_to_correction,
tuple_to_diffusion,
tuple_to_source,
tuple_to_upscaling,
)
class ConversionContextTests(unittest.TestCase): class ConversionContextTests(unittest.TestCase):
@ -17,3 +25,160 @@ class DownloadProgressTests(unittest.TestCase):
def test_download_example(self): def test_download_example(self):
path = download_progress([("https://example.com", "/tmp/example-dot-com")]) path = download_progress([("https://example.com", "/tmp/example-dot-com")])
self.assertEqual(path, "/tmp/example-dot-com") self.assertEqual(path, "/tmp/example-dot-com")
class TupleToSourceTests(unittest.TestCase):
def test_basic_tuple(self):
source = tuple_to_source(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_list(self):
source = tuple_to_source(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_dict(self):
source = tuple_to_source(["foo", "bar"])
source["bin"] = "bin"
# make sure this is returned as-is with extra fields
second = tuple_to_source(source)
self.assertEqual(source, second)
self.assertIn("bin", second)
class TupleToCorrectionTests(unittest.TestCase):
def test_basic_tuple(self):
source = tuple_to_correction(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_list(self):
source = tuple_to_correction(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_dict(self):
source = tuple_to_correction(["foo", "bar"])
source["bin"] = "bin"
# make sure this is returned with extra fields
second = tuple_to_source(source)
self.assertEqual(source, second)
self.assertIn("bin", second)
def test_scale_tuple(self):
source = tuple_to_correction(["foo", "bar", 2])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_half_tuple(self):
source = tuple_to_correction(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_opset_tuple(self):
source = tuple_to_correction(["foo", "bar", 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_all_tuple(self):
source = tuple_to_correction(["foo", "bar", 2, True, 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
self.assertEqual(source["scale"], 2)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)
class TupleToDiffusionTests(unittest.TestCase):
def test_basic_tuple(self):
source = tuple_to_diffusion(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_list(self):
source = tuple_to_diffusion(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_dict(self):
source = tuple_to_diffusion(["foo", "bar"])
source["bin"] = "bin"
# make sure this is returned with extra fields
second = tuple_to_diffusion(source)
self.assertEqual(source, second)
self.assertIn("bin", second)
def test_single_vae_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_half_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_opset_tuple(self):
source = tuple_to_diffusion(["foo", "bar", 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_all_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True, True, 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
self.assertEqual(source["single_vae"], True)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)
class TupleToUpscalingTests(unittest.TestCase):
def test_basic_tuple(self):
source = tuple_to_upscaling(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_list(self):
source = tuple_to_upscaling(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_dict(self):
source = tuple_to_upscaling(["foo", "bar"])
source["bin"] = "bin"
# make sure this is returned with extra fields
second = tuple_to_source(source)
self.assertEqual(source, second)
self.assertIn("bin", second)
def test_scale_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 2])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_half_tuple(self):
source = tuple_to_upscaling(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_opset_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_all_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 2, True, 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
self.assertEqual(source["scale"], 2)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)

View File

@ -1,9 +1,16 @@
from os import path
from typing import List from typing import List
from unittest import skipUnless
from onnx_web.params import DeviceParams
def test_with_models(models: List[str]): def test_needs_models(models: List[str]):
def wrapper(func): return skipUnless(all([path.exists(model) for model in models]), "model does not exist")
# TODO: check if models exist
return func
return wrapper
def test_device() -> DeviceParams:
return DeviceParams("cpu", "CPUExecutionProvider")
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"

View File

View File

@ -0,0 +1,33 @@
import unittest
from PIL import Image
from onnx_web.image.mask_filter import (
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
mask_filter_none,
)
class MaskFilterNoneTests(unittest.TestCase):
def test_basic(self):
dims = (64, 64)
mask = Image.new("RGB", dims)
result = mask_filter_none(mask, dims, (0, 0))
self.assertEqual(result.size, dims)
class MaskFilterGaussianMultiplyTests(unittest.TestCase):
def test_basic(self):
dims = (64, 64)
mask = Image.new("RGB", dims)
result = mask_filter_gaussian_multiply(mask, dims, (0, 0))
self.assertEqual(result.size, dims)
class MaskFilterGaussianScreenTests(unittest.TestCase):
def test_basic(self):
dims = (64, 64)
mask = Image.new("RGB", dims)
result = mask_filter_gaussian_screen(mask, dims, (0, 0))
self.assertEqual(result.size, dims)

View File

@ -0,0 +1,37 @@
import unittest
from PIL import Image
from onnx_web.image.source_filter import (
source_filter_gaussian,
source_filter_noise,
source_filter_none,
)
from onnx_web.server.context import ServerContext
class SourceFilterNoneTests(unittest.TestCase):
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_none(server, source)
self.assertEqual(result.size, dims)
class SourceFilterGaussianTests(unittest.TestCase):
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_gaussian(server, source)
self.assertEqual(result.size, dims)
class SourceFilterNoiseTests(unittest.TestCase):
def test_basic(self):
dims = (64, 64)
server = ServerContext()
source = Image.new("RGB", dims)
result = source_filter_noise(server, source)
self.assertEqual(result.size, dims)

View File

@ -1,5 +1,6 @@
import unittest import unittest
from onnx_web.server.context import ServerContext
from onnx_web.server.load import ( from onnx_web.server.load import (
get_available_platforms, get_available_platforms,
get_config_params, get_config_params,
@ -14,6 +15,8 @@ from onnx_web.server.load import (
get_source_filters, get_source_filters,
get_upscaling_models, get_upscaling_models,
get_wildcard_data, get_wildcard_data,
load_extras,
load_models,
) )
@ -81,3 +84,13 @@ class SourceFilterTests(unittest.TestCase):
def test_before_setup(self): def test_before_setup(self):
filters = get_source_filters() filters = get_source_filters()
self.assertIsNotNone(filters) self.assertIsNotNone(filters)
class LoadExtrasTests(unittest.TestCase):
def test_default_extras(self):
server = ServerContext(extra_models=["../models/extras.json"])
load_extras(server)
class LoadModelsTests(unittest.TestCase):
def test_default_models(self):
server = ServerContext(model_path="../models")
load_models(server)

View File

@ -17,10 +17,9 @@ from onnx_web.diffusers.load import (
) )
from onnx_web.diffusers.patches.unet import UNetWrapper from onnx_web.diffusers.patches.unet import UNetWrapper
from onnx_web.diffusers.patches.vae import VAEWrapper from onnx_web.diffusers.patches.vae import VAEWrapper
from onnx_web.models.meta import NetworkModel, NetworkType from onnx_web.models.meta import NetworkModel
from onnx_web.params import DeviceParams, ImageParams from onnx_web.params import DeviceParams, ImageParams
from onnx_web.server.context import ServerContext from onnx_web.server.context import ServerContext
from tests.helpers import test_with_models
from tests.mocks import MockPipeline from tests.mocks import MockPipeline

View File

@ -0,0 +1,171 @@
import unittest
from multiprocessing import Queue, Value
from os import path
from PIL import Image
from onnx_web.diffusers.run import (
run_blend_pipeline,
run_img2img_pipeline,
run_txt2img_pipeline,
run_upscale_pipeline,
)
from onnx_web.params import HighresParams, ImageParams, Size, UpscaleParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
class TestTxt2ImgPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_basic(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
run_txt2img_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
Size(256, 256),
["test-txt2img.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
)
self.assertTrue(path.exists("../outputs/test-txt2img.png"))
class TestImg2ImgPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_basic(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
source = Image.new("RGB", (64, 64), "black")
run_img2img_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
["test-img2img.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
1.0,
)
self.assertTrue(path.exists("../outputs/test-img2img.png"))
class TestUpscalePipeline(unittest.TestCase):
@test_needs_models(["../models/upscaling-stable-diffusion-x4"])
def test_basic(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
source = Image.new("RGB", (64, 64), "black")
run_upscale_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
"../models/upscaling-stable-diffusion-x4", "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
Size(256, 256),
["test-upscale.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
)
self.assertTrue(path.exists("../outputs/test-upscale.png"))
class TestBlendPipeline(unittest.TestCase):
def test_basic(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
source = Image.new("RGBA", (64, 64), "black")
mask = Image.new("RGBA", (64, 64), "white")
run_blend_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
Size(64, 64),
["test-blend.png"],
UpscaleParams("test"),
[source, source],
mask,
)
self.assertTrue(path.exists("../outputs/test-blend.png"))

View File

@ -9,16 +9,22 @@ from onnx_web.worker.pool import DevicePoolExecutor
TEST_JOIN_TIMEOUT = 0.2 TEST_JOIN_TIMEOUT = 0.2
def test_job(*args, lock: Event, **kwargs): lock = Event()
def test_job(*args, **kwargs):
lock.wait() lock.wait()
def wait_job(*args, **kwargs):
sleep(0.5)
class TestWorkerPool(unittest.TestCase): class TestWorkerPool(unittest.TestCase):
lock: Optional[Event] # lock: Optional[Event]
pool: Optional[DevicePoolExecutor] pool: Optional[DevicePoolExecutor]
def setUp(self) -> None: def setUp(self) -> None:
self.lock = Event()
self.pool = None self.pool = None
def tearDown(self) -> None: def tearDown(self) -> None:
@ -38,7 +44,17 @@ class TestWorkerPool(unittest.TestCase):
self.assertEqual(len(self.pool.workers), 1) self.assertEqual(len(self.pool.workers), 1)
def test_cancel_pending(self): def test_cancel_pending(self):
pass device = DeviceParams("cpu", "CPUProvider")
server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start()
self.pool.submit("test", wait_job, lock=lock)
self.assertEqual(self.pool.done("test"), (True, None))
self.assertTrue(self.pool.cancel("test"))
self.assertEqual(self.pool.done("test"), (False, None))
def test_cancel_running(self): def test_cancel_running(self):
pass pass
@ -61,48 +77,46 @@ class TestWorkerPool(unittest.TestCase):
self.assertEqual(self.pool.get_next_device(needs_device=device2), 1) self.assertEqual(self.pool.get_next_device(needs_device=device2), 1)
def test_done_running(self): def test_done_running(self):
"""
device = DeviceParams("cpu", "CPUProvider") device = DeviceParams("cpu", "CPUProvider")
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1)
self.pool.start() self.pool.start(lock)
sleep(2.0)
self.pool.submit("test", test_job, lock=self.lock) self.pool.submit("test", test_job)
sleep(5.0) sleep(2.0)
self.assertEqual(self.pool.done("test"), (False, None))
""" pending, _progress = self.pool.done("test")
pass self.assertFalse(pending)
def test_done_pending(self): def test_done_pending(self):
device = DeviceParams("cpu", "CPUProvider") device = DeviceParams("cpu", "CPUProvider")
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start() self.pool.start(lock)
self.pool.submit("test1", test_job, lock=self.lock) self.pool.submit("test1", test_job)
self.pool.submit("test2", test_job, lock=self.lock) self.pool.submit("test2", test_job)
self.assertTrue(self.pool.done("test2"), (True, None)) self.assertTrue(self.pool.done("test2"), (True, None))
self.lock.set() lock.set()
def test_done_finished(self): def test_done_finished(self):
"""
device = DeviceParams("cpu", "CPUProvider") device = DeviceParams("cpu", "CPUProvider")
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1)
self.pool.start() self.pool.start()
sleep(2.0)
self.pool.submit("test", test_job, lock=self.lock) self.pool.submit("test", wait_job)
self.assertEqual(self.pool.done("test"), (True, None)) self.assertEqual(self.pool.done("test"), (True, None))
self.lock.set() sleep(2.0)
sleep(5.0) pending, _progress = self.pool.done("test")
self.assertEqual(self.pool.done("test"), (False, None)) self.assertFalse(pending)
"""
pass
def test_recycle_live(self): def test_recycle_live(self):
pass pass