1
0
Fork 0

feat: add batch endpoints for cancel and status, update responses

This commit is contained in:
Sean Sube 2024-01-03 19:09:18 -06:00
parent 19c91f70f5
commit 44a8d61082
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
36 changed files with 981 additions and 707 deletions

View File

@ -36,15 +36,14 @@ class CorrectCodeformerStage(BaseStage):
# https://pypi.org/project/codeformer-perceptor/ # https://pypi.org/project/codeformer-perceptor/
# import must be within the load function for patches to take effect # import must be within the load function for patches to take effect
# TODO: rewrite and remove from codeformer.basicsr.archs.codeformer_arch import CodeFormer
from codeformer.basicsr.utils import img2tensor, tensor2img from codeformer.basicsr.utils import img2tensor, tensor2img
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
upscale = upscale.with_args(**kwargs) upscale = upscale.with_args(**kwargs)
device = worker.get_device() device = worker.get_device()
net = ARCH_REGISTRY.get("CodeFormer")( net = CodeFormer(
dim_embd=512, dim_embd=512,
codebook_size=1024, codebook_size=1024,
n_head=8, n_head=8,

View File

@ -1,10 +1,28 @@
from typing import Any, List, Optional, Tuple from json import dumps
from logging import getLogger
from os import path
from typing import Any, List, Optional
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from ..output import json_params from ..convert.utils import resolve_tensor
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
from ..server.load import get_extra_hashes
from ..utils import hash_file
logger = getLogger(__name__)
class NetworkMetadata:
name: str
hash: str
weight: float
def __init__(self, name: str, hash: str, weight: float) -> None:
self.name = name
self.hash = hash
self.weight = weight
class ImageMetadata: class ImageMetadata:
@ -13,8 +31,9 @@ class ImageMetadata:
params: ImageParams params: ImageParams
size: Size size: Size
upscale: UpscaleParams upscale: UpscaleParams
inversions: Optional[List[Tuple[str, float]]] inversions: Optional[List[NetworkMetadata]]
loras: Optional[List[Tuple[str, float]]] loras: Optional[List[NetworkMetadata]]
models: Optional[List[NetworkMetadata]]
def __init__( def __init__(
self, self,
@ -23,8 +42,9 @@ class ImageMetadata:
upscale: Optional[UpscaleParams] = None, upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None, border: Optional[Border] = None,
highres: Optional[HighresParams] = None, highres: Optional[HighresParams] = None,
inversions: Optional[List[Tuple[str, float]]] = None, inversions: Optional[List[NetworkMetadata]] = None,
loras: Optional[List[Tuple[str, float]]] = None, loras: Optional[List[NetworkMetadata]] = None,
models: Optional[List[NetworkMetadata]] = None,
) -> None: ) -> None:
self.params = params self.params = params
self.size = size self.size = size
@ -33,19 +53,108 @@ class ImageMetadata:
self.highres = highres self.highres = highres
self.inversions = inversions self.inversions = inversions
self.loras = loras self.loras = loras
self.models = models
def to_auto1111(self, server, outputs) -> str:
model_name = path.basename(path.normpath(self.params.model))
logger.debug("getting model hash for %s", model_name)
model_hash = get_extra_hashes().get(model_name, None)
if model_hash is None:
model_hash_path = path.join(self.params.model, "hash.txt")
if path.exists(model_hash_path):
with open(model_hash_path, "r") as f:
model_hash = f.readline().rstrip(",. \n\t\r")
model_hash = model_hash or "unknown"
hash_map = {
model_name: model_hash,
}
inversion_hashes = ""
if self.inversions is not None:
inversion_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "inversion", name))
).upper(),
)
for name, _weight in self.inversions
]
inversion_hashes = ",".join(
[f"{name}: {hash}" for name, hash in inversion_pairs]
)
hash_map.update(dict(inversion_pairs))
lora_hashes = ""
if self.loras is not None:
lora_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "lora", name))
).upper(),
)
for name, _weight in self.loras
]
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
hash_map.update(dict(lora_pairs))
return (
f"{self.params.prompt or ''}\nNegative prompt: {self.params.negative_prompt or ''}\n"
f"Steps: {self.params.steps}, Sampler: {self.params.scheduler}, CFG scale: {self.params.cfg}, "
f"Seed: {self.params.seed}, Size: {self.size.width}x{self.size.height}, "
f"Model hash: {model_hash}, Model: {model_name}, "
f"Tool: onnx-web, Version: {server.server_version}, "
f'Inversion hashes: "{inversion_hashes}", '
f'Lora hashes: "{lora_hashes}", '
f"Hashes: {dumps(hash_map)}"
)
def tojson(self, server, outputs): def tojson(self, server, outputs):
return json_params( json = {
server, "input_size": self.size.tojson(),
outputs, "outputs": outputs,
self.params, "params": self.params.tojson(),
self.size, "inversions": {},
upscale=self.upscale, "loras": {},
border=self.border, }
highres=self.highres,
inversions=self.inversions, json["params"]["model"] = path.basename(self.params.model)
loras=self.loras, json["params"]["scheduler"] = self.params.scheduler # TODO: why tho?
)
# calculate final output size
output_size = self.size
if self.border is not None:
json["border"] = self.border.tojson()
output_size = output_size.add_border(self.border)
if self.highres is not None:
json["highres"] = self.highres.tojson()
output_size = self.highres.resize(output_size)
if self.upscale is not None:
json["upscale"] = self.upscale.tojson()
output_size = self.upscale.resize(output_size)
json["size"] = output_size.tojson()
if self.inversions is not None:
for name, weight in self.inversions:
hash = hash_file(
resolve_tensor(path.join(server.model_path, "inversion", name))
).upper()
json["inversions"][name] = {"weight": weight, "hash": hash}
if self.loras is not None:
for name, weight in self.loras:
hash = hash_file(
resolve_tensor(path.join(server.model_path, "lora", name))
).upper()
json["loras"][name] = {"weight": weight, "hash": hash}
return json
class StageResult: class StageResult:
@ -86,6 +195,7 @@ class StageResult:
self.arrays = arrays self.arrays = arrays
self.images = images self.images = images
self.source = source self.source = source
self.metadata = []
def __len__(self) -> int: def __len__(self) -> int:
if self.arrays is not None: if self.arrays is not None:
@ -117,7 +227,7 @@ class StageResult:
elif self.images is not None: elif self.images is not None:
self.images.append(Image.fromarray(np.uint8(array), shape_mode(array))) self.images.append(Image.fromarray(np.uint8(array), shape_mode(array)))
else: else:
raise ValueError("invalid stage result") self.arrays = [array]
if metadata is not None: if metadata is not None:
self.metadata.append(metadata) self.metadata.append(metadata)
@ -130,13 +240,45 @@ class StageResult:
elif self.arrays is not None: elif self.arrays is not None:
self.arrays.append(np.array(image)) self.arrays.append(np.array(image))
else: else:
raise ValueError("invalid stage result") self.images = [image]
if metadata is not None: if metadata is not None:
self.metadata.append(metadata) self.metadata.append(metadata)
else: else:
self.metadata.append(ImageMetadata()) self.metadata.append(ImageMetadata())
def insert_array(
self, index: int, array: np.ndarray, metadata: Optional[ImageMetadata]
):
if self.arrays is not None:
self.arrays.insert(index, array)
elif self.images is not None:
self.images.insert(
index, Image.fromarray(np.uint8(array), shape_mode(array))
)
else:
self.arrays = [array]
if metadata is not None:
self.metadata.insert(index, metadata)
else:
self.metadata.insert(index, ImageMetadata())
def insert_image(
self, index: int, image: Image.Image, metadata: Optional[ImageMetadata]
):
if self.images is not None:
self.images.insert(index, image)
elif self.arrays is not None:
self.arrays.insert(index, np.array(image))
else:
self.images = [image]
if metadata is not None:
self.metadata.insert(index, metadata)
else:
self.metadata.insert(index, ImageMetadata())
def shape_mode(arr: np.ndarray) -> str: def shape_mode(arr: np.ndarray) -> str:
if len(arr.shape) != 3: if len(arr.shape) != 3:

View File

