feat(api): switch to codeformer lib that works with torch 2.x
This commit is contained in:
parent
6a6a3f04bc
commit
7ed30ee470
|
@ -1,7 +1,10 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
from os import path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from torchvision.transforms.functional import normalize
|
||||||
|
|
||||||
from ..params import ImageParams, StageParams, UpscaleParams
|
from ..params import ImageParams, StageParams, UpscaleParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
|
@ -16,7 +19,7 @@ class CorrectCodeformerStage(BaseStage):
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
_server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
sources: StageResult,
|
sources: StageResult,
|
||||||
|
@ -27,10 +30,88 @@ class CorrectCodeformerStage(BaseStage):
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
# must be within the load function for patch to take effect
|
# must be within the load function for patch to take effect
|
||||||
# TODO: rewrite and remove
|
# TODO: rewrite and remove
|
||||||
from codeformer import CodeFormer
|
from codeformer.basicsr.utils import img2tensor, tensor2img
|
||||||
|
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
|
||||||
|
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
|
|
||||||
upscale = upscale.with_args(**kwargs)
|
upscale = upscale.with_args(**kwargs)
|
||||||
|
|
||||||
device = worker.get_device()
|
device = worker.get_device()
|
||||||
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
|
|
||||||
return StageResult(images=[pipe(source) for source in sources.as_image()])
|
net = ARCH_REGISTRY.get("CodeFormer")(
|
||||||
|
dim_embd=512,
|
||||||
|
codebook_size=1024,
|
||||||
|
n_head=8,
|
||||||
|
n_layers=9,
|
||||||
|
connect_list=["32", "64", "128", "256"],
|
||||||
|
).to(device.torch_str())
|
||||||
|
|
||||||
|
ckpt_path = path.join(server.cache_path, "correction-codeformer.pth")
|
||||||
|
checkpoint = torch.load(ckpt_path)["params_ema"]
|
||||||
|
net.load_state_dict(checkpoint)
|
||||||
|
net.eval()
|
||||||
|
|
||||||
|
det_model = "retinaface_resnet50"
|
||||||
|
face_helper = FaceRestoreHelper(
|
||||||
|
upscale.face_outscale,
|
||||||
|
face_size=512,
|
||||||
|
crop_ratio=(1, 1),
|
||||||
|
det_model=det_model,
|
||||||
|
save_ext="png",
|
||||||
|
use_parse=True,
|
||||||
|
device=device.torch_str(),
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for img in sources.as_image():
|
||||||
|
# clean all the intermediate results to process the next image
|
||||||
|
face_helper.clean_all()
|
||||||
|
face_helper.read_image(img)
|
||||||
|
|
||||||
|
# get face landmarks for each face
|
||||||
|
num_det_faces = face_helper.get_face_landmarks_5(
|
||||||
|
only_center_face=False, resize=640, eye_dist_threshold=5
|
||||||
|
)
|
||||||
|
logger.debug("detected %s faces", num_det_faces)
|
||||||
|
|
||||||
|
# align and warp each face
|
||||||
|
face_helper.align_warp_face()
|
||||||
|
|
||||||
|
# face restoration for each cropped face
|
||||||
|
for cropped_face in face_helper.cropped_faces:
|
||||||
|
# prepare data
|
||||||
|
cropped_face_t = img2tensor(
|
||||||
|
cropped_face / 255.0, bgr2rgb=True, float32=True
|
||||||
|
)
|
||||||
|
normalize(
|
||||||
|
cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True
|
||||||
|
)
|
||||||
|
cropped_face_t = cropped_face_t.unsqueeze(0).to(device.torch_str())
|
||||||
|
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
output = net(
|
||||||
|
cropped_face_t, w=upscale.face_strength, adain=True
|
||||||
|
)[0]
|
||||||
|
restored_face = tensor2img(
|
||||||
|
output, rgb2bgr=True, min_max=(-1, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
del output
|
||||||
|
except Exception:
|
||||||
|
logger.exception("inference failed for CodeFormer")
|
||||||
|
restored_face = tensor2img(
|
||||||
|
cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
restored_face = restored_face.astype("uint8")
|
||||||
|
face_helper.add_restored_face(restored_face, cropped_face)
|
||||||
|
|
||||||
|
# paste_back
|
||||||
|
face_helper.get_inverse_affine(None)
|
||||||
|
|
||||||
|
# paste each restored face to the input image
|
||||||
|
results.append(
|
||||||
|
face_helper.paste_faces_to_input_image(upsample_img=img, draw_box=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
return StageResult.from_images(results)
|
||||||
|
|
|
@ -150,6 +150,7 @@ def apply_patch_basicsr(server: ServerContext):
|
||||||
def apply_patch_codeformer(server: ServerContext):
|
def apply_patch_codeformer(server: ServerContext):
|
||||||
logger.debug("patching CodeFormer module")
|
logger.debug("patching CodeFormer module")
|
||||||
try:
|
try:
|
||||||
|
import codeformer.basicsr.utils # download_util
|
||||||
import codeformer.facelib.utils.misc
|
import codeformer.facelib.utils.misc
|
||||||
|
|
||||||
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
|
||||||
|
|
|
@ -18,11 +18,12 @@ onnx==1.15.0
|
||||||
optimum==1.16.0
|
optimum==1.16.0
|
||||||
safetensors==0.4.1
|
safetensors==0.4.1
|
||||||
timm==0.9.12
|
timm==0.9.12
|
||||||
|
torchsde==0.2.6
|
||||||
transformers==4.36.1
|
transformers==4.36.1
|
||||||
|
|
||||||
#### Upscaling and face correction
|
#### Upscaling and face correction
|
||||||
basicsr==1.4.2
|
basicsr==1.4.2
|
||||||
codeformer-perceptor==0.1.2
|
# codeformer-perceptor==0.1.2
|
||||||
facexlib==0.2.5
|
facexlib==0.2.5
|
||||||
gfpgan==1.3.8
|
gfpgan==1.3.8
|
||||||
realesrgan==0.3.0
|
realesrgan==0.3.0
|
||||||
|
|
Loading…
Reference in New Issue