1
0
Fork 0

feat(api): add tile size and stride to image parameters

This commit is contained in:
Sean Sube 2023-05-01 23:20:40 -05:00
parent 746e33b1f5
commit 95725fff79
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
15 changed files with 140 additions and 47 deletions

View File

@ -31,6 +31,7 @@ def blend_controlnet(
pipe = load_pipeline(
server,
params,
"controlnet",
params.model,
params.scheduler,

View File

@ -33,6 +33,7 @@ def blend_img2img(
pipe_type = "lpw" if params.lpw() else "img2img"
pipe = load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,

View File

@ -62,6 +62,7 @@ def blend_inpaint(
pipe_type = "lpw" if params.lpw() else "inpaint"
pipe = load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,

View File

@ -34,6 +34,7 @@ def blend_pix2pix(
pipe = load_pipeline(
server,
params,
"pix2pix",
params.model,
params.scheduler,

View File

@ -40,6 +40,7 @@ def source_txt2img(
pipe_type = "lpw" if params.lpw() else "txt2img"
pipe = load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,

View File

@ -82,6 +82,7 @@ def upscale_outpaint(
pipe_type = params.get_valid_pipeline("inpaint", params.pipeline)
pipe = load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,
@ -125,9 +126,7 @@ def upscale_outpaint(
if params.pipeline == "panorama":
logger.debug("outpainting with one shot panorama, no tiling")
return outpaint(
source, (source.width, source.height, max(source.width, source.height))
)
return outpaint(source, (0, 0, max(source.width, source.height)))
if overlap == 0:
logger.debug("outpainting with 0 margin, using grid tiling")
output = process_tile_grid(source, SizeChart.auto, 1, [outpaint])

View File

@ -37,6 +37,7 @@ def upscale_stable_diffusion(
pipeline = load_pipeline(
server,
params,
"upscale",
path.join(server.model_path, upscale.upscale_model),
params.scheduler,

View File

@ -10,8 +10,7 @@ from ..convert.diffusion.lora import blend_loras, buffer_external_data_tensors
from ..convert.diffusion.textual_inversion import blend_textual_inversions
from ..diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ..diffusers.utils import expand_prompt
from ..models.meta import NetworkModel
from ..params import DeviceParams
from ..params import DeviceParams, ImageParams
from ..server import ServerContext
from ..utils import run_gc
from .patches.unet import UNetWrapper
@ -93,20 +92,20 @@ def get_scheduler_name(scheduler: Any) -> Optional[str]:
def load_pipeline(
server: ServerContext,
params: ImageParams,
pipeline: str,
model: str,
scheduler_name: str,
device: DeviceParams,
control: Optional[NetworkModel] = None,
inversions: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None,
):
inversions = inversions or []
loras = loras or []
control_key = control.name if control is not None else None
model = params.model
torch_dtype = server.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
control_key = params.control.name if params.control is not None else None
pipe_key = (
pipeline,
model,
@ -116,8 +115,8 @@ def load_pipeline(
inversions,
loras,
)
scheduler_key = (scheduler_name, model)
scheduler_type = pipeline_schedulers[scheduler_name]
scheduler_key = (params.scheduler, model)
scheduler_type = pipeline_schedulers[params.scheduler]
cache_pipe = server.cache.get("diffusion", pipe_key)
@ -164,8 +163,10 @@ def load_pipeline(
unet_type = "unet"
# ControlNet component
if pipeline == "controlnet" and control is not None:
cnet_path = path.join(server.model_path, "control", f"{control.name}.onnx")
if pipeline == "controlnet" and params.control is not None:
cnet_path = path.join(
server.model_path, "control", f"{params.control.name}.onnx"
)
logger.debug("loading ControlNet weights from %s", cnet_path)
components["controlnet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
@ -317,6 +318,11 @@ def load_pipeline(
)
)
# additional options for panorama pipeline
if pipeline == "panorama":
components["window"] = params.tiles
components["stride"] = params.stride()
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class.from_pretrained(
@ -333,12 +339,12 @@ def load_pipeline(
optimize_pipeline(server, pipe)
# TODO: CPU VAE, etc
# TODO: remove this, not relevant with ONNX
if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device.torch_str())
# monkey-patch pipeline
patch_pipeline(server, pipe, pipeline)
patch_pipeline(server, pipe, pipeline_class, params)
server.cache.set("diffusion", pipe_key, pipe)
server.cache.set("scheduler", scheduler_key, components["scheduler"])
@ -402,6 +408,7 @@ def patch_pipeline(
server: ServerContext,
pipe: StableDiffusionPipeline,
pipeline: Any,
params: ImageParams,
) -> None:
logger.debug("patching SD pipeline")
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
@ -411,9 +418,21 @@ def patch_pipeline(
if hasattr(pipe, "vae_decoder"):
original_decoder = pipe.vae_decoder
pipe.vae_decoder = VAEWrapper(server, original_decoder, decoder=True)
pipe.vae_decoder = VAEWrapper(
server,
original_decoder,
decoder=True,
tiles=params.tiles,
stride=params.stride(),
)
original_encoder = pipe.vae_encoder
pipe.vae_encoder = VAEWrapper(server, original_encoder, decoder=False)
pipe.vae_encoder = VAEWrapper(
server,
original_encoder,
decoder=False,
tiles=params.tiles,
stride=params.stride(),
)
elif hasattr(pipe, "vae"):
pass # TODO: current wrapper does not work with upscaling VAE
else:

View File

@ -12,8 +12,6 @@ from ...server import ServerContext
logger = getLogger(__name__)
LATENT_CHANNELS = 4
LATENT_SIZE = 32
SAMPLE_SIZE = 256
# TODO: does this need to change for fp16 modes?
timestep_dtype = np.float32
@ -25,14 +23,23 @@ def set_vae_dtype(dtype):
class VAEWrapper(object):
def __init__(self, server: ServerContext, wrapped: OnnxRuntimeModel, decoder: bool):
def __init__(
self,
server: ServerContext,
wrapped: OnnxRuntimeModel,
decoder: bool,
tiles: int,
stride: int,
):
self.server = server
self.wrapped = wrapped
self.decoder = decoder
self.tiles = tiles
self.stride = stride
self.tile_sample_min_size = SAMPLE_SIZE
self.tile_latent_min_size = LATENT_SIZE
self.tile_overlap_factor = 0.25
self.tile_latent_min_size = tiles
self.tile_sample_min_size = tiles * 8
self.tile_overlap_factor = stride / tiles
def __call__(self, latent_sample=None, sample=None, **kwargs):
global timestep_dtype
@ -52,10 +59,16 @@ class VAEWrapper(object):
logger.debug("converting VAE sample dtype")
sample = sample.astype(timestep_dtype)
if self.decoder:
return self.tiled_decode(latent_sample, **kwargs)
if self.tiles is not None and self.stride is not None:
if self.decoder:
return self.tiled_decode(latent_sample, **kwargs)
else:
return self.tiled_encode(sample, **kwargs)
else:
return self.tiled_encode(sample, **kwargs)
if self.decoder:
return self.wrapped(latent_sample=latent_sample)
else:
return self.wrapped(sample=sample)
def __getattr__(self, attr):
return getattr(self.wrapped, attr)

View File

@ -33,6 +33,9 @@ logger = logging.get_logger(__name__)
NUM_UNET_INPUT_CHANNELS = 9
NUM_LATENT_CHANNELS = 4
DEFAULT_WINDOW = 32
DEFAULT_STRIDE = 8
def preprocess(image):
if isinstance(image, torch.Tensor):
@ -105,9 +108,14 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
window: Optional[int] = None,
stride: Optional[int] = None,
):
super().__init__()
self.window = window or DEFAULT_WINDOW
self.stride = stride or DEFAULT_STRIDE
if (
hasattr(scheduler.config, "steps_offset")
and scheduler.config.steps_offset != 1
@ -338,7 +346,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}."
)
def get_views(self, panorama_height, panorama_width, window_size=32, stride=8):
def get_views(self, panorama_height, panorama_width, window_size, stride):
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
panorama_height /= 8
panorama_width /= 8
@ -514,7 +522,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# panorama additions
views = self.get_views(height, width)
views = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents)
value = np.zeros_like(latents)
@ -816,7 +824,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# panorama additions
views = self.get_views(height, width)
views = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents)
value = np.zeros_like(latents)
@ -1124,7 +1132,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
# panorama additions
views = self.get_views(height, width)
views = self.get_views(height, width, self.window, self.stride)
count = np.zeros_like(latents)
value = np.zeros_like(latents)

View File

@ -55,9 +55,8 @@ def run_loopback(
pipe = pipeline or load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,
job.get_device(),
inversions=inversions,
loras=loras,
@ -140,9 +139,8 @@ def run_highres(
highres_pipe = pipeline or load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,
job.get_device(),
inversions=inversions,
loras=loras,
@ -242,9 +240,8 @@ def run_txt2img_pipeline(
pipe = load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,
job.get_device(),
inversions=inversions,
loras=loras,
@ -350,11 +347,9 @@ def run_img2img_pipeline(
pipe_type = params.get_valid_pipeline("img2img")
pipe = load_pipeline(
server,
params,
pipe_type,
params.model,
params.scheduler,
job.get_device(),
control=params.control,
inversions=inversions,
loras=loras,
)

View File

@ -191,6 +191,9 @@ class ImageParams:
input_prompt: str
input_negative_prompt: str
loopback: int
tiled_vae: bool
tiles: int
overlap: float
def __init__(
self,
@ -208,6 +211,9 @@ class ImageParams:
input_prompt: Optional[str] = None,
input_negative_prompt: Optional[str] = None,
loopback: int = 0,
tiled_vae: bool = False,
tiles: int = 512,
overlap: float = 0.25,
) -> None:
self.model = model
self.pipeline = pipeline
@ -223,6 +229,9 @@ class ImageParams:
self.input_prompt = input_prompt or prompt
self.input_negative_prompt = input_negative_prompt or negative_prompt
self.loopback = loopback
self.tiled_vae = tiled_vae
self.tiles = tiles
self.overlap = overlap
def do_cfg(self):
return self.cfg > 1.0
@ -251,6 +260,9 @@ class ImageParams:
def lpw(self):
return self.pipeline == "lpw"
def stride(self):
return int(self.tiles * self.stride)
def tojson(self) -> Dict[str, Optional[Param]]:
return {
"model": self.model,
@ -267,6 +279,9 @@ class ImageParams:
"input_prompt": self.input_prompt,
"input_negative_prompt": self.input_negative_prompt,
"loopback": self.loopback,
"tiled_vae": self.tiled_vae,
"tiles": self.tiles,
"overlap": self.overlap,
}
def with_args(self, **kwargs):
@ -285,6 +300,9 @@ class ImageParams:
kwargs.get("input_prompt", self.input_prompt),
kwargs.get("input_negative_prompt", self.input_negative_prompt),
kwargs.get("loopback", self.loopback),
kwargs.get("tiled_vae", self.tiled_vae),
kwargs.get("tiles", self.tiles),
kwargs.get("overlap", self.overlap),
)
@ -405,14 +423,12 @@ class HighresParams:
strength: float,
method: Literal["bilinear", "lanczos", "upscale"] = "lanczos",
iterations: int = 1,
tiled_vae: bool = False,
):
self.scale = scale
self.steps = steps
self.strength = strength
self.method = method
self.iterations = iterations
self.tiled_vae = tiled_vae
def resize(self, size: Size) -> Size:
return Size(
@ -427,5 +443,4 @@ class HighresParams:
"scale": self.scale,
"steps": self.steps,
"strength": self.strength,
"tiled_vae": self.tiled_vae,
}

View File

@ -125,6 +125,21 @@ def pipeline_from_request(
get_config_value("width", "max"),
get_config_value("width", "min"),
)
tiled_vae = get_boolean(request.args, "tiledVAE", get_config_value("tiledVAE"))
tiles = get_and_clamp_int(
request.args,
"tiles",
get_config_value("tiles"),
get_config_value("tiles", "max"),
get_config_value("tiles", "min"),
)
overlap = get_and_clamp_float(
request.args,
"overlap",
get_config_value("overlap"),
get_config_value("overlap", "max"),
get_config_value("overlap", "min"),
)
seed = int(request.args.get("seed", -1))
if seed == -1:
@ -159,6 +174,9 @@ def pipeline_from_request(
batch=batch,
control=control,
loopback=loopback,
tiled_vae=tiled_vae,
tiles=tiles,
overlap=overlap,
)
size = Size(width, height)
return (device, params, size)
@ -279,17 +297,13 @@ def highres_from_request() -> HighresParams:
request.args,
"highresStrength",
get_config_value("highresStrength"),
get_config_value("highresStrength"),
get_config_value("highresStrength", "max"),
get_config_value("highresStrength", "min"),
)
tiled_vae = get_boolean(request.args, "tiledVAE", get_config_value("tiledVAE"))
return HighresParams(
scale,
steps,
strength,
method=method,
iterations=iterations,
tiled_vae=tiled_vae,
)

View File

@ -130,6 +130,12 @@
"max": 4,
"step": 1
},
"overlap": {
"default": 0.25,
"min": 0.0,
"max": 0.9,
"step": 0.01
},
"pipeline": {
"default": "",
"keys": [
@ -183,6 +189,12 @@
"tiledVAE": {
"default": false
},
"tiles": {
"default": 512,
"min": 128,
"max": 1024,
"step": 128
},
"tileOrder": {
"default": "spiral",
"keys": [

View File

@ -128,6 +128,12 @@
"max": 4,
"step": 1
},
"overlap": {
"default": 0.25,
"min": 0.0,
"max": 0.9,
"step": 0.01
},
"pipeline": {
"default": "",
"keys": [
@ -181,6 +187,12 @@
"tiledVAE": {
"default": false
},
"tiles": {
"default": 512,
"min": 128,
"max": 1024,
"step": 128
},
"tileOrder": {
"default": "spiral",
"keys": [