fix(api): update codeformer patches for new lib
This commit is contained in:
parent
7ed30ee470
commit
3e1db707ac
|
@ -14,6 +14,9 @@ from .result import StageResult
|
|||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
CORRECTION_MODEL = "correction-codeformer.pth"
|
||||
DETECTION_MODEL = "retinaface_resnet50"
|
||||
|
||||
|
||||
class CorrectCodeformerStage(BaseStage):
|
||||
def run(
|
||||
|
@ -28,7 +31,10 @@ class CorrectCodeformerStage(BaseStage):
|
|||
upscale: UpscaleParams,
|
||||
**kwargs,
|
||||
) -> StageResult:
|
||||
# must be within the load function for patch to take effect
|
||||
# adapted from https://github.com/kadirnar/codeformer-pip/blob/main/codeformer/app.py and
|
||||
# https://pypi.org/project/codeformer-perceptor/
|
||||
|
||||
# import must be within the load function for patches to take effect
|
||||
# TODO: rewrite and remove
|
||||
from codeformer.basicsr.utils import img2tensor, tensor2img
|
||||
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
@ -45,17 +51,16 @@ class CorrectCodeformerStage(BaseStage):
|
|||
connect_list=["32", "64", "128", "256"],
|
||||
).to(device.torch_str())
|
||||
|
||||
ckpt_path = path.join(server.cache_path, "correction-codeformer.pth")
|
||||
ckpt_path = path.join(server.cache_path, CORRECTION_MODEL)
|
||||
checkpoint = torch.load(ckpt_path)["params_ema"]
|
||||
net.load_state_dict(checkpoint)
|
||||
net.eval()
|
||||
|
||||
det_model = "retinaface_resnet50"
|
||||
face_helper = FaceRestoreHelper(
|
||||
upscale.face_outscale,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model=det_model,
|
||||
det_model=DETECTION_MODEL,
|
||||
save_ext="png",
|
||||
use_parse=True,
|
||||
device=device.torch_str(),
|
||||
|
|
|
@ -150,11 +150,10 @@ def apply_patch_basicsr(server: ServerContext):
|
|||
def apply_patch_codeformer(server: ServerContext):
|
||||
logger.debug("patching CodeFormer module")
|
||||
try:
|
||||
import codeformer.basicsr.utils # download_util
|
||||
import codeformer.facelib.utils.misc
|
||||
import codeformer.basicsr.utils.download_util
|
||||
|
||||
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
||||
codeformer.facelib.utils.misc.load_file_from_url = partial(
|
||||
codeformer.basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
|
||||
codeformer.basicsr.utils.download_util.load_file_from_url = partial(
|
||||
patch_cache_path, server
|
||||
)
|
||||
except ImportError:
|
||||
|
|
Loading…
Reference in New Issue