1
0
Fork 0

resize mask to match source

This commit is contained in:
Sean Sube 2023-04-29 15:40:26 -05:00
parent bfe989997e
commit fbab26fe31
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 13 additions and 7 deletions

View File

@ -13,7 +13,7 @@ from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..utils import is_debug from ..utils import is_debug
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .utils import process_tile_grid, process_tile_order, complete_tile from .utils import complete_tile, process_tile_grid, process_tile_order
logger = getLogger(__name__) logger = getLogger(__name__)
@ -35,7 +35,12 @@ def upscale_outpaint(
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
source = stage_source or source source = stage_source or source
logger.info("upscaling image by expanding borders: %s", border) logger.info(
"upscaling %s x %s image by expanding borders: %s",
source.width,
source.height,
border,
)
margin_x = float(max(border.left, border.right)) margin_x = float(max(border.left, border.right))
margin_y = float(max(border.top, border.bottom)) margin_y = float(max(border.top, border.bottom))
@ -53,7 +58,8 @@ def upscale_outpaint(
noise_source=noise_source, noise_source=noise_source,
mask_filter=mask_filter, mask_filter=mask_filter,
) )
full_latents = get_latents_from_seed(params.seed, full_size) stage_mask = stage_mask.resize(source.size)
full_latents = get_latents_from_seed(params.seed, Size(*full_size))
draw_mask = ImageDraw.Draw(stage_mask) draw_mask = ImageDraw.Draw(stage_mask)

View File

@ -33,7 +33,6 @@ def complete_tile(
return source return source
def process_tile_grid( def process_tile_grid(
source: Image.Image, source: Image.Image,
tile: int, tile: int,
@ -152,7 +151,7 @@ def generate_tile_spiral(
while accum_width < walk_width and accum_height < walk_height: while accum_width < walk_width and accum_height < walk_height:
# add a tile # add a tile
logger.trace( logger.trace(
"adding tile at %s:%s, %s:%s, %s:%s", "adding tile at %s:%s, %s:%s, %s:%s, %s",
tile_left, tile_left,
tile_top, tile_top,
accum_width, accum_width,

View File

@ -17,6 +17,7 @@ def expand_image(
mask_filter=mask_filter_none, mask_filter=mask_filter_none,
): ):
size = Size(*source.size).add_border(expand).round_to_tile() size = Size(*source.size).add_border(expand).round_to_tile()
size = tuple(size)
origin = (expand.left, expand.top) origin = (expand.left, expand.top)
full_source = Image.new("RGB", size, fill) full_source = Image.new("RGB", size, fill)
@ -29,7 +30,7 @@ def expand_image(
full_source = Image.composite(full_noise, full_source, full_mask.convert("L")) full_source = Image.composite(full_noise, full_source, full_mask.convert("L"))
return (full_source, full_mask, full_noise, (size.width, size.height)) return (full_source, full_mask, full_noise, size)
def valid_image( def valid_image(

View File

@ -79,7 +79,7 @@ class Size:
border.top + self.height + border.bottom, border.top + self.height + border.bottom,
) )
def round_to_tile(self, tile = 512): def round_to_tile(self, tile=512):
return Size( return Size(
ceil(self.width / tile) * tile, ceil(self.width / tile) * tile,
ceil(self.height / tile) * tile, ceil(self.height / tile) * tile,