diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index 813f5863..ce1ee515 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -3,6 +3,41 @@ from typing import List, Optional import numpy as np from PIL import Image +from ..output import json_params +from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams + + +class ImageMetadata: + border: Border + highres: HighresParams + params: ImageParams + size: Size + upscale: UpscaleParams + + def __init__( + self, + params: ImageParams, + size: Size, + upscale: Optional[UpscaleParams] = None, + border: Optional[Border] = None, + highres: Optional[HighresParams] = None, + ) -> None: + self.params = params + self.size = size + self.upscale = upscale + self.border = border + self.highres = highres + + def tojson(self): + return json_params( + [], + self.params, + self.size, + upscale=self.upscale, + border=self.border, + highres=self.highres, + ) + class StageResult: """ @@ -14,6 +49,7 @@ class StageResult: arrays: Optional[List[np.ndarray]] images: Optional[List[Image.Image]] + metadata: List[ImageMetadata] @staticmethod def empty(): @@ -60,6 +96,32 @@ class StageResult: else: return [] + def push_array(self, array: np.ndarray, metadata: Optional[ImageMetadata]): + if self.arrays is not None: + self.arrays.append(array) + elif self.images is not None: + self.images.append(Image.fromarray(np.uint8(array), shape_mode(array))) + else: + raise ValueError("invalid stage result") + + if metadata is not None: + self.metadata.append(metadata) + else: + self.metadata.append(ImageMetadata()) + + def push_image(self, image: Image.Image, metadata: Optional[ImageMetadata]): + if self.images is not None: + self.images.append(image) + elif self.arrays is not None: + self.arrays.append(np.array(image)) + else: + raise ValueError("invalid stage result") + + if metadata is not None: + self.metadata.append(metadata) + else: + self.metadata.append(ImageMetadata()) + def shape_mode(arr: np.ndarray) -> str: if len(arr.shape) != 3: