1
0
Fork 0

pass progress on to most stages

This commit is contained in:
Sean Sube 2023-02-12 12:33:36 -06:00
parent 9d1c5dca52
commit 2f6a3afddb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 41 additions and 11 deletions

View File

@ -82,7 +82,8 @@ class ChainPipeline:
TODO: handle List[Image] inputs and outputs TODO: handle List[Image] inputs and outputs
""" """
if callback is not None: if callback is not None:
callback = ChainProgress(callback, start=callback.step) start = callback.step if hasattr(callback, "step") else 0
callback = ChainProgress(callback, start=start)
start = monotonic() start = monotonic()
logger.info( logger.info(
@ -115,7 +116,15 @@ class ChainPipeline:
) )
def stage_tile(tile: Image.Image, _dims) -> Image.Image: def stage_tile(tile: Image.Image, _dims) -> Image.Image:
tile = stage_pipe(job, server, stage_params, params, tile, **kwargs) tile = stage_pipe(
job,
server,
stage_params,
params,
tile,
callback=callback,
**kwargs
)
if is_debug(): if is_debug():
save_image(server, "last-tile.png", tile) save_image(server, "last-tile.png", tile)
@ -131,7 +140,15 @@ class ChainPipeline:
) )
else: else:
logger.info("image within tile size, running stage") logger.info("image within tile size, running stage")
image = stage_pipe(job, server, stage_params, params, image, **kwargs) image = stage_pipe(
job,
server,
stage_params,
params,
image,
callback=callback,
**kwargs
)
logger.info( logger.info(
"finished stage %s, result size: %sx%s", name, image.width, image.height "finished stage %s, result size: %sx%s", name, image.width, image.height

View File

@ -51,6 +51,7 @@ pipeline_schedulers = {
"pndm": PNDMScheduler, "pndm": PNDMScheduler,
} }
def get_scheduler_name(scheduler: Any) -> Optional[str]: def get_scheduler_name(scheduler: Any) -> Optional[str]:
for k, v in pipeline_schedulers.items(): for k, v in pipeline_schedulers.items():
if scheduler == v or scheduler == v.__name__: if scheduler == v or scheduler == v.__name__:

View File

@ -154,7 +154,7 @@ def run_inpaint_pipeline(
tile_order: str, tile_order: str,
) -> None: ) -> None:
# device = job.get_device() # device = job.get_device()
# progress = job.get_progress_callback() progress = job.get_progress_callback()
stage = StageParams(tile_order=tile_order) stage = StageParams(tile_order=tile_order)
image = upscale_outpaint( image = upscale_outpaint(
@ -168,6 +168,7 @@ def run_inpaint_pipeline(
fill_color=fill_color, fill_color=fill_color,
mask_filter=mask_filter, mask_filter=mask_filter,
noise_source=noise_source, noise_source=noise_source,
callback=progress,
) )
logger.info("applying mask filter and generating noise source") logger.info("applying mask filter and generating noise source")
@ -176,7 +177,9 @@ def run_inpaint_pipeline(
else: else:
logger.info("output image size does not match source, skipping post-blend") logger.info("output image size does not match source, skipping post-blend")
image = run_upscale_correction(job, server, stage, params, image, upscale=upscale) image = run_upscale_correction(
job, server, stage, params, image, upscale=upscale, callback=progress
)
dest = save_image(server, output, image) dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale, border=border) save_params(server, output, params, size, upscale=upscale, border=border)
@ -197,11 +200,11 @@ def run_upscale_pipeline(
source_image: Image.Image, source_image: Image.Image,
) -> None: ) -> None:
# device = job.get_device() # device = job.get_device()
# progress = job.get_progress_callback() progress = job.get_progress_callback()
stage = StageParams() stage = StageParams()
image = run_upscale_correction( image = run_upscale_correction(
job, server, stage, params, source_image, upscale=upscale job, server, stage, params, source_image, upscale=upscale, callback=progress
) )
dest = save_image(server, output, image) dest = save_image(server, output, image)

View File

@ -9,7 +9,7 @@ from .chain import (
upscale_resrgan, upscale_resrgan,
upscale_stable_diffusion, upscale_stable_diffusion,
) )
from .device_pool import JobContext from .device_pool import JobContext, ProgressCallback
from .params import ImageParams, SizeChart, StageParams, UpscaleParams from .params import ImageParams, SizeChart, StageParams, UpscaleParams
from .utils import ServerContext from .utils import ServerContext
@ -24,6 +24,7 @@ def run_upscale_correction(
image: Image.Image, image: Image.Image,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
callback: ProgressCallback = None,
) -> Image.Image: ) -> Image.Image:
""" """
This is a convenience method for a chain pipeline that will run upscaling and This is a convenience method for a chain pipeline that will run upscaling and
@ -35,10 +36,10 @@ def run_upscale_correction(
if upscale.scale > 1: if upscale.scale > 1:
if "esrgan" in upscale.upscale_model: if "esrgan" in upscale.upscale_model:
resr_stage = StageParams( esrgan_stage = StageParams(
tile_size=stage.tile_size, outscale=upscale.outscale tile_size=stage.tile_size, outscale=upscale.outscale
) )
chain.append((upscale_resrgan, resr_stage, None)) chain.append((upscale_resrgan, esrgan_stage, None))
elif "stable-diffusion" in upscale.upscale_model: elif "stable-diffusion" in upscale.upscale_model:
mini_tile = min(SizeChart.mini, stage.tile_size) mini_tile = min(SizeChart.mini, stage.tile_size)
sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale) sd_stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale)
@ -57,4 +58,12 @@ def run_upscale_correction(
else: else:
logger.warn("unknown correction model: %s", upscale.correction_model) logger.warn("unknown correction model: %s", upscale.correction_model)
return chain(job, server, params, image, prompt=params.prompt, upscale=upscale) return chain(
job,
server,
params,
image,
prompt=params.prompt,
upscale=upscale,
callback=callback,
)