diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 49b06ee4..2e8d952f 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -25,6 +25,28 @@ class UpscaleRealESRGANStage(BaseStage): # TODO: rewrite and remove from realesrgan import RealESRGANer + class RealESRGANWrapper(RealESRGANer): + def __init__( + self, + scale, + model_path, + dni_weight=None, + model=None, + tile=0, + tile_pad=10, + pre_pad=10, + half=False, + device=None, + gpu_id=None, + ): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + self.model = model + model_file = "%s.%s" % (params.upscale_model, params.format) model_path = path.join(server.model_path, model_file) @@ -54,16 +76,14 @@ class UpscaleRealESRGANStage(BaseStage): logger.debug("loading Real ESRGAN upscale model from %s", model_path) # TODO: shouldn't need the PTH file - model_path_pth = path.join(server.cache_path, ("%s.pth" % params.upscale_model)) - upsampler = RealESRGANer( + upsampler = RealESRGANWrapper( scale=params.scale, - model_path=model_path_pth, dni_weight=dni_weight, model=model, tile=tile, tile_pad=params.tile_pad, pre_pad=params.pre_pad, - half=False, # TODO: use server optimizations + half=("torch-fp16" in server.optimizations), ) server.cache.set(ModelTypes.upscaling, cache_key, upsampler) diff --git a/api/onnx_web/convert/upscaling/resrgan.py b/api/onnx_web/convert/upscaling/resrgan.py index 0069bcd6..f3c189af 100644 --- a/api/onnx_web/convert/upscaling/resrgan.py +++ b/api/onnx_web/convert/upscaling/resrgan.py @@ -27,7 +27,7 @@ SPECIAL_KEYS = { "model.10.weight": "conv_last.weight", } -SUB_NAME = compile(r"^model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)$") +SUB_NAME = compile(r"^model\.1\.sub\.(\d+)\.RDB(\d)\.conv(\d)\.0\.(bias|weight)$") def fix_resrgan_keys(model): @@ -41,11 +41,14 @@ def fix_resrgan_keys(model): if matched is not None: sub_index, rdb_index, conv_index, node_type = matched.groups() new_key = ( - f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}" + f"body.{sub_index}.rdb{rdb_index}.conv{conv_index}.{node_type}" ) else: raise ValueError("unknown key format") + if new_key in model: + raise ValueError("key collision") + model[new_key] = model[key] del model[key]