2024-01-04 01:09:18 +00:00
|
|
|
from json import dumps
|
|
|
|
from logging import getLogger
|
|
|
|
from os import path
|
2024-01-13 16:01:50 +00:00
|
|
|
from re import compile
|
2024-01-04 04:15:50 +00:00
|
|
|
from typing import Any, List, Optional, Tuple
|
2023-11-18 23:18:23 +00:00
|
|
|
|
|
|
|
import numpy as np
|
2023-11-19 00:13:13 +00:00
|
|
|
from PIL import Image
|
|
|
|
|
2024-01-04 01:09:18 +00:00
|
|
|
from ..convert.utils import resolve_tensor
|
2024-01-03 03:24:27 +00:00
|
|
|
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
|
2024-01-04 04:15:50 +00:00
|
|
|
from ..server.context import ServerContext
|
2024-01-04 01:09:18 +00:00
|
|
|
from ..server.load import get_extra_hashes
|
2024-01-13 16:01:50 +00:00
|
|
|
from ..utils import hash_file, load_config_str
|
2024-01-04 01:09:18 +00:00
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2024-01-13 16:01:50 +00:00
|
|
|
FLOAT_PATTERN = compile(r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?")
|
|
|
|
|
2024-01-04 01:09:18 +00:00
|
|
|
|
|
|
|
class NetworkMetadata:
|
|
|
|
name: str
|
|
|
|
hash: str
|
|
|
|
weight: float
|
|
|
|
|
|
|
|
def __init__(self, name: str, hash: str, weight: float) -> None:
|
|
|
|
self.name = name
|
|
|
|
self.hash = hash
|
|
|
|
self.weight = weight
|
2024-01-03 03:24:27 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ImageMetadata:
|
2024-01-06 22:59:02 +00:00
|
|
|
ancestors: List["ImageMetadata"]
|
2024-01-06 23:20:37 +00:00
|
|
|
note: str
|
2024-01-03 03:24:27 +00:00
|
|
|
params: ImageParams
|
|
|
|
size: Size
|
2024-01-06 22:59:02 +00:00
|
|
|
|
|
|
|
# models
|
|
|
|
inversions: List[NetworkMetadata]
|
|
|
|
loras: List[NetworkMetadata]
|
|
|
|
models: List[NetworkMetadata]
|
|
|
|
|
|
|
|
# optional params
|
|
|
|
border: Optional[Border]
|
|
|
|
highres: Optional[HighresParams]
|
|
|
|
upscale: Optional[UpscaleParams]
|
2024-01-03 03:24:27 +00:00
|
|
|
|
2024-01-06 20:17:26 +00:00
|
|
|
@staticmethod
|
|
|
|
def unknown_image() -> "ImageMetadata":
|
|
|
|
UNKNOWN_STR = "unknown"
|
|
|
|
return ImageMetadata(
|
|
|
|
ImageParams(UNKNOWN_STR, UNKNOWN_STR, UNKNOWN_STR, "", 0, 0, 0),
|
|
|
|
Size(0, 0),
|
|
|
|
)
|
|
|
|
|
2024-01-03 03:24:27 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
params: ImageParams,
|
|
|
|
size: Size,
|
|
|
|
upscale: Optional[UpscaleParams] = None,
|
|
|
|
border: Optional[Border] = None,
|
|
|
|
highres: Optional[HighresParams] = None,
|
2024-01-04 01:09:18 +00:00
|
|
|
inversions: Optional[List[NetworkMetadata]] = None,
|
|
|
|
loras: Optional[List[NetworkMetadata]] = None,
|
|
|
|
models: Optional[List[NetworkMetadata]] = None,
|
2024-01-06 22:59:02 +00:00
|
|
|
ancestors: Optional[List["ImageMetadata"]] = None,
|
2024-01-03 03:24:27 +00:00
|
|
|
) -> None:
|
|
|
|
self.params = params
|
|
|
|
self.size = size
|
|
|
|
self.upscale = upscale
|
|
|
|
self.border = border
|
|
|
|
self.highres = highres
|
2024-01-06 22:59:02 +00:00
|
|
|
self.inversions = inversions or []
|
|
|
|
self.loras = loras or []
|
|
|
|
self.models = models or []
|
|
|
|
self.ancestors = ancestors or []
|
2024-01-06 23:20:37 +00:00
|
|
|
self.note = ""
|
2024-01-06 22:59:02 +00:00
|
|
|
|
|
|
|
def child(
|
|
|
|
self,
|
|
|
|
params: ImageParams,
|
|
|
|
size: Size,
|
|
|
|
upscale: Optional[UpscaleParams] = None,
|
|
|
|
border: Optional[Border] = None,
|
|
|
|
highres: Optional[HighresParams] = None,
|
|
|
|
inversions: Optional[List[NetworkMetadata]] = None,
|
|
|
|
loras: Optional[List[NetworkMetadata]] = None,
|
|
|
|
models: Optional[List[NetworkMetadata]] = None,
|
|
|
|
) -> "ImageMetadata":
|
|
|
|
return ImageMetadata(
|
|
|
|
params,
|
|
|
|
size,
|
|
|
|
upscale,
|
|
|
|
border,
|
|
|
|
highres,
|
|
|
|
inversions,
|
|
|
|
loras,
|
|
|
|
models,
|
|
|
|
[self],
|
|
|
|
)
|
2024-01-03 03:24:27 +00:00
|
|
|
|
2024-01-06 03:11:44 +00:00
|
|
|
def get_model_hash(
|
|
|
|
self, server: ServerContext, model: Optional[str] = None
|
|
|
|
) -> Tuple[str, str]:
|
2024-01-04 04:15:50 +00:00
|
|
|
model_name = path.basename(path.normpath(model or self.params.model))
|
2024-01-04 01:09:18 +00:00
|
|
|
logger.debug("getting model hash for %s", model_name)
|
|
|
|
|
2024-01-06 03:11:44 +00:00
|
|
|
if model_name in server.hash_cache:
|
|
|
|
logger.debug("using cached model hash for %s", model_name)
|
|
|
|
return (model_name, server.hash_cache[model_name])
|
|
|
|
|
2024-01-04 01:09:18 +00:00
|
|
|
model_hash = get_extra_hashes().get(model_name, None)
|
|
|
|
if model_hash is None:
|
|
|
|
model_hash_path = path.join(self.params.model, "hash.txt")
|
|
|
|
if path.exists(model_hash_path):
|
|
|
|
with open(model_hash_path, "r") as f:
|
|
|
|
model_hash = f.readline().rstrip(",. \n\t\r")
|
|
|
|
|
2024-01-06 03:11:44 +00:00
|
|
|
model_hash = model_hash or "unknown"
|
|
|
|
server.hash_cache[model_name] = model_hash
|
|
|
|
|
|
|
|
return (model_name, model_hash)
|
|
|
|
|
|
|
|
def get_network_hash(
|
|
|
|
self, server: ServerContext, network_name: str, network_type: str
|
|
|
|
) -> Tuple[str, str]:
|
|
|
|
# run this again just in case the file path changes
|
|
|
|
network_path = resolve_tensor(
|
|
|
|
path.join(server.model_path, network_type, network_name)
|
|
|
|
)
|
|
|
|
|
|
|
|
if network_path in server.hash_cache:
|
|
|
|
logger.debug("using cached network hash for %s", network_path)
|
|
|
|
return (network_name, server.hash_cache[network_path])
|
|
|
|
|
|
|
|
network_hash = hash_file(network_path).upper()
|
|
|
|
server.hash_cache[network_path] = network_hash
|
|
|
|
|
|
|
|
return (network_name, network_hash)
|
2024-01-04 04:15:50 +00:00
|
|
|
|
2024-01-14 19:19:09 +00:00
|
|
|
def to_exif(self, server: ServerContext) -> str:
|
2024-01-06 03:11:44 +00:00
|
|
|
model_name, model_hash = self.get_model_hash(server)
|
2024-01-04 01:09:18 +00:00
|
|
|
hash_map = {
|
|
|
|
model_name: model_hash,
|
|
|
|
}
|
|
|
|
|
|
|
|
inversion_hashes = ""
|
|
|
|
if self.inversions is not None:
|
|
|
|
inversion_pairs = [
|
|
|
|
(
|
|
|
|
name,
|
2024-01-06 03:11:44 +00:00
|
|
|
self.get_network_hash(server, name, "inversion")[1],
|
2024-01-04 01:09:18 +00:00
|
|
|
)
|
|
|
|
for name, _weight in self.inversions
|
|
|
|
]
|
|
|
|
inversion_hashes = ",".join(
|
|
|
|
[f"{name}: {hash}" for name, hash in inversion_pairs]
|
|
|
|
)
|
|
|
|
hash_map.update(dict(inversion_pairs))
|
|
|
|
|
|
|
|
lora_hashes = ""
|
|
|
|
if self.loras is not None:
|
|
|
|
lora_pairs = [
|
|
|
|
(
|
|
|
|
name,
|
2024-01-06 03:11:44 +00:00
|
|
|
self.get_network_hash(server, name, "lora")[1],
|
2024-01-04 01:09:18 +00:00
|
|
|
)
|
|
|
|
for name, _weight in self.loras
|
|
|
|
]
|
|
|
|
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
|
|
|
|
hash_map.update(dict(lora_pairs))
|
|
|
|
|
|
|
|
return (
|
|
|
|
f"{self.params.prompt or ''}\nNegative prompt: {self.params.negative_prompt or ''}\n"
|
|
|
|
f"Steps: {self.params.steps}, Sampler: {self.params.scheduler}, CFG scale: {self.params.cfg}, "
|
|
|
|
f"Seed: {self.params.seed}, Size: {self.size.width}x{self.size.height}, "
|
|
|
|
f"Model hash: {model_hash}, Model: {model_name}, "
|
|
|
|
f"Tool: onnx-web, Version: {server.server_version}, "
|
|
|
|
f'Inversion hashes: "{inversion_hashes}", '
|
|
|
|
f'Lora hashes: "{lora_hashes}", '
|
|
|
|
f"Hashes: {dumps(hash_map)}"
|
2024-01-03 03:24:27 +00:00
|
|
|
)
|
|
|
|
|
2024-01-04 04:15:50 +00:00
|
|
|
def tojson(self, server: ServerContext, output: List[str]):
|
2024-01-04 01:09:18 +00:00
|
|
|
json = {
|
|
|
|
"input_size": self.size.tojson(),
|
2024-01-04 04:15:50 +00:00
|
|
|
"outputs": output,
|
2024-01-04 01:09:18 +00:00
|
|
|
"params": self.params.tojson(),
|
2024-01-04 04:15:50 +00:00
|
|
|
"inversions": [],
|
|
|
|
"loras": [],
|
|
|
|
"models": [],
|
2024-01-04 01:09:18 +00:00
|
|
|
}
|
|
|
|
|
2024-01-04 04:15:50 +00:00
|
|
|
# fix up some fields
|
2024-01-06 03:11:44 +00:00
|
|
|
model_name, model_hash = self.get_model_hash(server, self.params.model)
|
2024-01-04 05:46:50 +00:00
|
|
|
json["params"]["model"] = model_name
|
2024-01-04 05:38:44 +00:00
|
|
|
json["models"].append(
|
|
|
|
{
|
2024-01-04 05:46:50 +00:00
|
|
|
"hash": model_hash,
|
|
|
|
"name": model_name,
|
2024-01-04 05:38:44 +00:00
|
|
|
"weight": 1.0,
|
|
|
|
}
|
|
|
|
)
|
2024-01-04 01:09:18 +00:00
|
|
|
|
2024-01-14 15:41:54 +00:00
|
|
|
# add optional params
|
2024-01-04 01:09:18 +00:00
|
|
|
if self.border is not None:
|
|
|
|
json["border"] = self.border.tojson()
|
|
|
|
|
|
|
|
if self.highres is not None:
|
|
|
|
json["highres"] = self.highres.tojson()
|
|
|
|
|
|
|
|
if self.upscale is not None:
|
|
|
|
json["upscale"] = self.upscale.tojson()
|
|
|
|
|
2024-01-14 15:41:54 +00:00
|
|
|
# calculate final output size
|
|
|
|
json["size"] = self.get_output_size().tojson()
|
2024-01-04 01:09:18 +00:00
|
|
|
|
2024-01-14 15:41:54 +00:00
|
|
|
# hash and add models and networks
|
2024-01-04 01:09:18 +00:00
|
|
|
if self.inversions is not None:
|
|
|
|
for name, weight in self.inversions:
|
2024-01-13 04:58:52 +00:00
|
|
|
model_hash = self.get_network_hash(server, name, "inversion")[1]
|
2024-01-04 04:15:50 +00:00
|
|
|
json["inversions"].append(
|
2024-01-13 04:58:52 +00:00
|
|
|
{"name": name, "weight": weight, "hash": model_hash}
|
2024-01-04 04:15:50 +00:00
|
|
|
)
|
2024-01-04 01:09:18 +00:00
|
|
|
|
|
|
|
if self.loras is not None:
|
|
|
|
for name, weight in self.loras:
|
2024-01-13 04:58:52 +00:00
|
|
|
model_hash = self.get_network_hash(server, name, "lora")[1]
|
|
|
|
json["loras"].append(
|
|
|
|
{"name": name, "weight": weight, "hash": model_hash}
|
|
|
|
)
|
2024-01-04 04:15:50 +00:00
|
|
|
|
|
|
|
if self.models is not None:
|
|
|
|
for name, weight in self.models:
|
2024-01-13 04:58:52 +00:00
|
|
|
name, model_hash = self.get_model_hash(server)
|
|
|
|
json["models"].append(
|
|
|
|
{"name": name, "weight": weight, "hash": model_hash}
|
|
|
|
)
|
2024-01-04 01:09:18 +00:00
|
|
|
|
|
|
|
return json
|
|
|
|
|
2024-01-14 15:41:54 +00:00
|
|
|
def get_output_size(self) -> Size:
|
|
|
|
output_size = self.size
|
|
|
|
if self.border is not None:
|
|
|
|
output_size = output_size.add_border(self.border)
|
|
|
|
|
|
|
|
if self.highres is not None:
|
|
|
|
output_size = self.highres.resize(output_size)
|
|
|
|
|
|
|
|
if self.upscale is not None:
|
|
|
|
output_size = self.upscale.resize(output_size)
|
|
|
|
|
|
|
|
return output_size
|
|
|
|
|
2024-01-14 18:24:59 +00:00
|
|
|
def with_args(
|
|
|
|
self,
|
|
|
|
params: Optional[ImageParams] = None,
|
|
|
|
size: Optional[Size] = None,
|
|
|
|
upscale: Optional[UpscaleParams] = None,
|
|
|
|
border: Optional[Border] = None,
|
|
|
|
highres: Optional[HighresParams] = None,
|
|
|
|
inversions: Optional[List[NetworkMetadata]] = None,
|
|
|
|
loras: Optional[List[NetworkMetadata]] = None,
|
|
|
|
models: Optional[List[NetworkMetadata]] = None,
|
|
|
|
ancestors: Optional[List["ImageMetadata"]] = None,
|
|
|
|
) -> "ImageMetadata":
|
|
|
|
return ImageMetadata(
|
|
|
|
params or self.params,
|
|
|
|
size or self.size,
|
|
|
|
upscale=upscale or self.upscale,
|
|
|
|
border=border or self.border,
|
|
|
|
highres=highres or self.highres,
|
|
|
|
inversions=inversions or self.inversions,
|
|
|
|
loras=loras or self.loras,
|
|
|
|
models=models or self.models,
|
|
|
|
ancestors=ancestors or self.ancestors,
|
|
|
|
)
|
|
|
|
|
2024-01-13 16:01:50 +00:00
|
|
|
@staticmethod
|
|
|
|
def from_exif(input: str) -> "ImageMetadata":
|
|
|
|
lines = input.splitlines()
|
|
|
|
prompt, maybe_negative, *rest = lines
|
|
|
|
|
|
|
|
# process negative prompt or put that line back into rest
|
|
|
|
if maybe_negative.startswith("Negative prompt:"):
|
|
|
|
negative_prompt = maybe_negative[len("Negative prompt:") :]
|
|
|
|
negative_prompt = negative_prompt.strip()
|
|
|
|
else:
|
|
|
|
rest.insert(0, maybe_negative)
|
|
|
|
negative_prompt = None
|
|
|
|
|
|
|
|
rest = " ".join(rest)
|
|
|
|
other_params = rest.split(",")
|
|
|
|
|
|
|
|
# process other params
|
|
|
|
params = {}
|
|
|
|
size = None
|
|
|
|
for param in other_params:
|
|
|
|
key, value = param.split(":")
|
|
|
|
key = key.strip().lower()
|
|
|
|
value = value.strip()
|
|
|
|
|
|
|
|
if key == "size":
|
|
|
|
width, height = value.split("x")
|
|
|
|
width = int(width.strip())
|
|
|
|
height = int(height.strip())
|
|
|
|
size = Size(width, height)
|
|
|
|
elif value.isdecimal():
|
|
|
|
value = int(value)
|
|
|
|
elif FLOAT_PATTERN.match(value) is not None:
|
|
|
|
value = float(value)
|
|
|
|
|
|
|
|
params[key] = value
|
|
|
|
|
|
|
|
params = ImageParams(
|
|
|
|
"TODO",
|
|
|
|
"txt2img", # TODO: can this be detected?
|
|
|
|
params["sampler"],
|
|
|
|
prompt,
|
|
|
|
params["cfg scale"],
|
|
|
|
params["steps"],
|
|
|
|
params["seed"],
|
|
|
|
negative_prompt,
|
|
|
|
)
|
|
|
|
return ImageMetadata(params, size)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def from_json(input: str) -> "ImageMetadata":
|
|
|
|
data = load_config_str(input)
|
|
|
|
# TODO: enforce schema
|
|
|
|
|
|
|
|
return ImageMetadata(
|
|
|
|
data["params"],
|
|
|
|
data["input_size"],
|
|
|
|
data.get("upscale", None),
|
|
|
|
data.get("border", None),
|
|
|
|
data.get("highres", None),
|
|
|
|
data.get("inversions", None),
|
|
|
|
data.get("loras", None),
|
|
|
|
data.get("models", None),
|
|
|
|
)
|
|
|
|
|
2023-11-18 23:18:23 +00:00
|
|
|
|
2024-01-13 04:58:52 +00:00
|
|
|
ERROR_NO_METADATA = "metadata must be provided"
|
|
|
|
|
|
|
|
|
2023-11-18 23:18:23 +00:00
|
|
|
class StageResult:
|
2023-11-19 00:13:13 +00:00
|
|
|
"""
|
|
|
|
Chain pipeline stage result.
|
|
|
|
Can contain PIL images or numpy arrays, with helpers to convert between them.
|
|
|
|
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
|
|
|
|
they are expected.
|
|
|
|
"""
|
|
|
|
|
|
|
|
arrays: Optional[List[np.ndarray]]
|
|
|
|
images: Optional[List[Image.Image]]
|
2024-01-03 03:24:27 +00:00
|
|
|
metadata: List[ImageMetadata]
|
2023-11-19 00:13:13 +00:00
|
|
|
|
2023-11-19 03:35:00 +00:00
|
|
|
@staticmethod
|
|
|
|
def empty():
|
|
|
|
return StageResult(images=[])
|
|
|
|
|
2023-11-20 05:18:57 +00:00
|
|
|
@staticmethod
|
2024-01-06 20:17:26 +00:00
|
|
|
def from_arrays(arrays: List[np.ndarray], metadata: List[ImageMetadata]):
|
2024-01-06 02:11:58 +00:00
|
|
|
return StageResult(arrays=arrays, metadata=metadata)
|
2023-11-20 05:18:57 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2024-01-06 20:17:26 +00:00
|
|
|
def from_images(images: List[Image.Image], metadata: List[ImageMetadata]):
|
2024-01-06 02:11:58 +00:00
|
|
|
return StageResult(images=images, metadata=metadata)
|
2023-11-20 05:18:57 +00:00
|
|
|
|
2024-01-03 03:49:22 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
arrays: Optional[List[np.ndarray]] = None,
|
|
|
|
images: Optional[List[Image.Image]] = None,
|
2024-01-06 20:17:26 +00:00
|
|
|
metadata: Optional[List[ImageMetadata]] = None, # TODO: should not be optional
|
2024-01-03 03:49:22 +00:00
|
|
|
source: Optional[Any] = None,
|
|
|
|
) -> None:
|
2024-01-06 02:11:58 +00:00
|
|
|
data_provided = sum(
|
|
|
|
[arrays is not None, images is not None, source is not None]
|
|
|
|
)
|
|
|
|
if data_provided > 1:
|
|
|
|
raise ValueError("results must only contain one type of data")
|
|
|
|
elif data_provided == 0:
|
|
|
|
raise ValueError("results must contain some data")
|
|
|
|
|
|
|
|
if source is not None:
|
|
|
|
self.arrays = source.arrays
|
|
|
|
self.images = source.images
|
|
|
|
self.metadata = source.metadata
|
|
|
|
else:
|
|
|
|
self.arrays = arrays
|
|
|
|
self.images = images
|
|
|
|
self.metadata = metadata or []
|
2023-11-19 00:13:13 +00:00
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
if self.arrays is not None:
|
|
|
|
return len(self.arrays)
|
2023-12-03 18:13:45 +00:00
|
|
|
elif self.images is not None:
|
2023-11-19 00:13:13 +00:00
|
|
|
return len(self.images)
|
2023-12-03 18:13:45 +00:00
|
|
|
else:
|
2023-12-03 18:57:56 +00:00
|
|
|
return 0
|
2023-11-19 00:13:13 +00:00
|
|
|
|
2024-01-06 02:11:58 +00:00
|
|
|
def as_arrays(self) -> List[np.ndarray]:
|
2023-11-19 00:13:13 +00:00
|
|
|
if self.arrays is not None:
|
|
|
|
return self.arrays
|
2023-12-03 18:13:45 +00:00
|
|
|
elif self.images is not None:
|
|
|
|
return [np.array(i) for i in self.images]
|
|
|
|
else:
|
2023-12-03 18:57:56 +00:00
|
|
|
return []
|
2023-11-19 00:13:13 +00:00
|
|
|
|
2024-01-06 02:11:58 +00:00
|
|
|
def as_images(self) -> List[Image.Image]:
|
2023-11-19 00:13:13 +00:00
|
|
|
if self.images is not None:
|
|
|
|
return self.images
|
2023-12-03 18:13:45 +00:00
|
|
|
elif self.arrays is not None:
|
|
|
|
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
|
|
|
|
else:
|
2023-12-03 18:57:56 +00:00
|
|
|
return []
|
2023-11-26 00:52:47 +00:00
|
|
|
|
2024-01-06 02:11:58 +00:00
|
|
|
def push_array(self, array: np.ndarray, metadata: ImageMetadata):
|
2024-01-03 03:24:27 +00:00
|
|
|
if self.arrays is not None:
|
|
|
|
self.arrays.append(array)
|
|
|
|
elif self.images is not None:
|
|
|
|
self.images.append(Image.fromarray(np.uint8(array), shape_mode(array)))
|
|
|
|
else:
|
2024-01-04 01:09:18 +00:00
|
|
|
self.arrays = [array]
|
2024-01-03 03:24:27 +00:00
|
|
|
|
|
|
|
if metadata is not None:
|
|
|
|
self.metadata.append(metadata)
|
|
|
|
else:
|
2024-01-13 04:58:52 +00:00
|
|
|
raise ValueError(ERROR_NO_METADATA)
|
2024-01-03 03:24:27 +00:00
|
|
|
|
2024-01-06 02:11:58 +00:00
|
|
|
def push_image(self, image: Image.Image, metadata: ImageMetadata):
|
2024-01-03 03:24:27 +00:00
|
|
|
if self.images is not None:
|
|
|
|
self.images.append(image)
|
|
|
|
elif self.arrays is not None:
|
|
|
|
self.arrays.append(np.array(image))
|
|
|
|
else:
|
2024-01-04 01:09:18 +00:00
|
|
|
self.images = [image]
|
2024-01-03 03:24:27 +00:00
|
|
|
|
|
|
|
if metadata is not None:
|
|
|
|
self.metadata.append(metadata)
|
|
|
|
else:
|
2024-01-13 04:58:52 +00:00
|
|
|
raise ValueError(ERROR_NO_METADATA)
|
2024-01-03 03:24:27 +00:00
|
|
|
|
2024-01-06 02:11:58 +00:00
|
|
|
def insert_array(self, index: int, array: np.ndarray, metadata: ImageMetadata):
|
2024-01-04 01:09:18 +00:00
|
|
|
if self.arrays is not None:
|
|
|
|
self.arrays.insert(index, array)
|
|
|
|
elif self.images is not None:
|
|
|
|
self.images.insert(
|
|
|
|
index, Image.fromarray(np.uint8(array), shape_mode(array))
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
self.arrays = [array]
|
|
|
|
|
|
|
|
if metadata is not None:
|
|
|
|
self.metadata.insert(index, metadata)
|
|
|
|
else:
|
2024-01-13 04:58:52 +00:00
|
|
|
raise ValueError(ERROR_NO_METADATA)
|
2024-01-04 01:09:18 +00:00
|
|
|
|
2024-01-06 02:11:58 +00:00
|
|
|
def insert_image(self, index: int, image: Image.Image, metadata: ImageMetadata):
|
2024-01-04 01:09:18 +00:00
|
|
|
if self.images is not None:
|
|
|
|
self.images.insert(index, image)
|
|
|
|
elif self.arrays is not None:
|
|
|
|
self.arrays.insert(index, np.array(image))
|
|
|
|
else:
|
|
|
|
self.images = [image]
|
|
|
|
|
|
|
|
if metadata is not None:
|
|
|
|
self.metadata.insert(index, metadata)
|
|
|
|
else:
|
2024-01-13 04:58:52 +00:00
|
|
|
raise ValueError(ERROR_NO_METADATA)
|
2024-01-06 02:11:58 +00:00
|
|
|
|
|
|
|
def size(self) -> Size:
|
|
|
|
if self.images is not None:
|
2024-01-06 03:19:48 +00:00
|
|
|
return Size(
|
2024-01-06 03:23:05 +00:00
|
|
|
max([image.width for image in self.images], default=0),
|
|
|
|
max([image.height for image in self.images], default=0),
|
2024-01-06 03:19:48 +00:00
|
|
|
)
|
2024-01-06 02:11:58 +00:00
|
|
|
elif self.arrays is not None:
|
|
|
|
return Size(
|
2024-01-06 08:33:01 +00:00
|
|
|
max([array.shape[0] for array in self.arrays], default=0),
|
|
|
|
max([array.shape[1] for array in self.arrays], default=0),
|
2024-01-06 02:11:58 +00:00
|
|
|
) # TODO: which fields within the shape are width/height?
|
|
|
|
else:
|
|
|
|
return Size(0, 0)
|
|
|
|
|
|
|
|
def validate(self) -> None:
|
|
|
|
"""
|
|
|
|
Make sure the data exists and that data and metadata match in length.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if self.arrays is None and self.images is None:
|
|
|
|
raise ValueError("no data in result")
|
|
|
|
|
|
|
|
if len(self) != len(self.metadata):
|
|
|
|
raise ValueError("metadata and data do not match in length")
|
2024-01-04 01:09:18 +00:00
|
|
|
|
2023-11-26 00:52:47 +00:00
|
|
|
|
|
|
|
def shape_mode(arr: np.ndarray) -> str:
|
|
|
|
if len(arr.shape) != 3:
|
|
|
|
raise ValueError("unknown array format")
|
|
|
|
|
|
|
|
if arr.shape[-1] == 3:
|
|
|
|
return "RGB"
|
|
|
|
elif arr.shape[-1] == 4:
|
|
|
|
return "RGBA"
|
|
|
|
|
2023-11-26 05:18:57 +00:00
|
|
|
raise ValueError("unknown image format")
|