1
0
Fork 0

fix tests

This commit is contained in:
Sean Sube 2024-02-25 14:03:30 -06:00
parent 26e3405f40
commit a653f4421b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 156 additions and 78 deletions

View File

@ -47,9 +47,10 @@ class TestTxt2ImgPipeline(unittest.TestCase):
active = Value("L", 0) active = Value("L", 0)
idle = Value("L", 0) idle = Value("L", 0)
device = test_device()
worker = WorkerContext( worker = WorkerContext(
"test", "test",
test_device(), device,
cancel, cancel,
logs, logs,
pending, pending,
@ -59,11 +60,14 @@ class TestTxt2ImgPipeline(unittest.TestCase):
3, 3,
0.1, 0.1,
) )
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {})) worker.start(
JobCommand(
"test-txt2img-basic", "test", "test", run_txt2img_pipeline, [], {}
)
)
run_txt2img_pipeline( params = RequestParams(
worker, device,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams( ImageParams(
TEST_MODEL_DIFFUSION_SD15, TEST_MODEL_DIFFUSION_SD15,
"txt2img", "txt2img",
@ -73,12 +77,18 @@ class TestTxt2ImgPipeline(unittest.TestCase):
1, 1,
1, 1,
), ),
Size(256, 256), size=Size(256, 256),
UpscaleParams("test"), upscale=UpscaleParams("test"),
HighresParams(False, 1, 0, 0), highres=HighresParams(False, 1, 0, 0),
) )
self.assertTrue(path.exists("../outputs/test-txt2img-basic.png")) run_txt2img_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
params,
)
self.assertTrue(path.exists("../outputs/test-txt2img-basic_0.png"))
with Image.open("../outputs/test-txt2img-basic.png") as output: with Image.open("../outputs/test-txt2img-basic.png") as output:
self.assertEqual(output.size, (256, 256)) self.assertEqual(output.size, (256, 256))
@ -93,9 +103,10 @@ class TestTxt2ImgPipeline(unittest.TestCase):
active = Value("L", 0) active = Value("L", 0)
idle = Value("L", 0) idle = Value("L", 0)
device = test_device()
worker = WorkerContext( worker = WorkerContext(
"test", "test",
test_device(), device,
cancel, cancel,
logs, logs,
pending, pending,
@ -105,11 +116,14 @@ class TestTxt2ImgPipeline(unittest.TestCase):
3, 3,
0.1, 0.1,
) )
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {})) worker.start(
JobCommand(
"test-txt2img-batch", "test", "test", run_txt2img_pipeline, [], {}
)
)
run_txt2img_pipeline( params = RequestParams(
worker, device,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams( ImageParams(
TEST_MODEL_DIFFUSION_SD15, TEST_MODEL_DIFFUSION_SD15,
"txt2img", "txt2img",
@ -120,15 +134,21 @@ class TestTxt2ImgPipeline(unittest.TestCase):
1, 1,
batch=2, batch=2,
), ),
Size(256, 256), size=Size(256, 256),
UpscaleParams("test"), upscale=UpscaleParams("test"),
HighresParams(False, 1, 0, 0), highres=HighresParams(False, 1, 0, 0),
) )
self.assertTrue(path.exists("../outputs/test-txt2img-batch-0.png")) run_txt2img_pipeline(
self.assertTrue(path.exists("../outputs/test-txt2img-batch-1.png")) worker,
ServerContext(model_path="../models", output_path="../outputs"),
params,
)
with Image.open("../outputs/test-txt2img-batch-0.png") as output: self.assertTrue(path.exists("../outputs/test-txt2img-batch_0.png"))
self.assertTrue(path.exists("../outputs/test-txt2img-batch_1.png"))
with Image.open("../outputs/test-txt2img-batch_0.png") as output:
self.assertEqual(output.size, (256, 256)) self.assertEqual(output.size, (256, 256))
# TODO: test contents of image # TODO: test contents of image
@ -141,9 +161,10 @@ class TestTxt2ImgPipeline(unittest.TestCase):
active = Value("L", 0) active = Value("L", 0)
idle = Value("L", 0) idle = Value("L", 0)
device = test_device()
worker = WorkerContext( worker = WorkerContext(
"test", "test",
test_device(), device,
cancel, cancel,
logs, logs,
pending, pending,
@ -153,11 +174,14 @@ class TestTxt2ImgPipeline(unittest.TestCase):
3, 3,
0.1, 0.1,
) )
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {})) worker.start(
JobCommand(
"test-txt2img-highres", "test", "test", run_txt2img_pipeline, [], {}
)
)
run_txt2img_pipeline( params = RequestParams(
worker, device,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams( ImageParams(
TEST_MODEL_DIFFUSION_SD15, TEST_MODEL_DIFFUSION_SD15,
"txt2img", "txt2img",
@ -168,13 +192,19 @@ class TestTxt2ImgPipeline(unittest.TestCase):
1, 1,
unet_tile=256, unet_tile=256,
), ),
Size(256, 256), size=Size(256, 256),
UpscaleParams("test", scale=2, outscale=2), upscale=UpscaleParams("test", scale=2, outscale=2),
HighresParams(True, 2, 0, 0), highres=HighresParams(True, 2, 0, 0),
) )
self.assertTrue(path.exists("../outputs/test-txt2img-highres.png")) run_txt2img_pipeline(
with Image.open("../outputs/test-txt2img-highres.png") as output: worker,
ServerContext(model_path="../models", output_path="../outputs"),
params,
)
self.assertTrue(path.exists("../outputs/test-txt2img-highres_0.png"))
with Image.open("../outputs/test-txt2img-highres_0.png") as output:
self.assertEqual(output.size, (512, 512)) self.assertEqual(output.size, (512, 512))
@test_needs_models([TEST_MODEL_DIFFUSION_SD15]) @test_needs_models([TEST_MODEL_DIFFUSION_SD15])
@ -186,9 +216,10 @@ class TestTxt2ImgPipeline(unittest.TestCase):
active = Value("L", 0) active = Value("L", 0)
idle = Value("L", 0) idle = Value("L", 0)
device = test_device()
worker = WorkerContext( worker = WorkerContext(
"test", "test",
test_device(), device,
cancel, cancel,
logs, logs,
pending, pending,
@ -198,11 +229,19 @@ class TestTxt2ImgPipeline(unittest.TestCase):
3, 3,
0.1, 0.1,
) )
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {})) worker.start(
JobCommand(
"test-txt2img-highres-batch",
"test",
"test",
run_txt2img_pipeline,
[],
{},
)
)
run_txt2img_pipeline( params = RequestParams(
worker, device,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams( ImageParams(
TEST_MODEL_DIFFUSION_SD15, TEST_MODEL_DIFFUSION_SD15,
"txt2img", "txt2img",
@ -213,15 +252,21 @@ class TestTxt2ImgPipeline(unittest.TestCase):
1, 1,
batch=2, batch=2,
), ),
Size(256, 256), size=Size(256, 256),
UpscaleParams("test"), upscale=UpscaleParams("test"),
HighresParams(True, 2, 0, 0), highres=HighresParams(True, 2, 0, 0),
) )
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-0.png")) run_txt2img_pipeline(
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-1.png")) worker,
ServerContext(model_path="../models", output_path="../outputs"),
params,
)
with Image.open("../outputs/test-txt2img-highres-batch-0.png") as output: self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch_0.png"))
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch_1.png"))
with Image.open("../outputs/test-txt2img-highres-batch_0.png") as output:
self.assertEqual(output.size, (512, 512)) self.assertEqual(output.size, (512, 512))
@ -229,44 +274,53 @@ class TestImg2ImgPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15]) @test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_basic(self): def test_basic(self):
worker = test_worker() worker = test_worker()
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {})) worker.start(
JobCommand("test-img2img", "test", "test", run_txt2img_pipeline, [], {})
)
source = Image.new("RGB", (64, 64), "black") source = Image.new("RGB", (64, 64), "black")
run_img2img_pipeline( params = RequestParams(
worker, test_device(),
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams( ImageParams(
TEST_MODEL_DIFFUSION_SD15, TEST_MODEL_DIFFUSION_SD15,
"txt2img", "img2img",
TEST_SCHEDULER, TEST_SCHEDULER,
TEST_PROMPT, TEST_PROMPT,
3.0, 3.0,
1, 1,
1, 1,
), ),
UpscaleParams("test"), upscale=UpscaleParams("test"),
HighresParams(False, 1, 0, 0), highres=HighresParams(False, 1, 0, 0),
)
run_img2img_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
params,
source, source,
1.0, 1.0,
) )
self.assertTrue(path.exists("../outputs/test-img2img.png")) self.assertTrue(path.exists("../outputs/test-img2img_0.png"))
class TestInpaintPipeline(unittest.TestCase): class TestInpaintPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT]) @test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
def test_basic_white(self): def test_basic_white(self):
worker = test_worker() worker = test_worker()
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {})) worker.start(
JobCommand(
"test-inpaint-white", "test", "test", run_txt2img_pipeline, [], {}
)
)
source = Image.new("RGB", (64, 64), "black") source = Image.new("RGB", (64, 64), "black")
mask = Image.new("RGB", (64, 64), "white") mask = Image.new("RGB", (64, 64), "white")
run_inpaint_pipeline( params = RequestParams(
worker, test_device(),
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams( ImageParams(
TEST_MODEL_DIFFUSION_SD15_INPAINT, TEST_MODEL_DIFFUSION_SD15_INPAINT,
"txt2img", "inpaint",
TEST_SCHEDULER, TEST_SCHEDULER,
TEST_PROMPT, TEST_PROMPT,
3.0, 3.0,
@ -274,9 +328,15 @@ class TestInpaintPipeline(unittest.TestCase):
1, 1,
unet_tile=64, unet_tile=64,
), ),
Size(*source.size), size=Size(*source.size),
UpscaleParams("test"), upscale=UpscaleParams("test"),
HighresParams(False, 1, 0, 0), highres=HighresParams(False, 1, 0, 0),
)
run_inpaint_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
params,
source, source,
mask, mask,
Border.even(0), Border.even(0),
@ -288,21 +348,24 @@ class TestInpaintPipeline(unittest.TestCase):
0.0, 0.0,
) )
self.assertTrue(path.exists("../outputs/test-inpaint-white.png")) self.assertTrue(path.exists("../outputs/test-inpaint-white_0.png"))
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT]) @test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
def test_basic_black(self): def test_basic_black(self):
worker = test_worker() worker = test_worker()
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {})) worker.start(
JobCommand(
"test-inpaint-black", "test", "test", run_txt2img_pipeline, [], {}
)
)
source = Image.new("RGB", (64, 64), "black") source = Image.new("RGB", (64, 64), "black")
mask = Image.new("RGB", (64, 64), "black") mask = Image.new("RGB", (64, 64), "black")
run_inpaint_pipeline( params = RequestParams(
worker, test_device(),
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams( ImageParams(
TEST_MODEL_DIFFUSION_SD15_INPAINT, TEST_MODEL_DIFFUSION_SD15_INPAINT,
"txt2img", "inpaint",
TEST_SCHEDULER, TEST_SCHEDULER,
TEST_PROMPT, TEST_PROMPT,
3.0, 3.0,
@ -310,9 +373,15 @@ class TestInpaintPipeline(unittest.TestCase):
1, 1,
unet_tile=64, unet_tile=64,
), ),
Size(*source.size), size=Size(*source.size),
UpscaleParams("test"), upscale=UpscaleParams("test"),
HighresParams(False, 1, 0, 0), highres=HighresParams(False, 1, 0, 0),
)
run_inpaint_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
params,
source, source,
mask, mask,
Border.even(0), Border.even(0),
@ -324,7 +393,7 @@ class TestInpaintPipeline(unittest.TestCase):
0.0, 0.0,
) )
self.assertTrue(path.exists("../outputs/test-inpaint-black.png")) self.assertTrue(path.exists("../outputs/test-inpaint-black_0.png"))
class TestUpscalePipeline(unittest.TestCase): class TestUpscalePipeline(unittest.TestCase):
@ -337,9 +406,10 @@ class TestUpscalePipeline(unittest.TestCase):
active = Value("L", 0) active = Value("L", 0)
idle = Value("L", 0) idle = Value("L", 0)
device = test_device()
worker = WorkerContext( worker = WorkerContext(
"test", "test",
test_device(), device,
cancel, cancel,
logs, logs,
pending, pending,
@ -349,12 +419,13 @@ class TestUpscalePipeline(unittest.TestCase):
3, 3,
0.1, 0.1,
) )
worker.start(JobCommand("test", "test", "test", run_upscale_pipeline, [], {})) worker.start(
JobCommand("test-upscale", "test", "test", run_upscale_pipeline, [], {})
)
source = Image.new("RGB", (64, 64), "black") source = Image.new("RGB", (64, 64), "black")
run_upscale_pipeline( params = RequestParams(
worker, device,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams( ImageParams(
"../models/upscaling-stable-diffusion-x4", "../models/upscaling-stable-diffusion-x4",
"txt2img", "txt2img",
@ -364,13 +435,18 @@ class TestUpscalePipeline(unittest.TestCase):
1, 1,
1, 1,
), ),
Size(256, 256), size=Size(256, 256),
UpscaleParams("test"), upscale=UpscaleParams("test"),
HighresParams(False, 1, 0, 0), highres=HighresParams(False, 1, 0, 0),
)
run_upscale_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
params,
source, source,
) )
self.assertTrue(path.exists("../outputs/test-upscale.png")) self.assertTrue(path.exists("../outputs/test-upscale_0.png"))
class TestBlendPipeline(unittest.TestCase): class TestBlendPipeline(unittest.TestCase):
@ -395,7 +471,9 @@ class TestBlendPipeline(unittest.TestCase):
3, 3,
0.1, 0.1,
) )
worker.start(JobCommand("test", "test", "test", run_blend_pipeline, [], {})) worker.start(
JobCommand("test-blend", "test", "test", run_blend_pipeline, [], {})
)
source = Image.new("RGBA", (64, 64), "black") source = Image.new("RGBA", (64, 64), "black")
mask = Image.new("RGBA", (64, 64), "white") mask = Image.new("RGBA", (64, 64), "white")
@ -422,4 +500,4 @@ class TestBlendPipeline(unittest.TestCase):
mask, mask,
) )
self.assertTrue(path.exists("../outputs/test-blend.png")) self.assertTrue(path.exists("../outputs/test-blend_0.png"))