fix(api): get ESRGAN/GFPGAN paths from server context, clean up test scripts
This commit is contained in:
parent
0f2d6d2ec7
commit
120056f878
|
@ -10,19 +10,21 @@ from typing import Any
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
denoise_strength = 0.5
|
from .utils import (
|
||||||
|
ServerContext
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: these should all be params or config
|
||||||
fp16 = False
|
fp16 = False
|
||||||
netscale = 4
|
|
||||||
outscale = 4
|
outscale = 4
|
||||||
pre_pad = 0
|
pre_pad = 0
|
||||||
tile = 0
|
|
||||||
tile_pad = 10
|
tile_pad = 10
|
||||||
|
|
||||||
gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
|
gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
|
||||||
resrgan_name = 'RealESRGAN_x4plus'
|
resrgan_name = 'RealESRGAN_x4plus'
|
||||||
resrgan_url = [
|
resrgan_url = [
|
||||||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
|
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
|
||||||
resrgan_path = path.join('..', 'models', 'RealESRGAN_x4plus.onnx')
|
|
||||||
|
|
||||||
class ONNXImage():
|
class ONNXImage():
|
||||||
def __init__(self, source) -> None:
|
def __init__(self, source) -> None:
|
||||||
|
@ -55,9 +57,13 @@ class ONNXNet():
|
||||||
Provides the RRDBNet interface but using ONNX.
|
Provides the RRDBNet interface but using ONNX.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, ctx: ServerContext) -> None:
|
||||||
|
'''
|
||||||
|
TODO: get platform provider from request params
|
||||||
|
'''
|
||||||
|
model_path = path.join(ctx.model_path, resrgan_name + '.onnx')
|
||||||
self.session = InferenceSession(
|
self.session = InferenceSession(
|
||||||
resrgan_path, providers=['DmlExecutionProvider'])
|
model_path, providers=['DmlExecutionProvider'])
|
||||||
|
|
||||||
def __call__(self, image: Any) -> Any:
|
def __call__(self, image: Any) -> Any:
|
||||||
input_name = self.session.get_inputs()[0].name
|
input_name = self.session.get_inputs()[0].name
|
||||||
|
@ -80,26 +86,37 @@ class ONNXNet():
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def make_resrgan(model_path, tile=0):
|
class UpscaleParams():
|
||||||
model_path = path.join(model_path, resrgan_name + '.pth')
|
def __init__(self, scale=4, faces=True, platform='onnx', denoise=0.5) -> None:
|
||||||
|
self.denoise = denoise
|
||||||
|
self.scale = scale
|
||||||
|
self.faces = faces
|
||||||
|
self.platform = platform
|
||||||
|
|
||||||
|
|
||||||
|
def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
||||||
|
model_path = path.join(ctx.model_path, resrgan_name + '.pth')
|
||||||
if not path.isfile(model_path):
|
if not path.isfile(model_path):
|
||||||
for url in resrgan_url:
|
for url in resrgan_url:
|
||||||
model_path = load_file_from_url(
|
model_path = load_file_from_url(
|
||||||
url=url, model_dir=path.join(model_path, resrgan_name), progress=True, file_name=None)
|
url=url, model_dir=path.join(model_path, resrgan_name), progress=True, file_name=None)
|
||||||
|
|
||||||
# model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
# use ONNX acceleration, if available
|
||||||
# num_block=23, num_grow_ch=32, scale=4)
|
if params.platform == 'onnx':
|
||||||
model = ONNXNet()
|
model = ONNXNet(ctx)
|
||||||
|
else:
|
||||||
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||||
|
num_block=23, num_grow_ch=32, scale=params.scale)
|
||||||
|
|
||||||
dni_weight = None
|
dni_weight = None
|
||||||
if resrgan_name == 'realesr-general-x4v3' and denoise_strength != 1:
|
if resrgan_name == 'realesr-general-x4v3' and params.denoise != 1:
|
||||||
wdn_model_path = model_path.replace(
|
wdn_model_path = model_path.replace(
|
||||||
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
|
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
|
||||||
model_path = [model_path, wdn_model_path]
|
model_path = [model_path, wdn_model_path]
|
||||||
dni_weight = [denoise_strength, 1 - denoise_strength]
|
dni_weight = [params.denoise, 1 - params.denoise]
|
||||||
|
|
||||||
upsampler = RealESRGANer(
|
upsampler = RealESRGANer(
|
||||||
scale=netscale,
|
scale=params.scale,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
dni_weight=dni_weight,
|
dni_weight=dni_weight,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -111,19 +128,23 @@ def make_resrgan(model_path, tile=0):
|
||||||
return upsampler
|
return upsampler
|
||||||
|
|
||||||
|
|
||||||
def upscale_resrgan(source_image: Image, model_path: str, faces=True) -> Image:
|
def upscale_resrgan(ctx: ServerContext, source_image: Image, params: UpscaleParams) -> Image:
|
||||||
image = np.array(source_image)
|
image = np.array(source_image)
|
||||||
upsampler = make_resrgan(model_path)
|
upsampler = make_resrgan(ctx.model_path)
|
||||||
|
|
||||||
|
# TODO: what is outscale for here?
|
||||||
output, _ = upsampler.enhance(image, outscale=outscale)
|
output, _ = upsampler.enhance(image, outscale=outscale)
|
||||||
|
|
||||||
if faces:
|
if params.faces:
|
||||||
output = upscale_gfpgan(output, make_resrgan(model_path, 512))
|
output = upscale_gfpgan(ctx, output)
|
||||||
|
|
||||||
return Image.fromarray(output, 'RGB')
|
return Image.fromarray(output, 'RGB')
|
||||||
|
|
||||||
|
|
||||||
def upscale_gfpgan(image, upsampler) -> Image:
|
def upscale_gfpgan(ctx: ServerContext, image, upsampler=None) -> Image:
|
||||||
|
if upsampler is None:
|
||||||
|
upsampler = make_resrgan(ctx.model_path, 512)
|
||||||
|
|
||||||
face_enhancer = GFPGANer(
|
face_enhancer = GFPGANer(
|
||||||
model_path=gfpgan_url,
|
model_path=gfpgan_url,
|
||||||
upscale=outscale,
|
upscale=outscale,
|
||||||
|
|
|
@ -21,26 +21,3 @@ pipe = OnnxStableDiffusionPipeline.from_pretrained(model, provider='DmlExecution
|
||||||
image = pipe(prompt, height, width, num_inference_steps=steps, guidance_scale=cfg).images[0]
|
image = pipe(prompt, height, width, num_inference_steps=steps, guidance_scale=cfg).images[0]
|
||||||
image.save(output)
|
image.save(output)
|
||||||
print('saved test image to %s' % output)
|
print('saved test image to %s' % output)
|
||||||
|
|
||||||
|
|
||||||
upscale = path.join('..', 'outputs', 'test-large.png')
|
|
||||||
esrgan = path.join('..', 'models', 'RealESRGAN_x4plus.onnx')
|
|
||||||
|
|
||||||
print('upscaling test image...')
|
|
||||||
sess = ort.InferenceSession(esrgan, providers=['DmlExecutionProvider'])
|
|
||||||
|
|
||||||
in_image = cv2.imread(output, cv2.IMREAD_UNCHANGED)
|
|
||||||
|
|
||||||
in_mat = cv2.cvtColor(in_image, cv2.COLOR_BGR2RGB)
|
|
||||||
in_mat = np.transpose(in_mat, (2, 1, 0))[np.newaxis]
|
|
||||||
in_mat = in_mat.astype(np.float32)
|
|
||||||
in_mat = in_mat/255
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
input_name = sess.get_inputs()[0].name
|
|
||||||
output_name = sess.get_outputs()[0].name
|
|
||||||
in_mat = torch.tensor(in_mat)
|
|
||||||
out_mat = sess.run([output_name], {input_name: in_mat.cpu().numpy()})[0]
|
|
||||||
elapsed_time = time.time() - start_time
|
|
||||||
print(elapsed_time)
|
|
||||||
print('upscaled test image to %s')
|
|
|
@ -20,13 +20,13 @@ session = ort.InferenceSession(esrgan, providers=['DmlExecutionProvider'])
|
||||||
|
|
||||||
in_image = cv2.imread(output, cv2.IMREAD_UNCHANGED)
|
in_image = cv2.imread(output, cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
|
# convert to input format
|
||||||
in_mat = cv2.cvtColor(in_image, cv2.COLOR_BGR2RGB)
|
in_mat = cv2.cvtColor(in_image, cv2.COLOR_BGR2RGB)
|
||||||
print('shape before', np.shape(in_mat))
|
|
||||||
in_mat = np.transpose(in_mat, (2, 1, 0))[np.newaxis]
|
in_mat = np.transpose(in_mat, (2, 1, 0))[np.newaxis]
|
||||||
print('shape after', np.shape(in_mat))
|
|
||||||
in_mat = in_mat.astype(np.float32)
|
in_mat = in_mat.astype(np.float32)
|
||||||
in_mat = in_mat/255
|
in_mat = in_mat /255
|
||||||
|
|
||||||
|
# run network
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
input_name = session.get_inputs()[0].name
|
input_name = session.get_inputs()[0].name
|
||||||
output_name = session.get_outputs()[0].name
|
output_name = session.get_outputs()[0].name
|
||||||
|
@ -37,11 +37,9 @@ out_mat = session.run([output_name], {
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
print(elapsed_time)
|
print(elapsed_time)
|
||||||
|
|
||||||
print('output shape', np.shape(out_mat))
|
# convert back to original format
|
||||||
out_mat = np.squeeze(out_mat, (0))
|
out_mat = np.squeeze(out_mat, (0))
|
||||||
print(np.shape(out_mat))
|
|
||||||
out_mat = np.transpose(out_mat, (2, 1, 0))
|
out_mat = np.transpose(out_mat, (2, 1, 0))
|
||||||
print(out_mat, np.shape(out_mat))
|
|
||||||
out_mat = np.clip(out_mat, 0.0, 1.0)
|
out_mat = np.clip(out_mat, 0.0, 1.0)
|
||||||
out_mat = out_mat * 255
|
out_mat = out_mat * 255
|
||||||
out_mat = out_mat.astype(np.uint8)
|
out_mat = out_mat.astype(np.uint8)
|
||||||
|
|
Loading…
Reference in New Issue