1
0
Fork 0

pass device to wrapper

This commit is contained in:
Sean Sube 2023-12-27 09:01:38 -06:00
parent d17b946091
commit 7d56689527
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 3 additions and 1 deletions

View File

@ -46,6 +46,7 @@ class UpscaleRealESRGANStage(BaseStage):
self.mod_scale = None self.mod_scale = None
self.half = half self.half = half
self.model = model self.model = model
self.device = device
model_file = "%s.%s" % (params.upscale_model, params.format) model_file = "%s.%s" % (params.upscale_model, params.format)
model_path = path.join(server.model_path, model_file) model_path = path.join(server.model_path, model_file)
@ -75,15 +76,16 @@ class UpscaleRealESRGANStage(BaseStage):
logger.debug("loading Real ESRGAN upscale model from %s", model_path) logger.debug("loading Real ESRGAN upscale model from %s", model_path)
# TODO: shouldn't need the PTH file
upsampler = RealESRGANWrapper( upsampler = RealESRGANWrapper(
scale=params.scale, scale=params.scale,
model_path=None,
dni_weight=dni_weight, dni_weight=dni_weight,
model=model, model=model,
tile=tile, tile=tile,
tile_pad=params.tile_pad, tile_pad=params.tile_pad,
pre_pad=params.pre_pad, pre_pad=params.pre_pad,
half=("torch-fp16" in server.optimizations), half=("torch-fp16" in server.optimizations),
device=device.torch_str(),
) )
server.cache.set(ModelTypes.upscaling, cache_key, upsampler) server.cache.set(ModelTypes.upscaling, cache_key, upsampler)