from logging import getLogger from os import path from typing import Optional import numpy as np from PIL import Image from ..models.onnx import OnnxModel from ..params import ( DeviceParams, HighresParams, ImageParams, SizeChart, StageParams, UpscaleParams, ) from ..server import ModelTypes, ServerContext from ..utils import run_gc from ..worker import ProgressCallback, WorkerContext from .base import BaseStage from .result import StageResult logger = getLogger(__name__) class UpscaleSwinIRStage(BaseStage): max_tile = SizeChart.micro def load( self, server: ServerContext, _stage: StageParams, upscale: UpscaleParams, device: DeviceParams, ): # must be within the load function for patch to take effect model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model)) cache_key = (model_path,) cache_pipe = server.cache.get(ModelTypes.upscaling, cache_key) if cache_pipe is not None: logger.info("reusing existing SwinIR pipeline") return cache_pipe logger.debug("loading SwinIR model from %s", model_path) pipe = OnnxModel( server, model_path, provider=device.ort_provider("swinir"), sess_options=device.sess_options(), ) server.cache.set(ModelTypes.upscaling, cache_key, pipe) run_gc([device]) return pipe def run( self, worker: WorkerContext, server: ServerContext, stage: StageParams, _params: ImageParams, sources: StageResult, *, upscale: UpscaleParams, highres: Optional[HighresParams] = None, stage_source: Optional[Image.Image] = None, callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: upscale = upscale.with_args(**kwargs) if upscale.upscale_model is None: logger.warning("no upscale model given, skipping") return sources logger.info( "upscaling %sx with SwinIR model: %s", upscale.outscale, upscale.upscale_model, ) device = worker.get_device() swinir = self.load(server, stage, upscale, device) outputs = [] for source in sources.as_arrays(): # TODO: add support for grayscale (1-channel) images image = source / 255.0 image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) image = np.expand_dims(image, axis=0) logger.trace("SwinIR input shape: %s", image.shape) scale = upscale.outscale logger.trace( "SwinIR output shape: %s", ( image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale, ), ) output = swinir(image) output = np.clip(np.squeeze(output, axis=0), 0, 1) output = output[[2, 1, 0], :, :].transpose((1, 2, 0)) output = (output * 255.0).round().astype(np.uint8) logger.info("output image size: %s", output.shape) outputs.append(output) return StageResult(images=outputs, metadata=sources.metadata)