1
0
Fork 0

move stages and tests to using stage result

This commit is contained in:
Sean Sube 2023-11-18 21:35:00 -06:00
parent 7e6749e0d7
commit eb77c83d80
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
24 changed files with 68 additions and 60 deletions

View File

@ -35,22 +35,23 @@ class BlendGridStage(BaseStage):
) -> StageResult: ) -> StageResult:
logger.info("combining source images using grid layout") logger.info("combining source images using grid layout")
size = sources[0].size images = sources.as_image()
size = images[0].size
output = Image.new("RGB", (size[0] * width, size[1] * height)) output = Image.new("RGB", (size[0] * width, size[1] * height))
# TODO: labels # TODO: labels
if order is None: if order is None:
order = range(len(sources)) order = range(len(images))
for i in range(len(order)): for i in range(len(order)):
x = i % width x = i % width
y = i // width y = i // width
n = order[i] n = order[i]
output.paste(sources[n], (x * size[0], y * size[1])) output.paste(images[n], (x * size[0], y * size[1]))
return StageResult(images=[*sources, output]) return StageResult(images=[*images, output])
def outputs( def outputs(
self, self,

View File

@ -29,5 +29,5 @@ class BlendLinearStage(BaseStage):
logger.info("blending source images using linear interpolation") logger.info("blending source images using linear interpolation")
return StageResult( return StageResult(
images=[Image.blend(source, stage_source, alpha) for source in sources] images=[Image.blend(source, stage_source, alpha) for source in sources.as_image()]
) )

View File

@ -40,6 +40,6 @@ class BlendMaskStage(BaseStage):
return StageResult( return StageResult(
images=[ images=[
Image.composite(stage_source, source, mult_mask) for source in sources Image.composite(stage_source, source, mult_mask) for source in sources.as_image()
] ]
) )

View File

@ -33,4 +33,4 @@ class CorrectCodeformerStage(BaseStage):
device = worker.get_device() device = worker.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str()) pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return StageResult(images=[pipe(source) for source in sources]) return StageResult(images=[pipe(source) for source in sources.as_image()])

View File

@ -30,7 +30,7 @@ class PersistDiskStage(BaseStage):
**kwargs, **kwargs,
) -> StageResult: ) -> StageResult:
logger.info( logger.info(
"persisting images to disk: %s, %s", [s.size for s in sources], output "persisting %s images to disk: %s", len(sources), output
) )
for source, name in zip(sources, output): for source, name in zip(sources, output):

View File

