1
0
Fork 0

correct param names for VAE encoder wrapper

This commit is contained in:
Sean Sube 2023-04-28 20:37:59 -05:00
parent 9a2421ee47
commit 6d05f40ecb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 2 additions and 2 deletions

View File

@ -55,7 +55,7 @@ class VAEWrapper(object):
if self.decoder:
return self.tiled_decode(latent_sample, **kwargs)
else:
return self.tiled_encode(latent_sample, **kwargs)
return self.tiled_encode(sample, **kwargs)
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
@ -106,7 +106,7 @@ class VAEWrapper(object):
i : i + self.tile_sample_min_size,
j : j + self.tile_sample_min_size,
]
tile = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())[0])
tile = torch.from_numpy(self.wrapped(sample=tile.numpy())[0])
row.append(tile)
rows.append(row)
result_rows = []