rewrite tile handling for image stacks
This commit is contained in:
parent
eb77c83d80
commit
98fcc07524
|
@ -29,5 +29,8 @@ class BlendLinearStage(BaseStage):
|
||||||
logger.info("blending source images using linear interpolation")
|
logger.info("blending source images using linear interpolation")
|
||||||
|
|
||||||
return StageResult(
|
return StageResult(
|
||||||
images=[Image.blend(source, stage_source, alpha) for source in sources.as_image()]
|
images=[
|
||||||
|
Image.blend(source, stage_source, alpha)
|
||||||
|
for source in sources.as_image()
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -40,6 +40,7 @@ class BlendMaskStage(BaseStage):
|
||||||
|
|
||||||
return StageResult(
|
return StageResult(
|
||||||
images=[
|
images=[
|
||||||
Image.composite(stage_source, source, mult_mask) for source in sources.as_image()
|
Image.composite(stage_source, source, mult_mask)
|
||||||
|
for source in sources.as_image()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -29,9 +29,7 @@ class PersistDiskStage(BaseStage):
|
||||||
stage_source: Optional[Image.Image] = None,
|
stage_source: Optional[Image.Image] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
logger.info(
|
logger.info("persisting %s images to disk: %s", len(sources), output)
|
||||||
"persisting %s images to disk: %s", len(sources), output
|
|
||||||
)
|
|
||||||
|
|
||||||
for source, name in zip(sources, output):
|
for source, name in zip(sources, output):
|
||||||
dest = save_image(server, name, source, params=params, size=size)
|
dest = save_image(server, name, source, params=params, size=size)
|
||||||
|
|
|
@ -12,8 +12,8 @@ from ..server import ServerContext
|
||||||
from ..utils import is_debug, run_gc
|
from ..utils import is_debug, run_gc
|
||||||
from ..worker import ProgressCallback, WorkerContext
|
from ..worker import ProgressCallback, WorkerContext
|
||||||
from .base import BaseStage
|
from .base import BaseStage
|
||||||
from .tile import needs_tile, process_tile_order
|
|
||||||
from .result import StageResult
|
from .result import StageResult
|
||||||
|
from .tile import needs_tile, process_tile_order
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -163,60 +163,55 @@ class ChainPipeline:
|
||||||
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
||||||
|
|
||||||
# TODO: stage_sources will always be defined here
|
# TODO: stage_sources will always be defined here
|
||||||
if stage_sources or must_tile:
|
if must_tile:
|
||||||
stage_results = []
|
logger.info(
|
||||||
for source in stage_sources.as_image():
|
"image contains sources or is larger than tile size of %s, tiling stage",
|
||||||
logger.info(
|
tile,
|
||||||
"image contains sources or is larger than tile size of %s, tiling stage",
|
)
|
||||||
tile,
|
|
||||||
)
|
|
||||||
|
|
||||||
def stage_tile(
|
def stage_tile(
|
||||||
source_tile: Image.Image,
|
source_tile: List[Image.Image],
|
||||||
tile_mask: Image.Image,
|
tile_mask: Image.Image,
|
||||||
dims: Tuple[int, int, int],
|
dims: Tuple[int, int, int],
|
||||||
) -> StageResult:
|
) -> List[Image.Image]:
|
||||||
for _i in range(worker.retries):
|
for _i in range(worker.retries):
|
||||||
try:
|
try:
|
||||||
tile_result = stage_pipe.run(
|
tile_result = stage_pipe.run(
|
||||||
worker,
|
worker,
|
||||||
server,
|
server,
|
||||||
stage_params,
|
stage_params,
|
||||||
per_stage_params,
|
per_stage_params,
|
||||||
StageResult(images=[source_tile]),
|
StageResult(images=[source_tile]),
|
||||||
tile_mask=tile_mask,
|
tile_mask=tile_mask,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
dims=dims,
|
dims=dims,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
for j, image in enumerate(tile_result.as_image()):
|
for j, image in enumerate(tile_result.as_image()):
|
||||||
save_image(server, f"last-tile-{j}.png", image)
|
save_image(server, f"last-tile-{j}.png", image)
|
||||||
|
|
||||||
# TODO: return whole result
|
return tile_result.as_image()
|
||||||
return tile_result.as_image()[0]
|
except Exception:
|
||||||
except Exception:
|
worker.retries = worker.retries - 1
|
||||||
worker.retries = worker.retries - 1
|
logger.exception(
|
||||||
logger.exception(
|
"error while running stage pipeline for tile, %s retries left",
|
||||||
"error while running stage pipeline for tile, %s retries left",
|
worker.retries,
|
||||||
worker.retries,
|
)
|
||||||
)
|
server.cache.clear()
|
||||||
server.cache.clear()
|
run_gc([worker.get_device()])
|
||||||
run_gc([worker.get_device()])
|
|
||||||
|
|
||||||
raise RetryException("exhausted retries on tile")
|
raise RetryException("exhausted retries on tile")
|
||||||
|
|
||||||
output = process_tile_order(
|
stage_results = process_tile_order(
|
||||||
stage_params.tile_order,
|
stage_params.tile_order,
|
||||||
source,
|
stage_sources,
|
||||||
tile,
|
tile,
|
||||||
stage_params.outscale,
|
stage_params.outscale,
|
||||||
[stage_tile],
|
[stage_tile],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
stage_results.append(output)
|
|
||||||
|
|
||||||
stage_sources = StageResult(images=stage_results)
|
stage_sources = StageResult(images=stage_results)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -2,7 +2,7 @@ import itertools
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Any, List, Optional, Protocol, Tuple
|
from typing import Any, Callable, List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
@ -17,12 +17,17 @@ from .result import StageResult
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
TileGenerator = Callable[[int, int, int, Optional[float]], List[Tuple[int, int]]]
|
||||||
|
|
||||||
|
|
||||||
class TileCallback(Protocol):
|
class TileCallback(Protocol):
|
||||||
"""
|
"""
|
||||||
Definition for a tile job function.
|
Definition for a tile job function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
|
def __call__(
|
||||||
|
self, image: Image.Image, dims: Tuple[int, int, int]
|
||||||
|
) -> List[Image.Image]:
|
||||||
"""
|
"""
|
||||||
Run this stage against a single tile.
|
Run this stage against a single tile.
|
||||||
"""
|
"""
|
||||||
|
@ -33,6 +38,9 @@ def complete_tile(
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
tile: int,
|
tile: int,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
|
"""
|
||||||
|
TODO: clean up
|
||||||
|
"""
|
||||||
if source is None:
|
if source is None:
|
||||||
return source
|
return source
|
||||||
|
|
||||||
|
@ -67,7 +75,7 @@ def needs_tile(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_tile_grads(
|
def make_tile_grads(
|
||||||
left: int,
|
left: int,
|
||||||
top: int,
|
top: int,
|
||||||
tile: int,
|
tile: int,
|
||||||
|
@ -161,7 +169,7 @@ def blend_tiles(
|
||||||
points = [-1, min(p1, p2), max(p1, p2), (tile * scale)]
|
points = [-1, min(p1, p2), max(p1, p2), (tile * scale)]
|
||||||
|
|
||||||
# gradient blending
|
# gradient blending
|
||||||
grad_x, grad_y = get_tile_grads(left, top, adj_tile, width, height)
|
grad_x, grad_y = make_tile_grads(left, top, adj_tile, width, height)
|
||||||
logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y)
|
logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y)
|
||||||
|
|
||||||
mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)]
|
mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)]
|
||||||
|
@ -225,60 +233,18 @@ def blend_tiles(
|
||||||
return Image.fromarray(np.uint8(pixels))
|
return Image.fromarray(np.uint8(pixels))
|
||||||
|
|
||||||
|
|
||||||
def process_tile_grid(
|
def process_tile_stack(
|
||||||
source: Image.Image,
|
stack: StageResult,
|
||||||
tile: int,
|
|
||||||
scale: int,
|
|
||||||
filters: List[TileCallback],
|
|
||||||
overlap: float = 0.0,
|
|
||||||
**kwargs,
|
|
||||||
) -> Image.Image:
|
|
||||||
width, height = kwargs.get("size", source.size if source else None)
|
|
||||||
|
|
||||||
adj_tile = int(float(tile) * (1.0 - overlap))
|
|
||||||
tiles_x = ceil(width / adj_tile)
|
|
||||||
tiles_y = ceil(height / adj_tile)
|
|
||||||
total = tiles_x * tiles_y
|
|
||||||
logger.debug(
|
|
||||||
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
|
|
||||||
total,
|
|
||||||
tiles_x,
|
|
||||||
tiles_y,
|
|
||||||
adj_tile,
|
|
||||||
overlap,
|
|
||||||
)
|
|
||||||
|
|
||||||
tiles: List[Tuple[int, int, Image.Image]] = []
|
|
||||||
|
|
||||||
for y in range(tiles_y):
|
|
||||||
for x in range(tiles_x):
|
|
||||||
idx = (y * tiles_x) + x
|
|
||||||
left = x * adj_tile
|
|
||||||
top = y * adj_tile
|
|
||||||
logger.info("processing tile %s of %s, %s.%s", idx + 1, total, y, x)
|
|
||||||
|
|
||||||
tile_image = (
|
|
||||||
source.crop((left, top, left + tile, top + tile)) if source else None
|
|
||||||
)
|
|
||||||
tile_image = complete_tile(tile_image, tile)
|
|
||||||
|
|
||||||
for filter in filters:
|
|
||||||
tile_image = filter(tile_image, (left, top, tile))
|
|
||||||
|
|
||||||
tiles.append((left, top, tile_image))
|
|
||||||
|
|
||||||
return blend_tiles(tiles, scale, width, height, tile, overlap)
|
|
||||||
|
|
||||||
|
|
||||||
def process_tile_spiral(
|
|
||||||
source: Image.Image,
|
|
||||||
tile: int,
|
tile: int,
|
||||||
scale: int,
|
scale: int,
|
||||||
filters: List[TileCallback],
|
filters: List[TileCallback],
|
||||||
|
tile_generator: TileGenerator,
|
||||||
overlap: float = 0.5,
|
overlap: float = 0.5,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> List[Image.Image]:
|
||||||
width, height = kwargs.get("size", source.size if source else None)
|
sources = stack.as_image()
|
||||||
|
|
||||||
|
width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None)
|
||||||
mask = kwargs.get("mask", None)
|
mask = kwargs.get("mask", None)
|
||||||
noise_source = kwargs.get("noise_source", noise_source_histogram)
|
noise_source = kwargs.get("noise_source", noise_source_histogram)
|
||||||
fill_color = kwargs.get("fill_color", None)
|
fill_color = kwargs.get("fill_color", None)
|
||||||
|
@ -286,18 +252,9 @@ def process_tile_spiral(
|
||||||
tile_mask = None
|
tile_mask = None
|
||||||
|
|
||||||
tiles: List[Tuple[int, int, Image.Image]] = []
|
tiles: List[Tuple[int, int, Image.Image]] = []
|
||||||
|
tile_coords = tile_generator(width, height, tile, overlap)
|
||||||
|
|
||||||
# tile tuples is source, multiply by scale for dest
|
for counter, (left, top) in enumerate(tile_coords):
|
||||||
counter = 0
|
|
||||||
tile_coords = generate_tile_spiral(width, height, tile, overlap=overlap)
|
|
||||||
|
|
||||||
if len(tile_coords) == 1:
|
|
||||||
single_tile = True
|
|
||||||
else:
|
|
||||||
single_tile = False
|
|
||||||
|
|
||||||
for left, top in tile_coords:
|
|
||||||
counter += 1
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top
|
"processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top
|
||||||
)
|
)
|
||||||
|
@ -321,26 +278,31 @@ def process_tile_spiral(
|
||||||
needs_margin = True
|
needs_margin = True
|
||||||
bottom_margin = height - bottom
|
bottom_margin = height - bottom
|
||||||
|
|
||||||
# if no source given, we don't have a source image
|
if needs_margin:
|
||||||
if not source:
|
logger.debug(
|
||||||
tile_image = None
|
"tiling with added margins: %s, %s, %s, %s",
|
||||||
elif needs_margin:
|
left_margin,
|
||||||
# in the special case where the image is smaller than the specified tile size, just use the image
|
top_margin,
|
||||||
if single_tile:
|
right_margin,
|
||||||
logger.debug("creating and processing single-tile subtile")
|
bottom_margin,
|
||||||
tile_image = source
|
)
|
||||||
if mask:
|
tile_stack = add_margin(
|
||||||
tile_mask = mask
|
stack,
|
||||||
# otherwise use add histogram noise outside of the image border
|
left,
|
||||||
else:
|
top,
|
||||||
logger.debug(
|
right,
|
||||||
"tiling and adding margins: %s, %s, %s, %s",
|
bottom,
|
||||||
left_margin,
|
left_margin,
|
||||||
top_margin,
|
top_margin,
|
||||||
right_margin,
|
right_margin,
|
||||||
bottom_margin,
|
bottom_margin,
|
||||||
)
|
tile,
|
||||||
base_image = source.crop(
|
noise_source,
|
||||||
|
fill_color,
|
||||||
|
)
|
||||||
|
|
||||||
|
if mask:
|
||||||
|
base_mask = mask.crop(
|
||||||
(
|
(
|
||||||
left + left_margin,
|
left + left_margin,
|
||||||
top + top_margin,
|
top + top_margin,
|
||||||
|
@ -348,43 +310,35 @@ def process_tile_spiral(
|
||||||
bottom + bottom_margin,
|
bottom + bottom_margin,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tile_image = noise_source(
|
tile_mask = Image.new("L", (tile, tile), color=0)
|
||||||
base_image, (tile, tile), (0, 0), fill=fill_color
|
tile_mask.paste(base_mask, (left_margin, top_margin))
|
||||||
)
|
|
||||||
tile_image.paste(base_image, (left_margin, top_margin))
|
|
||||||
|
|
||||||
if mask:
|
|
||||||
base_mask = mask.crop(
|
|
||||||
(
|
|
||||||
left + left_margin,
|
|
||||||
top + top_margin,
|
|
||||||
right + right_margin,
|
|
||||||
bottom + bottom_margin,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
tile_mask = Image.new("L", (tile, tile), color=0)
|
|
||||||
tile_mask.paste(base_mask, (left_margin, top_margin))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.debug("tiling normally")
|
logger.debug("tiling normally")
|
||||||
tile_image = source.crop((left, top, right, bottom))
|
tile_stack = get_result_tile(stack, (left, top), Size(tile, tile))
|
||||||
if mask:
|
if mask:
|
||||||
tile_mask = mask.crop((left, top, right, bottom))
|
tile_mask = mask.crop((left, top, right, bottom))
|
||||||
|
|
||||||
for image_filter in filters:
|
for image_filter in filters:
|
||||||
tile_image = image_filter(tile_image, tile_mask, (left, top, tile))
|
tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile))
|
||||||
|
|
||||||
tiles.append((left, top, tile_image))
|
tiles.append((left, top, tile_stack))
|
||||||
|
|
||||||
if single_tile:
|
lefts, tops, stacks = list(zip(*tiles))
|
||||||
return tile_image
|
coords = list(zip(lefts, tops))
|
||||||
else:
|
stacks = list(zip(*stacks))
|
||||||
return blend_tiles(tiles, scale, width, height, tile, overlap)
|
|
||||||
|
result = []
|
||||||
|
for stack in stacks:
|
||||||
|
stack_tiles = zip(coords, stack)
|
||||||
|
result.append(blend_tiles(stack_tiles, scale, width, height, tile, overlap))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def process_tile_order(
|
def process_tile_order(
|
||||||
order: TileOrder,
|
order: TileOrder,
|
||||||
source: Image.Image,
|
stack: StageResult,
|
||||||
tile: int,
|
tile: int,
|
||||||
scale: int,
|
scale: int,
|
||||||
filters: List[TileCallback],
|
filters: List[TileCallback],
|
||||||
|
@ -395,13 +349,17 @@ def process_tile_order(
|
||||||
"""
|
"""
|
||||||
if order == TileOrder.grid:
|
if order == TileOrder.grid:
|
||||||
logger.debug("using grid tile order with tile size: %s", tile)
|
logger.debug("using grid tile order with tile size: %s", tile)
|
||||||
return process_tile_grid(source, tile, scale, filters, **kwargs)
|
return process_tile_stack(
|
||||||
|
stack, tile, scale, filters, generate_tile_grid, **kwargs
|
||||||
|
)
|
||||||
elif order == TileOrder.kernel:
|
elif order == TileOrder.kernel:
|
||||||
logger.debug("using kernel tile order with tile size: %s", tile)
|
logger.debug("using kernel tile order with tile size: %s", tile)
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
elif order == TileOrder.spiral:
|
elif order == TileOrder.spiral:
|
||||||
logger.debug("using spiral tile order with tile size: %s", tile)
|
logger.debug("using spiral tile order with tile size: %s", tile)
|
||||||
return process_tile_spiral(source, tile, scale, filters, **kwargs)
|
return process_tile_stack(
|
||||||
|
stack, tile, scale, filters, generate_tile_spiral, **kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("unknown tile order: %s", order)
|
logger.warning("unknown tile order: %s", order)
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
@ -495,3 +453,76 @@ def generate_tile_spiral(
|
||||||
height_tile_target -= abs(state.value[1])
|
height_tile_target -= abs(state.value[1])
|
||||||
|
|
||||||
return tile_coords
|
return tile_coords
|
||||||
|
|
||||||
|
|
||||||
|
def generate_tile_grid(
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
tile: int,
|
||||||
|
overlap: float = 0.0,
|
||||||
|
) -> List[Tuple[int, int]]:
|
||||||
|
adj_tile = int(float(tile) * (1.0 - overlap))
|
||||||
|
tiles_x = ceil(width / adj_tile)
|
||||||
|
tiles_y = ceil(height / adj_tile)
|
||||||
|
total = tiles_x * tiles_y
|
||||||
|
logger.debug(
|
||||||
|
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
|
||||||
|
total,
|
||||||
|
tiles_x,
|
||||||
|
tiles_y,
|
||||||
|
adj_tile,
|
||||||
|
overlap,
|
||||||
|
)
|
||||||
|
|
||||||
|
tiles: List[Tuple[int, int, Image.Image]] = []
|
||||||
|
|
||||||
|
for y in range(tiles_y):
|
||||||
|
for x in range(tiles_x):
|
||||||
|
left = x * adj_tile
|
||||||
|
top = y * adj_tile
|
||||||
|
|
||||||
|
tiles.append((int(left), int(top)))
|
||||||
|
|
||||||
|
return tiles
|
||||||
|
|
||||||
|
|
||||||
|
def get_result_tile(
|
||||||
|
result: StageResult,
|
||||||
|
origin: Tuple[int, int],
|
||||||
|
tile: Size,
|
||||||
|
) -> List[Image.Image]:
|
||||||
|
top, left = origin
|
||||||
|
return [
|
||||||
|
layer.crop((top, left, top + tile.height, left + tile.width))
|
||||||
|
for layer in result.as_image()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def add_margin(
|
||||||
|
stack: List[Image.Image],
|
||||||
|
left: int,
|
||||||
|
top: int,
|
||||||
|
right: int,
|
||||||
|
bottom: int,
|
||||||
|
left_margin: int,
|
||||||
|
top_margin: int,
|
||||||
|
right_margin: int,
|
||||||
|
bottom_margin: int,
|
||||||
|
tile: int,
|
||||||
|
noise_source,
|
||||||
|
fill_color,
|
||||||
|
) -> List[Image.Image]:
|
||||||
|
results = []
|
||||||
|
for source in stack:
|
||||||
|
base_image = source.crop(
|
||||||
|
(
|
||||||
|
left + left_margin,
|
||||||
|
top + top_margin,
|
||||||
|
right + right_margin,
|
||||||
|
bottom + bottom_margin,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tile_image = noise_source(base_image, (tile, tile), (0, 0), fill=fill_color)
|
||||||
|
tile_image.paste(base_image, (left_margin, top_margin))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
|
@ -79,12 +79,15 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
logger.trace("BSRGAN input shape: %s", image.shape)
|
logger.trace("BSRGAN input shape: %s", image.shape)
|
||||||
|
|
||||||
scale = upscale.outscale
|
scale = upscale.outscale
|
||||||
logger.trace("BSRGAN output shape: %s", (
|
logger.trace(
|
||||||
|
"BSRGAN output shape: %s",
|
||||||
|
(
|
||||||
image.shape[0],
|
image.shape[0],
|
||||||
image.shape[1],
|
image.shape[1],
|
||||||
image.shape[2] * scale,
|
image.shape[2] * scale,
|
||||||
image.shape[3] * scale,
|
image.shape[3] * scale,
|
||||||
))
|
),
|
||||||
|
)
|
||||||
|
|
||||||
output = bsrgan(image)
|
output = bsrgan(image)
|
||||||
|
|
||||||
|
|
|
@ -38,10 +38,14 @@ class UpscaleSimpleStage(BaseStage):
|
||||||
|
|
||||||
if method == "bilinear":
|
if method == "bilinear":
|
||||||
logger.debug("using bilinear interpolation for highres")
|
logger.debug("using bilinear interpolation for highres")
|
||||||
outputs.append(source.resize(scaled_size, resample=Image.Resampling.BILINEAR))
|
outputs.append(
|
||||||
|
source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
|
||||||
|
)
|
||||||
elif method == "lanczos":
|
elif method == "lanczos":
|
||||||
logger.debug("using Lanczos interpolation for highres")
|
logger.debug("using Lanczos interpolation for highres")
|
||||||
outputs.append(source.resize(scaled_size, resample=Image.Resampling.LANCZOS))
|
outputs.append(
|
||||||
|
source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("unknown upscaling method: %s", method)
|
logger.warning("unknown upscaling method: %s", method)
|
||||||
|
|
||||||
|
|
|
@ -105,7 +105,9 @@ def run_txt2img_pipeline(
|
||||||
# run and save
|
# run and save
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain.run(worker, server, params, StageResult.empty(), callback=progress, latents=latents)
|
images = chain.run(
|
||||||
|
worker, server, params, StageResult.empty(), callback=progress, latents=latents
|
||||||
|
)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
|
|
||||||
|
@ -200,7 +202,9 @@ def run_img2img_pipeline(
|
||||||
|
|
||||||
# run and append the filtered source
|
# run and append the filtered source
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain.run(worker, server, params, StageResult(images=[source]), callback=progress)
|
images = chain.run(
|
||||||
|
worker, server, params, StageResult(images=[source]), callback=progress
|
||||||
|
)
|
||||||
|
|
||||||
if source_filter is not None and source_filter != "none":
|
if source_filter is not None and source_filter != "none":
|
||||||
images.append(source)
|
images.append(source)
|
||||||
|
@ -380,7 +384,9 @@ def run_inpaint_pipeline(
|
||||||
# run and save
|
# run and save
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain.run(worker, server, params, [source], callback=progress, latents=latents)
|
images = chain.run(
|
||||||
|
worker, server, params, [source], callback=progress, latents=latents
|
||||||
|
)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
|
@ -455,7 +461,9 @@ def run_upscale_pipeline(
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain.run(worker, server, params, StageResult(images=[source]), callback=progress)
|
images = chain.run(
|
||||||
|
worker, server, params, StageResult(images=[source]), callback=progress
|
||||||
|
)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
|
@ -505,7 +513,9 @@ def run_blend_pipeline(
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain.run(worker, server, params, StageResult(images=sources), callback=progress)
|
images = chain.run(
|
||||||
|
worker, server, params, StageResult(images=sources), callback=progress
|
||||||
|
)
|
||||||
|
|
||||||
for image, output in zip(images, outputs):
|
for image, output in zip(images, outputs):
|
||||||
dest = save_image(server, output, image, params, size, upscale=upscale)
|
dest = save_image(server, output, image, params, size, upscale=upscale)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
from onnx_web.chain.result import StageResult
|
|
||||||
|
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
from onnx_web.chain.source_noise import SourceNoiseStage
|
from onnx_web.chain.source_noise import SourceNoiseStage
|
||||||
from onnx_web.image.noise_source import noise_source_fill_edge
|
from onnx_web.image.noise_source import noise_source_fill_edge
|
||||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
from onnx_web.chain.result import StageResult
|
|
||||||
|
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
from onnx_web.chain.source_s3 import SourceS3Stage
|
from onnx_web.chain.source_s3 import SourceS3Stage
|
||||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
from onnx_web.chain.result import StageResult
|
|
||||||
|
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
from onnx_web.chain.source_url import SourceURLStage
|
from onnx_web.chain.source_url import SourceURLStage
|
||||||
from onnx_web.params import HighresParams, Size, UpscaleParams
|
from onnx_web.params import HighresParams, Size, UpscaleParams
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,11 @@ from PIL import Image
|
||||||
|
|
||||||
from onnx_web.chain.tile import (
|
from onnx_web.chain.tile import (
|
||||||
complete_tile,
|
complete_tile,
|
||||||
|
generate_tile_grid,
|
||||||
generate_tile_spiral,
|
generate_tile_spiral,
|
||||||
get_tile_grads,
|
make_tile_grads,
|
||||||
needs_tile,
|
needs_tile,
|
||||||
process_tile_grid,
|
process_tile_stack,
|
||||||
process_tile_spiral,
|
|
||||||
)
|
)
|
||||||
from onnx_web.params import Size
|
from onnx_web.params import Size
|
||||||
|
|
||||||
|
@ -59,24 +59,46 @@ class TestNeedsTile(unittest.TestCase):
|
||||||
|
|
||||||
class TestTileGrads(unittest.TestCase):
|
class TestTileGrads(unittest.TestCase):
|
||||||
def test_center_tile(self):
|
def test_center_tile(self):
|
||||||
grad_x, grad_y = get_tile_grads(32, 32, 8, 64, 64)
|
grad_x, grad_y = make_tile_grads(32, 32, 8, 64, 64)
|
||||||
|
|
||||||
self.assertEqual(grad_x, [0, 1, 1, 0])
|
self.assertEqual(grad_x, [0, 1, 1, 0])
|
||||||
self.assertEqual(grad_y, [0, 1, 1, 0])
|
self.assertEqual(grad_y, [0, 1, 1, 0])
|
||||||
|
|
||||||
def test_vertical_edge_tile(self):
|
def test_vertical_edge_tile(self):
|
||||||
grad_x, grad_y = get_tile_grads(32, 0, 8, 64, 8)
|
grad_x, grad_y = make_tile_grads(32, 0, 8, 64, 8)
|
||||||
|
|
||||||
self.assertEqual(grad_x, [0, 1, 1, 0])
|
self.assertEqual(grad_x, [0, 1, 1, 0])
|
||||||
self.assertEqual(grad_y, [1, 1, 1, 1])
|
self.assertEqual(grad_y, [1, 1, 1, 1])
|
||||||
|
|
||||||
def test_horizontal_edge_tile(self):
|
def test_horizontal_edge_tile(self):
|
||||||
grad_x, grad_y = get_tile_grads(0, 32, 8, 8, 64)
|
grad_x, grad_y = make_tile_grads(0, 32, 8, 8, 64)
|
||||||
|
|
||||||
self.assertEqual(grad_x, [1, 1, 1, 1])
|
self.assertEqual(grad_x, [1, 1, 1, 1])
|
||||||
self.assertEqual(grad_y, [0, 1, 1, 0])
|
self.assertEqual(grad_y, [0, 1, 1, 0])
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateTileGrid(unittest.TestCase):
|
||||||
|
def test_grid_complete(self):
|
||||||
|
tiles = generate_tile_grid(16, 16, 8, 0.0)
|
||||||
|
|
||||||
|
self.assertEqual(len(tiles), 4)
|
||||||
|
self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)])
|
||||||
|
|
||||||
|
def test_grid_no_overlap(self):
|
||||||
|
tiles = generate_tile_grid(64, 64, 8, 0.0)
|
||||||
|
|
||||||
|
self.assertEqual(len(tiles), 64)
|
||||||
|
self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)])
|
||||||
|
self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)])
|
||||||
|
|
||||||
|
def test_grid_50_overlap(self):
|
||||||
|
tiles = generate_tile_grid(64, 64, 8, 0.5)
|
||||||
|
|
||||||
|
self.assertEqual(len(tiles), 225)
|
||||||
|
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
|
||||||
|
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
|
||||||
|
|
||||||
|
|
||||||
class TestGenerateTileSpiral(unittest.TestCase):
|
class TestGenerateTileSpiral(unittest.TestCase):
|
||||||
def test_spiral_complete(self):
|
def test_spiral_complete(self):
|
||||||
tiles = generate_tile_spiral(16, 16, 8, 0.0)
|
tiles = generate_tile_spiral(16, 16, 8, 0.0)
|
||||||
|
@ -99,29 +121,15 @@ class TestGenerateTileSpiral(unittest.TestCase):
|
||||||
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
|
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
|
||||||
|
|
||||||
|
|
||||||
class TestProcessTileGrid(unittest.TestCase):
|
class TestProcessTileStack(unittest.TestCase):
|
||||||
def test_grid_full(self):
|
def test_grid_full(self):
|
||||||
source = Image.new("RGB", (64, 64))
|
source = Image.new("RGB", (64, 64))
|
||||||
blend = process_tile_grid(source, 32, 1, [])
|
blend = process_tile_stack(source, 32, 1, [])
|
||||||
|
|
||||||
self.assertEqual(blend.size, (64, 64))
|
self.assertEqual(blend.size, (64, 64))
|
||||||
|
|
||||||
def test_grid_partial(self):
|
def test_grid_partial(self):
|
||||||
source = Image.new("RGB", (72, 72))
|
source = Image.new("RGB", (72, 72))
|
||||||
blend = process_tile_grid(source, 32, 1, [])
|
blend = process_tile_stack(source, 32, 1, [])
|
||||||
|
|
||||||
self.assertEqual(blend.size, (72, 72))
|
|
||||||
|
|
||||||
|
|
||||||
class TestProcessTileSpiral(unittest.TestCase):
|
|
||||||
def test_grid_full(self):
|
|
||||||
source = Image.new("RGB", (64, 64))
|
|
||||||
blend = process_tile_spiral(source, 32, 1, [])
|
|
||||||
|
|
||||||
self.assertEqual(blend.size, (64, 64))
|
|
||||||
|
|
||||||
def test_grid_partial(self):
|
|
||||||
source = Image.new("RGB", (72, 72))
|
|
||||||
blend = process_tile_spiral(source, 32, 1, [])
|
|
||||||
|
|
||||||
self.assertEqual(blend.size, (72, 72))
|
self.assertEqual(blend.size, (72, 72))
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
from onnx_web.chain.result import StageResult
|
|
||||||
|
|
||||||
|
from onnx_web.chain.result import StageResult
|
||||||
from onnx_web.chain.upscale_highres import UpscaleHighresStage
|
from onnx_web.chain.upscale_highres import UpscaleHighresStage
|
||||||
from onnx_web.params import HighresParams, UpscaleParams
|
from onnx_web.params import HighresParams, UpscaleParams
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue