1
0
Fork 0

fix(api): use SRVGG net for Real ESRGAN v3

This commit is contained in:
Sean Sube 2023-02-06 21:36:20 -06:00
parent ae5cf1fd28
commit de4a3818a0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 18 additions and 6 deletions

View File

@ -5,6 +5,7 @@ import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image from PIL import Image
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from ..device_pool import JobContext from ..device_pool import JobContext
from ..onnx import OnnxNet from ..onnx import OnnxNet
@ -16,6 +17,8 @@ logger = getLogger(__name__)
last_pipeline_instance = None last_pipeline_instance = None
last_pipeline_params = (None, None) last_pipeline_params = (None, None)
x4_v3_tag = "real-esrgan-x4-v3"
def load_resrgan( def load_resrgan(
ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0 ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
@ -33,8 +36,18 @@ def load_resrgan(
logger.info("reusing existing Real ESRGAN pipeline") logger.info("reusing existing Real ESRGAN pipeline")
return last_pipeline_instance return last_pipeline_instance
# use ONNX acceleration, if available if x4_v3_tag in model_file:
if params.format == "onnx": # the x4-v3 model needs a different network
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=4,
act_type="prelu",
)
elif params.format == "onnx":
# use ONNX acceleration, if available
model = OnnxNet( model = OnnxNet(
ctx, model_file, provider=device.provider, provider_options=device.options ctx, model_file, provider=device.provider, provider_options=device.options
) )
@ -47,13 +60,12 @@ def load_resrgan(
num_grow_ch=32, num_grow_ch=32,
scale=params.scale, scale=params.scale,
) )
else:
raise Exception("unknown platform %s" % params.format) raise Exception("unknown platform %s" % params.format)
dni_weight = None dni_weight = None
if params.upscale_model == "real-esrgan-x4-v3" and params.denoise != 1: if params.upscale_model == x4_v3_tag and params.denoise != 1:
wdn_model_path = model_path.replace( wdn_model_path = model_path.replace(x4_v3_tag, "%s-wdn" % (x4_v3_tag))
"real-esrgan-x4-v3", "real-esrgan-x4-v3-wdn"
)
model_path = [model_path, wdn_model_path] model_path = [model_path, wdn_model_path]
dni_weight = [params.denoise, 1 - params.denoise] dni_weight = [params.denoise, 1 - params.denoise]