diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index eeea47d4..38b61e50 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -169,14 +169,15 @@ def load_pipeline( run_gc([device]) logger.debug("loading new diffusion pipeline from %s", model) + scheduler = scheduler_type.from_pretrained( + model, + provider=device.ort_provider(), + sess_options=device.sess_options(), + subfolder="scheduler", + torch_dtype=torch_dtype, + ) components = { - "scheduler": scheduler_type.from_pretrained( - model, - provider=device.ort_provider(), - sess_options=device.sess_options(), - subfolder="scheduler", - torch_dtype=torch_dtype, - ) + "scheduler": scheduler, } # shared components @@ -257,7 +258,7 @@ def load_pipeline( patch_pipeline(server, pipe, pipeline_class, params) server.cache.set(ModelTypes.diffusion, pipe_key, pipe) - server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"]) + server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler) for vae in VAE_COMPONENTS: if hasattr(pipe, vae): diff --git a/api/tests/helpers.py b/api/tests/helpers.py new file mode 100644 index 00000000..3fc5cc7b --- /dev/null +++ b/api/tests/helpers.py @@ -0,0 +1,9 @@ +from typing import List + + +def test_with_models(models: List[str]): + def wrapper(func): + # TODO: check if models exist + return func + + return wrapper diff --git a/api/tests/test_diffusers/test_load.py b/api/tests/test_diffusers/test_load.py index beaab12c..7b811497 100644 --- a/api/tests/test_diffusers/test_load.py +++ b/api/tests/test_diffusers/test_load.py @@ -1,18 +1,26 @@ import unittest +from os import path +import torch from diffusers import DDIMScheduler from onnx_web.diffusers.load import ( get_available_pipelines, get_pipeline_schedulers, get_scheduler_name, + load_controlnet, + load_text_encoders, + load_unet, + load_vae, optimize_pipeline, patch_pipeline, ) from onnx_web.diffusers.patches.unet import UNetWrapper from onnx_web.diffusers.patches.vae import VAEWrapper -from onnx_web.params import ImageParams +from onnx_web.models.meta import NetworkModel, NetworkType +from onnx_web.params import DeviceParams, ImageParams from onnx_web.server.context import ServerContext +from tests.helpers import test_with_models from tests.mocks import MockPipeline @@ -130,3 +138,140 @@ class TestPatchPipeline(unittest.TestCase): patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1)) self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper)) self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper)) + + +class TestLoadControlNet(unittest.TestCase): + @unittest.skipUnless(path.exists("../models/control/canny.onnx"), "model does not exist") + def test_load_existing(self): + """ + Should load a model + """ + components = load_controlnet( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("canny", "control")), + ) + self.assertIn("controlnet", components) + + def test_load_missing(self): + """ + Should throw + """ + components = {} + try: + components = load_controlnet( + ServerContext(), + DeviceParams("cpu", "CPUExecutionProvider"), + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("missing", "control")), + ) + except: + self.assertNotIn("controlnet", components) + return + + self.fail() + + +class TestLoadTextEncoders(unittest.TestCase): + @unittest.skipUnless(lambda: path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist") + def test_load_embeddings(self): + """ + Should add the token to tokenizer + Should increase the encoder dims + """ + components = load_text_encoders( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/stable-diffusion-onnx-v1-5", + [ + # TODO: add some embeddings + ], + [], + torch.float32, + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertIn("text_encoder", components) + + def test_load_embeddings_xl(self): + pass + + @unittest.skipUnless(lambda: path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist") + def test_load_loras(self): + components = load_text_encoders( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/stable-diffusion-onnx-v1-5", + [], + [ + # TODO: add some loras + ], + torch.float32, + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertIn("text_encoder", components) + + def test_load_loras_xl(self): + pass + +class TestLoadUnet(unittest.TestCase): + @unittest.skipUnless(lambda: path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"), "model does not exist") + def test_load_unet_loras(self): + components = load_unet( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/stable-diffusion-onnx-v1-5", + [ + # TODO: add some loras + ], + "unet", + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertIn("unet", components) + + def test_load_unet_loras_xl(self): + pass + + @unittest.skipUnless(lambda: path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"), "model does not exist") + def test_load_cnet_loras(self): + components = load_unet( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/stable-diffusion-onnx-v1-5", + [ + # TODO: add some loras + ], + "cnet", + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertIn("unet", components) + + +class TestLoadVae(unittest.TestCase): + @unittest.skipUnless(lambda: path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"), "model does not exist") + def test_load_single(self): + """ + Should return single component + """ + components = load_vae( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/upscaling-stable-diffusion-x4", + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertIn("vae", components) + self.assertNotIn("vae_decoder", components) + self.assertNotIn("vae_encoder", components) + + @unittest.skipUnless(lambda: path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"), "model does not exist") + def test_load_split(self): + """ + Should return split encoder/decoder + """ + components = load_vae( + ServerContext(model_path="../models"), + DeviceParams("cpu", "CPUExecutionProvider"), + "../models/stable-diffusion-onnx-v1-5", + ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1), + ) + self.assertNotIn("vae", components) + self.assertIn("vae_decoder", components) + self.assertIn("vae_encoder", components) diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace index 08f791d9..d3965a5f 100644 --- a/onnx-web.code-workspace +++ b/onnx-web.code-workspace @@ -83,6 +83,7 @@ "scandir", "scipy", "scrollback", + "sdxl", "sess", "Singlestep", "spacy",