From 7ed30ee47019ab83c7fa33babaca5331bbfb1396 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 18 Dec 2023 21:06:39 -0600 Subject: [PATCH] feat(api): switch to codeformer lib that works with torch 2.x --- api/onnx_web/chain/correct_codeformer.py | 91 ++++++++++++++++++++++-- api/onnx_web/server/hacks.py | 1 + api/requirements/base.txt | 3 +- 3 files changed, 89 insertions(+), 6 deletions(-) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 1169d4fb..3e0bcd77 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -1,7 +1,10 @@ from logging import getLogger +from os import path from typing import Optional +import torch from PIL import Image +from torchvision.transforms.functional import normalize from ..params import ImageParams, StageParams, UpscaleParams from ..server import ServerContext @@ -16,7 +19,7 @@ class CorrectCodeformerStage(BaseStage): def run( self, worker: WorkerContext, - _server: ServerContext, + server: ServerContext, _stage: StageParams, _params: ImageParams, sources: StageResult, @@ -27,10 +30,88 @@ class CorrectCodeformerStage(BaseStage): ) -> StageResult: # must be within the load function for patch to take effect # TODO: rewrite and remove - from codeformer import CodeFormer + from codeformer.basicsr.utils import img2tensor, tensor2img + from codeformer.basicsr.utils.registry import ARCH_REGISTRY + from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper upscale = upscale.with_args(**kwargs) - device = worker.get_device() - pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str()) - return StageResult(images=[pipe(source) for source in sources.as_image()]) + + net = ARCH_REGISTRY.get("CodeFormer")( + dim_embd=512, + codebook_size=1024, + n_head=8, + n_layers=9, + connect_list=["32", "64", "128", "256"], + ).to(device.torch_str()) + + ckpt_path = path.join(server.cache_path, "correction-codeformer.pth") + checkpoint = torch.load(ckpt_path)["params_ema"] + net.load_state_dict(checkpoint) + net.eval() + + det_model = "retinaface_resnet50" + face_helper = FaceRestoreHelper( + upscale.face_outscale, + face_size=512, + crop_ratio=(1, 1), + det_model=det_model, + save_ext="png", + use_parse=True, + device=device.torch_str(), + ) + + results = [] + for img in sources.as_image(): + # clean all the intermediate results to process the next image + face_helper.clean_all() + face_helper.read_image(img) + + # get face landmarks for each face + num_det_faces = face_helper.get_face_landmarks_5( + only_center_face=False, resize=640, eye_dist_threshold=5 + ) + logger.debug("detected %s faces", num_det_faces) + + # align and warp each face + face_helper.align_warp_face() + + # face restoration for each cropped face + for cropped_face in face_helper.cropped_faces: + # prepare data + cropped_face_t = img2tensor( + cropped_face / 255.0, bgr2rgb=True, float32=True + ) + normalize( + cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True + ) + cropped_face_t = cropped_face_t.unsqueeze(0).to(device.torch_str()) + + try: + with torch.no_grad(): + output = net( + cropped_face_t, w=upscale.face_strength, adain=True + )[0] + restored_face = tensor2img( + output, rgb2bgr=True, min_max=(-1, 1) + ) + + del output + except Exception: + logger.exception("inference failed for CodeFormer") + restored_face = tensor2img( + cropped_face_t, rgb2bgr=True, min_max=(-1, 1) + ) + + restored_face = restored_face.astype("uint8") + face_helper.add_restored_face(restored_face, cropped_face) + + # paste_back + face_helper.get_inverse_affine(None) + + # paste each restored face to the input image + results.append( + face_helper.paste_faces_to_input_image(upsample_img=img, draw_box=False) + ) + + return StageResult.from_images(results) diff --git a/api/onnx_web/server/hacks.py b/api/onnx_web/server/hacks.py index 052a5c42..78ff213d 100644 --- a/api/onnx_web/server/hacks.py +++ b/api/onnx_web/server/hacks.py @@ -150,6 +150,7 @@ def apply_patch_basicsr(server: ServerContext): def apply_patch_codeformer(server: ServerContext): logger.debug("patching CodeFormer module") try: + import codeformer.basicsr.utils # download_util import codeformer.facelib.utils.misc codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl diff --git a/api/requirements/base.txt b/api/requirements/base.txt index 59e112d1..4dfea9b8 100644 --- a/api/requirements/base.txt +++ b/api/requirements/base.txt @@ -18,11 +18,12 @@ onnx==1.15.0 optimum==1.16.0 safetensors==0.4.1 timm==0.9.12 +torchsde==0.2.6 transformers==4.36.1 #### Upscaling and face correction basicsr==1.4.2 -codeformer-perceptor==0.1.2 +# codeformer-perceptor==0.1.2 facexlib==0.2.5 gfpgan==1.3.8 realesrgan==0.3.0