correct param names for VAE encoder wrapper
This commit is contained in:
parent
9a2421ee47
commit
6d05f40ecb
|
@ -55,7 +55,7 @@ class VAEWrapper(object):
|
|||
if self.decoder:
|
||||
return self.tiled_decode(latent_sample, **kwargs)
|
||||
else:
|
||||
return self.tiled_encode(latent_sample, **kwargs)
|
||||
return self.tiled_encode(sample, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.wrapped, attr)
|
||||
|
@ -106,7 +106,7 @@ class VAEWrapper(object):
|
|||
i : i + self.tile_sample_min_size,
|
||||
j : j + self.tile_sample_min_size,
|
||||
]
|
||||
tile = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())[0])
|
||||
tile = torch.from_numpy(self.wrapped(sample=tile.numpy())[0])
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
|
|
Loading…
Reference in New Issue