@ -16,7 +16,7 @@ from ..chain.highres import stage_highres
from ..chain.result import StageResult from ..chain.result import StageResult
from ..chain.upscale import split_upscale, stage_upscale_correction from ..chain.upscale import split_upscale, stage_upscale_correction
from ..image import expand_image from ..image import expand_image
from ..output import save_image from ..output import save_image, save_result
from ..params import ( from ..params import (
Border, Border,
HighresParams, HighresParams,
@ -29,7 +29,7 @@ from ..server import ServerContext
from ..server.load import get_source_filters from ..server.load import get_source_filters
from ..utils import is_debug, run_gc, show_system_toast from ..utils import is_debug, run_gc, show_system_toast
from ..worker import WorkerContext from ..worker import WorkerContext
from .utils import get_latents_from_seed, parse_prompt from .utils import get_latents_from_seed
logger = getLogger(__name__) logger = getLogger(__name__)
@ -57,7 +57,6 @@ def run_txt2img_pipeline(
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
size: Size, size: Size,
outputs: List[str],
upscale: UpscaleParams, upscale: UpscaleParams,
highres: HighresParams, highres: HighresParams,
) -> None: ) -> None:
@ -114,50 +113,34 @@ def run_txt2img_pipeline(
# run and save # run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch) latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain.run( images = chain(
worker, server, params, StageResult.empty(), callback=progress, latents=latents worker, server, params, StageResult.empty(), callback=progress, latents=latents
) )
_pairs, loras, inversions, _rest = parse_prompt(params)
# add a thumbnail, if requested # add a thumbnail, if requested
cover = images[0] cover = images.as_image()[0]
if params.thumbnail and ( if params.thumbnail and (
cover.width > server.thumbnail_size or cover.height > server.thumbnail_size cover.width > server.thumbnail_size or cover.height > server.thumbnail_size
): ):
thumbnail = cover.copy() thumbnail = cover.copy()
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size)) thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
images.insert(0, thumbnail) images.insert_image(0, thumbnail)
outputs.insert(0, f"{worker.name}-thumb.{server.image_format}")
for image, output in zip(images, outputs): save_result(server, images, worker.job)
logger.trace("saving output image %s: %s", output, image.size)
dest = save_image(
server,
output,
image,
params,
size,
upscale=upscale,
highres=highres,
inversions=inversions,
loras=loras,
)
# clean up # clean up
run_gc([worker.get_device()]) run_gc([worker.get_device()])
# notify the user # notify the user
show_system_toast(f"finished txt2img job: {dest}") show_system_toast(f"finished txt2img job: {worker.job}")
logger.info("finished txt2img job: %s", dest) logger.info("finished txt2img job: %s", worker.job)
def run_img2img_pipeline( def run_img2img_pipeline(
worker: WorkerContext, worker: WorkerContext,
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
outputs: List[str],
upscale: UpscaleParams, upscale: UpscaleParams,
highres: HighresParams, highres: HighresParams,
source: Image.Image, source: Image.Image,
@ -228,36 +211,21 @@ def run_img2img_pipeline(
# run and append the filtered source # run and append the filtered source
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain.run( images = chain(
worker, server, params, StageResult(images=[source]), callback=progress worker, server, params, StageResult(images=[source]), callback=progress
) )
if source_filter is not None and source_filter != "none": if source_filter is not None and source_filter != "none":
images.append(source) images.push_image(source)
# save with metadata save_result(server, images, worker.job)
_pairs, loras, inversions, _rest = parse_prompt(params)
size = Size(*source.size)
for image, output in zip(images, outputs):
dest = save_image(
server,
output,
image,
params,
size,
upscale=upscale,
highres=highres,
inversions=inversions,
loras=loras,
)
# clean up # clean up
run_gc([worker.get_device()]) run_gc([worker.get_device()])
# notify the user # notify the user
show_system_toast(f"finished img2img job: {dest}") show_system_toast(f"finished img2img job: {worker.job}")
logger.info("finished img2img job: %s", dest) logger.info("finished img2img job: %s", worker.job)
def run_inpaint_pipeline( def run_inpaint_pipeline(
@ -265,7 +233,6 @@ def run_inpaint_pipeline(
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
size: Size, size: Size,
outputs: List[str],
upscale: UpscaleParams, upscale: UpscaleParams,
highres: HighresParams, highres: HighresParams,
source: Image.Image, source: Image.Image,
@ -290,7 +257,7 @@ def run_inpaint_pipeline(
mask = ImageOps.contain(mask, (mask_max, mask_max)) mask = ImageOps.contain(mask, (mask_max, mask_max))
mask = mask.crop((0, 0, source.width, source.height)) mask = mask.crop((0, 0, source.width, source.height))
source, mask, noise, full_size = expand_image( source, mask, noise, _full_size = expand_image(
source, source,
mask, mask,
border, border,
@ -414,7 +381,7 @@ def run_inpaint_pipeline(
# run and save # run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch) latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain.run( images = chain(
worker, worker,
server, server,
params, params,
@ -423,33 +390,28 @@ def run_inpaint_pipeline(
latents=latents, latents=latents,
) )
_pairs, loras, inversions, _rest = parse_prompt(params) for i, image, metadata in enumerate(zip(images.as_image(), images.metadata)):
for image, output in zip(images, outputs):
if full_res_inpaint: if full_res_inpaint:
if is_debug(): if is_debug():
save_image(server, "adjusted-output.png", image) save_image(server, "adjusted-output.png", image)
mini_image = ImageOps.contain(image, (adj_mask_size, adj_mask_size)) mini_image = ImageOps.contain(image, (adj_mask_size, adj_mask_size))
image = original_source image = original_source
image.paste(mini_image, box=adj_mask_border) image.paste(mini_image, box=adj_mask_border)
dest = save_image(
save_image(
server, server,
output, f"{worker.job}_{i}.{server.image_format}",
image, image,
params, metadata,
size,
upscale=upscale,
border=border,
inversions=inversions,
loras=loras,
) )
# clean up # clean up
del image
run_gc([worker.get_device()]) run_gc([worker.get_device()])
# notify the user # notify the user
show_system_toast(f"finished inpaint job: {dest}") show_system_toast(f"finished inpaint job: {worker.job}")
logger.info("finished inpaint job: %s", dest) logger.info("finished inpaint job: %s", worker.job)
def run_upscale_pipeline( def run_upscale_pipeline(
@ -457,7 +419,6 @@ def run_upscale_pipeline(
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
size: Size, size: Size,
outputs: List[str],
upscale: UpscaleParams, upscale: UpscaleParams,
highres: HighresParams, highres: HighresParams,
source: Image.Image, source: Image.Image,
@ -497,30 +458,18 @@ def run_upscale_pipeline(
# run and save # run and save
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain.run( images = chain(
worker, server, params, StageResult(images=[source]), callback=progress worker, server, params, StageResult(images=[source]), callback=progress
) )
_pairs, loras, inversions, _rest = parse_prompt(params) save_result(server, images, worker.job)
for image, output in zip(images, outputs):
dest = save_image(
server,
output,
image,
params,
size,
upscale=upscale,
inversions=inversions,
loras=loras,
)
# clean up # clean up
del image
run_gc([worker.get_device()]) run_gc([worker.get_device()])
# notify the user # notify the user
show_system_toast(f"finished upscale job: {dest}") show_system_toast(f"finished upscale job: {worker.job}")
logger.info("finished upscale job: %s", dest) logger.info("finished upscale job: %s", worker.job)
def run_blend_pipeline( def run_blend_pipeline(
@ -528,7 +477,6 @@ def run_blend_pipeline(
server: ServerContext, server: ServerContext,
params: ImageParams, params: ImageParams,
size: Size, size: Size,
outputs: List[str],
upscale: UpscaleParams, upscale: UpscaleParams,
# highres: HighresParams, # highres: HighresParams,
sources: List[Image.Image], sources: List[Image.Image],
@ -559,17 +507,15 @@ def run_blend_pipeline(
# run and save # run and save
progress = worker.get_progress_callback() progress = worker.get_progress_callback()
images = chain.run( images = chain(
worker, server, params, StageResult(images=sources), callback=progress worker, server, params, StageResult(images=sources), callback=progress
) )
for image, output in zip(images, outputs): save_result(server, images, worker.job)
dest = save_image(server, output, image, params, size, upscale=upscale)
# clean up # clean up
del image
run_gc([worker.get_device()]) run_gc([worker.get_device()])
# notify the user # notify the user
show_system_toast(f"finished blend job: {dest}") show_system_toast(f"finished blend job: {worker.job}")
logger.info("finished blend job: %s", dest) logger.info("finished blend job: %s", worker.job)

View File

@ -1,173 +1,20 @@
from hashlib import sha256 from hashlib import sha256
from json import dumps from json import dumps
from logging import getLogger from logging import getLogger
from os import path
from struct import pack
from time import time from time import time
from typing import Any, Dict, List, Optional, Tuple from typing import List, Optional
from piexif import ExifIFD, ImageIFD, dump from piexif import ExifIFD, ImageIFD, dump
from piexif.helper import UserComment from piexif.helper import UserComment
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from .convert.utils import resolve_tensor from .chain.result import ImageMetadata, StageResult
from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams from .params import ImageParams, Param, Size
from .server import ServerContext from .server import ServerContext
from .server.load import get_extra_hashes from .utils import base_join, hash_value
from .utils import base_join
logger = getLogger(__name__) logger = getLogger(__name__)
HASH_BUFFER_SIZE = 2**22 # 4MB
def hash_file(name: str):
sha = sha256()
with open(name, "rb") as f:
while True:
data = f.read(HASH_BUFFER_SIZE)
if not data:
break
sha.update(data)
return sha.hexdigest()
def hash_value(sha, param: Optional[Param]):
if param is None:
return
elif isinstance(param, bool):
sha.update(bytearray(pack("!B", param)))
elif isinstance(param, float):
sha.update(bytearray(pack("!f", param)))
elif isinstance(param, int):
sha.update(bytearray(pack("!I", param)))
elif isinstance(param, str):
sha.update(param.encode("utf-8"))
else:
logger.warning("cannot hash param: %s, %s", param, type(param))
def json_params(
server: ServerContext,
outputs: List[str],
params: ImageParams,
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
inversions: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None,
parent: Optional[Dict] = None,
) -> Any:
json = {
"input_size": size.tojson(),
"outputs": outputs,
"params": params.tojson(),
"inversions": {},
"loras": {},
}
json["params"]["model"] = path.basename(params.model)
json["params"]["scheduler"] = params.scheduler
# calculate final output size
output_size = size
if border is not None:
json["border"] = border.tojson()
output_size = output_size.add_border(border)
if highres is not None:
json["highres"] = highres.tojson()
output_size = highres.resize(output_size)
if upscale is not None:
json["upscale"] = upscale.tojson()
output_size = upscale.resize(output_size)
json["size"] = output_size.tojson()
if inversions is not None:
for name, weight in inversions:
hash = hash_file(
resolve_tensor(path.join(server.model_path, "inversion", name))
).upper()
json["inversions"][name] = {"weight": weight, "hash": hash}
if loras is not None:
for name, weight in loras:
hash = hash_file(
resolve_tensor(path.join(server.model_path, "lora", name))
).upper()
json["loras"][name] = {"weight": weight, "hash": hash}
return json
def str_params(
server: ServerContext,
params: ImageParams,
size: Size,
inversions: List[Tuple[str, float]] = None,
loras: List[Tuple[str, float]] = None,
) -> str:
model_name = path.basename(path.normpath(params.model))
logger.debug("getting model hash for %s", model_name)
model_hash = get_extra_hashes().get(model_name, None)
if model_hash is None:
model_hash_path = path.join(params.model, "hash.txt")
if path.exists(model_hash_path):
with open(model_hash_path, "r") as f:
model_hash = f.readline().rstrip(",. \n\t\r")
model_hash = model_hash or "unknown"
hash_map = {
model_name: model_hash,
}
inversion_hashes = ""
if inversions is not None:
inversion_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "inversion", name))
).upper(),
)
for name, _weight in inversions
]
inversion_hashes = ",".join(
[f"{name}: {hash}" for name, hash in inversion_pairs]
)
hash_map.update(dict(inversion_pairs))
lora_hashes = ""
if loras is not None:
lora_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "lora", name))
).upper(),
)
for name, _weight in loras
]
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
hash_map.update(dict(lora_pairs))
return (
f"{params.prompt or ''}\nNegative prompt: {params.negative_prompt or ''}\n"
f"Steps: {params.steps}, Sampler: {params.scheduler}, CFG scale: {params.cfg}, "
f"Seed: {params.seed}, Size: {size.width}x{size.height}, "
f"Model hash: {model_hash}, Model: {model_name}, "
f"Tool: onnx-web, Version: {server.server_version}, "
f'Inversion hashes: "{inversion_hashes}", '
f'Lora hashes: "{lora_hashes}", '
f"Hashes: {dumps(hash_map)}"
)
def make_output_name( def make_output_name(
server: ServerContext, server: ServerContext,
@ -179,6 +26,19 @@ def make_output_name(
offset: int = 0, offset: int = 0,
) -> List[str]: ) -> List[str]:
count = count or params.batch count = count or params.batch
job_name = make_job_name(mode, params, size, extras)
return [
f"{job_name}_{i}.{server.image_format}" for i in range(offset, count + offset)
]
def make_job_name(
mode: str,
params: ImageParams,
size: Size,
extras: Optional[List[Optional[Param]]] = None,
) -> str:
now = int(time()) now = int(time())
sha = sha256() sha = sha256()
@ -200,49 +60,49 @@ def make_output_name(
for param in extras: for param in extras:
hash_value(sha, param) hash_value(sha, param)
return [ return f"{mode}_{params.seed}_{sha.hexdigest()}_{now}"
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
for i in range(offset, count + offset)
] def save_result(
server: ServerContext,
result: StageResult,
base_name: str,
) -> List[str]:
results = []
for i, image, metadata in enumerate(zip(result.as_image(), result.metadata)):
results.append(
save_image(
server,
base_name + f"_{i}.{server.image_format}",
image,
metadata,
)
)
return results
def save_image( def save_image(
server: ServerContext, server: ServerContext,
output: str, output: str,
image: Image.Image, image: Image.Image,
params: Optional[ImageParams] = None, metadata: ImageMetadata,
size: Optional[Size] = None,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
inversions: List[Tuple[str, float]] = None,
loras: List[Tuple[str, float]] = None,
) -> str: ) -> str:
path = base_join(server.output_path, output) path = base_join(server.output_path, output)
if server.image_format == "png": if server.image_format == "png":
exif = PngImagePlugin.PngInfo() exif = PngImagePlugin.PngInfo()
if params is not None: if metadata is not None:
exif.add_text("make", "onnx-web") exif.add_text("make", "onnx-web")
exif.add_text( exif.add_text(
"maker note", "maker note",
dumps( dumps(metadata.tojson(server, [output])),
json_params(
server,
[output],
params,
size,
upscale=upscale,
border=border,
highres=highres,
)
),
) )
exif.add_text("model", server.server_version) exif.add_text("model", server.server_version)
exif.add_text( exif.add_text(
"parameters", "parameters",
str_params(server, params, size, inversions=inversions, loras=loras), metadata.to_auto1111(server, [output]),
) )
image.save(path, format=server.image_format, pnginfo=exif) image.save(path, format=server.image_format, pnginfo=exif)
@ -251,23 +111,11 @@ def save_image(
{ {
"0th": { "0th": {
ExifIFD.MakerNote: UserComment.dump( ExifIFD.MakerNote: UserComment.dump(
dumps( dumps(metadata.tojson(server, [output])),
json_params(
server,
[output],
params,
size,
upscale=upscale,
border=border,
highres=highres,
)
),
encoding="unicode", encoding="unicode",
), ),
ExifIFD.UserComment: UserComment.dump( ExifIFD.UserComment: UserComment.dump(
str_params( metadata.to_auto1111(server, [output]),
server, params, size, inversions=inversions, loras=loras
),
encoding="unicode", encoding="unicode",
), ),
ImageIFD.Make: "onnx-web", ImageIFD.Make: "onnx-web",
@ -277,34 +125,23 @@ def save_image(
) )
image.save(path, format=server.image_format, exif=exif) image.save(path, format=server.image_format, exif=exif)
if params is not None: if metadata is not None:
save_params( save_metadata(
server, server,
output, output,
params,
size,
upscale=upscale,
border=border,
highres=highres,
) )
logger.debug("saved output image to: %s", path) logger.debug("saved output image to: %s", path)
return path return path
def save_params( def save_metadata(
server: ServerContext, server: ServerContext,
output: str, output: str,
params: ImageParams, metadata: ImageMetadata,
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
) -> str: ) -> str:
path = base_join(server.output_path, f"{output}.json") path = base_join(server.output_path, f"{output}.json")
json = json_params( json = metadata.tojson(server, [output])
server, output, params, size, upscale=upscale, border=border, highres=highres
)
with open(path, "w") as f: with open(path, "w") as f:
f.write(dumps(json)) f.write(dumps(json))
logger.debug("saved image params to: %s", path) logger.debug("saved image params to: %s", path)

View File

@ -13,6 +13,24 @@ Param = Union[str, int, float]
Point = Tuple[int, int] Point = Tuple[int, int]
class Progress:
current: int
total: int
def __init__(self, current: int, total: int) -> None:
self.current = current
self.total = total
def __str__(self) -> str:
return "%s/%s" % (self.current, self.total)
def tojson(self):
return {
"current": self.current,
"total": self.total,
}
class SizeChart(IntEnum): class SizeChart(IntEnum):
micro = 64 micro = 64
mini = 128 # small tile for very expensive models mini = 128 # small tile for very expensive models

View File

@ -26,14 +26,14 @@ def restart_workers(server: ServerContext, pool: DevicePoolExecutor):
pool.recycle(recycle_all=True) pool.recycle(recycle_all=True)
logger.info("restarted worker pool") logger.info("restarted worker pool")
return jsonify(pool.status()) return jsonify(pool.summary())
def worker_status(server: ServerContext, pool: DevicePoolExecutor): def worker_status(server: ServerContext, pool: DevicePoolExecutor):
if not check_admin(server): if not check_admin(server):
return make_response(jsonify({})), 401 return make_response(jsonify({})), 401
return jsonify(pool.status()) return jsonify(pool.summary())
def get_extra_models(server: ServerContext): def get_extra_models(server: ServerContext):
@ -102,8 +102,8 @@ def register_admin_routes(app: Flask, server: ServerContext, pool: DevicePoolExe
app.route("/api/extras", methods=["PUT"])( app.route("/api/extras", methods=["PUT"])(
wrap_route(update_extra_models, server) wrap_route(update_extra_models, server)
), ),
app.route("/api/restart", methods=["POST"])( app.route("/api/worker/restart", methods=["POST"])(
wrap_route(restart_workers, server, pool=pool) wrap_route(restart_workers, server, pool=pool)
), ),
app.route("/api/status")(wrap_route(worker_status, server, pool=pool)), app.route("/api/worker/status")(wrap_route(worker_status, server, pool=pool)),
] ]

View File

@ -1,14 +1,14 @@
from io import BytesIO from io import BytesIO
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Any, Dict from typing import Any, Dict, List
from flask import Flask, jsonify, make_response, request, url_for from flask import Flask, jsonify, make_response, request, url_for
from jsonschema import validate from jsonschema import validate
from PIL import Image from PIL import Image
from ..chain import CHAIN_STAGES, ChainPipeline from ..chain import CHAIN_STAGES, ChainPipeline
from ..chain.result import StageResult from ..chain.result import ImageMetadata, StageResult
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
from ..diffusers.run import ( from ..diffusers.run import (
run_blend_pipeline, run_blend_pipeline,
@ -18,8 +18,8 @@ from ..diffusers.run import (
run_upscale_pipeline, run_upscale_pipeline,
) )
from ..diffusers.utils import replace_wildcards from ..diffusers.utils import replace_wildcards
from ..output import json_params, make_output_name from ..output import make_job_name
from ..params import Size, StageParams, TileOrder from ..params import Progress, Size, StageParams, TileOrder
from ..transformers.run import run_txt2txt_pipeline from ..transformers.run import run_txt2txt_pipeline
from ..utils import ( from ..utils import (
base_join, base_join,
@ -34,6 +34,7 @@ from ..utils import (
load_config_str, load_config_str,
sanitize_name, sanitize_name,
) )
from ..worker.command import JobType
from ..worker.pool import DevicePoolExecutor from ..worker.pool import DevicePoolExecutor
from .context import ServerContext from .context import ServerContext
from .load import ( from .load import (
@ -92,6 +93,64 @@ def error_reply(err: str):
return response return response
def job_reply(name: str):
return jsonify(
{
"name": name,
}
)
def image_reply(
name: str,
status: str,
job_type: str,
stages: Progress = None,
steps: Progress = None,
tiles: Progress = None,
outputs: List[str] = None,
metadata: List[ImageMetadata] = None,
):
if stages is None:
stages = Progress()
if steps is None:
steps = Progress()
if tiles is None:
tiles = Progress()
data = {
"name": name,
"status": status,
"type": job_type,
"stages": stages.tojson(),
"steps": steps.tojson(),
"tiles": tiles.tojson(),
}
if len(metadata) != len(outputs):
logger.error("metadata and outputs must be the same length")
return error_reply("metadata and outputs must be the same length")
if outputs is not None:
data["outputs"] = outputs
if metadata is not None:
data["metadata"] = metadata
return jsonify(data)
def multi_image_reply(results: Dict[str, Any]):
# TODO: not that
return jsonify(
{
"results": results,
}
)
def url_from_rule(rule) -> str: def url_from_rule(rule) -> str:
options = {} options = {}
for arg in rule.arguments: for arg in rule.arguments:
@ -197,17 +256,15 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
) )
output_count += 1 output_count += 1
output = make_output_name( job_name = make_job_name(
server, "img2img", params, size, extras=[strength], count=output_count server, "img2img", params, size, extras=[strength], count=output_count
) )
job_name = output[0]
pool.submit( pool.submit(
job_name, job_name,
JobType.IMG2IMG,
run_img2img_pipeline, run_img2img_pipeline,
server, server,
params, params,
output,
upscale, upscale,
highres, highres,
source, source,
@ -218,9 +275,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
logger.info("img2img job queued for: %s", job_name) logger.info("img2img job queued for: %s", job_name)
return jsonify( return job_reply(job_name)
json_params(server, output, params, size, upscale=upscale, highres=highres)
)
def txt2img(server: ServerContext, pool: DevicePoolExecutor): def txt2img(server: ServerContext, pool: DevicePoolExecutor):
@ -230,16 +285,15 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
replace_wildcards(params, get_wildcard_data()) replace_wildcards(params, get_wildcard_data())
output = make_output_name(server, "txt2img", params, size, count=params.batch) job_name = make_job_name(server, "txt2img", params, size, count=params.batch)
job_name = output[0]
pool.submit( pool.submit(
job_name, job_name,
JobType.TXT2IMG,
run_txt2img_pipeline, run_txt2img_pipeline,
server, server,
params, params,
size, size,
output,
upscale, upscale,
highres, highres,
needs_device=device, needs_device=device,
@ -247,9 +301,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
logger.info("txt2img job queued for: %s", job_name) logger.info("txt2img job queued for: %s", job_name)
return jsonify( return job_reply(job_name)
json_params(server, output, params, size, upscale=upscale, highres=highres)
)
def inpaint(server: ServerContext, pool: DevicePoolExecutor): def inpaint(server: ServerContext, pool: DevicePoolExecutor):
@ -295,7 +347,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
replace_wildcards(params, get_wildcard_data()) replace_wildcards(params, get_wildcard_data())
output = make_output_name( job_name = make_job_name(
server, server,
"inpaint", "inpaint",
params, params,
@ -312,14 +364,13 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
], ],
) )
job_name = output[0]
pool.submit( pool.submit(
job_name, job_name,
JobType.INPAINT,
run_inpaint_pipeline, run_inpaint_pipeline,
server, server,
params, params,
size, size,
output,
upscale, upscale,
highres, highres,
source, source,
@ -336,17 +387,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
logger.info("inpaint job queued for: %s", job_name) logger.info("inpaint job queued for: %s", job_name)
return jsonify( return job_reply(job_name)
json_params(
server,
output,
params,
size,
upscale=upscale,
border=expand,
highres=highres,
)
)
def upscale(server: ServerContext, pool: DevicePoolExecutor): def upscale(server: ServerContext, pool: DevicePoolExecutor):
@ -362,16 +403,14 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
replace_wildcards(params, get_wildcard_data()) replace_wildcards(params, get_wildcard_data())
output = make_output_name(server, "upscale", params, size) job_name = make_job_name(server, "upscale", params, size)
job_name = output[0]
pool.submit( pool.submit(
job_name, job_name,
JobType.UPSCALE,
run_upscale_pipeline, run_upscale_pipeline,
server, server,
params, params,
size, size,
output,
upscale, upscale,
highres, highres,
source, source,
@ -380,9 +419,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
logger.info("upscale job queued for: %s", job_name) logger.info("upscale job queued for: %s", job_name)
return jsonify( return job_reply(job_name)
json_params(server, output, params, size, upscale=upscale, highres=highres)
)
# keys that are specially parsed by params and should not show up in with_args # keys that are specially parsed by params and should not show up in with_args
@ -478,25 +515,21 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.info("running chain pipeline with %s stages", len(pipeline.stages)) logger.info("running chain pipeline with %s stages", len(pipeline.stages))
output = make_output_name( job_name = make_job_name(server, "chain", base_params, base_size)
server, "chain", base_params, base_size, count=pipeline.outputs(base_params, 0)
)
job_name = output[0]
# build and run chain pipeline # build and run chain pipeline
pool.submit( pool.submit(
job_name, job_name,
JobType.CHAIN,
pipeline, pipeline,
server, server,
base_params, base_params,
StageResult.empty(), StageResult.empty(),
output=output,
size=base_size, size=base_size,
needs_device=device, needs_device=device,
) )
step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size)) return job_reply(job_name)
return jsonify(json_params(server, output, step_params, base_size))
def blend(server: ServerContext, pool: DevicePoolExecutor): def blend(server: ServerContext, pool: DevicePoolExecutor):
@ -520,15 +553,14 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server) device, params, size = pipeline_from_request(server)
upscale = build_upscale() upscale = build_upscale()
output = make_output_name(server, "upscale", params, size) job_name = make_job_name(server, "blend", params, size)
job_name = output[0]
pool.submit( pool.submit(
job_name, job_name,
JobType.BLEND,
run_blend_pipeline, run_blend_pipeline,
server, server,
params, params,
size, size,
output,
upscale, upscale,
# TODO: highres # TODO: highres
sources, sources,
@ -538,27 +570,26 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
logger.info("upscale job queued for: %s", job_name) logger.info("upscale job queued for: %s", job_name)
return jsonify(json_params(server, output, params, size, upscale=upscale)) return job_reply(job_name)
def txt2txt(server: ServerContext, pool: DevicePoolExecutor): def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server) device, params, size = pipeline_from_request(server)
output = make_output_name(server, "txt2txt", params, size) job_name = make_job_name(server, "txt2txt", params, size)
job_name = output[0]
logger.info("upscale job queued for: %s", job_name) logger.info("upscale job queued for: %s", job_name)
pool.submit( pool.submit(
job_name, job_name,
JobType.TXT2TXT,
run_txt2txt_pipeline, run_txt2txt_pipeline,
server, server,
params, params,
size, size,
output,
needs_device=device, needs_device=device,
) )
return jsonify(json_params(server, output, params, size)) return job_reply(job_name)
def cancel(server: ServerContext, pool: DevicePoolExecutor): def cancel(server: ServerContext, pool: DevicePoolExecutor):
@ -601,9 +632,64 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
) )
def job_cancel(server: ServerContext, pool: DevicePoolExecutor):
legacy_job_name = request.args.get("job", None)
job_list = request.args.get("jobs", "").split(",")
if legacy_job_name is not None:
job_list.append(legacy_job_name)
if len(job_list) == 0:
return error_reply("at least one job name is required")
results = {}
for job_name in job_list:
job_name = sanitize_name(job_name)
cancelled = pool.cancel(job_name)
results[job_name] = cancelled
return multi_image_reply(results)
def job_status(server: ServerContext, pool: DevicePoolExecutor):
legacy_job_name = request.args.get("job", None)
job_list = request.args.get("jobs", "").split(",")
if legacy_job_name is not None:
job_list.append(legacy_job_name)
if len(job_list) == 0:
return error_reply("at least one job name is required")
for job_name in job_list:
job_name = sanitize_name(job_name)
status, progress = pool.status(job_name)
# TODO: accumulate results
if progress is not None:
# TODO: add output paths based on progress.results counter
return image_reply(
job_name,
status,
"TODO",
stages=Progress(progress.stages, 0),
steps=Progress(progress.steps, 0),
tiles=Progress(progress.tiles, 0),
)
return image_reply(job_name, status, "TODO")
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor): def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
return [ return [
app.route("/api")(wrap_route(introspect, server, app=app)), app.route("/api")(wrap_route(introspect, server, app=app)),
# job routes
app.route("/api/job", methods=["POST"])(wrap_route(chain, server, pool=pool)),
app.route("/api/job/cancel", methods=["PUT"])(
wrap_route(job_cancel, server, pool=pool)
),
app.route("/api/job/status")(wrap_route(job_status, server, pool=pool)),
# settings routes
app.route("/api/settings/filters")(wrap_route(list_filters, server)), app.route("/api/settings/filters")(wrap_route(list_filters, server)),
app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)), app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)),
app.route("/api/settings/models")(wrap_route(list_models, server)), app.route("/api/settings/models")(wrap_route(list_models, server)),
@ -614,6 +700,7 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)), app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)),
app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)), app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)),
app.route("/api/settings/wildcards")(wrap_route(list_wildcards, server)), app.route("/api/settings/wildcards")(wrap_route(list_wildcards, server)),
# legacy job routes
app.route("/api/img2img", methods=["POST"])( app.route("/api/img2img", methods=["POST"])(
wrap_route(img2img, server, pool=pool) wrap_route(img2img, server, pool=pool)
), ),
@ -631,6 +718,7 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu
), ),
app.route("/api/chain", methods=["POST"])(wrap_route(chain, server, pool=pool)), app.route("/api/chain", methods=["POST"])(wrap_route(chain, server, pool=pool)),
app.route("/api/blend", methods=["POST"])(wrap_route(blend, server, pool=pool)), app.route("/api/blend", methods=["POST"])(wrap_route(blend, server, pool=pool)),
# deprecated routes
app.route("/api/cancel", methods=["PUT"])( app.route("/api/cancel", methods=["PUT"])(
wrap_route(cancel, server, pool=pool) wrap_route(cancel, server, pool=pool)
), ),

View File

@ -12,7 +12,6 @@ def run_txt2txt_pipeline(
_server: ServerContext, _server: ServerContext,
params: ImageParams, params: ImageParams,
_size: Size, _size: Size,
output: str,
) -> None: ) -> None:
from transformers import AutoTokenizer, GPTJForCausalLM from transformers import AutoTokenizer, GPTJForCausalLM
@ -38,4 +37,4 @@ def run_txt2txt_pipeline(
print("Server says: %s" % result_text) print("Server says: %s" % result_text)
logger.info("finished txt2txt job: %s", output) logger.info("finished txt2txt job: %s", worker.job)

View File

@ -2,16 +2,18 @@ import gc
import importlib import importlib
import json import json
import threading import threading
from hashlib import sha256
from json import JSONDecodeError from json import JSONDecodeError
from logging import getLogger from logging import getLogger
from os import environ, path from os import environ, path
from platform import system from platform import system
from struct import pack
from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
import torch import torch
from yaml import safe_load from yaml import safe_load
from .params import DeviceParams, SizeChart from .params import DeviceParams, Param, SizeChart
logger = getLogger(__name__) logger = getLogger(__name__)
@ -218,3 +220,34 @@ def load_config_str(raw: str) -> Dict:
return json.loads(raw) return json.loads(raw)
except JSONDecodeError: except JSONDecodeError:
return safe_load(raw) return safe_load(raw)
HASH_BUFFER_SIZE = 2**22 # 4MB
def hash_file(name: str):
sha = sha256()
with open(name, "rb") as f:
while True:
data = f.read(HASH_BUFFER_SIZE)
if not data:
break
sha.update(data)
return sha.hexdigest()
def hash_value(sha, param: Optional[Param]):
if param is None:
return
elif isinstance(param, bool):
sha.update(bytearray(pack("!B", param)))
elif isinstance(param, float):
sha.update(bytearray(pack("!f", param)))
elif isinstance(param, int):
sha.update(bytearray(pack("!I", param)))
elif isinstance(param, str):
sha.update(param.encode("utf-8"))
else:
logger.warning("cannot hash param: %s, %s", param, type(param))

View File

@ -1,34 +1,61 @@
from enum import Enum
from typing import Any, Callable, Dict from typing import Any, Callable, Dict
class JobStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
CANCELLED = "cancelled"
UNKNOWN = "unknown"
class JobType(str, Enum):
TXT2TXT = "txt2txt"
TXT2IMG = "txt2img"
IMG2IMG = "img2img"
INPAINT = "inpaint"
UPSCALE = "upscale"
BLEND = "blend"
CHAIN = "chain"
class ProgressCommand: class ProgressCommand:
device: str device: str
job: str job: str
finished: bool job_type: str
progress: int status: JobStatus
cancelled: bool results: int
failed: bool steps: int
stages: int
tiles: int
def __init__( def __init__(
self, self,
job: str, job: str,
job_type: str,
device: str, device: str,
finished: bool, status: JobStatus,
progress: int, results: int = 0,
cancelled: bool = False, steps: int = 0,
failed: bool = False, stages: int = 0,
tiles: int = 0,
): ):
self.job = job self.job = job
self.job_type = job_type
self.device = device self.device = device
self.finished = finished self.status = status
self.progress = progress self.results = results
self.cancelled = cancelled self.steps = steps
self.failed = failed self.stages = stages
self.tiles = tiles
class JobCommand: class JobCommand:
device: str device: str
name: str name: str
job_type: str
fn: Callable[..., None] fn: Callable[..., None]
args: Any args: Any
kwargs: Dict[str, Any] kwargs: Dict[str, Any]
@ -37,12 +64,14 @@ class JobCommand:
self, self,
name: str, name: str,
device: str, device: str,
job_type: str,
fn: Callable[..., None], fn: Callable[..., None],
args: Any, args: Any,
kwargs: Dict[str, Any], kwargs: Dict[str, Any],
): ):
self.device = device self.device = device
self.name = name self.name = name
self.job_type = job_type
self.fn = fn self.fn = fn
self.args = args self.args = args
self.kwargs = kwargs self.kwargs = kwargs

View File

@ -2,21 +2,23 @@ from logging import getLogger
from os import getpid from os import getpid
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
import numpy as np
from torch.multiprocessing import Queue, Value from torch.multiprocessing import Queue, Value
from ..errors import CancelledException from ..errors import CancelledException
from ..params import DeviceParams from ..params import DeviceParams
from .command import JobCommand, ProgressCommand from .command import JobCommand, JobStatus, ProgressCommand
logger = getLogger(__name__) logger = getLogger(__name__)
ProgressCallback = Callable[[int, int, Any], None] ProgressCallback = Callable[[int, int, np.ndarray], None]
class WorkerContext: class WorkerContext:
cancel: "Value[bool]" cancel: "Value[bool]"
job: Optional[str] job: Optional[str]
job_type: Optional[str]
name: str name: str
pending: "Queue[JobCommand]" pending: "Queue[JobCommand]"
active_pid: "Value[int]" active_pid: "Value[int]"
@ -41,6 +43,7 @@ class WorkerContext:
timeout: float, timeout: float,
): ):
self.job = None self.job = None
self.job_type = None
self.name = name self.name = name
self.device = device self.device = device
self.cancel = cancel self.cancel = cancel
@ -54,9 +57,15 @@ class WorkerContext:
self.retries = retries self.retries = retries
self.timeout = timeout self.timeout = timeout
def start(self, job: str) -> None: def start(self, job: JobCommand) -> None:
self.job = job # set job name and type
self.job = job.name
self.job_type = job.job_type
# reset retries
self.retries = self.initial_retries self.retries = self.initial_retries
# clear flags
self.set_cancel(cancel=False) self.set_cancel(cancel=False)
self.set_idle(idle=False) self.set_idle(idle=False)
@ -81,7 +90,7 @@ class WorkerContext:
def get_progress(self) -> int: def get_progress(self) -> int:
if self.last_progress is not None: if self.last_progress is not None:
return self.last_progress.progress return self.last_progress.steps
return 0 return 0
@ -112,13 +121,11 @@ class WorkerContext:
logger.debug("setting progress for job %s to %s", self.job, progress) logger.debug("setting progress for job %s to %s", self.job, progress)
self.last_progress = ProgressCommand( self.last_progress = ProgressCommand(
self.job, self.job,
self.job_type,
self.device.device, self.device.device,
False, JobStatus.RUNNING,
progress, steps=progress,
self.is_cancelled(),
False,
) )
self.progress.put( self.progress.put(
self.last_progress, self.last_progress,
block=False, block=False,
@ -131,11 +138,10 @@ class WorkerContext:
logger.debug("setting finished for job %s", self.job) logger.debug("setting finished for job %s", self.job)
self.last_progress = ProgressCommand( self.last_progress = ProgressCommand(
self.job, self.job,
self.job_type,
self.device.device, self.device.device,
True, JobStatus.SUCCESS, # TODO: FAILED
self.get_progress(), steps=self.get_progress(),
self.is_cancelled(),
False,
) )
self.progress.put( self.progress.put(
self.last_progress, self.last_progress,
@ -150,11 +156,10 @@ class WorkerContext:
try: try:
self.last_progress = ProgressCommand( self.last_progress = ProgressCommand(
self.job, self.job,
self.job_type,
self.device.device, self.device.device,
True, JobStatus.FAILED,
self.get_progress(), steps=self.get_progress(),
self.is_cancelled(),
True,
) )
self.progress.put( self.progress.put(
self.last_progress, self.last_progress,
@ -162,25 +167,3 @@ class WorkerContext:
) )
except Exception: except Exception:
logger.exception("error setting failure on job %s", self.job) logger.exception("error setting failure on job %s", self.job)
class JobStatus:
name: str
device: str
progress: int
cancelled: bool
finished: bool
def __init__(
self,
name: str,
device: DeviceParams,
progress: int = 0,
cancelled: bool = False,
finished: bool = False,
) -> None:
self.name = name
self.device = device.device
self.progress = progress
self.cancelled = cancelled
self.finished = finished

View File

@ -8,7 +8,7 @@ from torch.multiprocessing import Process, Queue, Value
from ..params import DeviceParams from ..params import DeviceParams
from ..server import ServerContext from ..server import ServerContext
from .command import JobCommand, ProgressCommand from .command import JobCommand, JobStatus, ProgressCommand
from .context import WorkerContext from .context import WorkerContext
from .utils import Interval from .utils import Interval
from .worker import worker_main from .worker import worker_main
@ -201,6 +201,10 @@ class DevicePoolExecutor:
should be cancelled on the next progress callback. should be cancelled on the next progress callback.
""" """
if key in self.cancelled_jobs:
logger.debug("cancelling already cancelled job: %s", key)
return True
for job in self.finished_jobs: for job in self.finished_jobs:
if job.job == key: if job.job == key:
logger.debug("cannot cancel finished job: %s", key) logger.debug("cannot cancel finished job: %s", key)
@ -209,6 +213,9 @@ class DevicePoolExecutor:
for job in self.pending_jobs: for job in self.pending_jobs:
if job.name == key: if job.name == key:
self.pending_jobs.remove(job) self.pending_jobs.remove(job)
self.cancelled_jobs.append(
key
) # ensure workers never pick up this job and the status endpoint knows about it later
logger.info("cancelled pending job: %s", key) logger.info("cancelled pending job: %s", key)
return True return True
@ -221,28 +228,31 @@ class DevicePoolExecutor:
self.cancelled_jobs.append(key) self.cancelled_jobs.append(key)
return True return True
def done(self, key: str) -> Tuple[bool, Optional[ProgressCommand]]: def status(self, key: str) -> Tuple[JobStatus, Optional[ProgressCommand]]:
""" """
Check if a job has been finished and report the last progress update. Check if a job has been finished and report the last progress update.
If the job is still pending, the first item will be True and there will be no ProgressCommand.
""" """
if key in self.cancelled_jobs:
logger.debug("checking status for cancelled job: %s", key)
return (JobStatus.CANCELLED, None)
if key in self.running_jobs: if key in self.running_jobs:
logger.debug("checking status for running job: %s", key) logger.debug("checking status for running job: %s", key)
return (False, self.running_jobs[key]) return (JobStatus.RUNNING, self.running_jobs[key])
for job in self.finished_jobs: for job in self.finished_jobs:
if job.job == key: if job.job == key:
logger.debug("checking status for finished job: %s", key) logger.debug("checking status for finished job: %s", key)
return (False, job) return (job.status, job)
for job in self.pending_jobs: for job in self.pending_jobs:
if job.name == key: if job.name == key:
logger.debug("checking status for pending job: %s", key) logger.debug("checking status for pending job: %s", key)
return (True, None) return (JobStatus.PENDING, None)
logger.trace("checking status for unknown job: %s", key) logger.trace("checking status for unknown job: %s", key)
return (False, None) return (JobStatus.UNKNOWN, None)
def join(self): def join(self):
logger.info("stopping worker pool") logger.info("stopping worker pool")
@ -383,6 +393,7 @@ class DevicePoolExecutor:
def submit( def submit(
self, self,
key: str, key: str,
job_type: str,
fn: Callable[..., None], fn: Callable[..., None],
/, /,
*args, *args,
@ -399,56 +410,63 @@ class DevicePoolExecutor:
) )
# build and queue job # build and queue job
job = JobCommand(key, device, fn, args, kwargs) job = JobCommand(key, device, job_type, fn, args, kwargs)
self.pending_jobs.append(job) self.pending_jobs.append(job)
def status(self) -> Dict[str, List[Tuple[str, int, bool, bool, bool, bool]]]: def summary(self) -> Dict[str, List[Tuple[str, int, JobStatus]]]:
""" """
Returns a tuple of: job/device, progress, progress, finished, cancelled, failed Returns a tuple of: job/device, progress, progress, finished, cancelled, failed
""" """
return {
"cancelled": [], jobs: Tuple[str, int, JobStatus] = []
"finished": [ jobs.extend(
[
( (
job.job, job,
job.progress, 0,
False, JobStatus.CANCELLED,
job.finished,
job.cancelled,
job.failed,
) )
for job in self.finished_jobs for job in self.cancelled_jobs
], ]
"pending": [ )
jobs.extend(
[
( (
job.name, job.name,
0, 0,
True, JobStatus.PENDING,
False,
False,
False,
) )
for job in self.pending_jobs for job in self.pending_jobs
], ]
"running": [ )
jobs.extend(
[
( (
name, name,
job.progress, job.steps,
False, job.status,
job.finished,
job.cancelled,
job.failed,
) )
for name, job in self.running_jobs.items() for name, job in self.running_jobs.items()
], ]
"total": [ )
jobs.extend(
[
(
job.job,
job.steps,
job.status,
)
for job in self.finished_jobs
]
)
return {
"jobs": jobs,
"workers": [
( (
device, device,
total, total,
self.workers[device].is_alive(), self.workers[device].is_alive(),
False,
False,
False,
) )
for device, total in self.total_jobs.items() for device, total in self.total_jobs.items()
], ],
@ -476,20 +494,18 @@ class DevicePoolExecutor:
self.cancelled_jobs.remove(progress.job) self.cancelled_jobs.remove(progress.job)
def update_job(self, progress: ProgressCommand): def update_job(self, progress: ProgressCommand):
if progress.finished: if progress.status in [JobStatus.SUCCESS, JobStatus.FAILED]:
return self.finish_job(progress) return self.finish_job(progress)
# move from pending to running # move from pending to running
logger.debug( logger.debug("progress update for job: %s to %s", progress.job, progress.steps)
"progress update for job: %s to %s", progress.job, progress.progress
)
self.running_jobs[progress.job] = progress self.running_jobs[progress.job] = progress
self.pending_jobs[:] = [ self.pending_jobs[:] = [
job for job in self.pending_jobs if job.name != progress.job job for job in self.pending_jobs if job.name != progress.job
] ]
# increment job counter if this is the start of a new job # increment job counter if this is the start of a new job
if progress.progress == 0: if progress.steps == 0:
if progress.device in self.total_jobs: if progress.device in self.total_jobs:
self.total_jobs[progress.device] += 1 self.total_jobs[progress.device] += 1
else: else:

View File

@ -57,7 +57,7 @@ def worker_main(
logger.info("worker %s got job: %s", worker.device.device, job.name) logger.info("worker %s got job: %s", worker.device.device, job.name)
# clear flags and save the job name # clear flags and save the job name
worker.start(job.name) worker.start(job)
logger.info("starting job: %s", job.name) logger.info("starting job: %s", job.name)
# reset progress, which does a final check for cancellation # reset progress, which does a final check for cancellation

View File

@ -1,12 +1,6 @@
from diffusers import OnnxStableDiffusionPipeline from diffusers import OnnxStableDiffusionPipeline
from os import path from os import path
import cv2
import numpy as np
import onnxruntime as ort
import torch
import time
cfg = 8 cfg = 8
steps = 22 steps = 22
height = 512 height = 512

View File

@ -22,6 +22,7 @@ from onnx_web.params import (
UpscaleParams, UpscaleParams,
) )
from onnx_web.server.context import ServerContext from onnx_web.server.context import ServerContext
from onnx_web.worker.command import JobCommand
from onnx_web.worker.context import WorkerContext from onnx_web.worker.context import WorkerContext
from tests.helpers import ( from tests.helpers import (
TEST_MODEL_DIFFUSION_SD15, TEST_MODEL_DIFFUSION_SD15,
@ -57,7 +58,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
3, 3,
0.1, 0.1,
) )
worker.start("test") worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
run_txt2img_pipeline( run_txt2img_pipeline(
worker, worker,
@ -72,7 +73,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
1, 1,
), ),
Size(256, 256), Size(256, 256),
["test-txt2img-basic.png"],
UpscaleParams("test"), UpscaleParams("test"),
HighresParams(False, 1, 0, 0), HighresParams(False, 1, 0, 0),
) )
@ -103,7 +103,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
3, 3,
0.1, 0.1,
) )
worker.start("test") worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
run_txt2img_pipeline( run_txt2img_pipeline(
worker, worker,
@ -119,7 +119,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
batch=2, batch=2,
), ),
Size(256, 256), Size(256, 256),
["test-txt2img-batch-0.png", "test-txt2img-batch-1.png"],
UpscaleParams("test"), UpscaleParams("test"),
HighresParams(False, 1, 0, 0), HighresParams(False, 1, 0, 0),
) )
@ -152,7 +151,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
3, 3,
0.1, 0.1,
) )
worker.start("test") worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
run_txt2img_pipeline( run_txt2img_pipeline(
worker, worker,
@ -168,7 +167,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
unet_tile=256, unet_tile=256,
), ),
Size(256, 256), Size(256, 256),
["test-txt2img-highres.png"],
UpscaleParams("test", scale=2, outscale=2), UpscaleParams("test", scale=2, outscale=2),
HighresParams(True, 2, 0, 0), HighresParams(True, 2, 0, 0),
) )
@ -198,7 +196,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
3, 3,
0.1, 0.1,
) )
worker.start("test") worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
run_txt2img_pipeline( run_txt2img_pipeline(
worker, worker,
@ -214,7 +212,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
batch=2, batch=2,
), ),
Size(256, 256), Size(256, 256),
["test-txt2img-highres-batch-0.png", "test-txt2img-highres-batch-1.png"],
UpscaleParams("test"), UpscaleParams("test"),
HighresParams(True, 2, 0, 0), HighresParams(True, 2, 0, 0),
) )
@ -230,7 +227,7 @@ class TestImg2ImgPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15]) @test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_basic(self): def test_basic(self):
worker = test_worker() worker = test_worker()
worker.start("test") worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
source = Image.new("RGB", (64, 64), "black") source = Image.new("RGB", (64, 64), "black")
run_img2img_pipeline( run_img2img_pipeline(
@ -245,7 +242,6 @@ class TestImg2ImgPipeline(unittest.TestCase):
1, 1,
1, 1,
), ),
["test-img2img.png"],
UpscaleParams("test"), UpscaleParams("test"),
HighresParams(False, 1, 0, 0), HighresParams(False, 1, 0, 0),
source, source,
@ -259,7 +255,7 @@ class TestInpaintPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT]) @test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
def test_basic_white(self): def test_basic_white(self):
worker = test_worker() worker = test_worker()
worker.start("test") worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
source = Image.new("RGB", (64, 64), "black") source = Image.new("RGB", (64, 64), "black")
mask = Image.new("RGB", (64, 64), "white") mask = Image.new("RGB", (64, 64), "white")
@ -277,7 +273,6 @@ class TestInpaintPipeline(unittest.TestCase):
unet_tile=64, unet_tile=64,
), ),
Size(*source.size), Size(*source.size),
["test-inpaint-white.png"],
UpscaleParams("test"), UpscaleParams("test"),
HighresParams(False, 1, 0, 0), HighresParams(False, 1, 0, 0),
source, source,
@ -296,7 +291,7 @@ class TestInpaintPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT]) @test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
def test_basic_black(self): def test_basic_black(self):
worker = test_worker() worker = test_worker()
worker.start("test") worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
source = Image.new("RGB", (64, 64), "black") source = Image.new("RGB", (64, 64), "black")
mask = Image.new("RGB", (64, 64), "black") mask = Image.new("RGB", (64, 64), "black")
@ -314,7 +309,6 @@ class TestInpaintPipeline(unittest.TestCase):
unet_tile=64, unet_tile=64,
), ),
Size(*source.size), Size(*source.size),
["test-inpaint-black.png"],
UpscaleParams("test"), UpscaleParams("test"),
HighresParams(False, 1, 0, 0), HighresParams(False, 1, 0, 0),
source, source,
@ -353,7 +347,7 @@ class TestUpscalePipeline(unittest.TestCase):
3, 3,
0.1, 0.1,
) )
worker.start("test") worker.start(JobCommand("test", "test", "test", run_upscale_pipeline, [], {}))
source = Image.new("RGB", (64, 64), "black") source = Image.new("RGB", (64, 64), "black")
run_upscale_pipeline( run_upscale_pipeline(
@ -369,7 +363,6 @@ class TestUpscalePipeline(unittest.TestCase):
1, 1,
), ),
Size(256, 256), Size(256, 256),
["test-upscale.png"],
UpscaleParams("test"), UpscaleParams("test"),
HighresParams(False, 1, 0, 0), HighresParams(False, 1, 0, 0),
source, source,
@ -399,7 +392,7 @@ class TestBlendPipeline(unittest.TestCase):
3, 3,
0.1, 0.1,
) )
worker.start("test") worker.start(JobCommand("test", "test", "test", run_blend_pipeline, [], {}))
source = Image.new("RGBA", (64, 64), "black") source = Image.new("RGBA", (64, 64), "black")
mask = Image.new("RGBA", (64, 64), "white") mask = Image.new("RGBA", (64, 64), "white")
@ -417,7 +410,6 @@ class TestBlendPipeline(unittest.TestCase):
unet_tile=64, unet_tile=64,
), ),
Size(64, 64), Size(64, 64),
["test-blend.png"],
UpscaleParams("test"), UpscaleParams("test"),
[source, source], [source, source],
mask, mask,

