From de4a3818a08aa0de741caca3ff31ccfddb1d6975 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 6 Feb 2023 21:36:20 -0600 Subject: [PATCH] fix(api): use SRVGG net for Real ESRGAN v3 --- api/onnx_web/chain/upscale_resrgan.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 25835ef7..f543086a 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -5,6 +5,7 @@ import numpy as np from basicsr.archs.rrdbnet_arch import RRDBNet from PIL import Image from realesrgan import RealESRGANer +from realesrgan.archs.srvgg_arch import SRVGGNetCompact from ..device_pool import JobContext from ..onnx import OnnxNet @@ -16,6 +17,8 @@ logger = getLogger(__name__) last_pipeline_instance = None last_pipeline_params = (None, None) +x4_v3_tag = "real-esrgan-x4-v3" + def load_resrgan( ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0 @@ -33,8 +36,18 @@ def load_resrgan( logger.info("reusing existing Real ESRGAN pipeline") return last_pipeline_instance - # use ONNX acceleration, if available - if params.format == "onnx": + if x4_v3_tag in model_file: + # the x4-v3 model needs a different network + model = SRVGGNetCompact( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_conv=32, + upscale=4, + act_type="prelu", + ) + elif params.format == "onnx": + # use ONNX acceleration, if available model = OnnxNet( ctx, model_file, provider=device.provider, provider_options=device.options ) @@ -47,13 +60,12 @@ def load_resrgan( num_grow_ch=32, scale=params.scale, ) + else: raise Exception("unknown platform %s" % params.format) dni_weight = None - if params.upscale_model == "real-esrgan-x4-v3" and params.denoise != 1: - wdn_model_path = model_path.replace( - "real-esrgan-x4-v3", "real-esrgan-x4-v3-wdn" - ) + if params.upscale_model == x4_v3_tag and params.denoise != 1: + wdn_model_path = model_path.replace(x4_v3_tag, "%s-wdn" % (x4_v3_tag)) model_path = [model_path, wdn_model_path] dni_weight = [params.denoise, 1 - params.denoise]