diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 15ef79dc..ef470d9c 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -60,7 +60,7 @@ class SourceTxt2ImgStage(BaseStage): if latents is None: latents = get_latents_from_seed(params.seed, latent_size, params.batch) else: - latents = get_tile_latents(latents, dims, latent_size) + latents = get_tile_latents(latents, params.seed, latent_size, dims) pipe_type = params.get_valid_pipeline("txt2img") pipe = load_pipeline( diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 9753b2b3..ec1a9f1c 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -79,7 +79,7 @@ class UpscaleOutpaintStage(BaseStage): if latents is None: latents = get_latents_from_seed(params.seed, latent_size, params.batch) else: - latents = get_tile_latents(latents, dims, latent_size) + latents = get_tile_latents(latents, params.seed, latent_size, dims) if params.lpw(): logger.debug("using LPW pipeline for inpaint") diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 02dcb35c..64bc354b 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -268,8 +268,9 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray: def get_tile_latents( full_latents: np.ndarray, - dims: Tuple[int, int, int], + seed: int, size: Size, + dims: Tuple[int, int, int], ) -> np.ndarray: x, y, tile = dims t = tile // LATENT_FACTOR @@ -284,12 +285,9 @@ def get_tile_latents( tile_latents = full_latents[:, :, y:yt, x:xt] if tile_latents.shape[2] < t or tile_latents.shape[3] < t: - px = mx - tile_latents.shape[3] - py = my - tile_latents.shape[2] - - tile_latents = np.pad( - tile_latents, ((0, 0), (0, 0), (0, py), (0, px)), mode="reflect" - ) + extra_latents = get_latents_from_seed(seed, size, batch=tile_latents.shape[0]) + extra_latents[:, :, 0:t, 0:t] = tile_latents + tile_latents = extra_latents return tile_latents