From 95886430a4b41c47ff17cf9eab78c01317b02097 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 27 Dec 2023 05:08:15 -0600 Subject: [PATCH] feat(api): support more RealESRGAN-based models --- api/onnx_web/convert/upscaling/resrgan.py | 38 ++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/convert/upscaling/resrgan.py b/api/onnx_web/convert/upscaling/resrgan.py index 291670a6..dabed6c5 100644 --- a/api/onnx_web/convert/upscaling/resrgan.py +++ b/api/onnx_web/convert/upscaling/resrgan.py @@ -11,6 +11,39 @@ logger = getLogger(__name__) TAG_X4_V3 = "real-esrgan-x4-v3" +SPECIAL_KEYS = { + "model.0.bias": "conv_first.bias", + "model.0.weight": "conv_first.weight", + "model.1.sub.23.bias": "conv_body.bias", + "model.1.sub.23.weight": "conv_body.weight", + "model.3.bias": "conv_up1.bias", + "model.3.weight": "conv_up1.weight", + "model.6.bias": "conv_up2.bias", + "model.6.weight": "conv_up2.weight", + "model.8.bias": "conv_hr.bias", + "model.8.weight": "conv_hr.weight", + "model.10.bias": "conv_last.bias", + "model.10.weight": "conv_last.weight", +} + +SUB_NAME = r"model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)" + + +def fix_resrgan_keys(model): + original_keys = list(model.keys()) + for key in original_keys: + if key in SPECIAL_KEYS: + new_key = SPECIAL_KEYS[key] + else: + # convert RDBN keys + sub_index, rdb_index, conv_index, node_type = key.match(SUB_NAME) + new_key = f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}" + + model[new_key] = model[key] + del model[key] + + return model + @torch.no_grad() def convert_upscale_resrgan( @@ -54,8 +87,11 @@ def convert_upscale_resrgan( torch_model = torch.load(source, map_location=conversion.map_location) if "params_ema" in torch_model: model.load_state_dict(torch_model["params_ema"]) - else: + elif "params" in torch_model: model.load_state_dict(torch_model["params"], strict=False) + else: + # keys need fixed up to match + model.load_state_dict(fix_resrgan_keys(torch_model), strict=False) model.to(conversion.training_device).train(False) model.eval()