more tests, apply lint
This commit is contained in:
parent
66dfa7206a
commit
f00bfe9bd0
|
@ -139,7 +139,9 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
|
||||||
)
|
)
|
||||||
elif block == "text_model" or simple:
|
elif block == "text_model" or simple:
|
||||||
match = next(
|
match = next(
|
||||||
node for node in remaining if fix_node_name(node.name) == f"{root}_MatMul"
|
node
|
||||||
|
for node in remaining
|
||||||
|
if fix_node_name(node.name) == f"{root}_MatMul"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# search in order. one side has sparse indices, so they will not match.
|
# search in order. one side has sparse indices, so they will not match.
|
||||||
|
@ -172,6 +174,12 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
|
||||||
fixed[name] = value
|
fixed[name] = value
|
||||||
remaining.remove(match)
|
remaining.remove(match)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"SDXL LoRA key fixup matched %s keys, %s remaining",
|
||||||
|
len(fixed.keys()),
|
||||||
|
len(remaining),
|
||||||
|
)
|
||||||
|
|
||||||
return fixed
|
return fixed
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,8 +13,8 @@ import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||||
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||||
|
|
|
@ -19,8 +19,8 @@ import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
|
||||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||||
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||||
|
|
|
@ -11,8 +11,8 @@ from typing import Any, Callable, List, Optional, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL
|
import PIL
|
||||||
import torch
|
import torch
|
||||||
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
|
|
||||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||||
|
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionUpscalePipeline
|
from diffusers.pipelines.stable_diffusion import StableDiffusionUpscalePipeline
|
||||||
from diffusers.schedulers import DDPMScheduler
|
from diffusers.schedulers import DDPMScheduler
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from onnx_web.chain.correct_gfpgan import CorrectGFPGANStage
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
from onnx_web.params import HighresParams, UpscaleParams
|
||||||
|
from onnx_web.server.context import ServerContext
|
||||||
|
from onnx_web.server.hacks import apply_patches
|
||||||
|
from onnx_web.worker.context import WorkerContext
|
||||||
|
from tests.helpers import test_device, test_needs_onnx_models
|
||||||
|
|
||||||
|
TEST_MODEL = "../models/correction-gfpgan-v1-3"
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectGFPGANStageTests(unittest.TestCase):
|
||||||
|
@test_needs_onnx_models([TEST_MODEL])
|
||||||
|
def test_empty(self):
|
||||||
|
server = ServerContext(model_path="../models", output_path="../outputs")
|
||||||
|
apply_patches(server)
|
||||||
|
|
||||||
|
worker = WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0,
|
||||||
|
0.1,
|
||||||
|
)
|
||||||
|
stage = CorrectGFPGANStage()
|
||||||
|
sources = StageResult.empty()
|
||||||
|
result = stage.run(
|
||||||
|
worker,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
sources,
|
||||||
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
|
upscale=UpscaleParams(TEST_MODEL),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,41 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
from onnx_web.chain.upscale_bsrgan import UpscaleBSRGANStage
|
||||||
|
from onnx_web.params import HighresParams, UpscaleParams
|
||||||
|
from onnx_web.server.context import ServerContext
|
||||||
|
from onnx_web.worker.context import WorkerContext
|
||||||
|
from tests.helpers import test_device, test_needs_onnx_models
|
||||||
|
|
||||||
|
TEST_MODEL = "../models/upscaling-bsrgan-x4"
|
||||||
|
|
||||||
|
|
||||||
|
class UpscaleBSRGANStageTests(unittest.TestCase):
|
||||||
|
@test_needs_onnx_models([TEST_MODEL])
|
||||||
|
def test_empty(self):
|
||||||
|
stage = UpscaleBSRGANStage()
|
||||||
|
sources = StageResult.empty()
|
||||||
|
result = stage.run(
|
||||||
|
WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
3,
|
||||||
|
0.1,
|
||||||
|
),
|
||||||
|
ServerContext(
|
||||||
|
model_path="../models",
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
sources,
|
||||||
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
|
upscale=UpscaleParams(TEST_MODEL),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,52 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
from onnx_web.chain.upscale_outpaint import UpscaleOutpaintStage
|
||||||
|
from onnx_web.params import Border, HighresParams, ImageParams, UpscaleParams
|
||||||
|
from onnx_web.server.context import ServerContext
|
||||||
|
from onnx_web.worker.context import WorkerContext
|
||||||
|
from tests.helpers import test_device, test_needs_models
|
||||||
|
|
||||||
|
|
||||||
|
class UpscaleOutpaintStageTests(unittest.TestCase):
|
||||||
|
@test_needs_models(["../models/stable-diffusion-onnx-v1-inpainting"])
|
||||||
|
def test_empty(self):
|
||||||
|
stage = UpscaleOutpaintStage()
|
||||||
|
sources = StageResult.empty()
|
||||||
|
result = stage.run(
|
||||||
|
WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
3,
|
||||||
|
0.1,
|
||||||
|
),
|
||||||
|
ServerContext(
|
||||||
|
# model_path="../models",
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
ImageParams(
|
||||||
|
"../models/stable-diffusion-onnx-v1-inpainting",
|
||||||
|
"inpaint",
|
||||||
|
"euler",
|
||||||
|
"test",
|
||||||
|
5.0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
sources,
|
||||||
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
|
upscale=UpscaleParams("stable-diffusion-onnx-v1-inpainting"),
|
||||||
|
border=Border.even(0),
|
||||||
|
dims=(),
|
||||||
|
tile_mask=Image.new("RGB", (64, 64)),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,39 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
from onnx_web.chain.upscale_resrgan import UpscaleRealESRGANStage
|
||||||
|
from onnx_web.params import HighresParams, StageParams, UpscaleParams
|
||||||
|
from onnx_web.server.context import ServerContext
|
||||||
|
from onnx_web.worker.context import WorkerContext
|
||||||
|
from tests.helpers import test_device, test_needs_onnx_models
|
||||||
|
|
||||||
|
TEST_MODEL = "../models/upscaling-real-esrgan-x4-v3"
|
||||||
|
|
||||||
|
|
||||||
|
class UpscaleRealESRGANStageTests(unittest.TestCase):
|
||||||
|
@test_needs_onnx_models([TEST_MODEL])
|
||||||
|
def test_empty(self):
|
||||||
|
stage = UpscaleRealESRGANStage()
|
||||||
|
sources = StageResult.empty()
|
||||||
|
result = stage.run(
|
||||||
|
WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
3,
|
||||||
|
0.1,
|
||||||
|
),
|
||||||
|
ServerContext(model_path="../models"),
|
||||||
|
StageParams(),
|
||||||
|
None,
|
||||||
|
sources,
|
||||||
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
|
upscale=UpscaleParams("upscaling-real-esrgan-x4-v3"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,41 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
from onnx_web.chain.upscale_swinir import UpscaleSwinIRStage
|
||||||
|
from onnx_web.params import HighresParams, UpscaleParams
|
||||||
|
from onnx_web.server.context import ServerContext
|
||||||
|
from onnx_web.worker.context import WorkerContext
|
||||||
|
from tests.helpers import test_device, test_needs_onnx_models
|
||||||
|
|
||||||
|
TEST_MODEL = "../models/upscaling-swinir-real-large-x4"
|
||||||
|
|
||||||
|
|
||||||
|
class UpscaleSwinIRStageTests(unittest.TestCase):
|
||||||
|
@test_needs_onnx_models([TEST_MODEL])
|
||||||
|
def test_empty(self):
|
||||||
|
stage = UpscaleSwinIRStage()
|
||||||
|
sources = StageResult.empty()
|
||||||
|
result = stage.run(
|
||||||
|
WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
3,
|
||||||
|
0.1,
|
||||||
|
),
|
||||||
|
ServerContext(
|
||||||
|
# model_path="../models",
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
sources,
|
||||||
|
highres=HighresParams(False, 1, 0, 0),
|
||||||
|
upscale=UpscaleParams(TEST_MODEL),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 0)
|
|
@ -11,6 +11,12 @@ def test_needs_models(models: List[str]):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_needs_onnx_models(models: List[str]):
|
||||||
|
return skipUnless(
|
||||||
|
all([path.exists(f"{model}.onnx") for model in models]), "model does not exist"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_device() -> DeviceParams:
|
def test_device() -> DeviceParams:
|
||||||
return DeviceParams("cpu", "CPUExecutionProvider")
|
return DeviceParams("cpu", "CPUExecutionProvider")
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,55 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
self.assertEqual(output.size, (256, 256))
|
self.assertEqual(output.size, (256, 256))
|
||||||
# TODO: test contents of image
|
# TODO: test contents of image
|
||||||
|
|
||||||
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||||
|
def test_batch(self):
|
||||||
|
cancel = Value("L", 0)
|
||||||
|
logs = Queue()
|
||||||
|
pending = Queue()
|
||||||
|
progress = Queue()
|
||||||
|
active = Value("L", 0)
|
||||||
|
idle = Value("L", 0)
|
||||||
|
|
||||||
|
worker = WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
cancel,
|
||||||
|
logs,
|
||||||
|
pending,
|
||||||
|
progress,
|
||||||
|
active,
|
||||||
|
idle,
|
||||||
|
3,
|
||||||
|
0.1,
|
||||||
|
)
|
||||||
|
worker.start("test")
|
||||||
|
|
||||||
|
run_txt2img_pipeline(
|
||||||
|
worker,
|
||||||
|
ServerContext(model_path="../models", output_path="../outputs"),
|
||||||
|
ImageParams(
|
||||||
|
TEST_MODEL_DIFFUSION_SD15,
|
||||||
|
"txt2img",
|
||||||
|
"ddim",
|
||||||
|
"an astronaut eating a hamburger",
|
||||||
|
3.0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
batch=2,
|
||||||
|
),
|
||||||
|
Size(256, 256),
|
||||||
|
["test-txt2img-batch-0.png", "test-txt2img-batch-1.png"],
|
||||||
|
UpscaleParams("test"),
|
||||||
|
HighresParams(False, 1, 0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(path.exists("../outputs/test-txt2img-batch-0.png"))
|
||||||
|
self.assertTrue(path.exists("../outputs/test-txt2img-batch-1.png"))
|
||||||
|
|
||||||
|
output = Image.open("../outputs/test-txt2img-batch-0.png")
|
||||||
|
self.assertEqual(output.size, (256, 256))
|
||||||
|
# TODO: test contents of image
|
||||||
|
|
||||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||||
def test_highres(self):
|
def test_highres(self):
|
||||||
cancel = Value("L", 0)
|
cancel = Value("L", 0)
|
||||||
|
@ -108,6 +157,54 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
output = Image.open("../outputs/test-txt2img-highres.png")
|
output = Image.open("../outputs/test-txt2img-highres.png")
|
||||||
self.assertEqual(output.size, (512, 512))
|
self.assertEqual(output.size, (512, 512))
|
||||||
|
|
||||||
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||||
|
def test_highres_batch(self):
|
||||||
|
cancel = Value("L", 0)
|
||||||
|
logs = Queue()
|
||||||
|
pending = Queue()
|
||||||
|
progress = Queue()
|
||||||
|
active = Value("L", 0)
|
||||||
|
idle = Value("L", 0)
|
||||||
|
|
||||||
|
worker = WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
cancel,
|
||||||
|
logs,
|
||||||
|
pending,
|
||||||
|
progress,
|
||||||
|
active,
|
||||||
|
idle,
|
||||||
|
3,
|
||||||
|
0.1,
|
||||||
|
)
|
||||||
|
worker.start("test")
|
||||||
|
|
||||||
|
run_txt2img_pipeline(
|
||||||
|
worker,
|
||||||
|
ServerContext(model_path="../models", output_path="../outputs"),
|
||||||
|
ImageParams(
|
||||||
|
TEST_MODEL_DIFFUSION_SD15,
|
||||||
|
"txt2img",
|
||||||
|
"ddim",
|
||||||
|
"an astronaut eating a hamburger",
|
||||||
|
3.0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
batch=2,
|
||||||
|
),
|
||||||
|
Size(256, 256),
|
||||||
|
["test-txt2img-highres-batch-0.png", "test-txt2img-highres-batch-1.png"],
|
||||||
|
UpscaleParams("test"),
|
||||||
|
HighresParams(True, 2, 0, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-0.png"))
|
||||||
|
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-1.png"))
|
||||||
|
|
||||||
|
output = Image.open("../outputs/test-txt2img-highres-batch-0.png")
|
||||||
|
self.assertEqual(output.size, (512, 512))
|
||||||
|
|
||||||
|
|
||||||
class TestImg2ImgPipeline(unittest.TestCase):
|
class TestImg2ImgPipeline(unittest.TestCase):
|
||||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||||
|
|
|
@ -89,10 +89,8 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||||
)
|
)
|
||||||
self.pool.start(lock)
|
self.pool.start(lock)
|
||||||
sleep(2.0)
|
|
||||||
|
|
||||||
self.pool.submit("test", test_job)
|
self.pool.submit("test", test_job)
|
||||||
sleep(2.0)
|
sleep(5.0)
|
||||||
|
|
||||||
pending, _progress = self.pool.done("test")
|
pending, _progress = self.pool.done("test")
|
||||||
self.assertFalse(pending)
|
self.assertFalse(pending)
|
||||||
|
@ -121,12 +119,10 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||||
)
|
)
|
||||||
self.pool.start()
|
self.pool.start()
|
||||||
sleep(2.0)
|
|
||||||
|
|
||||||
self.pool.submit("test", wait_job)
|
self.pool.submit("test", wait_job)
|
||||||
self.assertEqual(self.pool.done("test"), (True, None))
|
self.assertEqual(self.pool.done("test"), (True, None))
|
||||||
|
|
||||||
sleep(2.0)
|
sleep(5.0)
|
||||||
pending, _progress = self.pool.done("test")
|
pending, _progress = self.pool.done("test")
|
||||||
self.assertFalse(pending)
|
self.assertFalse(pending)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue