fix(api): consistently handle tile size across premade pipelines
This commit is contained in:
parent
02447f5fd6
commit
d78e843af4
|
@ -34,6 +34,24 @@ from .utils import get_latents_from_seed, parse_prompt
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_tile(params: ImageParams, size: Size) -> int:
|
||||||
|
if params.is_panorama():
|
||||||
|
tile = max(params.unet_tile, size.width, size.height)
|
||||||
|
logger.debug("adjusting tile size for panorama to %s", tile)
|
||||||
|
return tile
|
||||||
|
|
||||||
|
return params.unet_tile
|
||||||
|
|
||||||
|
|
||||||
|
def get_highres_tile(
|
||||||
|
server: ServerContext, params: ImageParams, highres: HighresParams, tile: int
|
||||||
|
) -> int:
|
||||||
|
if params.is_panorama() and server.has_feature("panorama-highres"):
|
||||||
|
return tile * highres.scale
|
||||||
|
|
||||||
|
return params.unet_tile
|
||||||
|
|
||||||
|
|
||||||
def run_txt2img_pipeline(
|
def run_txt2img_pipeline(
|
||||||
worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
|
@ -44,11 +62,7 @@ def run_txt2img_pipeline(
|
||||||
highres: HighresParams,
|
highres: HighresParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
# if using panorama, the pipeline will tile itself (views)
|
# if using panorama, the pipeline will tile itself (views)
|
||||||
if params.is_panorama():
|
tile_size = get_base_tile(params, size)
|
||||||
tile_size = max(params.unet_tile, size.width, size.height)
|
|
||||||
logger.debug("adjusting tile size for panorama to %s", tile_size)
|
|
||||||
else:
|
|
||||||
tile_size = params.unet_tile
|
|
||||||
|
|
||||||
# prepare the chain pipeline and first stage
|
# prepare the chain pipeline and first stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
|
@ -63,12 +77,8 @@ def run_txt2img_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply upscaling and correction, before highres
|
# apply upscaling and correction, before highres
|
||||||
highres_size = params.unet_tile
|
highres_size = get_highres_tile(server, params, highres, tile_size)
|
||||||
if params.is_panorama():
|
if params.is_panorama():
|
||||||
if server.has_feature("panorama-highres"):
|
|
||||||
# run the whole highres pass with one panorama call
|
|
||||||
highres_size = tile_size * highres.scale
|
|
||||||
|
|
||||||
chain.stage(
|
chain.stage(
|
||||||
BlendDenoiseStage(),
|
BlendDenoiseStage(),
|
||||||
StageParams(tile_size=highres_size),
|
StageParams(tile_size=highres_size),
|
||||||
|
@ -151,13 +161,13 @@ def run_img2img_pipeline(
|
||||||
source = f(server, source)
|
source = f(server, source)
|
||||||
|
|
||||||
# prepare the chain pipeline and first stage
|
# prepare the chain pipeline and first stage
|
||||||
|
tile_size = get_base_tile(params, Size(*source.size))
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams(
|
|
||||||
tile_size=params.unet_tile,
|
|
||||||
)
|
|
||||||
chain.stage(
|
chain.stage(
|
||||||
BlendImg2ImgStage(),
|
BlendImg2ImgStage(),
|
||||||
stage,
|
StageParams(
|
||||||
|
tile_size=tile_size,
|
||||||
|
),
|
||||||
prompt_index=0,
|
prompt_index=0,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
overlap=params.vae_overlap,
|
overlap=params.vae_overlap,
|
||||||
|
@ -167,7 +177,10 @@ def run_img2img_pipeline(
|
||||||
first_upscale, after_upscale = split_upscale(upscale)
|
first_upscale, after_upscale = split_upscale(upscale)
|
||||||
if first_upscale:
|
if first_upscale:
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(
|
||||||
|
outscale=first_upscale.outscale,
|
||||||
|
tile_size=tile_size,
|
||||||
|
),
|
||||||
params,
|
params,
|
||||||
upscale=first_upscale,
|
upscale=first_upscale,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
|
@ -177,13 +190,16 @@ def run_img2img_pipeline(
|
||||||
for _i in range(params.loopback):
|
for _i in range(params.loopback):
|
||||||
chain.stage(
|
chain.stage(
|
||||||
BlendImg2ImgStage(),
|
BlendImg2ImgStage(),
|
||||||
stage,
|
StageParams(
|
||||||
|
tile_size=tile_size,
|
||||||
|
),
|
||||||
strength=strength,
|
strength=strength,
|
||||||
)
|
)
|
||||||
|
|
||||||
# highres, if selected
|
# highres, if selected
|
||||||
|
highres_size = get_highres_tile(server, params, highres, tile_size)
|
||||||
stage_highres(
|
stage_highres(
|
||||||
stage,
|
StageParams(tile_size=highres_size, outscale=highres.scale),
|
||||||
params,
|
params,
|
||||||
highres,
|
highres,
|
||||||
upscale,
|
upscale,
|
||||||
|
@ -193,7 +209,7 @@ def run_img2img_pipeline(
|
||||||
|
|
||||||
# apply upscaling and correction, after highres
|
# apply upscaling and correction, after highres
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(tile_size=tile_size, outscale=after_upscale.scale),
|
||||||
params,
|
params,
|
||||||
upscale=after_upscale,
|
upscale=after_upscale,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
|
@ -252,7 +268,7 @@ def run_inpaint_pipeline(
|
||||||
full_res_inpaint_padding: float,
|
full_res_inpaint_padding: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.debug("building inpaint pipeline")
|
logger.debug("building inpaint pipeline")
|
||||||
tile_size = params.unet_tile
|
tile_size = get_base_tile(params, size)
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
# if no mask was provided, keep the full source image
|
# if no mask was provided, keep the full source image
|
||||||
|
@ -339,10 +355,9 @@ def run_inpaint_pipeline(
|
||||||
|
|
||||||
# set up the chain pipeline and base stage
|
# set up the chain pipeline and base stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams(tile_order=tile_order, tile_size=tile_size)
|
|
||||||
chain.stage(
|
chain.stage(
|
||||||
UpscaleOutpaintStage(),
|
UpscaleOutpaintStage(),
|
||||||
stage,
|
StageParams(tile_order=tile_order, tile_size=tile_size),
|
||||||
border=border,
|
border=border,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
fill_color=fill_color,
|
fill_color=fill_color,
|
||||||
|
@ -356,15 +371,16 @@ def run_inpaint_pipeline(
|
||||||
first_upscale, after_upscale = split_upscale(upscale)
|
first_upscale, after_upscale = split_upscale(upscale)
|
||||||
if first_upscale:
|
if first_upscale:
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(outscale=first_upscale.outscale, tile_size=tile_size),
|
||||||
params,
|
params,
|
||||||
upscale=first_upscale,
|
upscale=first_upscale,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply highres
|
# apply highres
|
||||||
|
highres_size = get_highres_tile(server, params, highres, tile_size)
|
||||||
stage_highres(
|
stage_highres(
|
||||||
stage,
|
StageParams(outscale=highres.scale, tile_size=highres_size),
|
||||||
params,
|
params,
|
||||||
highres,
|
highres,
|
||||||
upscale,
|
upscale,
|
||||||
|
@ -374,7 +390,7 @@ def run_inpaint_pipeline(
|
||||||
|
|
||||||
# apply upscaling and correction
|
# apply upscaling and correction
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(outscale=after_upscale.outscale),
|
||||||
params,
|
params,
|
||||||
upscale=after_upscale,
|
upscale=after_upscale,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
|
@ -433,21 +449,22 @@ def run_upscale_pipeline(
|
||||||
) -> None:
|
) -> None:
|
||||||
# set up the chain pipeline, no base stage for upscaling
|
# set up the chain pipeline, no base stage for upscaling
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams(tile_size=params.unet_tile)
|
tile_size = get_base_tile(params, size)
|
||||||
|
|
||||||
# apply upscaling and correction, before highres
|
# apply upscaling and correction, before highres
|
||||||
first_upscale, after_upscale = split_upscale(upscale)
|
first_upscale, after_upscale = split_upscale(upscale)
|
||||||
if first_upscale:
|
if first_upscale:
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(outscale=first_upscale.outscale, tile_size=tile_size),
|
||||||
params,
|
params,
|
||||||
upscale=first_upscale,
|
upscale=first_upscale,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
)
|
)
|
||||||
|
|
||||||
# apply highres
|
# apply highres
|
||||||
|
highres_size = get_highres_tile(server, params, highres, tile_size)
|
||||||
stage_highres(
|
stage_highres(
|
||||||
stage,
|
StageParams(outscale=highres.scale, tile_size=highres_size),
|
||||||
params,
|
params,
|
||||||
highres,
|
highres,
|
||||||
upscale,
|
upscale,
|
||||||
|
@ -457,7 +474,7 @@ def run_upscale_pipeline(
|
||||||
|
|
||||||
# apply upscaling and correction, after highres
|
# apply upscaling and correction, after highres
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(outscale=after_upscale.outscale, tile_size=tile_size),
|
||||||
params,
|
params,
|
||||||
upscale=after_upscale,
|
upscale=after_upscale,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
|
@ -504,12 +521,18 @@ def run_blend_pipeline(
|
||||||
) -> None:
|
) -> None:
|
||||||
# set up the chain pipeline and base stage
|
# set up the chain pipeline and base stage
|
||||||
chain = ChainPipeline()
|
chain = ChainPipeline()
|
||||||
stage = StageParams()
|
tile_size = get_base_tile(params, size)
|
||||||
chain.stage(BlendMaskStage(), stage, stage_source=sources[1], stage_mask=mask)
|
|
||||||
|
chain.stage(
|
||||||
|
BlendMaskStage(),
|
||||||
|
StageParams(tile_size=tile_size),
|
||||||
|
stage_source=sources[1],
|
||||||
|
stage_mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
# apply upscaling and correction
|
# apply upscaling and correction
|
||||||
stage_upscale_correction(
|
stage_upscale_correction(
|
||||||
stage,
|
StageParams(outscale=upscale.outscale),
|
||||||
params,
|
params,
|
||||||
upscale=upscale,
|
upscale=upscale,
|
||||||
chain=chain,
|
chain=chain,
|
||||||
|
|
|
@ -488,10 +488,14 @@ class HighresParams:
|
||||||
self.method = method
|
self.method = method
|
||||||
self.iterations = iterations
|
self.iterations = iterations
|
||||||
|
|
||||||
|
def outscale(self) -> int:
|
||||||
|
return self.scale**self.iterations
|
||||||
|
|
||||||
def resize(self, size: Size) -> Size:
|
def resize(self, size: Size) -> Size:
|
||||||
|
outscale = self.outscale()
|
||||||
return Size(
|
return Size(
|
||||||
size.width * (self.scale**self.iterations),
|
size.width * outscale,
|
||||||
size.height * (self.scale**self.iterations),
|
size.height * outscale,
|
||||||
)
|
)
|
||||||
|
|
||||||
def tojson(self):
|
def tojson(self):
|
||||||
|
|
Loading…
Reference in New Issue