1
0
Fork 0

update tests

This commit is contained in:
Sean Sube 2024-01-05 20:13:57 -06:00
parent 88f99ef6c2
commit 6b0b2e41a6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
23 changed files with 467 additions and 49 deletions

View File

@ -107,7 +107,7 @@ def download_progress(source: str, dest: str):
stream=True, stream=True,
allow_redirects=True, allow_redirects=True,
headers={ headers={
"User-Agent": "onnx-web-api", "User-Agent": "onnx-web-api", # TODO: add version
}, },
) )
if req.status_code != 200: if req.status_code != 200:
@ -226,9 +226,7 @@ def load_torch(name: str, map_location=None) -> Optional[Dict]:
logger.debug("loading tensor with Torch: %s", name) logger.debug("loading tensor with Torch: %s", name)
checkpoint = torch.load(name, map_location=map_location) checkpoint = torch.load(name, map_location=map_location)
except Exception: except Exception:
logger.exception( logger.exception("error loading with Torch, trying with Torch JIT: %s", name)
"error loading with Torch JIT, trying with Torch JIT: %s", name
)
checkpoint = torch.jit.load(name) checkpoint = torch.jit.load(name)
return checkpoint return checkpoint

View File

@ -81,6 +81,12 @@ class Size:
def __str__(self) -> str: def __str__(self) -> str:
return "%sx%s" % (self.width, self.height) return "%sx%s" % (self.width, self.height)
def __eq__(self, other: Any) -> bool:
if isinstance(other, Size):
return self.width == other.width and self.height == other.height
return False
def add_border(self, border: Border): def add_border(self, border: Border):
return Size( return Size(
border.left + self.width + border.right, border.left + self.width + border.right,

View File

@ -22,6 +22,13 @@ class JobType(str, Enum):
class Progress: class Progress:
"""
Generic counter with current and expected/final/total value. Can be used to count up or down.
Counter is considered "complete" when the current value is greater than or equal to the total value, and "empty"
when the current value is zero.
"""
current: int current: int
total: int total: int

View File

@ -0,0 +1,50 @@
import unittest
from PIL import Image
from onnx_web.chain.blend_denoise_localstd import BlendDenoiseLocalStdStage
from onnx_web.chain.result import ImageMetadata, StageResult
from onnx_web.params import ImageParams, Size
class TestBlendDenoiseLocalStdStage(unittest.TestCase):
def test_run(self):
# Create a dummy image
image = Image.new("RGB", (64, 64), color="white")
# Create a dummy StageResult object
sources = StageResult.from_images(
[image],
metadata=[
ImageMetadata(
ImageParams("test", "txt2img", "ddim", "test", 5.0, 25, 0),
Size(64, 64),
)
],
)
# Create an instance of BlendDenoiseLocalStdStage
stage = BlendDenoiseLocalStdStage()
# Call the run method with dummy parameters
result = stage.run(
_worker=None,
_server=None,
_stage=None,
_params=None,
sources=sources,
strength=5,
range=4,
stage_source=None,
callback=None,
)
# Assert that the result is an instance of StageResult
self.assertIsInstance(result, StageResult)
# Assert that the result contains the denoised image
self.assertEqual(len(result), 1)
self.assertEqual(result.size(), Size(64, 64))
# Assert that the metadata is preserved
self.assertEqual(result.metadata, sources.metadata)

View File

@ -3,7 +3,8 @@ import unittest
from PIL import Image from PIL import Image
from onnx_web.chain.blend_grid import BlendGridStage from onnx_web.chain.blend_grid import BlendGridStage
from onnx_web.chain.result import StageResult from onnx_web.chain.result import ImageMetadata, StageResult
from onnx_web.params import ImageParams, Size
class BlendGridStageTests(unittest.TestCase): class BlendGridStageTests(unittest.TestCase):
@ -15,9 +16,17 @@ class BlendGridStageTests(unittest.TestCase):
Image.new("RGB", (64, 64), "white"), Image.new("RGB", (64, 64), "white"),
Image.new("RGB", (64, 64), "black"), Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"), Image.new("RGB", (64, 64), "white"),
],
metadata=[
ImageMetadata(
ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1),
Size(64, 64),
),
] ]
* 4,
) )
result = stage.run(None, None, None, None, sources, height=2, width=2) result = stage.run(None, None, None, None, sources, height=2, width=2)
result.validate()
self.assertEqual(len(result), 5) self.assertEqual(len(result), 5)
self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0)) self.assertEqual(result.as_images()[-1].getpixel((0, 0)), (0, 0, 0))

