1
0
Fork 0

fix(api): pin outscale for GFPGAN to 1 to avoid sparse tiling

This commit is contained in:
Sean Sube 2023-01-31 17:08:30 -06:00
parent f4fc6271bc
commit c34ddacf55
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 21 additions and 4 deletions

View File

@ -31,7 +31,8 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[
global last_pipeline_params
if upsampler is None:
upsampler = load_resrgan(ctx, upscale)
bg_upscale = upscale.rescale(upscale.outscale)
upsampler = load_resrgan(ctx, bg_upscale)
face_path = path.join(ctx.model_path, '%s.pth' %
(upscale.correction_model))
@ -40,7 +41,7 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[
logger.info('reusing existing GFPGAN pipeline')
return last_pipeline_instance
# TODO: doesn't have a model param, not sure how to pass ONNX model
# TODO: find a way to pass the ONNX model to underlying architectures
gfpgan = GFPGANer(
model_path=face_path,
upscale=upscale.outscale,
@ -73,7 +74,7 @@ def correct_gfpgan(
output = np.array(source_image)
_, _, output = gfpgan.enhance(
source_image, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength)
output, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength)
output = Image.fromarray(output, 'RGB')
return output

View File

@ -131,5 +131,21 @@ class UpscaleParams():
self.scale = scale
self.tile_pad = tile_pad
def rescale(self, scale: int):
return UpscaleParams(
self.upscale_model,
self.provider,
correction_model=self.correction_model,
denoise=self.denoise,
faces=self.faces,
face_strength=self.face_strength,
format=self.format,
half=self.half,
outscale=scale,
scale=scale,
pre_pad=self.pre_pad,
tile_pad=self.tile_pad,
)
def resize(self, size: Size) -> Size:
return Size(size.width * self.outscale, size.height * self.outscale)

View File

@ -49,7 +49,7 @@ def run_upscale_correction(
if upscale.faces:
stage = StageParams(tile_size=stage.tile_size,
outscale=upscale.outscale)
outscale=1)
chain.append((correct_gfpgan, stage, kwargs))
return chain(ctx, params, image)