diff --git a/api/onnx_web/chain/blend_grid.py b/api/onnx_web/chain/blend_grid.py index d5f41b03..234f6c3b 100644 --- a/api/onnx_web/chain/blend_grid.py +++ b/api/onnx_web/chain/blend_grid.py @@ -35,22 +35,23 @@ class BlendGridStage(BaseStage): ) -> StageResult: logger.info("combining source images using grid layout") - size = sources[0].size + images = sources.as_image() + size = images[0].size output = Image.new("RGB", (size[0] * width, size[1] * height)) # TODO: labels if order is None: - order = range(len(sources)) + order = range(len(images)) for i in range(len(order)): x = i % width y = i // width n = order[i] - output.paste(sources[n], (x * size[0], y * size[1])) + output.paste(images[n], (x * size[0], y * size[1])) - return StageResult(images=[*sources, output]) + return StageResult(images=[*images, output]) def outputs( self, diff --git a/api/onnx_web/chain/blend_linear.py b/api/onnx_web/chain/blend_linear.py index 4200e3fb..e4a98d9d 100644 --- a/api/onnx_web/chain/blend_linear.py +++ b/api/onnx_web/chain/blend_linear.py @@ -29,5 +29,5 @@ class BlendLinearStage(BaseStage): logger.info("blending source images using linear interpolation") return StageResult( - images=[Image.blend(source, stage_source, alpha) for source in sources] + images=[Image.blend(source, stage_source, alpha) for source in sources.as_image()] ) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 2e1b2ca0..926331a3 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -40,6 +40,6 @@ class BlendMaskStage(BaseStage): return StageResult( images=[ - Image.composite(stage_source, source, mult_mask) for source in sources + Image.composite(stage_source, source, mult_mask) for source in sources.as_image() ] ) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 8f72b636..1169d4fb 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -33,4 +33,4 @@ class CorrectCodeformerStage(BaseStage): device = worker.get_device() pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str()) - return StageResult(images=[pipe(source) for source in sources]) + return StageResult(images=[pipe(source) for source in sources.as_image()]) diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index f55d54e1..f7d988cc 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -30,7 +30,7 @@ class PersistDiskStage(BaseStage): **kwargs, ) -> StageResult: logger.info( - "persisting images to disk: %s, %s", [s.size for s in sources], output + "persisting %s images to disk: %s", len(sources), output ) for source, name in zip(sources, output): diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index 732f20e9..40a43ccf 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -12,8 +12,8 @@ from ..server import ServerContext from ..utils import is_debug, run_gc from ..worker import ProgressCallback, WorkerContext from .base import BaseStage -from .result import StageResult from .tile import needs_tile, process_tile_order +from .result import StageResult logger = getLogger(__name__) @@ -77,7 +77,7 @@ class ChainPipeline: sources: StageResult, callback: Optional[ProgressCallback], **kwargs, - ) -> StageResult: + ) -> List[Image.Image]: result = self( worker, server, params, sources=sources, callback=callback, **kwargs ) @@ -136,7 +136,7 @@ class ChainPipeline: logger.debug( "running stage %s with %s source images, parameters: %s", name, - len(stage_sources) - stage_sources.count(None), + len(stage_sources), kwargs.keys(), ) @@ -154,7 +154,7 @@ class ChainPipeline: size=kwargs.get("size", None), source=source, ) - for source in stage_sources + for source in stage_sources.as_image() ] ) @@ -162,9 +162,10 @@ class ChainPipeline: if stage_pipe.max_tile > 0: tile = min(stage_pipe.max_tile, stage_params.tile_size) + # TODO: stage_sources will always be defined here if stage_sources or must_tile: stage_results = [] - for source in stage_sources: + for source in stage_sources.as_image(): logger.info( "image contains sources or is larger than tile size of %s, tiling stage", tile, @@ -182,7 +183,7 @@ class ChainPipeline: server, stage_params, per_stage_params, - [source_tile], + StageResult(images=[source_tile]), tile_mask=tile_mask, callback=callback, dims=dims, @@ -193,7 +194,8 @@ class ChainPipeline: for j, image in enumerate(tile_result.as_image()): save_image(server, f"last-tile-{j}.png", image) - return tile_result + # TODO: return whole result + return tile_result.as_image()[0] except Exception: worker.retries = worker.retries - 1 logger.exception( @@ -257,7 +259,8 @@ class ChainPipeline: ) if is_debug(): - save_image(server, "last-stage.png", stage_sources[0]) + for j, image in enumerate(stage_sources.as_image()): + save_image(server, f"last-stage-{j}.png", image) end = monotonic() duration = timedelta(seconds=(end - start)) diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index b6fbc0bf..3bc54e43 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -15,6 +15,10 @@ class StageResult: arrays: Optional[List[np.ndarray]] images: Optional[List[Image.Image]] + @staticmethod + def empty(): + return StageResult(images=[]) + def __init__(self, arrays=None, images=None) -> None: if arrays is not None and images is not None: raise ValueError("stages must only return one type of result") diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 49d63d44..21881e18 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -156,7 +156,7 @@ class SourceTxt2ImgStage(BaseStage): callback=callback, ) - outputs = list(sources) + outputs = sources.as_image() outputs.extend(result.images) logger.debug("produced %s outputs", len(outputs)) return StageResult(images=outputs) diff --git a/api/onnx_web/chain/source_url.py b/api/onnx_web/chain/source_url.py index c8d100e1..b6aa62cd 100644 --- a/api/onnx_web/chain/source_url.py +++ b/api/onnx_web/chain/source_url.py @@ -34,7 +34,7 @@ class SourceURLStage(BaseStage): "source images were passed to a source stage, new images will be appended" ) - outputs = list(sources) + outputs = sources.as_image() for url in source_urls: response = requests.get(url) output = Image.open(BytesIO(response.content)) diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index 41182f73..d68c0042 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -79,28 +79,23 @@ class UpscaleBSRGANStage(BaseStage): logger.trace("BSRGAN input shape: %s", image.shape) scale = upscale.outscale - dest = np.zeros( - ( + logger.trace("BSRGAN output shape: %s", ( image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale, - ) - ) - logger.trace("BSRGAN output shape: %s", dest.shape) + )) - dest = bsrgan(image) + output = bsrgan(image) - dest = np.clip(np.squeeze(dest, axis=0), 0, 1) - dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) - dest = (dest * 255.0).round().astype(np.uint8) - - output = Image.fromarray(dest, "RGB") - logger.debug("output image size: %s x %s", output.width, output.height) + output = np.clip(np.squeeze(output, axis=0), 0, 1) + output = output[[2, 1, 0], :, :].transpose((1, 2, 0)) + output = (output * 255.0).round().astype(np.uint8) + logger.debug("output image shape: %s", output.shape) outputs.append(output) - return StageResult(images=outputs) + return StageResult(arrays=outputs) def steps( self, diff --git a/api/onnx_web/chain/upscale_highres.py b/api/onnx_web/chain/upscale_highres.py index 0e027ca1..bd7f826a 100644 --- a/api/onnx_web/chain/upscale_highres.py +++ b/api/onnx_web/chain/upscale_highres.py @@ -42,7 +42,7 @@ class UpscaleHighresStage(BaseStage): source, callback=callback, ) - for source in sources + for source in sources.as_image() ] return StageResult(images=outputs) diff --git a/api/onnx_web/chain/upscale_simple.py b/api/onnx_web/chain/upscale_simple.py index f19f3b84..0ec0499c 100644 --- a/api/onnx_web/chain/upscale_simple.py +++ b/api/onnx_web/chain/upscale_simple.py @@ -38,13 +38,11 @@ class UpscaleSimpleStage(BaseStage): if method == "bilinear": logger.debug("using bilinear interpolation for highres") - source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR) + outputs.append(source.resize(scaled_size, resample=Image.Resampling.BILINEAR)) elif method == "lanczos": logger.debug("using Lanczos interpolation for highres") - source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS) + outputs.append(source.resize(scaled_size, resample=Image.Resampling.LANCZOS)) else: logger.warning("unknown upscaling method: %s", method) - outputs.append(source) - - return outputs + return StageResult(images=outputs) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 8ff01f68..bf2fa7ea 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -73,4 +73,4 @@ class UpscaleStableDiffusionStage(BaseStage): ) outputs.extend(result.images) - return outputs + return StageResult(images=outputs) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index a622c693..317c73c5 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -105,7 +105,7 @@ def run_txt2img_pipeline( # run and save latents = get_latents_from_seed(params.seed, size, batch=params.batch) progress = worker.get_progress_callback() - images = chain.run(worker, server, params, [], callback=progress, latents=latents) + images = chain.run(worker, server, params, StageResult.empty(), callback=progress, latents=latents) _pairs, loras, inversions, _rest = parse_prompt(params) @@ -200,7 +200,7 @@ def run_img2img_pipeline( # run and append the filtered source progress = worker.get_progress_callback() - images = chain(worker, server, params, [source], callback=progress) + images = chain.run(worker, server, params, StageResult(images=[source]), callback=progress) if source_filter is not None and source_filter != "none": images.append(source) @@ -380,7 +380,7 @@ def run_inpaint_pipeline( # run and save latents = get_latents_from_seed(params.seed, size, batch=params.batch) progress = worker.get_progress_callback() - images = chain(worker, server, params, [source], callback=progress, latents=latents) + images = chain.run(worker, server, params, [source], callback=progress, latents=latents) _pairs, loras, inversions, _rest = parse_prompt(params) for image, output in zip(images, outputs): @@ -455,7 +455,7 @@ def run_upscale_pipeline( # run and save progress = worker.get_progress_callback() - images = chain(worker, server, params, [source], callback=progress) + images = chain.run(worker, server, params, StageResult(images=[source]), callback=progress) _pairs, loras, inversions, _rest = parse_prompt(params) for image, output in zip(images, outputs): @@ -487,7 +487,7 @@ def run_blend_pipeline( outputs: List[str], upscale: UpscaleParams, # highres: HighresParams, - sources: StageResult, + sources: List[Image.Image], mask: Image.Image, ) -> None: # set up the chain pipeline and base stage @@ -505,7 +505,7 @@ def run_blend_pipeline( # run and save progress = worker.get_progress_callback() - images = chain(worker, server, params, sources, callback=progress) + images = chain.run(worker, server, params, StageResult(images=sources), callback=progress) for image, output in zip(images, outputs): dest = save_image(server, output, image, params, size, upscale=upscale) diff --git a/api/tests/chain/test_blend_grid.py b/api/tests/chain/test_blend_grid.py index 8244df5e..b1623019 100644 --- a/api/tests/chain/test_blend_grid.py +++ b/api/tests/chain/test_blend_grid.py @@ -3,19 +3,19 @@ import unittest from PIL import Image from onnx_web.chain.blend_grid import BlendGridStage -from onnx_web.chain.blend_linear import BlendLinearStage +from onnx_web.chain.result import StageResult class BlendGridStageTests(unittest.TestCase): def test_stage(self): stage = BlendGridStage() - sources = [ + sources = StageResult(images=[ Image.new("RGB", (64, 64), "black"), Image.new("RGB", (64, 64), "white"), Image.new("RGB", (64, 64), "black"), Image.new("RGB", (64, 64), "white"), - ] + ]) result = stage.run(None, None, None, None, sources, height=2, width=2) self.assertEqual(len(result), 5) - self.assertEqual(result[-1].getpixel((0,0)), (0, 0, 0)) \ No newline at end of file + self.assertEqual(result.as_image()[-1].getpixel((0,0)), (0, 0, 0)) \ No newline at end of file diff --git a/api/tests/chain/test_blend_linear.py b/api/tests/chain/test_blend_linear.py index 9d20fe55..a983a2e1 100644 --- a/api/tests/chain/test_blend_linear.py +++ b/api/tests/chain/test_blend_linear.py @@ -3,16 +3,17 @@ import unittest from PIL import Image from onnx_web.chain.blend_linear import BlendLinearStage +from onnx_web.chain.result import StageResult class BlendLinearStageTests(unittest.TestCase): def test_stage(self): stage = BlendLinearStage() - sources = [ + sources = StageResult(images=[ Image.new("RGB", (64, 64), "black"), - ] + ]) stage_source = Image.new("RGB", (64, 64), "white") result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source) self.assertEqual(len(result), 1) - self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127)) \ No newline at end of file + self.assertEqual(result.as_image()[0].getpixel((0,0)), (127, 127, 127)) \ No newline at end of file diff --git a/api/tests/chain/test_blend_mask.py b/api/tests/chain/test_blend_mask.py index cf70535f..4fcb8130 100644 --- a/api/tests/chain/test_blend_mask.py +++ b/api/tests/chain/test_blend_mask.py @@ -3,13 +3,14 @@ import unittest from PIL import Image from onnx_web.chain.blend_mask import BlendMaskStage +from onnx_web.chain.result import StageResult from onnx_web.params import HighresParams, UpscaleParams class BlendMaskStageTests(unittest.TestCase): def test_empty(self): stage = BlendMaskStage() - sources = [] + sources = StageResult.empty() result = stage.run( None, None, diff --git a/api/tests/chain/test_correct_codeformer.py b/api/tests/chain/test_correct_codeformer.py index 8203e876..9cc24de0 100644 --- a/api/tests/chain/test_correct_codeformer.py +++ b/api/tests/chain/test_correct_codeformer.py @@ -25,7 +25,7 @@ class CorrectCodeformerStageTests(unittest.TestCase): 0, ) stage = CorrectCodeformerStage() - sources = [] + sources = StageResult.empty() result = stage.run(worker, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams("")) self.assertEqual(len(result), 0) diff --git a/api/tests/chain/test_reduce_crop.py b/api/tests/chain/test_reduce_crop.py index 4e79d8f8..bfc7adc4 100644 --- a/api/tests/chain/test_reduce_crop.py +++ b/api/tests/chain/test_reduce_crop.py @@ -1,13 +1,14 @@ import unittest from onnx_web.chain.reduce_crop import ReduceCropStage +from onnx_web.chain.result import StageResult from onnx_web.params import HighresParams, Size, UpscaleParams class ReduceCropStageTests(unittest.TestCase): def test_empty(self): stage = ReduceCropStage() - sources = [] + sources = StageResult.empty() result = stage.run( None, None, diff --git a/api/tests/chain/test_reduce_thumbnail.py b/api/tests/chain/test_reduce_thumbnail.py index 14cb12a7..8b129672 100644 --- a/api/tests/chain/test_reduce_thumbnail.py +++ b/api/tests/chain/test_reduce_thumbnail.py @@ -2,8 +2,8 @@ import unittest from PIL import Image -from onnx_web.chain.reduce_crop import ReduceCropStage from onnx_web.chain.reduce_thumbnail import ReduceThumbnailStage +from onnx_web.chain.result import StageResult from onnx_web.params import HighresParams, Size, UpscaleParams @@ -11,7 +11,7 @@ class ReduceThumbnailStageTests(unittest.TestCase): def test_empty(self): stage_source = Image.new("RGB", (64, 64)) stage = ReduceThumbnailStage() - sources = [] + sources = StageResult.empty() result = stage.run( None, None, diff --git a/api/tests/chain/test_source_noise.py b/api/tests/chain/test_source_noise.py index 8187a751..f43a8f86 100644 --- a/api/tests/chain/test_source_noise.py +++ b/api/tests/chain/test_source_noise.py @@ -1,4 +1,5 @@ import unittest +from onnx_web.chain.result import StageResult from onnx_web.chain.source_noise import SourceNoiseStage from onnx_web.image.noise_source import noise_source_fill_edge @@ -8,7 +9,7 @@ from onnx_web.params import HighresParams, Size, UpscaleParams class SourceNoiseStageTests(unittest.TestCase): def test_empty(self): stage = SourceNoiseStage() - sources = [] + sources = StageResult.empty() result = stage.run( None, None, diff --git a/api/tests/chain/test_source_s3.py b/api/tests/chain/test_source_s3.py index aad37c5b..9b1e11ea 100644 --- a/api/tests/chain/test_source_s3.py +++ b/api/tests/chain/test_source_s3.py @@ -1,4 +1,5 @@ import unittest +from onnx_web.chain.result import StageResult from onnx_web.chain.source_s3 import SourceS3Stage from onnx_web.params import HighresParams, Size, UpscaleParams @@ -7,7 +8,7 @@ from onnx_web.params import HighresParams, Size, UpscaleParams class SourceS3StageTests(unittest.TestCase): def test_empty(self): stage = SourceS3Stage() - sources = [] + sources = StageResult.empty() result = stage.run( None, None, diff --git a/api/tests/chain/test_source_url.py b/api/tests/chain/test_source_url.py index 1f185b7b..fe7588c7 100644 --- a/api/tests/chain/test_source_url.py +++ b/api/tests/chain/test_source_url.py @@ -1,4 +1,5 @@ import unittest +from onnx_web.chain.result import StageResult from onnx_web.chain.source_url import SourceURLStage from onnx_web.params import HighresParams, Size, UpscaleParams @@ -7,7 +8,7 @@ from onnx_web.params import HighresParams, Size, UpscaleParams class SourceURLStageTests(unittest.TestCase): def test_empty(self): stage = SourceURLStage() - sources = [] + sources = StageResult.empty() result = stage.run( None, None, diff --git a/api/tests/chain/test_upscale_highres.py b/api/tests/chain/test_upscale_highres.py index 95897c2c..8789e447 100644 --- a/api/tests/chain/test_upscale_highres.py +++ b/api/tests/chain/test_upscale_highres.py @@ -1,4 +1,5 @@ import unittest +from onnx_web.chain.result import StageResult from onnx_web.chain.upscale_highres import UpscaleHighresStage from onnx_web.params import HighresParams, UpscaleParams @@ -7,7 +8,7 @@ from onnx_web.params import HighresParams, UpscaleParams class UpscaleHighresStageTests(unittest.TestCase): def test_empty(self): stage = UpscaleHighresStage() - sources = [] + sources = StageResult.empty() result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams("")) self.assertEqual(len(result), 0)