From 404f24f9adaf5ffa46044b0ce753cddfd3bd27e9 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 27 Dec 2023 05:17:17 -0600 Subject: [PATCH] compile pattern --- api/onnx_web/convert/upscaling/resrgan.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/api/onnx_web/convert/upscaling/resrgan.py b/api/onnx_web/convert/upscaling/resrgan.py index dabed6c5..0069bcd6 100644 --- a/api/onnx_web/convert/upscaling/resrgan.py +++ b/api/onnx_web/convert/upscaling/resrgan.py @@ -1,5 +1,6 @@ from logging import getLogger from os import path +from re import compile import torch from torch.onnx import export @@ -26,7 +27,7 @@ SPECIAL_KEYS = { "model.10.weight": "conv_last.weight", } -SUB_NAME = r"model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)" +SUB_NAME = compile(r"^model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)$") def fix_resrgan_keys(model): @@ -36,8 +37,14 @@ def fix_resrgan_keys(model): new_key = SPECIAL_KEYS[key] else: # convert RDBN keys - sub_index, rdb_index, conv_index, node_type = key.match(SUB_NAME) - new_key = f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}" + matched = SUB_NAME.match(key) + if matched is not None: + sub_index, rdb_index, conv_index, node_type = matched.groups() + new_key = ( + f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}" + ) + else: + raise ValueError("unknown key format") model[new_key] = model[key] del model[key]