update pipeline tests
This commit is contained in:
parent
b6aed0542c
commit
a02523c54c
|
@ -3,10 +3,11 @@ import unittest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from onnx_web.chain.blend_img2img import BlendImg2ImgStage
|
from onnx_web.chain.blend_img2img import BlendImg2ImgStage
|
||||||
from onnx_web.params import DeviceParams, ImageParams
|
from onnx_web.chain.result import StageResult
|
||||||
|
from onnx_web.params import ImageParams
|
||||||
from onnx_web.server.context import ServerContext
|
from onnx_web.server.context import ServerContext
|
||||||
from onnx_web.worker.context import WorkerContext
|
from onnx_web.worker.context import WorkerContext
|
||||||
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_needs_models
|
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
|
||||||
|
|
||||||
|
|
||||||
class BlendImg2ImgStageTests(unittest.TestCase):
|
class BlendImg2ImgStageTests(unittest.TestCase):
|
||||||
|
@ -25,7 +26,7 @@ class BlendImg2ImgStageTests(unittest.TestCase):
|
||||||
server = ServerContext(model_path="../models", output_path="../outputs")
|
server = ServerContext(model_path="../models", output_path="../outputs")
|
||||||
worker = WorkerContext(
|
worker = WorkerContext(
|
||||||
"test",
|
"test",
|
||||||
DeviceParams("cpu", "CPUProvider"),
|
test_device(),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
@ -33,11 +34,14 @@ class BlendImg2ImgStageTests(unittest.TestCase):
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
0,
|
0,
|
||||||
|
0.1,
|
||||||
)
|
)
|
||||||
sources = [
|
sources = StageResult(
|
||||||
|
images=[
|
||||||
Image.new("RGB", (64, 64), "black"),
|
Image.new("RGB", (64, 64), "black"),
|
||||||
]
|
]
|
||||||
|
)
|
||||||
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
|
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
|
||||||
|
|
||||||
self.assertEqual(len(result), 1)
|
self.assertEqual(len(result), 1)
|
||||||
self.assertEqual(result[0].getpixel((0, 0)), (127, 127, 127))
|
self.assertEqual(result.as_image()[0].getpixel((0, 0)), (0, 0, 0))
|
||||||
|
|
|
@ -1,21 +1,27 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from onnx_web.chain.correct_codeformer import CorrectCodeformerStage
|
from onnx_web.chain.correct_codeformer import CorrectCodeformerStage
|
||||||
from onnx_web.params import DeviceParams, HighresParams, UpscaleParams
|
from onnx_web.chain.result import StageResult
|
||||||
|
from onnx_web.params import HighresParams, UpscaleParams
|
||||||
from onnx_web.server.context import ServerContext
|
from onnx_web.server.context import ServerContext
|
||||||
from onnx_web.server.hacks import apply_patches
|
from onnx_web.server.hacks import apply_patches
|
||||||
from onnx_web.worker.context import WorkerContext
|
from onnx_web.worker.context import WorkerContext
|
||||||
|
from tests.helpers import (
|
||||||
|
TEST_MODEL_CORRECTION_CODEFORMER,
|
||||||
|
test_device,
|
||||||
|
test_needs_models,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CorrectCodeformerStageTests(unittest.TestCase):
|
class CorrectCodeformerStageTests(unittest.TestCase):
|
||||||
|
@test_needs_models([TEST_MODEL_CORRECTION_CODEFORMER])
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
"""
|
server = ServerContext(model_path="../models", output_path="../outputs")
|
||||||
server = ServerContext()
|
|
||||||
apply_patches(server)
|
apply_patches(server)
|
||||||
|
|
||||||
worker = WorkerContext(
|
worker = WorkerContext(
|
||||||
"test",
|
"test",
|
||||||
DeviceParams("cpu", "CPUProvider"),
|
test_device(),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
@ -23,11 +29,18 @@ class CorrectCodeformerStageTests(unittest.TestCase):
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
0,
|
0,
|
||||||
|
0.1,
|
||||||
)
|
)
|
||||||
stage = CorrectCodeformerStage()
|
stage = CorrectCodeformerStage()
|
||||||
sources = StageResult.empty()
|
sources = StageResult.empty()
|
||||||
result = stage.run(worker, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
|
result = stage.run(
|
||||||
|
worker,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
sources,
|
||||||
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
|
upscale=UpscaleParams(""),
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
|
@ -233,7 +233,7 @@ class BlendWeightsLoHATests(unittest.TestCase):
|
||||||
i = 32
|
i = 32
|
||||||
j = 4
|
j = 4
|
||||||
k = 1
|
k = 1
|
||||||
l = 1
|
l = 1 # NOQA
|
||||||
p = 2
|
p = 2
|
||||||
r = 4
|
r = 4
|
||||||
|
|
||||||
|
|
|
@ -12,10 +12,7 @@ from onnx_web.convert.utils import (
|
||||||
tuple_to_source,
|
tuple_to_source,
|
||||||
tuple_to_upscaling,
|
tuple_to_upscaling,
|
||||||
)
|
)
|
||||||
from tests.helpers import (
|
from tests.helpers import TEST_MODEL_UPSCALING_SWINIR, test_needs_models
|
||||||
TEST_MODEL_UPSCALING_SWINIR,
|
|
||||||
test_needs_models,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ConversionContextTests(unittest.TestCase):
|
class ConversionContextTests(unittest.TestCase):
|
||||||
|
|
|
@ -15,5 +15,6 @@ def test_device() -> DeviceParams:
|
||||||
return DeviceParams("cpu", "CPUExecutionProvider")
|
return DeviceParams("cpu", "CPUExecutionProvider")
|
||||||
|
|
||||||
|
|
||||||
|
TEST_MODEL_CORRECTION_CODEFORMER = "../models/.cache/correction-codeformer.pth"
|
||||||
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
|
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
|
||||||
TEST_MODEL_UPSCALING_SWINIR = "../models/.cache/upscaling-swinir.pth"
|
TEST_MODEL_UPSCALING_SWINIR = "../models/.cache/upscaling-swinir.pth"
|
||||||
|
|
Loading…
Reference in New Issue