feat(api): add CodeFormer stage for chain pipelines
This commit is contained in:
parent
35681efc1b
commit
e059f11253
|
@ -27,12 +27,12 @@ package-upload:
|
||||||
lint-check:
|
lint-check:
|
||||||
black --check --preview onnx_web
|
black --check --preview onnx_web
|
||||||
isort --check-only --skip __init__.py --filter-files onnx_web
|
isort --check-only --skip __init__.py --filter-files onnx_web
|
||||||
flake8 --per-file-ignores="__init__.py:F401" onnx_web
|
flake8 onnx_web
|
||||||
|
|
||||||
lint-fix:
|
lint-fix:
|
||||||
black onnx_web
|
black onnx_web
|
||||||
isort --skip __init__.py --filter-files onnx_web
|
isort --skip __init__.py --filter-files onnx_web
|
||||||
flake8 --per-file-ignores="__init__.py:F401" onnx_web
|
flake8 onnx_web
|
||||||
|
|
||||||
typecheck:
|
typecheck:
|
||||||
mypy -m onnx_web.serve
|
mypy -m onnx_web.serve
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
|
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
|
||||||
from .blend_img2img import blend_img2img
|
from .blend_img2img import blend_img2img
|
||||||
from .blend_inpaint import blend_inpaint
|
from .blend_inpaint import blend_inpaint
|
||||||
|
from .correct_codeformer import correct_codeformer
|
||||||
from .correct_gfpgan import correct_gfpgan
|
from .correct_gfpgan import correct_gfpgan
|
||||||
from .persist_disk import persist_disk
|
from .persist_disk import persist_disk
|
||||||
from .persist_s3 import persist_s3
|
from .persist_s3 import persist_s3
|
||||||
|
|
|
@ -1,14 +1,10 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
|
||||||
import torch
|
from codeformer import CodeFormer
|
||||||
from basicsr.utils import img2tensor, tensor2img
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.transforms.functional import normalize
|
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
from ..device_pool import JobContext
|
||||||
from ..params import ImageParams, StageParams, UpscaleParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..utils import ServerContext
|
from ..utils import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -18,27 +14,17 @@ pretrain_model_url = (
|
||||||
)
|
)
|
||||||
|
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
upscale = 2
|
|
||||||
|
|
||||||
|
|
||||||
def correct_codeformer(
|
def correct_codeformer(
|
||||||
job: JobContext,
|
_job: JobContext,
|
||||||
server: ServerContext,
|
_server: ServerContext,
|
||||||
stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
_params: ImageParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
*,
|
|
||||||
upscale: UpscaleParams = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
ARCH_REGISTRY = {}
|
pipe = CodeFormer(
|
||||||
bg_upsampler = None
|
|
||||||
face_upsampler = None
|
|
||||||
model = "TODO"
|
|
||||||
w = None
|
|
||||||
|
|
||||||
# ------------------ set up CodeFormer restorer -------------------
|
|
||||||
net = ARCH_REGISTRY.get("CodeFormer")(
|
|
||||||
dim_embd=512,
|
dim_embd=512,
|
||||||
codebook_size=1024,
|
codebook_size=1024,
|
||||||
n_head=8,
|
n_head=8,
|
||||||
|
@ -46,77 +32,4 @@ def correct_codeformer(
|
||||||
connect_list=["32", "64", "128", "256"],
|
connect_list=["32", "64", "128", "256"],
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
# ckpt_path = 'weights/CodeFormer/codeformer.pth'
|
return pipe(source_image)
|
||||||
ckpt_path = load_file_from_url(
|
|
||||||
url=pretrain_model_url,
|
|
||||||
model_dir="weights/CodeFormer",
|
|
||||||
progress=True,
|
|
||||||
file_name=None,
|
|
||||||
)
|
|
||||||
checkpoint = torch.load(ckpt_path)
|
|
||||||
checkpoint = checkpoint["params_ema"]
|
|
||||||
net.load_state_dict(checkpoint)
|
|
||||||
net.eval()
|
|
||||||
|
|
||||||
# ------------------ set up FaceRestoreHelper -------------------
|
|
||||||
# large det_model: 'YOLOv5l', 'retinaface_resnet50'
|
|
||||||
# small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
|
|
||||||
|
|
||||||
face_helper = FaceRestoreHelper(
|
|
||||||
upscale,
|
|
||||||
face_size=512,
|
|
||||||
crop_ratio=(1, 1),
|
|
||||||
det_model=model,
|
|
||||||
save_ext="png",
|
|
||||||
use_parse=True,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get face landmarks for each face
|
|
||||||
num_det_faces = face_helper.get_face_landmarks_5(
|
|
||||||
only_center_face=False, resize=640, eye_dist_threshold=5
|
|
||||||
)
|
|
||||||
logger.info("detect %s faces", num_det_faces)
|
|
||||||
# align and warp each face
|
|
||||||
face_helper.align_warp_face()
|
|
||||||
|
|
||||||
# face restoration for each cropped face
|
|
||||||
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
|
||||||
# prepare data
|
|
||||||
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
|
|
||||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
|
||||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
|
||||||
|
|
||||||
try:
|
|
||||||
with torch.no_grad():
|
|
||||||
output = net(cropped_face_t, w=w, adain=True)[0]
|
|
||||||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
|
||||||
del output
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
except Exception as error:
|
|
||||||
logger.error("Failed inference for CodeFormer: %s", error)
|
|
||||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
|
||||||
|
|
||||||
restored_face = restored_face.astype("uint8")
|
|
||||||
face_helper.add_restored_face(restored_face, cropped_face)
|
|
||||||
|
|
||||||
# upsample the background
|
|
||||||
if bg_upsampler is not None:
|
|
||||||
# Now only support RealESRGAN for upsampling background
|
|
||||||
bg_img = bg_upsampler.enhance(source_image, outscale=upscale.scale)[0]
|
|
||||||
else:
|
|
||||||
bg_img = None
|
|
||||||
|
|
||||||
# paste_back
|
|
||||||
face_helper.get_inverse_affine(None)
|
|
||||||
# paste each restored face to the input image
|
|
||||||
if face_upsampler is not None:
|
|
||||||
restored_img = face_helper.paste_faces_to_input_image(
|
|
||||||
upsample_img=bg_img, draw_box=False, face_upsampler=face_upsampler
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
restored_img = face_helper.paste_faces_to_input_image(
|
|
||||||
upsample_img=bg_img, draw_box=False
|
|
||||||
)
|
|
||||||
|
|
||||||
return restored_img
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ from .chain import (
|
||||||
ChainPipeline,
|
ChainPipeline,
|
||||||
blend_img2img,
|
blend_img2img,
|
||||||
blend_inpaint,
|
blend_inpaint,
|
||||||
|
correct_codeformer,
|
||||||
correct_gfpgan,
|
correct_gfpgan,
|
||||||
persist_disk,
|
persist_disk,
|
||||||
persist_s3,
|
persist_s3,
|
||||||
|
@ -121,6 +122,7 @@ mask_filters = {
|
||||||
chain_stages = {
|
chain_stages = {
|
||||||
"blend-img2img": blend_img2img,
|
"blend-img2img": blend_img2img,
|
||||||
"blend-inpaint": blend_inpaint,
|
"blend-inpaint": blend_inpaint,
|
||||||
|
"correct-codeformer": correct_codeformer,
|
||||||
"correct-gfpgan": correct_gfpgan,
|
"correct-gfpgan": correct_gfpgan,
|
||||||
"persist-disk": persist_disk,
|
"persist-disk": persist_disk,
|
||||||
"persist-s3": persist_s3,
|
"persist-s3": persist_s3,
|
||||||
|
|
|
@ -11,6 +11,7 @@ transformers
|
||||||
|
|
||||||
#### Upscaling and face correction
|
#### Upscaling and face correction
|
||||||
basicsr
|
basicsr
|
||||||
|
codeformer-perceptor
|
||||||
facexlib
|
facexlib
|
||||||
gfpgan
|
gfpgan
|
||||||
realesrgan
|
realesrgan
|
||||||
|
|
Loading…
Reference in New Issue