1
0
Fork 0

fix(api): consistently handle tile size across premade pipelines

This commit is contained in:
Sean Sube 2023-11-25 14:02:42 -06:00
parent 02447f5fd6
commit d78e843af4
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 60 additions and 33 deletions

View File

@ -34,6 +34,24 @@ from .utils import get_latents_from_seed, parse_prompt
logger = getLogger(__name__) logger = getLogger(__name__)
def get_base_tile(params: ImageParams, size: Size) -> int:
if params.is_panorama():
tile = max(params.unet_tile, size.width, size.height)
logger.debug("adjusting tile size for panorama to %s", tile)
return tile
return params.unet_tile
def get_highres_tile(
server: ServerContext, params: ImageParams, highres: HighresParams, tile: int
) -> int:
if params.is_panorama() and server.has_feature("panorama-highres"):
return tile * highres.scale
return params.unet_tile
def run_txt2img_pipeline( def run_txt2img_pipeline(
worker: WorkerContext, worker: WorkerContext,
server: ServerContext, server: ServerContext,
@ -44,11 +62,7 @@ def run_txt2img_pipeline(
highres: HighresParams, highres: HighresParams,
) -> None: ) -> None:
# if using panorama, the pipeline will tile itself (views) # if using panorama, the pipeline will tile itself (views)
if params.is_panorama(): tile_size = get_base_tile(params, size)
tile_size = max(params.unet_tile, size.width, size.height)
logger.debug("adjusting tile size for panorama to %s", tile_size)
else:
tile_size = params.unet_tile
# prepare the chain pipeline and first stage # prepare the chain pipeline and first stage
chain = ChainPipeline() chain = ChainPipeline()
@ -63,12 +77,8 @@ def run_txt2img_pipeline(
) )
# apply upscaling and correction, before highres # apply upscaling and correction, before highres
highres_size = params.unet_tile highres_size = get_highres_tile(server, params, highres, tile_size)
if params.is_panorama(): if params.is_panorama():
if server.has_feature("panorama-highres"):
# run the whole highres pass with one panorama call
highres_size = tile_size * highres.scale
chain.stage( chain.stage(
BlendDenoiseStage(), BlendDenoiseStage(),
StageParams(tile_size=highres_size), StageParams(tile_size=highres_size),
@ -151,13 +161,13 @@ def run_img2img_pipeline(
source = f(server, source) source = f(server, source)
# prepare the chain pipeline and first stage # prepare the chain pipeline and first stage
tile_size = get_base_tile(params, Size(*source.size))
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams(
tile_size=params.unet_tile,
)
chain.stage( chain.stage(
BlendImg2ImgStage(), BlendImg2ImgStage(),
stage, StageParams(
tile_size=tile_size,
),
prompt_index=0, prompt_index=0,
strength=strength, strength=strength,
overlap=params.vae_overlap, overlap=params.vae_overlap,
@ -167,7 +177,10 @@ def run_img2img_pipeline(
first_upscale, after_upscale = split_upscale(upscale) first_upscale, after_upscale = split_upscale(upscale)
if first_upscale: if first_upscale:
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(
outscale=first_upscale.outscale,
tile_size=tile_size,
),
params, params,
upscale=first_upscale, upscale=first_upscale,
chain=chain, chain=chain,
@ -177,13 +190,16 @@ def run_img2img_pipeline(
for _i in range(params.loopback): for _i in range(params.loopback):
chain.stage( chain.stage(
BlendImg2ImgStage(), BlendImg2ImgStage(),
stage, StageParams(
tile_size=tile_size,
),
strength=strength, strength=strength,
) )
# highres, if selected # highres, if selected
highres_size = get_highres_tile(server, params, highres, tile_size)
stage_highres( stage_highres(
stage, StageParams(tile_size=highres_size, outscale=highres.scale),
params, params,
highres, highres,
upscale, upscale,
@ -193,7 +209,7 @@ def run_img2img_pipeline(
# apply upscaling and correction, after highres # apply upscaling and correction, after highres
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(tile_size=tile_size, outscale=after_upscale.scale),
params, params,
upscale=after_upscale, upscale=after_upscale,
chain=chain, chain=chain,
@ -252,7 +268,7 @@ def run_inpaint_pipeline(
full_res_inpaint_padding: float, full_res_inpaint_padding: float,
) -> None: ) -> None:
logger.debug("building inpaint pipeline") logger.debug("building inpaint pipeline")
tile_size = params.unet_tile tile_size = get_base_tile(params, size)
if mask is None: if mask is None:
# if no mask was provided, keep the full source image # if no mask was provided, keep the full source image
@ -339,10 +355,9 @@ def run_inpaint_pipeline(
# set up the chain pipeline and base stage # set up the chain pipeline and base stage
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams(tile_order=tile_order, tile_size=tile_size)
chain.stage( chain.stage(
UpscaleOutpaintStage(), UpscaleOutpaintStage(),
stage, StageParams(tile_order=tile_order, tile_size=tile_size),
border=border, border=border,
mask=mask, mask=mask,
fill_color=fill_color, fill_color=fill_color,
@ -356,15 +371,16 @@ def run_inpaint_pipeline(
first_upscale, after_upscale = split_upscale(upscale) first_upscale, after_upscale = split_upscale(upscale)
if first_upscale: if first_upscale:
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(outscale=first_upscale.outscale, tile_size=tile_size),
params, params,
upscale=first_upscale, upscale=first_upscale,
chain=chain, chain=chain,
) )
# apply highres # apply highres
highres_size = get_highres_tile(server, params, highres, tile_size)
stage_highres( stage_highres(
stage, StageParams(outscale=highres.scale, tile_size=highres_size),
params, params,
highres, highres,
upscale, upscale,
@ -374,7 +390,7 @@ def run_inpaint_pipeline(
# apply upscaling and correction # apply upscaling and correction
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(outscale=after_upscale.outscale),
params, params,
upscale=after_upscale, upscale=after_upscale,
chain=chain, chain=chain,
@ -433,21 +449,22 @@ def run_upscale_pipeline(
) -> None: ) -> None:
# set up the chain pipeline, no base stage for upscaling # set up the chain pipeline, no base stage for upscaling
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams(tile_size=params.unet_tile) tile_size = get_base_tile(params, size)
# apply upscaling and correction, before highres # apply upscaling and correction, before highres
first_upscale, after_upscale = split_upscale(upscale) first_upscale, after_upscale = split_upscale(upscale)
if first_upscale: if first_upscale:
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(outscale=first_upscale.outscale, tile_size=tile_size),
params, params,
upscale=first_upscale, upscale=first_upscale,
chain=chain, chain=chain,
) )
# apply highres # apply highres
highres_size = get_highres_tile(server, params, highres, tile_size)
stage_highres( stage_highres(
stage, StageParams(outscale=highres.scale, tile_size=highres_size),
params, params,
highres, highres,
upscale, upscale,
@ -457,7 +474,7 @@ def run_upscale_pipeline(
# apply upscaling and correction, after highres # apply upscaling and correction, after highres
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(outscale=after_upscale.outscale, tile_size=tile_size),
params, params,
upscale=after_upscale, upscale=after_upscale,
chain=chain, chain=chain,
@ -504,12 +521,18 @@ def run_blend_pipeline(
) -> None: ) -> None:
# set up the chain pipeline and base stage # set up the chain pipeline and base stage
chain = ChainPipeline() chain = ChainPipeline()
stage = StageParams() tile_size = get_base_tile(params, size)
chain.stage(BlendMaskStage(), stage, stage_source=sources[1], stage_mask=mask)
chain.stage(
BlendMaskStage(),
StageParams(tile_size=tile_size),
stage_source=sources[1],
stage_mask=mask,
)
# apply upscaling and correction # apply upscaling and correction
stage_upscale_correction( stage_upscale_correction(
stage, StageParams(outscale=upscale.outscale),
params, params,
upscale=upscale, upscale=upscale,
chain=chain, chain=chain,

View File

@ -488,10 +488,14 @@ class HighresParams:
self.method = method self.method = method
self.iterations = iterations self.iterations = iterations
def outscale(self) -> int:
return self.scale**self.iterations
def resize(self, size: Size) -> Size: def resize(self, size: Size) -> Size:
outscale = self.outscale()
return Size( return Size(
size.width * (self.scale**self.iterations), size.width * outscale,
size.height * (self.scale**self.iterations), size.height * outscale,
) )
def tojson(self): def tojson(self):