From 44a8d610822e851a07f89fc267389fc919f2b3a3 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 3 Jan 2024 19:09:18 -0600 Subject: [PATCH] feat: add batch endpoints for cancel and status, update responses --- api/onnx_web/chain/correct_codeformer.py | 5 +- api/onnx_web/chain/result.py | 180 ++++++++++++++-- api/onnx_web/diffusers/run.py | 116 +++------- api/onnx_web/output.py | 259 +++++------------------ api/onnx_web/params.py | 18 ++ api/onnx_web/server/admin.py | 8 +- api/onnx_web/server/api.py | 194 ++++++++++++----- api/onnx_web/transformers/run.py | 3 +- api/onnx_web/utils.py | 35 ++- api/onnx_web/worker/command.py | 53 +++-- api/onnx_web/worker/context.py | 63 ++---- api/onnx_web/worker/pool.py | 102 +++++---- api/onnx_web/worker/worker.py | 2 +- api/scripts/test-diffusers.py | 6 - api/tests/test_diffusers/test_run.py | 28 +-- api/tests/worker/test_pool.py | 29 +-- api/tests/worker/test_worker.py | 6 +- docs/index.md | 14 +- gui/src/client/api.ts | 99 ++++----- gui/src/client/base.ts | 39 ++-- gui/src/client/local.ts | 9 +- gui/src/client/utils.ts | 10 +- gui/src/components/ImageHistory.tsx | 23 +- gui/src/components/card/ErrorCard.tsx | 20 +- gui/src/components/card/ImageCard.tsx | 40 ++-- gui/src/components/card/LoadingCard.tsx | 103 +++++---- gui/src/components/tab/Blend.tsx | 4 +- gui/src/components/tab/Img2Img.tsx | 4 +- gui/src/components/tab/Inpaint.tsx | 8 +- gui/src/components/tab/Txt2Img.tsx | 4 +- gui/src/components/tab/Upscale.tsx | 4 +- gui/src/state/history.ts | 22 +- gui/src/strings/en.ts | 2 +- gui/src/types/api-v2.ts | 160 ++++++++++++++ gui/src/types/api.ts | 12 ++ onnx-web.code-workspace | 4 +- 36 files changed, 981 insertions(+), 707 deletions(-) create mode 100644 gui/src/types/api-v2.ts diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 84c8a9d9..a03da0f6 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -36,15 +36,14 @@ class CorrectCodeformerStage(BaseStage): # https://pypi.org/project/codeformer-perceptor/ # import must be within the load function for patches to take effect - # TODO: rewrite and remove + from codeformer.basicsr.archs.codeformer_arch import CodeFormer from codeformer.basicsr.utils import img2tensor, tensor2img - from codeformer.basicsr.utils.registry import ARCH_REGISTRY from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper upscale = upscale.with_args(**kwargs) device = worker.get_device() - net = ARCH_REGISTRY.get("CodeFormer")( + net = CodeFormer( dim_embd=512, codebook_size=1024, n_head=8, diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index 2e60877f..7166f397 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -1,10 +1,28 @@ -from typing import Any, List, Optional, Tuple +from json import dumps +from logging import getLogger +from os import path +from typing import Any, List, Optional import numpy as np from PIL import Image -from ..output import json_params +from ..convert.utils import resolve_tensor from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams +from ..server.load import get_extra_hashes +from ..utils import hash_file + +logger = getLogger(__name__) + + +class NetworkMetadata: + name: str + hash: str + weight: float + + def __init__(self, name: str, hash: str, weight: float) -> None: + self.name = name + self.hash = hash + self.weight = weight class ImageMetadata: @@ -13,8 +31,9 @@ class ImageMetadata: params: ImageParams size: Size upscale: UpscaleParams - inversions: Optional[List[Tuple[str, float]]] - loras: Optional[List[Tuple[str, float]]] + inversions: Optional[List[NetworkMetadata]] + loras: Optional[List[NetworkMetadata]] + models: Optional[List[NetworkMetadata]] def __init__( self, @@ -23,8 +42,9 @@ class ImageMetadata: upscale: Optional[UpscaleParams] = None, border: Optional[Border] = None, highres: Optional[HighresParams] = None, - inversions: Optional[List[Tuple[str, float]]] = None, - loras: Optional[List[Tuple[str, float]]] = None, + inversions: Optional[List[NetworkMetadata]] = None, + loras: Optional[List[NetworkMetadata]] = None, + models: Optional[List[NetworkMetadata]] = None, ) -> None: self.params = params self.size = size @@ -33,19 +53,108 @@ class ImageMetadata: self.highres = highres self.inversions = inversions self.loras = loras + self.models = models + + def to_auto1111(self, server, outputs) -> str: + model_name = path.basename(path.normpath(self.params.model)) + logger.debug("getting model hash for %s", model_name) + + model_hash = get_extra_hashes().get(model_name, None) + if model_hash is None: + model_hash_path = path.join(self.params.model, "hash.txt") + if path.exists(model_hash_path): + with open(model_hash_path, "r") as f: + model_hash = f.readline().rstrip(",. \n\t\r") + + model_hash = model_hash or "unknown" + hash_map = { + model_name: model_hash, + } + + inversion_hashes = "" + if self.inversions is not None: + inversion_pairs = [ + ( + name, + hash_file( + resolve_tensor(path.join(server.model_path, "inversion", name)) + ).upper(), + ) + for name, _weight in self.inversions + ] + inversion_hashes = ",".join( + [f"{name}: {hash}" for name, hash in inversion_pairs] + ) + hash_map.update(dict(inversion_pairs)) + + lora_hashes = "" + if self.loras is not None: + lora_pairs = [ + ( + name, + hash_file( + resolve_tensor(path.join(server.model_path, "lora", name)) + ).upper(), + ) + for name, _weight in self.loras + ] + lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs]) + hash_map.update(dict(lora_pairs)) + + return ( + f"{self.params.prompt or ''}\nNegative prompt: {self.params.negative_prompt or ''}\n" + f"Steps: {self.params.steps}, Sampler: {self.params.scheduler}, CFG scale: {self.params.cfg}, " + f"Seed: {self.params.seed}, Size: {self.size.width}x{self.size.height}, " + f"Model hash: {model_hash}, Model: {model_name}, " + f"Tool: onnx-web, Version: {server.server_version}, " + f'Inversion hashes: "{inversion_hashes}", ' + f'Lora hashes: "{lora_hashes}", ' + f"Hashes: {dumps(hash_map)}" + ) def tojson(self, server, outputs): - return json_params( - server, - outputs, - self.params, - self.size, - upscale=self.upscale, - border=self.border, - highres=self.highres, - inversions=self.inversions, - loras=self.loras, - ) + json = { + "input_size": self.size.tojson(), + "outputs": outputs, + "params": self.params.tojson(), + "inversions": {}, + "loras": {}, + } + + json["params"]["model"] = path.basename(self.params.model) + json["params"]["scheduler"] = self.params.scheduler # TODO: why tho? + + # calculate final output size + output_size = self.size + if self.border is not None: + json["border"] = self.border.tojson() + output_size = output_size.add_border(self.border) + + if self.highres is not None: + json["highres"] = self.highres.tojson() + output_size = self.highres.resize(output_size) + + if self.upscale is not None: + json["upscale"] = self.upscale.tojson() + output_size = self.upscale.resize(output_size) + + json["size"] = output_size.tojson() + + if self.inversions is not None: + for name, weight in self.inversions: + hash = hash_file( + resolve_tensor(path.join(server.model_path, "inversion", name)) + ).upper() + json["inversions"][name] = {"weight": weight, "hash": hash} + + if self.loras is not None: + for name, weight in self.loras: + hash = hash_file( + resolve_tensor(path.join(server.model_path, "lora", name)) + ).upper() + json["loras"][name] = {"weight": weight, "hash": hash} + + return json class StageResult: @@ -86,6 +195,7 @@ class StageResult: self.arrays = arrays self.images = images self.source = source + self.metadata = [] def __len__(self) -> int: if self.arrays is not None: @@ -117,7 +227,7 @@ class StageResult: elif self.images is not None: self.images.append(Image.fromarray(np.uint8(array), shape_mode(array))) else: - raise ValueError("invalid stage result") + self.arrays = [array] if metadata is not None: self.metadata.append(metadata) @@ -130,13 +240,45 @@ class StageResult: elif self.arrays is not None: self.arrays.append(np.array(image)) else: - raise ValueError("invalid stage result") + self.images = [image] if metadata is not None: self.metadata.append(metadata) else: self.metadata.append(ImageMetadata()) + def insert_array( + self, index: int, array: np.ndarray, metadata: Optional[ImageMetadata] + ): + if self.arrays is not None: + self.arrays.insert(index, array) + elif self.images is not None: + self.images.insert( + index, Image.fromarray(np.uint8(array), shape_mode(array)) + ) + else: + self.arrays = [array] + + if metadata is not None: + self.metadata.insert(index, metadata) + else: + self.metadata.insert(index, ImageMetadata()) + + def insert_image( + self, index: int, image: Image.Image, metadata: Optional[ImageMetadata] + ): + if self.images is not None: + self.images.insert(index, image) + elif self.arrays is not None: + self.arrays.insert(index, np.array(image)) + else: + self.images = [image] + + if metadata is not None: + self.metadata.insert(index, metadata) + else: + self.metadata.insert(index, ImageMetadata()) + def shape_mode(arr: np.ndarray) -> str: if len(arr.shape) != 3: diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 1fef8087..5bc625c4 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -16,7 +16,7 @@ from ..chain.highres import stage_highres from ..chain.result import StageResult from ..chain.upscale import split_upscale, stage_upscale_correction from ..image import expand_image -from ..output import save_image +from ..output import save_image, save_result from ..params import ( Border, HighresParams, @@ -29,7 +29,7 @@ from ..server import ServerContext from ..server.load import get_source_filters from ..utils import is_debug, run_gc, show_system_toast from ..worker import WorkerContext -from .utils import get_latents_from_seed, parse_prompt +from .utils import get_latents_from_seed logger = getLogger(__name__) @@ -57,7 +57,6 @@ def run_txt2img_pipeline( server: ServerContext, params: ImageParams, size: Size, - outputs: List[str], upscale: UpscaleParams, highres: HighresParams, ) -> None: @@ -114,50 +113,34 @@ def run_txt2img_pipeline( # run and save latents = get_latents_from_seed(params.seed, size, batch=params.batch) progress = worker.get_progress_callback() - images = chain.run( + images = chain( worker, server, params, StageResult.empty(), callback=progress, latents=latents ) - _pairs, loras, inversions, _rest = parse_prompt(params) - # add a thumbnail, if requested - cover = images[0] + cover = images.as_image()[0] if params.thumbnail and ( cover.width > server.thumbnail_size or cover.height > server.thumbnail_size ): thumbnail = cover.copy() thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size)) - images.insert(0, thumbnail) - outputs.insert(0, f"{worker.name}-thumb.{server.image_format}") + images.insert_image(0, thumbnail) - for image, output in zip(images, outputs): - logger.trace("saving output image %s: %s", output, image.size) - dest = save_image( - server, - output, - image, - params, - size, - upscale=upscale, - highres=highres, - inversions=inversions, - loras=loras, - ) + save_result(server, images, worker.job) # clean up run_gc([worker.get_device()]) # notify the user - show_system_toast(f"finished txt2img job: {dest}") - logger.info("finished txt2img job: %s", dest) + show_system_toast(f"finished txt2img job: {worker.job}") + logger.info("finished txt2img job: %s", worker.job) def run_img2img_pipeline( worker: WorkerContext, server: ServerContext, params: ImageParams, - outputs: List[str], upscale: UpscaleParams, highres: HighresParams, source: Image.Image, @@ -228,36 +211,21 @@ def run_img2img_pipeline( # run and append the filtered source progress = worker.get_progress_callback() - images = chain.run( + images = chain( worker, server, params, StageResult(images=[source]), callback=progress ) if source_filter is not None and source_filter != "none": - images.append(source) + images.push_image(source) - # save with metadata - _pairs, loras, inversions, _rest = parse_prompt(params) - size = Size(*source.size) - - for image, output in zip(images, outputs): - dest = save_image( - server, - output, - image, - params, - size, - upscale=upscale, - highres=highres, - inversions=inversions, - loras=loras, - ) + save_result(server, images, worker.job) # clean up run_gc([worker.get_device()]) # notify the user - show_system_toast(f"finished img2img job: {dest}") - logger.info("finished img2img job: %s", dest) + show_system_toast(f"finished img2img job: {worker.job}") + logger.info("finished img2img job: %s", worker.job) def run_inpaint_pipeline( @@ -265,7 +233,6 @@ def run_inpaint_pipeline( server: ServerContext, params: ImageParams, size: Size, - outputs: List[str], upscale: UpscaleParams, highres: HighresParams, source: Image.Image, @@ -290,7 +257,7 @@ def run_inpaint_pipeline( mask = ImageOps.contain(mask, (mask_max, mask_max)) mask = mask.crop((0, 0, source.width, source.height)) - source, mask, noise, full_size = expand_image( + source, mask, noise, _full_size = expand_image( source, mask, border, @@ -414,7 +381,7 @@ def run_inpaint_pipeline( # run and save latents = get_latents_from_seed(params.seed, size, batch=params.batch) progress = worker.get_progress_callback() - images = chain.run( + images = chain( worker, server, params, @@ -423,33 +390,28 @@ def run_inpaint_pipeline( latents=latents, ) - _pairs, loras, inversions, _rest = parse_prompt(params) - for image, output in zip(images, outputs): + for i, image, metadata in enumerate(zip(images.as_image(), images.metadata)): if full_res_inpaint: if is_debug(): save_image(server, "adjusted-output.png", image) + mini_image = ImageOps.contain(image, (adj_mask_size, adj_mask_size)) image = original_source image.paste(mini_image, box=adj_mask_border) - dest = save_image( + + save_image( server, - output, + f"{worker.job}_{i}.{server.image_format}", image, - params, - size, - upscale=upscale, - border=border, - inversions=inversions, - loras=loras, + metadata, ) # clean up - del image run_gc([worker.get_device()]) # notify the user - show_system_toast(f"finished inpaint job: {dest}") - logger.info("finished inpaint job: %s", dest) + show_system_toast(f"finished inpaint job: {worker.job}") + logger.info("finished inpaint job: %s", worker.job) def run_upscale_pipeline( @@ -457,7 +419,6 @@ def run_upscale_pipeline( server: ServerContext, params: ImageParams, size: Size, - outputs: List[str], upscale: UpscaleParams, highres: HighresParams, source: Image.Image, @@ -497,30 +458,18 @@ def run_upscale_pipeline( # run and save progress = worker.get_progress_callback() - images = chain.run( + images = chain( worker, server, params, StageResult(images=[source]), callback=progress ) - _pairs, loras, inversions, _rest = parse_prompt(params) - for image, output in zip(images, outputs): - dest = save_image( - server, - output, - image, - params, - size, - upscale=upscale, - inversions=inversions, - loras=loras, - ) + save_result(server, images, worker.job) # clean up - del image run_gc([worker.get_device()]) # notify the user - show_system_toast(f"finished upscale job: {dest}") - logger.info("finished upscale job: %s", dest) + show_system_toast(f"finished upscale job: {worker.job}") + logger.info("finished upscale job: %s", worker.job) def run_blend_pipeline( @@ -528,7 +477,6 @@ def run_blend_pipeline( server: ServerContext, params: ImageParams, size: Size, - outputs: List[str], upscale: UpscaleParams, # highres: HighresParams, sources: List[Image.Image], @@ -559,17 +507,15 @@ def run_blend_pipeline( # run and save progress = worker.get_progress_callback() - images = chain.run( + images = chain( worker, server, params, StageResult(images=sources), callback=progress ) - for image, output in zip(images, outputs): - dest = save_image(server, output, image, params, size, upscale=upscale) + save_result(server, images, worker.job) # clean up - del image run_gc([worker.get_device()]) # notify the user - show_system_toast(f"finished blend job: {dest}") - logger.info("finished blend job: %s", dest) + show_system_toast(f"finished blend job: {worker.job}") + logger.info("finished blend job: %s", worker.job) diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index 816d2184..7faf301e 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -1,173 +1,20 @@ from hashlib import sha256 from json import dumps from logging import getLogger -from os import path -from struct import pack from time import time -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional from piexif import ExifIFD, ImageIFD, dump from piexif.helper import UserComment from PIL import Image, PngImagePlugin -from .convert.utils import resolve_tensor -from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams +from .chain.result import ImageMetadata, StageResult +from .params import ImageParams, Param, Size from .server import ServerContext -from .server.load import get_extra_hashes -from .utils import base_join +from .utils import base_join, hash_value logger = getLogger(__name__) -HASH_BUFFER_SIZE = 2**22 # 4MB - - -def hash_file(name: str): - sha = sha256() - with open(name, "rb") as f: - while True: - data = f.read(HASH_BUFFER_SIZE) - if not data: - break - - sha.update(data) - - return sha.hexdigest() - - -def hash_value(sha, param: Optional[Param]): - if param is None: - return - elif isinstance(param, bool): - sha.update(bytearray(pack("!B", param))) - elif isinstance(param, float): - sha.update(bytearray(pack("!f", param))) - elif isinstance(param, int): - sha.update(bytearray(pack("!I", param))) - elif isinstance(param, str): - sha.update(param.encode("utf-8")) - else: - logger.warning("cannot hash param: %s, %s", param, type(param)) - - -def json_params( - server: ServerContext, - outputs: List[str], - params: ImageParams, - size: Size, - upscale: Optional[UpscaleParams] = None, - border: Optional[Border] = None, - highres: Optional[HighresParams] = None, - inversions: Optional[List[Tuple[str, float]]] = None, - loras: Optional[List[Tuple[str, float]]] = None, - parent: Optional[Dict] = None, -) -> Any: - json = { - "input_size": size.tojson(), - "outputs": outputs, - "params": params.tojson(), - "inversions": {}, - "loras": {}, - } - - json["params"]["model"] = path.basename(params.model) - json["params"]["scheduler"] = params.scheduler - - # calculate final output size - output_size = size - if border is not None: - json["border"] = border.tojson() - output_size = output_size.add_border(border) - - if highres is not None: - json["highres"] = highres.tojson() - output_size = highres.resize(output_size) - - if upscale is not None: - json["upscale"] = upscale.tojson() - output_size = upscale.resize(output_size) - - json["size"] = output_size.tojson() - - if inversions is not None: - for name, weight in inversions: - hash = hash_file( - resolve_tensor(path.join(server.model_path, "inversion", name)) - ).upper() - json["inversions"][name] = {"weight": weight, "hash": hash} - - if loras is not None: - for name, weight in loras: - hash = hash_file( - resolve_tensor(path.join(server.model_path, "lora", name)) - ).upper() - json["loras"][name] = {"weight": weight, "hash": hash} - - return json - - -def str_params( - server: ServerContext, - params: ImageParams, - size: Size, - inversions: List[Tuple[str, float]] = None, - loras: List[Tuple[str, float]] = None, -) -> str: - model_name = path.basename(path.normpath(params.model)) - logger.debug("getting model hash for %s", model_name) - - model_hash = get_extra_hashes().get(model_name, None) - if model_hash is None: - model_hash_path = path.join(params.model, "hash.txt") - if path.exists(model_hash_path): - with open(model_hash_path, "r") as f: - model_hash = f.readline().rstrip(",. \n\t\r") - - model_hash = model_hash or "unknown" - hash_map = { - model_name: model_hash, - } - - inversion_hashes = "" - if inversions is not None: - inversion_pairs = [ - ( - name, - hash_file( - resolve_tensor(path.join(server.model_path, "inversion", name)) - ).upper(), - ) - for name, _weight in inversions - ] - inversion_hashes = ",".join( - [f"{name}: {hash}" for name, hash in inversion_pairs] - ) - hash_map.update(dict(inversion_pairs)) - - lora_hashes = "" - if loras is not None: - lora_pairs = [ - ( - name, - hash_file( - resolve_tensor(path.join(server.model_path, "lora", name)) - ).upper(), - ) - for name, _weight in loras - ] - lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs]) - hash_map.update(dict(lora_pairs)) - - return ( - f"{params.prompt or ''}\nNegative prompt: {params.negative_prompt or ''}\n" - f"Steps: {params.steps}, Sampler: {params.scheduler}, CFG scale: {params.cfg}, " - f"Seed: {params.seed}, Size: {size.width}x{size.height}, " - f"Model hash: {model_hash}, Model: {model_name}, " - f"Tool: onnx-web, Version: {server.server_version}, " - f'Inversion hashes: "{inversion_hashes}", ' - f'Lora hashes: "{lora_hashes}", ' - f"Hashes: {dumps(hash_map)}" - ) - def make_output_name( server: ServerContext, @@ -179,6 +26,19 @@ def make_output_name( offset: int = 0, ) -> List[str]: count = count or params.batch + job_name = make_job_name(mode, params, size, extras) + + return [ + f"{job_name}_{i}.{server.image_format}" for i in range(offset, count + offset) + ] + + +def make_job_name( + mode: str, + params: ImageParams, + size: Size, + extras: Optional[List[Optional[Param]]] = None, +) -> str: now = int(time()) sha = sha256() @@ -200,49 +60,49 @@ def make_output_name( for param in extras: hash_value(sha, param) - return [ - f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}" - for i in range(offset, count + offset) - ] + return f"{mode}_{params.seed}_{sha.hexdigest()}_{now}" + + +def save_result( + server: ServerContext, + result: StageResult, + base_name: str, +) -> List[str]: + results = [] + for i, image, metadata in enumerate(zip(result.as_image(), result.metadata)): + results.append( + save_image( + server, + base_name + f"_{i}.{server.image_format}", + image, + metadata, + ) + ) + + return results def save_image( server: ServerContext, output: str, image: Image.Image, - params: Optional[ImageParams] = None, - size: Optional[Size] = None, - upscale: Optional[UpscaleParams] = None, - border: Optional[Border] = None, - highres: Optional[HighresParams] = None, - inversions: List[Tuple[str, float]] = None, - loras: List[Tuple[str, float]] = None, + metadata: ImageMetadata, ) -> str: path = base_join(server.output_path, output) if server.image_format == "png": exif = PngImagePlugin.PngInfo() - if params is not None: + if metadata is not None: exif.add_text("make", "onnx-web") exif.add_text( "maker note", - dumps( - json_params( - server, - [output], - params, - size, - upscale=upscale, - border=border, - highres=highres, - ) - ), + dumps(metadata.tojson(server, [output])), ) exif.add_text("model", server.server_version) exif.add_text( "parameters", - str_params(server, params, size, inversions=inversions, loras=loras), + metadata.to_auto1111(server, [output]), ) image.save(path, format=server.image_format, pnginfo=exif) @@ -251,23 +111,11 @@ def save_image( { "0th": { ExifIFD.MakerNote: UserComment.dump( - dumps( - json_params( - server, - [output], - params, - size, - upscale=upscale, - border=border, - highres=highres, - ) - ), + dumps(metadata.tojson(server, [output])), encoding="unicode", ), ExifIFD.UserComment: UserComment.dump( - str_params( - server, params, size, inversions=inversions, loras=loras - ), + metadata.to_auto1111(server, [output]), encoding="unicode", ), ImageIFD.Make: "onnx-web", @@ -277,34 +125,23 @@ def save_image( ) image.save(path, format=server.image_format, exif=exif) - if params is not None: - save_params( + if metadata is not None: + save_metadata( server, output, - params, - size, - upscale=upscale, - border=border, - highres=highres, ) logger.debug("saved output image to: %s", path) return path -def save_params( +def save_metadata( server: ServerContext, output: str, - params: ImageParams, - size: Size, - upscale: Optional[UpscaleParams] = None, - border: Optional[Border] = None, - highres: Optional[HighresParams] = None, + metadata: ImageMetadata, ) -> str: path = base_join(server.output_path, f"{output}.json") - json = json_params( - server, output, params, size, upscale=upscale, border=border, highres=highres - ) + json = metadata.tojson(server, [output]) with open(path, "w") as f: f.write(dumps(json)) logger.debug("saved image params to: %s", path) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index b8dc3049..894d04f2 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -13,6 +13,24 @@ Param = Union[str, int, float] Point = Tuple[int, int] +class Progress: + current: int + total: int + + def __init__(self, current: int, total: int) -> None: + self.current = current + self.total = total + + def __str__(self) -> str: + return "%s/%s" % (self.current, self.total) + + def tojson(self): + return { + "current": self.current, + "total": self.total, + } + + class SizeChart(IntEnum): micro = 64 mini = 128 # small tile for very expensive models diff --git a/api/onnx_web/server/admin.py b/api/onnx_web/server/admin.py index bdbc9ade..54358f9d 100644 --- a/api/onnx_web/server/admin.py +++ b/api/onnx_web/server/admin.py @@ -26,14 +26,14 @@ def restart_workers(server: ServerContext, pool: DevicePoolExecutor): pool.recycle(recycle_all=True) logger.info("restarted worker pool") - return jsonify(pool.status()) + return jsonify(pool.summary()) def worker_status(server: ServerContext, pool: DevicePoolExecutor): if not check_admin(server): return make_response(jsonify({})), 401 - return jsonify(pool.status()) + return jsonify(pool.summary()) def get_extra_models(server: ServerContext): @@ -102,8 +102,8 @@ def register_admin_routes(app: Flask, server: ServerContext, pool: DevicePoolExe app.route("/api/extras", methods=["PUT"])( wrap_route(update_extra_models, server) ), - app.route("/api/restart", methods=["POST"])( + app.route("/api/worker/restart", methods=["POST"])( wrap_route(restart_workers, server, pool=pool) ), - app.route("/api/status")(wrap_route(worker_status, server, pool=pool)), + app.route("/api/worker/status")(wrap_route(worker_status, server, pool=pool)), ] diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 078bbf62..c20c948c 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -1,14 +1,14 @@ from io import BytesIO from logging import getLogger from os import path -from typing import Any, Dict +from typing import Any, Dict, List from flask import Flask, jsonify, make_response, request, url_for from jsonschema import validate from PIL import Image from ..chain import CHAIN_STAGES, ChainPipeline -from ..chain.result import StageResult +from ..chain.result import ImageMetadata, StageResult from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers from ..diffusers.run import ( run_blend_pipeline, @@ -18,8 +18,8 @@ from ..diffusers.run import ( run_upscale_pipeline, ) from ..diffusers.utils import replace_wildcards -from ..output import json_params, make_output_name -from ..params import Size, StageParams, TileOrder +from ..output import make_job_name +from ..params import Progress, Size, StageParams, TileOrder from ..transformers.run import run_txt2txt_pipeline from ..utils import ( base_join, @@ -34,6 +34,7 @@ from ..utils import ( load_config_str, sanitize_name, ) +from ..worker.command import JobType from ..worker.pool import DevicePoolExecutor from .context import ServerContext from .load import ( @@ -92,6 +93,64 @@ def error_reply(err: str): return response +def job_reply(name: str): + return jsonify( + { + "name": name, + } + ) + + +def image_reply( + name: str, + status: str, + job_type: str, + stages: Progress = None, + steps: Progress = None, + tiles: Progress = None, + outputs: List[str] = None, + metadata: List[ImageMetadata] = None, +): + if stages is None: + stages = Progress() + + if steps is None: + steps = Progress() + + if tiles is None: + tiles = Progress() + + data = { + "name": name, + "status": status, + "type": job_type, + "stages": stages.tojson(), + "steps": steps.tojson(), + "tiles": tiles.tojson(), + } + + if len(metadata) != len(outputs): + logger.error("metadata and outputs must be the same length") + return error_reply("metadata and outputs must be the same length") + + if outputs is not None: + data["outputs"] = outputs + + if metadata is not None: + data["metadata"] = metadata + + return jsonify(data) + + +def multi_image_reply(results: Dict[str, Any]): + # TODO: not that + return jsonify( + { + "results": results, + } + ) + + def url_from_rule(rule) -> str: options = {} for arg in rule.arguments: @@ -197,17 +256,15 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): ) output_count += 1 - output = make_output_name( + job_name = make_job_name( server, "img2img", params, size, extras=[strength], count=output_count ) - - job_name = output[0] pool.submit( job_name, + JobType.IMG2IMG, run_img2img_pipeline, server, params, - output, upscale, highres, source, @@ -218,9 +275,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): logger.info("img2img job queued for: %s", job_name) - return jsonify( - json_params(server, output, params, size, upscale=upscale, highres=highres) - ) + return job_reply(job_name) def txt2img(server: ServerContext, pool: DevicePoolExecutor): @@ -230,16 +285,15 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor): replace_wildcards(params, get_wildcard_data()) - output = make_output_name(server, "txt2img", params, size, count=params.batch) + job_name = make_job_name(server, "txt2img", params, size, count=params.batch) - job_name = output[0] pool.submit( job_name, + JobType.TXT2IMG, run_txt2img_pipeline, server, params, size, - output, upscale, highres, needs_device=device, @@ -247,9 +301,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor): logger.info("txt2img job queued for: %s", job_name) - return jsonify( - json_params(server, output, params, size, upscale=upscale, highres=highres) - ) + return job_reply(job_name) def inpaint(server: ServerContext, pool: DevicePoolExecutor): @@ -295,7 +347,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): replace_wildcards(params, get_wildcard_data()) - output = make_output_name( + job_name = make_job_name( server, "inpaint", params, @@ -312,14 +364,13 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): ], ) - job_name = output[0] pool.submit( job_name, + JobType.INPAINT, run_inpaint_pipeline, server, params, size, - output, upscale, highres, source, @@ -336,17 +387,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): logger.info("inpaint job queued for: %s", job_name) - return jsonify( - json_params( - server, - output, - params, - size, - upscale=upscale, - border=expand, - highres=highres, - ) - ) + return job_reply(job_name) def upscale(server: ServerContext, pool: DevicePoolExecutor): @@ -362,16 +403,14 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): replace_wildcards(params, get_wildcard_data()) - output = make_output_name(server, "upscale", params, size) - - job_name = output[0] + job_name = make_job_name(server, "upscale", params, size) pool.submit( job_name, + JobType.UPSCALE, run_upscale_pipeline, server, params, size, - output, upscale, highres, source, @@ -380,9 +419,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor): logger.info("upscale job queued for: %s", job_name) - return jsonify( - json_params(server, output, params, size, upscale=upscale, highres=highres) - ) + return job_reply(job_name) # keys that are specially parsed by params and should not show up in with_args @@ -478,25 +515,21 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): logger.info("running chain pipeline with %s stages", len(pipeline.stages)) - output = make_output_name( - server, "chain", base_params, base_size, count=pipeline.outputs(base_params, 0) - ) - job_name = output[0] + job_name = make_job_name(server, "chain", base_params, base_size) # build and run chain pipeline pool.submit( job_name, + JobType.CHAIN, pipeline, server, base_params, StageResult.empty(), - output=output, size=base_size, needs_device=device, ) - step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size)) - return jsonify(json_params(server, output, step_params, base_size)) + return job_reply(job_name) def blend(server: ServerContext, pool: DevicePoolExecutor): @@ -520,15 +553,14 @@ def blend(server: ServerContext, pool: DevicePoolExecutor): device, params, size = pipeline_from_request(server) upscale = build_upscale() - output = make_output_name(server, "upscale", params, size) - job_name = output[0] + job_name = make_job_name(server, "blend", params, size) pool.submit( job_name, + JobType.BLEND, run_blend_pipeline, server, params, size, - output, upscale, # TODO: highres sources, @@ -538,27 +570,26 @@ def blend(server: ServerContext, pool: DevicePoolExecutor): logger.info("upscale job queued for: %s", job_name) - return jsonify(json_params(server, output, params, size, upscale=upscale)) + return job_reply(job_name) def txt2txt(server: ServerContext, pool: DevicePoolExecutor): device, params, size = pipeline_from_request(server) - output = make_output_name(server, "txt2txt", params, size) - job_name = output[0] + job_name = make_job_name(server, "txt2txt", params, size) logger.info("upscale job queued for: %s", job_name) pool.submit( job_name, + JobType.TXT2TXT, run_txt2txt_pipeline, server, params, size, - output, needs_device=device, ) - return jsonify(json_params(server, output, params, size)) + return job_reply(job_name) def cancel(server: ServerContext, pool: DevicePoolExecutor): @@ -601,9 +632,64 @@ def ready(server: ServerContext, pool: DevicePoolExecutor): ) +def job_cancel(server: ServerContext, pool: DevicePoolExecutor): + legacy_job_name = request.args.get("job", None) + job_list = request.args.get("jobs", "").split(",") + + if legacy_job_name is not None: + job_list.append(legacy_job_name) + + if len(job_list) == 0: + return error_reply("at least one job name is required") + + results = {} + for job_name in job_list: + job_name = sanitize_name(job_name) + cancelled = pool.cancel(job_name) + results[job_name] = cancelled + + return multi_image_reply(results) + + +def job_status(server: ServerContext, pool: DevicePoolExecutor): + legacy_job_name = request.args.get("job", None) + job_list = request.args.get("jobs", "").split(",") + + if legacy_job_name is not None: + job_list.append(legacy_job_name) + + if len(job_list) == 0: + return error_reply("at least one job name is required") + + for job_name in job_list: + job_name = sanitize_name(job_name) + status, progress = pool.status(job_name) + + # TODO: accumulate results + if progress is not None: + # TODO: add output paths based on progress.results counter + return image_reply( + job_name, + status, + "TODO", + stages=Progress(progress.stages, 0), + steps=Progress(progress.steps, 0), + tiles=Progress(progress.tiles, 0), + ) + + return image_reply(job_name, status, "TODO") + + def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor): return [ app.route("/api")(wrap_route(introspect, server, app=app)), + # job routes + app.route("/api/job", methods=["POST"])(wrap_route(chain, server, pool=pool)), + app.route("/api/job/cancel", methods=["PUT"])( + wrap_route(job_cancel, server, pool=pool) + ), + app.route("/api/job/status")(wrap_route(job_status, server, pool=pool)), + # settings routes app.route("/api/settings/filters")(wrap_route(list_filters, server)), app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)), app.route("/api/settings/models")(wrap_route(list_models, server)), @@ -614,6 +700,7 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)), app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)), app.route("/api/settings/wildcards")(wrap_route(list_wildcards, server)), + # legacy job routes app.route("/api/img2img", methods=["POST"])( wrap_route(img2img, server, pool=pool) ), @@ -631,6 +718,7 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu ), app.route("/api/chain", methods=["POST"])(wrap_route(chain, server, pool=pool)), app.route("/api/blend", methods=["POST"])(wrap_route(blend, server, pool=pool)), + # deprecated routes app.route("/api/cancel", methods=["PUT"])( wrap_route(cancel, server, pool=pool) ), diff --git a/api/onnx_web/transformers/run.py b/api/onnx_web/transformers/run.py index eb789639..e45ab4f7 100644 --- a/api/onnx_web/transformers/run.py +++ b/api/onnx_web/transformers/run.py @@ -12,7 +12,6 @@ def run_txt2txt_pipeline( _server: ServerContext, params: ImageParams, _size: Size, - output: str, ) -> None: from transformers import AutoTokenizer, GPTJForCausalLM @@ -38,4 +37,4 @@ def run_txt2txt_pipeline( print("Server says: %s" % result_text) - logger.info("finished txt2txt job: %s", output) + logger.info("finished txt2txt job: %s", worker.job) diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index d047ec0d..b2eacec3 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -2,16 +2,18 @@ import gc import importlib import json import threading +from hashlib import sha256 from json import JSONDecodeError from logging import getLogger from os import environ, path from platform import system +from struct import pack from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union import torch from yaml import safe_load -from .params import DeviceParams, SizeChart +from .params import DeviceParams, Param, SizeChart logger = getLogger(__name__) @@ -218,3 +220,34 @@ def load_config_str(raw: str) -> Dict: return json.loads(raw) except JSONDecodeError: return safe_load(raw) + + +HASH_BUFFER_SIZE = 2**22 # 4MB + + +def hash_file(name: str): + sha = sha256() + with open(name, "rb") as f: + while True: + data = f.read(HASH_BUFFER_SIZE) + if not data: + break + + sha.update(data) + + return sha.hexdigest() + + +def hash_value(sha, param: Optional[Param]): + if param is None: + return + elif isinstance(param, bool): + sha.update(bytearray(pack("!B", param))) + elif isinstance(param, float): + sha.update(bytearray(pack("!f", param))) + elif isinstance(param, int): + sha.update(bytearray(pack("!I", param))) + elif isinstance(param, str): + sha.update(param.encode("utf-8")) + else: + logger.warning("cannot hash param: %s, %s", param, type(param)) diff --git a/api/onnx_web/worker/command.py b/api/onnx_web/worker/command.py index 1d7db225..03f15068 100644 --- a/api/onnx_web/worker/command.py +++ b/api/onnx_web/worker/command.py @@ -1,34 +1,61 @@ +from enum import Enum from typing import Any, Callable, Dict +class JobStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + CANCELLED = "cancelled" + UNKNOWN = "unknown" + + +class JobType(str, Enum): + TXT2TXT = "txt2txt" + TXT2IMG = "txt2img" + IMG2IMG = "img2img" + INPAINT = "inpaint" + UPSCALE = "upscale" + BLEND = "blend" + CHAIN = "chain" + + class ProgressCommand: device: str job: str - finished: bool - progress: int - cancelled: bool - failed: bool + job_type: str + status: JobStatus + results: int + steps: int + stages: int + tiles: int def __init__( self, job: str, + job_type: str, device: str, - finished: bool, - progress: int, - cancelled: bool = False, - failed: bool = False, + status: JobStatus, + results: int = 0, + steps: int = 0, + stages: int = 0, + tiles: int = 0, ): self.job = job + self.job_type = job_type self.device = device - self.finished = finished - self.progress = progress - self.cancelled = cancelled - self.failed = failed + self.status = status + self.results = results + self.steps = steps + self.stages = stages + self.tiles = tiles class JobCommand: device: str name: str + job_type: str fn: Callable[..., None] args: Any kwargs: Dict[str, Any] @@ -37,12 +64,14 @@ class JobCommand: self, name: str, device: str, + job_type: str, fn: Callable[..., None], args: Any, kwargs: Dict[str, Any], ): self.device = device self.name = name + self.job_type = job_type self.fn = fn self.args = args self.kwargs = kwargs diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 2d6d0278..e3de145c 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -2,21 +2,23 @@ from logging import getLogger from os import getpid from typing import Any, Callable, Optional +import numpy as np from torch.multiprocessing import Queue, Value from ..errors import CancelledException from ..params import DeviceParams -from .command import JobCommand, ProgressCommand +from .command import JobCommand, JobStatus, ProgressCommand logger = getLogger(__name__) -ProgressCallback = Callable[[int, int, Any], None] +ProgressCallback = Callable[[int, int, np.ndarray], None] class WorkerContext: cancel: "Value[bool]" job: Optional[str] + job_type: Optional[str] name: str pending: "Queue[JobCommand]" active_pid: "Value[int]" @@ -41,6 +43,7 @@ class WorkerContext: timeout: float, ): self.job = None + self.job_type = None self.name = name self.device = device self.cancel = cancel @@ -54,9 +57,15 @@ class WorkerContext: self.retries = retries self.timeout = timeout - def start(self, job: str) -> None: - self.job = job + def start(self, job: JobCommand) -> None: + # set job name and type + self.job = job.name + self.job_type = job.job_type + + # reset retries self.retries = self.initial_retries + + # clear flags self.set_cancel(cancel=False) self.set_idle(idle=False) @@ -81,7 +90,7 @@ class WorkerContext: def get_progress(self) -> int: if self.last_progress is not None: - return self.last_progress.progress + return self.last_progress.steps return 0 @@ -112,13 +121,11 @@ class WorkerContext: logger.debug("setting progress for job %s to %s", self.job, progress) self.last_progress = ProgressCommand( self.job, + self.job_type, self.device.device, - False, - progress, - self.is_cancelled(), - False, + JobStatus.RUNNING, + steps=progress, ) - self.progress.put( self.last_progress, block=False, @@ -131,11 +138,10 @@ class WorkerContext: logger.debug("setting finished for job %s", self.job) self.last_progress = ProgressCommand( self.job, + self.job_type, self.device.device, - True, - self.get_progress(), - self.is_cancelled(), - False, + JobStatus.SUCCESS, # TODO: FAILED + steps=self.get_progress(), ) self.progress.put( self.last_progress, @@ -150,11 +156,10 @@ class WorkerContext: try: self.last_progress = ProgressCommand( self.job, + self.job_type, self.device.device, - True, - self.get_progress(), - self.is_cancelled(), - True, + JobStatus.FAILED, + steps=self.get_progress(), ) self.progress.put( self.last_progress, @@ -162,25 +167,3 @@ class WorkerContext: ) except Exception: logger.exception("error setting failure on job %s", self.job) - - -class JobStatus: - name: str - device: str - progress: int - cancelled: bool - finished: bool - - def __init__( - self, - name: str, - device: DeviceParams, - progress: int = 0, - cancelled: bool = False, - finished: bool = False, - ) -> None: - self.name = name - self.device = device.device - self.progress = progress - self.cancelled = cancelled - self.finished = finished diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index d210421e..a1dd0da0 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -8,7 +8,7 @@ from torch.multiprocessing import Process, Queue, Value from ..params import DeviceParams from ..server import ServerContext -from .command import JobCommand, ProgressCommand +from .command import JobCommand, JobStatus, ProgressCommand from .context import WorkerContext from .utils import Interval from .worker import worker_main @@ -201,6 +201,10 @@ class DevicePoolExecutor: should be cancelled on the next progress callback. """ + if key in self.cancelled_jobs: + logger.debug("cancelling already cancelled job: %s", key) + return True + for job in self.finished_jobs: if job.job == key: logger.debug("cannot cancel finished job: %s", key) @@ -209,6 +213,9 @@ class DevicePoolExecutor: for job in self.pending_jobs: if job.name == key: self.pending_jobs.remove(job) + self.cancelled_jobs.append( + key + ) # ensure workers never pick up this job and the status endpoint knows about it later logger.info("cancelled pending job: %s", key) return True @@ -221,28 +228,31 @@ class DevicePoolExecutor: self.cancelled_jobs.append(key) return True - def done(self, key: str) -> Tuple[bool, Optional[ProgressCommand]]: + def status(self, key: str) -> Tuple[JobStatus, Optional[ProgressCommand]]: """ Check if a job has been finished and report the last progress update. - - If the job is still pending, the first item will be True and there will be no ProgressCommand. """ + + if key in self.cancelled_jobs: + logger.debug("checking status for cancelled job: %s", key) + return (JobStatus.CANCELLED, None) + if key in self.running_jobs: logger.debug("checking status for running job: %s", key) - return (False, self.running_jobs[key]) + return (JobStatus.RUNNING, self.running_jobs[key]) for job in self.finished_jobs: if job.job == key: logger.debug("checking status for finished job: %s", key) - return (False, job) + return (job.status, job) for job in self.pending_jobs: if job.name == key: logger.debug("checking status for pending job: %s", key) - return (True, None) + return (JobStatus.PENDING, None) logger.trace("checking status for unknown job: %s", key) - return (False, None) + return (JobStatus.UNKNOWN, None) def join(self): logger.info("stopping worker pool") @@ -383,6 +393,7 @@ class DevicePoolExecutor: def submit( self, key: str, + job_type: str, fn: Callable[..., None], /, *args, @@ -399,56 +410,63 @@ class DevicePoolExecutor: ) # build and queue job - job = JobCommand(key, device, fn, args, kwargs) + job = JobCommand(key, device, job_type, fn, args, kwargs) self.pending_jobs.append(job) - def status(self) -> Dict[str, List[Tuple[str, int, bool, bool, bool, bool]]]: + def summary(self) -> Dict[str, List[Tuple[str, int, JobStatus]]]: """ Returns a tuple of: job/device, progress, progress, finished, cancelled, failed """ - return { - "cancelled": [], - "finished": [ + + jobs: Tuple[str, int, JobStatus] = [] + jobs.extend( + [ ( - job.job, - job.progress, - False, - job.finished, - job.cancelled, - job.failed, + job, + 0, + JobStatus.CANCELLED, ) - for job in self.finished_jobs - ], - "pending": [ + for job in self.cancelled_jobs + ] + ) + jobs.extend( + [ ( job.name, 0, - True, - False, - False, - False, + JobStatus.PENDING, ) for job in self.pending_jobs - ], - "running": [ + ] + ) + jobs.extend( + [ ( name, - job.progress, - False, - job.finished, - job.cancelled, - job.failed, + job.steps, + job.status, ) for name, job in self.running_jobs.items() - ], - "total": [ + ] + ) + jobs.extend( + [ + ( + job.job, + job.steps, + job.status, + ) + for job in self.finished_jobs + ] + ) + + return { + "jobs": jobs, + "workers": [ ( device, total, self.workers[device].is_alive(), - False, - False, - False, ) for device, total in self.total_jobs.items() ], @@ -476,20 +494,18 @@ class DevicePoolExecutor: self.cancelled_jobs.remove(progress.job) def update_job(self, progress: ProgressCommand): - if progress.finished: + if progress.status in [JobStatus.SUCCESS, JobStatus.FAILED]: return self.finish_job(progress) # move from pending to running - logger.debug( - "progress update for job: %s to %s", progress.job, progress.progress - ) + logger.debug("progress update for job: %s to %s", progress.job, progress.steps) self.running_jobs[progress.job] = progress self.pending_jobs[:] = [ job for job in self.pending_jobs if job.name != progress.job ] # increment job counter if this is the start of a new job - if progress.progress == 0: + if progress.steps == 0: if progress.device in self.total_jobs: self.total_jobs[progress.device] += 1 else: diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 55ebcaac..062d58e4 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -57,7 +57,7 @@ def worker_main( logger.info("worker %s got job: %s", worker.device.device, job.name) # clear flags and save the job name - worker.start(job.name) + worker.start(job) logger.info("starting job: %s", job.name) # reset progress, which does a final check for cancellation diff --git a/api/scripts/test-diffusers.py b/api/scripts/test-diffusers.py index 5852eb6b..3fba1ce4 100644 --- a/api/scripts/test-diffusers.py +++ b/api/scripts/test-diffusers.py @@ -1,12 +1,6 @@ from diffusers import OnnxStableDiffusionPipeline from os import path -import cv2 -import numpy as np -import onnxruntime as ort -import torch -import time - cfg = 8 steps = 22 height = 512 diff --git a/api/tests/test_diffusers/test_run.py b/api/tests/test_diffusers/test_run.py index 322712e4..1796d978 100644 --- a/api/tests/test_diffusers/test_run.py +++ b/api/tests/test_diffusers/test_run.py @@ -22,6 +22,7 @@ from onnx_web.params import ( UpscaleParams, ) from onnx_web.server.context import ServerContext +from onnx_web.worker.command import JobCommand from onnx_web.worker.context import WorkerContext from tests.helpers import ( TEST_MODEL_DIFFUSION_SD15, @@ -57,7 +58,7 @@ class TestTxt2ImgPipeline(unittest.TestCase): 3, 0.1, ) - worker.start("test") + worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {})) run_txt2img_pipeline( worker, @@ -72,7 +73,6 @@ class TestTxt2ImgPipeline(unittest.TestCase): 1, ), Size(256, 256), - ["test-txt2img-basic.png"], UpscaleParams("test"), HighresParams(False, 1, 0, 0), ) @@ -103,7 +103,7 @@ class TestTxt2ImgPipeline(unittest.TestCase): 3, 0.1, ) - worker.start("test") + worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {})) run_txt2img_pipeline( worker, @@ -119,7 +119,6 @@ class TestTxt2ImgPipeline(unittest.TestCase): batch=2, ), Size(256, 256), - ["test-txt2img-batch-0.png", "test-txt2img-batch-1.png"], UpscaleParams("test"), HighresParams(False, 1, 0, 0), ) @@ -152,7 +151,7 @@ class TestTxt2ImgPipeline(unittest.TestCase): 3, 0.1, ) - worker.start("test") + worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {})) run_txt2img_pipeline( worker, @@ -168,7 +167,6 @@ class TestTxt2ImgPipeline(unittest.TestCase): unet_tile=256, ), Size(256, 256), - ["test-txt2img-highres.png"], UpscaleParams("test", scale=2, outscale=2), HighresParams(True, 2, 0, 0), ) @@ -198,7 +196,7 @@ class TestTxt2ImgPipeline(unittest.TestCase): 3, 0.1, ) - worker.start("test") + worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {})) run_txt2img_pipeline( worker, @@ -214,7 +212,6 @@ class TestTxt2ImgPipeline(unittest.TestCase): batch=2, ), Size(256, 256), - ["test-txt2img-highres-batch-0.png", "test-txt2img-highres-batch-1.png"], UpscaleParams("test"), HighresParams(True, 2, 0, 0), ) @@ -230,7 +227,7 @@ class TestImg2ImgPipeline(unittest.TestCase): @test_needs_models([TEST_MODEL_DIFFUSION_SD15]) def test_basic(self): worker = test_worker() - worker.start("test") + worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {})) source = Image.new("RGB", (64, 64), "black") run_img2img_pipeline( @@ -245,7 +242,6 @@ class TestImg2ImgPipeline(unittest.TestCase): 1, 1, ), - ["test-img2img.png"], UpscaleParams("test"), HighresParams(False, 1, 0, 0), source, @@ -259,7 +255,7 @@ class TestInpaintPipeline(unittest.TestCase): @test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT]) def test_basic_white(self): worker = test_worker() - worker.start("test") + worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {})) source = Image.new("RGB", (64, 64), "black") mask = Image.new("RGB", (64, 64), "white") @@ -277,7 +273,6 @@ class TestInpaintPipeline(unittest.TestCase): unet_tile=64, ), Size(*source.size), - ["test-inpaint-white.png"], UpscaleParams("test"), HighresParams(False, 1, 0, 0), source, @@ -296,7 +291,7 @@ class TestInpaintPipeline(unittest.TestCase): @test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT]) def test_basic_black(self): worker = test_worker() - worker.start("test") + worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {})) source = Image.new("RGB", (64, 64), "black") mask = Image.new("RGB", (64, 64), "black") @@ -314,7 +309,6 @@ class TestInpaintPipeline(unittest.TestCase): unet_tile=64, ), Size(*source.size), - ["test-inpaint-black.png"], UpscaleParams("test"), HighresParams(False, 1, 0, 0), source, @@ -353,7 +347,7 @@ class TestUpscalePipeline(unittest.TestCase): 3, 0.1, ) - worker.start("test") + worker.start(JobCommand("test", "test", "test", run_upscale_pipeline, [], {})) source = Image.new("RGB", (64, 64), "black") run_upscale_pipeline( @@ -369,7 +363,6 @@ class TestUpscalePipeline(unittest.TestCase): 1, ), Size(256, 256), - ["test-upscale.png"], UpscaleParams("test"), HighresParams(False, 1, 0, 0), source, @@ -399,7 +392,7 @@ class TestBlendPipeline(unittest.TestCase): 3, 0.1, ) - worker.start("test") + worker.start(JobCommand("test", "test", "test", run_blend_pipeline, [], {})) source = Image.new("RGBA", (64, 64), "black") mask = Image.new("RGBA", (64, 64), "white") @@ -417,7 +410,6 @@ class TestBlendPipeline(unittest.TestCase): unet_tile=64, ), Size(64, 64), - ["test-blend.png"], UpscaleParams("test"), [source, source], mask, diff --git a/api/tests/worker/test_pool.py b/api/tests/worker/test_pool.py index ea709156..721fe87d 100644 --- a/api/tests/worker/test_pool.py +++ b/api/tests/worker/test_pool.py @@ -5,6 +5,7 @@ from typing import Optional from onnx_web.params import DeviceParams from onnx_web.server.context import ServerContext +from onnx_web.worker.command import JobStatus from onnx_web.worker.pool import DevicePoolExecutor TEST_JOIN_TIMEOUT = 0.2 @@ -50,11 +51,11 @@ class TestWorkerPool(unittest.TestCase): self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool.start() - self.pool.submit("test", wait_job, lock=lock) - self.assertEqual(self.pool.done("test"), (True, None)) + self.pool.submit("test", "test", wait_job, lock=lock) + self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None)) self.assertTrue(self.pool.cancel("test")) - self.assertEqual(self.pool.done("test"), (False, None)) + self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None)) def test_cancel_running(self): pass @@ -88,12 +89,14 @@ class TestWorkerPool(unittest.TestCase): self.pool = DevicePoolExecutor( server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 ) + + lock.clear() self.pool.start(lock) - self.pool.submit("test", test_job) + self.pool.submit("test", "test", test_job) sleep(5.0) - pending, _progress = self.pool.done("test") - self.assertFalse(pending) + status, _progress = self.pool.status("test") + self.assertEqual(status, JobStatus.RUNNING) def test_done_pending(self): device = DeviceParams("cpu", "CPUProvider") @@ -102,9 +105,9 @@ class TestWorkerPool(unittest.TestCase): self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool.start(lock) - self.pool.submit("test1", test_job) - self.pool.submit("test2", test_job) - self.assertTrue(self.pool.done("test2"), (True, None)) + self.pool.submit("test1", "test", test_job) + self.pool.submit("test2", "test", test_job) + self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None)) lock.set() @@ -119,12 +122,12 @@ class TestWorkerPool(unittest.TestCase): server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 ) self.pool.start() - self.pool.submit("test", wait_job) - self.assertEqual(self.pool.done("test"), (True, None)) + self.pool.submit("test", "test", wait_job) + self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None)) sleep(5.0) - pending, _progress = self.pool.done("test") - self.assertFalse(pending) + status, _progress = self.pool.status("test") + self.assertEqual(status, JobStatus.SUCCESS) def test_recycle_live(self): pass diff --git a/api/tests/worker/test_worker.py b/api/tests/worker/test_worker.py index 6365fac9..a0684fff 100644 --- a/api/tests/worker/test_worker.py +++ b/api/tests/worker/test_worker.py @@ -40,7 +40,7 @@ class WorkerMainTests(unittest.TestCase): nonlocal status status = exit_status - job = JobCommand("test", "test", main_interrupt, [], {}) + job = JobCommand("test", "test", "test", main_interrupt, [], {}) cancel = Value("L", False) logs = Queue() pending = Queue() @@ -75,7 +75,7 @@ class WorkerMainTests(unittest.TestCase): nonlocal status status = exit_status - job = JobCommand("test", "test", main_retry, [], {}) + job = JobCommand("test", "test", "test", main_retry, [], {}) cancel = Value("L", False) logs = Queue() pending = Queue() @@ -144,7 +144,7 @@ class WorkerMainTests(unittest.TestCase): nonlocal status status = exit_status - job = JobCommand("test", "test", main_memory, [], {}) + job = JobCommand("test", "test", "test", main_memory, [], {}) cancel = Value("L", False) logs = Queue() pending = Queue() diff --git a/docs/index.md b/docs/index.md index c0cfffc3..e839f1ac 100644 --- a/docs/index.md +++ b/docs/index.md @@ -3,7 +3,7 @@ onnx-web is designed to simplify the process of running Stable Diffusion and other [ONNX models](https://onnx.ai) so you can focus on making high quality, high resolution art. With the efficiency of hardware acceleration on both AMD and Nvidia GPUs, and offering a reliable CPU software fallback, it offers the full feature set on desktop, laptops, and -servers with a seamless user experience. +multi-GPU servers with a seamless user experience. You can navigate through the user-friendly web UI, hosted on Github Pages and accessible across all major browsers, including your go-to mobile device. Here, you have the flexibility to choose diffusion models and accelerators for each @@ -84,18 +84,6 @@ This is an incomplete list of new and interesting features: - includes both the API and GUI bundle in a single container - runs well on [RunPod](https://www.runpod.io/), [Vast.ai](https://vast.ai/), and other GPU container hosting services -## Contents - -- [onnx-web](#onnx-web) - - [Features](#features) - - [Contents](#contents) - - [Setup](#setup) - - [Adding your own models](#adding-your-own-models) - - [Usage](#usage) - - [Known errors and solutions](#known-errors-and-solutions) - - [Running the containers](#running-the-containers) - - [Credits](#credits) - ## Setup There are a few ways to run onnx-web: diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 8ea871ea..65eea73e 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -4,10 +4,7 @@ import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils'; import { ServerParams } from '../config.js'; import { FilterResponse, - ImageResponse, - ImageResponseWithRetry, ModelResponse, - ReadyResponse, RetryParams, WriteExtrasResponse, } from '../types/api.js'; @@ -27,6 +24,7 @@ import { } from '../types/params.js'; import { range } from '../utils.js'; import { ApiClient } from './base.js'; +import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/api-v2.js'; /** * Fixed precision for integer parameters. @@ -43,8 +41,9 @@ export const FIXED_INTEGER = 0; export const FIXED_FLOAT = 2; export const STATUS_SUCCESS = 200; -export function equalResponse(a: ImageResponse, b: ImageResponse): boolean { - return a.outputs === b.outputs; +export function equalResponse(a: JobResponse, b: JobResponse): boolean { + return a.name === b.name && a.status === b.status && a.type === b.type; + // return a.outputs === b.outputs; } /** @@ -141,8 +140,8 @@ export function appendHighresToURL(url: URL, highres: HighresParams) { * Make an API client using the given API root and fetch client. */ export function makeClient(root: string, token: Maybe = undefined, f = fetch): ApiClient { - function parseRequest(url: URL, options: RequestInit): Promise { - return f(url, options).then((res) => parseApiResponse(root, res)); + function parseRequest(url: URL, options: RequestInit): Promise { + return f(url, options).then((res) => parseJobResponse(root, res)); } return { @@ -218,7 +217,7 @@ export function makeClient(root: string, token: Maybe = undefined, f = f const res = await f(path); return await res.json() as Array; }, - async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { + async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { const url = makeImageURL(root, 'img2img', params); appendModelToURL(url, model); @@ -240,12 +239,12 @@ export function makeClient(root: string, token: Maybe = undefined, f = f const body = new FormData(); body.append('source', params.source, 'source'); - const image = await parseRequest(url, { + const job = await parseRequest(url, { body, method: 'POST', }); return { - image, + job, retry: { type: 'img2img', model, @@ -254,7 +253,7 @@ export function makeClient(root: string, token: Maybe = undefined, f = f }, }; }, - async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { + async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { const url = makeImageURL(root, 'txt2img', params); appendModelToURL(url, model); @@ -274,11 +273,11 @@ export function makeClient(root: string, token: Maybe = undefined, f = f appendHighresToURL(url, highres); } - const image = await parseRequest(url, { + const job = await parseRequest(url, { method: 'POST', }); return { - image, + job, retry: { type: 'txt2img', model, @@ -288,7 +287,7 @@ export function makeClient(root: string, token: Maybe = undefined, f = f }, }; }, - async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { + async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { const url = makeImageURL(root, 'inpaint', params); appendModelToURL(url, model); @@ -309,12 +308,12 @@ export function makeClient(root: string, token: Maybe = undefined, f = f body.append('mask', params.mask, 'mask'); body.append('source', params.source, 'source'); - const image = await parseRequest(url, { + const job = await parseRequest(url, { body, method: 'POST', }); return { - image, + job, retry: { type: 'inpaint', model, @@ -323,7 +322,7 @@ export function makeClient(root: string, token: Maybe = undefined, f = f }, }; }, - async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { + async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { const url = makeImageURL(root, 'inpaint', params); appendModelToURL(url, model); @@ -361,12 +360,12 @@ export function makeClient(root: string, token: Maybe = undefined, f = f body.append('mask', params.mask, 'mask'); body.append('source', params.source, 'source'); - const image = await parseRequest(url, { + const job = await parseRequest(url, { body, method: 'POST', }); return { - image, + job, retry: { type: 'outpaint', model, @@ -375,7 +374,7 @@ export function makeClient(root: string, token: Maybe = undefined, f = f }, }; }, - async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { + async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise { const url = makeApiUrl(root, 'upscale'); appendModelToURL(url, model); @@ -396,12 +395,12 @@ export function makeClient(root: string, token: Maybe = undefined, f = f const body = new FormData(); body.append('source', params.source, 'source'); - const image = await parseRequest(url, { + const job = await parseRequest(url, { body, method: 'POST', }); return { - image, + job, retry: { type: 'upscale', model, @@ -410,7 +409,7 @@ export function makeClient(root: string, token: Maybe = undefined, f = f }, }; }, - async blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise { + async blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise { const url = makeApiUrl(root, 'blend'); appendModelToURL(url, model); @@ -426,12 +425,12 @@ export function makeClient(root: string, token: Maybe = undefined, f = f body.append(name, params.sources[i], name); } - const image = await parseRequest(url, { + const job = await parseRequest(url, { body, method: 'POST', }); return { - image, + job, retry: { type: 'blend', model, @@ -440,8 +439,8 @@ export function makeClient(root: string, token: Maybe = undefined, f = f } }; }, - async chain(model: ModelParams, chain: ChainPipeline): Promise { - const url = makeApiUrl(root, 'chain'); + async chain(model: ModelParams, chain: ChainPipeline): Promise { + const url = makeApiUrl(root, 'job'); const body = JSON.stringify({ ...chain, platform: model.platform, @@ -456,23 +455,23 @@ export function makeClient(root: string, token: Maybe = undefined, f = f method: 'POST', }); }, - async ready(key: string): Promise { - const path = makeApiUrl(root, 'ready'); - path.searchParams.append('output', key); + async status(keys: Array): Promise> { + const path = makeApiUrl(root, 'job', 'status'); + path.searchParams.append('jobs', keys.join(',')); const res = await f(path); - return await res.json() as ReadyResponse; + return await res.json() as Array; }, - async cancel(key: string): Promise { - const path = makeApiUrl(root, 'cancel'); - path.searchParams.append('output', key); + async cancel(keys: Array): Promise> { + const path = makeApiUrl(root, 'job', 'cancel'); + path.searchParams.append('jobs', keys.join(',')); const res = await f(path, { method: 'PUT', }); - return res.status === STATUS_SUCCESS; + return await res.json() as Array; }, - async retry(retry: RetryParams): Promise { + async retry(retry: RetryParams): Promise { switch (retry.type) { case 'blend': return this.blend(retry.model, retry.params, retry.upscale); @@ -491,7 +490,7 @@ export function makeClient(root: string, token: Maybe = undefined, f = f } }, async restart(): Promise { - const path = makeApiUrl(root, 'restart'); + const path = makeApiUrl(root, 'worker', 'restart'); if (doesExist(token)) { path.searchParams.append('token', token); @@ -502,8 +501,8 @@ export function makeClient(root: string, token: Maybe = undefined, f = f }); return res.status === STATUS_SUCCESS; }, - async status(): Promise> { - const path = makeApiUrl(root, 'status'); + async workers(): Promise> { + const path = makeApiUrl(root, 'worker', 'status'); if (doesExist(token)) { path.searchParams.append('token', token); @@ -512,6 +511,9 @@ export function makeClient(root: string, token: Maybe = undefined, f = f const res = await f(path); return res.json(); }, + outputURL(image: SuccessJobResponse, index: number): string { + return new URL(joinPath('output', image.outputs[index]), root).toString(); + }, }; } @@ -521,24 +523,9 @@ export function makeClient(root: string, token: Maybe = undefined, f = f * The server sends over the output key, and the client is in the best position to turn * that into a full URL, since it already knows the root URL of the server. */ -export async function parseApiResponse(root: string, res: Response): Promise { - type LimitedResponse = Omit & { outputs: Array }; - +export async function parseJobResponse(root: string, res: Response): Promise { if (res.status === STATUS_SUCCESS) { - const data = await res.json() as LimitedResponse; - - const outputs = data.outputs.map((output) => { - const url = new URL(joinPath('output', output), root).toString(); - return { - key: output, - url, - }; - }); - - return { - ...data, - outputs, - }; + return await res.json() as JobResponse; } else { throw new Error('request error'); } diff --git a/gui/src/client/base.ts b/gui/src/client/base.ts index 62ef440a..0c1df7b6 100644 --- a/gui/src/client/base.ts +++ b/gui/src/client/base.ts @@ -1,12 +1,19 @@ import { ServerParams } from '../config.js'; import { ExtrasFile } from '../types/model.js'; -import { WriteExtrasResponse, FilterResponse, ModelResponse, ImageResponseWithRetry, ImageResponse, ReadyResponse, RetryParams } from '../types/api.js'; +import { WriteExtrasResponse, FilterResponse, ModelResponse, RetryParams } from '../types/api.js'; import { ChainPipeline } from '../types/chain.js'; import { ModelParams, Txt2ImgParams, UpscaleParams, HighresParams, Img2ImgParams, InpaintParams, OutpaintParams, UpscaleReqParams, BlendParams } from '../types/params.js'; +import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/api-v2.js'; export interface ApiClient { + /** + * Get the first extras file. + */ extras(): Promise; + /** + * Update the first extras file. + */ writeExtras(extras: ExtrasFile): Promise; /** @@ -51,54 +58,60 @@ export interface ApiClient { translation: Record; }>>; + /** + * Get the available wildcards. + */ wildcards(): Promise>; /** * Start a txt2img pipeline. */ - txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; /** * Start an im2img pipeline. */ - img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; /** * Start an inpaint pipeline. */ - inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; /** * Start an outpaint pipeline. */ - outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; /** * Start an upscale pipeline. */ - upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; + upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise; /** * Start a blending pipeline. */ - blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise; + blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise; - chain(model: ModelParams, chain: ChainPipeline): Promise; + /** + * Start a custom chain pipeline. + */ + chain(model: ModelParams, chain: ChainPipeline): Promise; /** * Check whether job has finished and its output is ready. */ - ready(key: string): Promise; + status(keys: Array): Promise>; /** * Cancel an existing job. */ - cancel(key: string): Promise; + cancel(keys: Array): Promise>; /** * Retry a previous job using the same parameters. */ - retry(params: RetryParams): Promise; + retry(params: RetryParams): Promise; /** * Restart the image job workers. @@ -108,5 +121,7 @@ export interface ApiClient { /** * Check the status of the image job workers. */ - status(): Promise>; + workers(): Promise>; + + outputURL(image: SuccessJobResponse, index: number): string; } diff --git a/gui/src/client/local.ts b/gui/src/client/local.ts index 06dcd6b0..64804876 100644 --- a/gui/src/client/local.ts +++ b/gui/src/client/local.ts @@ -48,7 +48,7 @@ export const LOCAL_CLIENT = { async params() { throw new NoServerError(); }, - async ready(key) { + async status(key) { throw new NoServerError(); }, async cancel(key) { @@ -78,7 +78,10 @@ export const LOCAL_CLIENT = { async restart() { throw new NoServerError(); }, - async status() { + async workers() { throw new NoServerError(); - } + }, + outputURL(image, index) { + throw new NoServerError(); + }, } as ApiClient; diff --git a/gui/src/client/utils.ts b/gui/src/client/utils.ts index 967c818d..bde45282 100644 --- a/gui/src/client/utils.ts +++ b/gui/src/client/utils.ts @@ -97,11 +97,19 @@ export function expandRanges(range: string): Array { export const GRID_TILE_SIZE = 8192; // eslint-disable-next-line max-params -export function makeTxt2ImgGridPipeline(grid: PipelineGrid, model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): ChainPipeline { +export function makeTxt2ImgGridPipeline( + grid: PipelineGrid, + model: ModelParams, + params: Txt2ImgParams, + upscale?: UpscaleParams, + highres?: HighresParams, +): ChainPipeline { const pipeline: ChainPipeline = { defaults: { ...model, ...params, + ...(upscale || {}), + ...(highres || {}), }, stages: [], }; diff --git a/gui/src/components/ImageHistory.tsx b/gui/src/components/ImageHistory.tsx index 20a520ed..38d9e9bb 100644 --- a/gui/src/components/ImageHistory.tsx +++ b/gui/src/components/ImageHistory.tsx @@ -10,6 +10,7 @@ import { OnnxState, StateContext } from '../state/full.js'; import { ErrorCard } from './card/ErrorCard.js'; import { ImageCard } from './card/ImageCard.js'; import { LoadingCard } from './card/LoadingCard.js'; +import { JobStatus } from '../types/api-v2.js'; export function ImageHistory() { const store = mustExist(useContext(StateContext)); @@ -25,19 +26,19 @@ export function ImageHistory() { const limited = history.slice(0, limit); for (const item of limited) { - const key = item.image.outputs[0].key; + const key = item.image.name; - if (doesExist(item.ready) && item.ready.ready) { - if (item.ready.cancelled || item.ready.failed) { - children.push([key, ]); - continue; - } - - children.push([key, ]); - continue; + switch (item.image.status) { + case JobStatus.SUCCESS: + children.push([key, ]); + break; + case JobStatus.FAILED: + children.push([key, ]); + break; + default: + children.push([key, ]); + break; } - - children.push([key, ]); } return {children.map(([key, child]) => {child})}; diff --git a/gui/src/components/card/ErrorCard.tsx b/gui/src/components/card/ErrorCard.tsx index f7106584..78b2698e 100644 --- a/gui/src/components/card/ErrorCard.tsx +++ b/gui/src/components/card/ErrorCard.tsx @@ -10,16 +10,15 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; -import { ImageResponse, ReadyResponse, RetryParams } from '../../types/api.js'; +import { FailedJobResponse, RetryParams } from '../../types/api-v2.js'; export interface ErrorCardProps { - image: ImageResponse; - ready: ReadyResponse; + image: FailedJobResponse; retry: Maybe; } export function ErrorCard(props: ErrorCardProps) { - const { image, ready, retry: retryParams } = props; + const { image, retry: retryParams } = props; const client = mustExist(useContext(ClientContext)); const { params } = mustExist(useContext(ConfigContext)); @@ -32,8 +31,8 @@ export function ErrorCard(props: ErrorCardProps) { removeHistory(image); if (doesExist(retryParams)) { - const { image: nextImage, retry: nextRetry } = await client.retry(retryParams); - pushHistory(nextImage, nextRetry); + const { job: nextJob, retry: nextRetry } = await client.retry(retryParams); + pushHistory(nextJob, nextRetry); } } @@ -52,10 +51,11 @@ export function ErrorCard(props: ErrorCardProps) { spacing={2} sx={{ alignItems: 'center' }} > - {t('loading.progress', { - current: ready.progress, - total: image.params.steps, - })} + + {t('loading.progress', image.steps)} +
+ {image.error} +
retry.mutate()}> diff --git a/gui/src/components/card/ImageCard.tsx b/gui/src/components/card/ImageCard.tsx index e614d430..75a2ac0c 100644 --- a/gui/src/components/card/ImageCard.tsx +++ b/gui/src/components/card/ImageCard.tsx @@ -2,21 +2,22 @@ import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils' import { ArrowLeft, ArrowRight, Blender, Brush, ContentCopy, Delete, Download, ZoomOutMap } from '@mui/icons-material'; import { Box, Card, CardContent, CardMedia, Grid, IconButton, Menu, MenuItem, Paper, Tooltip } from '@mui/material'; import * as React from 'react'; -import { useContext, useState } from 'react'; +import { useContext, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { useHash } from 'react-use/lib/useHash'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { ConfigContext, OnnxState, StateContext } from '../../state/full.js'; -import { ImageResponse } from '../../types/api.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { range, visibleIndex } from '../../utils.js'; import { BLEND_SOURCES } from '../../constants.js'; +import { JobResponse, SuccessJobResponse } from '../../types/api-v2.js'; +import { getApiRoot } from '../../config.js'; export interface ImageCardProps { - image: ImageResponse; + image: SuccessJobResponse; - onDelete?: (key: ImageResponse) => void; + onDelete?: (key: JobResponse) => void; } export function GridItem(props: { xs: number; children: React.ReactNode }) { @@ -27,18 +28,19 @@ export function GridItem(props: { xs: number; children: React.ReactNode }) { export function ImageCard(props: ImageCardProps) { const { image } = props; - const { params, outputs, size } = image; + const { metadata, outputs } = image; const [_hash, setHash] = useHash(); const [blendAnchor, setBlendAnchor] = useState>(); const [saveAnchor, setSaveAnchor] = useState>(); + const client = mustExist(useContext(ClientContext)); const config = mustExist(useContext(ConfigContext)); const store = mustExist(useContext(StateContext)); const { setBlend, setImg2Img, setInpaint, setUpscale } = useStore(store, selectActions, shallow); async function loadSource() { - const req = await fetch(outputs[index].url); + const req = await fetch(url); return req.blob(); } @@ -84,12 +86,12 @@ export function ImageCard(props: ImageCardProps) { } function downloadImage() { - window.open(outputs[index].url, '_blank'); + window.open(url, '_blank'); close(); } function downloadMetadata() { - window.open(outputs[index].url + '.json', '_blank'); + window.open(url + '.json', '_blank'); close(); } @@ -106,14 +108,16 @@ export function ImageCard(props: ImageCardProps) { return mustDefault(t(`${key}.${name}`), name); } - const model = getLabel('model', params.model); - const scheduler = getLabel('scheduler', params.scheduler); + const url = useMemo(() => client.outputURL(image, index), [image, index]); + + const model = getLabel('model', metadata[index].model); + const scheduler = getLabel('scheduler', metadata[index].scheduler); return @@ -146,12 +150,12 @@ export function ImageCard(props: ImageCardProps) { {t('modelType.diffusion', {count: 1})}: {model} {t('parameter.scheduler')}: {scheduler} - {t('parameter.seed')}: {params.seed} - {t('parameter.cfg')}: {params.cfg} - {t('parameter.steps')}: {params.steps} - {t('parameter.size')}: {size.width}x{size.height} + {t('parameter.seed')}: {metadata[index].params.seed} + {t('parameter.cfg')}: {metadata[index].params.cfg} + {t('parameter.steps')}: {metadata[index].params.steps} + {t('parameter.size')}: {metadata[index].size.width}x{metadata[index].size.height} - {params.prompt} + {metadata[index].params.prompt} diff --git a/gui/src/components/card/LoadingCard.tsx b/gui/src/components/card/LoadingCard.tsx index 71bfb5f0..e2a81247 100644 --- a/gui/src/components/card/LoadingCard.tsx +++ b/gui/src/components/card/LoadingCard.tsx @@ -1,4 +1,4 @@ -import { doesExist, mustExist } from '@apextoaster/js-utils'; +import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils'; import { Box, Button, Card, CardContent, CircularProgress, Typography } from '@mui/material'; import { Stack } from '@mui/system'; import { useMutation, useQuery } from '@tanstack/react-query'; @@ -10,19 +10,17 @@ import { shallow } from 'zustand/shallow'; import { POLL_TIME } from '../../config.js'; import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; -import { ImageResponse } from '../../types/api.js'; +import { JobResponse, JobStatus } from '../../types/api-v2.js'; const LOADING_PERCENT = 100; const LOADING_OVERAGE = 99; export interface LoadingCardProps { - image: ImageResponse; - index: number; + image: JobResponse; } export function LoadingCard(props: LoadingCardProps) { - const { image, index } = props; - const { steps } = props.image.params; + const { image } = props; const client = mustExist(useContext(ClientContext)); const { params } = mustExist(useContext(ConfigContext)); @@ -31,50 +29,22 @@ export function LoadingCard(props: LoadingCardProps) { const { removeHistory, setReady } = useStore(store, selectActions, shallow); const { t } = useTranslation(); - const cancel = useMutation(() => client.cancel(image.outputs[index].key)); - const ready = useQuery(['ready', image.outputs[index].key], () => client.ready(image.outputs[index].key), { + const cancel = useMutation(() => client.cancel([image.name])); + const ready = useQuery(['ready', image.name], () => client.status([image.name]), { // data will always be ready without this, even if the API says its not cacheTime: 0, refetchInterval: POLL_TIME, }); - function getProgress() { - if (doesExist(ready.data)) { - return ready.data.progress; - } - - return 0; - } - - function getPercent() { - const progress = getProgress(); - if (progress > steps) { - // steps was not complete, show 99% until done - return LOADING_OVERAGE; - } - - const pct = progress / steps; - return Math.ceil(pct * LOADING_PERCENT); - } - - function getTotal() { - const progress = getProgress(); - if (progress > steps) { - // steps was not complete, show 99% until done - return t('loading.unknown'); - } - - return steps.toFixed(0); - } - function getReady() { - return doesExist(ready.data) && ready.data.ready; + return doesExist(ready.data) && ready.data[0].status === JobStatus.SUCCESS; } function renderProgress() { - const progress = getProgress(); - if (progress > 0 && progress <= steps) { - return ; + const progress = getProgress(ready.data); + const total = getTotal(ready.data); + if (progress > 0 && progress <= total) { + return ; } else { return ; } @@ -88,9 +58,9 @@ export function LoadingCard(props: LoadingCardProps) { useEffect(() => { if (ready.status === 'success' && getReady()) { - setReady(props.image, ready.data); + setReady(ready.data[0]); } - }, [ready.status, getReady(), getProgress()]); + }, [ready.status, getReady(), getProgress(ready.data)]); return @@ -106,10 +76,7 @@ export function LoadingCard(props: LoadingCardProps) { sx={{ alignItems: 'center' }} > {renderProgress()} - {t('loading.progress', { - current: getProgress(), - total: getTotal(), - })} + {t('loading.progress', selectStatus(ready.data, image))} @@ -125,3 +92,45 @@ export function selectActions(state: OnnxState) { setReady: state.setReady, }; } + +export function selectStatus(data: Maybe>, defaultData: JobResponse) { + if (doesExist(data) && data.length > 0) { + return { + steps: data[0].steps, + stages: data[0].stages, + tiles: data[0].tiles, + }; + } + + return { + steps: defaultData.steps, + stages: defaultData.stages, + tiles: defaultData.tiles, + }; +} + +export function getPercent(current: number, total: number): number { + if (current > total) { + // steps was not complete, show 99% until done + return LOADING_OVERAGE; + } + + const pct = current / total; + return Math.ceil(pct * LOADING_PERCENT); +} + +export function getProgress(data: Maybe>) { + if (doesExist(data)) { + return data[0].steps.current; + } + + return 0; +} + +export function getTotal(data: Maybe>) { + if (doesExist(data)) { + return data[0].steps.total; + } + + return 0; +} diff --git a/gui/src/components/tab/Blend.tsx b/gui/src/components/tab/Blend.tsx index 21fe47b6..23f8d29a 100644 --- a/gui/src/components/tab/Blend.tsx +++ b/gui/src/components/tab/Blend.tsx @@ -20,13 +20,13 @@ import { MaskCanvas } from '../input/MaskCanvas.js'; export function Blend() { async function uploadSource() { const { blend, blendModel, blendUpscale } = store.getState(); - const { image, retry } = await client.blend(blendModel, { + const { job, retry } = await client.blend(blendModel, { ...blend, mask: mustExist(blend.mask), sources: mustExist(blend.sources), // TODO: show an error if this doesn't exist }, blendUpscale); - pushHistory(image, retry); + pushHistory(job, retry); } const client = mustExist(useContext(ClientContext)); diff --git a/gui/src/components/tab/Img2Img.tsx b/gui/src/components/tab/Img2Img.tsx index 90a75919..61911f4b 100644 --- a/gui/src/components/tab/Img2Img.tsx +++ b/gui/src/components/tab/Img2Img.tsx @@ -27,12 +27,12 @@ export function Img2Img() { const state = store.getState(); const img2img = selectParams(state); - const { image, retry } = await client.img2img(model, { + const { job, retry } = await client.img2img(model, { ...img2img, source: mustExist(img2img.source), // TODO: show an error if this doesn't exist }, selectUpscale(state), selectHighres(state)); - pushHistory(image, retry); + pushHistory(job, retry); } const client = mustExist(useContext(ClientContext)); diff --git a/gui/src/components/tab/Inpaint.tsx b/gui/src/components/tab/Inpaint.tsx index cd098b06..87ea7405 100644 --- a/gui/src/components/tab/Inpaint.tsx +++ b/gui/src/components/tab/Inpaint.tsx @@ -39,22 +39,22 @@ export function Inpaint() { const inpaint = selectParams(state); if (outpaint.enabled) { - const { image, retry } = await client.outpaint(model, { + const { job, retry } = await client.outpaint(model, { ...inpaint, ...outpaint, mask: mustExist(mask), source: mustExist(source), }, selectUpscale(state), selectHighres(state)); - pushHistory(image, retry); + pushHistory(job, retry); } else { - const { image, retry } = await client.inpaint(model, { + const { job, retry } = await client.inpaint(model, { ...inpaint, mask: mustExist(mask), source: mustExist(source), }, selectUpscale(state), selectHighres(state)); - pushHistory(image, retry); + pushHistory(job, retry); } } diff --git a/gui/src/components/tab/Txt2Img.tsx b/gui/src/components/tab/Txt2Img.tsx index 210cbaa3..060e6766 100644 --- a/gui/src/components/tab/Txt2Img.tsx +++ b/gui/src/components/tab/Txt2Img.tsx @@ -69,8 +69,8 @@ export function Txt2Img() { const image = await client.chain(model, chain); pushHistory(image); } else { - const { image, retry } = await client.txt2img(model, params2, upscale, highres); - pushHistory(image, retry); + const { job, retry } = await client.txt2img(model, params2, upscale, highres); + pushHistory(job, retry); } } diff --git a/gui/src/components/tab/Upscale.tsx b/gui/src/components/tab/Upscale.tsx index 0fe08cc0..4c0ea2e1 100644 --- a/gui/src/components/tab/Upscale.tsx +++ b/gui/src/components/tab/Upscale.tsx @@ -21,12 +21,12 @@ import { PromptInput } from '../input/PromptInput.js'; export function Upscale() { async function uploadSource() { const { upscaleHighres, upscaleUpscale, upscaleModel, upscale } = store.getState(); - const { image, retry } = await client.upscale(upscaleModel, { + const { job, retry } = await client.upscale(upscaleModel, { ...upscale, source: mustExist(upscale.source), // TODO: show an error if this doesn't exist }, upscaleUpscale, upscaleHighres); - pushHistory(image, retry); + pushHistory(job, retry); } const client = mustExist(useContext(ClientContext)); diff --git a/gui/src/state/history.ts b/gui/src/state/history.ts index de71ef72..c71da4a1 100644 --- a/gui/src/state/history.ts +++ b/gui/src/state/history.ts @@ -2,6 +2,7 @@ import { Maybe } from '@apextoaster/js-utils'; import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js'; import { Slice } from './types.js'; import { DEFAULT_HISTORY } from '../constants.js'; +import { JobResponse } from '../types/api-v2.js'; export interface HistoryItem { image: ImageResponse; @@ -9,14 +10,19 @@ export interface HistoryItem { retry: Maybe; } +export interface HistoryItemV2 { + image: JobResponse; + retry: Maybe; +} + export interface HistorySlice { - history: Array; + history: Array; limit: number; - pushHistory(image: ImageResponse, retry?: RetryParams): void; - removeHistory(image: ImageResponse): void; + pushHistory(image: JobResponse, retry?: RetryParams): void; + removeHistory(image: JobResponse): void; setLimit(limit: number): void; - setReady(image: ImageResponse, ready: ReadyResponse): void; + setReady(image: JobResponse): void; } export function createHistorySlice(): Slice { @@ -39,7 +45,7 @@ export function createHistorySlice(): Slice ({ ...prev, - history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key), + history: prev.history.filter((it) => it.image.name !== image.name), })); }, setLimit(limit) { @@ -48,12 +54,12 @@ export function createHistorySlice(): Slice { const history = [...prev.history]; - const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key); + const idx = history.findIndex((it) => it.image.name === image.name); if (idx >= 0) { - history[idx].ready = ready; + history[idx].image = image; } else { // TODO: error } diff --git a/gui/src/strings/en.ts b/gui/src/strings/en.ts index 6b4e7acc..9aa16e73 100644 --- a/gui/src/strings/en.ts +++ b/gui/src/strings/en.ts @@ -67,7 +67,7 @@ export const I18N_STRINGS_EN = { }, loading: { cancel: 'Cancel', - progress: '{{current}} of {{total}} steps', + progress: '{{steps.current}} of {{steps.total}} steps, {{tiles.current}} of {{tiles.total}} tiles, {{stages.current}} of {{stages.total}} stages', server: 'Connecting to server...', unknown: 'many', }, diff --git a/gui/src/types/api-v2.ts b/gui/src/types/api-v2.ts new file mode 100644 index 00000000..d7a993d7 --- /dev/null +++ b/gui/src/types/api-v2.ts @@ -0,0 +1,160 @@ +import { RetryParams } from './api.js'; +import { BaseImgParams, HighresParams, Img2ImgParams, InpaintParams, Txt2ImgParams, UpscaleParams } from './params.js'; + +export interface Progress { + current: number; + total: number; +} + +export interface Size { + width: number; + height: number; +} + +export interface NetworkMetadata { + name: string; + hash: string; + weight: number; +} + +export interface ImageMetadata { + input_size: Size; + size: Size; + outputs: Array; + params: TParams; + inversions: Array; + loras: Array; + model: string; + scheduler: string; + border: unknown; + highres: HighresParams; + upscale: UpscaleParams; + type: TType; +} + +export enum JobStatus { + PENDING = 'pending', + RUNNING = 'running', + SUCCESS = 'success', + FAILED = 'failed', + CANCELLED = 'cancelled', + UNKNOWN = 'unknown', +} + +export enum JobType { + TXT2IMG = 'txt2img', + IMG2IMG = 'img2img', + INPAINT = 'inpaint', + UPSCALE = 'upscale', + BLEND = 'blend', + CHAIN = 'chain', +} + +export interface BaseJobResponse { + name: string; + status: JobStatus; + type: JobType; + + stages: Progress; + steps: Progress; + tiles: Progress; +} + +/** + * Pending image job. + */ +export interface PendingJobResponse extends BaseJobResponse { + status: JobStatus.PENDING | JobStatus.RUNNING; + queue: Progress; +} + +/** + * Failed image job with error information. + */ +export interface FailedJobResponse extends BaseJobResponse { + status: JobStatus.FAILED; + error: string; +} + +/** + * Successful txt2img image job with output keys and metadata. + */ +export interface SuccessTxt2ImgJobResponse extends BaseJobResponse { + status: JobStatus.SUCCESS; + type: JobType.TXT2IMG; + outputs: Array; + metadata: Array>; +} + +/** + * Successful img2img job with output keys and metadata. + */ +export interface SuccessImg2ImgJobResponse extends BaseJobResponse { + status: JobStatus.SUCCESS; + type: JobType.IMG2IMG; + outputs: Array; + metadata: Array>; +} + +/** + * Successful inpaint job with output keys and metadata. + */ +export interface SuccessInpaintJobResponse extends BaseJobResponse { + status: JobStatus.SUCCESS; + type: JobType.INPAINT; + outputs: Array; + metadata: Array>; +} + +/** + * Successful upscale job with output keys and metadata. + */ +export interface SuccessUpscaleJobResponse extends BaseJobResponse { + status: JobStatus.SUCCESS; + type: JobType.UPSCALE; + outputs: Array; + metadata: Array>; +} + +/** + * Successful blend job with output keys and metadata. + */ +export interface SuccessBlendJobResponse extends BaseJobResponse { + status: JobStatus.SUCCESS; + type: JobType.BLEND; + outputs: Array; + metadata: Array>; +} + +/** + * Successful chain pipeline job with output keys and metadata. + */ +export interface SuccessChainJobResponse extends BaseJobResponse { + status: JobStatus.SUCCESS; + type: JobType.CHAIN; + outputs: Array; + metadata: Array>; // TODO: could be all kinds +} + +export type SuccessJobResponse + = SuccessTxt2ImgJobResponse + | SuccessImg2ImgJobResponse + | SuccessInpaintJobResponse + | SuccessUpscaleJobResponse + | SuccessBlendJobResponse + | SuccessChainJobResponse; + +export type JobResponse = PendingJobResponse | FailedJobResponse | SuccessJobResponse; + +/** + * Status response from the job endpoint, with parameters to retry the job if it fails. + */ +export interface JobResponseWithRetry { + job: JobResponse; + retry: RetryParams; +} + +/** + * Re-export `RetryParams` for convenience. + */ +export { RetryParams }; diff --git a/gui/src/types/api.ts b/gui/src/types/api.ts index 65179280..884e103c 100644 --- a/gui/src/types/api.ts +++ b/gui/src/types/api.ts @@ -14,6 +14,8 @@ import { /** * Output image data within the response. + * + * @deprecated */ export interface ImageOutput { key: string; @@ -22,6 +24,8 @@ export interface ImageOutput { /** * General response for most image requests. + * + * @deprecated */ export interface ImageResponse { outputs: Array; @@ -119,11 +123,19 @@ export type RetryParams = { upscale?: UpscaleParams; }; +/** + * Status response from the image endpoint, with parameters to retry the job if it fails. + * + * @deprecated + */ export interface ImageResponseWithRetry { image: ImageResponse; retry: RetryParams; } +/** + * @deprecated + */ export interface ImageMetadata { highres: HighresParams; outputs: string | Array; diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace index 2e4ba298..f64caca5 100644 --- a/onnx-web.code-workspace +++ b/onnx-web.code-workspace @@ -43,6 +43,7 @@ "dtype", "ESRGAN", "Exif", + "fromarray", "ftfy", "gfpgan", "Heun", @@ -115,6 +116,7 @@ "webp", "xformers", "zustand" - ] + ], + "git.ignoreLimitWarning": true } } \ No newline at end of file