fix(api): use training device when loading Real ESRGAN model (#67)
This commit is contained in:
parent
5286c4f596
commit
8c9c99eeb5
|
@ -53,7 +53,7 @@ extra_models: Models = {
|
||||||
model_path = environ.get('ONNX_WEB_MODEL_PATH',
|
model_path = environ.get('ONNX_WEB_MODEL_PATH',
|
||||||
path.join('..', 'models'))
|
path.join('..', 'models'))
|
||||||
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
map_location = None if torch.cuda.is_available() else torch.device('cpu')
|
map_location = torch.device(training_device)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
||||||
|
|
Loading…
Reference in New Issue