fix(api): put conversion RNG on training device (#67)
This commit is contained in:
parent
246aa3dd15
commit
abc1ae5112
|
@ -84,7 +84,7 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
||||||
model.to(training_device).train(False)
|
model.to(training_device).train(False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
rng = torch.rand(1, 3, 64, 64)
|
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
||||||
input_names = ['data']
|
input_names = ['data']
|
||||||
output_names = ['output']
|
output_names = ['output']
|
||||||
dynamic_axes = {'data': {2: 'width', 3: 'height'},
|
dynamic_axes = {'data': {2: 'width', 3: 'height'},
|
||||||
|
@ -134,7 +134,7 @@ def convert_gfpgan(name: str, url: str, scale: int, opset: int):
|
||||||
model.to(training_device).train(False)
|
model.to(training_device).train(False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
rng = torch.rand(1, 3, 64, 64)
|
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
||||||
input_names = ['data']
|
input_names = ['data']
|
||||||
output_names = ['output']
|
output_names = ['output']
|
||||||
dynamic_axes = {'data': {2: 'width', 3: 'height'},
|
dynamic_axes = {'data': {2: 'width', 3: 'height'},
|
||||||
|
|
Loading…
Reference in New Issue