fix(api): accumulate progress from inpaint pipelines (#90)
This commit is contained in:
parent
b85c806ba7
commit
034be3266e
|
@ -5,14 +5,14 @@ formatters:
|
|||
handlers:
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
level: INFO
|
||||
level: DEBUG
|
||||
formatter: simple
|
||||
stream: ext://sys.stdout
|
||||
loggers:
|
||||
'':
|
||||
level: INFO
|
||||
level: DEBUG
|
||||
handlers: [console]
|
||||
propagate: True
|
||||
root:
|
||||
level: INFO
|
||||
level: DEBUG
|
||||
handlers: [console]
|
|
@ -47,6 +47,11 @@ class ChainProgress:
|
|||
def get_total(self) -> int:
|
||||
return self.step + self.total
|
||||
|
||||
@classmethod
|
||||
def from_progress(cls, parent: ProgressCallback):
|
||||
start = parent.step if hasattr(parent, "step") else 0
|
||||
return ChainProgress(parent, start=start)
|
||||
|
||||
|
||||
class ChainPipeline:
|
||||
"""
|
||||
|
@ -82,8 +87,7 @@ class ChainPipeline:
|
|||
TODO: handle List[Image] inputs and outputs
|
||||
"""
|
||||
if callback is not None:
|
||||
start = callback.step if hasattr(callback, "step") else 0
|
||||
callback = ChainProgress(callback, start=start)
|
||||
callback = ChainProgress.from_progress(callback)
|
||||
|
||||
start = monotonic()
|
||||
logger.info(
|
||||
|
|
|
@ -6,6 +6,8 @@ import torch
|
|||
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from onnx_web.chain.base import ChainProgress
|
||||
|
||||
from ..chain import upscale_outpaint
|
||||
from ..device_pool import JobContext
|
||||
from ..output import save_image, save_params
|
||||
|
@ -65,7 +67,13 @@ def run_txt2img_pipeline(
|
|||
|
||||
image = result.images[0]
|
||||
image = run_upscale_correction(
|
||||
job, server, StageParams(), params, image, upscale=upscale, callback=progress,
|
||||
job,
|
||||
server,
|
||||
StageParams(),
|
||||
params,
|
||||
image,
|
||||
upscale=upscale,
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
|
@ -123,7 +131,13 @@ def run_img2img_pipeline(
|
|||
|
||||
image = result.images[0]
|
||||
image = run_upscale_correction(
|
||||
job, server, StageParams(), params, image, upscale=upscale, callback=progress,
|
||||
job,
|
||||
server,
|
||||
StageParams(),
|
||||
params,
|
||||
image,
|
||||
upscale=upscale,
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
|
@ -157,6 +171,9 @@ def run_inpaint_pipeline(
|
|||
progress = job.get_progress_callback()
|
||||
stage = StageParams(tile_order=tile_order)
|
||||
|
||||
# calling the upscale_outpaint stage directly needs accumulating progress
|
||||
progress = ChainProgress.from_progress(progress)
|
||||
|
||||
image = upscale_outpaint(
|
||||
job,
|
||||
server,
|
||||
|
|
Loading…
Reference in New Issue