From f00bfe9bd00fd767699ffc4e4a2300d5ed27ad5f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 23 Nov 2023 11:19:58 -0600 Subject: [PATCH] more tests, apply lint --- api/onnx_web/convert/diffusion/lora.py | 10 +- .../diffusers/pipelines/controlnet.py | 2 +- api/onnx_web/diffusers/pipelines/panorama.py | 2 +- api/onnx_web/diffusers/pipelines/upscale.py | 2 +- api/tests/chain/test_correct_gfpgan.py | 44 +++++++++ api/tests/chain/test_upscale_bsrgan.py | 41 ++++++++ api/tests/chain/test_upscale_outpaint.py | 52 ++++++++++ api/tests/chain/test_upscale_resrgan.py | 39 ++++++++ api/tests/chain/test_upscale_swinir.py | 41 ++++++++ api/tests/helpers.py | 6 ++ api/tests/test_diffusers/test_run.py | 97 +++++++++++++++++++ api/tests/worker/test_pool.py | 8 +- 12 files changed, 334 insertions(+), 10 deletions(-) create mode 100644 api/tests/chain/test_correct_gfpgan.py create mode 100644 api/tests/chain/test_upscale_bsrgan.py create mode 100644 api/tests/chain/test_upscale_outpaint.py create mode 100644 api/tests/chain/test_upscale_resrgan.py create mode 100644 api/tests/chain/test_upscale_swinir.py diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index b912b107..d10157ee 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -139,7 +139,9 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any] ) elif block == "text_model" or simple: match = next( - node for node in remaining if fix_node_name(node.name) == f"{root}_MatMul" + node + for node in remaining + if fix_node_name(node.name) == f"{root}_MatMul" ) else: # search in order. one side has sparse indices, so they will not match. @@ -172,6 +174,12 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any] fixed[name] = value remaining.remove(match) + logger.debug( + "SDXL LoRA key fixup matched %s keys, %s remaining", + len(fixed.keys()), + len(remaining), + ) + return fixed diff --git a/api/onnx_web/diffusers/pipelines/controlnet.py b/api/onnx_web/diffusers/pipelines/controlnet.py index bcee67bf..e0515dc5 100644 --- a/api/onnx_web/diffusers/pipelines/controlnet.py +++ b/api/onnx_web/diffusers/pipelines/controlnet.py @@ -13,8 +13,8 @@ import numpy as np import PIL import torch from diffusers.configuration_utils import FrozenDict -from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.utils import PIL_INTERPOLATION, deprecate, logging diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index 22473a51..11511fa7 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -19,8 +19,8 @@ import numpy as np import PIL import torch from diffusers.configuration_utils import FrozenDict -from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.utils import PIL_INTERPOLATION, deprecate, logging diff --git a/api/onnx_web/diffusers/pipelines/upscale.py b/api/onnx_web/diffusers/pipelines/upscale.py index db961091..ff571f83 100644 --- a/api/onnx_web/diffusers/pipelines/upscale.py +++ b/api/onnx_web/diffusers/pipelines/upscale.py @@ -11,8 +11,8 @@ from typing import Any, Callable, List, Optional, Union import numpy as np import PIL import torch -from diffusers.pipelines.pipeline_utils import ImagePipelineOutput from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel +from diffusers.pipelines.pipeline_utils import ImagePipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionUpscalePipeline from diffusers.schedulers import DDPMScheduler diff --git a/api/tests/chain/test_correct_gfpgan.py b/api/tests/chain/test_correct_gfpgan.py new file mode 100644 index 00000000..9f8b6cb3 --- /dev/null +++ b/api/tests/chain/test_correct_gfpgan.py @@ -0,0 +1,44 @@ +import unittest + +from onnx_web.chain.correct_gfpgan import CorrectGFPGANStage +from onnx_web.chain.result import StageResult +from onnx_web.params import HighresParams, UpscaleParams +from onnx_web.server.context import ServerContext +from onnx_web.server.hacks import apply_patches +from onnx_web.worker.context import WorkerContext +from tests.helpers import test_device, test_needs_onnx_models + +TEST_MODEL = "../models/correction-gfpgan-v1-3" + + +class CorrectGFPGANStageTests(unittest.TestCase): + @test_needs_onnx_models([TEST_MODEL]) + def test_empty(self): + server = ServerContext(model_path="../models", output_path="../outputs") + apply_patches(server) + + worker = WorkerContext( + "test", + test_device(), + None, + None, + None, + None, + None, + None, + 0, + 0.1, + ) + stage = CorrectGFPGANStage() + sources = StageResult.empty() + result = stage.run( + worker, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(TEST_MODEL), + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_bsrgan.py b/api/tests/chain/test_upscale_bsrgan.py new file mode 100644 index 00000000..f93b800c --- /dev/null +++ b/api/tests/chain/test_upscale_bsrgan.py @@ -0,0 +1,41 @@ +import unittest + +from onnx_web.chain.result import StageResult +from onnx_web.chain.upscale_bsrgan import UpscaleBSRGANStage +from onnx_web.params import HighresParams, UpscaleParams +from onnx_web.server.context import ServerContext +from onnx_web.worker.context import WorkerContext +from tests.helpers import test_device, test_needs_onnx_models + +TEST_MODEL = "../models/upscaling-bsrgan-x4" + + +class UpscaleBSRGANStageTests(unittest.TestCase): + @test_needs_onnx_models([TEST_MODEL]) + def test_empty(self): + stage = UpscaleBSRGANStage() + sources = StageResult.empty() + result = stage.run( + WorkerContext( + "test", + test_device(), + None, + None, + None, + None, + None, + None, + 3, + 0.1, + ), + ServerContext( + model_path="../models", + ), + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(TEST_MODEL), + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_outpaint.py b/api/tests/chain/test_upscale_outpaint.py new file mode 100644 index 00000000..0e524014 --- /dev/null +++ b/api/tests/chain/test_upscale_outpaint.py @@ -0,0 +1,52 @@ +import unittest + +from PIL import Image + +from onnx_web.chain.result import StageResult +from onnx_web.chain.upscale_outpaint import UpscaleOutpaintStage +from onnx_web.params import Border, HighresParams, ImageParams, UpscaleParams +from onnx_web.server.context import ServerContext +from onnx_web.worker.context import WorkerContext +from tests.helpers import test_device, test_needs_models + + +class UpscaleOutpaintStageTests(unittest.TestCase): + @test_needs_models(["../models/stable-diffusion-onnx-v1-inpainting"]) + def test_empty(self): + stage = UpscaleOutpaintStage() + sources = StageResult.empty() + result = stage.run( + WorkerContext( + "test", + test_device(), + None, + None, + None, + None, + None, + None, + 3, + 0.1, + ), + ServerContext( + # model_path="../models", + ), + None, + ImageParams( + "../models/stable-diffusion-onnx-v1-inpainting", + "inpaint", + "euler", + "test", + 5.0, + 1, + 1, + ), + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams("stable-diffusion-onnx-v1-inpainting"), + border=Border.even(0), + dims=(), + tile_mask=Image.new("RGB", (64, 64)), + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_resrgan.py b/api/tests/chain/test_upscale_resrgan.py new file mode 100644 index 00000000..f832767f --- /dev/null +++ b/api/tests/chain/test_upscale_resrgan.py @@ -0,0 +1,39 @@ +import unittest + +from onnx_web.chain.result import StageResult +from onnx_web.chain.upscale_resrgan import UpscaleRealESRGANStage +from onnx_web.params import HighresParams, StageParams, UpscaleParams +from onnx_web.server.context import ServerContext +from onnx_web.worker.context import WorkerContext +from tests.helpers import test_device, test_needs_onnx_models + +TEST_MODEL = "../models/upscaling-real-esrgan-x4-v3" + + +class UpscaleRealESRGANStageTests(unittest.TestCase): + @test_needs_onnx_models([TEST_MODEL]) + def test_empty(self): + stage = UpscaleRealESRGANStage() + sources = StageResult.empty() + result = stage.run( + WorkerContext( + "test", + test_device(), + None, + None, + None, + None, + None, + None, + 3, + 0.1, + ), + ServerContext(model_path="../models"), + StageParams(), + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams("upscaling-real-esrgan-x4-v3"), + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_upscale_swinir.py b/api/tests/chain/test_upscale_swinir.py new file mode 100644 index 00000000..e2bf69fe --- /dev/null +++ b/api/tests/chain/test_upscale_swinir.py @@ -0,0 +1,41 @@ +import unittest + +from onnx_web.chain.result import StageResult +from onnx_web.chain.upscale_swinir import UpscaleSwinIRStage +from onnx_web.params import HighresParams, UpscaleParams +from onnx_web.server.context import ServerContext +from onnx_web.worker.context import WorkerContext +from tests.helpers import test_device, test_needs_onnx_models + +TEST_MODEL = "../models/upscaling-swinir-real-large-x4" + + +class UpscaleSwinIRStageTests(unittest.TestCase): + @test_needs_onnx_models([TEST_MODEL]) + def test_empty(self): + stage = UpscaleSwinIRStage() + sources = StageResult.empty() + result = stage.run( + WorkerContext( + "test", + test_device(), + None, + None, + None, + None, + None, + None, + 3, + 0.1, + ), + ServerContext( + # model_path="../models", + ), + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(TEST_MODEL), + ) + + self.assertEqual(len(result), 0) diff --git a/api/tests/helpers.py b/api/tests/helpers.py index e6c359ed..852ecd77 100644 --- a/api/tests/helpers.py +++ b/api/tests/helpers.py @@ -11,6 +11,12 @@ def test_needs_models(models: List[str]): ) +def test_needs_onnx_models(models: List[str]): + return skipUnless( + all([path.exists(f"{model}.onnx") for model in models]), "model does not exist" + ) + + def test_device() -> DeviceParams: return DeviceParams("cpu", "CPUExecutionProvider") diff --git a/api/tests/test_diffusers/test_run.py b/api/tests/test_diffusers/test_run.py index bb374838..59600fea 100644 --- a/api/tests/test_diffusers/test_run.py +++ b/api/tests/test_diffusers/test_run.py @@ -63,6 +63,55 @@ class TestTxt2ImgPipeline(unittest.TestCase): self.assertEqual(output.size, (256, 256)) # TODO: test contents of image + @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) + def test_batch(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) + + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") + + run_txt2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15, + "txt2img", + "ddim", + "an astronaut eating a hamburger", + 3.0, + 1, + 1, + batch=2, + ), + Size(256, 256), + ["test-txt2img-batch-0.png", "test-txt2img-batch-1.png"], + UpscaleParams("test"), + HighresParams(False, 1, 0, 0), + ) + + self.assertTrue(path.exists("../outputs/test-txt2img-batch-0.png")) + self.assertTrue(path.exists("../outputs/test-txt2img-batch-1.png")) + + output = Image.open("../outputs/test-txt2img-batch-0.png") + self.assertEqual(output.size, (256, 256)) + # TODO: test contents of image + @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) def test_highres(self): cancel = Value("L", 0) @@ -108,6 +157,54 @@ class TestTxt2ImgPipeline(unittest.TestCase): output = Image.open("../outputs/test-txt2img-highres.png") self.assertEqual(output.size, (512, 512)) + @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) + def test_highres_batch(self): + cancel = Value("L", 0) + logs = Queue() + pending = Queue() + progress = Queue() + active = Value("L", 0) + idle = Value("L", 0) + + worker = WorkerContext( + "test", + test_device(), + cancel, + logs, + pending, + progress, + active, + idle, + 3, + 0.1, + ) + worker.start("test") + + run_txt2img_pipeline( + worker, + ServerContext(model_path="../models", output_path="../outputs"), + ImageParams( + TEST_MODEL_DIFFUSION_SD15, + "txt2img", + "ddim", + "an astronaut eating a hamburger", + 3.0, + 1, + 1, + batch=2, + ), + Size(256, 256), + ["test-txt2img-highres-batch-0.png", "test-txt2img-highres-batch-1.png"], + UpscaleParams("test"), + HighresParams(True, 2, 0, 0), + ) + + self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-0.png")) + self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-1.png")) + + output = Image.open("../outputs/test-txt2img-highres-batch-0.png") + self.assertEqual(output.size, (512, 512)) + class TestImg2ImgPipeline(unittest.TestCase): @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) diff --git a/api/tests/worker/test_pool.py b/api/tests/worker/test_pool.py index 3f6f13cd..ea709156 100644 --- a/api/tests/worker/test_pool.py +++ b/api/tests/worker/test_pool.py @@ -89,10 +89,8 @@ class TestWorkerPool(unittest.TestCase): server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 ) self.pool.start(lock) - sleep(2.0) - self.pool.submit("test", test_job) - sleep(2.0) + sleep(5.0) pending, _progress = self.pool.done("test") self.assertFalse(pending) @@ -121,12 +119,10 @@ class TestWorkerPool(unittest.TestCase): server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 ) self.pool.start() - sleep(2.0) - self.pool.submit("test", wait_job) self.assertEqual(self.pool.done("test"), (True, None)) - sleep(2.0) + sleep(5.0) pending, _progress = self.pool.done("test") self.assertFalse(pending)