1
0
Fork 0
onnx-web/api/onnx_web/chain/result.py

136 lines
3.8 KiB
Python

from typing import List, Optional
import numpy as np
from PIL import Image
from ..output import json_params
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
class ImageMetadata:
border: Border
highres: HighresParams
params: ImageParams
size: Size
upscale: UpscaleParams
def __init__(
self,
params: ImageParams,
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
) -> None:
self.params = params
self.size = size
self.upscale = upscale
self.border = border
self.highres = highres
def tojson(self):
return json_params(
[],
self.params,
self.size,
upscale=self.upscale,
border=self.border,
highres=self.highres,
)
class StageResult:
"""
Chain pipeline stage result.
Can contain PIL images or numpy arrays, with helpers to convert between them.
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
they are expected.
"""
arrays: Optional[List[np.ndarray]]
images: Optional[List[Image.Image]]
metadata: List[ImageMetadata]
@staticmethod
def empty():
return StageResult(images=[])
@staticmethod
def from_arrays(arrays: List[np.ndarray]):
return StageResult(arrays=arrays)
@staticmethod
def from_images(images: List[Image.Image]):
return StageResult(images=images)
def __init__(self, arrays=None, images=None) -> None:
if arrays is not None and images is not None:
raise ValueError("stages must only return one type of result")
elif arrays is None and images is None:
raise ValueError("stages must return results")
self.arrays = arrays
self.images = images
def __len__(self) -> int:
if self.arrays is not None:
return len(self.arrays)
elif self.images is not None:
return len(self.images)
else:
return 0
def as_numpy(self) -> List[np.ndarray]:
if self.arrays is not None:
return self.arrays
elif self.images is not None:
return [np.array(i) for i in self.images]
else:
return []
def as_image(self) -> List[Image.Image]:
if self.images is not None:
return self.images
elif self.arrays is not None:
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
else:
return []
def push_array(self, array: np.ndarray, metadata: Optional[ImageMetadata]):
if self.arrays is not None:
self.arrays.append(array)
elif self.images is not None:
self.images.append(Image.fromarray(np.uint8(array), shape_mode(array)))
else:
raise ValueError("invalid stage result")
if metadata is not None:
self.metadata.append(metadata)
else:
self.metadata.append(ImageMetadata())
def push_image(self, image: Image.Image, metadata: Optional[ImageMetadata]):
if self.images is not None:
self.images.append(image)
elif self.arrays is not None:
self.arrays.append(np.array(image))
else:
raise ValueError("invalid stage result")
if metadata is not None:
self.metadata.append(metadata)
else:
self.metadata.append(ImageMetadata())
def shape_mode(arr: np.ndarray) -> str:
if len(arr.shape) != 3:
raise ValueError("unknown array format")
if arr.shape[-1] == 3:
return "RGB"
elif arr.shape[-1] == 4:
return "RGBA"
raise ValueError("unknown image format")