1
0
Fork 0

wrap ESRGAN helper

This commit is contained in:
Sean Sube 2023-12-27 08:47:06 -06:00
parent 404f24f9ad
commit d17b946091
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 29 additions and 6 deletions

View File

@ -25,6 +25,28 @@ class UpscaleRealESRGANStage(BaseStage):
# TODO: rewrite and remove
from realesrgan import RealESRGANer
class RealESRGANWrapper(RealESRGANer):
def __init__(
self,
scale,
model_path,
dni_weight=None,
model=None,
tile=0,
tile_pad=10,
pre_pad=10,
half=False,
device=None,
gpu_id=None,
):
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale = None
self.half = half
self.model = model
model_file = "%s.%s" % (params.upscale_model, params.format)
model_path = path.join(server.model_path, model_file)
@ -54,16 +76,14 @@ class UpscaleRealESRGANStage(BaseStage):
logger.debug("loading Real ESRGAN upscale model from %s", model_path)
# TODO: shouldn't need the PTH file
model_path_pth = path.join(server.cache_path, ("%s.pth" % params.upscale_model))
upsampler = RealESRGANer(
upsampler = RealESRGANWrapper(
scale=params.scale,
model_path=model_path_pth,
dni_weight=dni_weight,
model=model,
tile=tile,
tile_pad=params.tile_pad,
pre_pad=params.pre_pad,
half=False, # TODO: use server optimizations
half=("torch-fp16" in server.optimizations),
)
server.cache.set(ModelTypes.upscaling, cache_key, upsampler)

View File

@ -27,7 +27,7 @@ SPECIAL_KEYS = {
"model.10.weight": "conv_last.weight",
}
SUB_NAME = compile(r"^model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)$")
SUB_NAME = compile(r"^model\.1\.sub\.(\d+)\.RDB(\d)\.conv(\d)\.0\.(bias|weight)$")
def fix_resrgan_keys(model):
@ -41,11 +41,14 @@ def fix_resrgan_keys(model):
if matched is not None:
sub_index, rdb_index, conv_index, node_type = matched.groups()
new_key = (
f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}"
f"body.{sub_index}.rdb{rdb_index}.conv{conv_index}.{node_type}"
)
else:
raise ValueError("unknown key format")
if new_key in model:
raise ValueError("key collision")
model[new_key] = model[key]
del model[key]