1
0
Fork 0

add image metadata to stage results

This commit is contained in:
Sean Sube 2024-01-02 21:24:27 -06:00
parent 0fa7eff8a8
commit 4edd39740b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 62 additions and 0 deletions

View File

@ -3,6 +3,41 @@ from typing import List, Optional
import numpy as np
from PIL import Image
from ..output import json_params
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
class ImageMetadata:
border: Border
highres: HighresParams
params: ImageParams
size: Size
upscale: UpscaleParams
def __init__(
self,
params: ImageParams,
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
) -> None:
self.params = params
self.size = size
self.upscale = upscale
self.border = border
self.highres = highres
def tojson(self):
return json_params(
[],
self.params,
self.size,
upscale=self.upscale,
border=self.border,
highres=self.highres,
)
class StageResult:
"""
@ -14,6 +49,7 @@ class StageResult:
arrays: Optional[List[np.ndarray]]
images: Optional[List[Image.Image]]
metadata: List[ImageMetadata]
@staticmethod
def empty():
@ -60,6 +96,32 @@ class StageResult:
else:
return []
def push_array(self, array: np.ndarray, metadata: Optional[ImageMetadata]):
if self.arrays is not None:
self.arrays.append(array)
elif self.images is not None:
self.images.append(Image.fromarray(np.uint8(array), shape_mode(array)))
else:
raise ValueError("invalid stage result")
if metadata is not None:
self.metadata.append(metadata)
else:
self.metadata.append(ImageMetadata())
def push_image(self, image: Image.Image, metadata: Optional[ImageMetadata]):
if self.images is not None:
self.images.append(image)
elif self.arrays is not None:
self.arrays.append(np.array(image))
else:
raise ValueError("invalid stage result")
if metadata is not None:
self.metadata.append(metadata)
else:
self.metadata.append(ImageMetadata())
def shape_mode(arr: np.ndarray) -> str:
if len(arr.shape) != 3: