remove gradients from tiled VAE tensors
This commit is contained in:
parent
dbd9a186ae
commit
878e29ad3d
|
@ -63,6 +63,7 @@ class VAEWrapper(object):
|
|||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
@torch.no_grad()
|
||||
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
Args:
|
||||
|
@ -87,7 +88,7 @@ class VAEWrapper(object):
|
|||
row = []
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = torch.from_numpy(self(tile.numpy()))
|
||||
tile = torch.from_numpy(self.wrapped(latent_sample=tile.numpy()))
|
||||
tile = self.quant_conv(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
|
@ -113,6 +114,7 @@ class VAEWrapper(object):
|
|||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
@torch.no_grad()
|
||||
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
r"""Decode a batch of images using a tiled decoder.
|
||||
Args:
|
||||
|
@ -140,7 +142,7 @@ class VAEWrapper(object):
|
|||
for j in range(0, z.shape[3], overlap_size):
|
||||
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
||||
tile = self.post_quant_conv(tile)
|
||||
decoded = torch.from_numpy(self(tile.numpy()))
|
||||
decoded = torch.from_numpy(self.wrapped(latent_sample=tile.numpy()))
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
|
|
Loading…
Reference in New Issue