update tests
This commit is contained in:
parent
88f99ef6c2
commit
6b0b2e41a6
|
@ -107,7 +107,7 @@ def download_progress(source: str, dest: str):
|
|||
stream=True,
|
||||
allow_redirects=True,
|
||||
headers={
|
||||
"User-Agent": "onnx-web-api",
|
||||
"User-Agent": "onnx-web-api", # TODO: add version
|
||||
},
|
||||
)
|
||||
if req.status_code != 200:
|
||||
|
@ -226,9 +226,7 @@ def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
|||
logger.debug("loading tensor with Torch: %s", name)
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error loading with Torch JIT, trying with Torch JIT: %s", name
|
||||
)
|
||||
logger.exception("error loading with Torch, trying with Torch JIT: %s", name)
|
||||
checkpoint = torch.jit.load(name)
|
||||
|
||||
return checkpoint
|
||||
|
|
|
@ -81,6 +81,12 @@ class Size:
|
|||
def __str__(self) -> str:
|
||||
return "%sx%s" % (self.width, self.height)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, Size):
|
||||
return self.width == other.width and self.height == other.height
|
||||
|
||||
return False
|
||||
|
||||
def add_border(self, border: Border):
|
||||
return Size(
|
||||
border.left + self.width + border.right,
|
||||
|
|
|
@ -22,6 +22,13 @@ class JobType(str, Enum):
|
|||
|
||||
|
||||
class Progress:
|
||||
"""
|
||||
Generic counter with current and expected/final/total value. Can be used to count up or down.
|
||||
|
||||
Counter is considered "complete" when the current value is greater than or equal to the total value, and "empty"
|
||||
when the current value is zero.
|
||||
"""
|
||||
|
||||
current: int
|
||||
total: int
|
||||
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.blend_denoise_localstd import BlendDenoiseLocalStdStage
|
||||
from onnx_web.chain.result import ImageMetadata, StageResult
|
||||
from onnx_web.params import ImageParams, Size
|
||||
|
||||
|
||||
class TestBlendDenoiseLocalStdStage(unittest.TestCase):
|
||||
def test_run(self):
|
||||
# Create a dummy image
|
||||
image = Image.new("RGB", (64, 64), color="white")
|
||||
|
||||
# Create a dummy StageResult object
|
||||
sources = StageResult.from_images(
|
||||
[image],
|
||||
metadata=[
|
||||
ImageMetadata(
|
||||
ImageParams("test", "txt2img", "ddim", "test", 5.0, 25, 0),
|
||||
Size(64, 64),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Create an instance of BlendDenoiseLocalStdStage
|
||||
stage = BlendDenoiseLocalStdStage()
|
||||
|
||||
# Call the run method with dummy parameters
|
||||
result = stage.run(
|
||||
_worker=None,
|
||||
_server=None,
|
||||
_stage=None,
|
||||
_params=None,
|
||||
sources=sources,
|
||||
strength=5,
|
||||
range=4,
|
||||
stage_source=None,
|
||||
callback=None,
|
||||
)
|
||||
|
||||
# Assert that the result is an instance of StageResult
|
||||
self.assertIsInstance(result, StageResult)
|
||||
|
||||
# Assert that the result contains the denoised image
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result.size(), Size(64, 64))
|
||||
|
||||
# Assert that the metadata is preserved
|
||||
self.assertEqual(result.metadata, sources.metadata)
|
|
@ -3,7 +3,8 @@ import unittest
|
|||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.blend_grid import BlendGridStage
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.chain.result import ImageMetadata, StageResult
|
||||
from onnx_web.params import ImageParams, Size
|
||||
|
||||
|
||||
class BlendGridStageTests(unittest.TestCase):
|
||||
|
@ -15,9 +16,17 @@ class BlendGridStageTests(unittest.TestCase):
|
|||
Image.new("RGB", (64, 64), "white"),
|
||||
Image.new("RGB", (64, 64), "black"),
|
||||
Image.new("RGB", (64, 64), "white"),
|
||||
],
|
||||
metadata=[
|
||||
ImageMetadata(
|
||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1),
|
||||
Size(64, 64),
|
||||
),
|
||||
]
|
||||
* 4,
|
||||
)
|
||||
result = stage.run(None, None, None, None, sources, height=2, width=2)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 5)
|
||||
self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0))
|
||||
self.assertEqual(result.as_images()[-1].getpixel((0, 0)), (0, 0, 0))
|
||||
|
|
|
@ -3,8 +3,8 @@ import unittest
|
|||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.blend_img2img import BlendImg2ImgStage
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.params import ImageParams
|
||||
from onnx_web.chain.result import ImageMetadata, StageResult
|
||||
from onnx_web.params import ImageParams, Size
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
|
||||
|
@ -39,9 +39,16 @@ class BlendImg2ImgStageTests(unittest.TestCase):
|
|||
sources = StageResult(
|
||||
images=[
|
||||
Image.new("RGB", (64, 64), "black"),
|
||||
]
|
||||
],
|
||||
metadata=[
|
||||
ImageMetadata(
|
||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1),
|
||||
Size(64, 64),
|
||||
),
|
||||
],
|
||||
)
|
||||
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (0, 0, 0))
|
||||
self.assertEqual(result.as_images()[0].getpixel((0, 0)), (0, 0, 0))
|
||||
|
|
|
@ -3,7 +3,8 @@ import unittest
|
|||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.blend_linear import BlendLinearStage
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.chain.result import ImageMetadata, StageResult
|
||||
from onnx_web.params import ImageParams, Size
|
||||
|
||||
|
||||
class BlendLinearStageTests(unittest.TestCase):
|
||||
|
@ -12,12 +13,19 @@ class BlendLinearStageTests(unittest.TestCase):
|
|||
sources = StageResult(
|
||||
images=[
|
||||
Image.new("RGB", (64, 64), "black"),
|
||||
]
|
||||
],
|
||||
metadata=[
|
||||
ImageMetadata(
|
||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1),
|
||||
Size(64, 64),
|
||||
),
|
||||
],
|
||||
)
|
||||
stage_source = Image.new("RGB", (64, 64), "white")
|
||||
result = stage.run(
|
||||
None, None, None, None, sources, alpha=0.5, stage_source=stage_source
|
||||
)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127))
|
||||
self.assertEqual(result.as_images()[0].getpixel((0, 0)), (127, 127, 127))
|
||||
|
|
|
@ -23,5 +23,6 @@ class BlendMaskStageTests(unittest.TestCase):
|
|||
stage_source=Image.new("RGBA", (64, 64)),
|
||||
dims=(0, 0, SizeChart.auto),
|
||||
)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
|
|
@ -42,5 +42,6 @@ class CorrectCodeformerStageTests(unittest.TestCase):
|
|||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(""),
|
||||
)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
|
|
@ -6,13 +6,14 @@ from onnx_web.params import HighresParams, UpscaleParams
|
|||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.server.hacks import apply_patches
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from tests.helpers import test_device, test_needs_onnx_models
|
||||
from tests.helpers import test_device, test_needs_models
|
||||
|
||||
TEST_MODEL = "../models/correction-gfpgan-v1-3"
|
||||
TEST_MODEL_NAME = "correction-gfpgan-v1-3"
|
||||
TEST_MODEL = f"../models/.cache/{TEST_MODEL_NAME}.pth"
|
||||
|
||||
|
||||
class CorrectGFPGANStageTests(unittest.TestCase):
|
||||
@test_needs_onnx_models([TEST_MODEL])
|
||||
@test_needs_models([TEST_MODEL])
|
||||
def test_empty(self):
|
||||
server = ServerContext(model_path="../models", output_path="../outputs")
|
||||
apply_patches(server)
|
||||
|
@ -33,12 +34,13 @@ class CorrectGFPGANStageTests(unittest.TestCase):
|
|||
sources = StageResult.empty()
|
||||
result = stage.run(
|
||||
worker,
|
||||
None,
|
||||
server,
|
||||
None,
|
||||
None,
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(TEST_MODEL),
|
||||
upscale=UpscaleParams(TEST_MODEL_NAME, TEST_MODEL_NAME),
|
||||
)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
|
|
@ -20,5 +20,6 @@ class ReduceCropStageTests(unittest.TestCase):
|
|||
origin=Size(0, 0),
|
||||
size=Size(128, 128),
|
||||
)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
|
|
@ -24,5 +24,6 @@ class ReduceThumbnailStageTests(unittest.TestCase):
|
|||
size=Size(128, 128),
|
||||
stage_source=stage_source,
|
||||
)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
|
|
@ -22,5 +22,6 @@ class SourceNoiseStageTests(unittest.TestCase):
|
|||
size=Size(128, 128),
|
||||
noise_source=noise_source_fill_edge,
|
||||
)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
|
|
@ -22,5 +22,6 @@ class SourceS3StageTests(unittest.TestCase):
|
|||
bucket="test",
|
||||
source_keys=[],
|
||||
)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
|
|
@ -21,5 +21,6 @@ class SourceURLStageTests(unittest.TestCase):
|
|||
size=Size(128, 128),
|
||||
source_urls=[],
|
||||
)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
|
|
@ -37,5 +37,6 @@ class UpscaleBSRGANStageTests(unittest.TestCase):
|
|||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(TEST_MODEL),
|
||||
)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
|
|
@ -18,5 +18,6 @@ class UpscaleHighresStageTests(unittest.TestCase):
|
|||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(""),
|
||||
)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
|
|
@ -46,5 +46,6 @@ class UpscaleOutpaintStageTests(unittest.TestCase):
|
|||
dims=(),
|
||||
tile_mask=Image.new("RGB", (64, 64)),
|
||||
)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
|
|
@ -35,5 +35,6 @@ class UpscaleRealESRGANStageTests(unittest.TestCase):
|
|||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams("upscaling-real-esrgan-x4-v3"),
|
||||
)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
|
|
@ -35,5 +35,6 @@ class UpscaleSwinIRStageTests(unittest.TestCase):
|
|||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(TEST_MODEL),
|
||||
)
|
||||
result.validate()
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
|
|
@ -1,9 +1,17 @@
|
|||
import unittest
|
||||
from os import path
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from onnx_web.convert.utils import (
|
||||
DEFAULT_OPSET,
|
||||
ConversionContext,
|
||||
build_cache_paths,
|
||||
download_progress,
|
||||
fix_diffusion_name,
|
||||
get_first_exists,
|
||||
load_tensor,
|
||||
load_torch,
|
||||
remove_prefix,
|
||||
resolve_tensor,
|
||||
source_format,
|
||||
|
@ -30,6 +38,32 @@ class DownloadProgressTests(unittest.TestCase):
|
|||
path = download_progress("https://example.com", "/tmp/example-dot-com")
|
||||
self.assertEqual(path, "/tmp/example-dot-com")
|
||||
|
||||
@patch("onnx_web.convert.utils.Path")
|
||||
@patch("onnx_web.convert.utils.requests")
|
||||
@patch("onnx_web.convert.utils.shutil")
|
||||
@patch("onnx_web.convert.utils.tqdm")
|
||||
def test_download_progress(self, mock_tqdm, mock_shutil, mock_requests, mock_path):
|
||||
source = "http://example.com/image.jpg"
|
||||
dest = "/path/to/destination/image.jpg"
|
||||
|
||||
dest_path_mock = MagicMock()
|
||||
mock_path.return_value.expanduser.return_value.resolve.return_value = (
|
||||
dest_path_mock
|
||||
)
|
||||
dest_path_mock.exists.return_value = False
|
||||
dest_path_mock.absolute.return_value = "test"
|
||||
mock_requests.get.return_value.status_code = 200
|
||||
mock_requests.get.return_value.headers.get.return_value = "1000"
|
||||
mock_tqdm.wrapattr.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
result = download_progress(source, dest)
|
||||
|
||||
mock_path.assert_called_once_with(dest)
|
||||
dest_path_mock.parent.mkdir.assert_called_once_with(parents=True, exist_ok=True)
|
||||
dest_path_mock.open.assert_called_once_with("wb")
|
||||
mock_shutil.copyfileobj.assert_called_once()
|
||||
self.assertEqual(result, str(dest_path_mock.absolute.return_value))
|
||||
|
||||
|
||||
class TupleToSourceTests(unittest.TestCase):
|
||||
def test_basic_tuple(self):
|
||||
|
@ -221,14 +255,6 @@ class RemovePrefixTests(unittest.TestCase):
|
|||
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
|
||||
|
||||
|
||||
class LoadTorchTests(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class LoadTensorTests(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class ResolveTensorTests(unittest.TestCase):
|
||||
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
|
||||
def test_resolve_existing(self):
|
||||
|
@ -239,3 +265,251 @@ class ResolveTensorTests(unittest.TestCase):
|
|||
|
||||
def test_resolve_missing(self):
|
||||
self.assertIsNone(resolve_tensor("missing"))
|
||||
|
||||
|
||||
class LoadTorchTests(unittest.TestCase):
|
||||
@patch("onnx_web.convert.utils.logger")
|
||||
@patch("onnx_web.convert.utils.torch")
|
||||
def test_load_torch_with_torch_load(self, mock_torch, mock_logger):
|
||||
name = "model.pth"
|
||||
map_location = "cpu"
|
||||
checkpoint = MagicMock()
|
||||
mock_torch.load.return_value = checkpoint
|
||||
|
||||
result = load_torch(name, map_location)
|
||||
|
||||
mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name)
|
||||
mock_torch.load.assert_called_once_with(name, map_location=map_location)
|
||||
self.assertEqual(result, checkpoint)
|
||||
|
||||
@patch("onnx_web.convert.utils.logger")
|
||||
@patch("onnx_web.convert.utils.torch")
|
||||
def test_load_torch_with_torch_jit_load(self, mock_torch, mock_logger):
|
||||
name = "model.pth"
|
||||
checkpoint = MagicMock()
|
||||
mock_torch.load.side_effect = Exception()
|
||||
mock_torch.jit.load.return_value = checkpoint
|
||||
|
||||
result = load_torch(name)
|
||||
|
||||
mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name)
|
||||
mock_logger.exception.assert_called_once_with(
|
||||
"error loading with Torch, trying with Torch JIT: %s", name
|
||||
)
|
||||
mock_torch.jit.load.assert_called_once_with(name)
|
||||
self.assertEqual(result, checkpoint)
|
||||
|
||||
|
||||
class LoadTensorTests(unittest.TestCase):
|
||||
@patch("onnx_web.convert.utils.logger")
|
||||
@patch("onnx_web.convert.utils.path")
|
||||
@patch("onnx_web.convert.utils.torch")
|
||||
def test_load_tensor_with_no_extension(self, mock_torch, mock_path, mock_logger):
|
||||
name = "model"
|
||||
map_location = "cpu"
|
||||
checkpoint = MagicMock()
|
||||
mock_path.exists.return_value = True
|
||||
mock_path.splitext.side_effect = [("model", ""), ("model", ".safetensors")]
|
||||
mock_torch.load.return_value = checkpoint
|
||||
|
||||
result = load_tensor(name, map_location)
|
||||
|
||||
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
||||
mock_path.splitext.assert_called_once_with(name)
|
||||
mock_path.exists.assert_called_once_with(name)
|
||||
mock_torch.load.assert_called_once_with(name, map_location=map_location)
|
||||
self.assertEqual(result, checkpoint)
|
||||
|
||||
@patch("onnx_web.convert.utils.logger")
|
||||
@patch("onnx_web.convert.utils.environ")
|
||||
@patch("onnx_web.convert.utils.safetensors")
|
||||
def test_load_tensor_with_safetensors_extension(
|
||||
self, mock_safetensors, mock_environ, mock_logger
|
||||
):
|
||||
name = "model.safetensors"
|
||||
checkpoint = MagicMock()
|
||||
mock_environ.__getitem__.return_value = "1"
|
||||
mock_safetensors.torch.load_file.return_value = checkpoint
|
||||
|
||||
result = load_tensor(name)
|
||||
|
||||
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
||||
mock_safetensors.torch.load_file.assert_called_once_with(name, device="cpu")
|
||||
self.assertEqual(result, checkpoint)
|
||||
|
||||
@patch("onnx_web.convert.utils.logger")
|
||||
@patch("onnx_web.convert.utils.torch")
|
||||
def test_load_tensor_with_pickle_extension(self, mock_torch, mock_logger):
|
||||
name = "model.pt"
|
||||
map_location = "cpu"
|
||||
checkpoint = MagicMock()
|
||||
mock_torch.load.side_effect = [checkpoint]
|
||||
|
||||
result = load_tensor(name, map_location)
|
||||
|
||||
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
||||
mock_torch.load.assert_has_calls(
|
||||
[
|
||||
mock.call(name, map_location=map_location),
|
||||
]
|
||||
)
|
||||
self.assertEqual(result, checkpoint)
|
||||
|
||||
@patch("onnx_web.convert.utils.logger")
|
||||
@patch("onnx_web.convert.utils.torch")
|
||||
def test_load_tensor_with_onnx_extension(self, mock_torch, mock_logger):
|
||||
name = "model.onnx"
|
||||
map_location = "cpu"
|
||||
checkpoint = MagicMock()
|
||||
mock_torch.load.side_effect = [checkpoint]
|
||||
|
||||
result = load_tensor(name, map_location)
|
||||
|
||||
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx"
|
||||
)
|
||||
mock_torch.load.assert_has_calls(
|
||||
[
|
||||
mock.call(name, map_location=map_location),
|
||||
]
|
||||
)
|
||||
self.assertEqual(result, checkpoint)
|
||||
|
||||
@patch("onnx_web.convert.utils.logger")
|
||||
@patch("onnx_web.convert.utils.torch")
|
||||
def test_load_tensor_with_unknown_extension(self, mock_torch, mock_logger):
|
||||
name = "model.xyz"
|
||||
map_location = "cpu"
|
||||
checkpoint = MagicMock()
|
||||
mock_torch.load.side_effect = [checkpoint]
|
||||
|
||||
result = load_tensor(name, map_location)
|
||||
|
||||
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"unknown tensor type, falling back to PyTorch: %s", "xyz"
|
||||
)
|
||||
mock_torch.load.assert_has_calls(
|
||||
[
|
||||
mock.call(name, map_location=map_location),
|
||||
]
|
||||
)
|
||||
self.assertEqual(result, checkpoint)
|
||||
|
||||
@patch("onnx_web.convert.utils.logger")
|
||||
@patch("onnx_web.convert.utils.torch")
|
||||
def test_load_tensor_with_error_loading_tensor(self, mock_torch, mock_logger):
|
||||
name = "model"
|
||||
map_location = "cpu"
|
||||
mock_torch.load.side_effect = Exception()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
load_tensor(name, map_location)
|
||||
|
||||
|
||||
class FixDiffusionNameTests(unittest.TestCase):
|
||||
def test_fix_diffusion_name_with_valid_name(self):
|
||||
name = "diffusion-model"
|
||||
result = fix_diffusion_name(name)
|
||||
self.assertEqual(result, name)
|
||||
|
||||
@patch("onnx_web.convert.utils.logger")
|
||||
def test_fix_diffusion_name_with_invalid_name(self, logger):
|
||||
name = "model"
|
||||
expected_result = "diffusion-model"
|
||||
result = fix_diffusion_name(name)
|
||||
|
||||
self.assertEqual(result, expected_result)
|
||||
logger.warning.assert_called_once_with(
|
||||
"diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match",
|
||||
name,
|
||||
)
|
||||
|
||||
|
||||
class BuildCachePathsTests(unittest.TestCase):
|
||||
def test_build_cache_paths_without_format(self):
|
||||
name = "model.onnx"
|
||||
client = "client1"
|
||||
cache = "/path/to/cache"
|
||||
|
||||
conversion = ConversionContext(cache_path=cache)
|
||||
result = build_cache_paths(conversion, name, client, cache)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", "model.onnx"),
|
||||
path.join("/path/to/cache/client1", "model.onnx"),
|
||||
]
|
||||
self.assertEqual(result, expected_paths)
|
||||
|
||||
def test_build_cache_paths_with_format(self):
|
||||
name = "model"
|
||||
client = "client2"
|
||||
cache = "/path/to/cache"
|
||||
format = "onnx"
|
||||
|
||||
conversion = ConversionContext(cache_path=cache)
|
||||
result = build_cache_paths(conversion, name, client, cache, format)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", "model.onnx"),
|
||||
path.join("/path/to/cache/client2", "model.onnx"),
|
||||
]
|
||||
self.assertEqual(result, expected_paths)
|
||||
|
||||
def test_build_cache_paths_with_existing_extension(self):
|
||||
name = "model.pth"
|
||||
client = "client3"
|
||||
cache = "/path/to/cache"
|
||||
format = "onnx"
|
||||
|
||||
conversion = ConversionContext(cache_path=cache)
|
||||
result = build_cache_paths(conversion, name, client, cache, format)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", "model.pth"),
|
||||
path.join("/path/to/cache/client3", "model.pth"),
|
||||
]
|
||||
self.assertEqual(result, expected_paths)
|
||||
|
||||
def test_build_cache_paths_with_empty_extension(self):
|
||||
name = "model"
|
||||
client = "client4"
|
||||
cache = "/path/to/cache"
|
||||
format = "onnx"
|
||||
|
||||
conversion = ConversionContext(cache_path=cache)
|
||||
result = build_cache_paths(conversion, name, client, cache, format)
|
||||
|
||||
expected_paths = [
|
||||
path.join("/path/to/cache", "model.onnx"),
|
||||
path.join("/path/to/cache/client4", "model.onnx"),
|
||||
]
|
||||
self.assertEqual(result, expected_paths)
|
||||
|
||||
|
||||
class GetFirstExistsTests(unittest.TestCase):
|
||||
@patch("onnx_web.convert.utils.path")
|
||||
@patch("onnx_web.convert.utils.logger")
|
||||
def test_get_first_exists_with_existing_path(self, mock_logger, mock_path):
|
||||
paths = ["path1", "path2", "path3"]
|
||||
mock_path.exists.side_effect = [False, True, False]
|
||||
mock_path.return_value = MagicMock()
|
||||
|
||||
result = get_first_exists(paths)
|
||||
|
||||
mock_logger.debug.assert_called_once_with(
|
||||
"model already exists in cache, skipping fetch: %s", "path2"
|
||||
)
|
||||
self.assertEqual(result, "path2")
|
||||
|
||||
@patch("onnx_web.convert.utils.path")
|
||||
@patch("onnx_web.convert.utils.logger")
|
||||
def test_get_first_exists_with_no_existing_path(self, mock_logger, mock_path):
|
||||
paths = ["path1", "path2", "path3"]
|
||||
mock_path.exists.return_value = False
|
||||
|
||||
result = get_first_exists(paths)
|
||||
|
||||
mock_logger.debug.assert_not_called()
|
||||
self.assertIsNone(result)
|
||||
|
|
|
@ -78,7 +78,8 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
)
|
||||
|
||||
self.assertTrue(path.exists("../outputs/test-txt2img-basic.png"))
|
||||
output = Image.open("../outputs/test-txt2img-basic.png")
|
||||
|
||||
with Image.open("../outputs/test-txt2img-basic.png") as output:
|
||||
self.assertEqual(output.size, (256, 256))
|
||||
# TODO: test contents of image
|
||||
|
||||
|
@ -126,7 +127,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
self.assertTrue(path.exists("../outputs/test-txt2img-batch-0.png"))
|
||||
self.assertTrue(path.exists("../outputs/test-txt2img-batch-1.png"))
|
||||
|
||||
output = Image.open("../outputs/test-txt2img-batch-0.png")
|
||||
with Image.open("../outputs/test-txt2img-batch-0.png") as output:
|
||||
self.assertEqual(output.size, (256, 256))
|
||||
# TODO: test contents of image
|
||||
|
||||
|
@ -172,7 +173,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
)
|
||||
|
||||
self.assertTrue(path.exists("../outputs/test-txt2img-highres.png"))
|
||||
output = Image.open("../outputs/test-txt2img-highres.png")
|
||||
with Image.open("../outputs/test-txt2img-highres.png") as output:
|
||||
self.assertEqual(output.size, (512, 512))
|
||||
|
||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||
|
@ -219,7 +220,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-0.png"))
|
||||
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-1.png"))
|
||||
|
||||
output = Image.open("../outputs/test-txt2img-highres-batch-0.png")
|
||||
with Image.open("../outputs/test-txt2img-highres-batch-0.png") as output:
|
||||
self.assertEqual(output.size, (512, 512))
|
||||
|
||||
|
||||
|
|
|
@ -7,20 +7,29 @@ from onnx_web.params import DeviceParams
|
|||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.command import JobStatus
|
||||
from onnx_web.worker.pool import DevicePoolExecutor
|
||||
from tests.helpers import test_device
|
||||
|
||||
TEST_JOIN_TIMEOUT = 0.2
|
||||
|
||||
lock = Event()
|
||||
|
||||
|
||||
def test_job(*args, **kwargs):
|
||||
def lock_job(*args, **kwargs):
|
||||
lock.wait()
|
||||
|
||||
|
||||
def wait_job(*args, **kwargs):
|
||||
def sleep_job(*args, **kwargs):
|
||||
sleep(0.5)
|
||||
|
||||
|
||||
def progress_job(worker, *args, **kwargs):
|
||||
worker.set_progress(1)
|
||||
|
||||
|
||||
def fail_job(*args, **kwargs):
|
||||
raise RuntimeError("job failed")
|
||||
|
||||
|
||||
class TestWorkerPool(unittest.TestCase):
|
||||
# lock: Optional[Event]
|
||||
pool: Optional[DevicePoolExecutor]
|
||||
|
@ -38,20 +47,20 @@ class TestWorkerPool(unittest.TestCase):
|
|||
self.pool.start()
|
||||
|
||||
def test_fake_worker(self):
|
||||
device = DeviceParams("cpu", "CPUProvider")
|
||||
device = test_device()
|
||||
server = ServerContext()
|
||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||
self.pool.start()
|
||||
self.assertEqual(len(self.pool.workers), 1)
|
||||
|
||||
def test_cancel_pending(self):
|
||||
device = DeviceParams("cpu", "CPUProvider")
|
||||
device = test_device()
|
||||
server = ServerContext()
|
||||
|
||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||
self.pool.start()
|
||||
|
||||
self.pool.submit("test", "test", wait_job, lock=lock)
|
||||
self.pool.submit("test", "test", sleep_job, lock=lock)
|
||||
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
|
||||
|
||||
self.assertTrue(self.pool.cancel("test"))
|
||||
|
@ -61,7 +70,7 @@ class TestWorkerPool(unittest.TestCase):
|
|||
pass
|
||||
|
||||
def test_next_device(self):
|
||||
device = DeviceParams("cpu", "CPUProvider")
|
||||
device = test_device()
|
||||
server = ServerContext()
|
||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||
self.pool.start()
|
||||
|
@ -83,7 +92,7 @@ class TestWorkerPool(unittest.TestCase):
|
|||
"""
|
||||
TODO: flaky
|
||||
"""
|
||||
device = DeviceParams("cpu", "CPUProvider")
|
||||
device = test_device()
|
||||
server = ServerContext()
|
||||
|
||||
self.pool = DevicePoolExecutor(
|
||||
|
@ -92,21 +101,21 @@ class TestWorkerPool(unittest.TestCase):
|
|||
|
||||
lock.clear()
|
||||
self.pool.start(lock)
|
||||
self.pool.submit("test", "test", test_job)
|
||||
self.pool.submit("test", "test", lock_job)
|
||||
sleep(5.0)
|
||||
|
||||
status, _progress = self.pool.status("test")
|
||||
self.assertEqual(status, JobStatus.RUNNING)
|
||||
|
||||
def test_done_pending(self):
|
||||
device = DeviceParams("cpu", "CPUProvider")
|
||||
device = test_device()
|
||||
server = ServerContext()
|
||||
|
||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||
self.pool.start(lock)
|
||||
|
||||
self.pool.submit("test1", "test", test_job)
|
||||
self.pool.submit("test2", "test", test_job)
|
||||
self.pool.submit("test1", "test", lock_job)
|
||||
self.pool.submit("test2", "test", lock_job)
|
||||
self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None))
|
||||
|
||||
lock.set()
|
||||
|
@ -115,14 +124,14 @@ class TestWorkerPool(unittest.TestCase):
|
|||
"""
|
||||
TODO: flaky
|
||||
"""
|
||||
device = DeviceParams("cpu", "CPUProvider")
|
||||
device = test_device()
|
||||
server = ServerContext()
|
||||
|
||||
self.pool = DevicePoolExecutor(
|
||||
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||
)
|
||||
self.pool.start()
|
||||
self.pool.submit("test", "test", wait_job)
|
||||
self.pool.submit("test", "test", sleep_job)
|
||||
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
|
||||
|
||||
sleep(5.0)
|
||||
|
@ -137,3 +146,38 @@ class TestWorkerPool(unittest.TestCase):
|
|||
|
||||
def test_running_status(self):
|
||||
pass
|
||||
|
||||
def test_progress_update(self):
|
||||
pass
|
||||
|
||||
def test_progress_finished(self):
|
||||
device = test_device()
|
||||
server = ServerContext()
|
||||
|
||||
self.pool = DevicePoolExecutor(
|
||||
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||
)
|
||||
self.pool.start()
|
||||
|
||||
self.pool.submit("test", "test", progress_job)
|
||||
sleep(5.0)
|
||||
|
||||
status, progress = self.pool.status("test")
|
||||
self.assertEqual(status, JobStatus.SUCCESS)
|
||||
self.assertEqual(progress.steps.current, 1)
|
||||
|
||||
def test_progress_failed(self):
|
||||
device = test_device()
|
||||
server = ServerContext()
|
||||
|
||||
self.pool = DevicePoolExecutor(
|
||||
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||
)
|
||||
self.pool.start()
|
||||
|
||||
self.pool.submit("test", "test", fail_job)
|
||||
sleep(5.0)
|
||||
|
||||
status, progress = self.pool.status("test")
|
||||
self.assertEqual(status, JobStatus.FAILED)
|
||||
self.assertEqual(progress.steps.current, 0)
|
||||
|
|
Loading…
Reference in New Issue