1
0
Fork 0

remove gradients from tiled VAE tensors

This commit is contained in:
Sean Sube 2023-04-27 23:41:43 -05:00
parent dbd9a186ae
commit 878e29ad3d
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 4 additions and 2 deletions

View File

@ -63,6 +63,7 @@ class VAEWrapper(object):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
@torch.no_grad()
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
Args:
@ -87,7 +88,7 @@ class VAEWrapper(object):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = torch.from_numpy(self(tile.numpy()))
tile = torch.from_numpy(self.wrapped(latent_sample=tile.numpy()))
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
@ -113,6 +114,7 @@ class VAEWrapper(object):
return AutoencoderKLOutput(latent_dist=posterior)
@torch.no_grad()
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""Decode a batch of images using a tiled decoder.
Args:
@ -140,7 +142,7 @@ class VAEWrapper(object):
for j in range(0, z.shape[3], overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
tile = self.post_quant_conv(tile)
decoded = torch.from_numpy(self(tile.numpy()))
decoded = torch.from_numpy(self.wrapped(latent_sample=tile.numpy()))
row.append(decoded)
rows.append(row)
result_rows = []