diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index 5d3c9e20..5d1ed922 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -71,10 +71,6 @@ class UpscaleBSRGANStage(BaseStage): device = job.get_device() bsrgan = self.load(server, stage, upscale, device) - tile_size = (64, 64) - tile_x = source.width // tile_size[0] - tile_y = source.height // tile_size[1] - image = np.array(source) / 255.0 image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = np.expand_dims(image, axis=0) @@ -91,33 +87,7 @@ class UpscaleBSRGANStage(BaseStage): ) logger.trace("BSRGAN output shape: %s", dest.shape) - for x in range(tile_x): - for y in range(tile_y): - xt = x * tile_size[0] - yt = y * tile_size[1] - - ix1 = xt - ix2 = xt + tile_size[0] - iy1 = yt - iy2 = yt + tile_size[1] - logger.debug( - "running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)", - ix1, - ix2, - iy1, - iy2, - ix1 * scale, - ix2 * scale, - iy1 * scale, - iy2 * scale, - ) - - dest[ - :, - :, - ix1 * scale : ix2 * scale, - iy1 * scale : iy2 * scale, - ] = bsrgan(image[:, :, ix1:ix2, iy1:iy2]) + dest = bsrgan(image) dest = np.clip(np.squeeze(dest, axis=0), 0, 1) dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) @@ -129,7 +99,8 @@ class UpscaleBSRGANStage(BaseStage): def steps( self, - _params: ImageParams, + params: ImageParams, size: Size, ) -> int: - return size.width // self.max_tile * size.height // self.max_tile + tile = min(params.tiles, self.max_tile) + return size.width // tile * size.height // tile diff --git a/api/onnx_web/chain/upscale_swinir.py b/api/onnx_web/chain/upscale_swinir.py index c99253c9..e7379ff5 100644 --- a/api/onnx_web/chain/upscale_swinir.py +++ b/api/onnx_web/chain/upscale_swinir.py @@ -71,11 +71,6 @@ class UpscaleSwinIRStage(BaseStage): device = job.get_device() swinir = self.load(server, stage, upscale, device) - # TODO: add support for other sizes - tile_size = (64, 64) - tile_x = source.width // tile_size[0] - tile_y = source.height // tile_size[1] - # TODO: add support for grayscale (1-channel) images image = np.array(source) / 255.0 image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) @@ -93,34 +88,7 @@ class UpscaleSwinIRStage(BaseStage): ) logger.info("SwinIR output shape: %s", dest.shape) - for x in range(tile_x): - for y in range(tile_y): - xt = x * tile_size[0] - yt = y * tile_size[1] - - ix1 = xt - ix2 = xt + tile_size[0] - iy1 = yt - iy2 = yt + tile_size[1] - logger.info( - "running SwinIR on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)", - ix1, - ix2, - iy1, - iy2, - ix1 * scale, - ix2 * scale, - iy1 * scale, - iy2 * scale, - ) - - dest[ - :, - :, - ix1 * scale : ix2 * scale, - iy1 * scale : iy2 * scale, - ] = swinir(image[:, :, ix1:ix2, iy1:iy2]) - + dest = swinir(image) dest = np.clip(np.squeeze(dest, axis=0), 0, 1) dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) dest = (dest * 255.0).round().astype(np.uint8)