fix sample size, use first output
This commit is contained in:
parent
878e29ad3d
commit
577e2320f5
|
@ -13,7 +13,8 @@ from diffusers import OnnxRuntimeModel
|
|||
logger = getLogger(__name__)
|
||||
|
||||
LATENT_CHANNELS = 4
|
||||
SAMPLE_SIZE = 32
|
||||
LATENT_SIZE = 32
|
||||
SAMPLE_SIZE = 256
|
||||
|
||||
# TODO: does this need to change for fp16 modes?
|
||||
timestep_dtype = np.float32
|
||||
|
@ -31,7 +32,7 @@ class VAEWrapper(object):
|
|||
self.decoder = decoder
|
||||
|
||||
self.tile_sample_min_size = SAMPLE_SIZE
|
||||
self.tile_latent_min_size = SAMPLE_SIZE
|
||||
self.tile_latent_min_size = LATENT_SIZE
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
self.quant_conv = nn.Conv2d(2 * LATENT_CHANNELS, 2 * LATENT_CHANNELS, 1)
|
||||
|
@ -88,7 +89,7 @@ class VAEWrapper(object):
|
|||
row = []
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = torch.from_numpy(self.wrapped(latent_sample=tile.numpy()))
|
||||
tile = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())[0])
|
||||
tile = self.quant_conv(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
|
@ -142,9 +143,10 @@ class VAEWrapper(object):
|
|||
for j in range(0, z.shape[3], overlap_size):
|
||||
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
||||
tile = self.post_quant_conv(tile)
|
||||
decoded = torch.from_numpy(self.wrapped(latent_sample=tile.numpy()))
|
||||
decoded = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())[0])
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
|
|
Loading…
Reference in New Issue