2023-09-14 03:03:39 +00:00
|
|
|
import unittest
|
2023-09-26 02:57:25 +00:00
|
|
|
from os import path
|
2023-12-27 02:25:04 +00:00
|
|
|
from unittest.mock import MagicMock
|
2023-09-14 03:03:39 +00:00
|
|
|
|
2023-09-26 02:57:25 +00:00
|
|
|
import torch
|
2023-12-27 02:21:34 +00:00
|
|
|
from diffusers import DDIMScheduler, OnnxRuntimeModel
|
2023-09-14 03:03:39 +00:00
|
|
|
|
2023-09-14 03:04:31 +00:00
|
|
|
from onnx_web.diffusers.load import (
|
|
|
|
get_available_pipelines,
|
|
|
|
get_pipeline_schedulers,
|
|
|
|
get_scheduler_name,
|
2023-09-26 02:57:25 +00:00
|
|
|
load_controlnet,
|
|
|
|
load_text_encoders,
|
|
|
|
load_unet,
|
|
|
|
load_vae,
|
2023-09-16 00:16:47 +00:00
|
|
|
optimize_pipeline,
|
|
|
|
patch_pipeline,
|
2023-09-14 03:04:31 +00:00
|
|
|
)
|
2023-09-16 00:16:47 +00:00
|
|
|
from onnx_web.diffusers.patches.unet import UNetWrapper
|
|
|
|
from onnx_web.diffusers.patches.vae import VAEWrapper
|
2023-09-28 23:45:04 +00:00
|
|
|
from onnx_web.models.meta import NetworkModel
|
2023-09-26 02:57:25 +00:00
|
|
|
from onnx_web.params import DeviceParams, ImageParams
|
2023-09-16 00:16:47 +00:00
|
|
|
from onnx_web.server.context import ServerContext
|
|
|
|
from tests.mocks import MockPipeline
|
2023-09-14 03:04:31 +00:00
|
|
|
|
2023-09-14 03:03:39 +00:00
|
|
|
|
|
|
|
class TestAvailablePipelines(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_available_pipelines(self):
|
|
|
|
pipelines = get_available_pipelines()
|
2023-09-14 03:03:39 +00:00
|
|
|
|
2023-11-20 05:18:57 +00:00
|
|
|
self.assertIn("txt2img", pipelines)
|
2023-09-14 03:03:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestPipelineSchedulers(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_pipeline_schedulers(self):
|
|
|
|
schedulers = get_pipeline_schedulers()
|
2023-09-14 03:03:39 +00:00
|
|
|
|
2023-11-20 05:18:57 +00:00
|
|
|
self.assertIn("euler-a", schedulers)
|
2023-09-14 03:03:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestSchedulerNames(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_valid_name(self):
|
|
|
|
scheduler = get_scheduler_name(DDIMScheduler)
|
2023-09-14 03:03:39 +00:00
|
|
|
|
2023-11-20 05:18:57 +00:00
|
|
|
self.assertEqual("ddim", scheduler)
|
2023-09-14 03:03:39 +00:00
|
|
|
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_missing_names(self):
|
|
|
|
self.assertIsNone(get_scheduler_name("test"))
|
2023-09-14 03:03:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestOptimizePipeline(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_auto_attention_slicing(self):
|
|
|
|
server = ServerContext(
|
|
|
|
optimizations=[
|
|
|
|
"diffusers-attention-slicing-auto",
|
|
|
|
],
|
|
|
|
)
|
|
|
|
pipeline = MockPipeline()
|
|
|
|
optimize_pipeline(server, pipeline)
|
|
|
|
self.assertEqual(pipeline.slice_size, "auto")
|
|
|
|
|
|
|
|
def test_max_attention_slicing(self):
|
|
|
|
server = ServerContext(
|
|
|
|
optimizations=[
|
|
|
|
"diffusers-attention-slicing-max",
|
|
|
|
]
|
|
|
|
)
|
|
|
|
pipeline = MockPipeline()
|
|
|
|
optimize_pipeline(server, pipeline)
|
|
|
|
self.assertEqual(pipeline.slice_size, "max")
|
|
|
|
|
|
|
|
def test_vae_slicing(self):
|
|
|
|
server = ServerContext(
|
|
|
|
optimizations=[
|
|
|
|
"diffusers-vae-slicing",
|
|
|
|
]
|
|
|
|
)
|
|
|
|
pipeline = MockPipeline()
|
|
|
|
optimize_pipeline(server, pipeline)
|
|
|
|
self.assertEqual(pipeline.vae_slicing, True)
|
|
|
|
|
|
|
|
def test_cpu_offload_sequential(self):
|
|
|
|
server = ServerContext(
|
|
|
|
optimizations=[
|
|
|
|
"diffusers-cpu-offload-sequential",
|
|
|
|
]
|
|
|
|
)
|
|
|
|
pipeline = MockPipeline()
|
|
|
|
optimize_pipeline(server, pipeline)
|
|
|
|
self.assertEqual(pipeline.sequential_offload, True)
|
|
|
|
|
|
|
|
def test_cpu_offload_model(self):
|
|
|
|
server = ServerContext(
|
|
|
|
optimizations=[
|
|
|
|
"diffusers-cpu-offload-model",
|
|
|
|
]
|
|
|
|
)
|
|
|
|
pipeline = MockPipeline()
|
|
|
|
optimize_pipeline(server, pipeline)
|
|
|
|
self.assertEqual(pipeline.model_offload, True)
|
|
|
|
|
|
|
|
def test_memory_efficient_attention(self):
|
|
|
|
server = ServerContext(
|
|
|
|
optimizations=[
|
|
|
|
"diffusers-memory-efficient-attention",
|
|
|
|
]
|
|
|
|
)
|
|
|
|
pipeline = MockPipeline()
|
|
|
|
optimize_pipeline(server, pipeline)
|
|
|
|
self.assertEqual(pipeline.xformers, True)
|
2023-09-14 03:03:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestPatchPipeline(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_expand_not_lpw(self):
|
|
|
|
"""
|
|
|
|
server = ServerContext()
|
|
|
|
pipeline = MockPipeline()
|
|
|
|
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
|
|
|
|
self.assertEqual(pipeline._encode_prompt, expand_prompt)
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
def test_unet_wrapper_not_xl(self):
|
2023-12-27 02:21:34 +00:00
|
|
|
session = MagicMock()
|
|
|
|
session.get_inputs.return_value = []
|
|
|
|
|
2023-11-20 05:18:57 +00:00
|
|
|
server = ServerContext()
|
|
|
|
pipeline = MockPipeline()
|
2023-12-27 02:21:34 +00:00
|
|
|
pipeline.unet = OnnxRuntimeModel(model=session)
|
|
|
|
|
2023-11-20 05:18:57 +00:00
|
|
|
patch_pipeline(
|
|
|
|
server,
|
|
|
|
pipeline,
|
|
|
|
None,
|
|
|
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
|
|
|
)
|
|
|
|
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
|
|
|
|
|
|
|
|
def test_unet_wrapper_xl(self):
|
2023-12-27 02:21:34 +00:00
|
|
|
session = MagicMock()
|
|
|
|
session.get_inputs.return_value = []
|
|
|
|
|
2023-11-20 05:18:57 +00:00
|
|
|
server = ServerContext()
|
|
|
|
pipeline = MockPipeline()
|
2023-12-27 02:21:34 +00:00
|
|
|
pipeline.unet = OnnxRuntimeModel(model=session)
|
|
|
|
|
2023-11-20 05:18:57 +00:00
|
|
|
patch_pipeline(
|
|
|
|
server,
|
|
|
|
pipeline,
|
|
|
|
None,
|
|
|
|
ImageParams("test", "txt2img-sdxl", "ddim", "test", 1.0, 10, 1),
|
|
|
|
)
|
|
|
|
self.assertTrue(isinstance(pipeline.unet, UNetWrapper))
|
|
|
|
|
|
|
|
def test_vae_wrapper(self):
|
2023-12-27 02:21:34 +00:00
|
|
|
session = MagicMock()
|
|
|
|
session.get_inputs.return_value = []
|
|
|
|
|
2023-11-20 05:18:57 +00:00
|
|
|
server = ServerContext()
|
|
|
|
pipeline = MockPipeline()
|
2023-12-27 02:21:34 +00:00
|
|
|
pipeline.unet = OnnxRuntimeModel(model=session)
|
|
|
|
|
2023-11-20 05:18:57 +00:00
|
|
|
patch_pipeline(
|
|
|
|
server,
|
|
|
|
pipeline,
|
|
|
|
None,
|
|
|
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
|
|
|
)
|
|
|
|
self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper))
|
|
|
|
self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper))
|
2023-09-26 02:57:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestLoadControlNet(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
@unittest.skipUnless(
|
|
|
|
path.exists("../models/control/canny.onnx"), "model does not exist"
|
2023-09-26 02:57:25 +00:00
|
|
|
)
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_load_existing(self):
|
|
|
|
"""
|
|
|
|
Should load a model
|
|
|
|
"""
|
|
|
|
components = load_controlnet(
|
|
|
|
ServerContext(model_path="../models"),
|
|
|
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
|
|
|
ImageParams(
|
|
|
|
"test",
|
|
|
|
"txt2img",
|
|
|
|
"ddim",
|
|
|
|
"test",
|
|
|
|
1.0,
|
|
|
|
10,
|
|
|
|
1,
|
|
|
|
control=NetworkModel("canny", "control"),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
self.assertIn("controlnet", components)
|
|
|
|
|
|
|
|
def test_load_missing(self):
|
|
|
|
"""
|
|
|
|
Should throw
|
|
|
|
"""
|
|
|
|
components = {}
|
|
|
|
try:
|
|
|
|
components = load_controlnet(
|
|
|
|
ServerContext(),
|
|
|
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
|
|
|
ImageParams(
|
|
|
|
"test",
|
|
|
|
"txt2img",
|
|
|
|
"ddim",
|
|
|
|
"test",
|
|
|
|
1.0,
|
|
|
|
10,
|
|
|
|
1,
|
|
|
|
control=NetworkModel("missing", "control"),
|
|
|
|
),
|
|
|
|
)
|
|
|
|
except Exception:
|
|
|
|
self.assertNotIn("controlnet", components)
|
|
|
|
return
|
|
|
|
|
|
|
|
self.fail()
|
2023-09-26 02:57:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestLoadTextEncoders(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
@unittest.skipUnless(
|
|
|
|
path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"),
|
|
|
|
"model does not exist",
|
2023-09-26 02:57:25 +00:00
|
|
|
)
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_load_embeddings(self):
|
|
|
|
"""
|
|
|
|
Should add the token to tokenizer
|
|
|
|
Should increase the encoder dims
|
|
|
|
"""
|
|
|
|
components = load_text_encoders(
|
|
|
|
ServerContext(model_path="../models"),
|
|
|
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
|
|
|
"../models/stable-diffusion-onnx-v1-5",
|
|
|
|
[
|
|
|
|
# TODO: add some embeddings
|
|
|
|
],
|
|
|
|
[],
|
|
|
|
torch.float32,
|
|
|
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
|
|
|
)
|
|
|
|
self.assertIn("text_encoder", components)
|
|
|
|
|
|
|
|
def test_load_embeddings_xl(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
@unittest.skipUnless(
|
|
|
|
path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"),
|
|
|
|
"model does not exist",
|
2023-09-26 02:57:25 +00:00
|
|
|
)
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_load_loras(self):
|
|
|
|
components = load_text_encoders(
|
|
|
|
ServerContext(model_path="../models"),
|
|
|
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
|
|
|
"../models/stable-diffusion-onnx-v1-5",
|
|
|
|
[],
|
|
|
|
[
|
|
|
|
# TODO: add some loras
|
|
|
|
],
|
|
|
|
torch.float32,
|
|
|
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
|
|
|
)
|
|
|
|
self.assertIn("text_encoder", components)
|
|
|
|
|
|
|
|
def test_load_loras_xl(self):
|
|
|
|
pass
|
2023-09-26 02:57:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestLoadUnet(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
@unittest.skipUnless(
|
|
|
|
path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"),
|
|
|
|
"model does not exist",
|
2023-09-26 02:57:25 +00:00
|
|
|
)
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_load_unet_loras(self):
|
|
|
|
components = load_unet(
|
|
|
|
ServerContext(model_path="../models"),
|
|
|
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
|
|
|
"../models/stable-diffusion-onnx-v1-5",
|
|
|
|
[
|
|
|
|
# TODO: add some loras
|
|
|
|
],
|
|
|
|
"unet",
|
|
|
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
|
|
|
)
|
|
|
|
self.assertIn("unet", components)
|
|
|
|
|
|
|
|
def test_load_unet_loras_xl(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
@unittest.skipUnless(
|
|
|
|
path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"),
|
|
|
|
"model does not exist",
|
2023-09-26 02:57:25 +00:00
|
|
|
)
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_load_cnet_loras(self):
|
|
|
|
components = load_unet(
|
|
|
|
ServerContext(model_path="../models"),
|
|
|
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
|
|
|
"../models/stable-diffusion-onnx-v1-5",
|
|
|
|
[
|
|
|
|
# TODO: add some loras
|
|
|
|
],
|
|
|
|
"cnet",
|
|
|
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
|
|
|
)
|
|
|
|
self.assertIn("unet", components)
|
2023-09-26 02:57:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
class TestLoadVae(unittest.TestCase):
|
2023-11-20 05:18:57 +00:00
|
|
|
@unittest.skipUnless(
|
|
|
|
path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"),
|
|
|
|
"model does not exist",
|
2023-09-26 02:57:25 +00:00
|
|
|
)
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_load_single(self):
|
|
|
|
"""
|
|
|
|
Should return single component
|
|
|
|
"""
|
|
|
|
components = load_vae(
|
|
|
|
ServerContext(model_path="../models"),
|
|
|
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
|
|
|
"../models/upscaling-stable-diffusion-x4",
|
|
|
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
|
|
|
)
|
|
|
|
self.assertIn("vae", components)
|
|
|
|
self.assertNotIn("vae_decoder", components)
|
|
|
|
self.assertNotIn("vae_encoder", components)
|
|
|
|
|
|
|
|
@unittest.skipUnless(
|
|
|
|
path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"),
|
|
|
|
"model does not exist",
|
2023-09-26 02:57:25 +00:00
|
|
|
)
|
2023-11-20 05:18:57 +00:00
|
|
|
def test_load_split(self):
|
|
|
|
"""
|
|
|
|
Should return split encoder/decoder
|
|
|
|
"""
|
|
|
|
components = load_vae(
|
|
|
|
ServerContext(model_path="../models"),
|
|
|
|
DeviceParams("cpu", "CPUExecutionProvider"),
|
|
|
|
"../models/stable-diffusion-onnx-v1-5",
|
|
|
|
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
|
|
|
)
|
|
|
|
self.assertNotIn("vae", components)
|
|
|
|
self.assertIn("vae_decoder", components)
|
|
|
|
self.assertIn("vae_encoder", components)
|