fix(api): pass current device when loading GFPGAN
This commit is contained in:
parent
811b6640a8
commit
c7e0041229
|
@ -8,7 +8,7 @@ from PIL import Image
|
|||
from realesrgan import RealESRGANer
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..params import ImageParams, StageParams, UpscaleParams
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..utils import ServerContext, run_gc
|
||||
from .upscale_resrgan import load_resrgan
|
||||
|
||||
|
@ -20,7 +20,7 @@ last_pipeline_params = None
|
|||
|
||||
|
||||
def load_gfpgan(
|
||||
ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None
|
||||
ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams, upsampler: Optional[RealESRGANer] = None
|
||||
):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_params
|
||||
|
@ -54,7 +54,7 @@ def load_gfpgan(
|
|||
|
||||
|
||||
def correct_gfpgan(
|
||||
_job: JobContext,
|
||||
job: JobContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
|
@ -69,7 +69,8 @@ def correct_gfpgan(
|
|||
return source_image
|
||||
|
||||
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
|
||||
gfpgan = load_gfpgan(server, upscale, upsampler=upsampler)
|
||||
device = job.get_device()
|
||||
gfpgan = load_gfpgan(server, upscale, device, upsampler=upsampler)
|
||||
|
||||
output = np.array(source_image)
|
||||
_, _, output = gfpgan.enhance(
|
||||
|
|
Loading…
Reference in New Issue