compile pattern
This commit is contained in:
parent
95886430a4
commit
404f24f9ad
|
@ -1,5 +1,6 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from re import compile
|
||||
|
||||
import torch
|
||||
from torch.onnx import export
|
||||
|
@ -26,7 +27,7 @@ SPECIAL_KEYS = {
|
|||
"model.10.weight": "conv_last.weight",
|
||||
}
|
||||
|
||||
SUB_NAME = r"model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)"
|
||||
SUB_NAME = compile(r"^model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)$")
|
||||
|
||||
|
||||
def fix_resrgan_keys(model):
|
||||
|
@ -36,8 +37,14 @@ def fix_resrgan_keys(model):
|
|||
new_key = SPECIAL_KEYS[key]
|
||||
else:
|
||||
# convert RDBN keys
|
||||
sub_index, rdb_index, conv_index, node_type = key.match(SUB_NAME)
|
||||
new_key = f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}"
|
||||
matched = SUB_NAME.match(key)
|
||||
if matched is not None:
|
||||
sub_index, rdb_index, conv_index, node_type = matched.groups()
|
||||
new_key = (
|
||||
f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}"
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown key format")
|
||||
|
||||
model[new_key] = model[key]
|
||||
del model[key]
|
||||
|
|
Loading…
Reference in New Issue