1
0
Fork 0

fix(api): remove inner loops in upscale stages

This commit is contained in:
Sean Sube 2023-07-02 20:38:52 -05:00
parent 99c91a301c
commit 12698d38eb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 5 additions and 66 deletions

View File

@ -71,10 +71,6 @@ class UpscaleBSRGANStage(BaseStage):
device = job.get_device() device = job.get_device()
bsrgan = self.load(server, stage, upscale, device) bsrgan = self.load(server, stage, upscale, device)
tile_size = (64, 64)
tile_x = source.width // tile_size[0]
tile_y = source.height // tile_size[1]
image = np.array(source) / 255.0 image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0) image = np.expand_dims(image, axis=0)
@ -91,33 +87,7 @@ class UpscaleBSRGANStage(BaseStage):
) )
logger.trace("BSRGAN output shape: %s", dest.shape) logger.trace("BSRGAN output shape: %s", dest.shape)
for x in range(tile_x): dest = bsrgan(image)
for y in range(tile_y):
xt = x * tile_size[0]
yt = y * tile_size[1]
ix1 = xt
ix2 = xt + tile_size[0]
iy1 = yt
iy2 = yt + tile_size[1]
logger.debug(
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
ix1,
ix2,
iy1,
iy2,
ix1 * scale,
ix2 * scale,
iy1 * scale,
iy2 * scale,
)
dest[
:,
:,
ix1 * scale : ix2 * scale,
iy1 * scale : iy2 * scale,
] = bsrgan(image[:, :, ix1:ix2, iy1:iy2])
dest = np.clip(np.squeeze(dest, axis=0), 0, 1) dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
@ -129,7 +99,8 @@ class UpscaleBSRGANStage(BaseStage):
def steps( def steps(
self, self,
_params: ImageParams, params: ImageParams,
size: Size, size: Size,
) -> int: ) -> int:
return size.width // self.max_tile * size.height // self.max_tile tile = min(params.tiles, self.max_tile)
return size.width // tile * size.height // tile

View File

@ -71,11 +71,6 @@ class UpscaleSwinIRStage(BaseStage):
device = job.get_device() device = job.get_device()
swinir = self.load(server, stage, upscale, device) swinir = self.load(server, stage, upscale, device)
# TODO: add support for other sizes
tile_size = (64, 64)
tile_x = source.width // tile_size[0]
tile_y = source.height // tile_size[1]
# TODO: add support for grayscale (1-channel) images # TODO: add support for grayscale (1-channel) images
image = np.array(source) / 255.0 image = np.array(source) / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
@ -93,34 +88,7 @@ class UpscaleSwinIRStage(BaseStage):
) )
logger.info("SwinIR output shape: %s", dest.shape) logger.info("SwinIR output shape: %s", dest.shape)
for x in range(tile_x): dest = swinir(image)
for y in range(tile_y):
xt = x * tile_size[0]
yt = y * tile_size[1]
ix1 = xt
ix2 = xt + tile_size[0]
iy1 = yt
iy2 = yt + tile_size[1]
logger.info(
"running SwinIR on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
ix1,
ix2,
iy1,
iy2,
ix1 * scale,
ix2 * scale,
iy1 * scale,
iy2 * scale,
)
dest[
:,
:,
ix1 * scale : ix2 * scale,
iy1 * scale : iy2 * scale,
] = swinir(image[:, :, ix1:ix2, iy1:iy2])
dest = np.clip(np.squeeze(dest, axis=0), 0, 1) dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
dest = (dest * 255.0).round().astype(np.uint8) dest = (dest * 255.0).round().astype(np.uint8)