diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index 409545b9..ac1243d2 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -231,9 +231,7 @@ class ChainPipeline: ) metadata = stage_sources.metadata - stage_sources = StageResult( - images=stage_results, metadata=metadata - ) + stage_sources = StageResult(images=stage_results, metadata=metadata) else: logger.debug( "image does not contain sources and is within tile size of %s, running stage", diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index b145d91c..0b59f293 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -327,10 +327,14 @@ class StageResult: def size(self) -> Size: if self.images is not None: - return Size(self.images[0].width, self.images[0].height) + return Size( + max([image.width for image in self.images]), + max([image.height for image in self.images]), + ) elif self.arrays is not None: return Size( - self.arrays[0].shape[0], self.arrays[0].shape[1] + max([array.shape[0] for array in self.arrays]), + max([array.shape[1] for array in self.arrays]), ) # TODO: which fields within the shape are width/height? else: return Size(0, 0)