From 3e1db707ac79265546e8be4bbe432be9b8f8ba08 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 18 Dec 2023 21:12:08 -0600 Subject: [PATCH] fix(api): update codeformer patches for new lib --- api/onnx_web/chain/correct_codeformer.py | 13 +++++++++---- api/onnx_web/server/hacks.py | 7 +++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 3e0bcd77..b23a2fc0 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -14,6 +14,9 @@ from .result import StageResult logger = getLogger(__name__) +CORRECTION_MODEL = "correction-codeformer.pth" +DETECTION_MODEL = "retinaface_resnet50" + class CorrectCodeformerStage(BaseStage): def run( @@ -28,7 +31,10 @@ class CorrectCodeformerStage(BaseStage): upscale: UpscaleParams, **kwargs, ) -> StageResult: - # must be within the load function for patch to take effect + # adapted from https://github.com/kadirnar/codeformer-pip/blob/main/codeformer/app.py and + # https://pypi.org/project/codeformer-perceptor/ + + # import must be within the load function for patches to take effect # TODO: rewrite and remove from codeformer.basicsr.utils import img2tensor, tensor2img from codeformer.basicsr.utils.registry import ARCH_REGISTRY @@ -45,17 +51,16 @@ class CorrectCodeformerStage(BaseStage): connect_list=["32", "64", "128", "256"], ).to(device.torch_str()) - ckpt_path = path.join(server.cache_path, "correction-codeformer.pth") + ckpt_path = path.join(server.cache_path, CORRECTION_MODEL) checkpoint = torch.load(ckpt_path)["params_ema"] net.load_state_dict(checkpoint) net.eval() - det_model = "retinaface_resnet50" face_helper = FaceRestoreHelper( upscale.face_outscale, face_size=512, crop_ratio=(1, 1), - det_model=det_model, + det_model=DETECTION_MODEL, save_ext="png", use_parse=True, device=device.torch_str(), diff --git a/api/onnx_web/server/hacks.py b/api/onnx_web/server/hacks.py index 78ff213d..c2eff43b 100644 --- a/api/onnx_web/server/hacks.py +++ b/api/onnx_web/server/hacks.py @@ -150,11 +150,10 @@ def apply_patch_basicsr(server: ServerContext): def apply_patch_codeformer(server: ServerContext): logger.debug("patching CodeFormer module") try: - import codeformer.basicsr.utils # download_util - import codeformer.facelib.utils.misc + import codeformer.basicsr.utils.download_util - codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl - codeformer.facelib.utils.misc.load_file_from_url = partial( + codeformer.basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl + codeformer.basicsr.utils.download_util.load_file_from_url = partial( patch_cache_path, server ) except ImportError: