523 lines
18 KiB
Python
523 lines
18 KiB
Python
import unittest
|
|
from os import path
|
|
from unittest import mock
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from onnx_web.constants import ONNX_MODEL
|
|
from onnx_web.convert.utils import (
|
|
DEFAULT_OPSET,
|
|
ConversionContext,
|
|
build_cache_paths,
|
|
download_progress,
|
|
fix_diffusion_name,
|
|
get_first_exists,
|
|
load_tensor,
|
|
load_torch,
|
|
remove_prefix,
|
|
resolve_tensor,
|
|
source_format,
|
|
tuple_to_correction,
|
|
tuple_to_diffusion,
|
|
tuple_to_source,
|
|
tuple_to_upscaling,
|
|
)
|
|
from tests.helpers import TEST_MODEL_UPSCALING_SWINIR, test_needs_models
|
|
|
|
|
|
class ConversionContextTests(unittest.TestCase):
|
|
def test_from_environ(self):
|
|
context = ConversionContext.from_environ()
|
|
self.assertEqual(context.opset, DEFAULT_OPSET)
|
|
|
|
def test_map_location(self):
|
|
context = ConversionContext.from_environ()
|
|
self.assertEqual(context.map_location.type, "cpu")
|
|
|
|
|
|
class DownloadProgressTests(unittest.TestCase):
|
|
def test_download_example(self):
|
|
path = download_progress("https://example.com", "/tmp/example-dot-com")
|
|
self.assertEqual(path, "/tmp/example-dot-com")
|
|
|
|
@patch("onnx_web.convert.utils.Path")
|
|
@patch("onnx_web.convert.utils.requests")
|
|
@patch("onnx_web.convert.utils.shutil")
|
|
@patch("onnx_web.convert.utils.tqdm")
|
|
def test_download_progress(self, mock_tqdm, mock_shutil, mock_requests, mock_path):
|
|
source = "http://example.com/image.jpg"
|
|
dest = "/path/to/destination/image.jpg"
|
|
|
|
dest_path_mock = MagicMock()
|
|
mock_path.return_value.expanduser.return_value.resolve.return_value = (
|
|
dest_path_mock
|
|
)
|
|
dest_path_mock.exists.return_value = False
|
|
dest_path_mock.absolute.return_value = "test"
|
|
mock_requests.get.return_value.status_code = 200
|
|
mock_requests.get.return_value.headers.get.return_value = "1000"
|
|
mock_tqdm.wrapattr.return_value.__enter__.return_value = MagicMock()
|
|
|
|
result = download_progress(source, dest)
|
|
|
|
mock_path.assert_called_once_with(dest)
|
|
dest_path_mock.parent.mkdir.assert_called_once_with(parents=True, exist_ok=True)
|
|
dest_path_mock.open.assert_called_once_with("wb")
|
|
mock_shutil.copyfileobj.assert_called_once()
|
|
self.assertEqual(result, str(dest_path_mock.absolute.return_value))
|
|
|
|
|
|
class TupleToSourceTests(unittest.TestCase):
|
|
def test_basic_tuple(self):
|
|
source = tuple_to_source(("foo", "bar"))
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_basic_list(self):
|
|
source = tuple_to_source(["foo", "bar"])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_basic_dict(self):
|
|
source = tuple_to_source(["foo", "bar"])
|
|
source["bin"] = "bin"
|
|
|
|
# make sure this is returned as-is with extra fields
|
|
second = tuple_to_source(source)
|
|
|
|
self.assertEqual(source, second)
|
|
self.assertIn("bin", second)
|
|
|
|
|
|
class TupleToCorrectionTests(unittest.TestCase):
|
|
def test_basic_tuple(self):
|
|
source = tuple_to_correction(("foo", "bar"))
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_basic_list(self):
|
|
source = tuple_to_correction(["foo", "bar"])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_basic_dict(self):
|
|
source = tuple_to_correction(["foo", "bar"])
|
|
source["bin"] = "bin"
|
|
|
|
# make sure this is returned with extra fields
|
|
second = tuple_to_source(source)
|
|
|
|
self.assertEqual(source, second)
|
|
self.assertIn("bin", second)
|
|
|
|
def test_scale_tuple(self):
|
|
source = tuple_to_correction(["foo", "bar", 2])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_half_tuple(self):
|
|
source = tuple_to_correction(["foo", "bar", True])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_opset_tuple(self):
|
|
source = tuple_to_correction(["foo", "bar", 14])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_all_tuple(self):
|
|
source = tuple_to_correction(["foo", "bar", 2, True, 14])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
self.assertEqual(source["scale"], 2)
|
|
self.assertEqual(source["half"], True)
|
|
self.assertEqual(source["opset"], 14)
|
|
|
|
|
|
class TupleToDiffusionTests(unittest.TestCase):
|
|
def test_basic_tuple(self):
|
|
source = tuple_to_diffusion(("foo", "bar"))
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_basic_list(self):
|
|
source = tuple_to_diffusion(["foo", "bar"])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_basic_dict(self):
|
|
source = tuple_to_diffusion(["foo", "bar"])
|
|
source["bin"] = "bin"
|
|
|
|
# make sure this is returned with extra fields
|
|
second = tuple_to_diffusion(source)
|
|
|
|
self.assertEqual(source, second)
|
|
self.assertIn("bin", second)
|
|
|
|
def test_single_vae_tuple(self):
|
|
source = tuple_to_diffusion(["foo", "bar", True])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_half_tuple(self):
|
|
source = tuple_to_diffusion(["foo", "bar", True])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_opset_tuple(self):
|
|
source = tuple_to_diffusion(["foo", "bar", 14])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_all_tuple(self):
|
|
source = tuple_to_diffusion(["foo", "bar", True, True, 14])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
self.assertEqual(source["single_vae"], True)
|
|
self.assertEqual(source["half"], True)
|
|
self.assertEqual(source["opset"], 14)
|
|
|
|
|
|
class TupleToUpscalingTests(unittest.TestCase):
|
|
def test_basic_tuple(self):
|
|
source = tuple_to_upscaling(("foo", "bar"))
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_basic_list(self):
|
|
source = tuple_to_upscaling(["foo", "bar"])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_basic_dict(self):
|
|
source = tuple_to_upscaling(["foo", "bar"])
|
|
source["bin"] = "bin"
|
|
|
|
# make sure this is returned with extra fields
|
|
second = tuple_to_source(source)
|
|
|
|
self.assertEqual(source, second)
|
|
self.assertIn("bin", second)
|
|
|
|
def test_scale_tuple(self):
|
|
source = tuple_to_upscaling(["foo", "bar", 2])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_half_tuple(self):
|
|
source = tuple_to_upscaling(["foo", "bar", True])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_opset_tuple(self):
|
|
source = tuple_to_upscaling(["foo", "bar", 14])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
|
|
def test_all_tuple(self):
|
|
source = tuple_to_upscaling(["foo", "bar", 2, True, 14])
|
|
self.assertEqual(source["name"], "foo")
|
|
self.assertEqual(source["source"], "bar")
|
|
self.assertEqual(source["scale"], 2)
|
|
self.assertEqual(source["half"], True)
|
|
self.assertEqual(source["opset"], 14)
|
|
|
|
|
|
class SourceFormatTests(unittest.TestCase):
|
|
def test_with_format(self):
|
|
result = source_format(
|
|
{
|
|
"format": "foo",
|
|
}
|
|
)
|
|
self.assertEqual(result, "foo")
|
|
|
|
def test_source_known_extension(self):
|
|
result = source_format(
|
|
{
|
|
"source": "foo.safetensors",
|
|
}
|
|
)
|
|
self.assertEqual(result, "safetensors")
|
|
|
|
def test_source_unknown_extension(self):
|
|
result = source_format({"source": "foo.none"})
|
|
self.assertEqual(result, None)
|
|
|
|
def test_incomplete_model(self):
|
|
self.assertIsNone(source_format({}))
|
|
|
|
|
|
class RemovePrefixTests(unittest.TestCase):
|
|
def test_with_prefix(self):
|
|
self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar")
|
|
|
|
def test_without_prefix(self):
|
|
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
|
|
|
|
|
|
class ResolveTensorTests(unittest.TestCase):
|
|
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
|
|
def test_resolve_existing(self):
|
|
self.assertEqual(
|
|
resolve_tensor("../models/.cache/upscaling-swinir"),
|
|
TEST_MODEL_UPSCALING_SWINIR,
|
|
)
|
|
|
|
def test_resolve_missing(self):
|
|
self.assertIsNone(resolve_tensor("missing"))
|
|
|
|
|
|
TORCH_MODEL = "model.pth"
|
|
|
|
|
|
class LoadTorchTests(unittest.TestCase):
|
|
@patch("onnx_web.convert.utils.logger")
|
|
@patch("onnx_web.convert.utils.torch")
|
|
def test_load_torch_with_torch_load(self, mock_torch, mock_logger):
|
|
map_location = "cpu"
|
|
checkpoint = MagicMock()
|
|
mock_torch.load.return_value = checkpoint
|
|
|
|
result = load_torch(TORCH_MODEL, map_location)
|
|
|
|
mock_logger.debug.assert_called_once_with(
|
|
"loading tensor with Torch: %s", TORCH_MODEL
|
|
)
|
|
mock_torch.load.assert_called_once_with(TORCH_MODEL, map_location=map_location)
|
|
self.assertEqual(result, checkpoint)
|
|
|
|
@patch("onnx_web.convert.utils.logger")
|
|
@patch("onnx_web.convert.utils.torch")
|
|
def test_load_torch_with_torch_jit_load(self, mock_torch, mock_logger):
|
|
checkpoint = MagicMock()
|
|
mock_torch.load.side_effect = Exception()
|
|
mock_torch.jit.load.return_value = checkpoint
|
|
|
|
result = load_torch(TORCH_MODEL)
|
|
|
|
mock_logger.debug.assert_called_once_with(
|
|
"loading tensor with Torch: %s", TORCH_MODEL
|
|
)
|
|
mock_logger.exception.assert_called_once_with(
|
|
"error loading with Torch, trying with Torch JIT: %s", TORCH_MODEL
|
|
)
|
|
mock_torch.jit.load.assert_called_once_with(TORCH_MODEL)
|
|
self.assertEqual(result, checkpoint)
|
|
|
|
|
|
LOAD_TENSOR_LOG = "loading tensor: %s"
|
|
|
|
|
|
class LoadTensorTests(unittest.TestCase):
|
|
@patch("onnx_web.convert.utils.logger")
|
|
@patch("onnx_web.convert.utils.path")
|
|
@patch("onnx_web.convert.utils.torch")
|
|
def test_load_tensor_with_no_extension(self, mock_torch, mock_path, mock_logger):
|
|
name = "model"
|
|
map_location = "cpu"
|
|
checkpoint = MagicMock()
|
|
mock_path.exists.return_value = True
|
|
mock_path.splitext.side_effect = [("model", ""), ("model", ".safetensors")]
|
|
mock_torch.load.return_value = checkpoint
|
|
|
|
result = load_tensor(name, map_location)
|
|
|
|
mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)])
|
|
mock_path.splitext.assert_called_once_with(name)
|
|
mock_path.exists.assert_called_once_with(name)
|
|
mock_torch.load.assert_called_once_with(name, map_location=map_location)
|
|
self.assertEqual(result, checkpoint)
|
|
|
|
@patch("onnx_web.convert.utils.logger")
|
|
@patch("onnx_web.convert.utils.environ")
|
|
@patch("onnx_web.convert.utils.safetensors")
|
|
def test_load_tensor_with_safetensors_extension(
|
|
self, mock_safetensors, mock_environ, mock_logger
|
|
):
|
|
name = "model.safetensors"
|
|
checkpoint = MagicMock()
|
|
mock_environ.__getitem__.return_value = "1"
|
|
mock_safetensors.torch.load_file.return_value = checkpoint
|
|
|
|
result = load_tensor(name)
|
|
|
|
mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)])
|
|
mock_safetensors.torch.load_file.assert_called_once_with(name, device="cpu")
|
|
self.assertEqual(result, checkpoint)
|
|
|
|
@patch("onnx_web.convert.utils.logger")
|
|
@patch("onnx_web.convert.utils.torch")
|
|
def test_load_tensor_with_pickle_extension(self, mock_torch, mock_logger):
|
|
name = "model.pt"
|
|
map_location = "cpu"
|
|
checkpoint = MagicMock()
|
|
mock_torch.load.side_effect = [checkpoint]
|
|
|
|
result = load_tensor(name, map_location)
|
|
|
|
mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)])
|
|
mock_torch.load.assert_has_calls(
|
|
[
|
|
mock.call(name, map_location=map_location),
|
|
]
|
|
)
|
|
self.assertEqual(result, checkpoint)
|
|
|
|
@patch("onnx_web.convert.utils.logger")
|
|
@patch("onnx_web.convert.utils.torch")
|
|
def test_load_tensor_with_onnx_extension(self, mock_torch, mock_logger):
|
|
map_location = "cpu"
|
|
checkpoint = MagicMock()
|
|
mock_torch.load.side_effect = [checkpoint]
|
|
|
|
result = load_tensor(ONNX_MODEL, map_location)
|
|
|
|
mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, ONNX_MODEL)])
|
|
mock_logger.warning.assert_called_once_with(
|
|
"tensor has ONNX extension, attempting to use PyTorch anyways: %s", "onnx"
|
|
)
|
|
mock_torch.load.assert_has_calls(
|
|
[
|
|
mock.call(ONNX_MODEL, map_location=map_location),
|
|
]
|
|
)
|
|
self.assertEqual(result, checkpoint)
|
|
|
|
@patch("onnx_web.convert.utils.logger")
|
|
@patch("onnx_web.convert.utils.torch")
|
|
def test_load_tensor_with_unknown_extension(self, mock_torch, mock_logger):
|
|
name = "model.xyz"
|
|
map_location = "cpu"
|
|
checkpoint = MagicMock()
|
|
mock_torch.load.side_effect = [checkpoint]
|
|
|
|
result = load_tensor(name, map_location)
|
|
|
|
mock_logger.debug.assert_has_calls([mock.call(LOAD_TENSOR_LOG, name)])
|
|
mock_logger.warning.assert_called_once_with(
|
|
"unknown tensor type, falling back to PyTorch: %s", "xyz"
|
|
)
|
|
mock_torch.load.assert_has_calls(
|
|
[
|
|
mock.call(name, map_location=map_location),
|
|
]
|
|
)
|
|
self.assertEqual(result, checkpoint)
|
|
|
|
@patch("onnx_web.convert.utils.logger")
|
|
@patch("onnx_web.convert.utils.torch")
|
|
def test_load_tensor_with_error_loading_tensor(self, mock_torch, mock_logger):
|
|
name = "model"
|
|
map_location = "cpu"
|
|
mock_torch.load.side_effect = Exception()
|
|
|
|
with self.assertRaises(ValueError):
|
|
load_tensor(name, map_location)
|
|
|
|
|
|
class FixDiffusionNameTests(unittest.TestCase):
|
|
def test_fix_diffusion_name_with_valid_name(self):
|
|
name = "diffusion-model"
|
|
result = fix_diffusion_name(name)
|
|
self.assertEqual(result, name)
|
|
|
|
@patch("onnx_web.convert.utils.logger")
|
|
def test_fix_diffusion_name_with_invalid_name(self, logger):
|
|
name = "model"
|
|
expected_result = "diffusion-model"
|
|
result = fix_diffusion_name(name)
|
|
|
|
self.assertEqual(result, expected_result)
|
|
logger.warning.assert_called_once_with(
|
|
"diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match",
|
|
name,
|
|
)
|
|
|
|
|
|
CACHE_PATH = "/path/to/cache"
|
|
|
|
|
|
class BuildCachePathsTests(unittest.TestCase):
|
|
def test_build_cache_paths_without_format(self):
|
|
client = "client1"
|
|
|
|
conversion = ConversionContext(cache_path=CACHE_PATH)
|
|
result = build_cache_paths(conversion, ONNX_MODEL, client, CACHE_PATH)
|
|
|
|
expected_paths = [
|
|
path.join(CACHE_PATH, ONNX_MODEL),
|
|
path.join("/path/to/cache/client1", ONNX_MODEL),
|
|
]
|
|
self.assertEqual(result, expected_paths)
|
|
|
|
def test_build_cache_paths_with_format(self):
|
|
name = "model"
|
|
client = "client2"
|
|
model_format = "onnx"
|
|
|
|
conversion = ConversionContext(cache_path=CACHE_PATH)
|
|
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
|
|
|
|
expected_paths = [
|
|
path.join(CACHE_PATH, ONNX_MODEL),
|
|
path.join("/path/to/cache/client2", ONNX_MODEL),
|
|
]
|
|
self.assertEqual(result, expected_paths)
|
|
|
|
def test_build_cache_paths_with_existing_extension(self):
|
|
client = "client3"
|
|
model_format = "onnx"
|
|
|
|
conversion = ConversionContext(cache_path=CACHE_PATH)
|
|
result = build_cache_paths(
|
|
conversion, TORCH_MODEL, client, CACHE_PATH, model_format
|
|
)
|
|
|
|
expected_paths = [
|
|
path.join(CACHE_PATH, TORCH_MODEL),
|
|
path.join("/path/to/cache/client3", TORCH_MODEL),
|
|
]
|
|
self.assertEqual(result, expected_paths)
|
|
|
|
def test_build_cache_paths_with_empty_extension(self):
|
|
name = "model"
|
|
client = "client4"
|
|
model_format = "onnx"
|
|
|
|
conversion = ConversionContext(cache_path=CACHE_PATH)
|
|
result = build_cache_paths(conversion, name, client, CACHE_PATH, model_format)
|
|
|
|
expected_paths = [
|
|
path.join(CACHE_PATH, ONNX_MODEL),
|
|
path.join("/path/to/cache/client4", ONNX_MODEL),
|
|
]
|
|
self.assertEqual(result, expected_paths)
|
|
|
|
|
|
class GetFirstExistsTests(unittest.TestCase):
|
|
@patch("onnx_web.convert.utils.path")
|
|
@patch("onnx_web.convert.utils.logger")
|
|
def test_get_first_exists_with_existing_path(self, mock_logger, mock_path):
|
|
paths = ["path1", "path2", "path3"]
|
|
mock_path.exists.side_effect = [False, True, False]
|
|
mock_path.return_value = MagicMock()
|
|
|
|
result = get_first_exists(paths)
|
|
|
|
mock_logger.debug.assert_called_once_with(
|
|
"model already exists in cache, skipping fetch: %s", "path2"
|
|
)
|
|
self.assertEqual(result, "path2")
|
|
|
|
@patch("onnx_web.convert.utils.path")
|
|
@patch("onnx_web.convert.utils.logger")
|
|
def test_get_first_exists_with_no_existing_path(self, mock_logger, mock_path):
|
|
paths = ["path1", "path2", "path3"]
|
|
mock_path.exists.return_value = False
|
|
|
|
result = get_first_exists(paths)
|
|
|
|
mock_logger.debug.assert_not_called()
|
|
self.assertIsNone(result)
|