pass device to wrapper
This commit is contained in:
parent
d17b946091
commit
7d56689527
|
@ -46,6 +46,7 @@ class UpscaleRealESRGANStage(BaseStage):
|
||||||
self.mod_scale = None
|
self.mod_scale = None
|
||||||
self.half = half
|
self.half = half
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.device = device
|
||||||
|
|
||||||
model_file = "%s.%s" % (params.upscale_model, params.format)
|
model_file = "%s.%s" % (params.upscale_model, params.format)
|
||||||
model_path = path.join(server.model_path, model_file)
|
model_path = path.join(server.model_path, model_file)
|
||||||
|
@ -75,15 +76,16 @@ class UpscaleRealESRGANStage(BaseStage):
|
||||||
|
|
||||||
logger.debug("loading Real ESRGAN upscale model from %s", model_path)
|
logger.debug("loading Real ESRGAN upscale model from %s", model_path)
|
||||||
|
|
||||||
# TODO: shouldn't need the PTH file
|
|
||||||
upsampler = RealESRGANWrapper(
|
upsampler = RealESRGANWrapper(
|
||||||
scale=params.scale,
|
scale=params.scale,
|
||||||
|
model_path=None,
|
||||||
dni_weight=dni_weight,
|
dni_weight=dni_weight,
|
||||||
model=model,
|
model=model,
|
||||||
tile=tile,
|
tile=tile,
|
||||||
tile_pad=params.tile_pad,
|
tile_pad=params.tile_pad,
|
||||||
pre_pad=params.pre_pad,
|
pre_pad=params.pre_pad,
|
||||||
half=("torch-fp16" in server.optimizations),
|
half=("torch-fp16" in server.optimizations),
|
||||||
|
device=device.torch_str(),
|
||||||
)
|
)
|
||||||
|
|
||||||
server.cache.set(ModelTypes.upscaling, cache_key, upsampler)
|
server.cache.set(ModelTypes.upscaling, cache_key, upsampler)
|
||||||
|
|
Loading…
Reference in New Issue