1
0
Fork 0

attempt to fix face correction at various scales

This commit is contained in:
Sean Sube 2023-01-16 21:36:50 -06:00
parent 073ff8e02f
commit 03e06193eb
3 changed files with 10 additions and 7 deletions

View File

@ -24,8 +24,10 @@ sources: Dict[str, List[Tuple[str, str]]] = {
'stabilityai/stable-diffusion-2-inpainting'),
],
'gfpgan': [
('correction-gfpgan-v1-3',
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'),
('correction-gfpgan-v1-3-x2',
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', 2),
('correction-gfpgan-v1-3-x4',
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', 4),
],
'real_esrgan': [
('upscaling-real-esrgan-x2-plus',
@ -92,7 +94,7 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
@torch.no_grad()
def convert_gfpgan(name: str, url: str, opset: int):
def convert_gfpgan(name: str, url: str, scale: int, opset: int):
dest_path = path.join(model_path, name + '.pth')
dest_onnx = path.join(model_path, name + '.onnx')
print('converting GFPGAN model: %s -> %s' % (name, dest_onnx))
@ -109,7 +111,7 @@ def convert_gfpgan(name: str, url: str, opset: int):
print('loading and training model')
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=4)
num_block=23, num_grow_ch=32, scale=scale)
torch_model = torch.load(dest_path, map_location=map_location)
# TODO: make sure strict=False is safe here

View File

@ -177,8 +177,7 @@ def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=N
return image
if upsampler is None:
bg_params = params.rescale(params.outscale)
upsampler = make_resrgan(ctx, bg_params, tile=512)
upsampler = make_resrgan(ctx, params, tile=512)
face_path = path.join(ctx.model_path, '%s.pth' % (params.correction_model))

View File

@ -11,7 +11,9 @@ export const MODEL_LABELS = {
'upscaling-real-esrgan-x4-plus': 'Real ESRGAN x4 Plus',
'upscaling-real-esrgan-x4-v3': 'Real ESRGAN x4 v3',
// correction
'correction-gfpgan-v1-3': 'GFPGAN v1.3',
'correction-gfpgan-v1-3-x2': 'GFPGAN v1.3 x2',
'correction-gfpgan-v1-3-x4': 'GFPGAN v1.3 x4',
};
export const PLATFORM_LABELS: Record<string, string> = {