start testing pipeline loading
This commit is contained in:
parent
fc02fa6be1
commit
23aa00d696
|
@ -169,14 +169,15 @@ def load_pipeline(
|
|||
run_gc([device])
|
||||
|
||||
logger.debug("loading new diffusion pipeline from %s", model)
|
||||
scheduler = scheduler_type.from_pretrained(
|
||||
model,
|
||||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
subfolder="scheduler",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
components = {
|
||||
"scheduler": scheduler_type.from_pretrained(
|
||||
model,
|
||||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
subfolder="scheduler",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
|
||||
# shared components
|
||||
|
@ -257,7 +258,7 @@ def load_pipeline(
|
|||
patch_pipeline(server, pipe, pipeline_class, params)
|
||||
|
||||
server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
|
||||
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"])
|
||||
server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler)
|
||||
|
||||
for vae in VAE_COMPONENTS:
|
||||
if hasattr(pipe, vae):
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
from typing import List
|
||||
|
||||
|
||||
def test_with_models(models: List[str]):
|
||||
def wrapper(func):
|
||||
# TODO: check if models exist
|
||||
return func
|
||||
|
||||
return wrapper
|
|
@ -1,18 +1,26 @@
|
|||
import unittest
|
||||
from os import path
|
||||
|
||||
import torch
|
||||
from diffusers import DDIMScheduler
|
||||
|
||||
from onnx_web.diffusers.load import (
|
||||
get_available_pipelines,
|
||||
get_pipeline_schedulers,
|
||||
get_scheduler_name,
|
||||
load_controlnet,
|
||||
load_text_encoders,
|
||||
load_unet,
|
||||
load_vae,
|
||||
optimize_pipeline,
|
||||
patch_pipeline,
|
||||
)
|
||||
from onnx_web.diffusers.patches.unet import UNetWrapper
|
||||
from onnx_web.diffusers.patches.vae import VAEWrapper
|
||||
from onnx_web.params import ImageParams
|
||||
from onnx_web.models.meta import NetworkModel, NetworkType
|
||||
from onnx_web.params import DeviceParams, ImageParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from tests.helpers import test_with_models
|
||||
from tests.mocks import MockPipeline
|
||||
|
||||
|
||||
|
@ -130,3 +138,140 @@ class TestPatchPipeline(unittest.TestCase):
|
|||
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
|
||||
self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper))
|
||||
self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper))
|
||||
|
||||
|
||||
class TestLoadControlNet(unittest.TestCase):
|
||||
@unittest.skipUnless(path.exists("../models/control/canny.onnx"), "model does not exist")
|
||||
def test_load_existing(self):
|
||||
"""
|
||||
Should load a model
|
||||
"""
|
||||
components = load_controlnet(
|
||||
ServerContext(model_path="../models"),
|
||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("canny", "control")),
|
||||
)
|
||||
self.assertIn("controlnet", components)
|
||||
|
||||
def test_load_missing(self):
|
||||
"""
|
||||
Should throw
|
||||
"""
|
||||
components = {}
|
||||
try:
|
||||
components = load_controlnet(
|
||||
ServerContext(),
|
||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("missing", "control")),
|
||||
)
|
||||
except:
|
||||
self.assertNotIn("controlnet", components)
|
||||
return
|
||||
|
||||
self.fail()
|
||||
|
||||
|
||||
class TestLoadTextEncoders(unittest.TestCase):
|
||||
@unittest.skipUnless(lambda: path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist")
|
||||
def test_load_embeddings(self):
|
||||
"""
|
||||
Should add the token to tokenizer
|
||||
Should increase the encoder dims
|
||||
"""
|
||||
components = load_text_encoders(
|
||||
ServerContext(model_path="../models"),
|
||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||
"../models/stable-diffusion-onnx-v1-5",
|
||||
[
|
||||
# TODO: add some embeddings
|
||||
],
|
||||
[],
|
||||
torch.float32,
|
||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||
)
|
||||
self.assertIn("text_encoder", components)
|
||||
|
||||
def test_load_embeddings_xl(self):
|
||||
pass
|
||||
|
||||
@unittest.skipUnless(lambda: path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist")
|
||||
def test_load_loras(self):
|
||||
components = load_text_encoders(
|
||||
ServerContext(model_path="../models"),
|
||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||
"../models/stable-diffusion-onnx-v1-5",
|
||||
[],
|
||||
[
|
||||
# TODO: add some loras
|
||||
],
|
||||
torch.float32,
|
||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||
)
|
||||
self.assertIn("text_encoder", components)
|
||||
|
||||
def test_load_loras_xl(self):
|
||||
pass
|
||||
|
||||
class TestLoadUnet(unittest.TestCase):
|
||||
@unittest.skipUnless(lambda: path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"), "model does not exist")
|
||||
def test_load_unet_loras(self):
|
||||
components = load_unet(
|
||||
ServerContext(model_path="../models"),
|
||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||
"../models/stable-diffusion-onnx-v1-5",
|
||||
[
|
||||
# TODO: add some loras
|
||||
],
|
||||
"unet",
|
||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||
)
|
||||
self.assertIn("unet", components)
|
||||
|
||||
def test_load_unet_loras_xl(self):
|
||||
pass
|
||||
|
||||
@unittest.skipUnless(lambda: path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"), "model does not exist")
|
||||
def test_load_cnet_loras(self):
|
||||
components = load_unet(
|
||||
ServerContext(model_path="../models"),
|
||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||
"../models/stable-diffusion-onnx-v1-5",
|
||||
[
|
||||
# TODO: add some loras
|
||||
],
|
||||
"cnet",
|
||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||
)
|
||||
self.assertIn("unet", components)
|
||||
|
||||
|
||||
class TestLoadVae(unittest.TestCase):
|
||||
@unittest.skipUnless(lambda: path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"), "model does not exist")
|
||||
def test_load_single(self):
|
||||
"""
|
||||
Should return single component
|
||||
"""
|
||||
components = load_vae(
|
||||
ServerContext(model_path="../models"),
|
||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||
"../models/upscaling-stable-diffusion-x4",
|
||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||
)
|
||||
self.assertIn("vae", components)
|
||||
self.assertNotIn("vae_decoder", components)
|
||||
self.assertNotIn("vae_encoder", components)
|
||||
|
||||
@unittest.skipUnless(lambda: path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"), "model does not exist")
|
||||
def test_load_split(self):
|
||||
"""
|
||||
Should return split encoder/decoder
|
||||
"""
|
||||
components = load_vae(
|
||||
ServerContext(model_path="../models"),
|
||||
DeviceParams("cpu", "CPUExecutionProvider"),
|
||||
"../models/stable-diffusion-onnx-v1-5",
|
||||
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
|
||||
)
|
||||
self.assertNotIn("vae", components)
|
||||
self.assertIn("vae_decoder", components)
|
||||
self.assertIn("vae_encoder", components)
|
||||
|
|
|
@ -83,6 +83,7 @@
|
|||
"scandir",
|
||||
"scipy",
|
||||
"scrollback",
|
||||
"sdxl",
|
||||
"sess",
|
||||
"Singlestep",
|
||||
"spacy",
|
||||
|
|
Loading…
Reference in New Issue