fix(api): pass current device when loading GFPGAN
This commit is contained in:
parent
811b6640a8
commit
c7e0041229
|
@ -8,7 +8,7 @@ from PIL import Image
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
from ..device_pool import JobContext
|
from ..device_pool import JobContext
|
||||||
from ..params import ImageParams, StageParams, UpscaleParams
|
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||||
from ..utils import ServerContext, run_gc
|
from ..utils import ServerContext, run_gc
|
||||||
from .upscale_resrgan import load_resrgan
|
from .upscale_resrgan import load_resrgan
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ last_pipeline_params = None
|
||||||
|
|
||||||
|
|
||||||
def load_gfpgan(
|
def load_gfpgan(
|
||||||
ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None
|
ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams, upsampler: Optional[RealESRGANer] = None
|
||||||
):
|
):
|
||||||
global last_pipeline_instance
|
global last_pipeline_instance
|
||||||
global last_pipeline_params
|
global last_pipeline_params
|
||||||
|
@ -54,7 +54,7 @@ def load_gfpgan(
|
||||||
|
|
||||||
|
|
||||||
def correct_gfpgan(
|
def correct_gfpgan(
|
||||||
_job: JobContext,
|
job: JobContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
@ -69,7 +69,8 @@ def correct_gfpgan(
|
||||||
return source_image
|
return source_image
|
||||||
|
|
||||||
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
|
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
|
||||||
gfpgan = load_gfpgan(server, upscale, upsampler=upsampler)
|
device = job.get_device()
|
||||||
|
gfpgan = load_gfpgan(server, upscale, device, upsampler=upsampler)
|
||||||
|
|
||||||
output = np.array(source_image)
|
output = np.array(source_image)
|
||||||
_, _, output = gfpgan.enhance(
|
_, _, output = gfpgan.enhance(
|
||||||
|
|
Loading…
Reference in New Issue