pass progress on to most stages
This commit is contained in:
parent
9d1c5dca52
commit
2f6a3afddb
|
@ -82,7 +82,8 @@ class ChainPipeline:
|
|||
TODO: handle List[Image] inputs and outputs
|
||||
"""
|
||||
if callback is not None:
|
||||
callback = ChainProgress(callback, start=callback.step)
|
||||
start = callback.step if hasattr(callback, "step") else 0
|
||||
callback = ChainProgress(callback, start=start)
|
||||
|
||||
start = monotonic()
|
||||
logger.info(
|
||||
|
@ -115,7 +116,15 @@ class ChainPipeline:
|
|||
)
|
||||
|
||||
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
||||
tile = stage_pipe(job, server, stage_params, params, tile, **kwargs)
|
||||
tile = stage_pipe(
|
||||
job,
|
||||
server,
|
||||
stage_params,
|
||||
params,
|
||||
tile,
|
||||
callback=callback,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "last-tile.png", tile)
|
||||
|
@ -131,7 +140,15 @@ class ChainPipeline:
|
|||
)
|
||||
else:
|
||||
logger.info("image within tile size, running stage")
|
||||
image = stage_pipe(job, server, stage_params, params, image, **kwargs)
|
||||
image = stage_pipe(
|
||||
job,
|
||||
server,
|
||||
stage_params,
|
||||
params,
|
||||
image,
|
||||
callback=callback,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"finished stage %s, result size: %sx%s", name, image.width, image.height
|
||||
|
|
|
@ -51,6 +51,7 @@ pipeline_schedulers = {
|
|||
"pndm": PNDMScheduler,
|
||||
}
|
||||
|
||||
|
||||
def get_scheduler_name(scheduler: Any) -> Optional[str]:
|
||||
for k, v in pipeline_schedulers.items():
|
||||
if scheduler == v or scheduler == v.__name__:
|
||||
|
|
|
@ -154,7 +154,7 @@ def run_inpaint_pipeline(
|
|||
tile_order: str,
|
||||
) -> None:
|
||||
# device = job.get_device()
|
||||
# progress = job.get_progress_callback()
|
||||
progress = job.get_progress_callback()
|
||||
stage = StageParams(tile_order=tile_order)
|
||||
|
||||
image = upscale_outpaint(
|
||||
|
@ -168,6 +168,7 @@ def run_inpaint_pipeline(
|
|||
fill_color=fill_color,
|
||||
mask_filter=mask_filter,
|
||||
noise_source=noise_source,
|
||||
callback=progress,
|
||||
)
|
||||
logger.info("applying mask filter and generating noise source")
|
||||
|
||||
|
@ -176,7 +177,9 @@ def run_inpaint_pipeline(
|
|||
else:
|
||||
logger.info("output image size does not match source, skipping post-blend")
|
||||
|
||||
image = run_upscale_correction(job, server, stage, params, image, upscale=upscale)
|
||||
image = run_upscale_correction(
|
||||
job, server, stage, params, image, upscale=upscale, callback=progress
|
||||
)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale, border=border)
|
||||
|
@ -197,11 +200,11 @@ def run_upscale_pipeline(
|
|||
source_image: Image.Image,
|
||||
) -> None:
|
||||
# device = job.get_device()
|
||||
# progress = job.get_progress_callback()
|
||||
progress = job.get_progress_callback()
|
||||
stage = StageParams()
|
||||
|
||||
image = run_upscale_correction(
|
||||
job, server, stage, params, source_image, upscale=upscale
|
||||
job, server, stage, params, source_image, upscale=upscale, callback=progress
|
||||
)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
|
|
|
@ -9,7 +9,7 @@ from .chain import (
|
|||
upscale_resrgan,
|
||||
upscale_stable_diffusion,
|
||||
)
|
||||
from .device_pool import JobContext
|
||||
from .device_pool import JobContext, ProgressCallback
|
||||
from .params import ImageParams, SizeChart, StageParams, UpscaleParams
|
||||
from .utils import ServerContext
|
||||
|
||||
|
@ -24,6 +24,7 @@ def run_upscale_correction(
|
|||
image: Image.Image,
|
||||
*,
|
||||
upscale: UpscaleParams,
|
||||
callback: ProgressCallback = None,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
This is a convenience method for a chain pipeline that will run upscaling and
|
||||
|
@ -35,10 +36,10 @@ def run_upscale_correction(
|
|||
|
||||
if upscale.scale > 1:
|
||||
if "esrgan" in upscale.upscale_model:
|
||||
resr_stage = StageParams(
|
||||
esrgan_stage = StageParams(
|
||||
tile_size=stage.tile_size, outscale=upscale.outscale
|
||||
)
|
||||
chain.append((upscale_resrgan, resr_stage, None))
|
||||
chain.append((upscale_resrgan, esrgan_stage, None))
|
||||
elif "stable-diffusion" in upscale.upscale_model:
|
||||
mini_tile = min(SizeChart.mini, stage.tile_size)
|
||||
sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
|
||||
|
@ -57,4 +58,12 @@ def run_upscale_correction(
|
|||
else:
|
||||
logger.warn("unknown correction model: %s", upscale.correction_model)
|
||||
|
||||
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale)
|
||||
return chain(
|
||||
job,
|
||||
server,
|
||||
params,
|
||||
image,
|
||||
prompt=params.prompt,
|
||||
upscale=upscale,
|
||||
callback=callback,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue