1
0
Fork 0

fix(api): be more careful with VAE patch flags, add margin to latents if needed

This commit is contained in:
Sean Sube 2023-11-25 23:18:57 -06:00
parent 57fc183b15
commit 93e3125e28
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 125 additions and 29 deletions

View File

@ -64,4 +64,4 @@ def shape_mode(arr: np.ndarray) -> str:
elif arr.shape[-1] == 4: elif arr.shape[-1] == 4:
return "RGBA" return "RGBA"
raise ValueError("unknown image format") raise ValueError("unknown image format")

View File

@ -253,10 +253,11 @@ def load_pipeline(
for vae in VAE_COMPONENTS: for vae in VAE_COMPONENTS:
if hasattr(pipe, vae): if hasattr(pipe, vae):
vae_model = getattr(pipe, vae) vae_model = getattr(pipe, vae)
vae_model.set_tiled(tiled=params.tiled_vae) if isinstance(vae_model, VAEWrapper):
vae_model.set_window_size( vae_model.set_tiled(tiled=params.tiled_vae)
params.vae_tile // LATENT_FACTOR, params.vae_overlap vae_model.set_window_size(
) params.vae_tile // LATENT_FACTOR, params.vae_overlap
)
# update panorama params # update panorama params
if params.is_panorama(): if params.is_panorama():

View File

@ -300,9 +300,7 @@ def get_tile_latents(
tile_latents = full_latents[:, :, y:yt, x:xt] tile_latents = full_latents[:, :, y:yt, x:xt]
if tile_latents.shape != full_latents.shape and ( if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
tile_latents.shape[2] < t or tile_latents.shape[3] < t
):
extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0]) extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0])
extra_latents[ extra_latents[
:, :, 0 : tile_latents.shape[2], 0 : tile_latents.shape[3] :, :, 0 : tile_latents.shape[2], 0 : tile_latents.shape[3]

View File

@ -1,8 +1,10 @@
from multiprocessing import Queue, Value
from os import path from os import path
from typing import List from typing import List
from unittest import skipUnless from unittest import skipUnless
from onnx_web.params import DeviceParams from onnx_web.params import DeviceParams
from onnx_web.worker.context import WorkerContext
def test_needs_models(models: List[str]): def test_needs_models(models: List[str]):
@ -21,6 +23,29 @@ def test_device() -> DeviceParams:
return DeviceParams("cpu", "CPUExecutionProvider") return DeviceParams("cpu", "CPUExecutionProvider")
def test_worker() -> WorkerContext:
cancel = Value("L", 0)
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
return WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
TEST_MODEL_CORRECTION_CODEFORMER = "../models/.cache/correction-codeformer.pth" TEST_MODEL_CORRECTION_CODEFORMER = "../models/.cache/correction-codeformer.pth"
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5" TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
TEST_MODEL_DIFFUSION_SD15_INPAINT = "../models/stable-diffusion-onnx-v1-inpainting"
TEST_MODEL_UPSCALING_SWINIR = "../models/.cache/upscaling-swinir.pth" TEST_MODEL_UPSCALING_SWINIR = "../models/.cache/upscaling-swinir.pth"

View File

@ -7,13 +7,29 @@ from PIL import Image
from onnx_web.diffusers.run import ( from onnx_web.diffusers.run import (
run_blend_pipeline, run_blend_pipeline,
run_img2img_pipeline, run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline, run_txt2img_pipeline,
run_upscale_pipeline, run_upscale_pipeline,
) )
from onnx_web.params import HighresParams, ImageParams, Size, UpscaleParams from onnx_web.image.mask_filter import mask_filter_none
from onnx_web.image.noise_source import noise_source_uniform
from onnx_web.params import (
Border,
HighresParams,
ImageParams,
Size,
TileOrder,
UpscaleParams,
)
from onnx_web.server.context import ServerContext from onnx_web.server.context import ServerContext
from onnx_web.worker.context import WorkerContext from onnx_web.worker.context import WorkerContext
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models from tests.helpers import (
TEST_MODEL_DIFFUSION_SD15,
TEST_MODEL_DIFFUSION_SD15_INPAINT,
test_device,
test_needs_models,
test_worker,
)
TEST_PROMPT = "an astronaut eating a hamburger" TEST_PROMPT = "an astronaut eating a hamburger"
TEST_SCHEDULER = "ddim" TEST_SCHEDULER = "ddim"
@ -213,25 +229,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
class TestImg2ImgPipeline(unittest.TestCase): class TestImg2ImgPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15]) @test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_basic(self): def test_basic(self):
cancel = Value("L", 0) worker = test_worker()
logs = Queue()
pending = Queue()
progress = Queue()
active = Value("L", 0)
idle = Value("L", 0)
worker = WorkerContext(
"test",
test_device(),
cancel,
logs,
pending,
progress,
active,
idle,
3,
0.1,
)
worker.start("test") worker.start("test")
source = Image.new("RGB", (64, 64), "black") source = Image.new("RGB", (64, 64), "black")
@ -257,6 +255,80 @@ class TestImg2ImgPipeline(unittest.TestCase):
self.assertTrue(path.exists("../outputs/test-img2img.png")) self.assertTrue(path.exists("../outputs/test-img2img.png"))
class TestInpaintPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
def test_basic_white(self):
worker = test_worker()
worker.start("test")
source = Image.new("RGB", (64, 64), "black")
mask = Image.new("RGB", (64, 64), "white")
run_inpaint_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15_INPAINT,
"txt2img",
TEST_SCHEDULER,
TEST_PROMPT,
3.0,
1,
1,
),
Size(*source.size),
["test-inpaint-white.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
mask,
Border.even(0),
noise_source_uniform,
mask_filter_none,
"white",
TileOrder.spiral,
False,
0.0,
)
self.assertTrue(path.exists("../outputs/test-inpaint-white.png"))
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
def test_basic_black(self):
worker = test_worker()
worker.start("test")
source = Image.new("RGB", (64, 64), "black")
mask = Image.new("RGB", (64, 64), "black")
run_inpaint_pipeline(
worker,
ServerContext(model_path="../models", output_path="../outputs"),
ImageParams(
TEST_MODEL_DIFFUSION_SD15_INPAINT,
"txt2img",
TEST_SCHEDULER,
TEST_PROMPT,
3.0,
1,
1,
),
Size(*source.size),
["test-inpaint-black.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
mask,
Border.even(0),
noise_source_uniform,
mask_filter_none,
"black",
TileOrder.spiral,
False,
0.0,
)
self.assertTrue(path.exists("../outputs/test-inpaint-black.png"))
class TestUpscalePipeline(unittest.TestCase): class TestUpscalePipeline(unittest.TestCase):
@test_needs_models(["../models/upscaling-stable-diffusion-x4"]) @test_needs_models(["../models/upscaling-stable-diffusion-x4"])
def test_basic(self): def test_basic(self):