fix(api): pass upscale params when creating RESRGAN
This commit is contained in:
parent
5e5d748c0b
commit
091c4e6109
|
@ -3,9 +3,9 @@ from basicsr.archs.rrdbnet_arch import RRDBNet
|
|||
from basicsr.utils.download_util import load_file_from_url
|
||||
from os import path, environ
|
||||
from sys import exit
|
||||
from torch.onnx import export
|
||||
|
||||
import torch
|
||||
import torch.onnx
|
||||
|
||||
from .upscale import (
|
||||
gfpgan_url,
|
||||
|
@ -17,6 +17,7 @@ model_path = environ.get('ONNX_WEB_MODEL_PATH',
|
|||
path.join('..', 'models'))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_real_esrgan():
|
||||
dest_path = path.join(model_path, resrgan_name + '.pth')
|
||||
print('converting Real ESRGAN into %s' % dest_path)
|
||||
|
@ -44,7 +45,7 @@ def convert_real_esrgan():
|
|||
with torch.no_grad():
|
||||
dest_onnx = path.join(model_path, resrgan_name + '.onnx')
|
||||
print('exporting Real ESRGAN model to %s' % dest_onnx)
|
||||
torch.onnx.export(
|
||||
export(
|
||||
model,
|
||||
rng,
|
||||
dest_onnx,
|
||||
|
|
|
@ -97,7 +97,7 @@ def run_txt2img_pipeline(
|
|||
).images[0]
|
||||
|
||||
if upscale.faces or upscale.scale > 1:
|
||||
image = upscale_resrgan(ctx, image, upscale)
|
||||
image = upscale_resrgan(ctx, upscale, image)
|
||||
|
||||
dest = safer_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
@ -129,7 +129,7 @@ def run_img2img_pipeline(
|
|||
).images[0]
|
||||
|
||||
if upscale.faces or upscale.scale > 1:
|
||||
image = upscale_resrgan(ctx, image, upscale)
|
||||
image = upscale_resrgan(ctx, upscale, image)
|
||||
|
||||
dest = safer_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
@ -182,7 +182,7 @@ def run_inpaint_pipeline(
|
|||
).images[0]
|
||||
|
||||
if upscale.faces or upscale.scale > 1:
|
||||
image = upscale_resrgan(ctx, image, upscale)
|
||||
image = upscale_resrgan(ctx, upscale, image)
|
||||
|
||||
dest = safer_join(ctx.output_path, output)
|
||||
image.save(dest)
|
||||
|
|
|
@ -128,26 +128,26 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
|||
return upsampler
|
||||
|
||||
|
||||
def upscale_resrgan(ctx: ServerContext, source_image: Image, params: UpscaleParams) -> Image:
|
||||
def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Image) -> Image:
|
||||
print('upscaling image with Real ESRGAN', params)
|
||||
|
||||
image = np.array(source_image)
|
||||
upsampler = make_resrgan(ctx.model_path)
|
||||
upsampler = make_resrgan(ctx, params)
|
||||
|
||||
# TODO: what is outscale for here?
|
||||
output, _ = upsampler.enhance(image, outscale=outscale)
|
||||
|
||||
if params.faces:
|
||||
output = upscale_gfpgan(ctx, output)
|
||||
output = upscale_gfpgan(ctx, params, output)
|
||||
|
||||
return Image.fromarray(output, 'RGB')
|
||||
|
||||
|
||||
def upscale_gfpgan(ctx: ServerContext, image, upsampler=None) -> Image:
|
||||
def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image:
|
||||
print('correcting faces with GFPGAN')
|
||||
|
||||
if upsampler is None:
|
||||
upsampler = make_resrgan(ctx.model_path, 512)
|
||||
upsampler = make_resrgan(ctx, params, tile=512)
|
||||
|
||||
face_enhancer = GFPGANer(
|
||||
model_path=gfpgan_url,
|
||||
|
|
Loading…
Reference in New Issue