fix(api): be more careful with VAE patch flags, add margin to latents if needed
This commit is contained in:
parent
57fc183b15
commit
93e3125e28
|
@ -64,4 +64,4 @@ def shape_mode(arr: np.ndarray) -> str:
|
||||||
elif arr.shape[-1] == 4:
|
elif arr.shape[-1] == 4:
|
||||||
return "RGBA"
|
return "RGBA"
|
||||||
|
|
||||||
raise ValueError("unknown image format")
|
raise ValueError("unknown image format")
|
||||||
|
|
|
@ -253,10 +253,11 @@ def load_pipeline(
|
||||||
for vae in VAE_COMPONENTS:
|
for vae in VAE_COMPONENTS:
|
||||||
if hasattr(pipe, vae):
|
if hasattr(pipe, vae):
|
||||||
vae_model = getattr(pipe, vae)
|
vae_model = getattr(pipe, vae)
|
||||||
vae_model.set_tiled(tiled=params.tiled_vae)
|
if isinstance(vae_model, VAEWrapper):
|
||||||
vae_model.set_window_size(
|
vae_model.set_tiled(tiled=params.tiled_vae)
|
||||||
params.vae_tile // LATENT_FACTOR, params.vae_overlap
|
vae_model.set_window_size(
|
||||||
)
|
params.vae_tile // LATENT_FACTOR, params.vae_overlap
|
||||||
|
)
|
||||||
|
|
||||||
# update panorama params
|
# update panorama params
|
||||||
if params.is_panorama():
|
if params.is_panorama():
|
||||||
|
|
|
@ -300,9 +300,7 @@ def get_tile_latents(
|
||||||
|
|
||||||
tile_latents = full_latents[:, :, y:yt, x:xt]
|
tile_latents = full_latents[:, :, y:yt, x:xt]
|
||||||
|
|
||||||
if tile_latents.shape != full_latents.shape and (
|
if tile_latents.shape[2] < t or tile_latents.shape[3] < t:
|
||||||
tile_latents.shape[2] < t or tile_latents.shape[3] < t
|
|
||||||
):
|
|
||||||
extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0])
|
extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0])
|
||||||
extra_latents[
|
extra_latents[
|
||||||
:, :, 0 : tile_latents.shape[2], 0 : tile_latents.shape[3]
|
:, :, 0 : tile_latents.shape[2], 0 : tile_latents.shape[3]
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
|
from multiprocessing import Queue, Value
|
||||||
from os import path
|
from os import path
|
||||||
from typing import List
|
from typing import List
|
||||||
from unittest import skipUnless
|
from unittest import skipUnless
|
||||||
|
|
||||||
from onnx_web.params import DeviceParams
|
from onnx_web.params import DeviceParams
|
||||||
|
from onnx_web.worker.context import WorkerContext
|
||||||
|
|
||||||
|
|
||||||
def test_needs_models(models: List[str]):
|
def test_needs_models(models: List[str]):
|
||||||
|
@ -21,6 +23,29 @@ def test_device() -> DeviceParams:
|
||||||
return DeviceParams("cpu", "CPUExecutionProvider")
|
return DeviceParams("cpu", "CPUExecutionProvider")
|
||||||
|
|
||||||
|
|
||||||
|
def test_worker() -> WorkerContext:
|
||||||
|
cancel = Value("L", 0)
|
||||||
|
logs = Queue()
|
||||||
|
pending = Queue()
|
||||||
|
progress = Queue()
|
||||||
|
active = Value("L", 0)
|
||||||
|
idle = Value("L", 0)
|
||||||
|
|
||||||
|
return WorkerContext(
|
||||||
|
"test",
|
||||||
|
test_device(),
|
||||||
|
cancel,
|
||||||
|
logs,
|
||||||
|
pending,
|
||||||
|
progress,
|
||||||
|
active,
|
||||||
|
idle,
|
||||||
|
3,
|
||||||
|
0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TEST_MODEL_CORRECTION_CODEFORMER = "../models/.cache/correction-codeformer.pth"
|
TEST_MODEL_CORRECTION_CODEFORMER = "../models/.cache/correction-codeformer.pth"
|
||||||
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
|
TEST_MODEL_DIFFUSION_SD15 = "../models/stable-diffusion-onnx-v1-5"
|
||||||
|
TEST_MODEL_DIFFUSION_SD15_INPAINT = "../models/stable-diffusion-onnx-v1-inpainting"
|
||||||
TEST_MODEL_UPSCALING_SWINIR = "../models/.cache/upscaling-swinir.pth"
|
TEST_MODEL_UPSCALING_SWINIR = "../models/.cache/upscaling-swinir.pth"
|
||||||
|
|
|
@ -7,13 +7,29 @@ from PIL import Image
|
||||||
from onnx_web.diffusers.run import (
|
from onnx_web.diffusers.run import (
|
||||||
run_blend_pipeline,
|
run_blend_pipeline,
|
||||||
run_img2img_pipeline,
|
run_img2img_pipeline,
|
||||||
|
run_inpaint_pipeline,
|
||||||
run_txt2img_pipeline,
|
run_txt2img_pipeline,
|
||||||
run_upscale_pipeline,
|
run_upscale_pipeline,
|
||||||
)
|
)
|
||||||
from onnx_web.params import HighresParams, ImageParams, Size, UpscaleParams
|
from onnx_web.image.mask_filter import mask_filter_none
|
||||||
|
from onnx_web.image.noise_source import noise_source_uniform
|
||||||
|
from onnx_web.params import (
|
||||||
|
Border,
|
||||||
|
HighresParams,
|
||||||
|
ImageParams,
|
||||||
|
Size,
|
||||||
|
TileOrder,
|
||||||
|
UpscaleParams,
|
||||||
|
)
|
||||||
from onnx_web.server.context import ServerContext
|
from onnx_web.server.context import ServerContext
|
||||||
from onnx_web.worker.context import WorkerContext
|
from onnx_web.worker.context import WorkerContext
|
||||||
from tests.helpers import TEST_MODEL_DIFFUSION_SD15, test_device, test_needs_models
|
from tests.helpers import (
|
||||||
|
TEST_MODEL_DIFFUSION_SD15,
|
||||||
|
TEST_MODEL_DIFFUSION_SD15_INPAINT,
|
||||||
|
test_device,
|
||||||
|
test_needs_models,
|
||||||
|
test_worker,
|
||||||
|
)
|
||||||
|
|
||||||
TEST_PROMPT = "an astronaut eating a hamburger"
|
TEST_PROMPT = "an astronaut eating a hamburger"
|
||||||
TEST_SCHEDULER = "ddim"
|
TEST_SCHEDULER = "ddim"
|
||||||
|
@ -213,25 +229,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
class TestImg2ImgPipeline(unittest.TestCase):
|
class TestImg2ImgPipeline(unittest.TestCase):
|
||||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
cancel = Value("L", 0)
|
worker = test_worker()
|
||||||
logs = Queue()
|
|
||||||
pending = Queue()
|
|
||||||
progress = Queue()
|
|
||||||
active = Value("L", 0)
|
|
||||||
idle = Value("L", 0)
|
|
||||||
|
|
||||||
worker = WorkerContext(
|
|
||||||
"test",
|
|
||||||
test_device(),
|
|
||||||
cancel,
|
|
||||||
logs,
|
|
||||||
pending,
|
|
||||||
progress,
|
|
||||||
active,
|
|
||||||
idle,
|
|
||||||
3,
|
|
||||||
0.1,
|
|
||||||
)
|
|
||||||
worker.start("test")
|
worker.start("test")
|
||||||
|
|
||||||
source = Image.new("RGB", (64, 64), "black")
|
source = Image.new("RGB", (64, 64), "black")
|
||||||
|
@ -257,6 +255,80 @@ class TestImg2ImgPipeline(unittest.TestCase):
|
||||||
self.assertTrue(path.exists("../outputs/test-img2img.png"))
|
self.assertTrue(path.exists("../outputs/test-img2img.png"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestInpaintPipeline(unittest.TestCase):
|
||||||
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
|
||||||
|
def test_basic_white(self):
|
||||||
|
worker = test_worker()
|
||||||
|
worker.start("test")
|
||||||
|
|
||||||
|
source = Image.new("RGB", (64, 64), "black")
|
||||||
|
mask = Image.new("RGB", (64, 64), "white")
|
||||||
|
run_inpaint_pipeline(
|
||||||
|
worker,
|
||||||
|
ServerContext(model_path="../models", output_path="../outputs"),
|
||||||
|
ImageParams(
|
||||||
|
TEST_MODEL_DIFFUSION_SD15_INPAINT,
|
||||||
|
"txt2img",
|
||||||
|
TEST_SCHEDULER,
|
||||||
|
TEST_PROMPT,
|
||||||
|
3.0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
Size(*source.size),
|
||||||
|
["test-inpaint-white.png"],
|
||||||
|
UpscaleParams("test"),
|
||||||
|
HighresParams(False, 1, 0, 0),
|
||||||
|
source,
|
||||||
|
mask,
|
||||||
|
Border.even(0),
|
||||||
|
noise_source_uniform,
|
||||||
|
mask_filter_none,
|
||||||
|
"white",
|
||||||
|
TileOrder.spiral,
|
||||||
|
False,
|
||||||
|
0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(path.exists("../outputs/test-inpaint-white.png"))
|
||||||
|
|
||||||
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
|
||||||
|
def test_basic_black(self):
|
||||||
|
worker = test_worker()
|
||||||
|
worker.start("test")
|
||||||
|
|
||||||
|
source = Image.new("RGB", (64, 64), "black")
|
||||||
|
mask = Image.new("RGB", (64, 64), "black")
|
||||||
|
run_inpaint_pipeline(
|
||||||
|
worker,
|
||||||
|
ServerContext(model_path="../models", output_path="../outputs"),
|
||||||
|
ImageParams(
|
||||||
|
TEST_MODEL_DIFFUSION_SD15_INPAINT,
|
||||||
|
"txt2img",
|
||||||
|
TEST_SCHEDULER,
|
||||||
|
TEST_PROMPT,
|
||||||
|
3.0,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
),
|
||||||
|
Size(*source.size),
|
||||||
|
["test-inpaint-black.png"],
|
||||||
|
UpscaleParams("test"),
|
||||||
|
HighresParams(False, 1, 0, 0),
|
||||||
|
source,
|
||||||
|
mask,
|
||||||
|
Border.even(0),
|
||||||
|
noise_source_uniform,
|
||||||
|
mask_filter_none,
|
||||||
|
"black",
|
||||||
|
TileOrder.spiral,
|
||||||
|
False,
|
||||||
|
0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(path.exists("../outputs/test-inpaint-black.png"))
|
||||||
|
|
||||||
|
|
||||||
class TestUpscalePipeline(unittest.TestCase):
|
class TestUpscalePipeline(unittest.TestCase):
|
||||||
@test_needs_models(["../models/upscaling-stable-diffusion-x4"])
|
@test_needs_models(["../models/upscaling-stable-diffusion-x4"])
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
|
Loading…
Reference in New Issue