1
0
Fork 0

fix sample size, use first output

This commit is contained in:
Sean Sube 2023-04-28 13:30:37 -05:00
parent 878e29ad3d
commit 577e2320f5
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 6 additions and 4 deletions

View File

@ -13,7 +13,8 @@ from diffusers import OnnxRuntimeModel
logger = getLogger(__name__)
LATENT_CHANNELS = 4
SAMPLE_SIZE = 32
LATENT_SIZE = 32
SAMPLE_SIZE = 256
# TODO: does this need to change for fp16 modes?
timestep_dtype = np.float32
@ -31,7 +32,7 @@ class VAEWrapper(object):
self.decoder = decoder
self.tile_sample_min_size = SAMPLE_SIZE
self.tile_latent_min_size = SAMPLE_SIZE
self.tile_latent_min_size = LATENT_SIZE
self.tile_overlap_factor = 0.25
self.quant_conv = nn.Conv2d(2 * LATENT_CHANNELS, 2 * LATENT_CHANNELS, 1)
@ -88,7 +89,7 @@ class VAEWrapper(object):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = torch.from_numpy(self.wrapped(latent_sample=tile.numpy()))
tile = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())[0])
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
@ -142,9 +143,10 @@ class VAEWrapper(object):
for j in range(0, z.shape[3], overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
tile = self.post_quant_conv(tile)
decoded = torch.from_numpy(self.wrapped(latent_sample=tile.numpy()))
decoded = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())[0])
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []