feat(api): support more RealESRGAN-based models
This commit is contained in:
parent
f5e7b3b865
commit
95886430a4
|
@ -11,6 +11,39 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
TAG_X4_V3 = "real-esrgan-x4-v3"
|
TAG_X4_V3 = "real-esrgan-x4-v3"
|
||||||
|
|
||||||
|
SPECIAL_KEYS = {
|
||||||
|
"model.0.bias": "conv_first.bias",
|
||||||
|
"model.0.weight": "conv_first.weight",
|
||||||
|
"model.1.sub.23.bias": "conv_body.bias",
|
||||||
|
"model.1.sub.23.weight": "conv_body.weight",
|
||||||
|
"model.3.bias": "conv_up1.bias",
|
||||||
|
"model.3.weight": "conv_up1.weight",
|
||||||
|
"model.6.bias": "conv_up2.bias",
|
||||||
|
"model.6.weight": "conv_up2.weight",
|
||||||
|
"model.8.bias": "conv_hr.bias",
|
||||||
|
"model.8.weight": "conv_hr.weight",
|
||||||
|
"model.10.bias": "conv_last.bias",
|
||||||
|
"model.10.weight": "conv_last.weight",
|
||||||
|
}
|
||||||
|
|
||||||
|
SUB_NAME = r"model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)"
|
||||||
|
|
||||||
|
|
||||||
|
def fix_resrgan_keys(model):
|
||||||
|
original_keys = list(model.keys())
|
||||||
|
for key in original_keys:
|
||||||
|
if key in SPECIAL_KEYS:
|
||||||
|
new_key = SPECIAL_KEYS[key]
|
||||||
|
else:
|
||||||
|
# convert RDBN keys
|
||||||
|
sub_index, rdb_index, conv_index, node_type = key.match(SUB_NAME)
|
||||||
|
new_key = f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}"
|
||||||
|
|
||||||
|
model[new_key] = model[key]
|
||||||
|
del model[key]
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_upscale_resrgan(
|
def convert_upscale_resrgan(
|
||||||
|
@ -54,8 +87,11 @@ def convert_upscale_resrgan(
|
||||||
torch_model = torch.load(source, map_location=conversion.map_location)
|
torch_model = torch.load(source, map_location=conversion.map_location)
|
||||||
if "params_ema" in torch_model:
|
if "params_ema" in torch_model:
|
||||||
model.load_state_dict(torch_model["params_ema"])
|
model.load_state_dict(torch_model["params_ema"])
|
||||||
else:
|
elif "params" in torch_model:
|
||||||
model.load_state_dict(torch_model["params"], strict=False)
|
model.load_state_dict(torch_model["params"], strict=False)
|
||||||
|
else:
|
||||||
|
# keys need fixed up to match
|
||||||
|
model.load_state_dict(fix_resrgan_keys(torch_model), strict=False)
|
||||||
|
|
||||||
model.to(conversion.training_device).train(False)
|
model.to(conversion.training_device).train(False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
Loading…
Reference in New Issue