fix(api): put conversion RNG on training device (#67)
This commit is contained in:
parent
246aa3dd15
commit
abc1ae5112
|
@ -84,7 +84,7 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
|||
model.to(training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
rng = torch.rand(1, 3, 64, 64)
|
||||
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
||||
input_names = ['data']
|
||||
output_names = ['output']
|
||||
dynamic_axes = {'data': {2: 'width', 3: 'height'},
|
||||
|
@ -134,7 +134,7 @@ def convert_gfpgan(name: str, url: str, scale: int, opset: int):
|
|||
model.to(training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
rng = torch.rand(1, 3, 64, 64)
|
||||
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
||||
input_names = ['data']
|
||||
output_names = ['output']
|
||||
dynamic_axes = {'data': {2: 'width', 3: 'height'},
|
||||
|
|
Loading…
Reference in New Issue