View File

@ -3,8 +3,8 @@ import unittest
from PIL import Image from PIL import Image
from onnx_web.chain.blend_img2img import BlendImg2ImgStage from onnx_web.chain.blend_img2img import BlendImg2ImgStage
from onnx_web.chain.result import StageResult from onnx_web.chain.result import ImageMetadata, StageResult
from onnx_web.params import ImageParams from onnx_web.params import ImageParams, Size
from onnx_web.server.context import ServerContext from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext from onnx_web.worker.context import WorkerContext
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
@ -39,9 +39,16 @@ class BlendImg2ImgStageTests(unittest.TestCase):
sources = StageResult( sources = StageResult(
images=[ images=[
Image.new("RGB", (64, 64), "black"), Image.new("RGB", (64, 64), "black"),
] ],
metadata=[
ImageMetadata(
ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1),
Size(64, 64),
),
],
) )
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1) result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
result.validate()
self.assertEqual(len(result), 1) self.assertEqual(len(result), 1)
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (0, 0, 0)) self.assertEqual(result.as_images()[0].getpixel((0, 0)), (0, 0, 0))

View File

@ -3,7 +3,8 @@ import unittest
from PIL import Image from PIL import Image
from onnx_web.chain.blend_linear import BlendLinearStage from onnx_web.chain.blend_linear import BlendLinearStage
from onnx_web.chain.result import StageResult from onnx_web.chain.result import ImageMetadata, StageResult
from onnx_web.params import ImageParams, Size
class BlendLinearStageTests(unittest.TestCase): class BlendLinearStageTests(unittest.TestCase):
@ -12,12 +13,19 @@ class BlendLinearStageTests(unittest.TestCase):
sources = StageResult( sources = StageResult(
images=[ images=[
Image.new("RGB", (64, 64), "black"), Image.new("RGB", (64, 64), "black"),
] ],
metadata=[
ImageMetadata(
ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1),
Size(64, 64),
),
],
) )
stage_source = Image.new("RGB", (64, 64), "white") stage_source = Image.new("RGB", (64, 64), "white")
result = stage.run( result = stage.run(
None, None, None, None, sources, alpha=0.5, stage_source=stage_source None, None, None, None, sources, alpha=0.5, stage_source=stage_source
) )
result.validate()
self.assertEqual(len(result), 1) self.assertEqual(len(result), 1)
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127)) self.assertEqual(result.as_images()[0].getpixel((0, 0)), (127, 127, 127))

View File

@ -23,5 +23,6 @@ class BlendMaskStageTests(unittest.TestCase):
stage_source=Image.new("RGBA", (64, 64)), stage_source=Image.new("RGBA", (64, 64)),
dims=(0, 0, SizeChart.auto), dims=(0, 0, SizeChart.auto),
) )
result.validate()
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)

View File

@ -42,5 +42,6 @@ class CorrectCodeformerStageTests(unittest.TestCase):
highres=HighresParams(False, 1, 0, 0), highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""), upscale=UpscaleParams(""),
) )
result.validate()
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)

View File

@ -6,13 +6,14 @@ from onnx_web.params import HighresParams, UpscaleParams
from onnx_web.server.context import ServerContext from onnx_web.server.context import ServerContext
from onnx_web.server.hacks import apply_patches from onnx_web.server.hacks import apply_patches
from onnx_web.worker.context import WorkerContext from onnx_web.worker.context import WorkerContext
from tests.helpers import test_device, test_needs_onnx_models from tests.helpers import test_device, test_needs_models
TEST_MODEL = "../models/correction-gfpgan-v1-3" TEST_MODEL_NAME = "correction-gfpgan-v1-3"
TEST_MODEL = f"../models/.cache/{TEST_MODEL_NAME}.pth"
class CorrectGFPGANStageTests(unittest.TestCase): class CorrectGFPGANStageTests(unittest.TestCase):
@test_needs_onnx_models([TEST_MODEL]) @test_needs_models([TEST_MODEL])
def test_empty(self): def test_empty(self):
server = ServerContext(model_path="../models", output_path="../outputs") server = ServerContext(model_path="../models", output_path="../outputs")
apply_patches(server) apply_patches(server)
@ -33,12 +34,13 @@ class CorrectGFPGANStageTests(unittest.TestCase):
sources = StageResult.empty() sources = StageResult.empty()
result = stage.run( result = stage.run(
worker, worker,
None, server,
None, None,
None, None,
sources, sources,
highres=HighresParams(False, 1, 0, 0), highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(TEST_MODEL), upscale=UpscaleParams(TEST_MODEL_NAME, TEST_MODEL_NAME),
) )
result.validate()
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)

