1
0
Fork 0
onnx-web/api/onnx_web/chain/correct_gfpgan.py

85 lines
2.2 KiB
Python

from logging import getLogger
from os import path
from typing import Optional
import numpy as np
from gfpgan import GFPGANer
from PIL import Image
from realesrgan import RealESRGANer
from ..device_pool import JobContext
from ..params import ImageParams, StageParams, UpscaleParams
from ..utils import ServerContext, run_gc
from .upscale_resrgan import load_resrgan
logger = getLogger(__name__)
last_pipeline_instance = None
last_pipeline_params = None
def load_gfpgan(
ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None
):
global last_pipeline_instance
global last_pipeline_params
if upsampler is None:
bg_upscale = upscale.rescale(upscale.outscale)
upsampler = load_resrgan(ctx, bg_upscale)
face_path = path.join(ctx.model_path, "%s.pth" % (upscale.correction_model))
if last_pipeline_instance is not None and face_path == last_pipeline_params:
logger.info("reusing existing GFPGAN pipeline")
return last_pipeline_instance
logger.debug("loading GFPGAN model from %s", face_path)
# TODO: find a way to pass the ONNX model to underlying architectures
gfpgan = GFPGANer(
model_path=face_path,
upscale=upscale.outscale,
arch="clean",
channel_multiplier=2,
bg_upsampler=upsampler,
)
last_pipeline_instance = gfpgan
last_pipeline_params = face_path
run_gc()
return gfpgan
def correct_gfpgan(
_job: JobContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
*,
upscale: UpscaleParams,
upsampler: Optional[RealESRGANer] = None,
**kwargs,
) -> Image.Image:
if upscale.correction_model is None:
logger.warn("no face model given, skipping")
return source_image
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
gfpgan = load_gfpgan(server, upscale, upsampler=upsampler)
output = np.array(source_image)
_, _, output = gfpgan.enhance(
output,
has_aligned=False,
only_center_face=False,
paste_back=True,
weight=upscale.face_strength,
)
output = Image.fromarray(output, "RGB")
return output