fix(api): remove inner loops in upscale stages
This commit is contained in:
parent
99c91a301c
commit
12698d38eb
|
@ -71,10 +71,6 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
device = job.get_device()
|
device = job.get_device()
|
||||||
bsrgan = self.load(server, stage, upscale, device)
|
bsrgan = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
tile_size = (64, 64)
|
|
||||||
tile_x = source.width // tile_size[0]
|
|
||||||
tile_y = source.height // tile_size[1]
|
|
||||||
|
|
||||||
image = np.array(source) / 255.0
|
image = np.array(source) / 255.0
|
||||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||||
image = np.expand_dims(image, axis=0)
|
image = np.expand_dims(image, axis=0)
|
||||||
|
@ -91,33 +87,7 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
)
|
)
|
||||||
logger.trace("BSRGAN output shape: %s", dest.shape)
|
logger.trace("BSRGAN output shape: %s", dest.shape)
|
||||||
|
|
||||||
for x in range(tile_x):
|
dest = bsrgan(image)
|
||||||
for y in range(tile_y):
|
|
||||||
xt = x * tile_size[0]
|
|
||||||
yt = y * tile_size[1]
|
|
||||||
|
|
||||||
ix1 = xt
|
|
||||||
ix2 = xt + tile_size[0]
|
|
||||||
iy1 = yt
|
|
||||||
iy2 = yt + tile_size[1]
|
|
||||||
logger.debug(
|
|
||||||
"running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
|
|
||||||
ix1,
|
|
||||||
ix2,
|
|
||||||
iy1,
|
|
||||||
iy2,
|
|
||||||
ix1 * scale,
|
|
||||||
ix2 * scale,
|
|
||||||
iy1 * scale,
|
|
||||||
iy2 * scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
dest[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
ix1 * scale : ix2 * scale,
|
|
||||||
iy1 * scale : iy2 * scale,
|
|
||||||
] = bsrgan(image[:, :, ix1:ix2, iy1:iy2])
|
|
||||||
|
|
||||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||||
|
@ -129,7 +99,8 @@ class UpscaleBSRGANStage(BaseStage):
|
||||||
|
|
||||||
def steps(
|
def steps(
|
||||||
self,
|
self,
|
||||||
_params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
) -> int:
|
) -> int:
|
||||||
return size.width // self.max_tile * size.height // self.max_tile
|
tile = min(params.tiles, self.max_tile)
|
||||||
|
return size.width // tile * size.height // tile
|
||||||
|
|
|
@ -71,11 +71,6 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
device = job.get_device()
|
device = job.get_device()
|
||||||
swinir = self.load(server, stage, upscale, device)
|
swinir = self.load(server, stage, upscale, device)
|
||||||
|
|
||||||
# TODO: add support for other sizes
|
|
||||||
tile_size = (64, 64)
|
|
||||||
tile_x = source.width // tile_size[0]
|
|
||||||
tile_y = source.height // tile_size[1]
|
|
||||||
|
|
||||||
# TODO: add support for grayscale (1-channel) images
|
# TODO: add support for grayscale (1-channel) images
|
||||||
image = np.array(source) / 255.0
|
image = np.array(source) / 255.0
|
||||||
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
|
||||||
|
@ -93,34 +88,7 @@ class UpscaleSwinIRStage(BaseStage):
|
||||||
)
|
)
|
||||||
logger.info("SwinIR output shape: %s", dest.shape)
|
logger.info("SwinIR output shape: %s", dest.shape)
|
||||||
|
|
||||||
for x in range(tile_x):
|
dest = swinir(image)
|
||||||
for y in range(tile_y):
|
|
||||||
xt = x * tile_size[0]
|
|
||||||
yt = y * tile_size[1]
|
|
||||||
|
|
||||||
ix1 = xt
|
|
||||||
ix2 = xt + tile_size[0]
|
|
||||||
iy1 = yt
|
|
||||||
iy2 = yt + tile_size[1]
|
|
||||||
logger.info(
|
|
||||||
"running SwinIR on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)",
|
|
||||||
ix1,
|
|
||||||
ix2,
|
|
||||||
iy1,
|
|
||||||
iy2,
|
|
||||||
ix1 * scale,
|
|
||||||
ix2 * scale,
|
|
||||||
iy1 * scale,
|
|
||||||
iy2 * scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
dest[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
ix1 * scale : ix2 * scale,
|
|
||||||
iy1 * scale : iy2 * scale,
|
|
||||||
] = swinir(image[:, :, ix1:ix2, iy1:iy2])
|
|
||||||
|
|
||||||
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
dest = np.clip(np.squeeze(dest, axis=0), 0, 1)
|
||||||
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0))
|
||||||
dest = (dest * 255.0).round().astype(np.uint8)
|
dest = (dest * 255.0).round().astype(np.uint8)
|
||||||
|
|
Loading…
Reference in New Issue