rewrite tile handling for image stacks
This commit is contained in:
parent
eb77c83d80
commit
98fcc07524
|
@ -29,5 +29,8 @@ class BlendLinearStage(BaseStage):
|
|||
logger.info("blending source images using linear interpolation")
|
||||
|
||||
return StageResult(
|
||||
images=[Image.blend(source, stage_source, alpha) for source in sources.as_image()]
|
||||
images=[
|
||||
Image.blend(source, stage_source, alpha)
|
||||
for source in sources.as_image()
|
||||
]
|
||||
)
|
||||
|
|
|
@ -40,6 +40,7 @@ class BlendMaskStage(BaseStage):
|
|||
|
||||
return StageResult(
|
||||
images=[
|
||||
Image.composite(stage_source, source, mult_mask) for source in sources.as_image()
|
||||
Image.composite(stage_source, source, mult_mask)
|
||||
for source in sources.as_image()
|
||||
]
|
||||
)
|
||||
|
|
|
@ -29,9 +29,7 @@ class PersistDiskStage(BaseStage):
|
|||
stage_source: Optional[Image.Image] = None,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
logger.info(
|
||||
"persisting %s images to disk: %s", len(sources), output
|
||||
)
|
||||
logger.info("persisting %s images to disk: %s", len(sources), output)
|
||||
|
||||
for source, name in zip(sources, output):
|
||||
dest = save_image(server, name, source, params=params, size=size)
|
||||
|
|
|
@ -12,8 +12,8 @@ from ..server import ServerContext
|
|||
from ..utils import is_debug, run_gc
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
from .tile import needs_tile, process_tile_order
|
||||
from .result import StageResult
|
||||
from .tile import needs_tile, process_tile_order
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -163,19 +163,17 @@ class ChainPipeline:
|
|||
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
||||
|
||||
# TODO: stage_sources will always be defined here
|
||||
if stage_sources or must_tile:
|
||||
stage_results = []
|
||||
for source in stage_sources.as_image():
|
||||
if must_tile:
|
||||
logger.info(
|
||||
"image contains sources or is larger than tile size of %s, tiling stage",
|
||||
tile,
|
||||
)
|
||||
|
||||
def stage_tile(
|
||||
source_tile: Image.Image,
|
||||
source_tile: List[Image.Image],
|
||||
tile_mask: Image.Image,
|
||||
dims: Tuple[int, int, int],
|
||||
) -> StageResult:
|
||||
) -> List[Image.Image]:
|
||||
for _i in range(worker.retries):
|
||||
try:
|
||||
tile_result = stage_pipe.run(
|
||||
|
@ -194,8 +192,7 @@ class ChainPipeline:
|
|||
for j, image in enumerate(tile_result.as_image()):
|
||||
save_image(server, f"last-tile-{j}.png", image)
|
||||
|
||||
# TODO: return whole result
|
||||
return tile_result.as_image()[0]
|
||||
return tile_result.as_image()
|
||||
except Exception:
|
||||
worker.retries = worker.retries - 1
|
||||
logger.exception(
|
||||
|
@ -207,17 +204,15 @@ class ChainPipeline:
|
|||
|
||||
raise RetryException("exhausted retries on tile")
|
||||
|
||||
output = process_tile_order(
|
||||
stage_results = process_tile_order(
|
||||
stage_params.tile_order,
|
||||
source,
|
||||
stage_sources,
|
||||
tile,
|
||||
stage_params.outscale,
|
||||
[stage_tile],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
stage_results.append(output)
|
||||
|
||||
stage_sources = StageResult(images=stage_results)
|
||||
else:
|
||||
logger.debug(
|
||||
|
|
|
@ -2,7 +2,7 @@ import itertools
|
|||
from enum import Enum
|
||||
from logging import getLogger
|
||||
from math import ceil
|
||||
from typing import Any, List, Optional, Protocol, Tuple
|
||||
from typing import Any, Callable, List, Optional, Protocol, Tuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -17,12 +17,17 @@ from .result import StageResult
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
TileGenerator = Callable[[int, int, int, Optional[float]], List[Tuple[int, int]]]
|
||||
|
||||
|
||||
class TileCallback(Protocol):
|
||||
"""
|
||||
Definition for a tile job function.
|
||||
"""
|
||||
|
||||
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
|
||||
def __call__(
|
||||
self, image: Image.Image, dims: Tuple[int, int, int]
|
||||
) -> List[Image.Image]:
|
||||
"""
|
||||
Run this stage against a single tile.
|
||||
"""
|
||||
|
@ -33,6 +38,9 @@ def complete_tile(
|
|||
source: Image.Image,
|
||||
tile: int,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
TODO: clean up
|
||||
"""
|
||||
if source is None:
|
||||
return source
|
||||
|
||||
|
@ -67,7 +75,7 @@ def needs_tile(
|
|||
return False
|
||||
|
||||
|
||||
def get_tile_grads(
|
||||
def make_tile_grads(
|
||||
left: int,
|
||||
top: int,
|
||||
tile: int,
|
||||
|
@ -161,7 +169,7 @@ def blend_tiles(
|
|||
points = [-1, min(p1, p2), max(p1, p2), (tile * scale)]
|
||||
|
||||
# gradient blending
|
||||
grad_x, grad_y = get_tile_grads(left, top, adj_tile, width, height)
|
||||
grad_x, grad_y = make_tile_grads(left, top, adj_tile, width, height)
|
||||
logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y)
|
||||
|
||||
mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)]
|
||||
|
@ -225,60 +233,18 @@ def blend_tiles(
|
|||
return Image.fromarray(np.uint8(pixels))
|
||||
|
||||
|
||||
def process_tile_grid(
|
||||
source: Image.Image,
|
||||
tile: int,
|
||||
scale: int,
|
||||
filters: List[TileCallback],
|
||||
overlap: float = 0.0,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
width, height = kwargs.get("size", source.size if source else None)
|
||||
|
||||
adj_tile = int(float(tile) * (1.0 - overlap))
|
||||
tiles_x = ceil(width / adj_tile)
|
||||
tiles_y = ceil(height / adj_tile)
|
||||
total = tiles_x * tiles_y
|
||||
logger.debug(
|
||||
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
|
||||
total,
|
||||
tiles_x,
|
||||
tiles_y,
|
||||
adj_tile,
|
||||
overlap,
|
||||
)
|
||||
|
||||
tiles: List[Tuple[int, int, Image.Image]] = []
|
||||
|
||||
for y in range(tiles_y):
|
||||
for x in range(tiles_x):
|
||||
idx = (y * tiles_x) + x
|
||||
left = x * adj_tile
|
||||
top = y * adj_tile
|
||||
logger.info("processing tile %s of %s, %s.%s", idx + 1, total, y, x)
|
||||
|
||||
tile_image = (
|
||||
source.crop((left, top, left + tile, top + tile)) if source else None
|
||||
)
|
||||
tile_image = complete_tile(tile_image, tile)
|
||||
|
||||
for filter in filters:
|
||||
tile_image = filter(tile_image, (left, top, tile))
|
||||
|
||||
tiles.append((left, top, tile_image))
|
||||
|
||||
return blend_tiles(tiles, scale, width, height, tile, overlap)
|
||||
|
||||
|
||||
def process_tile_spiral(
|
||||
source: Image.Image,
|
||||
def process_tile_stack(
|
||||
stack: StageResult,
|
||||
tile: int,
|
||||
scale: int,
|
||||
filters: List[TileCallback],
|
||||
tile_generator: TileGenerator,
|
||||
overlap: float = 0.5,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
width, height = kwargs.get("size", source.size if source else None)
|
||||
) -> List[Image.Image]:
|
||||
sources = stack.as_image()
|
||||
|
||||
width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None)
|
||||
mask = kwargs.get("mask", None)
|
||||
noise_source = kwargs.get("noise_source", noise_source_histogram)
|
||||
fill_color = kwargs.get("fill_color", None)
|
||||
|
@ -286,18 +252,9 @@ def process_tile_spiral(
|
|||
tile_mask = None
|
||||
|
||||
tiles: List[Tuple[int, int, Image.Image]] = []
|
||||
tile_coords = tile_generator(width, height, tile, overlap)
|
||||
|
||||
# tile tuples is source, multiply by scale for dest
|
||||
counter = 0
|
||||
tile_coords = generate_tile_spiral(width, height, tile, overlap=overlap)
|
||||
|
||||
if len(tile_coords) == 1:
|
||||
single_tile = True
|
||||
else:
|
||||
single_tile = False
|
||||
|
||||
for left, top in tile_coords:
|
||||
counter += 1
|
||||
for counter, (left, top) in enumerate(tile_coords):
|
||||
logger.info(
|
||||
"processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top
|
||||
)
|
||||
|
@ -321,37 +278,28 @@ def process_tile_spiral(
|
|||
needs_margin = True
|
||||
bottom_margin = height - bottom
|
||||
|
||||
# if no source given, we don't have a source image
|
||||
if not source:
|
||||
tile_image = None
|
||||
elif needs_margin:
|
||||
# in the special case where the image is smaller than the specified tile size, just use the image
|
||||
if single_tile:
|
||||
logger.debug("creating and processing single-tile subtile")
|
||||
tile_image = source
|
||||
if mask:
|
||||
tile_mask = mask
|
||||
# otherwise use add histogram noise outside of the image border
|
||||
else:
|
||||
if needs_margin:
|
||||
logger.debug(
|
||||
"tiling and adding margins: %s, %s, %s, %s",
|
||||
"tiling with added margins: %s, %s, %s, %s",
|
||||
left_margin,
|
||||
top_margin,
|
||||
right_margin,
|
||||
bottom_margin,
|
||||
)
|
||||
base_image = source.crop(
|
||||
(
|
||||
left + left_margin,
|
||||
top + top_margin,
|
||||
right + right_margin,
|
||||
bottom + bottom_margin,
|
||||
tile_stack = add_margin(
|
||||
stack,
|
||||
left,
|
||||
top,
|
||||
right,
|
||||
bottom,
|
||||
left_margin,
|
||||
top_margin,
|
||||
right_margin,
|
||||
bottom_margin,
|
||||
tile,
|
||||
noise_source,
|
||||
fill_color,
|
||||
)
|
||||
)
|
||||
tile_image = noise_source(
|
||||
base_image, (tile, tile), (0, 0), fill=fill_color
|
||||
)
|
||||
tile_image.paste(base_image, (left_margin, top_margin))
|
||||
|
||||
if mask:
|
||||
base_mask = mask.crop(
|
||||
|
@ -367,24 +315,30 @@ def process_tile_spiral(
|
|||
|
||||
else:
|
||||
logger.debug("tiling normally")
|
||||
tile_image = source.crop((left, top, right, bottom))
|
||||
tile_stack = get_result_tile(stack, (left, top), Size(tile, tile))
|
||||
if mask:
|
||||
tile_mask = mask.crop((left, top, right, bottom))
|
||||
|
||||
for image_filter in filters:
|
||||
tile_image = image_filter(tile_image, tile_mask, (left, top, tile))
|
||||
tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile))
|
||||
|
||||
tiles.append((left, top, tile_image))
|
||||
tiles.append((left, top, tile_stack))
|
||||
|
||||
if single_tile:
|
||||
return tile_image
|
||||
else:
|
||||
return blend_tiles(tiles, scale, width, height, tile, overlap)
|
||||
lefts, tops, stacks = list(zip(*tiles))
|
||||
coords = list(zip(lefts, tops))
|
||||
stacks = list(zip(*stacks))
|
||||
|
||||
result = []
|
||||
for stack in stacks:
|
||||
stack_tiles = zip(coords, stack)
|
||||
result.append(blend_tiles(stack_tiles, scale, width, height, tile, overlap))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def process_tile_order(
|
||||
order: TileOrder,
|
||||
source: Image.Image,
|
||||
stack: StageResult,
|
||||
tile: int,
|
||||
scale: int,
|
||||
filters: List[TileCallback],
|
||||
|
@ -395,13 +349,17 @@ def process_tile_order(
|
|||
"""
|
||||
if order == TileOrder.grid:
|
||||
logger.debug("using grid tile order with tile size: %s", tile)
|
||||
return process_tile_grid(source, tile, scale, filters, **kwargs)
|
||||
return process_tile_stack(
|
||||
stack, tile, scale, filters, generate_tile_grid, **kwargs
|
||||
)
|
||||
elif order == TileOrder.kernel:
|
||||
logger.debug("using kernel tile order with tile size: %s", tile)
|
||||
raise NotImplementedError()
|
||||
elif order == TileOrder.spiral:
|
||||
logger.debug("using spiral tile order with tile size: %s", tile)
|
||||
return process_tile_spiral(source, tile, scale, filters, **kwargs)
|
||||
return process_tile_stack(
|
||||
stack, tile, scale, filters, generate_tile_spiral, **kwargs
|
||||
)
|
||||
else:
|
||||
logger.warning("unknown tile order: %s", order)
|
||||
raise ValueError()
|
||||
|
@ -495,3 +453,76 @@ def generate_tile_spiral(
|
|||
height_tile_target -= abs(state.value[1])
|
||||
|
||||
return tile_coords
|
||||
|
||||
|
||||
def generate_tile_grid(
|
||||
width: int,
|
||||
height: int,
|
||||
tile: int,
|
||||
overlap: float = 0.0,
|
||||
) -> List[Tuple[int, int]]:
|
||||
adj_tile = int(float(tile) * (1.0 - overlap))
|
||||
tiles_x = ceil(width / adj_tile)
|
||||
tiles_y = ceil(height / adj_tile)
|
||||
total = tiles_x * tiles_y
|
||||
logger.debug(
|
||||
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
|
||||
total,
|
||||
tiles_x,
|
||||
tiles_y,
|
||||
adj_tile,
|
||||
overlap,
|
||||
)
|
||||
|
||||
tiles: List[Tuple[int, int, Image.Image]] = []
|
||||
|
||||
for y in range(tiles_y):
|
||||
for x in range(tiles_x):
|
||||
left = x * adj_tile
|
||||
top = y * adj_tile
|
||||
|
||||
tiles.append((int(left), int(top)))
|
||||
|
||||
return tiles
|
||||
|
||||
|
||||
def get_result_tile(
|
||||
result: StageResult,
|
||||
origin: Tuple[int, int],
|
||||
tile: Size,
|
||||
) -> List[Image.Image]:
|
||||
top, left = origin
|
||||
return [
|
||||
layer.crop((top, left, top + tile.height, left + tile.width))
|
||||
for layer in result.as_image()
|
||||
]
|
||||
|
||||
|
||||
def add_margin(
|
||||
stack: List[Image.Image],
|
||||
left: int,
|
||||
top: int,
|
||||
right: int,
|
||||
bottom: int,
|
||||
left_margin: int,
|
||||
top_margin: int,
|
||||
right_margin: int,
|
||||
bottom_margin: int,
|
||||
tile: int,
|
||||
noise_source,
|
||||
fill_color,
|
||||
) -> List[Image.Image]:
|
||||
results = []
|
||||
for source in stack:
|
||||
base_image = source.crop(
|
||||
(
|
||||
left + left_margin,
|
||||
top + top_margin,
|
||||
right + right_margin,
|
||||
bottom + bottom_margin,
|
||||
)
|
||||
)
|
||||
tile_image = noise_source(base_image, (tile, tile), (0, 0), fill=fill_color)
|
||||
tile_image.paste(base_image, (left_margin, top_margin))
|
||||
|
||||
return results
|
||||
|
|
|
@ -79,12 +79,15 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
logger.trace("BSRGAN input shape: %s", image.shape)
|
||||
|
||||
scale = upscale.outscale
|
||||
logger.trace("BSRGAN output shape: %s", (
|
||||
logger.trace(
|
||||
"BSRGAN output shape: %s",
|
||||
(
|
||||
image.shape[0],
|
||||
image.shape[1],
|
||||
image.shape[2] * scale,
|
||||
image.shape[3] * scale,
|
||||
))
|
||||
),
|
||||
)
|
||||
|
||||
output = bsrgan(image)
|
||||
|
||||
|
|
|
@ -38,10 +38,14 @@ class UpscaleSimpleStage(BaseStage):
|
|||
|
||||
if method == "bilinear":
|
||||
logger.debug("using bilinear interpolation for highres")
|
||||
outputs.append(source.resize(scaled_size, resample=Image.Resampling.BILINEAR))
|
||||
outputs.append(
|
||||
source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
|
||||
)
|
||||
elif method == "lanczos":
|
||||
logger.debug("using Lanczos interpolation for highres")
|
||||
outputs.append(source.resize(scaled_size, resample=Image.Resampling.LANCZOS))
|
||||
outputs.append(
|
||||
source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
||||
)
|
||||
else:
|
||||
logger.warning("unknown upscaling method: %s", method)
|
||||
|
||||
|
|
|
@ -105,7 +105,9 @@ def run_txt2img_pipeline(
|
|||
# run and save
|
||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain.run(worker, server, params, StageResult.empty(), callback=progress, latents=latents)
|
||||
images = chain.run(
|
||||
worker, server, params, StageResult.empty(), callback=progress, latents=latents
|
||||
)
|
||||
|
||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||
|
||||
|
@ -200,7 +202,9 @@ def run_img2img_pipeline(
|
|||
|
||||
# run and append the filtered source
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain.run(worker, server, params, StageResult(images=[source]), callback=progress)
|
||||
images = chain.run(
|
||||
worker, server, params, StageResult(images=[source]), callback=progress
|
||||
)
|
||||
|
||||
if source_filter is not None and source_filter != "none":
|
||||
images.append(source)
|
||||
|
@ -380,7 +384,9 @@ def run_inpaint_pipeline(
|
|||
# run and save
|
||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain.run(worker, server, params, [source], callback=progress, latents=latents)
|
||||
images = chain.run(
|
||||
worker, server, params, [source], callback=progress, latents=latents
|
||||
)
|
||||
|
||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||
for image, output in zip(images, outputs):
|
||||
|
@ -455,7 +461,9 @@ def run_upscale_pipeline(
|
|||
|
||||
# run and save
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain.run(worker, server, params, StageResult(images=[source]), callback=progress)
|
||||
images = chain.run(
|
||||
worker, server, params, StageResult(images=[source]), callback=progress
|
||||
)
|
||||
|
||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||
for image, output in zip(images, outputs):
|
||||
|
@ -505,7 +513,9 @@ def run_blend_pipeline(
|
|||
|
||||
# run and save
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain.run(worker, server, params, StageResult(images=sources), callback=progress)
|
||||
images = chain.run(
|
||||
worker, server, params, StageResult(images=sources), callback=progress
|
||||
)
|
||||
|
||||
for image, output in zip(images, outputs):
|
||||
dest = save_image(server, output, image, params, size, upscale=upscale)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
from onnx_web.chain.result import StageResult
|
||||
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.chain.source_noise import SourceNoiseStage
|
||||
from onnx_web.image.noise_source import noise_source_fill_edge
|
||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
from onnx_web.chain.result import StageResult
|
||||
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.chain.source_s3 import SourceS3Stage
|
||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
from onnx_web.chain.result import StageResult
|
||||
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.chain.source_url import SourceURLStage
|
||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||
|
||||
|
|
|
@ -4,11 +4,11 @@ from PIL import Image
|
|||
|
||||
from onnx_web.chain.tile import (
|
||||
complete_tile,
|
||||
generate_tile_grid,
|
||||
generate_tile_spiral,
|
||||
get_tile_grads,
|
||||
make_tile_grads,
|
||||
needs_tile,
|
||||
process_tile_grid,
|
||||
process_tile_spiral,
|
||||
process_tile_stack,
|
||||
)
|
||||
from onnx_web.params import Size
|
||||
|
||||
|
@ -59,24 +59,46 @@ class TestNeedsTile(unittest.TestCase):
|
|||
|
||||
class TestTileGrads(unittest.TestCase):
|
||||
def test_center_tile(self):
|
||||
grad_x, grad_y = get_tile_grads(32, 32, 8, 64, 64)
|
||||
grad_x, grad_y = make_tile_grads(32, 32, 8, 64, 64)
|
||||
|
||||
self.assertEqual(grad_x, [0, 1, 1, 0])
|
||||
self.assertEqual(grad_y, [0, 1, 1, 0])
|
||||
|
||||
def test_vertical_edge_tile(self):
|
||||
grad_x, grad_y = get_tile_grads(32, 0, 8, 64, 8)
|
||||
grad_x, grad_y = make_tile_grads(32, 0, 8, 64, 8)
|
||||
|
||||
self.assertEqual(grad_x, [0, 1, 1, 0])
|
||||
self.assertEqual(grad_y, [1, 1, 1, 1])
|
||||
|
||||
def test_horizontal_edge_tile(self):
|
||||
grad_x, grad_y = get_tile_grads(0, 32, 8, 8, 64)
|
||||
grad_x, grad_y = make_tile_grads(0, 32, 8, 8, 64)
|
||||
|
||||
self.assertEqual(grad_x, [1, 1, 1, 1])
|
||||
self.assertEqual(grad_y, [0, 1, 1, 0])
|
||||
|
||||
|
||||
class TestGenerateTileGrid(unittest.TestCase):
|
||||
def test_grid_complete(self):
|
||||
tiles = generate_tile_grid(16, 16, 8, 0.0)
|
||||
|
||||
self.assertEqual(len(tiles), 4)
|
||||
self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)])
|
||||
|
||||
def test_grid_no_overlap(self):
|
||||
tiles = generate_tile_grid(64, 64, 8, 0.0)
|
||||
|
||||
self.assertEqual(len(tiles), 64)
|
||||
self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)])
|
||||
self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)])
|
||||
|
||||
def test_grid_50_overlap(self):
|
||||
tiles = generate_tile_grid(64, 64, 8, 0.5)
|
||||
|
||||
self.assertEqual(len(tiles), 225)
|
||||
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
|
||||
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
|
||||
|
||||
|
||||
class TestGenerateTileSpiral(unittest.TestCase):
|
||||
def test_spiral_complete(self):
|
||||
tiles = generate_tile_spiral(16, 16, 8, 0.0)
|
||||
|
@ -99,29 +121,15 @@ class TestGenerateTileSpiral(unittest.TestCase):
|
|||
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
|
||||
|
||||
|
||||
class TestProcessTileGrid(unittest.TestCase):
|
||||
class TestProcessTileStack(unittest.TestCase):
|
||||
def test_grid_full(self):
|
||||
source = Image.new("RGB", (64, 64))
|
||||
blend = process_tile_grid(source, 32, 1, [])
|
||||
blend = process_tile_stack(source, 32, 1, [])
|
||||
|
||||
self.assertEqual(blend.size, (64, 64))
|
||||
|
||||
def test_grid_partial(self):
|
||||
source = Image.new("RGB", (72, 72))
|
||||
blend = process_tile_grid(source, 32, 1, [])
|
||||
|
||||
self.assertEqual(blend.size, (72, 72))
|
||||
|
||||
|
||||
class TestProcessTileSpiral(unittest.TestCase):
|
||||
def test_grid_full(self):
|
||||
source = Image.new("RGB", (64, 64))
|
||||
blend = process_tile_spiral(source, 32, 1, [])
|
||||
|
||||
self.assertEqual(blend.size, (64, 64))
|
||||
|
||||
def test_grid_partial(self):
|
||||
source = Image.new("RGB", (72, 72))
|
||||
blend = process_tile_spiral(source, 32, 1, [])
|
||||
blend = process_tile_stack(source, 32, 1, [])
|
||||
|
||||
self.assertEqual(blend.size, (72, 72))
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
from onnx_web.chain.result import StageResult
|
||||
|
||||
from onnx_web.chain.result import StageResult
|
||||
from onnx_web.chain.upscale_highres import UpscaleHighresStage
|
||||
from onnx_web.params import HighresParams, UpscaleParams
|
||||
|
||||
|
|
Loading…
Reference in New Issue