1
0
Fork 0

work on non-square regions

This commit is contained in:
Sean Sube 2023-11-09 22:42:45 -06:00
parent 30a9d01432
commit 3622ac4bfb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 22 additions and 10 deletions

View File

@ -93,23 +93,31 @@ def get_tile_grads(
def make_tile_mask( def make_tile_mask(
shape: Any, shape: Any,
tile: int, tile: Tuple[int, int],
overlap: float, overlap: float,
) -> np.ndarray: ) -> np.ndarray:
mask = np.ones(shape) mask = np.ones(shape)
adj_tile = int(float(tile) * (1.0 - overlap))
tile_h, tile_w = tile
adj_tile_h = int(float(tile_h) * (1.0 - overlap))
adj_tile_w = int(float(tile_w) * (1.0 - overlap))
# sort gradient points # sort gradient points
p1 = adj_tile p1_h = adj_tile_h
p2 = tile - adj_tile p2_h = tile_h - adj_tile_h
points = [0, min(p1, p2), max(p1, p2), tile] points_h = [0, min(p1_h, p2_h), max(p1_h, p2_h), tile]
p1_w = adj_tile_w
p2_w = tile_w - adj_tile_w
points_w = [0, min(p1_w, p2_w), max(p1_w, p2_w), tile]
# build gradients # build gradients
grad_x, grad_y = [0, 1, 1, 0], [0, 1, 1, 0] grad_x, grad_y = [0, 1, 1, 0], [0, 1, 1, 0]
logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y) logger.debug("tile gradients: %s, %s, %s, %s", points_w, points_h, grad_x, grad_y)
mult_x = [np.interp(i, points, grad_x) for i in range(tile)] mult_x = [np.interp(i, points_w, grad_x) for i in range(tile_w)]
mult_y = [np.interp(i, points, grad_y) for i in range(tile)] mult_y = [np.interp(i, points_h, grad_y) for i in range(tile_h)]
mask = ((mask * mult_x).T * mult_y).T mask = ((mask * mult_x).T * mult_y).T

View File

@ -668,7 +668,9 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
if feather > 0.0: if feather > 0.0:
mask = make_tile_mask( mask = make_tile_mask(
(h_end - h_start, w_end - w_start), self.window, feather (h_end - h_start, w_end - w_start),
(h_end - h_start, w_end - w_start),
feather,
) )
mask = np.expand_dims(mask, axis=0) mask = np.expand_dims(mask, axis=0)
mask = np.repeat(mask, 4, axis=0) mask = np.repeat(mask, 4, axis=0)

View File

@ -522,7 +522,9 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix
if feather > 0.0: if feather > 0.0:
mask = make_tile_mask( mask = make_tile_mask(
(h_end - h_start, w_end - w_start), self.window, feather (h_end - h_start, w_end - w_start),
(h_end - h_start, w_end - w_start),
feather,
) )
mask = np.expand_dims(mask, axis=0) mask = np.expand_dims(mask, axis=0)
mask = np.repeat(mask, 4, axis=0) mask = np.repeat(mask, 4, axis=0)