fix(api): use SRVGG net for Real ESRGAN v3
This commit is contained in:
parent
ae5cf1fd28
commit
de4a3818a0
|
@ -5,6 +5,7 @@ import numpy as np
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
from ..device_pool import JobContext
|
||||||
from ..onnx import OnnxNet
|
from ..onnx import OnnxNet
|
||||||
|
@ -16,6 +17,8 @@ logger = getLogger(__name__)
|
||||||
last_pipeline_instance = None
|
last_pipeline_instance = None
|
||||||
last_pipeline_params = (None, None)
|
last_pipeline_params = (None, None)
|
||||||
|
|
||||||
|
x4_v3_tag = "real-esrgan-x4-v3"
|
||||||
|
|
||||||
|
|
||||||
def load_resrgan(
|
def load_resrgan(
|
||||||
ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0
|
||||||
|
@ -33,8 +36,18 @@ def load_resrgan(
|
||||||
logger.info("reusing existing Real ESRGAN pipeline")
|
logger.info("reusing existing Real ESRGAN pipeline")
|
||||||
return last_pipeline_instance
|
return last_pipeline_instance
|
||||||
|
|
||||||
# use ONNX acceleration, if available
|
if x4_v3_tag in model_file:
|
||||||
if params.format == "onnx":
|
# the x4-v3 model needs a different network
|
||||||
|
model = SRVGGNetCompact(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_conv=32,
|
||||||
|
upscale=4,
|
||||||
|
act_type="prelu",
|
||||||
|
)
|
||||||
|
elif params.format == "onnx":
|
||||||
|
# use ONNX acceleration, if available
|
||||||
model = OnnxNet(
|
model = OnnxNet(
|
||||||
ctx, model_file, provider=device.provider, provider_options=device.options
|
ctx, model_file, provider=device.provider, provider_options=device.options
|
||||||
)
|
)
|
||||||
|
@ -47,13 +60,12 @@ def load_resrgan(
|
||||||
num_grow_ch=32,
|
num_grow_ch=32,
|
||||||
scale=params.scale,
|
scale=params.scale,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
raise Exception("unknown platform %s" % params.format)
|
raise Exception("unknown platform %s" % params.format)
|
||||||
|
|
||||||
dni_weight = None
|
dni_weight = None
|
||||||
if params.upscale_model == "real-esrgan-x4-v3" and params.denoise != 1:
|
if params.upscale_model == x4_v3_tag and params.denoise != 1:
|
||||||
wdn_model_path = model_path.replace(
|
wdn_model_path = model_path.replace(x4_v3_tag, "%s-wdn" % (x4_v3_tag))
|
||||||
"real-esrgan-x4-v3", "real-esrgan-x4-v3-wdn"
|
|
||||||
)
|
|
||||||
model_path = [model_path, wdn_model_path]
|
model_path = [model_path, wdn_model_path]
|
||||||
dni_weight = [params.denoise, 1 - params.denoise]
|
dni_weight = [params.denoise, 1 - params.denoise]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue