1
0
Fork 0

feat(api): add CodeFormer stage for chain pipelines

This commit is contained in:
Sean Sube 2023-02-05 08:37:47 -06:00
parent 35681efc1b
commit e059f11253
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 15 additions and 98 deletions

View File

@ -27,12 +27,12 @@ package-upload:
lint-check: lint-check:
black --check --preview onnx_web black --check --preview onnx_web
isort --check-only --skip __init__.py --filter-files onnx_web isort --check-only --skip __init__.py --filter-files onnx_web
flake8 --per-file-ignores="__init__.py:F401" onnx_web flake8 onnx_web
lint-fix: lint-fix:
black onnx_web black onnx_web
isort --skip __init__.py --filter-files onnx_web isort --skip __init__.py --filter-files onnx_web
flake8 --per-file-ignores="__init__.py:F401" onnx_web flake8 onnx_web
typecheck: typecheck:
mypy -m onnx_web.serve mypy -m onnx_web.serve

View File

@ -1,6 +1,7 @@
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
from .blend_img2img import blend_img2img from .blend_img2img import blend_img2img
from .blend_inpaint import blend_inpaint from .blend_inpaint import blend_inpaint
from .correct_codeformer import correct_codeformer
from .correct_gfpgan import correct_gfpgan from .correct_gfpgan import correct_gfpgan
from .persist_disk import persist_disk from .persist_disk import persist_disk
from .persist_s3 import persist_s3 from .persist_s3 import persist_s3

View File

@ -1,14 +1,10 @@
from logging import getLogger from logging import getLogger
import torch from codeformer import CodeFormer
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from PIL import Image from PIL import Image
from torchvision.transforms.functional import normalize
from ..device_pool import JobContext from ..device_pool import JobContext
from ..params import ImageParams, StageParams, UpscaleParams from ..params import ImageParams, StageParams
from ..utils import ServerContext from ..utils import ServerContext
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,27 +14,17 @@ pretrain_model_url = (
) )
device = "cpu" device = "cpu"
upscale = 2
def correct_codeformer( def correct_codeformer(
job: JobContext, _job: JobContext,
server: ServerContext, _server: ServerContext,
stage: StageParams, _stage: StageParams,
params: ImageParams, _params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
*,
upscale: UpscaleParams = None,
**kwargs, **kwargs,
) -> Image.Image: ) -> Image.Image:
ARCH_REGISTRY = {} pipe = CodeFormer(
bg_upsampler = None
face_upsampler = None
model = "TODO"
w = None
# ------------------ set up CodeFormer restorer -------------------
net = ARCH_REGISTRY.get("CodeFormer")(
dim_embd=512, dim_embd=512,
codebook_size=1024, codebook_size=1024,
n_head=8, n_head=8,
@ -46,77 +32,4 @@ def correct_codeformer(
connect_list=["32", "64", "128", "256"], connect_list=["32", "64", "128", "256"],
).to(device) ).to(device)
# ckpt_path = 'weights/CodeFormer/codeformer.pth' return pipe(source_image)
ckpt_path = load_file_from_url(
url=pretrain_model_url,
model_dir="weights/CodeFormer",
progress=True,
file_name=None,
)
checkpoint = torch.load(ckpt_path)
checkpoint = checkpoint["params_ema"]
net.load_state_dict(checkpoint)
net.eval()
# ------------------ set up FaceRestoreHelper -------------------
# large det_model: 'YOLOv5l', 'retinaface_resnet50'
# small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
face_helper = FaceRestoreHelper(
upscale,
face_size=512,
crop_ratio=(1, 1),
det_model=model,
save_ext="png",
use_parse=True,
device=device,
)
# get face landmarks for each face
num_det_faces = face_helper.get_face_landmarks_5(
only_center_face=False, resize=640, eye_dist_threshold=5
)
logger.info("detect %s faces", num_det_faces)
# align and warp each face
face_helper.align_warp_face()
# face restoration for each cropped face
for idx, cropped_face in enumerate(face_helper.cropped_faces):
# prepare data
cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
try:
with torch.no_grad():
output = net(cropped_face_t, w=w, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
except Exception as error:
logger.error("Failed inference for CodeFormer: %s", error)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype("uint8")
face_helper.add_restored_face(restored_face, cropped_face)
# upsample the background
if bg_upsampler is not None:
# Now only support RealESRGAN for upsampling background
bg_img = bg_upsampler.enhance(source_image, outscale=upscale.scale)[0]
else:
bg_img = None
# paste_back
face_helper.get_inverse_affine(None)
# paste each restored face to the input image
if face_upsampler is not None:
restored_img = face_helper.paste_faces_to_input_image(
upsample_img=bg_img, draw_box=False, face_upsampler=face_upsampler
)
else:
restored_img = face_helper.paste_faces_to_input_image(
upsample_img=bg_img, draw_box=False
)
return restored_img

View File

@ -34,6 +34,7 @@ from .chain import (
ChainPipeline, ChainPipeline,
blend_img2img, blend_img2img,
blend_inpaint, blend_inpaint,
correct_codeformer,
correct_gfpgan, correct_gfpgan,
persist_disk, persist_disk,
persist_s3, persist_s3,
@ -121,6 +122,7 @@ mask_filters = {
chain_stages = { chain_stages = {
"blend-img2img": blend_img2img, "blend-img2img": blend_img2img,
"blend-inpaint": blend_inpaint, "blend-inpaint": blend_inpaint,
"correct-codeformer": correct_codeformer,
"correct-gfpgan": correct_gfpgan, "correct-gfpgan": correct_gfpgan,
"persist-disk": persist_disk, "persist-disk": persist_disk,
"persist-s3": persist_s3, "persist-s3": persist_s3,

View File

@ -11,6 +11,7 @@ transformers
#### Upscaling and face correction #### Upscaling and face correction
basicsr basicsr
codeformer-perceptor
facexlib facexlib
gfpgan gfpgan
realesrgan realesrgan