fix(api): accumulate progress from inpaint pipelines (#90)
This commit is contained in:
parent
b85c806ba7
commit
034be3266e
|
@ -5,14 +5,14 @@ formatters:
|
||||||
handlers:
|
handlers:
|
||||||
console:
|
console:
|
||||||
class: logging.StreamHandler
|
class: logging.StreamHandler
|
||||||
level: INFO
|
level: DEBUG
|
||||||
formatter: simple
|
formatter: simple
|
||||||
stream: ext://sys.stdout
|
stream: ext://sys.stdout
|
||||||
loggers:
|
loggers:
|
||||||
'':
|
'':
|
||||||
level: INFO
|
level: DEBUG
|
||||||
handlers: [console]
|
handlers: [console]
|
||||||
propagate: True
|
propagate: True
|
||||||
root:
|
root:
|
||||||
level: INFO
|
level: DEBUG
|
||||||
handlers: [console]
|
handlers: [console]
|
|
@ -47,6 +47,11 @@ class ChainProgress:
|
||||||
def get_total(self) -> int:
|
def get_total(self) -> int:
|
||||||
return self.step + self.total
|
return self.step + self.total
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_progress(cls, parent: ProgressCallback):
|
||||||
|
start = parent.step if hasattr(parent, "step") else 0
|
||||||
|
return ChainProgress(parent, start=start)
|
||||||
|
|
||||||
|
|
||||||
class ChainPipeline:
|
class ChainPipeline:
|
||||||
"""
|
"""
|
||||||
|
@ -82,8 +87,7 @@ class ChainPipeline:
|
||||||
TODO: handle List[Image] inputs and outputs
|
TODO: handle List[Image] inputs and outputs
|
||||||
"""
|
"""
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
start = callback.step if hasattr(callback, "step") else 0
|
callback = ChainProgress.from_progress(callback)
|
||||||
callback = ChainProgress(callback, start=start)
|
|
||||||
|
|
||||||
start = monotonic()
|
start = monotonic()
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
@ -6,6 +6,8 @@ import torch
|
||||||
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
|
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
|
||||||
from PIL import Image, ImageChops
|
from PIL import Image, ImageChops
|
||||||
|
|
||||||
|
from onnx_web.chain.base import ChainProgress
|
||||||
|
|
||||||
from ..chain import upscale_outpaint
|
from ..chain import upscale_outpaint
|
||||||
from ..device_pool import JobContext
|
from ..device_pool import JobContext
|
||||||
from ..output import save_image, save_params
|
from ..output import save_image, save_params
|
||||||
|
@ -65,7 +67,13 @@ def run_txt2img_pipeline(
|
||||||
|
|
||||||
image = result.images[0]
|
image = result.images[0]
|
||||||
image = run_upscale_correction(
|
image = run_upscale_correction(
|
||||||
job, server, StageParams(), params, image, upscale=upscale, callback=progress,
|
job,
|
||||||
|
server,
|
||||||
|
StageParams(),
|
||||||
|
params,
|
||||||
|
image,
|
||||||
|
upscale=upscale,
|
||||||
|
callback=progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
dest = save_image(server, output, image)
|
dest = save_image(server, output, image)
|
||||||
|
@ -123,7 +131,13 @@ def run_img2img_pipeline(
|
||||||
|
|
||||||
image = result.images[0]
|
image = result.images[0]
|
||||||
image = run_upscale_correction(
|
image = run_upscale_correction(
|
||||||
job, server, StageParams(), params, image, upscale=upscale, callback=progress,
|
job,
|
||||||
|
server,
|
||||||
|
StageParams(),
|
||||||
|
params,
|
||||||
|
image,
|
||||||
|
upscale=upscale,
|
||||||
|
callback=progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
dest = save_image(server, output, image)
|
dest = save_image(server, output, image)
|
||||||
|
@ -157,6 +171,9 @@ def run_inpaint_pipeline(
|
||||||
progress = job.get_progress_callback()
|
progress = job.get_progress_callback()
|
||||||
stage = StageParams(tile_order=tile_order)
|
stage = StageParams(tile_order=tile_order)
|
||||||
|
|
||||||
|
# calling the upscale_outpaint stage directly needs accumulating progress
|
||||||
|
progress = ChainProgress.from_progress(progress)
|
||||||
|
|
||||||
image = upscale_outpaint(
|
image = upscale_outpaint(
|
||||||
job,
|
job,
|
||||||
server,
|
server,
|
||||||
|
|
Loading…
Reference in New Issue