1
0
Fork 0
onnx-web/api/onnx_web/chain/result.py

518 lines
17 KiB
Python
Raw Permalink Normal View History

from json import dumps
from logging import getLogger
from os import path
from re import compile
2024-01-04 04:15:50 +00:00
from typing import Any, List, Optional, Tuple
import numpy as np
2023-11-19 00:13:13 +00:00
from PIL import Image
from ..convert.utils import resolve_tensor
2024-01-03 03:24:27 +00:00
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
2024-01-04 04:15:50 +00:00
from ..server.context import ServerContext
from ..server.load import get_extra_hashes
from ..utils import coalesce, hash_file, load_config_str
logger = getLogger(__name__)
FLOAT_PATTERN = compile(r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?")
class NetworkMetadata:
name: str
hash: str
weight: float
def __init__(self, name: str, hash: str, weight: float) -> None:
self.name = name
self.hash = hash
self.weight = weight
2024-01-03 03:24:27 +00:00
class ImageMetadata:
2024-01-06 22:59:02 +00:00
ancestors: List["ImageMetadata"]
2024-01-06 23:20:37 +00:00
note: str
2024-01-03 03:24:27 +00:00
params: ImageParams
size: Size
2024-01-06 22:59:02 +00:00
# models
inversions: List[NetworkMetadata]
loras: List[NetworkMetadata]
models: List[NetworkMetadata]
# optional params
border: Optional[Border]
highres: Optional[HighresParams]
upscale: Optional[UpscaleParams]
2024-01-03 03:24:27 +00:00
@staticmethod
def unknown_image() -> "ImageMetadata":
UNKNOWN_STR = "unknown"
return ImageMetadata(
ImageParams(UNKNOWN_STR, UNKNOWN_STR, UNKNOWN_STR, "", 0, 0, 0),
Size(0, 0),
)
2024-01-03 03:24:27 +00:00
def __init__(
self,
params: ImageParams,
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
inversions: Optional[List[NetworkMetadata]] = None,
loras: Optional[List[NetworkMetadata]] = None,
models: Optional[List[NetworkMetadata]] = None,
2024-01-06 22:59:02 +00:00
ancestors: Optional[List["ImageMetadata"]] = None,
2024-01-03 03:24:27 +00:00
) -> None:
self.params = params
self.size = size
self.upscale = upscale
self.border = border
self.highres = highres
2024-01-06 22:59:02 +00:00
self.inversions = inversions or []
self.loras = loras or []
self.models = models or []
self.ancestors = ancestors or []
2024-01-06 23:20:37 +00:00
self.note = ""
2024-01-06 22:59:02 +00:00
def child(
self,
params: ImageParams,
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
inversions: Optional[List[NetworkMetadata]] = None,
loras: Optional[List[NetworkMetadata]] = None,
models: Optional[List[NetworkMetadata]] = None,
) -> "ImageMetadata":
return ImageMetadata(
params,
size,
upscale,
border,
highres,
inversions,
loras,
models,
[self],
)
2024-01-03 03:24:27 +00:00
def get_model_hash(
self, server: ServerContext, model: Optional[str] = None
) -> Tuple[str, str]:
2024-01-04 04:15:50 +00:00
model_name = path.basename(path.normpath(model or self.params.model))
logger.debug("getting model hash for %s", model_name)
if model_name in server.hash_cache:
logger.debug("using cached model hash for %s", model_name)
return (model_name, server.hash_cache[model_name])
model_hash = get_extra_hashes().get(model_name, None)
if model_hash is None:
model_hash_path = path.join(self.params.model, "hash.txt")
if path.exists(model_hash_path):
with open(model_hash_path, "r") as f:
model_hash = f.readline().rstrip(",. \n\t\r")
model_hash = model_hash or "unknown"
server.hash_cache[model_name] = model_hash
return (model_name, model_hash)
def get_network_hash(
self, server: ServerContext, network_name: str, network_type: str
) -> Tuple[str, str]:
# run this again just in case the file path changes
network_path = resolve_tensor(
path.join(server.model_path, network_type, network_name)
)
if network_path in server.hash_cache:
logger.debug("using cached network hash for %s", network_path)
return (network_name, server.hash_cache[network_path])
network_hash = hash_file(network_path).upper()
server.hash_cache[network_path] = network_hash
return (network_name, network_hash)
2024-01-04 04:15:50 +00:00
2024-01-14 19:19:09 +00:00
def to_exif(self, server: ServerContext) -> str:
model_name, model_hash = self.get_model_hash(server)
hash_map = {
model_name: model_hash,
}
inversion_hashes = ""
if self.inversions is not None:
inversion_pairs = [
(
name,
self.get_network_hash(server, name, "inversion")[1],
)
for name, _weight in self.inversions
]
inversion_hashes = ",".join(
[f"{name}: {hash}" for name, hash in inversion_pairs]
)
hash_map.update(dict(inversion_pairs))
lora_hashes = ""
if self.loras is not None:
lora_pairs = [
(
name,
self.get_network_hash(server, name, "lora")[1],
)
for name, _weight in self.loras
]
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
hash_map.update(dict(lora_pairs))
return (
f"{self.params.prompt or ''}\nNegative prompt: {self.params.negative_prompt or ''}\n"
f"Steps: {self.params.steps}, Sampler: {self.params.scheduler}, CFG scale: {self.params.cfg}, "
f"Seed: {self.params.seed}, Size: {self.size.width}x{self.size.height}, "
f"Model hash: {model_hash}, Model: {model_name}, "
f"Tool: onnx-web, Version: {server.server_version}, "
f'Inversion hashes: "{inversion_hashes}", '
f'Lora hashes: "{lora_hashes}", '
f"Hashes: {dumps(hash_map)}"
2024-01-03 03:24:27 +00:00
)
2024-01-04 04:15:50 +00:00
def tojson(self, server: ServerContext, output: List[str]):
json = {
"input_size": self.size.tojson(),
2024-01-04 04:15:50 +00:00
"outputs": output,
"params": self.params.tojson(),
2024-01-04 04:15:50 +00:00
"inversions": [],
"loras": [],
"models": [],
}
2024-01-04 04:15:50 +00:00
# fix up some fields
model_name, model_hash = self.get_model_hash(server, self.params.model)
2024-01-04 05:46:50 +00:00
json["params"]["model"] = model_name
2024-01-04 05:38:44 +00:00
json["models"].append(
{
2024-01-04 05:46:50 +00:00
"hash": model_hash,
"name": model_name,
2024-01-04 05:38:44 +00:00
"weight": 1.0,
}
)
2024-01-14 15:41:54 +00:00
# add optional params
if self.border is not None:
json["border"] = self.border.tojson()
if self.highres is not None:
json["highres"] = self.highres.tojson()
if self.upscale is not None:
json["upscale"] = self.upscale.tojson()
2024-01-14 15:41:54 +00:00
# calculate final output size
json["size"] = self.get_output_size().tojson()
2024-01-14 15:41:54 +00:00
# hash and add models and networks
if self.inversions is not None:
for name, weight in self.inversions:
2024-01-13 04:58:52 +00:00
model_hash = self.get_network_hash(server, name, "inversion")[1]
2024-01-04 04:15:50 +00:00
json["inversions"].append(
2024-01-13 04:58:52 +00:00
{"name": name, "weight": weight, "hash": model_hash}
2024-01-04 04:15:50 +00:00
)
if self.loras is not None:
for name, weight in self.loras:
2024-01-13 04:58:52 +00:00
model_hash = self.get_network_hash(server, name, "lora")[1]
json["loras"].append(
{"name": name, "weight": weight, "hash": model_hash}
)
2024-01-04 04:15:50 +00:00
if self.models is not None:
for name, weight in self.models:
2024-01-13 04:58:52 +00:00
name, model_hash = self.get_model_hash(server)
json["models"].append(
{"name": name, "weight": weight, "hash": model_hash}
)
return json
2024-01-14 15:41:54 +00:00
def get_output_size(self) -> Size:
output_size = self.size
if self.border is not None:
output_size = output_size.add_border(self.border)
if self.highres is not None:
output_size = self.highres.resize(output_size)
if self.upscale is not None:
output_size = self.upscale.resize(output_size)
return output_size
2024-01-14 18:24:59 +00:00
def with_args(
self,
params: Optional[ImageParams] = None,
size: Optional[Size] = None,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
inversions: Optional[List[NetworkMetadata]] = None,
loras: Optional[List[NetworkMetadata]] = None,
models: Optional[List[NetworkMetadata]] = None,
ancestors: Optional[List["ImageMetadata"]] = None,
**kwargs,
2024-01-14 18:24:59 +00:00
) -> "ImageMetadata":
logger.info("ignoring extra kwargs for metadata: %s", kwargs)
2024-01-14 18:24:59 +00:00
return ImageMetadata(
params or self.params,
size or self.size,
upscale=coalesce(upscale, self.upscale),
border=coalesce(border, self.border),
highres=coalesce(highres, self.highres),
inversions=coalesce(inversions, self.inversions),
loras=coalesce(loras, self.loras),
models=coalesce(models, self.models),
ancestors=coalesce(ancestors, self.ancestors),
2024-01-14 18:24:59 +00:00
)
@staticmethod
def from_exif(input: str) -> "ImageMetadata":
lines = input.splitlines()
prompt, maybe_negative, *rest = lines
# process negative prompt or put that line back into rest
if maybe_negative.startswith("Negative prompt:"):
negative_prompt = maybe_negative[len("Negative prompt:") :]
negative_prompt = negative_prompt.strip()
else:
rest.insert(0, maybe_negative)
negative_prompt = None
rest = " ".join(rest)
other_params = rest.split(",")
# process other params
params = {}
size = None
for param in other_params:
key, value = param.split(":")
key = key.strip().lower()
value = value.strip()
if key == "size":
width, height = value.split("x")
width = int(width.strip())
height = int(height.strip())
size = Size(width, height)
elif value.isdecimal():
value = int(value)
elif FLOAT_PATTERN.match(value) is not None:
value = float(value)
params[key] = value
params = ImageParams(
"TODO",
"txt2img", # TODO: can this be detected?
params["sampler"],
prompt,
params["cfg scale"],
params["steps"],
params["seed"],
negative_prompt,
)
return ImageMetadata(params, size)
@staticmethod
def from_json(input: str) -> "ImageMetadata":
data = load_config_str(input)
# TODO: enforce schema
return ImageMetadata(
data["params"],
data["input_size"],
data.get("upscale", None),
data.get("border", None),
data.get("highres", None),
data.get("inversions", None),
data.get("loras", None),
data.get("models", None),
)
2024-01-13 04:58:52 +00:00
ERROR_NO_METADATA = "metadata must be provided"
class StageResult:
2023-11-19 00:13:13 +00:00
"""
Chain pipeline stage result.
Can contain PIL images or numpy arrays, with helpers to convert between them.
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
they are expected.
"""
arrays: Optional[List[np.ndarray]]
images: Optional[List[Image.Image]]
2024-01-03 03:24:27 +00:00
metadata: List[ImageMetadata]
2023-11-19 00:13:13 +00:00
2024-01-29 01:45:37 +00:00
# output paths, filled in when the result is saved
outputs: Optional[List[str]]
thumbnails: Optional[List[str]]
@staticmethod
def empty():
return StageResult(images=[])
2023-11-20 05:18:57 +00:00
@staticmethod
def from_arrays(arrays: List[np.ndarray], metadata: List[ImageMetadata]):
2024-01-06 02:11:58 +00:00
return StageResult(arrays=arrays, metadata=metadata)
2023-11-20 05:18:57 +00:00
@staticmethod
def from_images(images: List[Image.Image], metadata: List[ImageMetadata]):
2024-01-06 02:11:58 +00:00
return StageResult(images=images, metadata=metadata)
2023-11-20 05:18:57 +00:00
2024-01-03 03:49:22 +00:00
def __init__(
self,
arrays: Optional[List[np.ndarray]] = None,
images: Optional[List[Image.Image]] = None,
metadata: Optional[List[ImageMetadata]] = None, # TODO: should not be optional
2024-01-03 03:49:22 +00:00
source: Optional[Any] = None,
) -> None:
2024-01-06 02:11:58 +00:00
data_provided = sum(
[arrays is not None, images is not None, source is not None]
)
if data_provided > 1:
raise ValueError("results must only contain one type of data")
elif data_provided == 0:
raise ValueError("results must contain some data")
2024-01-29 01:45:37 +00:00
self.outputs = None
self.thumbnails = None
2024-01-06 02:11:58 +00:00
if source is not None:
self.arrays = source.arrays
self.images = source.images
self.metadata = source.metadata
else:
self.arrays = arrays
self.images = images
self.metadata = metadata or []
2023-11-19 00:13:13 +00:00
def __len__(self) -> int:
if self.arrays is not None:
return len(self.arrays)
2023-12-03 18:13:45 +00:00
elif self.images is not None:
2023-11-19 00:13:13 +00:00
return len(self.images)
2023-12-03 18:13:45 +00:00
else:
return 0
2023-11-19 00:13:13 +00:00
2024-01-06 02:11:58 +00:00
def as_arrays(self) -> List[np.ndarray]:
2023-11-19 00:13:13 +00:00
if self.arrays is not None:
return self.arrays
2023-12-03 18:13:45 +00:00
elif self.images is not None:
return [np.array(i) for i in self.images]
else:
return []
2023-11-19 00:13:13 +00:00
2024-01-06 02:11:58 +00:00
def as_images(self) -> List[Image.Image]:
2023-11-19 00:13:13 +00:00
if self.images is not None:
return self.images
2023-12-03 18:13:45 +00:00
elif self.arrays is not None:
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
else:
return []
2024-01-06 02:11:58 +00:00
def push_array(self, array: np.ndarray, metadata: ImageMetadata):
2024-01-03 03:24:27 +00:00
if self.arrays is not None:
self.arrays.append(array)
elif self.images is not None:
self.images.append(Image.fromarray(np.uint8(array), shape_mode(array)))
else:
self.arrays = [array]
2024-01-03 03:24:27 +00:00
if metadata is not None:
self.metadata.append(metadata)
else:
2024-01-13 04:58:52 +00:00
raise ValueError(ERROR_NO_METADATA)
2024-01-03 03:24:27 +00:00
2024-01-06 02:11:58 +00:00
def push_image(self, image: Image.Image, metadata: ImageMetadata):
2024-01-03 03:24:27 +00:00
if self.images is not None:
self.images.append(image)
elif self.arrays is not None:
self.arrays.append(np.array(image))
else:
self.images = [image]
2024-01-03 03:24:27 +00:00
if metadata is not None:
self.metadata.append(metadata)
else:
2024-01-13 04:58:52 +00:00
raise ValueError(ERROR_NO_METADATA)
2024-01-03 03:24:27 +00:00
2024-01-06 02:11:58 +00:00
def insert_array(self, index: int, array: np.ndarray, metadata: ImageMetadata):
if self.arrays is not None:
self.arrays.insert(index, array)
elif self.images is not None:
self.images.insert(
index, Image.fromarray(np.uint8(array), shape_mode(array))
)
else:
self.arrays = [array]
if metadata is not None:
self.metadata.insert(index, metadata)
else:
2024-01-13 04:58:52 +00:00
raise ValueError(ERROR_NO_METADATA)
2024-01-06 02:11:58 +00:00
def insert_image(self, index: int, image: Image.Image, metadata: ImageMetadata):
if self.images is not None:
self.images.insert(index, image)
elif self.arrays is not None:
self.arrays.insert(index, np.array(image))
else:
self.images = [image]
if metadata is not None:
self.metadata.insert(index, metadata)
else:
2024-01-13 04:58:52 +00:00
raise ValueError(ERROR_NO_METADATA)
2024-01-06 02:11:58 +00:00
def size(self) -> Size:
if self.images is not None:
2024-01-06 03:19:48 +00:00
return Size(
2024-01-06 03:23:05 +00:00
max([image.width for image in self.images], default=0),
max([image.height for image in self.images], default=0),
2024-01-06 03:19:48 +00:00
)
2024-01-06 02:11:58 +00:00
elif self.arrays is not None:
return Size(
2024-01-06 08:33:01 +00:00
max([array.shape[0] for array in self.arrays], default=0),
max([array.shape[1] for array in self.arrays], default=0),
2024-01-06 02:11:58 +00:00
) # TODO: which fields within the shape are width/height?
else:
return Size(0, 0)
def validate(self) -> None:
"""
Make sure the data exists and that data and metadata match in length.
"""
if self.arrays is None and self.images is None:
raise ValueError("no data in result")
if len(self) != len(self.metadata):
raise ValueError("metadata and data do not match in length")
def shape_mode(arr: np.ndarray) -> str:
if len(arr.shape) != 3:
raise ValueError("unknown array format")
if arr.shape[-1] == 3:
return "RGB"
elif arr.shape[-1] == 4:
return "RGBA"
raise ValueError("unknown image format")