2023-09-15 00:35:48 +00:00
|
|
|
import unittest
|
|
|
|
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
from onnx_web.chain.blend_img2img import BlendImg2ImgStage
|
2024-01-06 02:13:57 +00:00
|
|
|
from onnx_web.chain.result import ImageMetadata, StageResult
|
|
|
|
from onnx_web.params import ImageParams, Size
|
2023-09-15 00:35:48 +00:00
|
|
|
from onnx_web.server.context import ServerContext
|
|
|
|
from onnx_web.worker.context import WorkerContext
|
2023-11-22 01:11:04 +00:00
|
|
|
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
|
2023-09-15 00:35:48 +00:00
|
|
|
|
|
|
|
|
|
|
|
class BlendImg2ImgStageTests(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
2023-09-15 00:35:48 +00:00
|
|
|
def test_stage(self):
|
|
|
|
stage = BlendImg2ImgStage()
|
2023-11-20 05:18:57 +00:00
|
|
|
params = ImageParams(
|
|
|
|
TEST_MODEL_DIFFUSION_SD15,
|
|
|
|
"txt2img",
|
|
|
|
"euler-a",
|
|
|
|
"an astronaut eating a hamburger",
|
|
|
|
3.0,
|
|
|
|
1,
|
|
|
|
1,
|
|
|
|
)
|
|
|
|
server = ServerContext(model_path="../models", output_path="../outputs")
|
|
|
|
worker = WorkerContext(
|
|
|
|
"test",
|
2023-11-22 01:11:04 +00:00
|
|
|
test_device(),
|
2023-11-20 05:18:57 +00:00
|
|
|
None,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
0,
|
2023-11-22 01:11:04 +00:00
|
|
|
0.1,
|
|
|
|
)
|
|
|
|
sources = StageResult(
|
|
|
|
images=[
|
|
|
|
Image.new("RGB", (64, 64), "black"),
|
2024-01-06 02:13:57 +00:00
|
|
|
],
|
|
|
|
metadata=[
|
|
|
|
ImageMetadata(
|
|
|
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 25, 1),
|
|
|
|
Size(64, 64),
|
|
|
|
),
|
|
|
|
],
|
2023-11-20 05:18:57 +00:00
|
|
|
)
|
2023-09-15 00:35:48 +00:00
|
|
|
result = stage.run(worker, server, None, params, sources, strength=0.5, steps=1)
|
2024-01-06 02:13:57 +00:00
|
|
|
result.validate()
|
2023-09-15 00:35:48 +00:00
|
|
|
|
|
|
|
self.assertEqual(len(result), 1)
|
2024-01-06 02:13:57 +00:00
|
|
|
self.assertEqual(result.as_images()[0].getpixel((0, 0)), (0, 0, 0))
|