1
0
Fork 0

start testing pipeline loading

This commit is contained in:
Sean Sube 2023-09-25 21:57:25 -05:00
parent fc02fa6be1
commit 23aa00d696
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 165 additions and 9 deletions

View File

@ -169,14 +169,15 @@ def load_pipeline(
run_gc([device]) run_gc([device])
logger.debug("loading new diffusion pipeline from %s", model) logger.debug("loading new diffusion pipeline from %s", model)
scheduler = scheduler_type.from_pretrained(
model,
provider=device.ort_provider(),
sess_options=device.sess_options(),
subfolder="scheduler",
torch_dtype=torch_dtype,
)
components = { components = {
"scheduler": scheduler_type.from_pretrained( "scheduler": scheduler,
model,
provider=device.ort_provider(),
sess_options=device.sess_options(),
subfolder="scheduler",
torch_dtype=torch_dtype,
)
} }
# shared components # shared components
@ -257,7 +258,7 @@ def load_pipeline(
patch_pipeline(server, pipe, pipeline_class, params) patch_pipeline(server, pipe, pipeline_class, params)
server.cache.set(ModelTypes.diffusion, pipe_key, pipe) server.cache.set(ModelTypes.diffusion, pipe_key, pipe)
server.cache.set(ModelTypes.scheduler, scheduler_key, components["scheduler"]) server.cache.set(ModelTypes.scheduler, scheduler_key, scheduler)
for vae in VAE_COMPONENTS: for vae in VAE_COMPONENTS:
if hasattr(pipe, vae): if hasattr(pipe, vae):

9
api/tests/helpers.py Normal file
View File

@ -0,0 +1,9 @@
from typing import List
def test_with_models(models: List[str]):
def wrapper(func):
# TODO: check if models exist
return func
return wrapper

View File

@ -1,18 +1,26 @@
import unittest import unittest
from os import path
import torch
from diffusers import DDIMScheduler from diffusers import DDIMScheduler
from onnx_web.diffusers.load import ( from onnx_web.diffusers.load import (
get_available_pipelines, get_available_pipelines,
get_pipeline_schedulers, get_pipeline_schedulers,
get_scheduler_name, get_scheduler_name,
load_controlnet,
load_text_encoders,
load_unet,
load_vae,
optimize_pipeline, optimize_pipeline,
patch_pipeline, patch_pipeline,
) )
from onnx_web.diffusers.patches.unet import UNetWrapper from onnx_web.diffusers.patches.unet import UNetWrapper
from onnx_web.diffusers.patches.vae import VAEWrapper from onnx_web.diffusers.patches.vae import VAEWrapper
from onnx_web.params import ImageParams from onnx_web.models.meta import NetworkModel, NetworkType
from onnx_web.params import DeviceParams, ImageParams
from onnx_web.server.context import ServerContext from onnx_web.server.context import ServerContext
from tests.helpers import test_with_models
from tests.mocks import MockPipeline from tests.mocks import MockPipeline
@ -130,3 +138,140 @@ class TestPatchPipeline(unittest.TestCase):
patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1)) patch_pipeline(server, pipeline, None, ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1))
self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper)) self.assertTrue(isinstance(pipeline.vae_decoder, VAEWrapper))
self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper)) self.assertTrue(isinstance(pipeline.vae_encoder, VAEWrapper))
class TestLoadControlNet(unittest.TestCase):
@unittest.skipUnless(path.exists("../models/control/canny.onnx"), "model does not exist")
def test_load_existing(self):
"""
Should load a model
"""
components = load_controlnet(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("canny", "control")),
)
self.assertIn("controlnet", components)
def test_load_missing(self):
"""
Should throw
"""
components = {}
try:
components = load_controlnet(
ServerContext(),
DeviceParams("cpu", "CPUExecutionProvider"),
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1, control=NetworkModel("missing", "control")),
)
except:
self.assertNotIn("controlnet", components)
return
self.fail()
class TestLoadTextEncoders(unittest.TestCase):
@unittest.skipUnless(lambda: path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist")
def test_load_embeddings(self):
"""
Should add the token to tokenizer
Should increase the encoder dims
"""
components = load_text_encoders(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[
# TODO: add some embeddings
],
[],
torch.float32,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertIn("text_encoder", components)
def test_load_embeddings_xl(self):
pass
@unittest.skipUnless(lambda: path.exists("../models/stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), "model does not exist")
def test_load_loras(self):
components = load_text_encoders(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[],
[
# TODO: add some loras
],
torch.float32,
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertIn("text_encoder", components)
def test_load_loras_xl(self):
pass
class TestLoadUnet(unittest.TestCase):
@unittest.skipUnless(lambda: path.exists("../models/stable-diffusion-onnx-v1-5/unet/model.onnx"), "model does not exist")
def test_load_unet_loras(self):
components = load_unet(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[
# TODO: add some loras
],
"unet",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertIn("unet", components)
def test_load_unet_loras_xl(self):
pass
@unittest.skipUnless(lambda: path.exists("../models/stable-diffusion-onnx-v1-5/cnet/model.onnx"), "model does not exist")
def test_load_cnet_loras(self):
components = load_unet(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
[
# TODO: add some loras
],
"cnet",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertIn("unet", components)
class TestLoadVae(unittest.TestCase):
@unittest.skipUnless(lambda: path.exists("../models/upscaling-stable-diffusion-x4/vae/model.onnx"), "model does not exist")
def test_load_single(self):
"""
Should return single component
"""
components = load_vae(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/upscaling-stable-diffusion-x4",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertIn("vae", components)
self.assertNotIn("vae_decoder", components)
self.assertNotIn("vae_encoder", components)
@unittest.skipUnless(lambda: path.exists("../models/stable-diffusion-onnx-v1-5/vae_encoder/model.onnx"), "model does not exist")
def test_load_split(self):
"""
Should return split encoder/decoder
"""
components = load_vae(
ServerContext(model_path="../models"),
DeviceParams("cpu", "CPUExecutionProvider"),
"../models/stable-diffusion-onnx-v1-5",
ImageParams("test", "txt2img", "ddim", "test", 1.0, 10, 1),
)
self.assertNotIn("vae", components)
self.assertIn("vae_decoder", components)
self.assertIn("vae_encoder", components)

View File

@ -83,6 +83,7 @@
"scandir", "scandir",
"scipy", "scipy",
"scrollback", "scrollback",
"sdxl",
"sess", "sess",
"Singlestep", "Singlestep",
"spacy", "spacy",