add image metadata to stage results
This commit is contained in:
parent
0fa7eff8a8
commit
4edd39740b
|
@ -3,6 +3,41 @@ from typing import List, Optional
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..output import json_params
|
||||
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
|
||||
|
||||
|
||||
class ImageMetadata:
|
||||
border: Border
|
||||
highres: HighresParams
|
||||
params: ImageParams
|
||||
size: Size
|
||||
upscale: UpscaleParams
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
upscale: Optional[UpscaleParams] = None,
|
||||
border: Optional[Border] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
) -> None:
|
||||
self.params = params
|
||||
self.size = size
|
||||
self.upscale = upscale
|
||||
self.border = border
|
||||
self.highres = highres
|
||||
|
||||
def tojson(self):
|
||||
return json_params(
|
||||
[],
|
||||
self.params,
|
||||
self.size,
|
||||
upscale=self.upscale,
|
||||
border=self.border,
|
||||
highres=self.highres,
|
||||
)
|
||||
|
||||
|
||||
class StageResult:
|
||||
"""
|
||||
|
@ -14,6 +49,7 @@ class StageResult:
|
|||
|
||||
arrays: Optional[List[np.ndarray]]
|
||||
images: Optional[List[Image.Image]]
|
||||
metadata: List[ImageMetadata]
|
||||
|
||||
@staticmethod
|
||||
def empty():
|
||||
|
@ -60,6 +96,32 @@ class StageResult:
|
|||
else:
|
||||
return []
|
||||
|
||||
def push_array(self, array: np.ndarray, metadata: Optional[ImageMetadata]):
|
||||
if self.arrays is not None:
|
||||
self.arrays.append(array)
|
||||
elif self.images is not None:
|
||||
self.images.append(Image.fromarray(np.uint8(array), shape_mode(array)))
|
||||
else:
|
||||
raise ValueError("invalid stage result")
|
||||
|
||||
if metadata is not None:
|
||||
self.metadata.append(metadata)
|
||||
else:
|
||||
self.metadata.append(ImageMetadata())
|
||||
|
||||
def push_image(self, image: Image.Image, metadata: Optional[ImageMetadata]):
|
||||
if self.images is not None:
|
||||
self.images.append(image)
|
||||
elif self.arrays is not None:
|
||||
self.arrays.append(np.array(image))
|
||||
else:
|
||||
raise ValueError("invalid stage result")
|
||||
|
||||
if metadata is not None:
|
||||
self.metadata.append(metadata)
|
||||
else:
|
||||
self.metadata.append(ImageMetadata())
|
||||
|
||||
|
||||
def shape_mode(arr: np.ndarray) -> str:
|
||||
if len(arr.shape) != 3:
|
||||
|
|
Loading…
Reference in New Issue