2023-11-18 23:18:23 +00:00
|
|
|
from typing import List, Optional
|
|
|
|
|
|
|
|
import numpy as np
|
2023-11-19 00:13:13 +00:00
|
|
|
from PIL import Image
|
|
|
|
|
2023-11-18 23:18:23 +00:00
|
|
|
|
|
|
|
class StageResult:
|
2023-11-19 00:13:13 +00:00
|
|
|
"""
|
|
|
|
Chain pipeline stage result.
|
|
|
|
Can contain PIL images or numpy arrays, with helpers to convert between them.
|
|
|
|
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
|
|
|
|
they are expected.
|
|
|
|
"""
|
|
|
|
|
|
|
|
arrays: Optional[List[np.ndarray]]
|
|
|
|
images: Optional[List[Image.Image]]
|
|
|
|
|
2023-11-19 03:35:00 +00:00
|
|
|
@staticmethod
|
|
|
|
def empty():
|
|
|
|
return StageResult(images=[])
|
|
|
|
|
2023-11-20 05:18:57 +00:00
|
|
|
@staticmethod
|
|
|
|
def from_arrays(arrays: List[np.ndarray]):
|
|
|
|
return StageResult(arrays=arrays)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def from_images(images: List[Image.Image]):
|
|
|
|
return StageResult(images=images)
|
|
|
|
|
2023-11-19 00:13:13 +00:00
|
|
|
def __init__(self, arrays=None, images=None) -> None:
|
|
|
|
if arrays is not None and images is not None:
|
|
|
|
raise ValueError("stages must only return one type of result")
|
|
|
|
elif arrays is None and images is None:
|
|
|
|
raise ValueError("stages must return results")
|
|
|
|
|
|
|
|
self.arrays = arrays
|
|
|
|
self.images = images
|
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
if self.arrays is not None:
|
|
|
|
return len(self.arrays)
|
2023-12-03 18:13:45 +00:00
|
|
|
elif self.images is not None:
|
2023-11-19 00:13:13 +00:00
|
|
|
return len(self.images)
|
2023-12-03 18:13:45 +00:00
|
|
|
else:
|
2023-12-03 18:57:56 +00:00
|
|
|
return 0
|
2023-11-19 00:13:13 +00:00
|
|
|
|
|
|
|
def as_numpy(self) -> List[np.ndarray]:
|
|
|
|
if self.arrays is not None:
|
|
|
|
return self.arrays
|
2023-12-03 18:13:45 +00:00
|
|
|
elif self.images is not None:
|
|
|
|
return [np.array(i) for i in self.images]
|
|
|
|
else:
|
2023-12-03 18:57:56 +00:00
|
|
|
return []
|
2023-11-19 00:13:13 +00:00
|
|
|
|
|
|
|
def as_image(self) -> List[Image.Image]:
|
|
|
|
if self.images is not None:
|
|
|
|
return self.images
|
2023-12-03 18:13:45 +00:00
|
|
|
elif self.arrays is not None:
|
|
|
|
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
|
|
|
|
else:
|
2023-12-03 18:57:56 +00:00
|
|
|
return []
|
2023-11-26 00:52:47 +00:00
|
|
|
|
|
|
|
|
|
|
|
def shape_mode(arr: np.ndarray) -> str:
|
|
|
|
if len(arr.shape) != 3:
|
|
|
|
raise ValueError("unknown array format")
|
|
|
|
|
|
|
|
if arr.shape[-1] == 3:
|
|
|
|
return "RGB"
|
|
|
|
elif arr.shape[-1] == 4:
|
|
|
|
return "RGBA"
|
|
|
|
|
2023-11-26 05:18:57 +00:00
|
|
|
raise ValueError("unknown image format")
|