View File

@ -20,5 +20,6 @@ class ReduceCropStageTests(unittest.TestCase):
origin=Size(0, 0), origin=Size(0, 0),
size=Size(128, 128), size=Size(128, 128),
) )
result.validate()
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)

View File

@ -24,5 +24,6 @@ class ReduceThumbnailStageTests(unittest.TestCase):
size=Size(128, 128), size=Size(128, 128),
stage_source=stage_source, stage_source=stage_source,
) )
result.validate()
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)

View File

@ -22,5 +22,6 @@ class SourceNoiseStageTests(unittest.TestCase):
size=Size(128, 128), size=Size(128, 128),
noise_source=noise_source_fill_edge, noise_source=noise_source_fill_edge,
) )
result.validate()
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)

View File

@ -22,5 +22,6 @@ class SourceS3StageTests(unittest.TestCase):
bucket="test", bucket="test",
source_keys=[], source_keys=[],
) )
result.validate()
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)

View File

@ -21,5 +21,6 @@ class SourceURLStageTests(unittest.TestCase):
size=Size(128, 128), size=Size(128, 128),
source_urls=[], source_urls=[],
) )
result.validate()
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)

View File

@ -37,5 +37,6 @@ class UpscaleBSRGANStageTests(unittest.TestCase):
highres=HighresParams(False, 1, 0, 0), highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(TEST_MODEL), upscale=UpscaleParams(TEST_MODEL),
) )
result.validate()
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)

View File

@ -18,5 +18,6 @@ class UpscaleHighresStageTests(unittest.TestCase):
highres=HighresParams(False, 1, 0, 0), highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(""), upscale=UpscaleParams(""),
) )
result.validate()
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)

View File

@ -46,5 +46,6 @@ class UpscaleOutpaintStageTests(unittest.TestCase):
dims=(), dims=(),
tile_mask=Image.new("RGB", (64, 64)), tile_mask=Image.new("RGB", (64, 64)),
) )
result.validate()
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)

View File

@ -35,5 +35,6 @@ class UpscaleRealESRGANStageTests(unittest.TestCase):
highres=HighresParams(False, 1, 0, 0), highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams("upscaling-real-esrgan-x4-v3"), upscale=UpscaleParams("upscaling-real-esrgan-x4-v3"),
) )
result.validate()
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)

View File

@ -35,5 +35,6 @@ class UpscaleSwinIRStageTests(unittest.TestCase):
highres=HighresParams(False, 1, 0, 0), highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(TEST_MODEL), upscale=UpscaleParams(TEST_MODEL),
) )
result.validate()
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)

View File

