from logging import getLogger from os import path import numpy as np from PIL import Image from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..server import JobContext, ServerContext from ..utils import run_gc logger = getLogger(__name__) def load_gfpgan( server: ServerContext, _stage: StageParams, upscale: UpscaleParams, device: DeviceParams, ): # must be within the load function for patch to take effect from gfpgan import GFPGANer face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model)) cache_key = (face_path,) cache_pipe = server.cache.get("gfpgan", cache_key) if cache_pipe is not None: logger.info("reusing existing GFPGAN pipeline") return cache_pipe logger.debug("loading GFPGAN model from %s", face_path) # TODO: find a way to pass the ONNX model to underlying architectures gfpgan = GFPGANer( arch="clean", bg_upsampler=None, channel_multiplier=2, model_path=face_path, upscale=upscale.face_outscale, ) server.cache.set("gfpgan", cache_key, gfpgan) run_gc([device]) return gfpgan def correct_gfpgan( job: JobContext, server: ServerContext, stage: StageParams, _params: ImageParams, source: Image.Image, *, upscale: UpscaleParams, stage_source: Image.Image, **kwargs, ) -> Image.Image: upscale = upscale.with_args(**kwargs) source = stage_source or source if upscale.correction_model is None: logger.warn("no face model given, skipping") return source logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model) device = job.get_device() gfpgan = load_gfpgan(server, stage, upscale, device) output = np.array(source) _, _, output = gfpgan.enhance( output, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength, ) output = Image.fromarray(output, "RGB") return output