1
0
Fork 0

more tests, apply lint

This commit is contained in:
Sean Sube 2023-11-23 11:19:58 -06:00
parent 66dfa7206a
commit f00bfe9bd0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
12 changed files with 334 additions and 10 deletions

View File

@ -139,7 +139,9 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
)
elif block == "text_model" or simple:
match = next(
node for node in remaining if fix_node_name(node.name) == f"{root}_MatMul"
node
for node in remaining
if fix_node_name(node.name) == f"{root}_MatMul"
)
else:
# search in order. one side has sparse indices, so they will not match.
@ -172,6 +174,12 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
fixed[name] = value
remaining.remove(match)
logger.debug(
"SDXL LoRA key fixup matched %s keys, %s remaining",
len(fixed.keys()),
len(remaining),
)
return fixed

View File

@ -13,8 +13,8 @@ import numpy as np
import PIL
import torch
from diffusers.configuration_utils import FrozenDict
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging

View File

@ -19,8 +19,8 @@ import numpy as np
import PIL
import torch
from diffusers.configuration_utils import FrozenDict
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging

View File

@ -11,8 +11,8 @@ from typing import Any, Callable, List, Optional, Union
import numpy as np
import PIL
import torch
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionUpscalePipeline
from diffusers.schedulers import DDPMScheduler

View File