View File

@ -5,6 +5,7 @@ from typing import Optional
from onnx_web.params import DeviceParams from onnx_web.params import DeviceParams
from onnx_web.server.context import ServerContext from onnx_web.server.context import ServerContext
from onnx_web.worker.command import JobStatus
from onnx_web.worker.pool import DevicePoolExecutor from onnx_web.worker.pool import DevicePoolExecutor
TEST_JOIN_TIMEOUT = 0.2 TEST_JOIN_TIMEOUT = 0.2
@ -50,11 +51,11 @@ class TestWorkerPool(unittest.TestCase):
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start() self.pool.start()
self.pool.submit("test", wait_job, lock=lock) self.pool.submit("test", "test", wait_job, lock=lock)
self.assertEqual(self.pool.done("test"), (True, None)) self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
self.assertTrue(self.pool.cancel("test")) self.assertTrue(self.pool.cancel("test"))
self.assertEqual(self.pool.done("test"), (False, None)) self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None))
def test_cancel_running(self): def test_cancel_running(self):
pass pass
@ -88,12 +89,14 @@ class TestWorkerPool(unittest.TestCase):
self.pool = DevicePoolExecutor( self.pool = DevicePoolExecutor(
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
) )
lock.clear()
self.pool.start(lock) self.pool.start(lock)
self.pool.submit("test", test_job) self.pool.submit("test", "test", test_job)
sleep(5.0) sleep(5.0)
pending, _progress = self.pool.done("test") status, _progress = self.pool.status("test")
self.assertFalse(pending) self.assertEqual(status, JobStatus.RUNNING)
def test_done_pending(self): def test_done_pending(self):
device = DeviceParams("cpu", "CPUProvider") device = DeviceParams("cpu", "CPUProvider")
@ -102,9 +105,9 @@ class TestWorkerPool(unittest.TestCase):
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT) self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start(lock) self.pool.start(lock)
self.pool.submit("test1", test_job) self.pool.submit("test1", "test", test_job)
self.pool.submit("test2", test_job) self.pool.submit("test2", "test", test_job)
self.assertTrue(self.pool.done("test2"), (True, None)) self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None))
lock.set() lock.set()
@ -119,12 +122,12 @@ class TestWorkerPool(unittest.TestCase):
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1 server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
) )
self.pool.start() self.pool.start()
self.pool.submit("test", wait_job) self.pool.submit("test", "test", wait_job)
self.assertEqual(self.pool.done("test"), (True, None)) self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
sleep(5.0) sleep(5.0)
pending, _progress = self.pool.done("test") status, _progress = self.pool.status("test")
self.assertFalse(pending) self.assertEqual(status, JobStatus.SUCCESS)
def test_recycle_live(self): def test_recycle_live(self):
pass pass

