diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 2e8d952f..33338b43 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -46,6 +46,7 @@ class UpscaleRealESRGANStage(BaseStage): self.mod_scale = None self.half = half self.model = model + self.device = device model_file = "%s.%s" % (params.upscale_model, params.format) model_path = path.join(server.model_path, model_file) @@ -75,15 +76,16 @@ class UpscaleRealESRGANStage(BaseStage): logger.debug("loading Real ESRGAN upscale model from %s", model_path) - # TODO: shouldn't need the PTH file upsampler = RealESRGANWrapper( scale=params.scale, + model_path=None, dni_weight=dni_weight, model=model, tile=tile, tile_pad=params.tile_pad, pre_pad=params.pre_pad, half=("torch-fp16" in server.optimizations), + device=device.torch_str(), ) server.cache.set(ModelTypes.upscaling, cache_key, upsampler)