fix(api): continue adding tests, fix bugs encountered
This commit is contained in:
parent
898d76e4a5
commit
047e58c916
|
@ -133,9 +133,9 @@ def tuple_to_source(model: Union[ModelDict, LegacyModel]):
|
||||||
def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
|
def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
|
||||||
if isinstance(model, list) or isinstance(model, tuple):
|
if isinstance(model, list) or isinstance(model, tuple):
|
||||||
name, source, *rest = model
|
name, source, *rest = model
|
||||||
scale = rest[0] if len(rest) > 0 else 1
|
scale = rest.pop(0) if len(rest) > 0 else 1
|
||||||
half = rest[0] if len(rest) > 0 else False
|
half = rest.pop(0) if len(rest) > 0 else False
|
||||||
opset = rest[0] if len(rest) > 0 else None
|
opset = rest.pop(0) if len(rest) > 0 else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"name": name,
|
"name": name,
|
||||||
|
@ -151,9 +151,9 @@ def tuple_to_correction(model: Union[ModelDict, LegacyModel]):
|
||||||
def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
|
def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
|
||||||
if isinstance(model, list) or isinstance(model, tuple):
|
if isinstance(model, list) or isinstance(model, tuple):
|
||||||
name, source, *rest = model
|
name, source, *rest = model
|
||||||
single_vae = rest[0] if len(rest) > 0 else False
|
single_vae = rest.pop(0) if len(rest) > 0 else False
|
||||||
half = rest[0] if len(rest) > 0 else False
|
half = rest.pop(0) if len(rest) > 0 else False
|
||||||
opset = rest[0] if len(rest) > 0 else None
|
opset = rest.pop(0) if len(rest) > 0 else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"name": name,
|
"name": name,
|
||||||
|
@ -169,9 +169,9 @@ def tuple_to_diffusion(model: Union[ModelDict, LegacyModel]):
|
||||||
def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
||||||
if isinstance(model, list) or isinstance(model, tuple):
|
if isinstance(model, list) or isinstance(model, tuple):
|
||||||
name, source, *rest = model
|
name, source, *rest = model
|
||||||
scale = rest[0] if len(rest) > 0 else 1
|
scale = rest.pop(0) if len(rest) > 0 else 1
|
||||||
half = rest[0] if len(rest) > 0 else False
|
half = rest.pop(0) if len(rest) > 0 else False
|
||||||
opset = rest[0] if len(rest) > 0 else None
|
opset = rest.pop(0) if len(rest) > 0 else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"name": name,
|
"name": name,
|
||||||
|
@ -298,6 +298,7 @@ def onnx_export(
|
||||||
half=False,
|
half=False,
|
||||||
external_data=False,
|
external_data=False,
|
||||||
v2=False,
|
v2=False,
|
||||||
|
op_block_list=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
|
@ -316,8 +317,7 @@ def onnx_export(
|
||||||
opset_version=opset,
|
opset_version=opset,
|
||||||
)
|
)
|
||||||
|
|
||||||
op_block_list = None
|
if v2 and op_block_list is None:
|
||||||
if v2:
|
|
||||||
op_block_list = ["Attention", "MultiHeadAttention"]
|
op_block_list = ["Attention", "MultiHeadAttention"]
|
||||||
|
|
||||||
if half:
|
if half:
|
||||||
|
|
|
@ -97,6 +97,7 @@ def run_txt2img_pipeline(
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
|
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
|
logger.trace("saving output image %s: %s", output, image.size)
|
||||||
dest = save_image(
|
dest = save_image(
|
||||||
server,
|
server,
|
||||||
output,
|
output,
|
||||||
|
|
|
@ -47,7 +47,7 @@ def source_filter_noise(
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
strength: float = 0.5,
|
strength: float = 0.5,
|
||||||
):
|
):
|
||||||
noise = noise_source_histogram(source, source.size)
|
noise = noise_source_histogram(source, source.size, (0, 0))
|
||||||
return ImageChops.blend(source, noise, strength)
|
return ImageChops.blend(source, noise, strength)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ class WorkerContext:
|
||||||
idle: "Value[bool]"
|
idle: "Value[bool]"
|
||||||
timeout: float
|
timeout: float
|
||||||
retries: int
|
retries: int
|
||||||
|
initial_retries: int
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -37,6 +38,7 @@ class WorkerContext:
|
||||||
active_pid: "Value[int]",
|
active_pid: "Value[int]",
|
||||||
idle: "Value[bool]",
|
idle: "Value[bool]",
|
||||||
retries: int,
|
retries: int,
|
||||||
|
timeout: float,
|
||||||
):
|
):
|
||||||
self.job = None
|
self.job = None
|
||||||
self.name = name
|
self.name = name
|
||||||
|
@ -48,12 +50,13 @@ class WorkerContext:
|
||||||
self.active_pid = active_pid
|
self.active_pid = active_pid
|
||||||
self.last_progress = None
|
self.last_progress = None
|
||||||
self.idle = idle
|
self.idle = idle
|
||||||
|
self.initial_retries = retries
|
||||||
self.retries = retries
|
self.retries = retries
|
||||||
self.timeout = 1.0
|
self.timeout = timeout
|
||||||
|
|
||||||
def start(self, job: str) -> None:
|
def start(self, job: str) -> None:
|
||||||
self.job = job
|
self.job = job
|
||||||
self.retries = 3
|
self.retries = self.initial_retries
|
||||||
self.set_cancel(cancel=False)
|
self.set_cancel(cancel=False)
|
||||||
self.set_idle(idle=False)
|
self.set_idle(idle=False)
|
||||||
|
|
||||||
|
|
|
@ -86,15 +86,15 @@ class DevicePoolExecutor:
|
||||||
self.logs = Queue(self.max_pending_per_worker)
|
self.logs = Queue(self.max_pending_per_worker)
|
||||||
self.rlock = Lock()
|
self.rlock = Lock()
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self, *args) -> None:
|
||||||
self.create_health_worker()
|
self.create_health_worker()
|
||||||
self.create_logger_worker()
|
self.create_logger_worker()
|
||||||
self.create_progress_worker()
|
self.create_progress_worker()
|
||||||
|
|
||||||
for device in self.devices:
|
for device in self.devices:
|
||||||
self.create_device_worker(device)
|
self.create_device_worker(device, *args)
|
||||||
|
|
||||||
def create_device_worker(self, device: DeviceParams) -> None:
|
def create_device_worker(self, device: DeviceParams, *args) -> None:
|
||||||
name = device.device
|
name = device.device
|
||||||
|
|
||||||
# always recreate queues
|
# always recreate queues
|
||||||
|
@ -125,15 +125,16 @@ class DevicePoolExecutor:
|
||||||
active_pid=current,
|
active_pid=current,
|
||||||
idle=self.worker_idle[name],
|
idle=self.worker_idle[name],
|
||||||
retries=self.server.worker_retries,
|
retries=self.server.worker_retries,
|
||||||
|
timeout=self.progress_interval,
|
||||||
)
|
)
|
||||||
self.context[name] = context
|
self.context[name] = context
|
||||||
|
|
||||||
worker = Process(
|
worker = Process(
|
||||||
name=f"onnx-web worker: {name}",
|
name=f"onnx-web worker: {name}",
|
||||||
target=worker_main,
|
target=worker_main,
|
||||||
args=(context, self.server),
|
args=(context, self.server, *args),
|
||||||
|
daemon=True,
|
||||||
)
|
)
|
||||||
worker.daemon = True
|
|
||||||
self.workers[name] = worker
|
self.workers[name] = worker
|
||||||
|
|
||||||
logger.debug("starting worker for device %s", device)
|
logger.debug("starting worker for device %s", device)
|
||||||
|
|
|
@ -27,7 +27,7 @@ MEMORY_ERRORS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def worker_main(worker: WorkerContext, server: ServerContext):
|
def worker_main(worker: WorkerContext, server: ServerContext, *args):
|
||||||
apply_patches(server)
|
apply_patches(server)
|
||||||
setproctitle("onnx-web worker: %s" % (worker.device.device))
|
setproctitle("onnx-web worker: %s" % (worker.device.device))
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,8 @@ from onnx_web.chain.tile import (
|
||||||
generate_tile_spiral,
|
generate_tile_spiral,
|
||||||
get_tile_grads,
|
get_tile_grads,
|
||||||
needs_tile,
|
needs_tile,
|
||||||
|
process_tile_grid,
|
||||||
|
process_tile_spiral,
|
||||||
)
|
)
|
||||||
from onnx_web.params import Size
|
from onnx_web.params import Size
|
||||||
|
|
||||||
|
@ -95,3 +97,31 @@ class TestGenerateTileSpiral(unittest.TestCase):
|
||||||
self.assertEqual(len(tiles), 225)
|
self.assertEqual(len(tiles), 225)
|
||||||
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
|
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
|
||||||
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
|
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessTileGrid(unittest.TestCase):
|
||||||
|
def test_grid_full(self):
|
||||||
|
source = Image.new("RGB", (64, 64))
|
||||||
|
blend = process_tile_grid(source, 32, 1, [])
|
||||||
|
|
||||||
|
self.assertEqual(blend.size, (64, 64))
|
||||||
|
|
||||||
|
def test_grid_partial(self):
|
||||||
|
source = Image.new("RGB", (72, 72))
|
||||||
|
blend = process_tile_grid(source, 32, 1, [])
|
||||||
|
|
||||||
|
self.assertEqual(blend.size, (72, 72))
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessTileSpiral(unittest.TestCase):
|
||||||
|
def test_grid_full(self):
|
||||||
|
source = Image.new("RGB", (64, 64))
|
||||||
|
blend = process_tile_spiral(source, 32, 1, [])
|
||||||
|
|
||||||
|
self.assertEqual(blend.size, (64, 64))
|
||||||
|
|
||||||
|
def test_grid_partial(self):
|
||||||
|
source = Image.new("RGB", (72, 72))
|
||||||
|
blend = process_tile_spiral(source, 32, 1, [])
|
||||||
|
|
||||||
|
self.assertEqual(blend.size, (72, 72))
|
||||||
|
|
|
@ -1,6 +1,14 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from onnx_web.convert.utils import DEFAULT_OPSET, ConversionContext, download_progress
|
from onnx_web.convert.utils import (
|
||||||
|
DEFAULT_OPSET,
|
||||||
|
ConversionContext,
|
||||||
|
download_progress,
|
||||||
|
tuple_to_correction,
|
||||||
|
tuple_to_diffusion,
|
||||||
|
tuple_to_source,
|
||||||
|
tuple_to_upscaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConversionContextTests(unittest.TestCase):
|
class ConversionContextTests(unittest.TestCase):
|
||||||
|
@ -17,3 +25,160 @@ class DownloadProgressTests(unittest.TestCase):
|
||||||
def test_download_example(self):
|
def test_download_example(self):
|
||||||
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
|
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
|
||||||
self.assertEqual(path, "/tmp/example-dot-com")
|
self.assertEqual(path, "/tmp/example-dot-com")
|
||||||
|
|
||||||
|
|
||||||
|
class TupleToSourceTests(unittest.TestCase):
|
||||||
|
def test_basic_tuple(self):
|
||||||
|
source = tuple_to_source(("foo", "bar"))
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_basic_list(self):
|
||||||
|
source = tuple_to_source(["foo", "bar"])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_basic_dict(self):
|
||||||
|
source = tuple_to_source(["foo", "bar"])
|
||||||
|
source["bin"] = "bin"
|
||||||
|
|
||||||
|
# make sure this is returned as-is with extra fields
|
||||||
|
second = tuple_to_source(source)
|
||||||
|
|
||||||
|
self.assertEqual(source, second)
|
||||||
|
self.assertIn("bin", second)
|
||||||
|
|
||||||
|
|
||||||
|
class TupleToCorrectionTests(unittest.TestCase):
|
||||||
|
def test_basic_tuple(self):
|
||||||
|
source = tuple_to_correction(("foo", "bar"))
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_basic_list(self):
|
||||||
|
source = tuple_to_correction(["foo", "bar"])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_basic_dict(self):
|
||||||
|
source = tuple_to_correction(["foo", "bar"])
|
||||||
|
source["bin"] = "bin"
|
||||||
|
|
||||||
|
# make sure this is returned with extra fields
|
||||||
|
second = tuple_to_source(source)
|
||||||
|
|
||||||
|
self.assertEqual(source, second)
|
||||||
|
self.assertIn("bin", second)
|
||||||
|
|
||||||
|
def test_scale_tuple(self):
|
||||||
|
source = tuple_to_correction(["foo", "bar", 2])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_half_tuple(self):
|
||||||
|
source = tuple_to_correction(["foo", "bar", True])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_opset_tuple(self):
|
||||||
|
source = tuple_to_correction(["foo", "bar", 14])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_all_tuple(self):
|
||||||
|
source = tuple_to_correction(["foo", "bar", 2, True, 14])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
self.assertEqual(source["scale"], 2)
|
||||||
|
self.assertEqual(source["half"], True)
|
||||||
|
self.assertEqual(source["opset"], 14)
|
||||||
|
|
||||||
|
|
||||||
|
class TupleToDiffusionTests(unittest.TestCase):
|
||||||
|
def test_basic_tuple(self):
|
||||||
|
source = tuple_to_diffusion(("foo", "bar"))
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_basic_list(self):
|
||||||
|
source = tuple_to_diffusion(["foo", "bar"])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_basic_dict(self):
|
||||||
|
source = tuple_to_diffusion(["foo", "bar"])
|
||||||
|
source["bin"] = "bin"
|
||||||
|
|
||||||
|
# make sure this is returned with extra fields
|
||||||
|
second = tuple_to_diffusion(source)
|
||||||
|
|
||||||
|
self.assertEqual(source, second)
|
||||||
|
self.assertIn("bin", second)
|
||||||
|
|
||||||
|
def test_single_vae_tuple(self):
|
||||||
|
source = tuple_to_diffusion(["foo", "bar", True])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_half_tuple(self):
|
||||||
|
source = tuple_to_diffusion(["foo", "bar", True])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_opset_tuple(self):
|
||||||
|
source = tuple_to_diffusion(["foo", "bar", 14])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_all_tuple(self):
|
||||||
|
source = tuple_to_diffusion(["foo", "bar", True, True, 14])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
self.assertEqual(source["single_vae"], True)
|
||||||
|
self.assertEqual(source["half"], True)
|
||||||
|
self.assertEqual(source["opset"], 14)
|
||||||
|
|
||||||
|
|
||||||
|
class TupleToUpscalingTests(unittest.TestCase):
|
||||||
|
def test_basic_tuple(self):
|
||||||
|
source = tuple_to_upscaling(("foo", "bar"))
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_basic_list(self):
|
||||||
|
source = tuple_to_upscaling(["foo", "bar"])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_basic_dict(self):
|
||||||
|
source = tuple_to_upscaling(["foo", "bar"])
|
||||||
|
source["bin"] = "bin"
|
||||||
|
|
||||||
|
# make sure this is returned with extra fields
|
||||||
|
second = tuple_to_source(source)
|
||||||
|
|
||||||
|
self.assertEqual(source, second)
|
||||||
|
self.assertIn("bin", second)
|
||||||
|
|
||||||
|
def test_scale_tuple(self):
|
||||||
|
source = tuple_to_upscaling(["foo", "bar", 2])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_half_tuple(self):
|
||||||
|
source = tuple_to_upscaling(["foo", "bar", True])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_opset_tuple(self):
|
||||||
|
source = tuple_to_upscaling(["foo", "bar", 14])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
|
||||||
|
def test_all_tuple(self):
|
||||||
|
source = tuple_to_upscaling(["foo", "bar", 2, True, 14])
|
||||||
|
self.assertEqual(source["name"], "foo")
|
||||||
|
self.assertEqual(source["source"], "bar")
|
||||||
|
self.assertEqual(source["scale"], 2)
|
||||||
|
self.assertEqual(source["half"], True)
|
||||||
|
self.assertEqual(source["opset"], 14)
|
||||||
|
|
|
@ -1,9 +1,16 @@
|
||||||
|
from os import path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from unittest import skipUnless
|
||||||
|
|
||||||
|
from onnx_web.params import DeviceParams
|
||||||
|
|
||||||
|
|
||||||
def test_with_models(models: List[str]):
|
def test_needs_models(models: List[str]):
|
||||||
def wrapper(func):
|
return skipUnless(all([path.exists(model) for model in models]), "model does not exist")
|
||||||
# TODO: check if models exist
|
|
||||||
return func
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
def test_device() -> DeviceParams:
|
||||||
|
return DeviceParams("cpu", "CPUExecutionProvider")
|
||||||
|
|
||||||
|
|
||||||
|
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
|
|
@ -0,0 +1,33 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from onnx_web.image.mask_filter import (
|
||||||
|
mask_filter_gaussian_multiply,
|
||||||
|
mask_filter_gaussian_screen,
|
||||||
|
mask_filter_none,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MaskFilterNoneTests(unittest.TestCase):
|
||||||
|
def test_basic(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
mask = Image.new("RGB", dims)
|
||||||
|
result = mask_filter_none(mask, dims, (0, 0))
|
||||||
|
self.assertEqual(result.size, dims)
|
||||||
|
|
||||||
|
|
||||||
|
class MaskFilterGaussianMultiplyTests(unittest.TestCase):
|
||||||
|
def test_basic(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
mask = Image.new("RGB", dims)
|
||||||
|
result = mask_filter_gaussian_multiply(mask, dims, (0, 0))
|
||||||
|
self.assertEqual(result.size, dims)
|
||||||
|
|
||||||
|
|
||||||
|
class MaskFilterGaussianScreenTests(unittest.TestCase):
|
||||||
|
def test_basic(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
mask = Image.new("RGB", dims)
|
||||||
|
result = mask_filter_gaussian_screen(mask, dims, (0, 0))
|
||||||
|
self.assertEqual(result.size, dims)
|
|
@ -0,0 +1,37 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from onnx_web.image.source_filter import (
|
||||||
|
source_filter_gaussian,
|
||||||
|
source_filter_noise,
|
||||||
|
source_filter_none,
|
||||||
|
)
|
||||||
|
from onnx_web.server.context import ServerContext
|
||||||
|
|
||||||
|
|
||||||
|
class SourceFilterNoneTests(unittest.TestCase):
|
||||||
|
def test_basic(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
server = ServerContext()
|
||||||
|
source = Image.new("RGB", dims)
|
||||||
|
result = source_filter_none(server, source)
|
||||||
|
self.assertEqual(result.size, dims)
|
||||||
|
|
||||||
|
|
||||||
|
class SourceFilterGaussianTests(unittest.TestCase):
|
||||||
|
def test_basic(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
server = ServerContext()
|
||||||
|
source = Image.new("RGB", dims)
|
||||||
|
result = source_filter_gaussian(server, source)
|
||||||
|
self.assertEqual(result.size, dims)
|
||||||
|
|
||||||
|
|
||||||
|
class SourceFilterNoiseTests(unittest.TestCase):
|
||||||
|
def test_basic(self):
|
||||||
|
dims = (64, 64)
|
||||||
|
server = ServerContext()
|
||||||
|
source = Image.new("RGB", dims)
|
||||||
|
result = source_filter_noise(server, source)
|
||||||
|
self.assertEqual(result.size, dims)
|
|
@ -1,5 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from onnx_web.server.context import ServerContext
|
||||||
from onnx_web.server.load import (
|
from onnx_web.server.load import (
|
||||||
get_available_platforms,
|
get_available_platforms,
|
||||||
get_config_params,
|
get_config_params,
|
||||||
|
@ -14,6 +15,8 @@ from onnx_web.server.load import (
|
||||||
get_source_filters,
|
get_source_filters,
|
||||||
get_upscaling_models,
|
get_upscaling_models,
|
||||||
get_wildcard_data,
|
get_wildcard_data,
|
||||||
|
load_extras,
|
||||||
|
load_models,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,3 +84,13 @@ class SourceFilterTests(unittest.TestCase):
|
||||||
def test_before_setup(self):
|
def test_before_setup(self):
|
||||||
filters = get_source_filters()
|
filters = get_source_filters()
|
||||||
self.assertIsNotNone(filters)
|
self.assertIsNotNone(filters)
|
||||||
|
|
||||||
|
class LoadExtrasTests(unittest.TestCase):
|
||||||
|
def test_default_extras(self):
|
||||||
|
server = ServerContext(extra_models=["../models/extras.json"])
|
||||||
|
load_extras(server)
|
||||||
|
|
||||||
|
class LoadModelsTests(unittest.TestCase):
|
||||||
|
def test_default_models(self):
|
||||||
|
server = ServerContext(model_path="../models")
|
||||||
|
load_models(server)
|
||||||
|
|
|
@ -17,10 +17,9 @@ from onnx_web.diffusers.load import (
|
||||||
)
|
)
|
||||||
from onnx_web.diffusers.patches.unet import UNetWrapper
|
from onnx_web.diffusers.patches.unet import UNetWrapper
|
||||||
from onnx_web.diffusers.patches.vae import VAEWrapper
|
from onnx_web.diffusers.patches.vae import VAEWrapper
|
||||||
from onnx_web.models.meta import NetworkModel, NetworkType
|
from onnx_web.models.meta import NetworkModel
|
||||||
from onnx_web.params import DeviceParams, ImageParams
|
from onnx_web.params import DeviceParams, ImageParams
|
||||||
from onnx_web.server.context import ServerContext
|
from onnx_web.server.context import ServerContext
|
||||||
from tests.helpers import test_with_models
|
|
||||||
from tests.mocks import MockPipeline
|
from tests.mocks import MockPipeline
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,171 @@
|
||||||
|
import unittest
|
||||||
|
from multiprocessing import Queue, Value
|
||||||
|
from os import path
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from onnx_web.diffusers.run import (
|
||||||
|
run_blend_pipeline,
|
||||||
|
run_img2img_pipeline,
|
||||||
|
run_txt2img_pipeline,
|
||||||
|
run_upscale_pipeline,
|
||||||
|
)
|
||||||
|
from onnx_web.params import HighresParams, ImageParams, Size, UpscaleParams
|
||||||
|
from onnx_web.server.context import ServerContext
|
||||||
|
from onnx_web.worker.context import WorkerContext
|
||||||
|
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
|
||||||
|
|
||||||
|
|
||||||
|
class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||||
|
def test_basic(self):
|
||||||
|
cancel = Value("L", 0)
|
||||||
|
logs = Queue()
|
||||||
|
pending = Queue()
|
||||||
|
progress = Queue()
|
||||||
|
active = Value("L", 0)
|
||||||
|
idle = Value("L", 0)
|
||||||
|
|
||||||
|
worker = WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
cancel,
|
||||||
|
logs,
|
||||||
|
pending,
|
||||||
|
progress,
|
||||||
|
active,
|
||||||
|
idle,
|
||||||
|
3,
|
||||||
|
0.1,
|
||||||
|
)
|
||||||
|
worker.start("test")
|
||||||
|
|
||||||
|
run_txt2img_pipeline(
|
||||||
|
worker,
|
||||||
|
ServerContext(model_path="../models", output_path="../outputs"),
|
||||||
|
ImageParams(
|
||||||
|
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
|
||||||
|
Size(256, 256),
|
||||||
|
["test-txt2img.png"],
|
||||||
|
UpscaleParams("test"),
|
||||||
|
HighresParams(False, 1, 0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(path.exists("../outputs/test-txt2img.png"))
|
||||||
|
|
||||||
|
class TestImg2ImgPipeline(unittest.TestCase):
|
||||||
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||||
|
def test_basic(self):
|
||||||
|
cancel = Value("L", 0)
|
||||||
|
logs = Queue()
|
||||||
|
pending = Queue()
|
||||||
|
progress = Queue()
|
||||||
|
active = Value("L", 0)
|
||||||
|
idle = Value("L", 0)
|
||||||
|
|
||||||
|
worker = WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
cancel,
|
||||||
|
logs,
|
||||||
|
pending,
|
||||||
|
progress,
|
||||||
|
active,
|
||||||
|
idle,
|
||||||
|
3,
|
||||||
|
0.1,
|
||||||
|
)
|
||||||
|
worker.start("test")
|
||||||
|
|
||||||
|
source = Image.new("RGB", (64, 64), "black")
|
||||||
|
run_img2img_pipeline(
|
||||||
|
worker,
|
||||||
|
ServerContext(model_path="../models", output_path="../outputs"),
|
||||||
|
ImageParams(
|
||||||
|
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
|
||||||
|
["test-img2img.png"],
|
||||||
|
UpscaleParams("test"),
|
||||||
|
HighresParams(False, 1, 0, 0),
|
||||||
|
source,
|
||||||
|
1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(path.exists("../outputs/test-img2img.png"))
|
||||||
|
|
||||||
|
class TestUpscalePipeline(unittest.TestCase):
|
||||||
|
@test_needs_models(["../models/upscaling-stable-diffusion-x4"])
|
||||||
|
def test_basic(self):
|
||||||
|
cancel = Value("L", 0)
|
||||||
|
logs = Queue()
|
||||||
|
pending = Queue()
|
||||||
|
progress = Queue()
|
||||||
|
active = Value("L", 0)
|
||||||
|
idle = Value("L", 0)
|
||||||
|
|
||||||
|
worker = WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
cancel,
|
||||||
|
logs,
|
||||||
|
pending,
|
||||||
|
progress,
|
||||||
|
active,
|
||||||
|
idle,
|
||||||
|
3,
|
||||||
|
0.1,
|
||||||
|
)
|
||||||
|
worker.start("test")
|
||||||
|
|
||||||
|
source = Image.new("RGB", (64, 64), "black")
|
||||||
|
run_upscale_pipeline(
|
||||||
|
worker,
|
||||||
|
ServerContext(model_path="../models", output_path="../outputs"),
|
||||||
|
ImageParams(
|
||||||
|
"../models/upscaling-stable-diffusion-x4", "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
|
||||||
|
Size(256, 256),
|
||||||
|
["test-upscale.png"],
|
||||||
|
UpscaleParams("test"),
|
||||||
|
HighresParams(False, 1, 0, 0),
|
||||||
|
source,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(path.exists("../outputs/test-upscale.png"))
|
||||||
|
|
||||||
|
class TestBlendPipeline(unittest.TestCase):
|
||||||
|
def test_basic(self):
|
||||||
|
cancel = Value("L", 0)
|
||||||
|
logs = Queue()
|
||||||
|
pending = Queue()
|
||||||
|
progress = Queue()
|
||||||
|
active = Value("L", 0)
|
||||||
|
idle = Value("L", 0)
|
||||||
|
|
||||||
|
worker = WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
cancel,
|
||||||
|
logs,
|
||||||
|
pending,
|
||||||
|
progress,
|
||||||
|
active,
|
||||||
|
idle,
|
||||||
|
3,
|
||||||
|
0.1,
|
||||||
|
)
|
||||||
|
worker.start("test")
|
||||||
|
|
||||||
|
source = Image.new("RGBA", (64, 64), "black")
|
||||||
|
mask = Image.new("RGBA", (64, 64), "white")
|
||||||
|
run_blend_pipeline(
|
||||||
|
worker,
|
||||||
|
ServerContext(model_path="../models", output_path="../outputs"),
|
||||||
|
ImageParams(
|
||||||
|
TEST_MODEL_DIFFUSION_SD15, "txt2img", "ddim", "an astronaut eating a hamburger", 3.0, 1, 1),
|
||||||
|
Size(64, 64),
|
||||||
|
["test-blend.png"],
|
||||||
|
UpscaleParams("test"),
|
||||||
|
[source, source],
|
||||||
|
mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(path.exists("../outputs/test-blend.png"))
|
|
@ -9,16 +9,22 @@ from onnx_web.worker.pool import DevicePoolExecutor
|
||||||
|
|
||||||
TEST_JOIN_TIMEOUT = 0.2
|
TEST_JOIN_TIMEOUT = 0.2
|
||||||
|
|
||||||
def test_job(*args, lock: Event, **kwargs):
|
lock = Event()
|
||||||
|
|
||||||
|
|
||||||
|
def test_job(*args, **kwargs):
|
||||||
lock.wait()
|
lock.wait()
|
||||||
|
|
||||||
|
|
||||||
|
def wait_job(*args, **kwargs):
|
||||||
|
sleep(0.5)
|
||||||
|
|
||||||
|
|
||||||
class TestWorkerPool(unittest.TestCase):
|
class TestWorkerPool(unittest.TestCase):
|
||||||
lock: Optional[Event]
|
# lock: Optional[Event]
|
||||||
pool: Optional[DevicePoolExecutor]
|
pool: Optional[DevicePoolExecutor]
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.lock = Event()
|
|
||||||
self.pool = None
|
self.pool = None
|
||||||
|
|
||||||
def tearDown(self) -> None:
|
def tearDown(self) -> None:
|
||||||
|
@ -38,7 +44,17 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
self.assertEqual(len(self.pool.workers), 1)
|
self.assertEqual(len(self.pool.workers), 1)
|
||||||
|
|
||||||
def test_cancel_pending(self):
|
def test_cancel_pending(self):
|
||||||
pass
|
device = DeviceParams("cpu", "CPUProvider")
|
||||||
|
server = ServerContext()
|
||||||
|
|
||||||
|
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||||
|
self.pool.start()
|
||||||
|
|
||||||
|
self.pool.submit("test", wait_job, lock=lock)
|
||||||
|
self.assertEqual(self.pool.done("test"), (True, None))
|
||||||
|
|
||||||
|
self.assertTrue(self.pool.cancel("test"))
|
||||||
|
self.assertEqual(self.pool.done("test"), (False, None))
|
||||||
|
|
||||||
def test_cancel_running(self):
|
def test_cancel_running(self):
|
||||||
pass
|
pass
|
||||||
|
@ -61,48 +77,46 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
self.assertEqual(self.pool.get_next_device(needs_device=device2), 1)
|
self.assertEqual(self.pool.get_next_device(needs_device=device2), 1)
|
||||||
|
|
||||||
def test_done_running(self):
|
def test_done_running(self):
|
||||||
"""
|
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = DeviceParams("cpu", "CPUProvider")
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
|
|
||||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1)
|
||||||
self.pool.start()
|
self.pool.start(lock)
|
||||||
|
sleep(2.0)
|
||||||
|
|
||||||
self.pool.submit("test", test_job, lock=self.lock)
|
self.pool.submit("test", test_job)
|
||||||
sleep(5.0)
|
sleep(2.0)
|
||||||
self.assertEqual(self.pool.done("test"), (False, None))
|
|
||||||
"""
|
pending, _progress = self.pool.done("test")
|
||||||
pass
|
self.assertFalse(pending)
|
||||||
|
|
||||||
def test_done_pending(self):
|
def test_done_pending(self):
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = DeviceParams("cpu", "CPUProvider")
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
|
|
||||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||||
self.pool.start()
|
self.pool.start(lock)
|
||||||
|
|
||||||
self.pool.submit("test1", test_job, lock=self.lock)
|
self.pool.submit("test1", test_job)
|
||||||
self.pool.submit("test2", test_job, lock=self.lock)
|
self.pool.submit("test2", test_job)
|
||||||
self.assertTrue(self.pool.done("test2"), (True, None))
|
self.assertTrue(self.pool.done("test2"), (True, None))
|
||||||
|
|
||||||
self.lock.set()
|
lock.set()
|
||||||
|
|
||||||
def test_done_finished(self):
|
def test_done_finished(self):
|
||||||
"""
|
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = DeviceParams("cpu", "CPUProvider")
|
||||||
server = ServerContext()
|
server = ServerContext()
|
||||||
|
|
||||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1)
|
||||||
self.pool.start()
|
self.pool.start()
|
||||||
|
sleep(2.0)
|
||||||
|
|
||||||
self.pool.submit("test", test_job, lock=self.lock)
|
self.pool.submit("test", wait_job)
|
||||||
self.assertEqual(self.pool.done("test"), (True, None))
|
self.assertEqual(self.pool.done("test"), (True, None))
|
||||||
|
|
||||||
self.lock.set()
|
sleep(2.0)
|
||||||
sleep(5.0)
|
pending, _progress = self.pool.done("test")
|
||||||
self.assertEqual(self.pool.done("test"), (False, None))
|
self.assertFalse(pending)
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_recycle_live(self):
|
def test_recycle_live(self):
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue