diff --git a/api/tests/chain/test_blend_img2img.py b/api/tests/chain/test_blend_img2img.py index 31aa27a3..9d6f71d9 100644 --- a/api/tests/chain/test_blend_img2img.py +++ b/api/tests/chain/test_blend_img2img.py @@ -3,10 +3,11 @@ import unittest from PIL import Image from onnx_web.chain.blend_img2img import BlendImg2ImgStage -from onnx_web.params import DeviceParams, ImageParams +from onnx_web.chain.result import StageResult +from onnx_web.params import ImageParams from onnx_web.server.context import ServerContext from onnx_web.worker.context import WorkerContext -from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_needs_models +from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models class BlendImg2ImgStageTests(unittest.TestCase): @@ -25,7 +26,7 @@ class BlendImg2ImgStageTests(unittest.TestCase): server = ServerContext(model_path="../models", output_path="../outputs") worker = WorkerContext( "test", - DeviceParams("cpu", "CPUProvider"), + test_device(), None, None, None, @@ -33,11 +34,14 @@ class BlendImg2ImgStageTests(unittest.TestCase): None, None, 0, + 0.1, + ) + sources = StageResult( + images=[ + Image.new("RGB", (64, 64), "black"), + ] ) - sources = [ - Image.new("RGB", (64, 64), "black"), - ] result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1) self.assertEqual(len(result), 1) - self.assertEqual(result[0].getpixel((0, 0)), (127, 127, 127)) + self.assertEqual(result.as_image()[0].getpixel((0, 0)), (0, 0, 0)) diff --git a/api/tests/chain/test_correct_codeformer.py b/api/tests/chain/test_correct_codeformer.py index 8a90d0c9..fa764554 100644 --- a/api/tests/chain/test_correct_codeformer.py +++ b/api/tests/chain/test_correct_codeformer.py @@ -1,21 +1,27 @@ import unittest from onnx_web.chain.correct_codeformer import CorrectCodeformerStage -from onnx_web.params import DeviceParams, HighresParams, UpscaleParams +from onnx_web.chain.result import StageResult +from onnx_web.params import HighresParams, UpscaleParams from onnx_web.server.context import ServerContext from onnx_web.server.hacks import apply_patches from onnx_web.worker.context import WorkerContext +from tests.helpers import ( + TEST_MODEL_CORRECTION_CODEFORMER, + test_device, + test_needs_models, +) class CorrectCodeformerStageTests(unittest.TestCase): + @test_needs_models([TEST_MODEL_CORRECTION_CODEFORMER]) def test_empty(self): - """ - server = ServerContext() + server = ServerContext(model_path="../models", output_path="../outputs") apply_patches(server) worker = WorkerContext( "test", - DeviceParams("cpu", "CPUProvider"), + test_device(), None, None, None, @@ -23,11 +29,18 @@ class CorrectCodeformerStageTests(unittest.TestCase): None, None, 0, + 0.1, ) stage = CorrectCodeformerStage() sources = StageResult.empty() - result = stage.run(worker, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams("")) + result = stage.run( + worker, + None, + None, + None, + sources, + highres=HighresParams(False, 1, 0, 0), + upscale=UpscaleParams(""), + ) self.assertEqual(len(result), 0) - """ - pass diff --git a/api/tests/convert/diffusion/test_lora.py b/api/tests/convert/diffusion/test_lora.py index bcf19680..2b94b559 100644 --- a/api/tests/convert/diffusion/test_lora.py +++ b/api/tests/convert/diffusion/test_lora.py @@ -233,7 +233,7 @@ class BlendWeightsLoHATests(unittest.TestCase): i = 32 j = 4 k = 1 - l = 1 + l = 1 # NOQA p = 2 r = 4 diff --git a/api/tests/convert/test_utils.py b/api/tests/convert/test_utils.py index ae0c2842..4281adbc 100644 --- a/api/tests/convert/test_utils.py +++ b/api/tests/convert/test_utils.py @@ -12,10 +12,7 @@ from onnx_web.convert.utils import ( tuple_to_source, tuple_to_upscaling, ) -from tests.helpers import ( - TEST_MODEL_UPSCALING_SWINIR, - test_needs_models, -) +from tests.helpers import TEST_MODEL_UPSCALING_SWINIR, test_needs_models class ConversionContextTests(unittest.TestCase): diff --git a/api/tests/helpers.py b/api/tests/helpers.py index 64714819..e6c359ed 100644 --- a/api/tests/helpers.py +++ b/api/tests/helpers.py @@ -15,5 +15,6 @@ def test_device() -> DeviceParams: return DeviceParams("cpu", "CPUExecutionProvider") +TEST_MODEL_CORRECTION_CODEFORMER = "../models/.cache/correction-codeformer.pth" TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5" TEST_MODEL_UPSCALING_SWINIR = "../models/.cache/upscaling-swinir.pth"