diff --git a/api/onnx_web/diffusers/patches/vae.py b/api/onnx_web/diffusers/patches/vae.py index c58e2ff2..8422a86f 100644 --- a/api/onnx_web/diffusers/patches/vae.py +++ b/api/onnx_web/diffusers/patches/vae.py @@ -13,7 +13,8 @@ from diffusers import OnnxRuntimeModel logger = getLogger(__name__) LATENT_CHANNELS = 4 -SAMPLE_SIZE = 32 +LATENT_SIZE = 32 +SAMPLE_SIZE = 256 # TODO: does this need to change for fp16 modes? timestep_dtype = np.float32 @@ -31,7 +32,7 @@ class VAEWrapper(object): self.decoder = decoder self.tile_sample_min_size = SAMPLE_SIZE - self.tile_latent_min_size = SAMPLE_SIZE + self.tile_latent_min_size = LATENT_SIZE self.tile_overlap_factor = 0.25 self.quant_conv = nn.Conv2d(2 * LATENT_CHANNELS, 2 * LATENT_CHANNELS, 1) @@ -88,7 +89,7 @@ class VAEWrapper(object): row = [] for j in range(0, x.shape[3], overlap_size): tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] - tile = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())) + tile = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())[0]) tile = self.quant_conv(tile) row.append(tile) rows.append(row) @@ -142,9 +143,10 @@ class VAEWrapper(object): for j in range(0, z.shape[3], overlap_size): tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] tile = self.post_quant_conv(tile) - decoded = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())) + decoded = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())[0]) row.append(decoded) rows.append(row) + result_rows = [] for i, row in enumerate(rows): result_row = []