1
0
Fork 0
onnx-web/api/onnx_web/chain/result.py

74 lines
2.0 KiB
Python
Raw Normal View History

from typing import List, Optional
import numpy as np
2023-11-19 00:13:13 +00:00
from PIL import Image
class StageResult:
2023-11-19 00:13:13 +00:00
"""
Chain pipeline stage result.
Can contain PIL images or numpy arrays, with helpers to convert between them.
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
they are expected.
"""
arrays: Optional[List[np.ndarray]]
images: Optional[List[Image.Image]]
@staticmethod
def empty():
return StageResult(images=[])
2023-11-20 05:18:57 +00:00
@staticmethod
def from_arrays(arrays: List[np.ndarray]):
return StageResult(arrays=arrays)
@staticmethod
def from_images(images: List[Image.Image]):
return StageResult(images=images)
2023-11-19 00:13:13 +00:00
def __init__(self, arrays=None, images=None) -> None:
if arrays is not None and images is not None:
raise ValueError("stages must only return one type of result")
elif arrays is None and images is None:
raise ValueError("stages must return results")
self.arrays = arrays
self.images = images
def __len__(self) -> int:
if self.arrays is not None:
return len(self.arrays)
2023-12-03 18:13:45 +00:00
elif self.images is not None:
2023-11-19 00:13:13 +00:00
return len(self.images)
2023-12-03 18:13:45 +00:00
else:
return 0
2023-11-19 00:13:13 +00:00
def as_numpy(self) -> List[np.ndarray]:
if self.arrays is not None:
return self.arrays
2023-12-03 18:13:45 +00:00
elif self.images is not None:
return [np.array(i) for i in self.images]
else:
return []
2023-11-19 00:13:13 +00:00
def as_image(self) -> List[Image.Image]:
if self.images is not None:
return self.images
2023-12-03 18:13:45 +00:00
elif self.arrays is not None:
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
else:
return []
def shape_mode(arr: np.ndarray) -> str:
if len(arr.shape) != 3:
raise ValueError("unknown array format")
if arr.shape[-1] == 3:
return "RGB"
elif arr.shape[-1] == 4:
return "RGBA"
raise ValueError("unknown image format")