From 091c4e6109e5e276ac59228df0c091fc9e583477 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 16 Jan 2023 14:17:50 -0600 Subject: [PATCH] fix(api): pass upscale params when creating RESRGAN --- api/onnx_web/convert.py | 5 +++-- api/onnx_web/pipeline.py | 6 +++--- api/onnx_web/upscale.py | 10 +++++----- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/api/onnx_web/convert.py b/api/onnx_web/convert.py index 8720a597..b3e2f922 100644 --- a/api/onnx_web/convert.py +++ b/api/onnx_web/convert.py @@ -3,9 +3,9 @@ from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.utils.download_util import load_file_from_url from os import path, environ from sys import exit +from torch.onnx import export import torch -import torch.onnx from .upscale import ( gfpgan_url, @@ -17,6 +17,7 @@ model_path = environ.get('ONNX_WEB_MODEL_PATH', path.join('..', 'models')) +@torch.no_grad() def convert_real_esrgan(): dest_path = path.join(model_path, resrgan_name + '.pth') print('converting Real ESRGAN into %s' % dest_path) @@ -44,7 +45,7 @@ def convert_real_esrgan(): with torch.no_grad(): dest_onnx = path.join(model_path, resrgan_name + '.onnx') print('exporting Real ESRGAN model to %s' % dest_onnx) - torch.onnx.export( + export( model, rng, dest_onnx, diff --git a/api/onnx_web/pipeline.py b/api/onnx_web/pipeline.py index be8eb7dc..ba0da1cc 100644 --- a/api/onnx_web/pipeline.py +++ b/api/onnx_web/pipeline.py @@ -97,7 +97,7 @@ def run_txt2img_pipeline( ).images[0] if upscale.faces or upscale.scale > 1: - image = upscale_resrgan(ctx, image, upscale) + image = upscale_resrgan(ctx, upscale, image) dest = safer_join(ctx.output_path, output) image.save(dest) @@ -129,7 +129,7 @@ def run_img2img_pipeline( ).images[0] if upscale.faces or upscale.scale > 1: - image = upscale_resrgan(ctx, image, upscale) + image = upscale_resrgan(ctx, upscale, image) dest = safer_join(ctx.output_path, output) image.save(dest) @@ -182,7 +182,7 @@ def run_inpaint_pipeline( ).images[0] if upscale.faces or upscale.scale > 1: - image = upscale_resrgan(ctx, image, upscale) + image = upscale_resrgan(ctx, upscale, image) dest = safer_join(ctx.output_path, output) image.save(dest) diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 23fb5d77..bf87a449 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -128,26 +128,26 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0): return upsampler -def upscale_resrgan(ctx: ServerContext, source_image: Image, params: UpscaleParams) -> Image: +def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Image) -> Image: print('upscaling image with Real ESRGAN', params) image = np.array(source_image) - upsampler = make_resrgan(ctx.model_path) + upsampler = make_resrgan(ctx, params) # TODO: what is outscale for here? output, _ = upsampler.enhance(image, outscale=outscale) if params.faces: - output = upscale_gfpgan(ctx, output) + output = upscale_gfpgan(ctx, params, output) return Image.fromarray(output, 'RGB') -def upscale_gfpgan(ctx: ServerContext, image, upsampler=None) -> Image: +def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image: print('correcting faces with GFPGAN') if upsampler is None: - upsampler = make_resrgan(ctx.model_path, 512) + upsampler = make_resrgan(ctx, params, tile=512) face_enhancer = GFPGANer( model_path=gfpgan_url,