feat(api): support more RealESRGAN-based models
This commit is contained in:
parent
f5e7b3b865
commit
95886430a4
|
@ -11,6 +11,39 @@ logger = getLogger(__name__)
|
|||
|
||||
TAG_X4_V3 = "real-esrgan-x4-v3"
|
||||
|
||||
SPECIAL_KEYS = {
|
||||
"model.0.bias": "conv_first.bias",
|
||||
"model.0.weight": "conv_first.weight",
|
||||
"model.1.sub.23.bias": "conv_body.bias",
|
||||
"model.1.sub.23.weight": "conv_body.weight",
|
||||
"model.3.bias": "conv_up1.bias",
|
||||
"model.3.weight": "conv_up1.weight",
|
||||
"model.6.bias": "conv_up2.bias",
|
||||
"model.6.weight": "conv_up2.weight",
|
||||
"model.8.bias": "conv_hr.bias",
|
||||
"model.8.weight": "conv_hr.weight",
|
||||
"model.10.bias": "conv_last.bias",
|
||||
"model.10.weight": "conv_last.weight",
|
||||
}
|
||||
|
||||
SUB_NAME = r"model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)"
|
||||
|
||||
|
||||
def fix_resrgan_keys(model):
|
||||
original_keys = list(model.keys())
|
||||
for key in original_keys:
|
||||
if key in SPECIAL_KEYS:
|
||||
new_key = SPECIAL_KEYS[key]
|
||||
else:
|
||||
# convert RDBN keys
|
||||
sub_index, rdb_index, conv_index, node_type = key.match(SUB_NAME)
|
||||
new_key = f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}"
|
||||
|
||||
model[new_key] = model[key]
|
||||
del model[key]
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_upscale_resrgan(
|
||||
|
@ -54,8 +87,11 @@ def convert_upscale_resrgan(
|
|||
torch_model = torch.load(source, map_location=conversion.map_location)
|
||||
if "params_ema" in torch_model:
|
||||
model.load_state_dict(torch_model["params_ema"])
|
||||
else:
|
||||
elif "params" in torch_model:
|
||||
model.load_state_dict(torch_model["params"], strict=False)
|
||||
else:
|
||||
# keys need fixed up to match
|
||||
model.load_state_dict(fix_resrgan_keys(torch_model), strict=False)
|
||||
|
||||
model.to(conversion.training_device).train(False)
|
||||
model.eval()
|
||||
|
|
Loading…
Reference in New Issue