1
0
Fork 0

fix(api): pass current device when loading GFPGAN

This commit is contained in:
Sean Sube 2023-02-06 08:07:06 -06:00
parent 811b6640a8
commit c7e0041229
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 5 additions and 4 deletions

View File

@ -8,7 +8,7 @@ from PIL import Image
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from ..device_pool import JobContext from ..device_pool import JobContext
from ..params import ImageParams, StageParams, UpscaleParams from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..utils import ServerContext, run_gc from ..utils import ServerContext, run_gc
from .upscale_resrgan import load_resrgan from .upscale_resrgan import load_resrgan
@ -20,7 +20,7 @@ last_pipeline_params = None
def load_gfpgan( def load_gfpgan(
ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams, upsampler: Optional[RealESRGANer] = None
): ):
global last_pipeline_instance global last_pipeline_instance
global last_pipeline_params global last_pipeline_params
@ -54,7 +54,7 @@ def load_gfpgan(
def correct_gfpgan( def correct_gfpgan(
_job: JobContext, job: JobContext,
server: ServerContext, server: ServerContext,
_stage: StageParams, _stage: StageParams,
_params: ImageParams, _params: ImageParams,
@ -69,7 +69,8 @@ def correct_gfpgan(
return source_image return source_image
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model) logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
gfpgan = load_gfpgan(server, upscale, upsampler=upsampler) device = job.get_device()
gfpgan = load_gfpgan(server, upscale, device, upsampler=upsampler)
output = np.array(source_image) output = np.array(source_image)
_, _, output = gfpgan.enhance( _, _, output = gfpgan.enhance(