@ -1,9 +1,17 @@
import unittest import unittest
from os import path
from unittest import mock
from unittest.mock import MagicMock, patch
from onnx_web.convert.utils import ( from onnx_web.convert.utils import (
DEFAULT_OPSET, DEFAULT_OPSET,
ConversionContext, ConversionContext,
build_cache_paths,
download_progress, download_progress,
fix_diffusion_name,
get_first_exists,
load_tensor,
load_torch,
remove_prefix, remove_prefix,
resolve_tensor, resolve_tensor,
source_format, source_format,
@ -30,6 +38,32 @@ class DownloadProgressTests(unittest.TestCase):
path = download_progress("https://example.com", "/tmp/example-dot-com") path = download_progress("https://example.com", "/tmp/example-dot-com")
self.assertEqual(path, "/tmp/example-dot-com") self.assertEqual(path, "/tmp/example-dot-com")
@patch("onnx_web.convert.utils.Path")
@patch("onnx_web.convert.utils.requests")
@patch("onnx_web.convert.utils.shutil")
@patch("onnx_web.convert.utils.tqdm")
def test_download_progress(self, mock_tqdm, mock_shutil, mock_requests, mock_path):
source = "http://example.com/image.jpg"
dest = "/path/to/destination/image.jpg"
dest_path_mock = MagicMock()
mock_path.return_value.expanduser.return_value.resolve.return_value = (
dest_path_mock
)
dest_path_mock.exists.return_value = False
dest_path_mock.absolute.return_value = "test"
mock_requests.get.return_value.status_code = 200
mock_requests.get.return_value.headers.get.return_value = "1000"
mock_tqdm.wrapattr.return_value.__enter__.return_value = MagicMock()
result = download_progress(source, dest)
mock_path.assert_called_once_with(dest)
dest_path_mock.parent.mkdir.assert_called_once_with(parents=True, exist_ok=True)
dest_path_mock.open.assert_called_once_with("wb")
mock_shutil.copyfileobj.assert_called_once()
self.assertEqual(result, str(dest_path_mock.absolute.return_value))
class TupleToSourceTests(unittest.TestCase): class TupleToSourceTests(unittest.TestCase):
def test_basic_tuple(self): def test_basic_tuple(self):
@ -221,14 +255,6 @@ class RemovePrefixTests(unittest.TestCase):
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar") self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
class LoadTorchTests(unittest.TestCase):
pass
class LoadTensorTests(unittest.TestCase):
pass
class ResolveTensorTests(unittest.TestCase): class ResolveTensorTests(unittest.TestCase):
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR]) @test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
def test_resolve_existing(self): def test_resolve_existing(self):
@ -239,3 +265,251 @@ class ResolveTensorTests(unittest.TestCase):
def test_resolve_missing(self): def test_resolve_missing(self):
self.assertIsNone(resolve_tensor("missing")) self.assertIsNone(resolve_tensor("missing"))
class LoadTorchTests(unittest.TestCase):
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.torch")
def test_load_torch_with_torch_load(self, mock_torch, mock_logger):
name = "model.pth"
map_location = "cpu"
checkpoint = MagicMock()
mock_torch.load.return_value = checkpoint
result = load_torch(name, map_location)
mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name)
mock_torch.load.assert_called_once_with(name, map_location=map_location)
self.assertEqual(result, checkpoint)
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.torch")
def test_load_torch_with_torch_jit_load(self, mock_torch, mock_logger):
name = "model.pth"
checkpoint = MagicMock()
mock_torch.load.side_effect = Exception()
mock_torch.jit.load.return_value = checkpoint
result = load_torch(name)
mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name)
mock_logger.exception.assert_called_once_with(
"error loading with Torch, trying with Torch JIT: %s", name
)
mock_torch.jit.load.assert_called_once_with(name)
self.assertEqual(result, checkpoint)
class LoadTensorTests(unittest.TestCase):
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.path")
@patch("onnx_web.convert.utils.torch")
def test_load_tensor_with_no_extension(self, mock_torch, mock_path, mock_logger):
name = "model"
map_location = "cpu"
checkpoint = MagicMock()
mock_path.exists.return_value = True
mock_path.splitext.side_effect = [("model", ""), ("model", ".safetensors")]
mock_torch.load.return_value = checkpoint
result = load_tensor(name, map_location)
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
mock_path.splitext.assert_called_once_with(name)
mock_path.exists.assert_called_once_with(name)
mock_torch.load.assert_called_once_with(name, map_location=map_location)
self.assertEqual(result, checkpoint)
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.environ")
@patch("onnx_web.convert.utils.safetensors")
def test_load_tensor_with_safetensors_extension(
self, mock_safetensors, mock_environ, mock_logger
):
name = "model.safetensors"
checkpoint = MagicMock()
mock_environ.__getitem__.return_value = "1"
mock_safetensors.torch.load_file.return_value = checkpoint
result = load_tensor(name)
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
mock_safetensors.torch.load_file.assert_called_once_with(name, device="cpu")
self.assertEqual(result, checkpoint)
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.torch")
def test_load_tensor_with_pickle_extension(self, mock_torch, mock_logger):
name = "model.pt"
map_location = "cpu"
checkpoint = MagicMock()
mock_torch.load.side_effect = [checkpoint]
result = load_tensor(name, map_location)
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
mock_torch.load.assert_has_calls(
[
mock.call(name, map_location=map_location),
]
)
self.assertEqual(result, checkpoint)
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.torch")
def test_load_tensor_with_onnx_extension(self, mock_torch, mock_logger):
name = "model.onnx"
map_location = "cpu"
checkpoint = MagicMock()
mock_torch.load.side_effect = [checkpoint]
result = load_tensor(name, map_location)
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
mock_logger.warning.assert_called_once_with(
"tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx"
)
mock_torch.load.assert_has_calls(
[
mock.call(name, map_location=map_location),
]
)
self.assertEqual(result, checkpoint)
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.torch")
def test_load_tensor_with_unknown_extension(self, mock_torch, mock_logger):
name = "model.xyz"
map_location = "cpu"
checkpoint = MagicMock()
mock_torch.load.side_effect = [checkpoint]
result = load_tensor(name, map_location)
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
mock_logger.warning.assert_called_once_with(
"unknown tensor type, falling back to PyTorch: %s", "xyz"
)
mock_torch.load.assert_has_calls(
[
mock.call(name, map_location=map_location),
]
)
self.assertEqual(result, checkpoint)
@patch("onnx_web.convert.utils.logger")
@patch("onnx_web.convert.utils.torch")
def test_load_tensor_with_error_loading_tensor(self, mock_torch, mock_logger):
name = "model"
map_location = "cpu"
mock_torch.load.side_effect = Exception()
with self.assertRaises(ValueError):
load_tensor(name, map_location)
class FixDiffusionNameTests(unittest.TestCase):
def test_fix_diffusion_name_with_valid_name(self):
name = "diffusion-model"
result = fix_diffusion_name(name)
self.assertEqual(result, name)
@patch("onnx_web.convert.utils.logger")
def test_fix_diffusion_name_with_invalid_name(self, logger):
name = "model"
expected_result = "diffusion-model"
result = fix_diffusion_name(name)
self.assertEqual(result, expected_result)
logger.warning.assert_called_once_with(
"diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match",
name,
)
class BuildCachePathsTests(unittest.TestCase):
def test_build_cache_paths_without_format(self):
name = "model.onnx"
client = "client1"
cache = "/path/to/cache"
conversion = ConversionContext(cache_path=cache)
result = build_cache_paths(conversion, name, client, cache)
expected_paths = [
path.join("/path/to/cache", "model.onnx"),
path.join("/path/to/cache/client1", "model.onnx"),
]
self.assertEqual(result, expected_paths)
def test_build_cache_paths_with_format(self):
name = "model"
client = "client2"
cache = "/path/to/cache"
format = "onnx"
conversion = ConversionContext(cache_path=cache)
result = build_cache_paths(conversion, name, client, cache, format)
expected_paths = [
path.join("/path/to/cache", "model.onnx"),
path.join("/path/to/cache/client2", "model.onnx"),
]
self.assertEqual(result, expected_paths)
def test_build_cache_paths_with_existing_extension(self):
name = "model.pth"
client = "client3"
cache = "/path/to/cache"
format = "onnx"
conversion = ConversionContext(cache_path=cache)
result = build_cache_paths(conversion, name, client, cache, format)
expected_paths = [
path.join("/path/to/cache", "model.pth"),
path.join("/path/to/cache/client3", "model.pth"),
]
self.assertEqual(result, expected_paths)
def test_build_cache_paths_with_empty_extension(self):
name = "model"
client = "client4"
cache = "/path/to/cache"
format = "onnx"
conversion = ConversionContext(cache_path=cache)
result = build_cache_paths(conversion, name, client, cache, format)
expected_paths = [
path.join("/path/to/cache", "model.onnx"),
path.join("/path/to/cache/client4", "model.onnx"),
]
self.assertEqual(result, expected_paths)
class GetFirstExistsTests(unittest.TestCase):
@patch("onnx_web.convert.utils.path")
@patch("onnx_web.convert.utils.logger")
def test_get_first_exists_with_existing_path(self, mock_logger, mock_path):
paths = ["path1", "path2", "path3"]
mock_path.exists.side_effect = [False, True, False]
mock_path.return_value = MagicMock()
result = get_first_exists(paths)
mock_logger.debug.assert_called_once_with(
"model already exists in cache, skipping fetch: %s", "path2"
)
self.assertEqual(result, "path2")
@patch("onnx_web.convert.utils.path")
@patch("onnx_web.convert.utils.logger")
def test_get_first_exists_with_no_existing_path(self, mock_logger, mock_path):
paths = ["path1", "path2", "path3"]
mock_path.exists.return_value = False
result = get_first_exists(paths)
mock_logger.debug.assert_not_called()
self.assertIsNone(result)

