1
0
Fork 0

feat(api): switch to codeformer lib that works with torch 2.x

This commit is contained in:
Sean Sube 2023-12-18 21:06:39 -06:00
parent 6a6a3f04bc
commit 7ed30ee470
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 89 additions and 6 deletions

View File

@ -1,7 +1,10 @@
from logging import getLogger
from os import path
from typing import Optional
import torch
from PIL import Image
from torchvision.transforms.functional import normalize
from ..params import ImageParams, StageParams, UpscaleParams
from ..server import ServerContext
@ -16,7 +19,7 @@ class CorrectCodeformerStage(BaseStage):
def run(
self,
worker: WorkerContext,
_server: ServerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: StageResult,
@ -27,10 +30,88 @@ class CorrectCodeformerStage(BaseStage):
) -> StageResult:
# must be within the load function for patch to take effect
# TODO: rewrite and remove
from codeformer import CodeFormer
from codeformer.basicsr.utils import img2tensor, tensor2img
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
upscale = upscale.with_args(**kwargs)
device = worker.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return StageResult(images=[pipe(source) for source in sources.as_image()])
net = ARCH_REGISTRY.get("CodeFormer")(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=["32", "64", "128", "256"],
).to(device.torch_str())
ckpt_path = path.join(server.cache_path, "correction-codeformer.pth")
checkpoint = torch.load(ckpt_path)["params_ema"]
net.load_state_dict(checkpoint)
net.eval()
det_model = "retinaface_resnet50"
face_helper = FaceRestoreHelper(
upscale.face_outscale,
face_size=512,
crop_ratio=(1, 1),
det_model=det_model,
save_ext="png",
use_parse=True,
device=device.torch_str(),
)
results = []
for img in sources.as_image():
# clean all the intermediate results to process the next image
face_helper.clean_all()
face_helper.read_image(img)
# get face landmarks for each face
num_det_faces = face_helper.get_face_landmarks_5(
only_center_face=False, resize=640, eye_dist_threshold=5
)
logger.debug("detected %s faces", num_det_faces)
# align and warp each face
face_helper.align_warp_face()
# face restoration for each cropped face
for cropped_face in face_helper.cropped_faces:
# prepare data
cropped_face_t = img2tensor(
cropped_face / 255.0, bgr2rgb=True, float32=True
)
normalize(
cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True
)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device.torch_str())
try:
with torch.no_grad():
output = net(
cropped_face_t, w=upscale.face_strength, adain=True
)[0]
restored_face = tensor2img(
output, rgb2bgr=True, min_max=(-1, 1)
)
del output
except Exception:
logger.exception("inference failed for CodeFormer")
restored_face = tensor2img(
cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
)
restored_face = restored_face.astype("uint8")
face_helper.add_restored_face(restored_face, cropped_face)
# paste_back
face_helper.get_inverse_affine(None)
# paste each restored face to the input image
results.append(
face_helper.paste_faces_to_input_image(upsample_img=img, draw_box=False)
)
return StageResult.from_images(results)

View File

@ -150,6 +150,7 @@ def apply_patch_basicsr(server: ServerContext):
def apply_patch_codeformer(server: ServerContext):
logger.debug("patching CodeFormer module")
try:
import codeformer.basicsr.utils # download_util
import codeformer.facelib.utils.misc
codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl

View File

@ -18,11 +18,12 @@ onnx==1.15.0
optimum==1.16.0
safetensors==0.4.1
timm==0.9.12
torchsde==0.2.6
transformers==4.36.1
#### Upscaling and face correction
basicsr==1.4.2
codeformer-perceptor==0.1.2
# codeformer-perceptor==0.1.2
facexlib==0.2.5
gfpgan==1.3.8
realesrgan==0.3.0