move stages and tests to using stage result
This commit is contained in:
parent
7e6749e0d7
commit
eb77c83d80
|
@ -35,22 +35,23 @@ class BlendGridStage(BaseStage):
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
logger.info("combining source images using grid layout")
|
logger.info("combining source images using grid layout")
|
||||||
|
|
||||||
size = sources[0].size
|
images = sources.as_image()
|
||||||
|
size = images[0].size
|
||||||
|
|
||||||
output = Image.new("RGB", (size[0] * width, size[1] * height))
|
output = Image.new("RGB", (size[0] * width, size[1] * height))
|
||||||
|
|
||||||
# TODO: labels
|
# TODO: labels
|
||||||
if order is None:
|
if order is None:
|
||||||
order = range(len(sources))
|
order = range(len(images))
|
||||||
|
|
||||||
for i in range(len(order)):
|
for i in range(len(order)):
|
||||||
x = i % width
|
x = i % width
|
||||||
y = i // width
|
y = i // width
|
||||||
|
|
||||||
n = order[i]
|
n = order[i]
|
||||||
output.paste(sources[n], (x * size[0], y * size[1]))
|
output.paste(images[n], (x * size[0], y * size[1]))
|
||||||
|
|
||||||
return StageResult(images=[*sources, output])
|
return StageResult(images=[*images, output])
|
||||||
|
|
||||||
def outputs(
|
def outputs(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -29,5 +29,5 @@ class BlendLinearStage(BaseStage):
|
||||||
logger.info("blending source images using linear interpolation")
|
logger.info("blending source images using linear interpolation")
|
||||||
|
|
||||||
return StageResult(
|
return StageResult(
|
||||||
images=[Image.blend(source, stage_source, alpha) for source in sources]
|
images=[Image.blend(source, stage_source, alpha) for source in sources.as_image()]
|
||||||
)
|
)
|
||||||
|
|
|
@ -40,6 +40,6 @@ class BlendMaskStage(BaseStage):
|
||||||
|
|
||||||
return StageResult(
|
return StageResult(
|
||||||
images=[
|
images=[
|
||||||
Image.composite(stage_source, source, mult_mask) for source in sources
|
Image.composite(stage_source, source, mult_mask) for source in sources.as_image()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -33,4 +33,4 @@ class CorrectCodeformerStage(BaseStage):
|
||||||
|
|
||||||
device = worker.get_device()
|
device = worker.get_device()
|
||||||
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
||||||
return StageResult(images=[pipe(source) for source in sources])
|
return StageResult(images=[pipe(source) for source in sources.as_image()])
|
||||||
|
|
|
@ -30,7 +30,7 @@ class PersistDiskStage(BaseStage):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
logger.info(
|
logger.info(
|
||||||
"persisting images to disk: %s, %s", [s.size for s in sources], output
|
"persisting %s images to disk: %s", len(sources), output
|
||||||
)
|
)
|
||||||
|
|
||||||
for source, name in zip(sources, output):
|
for source, name in zip(sources, output):
|
||||||
|
|
|
@ -12,8 +12,8 @@ from ..server import ServerContext
|
||||||
from ..utils import is_debug, run_gc
|
from ..utils import is_debug, run_gc
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .result import StageResult
|
|
||||||
from .tile import needs_tile, process_tile_order
|
from .tile import needs_tile, process_tile_order
|
||||||
|
from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ class ChainPipeline:
|
||||||
sources: StageResult,
|
sources: StageResult,
|
||||||
callback: Optional[ProgressCallback],
|
callback: Optional[ProgressCallback],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> List[Image.Image]:
|
||||||
result = self(
|
result = self(
|
||||||
worker, server, params, sources=sources, callback=callback, **kwargs
|
worker, server, params, sources=sources, callback=callback, **kwargs
|
||||||
)
|
)
|
||||||
|
@ -136,7 +136,7 @@ class ChainPipeline:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"running stage %s with %s source images, parameters: %s",
|
"running stage %s with %s source images, parameters: %s",
|
||||||
name,
|
name,
|
||||||
len(stage_sources) - stage_sources.count(None),
|
len(stage_sources),
|
||||||
kwargs.keys(),
|
kwargs.keys(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ class ChainPipeline:
|
||||||
size=kwargs.get("size", None),
|
size=kwargs.get("size", None),
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
for source in stage_sources
|
for source in stage_sources.as_image()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -162,9 +162,10 @@ class ChainPipeline:
|
||||||
if stage_pipe.max_tile > 0:
|
if stage_pipe.max_tile > 0:
|
||||||
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
||||||
|
|
||||||
|
# TODO: stage_sources will always be defined here
|
||||||
if stage_sources or must_tile:
|
if stage_sources or must_tile:
|
||||||
stage_results = []
|
stage_results = []
|
||||||
for source in stage_sources:
|
for source in stage_sources.as_image():
|
||||||
logger.info(
|
logger.info(
|
||||||
"image contains sources or is larger than tile size of %s, tiling stage",
|
"image contains sources or is larger than tile size of %s, tiling stage",
|
||||||
tile,
|
tile,
|
||||||
|
@ -182,7 +183,7 @@ class ChainPipeline:
|
||||||
server,
|
server,
|
||||||
stage_params,
|
stage_params,
|
||||||
per_stage_params,
|
per_stage_params,
|
||||||
[source_tile],
|
StageResult(images=[source_tile]),
|
||||||
tile_mask=tile_mask,
|
tile_mask=tile_mask,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
|
@ -193,7 +194,8 @@ class ChainPipeline:
|
||||||
for j, image in enumerate(tile_result.as_image()):
|
for j, image in enumerate(tile_result.as_image()):
|
||||||
save_image(server, f"last-tile-{j}.png", image)
|
save_image(server, f"last-tile-{j}.png", image)
|
||||||
|
|
||||||
return tile_result
|
# TODO: return whole result
|
||||||
|
return tile_result.as_image()[0]
|
||||||
except Exception:
|
except Exception:
|
||||||
worker.retries = worker.retries - 1
|
worker.retries = worker.retries - 1
|
||||||
logger.exception(
|
logger.exception(
|
||||||
|
@ -257,7 +259,8 @@ class ChainPipeline:
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, "last-stage.png", stage_sources[0])
|
for j, image in enumerate(stage_sources.as_image()):
|
||||||
|
save_image(server, f"last-stage-{j}.png", image)
|
||||||
|
|
||||||
end = monotonic()
|
end = monotonic()
|
||||||
duration = timedelta(seconds=(end - start))
|
duration = timedelta(seconds=(end - start))
|
||||||
|
|
|
@ -15,6 +15,10 @@ class StageResult:
|
||||||
arrays: Optional[List[np.ndarray]]
|
arrays: Optional[List[np.ndarray]]
|
||||||
images: Optional[List[Image.Image]]
|
images: Optional[List[Image.Image]]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def empty():
|
||||||
|
return StageResult(images=[])
|
||||||
|
|
||||||
def __init__(self, arrays=None, images=None) -> None:
|
def __init__(self, arrays=None, images=None) -> None:
|
||||||
if arrays is not None and images is not None:
|
if arrays is not None and images is not None:
|
||||||
raise ValueError("stages must only return one type of result")
|
raise ValueError("stages must only return one type of result")
|
||||||
|
|
|
@ -156,7 +156,7 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = list(sources)
|
outputs = sources.as_image()
|
||||||
outputs.extend(result.images)
|
outputs.extend(result.images)
|
||||||
logger.debug("produced %s outputs", len(outputs))
|
logger.debug("produced %s outputs", len(outputs))
|
||||||
return StageResult(images=outputs)
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -34,7 +34,7 @@ class SourceURLStage(BaseStage):
|
||||||
"source images were passed to a source stage, new images will be appended"
|
"source images were passed to a source stage, new images will be appended"
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = list(sources)
|
outputs = sources.as_image()
|
||||||
for url in source_urls:
|
for url in source_urls:
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
output = Image.open(BytesIO(response.content))
|
output = Image.open(BytesIO(response.content))
|
||||||
|
|
|
@ -79,28 +79,23 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
logger.trace("BSRGAN input shape: %s", image.shape)
|
logger.trace("BSRGAN input shape: %s", image.shape)
|
||||||
|
|
||||||
scale = upscale.outscale
|
scale = upscale.outscale
|
||||||
dest = np.zeros(
|
logger.trace("BSRGAN output shape: %s", (
|
||||||
(
|
|
||||||
image.shape[0],
|
image.shape[0],
|
||||||
image.shape[1],
|
image.shape[1],
|
||||||
image.shape[2] * scale,
|
image.shape[2] * scale,
|
||||||
image.shape[3] * scale,
|
image.shape[3] * scale,
|
||||||
)
|
))
|
||||||
)
|
|
||||||
logger.trace("BSRGAN output shape: %s", dest.shape)
|
|
||||||
|
|
||||||
dest = bsrgan(image)
|
output = bsrgan(image)
|
||||||
|
|
||||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
output = np.clip(np.squeeze(output, axis=0), 0, 1)
|
||||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||||
dest = (dest * 255.0).round().astype(np.uint8)
|
output = (output * 255.0).round().astype(np.uint8)
|
||||||
|
|
||||||
output = Image.fromarray(dest, "RGB")
|
|
||||||
logger.debug("output image size: %s x %s", output.width, output.height)
|
|
||||||
|
|
||||||
|
logger.debug("output image shape: %s", output.shape)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
return StageResult(images=outputs)
|
return StageResult(arrays=outputs)
|
||||||
|
|
||||||
def steps(
|
def steps(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -42,7 +42,7 @@ class UpscaleHighresStage(BaseStage):
|
||||||
source,
|
source,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
)
|
)
|
||||||
for source in sources
|
for source in sources.as_image()
|
||||||
]
|
]
|
||||||
|
|
||||||
return StageResult(images=outputs)
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -38,13 +38,11 @@ class UpscaleSimpleStage(BaseStage):
|
||||||
|
|
||||||
if method == "bilinear":
|
if method == "bilinear":
|
||||||
logger.debug("using bilinear interpolation for highres")
|
logger.debug("using bilinear interpolation for highres")
|
||||||
source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
|
outputs.append(source.resize(scaled_size, resample=Image.Resampling.BILINEAR))
|
||||||
elif method == "lanczos":
|
elif method == "lanczos":
|
||||||
logger.debug("using Lanczos interpolation for highres")
|
logger.debug("using Lanczos interpolation for highres")
|
||||||
source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
outputs.append(source.resize(scaled_size, resample=Image.Resampling.LANCZOS))
|
||||||
else:
|
else:
|
||||||
logger.warning("unknown upscaling method: %s", method)
|
logger.warning("unknown upscaling method: %s", method)
|
||||||
|
|
||||||
outputs.append(source)
|
return StageResult(images=outputs)
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
|
@ -73,4 +73,4 @@ class UpscaleStableDiffusionStage(BaseStage):
|
||||||
)
|
)
|
||||||
outputs.extend(result.images)
|
outputs.extend(result.images)
|
||||||
|
|
||||||
return outputs
|
return StageResult(images=outputs)
|
||||||
|
|
|
@ -105,7 +105,7 @@ def run_txt2img_pipeline(
|
||||||
# run and save
|
# run and save
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain.run(worker, server, params, [], callback=progress, latents=latents)
|
images = chain.run(worker, server, params, StageResult.empty(), callback=progress, latents=latents)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
|
|
||||||
|
@ -200,7 +200,7 @@ def run_img2img_pipeline(
|
||||||
|
|
||||||
# run and append the filtered source
|
# run and append the filtered source
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain(worker, server, params, [source], callback=progress)
|
images = chain.run(worker, server, params, StageResult(images=[source]), callback=progress)
|
||||||
|
|
||||||
if source_filter is not None and source_filter != "none":
|
if source_filter is not None and source_filter != "none":
|
||||||
images.append(source)
|
images.append(source)
|
||||||
|
@ -380,7 +380,7 @@ def run_inpaint_pipeline(
|
||||||
# run and save
|
# run and save
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain(worker, server, params, [source], callback=progress, latents=latents)
|
images = chain.run(worker, server, params, [source], callback=progress, latents=latents)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
|
@ -455,7 +455,7 @@ def run_upscale_pipeline(
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain(worker, server, params, [source], callback=progress)
|
images = chain.run(worker, server, params, StageResult(images=[source]), callback=progress)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
|
@ -487,7 +487,7 @@ def run_blend_pipeline(
|
||||||
outputs: List[str],
|
outputs: List[str],
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
# highres: HighresParams,
|
# highres: HighresParams,
|
||||||
sources: StageResult,
|
sources: List[Image.Image],
|
||||||
mask: Image.Image,
|
mask: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
# set up the chain pipeline and base stage
|
# set up the chain pipeline and base stage
|
||||||
|
@ -505,7 +505,7 @@ def run_blend_pipeline(
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain(worker, server, params, sources, callback=progress)
|
images = chain.run(worker, server, params, StageResult(images=sources), callback=progress)
|
||||||
|
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
dest = save_image(server, output, image, params, size, upscale=upscale)
|
dest = save_image(server, output, image, params, size, upscale=upscale)
|
||||||
|
|
|
@ -3,19 +3,19 @@ import unittest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from onnx_web.chain.blend_grid import BlendGridStage
|
from onnx_web.chain.blend_grid import BlendGridStage
|
||||||
from onnx_web.chain.blend_linear import BlendLinearStage
|
from onnx_web.chain.result import StageResult
|
||||||
|
|
||||||
|
|
||||||
class BlendGridStageTests(unittest.TestCase):
|
class BlendGridStageTests(unittest.TestCase):
|
||||||
def test_stage(self):
|
def test_stage(self):
|
||||||
stage = BlendGridStage()
|
stage = BlendGridStage()
|
||||||
sources = [
|
sources = StageResult(images=[
|
||||||
Image.new("RGB", (64, 64), "black"),
|
Image.new("RGB", (64, 64), "black"),
|
||||||
Image.new("RGB", (64, 64), "white"),
|
Image.new("RGB", (64, 64), "white"),
|
||||||
Image.new("RGB", (64, 64), "black"),
|
Image.new("RGB", (64, 64), "black"),
|
||||||
Image.new("RGB", (64, 64), "white"),
|
Image.new("RGB", (64, 64), "white"),
|
||||||
]
|
])
|
||||||
result = stage.run(None, None, None, None, sources, height=2, width=2)
|
result = stage.run(None, None, None, None, sources, height=2, width=2)
|
||||||
|
|
||||||
self.assertEqual(len(result), 5)
|
self.assertEqual(len(result), 5)
|
||||||
self.assertEqual(result[-1].getpixel((0,0)), (0, 0, 0))
|
self.assertEqual(result.as_image()[-1].getpixel((0,0)), (0, 0, 0))
|
|
@ -3,16 +3,17 @@ import unittest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from onnx_web.chain.blend_linear import BlendLinearStage
|
from onnx_web.chain.blend_linear import BlendLinearStage
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
|
||||||
|
|
||||||
class BlendLinearStageTests(unittest.TestCase):
|
class BlendLinearStageTests(unittest.TestCase):
|
||||||
def test_stage(self):
|
def test_stage(self):
|
||||||
stage = BlendLinearStage()
|
stage = BlendLinearStage()
|
||||||
sources = [
|
sources = StageResult(images=[
|
||||||
Image.new("RGB", (64, 64), "black"),
|
Image.new("RGB", (64, 64), "black"),
|
||||||
]
|
])
|
||||||
stage_source = Image.new("RGB", (64, 64), "white")
|
stage_source = Image.new("RGB", (64, 64), "white")
|
||||||
result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source)
|
result = stage.run(None, None, None, None, sources, alpha=0.5, stage_source=stage_source)
|
||||||
|
|
||||||
self.assertEqual(len(result), 1)
|
self.assertEqual(len(result), 1)
|
||||||
self.assertEqual(result[0].getpixel((0,0)), (127, 127, 127))
|
self.assertEqual(result.as_image()[0].getpixel((0,0)), (127, 127, 127))
|
|
@ -3,13 +3,14 @@ import unittest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from onnx_web.chain.blend_mask import BlendMaskStage
|
from onnx_web.chain.blend_mask import BlendMaskStage
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
from onnx_web.params import HighresParams, UpscaleParams
|
from onnx_web.params import HighresParams, UpscaleParams
|
||||||
|
|
||||||
|
|
||||||
class BlendMaskStageTests(unittest.TestCase):
|
class BlendMaskStageTests(unittest.TestCase):
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
stage = BlendMaskStage()
|
stage = BlendMaskStage()
|
||||||
sources = []
|
sources = StageResult.empty()
|
||||||
result = stage.run(
|
result = stage.run(
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
|
|
@ -25,7 +25,7 @@ class CorrectCodeformerStageTests(unittest.TestCase):
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
stage = CorrectCodeformerStage()
|
stage = CorrectCodeformerStage()
|
||||||
sources = []
|
sources = StageResult.empty()
|
||||||
result = stage.run(worker, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
|
result = stage.run(worker, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from onnx_web.chain.reduce_crop import ReduceCropStage
|
from onnx_web.chain.reduce_crop import ReduceCropStage
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||||
|
|
||||||
|
|
||||||
class ReduceCropStageTests(unittest.TestCase):
|
class ReduceCropStageTests(unittest.TestCase):
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
stage = ReduceCropStage()
|
stage = ReduceCropStage()
|
||||||
sources = []
|
sources = StageResult.empty()
|
||||||
result = stage.run(
|
result = stage.run(
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
|
|
@ -2,8 +2,8 @@ import unittest
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from onnx_web.chain.reduce_crop import ReduceCropStage
|
|
||||||
from onnx_web.chain.reduce_thumbnail import ReduceThumbnailStage
|
from onnx_web.chain.reduce_thumbnail import ReduceThumbnailStage
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ class ReduceThumbnailStageTests(unittest.TestCase):
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
stage_source = Image.new("RGB", (64, 64))
|
stage_source = Image.new("RGB", (64, 64))
|
||||||
stage = ReduceThumbnailStage()
|
stage = ReduceThumbnailStage()
|
||||||
sources = []
|
sources = StageResult.empty()
|
||||||
result = stage.run(
|
result = stage.run(
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
|
||||||
from onnx_web.chain.source_noise import SourceNoiseStage
|
from onnx_web.chain.source_noise import SourceNoiseStage
|
||||||
from onnx_web.image.noise_source import noise_source_fill_edge
|
from onnx_web.image.noise_source import noise_source_fill_edge
|
||||||
|
@ -8,7 +9,7 @@ from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||||
class SourceNoiseStageTests(unittest.TestCase):
|
class SourceNoiseStageTests(unittest.TestCase):
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
stage = SourceNoiseStage()
|
stage = SourceNoiseStage()
|
||||||
sources = []
|
sources = StageResult.empty()
|
||||||
result = stage.run(
|
result = stage.run(
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
|
||||||
from onnx_web.chain.source_s3 import SourceS3Stage
|
from onnx_web.chain.source_s3 import SourceS3Stage
|
||||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||||
|
@ -7,7 +8,7 @@ from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||||
class SourceS3StageTests(unittest.TestCase):
|
class SourceS3StageTests(unittest.TestCase):
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
stage = SourceS3Stage()
|
stage = SourceS3Stage()
|
||||||
sources = []
|
sources = StageResult.empty()
|
||||||
result = stage.run(
|
result = stage.run(
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
|
||||||
from onnx_web.chain.source_url import SourceURLStage
|
from onnx_web.chain.source_url import SourceURLStage
|
||||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||||
|
@ -7,7 +8,7 @@ from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||||
class SourceURLStageTests(unittest.TestCase):
|
class SourceURLStageTests(unittest.TestCase):
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
stage = SourceURLStage()
|
stage = SourceURLStage()
|
||||||
sources = []
|
sources = StageResult.empty()
|
||||||
result = stage.run(
|
result = stage.run(
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
|
|
||||||
from onnx_web.chain.upscale_highres import UpscaleHighresStage
|
from onnx_web.chain.upscale_highres import UpscaleHighresStage
|
||||||
from onnx_web.params import HighresParams, UpscaleParams
|
from onnx_web.params import HighresParams, UpscaleParams
|
||||||
|
@ -7,7 +8,7 @@ from onnx_web.params import HighresParams, UpscaleParams
|
||||||
class UpscaleHighresStageTests(unittest.TestCase):
|
class UpscaleHighresStageTests(unittest.TestCase):
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
stage = UpscaleHighresStage()
|
stage = UpscaleHighresStage()
|
||||||
sources = []
|
sources = StageResult.empty()
|
||||||
result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
|
result = stage.run(None, None, None, None, sources, highres=HighresParams(False,1, 0, 0), upscale=UpscaleParams(""))
|
||||||
|
|
||||||
self.assertEqual(len(result), 0)
|
self.assertEqual(len(result), 0)
|
||||||
|
|
Loading…
Reference in New Issue