1
0
Fork 0

fix(api): get ESRGAN/GFPGAN paths from server context, clean up test scripts

This commit is contained in:
Sean Sube 2023-01-16 13:02:15 -06:00
parent 0f2d6d2ec7
commit 120056f878
3 changed files with 44 additions and 48 deletions

View File

@ -10,19 +10,21 @@ from typing import Any
import numpy as np
import torch
denoise_strength = 0.5
from .utils import (
ServerContext
)
# TODO: these should all be params or config
fp16 = False
netscale = 4
outscale = 4
pre_pad = 0
tile = 0
tile_pad = 10
gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
resrgan_name = 'RealESRGAN_x4plus'
resrgan_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
resrgan_path = path.join('..', 'models', 'RealESRGAN_x4plus.onnx')
class ONNXImage():
def __init__(self, source) -> None:
@ -55,9 +57,13 @@ class ONNXNet():
Provides the RRDBNet interface but using ONNX.
'''
def __init__(self) -> None:
def __init__(self, ctx: ServerContext) -> None:
'''
TODO: get platform provider from request params
'''
model_path = path.join(ctx.model_path, resrgan_name + '.onnx')
self.session = InferenceSession(
resrgan_path, providers=['DmlExecutionProvider'])
model_path, providers=['DmlExecutionProvider'])
def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name
@ -80,26 +86,37 @@ class ONNXNet():
return self
def make_resrgan(model_path, tile=0):
model_path = path.join(model_path, resrgan_name + '.pth')
class UpscaleParams():
def __init__(self, scale=4, faces=True, platform='onnx', denoise=0.5) -> None:
self.denoise = denoise
self.scale = scale
self.faces = faces
self.platform = platform
def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
model_path = path.join(ctx.model_path, resrgan_name + '.pth')
if not path.isfile(model_path):
for url in resrgan_url:
model_path = load_file_from_url(
url=url, model_dir=path.join(model_path, resrgan_name), progress=True, file_name=None)
# model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
# num_block=23, num_grow_ch=32, scale=4)
model = ONNXNet()
# use ONNX acceleration, if available
if params.platform == 'onnx':
model = ONNXNet(ctx)
else:
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale)
dni_weight = None
if resrgan_name == 'realesr-general-x4v3' and denoise_strength != 1:
if resrgan_name == 'realesr-general-x4v3' and params.denoise != 1:
wdn_model_path = model_path.replace(
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
dni_weight = [denoise_strength, 1 - denoise_strength]
dni_weight = [params.denoise, 1 - params.denoise]
upsampler = RealESRGANer(
scale=netscale,
scale=params.scale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
@ -111,19 +128,23 @@ def make_resrgan(model_path, tile=0):
return upsampler
def upscale_resrgan(source_image: Image, model_path: str, faces=True) -> Image:
def upscale_resrgan(ctx: ServerContext, source_image: Image, params: UpscaleParams) -> Image:
image = np.array(source_image)
upsampler = make_resrgan(model_path)
upsampler = make_resrgan(ctx.model_path)
# TODO: what is outscale for here?
output, _ = upsampler.enhance(image, outscale=outscale)
if faces:
output = upscale_gfpgan(output, make_resrgan(model_path, 512))
if params.faces:
output = upscale_gfpgan(ctx, output)
return Image.fromarray(output, 'RGB')
def upscale_gfpgan(image, upsampler) -> Image:
def upscale_gfpgan(ctx: ServerContext, image, upsampler=None) -> Image:
if upsampler is None:
upsampler = make_resrgan(ctx.model_path, 512)
face_enhancer = GFPGANer(
model_path=gfpgan_url,
upscale=outscale,

View File

@ -21,26 +21,3 @@ pipe = OnnxStableDiffusionPipeline.from_pretrained(model, provider='DmlExecution
image = pipe(prompt, height, width, num_inference_steps=steps, guidance_scale=cfg).images[0]
image.save(output)
print('saved test image to %s' % output)
upscale = path.join('..', 'outputs', 'test-large.png')
esrgan = path.join('..', 'models', 'RealESRGAN_x4plus.onnx')
print('upscaling test image...')
sess = ort.InferenceSession(esrgan, providers=['DmlExecutionProvider'])
in_image = cv2.imread(output, cv2.IMREAD_UNCHANGED)
in_mat = cv2.cvtColor(in_image, cv2.COLOR_BGR2RGB)
in_mat = np.transpose(in_mat, (2, 1, 0))[np.newaxis]
in_mat = in_mat.astype(np.float32)
in_mat = in_mat/255
start_time = time.time()
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
in_mat = torch.tensor(in_mat)
out_mat = sess.run([output_name], {input_name: in_mat.cpu().numpy()})[0]
elapsed_time = time.time() - start_time
print(elapsed_time)
print('upscaled test image to %s')

View File

@ -20,13 +20,13 @@ session = ort.InferenceSession(esrgan, providers=['DmlExecutionProvider'])
in_image = cv2.imread(output, cv2.IMREAD_UNCHANGED)
# convert to input format
in_mat = cv2.cvtColor(in_image, cv2.COLOR_BGR2RGB)
print('shape before', np.shape(in_mat))
in_mat = np.transpose(in_mat, (2, 1, 0))[np.newaxis]
print('shape after', np.shape(in_mat))
in_mat = in_mat.astype(np.float32)
in_mat = in_mat/255
in_mat = in_mat /255
# run network
start_time = time.time()
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
@ -37,11 +37,9 @@ out_mat = session.run([output_name], {
elapsed_time = time.time() - start_time
print(elapsed_time)
print('output shape', np.shape(out_mat))
# convert back to original format
out_mat = np.squeeze(out_mat, (0))
print(np.shape(out_mat))
out_mat = np.transpose(out_mat, (2, 1, 0))
print(out_mat, np.shape(out_mat))
out_mat = np.clip(out_mat, 0.0, 1.0)
out_mat = out_mat * 255
out_mat = out_mat.astype(np.uint8)