diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 1056b195..0c395159 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -125,14 +125,17 @@ def load_pipeline( logger.debug("reusing existing diffusion pipeline") pipe = cache_pipe + cache_pipe.vae_encoder.set_tiled(tiled=params.tiled_vae) + cache_pipe.vae_decoder.set_tiled(tiled=params.tiled_vae) + # update panorama params if pipeline == "panorama": latent_window = params.tiles // 8 - latent_stride = params.stride() // 8 + latent_stride = params.stride // 8 cache_pipe.set_window_size(latent_window, latent_stride) - cache_pipe.vae_encoder.set_window_size(latent_window, latent_stride) - cache_pipe.vae_decoder.set_window_size(latent_window, latent_stride) + cache_pipe.vae_encoder.set_window_size(latent_window, params.overlap) + cache_pipe.vae_decoder.set_window_size(latent_window, params.overlap) # update scheduler cache_scheduler = server.cache.get("scheduler", scheduler_key) @@ -332,7 +335,7 @@ def load_pipeline( # additional options for panorama pipeline if pipeline == "panorama": components["window"] = params.tiles // 8 - components["stride"] = params.stride() // 8 + components["stride"] = params.stride // 8 pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline) logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__) @@ -433,16 +436,16 @@ def patch_pipeline( server, original_decoder, decoder=True, - tiles=params.tiles, - stride=params.stride(), + window=params.tiles, + overlap=params.overlap, ) original_encoder = pipe.vae_encoder pipe.vae_encoder = VAEWrapper( server, original_encoder, decoder=False, - tiles=params.tiles, - stride=params.stride(), + window=params.tiles, + overlap=params.overlap, ) elif hasattr(pipe, "vae"): pass # TODO: current wrapper does not work with upscaling VAE diff --git a/api/onnx_web/diffusers/patches/vae.py b/api/onnx_web/diffusers/patches/vae.py index 02c9b925..5ae6731c 100644 --- a/api/onnx_web/diffusers/patches/vae.py +++ b/api/onnx_web/diffusers/patches/vae.py @@ -28,21 +28,22 @@ class VAEWrapper(object): server: ServerContext, wrapped: OnnxRuntimeModel, decoder: bool, - tiles: int, - stride: int, + window: int, + overlap: float, ): self.server = server self.wrapped = wrapped self.decoder = decoder - self.set_window_size(tiles, stride) + self.tiled = False + self.set_window_size(window, overlap) - def set_window_size(self, window: int, stride: int): - self.window = window - self.stride = stride + def set_tiled(self, tiled: bool = True): + self.tiled = tiled - self.tile_latent_min_size = self.window - self.tile_sample_min_size = self.window * 8 - self.tile_overlap_factor = self.stride / self.window + def set_window_size(self, window: int, overlap: float): + self.tile_latent_min_size = window + self.tile_sample_min_size = window * 8 + self.tile_overlap_factor = overlap def __call__(self, latent_sample=None, sample=None, **kwargs): global timestep_dtype @@ -62,7 +63,7 @@ class VAEWrapper(object): logger.debug("converting VAE sample dtype") sample = sample.astype(timestep_dtype) - if self.window is not None and self.stride is not None: + if self.tiled: if self.decoder: return self.tiled_decode(latent_sample, **kwargs) else: diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index b40fbd09..a1ba68fb 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -369,6 +369,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): ) return views + @torch.no_grad() def text2img( self, prompt: Union[str, List[str]] = None, @@ -619,6 +620,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): images=image, nsfw_content_detected=has_nsfw_concept ) + @torch.no_grad() def img2img( self, prompt: Union[str, List[str]] = None, diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index f0b62945..4db48ee0 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -214,6 +214,7 @@ class ImageParams: tiled_vae: bool = False, tiles: int = 512, overlap: float = 0.25, + stride: int = 64, ) -> None: self.model = model self.pipeline = pipeline @@ -232,6 +233,7 @@ class ImageParams: self.tiled_vae = tiled_vae self.tiles = tiles self.overlap = overlap + self.stride = stride def do_cfg(self): return self.cfg > 1.0 @@ -260,9 +262,6 @@ class ImageParams: def lpw(self): return self.pipeline == "lpw" - def stride(self): - return int(self.tiles * self.overlap) - def tojson(self) -> Dict[str, Optional[Param]]: return { "model": self.model, @@ -282,6 +281,7 @@ class ImageParams: "tiled_vae": self.tiled_vae, "tiles": self.tiles, "overlap": self.overlap, + "stride": self.stride, } def with_args(self, **kwargs): @@ -303,6 +303,7 @@ class ImageParams: kwargs.get("tiled_vae", self.tiled_vae), kwargs.get("tiles", self.tiles), kwargs.get("overlap", self.overlap), + kwargs.get("stride", self.stride), ) diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index 1908e9ee..627e4b0e 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -140,6 +140,13 @@ def pipeline_from_request( get_config_value("overlap", "max"), get_config_value("overlap", "min"), ) + stride = get_and_clamp_float( + request.args, + "stride", + get_config_value("stride"), + get_config_value("stride", "max"), + get_config_value("stride", "min"), + ) seed = int(request.args.get("seed", -1)) if seed == -1: @@ -177,6 +184,7 @@ def pipeline_from_request( tiled_vae=tiled_vae, tiles=tiles, overlap=overlap, + stride=stride, ) size = Size(width, height) return (device, params, size) diff --git a/api/params.json b/api/params.json index e61c6656..7f21d78c 100644 --- a/api/params.json +++ b/api/params.json @@ -186,6 +186,12 @@ "max": 1, "step": 0.01 }, + "stride": { + "default": 128, + "min": 64, + "max": 512, + "step": 64 + }, "tiledVAE": { "default": false }, diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index a4c076fc..af537eb8 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -51,6 +51,7 @@ export interface BaseImgParams { tiledVAE: boolean; tiles: number; overlap: number; + stride: number; cfg: number; steps: number; @@ -409,6 +410,7 @@ export function makeImageURL(root: string, type: string, params: BaseImgParams): url.searchParams.append('tiledVAE', String(params.tiledVAE)); url.searchParams.append('tiles', params.tiles.toFixed(FIXED_INTEGER)); url.searchParams.append('overlap', params.overlap.toFixed(FIXED_FLOAT)); + url.searchParams.append('stride', params.stride.toFixed(FIXED_FLOAT)); if (doesExist(params.scheduler)) { url.searchParams.append('scheduler', params.scheduler); diff --git a/gui/src/components/control/ImageControl.tsx b/gui/src/components/control/ImageControl.tsx index 1189c26a..297189f5 100644 --- a/gui/src/components/control/ImageControl.tsx +++ b/gui/src/components/control/ImageControl.tsx @@ -178,6 +178,21 @@ export function ImageControl(props: ImageControlProps) { } }} /> + { + if (doesExist(props.onChange)) { + props.onChange({ + ...controlState, + stride, + }); + } + }} + />