1
0
Fork 0
onnx-web/api/tests/convert/test_utils.py

245 lines
7.6 KiB
Python
Raw Normal View History

2023-09-15 00:35:48 +00:00
import unittest
from onnx_web.convert.utils import (
DEFAULT_OPSET,
ConversionContext,
download_progress,
remove_prefix,
resolve_tensor,
source_format,
tuple_to_correction,
tuple_to_diffusion,
tuple_to_source,
tuple_to_upscaling,
)
from tests.helpers import (
TEST_MODEL_UPSCALING_SWINIR,
test_needs_models,
)
2023-09-15 00:35:48 +00:00
class ConversionContextTests(unittest.TestCase):
def test_from_environ(self):
2023-11-20 05:18:57 +00:00
context = ConversionContext.from_environ()
self.assertEqual(context.opset, DEFAULT_OPSET)
2023-09-15 00:35:48 +00:00
def test_map_location(self):
2023-11-20 05:18:57 +00:00
context = ConversionContext.from_environ()
self.assertEqual(context.map_location.type, "cpu")
2023-09-15 00:35:48 +00:00
class DownloadProgressTests(unittest.TestCase):
2023-11-20 05:18:57 +00:00
def test_download_example(self):
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
self.assertEqual(path, "/tmp/example-dot-com")
class TupleToSourceTests(unittest.TestCase):
2023-11-20 05:18:57 +00:00
def test_basic_tuple(self):
source = tuple_to_source(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
2023-11-20 05:18:57 +00:00
def test_basic_list(self):
source = tuple_to_source(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
2023-11-20 05:18:57 +00:00
def test_basic_dict(self):
source = tuple_to_source(["foo", "bar"])
source["bin"] = "bin"
2023-11-20 05:18:57 +00:00
# make sure this is returned as-is with extra fields
second = tuple_to_source(source)
2023-11-20 05:18:57 +00:00
self.assertEqual(source, second)
self.assertIn("bin", second)
class TupleToCorrectionTests(unittest.TestCase):
2023-11-20 05:18:57 +00:00
def test_basic_tuple(self):
source = tuple_to_correction(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_list(self):
source = tuple_to_correction(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_dict(self):
source = tuple_to_correction(["foo", "bar"])
source["bin"] = "bin"
# make sure this is returned with extra fields
second = tuple_to_source(source)
self.assertEqual(source, second)
self.assertIn("bin", second)
def test_scale_tuple(self):
source = tuple_to_correction(["foo", "bar", 2])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_half_tuple(self):
source = tuple_to_correction(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_opset_tuple(self):
source = tuple_to_correction(["foo", "bar", 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_all_tuple(self):
source = tuple_to_correction(["foo", "bar", 2, True, 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
self.assertEqual(source["scale"], 2)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)
class TupleToDiffusionTests(unittest.TestCase):
2023-11-20 05:18:57 +00:00
def test_basic_tuple(self):
source = tuple_to_diffusion(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_list(self):
source = tuple_to_diffusion(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_dict(self):
source = tuple_to_diffusion(["foo", "bar"])
source["bin"] = "bin"
# make sure this is returned with extra fields
second = tuple_to_diffusion(source)
self.assertEqual(source, second)
self.assertIn("bin", second)
def test_single_vae_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_half_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_opset_tuple(self):
source = tuple_to_diffusion(["foo", "bar", 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_all_tuple(self):
source = tuple_to_diffusion(["foo", "bar", True, True, 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
self.assertEqual(source["single_vae"], True)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)
class TupleToUpscalingTests(unittest.TestCase):
2023-11-20 05:18:57 +00:00
def test_basic_tuple(self):
source = tuple_to_upscaling(("foo", "bar"))
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_list(self):
source = tuple_to_upscaling(["foo", "bar"])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_basic_dict(self):
source = tuple_to_upscaling(["foo", "bar"])
source["bin"] = "bin"
# make sure this is returned with extra fields
second = tuple_to_source(source)
self.assertEqual(source, second)
self.assertIn("bin", second)
def test_scale_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 2])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_half_tuple(self):
source = tuple_to_upscaling(["foo", "bar", True])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_opset_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
def test_all_tuple(self):
source = tuple_to_upscaling(["foo", "bar", 2, True, 14])
self.assertEqual(source["name"], "foo")
self.assertEqual(source["source"], "bar")
self.assertEqual(source["scale"], 2)
self.assertEqual(source["half"], True)
self.assertEqual(source["opset"], 14)
class SourceFormatTests(unittest.TestCase):
2023-11-20 05:18:57 +00:00
def test_with_format(self):
result = source_format(
{
"format": "foo",
}
)
self.assertEqual(result, "foo")
def test_source_known_extension(self):
result = source_format(
{
"source": "foo.safetensors",
}
)
self.assertEqual(result, "safetensors")
def test_source_unknown_extension(self):
result = source_format({"source": "foo.none"})
self.assertEqual(result, None)
def test_incomplete_model(self):
self.assertIsNone(source_format({}))
class RemovePrefixTests(unittest.TestCase):
2023-11-20 05:18:57 +00:00
def test_with_prefix(self):
self.assertEqual(remove_prefix("foo.bar", "foo"), ".bar")
2023-11-20 05:18:57 +00:00
def test_without_prefix(self):
self.assertEqual(remove_prefix("foo.bar", "bin"), "foo.bar")
class LoadTorchTests(unittest.TestCase):
2023-11-20 05:18:57 +00:00
pass
class LoadTensorTests(unittest.TestCase):
2023-11-20 05:18:57 +00:00
pass
class ResolveTensorTests(unittest.TestCase):
2023-11-20 05:18:57 +00:00
@test_needs_models([TEST_MODEL_UPSCALING_SWINIR])
def test_resolve_existing(self):
self.assertEqual(
resolve_tensor("../models/.cache/upscaling-swinir"),
TEST_MODEL_UPSCALING_SWINIR,
)
def test_resolve_missing(self):
self.assertIsNone(resolve_tensor("missing"))