From e059f11253eebc638d71053735b39a61672f8f67 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 5 Feb 2023 08:37:47 -0600 Subject: [PATCH] feat(api): add CodeFormer stage for chain pipelines --- api/Makefile | 4 +- api/onnx_web/chain/__init__.py | 1 + api/onnx_web/chain/correct_codeformer.py | 103 ++--------------------- api/onnx_web/serve.py | 2 + api/requirements.txt | 3 +- 5 files changed, 15 insertions(+), 98 deletions(-) diff --git a/api/Makefile b/api/Makefile index 5f248cbd..9584d6f3 100644 --- a/api/Makefile +++ b/api/Makefile @@ -27,12 +27,12 @@ package-upload: lint-check: black --check --preview onnx_web isort --check-only --skip __init__.py --filter-files onnx_web - flake8 --per-file-ignores="__init__.py:F401" onnx_web + flake8 onnx_web lint-fix: black onnx_web isort --skip __init__.py --filter-files onnx_web - flake8 --per-file-ignores="__init__.py:F401" onnx_web + flake8 onnx_web typecheck: mypy -m onnx_web.serve diff --git a/api/onnx_web/chain/__init__.py b/api/onnx_web/chain/__init__.py index a6349eb6..0312d85e 100644 --- a/api/onnx_web/chain/__init__.py +++ b/api/onnx_web/chain/__init__.py @@ -1,6 +1,7 @@ from .base import ChainPipeline, PipelineStage, StageCallback, StageParams from .blend_img2img import blend_img2img from .blend_inpaint import blend_inpaint +from .correct_codeformer import correct_codeformer from .correct_gfpgan import correct_gfpgan from .persist_disk import persist_disk from .persist_s3 import persist_s3 diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 34fa34d9..c5dfd7ce 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -1,14 +1,10 @@ from logging import getLogger -import torch -from basicsr.utils import img2tensor, tensor2img -from basicsr.utils.download_util import load_file_from_url -from facexlib.utils.face_restoration_helper import FaceRestoreHelper +from codeformer import CodeFormer from PIL import Image -from torchvision.transforms.functional import normalize from ..device_pool import JobContext -from ..params import ImageParams, StageParams, UpscaleParams +from ..params import ImageParams, StageParams from ..utils import ServerContext logger = getLogger(__name__) @@ -18,27 +14,17 @@ pretrain_model_url = ( ) device = "cpu" -upscale = 2 def correct_codeformer( - job: JobContext, - server: ServerContext, - stage: StageParams, - params: ImageParams, + _job: JobContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, source_image: Image.Image, - *, - upscale: UpscaleParams = None, **kwargs, ) -> Image.Image: - ARCH_REGISTRY = {} - bg_upsampler = None - face_upsampler = None - model = "TODO" - w = None - - # ------------------ set up CodeFormer restorer ------------------- - net = ARCH_REGISTRY.get("CodeFormer")( + pipe = CodeFormer( dim_embd=512, codebook_size=1024, n_head=8, @@ -46,77 +32,4 @@ def correct_codeformer( connect_list=["32", "64", "128", "256"], ).to(device) - # ckpt_path = 'weights/CodeFormer/codeformer.pth' - ckpt_path = load_file_from_url( - url=pretrain_model_url, - model_dir="weights/CodeFormer", - progress=True, - file_name=None, - ) - checkpoint = torch.load(ckpt_path) - checkpoint = checkpoint["params_ema"] - net.load_state_dict(checkpoint) - net.eval() - - # ------------------ set up FaceRestoreHelper ------------------- - # large det_model: 'YOLOv5l', 'retinaface_resnet50' - # small det_model: 'YOLOv5n', 'retinaface_mobile0.25' - - face_helper = FaceRestoreHelper( - upscale, - face_size=512, - crop_ratio=(1, 1), - det_model=model, - save_ext="png", - use_parse=True, - device=device, - ) - - # get face landmarks for each face - num_det_faces = face_helper.get_face_landmarks_5( - only_center_face=False, resize=640, eye_dist_threshold=5 - ) - logger.info("detect %s faces", num_det_faces) - # align and warp each face - face_helper.align_warp_face() - - # face restoration for each cropped face - for idx, cropped_face in enumerate(face_helper.cropped_faces): - # prepare data - cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(device) - - try: - with torch.no_grad(): - output = net(cropped_face_t, w=w, adain=True)[0] - restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) - del output - torch.cuda.empty_cache() - except Exception as error: - logger.error("Failed inference for CodeFormer: %s", error) - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - - restored_face = restored_face.astype("uint8") - face_helper.add_restored_face(restored_face, cropped_face) - - # upsample the background - if bg_upsampler is not None: - # Now only support RealESRGAN for upsampling background - bg_img = bg_upsampler.enhance(source_image, outscale=upscale.scale)[0] - else: - bg_img = None - - # paste_back - face_helper.get_inverse_affine(None) - # paste each restored face to the input image - if face_upsampler is not None: - restored_img = face_helper.paste_faces_to_input_image( - upsample_img=bg_img, draw_box=False, face_upsampler=face_upsampler - ) - else: - restored_img = face_helper.paste_faces_to_input_image( - upsample_img=bg_img, draw_box=False - ) - - return restored_img + return pipe(source_image) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index c361c203..4fbc7067 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -34,6 +34,7 @@ from .chain import ( ChainPipeline, blend_img2img, blend_inpaint, + correct_codeformer, correct_gfpgan, persist_disk, persist_s3, @@ -121,6 +122,7 @@ mask_filters = { chain_stages = { "blend-img2img": blend_img2img, "blend-inpaint": blend_inpaint, + "correct-codeformer": correct_codeformer, "correct-gfpgan": correct_gfpgan, "persist-disk": persist_disk, "persist-s3": persist_s3, diff --git a/api/requirements.txt b/api/requirements.txt index f99232cf..c023071f 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -11,6 +11,7 @@ transformers #### Upscaling and face correction basicsr +codeformer-perceptor facexlib gfpgan realesrgan @@ -20,4 +21,4 @@ boto3 flask flask-cors jsonschema -pyyaml \ No newline at end of file +pyyaml