From 3622ac4bfb05a5b398f48b432645c7d472143d7a Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 9 Nov 2023 22:42:45 -0600 Subject: [PATCH] work on non-square regions --- api/onnx_web/chain/tile.py | 24 ++++++++++++------- api/onnx_web/diffusers/pipelines/panorama.py | 4 +++- .../diffusers/pipelines/panorama_xl.py | 4 +++- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index efbdc771..4d8bee2d 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -93,23 +93,31 @@ def get_tile_grads( def make_tile_mask( shape: Any, - tile: int, + tile: Tuple[int, int], overlap: float, ) -> np.ndarray: mask = np.ones(shape) - adj_tile = int(float(tile) * (1.0 - overlap)) + + tile_h, tile_w = tile + + adj_tile_h = int(float(tile_h) * (1.0 - overlap)) + adj_tile_w = int(float(tile_w) * (1.0 - overlap)) # sort gradient points - p1 = adj_tile - p2 = tile - adj_tile - points = [0, min(p1, p2), max(p1, p2), tile] + p1_h = adj_tile_h + p2_h = tile_h - adj_tile_h + points_h = [0, min(p1_h, p2_h), max(p1_h, p2_h), tile] + + p1_w = adj_tile_w + p2_w = tile_w - adj_tile_w + points_w = [0, min(p1_w, p2_w), max(p1_w, p2_w), tile] # build gradients grad_x, grad_y = [0, 1, 1, 0], [0, 1, 1, 0] - logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y) + logger.debug("tile gradients: %s, %s, %s, %s", points_w, points_h, grad_x, grad_y) - mult_x = [np.interp(i, points, grad_x) for i in range(tile)] - mult_y = [np.interp(i, points, grad_y) for i in range(tile)] + mult_x = [np.interp(i, points_w, grad_x) for i in range(tile_w)] + mult_y = [np.interp(i, points_h, grad_y) for i in range(tile_h)] mask = ((mask * mult_x).T * mult_y).T diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index 98da1ac8..c5c39f7a 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -668,7 +668,9 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): if feather > 0.0: mask = make_tile_mask( - (h_end - h_start, w_end - w_start), self.window, feather + (h_end - h_start, w_end - w_start), + (h_end - h_start, w_end - w_start), + feather, ) mask = np.expand_dims(mask, axis=0) mask = np.repeat(mask, 4, axis=0) diff --git a/api/onnx_web/diffusers/pipelines/panorama_xl.py b/api/onnx_web/diffusers/pipelines/panorama_xl.py index 7c71748e..6b230108 100644 --- a/api/onnx_web/diffusers/pipelines/panorama_xl.py +++ b/api/onnx_web/diffusers/pipelines/panorama_xl.py @@ -522,7 +522,9 @@ class StableDiffusionXLPanoramaPipelineMixin(StableDiffusionXLImg2ImgPipelineMix if feather > 0.0: mask = make_tile_mask( - (h_end - h_start, w_end - w_start), self.window, feather + (h_end - h_start, w_end - w_start), + (h_end - h_start, w_end - w_start), + feather, ) mask = np.expand_dims(mask, axis=0) mask = np.repeat(mask, 4, axis=0)