diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index ffd55181..203c8375 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -8,7 +8,7 @@ from PIL import Image from realesrgan import RealESRGANer from ..device_pool import JobContext -from ..params import ImageParams, StageParams, UpscaleParams +from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..utils import ServerContext, run_gc from .upscale_resrgan import load_resrgan @@ -20,7 +20,7 @@ last_pipeline_params = None def load_gfpgan( - ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None + ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams, upsampler: Optional[RealESRGANer] = None ): global last_pipeline_instance global last_pipeline_params @@ -54,7 +54,7 @@ def load_gfpgan( def correct_gfpgan( - _job: JobContext, + job: JobContext, server: ServerContext, _stage: StageParams, _params: ImageParams, @@ -69,7 +69,8 @@ def correct_gfpgan( return source_image logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model) - gfpgan = load_gfpgan(server, upscale, upsampler=upsampler) + device = job.get_device() + gfpgan = load_gfpgan(server, upscale, device, upsampler=upsampler) output = np.array(source_image) _, _, output = gfpgan.enhance(