1
0
Fork 0

fix(api): ensure panorama never generates a negative number of views

This commit is contained in:
Sean Sube 2023-08-20 15:07:51 -05:00
parent 8e1f188d8f
commit 944c92b824
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 10 additions and 8 deletions

View File

@ -350,9 +350,17 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
panorama_height /= 8
panorama_width /= 8
num_blocks_height = (panorama_height - window_size) // stride + 1
num_blocks_width = (panorama_width - window_size) // stride + 1
num_blocks_height = abs((panorama_height - window_size) // stride) + 1
num_blocks_width = abs((panorama_width - window_size) // stride) + 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
logger.debug(
"panorama generated %s views, %s by %s blocks",
total_num_blocks,
num_blocks_height,
num_blocks_width,
)
views = []
for i in range(total_num_blocks):
h_start = int((i // num_blocks_width) * stride)
@ -361,12 +369,6 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
w_end = w_start + window_size
views.append((h_start, h_end, w_start, w_end))
logger.debug(
"panorama generated %s views, %s by %s blocks",
total_num_blocks,
num_blocks_height,
num_blocks_width,
)
return views
@torch.no_grad()