From 03e06193ebd4349f07dd2ea916c3dd74b54f96aa Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 16 Jan 2023 21:36:50 -0600 Subject: [PATCH] attempt to fix face correction at various scales --- api/onnx_web/convert.py | 10 ++++++---- api/onnx_web/upscale.py | 3 +-- gui/src/strings.ts | 4 +++- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/api/onnx_web/convert.py b/api/onnx_web/convert.py index baabb532..c8e6fae9 100644 --- a/api/onnx_web/convert.py +++ b/api/onnx_web/convert.py @@ -24,8 +24,10 @@ sources: Dict[str, List[Tuple[str, str]]] = { 'stabilityai/stable-diffusion-2-inpainting'), ], 'gfpgan': [ - ('correction-gfpgan-v1-3', - 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'), + ('correction-gfpgan-v1-3-x2', + 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', 2), + ('correction-gfpgan-v1-3-x4', + 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', 4), ], 'real_esrgan': [ ('upscaling-real-esrgan-x2-plus', @@ -92,7 +94,7 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int): @torch.no_grad() -def convert_gfpgan(name: str, url: str, opset: int): +def convert_gfpgan(name: str, url: str, scale: int, opset: int): dest_path = path.join(model_path, name + '.pth') dest_onnx = path.join(model_path, name + '.onnx') print('converting GFPGAN model: %s -> %s' % (name, dest_onnx)) @@ -109,7 +111,7 @@ def convert_gfpgan(name: str, url: str, opset: int): print('loading and training model') model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, - num_block=23, num_grow_ch=32, scale=4) + num_block=23, num_grow_ch=32, scale=scale) torch_model = torch.load(dest_path, map_location=map_location) # TODO: make sure strict=False is safe here diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 5993424b..0248ce14 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -177,8 +177,7 @@ def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=N return image if upsampler is None: - bg_params = params.rescale(params.outscale) - upsampler = make_resrgan(ctx, bg_params, tile=512) + upsampler = make_resrgan(ctx, params, tile=512) face_path = path.join(ctx.model_path, '%s.pth' % (params.correction_model)) diff --git a/gui/src/strings.ts b/gui/src/strings.ts index a6d86e5b..a768b0b2 100644 --- a/gui/src/strings.ts +++ b/gui/src/strings.ts @@ -11,7 +11,9 @@ export const MODEL_LABELS = { 'upscaling-real-esrgan-x4-plus': 'Real ESRGAN x4 Plus', 'upscaling-real-esrgan-x4-v3': 'Real ESRGAN x4 v3', // correction - 'correction-gfpgan-v1-3': 'GFPGAN v1.3', + 'correction-gfpgan-v1-3-x2': 'GFPGAN v1.3 x2', + 'correction-gfpgan-v1-3-x4': 'GFPGAN v1.3 x4', + }; export const PLATFORM_LABELS: Record = {