1
0
Fork 0

compile pattern

This commit is contained in:
Sean Sube 2023-12-27 05:17:17 -06:00
parent 95886430a4
commit 404f24f9ad
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 10 additions and 3 deletions

View File

@ -1,5 +1,6 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from re import compile
import torch import torch
from torch.onnx import export from torch.onnx import export
@ -26,7 +27,7 @@ SPECIAL_KEYS = {
"model.10.weight": "conv_last.weight", "model.10.weight": "conv_last.weight",
} }
SUB_NAME = r"model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)" SUB_NAME = compile(r"^model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)$")
def fix_resrgan_keys(model): def fix_resrgan_keys(model):
@ -36,8 +37,14 @@ def fix_resrgan_keys(model):
new_key = SPECIAL_KEYS[key] new_key = SPECIAL_KEYS[key]
else: else:
# convert RDBN keys # convert RDBN keys
sub_index, rdb_index, conv_index, node_type = key.match(SUB_NAME) matched = SUB_NAME.match(key)
new_key = f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}" if matched is not None:
sub_index, rdb_index, conv_index, node_type = matched.groups()
new_key = (
f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}"
)
else:
raise ValueError("unknown key format")
model[new_key] = model[key] model[new_key] = model[key]
del model[key] del model[key]