diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index f5bd3a06..11939b02 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -34,6 +34,24 @@ from .utils import get_latents_from_seed, parse_prompt logger = getLogger(__name__) +def get_base_tile(params: ImageParams, size: Size) -> int: + if params.is_panorama(): + tile = max(params.unet_tile, size.width, size.height) + logger.debug("adjusting tile size for panorama to %s", tile) + return tile + + return params.unet_tile + + +def get_highres_tile( + server: ServerContext, params: ImageParams, highres: HighresParams, tile: int +) -> int: + if params.is_panorama() and server.has_feature("panorama-highres"): + return tile * highres.scale + + return params.unet_tile + + def run_txt2img_pipeline( worker: WorkerContext, server: ServerContext, @@ -44,11 +62,7 @@ def run_txt2img_pipeline( highres: HighresParams, ) -> None: # if using panorama, the pipeline will tile itself (views) - if params.is_panorama(): - tile_size = max(params.unet_tile, size.width, size.height) - logger.debug("adjusting tile size for panorama to %s", tile_size) - else: - tile_size = params.unet_tile + tile_size = get_base_tile(params, size) # prepare the chain pipeline and first stage chain = ChainPipeline() @@ -63,12 +77,8 @@ def run_txt2img_pipeline( ) # apply upscaling and correction, before highres - highres_size = params.unet_tile + highres_size = get_highres_tile(server, params, highres, tile_size) if params.is_panorama(): - if server.has_feature("panorama-highres"): - # run the whole highres pass with one panorama call - highres_size = tile_size * highres.scale - chain.stage( BlendDenoiseStage(), StageParams(tile_size=highres_size), @@ -151,13 +161,13 @@ def run_img2img_pipeline( source = f(server, source) # prepare the chain pipeline and first stage + tile_size = get_base_tile(params, Size(*source.size)) chain = ChainPipeline() - stage = StageParams( - tile_size=params.unet_tile, - ) chain.stage( BlendImg2ImgStage(), - stage, + StageParams( + tile_size=tile_size, + ), prompt_index=0, strength=strength, overlap=params.vae_overlap, @@ -167,7 +177,10 @@ def run_img2img_pipeline( first_upscale, after_upscale = split_upscale(upscale) if first_upscale: stage_upscale_correction( - stage, + StageParams( + outscale=first_upscale.outscale, + tile_size=tile_size, + ), params, upscale=first_upscale, chain=chain, @@ -177,13 +190,16 @@ def run_img2img_pipeline( for _i in range(params.loopback): chain.stage( BlendImg2ImgStage(), - stage, + StageParams( + tile_size=tile_size, + ), strength=strength, ) # highres, if selected + highres_size = get_highres_tile(server, params, highres, tile_size) stage_highres( - stage, + StageParams(tile_size=highres_size, outscale=highres.scale), params, highres, upscale, @@ -193,7 +209,7 @@ def run_img2img_pipeline( # apply upscaling and correction, after highres stage_upscale_correction( - stage, + StageParams(tile_size=tile_size, outscale=after_upscale.scale), params, upscale=after_upscale, chain=chain, @@ -252,7 +268,7 @@ def run_inpaint_pipeline( full_res_inpaint_padding: float, ) -> None: logger.debug("building inpaint pipeline") - tile_size = params.unet_tile + tile_size = get_base_tile(params, size) if mask is None: # if no mask was provided, keep the full source image @@ -339,10 +355,9 @@ def run_inpaint_pipeline( # set up the chain pipeline and base stage chain = ChainPipeline() - stage = StageParams(tile_order=tile_order, tile_size=tile_size) chain.stage( UpscaleOutpaintStage(), - stage, + StageParams(tile_order=tile_order, tile_size=tile_size), border=border, mask=mask, fill_color=fill_color, @@ -356,15 +371,16 @@ def run_inpaint_pipeline( first_upscale, after_upscale = split_upscale(upscale) if first_upscale: stage_upscale_correction( - stage, + StageParams(outscale=first_upscale.outscale, tile_size=tile_size), params, upscale=first_upscale, chain=chain, ) # apply highres + highres_size = get_highres_tile(server, params, highres, tile_size) stage_highres( - stage, + StageParams(outscale=highres.scale, tile_size=highres_size), params, highres, upscale, @@ -374,7 +390,7 @@ def run_inpaint_pipeline( # apply upscaling and correction stage_upscale_correction( - stage, + StageParams(outscale=after_upscale.outscale), params, upscale=after_upscale, chain=chain, @@ -433,21 +449,22 @@ def run_upscale_pipeline( ) -> None: # set up the chain pipeline, no base stage for upscaling chain = ChainPipeline() - stage = StageParams(tile_size=params.unet_tile) + tile_size = get_base_tile(params, size) # apply upscaling and correction, before highres first_upscale, after_upscale = split_upscale(upscale) if first_upscale: stage_upscale_correction( - stage, + StageParams(outscale=first_upscale.outscale, tile_size=tile_size), params, upscale=first_upscale, chain=chain, ) # apply highres + highres_size = get_highres_tile(server, params, highres, tile_size) stage_highres( - stage, + StageParams(outscale=highres.scale, tile_size=highres_size), params, highres, upscale, @@ -457,7 +474,7 @@ def run_upscale_pipeline( # apply upscaling and correction, after highres stage_upscale_correction( - stage, + StageParams(outscale=after_upscale.outscale, tile_size=tile_size), params, upscale=after_upscale, chain=chain, @@ -504,12 +521,18 @@ def run_blend_pipeline( ) -> None: # set up the chain pipeline and base stage chain = ChainPipeline() - stage = StageParams() - chain.stage(BlendMaskStage(), stage, stage_source=sources[1], stage_mask=mask) + tile_size = get_base_tile(params, size) + + chain.stage( + BlendMaskStage(), + StageParams(tile_size=tile_size), + stage_source=sources[1], + stage_mask=mask, + ) # apply upscaling and correction stage_upscale_correction( - stage, + StageParams(outscale=upscale.outscale), params, upscale=upscale, chain=chain, diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 7d0ad48c..34f1d070 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -488,10 +488,14 @@ class HighresParams: self.method = method self.iterations = iterations + def outscale(self) -> int: + return self.scale**self.iterations + def resize(self, size: Size) -> Size: + outscale = self.outscale() return Size( - size.width * (self.scale**self.iterations), - size.height * (self.scale**self.iterations), + size.width * outscale, + size.height * outscale, ) def tojson(self):