View File

@ -78,9 +78,10 @@ class TestTxt2ImgPipeline(unittest.TestCase):
) )
self.assertTrue(path.exists("../outputs/test-txt2img-basic.png")) self.assertTrue(path.exists("../outputs/test-txt2img-basic.png"))
output = Image.open("../outputs/test-txt2img-basic.png")
self.assertEqual(output.size, (256, 256)) with Image.open("../outputs/test-txt2img-basic.png") as output:
# TODO: test contents of image self.assertEqual(output.size, (256, 256))
# TODO: test contents of image
@test_needs_models([TEST_MODEL_DIFFUSION_SD15]) @test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_batch(self): def test_batch(self):
@ -126,9 +127,9 @@ class TestTxt2ImgPipeline(unittest.TestCase):
self.assertTrue(path.exists("../outputs/test-txt2img-batch-0.png")) self.assertTrue(path.exists("../outputs/test-txt2img-batch-0.png"))
self.assertTrue(path.exists("../outputs/test-txt2img-batch-1.png")) self.assertTrue(path.exists("../outputs/test-txt2img-batch-1.png"))
output = Image.open("../outputs/test-txt2img-batch-0.png") with Image.open("../outputs/test-txt2img-batch-0.png") as output:
self.assertEqual(output.size, (256, 256)) self.assertEqual(output.size, (256, 256))
# TODO: test contents of image # TODO: test contents of image
@test_needs_models([TEST_MODEL_DIFFUSION_SD15]) @test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_highres(self): def test_highres(self):
@ -172,8 +173,8 @@ class TestTxt2ImgPipeline(unittest.TestCase):
) )
self.assertTrue(path.exists("../outputs/test-txt2img-highres.png")) self.assertTrue(path.exists("../outputs/test-txt2img-highres.png"))
output = Image.open("../outputs/test-txt2img-highres.png") with Image.open("../outputs/test-txt2img-highres.png") as output:
self.assertEqual(output.size, (512, 512)) self.assertEqual(output.size, (512, 512))
@test_needs_models([TEST_MODEL_DIFFUSION_SD15]) @test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_highres_batch(self): def test_highres_batch(self):
@ -219,8 +220,8 @@ class TestTxt2ImgPipeline(unittest.TestCase):
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-0.png")) self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-0.png"))
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-1.png")) self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-1.png"))
output = Image.open("../outputs/test-txt2img-highres-batch-0.png") with Image.open("../outputs/test-txt2img-highres-batch-0.png") as output:
self.assertEqual(output.size, (512, 512)) self.assertEqual(output.size, (512, 512))
class TestImg2ImgPipeline(unittest.TestCase): class TestImg2ImgPipeline(unittest.TestCase):

