fix(api): be more careful with VAE patch flags, add margin to latents if needed
This commit is contained in:
parent
57fc183b15
commit
93e3125e28
|
@ -64,4 +64,4 @@ def shape_mode(arr: np.ndarray) -> str:
|
|||
elif arr.shape[-1] == 4:
|
||||
return "RGBA"
|
||||
|
||||
raise ValueError("unknown image format")
|
||||
raise ValueError("unknown image format")
|
||||
|
|
|
@ -253,10 +253,11 @@ def load_pipeline(
|
|||
for vae in VAE_COMPONENTS:
|
||||
if hasattr(pipe, vae):
|
||||
vae_model = getattr(pipe, vae)
|
||||
vae_model.set_tiled(tiled=params.tiled_vae)
|
||||
vae_model.set_window_size(
|
||||
params.vae_tile // LATENT_FACTOR, params.vae_overlap
|
||||
)
|
||||
if isinstance(vae_model, VAEWrapper):
|
||||
vae_model.set_tiled(tiled=params.tiled_vae)
|
||||
vae_model.set_window_size(
|
||||
params.vae_tile // LATENT_FACTOR, params.vae_overlap
|
||||
)
|
||||
|
||||
# update panorama params
|
||||
if params.is_panorama():
|
||||
|
|
|
@ -300,9 +300,7 @@ def get_tile_latents(
|
|||
|
||||
tile_latents = full_latents[:, :, y:yt, x:xt]
|
||||
|
||||
if tile_latents.shape != full_latents.shape and (
|
||||
tile_latents.shape[2] < t or tile_latents.shape[3] < t
|
||||
):
|
||||
if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
|
||||
extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0])
|
||||
extra_latents[
|
||||
:, :, 0 : tile_latents.shape[2], 0 : tile_latents.shape[3]
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from multiprocessing import Queue, Value
|
||||
from os import path
|
||||
from typing import List
|
||||
from unittest import skipUnless
|
||||
|
||||
from onnx_web.params import DeviceParams
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
|
||||
|
||||
def test_needs_models(models: List[str]):
|
||||
|
@ -21,6 +23,29 @@ def test_device() -> DeviceParams:
|
|||
return DeviceParams("cpu", "CPUExecutionProvider")
|
||||
|
||||
|
||||
def test_worker() -> WorkerContext:
|
||||
cancel = Value("L", 0)
|
||||
logs = Queue()
|
||||
pending = Queue()
|
||||
progress = Queue()
|
||||
active = Value("L", 0)
|
||||
idle = Value("L", 0)
|
||||
|
||||
return WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
cancel,
|
||||
logs,
|
||||
pending,
|
||||
progress,
|
||||
active,
|
||||
idle,
|
||||
3,
|
||||
0.1,
|
||||
)
|
||||
|
||||
|
||||
TEST_MODEL_CORRECTION_CODEFORMER = "../models/.cache/correction-codeformer.pth"
|
||||
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
|
||||
TEST_MODEL_DIFFUSION_SD15_INPAINT = "../models/stable-diffusion-onnx-v1-inpainting"
|
||||
TEST_MODEL_UPSCALING_SWINIR = "../models/.cache/upscaling-swinir.pth"
|
||||
|
|
|
@ -7,13 +7,29 @@ from PIL import Image
|
|||
from onnx_web.diffusers.run import (
|
||||
run_blend_pipeline,
|
||||
run_img2img_pipeline,
|
||||
run_inpaint_pipeline,
|
||||
run_txt2img_pipeline,
|
||||
run_upscale_pipeline,
|
||||
)
|
||||
from onnx_web.params import HighresParams, ImageParams, Size, UpscaleParams
|
||||
from onnx_web.image.mask_filter import mask_filter_none
|
||||
from onnx_web.image.noise_source import noise_source_uniform
|
||||
from onnx_web.params import (
|
||||
Border,
|
||||
HighresParams,
|
||||
ImageParams,
|
||||
Size,
|
||||
TileOrder,
|
||||
UpscaleParams,
|
||||
)
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
|
||||
from tests.helpers import (
|
||||
TEST_MODEL_DIFFUSION_SD15,
|
||||
TEST_MODEL_DIFFUSION_SD15_INPAINT,
|
||||
test_device,
|
||||
test_needs_models,
|
||||
test_worker,
|
||||
)
|
||||
|
||||
TEST_PROMPT = "an astronaut eating a hamburger"
|
||||
TEST_SCHEDULER = "ddim"
|
||||
|
@ -213,25 +229,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
class TestImg2ImgPipeline(unittest.TestCase):
|
||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||
def test_basic(self):
|
||||
cancel = Value("L", 0)
|
||||
logs = Queue()
|
||||
pending = Queue()
|
||||
progress = Queue()
|
||||
active = Value("L", 0)
|
||||
idle = Value("L", 0)
|
||||
|
||||
worker = WorkerContext(
|
||||
"test",
|
||||
test_device(),
|
||||
cancel,
|
||||
logs,
|
||||
pending,
|
||||
progress,
|
||||
active,
|
||||
idle,
|
||||
3,
|
||||
0.1,
|
||||
)
|
||||
worker = test_worker()
|
||||
worker.start("test")
|
||||
|
||||
source = Image.new("RGB", (64, 64), "black")
|
||||
|
@ -257,6 +255,80 @@ class TestImg2ImgPipeline(unittest.TestCase):
|
|||
self.assertTrue(path.exists("../outputs/test-img2img.png"))
|
||||
|
||||
|
||||
class TestInpaintPipeline(unittest.TestCase):
|
||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
|
||||
def test_basic_white(self):
|
||||
worker = test_worker()
|
||||
worker.start("test")
|
||||
|
||||
source = Image.new("RGB", (64, 64), "black")
|
||||
mask = Image.new("RGB", (64, 64), "white")
|
||||
run_inpaint_pipeline(
|
||||
worker,
|
||||
ServerContext(model_path="../models", output_path="../outputs"),
|
||||
ImageParams(
|
||||
TEST_MODEL_DIFFUSION_SD15_INPAINT,
|
||||
"txt2img",
|
||||
TEST_SCHEDULER,
|
||||
TEST_PROMPT,
|
||||
3.0,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
Size(*source.size),
|
||||
["test-inpaint-white.png"],
|
||||
UpscaleParams("test"),
|
||||
HighresParams(False, 1, 0, 0),
|
||||
source,
|
||||
mask,
|
||||
Border.even(0),
|
||||
noise_source_uniform,
|
||||
mask_filter_none,
|
||||
"white",
|
||||
TileOrder.spiral,
|
||||
False,
|
||||
0.0,
|
||||
)
|
||||
|
||||
self.assertTrue(path.exists("../outputs/test-inpaint-white.png"))
|
||||
|
||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
|
||||
def test_basic_black(self):
|
||||
worker = test_worker()
|
||||
worker.start("test")
|
||||
|
||||
source = Image.new("RGB", (64, 64), "black")
|
||||
mask = Image.new("RGB", (64, 64), "black")
|
||||
run_inpaint_pipeline(
|
||||
worker,
|
||||
ServerContext(model_path="../models", output_path="../outputs"),
|
||||
ImageParams(
|
||||
TEST_MODEL_DIFFUSION_SD15_INPAINT,
|
||||
"txt2img",
|
||||
TEST_SCHEDULER,
|
||||
TEST_PROMPT,
|
||||
3.0,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
Size(*source.size),
|
||||
["test-inpaint-black.png"],
|
||||
UpscaleParams("test"),
|
||||
HighresParams(False, 1, 0, 0),
|
||||
source,
|
||||
mask,
|
||||
Border.even(0),
|
||||
noise_source_uniform,
|
||||
mask_filter_none,
|
||||
"black",
|
||||
TileOrder.spiral,
|
||||
False,
|
||||
0.0,
|
||||
)
|
||||
|
||||
self.assertTrue(path.exists("../outputs/test-inpaint-black.png"))
|
||||
|
||||
|
||||
class TestUpscalePipeline(unittest.TestCase):
|
||||
@test_needs_models(["../models/upscaling-stable-diffusion-x4"])
|
||||
def test_basic(self):
|
||||
|
|
Loading…
Reference in New Issue