View File

@ -40,7 +40,7 @@ class WorkerMainTests(unittest.TestCase):
nonlocal status nonlocal status
status = exit_status status = exit_status
job = JobCommand("test", "test", main_interrupt, [], {}) job = JobCommand("test", "test", "test", main_interrupt, [], {})
cancel = Value("L", False) cancel = Value("L", False)
logs = Queue() logs = Queue()
pending = Queue() pending = Queue()
@ -75,7 +75,7 @@ class WorkerMainTests(unittest.TestCase):
nonlocal status nonlocal status
status = exit_status status = exit_status
job = JobCommand("test", "test", main_retry, [], {}) job = JobCommand("test", "test", "test", main_retry, [], {})
cancel = Value("L", False) cancel = Value("L", False)
logs = Queue() logs = Queue()
pending = Queue() pending = Queue()
@ -144,7 +144,7 @@ class WorkerMainTests(unittest.TestCase):
nonlocal status nonlocal status
status = exit_status status = exit_status
job = JobCommand("test", "test", main_memory, [], {}) job = JobCommand("test", "test", "test", main_memory, [], {})
cancel = Value("L", False) cancel = Value("L", False)
logs = Queue() logs = Queue()
pending = Queue() pending = Queue()

View File

@ -3,7 +3,7 @@
onnx-web is designed to simplify the process of running Stable Diffusion and other [ONNX models](https://onnx.ai) so you onnx-web is designed to simplify the process of running Stable Diffusion and other [ONNX models](https://onnx.ai) so you
can focus on making high quality, high resolution art. With the efficiency of hardware acceleration on both AMD and can focus on making high quality, high resolution art. With the efficiency of hardware acceleration on both AMD and
Nvidia GPUs, and offering a reliable CPU software fallback, it offers the full feature set on desktop, laptops, and Nvidia GPUs, and offering a reliable CPU software fallback, it offers the full feature set on desktop, laptops, and
servers with a seamless user experience. multi-GPU servers with a seamless user experience.
You can navigate through the user-friendly web UI, hosted on Github Pages and accessible across all major browsers, You can navigate through the user-friendly web UI, hosted on Github Pages and accessible across all major browsers,
including your go-to mobile device. Here, you have the flexibility to choose diffusion models and accelerators for each including your go-to mobile device. Here, you have the flexibility to choose diffusion models and accelerators for each
@ -84,18 +84,6 @@ This is an incomplete list of new and interesting features:
- includes both the API and GUI bundle in a single container - includes both the API and GUI bundle in a single container
- runs well on [RunPod](https://www.runpod.io/), [Vast.ai](https://vast.ai/), and other GPU container hosting services - runs well on [RunPod](https://www.runpod.io/), [Vast.ai](https://vast.ai/), and other GPU container hosting services
## Contents
- [onnx-web](#onnx-web)
- [Features](#features)
- [Contents](#contents)
- [Setup](#setup)
- [Adding your own models](#adding-your-own-models)
- [Usage](#usage)
- [Known errors and solutions](#known-errors-and-solutions)
- [Running the containers](#running-the-containers)
- [Credits](#credits)
## Setup ## Setup
There are a few ways to run onnx-web: There are a few ways to run onnx-web:

View File

@ -4,10 +4,7 @@ import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils';
import { ServerParams } from '../config.js'; import { ServerParams } from '../config.js';
import { import {
FilterResponse, FilterResponse,
ImageResponse,
ImageResponseWithRetry,
ModelResponse, ModelResponse,
ReadyResponse,
RetryParams, RetryParams,
WriteExtrasResponse, WriteExtrasResponse,
} from '../types/api.js'; } from '../types/api.js';
@ -27,6 +24,7 @@ import {
} from '../types/params.js'; } from '../types/params.js';
import { range } from '../utils.js'; import { range } from '../utils.js';
import { ApiClient } from './base.js'; import { ApiClient } from './base.js';
import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/api-v2.js';
/** /**
* Fixed precision for integer parameters. * Fixed precision for integer parameters.
@ -43,8 +41,9 @@ export const FIXED_INTEGER = 0;
export const FIXED_FLOAT = 2; export const FIXED_FLOAT = 2;
export const STATUS_SUCCESS = 200; export const STATUS_SUCCESS = 200;
export function equalResponse(a: ImageResponse, b: ImageResponse): boolean { export function equalResponse(a: JobResponse, b: JobResponse): boolean {
return a.outputs === b.outputs; return a.name === b.name && a.status === b.status && a.type === b.type;
// return a.outputs === b.outputs;
} }
/** /**
@ -141,8 +140,8 @@ export function appendHighresToURL(url: URL, highres: HighresParams) {
* Make an API client using the given API root and fetch client. * Make an API client using the given API root and fetch client.
*/ */
export function makeClient(root: string, token: Maybe<string> = undefined, f = fetch): ApiClient { export function makeClient(root: string, token: Maybe<string> = undefined, f = fetch): ApiClient {
function parseRequest(url: URL, options: RequestInit): Promise<ImageResponse> { function parseRequest(url: URL, options: RequestInit): Promise<JobResponse> {
return f(url, options).then((res) => parseApiResponse(root, res)); return f(url, options).then((res) => parseJobResponse(root, res));
} }
return { return {
@ -218,7 +217,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
const res = await f(path); const res = await f(path);
return await res.json() as Array<string>; return await res.json() as Array<string>;
}, },
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> { async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
const url = makeImageURL(root, 'img2img', params); const url = makeImageURL(root, 'img2img', params);
appendModelToURL(url, model); appendModelToURL(url, model);
@ -240,12 +239,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
const body = new FormData(); const body = new FormData();
body.append('source', params.source, 'source'); body.append('source', params.source, 'source');
const image = await parseRequest(url, { const job = await parseRequest(url, {
body, body,
method: 'POST', method: 'POST',
}); });
return { return {
image, job,
retry: { retry: {
type: 'img2img', type: 'img2img',
model, model,
@ -254,7 +253,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
}, },
}; };
}, },
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> { async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
const url = makeImageURL(root, 'txt2img', params); const url = makeImageURL(root, 'txt2img', params);
appendModelToURL(url, model); appendModelToURL(url, model);
@ -274,11 +273,11 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
appendHighresToURL(url, highres); appendHighresToURL(url, highres);
} }
const image = await parseRequest(url, { const job = await parseRequest(url, {
method: 'POST', method: 'POST',
}); });
return { return {
image, job,
retry: { retry: {
type: 'txt2img', type: 'txt2img',
model, model,
@ -288,7 +287,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
}, },
}; };
}, },
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> { async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
const url = makeImageURL(root, 'inpaint', params); const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model); appendModelToURL(url, model);
@ -309,12 +308,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
body.append('mask', params.mask, 'mask'); body.append('mask', params.mask, 'mask');
body.append('source', params.source, 'source'); body.append('source', params.source, 'source');
const image = await parseRequest(url, { const job = await parseRequest(url, {
body, body,
method: 'POST', method: 'POST',
}); });
return { return {
image, job,
retry: { retry: {
type: 'inpaint', type: 'inpaint',
model, model,
@ -323,7 +322,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
}, },
}; };
}, },
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> { async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
const url = makeImageURL(root, 'inpaint', params); const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model); appendModelToURL(url, model);
@ -361,12 +360,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
body.append('mask', params.mask, 'mask'); body.append('mask', params.mask, 'mask');
body.append('source', params.source, 'source'); body.append('source', params.source, 'source');
const image = await parseRequest(url, { const job = await parseRequest(url, {
body, body,
method: 'POST', method: 'POST',
}); });
return { return {
image, job,
retry: { retry: {
type: 'outpaint', type: 'outpaint',
model, model,
@ -375,7 +374,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
}, },
}; };
}, },
async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> { async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
const url = makeApiUrl(root, 'upscale'); const url = makeApiUrl(root, 'upscale');
appendModelToURL(url, model); appendModelToURL(url, model);
@ -396,12 +395,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
const body = new FormData(); const body = new FormData();
body.append('source', params.source, 'source'); body.append('source', params.source, 'source');
const image = await parseRequest(url, { const job = await parseRequest(url, {
body, body,
method: 'POST', method: 'POST',
}); });
return { return {
image, job,
retry: { retry: {
type: 'upscale', type: 'upscale',
model, model,
@ -410,7 +409,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
}, },
}; };
}, },
async blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> { async blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<JobResponseWithRetry> {
const url = makeApiUrl(root, 'blend'); const url = makeApiUrl(root, 'blend');
appendModelToURL(url, model); appendModelToURL(url, model);
@ -426,12 +425,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
body.append(name, params.sources[i], name); body.append(name, params.sources[i], name);
} }
const image = await parseRequest(url, { const job = await parseRequest(url, {
body, body,
method: 'POST', method: 'POST',
}); });
return { return {
image, job,
retry: { retry: {
type: 'blend', type: 'blend',
model, model,
@ -440,8 +439,8 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
} }
}; };
}, },
async chain(model: ModelParams, chain: ChainPipeline): Promise<ImageResponse> { async chain(model: ModelParams, chain: ChainPipeline): Promise<JobResponse> {
const url = makeApiUrl(root, 'chain'); const url = makeApiUrl(root, 'job');
const body = JSON.stringify({ const body = JSON.stringify({
...chain, ...chain,
platform: model.platform, platform: model.platform,
@ -456,23 +455,23 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
method: 'POST', method: 'POST',
}); });
}, },
async ready(key: string): Promise<ReadyResponse> { async status(keys: Array<string>): Promise<Array<JobResponse>> {
const path = makeApiUrl(root, 'ready'); const path = makeApiUrl(root, 'job', 'status');
path.searchParams.append('output', key); path.searchParams.append('jobs', keys.join(','));
const res = await f(path); const res = await f(path);
return await res.json() as ReadyResponse; return await res.json() as Array<JobResponse>;
}, },
async cancel(key: string): Promise<boolean> { async cancel(keys: Array<string>): Promise<Array<JobResponse>> {
const path = makeApiUrl(root, 'cancel'); const path = makeApiUrl(root, 'job', 'cancel');
path.searchParams.append('output', key); path.searchParams.append('jobs', keys.join(','));
const res = await f(path, { const res = await f(path, {
method: 'PUT', method: 'PUT',
}); });
return res.status === STATUS_SUCCESS; return await res.json() as Array<JobResponse>;
}, },
async retry(retry: RetryParams): Promise<ImageResponseWithRetry> { async retry(retry: RetryParams): Promise<JobResponseWithRetry> {
switch (retry.type) { switch (retry.type) {
case 'blend': case 'blend':
return this.blend(retry.model, retry.params, retry.upscale); return this.blend(retry.model, retry.params, retry.upscale);
@ -491,7 +490,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
} }
}, },
async restart(): Promise<boolean> { async restart(): Promise<boolean> {
const path = makeApiUrl(root, 'restart'); const path = makeApiUrl(root, 'worker', 'restart');
if (doesExist(token)) { if (doesExist(token)) {
path.searchParams.append('token', token); path.searchParams.append('token', token);
@ -502,8 +501,8 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
}); });
return res.status === STATUS_SUCCESS; return res.status === STATUS_SUCCESS;
}, },
async status(): Promise<Array<unknown>> { async workers(): Promise<Array<unknown>> {
const path = makeApiUrl(root, 'status'); const path = makeApiUrl(root, 'worker', 'status');
if (doesExist(token)) { if (doesExist(token)) {
path.searchParams.append('token', token); path.searchParams.append('token', token);
@ -512,6 +511,9 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
const res = await f(path); const res = await f(path);
return res.json(); return res.json();
}, },
outputURL(image: SuccessJobResponse, index: number): string {
return new URL(joinPath('output', image.outputs[index]), root).toString();
},
}; };
} }
@ -521,24 +523,9 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
* The server sends over the output key, and the client is in the best position to turn * The server sends over the output key, and the client is in the best position to turn
* that into a full URL, since it already knows the root URL of the server. * that into a full URL, since it already knows the root URL of the server.
*/ */
export async function parseApiResponse(root: string, res: Response): Promise<ImageResponse> { export async function parseJobResponse(root: string, res: Response): Promise<JobResponse> {
type LimitedResponse = Omit<ImageResponse, 'outputs'> & { outputs: Array<string> };
if (res.status === STATUS_SUCCESS) { if (res.status === STATUS_SUCCESS) {
const data = await res.json() as LimitedResponse; return await res.json() as JobResponse;
const outputs = data.outputs.map((output) => {
const url = new URL(joinPath('output', output), root).toString();
return {
key: output,
url,
};
});
return {
...data,
outputs,
};
} else { } else {
throw new Error('request error'); throw new Error('request error');
} }

View File

@ -1,12 +1,19 @@
import { ServerParams } from '../config.js'; import { ServerParams } from '../config.js';
import { ExtrasFile } from '../types/model.js'; import { ExtrasFile } from '../types/model.js';
import { WriteExtrasResponse, FilterResponse, ModelResponse, ImageResponseWithRetry, ImageResponse, ReadyResponse, RetryParams } from '../types/api.js'; import { WriteExtrasResponse, FilterResponse, ModelResponse, RetryParams } from '../types/api.js';
import { ChainPipeline } from '../types/chain.js'; import { ChainPipeline } from '../types/chain.js';
import { ModelParams, Txt2ImgParams, UpscaleParams, HighresParams, Img2ImgParams, InpaintParams, OutpaintParams, UpscaleReqParams, BlendParams } from '../types/params.js'; import { ModelParams, Txt2ImgParams, UpscaleParams, HighresParams, Img2ImgParams, InpaintParams, OutpaintParams, UpscaleReqParams, BlendParams } from '../types/params.js';
import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/api-v2.js';
export interface ApiClient { export interface ApiClient {
/**
* Get the first extras file.
*/
extras(): Promise<ExtrasFile>; extras(): Promise<ExtrasFile>;
/**
* Update the first extras file.
*/
writeExtras(extras: ExtrasFile): Promise<WriteExtrasResponse>; writeExtras(extras: ExtrasFile): Promise<WriteExtrasResponse>;
/** /**
@ -51,54 +58,60 @@ export interface ApiClient {
translation: Record<string, string>; translation: Record<string, string>;
}>>; }>>;
/**
* Get the available wildcards.
*/
wildcards(): Promise<Array<string>>; wildcards(): Promise<Array<string>>;
/** /**
* Start a txt2img pipeline. * Start a txt2img pipeline.
*/ */
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>; txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
/** /**
* Start an im2img pipeline. * Start an im2img pipeline.
*/ */
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>; img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
/** /**
* Start an inpaint pipeline. * Start an inpaint pipeline.
*/ */
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>; inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
/** /**
* Start an outpaint pipeline. * Start an outpaint pipeline.
*/ */
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>; outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
/** /**
* Start an upscale pipeline. * Start an upscale pipeline.
*/ */
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>; upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
/** /**
* Start a blending pipeline. * Start a blending pipeline.
*/ */
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>; blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<JobResponseWithRetry>;
chain(model: ModelParams, chain: ChainPipeline): Promise<ImageResponse>; /**
* Start a custom chain pipeline.
*/
chain(model: ModelParams, chain: ChainPipeline): Promise<JobResponse>;
/** /**
* Check whether job has finished and its output is ready. * Check whether job has finished and its output is ready.
*/ */
ready(key: string): Promise<ReadyResponse>; status(keys: Array<string>): Promise<Array<JobResponse>>;
/** /**
* Cancel an existing job. * Cancel an existing job.
*/ */
cancel(key: string): Promise<boolean>; cancel(keys: Array<string>): Promise<Array<JobResponse>>;
/** /**
* Retry a previous job using the same parameters. * Retry a previous job using the same parameters.
*/ */
retry(params: RetryParams): Promise<ImageResponseWithRetry>; retry(params: RetryParams): Promise<JobResponseWithRetry>;
/** /**
* Restart the image job workers. * Restart the image job workers.
@ -108,5 +121,7 @@ export interface ApiClient {
/** /**
* Check the status of the image job workers. * Check the status of the image job workers.
*/ */
status(): Promise<Array<unknown>>; workers(): Promise<Array<unknown>>;
outputURL(image: SuccessJobResponse, index: number): string;
} }

View File

@ -48,7 +48,7 @@ export const LOCAL_CLIENT = {
async params() { async params() {
throw new NoServerError(); throw new NoServerError();
}, },
async ready(key) { async status(key) {
throw new NoServerError(); throw new NoServerError();
}, },
async cancel(key) { async cancel(key) {
@ -78,7 +78,10 @@ export const LOCAL_CLIENT = {
async restart() { async restart() {
throw new NoServerError(); throw new NoServerError();
}, },
async status() { async workers() {
throw new NoServerError(); throw new NoServerError();
} },
outputURL(image, index) {
throw new NoServerError();
},
} as ApiClient; } as ApiClient;

View File

@ -97,11 +97,19 @@ export function expandRanges(range: string): Array<string | number> {
export const GRID_TILE_SIZE = 8192; export const GRID_TILE_SIZE = 8192;
// eslint-disable-next-line max-params // eslint-disable-next-line max-params
export function makeTxt2ImgGridPipeline(grid: PipelineGrid, model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): ChainPipeline { export function makeTxt2ImgGridPipeline(
grid: PipelineGrid,
model: ModelParams,
params: Txt2ImgParams,
upscale?: UpscaleParams,
highres?: HighresParams,
): ChainPipeline {
const pipeline: ChainPipeline = { const pipeline: ChainPipeline = {
defaults: { defaults: {
...model, ...model,
...params, ...params,
...(upscale || {}),
...(highres || {}),
}, },
stages: [], stages: [],
}; };

View File

@ -10,6 +10,7 @@ import { OnnxState, StateContext } from '../state/full.js';
import { ErrorCard } from './card/ErrorCard.js'; import { ErrorCard } from './card/ErrorCard.js';
import { ImageCard } from './card/ImageCard.js'; import { ImageCard } from './card/ImageCard.js';
import { LoadingCard } from './card/LoadingCard.js'; import { LoadingCard } from './card/LoadingCard.js';
import { JobStatus } from '../types/api-v2.js';
export function ImageHistory() { export function ImageHistory() {
const store = mustExist(useContext(StateContext)); const store = mustExist(useContext(StateContext));
@ -25,19 +26,19 @@ export function ImageHistory() {
const limited = history.slice(0, limit); const limited = history.slice(0, limit);
for (const item of limited) { for (const item of limited) {
const key = item.image.outputs[0].key; const key = item.image.name;
if (doesExist(item.ready) && item.ready.ready) {
if (item.ready.cancelled || item.ready.failed) {
children.push([key, <ErrorCard key={`history-${key}`} image={item.image} ready={item.ready} retry={item.retry} />]);
continue;
}
switch (item.image.status) {
case JobStatus.SUCCESS:
children.push([key, <ImageCard key={`history-${key}`} image={item.image} onDelete={removeHistory} />]); children.push([key, <ImageCard key={`history-${key}`} image={item.image} onDelete={removeHistory} />]);
continue; break;
case JobStatus.FAILED:
children.push([key, <ErrorCard key={`history-${key}`} image={item.image} retry={item.retry} />]);
break;
default:
children.push([key, <LoadingCard key={`history-${key}`} image={item.image} />]);
break;
} }
children.push([key, <LoadingCard key={`history-${key}`} index={0} image={item.image} />]);
} }
return <Grid container spacing={2}>{children.map(([key, child]) => <Grid item key={key} xs={6}>{child}</Grid>)}</Grid>; return <Grid container spacing={2}>{children.map(([key, child]) => <Grid item key={key} xs={6}>{child}</Grid>)}</Grid>;

View File

@ -10,16 +10,15 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { ImageResponse, ReadyResponse, RetryParams } from '../../types/api.js'; import { FailedJobResponse, RetryParams } from '../../types/api-v2.js';
export interface ErrorCardProps { export interface ErrorCardProps {
image: ImageResponse; image: FailedJobResponse;
ready: ReadyResponse;
retry: Maybe<RetryParams>; retry: Maybe<RetryParams>;
} }
export function ErrorCard(props: ErrorCardProps) { export function ErrorCard(props: ErrorCardProps) {
const { image, ready, retry: retryParams } = props; const { image, retry: retryParams } = props;
const client = mustExist(useContext(ClientContext)); const client = mustExist(useContext(ClientContext));
const { params } = mustExist(useContext(ConfigContext)); const { params } = mustExist(useContext(ConfigContext));
@ -32,8 +31,8 @@ export function ErrorCard(props: ErrorCardProps) {
removeHistory(image); removeHistory(image);
if (doesExist(retryParams)) { if (doesExist(retryParams)) {
const { image: nextImage, retry: nextRetry } = await client.retry(retryParams); const { job: nextJob, retry: nextRetry } = await client.retry(retryParams);
pushHistory(nextImage, nextRetry); pushHistory(nextJob, nextRetry);
} }
} }
@ -52,10 +51,11 @@ export function ErrorCard(props: ErrorCardProps) {
spacing={2} spacing={2}
sx={{ alignItems: 'center' }} sx={{ alignItems: 'center' }}
> >
<Alert severity='error'>{t('loading.progress', { <Alert severity='error'>
current: ready.progress, {t('loading.progress', image.steps)}
total: image.params.steps, <br />
})}</Alert> {image.error}
</Alert>
<Stack direction='row' spacing={2}> <Stack direction='row' spacing={2}>
<Tooltip title={t('tooltip.retry')}> <Tooltip title={t('tooltip.retry')}>
<IconButton onClick={() => retry.mutate()}> <IconButton onClick={() => retry.mutate()}>

View File

@ -2,21 +2,22 @@ import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils'
import { ArrowLeft, ArrowRight, Blender, Brush, ContentCopy, Delete, Download, ZoomOutMap } from '@mui/icons-material'; import { ArrowLeft, ArrowRight, Blender, Brush, ContentCopy, Delete, Download, ZoomOutMap } from '@mui/icons-material';
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Menu, MenuItem, Paper, Tooltip } from '@mui/material'; import { Box, Card, CardContent, CardMedia, Grid, IconButton, Menu, MenuItem, Paper, Tooltip } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { useContext, useState } from 'react'; import { useContext, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useHash } from 'react-use/lib/useHash'; import { useHash } from 'react-use/lib/useHash';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { ImageResponse } from '../../types/api.js';
import { range, visibleIndex } from '../../utils.js'; import { range, visibleIndex } from '../../utils.js';
import { BLEND_SOURCES } from '../../constants.js'; import { BLEND_SOURCES } from '../../constants.js';
import { JobResponse, SuccessJobResponse } from '../../types/api-v2.js';
import { getApiRoot } from '../../config.js';
export interface ImageCardProps { export interface ImageCardProps {
image: ImageResponse; image: SuccessJobResponse;
onDelete?: (key: ImageResponse) => void; onDelete?: (key: JobResponse) => void;
} }
export function GridItem(props: { xs: number; children: React.ReactNode }) { export function GridItem(props: { xs: number; children: React.ReactNode }) {
@ -27,18 +28,19 @@ export function GridItem(props: { xs: number; children: React.ReactNode }) {
export function ImageCard(props: ImageCardProps) { export function ImageCard(props: ImageCardProps) {
const { image } = props; const { image } = props;
const { params, outputs, size } = image; const { metadata, outputs } = image;
const [_hash, setHash] = useHash(); const [_hash, setHash] = useHash();
const [blendAnchor, setBlendAnchor] = useState<Maybe<HTMLElement>>(); const [blendAnchor, setBlendAnchor] = useState<Maybe<HTMLElement>>();
const [saveAnchor, setSaveAnchor] = useState<Maybe<HTMLElement>>(); const [saveAnchor, setSaveAnchor] = useState<Maybe<HTMLElement>>();
const client = mustExist(useContext(ClientContext));
const config = mustExist(useContext(ConfigContext)); const config = mustExist(useContext(ConfigContext));
const store = mustExist(useContext(StateContext)); const store = mustExist(useContext(StateContext));
const { setBlend, setImg2Img, setInpaint, setUpscale } = useStore(store, selectActions, shallow); const { setBlend, setImg2Img, setInpaint, setUpscale } = useStore(store, selectActions, shallow);
async function loadSource() { async function loadSource() {
const req = await fetch(outputs[index].url); const req = await fetch(url);
return req.blob(); return req.blob();
} }
@ -84,12 +86,12 @@ export function ImageCard(props: ImageCardProps) {
} }
function downloadImage() { function downloadImage() {
window.open(outputs[index].url, '_blank'); window.open(url, '_blank');
close(); close();
} }
function downloadMetadata() { function downloadMetadata() {
window.open(outputs[index].url + '.json', '_blank'); window.open(url + '.json', '_blank');
close(); close();
} }
@ -106,14 +108,16 @@ export function ImageCard(props: ImageCardProps) {
return mustDefault(t(`${key}.${name}`), name); return mustDefault(t(`${key}.${name}`), name);
} }
const model = getLabel('model', params.model); const url = useMemo(() => client.outputURL(image, index), [image, index]);
const scheduler = getLabel('scheduler', params.scheduler);
const model = getLabel('model', metadata[index].model);
const scheduler = getLabel('scheduler', metadata[index].scheduler);
return <Card sx={{ maxWidth: config.params.width.default }} elevation={2}> return <Card sx={{ maxWidth: config.params.width.default }} elevation={2}>
<CardMedia sx={{ height: config.params.height.default }} <CardMedia sx={{ height: config.params.height.default }}
component='img' component='img'
image={outputs[index].url} image={url}
title={params.prompt} title={metadata[index].params.prompt}
/> />
<CardContent> <CardContent>
<Box textAlign='center'> <Box textAlign='center'>
@ -146,12 +150,12 @@ export function ImageCard(props: ImageCardProps) {
</GridItem> </GridItem>
<GridItem xs={4}>{t('modelType.diffusion', {count: 1})}: {model}</GridItem> <GridItem xs={4}>{t('modelType.diffusion', {count: 1})}: {model}</GridItem>
<GridItem xs={4}>{t('parameter.scheduler')}: {scheduler}</GridItem> <GridItem xs={4}>{t('parameter.scheduler')}: {scheduler}</GridItem>
<GridItem xs={4}>{t('parameter.seed')}: {params.seed}</GridItem> <GridItem xs={4}>{t('parameter.seed')}: {metadata[index].params.seed}</GridItem>
<GridItem xs={4}>{t('parameter.cfg')}: {params.cfg}</GridItem> <GridItem xs={4}>{t('parameter.cfg')}: {metadata[index].params.cfg}</GridItem>
<GridItem xs={4}>{t('parameter.steps')}: {params.steps}</GridItem> <GridItem xs={4}>{t('parameter.steps')}: {metadata[index].params.steps}</GridItem>
<GridItem xs={4}>{t('parameter.size')}: {size.width}x{size.height}</GridItem> <GridItem xs={4}>{t('parameter.size')}: {metadata[index].size.width}x{metadata[index].size.height}</GridItem>
<GridItem xs={12}> <GridItem xs={12}>
<Box textAlign='left'>{params.prompt}</Box> <Box textAlign='left'>{metadata[index].params.prompt}</Box>
</GridItem> </GridItem>
<GridItem xs={2}> <GridItem xs={2}>
<Tooltip title={t('tooltip.save')}> <Tooltip title={t('tooltip.save')}>

View File

@ -1,4 +1,4 @@
import { doesExist, mustExist } from '@apextoaster/js-utils'; import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils';
import { Box, Button, Card, CardContent, CircularProgress, Typography } from '@mui/material'; import { Box, Button, Card, CardContent, CircularProgress, Typography } from '@mui/material';
import { Stack } from '@mui/system'; import { Stack } from '@mui/system';
import { useMutation, useQuery } from '@tanstack/react-query'; import { useMutation, useQuery } from '@tanstack/react-query';
@ -10,19 +10,17 @@ import { shallow } from 'zustand/shallow';
import { POLL_TIME } from '../../config.js'; import { POLL_TIME } from '../../config.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { ImageResponse } from '../../types/api.js'; import { JobResponse, JobStatus } from '../../types/api-v2.js';
const LOADING_PERCENT = 100; const LOADING_PERCENT = 100;
const LOADING_OVERAGE = 99; const LOADING_OVERAGE = 99;
export interface LoadingCardProps { export interface LoadingCardProps {
image: ImageResponse; image: JobResponse;
index: number;
} }
export function LoadingCard(props: LoadingCardProps) { export function LoadingCard(props: LoadingCardProps) {
const { image, index } = props; const { image } = props;
const { steps } = props.image.params;
const client = mustExist(useContext(ClientContext)); const client = mustExist(useContext(ClientContext));
const { params } = mustExist(useContext(ConfigContext)); const { params } = mustExist(useContext(ConfigContext));
@ -31,50 +29,22 @@ export function LoadingCard(props: LoadingCardProps) {
const { removeHistory, setReady } = useStore(store, selectActions, shallow); const { removeHistory, setReady } = useStore(store, selectActions, shallow);
const { t } = useTranslation(); const { t } = useTranslation();
const cancel = useMutation(() => client.cancel(image.outputs[index].key)); const cancel = useMutation(() => client.cancel([image.name]));
const ready = useQuery(['ready', image.outputs[index].key], () => client.ready(image.outputs[index].key), { const ready = useQuery(['ready', image.name], () => client.status([image.name]), {
// data will always be ready without this, even if the API says its not // data will always be ready without this, even if the API says its not
cacheTime: 0, cacheTime: 0,
refetchInterval: POLL_TIME, refetchInterval: POLL_TIME,
}); });
function getProgress() {
if (doesExist(ready.data)) {
return ready.data.progress;
}
return 0;
}
function getPercent() {
const progress = getProgress();
if (progress > steps) {
// steps was not complete, show 99% until done
return LOADING_OVERAGE;
}
const pct = progress / steps;
return Math.ceil(pct * LOADING_PERCENT);
}
function getTotal() {
const progress = getProgress();
if (progress > steps) {
// steps was not complete, show 99% until done
return t('loading.unknown');
}
return steps.toFixed(0);
}
function getReady() { function getReady() {
return doesExist(ready.data) && ready.data.ready; return doesExist(ready.data) && ready.data[0].status === JobStatus.SUCCESS;
} }
function renderProgress() { function renderProgress() {
const progress = getProgress(); const progress = getProgress(ready.data);
if (progress > 0 && progress <= steps) { const total = getTotal(ready.data);
return <CircularProgress variant='determinate' value={getPercent()} />; if (progress > 0 && progress <= total) {
return <CircularProgress variant='determinate' value={getPercent(progress, total)} />;
} else { } else {
return <CircularProgress />; return <CircularProgress />;
} }
@ -88,9 +58,9 @@ export function LoadingCard(props: LoadingCardProps) {
useEffect(() => { useEffect(() => {
if (ready.status === 'success' && getReady()) { if (ready.status === 'success' && getReady()) {
setReady(props.image, ready.data); setReady(ready.data[0]);
} }
}, [ready.status, getReady(), getProgress()]); }, [ready.status, getReady(), getProgress(ready.data)]);
return <Card sx={{ maxWidth: params.width.default }}> return <Card sx={{ maxWidth: params.width.default }}>
<CardContent sx={{ height: params.height.default }}> <CardContent sx={{ height: params.height.default }}>
@ -106,10 +76,7 @@ export function LoadingCard(props: LoadingCardProps) {
sx={{ alignItems: 'center' }} sx={{ alignItems: 'center' }}
> >
{renderProgress()} {renderProgress()}
<Typography>{t('loading.progress', { <Typography>{t('loading.progress', selectStatus(ready.data, image))}</Typography>
current: getProgress(),
total: getTotal(),
})}</Typography>
<Button onClick={() => cancel.mutate()}>{t('loading.cancel')}</Button> <Button onClick={() => cancel.mutate()}>{t('loading.cancel')}</Button>
</Stack> </Stack>
</Box> </Box>
@ -125,3 +92,45 @@ export function selectActions(state: OnnxState) {
setReady: state.setReady, setReady: state.setReady,
}; };
} }
export function selectStatus(data: Maybe<Array<JobResponse>>, defaultData: JobResponse) {
if (doesExist(data) && data.length > 0) {
return {
steps: data[0].steps,
stages: data[0].stages,
tiles: data[0].tiles,
};
}
return {
steps: defaultData.steps,
stages: defaultData.stages,
tiles: defaultData.tiles,
};
}
export function getPercent(current: number, total: number): number {
if (current > total) {
// steps was not complete, show 99% until done
return LOADING_OVERAGE;
}
const pct = current / total;
return Math.ceil(pct * LOADING_PERCENT);
}
export function getProgress(data: Maybe<Array<JobResponse>>) {
if (doesExist(data)) {
return data[0].steps.current;
}
return 0;
}
export function getTotal(data: Maybe<Array<JobResponse>>) {
if (doesExist(data)) {
return data[0].steps.total;
}
return 0;
}

View File

@ -20,13 +20,13 @@ import { MaskCanvas } from '../input/MaskCanvas.js';
export function Blend() { export function Blend() {
async function uploadSource() { async function uploadSource() {
const { blend, blendModel, blendUpscale } = store.getState(); const { blend, blendModel, blendUpscale } = store.getState();
const { image, retry } = await client.blend(blendModel, { const { job, retry } = await client.blend(blendModel, {
...blend, ...blend,
mask: mustExist(blend.mask), mask: mustExist(blend.mask),
sources: mustExist(blend.sources), // TODO: show an error if this doesn't exist sources: mustExist(blend.sources), // TODO: show an error if this doesn't exist
}, blendUpscale); }, blendUpscale);
pushHistory(image, retry); pushHistory(job, retry);
} }
const client = mustExist(useContext(ClientContext)); const client = mustExist(useContext(ClientContext));

View File

@ -27,12 +27,12 @@ export function Img2Img() {
const state = store.getState(); const state = store.getState();
const img2img = selectParams(state); const img2img = selectParams(state);
const { image, retry } = await client.img2img(model, { const { job, retry } = await client.img2img(model, {
...img2img, ...img2img,
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
}, selectUpscale(state), selectHighres(state)); }, selectUpscale(state), selectHighres(state));
pushHistory(image, retry); pushHistory(job, retry);
} }
const client = mustExist(useContext(ClientContext)); const client = mustExist(useContext(ClientContext));

View File

@ -39,22 +39,22 @@ export function Inpaint() {
const inpaint = selectParams(state); const inpaint = selectParams(state);
if (outpaint.enabled) { if (outpaint.enabled) {
const { image, retry } = await client.outpaint(model, { const { job, retry } = await client.outpaint(model, {
...inpaint, ...inpaint,
...outpaint, ...outpaint,
mask: mustExist(mask), mask: mustExist(mask),
source: mustExist(source), source: mustExist(source),
}, selectUpscale(state), selectHighres(state)); }, selectUpscale(state), selectHighres(state));
pushHistory(image, retry); pushHistory(job, retry);
} else { } else {
const { image, retry } = await client.inpaint(model, { const { job, retry } = await client.inpaint(model, {
...inpaint, ...inpaint,
mask: mustExist(mask), mask: mustExist(mask),
source: mustExist(source), source: mustExist(source),
}, selectUpscale(state), selectHighres(state)); }, selectUpscale(state), selectHighres(state));
pushHistory(image, retry); pushHistory(job, retry);
} }
} }

View File

@ -69,8 +69,8 @@ export function Txt2Img() {
const image = await client.chain(model, chain); const image = await client.chain(model, chain);
pushHistory(image); pushHistory(image);
} else { } else {
const { image, retry } = await client.txt2img(model, params2, upscale, highres); const { job, retry } = await client.txt2img(model, params2, upscale, highres);
pushHistory(image, retry); pushHistory(job, retry);
} }
} }

View File

@ -21,12 +21,12 @@ import { PromptInput } from '../input/PromptInput.js';
export function Upscale() { export function Upscale() {
async function uploadSource() { async function uploadSource() {
const { upscaleHighres, upscaleUpscale, upscaleModel, upscale } = store.getState(); const { upscaleHighres, upscaleUpscale, upscaleModel, upscale } = store.getState();
const { image, retry } = await client.upscale(upscaleModel, { const { job, retry } = await client.upscale(upscaleModel, {
...upscale, ...upscale,
source: mustExist(upscale.source), // TODO: show an error if this doesn't exist source: mustExist(upscale.source), // TODO: show an error if this doesn't exist
}, upscaleUpscale, upscaleHighres); }, upscaleUpscale, upscaleHighres);
pushHistory(image, retry); pushHistory(job, retry);
} }
const client = mustExist(useContext(ClientContext)); const client = mustExist(useContext(ClientContext));

View File

@ -2,6 +2,7 @@ import { Maybe } from '@apextoaster/js-utils';
import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js'; import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js';
import { Slice } from './types.js'; import { Slice } from './types.js';
import { DEFAULT_HISTORY } from '../constants.js'; import { DEFAULT_HISTORY } from '../constants.js';
import { JobResponse } from '../types/api-v2.js';
export interface HistoryItem { export interface HistoryItem {
image: ImageResponse; image: ImageResponse;
@ -9,14 +10,19 @@ export interface HistoryItem {
retry: Maybe<RetryParams>; retry: Maybe<RetryParams>;
} }
export interface HistoryItemV2 {
image: JobResponse;
retry: Maybe<RetryParams>;
}
export interface HistorySlice { export interface HistorySlice {
history: Array<HistoryItem>; history: Array<HistoryItemV2>;
limit: number; limit: number;
pushHistory(image: ImageResponse, retry?: RetryParams): void; pushHistory(image: JobResponse, retry?: RetryParams): void;
removeHistory(image: ImageResponse): void; removeHistory(image: JobResponse): void;
setLimit(limit: number): void; setLimit(limit: number): void;
setReady(image: ImageResponse, ready: ReadyResponse): void; setReady(image: JobResponse): void;
} }
export function createHistorySlice<TState extends HistorySlice>(): Slice<TState, HistorySlice> { export function createHistorySlice<TState extends HistorySlice>(): Slice<TState, HistorySlice> {
@ -39,7 +45,7 @@ export function createHistorySlice<TState extends HistorySlice>(): Slice<TState,
removeHistory(image) { removeHistory(image) {
set((prev) => ({ set((prev) => ({
...prev, ...prev,
history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key), history: prev.history.filter((it) => it.image.name !== image.name),
})); }));
}, },
setLimit(limit) { setLimit(limit) {
@ -48,12 +54,12 @@ export function createHistorySlice<TState extends HistorySlice>(): Slice<TState,
limit, limit,
})); }));
}, },
setReady(image, ready) { setReady(image) {
set((prev) => { set((prev) => {
const history = [...prev.history]; const history = [...prev.history];
const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key); const idx = history.findIndex((it) => it.image.name === image.name);
if (idx >= 0) { if (idx >= 0) {
history[idx].ready = ready; history[idx].image = image;
} else { } else {
// TODO: error // TODO: error
} }

View File

@ -67,7 +67,7 @@ export const I18N_STRINGS_EN = {
}, },
loading: { loading: {
cancel: 'Cancel', cancel: 'Cancel',
progress: '{{current}} of {{total}} steps', progress: '{{steps.current}} of {{steps.total}} steps, {{tiles.current}} of {{tiles.total}} tiles, {{stages.current}} of {{stages.total}} stages',
server: 'Connecting to server...', server: 'Connecting to server...',
unknown: 'many', unknown: 'many',
}, },

160
gui/src/types/api-v2.ts Normal file
View File

@ -0,0 +1,160 @@
import { RetryParams } from './api.js';
import { BaseImgParams, HighresParams, Img2ImgParams, InpaintParams, Txt2ImgParams, UpscaleParams } from './params.js';
export interface Progress {
current: number;
total: number;
}
export interface Size {
width: number;
height: number;
}
export interface NetworkMetadata {
name: string;
hash: string;
weight: number;
}
export interface ImageMetadata<TParams extends BaseImgParams, TType extends JobType> {
input_size: Size;
size: Size;
outputs: Array<string>;
params: TParams;
inversions: Array<NetworkMetadata>;
loras: Array<NetworkMetadata>;
model: string;
scheduler: string;
border: unknown;
highres: HighresParams;
upscale: UpscaleParams;
type: TType;
}
export enum JobStatus {
PENDING = 'pending',
RUNNING = 'running',
SUCCESS = 'success',
FAILED = 'failed',
CANCELLED = 'cancelled',
UNKNOWN = 'unknown',
}
export enum JobType {
TXT2IMG = 'txt2img',
IMG2IMG = 'img2img',
INPAINT = 'inpaint',
UPSCALE = 'upscale',
BLEND = 'blend',
CHAIN = 'chain',
}
export interface BaseJobResponse {
name: string;
status: JobStatus;
type: JobType;
stages: Progress;
steps: Progress;
tiles: Progress;
}
/**
* Pending image job.
*/
export interface PendingJobResponse extends BaseJobResponse {
status: JobStatus.PENDING | JobStatus.RUNNING;
queue: Progress;
}
/**
* Failed image job with error information.
*/
export interface FailedJobResponse extends BaseJobResponse {
status: JobStatus.FAILED;
error: string;
}
/**
* Successful txt2img image job with output keys and metadata.
*/
export interface SuccessTxt2ImgJobResponse extends BaseJobResponse {
status: JobStatus.SUCCESS;
type: JobType.TXT2IMG;
outputs: Array<string>;
metadata: Array<ImageMetadata<Txt2ImgParams, JobType.TXT2IMG>>;
}
/**
* Successful img2img job with output keys and metadata.
*/
export interface SuccessImg2ImgJobResponse extends BaseJobResponse {
status: JobStatus.SUCCESS;
type: JobType.IMG2IMG;
outputs: Array<string>;
metadata: Array<ImageMetadata<Img2ImgParams, JobType.IMG2IMG>>;
}
/**
* Successful inpaint job with output keys and metadata.
*/
export interface SuccessInpaintJobResponse extends BaseJobResponse {
status: JobStatus.SUCCESS;
type: JobType.INPAINT;
outputs: Array<string>;
metadata: Array<ImageMetadata<InpaintParams, JobType.INPAINT>>;
}
/**
* Successful upscale job with output keys and metadata.
*/
export interface SuccessUpscaleJobResponse extends BaseJobResponse {
status: JobStatus.SUCCESS;
type: JobType.UPSCALE;
outputs: Array<string>;
metadata: Array<ImageMetadata<BaseImgParams, JobType.UPSCALE>>;
}
/**
* Successful blend job with output keys and metadata.
*/
export interface SuccessBlendJobResponse extends BaseJobResponse {
status: JobStatus.SUCCESS;
type: JobType.BLEND;
outputs: Array<string>;
metadata: Array<ImageMetadata<BaseImgParams, JobType.BLEND>>;
}
/**
* Successful chain pipeline job with output keys and metadata.
*/
export interface SuccessChainJobResponse extends BaseJobResponse {
status: JobStatus.SUCCESS;
type: JobType.CHAIN;
outputs: Array<string>;
metadata: Array<ImageMetadata<BaseImgParams, JobType>>; // TODO: could be all kinds
}
export type SuccessJobResponse
= SuccessTxt2ImgJobResponse
| SuccessImg2ImgJobResponse
| SuccessInpaintJobResponse
| SuccessUpscaleJobResponse
| SuccessBlendJobResponse
| SuccessChainJobResponse;
export type JobResponse = PendingJobResponse | FailedJobResponse | SuccessJobResponse;
/**
* Status response from the job endpoint, with parameters to retry the job if it fails.
*/
export interface JobResponseWithRetry {
job: JobResponse;
retry: RetryParams;
}
/**
* Re-export `RetryParams` for convenience.
*/
export { RetryParams };

View File

@ -14,6 +14,8 @@ import {
/** /**
* Output image data within the response. * Output image data within the response.
*
* @deprecated
*/ */
export interface ImageOutput { export interface ImageOutput {
key: string; key: string;
@ -22,6 +24,8 @@ export interface ImageOutput {
/** /**
* General response for most image requests. * General response for most image requests.
*
* @deprecated
*/ */
export interface ImageResponse { export interface ImageResponse {
outputs: Array<ImageOutput>; outputs: Array<ImageOutput>;
@ -119,11 +123,19 @@ export type RetryParams = {
upscale?: UpscaleParams; upscale?: UpscaleParams;
}; };
/**
* Status response from the image endpoint, with parameters to retry the job if it fails.
*
* @deprecated
*/
export interface ImageResponseWithRetry { export interface ImageResponseWithRetry {
image: ImageResponse; image: ImageResponse;
retry: RetryParams; retry: RetryParams;
} }
/**
* @deprecated
*/
export interface ImageMetadata { export interface ImageMetadata {
highres: HighresParams; highres: HighresParams;
outputs: string | Array<string>; outputs: string | Array<string>;

View File

@ -43,6 +43,7 @@
"dtype", "dtype",
"ESRGAN", "ESRGAN",
"Exif", "Exif",
"fromarray",
"ftfy", "ftfy",
"gfpgan", "gfpgan",
"Heun", "Heun",
@ -115,6 +116,7 @@
"webp", "webp",
"xformers", "xformers",
"zustand" "zustand"
] ],
"git.ignoreLimitWarning": true
} }
} }