View File

@ -7,20 +7,29 @@ from onnx_web.params import DeviceParams
from onnx_web.server.context import ServerContext from onnx_web.server.context import ServerContext
from onnx_web.worker.command import JobStatus from onnx_web.worker.command import JobStatus
from onnx_web.worker.pool import DevicePoolExecutor from onnx_web.worker.pool import DevicePoolExecutor
from tests.helpers import test_device
TEST_JOIN_TIMEOUT = 0.2 TEST_JOIN_TIMEOUT = 0.2
lock = Event() lock = Event()
def test_job(*args, **kwargs): def lock_job(*args, **kwargs):
lock.wait() lock.wait()
def wait_job(*args, **kwargs): def sleep_job(*args, **kwargs):
sleep(0.5) sleep(0.5)
def progress_job(worker, *args, **kwargs):
worker.set_progress(1)
def fail_job(*args, **kwargs):
raise RuntimeError("job failed")
class TestWorkerPool(unittest.TestCase): class TestWorkerPool(unittest.TestCase):
# lock: Optional[Event] # lock: Optional[Event]
pool: Optional[DevicePoolExecutor] pool: Optional[DevicePoolExecutor]
@ -38,20 +47,20 @@ class TestWorkerPool(unittest.TestCase):
self.pool.start() self.pool.start()
def test_fake_worker(self): def test_fake_worker(self):
device = DeviceParams("cpu", "CPUProvider") device = test_device()
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start() self.pool.start()
self.assertEqual(len(self.pool.workers), 1) self.assertEqual(len(self.pool.workers), 1)
def test_cancel_pending(self): def test_cancel_pending(self):
device = DeviceParams("cpu", "CPUProvider") device = test_device()
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start() self.pool.start()
self.pool.submit("test", "test", wait_job, lock=lock) self.pool.submit("test", "test", sleep_job, lock=lock)
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None)) self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
self.assertTrue(self.pool.cancel("test")) self.assertTrue(self.pool.cancel("test"))
@ -61,7 +70,7 @@ class TestWorkerPool(unittest.TestCase):
pass pass
def test_next_device(self): def test_next_device(self):
device = DeviceParams("cpu", "CPUProvider") device = test_device()
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start() self.pool.start()
@ -83,7 +92,7 @@ class TestWorkerPool(unittest.TestCase):
""" """
TODO: flaky TODO: flaky
""" """
device = DeviceParams("cpu", "CPUProvider") device = test_device()
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor( self.pool = DevicePoolExecutor(
@ -92,21 +101,21 @@ class TestWorkerPool(unittest.TestCase):
lock.clear() lock.clear()
self.pool.start(lock) self.pool.start(lock)
self.pool.submit("test", "test", test_job) self.pool.submit("test", "test", lock_job)
sleep(5.0) sleep(5.0)
status, _progress = self.pool.status("test") status, _progress = self.pool.status("test")
self.assertEqual(status, JobStatus.RUNNING) self.assertEqual(status, JobStatus.RUNNING)
def test_done_pending(self): def test_done_pending(self):
device = DeviceParams("cpu", "CPUProvider") device = test_device()
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start(lock) self.pool.start(lock)
self.pool.submit("test1", "test", test_job) self.pool.submit("test1", "test", lock_job)
self.pool.submit("test2", "test", test_job) self.pool.submit("test2", "test", lock_job)
self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None)) self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None))
lock.set() lock.set()
@ -115,14 +124,14 @@ class TestWorkerPool(unittest.TestCase):
""" """
TODO: flaky TODO: flaky
""" """
device = DeviceParams("cpu", "CPUProvider") device = test_device()
server = ServerContext() server = ServerContext()
self.pool = DevicePoolExecutor( self.pool = DevicePoolExecutor(
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
) )
self.pool.start() self.pool.start()
self.pool.submit("test", "test", wait_job) self.pool.submit("test", "test", sleep_job)
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None)) self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
sleep(5.0) sleep(5.0)
@ -137,3 +146,38 @@ class TestWorkerPool(unittest.TestCase):
def test_running_status(self): def test_running_status(self):
pass pass
def test_progress_update(self):
pass
def test_progress_finished(self):
device = test_device()
server = ServerContext()
self.pool = DevicePoolExecutor(
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
)
self.pool.start()
self.pool.submit("test", "test", progress_job)
sleep(5.0)
status, progress = self.pool.status("test")
self.assertEqual(status, JobStatus.SUCCESS)
self.assertEqual(progress.steps.current, 1)
def test_progress_failed(self):
device = test_device()
server = ServerContext()
self.pool = DevicePoolExecutor(
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
)
self.pool.start()
self.pool.submit("test", "test", fail_job)
sleep(5.0)
status, progress = self.pool.status("test")
self.assertEqual(status, JobStatus.FAILED)
self.assertEqual(progress.steps.current, 0)