From 120056f878593989d83fdf0c5f8f59bcf8f8d0bc Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 16 Jan 2023 13:02:15 -0600 Subject: [PATCH] fix(api): get ESRGAN/GFPGAN paths from server context, clean up test scripts --- api/onnx_web/upscale.py | 59 ++++++++++++++++++++++++++++------------- api/test-diffusers.py | 23 ---------------- api/test-resrgan.py | 10 +++---- 3 files changed, 44 insertions(+), 48 deletions(-) diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index c9f11c0d..9551319d 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -10,19 +10,21 @@ from typing import Any import numpy as np import torch -denoise_strength = 0.5 +from .utils import ( + ServerContext +) + +# TODO: these should all be params or config fp16 = False -netscale = 4 outscale = 4 pre_pad = 0 -tile = 0 tile_pad = 10 gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth' resrgan_name = 'RealESRGAN_x4plus' resrgan_url = [ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] -resrgan_path = path.join('..', 'models', 'RealESRGAN_x4plus.onnx') + class ONNXImage(): def __init__(self, source) -> None: @@ -55,9 +57,13 @@ class ONNXNet(): Provides the RRDBNet interface but using ONNX. ''' - def __init__(self) -> None: + def __init__(self, ctx: ServerContext) -> None: + ''' + TODO: get platform provider from request params + ''' + model_path = path.join(ctx.model_path, resrgan_name + '.onnx') self.session = InferenceSession( - resrgan_path, providers=['DmlExecutionProvider']) + model_path, providers=['DmlExecutionProvider']) def __call__(self, image: Any) -> Any: input_name = self.session.get_inputs()[0].name @@ -80,26 +86,37 @@ class ONNXNet(): return self -def make_resrgan(model_path, tile=0): - model_path = path.join(model_path, resrgan_name + '.pth') +class UpscaleParams(): + def __init__(self, scale=4, faces=True, platform='onnx', denoise=0.5) -> None: + self.denoise = denoise + self.scale = scale + self.faces = faces + self.platform = platform + + +def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0): + model_path = path.join(ctx.model_path, resrgan_name + '.pth') if not path.isfile(model_path): for url in resrgan_url: model_path = load_file_from_url( url=url, model_dir=path.join(model_path, resrgan_name), progress=True, file_name=None) - # model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, - # num_block=23, num_grow_ch=32, scale=4) - model = ONNXNet() + # use ONNX acceleration, if available + if params.platform == 'onnx': + model = ONNXNet(ctx) + else: + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, + num_block=23, num_grow_ch=32, scale=params.scale) dni_weight = None - if resrgan_name == 'realesr-general-x4v3' and denoise_strength != 1: + if resrgan_name == 'realesr-general-x4v3' and params.denoise != 1: wdn_model_path = model_path.replace( 'realesr-general-x4v3', 'realesr-general-wdn-x4v3') model_path = [model_path, wdn_model_path] - dni_weight = [denoise_strength, 1 - denoise_strength] + dni_weight = [params.denoise, 1 - params.denoise] upsampler = RealESRGANer( - scale=netscale, + scale=params.scale, model_path=model_path, dni_weight=dni_weight, model=model, @@ -111,19 +128,23 @@ def make_resrgan(model_path, tile=0): return upsampler -def upscale_resrgan(source_image: Image, model_path: str, faces=True) -> Image: +def upscale_resrgan(ctx: ServerContext, source_image: Image, params: UpscaleParams) -> Image: image = np.array(source_image) - upsampler = make_resrgan(model_path) + upsampler = make_resrgan(ctx.model_path) + # TODO: what is outscale for here? output, _ = upsampler.enhance(image, outscale=outscale) - if faces: - output = upscale_gfpgan(output, make_resrgan(model_path, 512)) + if params.faces: + output = upscale_gfpgan(ctx, output) return Image.fromarray(output, 'RGB') -def upscale_gfpgan(image, upsampler) -> Image: +def upscale_gfpgan(ctx: ServerContext, image, upsampler=None) -> Image: + if upsampler is None: + upsampler = make_resrgan(ctx.model_path, 512) + face_enhancer = GFPGANer( model_path=gfpgan_url, upscale=outscale, diff --git a/api/test-diffusers.py b/api/test-diffusers.py index 555304df..5852eb6b 100644 --- a/api/test-diffusers.py +++ b/api/test-diffusers.py @@ -21,26 +21,3 @@ pipe = OnnxStableDiffusionPipeline.from_pretrained(model, provider='DmlExecution image = pipe(prompt, height, width, num_inference_steps=steps, guidance_scale=cfg).images[0] image.save(output) print('saved test image to %s' % output) - - -upscale = path.join('..', 'outputs', 'test-large.png') -esrgan = path.join('..', 'models', 'RealESRGAN_x4plus.onnx') - -print('upscaling test image...') -sess = ort.InferenceSession(esrgan, providers=['DmlExecutionProvider']) - -in_image = cv2.imread(output, cv2.IMREAD_UNCHANGED) - -in_mat = cv2.cvtColor(in_image, cv2.COLOR_BGR2RGB) -in_mat = np.transpose(in_mat, (2, 1, 0))[np.newaxis] -in_mat = in_mat.astype(np.float32) -in_mat = in_mat/255 - -start_time = time.time() -input_name = sess.get_inputs()[0].name -output_name = sess.get_outputs()[0].name -in_mat = torch.tensor(in_mat) -out_mat = sess.run([output_name], {input_name: in_mat.cpu().numpy()})[0] -elapsed_time = time.time() - start_time -print(elapsed_time) -print('upscaled test image to %s') \ No newline at end of file diff --git a/api/test-resrgan.py b/api/test-resrgan.py index f62c63db..efebb80b 100644 --- a/api/test-resrgan.py +++ b/api/test-resrgan.py @@ -20,13 +20,13 @@ session = ort.InferenceSession(esrgan, providers=['DmlExecutionProvider']) in_image = cv2.imread(output, cv2.IMREAD_UNCHANGED) +# convert to input format in_mat = cv2.cvtColor(in_image, cv2.COLOR_BGR2RGB) -print('shape before', np.shape(in_mat)) in_mat = np.transpose(in_mat, (2, 1, 0))[np.newaxis] -print('shape after', np.shape(in_mat)) in_mat = in_mat.astype(np.float32) -in_mat = in_mat/255 +in_mat = in_mat /255 +# run network start_time = time.time() input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name @@ -37,11 +37,9 @@ out_mat = session.run([output_name], { elapsed_time = time.time() - start_time print(elapsed_time) -print('output shape', np.shape(out_mat)) +# convert back to original format out_mat = np.squeeze(out_mat, (0)) -print(np.shape(out_mat)) out_mat = np.transpose(out_mat, (2, 1, 0)) -print(out_mat, np.shape(out_mat)) out_mat = np.clip(out_mat, 0.0, 1.0) out_mat = out_mat * 255 out_mat = out_mat.astype(np.uint8)