1
0
Fork 0
onnx-web/api/onnx_web/chain/result.py

509 lines
16 KiB
Python

from json import dumps
from logging import getLogger
from os import path
from re import compile
from typing import Any, List, Optional, Tuple
import numpy as np
from PIL import Image
from ..convert.utils import resolve_tensor
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
from ..server.context import ServerContext
from ..server.load import get_extra_hashes
from ..utils import hash_file, load_config_str
logger = getLogger(__name__)
FLOAT_PATTERN = compile(r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?")
class NetworkMetadata:
name: str
hash: str
weight: float
def __init__(self, name: str, hash: str, weight: float) -> None:
self.name = name
self.hash = hash
self.weight = weight
class ImageMetadata:
ancestors: List["ImageMetadata"]
note: str
params: ImageParams
size: Size
# models
inversions: List[NetworkMetadata]
loras: List[NetworkMetadata]
models: List[NetworkMetadata]
# optional params
border: Optional[Border]
highres: Optional[HighresParams]
upscale: Optional[UpscaleParams]
@staticmethod
def unknown_image() -> "ImageMetadata":
UNKNOWN_STR = "unknown"
return ImageMetadata(
ImageParams(UNKNOWN_STR, UNKNOWN_STR, UNKNOWN_STR, "", 0, 0, 0),
Size(0, 0),
)
def __init__(
self,
params: ImageParams,
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
inversions: Optional[List[NetworkMetadata]] = None,
loras: Optional[List[NetworkMetadata]] = None,
models: Optional[List[NetworkMetadata]] = None,
ancestors: Optional[List["ImageMetadata"]] = None,
) -> None:
self.params = params
self.size = size
self.upscale = upscale
self.border = border
self.highres = highres
self.inversions = inversions or []
self.loras = loras or []
self.models = models or []
self.ancestors = ancestors or []
self.note = ""
def child(
self,
params: ImageParams,
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
inversions: Optional[List[NetworkMetadata]] = None,
loras: Optional[List[NetworkMetadata]] = None,
models: Optional[List[NetworkMetadata]] = None,
) -> "ImageMetadata":
return ImageMetadata(
params,
size,
upscale,
border,
highres,
inversions,
loras,
models,
[self],
)
def get_model_hash(
self, server: ServerContext, model: Optional[str] = None
) -> Tuple[str, str]:
model_name = path.basename(path.normpath(model or self.params.model))
logger.debug("getting model hash for %s", model_name)
if model_name in server.hash_cache:
logger.debug("using cached model hash for %s", model_name)
return (model_name, server.hash_cache[model_name])
model_hash = get_extra_hashes().get(model_name, None)
if model_hash is None:
model_hash_path = path.join(self.params.model, "hash.txt")
if path.exists(model_hash_path):
with open(model_hash_path, "r") as f:
model_hash = f.readline().rstrip(",. \n\t\r")
model_hash = model_hash or "unknown"
server.hash_cache[model_name] = model_hash
return (model_name, model_hash)
def get_network_hash(
self, server: ServerContext, network_name: str, network_type: str
) -> Tuple[str, str]:
# run this again just in case the file path changes
network_path = resolve_tensor(
path.join(server.model_path, network_type, network_name)
)
if network_path in server.hash_cache:
logger.debug("using cached network hash for %s", network_path)
return (network_name, server.hash_cache[network_path])
network_hash = hash_file(network_path).upper()
server.hash_cache[network_path] = network_hash
return (network_name, network_hash)
def to_exif(self, server: ServerContext, output: List[str]) -> str:
model_name, model_hash = self.get_model_hash(server)
hash_map = {
model_name: model_hash,
}
inversion_hashes = ""
if self.inversions is not None:
inversion_pairs = [
(
name,
self.get_network_hash(server, name, "inversion")[1],
)
for name, _weight in self.inversions
]
inversion_hashes = ",".join(
[f"{name}: {hash}" for name, hash in inversion_pairs]
)
hash_map.update(dict(inversion_pairs))
lora_hashes = ""
if self.loras is not None:
lora_pairs = [
(
name,
self.get_network_hash(server, name, "lora")[1],
)
for name, _weight in self.loras
]
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
hash_map.update(dict(lora_pairs))
return (
f"{self.params.prompt or ''}\nNegative prompt: {self.params.negative_prompt or ''}\n"
f"Steps: {self.params.steps}, Sampler: {self.params.scheduler}, CFG scale: {self.params.cfg}, "
f"Seed: {self.params.seed}, Size: {self.size.width}x{self.size.height}, "
f"Model hash: {model_hash}, Model: {model_name}, "
f"Tool: onnx-web, Version: {server.server_version}, "
f'Inversion hashes: "{inversion_hashes}", '
f'Lora hashes: "{lora_hashes}", '
f"Hashes: {dumps(hash_map)}"
)
def tojson(self, server: ServerContext, output: List[str]):
json = {
"input_size": self.size.tojson(),
"outputs": output,
"params": self.params.tojson(),
"inversions": [],
"loras": [],
"models": [],
}
# fix up some fields
model_name, model_hash = self.get_model_hash(server, self.params.model)
json["params"]["model"] = model_name
json["models"].append(
{
"hash": model_hash,
"name": model_name,
"weight": 1.0,
}
)
# add optional params
if self.border is not None:
json["border"] = self.border.tojson()
if self.highres is not None:
json["highres"] = self.highres.tojson()
if self.upscale is not None:
json["upscale"] = self.upscale.tojson()
# calculate final output size
json["size"] = self.get_output_size().tojson()
# hash and add models and networks
if self.inversions is not None:
for name, weight in self.inversions:
model_hash = self.get_network_hash(server, name, "inversion")[1]
json["inversions"].append(
{"name": name, "weight": weight, "hash": model_hash}
)
if self.loras is not None:
for name, weight in self.loras:
model_hash = self.get_network_hash(server, name, "lora")[1]
json["loras"].append(
{"name": name, "weight": weight, "hash": model_hash}
)
if self.models is not None:
for name, weight in self.models:
name, model_hash = self.get_model_hash(server)
json["models"].append(
{"name": name, "weight": weight, "hash": model_hash}
)
return json
def get_output_size(self) -> Size:
output_size = self.size
if self.border is not None:
output_size = output_size.add_border(self.border)
if self.highres is not None:
output_size = self.highres.resize(output_size)
if self.upscale is not None:
output_size = self.upscale.resize(output_size)
return output_size
def with_args(
self,
params: Optional[ImageParams] = None,
size: Optional[Size] = None,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
inversions: Optional[List[NetworkMetadata]] = None,
loras: Optional[List[NetworkMetadata]] = None,
models: Optional[List[NetworkMetadata]] = None,
ancestors: Optional[List["ImageMetadata"]] = None,
) -> "ImageMetadata":
return ImageMetadata(
params or self.params,
size or self.size,
upscale=upscale or self.upscale,
border=border or self.border,
highres=highres or self.highres,
inversions=inversions or self.inversions,
loras=loras or self.loras,
models=models or self.models,
ancestors=ancestors or self.ancestors,
)
@staticmethod
def from_exif(input: str) -> "ImageMetadata":
lines = input.splitlines()
prompt, maybe_negative, *rest = lines
# process negative prompt or put that line back into rest
if maybe_negative.startswith("Negative prompt:"):
negative_prompt = maybe_negative[len("Negative prompt:") :]
negative_prompt = negative_prompt.strip()
else:
rest.insert(0, maybe_negative)
negative_prompt = None
rest = " ".join(rest)
other_params = rest.split(",")
# process other params
params = {}
size = None
for param in other_params:
key, value = param.split(":")
key = key.strip().lower()
value = value.strip()
if key == "size":
width, height = value.split("x")
width = int(width.strip())
height = int(height.strip())
size = Size(width, height)
elif value.isdecimal():
value = int(value)
elif FLOAT_PATTERN.match(value) is not None:
value = float(value)
params[key] = value
params = ImageParams(
"TODO",
"txt2img", # TODO: can this be detected?
params["sampler"],
prompt,
params["cfg scale"],
params["steps"],
params["seed"],
negative_prompt,
)
return ImageMetadata(params, size)
@staticmethod
def from_json(input: str) -> "ImageMetadata":
data = load_config_str(input)
# TODO: enforce schema
return ImageMetadata(
data["params"],
data["input_size"],
data.get("upscale", None),
data.get("border", None),
data.get("highres", None),
data.get("inversions", None),
data.get("loras", None),
data.get("models", None),
)
ERROR_NO_METADATA = "metadata must be provided"
class StageResult:
"""
Chain pipeline stage result.
Can contain PIL images or numpy arrays, with helpers to convert between them.
This class intentionally does not provide `__iter__`, to ensure clients get results in the format
they are expected.
"""
arrays: Optional[List[np.ndarray]]
images: Optional[List[Image.Image]]
metadata: List[ImageMetadata]
@staticmethod
def empty():
return StageResult(images=[])
@staticmethod
def from_arrays(arrays: List[np.ndarray], metadata: List[ImageMetadata]):
return StageResult(arrays=arrays, metadata=metadata)
@staticmethod
def from_images(images: List[Image.Image], metadata: List[ImageMetadata]):
return StageResult(images=images, metadata=metadata)
def __init__(
self,
arrays: Optional[List[np.ndarray]] = None,
images: Optional[List[Image.Image]] = None,
metadata: Optional[List[ImageMetadata]] = None, # TODO: should not be optional
source: Optional[Any] = None,
) -> None:
data_provided = sum(
[arrays is not None, images is not None, source is not None]
)
if data_provided > 1:
raise ValueError("results must only contain one type of data")
elif data_provided == 0:
raise ValueError("results must contain some data")
if source is not None:
self.arrays = source.arrays
self.images = source.images
self.metadata = source.metadata
else:
self.arrays = arrays
self.images = images
self.metadata = metadata or []
def __len__(self) -> int:
if self.arrays is not None:
return len(self.arrays)
elif self.images is not None:
return len(self.images)
else:
return 0
def as_arrays(self) -> List[np.ndarray]:
if self.arrays is not None:
return self.arrays
elif self.images is not None:
return [np.array(i) for i in self.images]
else:
return []
def as_images(self) -> List[Image.Image]:
if self.images is not None:
return self.images
elif self.arrays is not None:
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
else:
return []
def push_array(self, array: np.ndarray, metadata: ImageMetadata):
if self.arrays is not None:
self.arrays.append(array)
elif self.images is not None:
self.images.append(Image.fromarray(np.uint8(array), shape_mode(array)))
else:
self.arrays = [array]
if metadata is not None:
self.metadata.append(metadata)
else:
raise ValueError(ERROR_NO_METADATA)
def push_image(self, image: Image.Image, metadata: ImageMetadata):
if self.images is not None:
self.images.append(image)
elif self.arrays is not None:
self.arrays.append(np.array(image))
else:
self.images = [image]
if metadata is not None:
self.metadata.append(metadata)
else:
raise ValueError(ERROR_NO_METADATA)
def insert_array(self, index: int, array: np.ndarray, metadata: ImageMetadata):
if self.arrays is not None:
self.arrays.insert(index, array)
elif self.images is not None:
self.images.insert(
index, Image.fromarray(np.uint8(array), shape_mode(array))
)
else:
self.arrays = [array]
if metadata is not None:
self.metadata.insert(index, metadata)
else:
raise ValueError(ERROR_NO_METADATA)
def insert_image(self, index: int, image: Image.Image, metadata: ImageMetadata):
if self.images is not None:
self.images.insert(index, image)
elif self.arrays is not None:
self.arrays.insert(index, np.array(image))
else:
self.images = [image]
if metadata is not None:
self.metadata.insert(index, metadata)
else:
raise ValueError(ERROR_NO_METADATA)
def size(self) -> Size:
if self.images is not None:
return Size(
max([image.width for image in self.images], default=0),
max([image.height for image in self.images], default=0),
)
elif self.arrays is not None:
return Size(
max([array.shape[0] for array in self.arrays], default=0),
max([array.shape[1] for array in self.arrays], default=0),
) # TODO: which fields within the shape are width/height?
else:
return Size(0, 0)
def validate(self) -> None:
"""
Make sure the data exists and that data and metadata match in length.
"""
if self.arrays is None and self.images is None:
raise ValueError("no data in result")
if len(self) != len(self.metadata):
raise ValueError("metadata and data do not match in length")
def shape_mode(arr: np.ndarray) -> str:
if len(arr.shape) != 3:
raise ValueError("unknown array format")
if arr.shape[-1] == 3:
return "RGB"
elif arr.shape[-1] == 4:
return "RGBA"
raise ValueError("unknown image format")