diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 2adaa866..8483bcc9 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -273,8 +273,7 @@ def run_txt2img_pipeline(model, provider, scheduler, prompt, negative_prompt, cf negative_prompt=negative_prompt, num_inference_steps=steps, ).images[0] - - image = upscale_resrgan(image) + image = upscale_resrgan(image, model_path) image.save(output) print('saved txt2img output: %s' % (output)) @@ -295,6 +294,7 @@ def run_img2img_pipeline(model, provider, scheduler, prompt, negative_prompt, cf num_inference_steps=steps, strength=strength, ).images[0] + image = upscale_resrgan(image, model_path) image.save(output) print('saved img2img output: %s' % (output)) diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 17614dd1..b9a71d34 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -11,7 +11,7 @@ denoise_strength = 0.5 gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth' resrgan_url = [ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] -fp32 = True +fp16 = False model_name = 'RealESRGAN_x4plus' netscale = 4 outscale = 4 @@ -20,13 +20,12 @@ tile = 0 tile_pad = 10 -def upscale_resrgan(source_image: Image, faces=True) -> Image: - model_path = path.join('weights', model_name + '.pth') +def make_resrgan(model_path): + model_path = path.join(model_path, model_name + '.pth') if not path.isfile(model_path): - ROOT_DIR = path.dirname(path.abspath(__file__)) for url in resrgan_url: model_path = load_file_from_url( - url=url, model_dir=path.join(ROOT_DIR, 'weights'), progress=True, file_name=None) + url=url, model_dir=path.join(model_path, model_name), progress=True, file_name=None) model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) @@ -46,13 +45,19 @@ def upscale_resrgan(source_image: Image, faces=True) -> Image: tile=tile, tile_pad=tile_pad, pre_pad=pre_pad, - half=fp32) + half=fp16) + return upsampler + + +def upscale_resrgan(source_image: Image, model_path: str, faces=True) -> Image: image = np.array(source_image) + upsampler = make_resrgan(model_path) + output, _ = upsampler.enhance(image, outscale=outscale) if faces: - output = upscale_gfpgan(output, upsampler) + output = upscale_gfpgan(output, upsampler) return Image.fromarray(output, 'RGB') @@ -65,6 +70,7 @@ def upscale_gfpgan(image, upsampler) -> Image: channel_multiplier=2, bg_upsampler=upsampler) - _, _, output = face_enhancer.enhance(image, has_aligned=False, only_center_face=False, paste_back=True) + _, _, output = face_enhancer.enhance( + image, has_aligned=False, only_center_face=False, paste_back=True) return output