compile pattern
This commit is contained in:
parent
95886430a4
commit
404f24f9ad
|
@ -1,5 +1,6 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
|
from re import compile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
@ -26,7 +27,7 @@ SPECIAL_KEYS = {
|
||||||
"model.10.weight": "conv_last.weight",
|
"model.10.weight": "conv_last.weight",
|
||||||
}
|
}
|
||||||
|
|
||||||
SUB_NAME = r"model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)"
|
SUB_NAME = compile(r"^model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)$")
|
||||||
|
|
||||||
|
|
||||||
def fix_resrgan_keys(model):
|
def fix_resrgan_keys(model):
|
||||||
|
@ -36,8 +37,14 @@ def fix_resrgan_keys(model):
|
||||||
new_key = SPECIAL_KEYS[key]
|
new_key = SPECIAL_KEYS[key]
|
||||||
else:
|
else:
|
||||||
# convert RDBN keys
|
# convert RDBN keys
|
||||||
sub_index, rdb_index, conv_index, node_type = key.match(SUB_NAME)
|
matched = SUB_NAME.match(key)
|
||||||
new_key = f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}"
|
if matched is not None:
|
||||||
|
sub_index, rdb_index, conv_index, node_type = matched.groups()
|
||||||
|
new_key = (
|
||||||
|
f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("unknown key format")
|
||||||
|
|
||||||
model[new_key] = model[key]
|
model[new_key] = model[key]
|
||||||
del model[key]
|
del model[key]
|
||||||
|
|
Loading…
Reference in New Issue