1
0
Fork 0

fix(api): accumulate progress from inpaint pipelines (#90)

This commit is contained in:
Sean Sube 2023-02-12 13:16:17 -06:00
parent b85c806ba7
commit 034be3266e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 29 additions and 8 deletions

View File

@ -5,14 +5,14 @@ formatters:
handlers:
console:
class: logging.StreamHandler
level: INFO
level: DEBUG
formatter: simple
stream: ext://sys.stdout
loggers:
'':
level: INFO
level: DEBUG
handlers: [console]
propagate: True
root:
level: INFO
level: DEBUG
handlers: [console]

View File

@ -47,6 +47,11 @@ class ChainProgress:
def get_total(self) -> int:
return self.step + self.total
@classmethod
def from_progress(cls, parent: ProgressCallback):
start = parent.step if hasattr(parent, "step") else 0
return ChainProgress(parent, start=start)
class ChainPipeline:
"""
@ -82,8 +87,7 @@ class ChainPipeline:
TODO: handle List[Image] inputs and outputs
"""
if callback is not None:
start = callback.step if hasattr(callback, "step") else 0
callback = ChainProgress(callback, start=start)
callback = ChainProgress.from_progress(callback)
start = monotonic()
logger.info(

View File

@ -6,6 +6,8 @@ import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
from PIL import Image, ImageChops
from onnx_web.chain.base import ChainProgress
from ..chain import upscale_outpaint
from ..device_pool import JobContext
from ..output import save_image, save_params
@ -65,7 +67,13 @@ def run_txt2img_pipeline(
image = result.images[0]
image = run_upscale_correction(
job, server, StageParams(), params, image, upscale=upscale, callback=progress,
job,
server,
StageParams(),
params,
image,
upscale=upscale,
callback=progress,
)
dest = save_image(server, output, image)
@ -123,7 +131,13 @@ def run_img2img_pipeline(
image = result.images[0]
image = run_upscale_correction(
job, server, StageParams(), params, image, upscale=upscale, callback=progress,
job,
server,
StageParams(),
params,
image,
upscale=upscale,
callback=progress,
)
dest = save_image(server, output, image)
@ -157,6 +171,9 @@ def run_inpaint_pipeline(
progress = job.get_progress_callback()
stage = StageParams(tile_order=tile_order)
# calling the upscale_outpaint stage directly needs accumulating progress
progress = ChainProgress.from_progress(progress)
image = upscale_outpaint(
job,
server,