@ -12,8 +12,8 @@ from ..server import ServerContext
from ..utils import is_debug, run_gc from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
from .result import StageResult
from .tile import needs_tile, process_tile_order from .tile import needs_tile, process_tile_order
from .result import StageResult
logger = getLogger(__name__) logger = getLogger(__name__)
@ -77,7 +77,7 @@ class ChainPipeline:
sources: StageResult, sources: StageResult,
callback: Optional[ProgressCallback], callback: Optional[ProgressCallback],
**kwargs, **kwargs,
) -> StageResult: ) -> List[Image.Image]:
result = self( result = self(
worker, server, params, sources=sources, callback=callback, **kwargs worker, server, params, sources=sources, callback=callback, **kwargs
) )
@ -136,7 +136,7 @@ class ChainPipeline:
logger.debug( logger.debug(
"running stage %s with %s source images, parameters: %s", "running stage %s with %s source images, parameters: %s",
name, name,
len(stage_sources) - stage_sources.count(None), len(stage_sources),
kwargs.keys(), kwargs.keys(),
) )
@ -154,7 +154,7 @@ class ChainPipeline:
size=kwargs.get("size", None), size=kwargs.get("size", None),
source=source, source=source,
) )
for source in stage_sources for source in stage_sources.as_image()
] ]
) )
@ -162,9 +162,10 @@ class ChainPipeline:
if stage_pipe.max_tile > 0: if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size) tile = min(stage_pipe.max_tile, stage_params.tile_size)
# TODO: stage_sources will always be defined here
if stage_sources or must_tile: if stage_sources or must_tile:
stage_results = [] stage_results = []
for source in stage_sources: for source in stage_sources.as_image():
logger.info( logger.info(
"image contains sources or is larger than tile size of %s, tiling stage", "image contains sources or is larger than tile size of %s, tiling stage",
tile, tile,
@ -182,7 +183,7 @@ class ChainPipeline:
server, server,
stage_params, stage_params,
per_stage_params, per_stage_params,
[source_tile], StageResult(images=[source_tile]),
tile_mask=tile_mask, tile_mask=tile_mask,
callback=callback, callback=callback,
dims=dims, dims=dims,
@ -193,7 +194,8 @@ class ChainPipeline:
for j, image in enumerate(tile_result.as_image()): for j, image in enumerate(tile_result.as_image()):
save_image(server, f"last-tile-{j}.png", image) save_image(server, f"last-tile-{j}.png", image)
return tile_result # TODO: return whole result
return tile_result.as_image()[0]
except Exception: except Exception:
worker.retries = worker.retries - 1 worker.retries = worker.retries - 1
logger.exception( logger.exception(
@ -257,7 +259,8 @@ class ChainPipeline:
) )
if is_debug(): if is_debug():
save_image(server, "last-stage.png", stage_sources[0]) for j, image in enumerate(stage_sources.as_image()):
save_image(server, f"last-stage-{j}.png", image)
end = monotonic() end = monotonic()
duration = timedelta(seconds=(end - start)) duration = timedelta(seconds=(end - start))

View File

@ -15,6 +15,10 @@ class StageResult:
arrays: Optional[List[np.ndarray]] arrays: Optional[List[np.ndarray]]
images: Optional[List[Image.Image]] images: Optional[List[Image.Image]]
@staticmethod
def empty():
return StageResult(images=[])
def __init__(self, arrays=None, images=None) -> None: def __init__(self, arrays=None, images=None) -> None:
if arrays is not None and images is not None: if arrays is not None and images is not None:
raise ValueError("stages must only return one type of result") raise ValueError("stages must only return one type of result")

View File

@ -156,7 +156,7 @@ class SourceTxt2ImgStage(BaseStage):
callback=callback, callback=callback,
) )
outputs = list(sources) outputs = sources.as_image()
outputs.extend(result.images) outputs.extend(result.images)
logger.debug("produced %s outputs", len(outputs)) logger.debug("produced %s outputs", len(outputs))
return StageResult(images=outputs) return StageResult(images=outputs)

View File

@ -34,7 +34,7 @@ class SourceURLStage(BaseStage):
"source images were passed to a source stage, new images will be appended" "source images were passed to a source stage, new images will be appended"
) )
outputs = list(sources) outputs = sources.as_image()
for url in source_urls: for url in source_urls:
response = requests.get(url) response = requests.get(url)
output = Image.open(BytesIO(response.content)) output = Image.open(BytesIO(response.content))

View File

@ -79,28 +79,23 @@ class UpscaleBSRGANStage(BaseStage):
logger.trace("BSRGAN input shape: %s", image.shape) logger.trace("BSRGAN input shape: %s", image.shape)
scale = upscale.outscale scale = upscale.outscale
dest = np.zeros( logger.trace("BSRGAN output shape: %s", (
(
image.shape[0], image.shape[0],
image.shape[1], image.shape[1],
image.shape[2] * scale, image.shape[2] * scale,
image.shape[3] * scale, image.shape[3] * scale,
) ))
)
logger.trace("BSRGAN output shape: %s", dest.shape)
dest = bsrgan(image) output = bsrgan(image)
dest = np.clip(np.squeeze(dest, axis=0), 0, 1) output = np.clip(np.squeeze(output, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8) output = (output * 255.0).round().astype(np.uint8)
output = Image.fromarray(dest, "RGB")
logger.debug("output image size: %s x %s", output.width, output.height)
logger.debug("output image shape: %s", output.shape)
outputs.append(output) outputs.append(output)
return StageResult(images=outputs) return StageResult(arrays=outputs)
def steps( def steps(
self, self,

View File

@ -42,7 +42,7 @@ class UpscaleHighresStage(BaseStage):
source, source,
callback=callback, callback=callback,
) )
for source in sources for source in sources.as_image()
] ]
return StageResult(images=outputs) return StageResult(images=outputs)

View File

@ -38,13 +38,11 @@ class UpscaleSimpleStage(BaseStage):
if method == "bilinear": if method == "bilinear":
logger.debug("using bilinear interpolation for highres") logger.debug("using bilinear interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR) outputs.append(source.resize(scaled_size, resample=Image.Resampling.BILINEAR))
elif method == "lanczos": elif method == "lanczos":
logger.debug("using Lanczos interpolation for highres") logger.debug("using Lanczos interpolation for highres")
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS) outputs.append(source.resize(scaled_size, resample=Image.Resampling.LANCZOS))
else: else:
logger.warning("unknown upscaling method: %s", method) logger.warning("unknown upscaling method: %s", method)
outputs.append(source) return StageResult(images=outputs)
return outputs

View File

@ -73,4 +73,4 @@ class UpscaleStableDiffusionStage(BaseStage):
) )
outputs.extend(result.images) outputs.extend(result.images)
return outputs return StageResult(images=outputs)

View File

@ -105,7 +105,7 @@ def run_txt2img_pipeline(
# run and save # run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch) latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain.run(worker, server, params, [], callback=progress, latents=latents) images = chain.run(worker, server, params, StageResult.empty(), callback=progress, latents=latents)
_pairs, loras, inversions, _rest = parse_prompt(params) _pairs, loras, inversions, _rest = parse_prompt(params)
@ -200,7 +200,7 @@ def run_img2img_pipeline(
# run and append the filtered source # run and append the filtered source
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain(worker, server, params, [source], callback=progress) images = chain.run(worker, server, params, StageResult(images=[source]), callback=progress)
if source_filter is not None and source_filter != "none": if source_filter is not None and source_filter != "none":
images.append(source) images.append(source)
@ -380,7 +380,7 @@ def run_inpaint_pipeline(
# run and save # run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch) latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain(worker, server, params, [source], callback=progress, latents=latents) images = chain.run(worker, server, params, [source], callback=progress, latents=latents)
_pairs, loras, inversions, _rest = parse_prompt(params) _pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs): for image, output in zip(images, outputs):
@ -455,7 +455,7 @@ def run_upscale_pipeline(
# run and save # run and save
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain(worker, server, params, [source], callback=progress) images = chain.run(worker, server, params, StageResult(images=[source]), callback=progress)
_pairs, loras, inversions, _rest = parse_prompt(params) _pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs): for image, output in zip(images, outputs):
@ -487,7 +487,7 @@ def run_blend_pipeline(
outputs: List[str], outputs: List[str],
upscale: UpscaleParams, upscale: UpscaleParams,
# highres: HighresParams, # highres: HighresParams,
sources: StageResult, sources: List[Image.Image],
mask: Image.Image, mask: Image.Image,
) -> None: ) -> None:
# set up the chain pipeline and base stage # set up the chain pipeline and base stage
@ -505,7 +505,7 @@ def run_blend_pipeline(
# run and save # run and save
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain(worker, server, params, sources, callback=progress) images = chain.run(worker, server, params, StageResult(images=sources), callback=progress)
for image, output in zip(images, outputs): for image, output in zip(images, outputs):
dest = save_image(server, output, image, params, size, upscale=upscale) dest = save_image(server, output, image, params, size, upscale=upscale)

View File

@ -3,19 +3,19 @@ import unittest
from PIL import Image from PIL import Image
from onnx_web.chain.blend_grid import BlendGridStage from onnx_web.chain.blend_grid import BlendGridStage
from onnx_web.chain.blend_linear import BlendLinearStage from onnx_web.chain.result import StageResult
class BlendGridStageTests(unittest.TestCase): class BlendGridStageTests(unittest.TestCase):
def test_stage(self): def test_stage(self):
stage = BlendGridStage() stage = BlendGridStage()
sources = [ sources = StageResult(images=[
Image.new("RGB", (64, 64), "black"), Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"), Image.new("RGB", (64, 64), "white"),
Image.new("RGB", (64, 64), "black"), Image.new("RGB", (64, 64), "black"),
Image.new("RGB", (64, 64), "white"), Image.new("RGB", (64, 64), "white"),
] ])
result = stage.run(None, None, None, None, sources, height=2, width=2) result = stage.run(None, None, None, None, sources, height=2, width=2)
self.assertEqual(len(result), 5) self.assertEqual(len(result), 5)
self.assertEqual(result[-1].getpixel((0,0)), (0, 0, 0)) self.assertEqual(result.as_image()[-1].getpixel((0,0)), (0, 0, 0))

View File

@ -3,16 +3,17 @@ import unittest
from PIL import Image from PIL import Image
from onnx_web.chain.blend_linear import BlendLinearStage from onnx_web.chain.blend_linear import BlendLinearStage
from onnx_web.chain.result import StageResult
class BlendLinearStageTests(unittest.TestCase): class BlendLinearStageTests(unittest.TestCase):
def test_stage(self): def test_stage(self):
stage = BlendLinearStage() stage = BlendLinearStage()
sources = [ sources = StageResult(images=[
Image.new("RGB", (64, 64), "black"), Image.new("RGB", (64, 64), "black"),
] ])
stage_source = Image.new("RGB", (64, 64), "white") stage_source = Image.new("RGB", (64, 64), "white")
result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source) result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source)
self.assertEqual(len(result), 1) self.assertEqual(len(result), 1)
self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127)) self.assertEqual(result.as_image()[0].getpixel((0,0)), (127, 127, 127))

