From e8d7d9a88186fedcf12ac6ce41819a5c59698715 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 4 Nov 2023 20:41:58 -0500 Subject: [PATCH] feat: split up UNet and VAE tile size and overlap/stride params --- api/onnx_web/chain/highres.py | 6 +- api/onnx_web/chain/source_txt2img.py | 4 +- api/onnx_web/chain/upscale_bsrgan.py | 2 +- api/onnx_web/chain/upscale_outpaint.py | 2 +- api/onnx_web/diffusers/load.py | 17 ++-- api/onnx_web/diffusers/run.py | 18 ++-- api/onnx_web/params.py | 34 ++++--- api/onnx_web/server/params.py | 50 +++++----- api/params.json | 44 +++++---- api/scripts/test-release.py | 4 +- gui/src/client/api.ts | 9 +- gui/src/components/control/ImageControl.tsx | 100 +++++++++++--------- gui/src/config.json | 32 ++++--- gui/src/state.ts | 10 +- gui/src/strings/de.ts | 6 +- gui/src/strings/en.ts | 8 +- gui/src/strings/es.ts | 6 +- gui/src/strings/fr.ts | 6 +- gui/src/types/params.ts | 9 +- 19 files changed, 210 insertions(+), 157 deletions(-) diff --git a/api/onnx_web/chain/highres.py b/api/onnx_web/chain/highres.py index 87b52d9b..482b86c7 100644 --- a/api/onnx_web/chain/highres.py +++ b/api/onnx_web/chain/highres.py @@ -43,7 +43,7 @@ def stage_highres( outscale=highres.scale, ), chain=chain, - overlap=params.overlap, + overlap=params.vae_overlap, ) else: logger.debug("using simple upscaling for highres") @@ -51,14 +51,14 @@ def stage_highres( UpscaleSimpleStage(), stage, method=highres.method, - overlap=params.overlap, + overlap=params.vae_overlap, upscale=upscale.with_args(scale=highres.scale, outscale=highres.scale), ) chain.stage( BlendImg2ImgStage(), stage, - overlap=params.overlap, + overlap=params.vae_overlap, prompt_index=prompt_index + i, strength=highres.strength, ) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 5448364e..13dc70d9 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -60,9 +60,9 @@ class SourceTxt2ImgStage(BaseStage): ) if params.is_xl(): - tile_size = max(stage.tile_size, params.tiles) + tile_size = max(stage.tile_size, params.unet_tile) else: - tile_size = params.tiles + tile_size = params.unet_tile # this works for panorama as well, because tile_size is already max(tile_size, *size) latent_size = size.min(tile_size, tile_size) diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index f9f02f1e..0137750e 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -106,5 +106,5 @@ class UpscaleBSRGANStage(BaseStage): params: ImageParams, size: Size, ) -> int: - tile = min(params.tiles, self.max_tile) + tile = min(params.unet_tile, self.max_tile) return size.width // tile * size.height // tile diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 78d32077..85ddc079 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -71,7 +71,7 @@ class UpscaleOutpaintStage(BaseStage): outputs.append(source) continue - tile_size = params.tiles + tile_size = params.unet_tile size = Size(*source.size) latent_size = size.min(tile_size, tile_size) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 38b61e50..07310e74 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -266,14 +266,13 @@ def load_pipeline( # update panorama params if params.is_panorama(): - latent_window = params.tiles // 8 - latent_stride = params.stride // 8 - - pipe.set_window_size(latent_window, latent_stride) + unet_stride = (params.unet_tile * (1 - params.unet_overlap)) // 8 + logger.debug("setting panorama window parameters: %s/%s for UNet, %s/%s for VAE", params.unet_tile, unet_stride, params.vae_tile, params.vae_overlap) + pipe.set_window_size(params.unet_tile // 8, unet_stride) for vae in VAE_COMPONENTS: if hasattr(pipe, vae): - getattr(pipe, vae).set_window_size(latent_window, params.overlap) + getattr(pipe, vae).set_window_size(params.vae_tile // 8, params.vae_overlap) run_gc([device]) @@ -626,8 +625,8 @@ def patch_pipeline( server, original_decoder, decoder=True, - window=params.tiles, - overlap=params.overlap, + window=params.unet_tile, + overlap=params.vae_overlap, ) logger.debug("patched VAE decoder with wrapper") @@ -637,8 +636,8 @@ def patch_pipeline( server, original_encoder, decoder=False, - window=params.tiles, - overlap=params.overlap, + window=params.unet_tile, + overlap=params.vae_overlap, ) logger.debug("patched VAE encoder with wrapper") diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index a9c72d2f..b86be2b1 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -44,9 +44,9 @@ def run_txt2img_pipeline( ) -> None: # if using panorama, the pipeline will tile itself (views) if params.is_panorama() or params.is_xl(): - tile_size = max(params.tiles, size.width, size.height) + tile_size = max(params.unet_tile, size.width, size.height) else: - tile_size = params.tiles + tile_size = params.unet_tile # prepare the chain pipeline and first stage chain = ChainPipeline() @@ -57,11 +57,11 @@ def run_txt2img_pipeline( ), size=size, prompt_index=0, - overlap=params.overlap, + overlap=params.vae_overlap, ) # apply upscaling and correction, before highres - stage = StageParams(tile_size=params.tiles) + stage = StageParams(tile_size=params.unet_tile) first_upscale, after_upscale = split_upscale(upscale) if first_upscale: stage_upscale_correction( @@ -139,14 +139,14 @@ def run_img2img_pipeline( # prepare the chain pipeline and first stage chain = ChainPipeline() stage = StageParams( - tile_size=params.tiles, + tile_size=params.unet_tile, ) chain.stage( BlendImg2ImgStage(), stage, prompt_index=0, strength=strength, - overlap=params.overlap, + overlap=params.vae_overlap, ) # apply upscaling and correction, before highres @@ -236,7 +236,7 @@ def run_inpaint_pipeline( full_res_inpaint_padding: float, ) -> None: logger.debug("building inpaint pipeline") - tile_size = params.tiles + tile_size = params.unet_tile if mask is None: # if no mask was provided, keep the full source image @@ -332,7 +332,7 @@ def run_inpaint_pipeline( fill_color=fill_color, mask_filter=mask_filter, noise_source=noise_source, - overlap=params.overlap, + overlap=params.vae_overlap, prompt_index=0, ) @@ -410,7 +410,7 @@ def run_upscale_pipeline( ) -> None: # set up the chain pipeline, no base stage for upscaling chain = ChainPipeline() - stage = StageParams(tile_size=params.tiles) + stage = StageParams(tile_size=params.unet_tile) # apply upscaling and correction, before highres first_upscale, after_upscale = split_upscale(upscale) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 885b09b3..3b896cae 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -204,8 +204,10 @@ class ImageParams: input_negative_prompt: str loopback: int tiled_vae: bool - tiles: int - overlap: float + unet_tile: int + unet_overlap: float + vae_tile: int + vae_overlap: float def __init__( self, @@ -224,9 +226,10 @@ class ImageParams: input_negative_prompt: Optional[str] = None, loopback: int = 0, tiled_vae: bool = False, - tiles: int = 512, - overlap: float = 0.25, - stride: int = 64, + unet_overlap: float = 0.25, + unet_tile: int = 512, + vae_overlap: float = 0.25, + vae_tile: int = 512, ) -> None: self.model = model self.pipeline = pipeline @@ -243,9 +246,10 @@ class ImageParams: self.input_negative_prompt = input_negative_prompt or negative_prompt self.loopback = loopback self.tiled_vae = tiled_vae - self.tiles = tiles - self.overlap = overlap - self.stride = stride + self.unet_overlap = unet_overlap + self.unet_tile = unet_tile + self.vae_overlap = vae_overlap + self.vae_tile = vae_tile def do_cfg(self): return self.cfg > 1.0 @@ -312,9 +316,10 @@ class ImageParams: "input_negative_prompt": self.input_negative_prompt, "loopback": self.loopback, "tiled_vae": self.tiled_vae, - "tiles": self.tiles, - "overlap": self.overlap, - "stride": self.stride, + "unet_overlap": self.unet_overlap, + "unet_tile": self.unet_tile, + "vae_overlap": self.vae_overlap, + "vae_tile": self.vae_tile, } def with_args(self, **kwargs): @@ -334,9 +339,10 @@ class ImageParams: kwargs.get("input_negative_prompt", self.input_negative_prompt), kwargs.get("loopback", self.loopback), kwargs.get("tiled_vae", self.tiled_vae), - kwargs.get("tiles", self.tiles), - kwargs.get("overlap", self.overlap), - kwargs.get("stride", self.stride), + kwargs.get("unet_overlap", self.unet_overlap), + kwargs.get("unet_tile", self.unet_tile), + kwargs.get("vae_overlap", self.vae_overlap), + kwargs.get("vae_tile", self.vae_tile), ) diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index d68e2dcc..b8dfe871 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -117,32 +117,35 @@ def build_params( get_config_value("steps", "max"), get_config_value("steps", "min"), ) - tiled_vae = get_boolean(data, "tiledVAE", get_config_value("tiledVAE")) - tiles = get_and_clamp_int( + tiled_vae = get_boolean(data, "tiled_vae", get_config_value("tiled_vae")) + unet_overlap = get_and_clamp_float( data, - "tiles", - get_config_value("tiles"), - get_config_value("tiles", "max"), - get_config_value("tiles", "min"), + "unet_overlap", + get_config_value("unet_overlap"), + get_config_value("unet_overlap", "max"), + get_config_value("unet_overlap", "min"), ) - overlap = get_and_clamp_float( + unet_tile = get_and_clamp_int( data, - "overlap", - get_config_value("overlap"), - get_config_value("overlap", "max"), - get_config_value("overlap", "min"), + "unet_tile", + get_config_value("unet_tile"), + get_config_value("unet_tile", "max"), + get_config_value("unet_tile", "min"), ) - stride = get_and_clamp_int( + vae_overlap = get_and_clamp_float( data, - "stride", - get_config_value("stride"), - get_config_value("stride", "max"), - get_config_value("stride", "min"), + "vae_overlap", + get_config_value("vae_overlap"), + get_config_value("vae_overlap", "max"), + get_config_value("vae_overlap", "min"), + ) + vae_tile = get_and_clamp_int( + data, + "vae_tile", + get_config_value("vae_tile"), + get_config_value("vae_tile", "max"), + get_config_value("vae_tile", "min"), ) - - if stride > tiles: - logger.info("limiting stride to tile size, %s > %s", stride, tiles) - stride = tiles seed = int(data.get("seed", -1)) if seed == -1: @@ -163,9 +166,10 @@ def build_params( control=control, loopback=loopback, tiled_vae=tiled_vae, - tiles=tiles, - overlap=overlap, - stride=stride, + unet_overlap=unet_overlap, + unet_tile=unet_tile, + vae_overlap=vae_overlap, + vae_tile=vae_tile, ) return params diff --git a/api/params.json b/api/params.json index c4a1ee32..9ed7451f 100644 --- a/api/params.json +++ b/api/params.json @@ -141,12 +141,6 @@ "max": 4, "step": 1 }, - "overlap": { - "default": 0.25, - "min": 0.0, - "max": 0.9, - "step": 0.01 - }, "pipeline": { "default": "", "keys": [ @@ -197,21 +191,9 @@ "max": 1, "step": 0.01 }, - "stride": { - "default": 128, - "min": 64, - "max": 512, - "step": 64 - }, - "tiledVAE": { + "tiled_vae": { "default": false }, - "tiles": { - "default": 512, - "min": 128, - "max": 2048, - "step": 128 - }, "tileOrder": { "default": "spiral", "keys": [ @@ -225,6 +207,18 @@ "max": 1024, "step": 8 }, + "unet_overlap": { + "default": 0.25, + "min": 0.0, + "max": 0.9, + "step": 0.01 + }, + "unet_tile": { + "default": 512, + "min": 128, + "max": 2048, + "step": 128 + }, "upscaleOrder": { "default": "correction-first", "keys": [ @@ -237,6 +231,18 @@ "default": "", "keys": [] }, + "vae_overlap": { + "default": 0.25, + "min": 0.0, + "max": 0.9, + "step": 0.01 + }, + "vae_tile": { + "default": 512, + "min": 256, + "max": 1024, + "step": 128 + }, "width": { "default": 512, "min": 128, diff --git a/api/scripts/test-release.py b/api/scripts/test-release.py index 2e6d7fd2..e46843a2 100644 --- a/api/scripts/test-release.py +++ b/api/scripts/test-release.py @@ -305,12 +305,12 @@ TEST_DATA = [ ), TestCase( "txt2img-panorama-1024x768-muffin", - "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiledVAE=true", + "txt2img?prompt=a+giant+muffin&seed=0&scheduler=ddim&width=1024&height=768&pipeline=panorama&tiled_vae=true", max_attempts=VERY_SLOW_TEST, ), TestCase( "img2img-panorama-1024x768-pumpkin", - "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiledVAE=true", + "img2img?prompt=a+giant+pumpkin&seed=0&scheduler=ddim&sourceFilter=none&pipeline=panorama&tiled_vae=true", source="txt2img-panorama-1024x768-muffin-0", max_attempts=VERY_SLOW_TEST, ), diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 3702440a..64fc89dc 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -70,10 +70,11 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams): url.searchParams.append('cfg', params.cfg.toFixed(FIXED_FLOAT)); url.searchParams.append('eta', params.eta.toFixed(FIXED_FLOAT)); url.searchParams.append('steps', params.steps.toFixed(FIXED_INTEGER)); - url.searchParams.append('tiledVAE', String(params.tiledVAE)); - url.searchParams.append('tiles', params.tiles.toFixed(FIXED_INTEGER)); - url.searchParams.append('overlap', params.overlap.toFixed(FIXED_FLOAT)); - url.searchParams.append('stride', params.stride.toFixed(FIXED_INTEGER)); + url.searchParams.append('tiled_vae', String(params.tiled_vae)); + url.searchParams.append('unet_overlap', params.unet_overlap.toFixed(FIXED_FLOAT)); + url.searchParams.append('unet_tile', params.unet_tile.toFixed(FIXED_INTEGER)); + url.searchParams.append('vae_overlap', params.vae_overlap.toFixed(FIXED_FLOAT)); + url.searchParams.append('vae_tile', params.vae_tile.toFixed(FIXED_INTEGER)); if (doesExist(params.scheduler)) { url.searchParams.append('scheduler', params.scheduler); diff --git a/gui/src/components/control/ImageControl.tsx b/gui/src/components/control/ImageControl.tsx index d031b9c5..bca1d2ba 100644 --- a/gui/src/components/control/ImageControl.tsx +++ b/gui/src/components/control/ImageControl.tsx @@ -1,3 +1,4 @@ +/* eslint-disable camelcase */ import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils'; import { Casino } from '@mui/icons-material'; import { Button, Checkbox, FormControlLabel, Stack } from '@mui/material'; @@ -47,9 +48,6 @@ export function ImageControl(props: ImageControlProps) { staleTime: STALE_TIME, }); - // max stride is the lesser of tile size and server's max stride - const maxStride = Math.min(state.tiles, params.stride.max); - return { + label={t('parameter.unet_tile')} + min={params.unet_tile.min} + max={params.unet_tile.max} + step={params.unet_tile.step} + value={state.unet_tile} + onChange={(unet_tile) => { props.onChange({ ...state, - tiles, + unet_tile, + }); + }} + /> + { + props.onChange({ + ...state, + unet_overlap, + }); + }} + /> + { + props.onChange({ + ...state, + tiled_vae: state.tiled_vae === false, + }); + }} + />} + /> + { + props.onChange({ + ...state, + vae_tile, }); }} /> { + disabled={state.tiled_vae === false} + label={t('parameter.vae_tile')} + min={params.vae_tile.min} + max={params.vae_tile.max} + step={params.vae_tile.step} + value={state.vae_tile} + onChange={(vae_tile) => { props.onChange({ ...state, - overlap, + vae_tile, }); }} /> - { - props.onChange({ - ...state, - stride, - }); - }} - /> - { - props.onChange({ - ...state, - tiledVAE: state.tiledVAE === false, - }); - }} - />} - />