diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index 70ff74f6..4022c06c 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -1,16 +1,13 @@ from logging import getLogger from os import path -from typing import Optional import numpy as np from gfpgan import GFPGANer from PIL import Image -from realesrgan import RealESRGANer from ..device_pool import JobContext from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..utils import ServerContext, run_gc -from .upscale_resrgan import load_resrgan logger = getLogger(__name__) @@ -19,9 +16,7 @@ last_pipeline_instance = None last_pipeline_params = None -def load_gfpgan( - ctx: ServerContext, upscale: UpscaleParams, _device: DeviceParams -): +def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, _device: DeviceParams): global last_pipeline_instance global last_pipeline_params @@ -56,7 +51,6 @@ def correct_gfpgan( source_image: Image.Image, *, upscale: UpscaleParams, - upsampler: Optional[RealESRGANer] = None, **kwargs, ) -> Image.Image: if upscale.correction_model is None: @@ -65,7 +59,7 @@ def correct_gfpgan( logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model) device = job.get_device() - gfpgan = load_gfpgan(server, upscale, device, upsampler=upsampler) + gfpgan = load_gfpgan(server, upscale, device) output = np.array(source_image) _, _, output = gfpgan.enhance(