1
0
Fork 0

compile pattern

This commit is contained in:
Sean Sube 2023-12-27 05:17:17 -06:00
parent 95886430a4
commit 404f24f9ad
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 10 additions and 3 deletions

View File

@ -1,5 +1,6 @@
from logging import getLogger
from os import path
from re import compile
import torch
from torch.onnx import export
@ -26,7 +27,7 @@ SPECIAL_KEYS = {
"model.10.weight": "conv_last.weight",
}
SUB_NAME = r"model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)"
SUB_NAME = compile(r"^model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)$")
def fix_resrgan_keys(model):
@ -36,8 +37,14 @@ def fix_resrgan_keys(model):
new_key = SPECIAL_KEYS[key]
else:
# convert RDBN keys
sub_index, rdb_index, conv_index, node_type = key.match(SUB_NAME)
new_key = f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}"
matched = SUB_NAME.match(key)
if matched is not None:
sub_index, rdb_index, conv_index, node_type = matched.groups()
new_key = (
f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}"
)
else:
raise ValueError("unknown key format")
model[new_key] = model[key]
del model[key]