From af416c252ddc50060fe9af70adc65cdc2ebe9a23 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 29 Jun 2023 23:36:45 -0500 Subject: [PATCH] feat(api): make chain pipeline work without a source image --- api/onnx_web/chain/base.py | 35 ++++++++++++++++++++++------------- api/onnx_web/diffusers/run.py | 10 ++++++++-- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 112b64bf..26d22a5b 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -99,27 +99,36 @@ class ChainPipeline: callback = ChainProgress.from_progress(callback) start = monotonic() - logger.info( - "running pipeline on source image with dimensions %sx%s", - source.width, - source.height, - ) image = source + if source is not None: + logger.info( + "running pipeline on source image with dimensions %sx%s", + source.width, + source.height, + ) + else: + logger.info("running pipeline without source image") + for stage_pipe, stage_params, stage_kwargs in self.stages: name = stage_params.name or stage_pipe.__name__ kwargs = stage_kwargs or {} kwargs = {**pipeline_kwargs, **kwargs} - logger.debug( - "running stage %s on image with dimensions %sx%s, %s", - name, - image.width, - image.height, - kwargs.keys(), - ) + if image is not None: + logger.debug( + "running stage %s on source image with dimensions %sx%s, %s", + name, + image.width, + image.height, + kwargs.keys(), + ) + else: + logger.debug( + "running stage %s without source image, %s", name, kwargs.keys() + ) - if ( + if image is not None and ( image.width > stage_params.tile_size or image.height > stage_params.tile_size ): diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index b1c48928..89646fd6 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -3,7 +3,13 @@ from typing import Any, List, Optional from PIL import Image -from ..chain import blend_img2img, blend_mask, upscale_highres, upscale_outpaint +from ..chain import ( + blend_img2img, + blend_mask, + source_txt2img, + upscale_highres, + upscale_outpaint, +) from ..chain.base import ChainPipeline from ..output import save_image from ..params import ( @@ -36,7 +42,7 @@ def run_txt2img_pipeline( # prepare the chain pipeline and first stage chain = ChainPipeline() stage = StageParams() - chain.append((blend_img2img, stage, None)) + chain.append((source_txt2img, stage, None)) # apply upscaling and correction, before highres first_upscale, after_upscale = split_upscale(upscale)