1
0
Fork 0
onnx-web/api/onnx_web/chain/upscale_swinir.py

102 lines
3.0 KiB
Python
Raw Normal View History

from logging import getLogger
from os import path
2023-11-19 00:13:13 +00:00
from typing import Optional
import numpy as np
from PIL import Image
from ..models.onnx import OnnxModel
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..server import ModelTypes, ServerContext
from ..utils import run_gc
from ..worker import WorkerContext
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class UpscaleSwinIRStage(BaseStage):
max_tile = 64
def load(
self,
server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
device: DeviceParams,
):
# must be within the load function for patch to take effect
model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model))
cache_key = (model_path,)
cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key)
if cache_pipe is not None:
logger.info("reusing existing SwinIR pipeline")
return cache_pipe
logger.debug("loading SwinIR model from %s", model_path)
pipe = OnnxModel(
server,
model_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
server.cache.set(ModelTypes.upscaling, cache_key, pipe)
run_gc([device])
return pipe
def run(
self,
worker: WorkerContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
sources: StageResult,
*,
upscale: UpscaleParams,
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> StageResult:
upscale = upscale.with_args(**kwargs)
if upscale.upscale_model is None:
logger.warning("no correction model given, skipping")
return sources
logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model)
device = worker.get_device()
swinir = self.load(server, stage, upscale, device)
outputs = []
for source in sources.as_numpy():
# TODO: add support for grayscale (1-channel) images
image = source / 255.0
image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1))
image = np.expand_dims(image, axis=0)
logger.trace("SwinIR input shape: %s", image.shape)
scale = upscale.outscale
2023-11-19 00:13:13 +00:00
logger.trace(
"SwinIR output shape: %s",
(
image.shape[0],
image.shape[1],
image.shape[2] * scale,
image.shape[3] * scale,
2023-11-19 00:13:13 +00:00
),
)
output = swinir(image)
output = np.clip(np.squeeze(output, axis=0), 0, 1)
output = output[[2, 1, 0], :, :].transpose((1, 2, 0))
output = (output * 255.0).round().astype(np.uint8)
logger.info("output image size: %s", output.shape)
outputs.append(output)
return StageResult(images=outputs)