1
0
Fork 0

fix(api): put conversion RNG on training device (#67)

This commit is contained in:
Sean Sube 2023-01-21 19:59:58 -06:00
parent 246aa3dd15
commit abc1ae5112
1 changed files with 2 additions and 2 deletions

View File

@ -84,7 +84,7 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
model.to(training_device).train(False) model.to(training_device).train(False)
model.eval() model.eval()
rng = torch.rand(1, 3, 64, 64) rng = torch.rand(1, 3, 64, 64, device=map_location)
input_names = ['data'] input_names = ['data']
output_names = ['output'] output_names = ['output']
dynamic_axes = {'data': {2: 'width', 3: 'height'}, dynamic_axes = {'data': {2: 'width', 3: 'height'},
@ -134,7 +134,7 @@ def convert_gfpgan(name: str, url: str, scale: int, opset: int):
model.to(training_device).train(False) model.to(training_device).train(False)
model.eval() model.eval()
rng = torch.rand(1, 3, 64, 64) rng = torch.rand(1, 3, 64, 64, device=map_location)
input_names = ['data'] input_names = ['data']
output_names = ['output'] output_names = ['output']
dynamic_axes = {'data': {2: 'width', 3: 'height'}, dynamic_axes = {'data': {2: 'width', 3: 'height'},