fix(api): update codeformer patches for new lib
This commit is contained in:
parent
7ed30ee470
commit
3e1db707ac
|
@ -14,6 +14,9 @@ from .result import StageResult
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
CORRECTION_MODEL = "correction-codeformer.pth"
|
||||||
|
DETECTION_MODEL = "retinaface_resnet50"
|
||||||
|
|
||||||
|
|
||||||
class CorrectCodeformerStage(BaseStage):
|
class CorrectCodeformerStage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
|
@ -28,7 +31,10 @@ class CorrectCodeformerStage(BaseStage):
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
# must be within the load function for patch to take effect
|
# adapted from https://github.com/kadirnar/codeformer-pip/blob/main/codeformer/app.py and
|
||||||
|
# https://pypi.org/project/codeformer-perceptor/
|
||||||
|
|
||||||
|
# import must be within the load function for patches to take effect
|
||||||
# TODO: rewrite and remove
|
# TODO: rewrite and remove
|
||||||
from codeformer.basicsr.utils import img2tensor, tensor2img
|
from codeformer.basicsr.utils import img2tensor, tensor2img
|
||||||
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
|
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
|
||||||
|
@ -45,17 +51,16 @@ class CorrectCodeformerStage(BaseStage):
|
||||||
connect_list=["32", "64", "128", "256"],
|
connect_list=["32", "64", "128", "256"],
|
||||||
).to(device.torch_str())
|
).to(device.torch_str())
|
||||||
|
|
||||||
ckpt_path = path.join(server.cache_path, "correction-codeformer.pth")
|
ckpt_path = path.join(server.cache_path, CORRECTION_MODEL)
|
||||||
checkpoint = torch.load(ckpt_path)["params_ema"]
|
checkpoint = torch.load(ckpt_path)["params_ema"]
|
||||||
net.load_state_dict(checkpoint)
|
net.load_state_dict(checkpoint)
|
||||||
net.eval()
|
net.eval()
|
||||||
|
|
||||||
det_model = "retinaface_resnet50"
|
|
||||||
face_helper = FaceRestoreHelper(
|
face_helper = FaceRestoreHelper(
|
||||||
upscale.face_outscale,
|
upscale.face_outscale,
|
||||||
face_size=512,
|
face_size=512,
|
||||||
crop_ratio=(1, 1),
|
crop_ratio=(1, 1),
|
||||||
det_model=det_model,
|
det_model=DETECTION_MODEL,
|
||||||
save_ext="png",
|
save_ext="png",
|
||||||
use_parse=True,
|
use_parse=True,
|
||||||
device=device.torch_str(),
|
device=device.torch_str(),
|
||||||
|
|
|
@ -150,11 +150,10 @@ def apply_patch_basicsr(server: ServerContext):
|
||||||
def apply_patch_codeformer(server: ServerContext):
|
def apply_patch_codeformer(server: ServerContext):
|
||||||
logger.debug("patching CodeFormer module")
|
logger.debug("patching CodeFormer module")
|
||||||
try:
|
try:
|
||||||
import codeformer.basicsr.utils # download_util
|
import codeformer.basicsr.utils.download_util
|
||||||
import codeformer.facelib.utils.misc
|
|
||||||
|
|
||||||
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
codeformer.basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
|
||||||
codeformer.facelib.utils.misc.load_file_from_url = partial(
|
codeformer.basicsr.utils.download_util.load_file_from_url = partial(
|
||||||
patch_cache_path, server
|
patch_cache_path, server
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
Loading…
Reference in New Issue