1
0
Fork 0

fix(api): update codeformer patches for new lib

This commit is contained in:
Sean Sube 2023-12-18 21:12:08 -06:00
parent 7ed30ee470
commit 3e1db707ac
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 12 additions and 8 deletions

View File

@ -14,6 +14,9 @@ from .result import StageResult
logger = getLogger(__name__)
CORRECTION_MODEL = "correction-codeformer.pth"
DETECTION_MODEL = "retinaface_resnet50"
class CorrectCodeformerStage(BaseStage):
def run(
@ -28,7 +31,10 @@ class CorrectCodeformerStage(BaseStage):
upscale: UpscaleParams,
**kwargs,
) -> StageResult:
# must be within the load function for patch to take effect
# adapted from https://github.com/kadirnar/codeformer-pip/blob/main/codeformer/app.py and
# https://pypi.org/project/codeformer-perceptor/
# import must be within the load function for patches to take effect
# TODO: rewrite and remove
from codeformer.basicsr.utils import img2tensor, tensor2img
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
@ -45,17 +51,16 @@ class CorrectCodeformerStage(BaseStage):
connect_list=["32", "64", "128", "256"],
).to(device.torch_str())
ckpt_path = path.join(server.cache_path, "correction-codeformer.pth")
ckpt_path = path.join(server.cache_path, CORRECTION_MODEL)
checkpoint = torch.load(ckpt_path)["params_ema"]
net.load_state_dict(checkpoint)
net.eval()
det_model = "retinaface_resnet50"
face_helper = FaceRestoreHelper(
upscale.face_outscale,
face_size=512,
crop_ratio=(1, 1),
det_model=det_model,
det_model=DETECTION_MODEL,
save_ext="png",
use_parse=True,
device=device.torch_str(),

View File

@ -150,11 +150,10 @@ def apply_patch_basicsr(server: ServerContext):
def apply_patch_codeformer(server: ServerContext):
logger.debug("patching CodeFormer module")
try:
import codeformer.basicsr.utils # download_util
import codeformer.facelib.utils.misc
import codeformer.basicsr.utils.download_util
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
codeformer.facelib.utils.misc.load_file_from_url = partial(
codeformer.basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
codeformer.basicsr.utils.download_util.load_file_from_url = partial(
patch_cache_path, server
)
except ImportError: