From 7d5668952759953fd31c5b67fb68bc1e49ee1b0a Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 27 Dec 2023 09:01:38 -0600 Subject: [PATCH] pass device to wrapper --- api/onnx_web/chain/upscale_resrgan.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 2e8d952f..33338b43 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -46,6 +46,7 @@ class UpscaleRealESRGANStage(BaseStage): self.mod_scale = None self.half = half self.model = model + self.device = device model_file = "%s.%s" % (params.upscale_model, params.format) model_path = path.join(server.model_path, model_file) @@ -75,15 +76,16 @@ class UpscaleRealESRGANStage(BaseStage): logger.debug("loading Real ESRGAN upscale model from %s", model_path) - # TODO: shouldn't need the PTH file upsampler = RealESRGANWrapper( scale=params.scale, + model_path=None, dni_weight=dni_weight, model=model, tile=tile, tile_pad=params.tile_pad, pre_pad=params.pre_pad, half=("torch-fp16" in server.optimizations), + device=device.torch_str(), ) server.cache.set(ModelTypes.upscaling, cache_key, upsampler)