fix(api): load upscaling model from models dir
This commit is contained in:
parent
45d65d1342
commit
806503c709
|
@ -273,8 +273,7 @@ def run_txt2img_pipeline(model, provider, scheduler, prompt, negative_prompt, cf
|
|||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=steps,
|
||||
).images[0]
|
||||
|
||||
image = upscale_resrgan(image)
|
||||
image = upscale_resrgan(image, model_path)
|
||||
image.save(output)
|
||||
|
||||
print('saved txt2img output: %s' % (output))
|
||||
|
@ -295,6 +294,7 @@ def run_img2img_pipeline(model, provider, scheduler, prompt, negative_prompt, cf
|
|||
num_inference_steps=steps,
|
||||
strength=strength,
|
||||
).images[0]
|
||||
image = upscale_resrgan(image, model_path)
|
||||
image.save(output)
|
||||
|
||||
print('saved img2img output: %s' % (output))
|
||||
|
|
|
@ -11,7 +11,7 @@ denoise_strength = 0.5
|
|||
gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
|
||||
resrgan_url = [
|
||||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
|
||||
fp32 = True
|
||||
fp16 = False
|
||||
model_name = 'RealESRGAN_x4plus'
|
||||
netscale = 4
|
||||
outscale = 4
|
||||
|
@ -20,13 +20,12 @@ tile = 0
|
|||
tile_pad = 10
|
||||
|
||||
|
||||
def upscale_resrgan(source_image: Image, faces=True) -> Image:
|
||||
model_path = path.join('weights', model_name + '.pth')
|
||||
def make_resrgan(model_path):
|
||||
model_path = path.join(model_path, model_name + '.pth')
|
||||
if not path.isfile(model_path):
|
||||
ROOT_DIR = path.dirname(path.abspath(__file__))
|
||||
for url in resrgan_url:
|
||||
model_path = load_file_from_url(
|
||||
url=url, model_dir=path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
|
||||
url=url, model_dir=path.join(model_path, model_name), progress=True, file_name=None)
|
||||
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=4)
|
||||
|
@ -46,9 +45,15 @@ def upscale_resrgan(source_image: Image, faces=True) -> Image:
|
|||
tile=tile,
|
||||
tile_pad=tile_pad,
|
||||
pre_pad=pre_pad,
|
||||
half=fp32)
|
||||
half=fp16)
|
||||
|
||||
return upsampler
|
||||
|
||||
|
||||
def upscale_resrgan(source_image: Image, model_path: str, faces=True) -> Image:
|
||||
image = np.array(source_image)
|
||||
upsampler = make_resrgan(model_path)
|
||||
|
||||
output, _ = upsampler.enhance(image, outscale=outscale)
|
||||
|
||||
if faces:
|
||||
|
@ -65,6 +70,7 @@ def upscale_gfpgan(image, upsampler) -> Image:
|
|||
channel_multiplier=2,
|
||||
bg_upsampler=upsampler)
|
||||
|
||||
_, _, output = face_enhancer.enhance(image, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
_, _, output = face_enhancer.enhance(
|
||||
image, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
|
||||
return output
|
||||
|
|
Loading…
Reference in New Issue