fix(api): remove inner loops in upscale stages
This commit is contained in:
parent
99c91a301c
commit
12698d38eb
|
@ -71,10 +71,6 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
device = job.get_device()
|
||||
bsrgan = self.load(server, stage, upscale, device)
|
||||
|
||||
tile_size = (64, 64)
|
||||
tile_x = source.width // tile_size[0]
|
||||
tile_y = source.height // tile_size[1]
|
||||
|
||||
image = np.array(source) / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
image = np.expand_dims(image, axis=0)
|
||||
|
@ -91,33 +87,7 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
)
|
||||
logger.trace("BSRGAN output shape: %s", dest.shape)
|
||||
|
||||
for x in range(tile_x):
|
||||
for y in range(tile_y):
|
||||
xt = x * tile_size[0]
|
||||
yt = y * tile_size[1]
|
||||
|
||||
ix1 = xt
|
||||
ix2 = xt + tile_size[0]
|
||||
iy1 = yt
|
||||
iy2 = yt + tile_size[1]
|
||||
logger.debug(
|
||||
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
|
||||
ix1,
|
||||
ix2,
|
||||
iy1,
|
||||
iy2,
|
||||
ix1 * scale,
|
||||
ix2 * scale,
|
||||
iy1 * scale,
|
||||
iy2 * scale,
|
||||
)
|
||||
|
||||
dest[
|
||||
:,
|
||||
:,
|
||||
ix1 * scale : ix2 * scale,
|
||||
iy1 * scale : iy2 * scale,
|
||||
] = bsrgan(image[:, :, ix1:ix2, iy1:iy2])
|
||||
dest = bsrgan(image)
|
||||
|
||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||
|
@ -129,7 +99,8 @@ class UpscaleBSRGANStage(BaseStage):
|
|||
|
||||
def steps(
|
||||
self,
|
||||
_params: ImageParams,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
) -> int:
|
||||
return size.width // self.max_tile * size.height // self.max_tile
|
||||
tile = min(params.tiles, self.max_tile)
|
||||
return size.width // tile * size.height // tile
|
||||
|
|
|
@ -71,11 +71,6 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
device = job.get_device()
|
||||
swinir = self.load(server, stage, upscale, device)
|
||||
|
||||
# TODO: add support for other sizes
|
||||
tile_size = (64, 64)
|
||||
tile_x = source.width // tile_size[0]
|
||||
tile_y = source.height // tile_size[1]
|
||||
|
||||
# TODO: add support for grayscale (1-channel) images
|
||||
image = np.array(source) / 255.0
|
||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||
|
@ -93,34 +88,7 @@ class UpscaleSwinIRStage(BaseStage):
|
|||
)
|
||||
logger.info("SwinIR output shape: %s", dest.shape)
|
||||
|
||||
for x in range(tile_x):
|
||||
for y in range(tile_y):
|
||||
xt = x * tile_size[0]
|
||||
yt = y * tile_size[1]
|
||||
|
||||
ix1 = xt
|
||||
ix2 = xt + tile_size[0]
|
||||
iy1 = yt
|
||||
iy2 = yt + tile_size[1]
|
||||
logger.info(
|
||||
"running SwinIR on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
|
||||
ix1,
|
||||
ix2,
|
||||
iy1,
|
||||
iy2,
|
||||
ix1 * scale,
|
||||
ix2 * scale,
|
||||
iy1 * scale,
|
||||
iy2 * scale,
|
||||
)
|
||||
|
||||
dest[
|
||||
:,
|
||||
:,
|
||||
ix1 * scale : ix2 * scale,
|
||||
iy1 * scale : iy2 * scale,
|
||||
] = swinir(image[:, :, ix1:ix2, iy1:iy2])
|
||||
|
||||
dest = swinir(image)
|
||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||
dest = (dest * 255.0).round().astype(np.uint8)
|
||||
|
|
Loading…
Reference in New Issue