update tests
This commit is contained in:
parent
88f99ef6c2
commit
6b0b2e41a6
|
@ -107,7 +107,7 @@ def download_progress(source: str, dest: str):
|
||||||
stream=True,
|
stream=True,
|
||||||
allow_redirects=True,
|
allow_redirects=True,
|
||||||
headers={
|
headers={
|
||||||
"User-Agent": "onnx-web-api",
|
"User-Agent": "onnx-web-api", # TODO: add version
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if req.status_code != 200:
|
if req.status_code != 200:
|
||||||
|
@ -226,9 +226,7 @@ def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
||||||
logger.debug("loading tensor with Torch: %s", name)
|
logger.debug("loading tensor with Torch: %s", name)
|
||||||
checkpoint = torch.load(name, map_location=map_location)
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception("error loading with Torch, trying with Torch JIT: %s", name)
|
||||||
"error loading with Torch JIT, trying with Torch JIT: %s", name
|
|
||||||
)
|
|
||||||
checkpoint = torch.jit.load(name)
|
checkpoint = torch.jit.load(name)
|
||||||
|
|
||||||
return checkpoint
|
return checkpoint
|
||||||
|
|
|
@ -81,6 +81,12 @@ class Size:
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return "%sx%s" % (self.width, self.height)
|
return "%sx%s" % (self.width, self.height)
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
if isinstance(other, Size):
|
||||||
|
return self.width == other.width and self.height == other.height
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def add_border(self, border: Border):
|
def add_border(self, border: Border):
|
||||||
return Size(
|
return Size(
|
||||||
border.left + self.width + border.right,
|
border.left + self.width + border.right,
|
||||||
|
|
|
@ -22,6 +22,13 @@ class JobType(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class Progress:
|
class Progress:
|
||||||
|
"""
|
||||||
|
Generic counter with current and expected/final/total value. Can be used to count up or down.
|
||||||
|
|
||||||
|
Counter is considered "complete" when the current value is greater than or equal to the total value, and "empty"
|
||||||
|
when the current value is zero.
|
||||||
|
"""
|
||||||
|
|
||||||
current: int
|
current: int
|
||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from onnx_web.chain.blend_denoise_localstd import BlendDenoiseLocalStdStage
|
||||||
|
from onnx_web.chain.result import ImageMetadata, StageResult
|
||||||
|
from onnx_web.params import ImageParams, Size
|
||||||
|
|
||||||
|
|
||||||
|
class TestBlendDenoiseLocalStdStage(unittest.TestCase):
|
||||||
|
def test_run(self):
|
||||||
|
# Create a dummy image
|
||||||
|
image = Image.new("RGB", (64, 64), color="white")
|
||||||
|
|
||||||
|
# Create a dummy StageResult object
|
||||||
|
sources = StageResult.from_images(
|
||||||
|
[image],
|
||||||
|
metadata=[
|
||||||
|
ImageMetadata(
|
||||||
|
ImageParams("test", "txt2img", "ddim", "test", 5.0, 25, 0),
|
||||||
|
Size(64, 64),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an instance of BlendDenoiseLocalStdStage
|
||||||
|
stage = BlendDenoiseLocalStdStage()
|
||||||
|
|
||||||
|
# Call the run method with dummy parameters
|
||||||
|
result = stage.run(
|
||||||
|
_worker=None,
|
||||||
|
_server=None,
|
||||||
|
_stage=None,
|
||||||
|
_params=None,
|
||||||
|
sources=sources,
|
||||||
|
strength=5,
|
||||||
|
range=4,
|
||||||
|
stage_source=None,
|
||||||
|
callback=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that the result is an instance of StageResult
|
||||||
|
self.assertIsInstance(result, StageResult)
|
||||||
|
|
||||||
|
# Assert that the result contains the denoised image
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
self.assertEqual(result.size(), Size(64, 64))
|
||||||
|
|
||||||
|
# Assert that the metadata is preserved
|
||||||
|
self.assertEqual(result.metadata, sources.metadata)
|
|
@ -3,7 +3,8 @@ import unittest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from onnx_web.chain.blend_grid import BlendGridStage
|
from onnx_web.chain.blend_grid import BlendGridStage
|
||||||
from onnx_web.chain.result import StageResult
|
from onnx_web.chain.result import ImageMetadata, StageResult
|
||||||
|
from onnx_web.params import ImageParams, Size
|
||||||
|
|
||||||
|
|
||||||
class BlendGridStageTests(unittest.TestCase):
|
class BlendGridStageTests(unittest.TestCase):
|
||||||
|
@ -15,9 +16,17 @@ class BlendGridStageTests(unittest.TestCase):
|
||||||
Image.new("RGB", (64, 64), "white"),
|
Image.new("RGB", (64, 64), "white"),
|
||||||
Image.new("RGB", (64, 64), "black"),
|
Image.new("RGB", (64, 64), "black"),
|
||||||
Image.new("RGB", (64, 64), "white"),
|
Image.new("RGB", (64, 64), "white"),
|
||||||
|
],
|
||||||
|
metadata=[
|
||||||
|
ImageMetadata(
|
||||||
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1),
|
||||||
|
Size(64, 64),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
* 4,
|
||||||
)
|
)
|
||||||
result = stage.run(None, None, None, None, sources, height=2, width=2)
|
result = stage.run(None, None, None, None, sources, height=2, width=2)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 5)
|
self.assertEqual(len(result), 5)
|
||||||
self.assertEqual(result.as_image()[-1].getpixel((0, 0)), (0, 0, 0))
|
self.assertEqual(result.as_images()[-1].getpixel((0, 0)), (0, 0, 0))
|
||||||
|
|
|
@ -3,8 +3,8 @@ import unittest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from onnx_web.chain.blend_img2img import BlendImg2ImgStage
|
from onnx_web.chain.blend_img2img import BlendImg2ImgStage
|
||||||
from onnx_web.chain.result import StageResult
|
from onnx_web.chain.result import ImageMetadata, StageResult
|
||||||
from onnx_web.params import ImageParams
|
from onnx_web.params import ImageParams, Size
|
||||||
from onnx_web.server.context import ServerContext
|
from onnx_web.server.context import ServerContext
|
||||||
from onnx_web.worker.context import WorkerContext
|
from onnx_web.worker.context import WorkerContext
|
||||||
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
|
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
|
||||||
|
@ -39,9 +39,16 @@ class BlendImg2ImgStageTests(unittest.TestCase):
|
||||||
sources = StageResult(
|
sources = StageResult(
|
||||||
images=[
|
images=[
|
||||||
Image.new("RGB", (64, 64), "black"),
|
Image.new("RGB", (64, 64), "black"),
|
||||||
]
|
],
|
||||||
|
metadata=[
|
||||||
|
ImageMetadata(
|
||||||
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1),
|
||||||
|
Size(64, 64),
|
||||||
|
),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
|
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 1)
|
self.assertEqual(len(result), 1)
|
||||||
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (0, 0, 0))
|
self.assertEqual(result.as_images()[0].getpixel((0, 0)), (0, 0, 0))
|
||||||
|
|
|
@ -3,7 +3,8 @@ import unittest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from onnx_web.chain.blend_linear import BlendLinearStage
|
from onnx_web.chain.blend_linear import BlendLinearStage
|
||||||
from onnx_web.chain.result import StageResult
|
from onnx_web.chain.result import ImageMetadata, StageResult
|
||||||
|
from onnx_web.params import ImageParams, Size
|
||||||
|
|
||||||
|
|
||||||
class BlendLinearStageTests(unittest.TestCase):
|
class BlendLinearStageTests(unittest.TestCase):
|
||||||
|
@ -12,12 +13,19 @@ class BlendLinearStageTests(unittest.TestCase):
|
||||||
sources = StageResult(
|
sources = StageResult(
|
||||||
images=[
|
images=[
|
||||||
Image.new("RGB", (64, 64), "black"),
|
Image.new("RGB", (64, 64), "black"),
|
||||||
]
|
],
|
||||||
|
metadata=[
|
||||||
|
ImageMetadata(
|
||||||
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1),
|
||||||
|
Size(64, 64),
|
||||||
|
),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
stage_source = Image.new("RGB", (64, 64), "white")
|
stage_source = Image.new("RGB", (64, 64), "white")
|
||||||
result = stage.run(
|
result = stage.run(
|
||||||
None, None, None, None, sources, alpha=0.5, stage_source=stage_source
|
None, None, None, None, sources, alpha=0.5, stage_source=stage_source
|
||||||
)
|
)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 1)
|
self.assertEqual(len(result), 1)
|
||||||
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (127, 127, 127))
|
self.assertEqual(result.as_images()[0].getpixel((0, 0)), (127, 127, 127))
|
||||||
|
|
|
@ -23,5 +23,6 @@ class BlendMaskStageTests(unittest.TestCase):
|
||||||
stage_source=Image.new("RGBA", (64, 64)),
|
stage_source=Image.new("RGBA", (64, 64)),
|
||||||
dims=(0, 0, SizeChart.auto),
|
dims=(0, 0, SizeChart.auto),
|
||||||
)
|
)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
|
@ -42,5 +42,6 @@ class CorrectCodeformerStageTests(unittest.TestCase):
|
||||||
highres=HighresParams(False, 1, 0, 0),
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
upscale=UpscaleParams(""),
|
upscale=UpscaleParams(""),
|
||||||
)
|
)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
|
@ -6,13 +6,14 @@ from onnx_web.params import HighresParams, UpscaleParams
|
||||||
from onnx_web.server.context import ServerContext
|
from onnx_web.server.context import ServerContext
|
||||||
from onnx_web.server.hacks import apply_patches
|
from onnx_web.server.hacks import apply_patches
|
||||||
from onnx_web.worker.context import WorkerContext
|
from onnx_web.worker.context import WorkerContext
|
||||||
from tests.helpers import test_device, test_needs_onnx_models
|
from tests.helpers import test_device, test_needs_models
|
||||||
|
|
||||||
TEST_MODEL = "../models/correction-gfpgan-v1-3"
|
TEST_MODEL_NAME = "correction-gfpgan-v1-3"
|
||||||
|
TEST_MODEL = f"../models/.cache/{TEST_MODEL_NAME}.pth"
|
||||||
|
|
||||||
|
|
||||||
class CorrectGFPGANStageTests(unittest.TestCase):
|
class CorrectGFPGANStageTests(unittest.TestCase):
|
||||||
@test_needs_onnx_models([TEST_MODEL])
|
@test_needs_models([TEST_MODEL])
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
server = ServerContext(model_path="../models", output_path="../outputs")
|
server = ServerContext(model_path="../models", output_path="../outputs")
|
||||||
apply_patches(server)
|
apply_patches(server)
|
||||||
|
@ -33,12 +34,13 @@ class CorrectGFPGANStageTests(unittest.TestCase):
|
||||||
sources = StageResult.empty()
|
sources = StageResult.empty()
|
||||||
result = stage.run(
|
result = stage.run(
|
||||||
worker,
|
worker,
|
||||||
None,
|
server,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
sources,
|
sources,
|
||||||
highres=HighresParams(False, 1, 0, 0),
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
upscale=UpscaleParams(TEST_MODEL),
|
upscale=UpscaleParams(TEST_MODEL_NAME, TEST_MODEL_NAME),
|
||||||
)
|
)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
|
@ -20,5 +20,6 @@ class ReduceCropStageTests(unittest.TestCase):
|
||||||
origin=Size(0, 0),
|
origin=Size(0, 0),
|
||||||
size=Size(128, 128),
|
size=Size(128, 128),
|
||||||
)
|
)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
|
@ -24,5 +24,6 @@ class ReduceThumbnailStageTests(unittest.TestCase):
|
||||||
size=Size(128, 128),
|
size=Size(128, 128),
|
||||||
stage_source=stage_source,
|
stage_source=stage_source,
|
||||||
)
|
)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
|
@ -22,5 +22,6 @@ class SourceNoiseStageTests(unittest.TestCase):
|
||||||
size=Size(128, 128),
|
size=Size(128, 128),
|
||||||
noise_source=noise_source_fill_edge,
|
noise_source=noise_source_fill_edge,
|
||||||
)
|
)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
|
@ -22,5 +22,6 @@ class SourceS3StageTests(unittest.TestCase):
|
||||||
bucket="test",
|
bucket="test",
|
||||||
source_keys=[],
|
source_keys=[],
|
||||||
)
|
)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
|
@ -21,5 +21,6 @@ class SourceURLStageTests(unittest.TestCase):
|
||||||
size=Size(128, 128),
|
size=Size(128, 128),
|
||||||
source_urls=[],
|
source_urls=[],
|
||||||
)
|
)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
|
@ -37,5 +37,6 @@ class UpscaleBSRGANStageTests(unittest.TestCase):
|
||||||
highres=HighresParams(False, 1, 0, 0),
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
upscale=UpscaleParams(TEST_MODEL),
|
upscale=UpscaleParams(TEST_MODEL),
|
||||||
)
|
)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
|
@ -18,5 +18,6 @@ class UpscaleHighresStageTests(unittest.TestCase):
|
||||||
highres=HighresParams(False, 1, 0, 0),
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
upscale=UpscaleParams(""),
|
upscale=UpscaleParams(""),
|
||||||
)
|
)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
|
@ -46,5 +46,6 @@ class UpscaleOutpaintStageTests(unittest.TestCase):
|
||||||
dims=(),
|
dims=(),
|
||||||
tile_mask=Image.new("RGB", (64, 64)),
|
tile_mask=Image.new("RGB", (64, 64)),
|
||||||
)
|
)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
|
@ -35,5 +35,6 @@ class UpscaleRealESRGANStageTests(unittest.TestCase):
|
||||||
highres=HighresParams(False, 1, 0, 0),
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
upscale=UpscaleParams("upscaling-real-esrgan-x4-v3"),
|
upscale=UpscaleParams("upscaling-real-esrgan-x4-v3"),
|
||||||
)
|
)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
|
@ -35,5 +35,6 @@ class UpscaleSwinIRStageTests(unittest.TestCase):
|
||||||
highres=HighresParams(False, 1, 0, 0),
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
upscale=UpscaleParams(TEST_MODEL),
|
upscale=UpscaleParams(TEST_MODEL),
|
||||||
)
|
)
|
||||||
|
result.validate()
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
|
@ -1,9 +1,17 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
from os import path
|
||||||
|
from unittest import mock
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from onnx_web.convert.utils import (
|
from onnx_web.convert.utils import (
|
||||||
DEFAULT_OPSET,
|
DEFAULT_OPSET,
|
||||||
ConversionContext,
|
ConversionContext,
|
||||||
|
build_cache_paths,
|
||||||
download_progress,
|
download_progress,
|
||||||
|
fix_diffusion_name,
|
||||||
|
get_first_exists,
|
||||||
|
load_tensor,
|
||||||
|
load_torch,
|
||||||
remove_prefix,
|
remove_prefix,
|
||||||
resolve_tensor,
|
resolve_tensor,
|
||||||
source_format,
|
source_format,
|
||||||
|
@ -30,6 +38,32 @@ class DownloadProgressTests(unittest.TestCase):
|
||||||
path = download_progress("https://example.com", "/tmp/example-dot-com")
|
path = download_progress("https://example.com", "/tmp/example-dot-com")
|
||||||
self.assertEqual(path, "/tmp/example-dot-com")
|
self.assertEqual(path, "/tmp/example-dot-com")
|
||||||
|
|
||||||
|
@patch("onnx_web.convert.utils.Path")
|
||||||
|
@patch("onnx_web.convert.utils.requests")
|
||||||
|
@patch("onnx_web.convert.utils.shutil")
|
||||||
|
@patch("onnx_web.convert.utils.tqdm")
|
||||||
|
def test_download_progress(self, mock_tqdm, mock_shutil, mock_requests, mock_path):
|
||||||
|
source = "http://example.com/image.jpg"
|
||||||
|
dest = "/path/to/destination/image.jpg"
|
||||||
|
|
||||||
|
dest_path_mock = MagicMock()
|
||||||
|
mock_path.return_value.expanduser.return_value.resolve.return_value = (
|
||||||
|
dest_path_mock
|
||||||
|
)
|
||||||
|
dest_path_mock.exists.return_value = False
|
||||||
|
dest_path_mock.absolute.return_value = "test"
|
||||||
|
mock_requests.get.return_value.status_code = 200
|
||||||
|
mock_requests.get.return_value.headers.get.return_value = "1000"
|
||||||
|
mock_tqdm.wrapattr.return_value.__enter__.return_value = MagicMock()
|
||||||
|
|
||||||
|
result = download_progress(source, dest)
|
||||||
|
|
||||||
|
mock_path.assert_called_once_with(dest)
|
||||||
|
dest_path_mock.parent.mkdir.assert_called_once_with(parents=True, exist_ok=True)
|
||||||
|
dest_path_mock.open.assert_called_once_with("wb")
|
||||||
|
mock_shutil.copyfileobj.assert_called_once()
|
||||||
|
self.assertEqual(result, str(dest_path_mock.absolute.return_value))
|
||||||
|
|
||||||
|
|
||||||
class TupleToSourceTests(unittest.TestCase):
|
class TupleToSourceTests(unittest.TestCase):
|
||||||
def test_basic_tuple(self):
|
def test_basic_tuple(self):
|
||||||
|
@ -221,14 +255,6 @@ class RemovePrefixTests(unittest.TestCase):
|
||||||
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
|
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
|
||||||
|
|
||||||
|
|
||||||
class LoadTorchTests(unittest.TestCase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LoadTensorTests(unittest.TestCase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ResolveTensorTests(unittest.TestCase):
|
class ResolveTensorTests(unittest.TestCase):
|
||||||
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
|
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
|
||||||
def test_resolve_existing(self):
|
def test_resolve_existing(self):
|
||||||
|
@ -239,3 +265,251 @@ class ResolveTensorTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_resolve_missing(self):
|
def test_resolve_missing(self):
|
||||||
self.assertIsNone(resolve_tensor("missing"))
|
self.assertIsNone(resolve_tensor("missing"))
|
||||||
|
|
||||||
|
|
||||||
|
class LoadTorchTests(unittest.TestCase):
|
||||||
|
@patch("onnx_web.convert.utils.logger")
|
||||||
|
@patch("onnx_web.convert.utils.torch")
|
||||||
|
def test_load_torch_with_torch_load(self, mock_torch, mock_logger):
|
||||||
|
name = "model.pth"
|
||||||
|
map_location = "cpu"
|
||||||
|
checkpoint = MagicMock()
|
||||||
|
mock_torch.load.return_value = checkpoint
|
||||||
|
|
||||||
|
result = load_torch(name, map_location)
|
||||||
|
|
||||||
|
mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name)
|
||||||
|
mock_torch.load.assert_called_once_with(name, map_location=map_location)
|
||||||
|
self.assertEqual(result, checkpoint)
|
||||||
|
|
||||||
|
@patch("onnx_web.convert.utils.logger")
|
||||||
|
@patch("onnx_web.convert.utils.torch")
|
||||||
|
def test_load_torch_with_torch_jit_load(self, mock_torch, mock_logger):
|
||||||
|
name = "model.pth"
|
||||||
|
checkpoint = MagicMock()
|
||||||
|
mock_torch.load.side_effect = Exception()
|
||||||
|
mock_torch.jit.load.return_value = checkpoint
|
||||||
|
|
||||||
|
result = load_torch(name)
|
||||||
|
|
||||||
|
mock_logger.debug.assert_called_once_with("loading tensor with Torch: %s", name)
|
||||||
|
mock_logger.exception.assert_called_once_with(
|
||||||
|
"error loading with Torch, trying with Torch JIT: %s", name
|
||||||
|
)
|
||||||
|
mock_torch.jit.load.assert_called_once_with(name)
|
||||||
|
self.assertEqual(result, checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
class LoadTensorTests(unittest.TestCase):
|
||||||
|
@patch("onnx_web.convert.utils.logger")
|
||||||
|
@patch("onnx_web.convert.utils.path")
|
||||||
|
@patch("onnx_web.convert.utils.torch")
|
||||||
|
def test_load_tensor_with_no_extension(self, mock_torch, mock_path, mock_logger):
|
||||||
|
name = "model"
|
||||||
|
map_location = "cpu"
|
||||||
|
checkpoint = MagicMock()
|
||||||
|
mock_path.exists.return_value = True
|
||||||
|
mock_path.splitext.side_effect = [("model", ""), ("model", ".safetensors")]
|
||||||
|
mock_torch.load.return_value = checkpoint
|
||||||
|
|
||||||
|
result = load_tensor(name, map_location)
|
||||||
|
|
||||||
|
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
||||||
|
mock_path.splitext.assert_called_once_with(name)
|
||||||
|
mock_path.exists.assert_called_once_with(name)
|
||||||
|
mock_torch.load.assert_called_once_with(name, map_location=map_location)
|
||||||
|
self.assertEqual(result, checkpoint)
|
||||||
|
|
||||||
|
@patch("onnx_web.convert.utils.logger")
|
||||||
|
@patch("onnx_web.convert.utils.environ")
|
||||||
|
@patch("onnx_web.convert.utils.safetensors")
|
||||||
|
def test_load_tensor_with_safetensors_extension(
|
||||||
|
self, mock_safetensors, mock_environ, mock_logger
|
||||||
|
):
|
||||||
|
name = "model.safetensors"
|
||||||
|
checkpoint = MagicMock()
|
||||||
|
mock_environ.__getitem__.return_value = "1"
|
||||||
|
mock_safetensors.torch.load_file.return_value = checkpoint
|
||||||
|
|
||||||
|
result = load_tensor(name)
|
||||||
|
|
||||||
|
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
||||||
|
mock_safetensors.torch.load_file.assert_called_once_with(name, device="cpu")
|
||||||
|
self.assertEqual(result, checkpoint)
|
||||||
|
|
||||||
|
@patch("onnx_web.convert.utils.logger")
|
||||||
|
@patch("onnx_web.convert.utils.torch")
|
||||||
|
def test_load_tensor_with_pickle_extension(self, mock_torch, mock_logger):
|
||||||
|
name = "model.pt"
|
||||||
|
map_location = "cpu"
|
||||||
|
checkpoint = MagicMock()
|
||||||
|
mock_torch.load.side_effect = [checkpoint]
|
||||||
|
|
||||||
|
result = load_tensor(name, map_location)
|
||||||
|
|
||||||
|
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
||||||
|
mock_torch.load.assert_has_calls(
|
||||||
|
[
|
||||||
|
mock.call(name, map_location=map_location),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual(result, checkpoint)
|
||||||
|
|
||||||
|
@patch("onnx_web.convert.utils.logger")
|
||||||
|
@patch("onnx_web.convert.utils.torch")
|
||||||
|
def test_load_tensor_with_onnx_extension(self, mock_torch, mock_logger):
|
||||||
|
name = "model.onnx"
|
||||||
|
map_location = "cpu"
|
||||||
|
checkpoint = MagicMock()
|
||||||
|
mock_torch.load.side_effect = [checkpoint]
|
||||||
|
|
||||||
|
result = load_tensor(name, map_location)
|
||||||
|
|
||||||
|
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
||||||
|
mock_logger.warning.assert_called_once_with(
|
||||||
|
"tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx"
|
||||||
|
)
|
||||||
|
mock_torch.load.assert_has_calls(
|
||||||
|
[
|
||||||
|
mock.call(name, map_location=map_location),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual(result, checkpoint)
|
||||||
|
|
||||||
|
@patch("onnx_web.convert.utils.logger")
|
||||||
|
@patch("onnx_web.convert.utils.torch")
|
||||||
|
def test_load_tensor_with_unknown_extension(self, mock_torch, mock_logger):
|
||||||
|
name = "model.xyz"
|
||||||
|
map_location = "cpu"
|
||||||
|
checkpoint = MagicMock()
|
||||||
|
mock_torch.load.side_effect = [checkpoint]
|
||||||
|
|
||||||
|
result = load_tensor(name, map_location)
|
||||||
|
|
||||||
|
mock_logger.debug.assert_has_calls([mock.call("loading tensor: %s", name)])
|
||||||
|
mock_logger.warning.assert_called_once_with(
|
||||||
|
"unknown tensor type, falling back to PyTorch: %s", "xyz"
|
||||||
|
)
|
||||||
|
mock_torch.load.assert_has_calls(
|
||||||
|
[
|
||||||
|
mock.call(name, map_location=map_location),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual(result, checkpoint)
|
||||||
|
|
||||||
|
@patch("onnx_web.convert.utils.logger")
|
||||||
|
@patch("onnx_web.convert.utils.torch")
|
||||||
|
def test_load_tensor_with_error_loading_tensor(self, mock_torch, mock_logger):
|
||||||
|
name = "model"
|
||||||
|
map_location = "cpu"
|
||||||
|
mock_torch.load.side_effect = Exception()
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
load_tensor(name, map_location)
|
||||||
|
|
||||||
|
|
||||||
|
class FixDiffusionNameTests(unittest.TestCase):
|
||||||
|
def test_fix_diffusion_name_with_valid_name(self):
|
||||||
|
name = "diffusion-model"
|
||||||
|
result = fix_diffusion_name(name)
|
||||||
|
self.assertEqual(result, name)
|
||||||
|
|
||||||
|
@patch("onnx_web.convert.utils.logger")
|
||||||
|
def test_fix_diffusion_name_with_invalid_name(self, logger):
|
||||||
|
name = "model"
|
||||||
|
expected_result = "diffusion-model"
|
||||||
|
result = fix_diffusion_name(name)
|
||||||
|
|
||||||
|
self.assertEqual(result, expected_result)
|
||||||
|
logger.warning.assert_called_once_with(
|
||||||
|
"diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match",
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BuildCachePathsTests(unittest.TestCase):
|
||||||
|
def test_build_cache_paths_without_format(self):
|
||||||
|
name = "model.onnx"
|
||||||
|
client = "client1"
|
||||||
|
cache = "/path/to/cache"
|
||||||
|
|
||||||
|
conversion = ConversionContext(cache_path=cache)
|
||||||
|
result = build_cache_paths(conversion, name, client, cache)
|
||||||
|
|
||||||
|
expected_paths = [
|
||||||
|
path.join("/path/to/cache", "model.onnx"),
|
||||||
|
path.join("/path/to/cache/client1", "model.onnx"),
|
||||||
|
]
|
||||||
|
self.assertEqual(result, expected_paths)
|
||||||
|
|
||||||
|
def test_build_cache_paths_with_format(self):
|
||||||
|
name = "model"
|
||||||
|
client = "client2"
|
||||||
|
cache = "/path/to/cache"
|
||||||
|
format = "onnx"
|
||||||
|
|
||||||
|
conversion = ConversionContext(cache_path=cache)
|
||||||
|
result = build_cache_paths(conversion, name, client, cache, format)
|
||||||
|
|
||||||
|
expected_paths = [
|
||||||
|
path.join("/path/to/cache", "model.onnx"),
|
||||||
|
path.join("/path/to/cache/client2", "model.onnx"),
|
||||||
|
]
|
||||||
|
self.assertEqual(result, expected_paths)
|
||||||
|
|
||||||
|
def test_build_cache_paths_with_existing_extension(self):
|
||||||
|
name = "model.pth"
|
||||||
|
client = "client3"
|
||||||
|
cache = "/path/to/cache"
|
||||||
|
format = "onnx"
|
||||||
|
|
||||||
|
conversion = ConversionContext(cache_path=cache)
|
||||||
|
result = build_cache_paths(conversion, name, client, cache, format)
|
||||||
|
|
||||||
|
expected_paths = [
|
||||||
|
path.join("/path/to/cache", "model.pth"),
|
||||||
|
path.join("/path/to/cache/client3", "model.pth"),
|
||||||
|
]
|
||||||
|
self.assertEqual(result, expected_paths)
|
||||||
|
|
||||||
|
def test_build_cache_paths_with_empty_extension(self):
|
||||||
|
name = "model"
|
||||||
|
client = "client4"
|
||||||
|
cache = "/path/to/cache"
|
||||||
|
format = "onnx"
|
||||||
|
|
||||||
|
conversion = ConversionContext(cache_path=cache)
|
||||||
|
result = build_cache_paths(conversion, name, client, cache, format)
|
||||||
|
|
||||||
|
expected_paths = [
|
||||||
|
path.join("/path/to/cache", "model.onnx"),
|
||||||
|
path.join("/path/to/cache/client4", "model.onnx"),
|
||||||
|
]
|
||||||
|
self.assertEqual(result, expected_paths)
|
||||||
|
|
||||||
|
|
||||||
|
class GetFirstExistsTests(unittest.TestCase):
|
||||||
|
@patch("onnx_web.convert.utils.path")
|
||||||
|
@patch("onnx_web.convert.utils.logger")
|
||||||
|
def test_get_first_exists_with_existing_path(self, mock_logger, mock_path):
|
||||||
|
paths = ["path1", "path2", "path3"]
|
||||||
|
mock_path.exists.side_effect = [False, True, False]
|
||||||
|
mock_path.return_value = MagicMock()
|
||||||
|
|
||||||
|
result = get_first_exists(paths)
|
||||||
|
|
||||||
|
mock_logger.debug.assert_called_once_with(
|
||||||
|
"model already exists in cache, skipping fetch: %s", "path2"
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "path2")
|
||||||
|
|
||||||
|
@patch("onnx_web.convert.utils.path")
|
||||||
|
@patch("onnx_web.convert.utils.logger")
|
||||||
|
def test_get_first_exists_with_no_existing_path(self, mock_logger, mock_path):
|
||||||
|
paths = ["path1", "path2", "path3"]
|
||||||
|
mock_path.exists.return_value = False
|
||||||
|
|
||||||
|
result = get_first_exists(paths)
|
||||||
|
|
||||||
|
mock_logger.debug.assert_not_called()
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
|
@ -78,7 +78,8 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(path.exists("../outputs/test-txt2img-basic.png"))
|
self.assertTrue(path.exists("../outputs/test-txt2img-basic.png"))
|
||||||
output = Image.open("../outputs/test-txt2img-basic.png")
|
|
||||||
|
with Image.open("../outputs/test-txt2img-basic.png") as output:
|
||||||
self.assertEqual(output.size, (256, 256))
|
self.assertEqual(output.size, (256, 256))
|
||||||
# TODO: test contents of image
|
# TODO: test contents of image
|
||||||
|
|
||||||
|
@ -126,7 +127,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
self.assertTrue(path.exists("../outputs/test-txt2img-batch-0.png"))
|
self.assertTrue(path.exists("../outputs/test-txt2img-batch-0.png"))
|
||||||
self.assertTrue(path.exists("../outputs/test-txt2img-batch-1.png"))
|
self.assertTrue(path.exists("../outputs/test-txt2img-batch-1.png"))
|
||||||
|
|
||||||
output = Image.open("../outputs/test-txt2img-batch-0.png")
|
with Image.open("../outputs/test-txt2img-batch-0.png") as output:
|
||||||
self.assertEqual(output.size, (256, 256))
|
self.assertEqual(output.size, (256, 256))
|
||||||
# TODO: test contents of image
|
# TODO: test contents of image
|
||||||
|
|
||||||
|
@ -172,7 +173,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(path.exists("../outputs/test-txt2img-highres.png"))
|
self.assertTrue(path.exists("../outputs/test-txt2img-highres.png"))
|
||||||
output = Image.open("../outputs/test-txt2img-highres.png")
|
with Image.open("../outputs/test-txt2img-highres.png") as output:
|
||||||
self.assertEqual(output.size, (512, 512))
|
self.assertEqual(output.size, (512, 512))
|
||||||
|
|
||||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||||
|
@ -219,7 +220,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-0.png"))
|
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-0.png"))
|
||||||
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-1.png"))
|
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-1.png"))
|
||||||
|
|
||||||
output = Image.open("../outputs/test-txt2img-highres-batch-0.png")
|
with Image.open("../outputs/test-txt2img-highres-batch-0.png") as output:
|
||||||
self.assertEqual(output.size, (512, 512))
|
self.assertEqual(output.size, (512, 512))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,20 +7,29 @@ from onnx_web.params import DeviceParams
|
||||||
from onnx_web.server.context import ServerContext
|
from onnx_web.server.context import ServerContext
|
||||||
from onnx_web.worker.command import JobStatus
|
from onnx_web.worker.command import JobStatus
|
||||||
from onnx_web.worker.pool import DevicePoolExecutor
|
from onnx_web.worker.pool import DevicePoolExecutor
|
||||||
|
from tests.helpers import test_device
|
||||||
|
|
||||||
TEST_JOIN_TIMEOUT = 0.2
|
TEST_JOIN_TIMEOUT = 0.2
|
||||||
|
|
||||||
lock = Event()
|
lock = Event()
|
||||||
|
|
||||||
|
|
||||||
def test_job(*args, **kwargs):
|
def lock_job(*args, **kwargs):
|
||||||
lock.wait()
|
lock.wait()
|
||||||
|
|
||||||
|
|
||||||
def wait_job(*args, **kwargs):
|
def sleep_job(*args, **kwargs):
|
||||||
sleep(0.5)
|
sleep(0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def progress_job(worker, *args, **kwargs):
|
||||||
|
worker.set_progress(1)
|
||||||
|
|
||||||
|
|
||||||
|
def fail_job(*args, **kwargs):
|
||||||
|
raise RuntimeError("job failed")
|
||||||
|
|
||||||
|
|
||||||
class TestWorkerPool(unittest.TestCase):
|
class TestWorkerPool(unittest.TestCase):
|
||||||
# lock: Optional[Event]
|
# lock: Optional[Event]
|
||||||
pool: Optional[DevicePoolExecutor]
|
pool: Optional[DevicePoolExecutor]
|
||||||
|
@ -38,20 +47,20 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
self.pool.start()
|
self.pool.start()
|
||||||
|
|
||||||
def test_fake_worker(self):
|
def test_fake_worker(self):
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = test_device()
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||||
self.pool.start()
|
self.pool.start()
|
||||||
self.assertEqual(len(self.pool.workers), 1)
|
self.assertEqual(len(self.pool.workers), 1)
|
||||||
|
|
||||||
def test_cancel_pending(self):
|
def test_cancel_pending(self):
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = test_device()
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
|
|
||||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||||
self.pool.start()
|
self.pool.start()
|
||||||
|
|
||||||
self.pool.submit("test", "test", wait_job, lock=lock)
|
self.pool.submit("test", "test", sleep_job, lock=lock)
|
||||||
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
|
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
|
||||||
|
|
||||||
self.assertTrue(self.pool.cancel("test"))
|
self.assertTrue(self.pool.cancel("test"))
|
||||||
|
@ -61,7 +70,7 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_next_device(self):
|
def test_next_device(self):
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = test_device()
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||||
self.pool.start()
|
self.pool.start()
|
||||||
|
@ -83,7 +92,7 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
TODO: flaky
|
TODO: flaky
|
||||||
"""
|
"""
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = test_device()
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
|
|
||||||
self.pool = DevicePoolExecutor(
|
self.pool = DevicePoolExecutor(
|
||||||
|
@ -92,21 +101,21 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
|
|
||||||
lock.clear()
|
lock.clear()
|
||||||
self.pool.start(lock)
|
self.pool.start(lock)
|
||||||
self.pool.submit("test", "test", test_job)
|
self.pool.submit("test", "test", lock_job)
|
||||||
sleep(5.0)
|
sleep(5.0)
|
||||||
|
|
||||||
status, _progress = self.pool.status("test")
|
status, _progress = self.pool.status("test")
|
||||||
self.assertEqual(status, JobStatus.RUNNING)
|
self.assertEqual(status, JobStatus.RUNNING)
|
||||||
|
|
||||||
def test_done_pending(self):
|
def test_done_pending(self):
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = test_device()
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
|
|
||||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||||
self.pool.start(lock)
|
self.pool.start(lock)
|
||||||
|
|
||||||
self.pool.submit("test1", "test", test_job)
|
self.pool.submit("test1", "test", lock_job)
|
||||||
self.pool.submit("test2", "test", test_job)
|
self.pool.submit("test2", "test", lock_job)
|
||||||
self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None))
|
self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None))
|
||||||
|
|
||||||
lock.set()
|
lock.set()
|
||||||
|
@ -115,14 +124,14 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
TODO: flaky
|
TODO: flaky
|
||||||
"""
|
"""
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = test_device()
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
|
|
||||||
self.pool = DevicePoolExecutor(
|
self.pool = DevicePoolExecutor(
|
||||||
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||||
)
|
)
|
||||||
self.pool.start()
|
self.pool.start()
|
||||||
self.pool.submit("test", "test", wait_job)
|
self.pool.submit("test", "test", sleep_job)
|
||||||
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
|
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
|
||||||
|
|
||||||
sleep(5.0)
|
sleep(5.0)
|
||||||
|
@ -137,3 +146,38 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
|
|
||||||
def test_running_status(self):
|
def test_running_status(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def test_progress_update(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_progress_finished(self):
|
||||||
|
device = test_device()
|
||||||
|
server = ServerContext()
|
||||||
|
|
||||||
|
self.pool = DevicePoolExecutor(
|
||||||
|
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||||
|
)
|
||||||
|
self.pool.start()
|
||||||
|
|
||||||
|
self.pool.submit("test", "test", progress_job)
|
||||||
|
sleep(5.0)
|
||||||
|
|
||||||
|
status, progress = self.pool.status("test")
|
||||||
|
self.assertEqual(status, JobStatus.SUCCESS)
|
||||||
|
self.assertEqual(progress.steps.current, 1)
|
||||||
|
|
||||||
|
def test_progress_failed(self):
|
||||||
|
device = test_device()
|
||||||
|
server = ServerContext()
|
||||||
|
|
||||||
|
self.pool = DevicePoolExecutor(
|
||||||
|
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||||
|
)
|
||||||
|
self.pool.start()
|
||||||
|
|
||||||
|
self.pool.submit("test", "test", fail_job)
|
||||||
|
sleep(5.0)
|
||||||
|
|
||||||
|
status, progress = self.pool.status("test")
|
||||||
|
self.assertEqual(status, JobStatus.FAILED)
|
||||||
|
self.assertEqual(progress.steps.current, 0)
|
||||||
|
|
Loading…
Reference in New Issue