1
0
Fork 0

fix(api): be more careful with VAE patch flags, add margin to latents if needed

This commit is contained in:
Sean Sube 2023-11-25 23:18:57 -06:00
parent 57fc183b15
commit 93e3125e28
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 125 additions and 29 deletions

View File

@ -253,10 +253,11 @@ def load_pipeline(
for vae in VAE_COMPONENTS:
if hasattr(pipe, vae):
vae_model = getattr(pipe, vae)
vae_model.set_tiled(tiled=params.tiled_vae)
vae_model.set_window_size(
params.vae_tile // LATENT_FACTOR, params.vae_overlap
)
if isinstance(vae_model, VAEWrapper):
vae_model.set_tiled(tiled=params.tiled_vae)
vae_model.set_window_size(
params.vae_tile // LATENT_FACTOR, params.vae_overlap
)
# update panorama params
if params.is_panorama():

View File

@ -300,9 +300,7 @@ def get_tile_latents(
tile_latents = full_latents[:, :, y:yt, x:xt]
if tile_latents.shape != full_latents.shape and (
tile_latents.shape[2] < t or tile_latents.shape[3] < t
):
if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0])
extra_latents[
:, :, 0 : tile_latents.shape[2], 0 : tile_latents.shape[3]

View File

@ -1,8 +1,10 @@
from multiprocessing import Queue, Value
from os import path
from typing import List
from unittest import skipUnless
from onnx_web.params import DeviceParams
from onnx_web.worker.context import WorkerContext
def test_needs_models(models: List[str]):
@ -21,6 +23,29 @@ def test_device() -> DeviceParams:
return DeviceParams("cpu", "CPUExecutionProvider")
def test_worker() -> WorkerContext:
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
return WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
TEST_MODEL_CORRECTION_CODEFORMER = "../models/.cache/correction-codeformer.pth"
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
TEST_MODEL_DIFFUSION_SD15_INPAINT = "../models/stable-diffusion-onnx-v1-inpainting"
TEST_MODEL_UPSCALING_SWINIR = "../models/.cache/upscaling-swinir.pth"

View File

@ -7,13 +7,29 @@ from PIL import Image
from onnx_web.diffusers.run import (
run_blend_pipeline,
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
run_upscale_pipeline,
)
from onnx_web.params import HighresParams, ImageParams, Size, UpscaleParams
from onnx_web.image.mask_filter import mask_filter_none
from onnx_web.image.noise_source import noise_source_uniform
from onnx_web.params import (
Border,
HighresParams,
ImageParams,
Size,
TileOrder,
UpscaleParams,
)
from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
from tests.helpers import (
TEST_MODEL_DIFFUSION_SD15,
TEST_MODEL_DIFFUSION_SD15_INPAINT,
test_device,
test_needs_models,
test_worker,
)
TEST_PROMPT = "an astronaut eating a hamburger"
TEST_SCHEDULER = "ddim"
@ -213,25 +229,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
class TestImg2ImgPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_basic(self):
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker = test_worker()
worker.start("test")
source = Image.new("RGB", (64, 64), "black")
@ -257,6 +255,80 @@ class TestImg2ImgPipeline(unittest.TestCase):
self.assertTrue(path.exists("../outputs/test-img2img.png"))
class TestInpaintPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
def test_basic_white(self):
worker = test_worker()
worker.start("test")
source = Image.new("RGB", (64, 64), "black")
mask = Image.new("RGB", (64, 64), "white")
run_inpaint_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15_INPAINT,
"txt2img",
TEST_SCHEDULER,
TEST_PROMPT,
3.0,
1,
1,
),
Size(*source.size),
["test-inpaint-white.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
mask,
Border.even(0),
noise_source_uniform,
mask_filter_none,
"white",
TileOrder.spiral,
False,
0.0,
)
self.assertTrue(path.exists("../outputs/test-inpaint-white.png"))
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
def test_basic_black(self):
worker = test_worker()
worker.start("test")
source = Image.new("RGB", (64, 64), "black")
mask = Image.new("RGB", (64, 64), "black")
run_inpaint_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15_INPAINT,
"txt2img",
TEST_SCHEDULER,
TEST_PROMPT,
3.0,
1,
1,
),
Size(*source.size),
["test-inpaint-black.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
mask,
Border.even(0),
noise_source_uniform,
mask_filter_none,
"black",
TileOrder.spiral,
False,
0.0,
)
self.assertTrue(path.exists("../outputs/test-inpaint-black.png"))
class TestUpscalePipeline(unittest.TestCase):
@test_needs_models(["../models/upscaling-stable-diffusion-x4"])
def test_basic(self):