View File

@ -3,13 +3,14 @@ import unittest
from PIL import Image from PIL import Image
from onnx_web.chain.blend_mask import BlendMaskStage from onnx_web.chain.blend_mask import BlendMaskStage
from onnx_web.chain.result import StageResult
from onnx_web.params import HighresParams, UpscaleParams from onnx_web.params import HighresParams, UpscaleParams
class BlendMaskStageTests(unittest.TestCase): class BlendMaskStageTests(unittest.TestCase):
def test_empty(self): def test_empty(self):
stage = BlendMaskStage() stage = BlendMaskStage()
sources = [] sources = StageResult.empty()
result = stage.run( result = stage.run(
None, None,
None, None,

View File

@ -25,7 +25,7 @@ class CorrectCodeformerStageTests(unittest.TestCase):
0, 0,
) )
stage = CorrectCodeformerStage() stage = CorrectCodeformerStage()
sources = [] sources = StageResult.empty()
result = stage.run(worker, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams("")) result = stage.run(worker, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)

View File

@ -1,13 +1,14 @@
import unittest import unittest
from onnx_web.chain.reduce_crop import ReduceCropStage from onnx_web.chain.reduce_crop import ReduceCropStage
from onnx_web.chain.result import StageResult
from onnx_web.params import HighresParams, Size, UpscaleParams from onnx_web.params import HighresParams, Size, UpscaleParams
class ReduceCropStageTests(unittest.TestCase): class ReduceCropStageTests(unittest.TestCase):
def test_empty(self): def test_empty(self):
stage = ReduceCropStage() stage = ReduceCropStage()
sources = [] sources = StageResult.empty()
result = stage.run( result = stage.run(
None, None,
None, None,

View File

@ -2,8 +2,8 @@ import unittest
from PIL import Image from PIL import Image
from onnx_web.chain.reduce_crop import ReduceCropStage
from onnx_web.chain.reduce_thumbnail import ReduceThumbnailStage from onnx_web.chain.reduce_thumbnail import ReduceThumbnailStage
from onnx_web.chain.result import StageResult
from onnx_web.params import HighresParams, Size, UpscaleParams from onnx_web.params import HighresParams, Size, UpscaleParams
@ -11,7 +11,7 @@ class ReduceThumbnailStageTests(unittest.TestCase):
def test_empty(self): def test_empty(self):
stage_source = Image.new("RGB", (64, 64)) stage_source = Image.new("RGB", (64, 64))
stage = ReduceThumbnailStage() stage = ReduceThumbnailStage()
sources = [] sources = StageResult.empty()
result = stage.run( result = stage.run(
None, None,
None, None,

View File

@ -1,4 +1,5 @@
import unittest import unittest
from onnx_web.chain.result import StageResult
from onnx_web.chain.source_noise import SourceNoiseStage from onnx_web.chain.source_noise import SourceNoiseStage
from onnx_web.image.noise_source import noise_source_fill_edge from onnx_web.image.noise_source import noise_source_fill_edge
@ -8,7 +9,7 @@ from onnx_web.params import HighresParams, Size, UpscaleParams
class SourceNoiseStageTests(unittest.TestCase): class SourceNoiseStageTests(unittest.TestCase):
def test_empty(self): def test_empty(self):
stage = SourceNoiseStage() stage = SourceNoiseStage()
sources = [] sources = StageResult.empty()
result = stage.run( result = stage.run(
None, None,
None, None,

View File

@ -1,4 +1,5 @@
import unittest import unittest
from onnx_web.chain.result import StageResult
from onnx_web.chain.source_s3 import SourceS3Stage from onnx_web.chain.source_s3 import SourceS3Stage
from onnx_web.params import HighresParams, Size, UpscaleParams from onnx_web.params import HighresParams, Size, UpscaleParams
@ -7,7 +8,7 @@ from onnx_web.params import HighresParams, Size, UpscaleParams
class SourceS3StageTests(unittest.TestCase): class SourceS3StageTests(unittest.TestCase):
def test_empty(self): def test_empty(self):
stage = SourceS3Stage() stage = SourceS3Stage()
sources = [] sources = StageResult.empty()
result = stage.run( result = stage.run(
None, None,
None, None,

View File

@ -1,4 +1,5 @@
import unittest import unittest
from onnx_web.chain.result import StageResult
from onnx_web.chain.source_url import SourceURLStage from onnx_web.chain.source_url import SourceURLStage
from onnx_web.params import HighresParams, Size, UpscaleParams from onnx_web.params import HighresParams, Size, UpscaleParams
@ -7,7 +8,7 @@ from onnx_web.params import HighresParams, Size, UpscaleParams
class SourceURLStageTests(unittest.TestCase): class SourceURLStageTests(unittest.TestCase):
def test_empty(self): def test_empty(self):
stage = SourceURLStage() stage = SourceURLStage()
sources = [] sources = StageResult.empty()
result = stage.run( result = stage.run(
None, None,
None, None,

View File

@ -1,4 +1,5 @@
import unittest import unittest
from onnx_web.chain.result import StageResult
from onnx_web.chain.upscale_highres import UpscaleHighresStage from onnx_web.chain.upscale_highres import UpscaleHighresStage
from onnx_web.params import HighresParams, UpscaleParams from onnx_web.params import HighresParams, UpscaleParams
@ -7,7 +8,7 @@ from onnx_web.params import HighresParams, UpscaleParams
class UpscaleHighresStageTests(unittest.TestCase): class UpscaleHighresStageTests(unittest.TestCase):
def test_empty(self): def test_empty(self):
stage = UpscaleHighresStage() stage = UpscaleHighresStage()
sources = [] sources = StageResult.empty()
result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams("")) result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
self.assertEqual(len(result), 0) self.assertEqual(len(result), 0)