more tests, apply lint
This commit is contained in:
parent
66dfa7206a
commit
f00bfe9bd0
|
@ -139,7 +139,9 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
|
|||
)
|
||||
elif block == "text_model" or simple:
|
||||
match = next(
|
||||
node for node in remaining if fix_node_name(node.name) == f"{root}_MatMul"
|
||||
node
|
||||
for node in remaining
|
||||
if fix_node_name(node.name) == f"{root}_MatMul"
|
||||
)
|
||||
else:
|
||||
# search in order. one side has sparse indices, so they will not match.
|
||||
|
@ -172,6 +174,12 @@ def fix_xl_names(keys: Dict[str, Any], nodes: List[NodeProto]) -> Dict[str, Any]
|
|||
fixed[name] = value
|
||||
remaining.remove(match)
|
||||
|
||||
logger.debug(
|
||||
"SDXL LoRA key fixup matched %s keys, %s remaining",
|
||||
len(fixed.keys()),
|
||||
len(remaining),
|
||||
)
|
||||
|
||||
return fixed
|
||||
|
||||
|
||||
|
|
|
@ -13,8 +13,8 @@ import numpy as np
|
|||
import PIL
|
||||
import torch
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||
|
|
|
@ -19,8 +19,8 @@ import numpy as np
|
|||
import PIL
|
||||
import torch
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
|
||||
|
|
|
@ -11,8 +11,8 @@ from typing import Any, Callable, List, Optional, Union
|
|||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionUpscalePipeline
|
||||
from diffusers.schedulers import DDPMScheduler
|
||||
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.chain.correct_gfpgan import CorrectGFPGANStage
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.params import HighresParams, UpscaleParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.server.hacks import apply_patches
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from tests.helpers import test_device, test_needs_onnx_models
|
||||
|
||||
TEST_MODEL = "../models/correction-gfpgan-v1-3"
|
||||
|
||||
|
||||
class CorrectGFPGANStageTests(unittest.TestCase):
|
||||
@test_needs_onnx_models([TEST_MODEL])
|
||||
def test_empty(self):
|
||||
server = ServerContext(model_path="../models", output_path="../outputs")
|
||||
apply_patches(server)
|
||||
|
||||
worker = WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0,
|
||||
0.1,
|
||||
)
|
||||
stage = CorrectGFPGANStage()
|
||||
sources = StageResult.empty()
|
||||
result = stage.run(
|
||||
worker,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(TEST_MODEL),
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,41 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.chain.upscale_bsrgan import UpscaleBSRGANStage
|
||||
from onnx_web.params import HighresParams, UpscaleParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from tests.helpers import test_device, test_needs_onnx_models
|
||||
|
||||
TEST_MODEL = "../models/upscaling-bsrgan-x4"
|
||||
|
||||
|
||||
class UpscaleBSRGANStageTests(unittest.TestCase):
|
||||
@test_needs_onnx_models([TEST_MODEL])
|
||||
def test_empty(self):
|
||||
stage = UpscaleBSRGANStage()
|
||||
sources = StageResult.empty()
|
||||
result = stage.run(
|
||||
WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
3,
|
||||
0.1,
|
||||
),
|
||||
ServerContext(
|
||||
model_path="../models",
|
||||
),
|
||||
None,
|
||||
None,
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(TEST_MODEL),
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,52 @@
|
|||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.chain.upscale_outpaint import UpscaleOutpaintStage
|
||||
from onnx_web.params import Border, HighresParams, ImageParams, UpscaleParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from tests.helpers import test_device, test_needs_models
|
||||
|
||||
|
||||
class UpscaleOutpaintStageTests(unittest.TestCase):
|
||||
@test_needs_models(["../models/stable-diffusion-onnx-v1-inpainting"])
|
||||
def test_empty(self):
|
||||
stage = UpscaleOutpaintStage()
|
||||
sources = StageResult.empty()
|
||||
result = stage.run(
|
||||
WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
3,
|
||||
0.1,
|
||||
),
|
||||
ServerContext(
|
||||
# model_path="../models",
|
||||
),
|
||||
None,
|
||||
ImageParams(
|
||||
"../models/stable-diffusion-onnx-v1-inpainting",
|
||||
"inpaint",
|
||||
"euler",
|
||||
"test",
|
||||
5.0,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams("stable-diffusion-onnx-v1-inpainting"),
|
||||
border=Border.even(0),
|
||||
dims=(),
|
||||
tile_mask=Image.new("RGB", (64, 64)),
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,39 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.chain.upscale_resrgan import UpscaleRealESRGANStage
|
||||
from onnx_web.params import HighresParams, StageParams, UpscaleParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from tests.helpers import test_device, test_needs_onnx_models
|
||||
|
||||
TEST_MODEL = "../models/upscaling-real-esrgan-x4-v3"
|
||||
|
||||
|
||||
class UpscaleRealESRGANStageTests(unittest.TestCase):
|
||||
@test_needs_onnx_models([TEST_MODEL])
|
||||
def test_empty(self):
|
||||
stage = UpscaleRealESRGANStage()
|
||||
sources = StageResult.empty()
|
||||
result = stage.run(
|
||||
WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
3,
|
||||
0.1,
|
||||
),
|
||||
ServerContext(model_path="../models"),
|
||||
StageParams(),
|
||||
None,
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams("upscaling-real-esrgan-x4-v3"),
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
|
@ -0,0 +1,41 @@
|
|||
import unittest
|
||||
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.chain.upscale_swinir import UpscaleSwinIRStage
|
||||
from onnx_web.params import HighresParams, UpscaleParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from tests.helpers import test_device, test_needs_onnx_models
|
||||
|
||||
TEST_MODEL = "../models/upscaling-swinir-real-large-x4"
|
||||
|
||||
|
||||
class UpscaleSwinIRStageTests(unittest.TestCase):
|
||||
@test_needs_onnx_models([TEST_MODEL])
|
||||
def test_empty(self):
|
||||
stage = UpscaleSwinIRStage()
|
||||
sources = StageResult.empty()
|
||||
result = stage.run(
|
||||
WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
3,
|
||||
0.1,
|
||||
),
|
||||
ServerContext(
|
||||
# model_path="../models",
|
||||
),
|
||||
None,
|
||||
None,
|
||||
sources,
|
||||
highres=HighresParams(False, 1, 0, 0),
|
||||
upscale=UpscaleParams(TEST_MODEL),
|
||||
)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
|
@ -11,6 +11,12 @@ def test_needs_models(models: List[str]):
|
|||
)
|
||||
|
||||
|
||||
def test_needs_onnx_models(models: List[str]):
|
||||
return skipUnless(
|
||||
all([path.exists(f"{model}.onnx") for model in models]), "model does not exist"
|
||||
)
|
||||
|
||||
|
||||
def test_device() -> DeviceParams:
|
||||
return DeviceParams("cpu", "CPUExecutionProvider")
|
||||
|
||||
|
|
|
@ -63,6 +63,55 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
self.assertEqual(output.size, (256, 256))
|
||||
# TODO: test contents of image
|
||||
|
||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||
def test_batch(self):
|
||||
cancel = Value("L", 0)
|
||||
logs = Queue()
|
||||
pending = Queue()
|
||||
progress = Queue()
|
||||
active = Value("L", 0)
|
||||
idle = Value("L", 0)
|
||||
|
||||
worker = WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
cancel,
|
||||
logs,
|
||||
pending,
|
||||
progress,
|
||||
active,
|
||||
idle,
|
||||
3,
|
||||
0.1,
|
||||
)
|
||||
worker.start("test")
|
||||
|
||||
run_txt2img_pipeline(
|
||||
worker,
|
||||
ServerContext(model_path="../models", output_path="../outputs"),
|
||||
ImageParams(
|
||||
TEST_MODEL_DIFFUSION_SD15,
|
||||
"txt2img",
|
||||
"ddim",
|
||||
"an astronaut eating a hamburger",
|
||||
3.0,
|
||||
1,
|
||||
1,
|
||||
batch=2,
|
||||
),
|
||||
Size(256, 256),
|
||||
["test-txt2img-batch-0.png", "test-txt2img-batch-1.png"],
|
||||
UpscaleParams("test"),
|
||||
HighresParams(False, 1, 0, 0),
|
||||
)
|
||||
|
||||
self.assertTrue(path.exists("../outputs/test-txt2img-batch-0.png"))
|
||||
self.assertTrue(path.exists("../outputs/test-txt2img-batch-1.png"))
|
||||
|
||||
output = Image.open("../outputs/test-txt2img-batch-0.png")
|
||||
self.assertEqual(output.size, (256, 256))
|
||||
# TODO: test contents of image
|
||||
|
||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||
def test_highres(self):
|
||||
cancel = Value("L", 0)
|
||||
|
@ -108,6 +157,54 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
output = Image.open("../outputs/test-txt2img-highres.png")
|
||||
self.assertEqual(output.size, (512, 512))
|
||||
|
||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||
def test_highres_batch(self):
|
||||
cancel = Value("L", 0)
|
||||
logs = Queue()
|
||||
pending = Queue()
|
||||
progress = Queue()
|
||||
active = Value("L", 0)
|
||||
idle = Value("L", 0)
|
||||
|
||||
worker = WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
cancel,
|
||||
logs,
|
||||
pending,
|
||||
progress,
|
||||
active,
|
||||
idle,
|
||||
3,
|
||||
0.1,
|
||||
)
|
||||
worker.start("test")
|
||||
|
||||
run_txt2img_pipeline(
|
||||
worker,
|
||||
ServerContext(model_path="../models", output_path="../outputs"),
|
||||
ImageParams(
|
||||
TEST_MODEL_DIFFUSION_SD15,
|
||||
"txt2img",
|
||||
"ddim",
|
||||
"an astronaut eating a hamburger",
|
||||
3.0,
|
||||
1,
|
||||
1,
|
||||
batch=2,
|
||||
),
|
||||
Size(256, 256),
|
||||
["test-txt2img-highres-batch-0.png", "test-txt2img-highres-batch-1.png"],
|
||||
UpscaleParams("test"),
|
||||
HighresParams(True, 2, 0, 0),
|
||||
)
|
||||
|
||||
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-0.png"))
|
||||
self.assertTrue(path.exists("../outputs/test-txt2img-highres-batch-1.png"))
|
||||
|
||||
output = Image.open("../outputs/test-txt2img-highres-batch-0.png")
|
||||
self.assertEqual(output.size, (512, 512))
|
||||
|
||||
|
||||
class TestImg2ImgPipeline(unittest.TestCase):
|
||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||
|
|
|
@ -89,10 +89,8 @@ class TestWorkerPool(unittest.TestCase):
|
|||
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||
)
|
||||
self.pool.start(lock)
|
||||
sleep(2.0)
|
||||
|
||||
self.pool.submit("test", test_job)
|
||||
sleep(2.0)
|
||||
sleep(5.0)
|
||||
|
||||
pending, _progress = self.pool.done("test")
|
||||
self.assertFalse(pending)
|
||||
|
@ -121,12 +119,10 @@ class TestWorkerPool(unittest.TestCase):
|
|||
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||
)
|
||||
self.pool.start()
|
||||
sleep(2.0)
|
||||
|
||||
self.pool.submit("test", wait_job)
|
||||
self.assertEqual(self.pool.done("test"), (True, None))
|
||||
|
||||
sleep(2.0)
|
||||
sleep(5.0)
|
||||
pending, _progress = self.pool.done("test")
|
||||
self.assertFalse(pending)
|
||||
|
||||
|
|
Loading…
Reference in New Issue