pass progress on to most stages
This commit is contained in:
parent
9d1c5dca52
commit
2f6a3afddb
|
@ -82,7 +82,8 @@ class ChainPipeline:
|
||||||
TODO: handle List[Image] inputs and outputs
|
TODO: handle List[Image] inputs and outputs
|
||||||
"""
|
"""
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback = ChainProgress(callback, start=callback.step)
|
start = callback.step if hasattr(callback, "step") else 0
|
||||||
|
callback = ChainProgress(callback, start=start)
|
||||||
|
|
||||||
start = monotonic()
|
start = monotonic()
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -115,7 +116,15 @@ class ChainPipeline:
|
||||||
)
|
)
|
||||||
|
|
||||||
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
||||||
tile = stage_pipe(job, server, stage_params, params, tile, **kwargs)
|
tile = stage_pipe(
|
||||||
|
job,
|
||||||
|
server,
|
||||||
|
stage_params,
|
||||||
|
params,
|
||||||
|
tile,
|
||||||
|
callback=callback,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, "last-tile.png", tile)
|
save_image(server, "last-tile.png", tile)
|
||||||
|
@ -131,7 +140,15 @@ class ChainPipeline:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("image within tile size, running stage")
|
logger.info("image within tile size, running stage")
|
||||||
image = stage_pipe(job, server, stage_params, params, image, **kwargs)
|
image = stage_pipe(
|
||||||
|
job,
|
||||||
|
server,
|
||||||
|
stage_params,
|
||||||
|
params,
|
||||||
|
image,
|
||||||
|
callback=callback,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"finished stage %s, result size: %sx%s", name, image.width, image.height
|
"finished stage %s, result size: %sx%s", name, image.width, image.height
|
||||||
|
|
|
@ -51,6 +51,7 @@ pipeline_schedulers = {
|
||||||
"pndm": PNDMScheduler,
|
"pndm": PNDMScheduler,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler_name(scheduler: Any) -> Optional[str]:
|
def get_scheduler_name(scheduler: Any) -> Optional[str]:
|
||||||
for k, v in pipeline_schedulers.items():
|
for k, v in pipeline_schedulers.items():
|
||||||
if scheduler == v or scheduler == v.__name__:
|
if scheduler == v or scheduler == v.__name__:
|
||||||
|
|
|
@ -154,7 +154,7 @@ def run_inpaint_pipeline(
|
||||||
tile_order: str,
|
tile_order: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
# device = job.get_device()
|
# device = job.get_device()
|
||||||
# progress = job.get_progress_callback()
|
progress = job.get_progress_callback()
|
||||||
stage = StageParams(tile_order=tile_order)
|
stage = StageParams(tile_order=tile_order)
|
||||||
|
|
||||||
image = upscale_outpaint(
|
image = upscale_outpaint(
|
||||||
|
@ -168,6 +168,7 @@ def run_inpaint_pipeline(
|
||||||
fill_color=fill_color,
|
fill_color=fill_color,
|
||||||
mask_filter=mask_filter,
|
mask_filter=mask_filter,
|
||||||
noise_source=noise_source,
|
noise_source=noise_source,
|
||||||
|
callback=progress,
|
||||||
)
|
)
|
||||||
logger.info("applying mask filter and generating noise source")
|
logger.info("applying mask filter and generating noise source")
|
||||||
|
|
||||||
|
@ -176,7 +177,9 @@ def run_inpaint_pipeline(
|
||||||
else:
|
else:
|
||||||
logger.info("output image size does not match source, skipping post-blend")
|
logger.info("output image size does not match source, skipping post-blend")
|
||||||
|
|
||||||
image = run_upscale_correction(job, server, stage, params, image, upscale=upscale)
|
image = run_upscale_correction(
|
||||||
|
job, server, stage, params, image, upscale=upscale, callback=progress
|
||||||
|
)
|
||||||
|
|
||||||
dest = save_image(server, output, image)
|
dest = save_image(server, output, image)
|
||||||
save_params(server, output, params, size, upscale=upscale, border=border)
|
save_params(server, output, params, size, upscale=upscale, border=border)
|
||||||
|
@ -197,11 +200,11 @@ def run_upscale_pipeline(
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
# device = job.get_device()
|
# device = job.get_device()
|
||||||
# progress = job.get_progress_callback()
|
progress = job.get_progress_callback()
|
||||||
stage = StageParams()
|
stage = StageParams()
|
||||||
|
|
||||||
image = run_upscale_correction(
|
image = run_upscale_correction(
|
||||||
job, server, stage, params, source_image, upscale=upscale
|
job, server, stage, params, source_image, upscale=upscale, callback=progress
|
||||||
)
|
)
|
||||||
|
|
||||||
dest = save_image(server, output, image)
|
dest = save_image(server, output, image)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from .chain import (
|
||||||
upscale_resrgan,
|
upscale_resrgan,
|
||||||
upscale_stable_diffusion,
|
upscale_stable_diffusion,
|
||||||
)
|
)
|
||||||
from .device_pool import JobContext
|
from .device_pool import JobContext, ProgressCallback
|
||||||
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||||
from .utils import ServerContext
|
from .utils import ServerContext
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ def run_upscale_correction(
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
*,
|
*,
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
|
callback: ProgressCallback = None,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
"""
|
"""
|
||||||
This is a convenience method for a chain pipeline that will run upscaling and
|
This is a convenience method for a chain pipeline that will run upscaling and
|
||||||
|
@ -35,10 +36,10 @@ def run_upscale_correction(
|
||||||
|
|
||||||
if upscale.scale > 1:
|
if upscale.scale > 1:
|
||||||
if "esrgan" in upscale.upscale_model:
|
if "esrgan" in upscale.upscale_model:
|
||||||
resr_stage = StageParams(
|
esrgan_stage = StageParams(
|
||||||
tile_size=stage.tile_size, outscale=upscale.outscale
|
tile_size=stage.tile_size, outscale=upscale.outscale
|
||||||
)
|
)
|
||||||
chain.append((upscale_resrgan, resr_stage, None))
|
chain.append((upscale_resrgan, esrgan_stage, None))
|
||||||
elif "stable-diffusion" in upscale.upscale_model:
|
elif "stable-diffusion" in upscale.upscale_model:
|
||||||
mini_tile = min(SizeChart.mini, stage.tile_size)
|
mini_tile = min(SizeChart.mini, stage.tile_size)
|
||||||
sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
||||||
|
@ -57,4 +58,12 @@ def run_upscale_correction(
|
||||||
else:
|
else:
|
||||||
logger.warn("unknown correction model: %s", upscale.correction_model)
|
logger.warn("unknown correction model: %s", upscale.correction_model)
|
||||||
|
|
||||||
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)
|
return chain(
|
||||||
|
job,
|
||||||
|
server,
|
||||||
|
params,
|
||||||
|
image,
|
||||||
|
prompt=params.prompt,
|
||||||
|
upscale=upscale,
|
||||||
|
callback=callback,
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue