1
0
Fork 0

feat: add batch endpoints for cancel and status, update responses

This commit is contained in:
Sean Sube 2024-01-03 19:09:18 -06:00
parent 19c91f70f5
commit 44a8d61082
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
36 changed files with 981 additions and 707 deletions

View File

@ -36,15 +36,14 @@ class CorrectCodeformerStage(BaseStage):
# https://pypi.org/project/codeformer-perceptor/
# import must be within the load function for patches to take effect
# TODO: rewrite and remove
from codeformer.basicsr.archs.codeformer_arch import CodeFormer
from codeformer.basicsr.utils import img2tensor, tensor2img
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
upscale = upscale.with_args(**kwargs)
device = worker.get_device()
net = ARCH_REGISTRY.get("CodeFormer")(
net = CodeFormer(
dim_embd=512,
codebook_size=1024,
n_head=8,

View File

@ -1,10 +1,28 @@
from typing import Any, List, Optional, Tuple
from json import dumps
from logging import getLogger
from os import path
from typing import Any, List, Optional
import numpy as np
from PIL import Image
from ..output import json_params
from ..convert.utils import resolve_tensor
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
from ..server.load import get_extra_hashes
from ..utils import hash_file
logger = getLogger(__name__)
class NetworkMetadata:
name: str
hash: str
weight: float
def __init__(self, name: str, hash: str, weight: float) -> None:
self.name = name
self.hash = hash
self.weight = weight
class ImageMetadata:
@ -13,8 +31,9 @@ class ImageMetadata:
params: ImageParams
size: Size
upscale: UpscaleParams
inversions: Optional[List[Tuple[str, float]]]
loras: Optional[List[Tuple[str, float]]]
inversions: Optional[List[NetworkMetadata]]
loras: Optional[List[NetworkMetadata]]
models: Optional[List[NetworkMetadata]]
def __init__(
self,
@ -23,8 +42,9 @@ class ImageMetadata:
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
inversions: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None,
inversions: Optional[List[NetworkMetadata]] = None,
loras: Optional[List[NetworkMetadata]] = None,
models: Optional[List[NetworkMetadata]] = None,
) -> None:
self.params = params
self.size = size
@ -33,19 +53,108 @@ class ImageMetadata:
self.highres = highres
self.inversions = inversions
self.loras = loras
self.models = models
def to_auto1111(self, server, outputs) -> str:
model_name = path.basename(path.normpath(self.params.model))
logger.debug("getting model hash for %s", model_name)
model_hash = get_extra_hashes().get(model_name, None)
if model_hash is None:
model_hash_path = path.join(self.params.model, "hash.txt")
if path.exists(model_hash_path):
with open(model_hash_path, "r") as f:
model_hash = f.readline().rstrip(",. \n\t\r")
model_hash = model_hash or "unknown"
hash_map = {
model_name: model_hash,
}
inversion_hashes = ""
if self.inversions is not None:
inversion_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "inversion", name))
).upper(),
)
for name, _weight in self.inversions
]
inversion_hashes = ",".join(
[f"{name}: {hash}" for name, hash in inversion_pairs]
)
hash_map.update(dict(inversion_pairs))
lora_hashes = ""
if self.loras is not None:
lora_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "lora", name))
).upper(),
)
for name, _weight in self.loras
]
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
hash_map.update(dict(lora_pairs))
return (
f"{self.params.prompt or ''}\nNegative prompt: {self.params.negative_prompt or ''}\n"
f"Steps: {self.params.steps}, Sampler: {self.params.scheduler}, CFG scale: {self.params.cfg}, "
f"Seed: {self.params.seed}, Size: {self.size.width}x{self.size.height}, "
f"Model hash: {model_hash}, Model: {model_name}, "
f"Tool: onnx-web, Version: {server.server_version}, "
f'Inversion hashes: "{inversion_hashes}", '
f'Lora hashes: "{lora_hashes}", '
f"Hashes: {dumps(hash_map)}"
)
def tojson(self, server, outputs):
return json_params(
server,
outputs,
self.params,
self.size,
upscale=self.upscale,
border=self.border,
highres=self.highres,
inversions=self.inversions,
loras=self.loras,
)
json = {
"input_size": self.size.tojson(),
"outputs": outputs,
"params": self.params.tojson(),
"inversions": {},
"loras": {},
}
json["params"]["model"] = path.basename(self.params.model)
json["params"]["scheduler"] = self.params.scheduler # TODO: why tho?
# calculate final output size
output_size = self.size
if self.border is not None:
json["border"] = self.border.tojson()
output_size = output_size.add_border(self.border)
if self.highres is not None:
json["highres"] = self.highres.tojson()
output_size = self.highres.resize(output_size)
if self.upscale is not None:
json["upscale"] = self.upscale.tojson()
output_size = self.upscale.resize(output_size)
json["size"] = output_size.tojson()
if self.inversions is not None:
for name, weight in self.inversions:
hash = hash_file(
resolve_tensor(path.join(server.model_path, "inversion", name))
).upper()
json["inversions"][name] = {"weight": weight, "hash": hash}
if self.loras is not None:
for name, weight in self.loras:
hash = hash_file(
resolve_tensor(path.join(server.model_path, "lora", name))
).upper()
json["loras"][name] = {"weight": weight, "hash": hash}
return json
class StageResult:
@ -86,6 +195,7 @@ class StageResult:
self.arrays = arrays
self.images = images
self.source = source
self.metadata = []
def __len__(self) -> int:
if self.arrays is not None:
@ -117,7 +227,7 @@ class StageResult:
elif self.images is not None:
self.images.append(Image.fromarray(np.uint8(array), shape_mode(array)))
else:
raise ValueError("invalid stage result")
self.arrays = [array]
if metadata is not None:
self.metadata.append(metadata)
@ -130,13 +240,45 @@ class StageResult:
elif self.arrays is not None:
self.arrays.append(np.array(image))
else:
raise ValueError("invalid stage result")
self.images = [image]
if metadata is not None:
self.metadata.append(metadata)
else:
self.metadata.append(ImageMetadata())
def insert_array(
self, index: int, array: np.ndarray, metadata: Optional[ImageMetadata]
):
if self.arrays is not None:
self.arrays.insert(index, array)
elif self.images is not None:
self.images.insert(
index, Image.fromarray(np.uint8(array), shape_mode(array))
)
else:
self.arrays = [array]
if metadata is not None:
self.metadata.insert(index, metadata)
else:
self.metadata.insert(index, ImageMetadata())
def insert_image(
self, index: int, image: Image.Image, metadata: Optional[ImageMetadata]
):
if self.images is not None:
self.images.insert(index, image)
elif self.arrays is not None:
self.arrays.insert(index, np.array(image))
else:
self.images = [image]
if metadata is not None:
self.metadata.insert(index, metadata)
else:
self.metadata.insert(index, ImageMetadata())
def shape_mode(arr: np.ndarray) -> str:
if len(arr.shape) != 3:

View File

@ -16,7 +16,7 @@ from ..chain.highres import stage_highres
from ..chain.result import StageResult
from ..chain.upscale import split_upscale, stage_upscale_correction
from ..image import expand_image
from ..output import save_image
from ..output import save_image, save_result
from ..params import (
Border,
HighresParams,
@ -29,7 +29,7 @@ from ..server import ServerContext
from ..server.load import get_source_filters
from ..utils import is_debug, run_gc, show_system_toast
from ..worker import WorkerContext
from .utils import get_latents_from_seed, parse_prompt
from .utils import get_latents_from_seed
logger = getLogger(__name__)
@ -57,7 +57,6 @@ def run_txt2img_pipeline(
server: ServerContext,
params: ImageParams,
size: Size,
outputs: List[str],
upscale: UpscaleParams,
highres: HighresParams,
) -> None:
@ -114,50 +113,34 @@ def run_txt2img_pipeline(
# run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback()
images = chain.run(
images = chain(
worker, server, params, StageResult.empty(), callback=progress, latents=latents
)
_pairs, loras, inversions, _rest = parse_prompt(params)
# add a thumbnail, if requested
cover = images[0]
cover = images.as_image()[0]
if params.thumbnail and (
cover.width > server.thumbnail_size or cover.height > server.thumbnail_size
):
thumbnail = cover.copy()
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
images.insert(0, thumbnail)
outputs.insert(0, f"{worker.name}-thumb.{server.image_format}")
images.insert_image(0, thumbnail)
for image, output in zip(images, outputs):
logger.trace("saving output image %s: %s", output, image.size)
dest = save_image(
server,
output,
image,
params,
size,
upscale=upscale,
highres=highres,
inversions=inversions,
loras=loras,
)
save_result(server, images, worker.job)
# clean up
run_gc([worker.get_device()])
# notify the user
show_system_toast(f"finished txt2img job: {dest}")
logger.info("finished txt2img job: %s", dest)
show_system_toast(f"finished txt2img job: {worker.job}")
logger.info("finished txt2img job: %s", worker.job)
def run_img2img_pipeline(
worker: WorkerContext,
server: ServerContext,
params: ImageParams,
outputs: List[str],
upscale: UpscaleParams,
highres: HighresParams,
source: Image.Image,
@ -228,36 +211,21 @@ def run_img2img_pipeline(
# run and append the filtered source
progress = worker.get_progress_callback()
images = chain.run(
images = chain(
worker, server, params, StageResult(images=[source]), callback=progress
)
if source_filter is not None and source_filter != "none":
images.append(source)
images.push_image(source)
# save with metadata
_pairs, loras, inversions, _rest = parse_prompt(params)
size = Size(*source.size)
for image, output in zip(images, outputs):
dest = save_image(
server,
output,
image,
params,
size,
upscale=upscale,
highres=highres,
inversions=inversions,
loras=loras,
)
save_result(server, images, worker.job)
# clean up
run_gc([worker.get_device()])
# notify the user
show_system_toast(f"finished img2img job: {dest}")
logger.info("finished img2img job: %s", dest)
show_system_toast(f"finished img2img job: {worker.job}")
logger.info("finished img2img job: %s", worker.job)
def run_inpaint_pipeline(
@ -265,7 +233,6 @@ def run_inpaint_pipeline(
server: ServerContext,
params: ImageParams,
size: Size,
outputs: List[str],
upscale: UpscaleParams,
highres: HighresParams,
source: Image.Image,
@ -290,7 +257,7 @@ def run_inpaint_pipeline(
mask = ImageOps.contain(mask, (mask_max, mask_max))
mask = mask.crop((0, 0, source.width, source.height))
source, mask, noise, full_size = expand_image(
source, mask, noise, _full_size = expand_image(
source,
mask,
border,
@ -414,7 +381,7 @@ def run_inpaint_pipeline(
# run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback()
images = chain.run(
images = chain(
worker,
server,
params,
@ -423,33 +390,28 @@ def run_inpaint_pipeline(
latents=latents,
)
_pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs):
for i, image, metadata in enumerate(zip(images.as_image(), images.metadata)):
if full_res_inpaint:
if is_debug():
save_image(server, "adjusted-output.png", image)
mini_image = ImageOps.contain(image, (adj_mask_size, adj_mask_size))
image = original_source
image.paste(mini_image, box=adj_mask_border)
dest = save_image(
save_image(
server,
output,
f"{worker.job}_{i}.{server.image_format}",
image,
params,
size,
upscale=upscale,
border=border,
inversions=inversions,
loras=loras,
metadata,
)
# clean up
del image
run_gc([worker.get_device()])
# notify the user
show_system_toast(f"finished inpaint job: {dest}")
logger.info("finished inpaint job: %s", dest)
show_system_toast(f"finished inpaint job: {worker.job}")
logger.info("finished inpaint job: %s", worker.job)
def run_upscale_pipeline(
@ -457,7 +419,6 @@ def run_upscale_pipeline(
server: ServerContext,
params: ImageParams,
size: Size,
outputs: List[str],
upscale: UpscaleParams,
highres: HighresParams,
source: Image.Image,
@ -497,30 +458,18 @@ def run_upscale_pipeline(
# run and save
progress = worker.get_progress_callback()
images = chain.run(
images = chain(
worker, server, params, StageResult(images=[source]), callback=progress
)
_pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs):
dest = save_image(
server,
output,
image,
params,
size,
upscale=upscale,
inversions=inversions,
loras=loras,
)
save_result(server, images, worker.job)
# clean up
del image
run_gc([worker.get_device()])
# notify the user
show_system_toast(f"finished upscale job: {dest}")
logger.info("finished upscale job: %s", dest)
show_system_toast(f"finished upscale job: {worker.job}")
logger.info("finished upscale job: %s", worker.job)
def run_blend_pipeline(
@ -528,7 +477,6 @@ def run_blend_pipeline(
server: ServerContext,
params: ImageParams,
size: Size,
outputs: List[str],
upscale: UpscaleParams,
# highres: HighresParams,
sources: List[Image.Image],
@ -559,17 +507,15 @@ def run_blend_pipeline(
# run and save
progress = worker.get_progress_callback()
images = chain.run(
images = chain(
worker, server, params, StageResult(images=sources), callback=progress
)
for image, output in zip(images, outputs):
dest = save_image(server, output, image, params, size, upscale=upscale)
save_result(server, images, worker.job)
# clean up
del image
run_gc([worker.get_device()])
# notify the user
show_system_toast(f"finished blend job: {dest}")
logger.info("finished blend job: %s", dest)
show_system_toast(f"finished blend job: {worker.job}")
logger.info("finished blend job: %s", worker.job)

View File

@ -1,173 +1,20 @@
from hashlib import sha256
from json import dumps
from logging import getLogger
from os import path
from struct import pack
from time import time
from typing import Any, Dict, List, Optional, Tuple
from typing import List, Optional
from piexif import ExifIFD, ImageIFD, dump
from piexif.helper import UserComment
from PIL import Image, PngImagePlugin
from .convert.utils import resolve_tensor
from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams
from .chain.result import ImageMetadata, StageResult
from .params import ImageParams, Param, Size
from .server import ServerContext
from .server.load import get_extra_hashes
from .utils import base_join
from .utils import base_join, hash_value
logger = getLogger(__name__)
HASH_BUFFER_SIZE = 2**22 # 4MB
def hash_file(name: str):
sha = sha256()
with open(name, "rb") as f:
while True:
data = f.read(HASH_BUFFER_SIZE)
if not data:
break
sha.update(data)
return sha.hexdigest()
def hash_value(sha, param: Optional[Param]):
if param is None:
return
elif isinstance(param, bool):
sha.update(bytearray(pack("!B", param)))
elif isinstance(param, float):
sha.update(bytearray(pack("!f", param)))
elif isinstance(param, int):
sha.update(bytearray(pack("!I", param)))
elif isinstance(param, str):
sha.update(param.encode("utf-8"))
else:
logger.warning("cannot hash param: %s, %s", param, type(param))
def json_params(
server: ServerContext,
outputs: List[str],
params: ImageParams,
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
inversions: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None,
parent: Optional[Dict] = None,
) -> Any:
json = {
"input_size": size.tojson(),
"outputs": outputs,
"params": params.tojson(),
"inversions": {},
"loras": {},
}
json["params"]["model"] = path.basename(params.model)
json["params"]["scheduler"] = params.scheduler
# calculate final output size
output_size = size
if border is not None:
json["border"] = border.tojson()
output_size = output_size.add_border(border)
if highres is not None:
json["highres"] = highres.tojson()
output_size = highres.resize(output_size)
if upscale is not None:
json["upscale"] = upscale.tojson()
output_size = upscale.resize(output_size)
json["size"] = output_size.tojson()
if inversions is not None:
for name, weight in inversions:
hash = hash_file(
resolve_tensor(path.join(server.model_path, "inversion", name))
).upper()
json["inversions"][name] = {"weight": weight, "hash": hash}
if loras is not None:
for name, weight in loras:
hash = hash_file(
resolve_tensor(path.join(server.model_path, "lora", name))
).upper()
json["loras"][name] = {"weight": weight, "hash": hash}
return json
def str_params(
server: ServerContext,
params: ImageParams,
size: Size,
inversions: List[Tuple[str, float]] = None,
loras: List[Tuple[str, float]] = None,
) -> str:
model_name = path.basename(path.normpath(params.model))
logger.debug("getting model hash for %s", model_name)
model_hash = get_extra_hashes().get(model_name, None)
if model_hash is None:
model_hash_path = path.join(params.model, "hash.txt")
if path.exists(model_hash_path):
with open(model_hash_path, "r") as f:
model_hash = f.readline().rstrip(",. \n\t\r")
model_hash = model_hash or "unknown"
hash_map = {
model_name: model_hash,
}
inversion_hashes = ""
if inversions is not None:
inversion_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "inversion", name))
).upper(),
)
for name, _weight in inversions
]
inversion_hashes = ",".join(
[f"{name}: {hash}" for name, hash in inversion_pairs]
)
hash_map.update(dict(inversion_pairs))
lora_hashes = ""
if loras is not None:
lora_pairs = [
(
name,
hash_file(
resolve_tensor(path.join(server.model_path, "lora", name))
).upper(),
)
for name, _weight in loras
]
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
hash_map.update(dict(lora_pairs))
return (
f"{params.prompt or ''}\nNegative prompt: {params.negative_prompt or ''}\n"
f"Steps: {params.steps}, Sampler: {params.scheduler}, CFG scale: {params.cfg}, "
f"Seed: {params.seed}, Size: {size.width}x{size.height}, "
f"Model hash: {model_hash}, Model: {model_name}, "
f"Tool: onnx-web, Version: {server.server_version}, "
f'Inversion hashes: "{inversion_hashes}", '
f'Lora hashes: "{lora_hashes}", '
f"Hashes: {dumps(hash_map)}"
)
def make_output_name(
server: ServerContext,
@ -179,6 +26,19 @@ def make_output_name(
offset: int = 0,
) -> List[str]:
count = count or params.batch
job_name = make_job_name(mode, params, size, extras)
return [
f"{job_name}_{i}.{server.image_format}" for i in range(offset, count + offset)
]
def make_job_name(
mode: str,
params: ImageParams,
size: Size,
extras: Optional[List[Optional[Param]]] = None,
) -> str:
now = int(time())
sha = sha256()
@ -200,49 +60,49 @@ def make_output_name(
for param in extras:
hash_value(sha, param)
return [
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
for i in range(offset, count + offset)
]
return f"{mode}_{params.seed}_{sha.hexdigest()}_{now}"
def save_result(
server: ServerContext,
result: StageResult,
base_name: str,
) -> List[str]:
results = []
for i, image, metadata in enumerate(zip(result.as_image(), result.metadata)):
results.append(
save_image(
server,
base_name + f"_{i}.{server.image_format}",
image,
metadata,
)
)
return results
def save_image(
server: ServerContext,
output: str,
image: Image.Image,
params: Optional[ImageParams] = None,
size: Optional[Size] = None,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
inversions: List[Tuple[str, float]] = None,
loras: List[Tuple[str, float]] = None,
metadata: ImageMetadata,
) -> str:
path = base_join(server.output_path, output)
if server.image_format == "png":
exif = PngImagePlugin.PngInfo()
if params is not None:
if metadata is not None:
exif.add_text("make", "onnx-web")
exif.add_text(
"maker note",
dumps(
json_params(
server,
[output],
params,
size,
upscale=upscale,
border=border,
highres=highres,
)
),
dumps(metadata.tojson(server, [output])),
)
exif.add_text("model", server.server_version)
exif.add_text(
"parameters",
str_params(server, params, size, inversions=inversions, loras=loras),
metadata.to_auto1111(server, [output]),
)
image.save(path, format=server.image_format, pnginfo=exif)
@ -251,23 +111,11 @@ def save_image(
{
"0th": {
ExifIFD.MakerNote: UserComment.dump(
dumps(
json_params(
server,
[output],
params,
size,
upscale=upscale,
border=border,
highres=highres,
)
),
dumps(metadata.tojson(server, [output])),
encoding="unicode",
),
ExifIFD.UserComment: UserComment.dump(
str_params(
server, params, size, inversions=inversions, loras=loras
),
metadata.to_auto1111(server, [output]),
encoding="unicode",
),
ImageIFD.Make: "onnx-web",
@ -277,34 +125,23 @@ def save_image(
)
image.save(path, format=server.image_format, exif=exif)
if params is not None:
save_params(
if metadata is not None:
save_metadata(
server,
output,
params,
size,
upscale=upscale,
border=border,
highres=highres,
)
logger.debug("saved output image to: %s", path)
return path
def save_params(
def save_metadata(
server: ServerContext,
output: str,
params: ImageParams,
size: Size,
upscale: Optional[UpscaleParams] = None,
border: Optional[Border] = None,
highres: Optional[HighresParams] = None,
metadata: ImageMetadata,
) -> str:
path = base_join(server.output_path, f"{output}.json")
json = json_params(
server, output, params, size, upscale=upscale, border=border, highres=highres
)
json = metadata.tojson(server, [output])
with open(path, "w") as f:
f.write(dumps(json))
logger.debug("saved image params to: %s", path)

View File

@ -13,6 +13,24 @@ Param = Union[str, int, float]
Point = Tuple[int, int]
class Progress:
current: int
total: int
def __init__(self, current: int, total: int) -> None:
self.current = current
self.total = total
def __str__(self) -> str:
return "%s/%s" % (self.current, self.total)
def tojson(self):
return {
"current": self.current,
"total": self.total,
}
class SizeChart(IntEnum):
micro = 64
mini = 128 # small tile for very expensive models

View File

@ -26,14 +26,14 @@ def restart_workers(server: ServerContext, pool: DevicePoolExecutor):
pool.recycle(recycle_all=True)
logger.info("restarted worker pool")
return jsonify(pool.status())
return jsonify(pool.summary())
def worker_status(server: ServerContext, pool: DevicePoolExecutor):
if not check_admin(server):
return make_response(jsonify({})), 401
return jsonify(pool.status())
return jsonify(pool.summary())
def get_extra_models(server: ServerContext):
@ -102,8 +102,8 @@ def register_admin_routes(app: Flask, server: ServerContext, pool: DevicePoolExe
app.route("/api/extras", methods=["PUT"])(
wrap_route(update_extra_models, server)
),
app.route("/api/restart", methods=["POST"])(
app.route("/api/worker/restart", methods=["POST"])(
wrap_route(restart_workers, server, pool=pool)
),
app.route("/api/status")(wrap_route(worker_status, server, pool=pool)),
app.route("/api/worker/status")(wrap_route(worker_status, server, pool=pool)),
]

View File

@ -1,14 +1,14 @@
from io import BytesIO
from logging import getLogger
from os import path
from typing import Any, Dict
from typing import Any, Dict, List
from flask import Flask, jsonify, make_response, request, url_for
from jsonschema import validate
from PIL import Image
from ..chain import CHAIN_STAGES, ChainPipeline
from ..chain.result import StageResult
from ..chain.result import ImageMetadata, StageResult
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
from ..diffusers.run import (
run_blend_pipeline,
@ -18,8 +18,8 @@ from ..diffusers.run import (
run_upscale_pipeline,
)
from ..diffusers.utils import replace_wildcards
from ..output import json_params, make_output_name
from ..params import Size, StageParams, TileOrder
from ..output import make_job_name
from ..params import Progress, Size, StageParams, TileOrder
from ..transformers.run import run_txt2txt_pipeline
from ..utils import (
base_join,
@ -34,6 +34,7 @@ from ..utils import (
load_config_str,
sanitize_name,
)
from ..worker.command import JobType
from ..worker.pool import DevicePoolExecutor
from .context import ServerContext
from .load import (
@ -92,6 +93,64 @@ def error_reply(err: str):
return response
def job_reply(name: str):
return jsonify(
{
"name": name,
}
)
def image_reply(
name: str,
status: str,
job_type: str,
stages: Progress = None,
steps: Progress = None,
tiles: Progress = None,
outputs: List[str] = None,
metadata: List[ImageMetadata] = None,
):
if stages is None:
stages = Progress()
if steps is None:
steps = Progress()
if tiles is None:
tiles = Progress()
data = {
"name": name,
"status": status,
"type": job_type,
"stages": stages.tojson(),
"steps": steps.tojson(),
"tiles": tiles.tojson(),
}
if len(metadata) != len(outputs):
logger.error("metadata and outputs must be the same length")
return error_reply("metadata and outputs must be the same length")
if outputs is not None:
data["outputs"] = outputs
if metadata is not None:
data["metadata"] = metadata
return jsonify(data)
def multi_image_reply(results: Dict[str, Any]):
# TODO: not that
return jsonify(
{
"results": results,
}
)
def url_from_rule(rule) -> str:
options = {}
for arg in rule.arguments:
@ -197,17 +256,15 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
)
output_count += 1
output = make_output_name(
job_name = make_job_name(
server, "img2img", params, size, extras=[strength], count=output_count
)
job_name = output[0]
pool.submit(
job_name,
JobType.IMG2IMG,
run_img2img_pipeline,
server,
params,
output,
upscale,
highres,
source,
@ -218,9 +275,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
logger.info("img2img job queued for: %s", job_name)
return jsonify(
json_params(server, output, params, size, upscale=upscale, highres=highres)
)
return job_reply(job_name)
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
@ -230,16 +285,15 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
replace_wildcards(params, get_wildcard_data())
output = make_output_name(server, "txt2img", params, size, count=params.batch)
job_name = make_job_name(server, "txt2img", params, size, count=params.batch)
job_name = output[0]
pool.submit(
job_name,
JobType.TXT2IMG,
run_txt2img_pipeline,
server,
params,
size,
output,
upscale,
highres,
needs_device=device,
@ -247,9 +301,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
logger.info("txt2img job queued for: %s", job_name)
return jsonify(
json_params(server, output, params, size, upscale=upscale, highres=highres)
)
return job_reply(job_name)
def inpaint(server: ServerContext, pool: DevicePoolExecutor):
@ -295,7 +347,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
replace_wildcards(params, get_wildcard_data())
output = make_output_name(
job_name = make_job_name(
server,
"inpaint",
params,
@ -312,14 +364,13 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
],
)
job_name = output[0]
pool.submit(
job_name,
JobType.INPAINT,
run_inpaint_pipeline,
server,
params,
size,
output,
upscale,
highres,
source,
@ -336,17 +387,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
logger.info("inpaint job queued for: %s", job_name)
return jsonify(
json_params(
server,
output,
params,
size,
upscale=upscale,
border=expand,
highres=highres,
)
)
return job_reply(job_name)
def upscale(server: ServerContext, pool: DevicePoolExecutor):
@ -362,16 +403,14 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
replace_wildcards(params, get_wildcard_data())
output = make_output_name(server, "upscale", params, size)
job_name = output[0]
job_name = make_job_name(server, "upscale", params, size)
pool.submit(
job_name,
JobType.UPSCALE,
run_upscale_pipeline,
server,
params,
size,
output,
upscale,
highres,
source,
@ -380,9 +419,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
logger.info("upscale job queued for: %s", job_name)
return jsonify(
json_params(server, output, params, size, upscale=upscale, highres=highres)
)
return job_reply(job_name)
# keys that are specially parsed by params and should not show up in with_args
@ -478,25 +515,21 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
output = make_output_name(
server, "chain", base_params, base_size, count=pipeline.outputs(base_params, 0)
)
job_name = output[0]
job_name = make_job_name(server, "chain", base_params, base_size)
# build and run chain pipeline
pool.submit(
job_name,
JobType.CHAIN,
pipeline,
server,
base_params,
StageResult.empty(),
output=output,
size=base_size,
needs_device=device,
)
step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size))
return jsonify(json_params(server, output, step_params, base_size))
return job_reply(job_name)
def blend(server: ServerContext, pool: DevicePoolExecutor):
@ -520,15 +553,14 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server)
upscale = build_upscale()
output = make_output_name(server, "upscale", params, size)
job_name = output[0]
job_name = make_job_name(server, "blend", params, size)
pool.submit(
job_name,
JobType.BLEND,
run_blend_pipeline,
server,
params,
size,
output,
upscale,
# TODO: highres
sources,
@ -538,27 +570,26 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
logger.info("upscale job queued for: %s", job_name)
return jsonify(json_params(server, output, params, size, upscale=upscale))
return job_reply(job_name)
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
device, params, size = pipeline_from_request(server)
output = make_output_name(server, "txt2txt", params, size)
job_name = output[0]
job_name = make_job_name(server, "txt2txt", params, size)
logger.info("upscale job queued for: %s", job_name)
pool.submit(
job_name,
JobType.TXT2TXT,
run_txt2txt_pipeline,
server,
params,
size,
output,
needs_device=device,
)
return jsonify(json_params(server, output, params, size))
return job_reply(job_name)
def cancel(server: ServerContext, pool: DevicePoolExecutor):
@ -601,9 +632,64 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
)
def job_cancel(server: ServerContext, pool: DevicePoolExecutor):
legacy_job_name = request.args.get("job", None)
job_list = request.args.get("jobs", "").split(",")
if legacy_job_name is not None:
job_list.append(legacy_job_name)
if len(job_list) == 0:
return error_reply("at least one job name is required")
results = {}
for job_name in job_list:
job_name = sanitize_name(job_name)
cancelled = pool.cancel(job_name)
results[job_name] = cancelled
return multi_image_reply(results)
def job_status(server: ServerContext, pool: DevicePoolExecutor):
legacy_job_name = request.args.get("job", None)
job_list = request.args.get("jobs", "").split(",")
if legacy_job_name is not None:
job_list.append(legacy_job_name)
if len(job_list) == 0:
return error_reply("at least one job name is required")
for job_name in job_list:
job_name = sanitize_name(job_name)
status, progress = pool.status(job_name)
# TODO: accumulate results
if progress is not None:
# TODO: add output paths based on progress.results counter
return image_reply(
job_name,
status,
"TODO",
stages=Progress(progress.stages, 0),
steps=Progress(progress.steps, 0),
tiles=Progress(progress.tiles, 0),
)
return image_reply(job_name, status, "TODO")
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
return [
app.route("/api")(wrap_route(introspect, server, app=app)),
# job routes
app.route("/api/job", methods=["POST"])(wrap_route(chain, server, pool=pool)),
app.route("/api/job/cancel", methods=["PUT"])(
wrap_route(job_cancel, server, pool=pool)
),
app.route("/api/job/status")(wrap_route(job_status, server, pool=pool)),
# settings routes
app.route("/api/settings/filters")(wrap_route(list_filters, server)),
app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)),
app.route("/api/settings/models")(wrap_route(list_models, server)),
@ -614,6 +700,7 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)),
app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)),
app.route("/api/settings/wildcards")(wrap_route(list_wildcards, server)),
# legacy job routes
app.route("/api/img2img", methods=["POST"])(
wrap_route(img2img, server, pool=pool)
),
@ -631,6 +718,7 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu
),
app.route("/api/chain", methods=["POST"])(wrap_route(chain, server, pool=pool)),
app.route("/api/blend", methods=["POST"])(wrap_route(blend, server, pool=pool)),
# deprecated routes
app.route("/api/cancel", methods=["PUT"])(
wrap_route(cancel, server, pool=pool)
),

View File

@ -12,7 +12,6 @@ def run_txt2txt_pipeline(
_server: ServerContext,
params: ImageParams,
_size: Size,
output: str,
) -> None:
from transformers import AutoTokenizer, GPTJForCausalLM
@ -38,4 +37,4 @@ def run_txt2txt_pipeline(
print("Server says: %s" % result_text)
logger.info("finished txt2txt job: %s", output)
logger.info("finished txt2txt job: %s", worker.job)

View File

@ -2,16 +2,18 @@ import gc
import importlib
import json
import threading
from hashlib import sha256
from json import JSONDecodeError
from logging import getLogger
from os import environ, path
from platform import system
from struct import pack
from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
import torch
from yaml import safe_load
from .params import DeviceParams, SizeChart
from .params import DeviceParams, Param, SizeChart
logger = getLogger(__name__)
@ -218,3 +220,34 @@ def load_config_str(raw: str) -> Dict:
return json.loads(raw)
except JSONDecodeError:
return safe_load(raw)
HASH_BUFFER_SIZE = 2**22 # 4MB
def hash_file(name: str):
sha = sha256()
with open(name, "rb") as f:
while True:
data = f.read(HASH_BUFFER_SIZE)
if not data:
break
sha.update(data)
return sha.hexdigest()
def hash_value(sha, param: Optional[Param]):
if param is None:
return
elif isinstance(param, bool):
sha.update(bytearray(pack("!B", param)))
elif isinstance(param, float):
sha.update(bytearray(pack("!f", param)))
elif isinstance(param, int):
sha.update(bytearray(pack("!I", param)))
elif isinstance(param, str):
sha.update(param.encode("utf-8"))
else:
logger.warning("cannot hash param: %s, %s", param, type(param))

View File

@ -1,34 +1,61 @@
from enum import Enum
from typing import Any, Callable, Dict
class JobStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
CANCELLED = "cancelled"
UNKNOWN = "unknown"
class JobType(str, Enum):
TXT2TXT = "txt2txt"
TXT2IMG = "txt2img"
IMG2IMG = "img2img"
INPAINT = "inpaint"
UPSCALE = "upscale"
BLEND = "blend"
CHAIN = "chain"
class ProgressCommand:
device: str
job: str
finished: bool
progress: int
cancelled: bool
failed: bool
job_type: str
status: JobStatus
results: int
steps: int
stages: int
tiles: int
def __init__(
self,
job: str,
job_type: str,
device: str,
finished: bool,
progress: int,
cancelled: bool = False,
failed: bool = False,
status: JobStatus,
results: int = 0,
steps: int = 0,
stages: int = 0,
tiles: int = 0,
):
self.job = job
self.job_type = job_type
self.device = device
self.finished = finished
self.progress = progress
self.cancelled = cancelled
self.failed = failed
self.status = status
self.results = results
self.steps = steps
self.stages = stages
self.tiles = tiles
class JobCommand:
device: str
name: str
job_type: str
fn: Callable[..., None]
args: Any
kwargs: Dict[str, Any]
@ -37,12 +64,14 @@ class JobCommand:
self,
name: str,
device: str,
job_type: str,
fn: Callable[..., None],
args: Any,
kwargs: Dict[str, Any],
):
self.device = device
self.name = name
self.job_type = job_type
self.fn = fn
self.args = args
self.kwargs = kwargs

View File

@ -2,21 +2,23 @@ from logging import getLogger
from os import getpid
from typing import Any, Callable, Optional
import numpy as np
from torch.multiprocessing import Queue, Value
from ..errors import CancelledException
from ..params import DeviceParams
from .command import JobCommand, ProgressCommand
from .command import JobCommand, JobStatus, ProgressCommand
logger = getLogger(__name__)
ProgressCallback = Callable[[int, int, Any], None]
ProgressCallback = Callable[[int, int, np.ndarray], None]
class WorkerContext:
cancel: "Value[bool]"
job: Optional[str]
job_type: Optional[str]
name: str
pending: "Queue[JobCommand]"
active_pid: "Value[int]"
@ -41,6 +43,7 @@ class WorkerContext:
timeout: float,
):
self.job = None
self.job_type = None
self.name = name
self.device = device
self.cancel = cancel
@ -54,9 +57,15 @@ class WorkerContext:
self.retries = retries
self.timeout = timeout
def start(self, job: str) -> None:
self.job = job
def start(self, job: JobCommand) -> None:
# set job name and type
self.job = job.name
self.job_type = job.job_type
# reset retries
self.retries = self.initial_retries
# clear flags
self.set_cancel(cancel=False)
self.set_idle(idle=False)
@ -81,7 +90,7 @@ class WorkerContext:
def get_progress(self) -> int:
if self.last_progress is not None:
return self.last_progress.progress
return self.last_progress.steps
return 0
@ -112,13 +121,11 @@ class WorkerContext:
logger.debug("setting progress for job %s to %s", self.job, progress)
self.last_progress = ProgressCommand(
self.job,
self.job_type,
self.device.device,
False,
progress,
self.is_cancelled(),
False,
JobStatus.RUNNING,
steps=progress,
)
self.progress.put(
self.last_progress,
block=False,
@ -131,11 +138,10 @@ class WorkerContext:
logger.debug("setting finished for job %s", self.job)
self.last_progress = ProgressCommand(
self.job,
self.job_type,
self.device.device,
True,
self.get_progress(),
self.is_cancelled(),
False,
JobStatus.SUCCESS, # TODO: FAILED
steps=self.get_progress(),
)
self.progress.put(
self.last_progress,
@ -150,11 +156,10 @@ class WorkerContext:
try:
self.last_progress = ProgressCommand(
self.job,
self.job_type,
self.device.device,
True,
self.get_progress(),
self.is_cancelled(),
True,
JobStatus.FAILED,
steps=self.get_progress(),
)
self.progress.put(
self.last_progress,
@ -162,25 +167,3 @@ class WorkerContext:
)
except Exception:
logger.exception("error setting failure on job %s", self.job)
class JobStatus:
name: str
device: str
progress: int
cancelled: bool
finished: bool
def __init__(
self,
name: str,
device: DeviceParams,
progress: int = 0,
cancelled: bool = False,
finished: bool = False,
) -> None:
self.name = name
self.device = device.device
self.progress = progress
self.cancelled = cancelled
self.finished = finished

View File

@ -8,7 +8,7 @@ from torch.multiprocessing import Process, Queue, Value
from ..params import DeviceParams
from ..server import ServerContext
from .command import JobCommand, ProgressCommand
from .command import JobCommand, JobStatus, ProgressCommand
from .context import WorkerContext
from .utils import Interval
from .worker import worker_main
@ -201,6 +201,10 @@ class DevicePoolExecutor:
should be cancelled on the next progress callback.
"""
if key in self.cancelled_jobs:
logger.debug("cancelling already cancelled job: %s", key)
return True
for job in self.finished_jobs:
if job.job == key:
logger.debug("cannot cancel finished job: %s", key)
@ -209,6 +213,9 @@ class DevicePoolExecutor:
for job in self.pending_jobs:
if job.name == key:
self.pending_jobs.remove(job)
self.cancelled_jobs.append(
key
) # ensure workers never pick up this job and the status endpoint knows about it later
logger.info("cancelled pending job: %s", key)
return True
@ -221,28 +228,31 @@ class DevicePoolExecutor:
self.cancelled_jobs.append(key)
return True
def done(self, key: str) -> Tuple[bool, Optional[ProgressCommand]]:
def status(self, key: str) -> Tuple[JobStatus, Optional[ProgressCommand]]:
"""
Check if a job has been finished and report the last progress update.
If the job is still pending, the first item will be True and there will be no ProgressCommand.
"""
if key in self.cancelled_jobs:
logger.debug("checking status for cancelled job: %s", key)
return (JobStatus.CANCELLED, None)
if key in self.running_jobs:
logger.debug("checking status for running job: %s", key)
return (False, self.running_jobs[key])
return (JobStatus.RUNNING, self.running_jobs[key])
for job in self.finished_jobs:
if job.job == key:
logger.debug("checking status for finished job: %s", key)
return (False, job)
return (job.status, job)
for job in self.pending_jobs:
if job.name == key:
logger.debug("checking status for pending job: %s", key)
return (True, None)
return (JobStatus.PENDING, None)
logger.trace("checking status for unknown job: %s", key)
return (False, None)
return (JobStatus.UNKNOWN, None)
def join(self):
logger.info("stopping worker pool")
@ -383,6 +393,7 @@ class DevicePoolExecutor:
def submit(
self,
key: str,
job_type: str,
fn: Callable[..., None],
/,
*args,
@ -399,56 +410,63 @@ class DevicePoolExecutor:
)
# build and queue job
job = JobCommand(key, device, fn, args, kwargs)
job = JobCommand(key, device, job_type, fn, args, kwargs)
self.pending_jobs.append(job)
def status(self) -> Dict[str, List[Tuple[str, int, bool, bool, bool, bool]]]:
def summary(self) -> Dict[str, List[Tuple[str, int, JobStatus]]]:
"""
Returns a tuple of: job/device, progress, progress, finished, cancelled, failed
"""
return {
"cancelled": [],
"finished": [
jobs: Tuple[str, int, JobStatus] = []
jobs.extend(
[
(
job.job,
job.progress,
False,
job.finished,
job.cancelled,
job.failed,
job,
0,
JobStatus.CANCELLED,
)
for job in self.finished_jobs
],
"pending": [
for job in self.cancelled_jobs
]
)
jobs.extend(
[
(
job.name,
0,
True,
False,
False,
False,
JobStatus.PENDING,
)
for job in self.pending_jobs
],
"running": [
]
)
jobs.extend(
[
(
name,
job.progress,
False,
job.finished,
job.cancelled,
job.failed,
job.steps,
job.status,
)
for name, job in self.running_jobs.items()
],
"total": [
]
)
jobs.extend(
[
(
job.job,
job.steps,
job.status,
)
for job in self.finished_jobs
]
)
return {
"jobs": jobs,
"workers": [
(
device,
total,
self.workers[device].is_alive(),
False,
False,
False,
)
for device, total in self.total_jobs.items()
],
@ -476,20 +494,18 @@ class DevicePoolExecutor:
self.cancelled_jobs.remove(progress.job)
def update_job(self, progress: ProgressCommand):
if progress.finished:
if progress.status in [JobStatus.SUCCESS, JobStatus.FAILED]:
return self.finish_job(progress)
# move from pending to running
logger.debug(
"progress update for job: %s to %s", progress.job, progress.progress
)
logger.debug("progress update for job: %s to %s", progress.job, progress.steps)
self.running_jobs[progress.job] = progress
self.pending_jobs[:] = [
job for job in self.pending_jobs if job.name != progress.job
]
# increment job counter if this is the start of a new job
if progress.progress == 0:
if progress.steps == 0:
if progress.device in self.total_jobs:
self.total_jobs[progress.device] += 1
else:

View File

@ -57,7 +57,7 @@ def worker_main(
logger.info("worker %s got job: %s", worker.device.device, job.name)
# clear flags and save the job name
worker.start(job.name)
worker.start(job)
logger.info("starting job: %s", job.name)
# reset progress, which does a final check for cancellation

View File

@ -1,12 +1,6 @@
from diffusers import OnnxStableDiffusionPipeline
from os import path
import cv2
import numpy as np
import onnxruntime as ort
import torch
import time
cfg = 8
steps = 22
height = 512

View File

@ -22,6 +22,7 @@ from onnx_web.params import (
UpscaleParams,
)
from onnx_web.server.context import ServerContext
from onnx_web.worker.command import JobCommand
from onnx_web.worker.context import WorkerContext
from tests.helpers import (
TEST_MODEL_DIFFUSION_SD15,
@ -57,7 +58,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
3,
0.1,
)
worker.start("test")
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
run_txt2img_pipeline(
worker,
@ -72,7 +73,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
1,
),
Size(256, 256),
["test-txt2img-basic.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
)
@ -103,7 +103,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
3,
0.1,
)
worker.start("test")
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
run_txt2img_pipeline(
worker,
@ -119,7 +119,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
batch=2,
),
Size(256, 256),
["test-txt2img-batch-0.png", "test-txt2img-batch-1.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
)
@ -152,7 +151,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
3,
0.1,
)
worker.start("test")
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
run_txt2img_pipeline(
worker,
@ -168,7 +167,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
unet_tile=256,
),
Size(256, 256),
["test-txt2img-highres.png"],
UpscaleParams("test", scale=2, outscale=2),
HighresParams(True, 2, 0, 0),
)
@ -198,7 +196,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
3,
0.1,
)
worker.start("test")
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
run_txt2img_pipeline(
worker,
@ -214,7 +212,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
batch=2,
),
Size(256, 256),
["test-txt2img-highres-batch-0.png", "test-txt2img-highres-batch-1.png"],
UpscaleParams("test"),
HighresParams(True, 2, 0, 0),
)
@ -230,7 +227,7 @@ class TestImg2ImgPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
def test_basic(self):
worker = test_worker()
worker.start("test")
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
source = Image.new("RGB", (64, 64), "black")
run_img2img_pipeline(
@ -245,7 +242,6 @@ class TestImg2ImgPipeline(unittest.TestCase):
1,
1,
),
["test-img2img.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
@ -259,7 +255,7 @@ class TestInpaintPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
def test_basic_white(self):
worker = test_worker()
worker.start("test")
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
source = Image.new("RGB", (64, 64), "black")
mask = Image.new("RGB", (64, 64), "white")
@ -277,7 +273,6 @@ class TestInpaintPipeline(unittest.TestCase):
unet_tile=64,
),
Size(*source.size),
["test-inpaint-white.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
@ -296,7 +291,7 @@ class TestInpaintPipeline(unittest.TestCase):
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
def test_basic_black(self):
worker = test_worker()
worker.start("test")
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
source = Image.new("RGB", (64, 64), "black")
mask = Image.new("RGB", (64, 64), "black")
@ -314,7 +309,6 @@ class TestInpaintPipeline(unittest.TestCase):
unet_tile=64,
),
Size(*source.size),
["test-inpaint-black.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
@ -353,7 +347,7 @@ class TestUpscalePipeline(unittest.TestCase):
3,
0.1,
)
worker.start("test")
worker.start(JobCommand("test", "test", "test", run_upscale_pipeline, [], {}))
source = Image.new("RGB", (64, 64), "black")
run_upscale_pipeline(
@ -369,7 +363,6 @@ class TestUpscalePipeline(unittest.TestCase):
1,
),
Size(256, 256),
["test-upscale.png"],
UpscaleParams("test"),
HighresParams(False, 1, 0, 0),
source,
@ -399,7 +392,7 @@ class TestBlendPipeline(unittest.TestCase):
3,
0.1,
)
worker.start("test")
worker.start(JobCommand("test", "test", "test", run_blend_pipeline, [], {}))
source = Image.new("RGBA", (64, 64), "black")
mask = Image.new("RGBA", (64, 64), "white")
@ -417,7 +410,6 @@ class TestBlendPipeline(unittest.TestCase):
unet_tile=64,
),
Size(64, 64),
["test-blend.png"],
UpscaleParams("test"),
[source, source],
mask,

View File

@ -5,6 +5,7 @@ from typing import Optional
from onnx_web.params import DeviceParams
from onnx_web.server.context import ServerContext
from onnx_web.worker.command import JobStatus
from onnx_web.worker.pool import DevicePoolExecutor
TEST_JOIN_TIMEOUT = 0.2
@ -50,11 +51,11 @@ class TestWorkerPool(unittest.TestCase):
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start()
self.pool.submit("test", wait_job, lock=lock)
self.assertEqual(self.pool.done("test"), (True, None))
self.pool.submit("test", "test", wait_job, lock=lock)
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
self.assertTrue(self.pool.cancel("test"))
self.assertEqual(self.pool.done("test"), (False, None))
self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None))
def test_cancel_running(self):
pass
@ -88,12 +89,14 @@ class TestWorkerPool(unittest.TestCase):
self.pool = DevicePoolExecutor(
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
)
lock.clear()
self.pool.start(lock)
self.pool.submit("test", test_job)
self.pool.submit("test", "test", test_job)
sleep(5.0)
pending, _progress = self.pool.done("test")
self.assertFalse(pending)
status, _progress = self.pool.status("test")
self.assertEqual(status, JobStatus.RUNNING)
def test_done_pending(self):
device = DeviceParams("cpu", "CPUProvider")
@ -102,9 +105,9 @@ class TestWorkerPool(unittest.TestCase):
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
self.pool.start(lock)
self.pool.submit("test1", test_job)
self.pool.submit("test2", test_job)
self.assertTrue(self.pool.done("test2"), (True, None))
self.pool.submit("test1", "test", test_job)
self.pool.submit("test2", "test", test_job)
self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None))
lock.set()
@ -119,12 +122,12 @@ class TestWorkerPool(unittest.TestCase):
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
)
self.pool.start()
self.pool.submit("test", wait_job)
self.assertEqual(self.pool.done("test"), (True, None))
self.pool.submit("test", "test", wait_job)
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
sleep(5.0)
pending, _progress = self.pool.done("test")
self.assertFalse(pending)
status, _progress = self.pool.status("test")
self.assertEqual(status, JobStatus.SUCCESS)
def test_recycle_live(self):
pass

View File

@ -40,7 +40,7 @@ class WorkerMainTests(unittest.TestCase):
nonlocal status
status = exit_status
job = JobCommand("test", "test", main_interrupt, [], {})
job = JobCommand("test", "test", "test", main_interrupt, [], {})
cancel = Value("L", False)
logs = Queue()
pending = Queue()
@ -75,7 +75,7 @@ class WorkerMainTests(unittest.TestCase):
nonlocal status
status = exit_status
job = JobCommand("test", "test", main_retry, [], {})
job = JobCommand("test", "test", "test", main_retry, [], {})
cancel = Value("L", False)
logs = Queue()
pending = Queue()
@ -144,7 +144,7 @@ class WorkerMainTests(unittest.TestCase):
nonlocal status
status = exit_status
job = JobCommand("test", "test", main_memory, [], {})
job = JobCommand("test", "test", "test", main_memory, [], {})
cancel = Value("L", False)
logs = Queue()
pending = Queue()

View File

@ -3,7 +3,7 @@
onnx-web is designed to simplify the process of running Stable Diffusion and other [ONNX models](https://onnx.ai) so you
can focus on making high quality, high resolution art. With the efficiency of hardware acceleration on both AMD and
Nvidia GPUs, and offering a reliable CPU software fallback, it offers the full feature set on desktop, laptops, and
servers with a seamless user experience.
multi-GPU servers with a seamless user experience.
You can navigate through the user-friendly web UI, hosted on Github Pages and accessible across all major browsers,
including your go-to mobile device. Here, you have the flexibility to choose diffusion models and accelerators for each
@ -84,18 +84,6 @@ This is an incomplete list of new and interesting features:
- includes both the API and GUI bundle in a single container
- runs well on [RunPod](https://www.runpod.io/), [Vast.ai](https://vast.ai/), and other GPU container hosting services
## Contents
- [onnx-web](#onnx-web)
- [Features](#features)
- [Contents](#contents)
- [Setup](#setup)
- [Adding your own models](#adding-your-own-models)
- [Usage](#usage)
- [Known errors and solutions](#known-errors-and-solutions)
- [Running the containers](#running-the-containers)
- [Credits](#credits)
## Setup
There are a few ways to run onnx-web:

View File

@ -4,10 +4,7 @@ import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils';
import { ServerParams } from '../config.js';
import {
FilterResponse,
ImageResponse,
ImageResponseWithRetry,
ModelResponse,
ReadyResponse,
RetryParams,
WriteExtrasResponse,
} from '../types/api.js';
@ -27,6 +24,7 @@ import {
} from '../types/params.js';
import { range } from '../utils.js';
import { ApiClient } from './base.js';
import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/api-v2.js';
/**
* Fixed precision for integer parameters.
@ -43,8 +41,9 @@ export const FIXED_INTEGER = 0;
export const FIXED_FLOAT = 2;
export const STATUS_SUCCESS = 200;
export function equalResponse(a: ImageResponse, b: ImageResponse): boolean {
return a.outputs === b.outputs;
export function equalResponse(a: JobResponse, b: JobResponse): boolean {
return a.name === b.name && a.status === b.status && a.type === b.type;
// return a.outputs === b.outputs;
}
/**
@ -141,8 +140,8 @@ export function appendHighresToURL(url: URL, highres: HighresParams) {
* Make an API client using the given API root and fetch client.
*/
export function makeClient(root: string, token: Maybe<string> = undefined, f = fetch): ApiClient {
function parseRequest(url: URL, options: RequestInit): Promise<ImageResponse> {
return f(url, options).then((res) => parseApiResponse(root, res));
function parseRequest(url: URL, options: RequestInit): Promise<JobResponse> {
return f(url, options).then((res) => parseJobResponse(root, res));
}
return {
@ -218,7 +217,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
const res = await f(path);
return await res.json() as Array<string>;
},
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
const url = makeImageURL(root, 'img2img', params);
appendModelToURL(url, model);
@ -240,12 +239,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
const body = new FormData();
body.append('source', params.source, 'source');
const image = await parseRequest(url, {
const job = await parseRequest(url, {
body,
method: 'POST',
});
return {
image,
job,
retry: {
type: 'img2img',
model,
@ -254,7 +253,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
},
};
},
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
const url = makeImageURL(root, 'txt2img', params);
appendModelToURL(url, model);
@ -274,11 +273,11 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
appendHighresToURL(url, highres);
}
const image = await parseRequest(url, {
const job = await parseRequest(url, {
method: 'POST',
});
return {
image,
job,
retry: {
type: 'txt2img',
model,
@ -288,7 +287,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
},
};
},
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model);
@ -309,12 +308,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
body.append('mask', params.mask, 'mask');
body.append('source', params.source, 'source');
const image = await parseRequest(url, {
const job = await parseRequest(url, {
body,
method: 'POST',
});
return {
image,
job,
retry: {
type: 'inpaint',
model,
@ -323,7 +322,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
},
};
},
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
const url = makeImageURL(root, 'inpaint', params);
appendModelToURL(url, model);
@ -361,12 +360,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
body.append('mask', params.mask, 'mask');
body.append('source', params.source, 'source');
const image = await parseRequest(url, {
const job = await parseRequest(url, {
body,
method: 'POST',
});
return {
image,
job,
retry: {
type: 'outpaint',
model,
@ -375,7 +374,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
},
};
},
async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
const url = makeApiUrl(root, 'upscale');
appendModelToURL(url, model);
@ -396,12 +395,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
const body = new FormData();
body.append('source', params.source, 'source');
const image = await parseRequest(url, {
const job = await parseRequest(url, {
body,
method: 'POST',
});
return {
image,
job,
retry: {
type: 'upscale',
model,
@ -410,7 +409,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
},
};
},
async blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
async blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<JobResponseWithRetry> {
const url = makeApiUrl(root, 'blend');
appendModelToURL(url, model);
@ -426,12 +425,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
body.append(name, params.sources[i], name);
}
const image = await parseRequest(url, {
const job = await parseRequest(url, {
body,
method: 'POST',
});
return {
image,
job,
retry: {
type: 'blend',
model,
@ -440,8 +439,8 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
}
};
},
async chain(model: ModelParams, chain: ChainPipeline): Promise<ImageResponse> {
const url = makeApiUrl(root, 'chain');
async chain(model: ModelParams, chain: ChainPipeline): Promise<JobResponse> {
const url = makeApiUrl(root, 'job');
const body = JSON.stringify({
...chain,
platform: model.platform,
@ -456,23 +455,23 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
method: 'POST',
});
},
async ready(key: string): Promise<ReadyResponse> {
const path = makeApiUrl(root, 'ready');
path.searchParams.append('output', key);
async status(keys: Array<string>): Promise<Array<JobResponse>> {
const path = makeApiUrl(root, 'job', 'status');
path.searchParams.append('jobs', keys.join(','));
const res = await f(path);
return await res.json() as ReadyResponse;
return await res.json() as Array<JobResponse>;
},
async cancel(key: string): Promise<boolean> {
const path = makeApiUrl(root, 'cancel');
path.searchParams.append('output', key);
async cancel(keys: Array<string>): Promise<Array<JobResponse>> {
const path = makeApiUrl(root, 'job', 'cancel');
path.searchParams.append('jobs', keys.join(','));
const res = await f(path, {
method: 'PUT',
});
return res.status === STATUS_SUCCESS;
return await res.json() as Array<JobResponse>;
},
async retry(retry: RetryParams): Promise<ImageResponseWithRetry> {
async retry(retry: RetryParams): Promise<JobResponseWithRetry> {
switch (retry.type) {
case 'blend':
return this.blend(retry.model, retry.params, retry.upscale);
@ -491,7 +490,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
}
},
async restart(): Promise<boolean> {
const path = makeApiUrl(root, 'restart');
const path = makeApiUrl(root, 'worker', 'restart');
if (doesExist(token)) {
path.searchParams.append('token', token);
@ -502,8 +501,8 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
});
return res.status === STATUS_SUCCESS;
},
async status(): Promise<Array<unknown>> {
const path = makeApiUrl(root, 'status');
async workers(): Promise<Array<unknown>> {
const path = makeApiUrl(root, 'worker', 'status');
if (doesExist(token)) {
path.searchParams.append('token', token);
@ -512,6 +511,9 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
const res = await f(path);
return res.json();
},
outputURL(image: SuccessJobResponse, index: number): string {
return new URL(joinPath('output', image.outputs[index]), root).toString();
},
};
}
@ -521,24 +523,9 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
* The server sends over the output key, and the client is in the best position to turn
* that into a full URL, since it already knows the root URL of the server.
*/
export async function parseApiResponse(root: string, res: Response): Promise<ImageResponse> {
type LimitedResponse = Omit<ImageResponse, 'outputs'> & { outputs: Array<string> };
export async function parseJobResponse(root: string, res: Response): Promise<JobResponse> {
if (res.status === STATUS_SUCCESS) {
const data = await res.json() as LimitedResponse;
const outputs = data.outputs.map((output) => {
const url = new URL(joinPath('output', output), root).toString();
return {
key: output,
url,
};
});
return {
...data,
outputs,
};
return await res.json() as JobResponse;
} else {
throw new Error('request error');
}

View File

@ -1,12 +1,19 @@
import { ServerParams } from '../config.js';
import { ExtrasFile } from '../types/model.js';
import { WriteExtrasResponse, FilterResponse, ModelResponse, ImageResponseWithRetry, ImageResponse, ReadyResponse, RetryParams } from '../types/api.js';
import { WriteExtrasResponse, FilterResponse, ModelResponse, RetryParams } from '../types/api.js';
import { ChainPipeline } from '../types/chain.js';
import { ModelParams, Txt2ImgParams, UpscaleParams, HighresParams, Img2ImgParams, InpaintParams, OutpaintParams, UpscaleReqParams, BlendParams } from '../types/params.js';
import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/api-v2.js';
export interface ApiClient {
/**
* Get the first extras file.
*/
extras(): Promise<ExtrasFile>;
/**
* Update the first extras file.
*/
writeExtras(extras: ExtrasFile): Promise<WriteExtrasResponse>;
/**
@ -51,54 +58,60 @@ export interface ApiClient {
translation: Record<string, string>;
}>>;
/**
* Get the available wildcards.
*/
wildcards(): Promise<Array<string>>;
/**
* Start a txt2img pipeline.
*/
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
/**
* Start an im2img pipeline.
*/
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
/**
* Start an inpaint pipeline.
*/
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
/**
* Start an outpaint pipeline.
*/
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
/**
* Start an upscale pipeline.
*/
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
/**
* Start a blending pipeline.
*/
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<JobResponseWithRetry>;
chain(model: ModelParams, chain: ChainPipeline): Promise<ImageResponse>;
/**
* Start a custom chain pipeline.
*/
chain(model: ModelParams, chain: ChainPipeline): Promise<JobResponse>;
/**
* Check whether job has finished and its output is ready.
*/
ready(key: string): Promise<ReadyResponse>;
status(keys: Array<string>): Promise<Array<JobResponse>>;
/**
* Cancel an existing job.
*/
cancel(key: string): Promise<boolean>;
cancel(keys: Array<string>): Promise<Array<JobResponse>>;
/**
* Retry a previous job using the same parameters.
*/
retry(params: RetryParams): Promise<ImageResponseWithRetry>;
retry(params: RetryParams): Promise<JobResponseWithRetry>;
/**
* Restart the image job workers.
@ -108,5 +121,7 @@ export interface ApiClient {
/**
* Check the status of the image job workers.
*/
status(): Promise<Array<unknown>>;
workers(): Promise<Array<unknown>>;
outputURL(image: SuccessJobResponse, index: number): string;
}

View File

@ -48,7 +48,7 @@ export const LOCAL_CLIENT = {
async params() {
throw new NoServerError();
},
async ready(key) {
async status(key) {
throw new NoServerError();
},
async cancel(key) {
@ -78,7 +78,10 @@ export const LOCAL_CLIENT = {
async restart() {
throw new NoServerError();
},
async status() {
async workers() {
throw new NoServerError();
}
},
outputURL(image, index) {
throw new NoServerError();
},
} as ApiClient;

View File

@ -97,11 +97,19 @@ export function expandRanges(range: string): Array<string | number> {
export const GRID_TILE_SIZE = 8192;
// eslint-disable-next-line max-params
export function makeTxt2ImgGridPipeline(grid: PipelineGrid, model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): ChainPipeline {
export function makeTxt2ImgGridPipeline(
grid: PipelineGrid,
model: ModelParams,
params: Txt2ImgParams,
upscale?: UpscaleParams,
highres?: HighresParams,
): ChainPipeline {
const pipeline: ChainPipeline = {
defaults: {
...model,
...params,
...(upscale || {}),
...(highres || {}),
},
stages: [],
};

View File

@ -10,6 +10,7 @@ import { OnnxState, StateContext } from '../state/full.js';
import { ErrorCard } from './card/ErrorCard.js';
import { ImageCard } from './card/ImageCard.js';
import { LoadingCard } from './card/LoadingCard.js';
import { JobStatus } from '../types/api-v2.js';
export function ImageHistory() {
const store = mustExist(useContext(StateContext));
@ -25,19 +26,19 @@ export function ImageHistory() {
const limited = history.slice(0, limit);
for (const item of limited) {
const key = item.image.outputs[0].key;
if (doesExist(item.ready) && item.ready.ready) {
if (item.ready.cancelled || item.ready.failed) {
children.push([key, <ErrorCard key={`history-${key}`} image={item.image} ready={item.ready} retry={item.retry} />]);
continue;
}
const key = item.image.name;
switch (item.image.status) {
case JobStatus.SUCCESS:
children.push([key, <ImageCard key={`history-${key}`} image={item.image} onDelete={removeHistory} />]);
continue;
break;
case JobStatus.FAILED:
children.push([key, <ErrorCard key={`history-${key}`} image={item.image} retry={item.retry} />]);
break;
default:
children.push([key, <LoadingCard key={`history-${key}`} image={item.image} />]);
break;
}
children.push([key, <LoadingCard key={`history-${key}`} index={0} image={item.image} />]);
}
return <Grid container spacing={2}>{children.map(([key, child]) => <Grid item key={key} xs={6}>{child}</Grid>)}</Grid>;

View File

@ -10,16 +10,15 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { ImageResponse, ReadyResponse, RetryParams } from '../../types/api.js';
import { FailedJobResponse, RetryParams } from '../../types/api-v2.js';
export interface ErrorCardProps {
image: ImageResponse;
ready: ReadyResponse;
image: FailedJobResponse;
retry: Maybe<RetryParams>;
}
export function ErrorCard(props: ErrorCardProps) {
const { image, ready, retry: retryParams } = props;
const { image, retry: retryParams } = props;
const client = mustExist(useContext(ClientContext));
const { params } = mustExist(useContext(ConfigContext));
@ -32,8 +31,8 @@ export function ErrorCard(props: ErrorCardProps) {
removeHistory(image);
if (doesExist(retryParams)) {
const { image: nextImage, retry: nextRetry } = await client.retry(retryParams);
pushHistory(nextImage, nextRetry);
const { job: nextJob, retry: nextRetry } = await client.retry(retryParams);
pushHistory(nextJob, nextRetry);
}
}
@ -52,10 +51,11 @@ export function ErrorCard(props: ErrorCardProps) {
spacing={2}
sx={{ alignItems: 'center' }}
>
<Alert severity='error'>{t('loading.progress', {
current: ready.progress,
total: image.params.steps,
})}</Alert>
<Alert severity='error'>
{t('loading.progress', image.steps)}
<br />
{image.error}
</Alert>
<Stack direction='row' spacing={2}>
<Tooltip title={t('tooltip.retry')}>
<IconButton onClick={() => retry.mutate()}>

View File

@ -2,21 +2,22 @@ import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils'
import { ArrowLeft, ArrowRight, Blender, Brush, ContentCopy, Delete, Download, ZoomOutMap } from '@mui/icons-material';
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Menu, MenuItem, Paper, Tooltip } from '@mui/material';
import * as React from 'react';
import { useContext, useState } from 'react';
import { useContext, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useHash } from 'react-use/lib/useHash';
import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { ImageResponse } from '../../types/api.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { range, visibleIndex } from '../../utils.js';
import { BLEND_SOURCES } from '../../constants.js';
import { JobResponse, SuccessJobResponse } from '../../types/api-v2.js';
import { getApiRoot } from '../../config.js';
export interface ImageCardProps {
image: ImageResponse;
image: SuccessJobResponse;
onDelete?: (key: ImageResponse) => void;
onDelete?: (key: JobResponse) => void;
}
export function GridItem(props: { xs: number; children: React.ReactNode }) {
@ -27,18 +28,19 @@ export function GridItem(props: { xs: number; children: React.ReactNode }) {
export function ImageCard(props: ImageCardProps) {
const { image } = props;
const { params, outputs, size } = image;
const { metadata, outputs } = image;
const [_hash, setHash] = useHash();
const [blendAnchor, setBlendAnchor] = useState<Maybe<HTMLElement>>();
const [saveAnchor, setSaveAnchor] = useState<Maybe<HTMLElement>>();
const client = mustExist(useContext(ClientContext));
const config = mustExist(useContext(ConfigContext));
const store = mustExist(useContext(StateContext));
const { setBlend, setImg2Img, setInpaint, setUpscale } = useStore(store, selectActions, shallow);
async function loadSource() {
const req = await fetch(outputs[index].url);
const req = await fetch(url);
return req.blob();
}
@ -84,12 +86,12 @@ export function ImageCard(props: ImageCardProps) {
}
function downloadImage() {
window.open(outputs[index].url, '_blank');
window.open(url, '_blank');
close();
}
function downloadMetadata() {
window.open(outputs[index].url + '.json', '_blank');
window.open(url + '.json', '_blank');
close();
}
@ -106,14 +108,16 @@ export function ImageCard(props: ImageCardProps) {
return mustDefault(t(`${key}.${name}`), name);
}
const model = getLabel('model', params.model);
const scheduler = getLabel('scheduler', params.scheduler);
const url = useMemo(() => client.outputURL(image, index), [image, index]);
const model = getLabel('model', metadata[index].model);
const scheduler = getLabel('scheduler', metadata[index].scheduler);
return <Card sx={{ maxWidth: config.params.width.default }} elevation={2}>
<CardMedia sx={{ height: config.params.height.default }}
component='img'
image={outputs[index].url}
title={params.prompt}
image={url}
title={metadata[index].params.prompt}
/>
<CardContent>
<Box textAlign='center'>
@ -146,12 +150,12 @@ export function ImageCard(props: ImageCardProps) {
</GridItem>
<GridItem xs={4}>{t('modelType.diffusion', {count: 1})}: {model}</GridItem>
<GridItem xs={4}>{t('parameter.scheduler')}: {scheduler}</GridItem>
<GridItem xs={4}>{t('parameter.seed')}: {params.seed}</GridItem>
<GridItem xs={4}>{t('parameter.cfg')}: {params.cfg}</GridItem>
<GridItem xs={4}>{t('parameter.steps')}: {params.steps}</GridItem>
<GridItem xs={4}>{t('parameter.size')}: {size.width}x{size.height}</GridItem>
<GridItem xs={4}>{t('parameter.seed')}: {metadata[index].params.seed}</GridItem>
<GridItem xs={4}>{t('parameter.cfg')}: {metadata[index].params.cfg}</GridItem>
<GridItem xs={4}>{t('parameter.steps')}: {metadata[index].params.steps}</GridItem>
<GridItem xs={4}>{t('parameter.size')}: {metadata[index].size.width}x{metadata[index].size.height}</GridItem>
<GridItem xs={12}>
<Box textAlign='left'>{params.prompt}</Box>
<Box textAlign='left'>{metadata[index].params.prompt}</Box>
</GridItem>
<GridItem xs={2}>
<Tooltip title={t('tooltip.save')}>

View File

@ -1,4 +1,4 @@
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils';
import { Box, Button, Card, CardContent, CircularProgress, Typography } from '@mui/material';
import { Stack } from '@mui/system';
import { useMutation, useQuery } from '@tanstack/react-query';
@ -10,19 +10,17 @@ import { shallow } from 'zustand/shallow';
import { POLL_TIME } from '../../config.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { ImageResponse } from '../../types/api.js';
import { JobResponse, JobStatus } from '../../types/api-v2.js';
const LOADING_PERCENT = 100;
const LOADING_OVERAGE = 99;
export interface LoadingCardProps {
image: ImageResponse;
index: number;
image: JobResponse;
}
export function LoadingCard(props: LoadingCardProps) {
const { image, index } = props;
const { steps } = props.image.params;
const { image } = props;
const client = mustExist(useContext(ClientContext));
const { params } = mustExist(useContext(ConfigContext));
@ -31,50 +29,22 @@ export function LoadingCard(props: LoadingCardProps) {
const { removeHistory, setReady } = useStore(store, selectActions, shallow);
const { t } = useTranslation();
const cancel = useMutation(() => client.cancel(image.outputs[index].key));
const ready = useQuery(['ready', image.outputs[index].key], () => client.ready(image.outputs[index].key), {
const cancel = useMutation(() => client.cancel([image.name]));
const ready = useQuery(['ready', image.name], () => client.status([image.name]), {
// data will always be ready without this, even if the API says its not
cacheTime: 0,
refetchInterval: POLL_TIME,
});
function getProgress() {
if (doesExist(ready.data)) {
return ready.data.progress;
}
return 0;
}
function getPercent() {
const progress = getProgress();
if (progress > steps) {
// steps was not complete, show 99% until done
return LOADING_OVERAGE;
}
const pct = progress / steps;
return Math.ceil(pct * LOADING_PERCENT);
}
function getTotal() {
const progress = getProgress();
if (progress > steps) {
// steps was not complete, show 99% until done
return t('loading.unknown');
}
return steps.toFixed(0);
}
function getReady() {
return doesExist(ready.data) && ready.data.ready;
return doesExist(ready.data) && ready.data[0].status === JobStatus.SUCCESS;
}
function renderProgress() {
const progress = getProgress();
if (progress > 0 && progress <= steps) {
return <CircularProgress variant='determinate' value={getPercent()} />;
const progress = getProgress(ready.data);
const total = getTotal(ready.data);
if (progress > 0 && progress <= total) {
return <CircularProgress variant='determinate' value={getPercent(progress, total)} />;
} else {
return <CircularProgress />;
}
@ -88,9 +58,9 @@ export function LoadingCard(props: LoadingCardProps) {
useEffect(() => {
if (ready.status === 'success' && getReady()) {
setReady(props.image, ready.data);
setReady(ready.data[0]);
}
}, [ready.status, getReady(), getProgress()]);
}, [ready.status, getReady(), getProgress(ready.data)]);
return <Card sx={{ maxWidth: params.width.default }}>
<CardContent sx={{ height: params.height.default }}>
@ -106,10 +76,7 @@ export function LoadingCard(props: LoadingCardProps) {
sx={{ alignItems: 'center' }}
>
{renderProgress()}
<Typography>{t('loading.progress', {
current: getProgress(),
total: getTotal(),
})}</Typography>
<Typography>{t('loading.progress', selectStatus(ready.data, image))}</Typography>
<Button onClick={() => cancel.mutate()}>{t('loading.cancel')}</Button>
</Stack>
</Box>
@ -125,3 +92,45 @@ export function selectActions(state: OnnxState) {
setReady: state.setReady,
};
}
export function selectStatus(data: Maybe<Array<JobResponse>>, defaultData: JobResponse) {
if (doesExist(data) && data.length > 0) {
return {
steps: data[0].steps,
stages: data[0].stages,
tiles: data[0].tiles,
};
}
return {
steps: defaultData.steps,
stages: defaultData.stages,
tiles: defaultData.tiles,
};
}
export function getPercent(current: number, total: number): number {
if (current > total) {
// steps was not complete, show 99% until done
return LOADING_OVERAGE;
}
const pct = current / total;
return Math.ceil(pct * LOADING_PERCENT);
}
export function getProgress(data: Maybe<Array<JobResponse>>) {
if (doesExist(data)) {
return data[0].steps.current;
}
return 0;
}
export function getTotal(data: Maybe<Array<JobResponse>>) {
if (doesExist(data)) {
return data[0].steps.total;
}
return 0;
}

View File

@ -20,13 +20,13 @@ import { MaskCanvas } from '../input/MaskCanvas.js';
export function Blend() {
async function uploadSource() {
const { blend, blendModel, blendUpscale } = store.getState();
const { image, retry } = await client.blend(blendModel, {
const { job, retry } = await client.blend(blendModel, {
...blend,
mask: mustExist(blend.mask),
sources: mustExist(blend.sources), // TODO: show an error if this doesn't exist
}, blendUpscale);
pushHistory(image, retry);
pushHistory(job, retry);
}
const client = mustExist(useContext(ClientContext));

View File

@ -27,12 +27,12 @@ export function Img2Img() {
const state = store.getState();
const img2img = selectParams(state);
const { image, retry } = await client.img2img(model, {
const { job, retry } = await client.img2img(model, {
...img2img,
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
}, selectUpscale(state), selectHighres(state));
pushHistory(image, retry);
pushHistory(job, retry);
}
const client = mustExist(useContext(ClientContext));

View File

@ -39,22 +39,22 @@ export function Inpaint() {
const inpaint = selectParams(state);
if (outpaint.enabled) {
const { image, retry } = await client.outpaint(model, {
const { job, retry } = await client.outpaint(model, {
...inpaint,
...outpaint,
mask: mustExist(mask),
source: mustExist(source),
}, selectUpscale(state), selectHighres(state));
pushHistory(image, retry);
pushHistory(job, retry);
} else {
const { image, retry } = await client.inpaint(model, {
const { job, retry } = await client.inpaint(model, {
...inpaint,
mask: mustExist(mask),
source: mustExist(source),
}, selectUpscale(state), selectHighres(state));
pushHistory(image, retry);
pushHistory(job, retry);
}
}

View File

@ -69,8 +69,8 @@ export function Txt2Img() {
const image = await client.chain(model, chain);
pushHistory(image);
} else {
const { image, retry } = await client.txt2img(model, params2, upscale, highres);
pushHistory(image, retry);
const { job, retry } = await client.txt2img(model, params2, upscale, highres);
pushHistory(job, retry);
}
}

View File

@ -21,12 +21,12 @@ import { PromptInput } from '../input/PromptInput.js';
export function Upscale() {
async function uploadSource() {
const { upscaleHighres, upscaleUpscale, upscaleModel, upscale } = store.getState();
const { image, retry } = await client.upscale(upscaleModel, {
const { job, retry } = await client.upscale(upscaleModel, {
...upscale,
source: mustExist(upscale.source), // TODO: show an error if this doesn't exist
}, upscaleUpscale, upscaleHighres);
pushHistory(image, retry);
pushHistory(job, retry);
}
const client = mustExist(useContext(ClientContext));

View File

@ -2,6 +2,7 @@ import { Maybe } from '@apextoaster/js-utils';
import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js';
import { Slice } from './types.js';
import { DEFAULT_HISTORY } from '../constants.js';
import { JobResponse } from '../types/api-v2.js';
export interface HistoryItem {
image: ImageResponse;
@ -9,14 +10,19 @@ export interface HistoryItem {
retry: Maybe<RetryParams>;
}
export interface HistoryItemV2 {
image: JobResponse;
retry: Maybe<RetryParams>;
}
export interface HistorySlice {
history: Array<HistoryItem>;
history: Array<HistoryItemV2>;
limit: number;
pushHistory(image: ImageResponse, retry?: RetryParams): void;
removeHistory(image: ImageResponse): void;
pushHistory(image: JobResponse, retry?: RetryParams): void;
removeHistory(image: JobResponse): void;
setLimit(limit: number): void;
setReady(image: ImageResponse, ready: ReadyResponse): void;
setReady(image: JobResponse): void;
}
export function createHistorySlice<TState extends HistorySlice>(): Slice<TState, HistorySlice> {
@ -39,7 +45,7 @@ export function createHistorySlice<TState extends HistorySlice>(): Slice<TState,
removeHistory(image) {
set((prev) => ({
...prev,
history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key),
history: prev.history.filter((it) => it.image.name !== image.name),
}));
},
setLimit(limit) {
@ -48,12 +54,12 @@ export function createHistorySlice<TState extends HistorySlice>(): Slice<TState,
limit,
}));
},
setReady(image, ready) {
setReady(image) {
set((prev) => {
const history = [...prev.history];
const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key);
const idx = history.findIndex((it) => it.image.name === image.name);
if (idx >= 0) {
history[idx].ready = ready;
history[idx].image = image;
} else {
// TODO: error
}

View File

@ -67,7 +67,7 @@ export const I18N_STRINGS_EN = {
},
loading: {
cancel: 'Cancel',
progress: '{{current}} of {{total}} steps',
progress: '{{steps.current}} of {{steps.total}} steps, {{tiles.current}} of {{tiles.total}} tiles, {{stages.current}} of {{stages.total}} stages',
server: 'Connecting to server...',
unknown: 'many',
},

160
gui/src/types/api-v2.ts Normal file
View File

@ -0,0 +1,160 @@
import { RetryParams } from './api.js';
import { BaseImgParams, HighresParams, Img2ImgParams, InpaintParams, Txt2ImgParams, UpscaleParams } from './params.js';
export interface Progress {
current: number;
total: number;
}
export interface Size {
width: number;
height: number;
}
export interface NetworkMetadata {
name: string;
hash: string;
weight: number;
}
export interface ImageMetadata<TParams extends BaseImgParams, TType extends JobType> {
input_size: Size;
size: Size;
outputs: Array<string>;
params: TParams;
inversions: Array<NetworkMetadata>;
loras: Array<NetworkMetadata>;
model: string;
scheduler: string;
border: unknown;
highres: HighresParams;
upscale: UpscaleParams;
type: TType;
}
export enum JobStatus {
PENDING = 'pending',
RUNNING = 'running',
SUCCESS = 'success',
FAILED = 'failed',
CANCELLED = 'cancelled',
UNKNOWN = 'unknown',
}
export enum JobType {
TXT2IMG = 'txt2img',
IMG2IMG = 'img2img',
INPAINT = 'inpaint',
UPSCALE = 'upscale',
BLEND = 'blend',
CHAIN = 'chain',
}
export interface BaseJobResponse {
name: string;
status: JobStatus;
type: JobType;
stages: Progress;
steps: Progress;
tiles: Progress;
}
/**
* Pending image job.
*/
export interface PendingJobResponse extends BaseJobResponse {
status: JobStatus.PENDING | JobStatus.RUNNING;
queue: Progress;
}
/**
* Failed image job with error information.
*/
export interface FailedJobResponse extends BaseJobResponse {
status: JobStatus.FAILED;
error: string;
}
/**
* Successful txt2img image job with output keys and metadata.
*/
export interface SuccessTxt2ImgJobResponse extends BaseJobResponse {
status: JobStatus.SUCCESS;
type: JobType.TXT2IMG;
outputs: Array<string>;
metadata: Array<ImageMetadata<Txt2ImgParams, JobType.TXT2IMG>>;
}
/**
* Successful img2img job with output keys and metadata.
*/
export interface SuccessImg2ImgJobResponse extends BaseJobResponse {
status: JobStatus.SUCCESS;
type: JobType.IMG2IMG;
outputs: Array<string>;
metadata: Array<ImageMetadata<Img2ImgParams, JobType.IMG2IMG>>;
}
/**
* Successful inpaint job with output keys and metadata.
*/
export interface SuccessInpaintJobResponse extends BaseJobResponse {
status: JobStatus.SUCCESS;
type: JobType.INPAINT;
outputs: Array<string>;
metadata: Array<ImageMetadata<InpaintParams, JobType.INPAINT>>;
}
/**
* Successful upscale job with output keys and metadata.
*/
export interface SuccessUpscaleJobResponse extends BaseJobResponse {
status: JobStatus.SUCCESS;
type: JobType.UPSCALE;
outputs: Array<string>;
metadata: Array<ImageMetadata<BaseImgParams, JobType.UPSCALE>>;
}
/**
* Successful blend job with output keys and metadata.
*/
export interface SuccessBlendJobResponse extends BaseJobResponse {
status: JobStatus.SUCCESS;
type: JobType.BLEND;
outputs: Array<string>;
metadata: Array<ImageMetadata<BaseImgParams, JobType.BLEND>>;
}
/**
* Successful chain pipeline job with output keys and metadata.
*/
export interface SuccessChainJobResponse extends BaseJobResponse {
status: JobStatus.SUCCESS;
type: JobType.CHAIN;
outputs: Array<string>;
metadata: Array<ImageMetadata<BaseImgParams, JobType>>; // TODO: could be all kinds
}
export type SuccessJobResponse
= SuccessTxt2ImgJobResponse
| SuccessImg2ImgJobResponse
| SuccessInpaintJobResponse
| SuccessUpscaleJobResponse
| SuccessBlendJobResponse
| SuccessChainJobResponse;
export type JobResponse = PendingJobResponse | FailedJobResponse | SuccessJobResponse;
/**
* Status response from the job endpoint, with parameters to retry the job if it fails.
*/
export interface JobResponseWithRetry {
job: JobResponse;
retry: RetryParams;
}
/**
* Re-export `RetryParams` for convenience.
*/
export { RetryParams };

View File

@ -14,6 +14,8 @@ import {
/**
* Output image data within the response.
*
* @deprecated
*/
export interface ImageOutput {
key: string;
@ -22,6 +24,8 @@ export interface ImageOutput {
/**
* General response for most image requests.
*
* @deprecated
*/
export interface ImageResponse {
outputs: Array<ImageOutput>;
@ -119,11 +123,19 @@ export type RetryParams = {
upscale?: UpscaleParams;
};
/**
* Status response from the image endpoint, with parameters to retry the job if it fails.
*
* @deprecated
*/
export interface ImageResponseWithRetry {
image: ImageResponse;
retry: RetryParams;
}
/**
* @deprecated
*/
export interface ImageMetadata {
highres: HighresParams;
outputs: string | Array<string>;

View File

@ -43,6 +43,7 @@
"dtype",
"ESRGAN",
"Exif",
"fromarray",
"ftfy",
"gfpgan",
"Heun",
@ -115,6 +116,7 @@
"webp",
"xformers",
"zustand"
]
],
"git.ignoreLimitWarning": true
}
}