@ -0,0 +1,44 @@
import unittest
from onnx_web.chain.correct_gfpgan import CorrectGFPGANStage
from onnx_web.chain.result import StageResult
from onnx_web.params import HighresParams, UpscaleParams
from onnx_web.server.context import ServerContext
from onnx_web.server.hacks import apply_patches
from onnx_web.worker.context import WorkerContext
from tests.helpers import test_device, test_needs_onnx_models
TEST_MODEL = "../models/correction-gfpgan-v1-3"
class CorrectGFPGANStageTests(unittest.TestCase):
@test_needs_onnx_models([TEST_MODEL])
def test_empty(self):
server = ServerContext(model_path="../models", output_path="../outputs")
apply_patches(server)
worker = WorkerContext(
"test",
test_device(),
None,
None,
None,
None,
None,
None,
0,
0.1,
)
stage = CorrectGFPGANStage()
sources = StageResult.empty()
result = stage.run(
worker,
None,
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(TEST_MODEL),
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,41 @@
import unittest
from onnx_web.chain.result import StageResult
from onnx_web.chain.upscale_bsrgan import UpscaleBSRGANStage
from onnx_web.params import HighresParams, UpscaleParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
from tests.helpers import test_device, test_needs_onnx_models
TEST_MODEL = "../models/upscaling-bsrgan-x4"
class UpscaleBSRGANStageTests(unittest.TestCase):
@test_needs_onnx_models([TEST_MODEL])
def test_empty(self):
stage = UpscaleBSRGANStage()
sources = StageResult.empty()
result = stage.run(
WorkerContext(
"test",
test_device(),
None,
None,
None,
None,
None,
None,
3,
0.1,
),
ServerContext(
model_path="../models",
),
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(TEST_MODEL),
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,52 @@
import unittest
from PIL import Image
from onnx_web.chain.result import StageResult
from onnx_web.chain.upscale_outpaint import UpscaleOutpaintStage
from onnx_web.params import Border, HighresParams, ImageParams, UpscaleParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
from tests.helpers import test_device, test_needs_models
class UpscaleOutpaintStageTests(unittest.TestCase):
@test_needs_models(["../models/stable-diffusion-onnx-v1-inpainting"])
def test_empty(self):
stage = UpscaleOutpaintStage()
sources = StageResult.empty()
result = stage.run(
WorkerContext(
"test",
test_device(),
None,
None,
None,
None,
None,
None,
3,
0.1,
),
ServerContext(
# model_path="../models",
),
None,
ImageParams(
"../models/stable-diffusion-onnx-v1-inpainting",
"inpaint",
"euler",
"test",
5.0,
1,
1,
),
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams("stable-diffusion-onnx-v1-inpainting"),
border=Border.even(0),
dims=(),
tile_mask=Image.new("RGB", (64, 64)),
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,39 @@
import unittest
from onnx_web.chain.result import StageResult
from onnx_web.chain.upscale_resrgan import UpscaleRealESRGANStage
from onnx_web.params import HighresParams, StageParams, UpscaleParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
from tests.helpers import test_device, test_needs_onnx_models
TEST_MODEL = "../models/upscaling-real-esrgan-x4-v3"
class UpscaleRealESRGANStageTests(unittest.TestCase):
@test_needs_onnx_models([TEST_MODEL])
def test_empty(self):
stage = UpscaleRealESRGANStage()
sources = StageResult.empty()
result = stage.run(
WorkerContext(
"test",
test_device(),
None,
None,
None,
None,
None,
None,
3,
0.1,
),
ServerContext(model_path="../models"),
StageParams(),
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams("upscaling-real-esrgan-x4-v3"),
)
self.assertEqual(len(result), 0)

View File

@ -0,0 +1,41 @@
import unittest
from onnx_web.chain.result import StageResult
from onnx_web.chain.upscale_swinir import UpscaleSwinIRStage
from onnx_web.params import HighresParams, UpscaleParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
from tests.helpers import test_device, test_needs_onnx_models
TEST_MODEL = "../models/upscaling-swinir-real-large-x4"
class UpscaleSwinIRStageTests(unittest.TestCase):
@test_needs_onnx_models([TEST_MODEL])
def test_empty(self):
stage = UpscaleSwinIRStage()
sources = StageResult.empty()
result = stage.run(
WorkerContext(
"test",
test_device(),
None,
None,
None,
None,
None,
None,
3,
0.1,
),
ServerContext(
# model_path="../models",
),
None,
None,
sources,
highres=HighresParams(False, 1, 0, 0),
upscale=UpscaleParams(TEST_MODEL),
)
self.assertEqual(len(result), 0)

View File

@ -11,6 +11,12 @@ def test_needs_models(models: List[str]):
)
def test_needs_onnx_models(models: List[str]):
return skipUnless(
all([path.exists(f"{model}.onnx") for model in models]), "model does not exist"
)
def test_device() -> DeviceParams:
return DeviceParams("cpu", "CPUExecutionProvider")

View File

@ -63,6 +63,55 @@ class TestTxt2ImgPipeline(unittest.TestCase):
self.assertEqual(output.size, (256, 256))
# TODO: test contents of image
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_batch(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
run_txt2img_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15,
"txt2img",
"ddim",
"an astronaut eating a hamburger",
3.0,
1,
1,
batch=2,
),
Size(256, 256),
["test-txt2img-batch-0.png", "test-txt2img-batch-1.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
)
self.assertTrue(path.exists("../outputs/test-txt2img-batch-0.png"))
self.assertTrue(path.exists("../outputs/test-txt2img-batch-1.png"))
output = Image.open("../outputs/test-txt2img-batch-0.png")
self.assertEqual(output.size, (256, 256))
# TODO: test contents of image
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_highres(self):
cancel = Value("L", 0)
@ -108,6 +157,54 @@ class TestTxt2ImgPipeline(unittest.TestCase):
output = Image.open("../outputs/test-txt2img-highres.png")
self.assertEqual(output.size, (512, 512))
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_highres_batch(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test")
run_txt2img_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15,
"txt2img",
"ddim",
"an astronaut eating a hamburger",
3.0,
1,
1,
batch=2,
),
Size(256, 256),
["test-txt2img-highres-batch-0.png", "test-txt2img-highres-batch-1.png"],
UpscaleParams("test"),
HighresParams(True, 2, 0, 0),
)
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-0.png"))
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-1.png"))
output = Image.open("../outputs/test-txt2img-highres-batch-0.png")
self.assertEqual(output.size, (512, 512))
class TestImg2ImgPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])

View File

@ -89,10 +89,8 @@ class TestWorkerPool(unittest.TestCase):
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
)
self.pool.start(lock)
sleep(2.0)
self.pool.submit("test", test_job)
sleep(2.0)
sleep(5.0)
pending, _progress = self.pool.done("test")
self.assertFalse(pending)
@ -121,12 +119,10 @@ class TestWorkerPool(unittest.TestCase):
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
)
self.pool.start()
sleep(2.0)
self.pool.submit("test", wait_job)
self.assertEqual(self.pool.done("test"), (True, None))
sleep(2.0)
sleep(5.0)
pending, _progress = self.pool.done("test")
self.assertFalse(pending)