1
0
Fork 0

fix(api): pass upscale params when creating RESRGAN

This commit is contained in:
Sean Sube 2023-01-16 14:17:50 -06:00
parent 5e5d748c0b
commit 091c4e6109
3 changed files with 11 additions and 10 deletions

View File

@ -3,9 +3,9 @@ from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from os import path, environ
from sys import exit
from torch.onnx import export
import torch
import torch.onnx
from .upscale import (
gfpgan_url,
@ -17,6 +17,7 @@ model_path = environ.get('ONNX_WEB_MODEL_PATH',
path.join('..', 'models'))
@torch.no_grad()
def convert_real_esrgan():
dest_path = path.join(model_path, resrgan_name + '.pth')
print('converting Real ESRGAN into %s' % dest_path)
@ -44,7 +45,7 @@ def convert_real_esrgan():
with torch.no_grad():
dest_onnx = path.join(model_path, resrgan_name + '.onnx')
print('exporting Real ESRGAN model to %s' % dest_onnx)
torch.onnx.export(
export(
model,
rng,
dest_onnx,

View File

@ -97,7 +97,7 @@ def run_txt2img_pipeline(
).images[0]
if upscale.faces or upscale.scale > 1:
image = upscale_resrgan(ctx, image, upscale)
image = upscale_resrgan(ctx, upscale, image)
dest = safer_join(ctx.output_path, output)
image.save(dest)
@ -129,7 +129,7 @@ def run_img2img_pipeline(
).images[0]
if upscale.faces or upscale.scale > 1:
image = upscale_resrgan(ctx, image, upscale)
image = upscale_resrgan(ctx, upscale, image)
dest = safer_join(ctx.output_path, output)
image.save(dest)
@ -182,7 +182,7 @@ def run_inpaint_pipeline(
).images[0]
if upscale.faces or upscale.scale > 1:
image = upscale_resrgan(ctx, image, upscale)
image = upscale_resrgan(ctx, upscale, image)
dest = safer_join(ctx.output_path, output)
image.save(dest)

View File

@ -128,26 +128,26 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
return upsampler
def upscale_resrgan(ctx: ServerContext, source_image: Image, params: UpscaleParams) -> Image:
def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Image) -> Image:
print('upscaling image with Real ESRGAN', params)
image = np.array(source_image)
upsampler = make_resrgan(ctx.model_path)
upsampler = make_resrgan(ctx, params)
# TODO: what is outscale for here?
output, _ = upsampler.enhance(image, outscale=outscale)
if params.faces:
output = upscale_gfpgan(ctx, output)
output = upscale_gfpgan(ctx, params, output)
return Image.fromarray(output, 'RGB')
def upscale_gfpgan(ctx: ServerContext, image, upsampler=None) -> Image:
def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image:
print('correcting faces with GFPGAN')
if upsampler is None:
upsampler = make_resrgan(ctx.model_path, 512)
upsampler = make_resrgan(ctx, params, tile=512)
face_enhancer = GFPGANer(
model_path=gfpgan_url,