1
0
Fork 0
onnx-web/api/onnx_web/chain/result.py

398 lines
13 KiB
Python
Raw Normal View History

from json import dumps
from logging import getLogger
from os import path
2024-01-04 04:15:50 +00:00
from typing import Any, List, Optional, Tuple
import numpy as np
2023-11-19 00:13:13 +00:00
from PIL import Image
from ..convert.utils import resolve_tensor
2024-01-03 03:24:27 +00:00
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
2024-01-04 04:15:50 +00:00
from ..server.context import ServerContext
from ..server.load import get_extra_hashes
from ..utils import hash_file
logger = getLogger(__name__)
class NetworkMetadata:
name: str
hash: str
weight: float
def __init__(self, name: str, hash: str, weight: float) -> None:
self.name = name
self.hash = hash
self.weight = weight
2024-01-03 03:24:27 +00:00
class ImageMetadata:
2024-01-06 22:59:02 +00:00
ancestors: List["ImageMetadata"]
2024-01-03 03:24:27 +00:00
params: ImageParams
size: Size
2024-01-06 22:59:02 +00:00
# models
inversions: List[NetworkMetadata]
loras: List[NetworkMetadata]
models: List[NetworkMetadata]
# optional params
border: Optional[Border]
highres: Optional[HighresParams]
upscale: Optional[UpscaleParams]
2024-01-03 03:24:27 +00:00
@staticmethod
def unknown_image() -> "ImageMetadata":
UNKNOWN_STR = "unknown"
return ImageMetadata(
ImageParams(UNKNOWN_STR, UNKNOWN_STR, UNKNOWN_STR, "", 0, 0, 0),
Size(0, 0),
)
2024-01-03 03:24:27 +00:00
def __init__(
self,
params: ImageParams,
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
inversions: Optional[List[NetworkMetadata]] = None,
loras: Optional[List[NetworkMetadata]] = None,
models: Optional[List[NetworkMetadata]] = None,
2024-01-06 22:59:02 +00:00
ancestors: Optional[List["ImageMetadata"]] = None,
2024-01-03 03:24:27 +00:00
) -> None:
self.params = params
self.size = size
self.upscale = upscale
self.border = border
self.highres = highres
2024-01-06 22:59:02 +00:00
self.inversions = inversions or []
self.loras = loras or []
self.models = models or []
self.ancestors = ancestors or []
def child(
self,
params: ImageParams,
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
inversions: Optional[List[NetworkMetadata]] = None,
loras: Optional[List[NetworkMetadata]] = None,
models: Optional[List[NetworkMetadata]] = None,
) -> "ImageMetadata":
return ImageMetadata(
params,
size,
upscale,
border,
highres,
inversions,
loras,
models,
[self],
)
2024-01-03 03:24:27 +00:00
def get_model_hash(
self, server: ServerContext, model: Optional[str] = None
) -> Tuple[str, str]:
2024-01-04 04:15:50 +00:00
model_name = path.basename(path.normpath(model or self.params.model))
logger.debug("getting model hash for %s", model_name)
if model_name in server.hash_cache:
logger.debug("using cached model hash for %s", model_name)
return (model_name, server.hash_cache[model_name])
model_hash = get_extra_hashes().get(model_name, None)
if model_hash is None:
model_hash_path = path.join(self.params.model, "hash.txt")
if path.exists(model_hash_path):
with open(model_hash_path, "r") as f:
model_hash = f.readline().rstrip(",. \n\t\r")
model_hash = model_hash or "unknown"
server.hash_cache[model_name] = model_hash
return (model_name, model_hash)
def get_network_hash(
self, server: ServerContext, network_name: str, network_type: str
) -> Tuple[str, str]:
# run this again just in case the file path changes
network_path = resolve_tensor(
path.join(server.model_path, network_type, network_name)
)
if network_path in server.hash_cache:
logger.debug("using cached network hash for %s", network_path)
return (network_name, server.hash_cache[network_path])
network_hash = hash_file(network_path).upper()
server.hash_cache[network_path] = network_hash
return (network_name, network_hash)
2024-01-04 04:15:50 +00:00
def to_exif(self, server: ServerContext, output: List[str]) -> str:
model_name, model_hash = self.get_model_hash(server)
hash_map = {
model_name: model_hash,
}
inversion_hashes = ""
if self.inversions is not None:
inversion_pairs = [
(
name,
self.get_network_hash(server, name, "inversion")[1],
)
for name, _weight in self.inversions
]
inversion_hashes = ",".join(
[f"{name}: {hash}" for name, hash in inversion_pairs]
)
hash_map.update(dict(inversion_pairs))
lora_hashes = ""
if self.loras is not None:
lora_pairs = [
(
name,
self.get_network_hash(server, name, "lora")[1],
)
for name, _weight in self.loras
]
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
hash_map.update(dict(lora_pairs))
return (
f"{self.params.prompt or ''}\nNegative prompt: {self.params.negative_prompt or ''}\n"
f"Steps: {self.params.steps}, Sampler: {self.params.scheduler}, CFG scale: {self.params.cfg}, "
f"Seed: {self.params.seed}, Size: {self.size.width}x{self.size.height}, "
f"Model hash: {model_hash}, Model: {model_name}, "
f"Tool: onnx-web, Version: {server.server_version}, "
f'Inversion hashes: "{inversion_hashes}", '
f'Lora hashes: "{lora_hashes}", '
f"Hashes: {dumps(hash_map)}"
2024-01-03 03:24:27 +00:00
)
2024-01-04 04:15:50 +00:00
def tojson(self, server: ServerContext, output: List[str]):
json = {
"input_size": self.size.tojson(),
2024-01-04 04:15:50 +00:00
"outputs": output,
"params": self.params.tojson(),
2024-01-04 04:15:50 +00:00
"inversions": [],
"loras": [],
"models": [],
}
2024-01-04 04:15:50 +00:00
# fix up some fields
model_name, model_hash = self.get_model_hash(server, self.params.model)
2024-01-04 05:46:50 +00:00
json["params"]["model"] = model_name
2024-01-04 05:38:44 +00:00
json["models"].append(
{
2024-01-04 05:46:50 +00:00
"hash": model_hash,
"name": model_name,
2024-01-04 05:38:44 +00:00
"weight": 1.0,
}
)
# calculate final output size
output_size = self.size
if self.border is not None:
json["border"] = self.border.tojson()
output_size = output_size.add_border(self.border)
if self.highres is not None:
json["highres"] = self.highres.tojson()
output_size = self.highres.resize(output_size)
if self.upscale is not None:
json["upscale"] = self.upscale.tojson()
output_size = self.upscale.resize(output_size)
json["size"] = output_size.tojson()
if self.inversions is not None:
for name, weight in self.inversions:
hash = self.get_network_hash(server, name, "inversion")[1]
2024-01-04 04:15:50 +00:00
json["inversions"].append(
{"name": name, "weight": weight, "hash": hash}
)
if self.loras is not None:
for name, weight in self.loras:
hash = self.get_network_hash(server, name, "lora")[1]
2024-01-04 04:15:50 +00:00
json["loras"].append({"name": name, "weight": weight, "hash": hash})
if self.models is not None:
for name, weight in self.models:
name, hash = self.get_model_hash()
json["models"].append({"name": name, "weight": weight, "hash": hash})
return json
class StageResult:
2023-11-19 00:13:13 +00:00
"""
Chain pipeline stage result.
Can contain PIL images or numpy arrays, with helpers to convert between them.
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
they are expected.
"""
arrays: Optional[List[np.ndarray]]
images: Optional[List[Image.Image]]
2024-01-03 03:24:27 +00:00
metadata: List[ImageMetadata]
2023-11-19 00:13:13 +00:00
@staticmethod
def empty():
return StageResult(images=[])
2023-11-20 05:18:57 +00:00
@staticmethod
def from_arrays(arrays: List[np.ndarray], metadata: List[ImageMetadata]):
2024-01-06 02:11:58 +00:00
return StageResult(arrays=arrays, metadata=metadata)
2023-11-20 05:18:57 +00:00
@staticmethod
def from_images(images: List[Image.Image], metadata: List[ImageMetadata]):
2024-01-06 02:11:58 +00:00
return StageResult(images=images, metadata=metadata)
2023-11-20 05:18:57 +00:00
2024-01-03 03:49:22 +00:00
def __init__(
self,
arrays: Optional[List[np.ndarray]] = None,
images: Optional[List[Image.Image]] = None,
metadata: Optional[List[ImageMetadata]] = None, # TODO: should not be optional
2024-01-03 03:49:22 +00:00
source: Optional[Any] = None,
) -> None:
2024-01-06 02:11:58 +00:00
data_provided = sum(
[arrays is not None, images is not None, source is not None]
)
if data_provided > 1:
raise ValueError("results must only contain one type of data")
elif data_provided == 0:
raise ValueError("results must contain some data")
if source is not None:
self.arrays = source.arrays
self.images = source.images
self.metadata = source.metadata
else:
self.arrays = arrays
self.images = images
self.metadata = metadata or []
2023-11-19 00:13:13 +00:00
def __len__(self) -> int:
if self.arrays is not None:
return len(self.arrays)
2023-12-03 18:13:45 +00:00
elif self.images is not None:
2023-11-19 00:13:13 +00:00
return len(self.images)
2023-12-03 18:13:45 +00:00
else:
return 0
2023-11-19 00:13:13 +00:00
2024-01-06 02:11:58 +00:00
def as_arrays(self) -> List[np.ndarray]:
2023-11-19 00:13:13 +00:00
if self.arrays is not None:
return self.arrays
2023-12-03 18:13:45 +00:00
elif self.images is not None:
return [np.array(i) for i in self.images]
else:
return []
2023-11-19 00:13:13 +00:00
2024-01-06 02:11:58 +00:00
def as_images(self) -> List[Image.Image]:
2023-11-19 00:13:13 +00:00
if self.images is not None:
return self.images
2023-12-03 18:13:45 +00:00
elif self.arrays is not None:
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
else:
return []
2024-01-06 02:11:58 +00:00
def push_array(self, array: np.ndarray, metadata: ImageMetadata):
2024-01-03 03:24:27 +00:00
if self.arrays is not None:
self.arrays.append(array)
elif self.images is not None:
self.images.append(Image.fromarray(np.uint8(array), shape_mode(array)))
else:
self.arrays = [array]
2024-01-03 03:24:27 +00:00
if metadata is not None:
self.metadata.append(metadata)
else:
2024-01-06 02:11:58 +00:00
raise ValueError("metadata must be provided")
2024-01-03 03:24:27 +00:00
2024-01-06 02:11:58 +00:00
def push_image(self, image: Image.Image, metadata: ImageMetadata):
2024-01-03 03:24:27 +00:00
if self.images is not None:
self.images.append(image)
elif self.arrays is not None:
self.arrays.append(np.array(image))
else:
self.images = [image]
2024-01-03 03:24:27 +00:00
if metadata is not None:
self.metadata.append(metadata)
else:
2024-01-06 02:11:58 +00:00
raise ValueError("metadata must be provided")
2024-01-03 03:24:27 +00:00
2024-01-06 02:11:58 +00:00
def insert_array(self, index: int, array: np.ndarray, metadata: ImageMetadata):
if self.arrays is not None:
self.arrays.insert(index, array)
elif self.images is not None:
self.images.insert(
index, Image.fromarray(np.uint8(array), shape_mode(array))
)
else:
self.arrays = [array]
if metadata is not None:
self.metadata.insert(index, metadata)
else:
2024-01-06 02:11:58 +00:00
raise ValueError("metadata must be provided")
2024-01-06 02:11:58 +00:00
def insert_image(self, index: int, image: Image.Image, metadata: ImageMetadata):
if self.images is not None:
self.images.insert(index, image)
elif self.arrays is not None:
self.arrays.insert(index, np.array(image))
else:
self.images = [image]
if metadata is not None:
self.metadata.insert(index, metadata)
else:
2024-01-06 02:11:58 +00:00
raise ValueError("metadata must be provided")
def size(self) -> Size:
if self.images is not None:
2024-01-06 03:19:48 +00:00
return Size(
2024-01-06 03:23:05 +00:00
max([image.width for image in self.images], default=0),
max([image.height for image in self.images], default=0),
2024-01-06 03:19:48 +00:00
)
2024-01-06 02:11:58 +00:00
elif self.arrays is not None:
return Size(
2024-01-06 08:33:01 +00:00
max([array.shape[0] for array in self.arrays], default=0),
max([array.shape[1] for array in self.arrays], default=0),
2024-01-06 02:11:58 +00:00
) # TODO: which fields within the shape are width/height?
else:
return Size(0, 0)
def validate(self) -> None:
"""
Make sure the data exists and that data and metadata match in length.
"""
if self.arrays is None and self.images is None:
raise ValueError("no data in result")
if len(self) != len(self.metadata):
raise ValueError("metadata and data do not match in length")
def shape_mode(arr: np.ndarray) -> str:
if len(arr.shape) != 3:
raise ValueError("unknown array format")
if arr.shape[-1] == 3:
return "RGB"
elif arr.shape[-1] == 4:
return "RGBA"
raise ValueError("unknown image format")