feat: add batch endpoints for cancel and status, update responses
This commit is contained in:
parent
19c91f70f5
commit
44a8d61082
|
@ -36,15 +36,14 @@ class CorrectCodeformerStage(BaseStage):
|
||||||
# https://pypi.org/project/codeformer-perceptor/
|
# https://pypi.org/project/codeformer-perceptor/
|
||||||
|
|
||||||
# import must be within the load function for patches to take effect
|
# import must be within the load function for patches to take effect
|
||||||
# TODO: rewrite and remove
|
from codeformer.basicsr.archs.codeformer_arch import CodeFormer
|
||||||
from codeformer.basicsr.utils import img2tensor, tensor2img
|
from codeformer.basicsr.utils import img2tensor, tensor2img
|
||||||
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
|
|
||||||
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
|
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
|
|
||||||
upscale = upscale.with_args(**kwargs)
|
upscale = upscale.with_args(**kwargs)
|
||||||
device = worker.get_device()
|
device = worker.get_device()
|
||||||
|
|
||||||
net = ARCH_REGISTRY.get("CodeFormer")(
|
net = CodeFormer(
|
||||||
dim_embd=512,
|
dim_embd=512,
|
||||||
codebook_size=1024,
|
codebook_size=1024,
|
||||||
n_head=8,
|
n_head=8,
|
||||||
|
|
|
@ -1,10 +1,28 @@
|
||||||
from typing import Any, List, Optional, Tuple
|
from json import dumps
|
||||||
|
from logging import getLogger
|
||||||
|
from os import path
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..output import json_params
|
from ..convert.utils import resolve_tensor
|
||||||
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
|
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
|
||||||
|
from ..server.load import get_extra_hashes
|
||||||
|
from ..utils import hash_file
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkMetadata:
|
||||||
|
name: str
|
||||||
|
hash: str
|
||||||
|
weight: float
|
||||||
|
|
||||||
|
def __init__(self, name: str, hash: str, weight: float) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.hash = hash
|
||||||
|
self.weight = weight
|
||||||
|
|
||||||
|
|
||||||
class ImageMetadata:
|
class ImageMetadata:
|
||||||
|
@ -13,8 +31,9 @@ class ImageMetadata:
|
||||||
params: ImageParams
|
params: ImageParams
|
||||||
size: Size
|
size: Size
|
||||||
upscale: UpscaleParams
|
upscale: UpscaleParams
|
||||||
inversions: Optional[List[Tuple[str, float]]]
|
inversions: Optional[List[NetworkMetadata]]
|
||||||
loras: Optional[List[Tuple[str, float]]]
|
loras: Optional[List[NetworkMetadata]]
|
||||||
|
models: Optional[List[NetworkMetadata]]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -23,8 +42,9 @@ class ImageMetadata:
|
||||||
upscale: Optional[UpscaleParams] = None,
|
upscale: Optional[UpscaleParams] = None,
|
||||||
border: Optional[Border] = None,
|
border: Optional[Border] = None,
|
||||||
highres: Optional[HighresParams] = None,
|
highres: Optional[HighresParams] = None,
|
||||||
inversions: Optional[List[Tuple[str, float]]] = None,
|
inversions: Optional[List[NetworkMetadata]] = None,
|
||||||
loras: Optional[List[Tuple[str, float]]] = None,
|
loras: Optional[List[NetworkMetadata]] = None,
|
||||||
|
models: Optional[List[NetworkMetadata]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.params = params
|
self.params = params
|
||||||
self.size = size
|
self.size = size
|
||||||
|
@ -33,19 +53,108 @@ class ImageMetadata:
|
||||||
self.highres = highres
|
self.highres = highres
|
||||||
self.inversions = inversions
|
self.inversions = inversions
|
||||||
self.loras = loras
|
self.loras = loras
|
||||||
|
self.models = models
|
||||||
|
|
||||||
|
def to_auto1111(self, server, outputs) -> str:
|
||||||
|
model_name = path.basename(path.normpath(self.params.model))
|
||||||
|
logger.debug("getting model hash for %s", model_name)
|
||||||
|
|
||||||
|
model_hash = get_extra_hashes().get(model_name, None)
|
||||||
|
if model_hash is None:
|
||||||
|
model_hash_path = path.join(self.params.model, "hash.txt")
|
||||||
|
if path.exists(model_hash_path):
|
||||||
|
with open(model_hash_path, "r") as f:
|
||||||
|
model_hash = f.readline().rstrip(",. \n\t\r")
|
||||||
|
|
||||||
|
model_hash = model_hash or "unknown"
|
||||||
|
hash_map = {
|
||||||
|
model_name: model_hash,
|
||||||
|
}
|
||||||
|
|
||||||
|
inversion_hashes = ""
|
||||||
|
if self.inversions is not None:
|
||||||
|
inversion_pairs = [
|
||||||
|
(
|
||||||
|
name,
|
||||||
|
hash_file(
|
||||||
|
resolve_tensor(path.join(server.model_path, "inversion", name))
|
||||||
|
).upper(),
|
||||||
|
)
|
||||||
|
for name, _weight in self.inversions
|
||||||
|
]
|
||||||
|
inversion_hashes = ",".join(
|
||||||
|
[f"{name}: {hash}" for name, hash in inversion_pairs]
|
||||||
|
)
|
||||||
|
hash_map.update(dict(inversion_pairs))
|
||||||
|
|
||||||
|
lora_hashes = ""
|
||||||
|
if self.loras is not None:
|
||||||
|
lora_pairs = [
|
||||||
|
(
|
||||||
|
name,
|
||||||
|
hash_file(
|
||||||
|
resolve_tensor(path.join(server.model_path, "lora", name))
|
||||||
|
).upper(),
|
||||||
|
)
|
||||||
|
for name, _weight in self.loras
|
||||||
|
]
|
||||||
|
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
|
||||||
|
hash_map.update(dict(lora_pairs))
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"{self.params.prompt or ''}\nNegative prompt: {self.params.negative_prompt or ''}\n"
|
||||||
|
f"Steps: {self.params.steps}, Sampler: {self.params.scheduler}, CFG scale: {self.params.cfg}, "
|
||||||
|
f"Seed: {self.params.seed}, Size: {self.size.width}x{self.size.height}, "
|
||||||
|
f"Model hash: {model_hash}, Model: {model_name}, "
|
||||||
|
f"Tool: onnx-web, Version: {server.server_version}, "
|
||||||
|
f'Inversion hashes: "{inversion_hashes}", '
|
||||||
|
f'Lora hashes: "{lora_hashes}", '
|
||||||
|
f"Hashes: {dumps(hash_map)}"
|
||||||
|
)
|
||||||
|
|
||||||
def tojson(self, server, outputs):
|
def tojson(self, server, outputs):
|
||||||
return json_params(
|
json = {
|
||||||
server,
|
"input_size": self.size.tojson(),
|
||||||
outputs,
|
"outputs": outputs,
|
||||||
self.params,
|
"params": self.params.tojson(),
|
||||||
self.size,
|
"inversions": {},
|
||||||
upscale=self.upscale,
|
"loras": {},
|
||||||
border=self.border,
|
}
|
||||||
highres=self.highres,
|
|
||||||
inversions=self.inversions,
|
json["params"]["model"] = path.basename(self.params.model)
|
||||||
loras=self.loras,
|
json["params"]["scheduler"] = self.params.scheduler # TODO: why tho?
|
||||||
)
|
|
||||||
|
# calculate final output size
|
||||||
|
output_size = self.size
|
||||||
|
if self.border is not None:
|
||||||
|
json["border"] = self.border.tojson()
|
||||||
|
output_size = output_size.add_border(self.border)
|
||||||
|
|
||||||
|
if self.highres is not None:
|
||||||
|
json["highres"] = self.highres.tojson()
|
||||||
|
output_size = self.highres.resize(output_size)
|
||||||
|
|
||||||
|
if self.upscale is not None:
|
||||||
|
json["upscale"] = self.upscale.tojson()
|
||||||
|
output_size = self.upscale.resize(output_size)
|
||||||
|
|
||||||
|
json["size"] = output_size.tojson()
|
||||||
|
|
||||||
|
if self.inversions is not None:
|
||||||
|
for name, weight in self.inversions:
|
||||||
|
hash = hash_file(
|
||||||
|
resolve_tensor(path.join(server.model_path, "inversion", name))
|
||||||
|
).upper()
|
||||||
|
json["inversions"][name] = {"weight": weight, "hash": hash}
|
||||||
|
|
||||||
|
if self.loras is not None:
|
||||||
|
for name, weight in self.loras:
|
||||||
|
hash = hash_file(
|
||||||
|
resolve_tensor(path.join(server.model_path, "lora", name))
|
||||||
|
).upper()
|
||||||
|
json["loras"][name] = {"weight": weight, "hash": hash}
|
||||||
|
|
||||||
|
return json
|
||||||
|
|
||||||
|
|
||||||
class StageResult:
|
class StageResult:
|
||||||
|
@ -86,6 +195,7 @@ class StageResult:
|
||||||
self.arrays = arrays
|
self.arrays = arrays
|
||||||
self.images = images
|
self.images = images
|
||||||
self.source = source
|
self.source = source
|
||||||
|
self.metadata = []
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
if self.arrays is not None:
|
if self.arrays is not None:
|
||||||
|
@ -117,7 +227,7 @@ class StageResult:
|
||||||
elif self.images is not None:
|
elif self.images is not None:
|
||||||
self.images.append(Image.fromarray(np.uint8(array), shape_mode(array)))
|
self.images.append(Image.fromarray(np.uint8(array), shape_mode(array)))
|
||||||
else:
|
else:
|
||||||
raise ValueError("invalid stage result")
|
self.arrays = [array]
|
||||||
|
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
self.metadata.append(metadata)
|
self.metadata.append(metadata)
|
||||||
|
@ -130,13 +240,45 @@ class StageResult:
|
||||||
elif self.arrays is not None:
|
elif self.arrays is not None:
|
||||||
self.arrays.append(np.array(image))
|
self.arrays.append(np.array(image))
|
||||||
else:
|
else:
|
||||||
raise ValueError("invalid stage result")
|
self.images = [image]
|
||||||
|
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
self.metadata.append(metadata)
|
self.metadata.append(metadata)
|
||||||
else:
|
else:
|
||||||
self.metadata.append(ImageMetadata())
|
self.metadata.append(ImageMetadata())
|
||||||
|
|
||||||
|
def insert_array(
|
||||||
|
self, index: int, array: np.ndarray, metadata: Optional[ImageMetadata]
|
||||||
|
):
|
||||||
|
if self.arrays is not None:
|
||||||
|
self.arrays.insert(index, array)
|
||||||
|
elif self.images is not None:
|
||||||
|
self.images.insert(
|
||||||
|
index, Image.fromarray(np.uint8(array), shape_mode(array))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.arrays = [array]
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
self.metadata.insert(index, metadata)
|
||||||
|
else:
|
||||||
|
self.metadata.insert(index, ImageMetadata())
|
||||||
|
|
||||||
|
def insert_image(
|
||||||
|
self, index: int, image: Image.Image, metadata: Optional[ImageMetadata]
|
||||||
|
):
|
||||||
|
if self.images is not None:
|
||||||
|
self.images.insert(index, image)
|
||||||
|
elif self.arrays is not None:
|
||||||
|
self.arrays.insert(index, np.array(image))
|
||||||
|
else:
|
||||||
|
self.images = [image]
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
self.metadata.insert(index, metadata)
|
||||||
|
else:
|
||||||
|
self.metadata.insert(index, ImageMetadata())
|
||||||
|
|
||||||
|
|
||||||
def shape_mode(arr: np.ndarray) -> str:
|
def shape_mode(arr: np.ndarray) -> str:
|
||||||
if len(arr.shape) != 3:
|
if len(arr.shape) != 3:
|
||||||
|
|
|
@ -16,7 +16,7 @@ from ..chain.highres import stage_highres
|
||||||
from ..chain.result import StageResult
|
from ..chain.result import StageResult
|
||||||
from ..chain.upscale import split_upscale, stage_upscale_correction
|
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||||
from ..image import expand_image
|
from ..image import expand_image
|
||||||
from ..output import save_image
|
from ..output import save_image, save_result
|
||||||
from ..params import (
|
from ..params import (
|
||||||
Border,
|
Border,
|
||||||
HighresParams,
|
HighresParams,
|
||||||
|
@ -29,7 +29,7 @@ from ..server import ServerContext
|
||||||
from ..server.load import get_source_filters
|
from ..server.load import get_source_filters
|
||||||
from ..utils import is_debug, run_gc, show_system_toast
|
from ..utils import is_debug, run_gc, show_system_toast
|
||||||
from ..worker import WorkerContext
|
from ..worker import WorkerContext
|
||||||
from .utils import get_latents_from_seed, parse_prompt
|
from .utils import get_latents_from_seed
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -57,7 +57,6 @@ def run_txt2img_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
outputs: List[str],
|
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
highres: HighresParams,
|
highres: HighresParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -114,50 +113,34 @@ def run_txt2img_pipeline(
|
||||||
# run and save
|
# run and save
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain.run(
|
images = chain(
|
||||||
worker, server, params, StageResult.empty(), callback=progress, latents=latents
|
worker, server, params, StageResult.empty(), callback=progress, latents=latents
|
||||||
)
|
)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
|
||||||
|
|
||||||
# add a thumbnail, if requested
|
# add a thumbnail, if requested
|
||||||
cover = images[0]
|
cover = images.as_image()[0]
|
||||||
if params.thumbnail and (
|
if params.thumbnail and (
|
||||||
cover.width > server.thumbnail_size or cover.height > server.thumbnail_size
|
cover.width > server.thumbnail_size or cover.height > server.thumbnail_size
|
||||||
):
|
):
|
||||||
thumbnail = cover.copy()
|
thumbnail = cover.copy()
|
||||||
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
|
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
|
||||||
|
|
||||||
images.insert(0, thumbnail)
|
images.insert_image(0, thumbnail)
|
||||||
outputs.insert(0, f"{worker.name}-thumb.{server.image_format}")
|
|
||||||
|
|
||||||
for image, output in zip(images, outputs):
|
save_result(server, images, worker.job)
|
||||||
logger.trace("saving output image %s: %s", output, image.size)
|
|
||||||
dest = save_image(
|
|
||||||
server,
|
|
||||||
output,
|
|
||||||
image,
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
upscale=upscale,
|
|
||||||
highres=highres,
|
|
||||||
inversions=inversions,
|
|
||||||
loras=loras,
|
|
||||||
)
|
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
run_gc([worker.get_device()])
|
run_gc([worker.get_device()])
|
||||||
|
|
||||||
# notify the user
|
# notify the user
|
||||||
show_system_toast(f"finished txt2img job: {dest}")
|
show_system_toast(f"finished txt2img job: {worker.job}")
|
||||||
logger.info("finished txt2img job: %s", dest)
|
logger.info("finished txt2img job: %s", worker.job)
|
||||||
|
|
||||||
|
|
||||||
def run_img2img_pipeline(
|
def run_img2img_pipeline(
|
||||||
worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
outputs: List[str],
|
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
highres: HighresParams,
|
highres: HighresParams,
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
|
@ -228,36 +211,21 @@ def run_img2img_pipeline(
|
||||||
|
|
||||||
# run and append the filtered source
|
# run and append the filtered source
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain.run(
|
images = chain(
|
||||||
worker, server, params, StageResult(images=[source]), callback=progress
|
worker, server, params, StageResult(images=[source]), callback=progress
|
||||||
)
|
)
|
||||||
|
|
||||||
if source_filter is not None and source_filter != "none":
|
if source_filter is not None and source_filter != "none":
|
||||||
images.append(source)
|
images.push_image(source)
|
||||||
|
|
||||||
# save with metadata
|
save_result(server, images, worker.job)
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
|
||||||
size = Size(*source.size)
|
|
||||||
|
|
||||||
for image, output in zip(images, outputs):
|
|
||||||
dest = save_image(
|
|
||||||
server,
|
|
||||||
output,
|
|
||||||
image,
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
upscale=upscale,
|
|
||||||
highres=highres,
|
|
||||||
inversions=inversions,
|
|
||||||
loras=loras,
|
|
||||||
)
|
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
run_gc([worker.get_device()])
|
run_gc([worker.get_device()])
|
||||||
|
|
||||||
# notify the user
|
# notify the user
|
||||||
show_system_toast(f"finished img2img job: {dest}")
|
show_system_toast(f"finished img2img job: {worker.job}")
|
||||||
logger.info("finished img2img job: %s", dest)
|
logger.info("finished img2img job: %s", worker.job)
|
||||||
|
|
||||||
|
|
||||||
def run_inpaint_pipeline(
|
def run_inpaint_pipeline(
|
||||||
|
@ -265,7 +233,6 @@ def run_inpaint_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
outputs: List[str],
|
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
highres: HighresParams,
|
highres: HighresParams,
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
|
@ -290,7 +257,7 @@ def run_inpaint_pipeline(
|
||||||
mask = ImageOps.contain(mask, (mask_max, mask_max))
|
mask = ImageOps.contain(mask, (mask_max, mask_max))
|
||||||
mask = mask.crop((0, 0, source.width, source.height))
|
mask = mask.crop((0, 0, source.width, source.height))
|
||||||
|
|
||||||
source, mask, noise, full_size = expand_image(
|
source, mask, noise, _full_size = expand_image(
|
||||||
source,
|
source,
|
||||||
mask,
|
mask,
|
||||||
border,
|
border,
|
||||||
|
@ -414,7 +381,7 @@ def run_inpaint_pipeline(
|
||||||
# run and save
|
# run and save
|
||||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain.run(
|
images = chain(
|
||||||
worker,
|
worker,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
|
@ -423,33 +390,28 @@ def run_inpaint_pipeline(
|
||||||
latents=latents,
|
latents=latents,
|
||||||
)
|
)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
for i, image, metadata in enumerate(zip(images.as_image(), images.metadata)):
|
||||||
for image, output in zip(images, outputs):
|
|
||||||
if full_res_inpaint:
|
if full_res_inpaint:
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, "adjusted-output.png", image)
|
save_image(server, "adjusted-output.png", image)
|
||||||
|
|
||||||
mini_image = ImageOps.contain(image, (adj_mask_size, adj_mask_size))
|
mini_image = ImageOps.contain(image, (adj_mask_size, adj_mask_size))
|
||||||
image = original_source
|
image = original_source
|
||||||
image.paste(mini_image, box=adj_mask_border)
|
image.paste(mini_image, box=adj_mask_border)
|
||||||
dest = save_image(
|
|
||||||
|
save_image(
|
||||||
server,
|
server,
|
||||||
output,
|
f"{worker.job}_{i}.{server.image_format}",
|
||||||
image,
|
image,
|
||||||
params,
|
metadata,
|
||||||
size,
|
|
||||||
upscale=upscale,
|
|
||||||
border=border,
|
|
||||||
inversions=inversions,
|
|
||||||
loras=loras,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
del image
|
|
||||||
run_gc([worker.get_device()])
|
run_gc([worker.get_device()])
|
||||||
|
|
||||||
# notify the user
|
# notify the user
|
||||||
show_system_toast(f"finished inpaint job: {dest}")
|
show_system_toast(f"finished inpaint job: {worker.job}")
|
||||||
logger.info("finished inpaint job: %s", dest)
|
logger.info("finished inpaint job: %s", worker.job)
|
||||||
|
|
||||||
|
|
||||||
def run_upscale_pipeline(
|
def run_upscale_pipeline(
|
||||||
|
@ -457,7 +419,6 @@ def run_upscale_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
outputs: List[str],
|
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
highres: HighresParams,
|
highres: HighresParams,
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
|
@ -497,30 +458,18 @@ def run_upscale_pipeline(
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain.run(
|
images = chain(
|
||||||
worker, server, params, StageResult(images=[source]), callback=progress
|
worker, server, params, StageResult(images=[source]), callback=progress
|
||||||
)
|
)
|
||||||
|
|
||||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
save_result(server, images, worker.job)
|
||||||
for image, output in zip(images, outputs):
|
|
||||||
dest = save_image(
|
|
||||||
server,
|
|
||||||
output,
|
|
||||||
image,
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
upscale=upscale,
|
|
||||||
inversions=inversions,
|
|
||||||
loras=loras,
|
|
||||||
)
|
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
del image
|
|
||||||
run_gc([worker.get_device()])
|
run_gc([worker.get_device()])
|
||||||
|
|
||||||
# notify the user
|
# notify the user
|
||||||
show_system_toast(f"finished upscale job: {dest}")
|
show_system_toast(f"finished upscale job: {worker.job}")
|
||||||
logger.info("finished upscale job: %s", dest)
|
logger.info("finished upscale job: %s", worker.job)
|
||||||
|
|
||||||
|
|
||||||
def run_blend_pipeline(
|
def run_blend_pipeline(
|
||||||
|
@ -528,7 +477,6 @@ def run_blend_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
size: Size,
|
size: Size,
|
||||||
outputs: List[str],
|
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
# highres: HighresParams,
|
# highres: HighresParams,
|
||||||
sources: List[Image.Image],
|
sources: List[Image.Image],
|
||||||
|
@ -559,17 +507,15 @@ def run_blend_pipeline(
|
||||||
|
|
||||||
# run and save
|
# run and save
|
||||||
progress = worker.get_progress_callback()
|
progress = worker.get_progress_callback()
|
||||||
images = chain.run(
|
images = chain(
|
||||||
worker, server, params, StageResult(images=sources), callback=progress
|
worker, server, params, StageResult(images=sources), callback=progress
|
||||||
)
|
)
|
||||||
|
|
||||||
for image, output in zip(images, outputs):
|
save_result(server, images, worker.job)
|
||||||
dest = save_image(server, output, image, params, size, upscale=upscale)
|
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
del image
|
|
||||||
run_gc([worker.get_device()])
|
run_gc([worker.get_device()])
|
||||||
|
|
||||||
# notify the user
|
# notify the user
|
||||||
show_system_toast(f"finished blend job: {dest}")
|
show_system_toast(f"finished blend job: {worker.job}")
|
||||||
logger.info("finished blend job: %s", dest)
|
logger.info("finished blend job: %s", worker.job)
|
||||||
|
|
|
@ -1,173 +1,20 @@
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from json import dumps
|
from json import dumps
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
|
||||||
from struct import pack
|
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import List, Optional
|
||||||
|
|
||||||
from piexif import ExifIFD, ImageIFD, dump
|
from piexif import ExifIFD, ImageIFD, dump
|
||||||
from piexif.helper import UserComment
|
from piexif.helper import UserComment
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
|
|
||||||
from .convert.utils import resolve_tensor
|
from .chain.result import ImageMetadata, StageResult
|
||||||
from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams
|
from .params import ImageParams, Param, Size
|
||||||
from .server import ServerContext
|
from .server import ServerContext
|
||||||
from .server.load import get_extra_hashes
|
from .utils import base_join, hash_value
|
||||||
from .utils import base_join
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
HASH_BUFFER_SIZE = 2**22 # 4MB
|
|
||||||
|
|
||||||
|
|
||||||
def hash_file(name: str):
|
|
||||||
sha = sha256()
|
|
||||||
with open(name, "rb") as f:
|
|
||||||
while True:
|
|
||||||
data = f.read(HASH_BUFFER_SIZE)
|
|
||||||
if not data:
|
|
||||||
break
|
|
||||||
|
|
||||||
sha.update(data)
|
|
||||||
|
|
||||||
return sha.hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def hash_value(sha, param: Optional[Param]):
|
|
||||||
if param is None:
|
|
||||||
return
|
|
||||||
elif isinstance(param, bool):
|
|
||||||
sha.update(bytearray(pack("!B", param)))
|
|
||||||
elif isinstance(param, float):
|
|
||||||
sha.update(bytearray(pack("!f", param)))
|
|
||||||
elif isinstance(param, int):
|
|
||||||
sha.update(bytearray(pack("!I", param)))
|
|
||||||
elif isinstance(param, str):
|
|
||||||
sha.update(param.encode("utf-8"))
|
|
||||||
else:
|
|
||||||
logger.warning("cannot hash param: %s, %s", param, type(param))
|
|
||||||
|
|
||||||
|
|
||||||
def json_params(
|
|
||||||
server: ServerContext,
|
|
||||||
outputs: List[str],
|
|
||||||
params: ImageParams,
|
|
||||||
size: Size,
|
|
||||||
upscale: Optional[UpscaleParams] = None,
|
|
||||||
border: Optional[Border] = None,
|
|
||||||
highres: Optional[HighresParams] = None,
|
|
||||||
inversions: Optional[List[Tuple[str, float]]] = None,
|
|
||||||
loras: Optional[List[Tuple[str, float]]] = None,
|
|
||||||
parent: Optional[Dict] = None,
|
|
||||||
) -> Any:
|
|
||||||
json = {
|
|
||||||
"input_size": size.tojson(),
|
|
||||||
"outputs": outputs,
|
|
||||||
"params": params.tojson(),
|
|
||||||
"inversions": {},
|
|
||||||
"loras": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
json["params"]["model"] = path.basename(params.model)
|
|
||||||
json["params"]["scheduler"] = params.scheduler
|
|
||||||
|
|
||||||
# calculate final output size
|
|
||||||
output_size = size
|
|
||||||
if border is not None:
|
|
||||||
json["border"] = border.tojson()
|
|
||||||
output_size = output_size.add_border(border)
|
|
||||||
|
|
||||||
if highres is not None:
|
|
||||||
json["highres"] = highres.tojson()
|
|
||||||
output_size = highres.resize(output_size)
|
|
||||||
|
|
||||||
if upscale is not None:
|
|
||||||
json["upscale"] = upscale.tojson()
|
|
||||||
output_size = upscale.resize(output_size)
|
|
||||||
|
|
||||||
json["size"] = output_size.tojson()
|
|
||||||
|
|
||||||
if inversions is not None:
|
|
||||||
for name, weight in inversions:
|
|
||||||
hash = hash_file(
|
|
||||||
resolve_tensor(path.join(server.model_path, "inversion", name))
|
|
||||||
).upper()
|
|
||||||
json["inversions"][name] = {"weight": weight, "hash": hash}
|
|
||||||
|
|
||||||
if loras is not None:
|
|
||||||
for name, weight in loras:
|
|
||||||
hash = hash_file(
|
|
||||||
resolve_tensor(path.join(server.model_path, "lora", name))
|
|
||||||
).upper()
|
|
||||||
json["loras"][name] = {"weight": weight, "hash": hash}
|
|
||||||
|
|
||||||
return json
|
|
||||||
|
|
||||||
|
|
||||||
def str_params(
|
|
||||||
server: ServerContext,
|
|
||||||
params: ImageParams,
|
|
||||||
size: Size,
|
|
||||||
inversions: List[Tuple[str, float]] = None,
|
|
||||||
loras: List[Tuple[str, float]] = None,
|
|
||||||
) -> str:
|
|
||||||
model_name = path.basename(path.normpath(params.model))
|
|
||||||
logger.debug("getting model hash for %s", model_name)
|
|
||||||
|
|
||||||
model_hash = get_extra_hashes().get(model_name, None)
|
|
||||||
if model_hash is None:
|
|
||||||
model_hash_path = path.join(params.model, "hash.txt")
|
|
||||||
if path.exists(model_hash_path):
|
|
||||||
with open(model_hash_path, "r") as f:
|
|
||||||
model_hash = f.readline().rstrip(",. \n\t\r")
|
|
||||||
|
|
||||||
model_hash = model_hash or "unknown"
|
|
||||||
hash_map = {
|
|
||||||
model_name: model_hash,
|
|
||||||
}
|
|
||||||
|
|
||||||
inversion_hashes = ""
|
|
||||||
if inversions is not None:
|
|
||||||
inversion_pairs = [
|
|
||||||
(
|
|
||||||
name,
|
|
||||||
hash_file(
|
|
||||||
resolve_tensor(path.join(server.model_path, "inversion", name))
|
|
||||||
).upper(),
|
|
||||||
)
|
|
||||||
for name, _weight in inversions
|
|
||||||
]
|
|
||||||
inversion_hashes = ",".join(
|
|
||||||
[f"{name}: {hash}" for name, hash in inversion_pairs]
|
|
||||||
)
|
|
||||||
hash_map.update(dict(inversion_pairs))
|
|
||||||
|
|
||||||
lora_hashes = ""
|
|
||||||
if loras is not None:
|
|
||||||
lora_pairs = [
|
|
||||||
(
|
|
||||||
name,
|
|
||||||
hash_file(
|
|
||||||
resolve_tensor(path.join(server.model_path, "lora", name))
|
|
||||||
).upper(),
|
|
||||||
)
|
|
||||||
for name, _weight in loras
|
|
||||||
]
|
|
||||||
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
|
|
||||||
hash_map.update(dict(lora_pairs))
|
|
||||||
|
|
||||||
return (
|
|
||||||
f"{params.prompt or ''}\nNegative prompt: {params.negative_prompt or ''}\n"
|
|
||||||
f"Steps: {params.steps}, Sampler: {params.scheduler}, CFG scale: {params.cfg}, "
|
|
||||||
f"Seed: {params.seed}, Size: {size.width}x{size.height}, "
|
|
||||||
f"Model hash: {model_hash}, Model: {model_name}, "
|
|
||||||
f"Tool: onnx-web, Version: {server.server_version}, "
|
|
||||||
f'Inversion hashes: "{inversion_hashes}", '
|
|
||||||
f'Lora hashes: "{lora_hashes}", '
|
|
||||||
f"Hashes: {dumps(hash_map)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_output_name(
|
def make_output_name(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
|
@ -179,6 +26,19 @@ def make_output_name(
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
count = count or params.batch
|
count = count or params.batch
|
||||||
|
job_name = make_job_name(mode, params, size, extras)
|
||||||
|
|
||||||
|
return [
|
||||||
|
f"{job_name}_{i}.{server.image_format}" for i in range(offset, count + offset)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def make_job_name(
|
||||||
|
mode: str,
|
||||||
|
params: ImageParams,
|
||||||
|
size: Size,
|
||||||
|
extras: Optional[List[Optional[Param]]] = None,
|
||||||
|
) -> str:
|
||||||
now = int(time())
|
now = int(time())
|
||||||
sha = sha256()
|
sha = sha256()
|
||||||
|
|
||||||
|
@ -200,49 +60,49 @@ def make_output_name(
|
||||||
for param in extras:
|
for param in extras:
|
||||||
hash_value(sha, param)
|
hash_value(sha, param)
|
||||||
|
|
||||||
return [
|
return f"{mode}_{params.seed}_{sha.hexdigest()}_{now}"
|
||||||
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
|
|
||||||
for i in range(offset, count + offset)
|
|
||||||
]
|
def save_result(
|
||||||
|
server: ServerContext,
|
||||||
|
result: StageResult,
|
||||||
|
base_name: str,
|
||||||
|
) -> List[str]:
|
||||||
|
results = []
|
||||||
|
for i, image, metadata in enumerate(zip(result.as_image(), result.metadata)):
|
||||||
|
results.append(
|
||||||
|
save_image(
|
||||||
|
server,
|
||||||
|
base_name + f"_{i}.{server.image_format}",
|
||||||
|
image,
|
||||||
|
metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def save_image(
|
def save_image(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
output: str,
|
output: str,
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
params: Optional[ImageParams] = None,
|
metadata: ImageMetadata,
|
||||||
size: Optional[Size] = None,
|
|
||||||
upscale: Optional[UpscaleParams] = None,
|
|
||||||
border: Optional[Border] = None,
|
|
||||||
highres: Optional[HighresParams] = None,
|
|
||||||
inversions: List[Tuple[str, float]] = None,
|
|
||||||
loras: List[Tuple[str, float]] = None,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
path = base_join(server.output_path, output)
|
path = base_join(server.output_path, output)
|
||||||
|
|
||||||
if server.image_format == "png":
|
if server.image_format == "png":
|
||||||
exif = PngImagePlugin.PngInfo()
|
exif = PngImagePlugin.PngInfo()
|
||||||
|
|
||||||
if params is not None:
|
if metadata is not None:
|
||||||
exif.add_text("make", "onnx-web")
|
exif.add_text("make", "onnx-web")
|
||||||
exif.add_text(
|
exif.add_text(
|
||||||
"maker note",
|
"maker note",
|
||||||
dumps(
|
dumps(metadata.tojson(server, [output])),
|
||||||
json_params(
|
|
||||||
server,
|
|
||||||
[output],
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
upscale=upscale,
|
|
||||||
border=border,
|
|
||||||
highres=highres,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
exif.add_text("model", server.server_version)
|
exif.add_text("model", server.server_version)
|
||||||
exif.add_text(
|
exif.add_text(
|
||||||
"parameters",
|
"parameters",
|
||||||
str_params(server, params, size, inversions=inversions, loras=loras),
|
metadata.to_auto1111(server, [output]),
|
||||||
)
|
)
|
||||||
|
|
||||||
image.save(path, format=server.image_format, pnginfo=exif)
|
image.save(path, format=server.image_format, pnginfo=exif)
|
||||||
|
@ -251,23 +111,11 @@ def save_image(
|
||||||
{
|
{
|
||||||
"0th": {
|
"0th": {
|
||||||
ExifIFD.MakerNote: UserComment.dump(
|
ExifIFD.MakerNote: UserComment.dump(
|
||||||
dumps(
|
dumps(metadata.tojson(server, [output])),
|
||||||
json_params(
|
|
||||||
server,
|
|
||||||
[output],
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
upscale=upscale,
|
|
||||||
border=border,
|
|
||||||
highres=highres,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
encoding="unicode",
|
encoding="unicode",
|
||||||
),
|
),
|
||||||
ExifIFD.UserComment: UserComment.dump(
|
ExifIFD.UserComment: UserComment.dump(
|
||||||
str_params(
|
metadata.to_auto1111(server, [output]),
|
||||||
server, params, size, inversions=inversions, loras=loras
|
|
||||||
),
|
|
||||||
encoding="unicode",
|
encoding="unicode",
|
||||||
),
|
),
|
||||||
ImageIFD.Make: "onnx-web",
|
ImageIFD.Make: "onnx-web",
|
||||||
|
@ -277,34 +125,23 @@ def save_image(
|
||||||
)
|
)
|
||||||
image.save(path, format=server.image_format, exif=exif)
|
image.save(path, format=server.image_format, exif=exif)
|
||||||
|
|
||||||
if params is not None:
|
if metadata is not None:
|
||||||
save_params(
|
save_metadata(
|
||||||
server,
|
server,
|
||||||
output,
|
output,
|
||||||
params,
|
|
||||||
size,
|
|
||||||
upscale=upscale,
|
|
||||||
border=border,
|
|
||||||
highres=highres,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("saved output image to: %s", path)
|
logger.debug("saved output image to: %s", path)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
def save_params(
|
def save_metadata(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
output: str,
|
output: str,
|
||||||
params: ImageParams,
|
metadata: ImageMetadata,
|
||||||
size: Size,
|
|
||||||
upscale: Optional[UpscaleParams] = None,
|
|
||||||
border: Optional[Border] = None,
|
|
||||||
highres: Optional[HighresParams] = None,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
path = base_join(server.output_path, f"{output}.json")
|
path = base_join(server.output_path, f"{output}.json")
|
||||||
json = json_params(
|
json = metadata.tojson(server, [output])
|
||||||
server, output, params, size, upscale=upscale, border=border, highres=highres
|
|
||||||
)
|
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
f.write(dumps(json))
|
f.write(dumps(json))
|
||||||
logger.debug("saved image params to: %s", path)
|
logger.debug("saved image params to: %s", path)
|
||||||
|
|
|
@ -13,6 +13,24 @@ Param = Union[str, int, float]
|
||||||
Point = Tuple[int, int]
|
Point = Tuple[int, int]
|
||||||
|
|
||||||
|
|
||||||
|
class Progress:
|
||||||
|
current: int
|
||||||
|
total: int
|
||||||
|
|
||||||
|
def __init__(self, current: int, total: int) -> None:
|
||||||
|
self.current = current
|
||||||
|
self.total = total
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return "%s/%s" % (self.current, self.total)
|
||||||
|
|
||||||
|
def tojson(self):
|
||||||
|
return {
|
||||||
|
"current": self.current,
|
||||||
|
"total": self.total,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class SizeChart(IntEnum):
|
class SizeChart(IntEnum):
|
||||||
micro = 64
|
micro = 64
|
||||||
mini = 128 # small tile for very expensive models
|
mini = 128 # small tile for very expensive models
|
||||||
|
|
|
@ -26,14 +26,14 @@ def restart_workers(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
pool.recycle(recycle_all=True)
|
pool.recycle(recycle_all=True)
|
||||||
logger.info("restarted worker pool")
|
logger.info("restarted worker pool")
|
||||||
|
|
||||||
return jsonify(pool.status())
|
return jsonify(pool.summary())
|
||||||
|
|
||||||
|
|
||||||
def worker_status(server: ServerContext, pool: DevicePoolExecutor):
|
def worker_status(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
if not check_admin(server):
|
if not check_admin(server):
|
||||||
return make_response(jsonify({})), 401
|
return make_response(jsonify({})), 401
|
||||||
|
|
||||||
return jsonify(pool.status())
|
return jsonify(pool.summary())
|
||||||
|
|
||||||
|
|
||||||
def get_extra_models(server: ServerContext):
|
def get_extra_models(server: ServerContext):
|
||||||
|
@ -102,8 +102,8 @@ def register_admin_routes(app: Flask, server: ServerContext, pool: DevicePoolExe
|
||||||
app.route("/api/extras", methods=["PUT"])(
|
app.route("/api/extras", methods=["PUT"])(
|
||||||
wrap_route(update_extra_models, server)
|
wrap_route(update_extra_models, server)
|
||||||
),
|
),
|
||||||
app.route("/api/restart", methods=["POST"])(
|
app.route("/api/worker/restart", methods=["POST"])(
|
||||||
wrap_route(restart_workers, server, pool=pool)
|
wrap_route(restart_workers, server, pool=pool)
|
||||||
),
|
),
|
||||||
app.route("/api/status")(wrap_route(worker_status, server, pool=pool)),
|
app.route("/api/worker/status")(wrap_route(worker_status, server, pool=pool)),
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,14 +1,14 @@
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from flask import Flask, jsonify, make_response, request, url_for
|
from flask import Flask, jsonify, make_response, request, url_for
|
||||||
from jsonschema import validate
|
from jsonschema import validate
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ..chain import CHAIN_STAGES, ChainPipeline
|
from ..chain import CHAIN_STAGES, ChainPipeline
|
||||||
from ..chain.result import StageResult
|
from ..chain.result import ImageMetadata, StageResult
|
||||||
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
|
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
|
||||||
from ..diffusers.run import (
|
from ..diffusers.run import (
|
||||||
run_blend_pipeline,
|
run_blend_pipeline,
|
||||||
|
@ -18,8 +18,8 @@ from ..diffusers.run import (
|
||||||
run_upscale_pipeline,
|
run_upscale_pipeline,
|
||||||
)
|
)
|
||||||
from ..diffusers.utils import replace_wildcards
|
from ..diffusers.utils import replace_wildcards
|
||||||
from ..output import json_params, make_output_name
|
from ..output import make_job_name
|
||||||
from ..params import Size, StageParams, TileOrder
|
from ..params import Progress, Size, StageParams, TileOrder
|
||||||
from ..transformers.run import run_txt2txt_pipeline
|
from ..transformers.run import run_txt2txt_pipeline
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
base_join,
|
base_join,
|
||||||
|
@ -34,6 +34,7 @@ from ..utils import (
|
||||||
load_config_str,
|
load_config_str,
|
||||||
sanitize_name,
|
sanitize_name,
|
||||||
)
|
)
|
||||||
|
from ..worker.command import JobType
|
||||||
from ..worker.pool import DevicePoolExecutor
|
from ..worker.pool import DevicePoolExecutor
|
||||||
from .context import ServerContext
|
from .context import ServerContext
|
||||||
from .load import (
|
from .load import (
|
||||||
|
@ -92,6 +93,64 @@ def error_reply(err: str):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def job_reply(name: str):
|
||||||
|
return jsonify(
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def image_reply(
|
||||||
|
name: str,
|
||||||
|
status: str,
|
||||||
|
job_type: str,
|
||||||
|
stages: Progress = None,
|
||||||
|
steps: Progress = None,
|
||||||
|
tiles: Progress = None,
|
||||||
|
outputs: List[str] = None,
|
||||||
|
metadata: List[ImageMetadata] = None,
|
||||||
|
):
|
||||||
|
if stages is None:
|
||||||
|
stages = Progress()
|
||||||
|
|
||||||
|
if steps is None:
|
||||||
|
steps = Progress()
|
||||||
|
|
||||||
|
if tiles is None:
|
||||||
|
tiles = Progress()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"name": name,
|
||||||
|
"status": status,
|
||||||
|
"type": job_type,
|
||||||
|
"stages": stages.tojson(),
|
||||||
|
"steps": steps.tojson(),
|
||||||
|
"tiles": tiles.tojson(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(metadata) != len(outputs):
|
||||||
|
logger.error("metadata and outputs must be the same length")
|
||||||
|
return error_reply("metadata and outputs must be the same length")
|
||||||
|
|
||||||
|
if outputs is not None:
|
||||||
|
data["outputs"] = outputs
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
data["metadata"] = metadata
|
||||||
|
|
||||||
|
return jsonify(data)
|
||||||
|
|
||||||
|
|
||||||
|
def multi_image_reply(results: Dict[str, Any]):
|
||||||
|
# TODO: not that
|
||||||
|
return jsonify(
|
||||||
|
{
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def url_from_rule(rule) -> str:
|
def url_from_rule(rule) -> str:
|
||||||
options = {}
|
options = {}
|
||||||
for arg in rule.arguments:
|
for arg in rule.arguments:
|
||||||
|
@ -197,17 +256,15 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
)
|
)
|
||||||
output_count += 1
|
output_count += 1
|
||||||
|
|
||||||
output = make_output_name(
|
job_name = make_job_name(
|
||||||
server, "img2img", params, size, extras=[strength], count=output_count
|
server, "img2img", params, size, extras=[strength], count=output_count
|
||||||
)
|
)
|
||||||
|
|
||||||
job_name = output[0]
|
|
||||||
pool.submit(
|
pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
|
JobType.IMG2IMG,
|
||||||
run_img2img_pipeline,
|
run_img2img_pipeline,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
output,
|
|
||||||
upscale,
|
upscale,
|
||||||
highres,
|
highres,
|
||||||
source,
|
source,
|
||||||
|
@ -218,9 +275,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("img2img job queued for: %s", job_name)
|
logger.info("img2img job queued for: %s", job_name)
|
||||||
|
|
||||||
return jsonify(
|
return job_reply(job_name)
|
||||||
json_params(server, output, params, size, upscale=upscale, highres=highres)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -230,16 +285,15 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
replace_wildcards(params, get_wildcard_data())
|
replace_wildcards(params, get_wildcard_data())
|
||||||
|
|
||||||
output = make_output_name(server, "txt2img", params, size, count=params.batch)
|
job_name = make_job_name(server, "txt2img", params, size, count=params.batch)
|
||||||
|
|
||||||
job_name = output[0]
|
|
||||||
pool.submit(
|
pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
|
JobType.TXT2IMG,
|
||||||
run_txt2img_pipeline,
|
run_txt2img_pipeline,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
output,
|
|
||||||
upscale,
|
upscale,
|
||||||
highres,
|
highres,
|
||||||
needs_device=device,
|
needs_device=device,
|
||||||
|
@ -247,9 +301,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("txt2img job queued for: %s", job_name)
|
logger.info("txt2img job queued for: %s", job_name)
|
||||||
|
|
||||||
return jsonify(
|
return job_reply(job_name)
|
||||||
json_params(server, output, params, size, upscale=upscale, highres=highres)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -295,7 +347,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
replace_wildcards(params, get_wildcard_data())
|
replace_wildcards(params, get_wildcard_data())
|
||||||
|
|
||||||
output = make_output_name(
|
job_name = make_job_name(
|
||||||
server,
|
server,
|
||||||
"inpaint",
|
"inpaint",
|
||||||
params,
|
params,
|
||||||
|
@ -312,14 +364,13 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
job_name = output[0]
|
|
||||||
pool.submit(
|
pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
|
JobType.INPAINT,
|
||||||
run_inpaint_pipeline,
|
run_inpaint_pipeline,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
output,
|
|
||||||
upscale,
|
upscale,
|
||||||
highres,
|
highres,
|
||||||
source,
|
source,
|
||||||
|
@ -336,17 +387,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("inpaint job queued for: %s", job_name)
|
logger.info("inpaint job queued for: %s", job_name)
|
||||||
|
|
||||||
return jsonify(
|
return job_reply(job_name)
|
||||||
json_params(
|
|
||||||
server,
|
|
||||||
output,
|
|
||||||
params,
|
|
||||||
size,
|
|
||||||
upscale=upscale,
|
|
||||||
border=expand,
|
|
||||||
highres=highres,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -362,16 +403,14 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
replace_wildcards(params, get_wildcard_data())
|
replace_wildcards(params, get_wildcard_data())
|
||||||
|
|
||||||
output = make_output_name(server, "upscale", params, size)
|
job_name = make_job_name(server, "upscale", params, size)
|
||||||
|
|
||||||
job_name = output[0]
|
|
||||||
pool.submit(
|
pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
|
JobType.UPSCALE,
|
||||||
run_upscale_pipeline,
|
run_upscale_pipeline,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
output,
|
|
||||||
upscale,
|
upscale,
|
||||||
highres,
|
highres,
|
||||||
source,
|
source,
|
||||||
|
@ -380,9 +419,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
|
||||||
return jsonify(
|
return job_reply(job_name)
|
||||||
json_params(server, output, params, size, upscale=upscale, highres=highres)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# keys that are specially parsed by params and should not show up in with_args
|
# keys that are specially parsed by params and should not show up in with_args
|
||||||
|
@ -478,25 +515,21 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
||||||
|
|
||||||
output = make_output_name(
|
job_name = make_job_name(server, "chain", base_params, base_size)
|
||||||
server, "chain", base_params, base_size, count=pipeline.outputs(base_params, 0)
|
|
||||||
)
|
|
||||||
job_name = output[0]
|
|
||||||
|
|
||||||
# build and run chain pipeline
|
# build and run chain pipeline
|
||||||
pool.submit(
|
pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
|
JobType.CHAIN,
|
||||||
pipeline,
|
pipeline,
|
||||||
server,
|
server,
|
||||||
base_params,
|
base_params,
|
||||||
StageResult.empty(),
|
StageResult.empty(),
|
||||||
output=output,
|
|
||||||
size=base_size,
|
size=base_size,
|
||||||
needs_device=device,
|
needs_device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size))
|
return job_reply(job_name)
|
||||||
return jsonify(json_params(server, output, step_params, base_size))
|
|
||||||
|
|
||||||
|
|
||||||
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -520,15 +553,14 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
device, params, size = pipeline_from_request(server)
|
device, params, size = pipeline_from_request(server)
|
||||||
upscale = build_upscale()
|
upscale = build_upscale()
|
||||||
|
|
||||||
output = make_output_name(server, "upscale", params, size)
|
job_name = make_job_name(server, "blend", params, size)
|
||||||
job_name = output[0]
|
|
||||||
pool.submit(
|
pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
|
JobType.BLEND,
|
||||||
run_blend_pipeline,
|
run_blend_pipeline,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
output,
|
|
||||||
upscale,
|
upscale,
|
||||||
# TODO: highres
|
# TODO: highres
|
||||||
sources,
|
sources,
|
||||||
|
@ -538,27 +570,26 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
|
||||||
return jsonify(json_params(server, output, params, size, upscale=upscale))
|
return job_reply(job_name)
|
||||||
|
|
||||||
|
|
||||||
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
device, params, size = pipeline_from_request(server)
|
device, params, size = pipeline_from_request(server)
|
||||||
|
|
||||||
output = make_output_name(server, "txt2txt", params, size)
|
job_name = make_job_name(server, "txt2txt", params, size)
|
||||||
job_name = output[0]
|
|
||||||
logger.info("upscale job queued for: %s", job_name)
|
logger.info("upscale job queued for: %s", job_name)
|
||||||
|
|
||||||
pool.submit(
|
pool.submit(
|
||||||
job_name,
|
job_name,
|
||||||
|
JobType.TXT2TXT,
|
||||||
run_txt2txt_pipeline,
|
run_txt2txt_pipeline,
|
||||||
server,
|
server,
|
||||||
params,
|
params,
|
||||||
size,
|
size,
|
||||||
output,
|
|
||||||
needs_device=device,
|
needs_device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
return jsonify(json_params(server, output, params, size))
|
return job_reply(job_name)
|
||||||
|
|
||||||
|
|
||||||
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -601,9 +632,64 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def job_cancel(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
legacy_job_name = request.args.get("job", None)
|
||||||
|
job_list = request.args.get("jobs", "").split(",")
|
||||||
|
|
||||||
|
if legacy_job_name is not None:
|
||||||
|
job_list.append(legacy_job_name)
|
||||||
|
|
||||||
|
if len(job_list) == 0:
|
||||||
|
return error_reply("at least one job name is required")
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
for job_name in job_list:
|
||||||
|
job_name = sanitize_name(job_name)
|
||||||
|
cancelled = pool.cancel(job_name)
|
||||||
|
results[job_name] = cancelled
|
||||||
|
|
||||||
|
return multi_image_reply(results)
|
||||||
|
|
||||||
|
|
||||||
|
def job_status(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
legacy_job_name = request.args.get("job", None)
|
||||||
|
job_list = request.args.get("jobs", "").split(",")
|
||||||
|
|
||||||
|
if legacy_job_name is not None:
|
||||||
|
job_list.append(legacy_job_name)
|
||||||
|
|
||||||
|
if len(job_list) == 0:
|
||||||
|
return error_reply("at least one job name is required")
|
||||||
|
|
||||||
|
for job_name in job_list:
|
||||||
|
job_name = sanitize_name(job_name)
|
||||||
|
status, progress = pool.status(job_name)
|
||||||
|
|
||||||
|
# TODO: accumulate results
|
||||||
|
if progress is not None:
|
||||||
|
# TODO: add output paths based on progress.results counter
|
||||||
|
return image_reply(
|
||||||
|
job_name,
|
||||||
|
status,
|
||||||
|
"TODO",
|
||||||
|
stages=Progress(progress.stages, 0),
|
||||||
|
steps=Progress(progress.steps, 0),
|
||||||
|
tiles=Progress(progress.tiles, 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_reply(job_name, status, "TODO")
|
||||||
|
|
||||||
|
|
||||||
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
|
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
|
||||||
return [
|
return [
|
||||||
app.route("/api")(wrap_route(introspect, server, app=app)),
|
app.route("/api")(wrap_route(introspect, server, app=app)),
|
||||||
|
# job routes
|
||||||
|
app.route("/api/job", methods=["POST"])(wrap_route(chain, server, pool=pool)),
|
||||||
|
app.route("/api/job/cancel", methods=["PUT"])(
|
||||||
|
wrap_route(job_cancel, server, pool=pool)
|
||||||
|
),
|
||||||
|
app.route("/api/job/status")(wrap_route(job_status, server, pool=pool)),
|
||||||
|
# settings routes
|
||||||
app.route("/api/settings/filters")(wrap_route(list_filters, server)),
|
app.route("/api/settings/filters")(wrap_route(list_filters, server)),
|
||||||
app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)),
|
app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)),
|
||||||
app.route("/api/settings/models")(wrap_route(list_models, server)),
|
app.route("/api/settings/models")(wrap_route(list_models, server)),
|
||||||
|
@ -614,6 +700,7 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu
|
||||||
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)),
|
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)),
|
||||||
app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)),
|
app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)),
|
||||||
app.route("/api/settings/wildcards")(wrap_route(list_wildcards, server)),
|
app.route("/api/settings/wildcards")(wrap_route(list_wildcards, server)),
|
||||||
|
# legacy job routes
|
||||||
app.route("/api/img2img", methods=["POST"])(
|
app.route("/api/img2img", methods=["POST"])(
|
||||||
wrap_route(img2img, server, pool=pool)
|
wrap_route(img2img, server, pool=pool)
|
||||||
),
|
),
|
||||||
|
@ -631,6 +718,7 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu
|
||||||
),
|
),
|
||||||
app.route("/api/chain", methods=["POST"])(wrap_route(chain, server, pool=pool)),
|
app.route("/api/chain", methods=["POST"])(wrap_route(chain, server, pool=pool)),
|
||||||
app.route("/api/blend", methods=["POST"])(wrap_route(blend, server, pool=pool)),
|
app.route("/api/blend", methods=["POST"])(wrap_route(blend, server, pool=pool)),
|
||||||
|
# deprecated routes
|
||||||
app.route("/api/cancel", methods=["PUT"])(
|
app.route("/api/cancel", methods=["PUT"])(
|
||||||
wrap_route(cancel, server, pool=pool)
|
wrap_route(cancel, server, pool=pool)
|
||||||
),
|
),
|
||||||
|
|
|
@ -12,7 +12,6 @@ def run_txt2txt_pipeline(
|
||||||
_server: ServerContext,
|
_server: ServerContext,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
_size: Size,
|
_size: Size,
|
||||||
output: str,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
from transformers import AutoTokenizer, GPTJForCausalLM
|
from transformers import AutoTokenizer, GPTJForCausalLM
|
||||||
|
|
||||||
|
@ -38,4 +37,4 @@ def run_txt2txt_pipeline(
|
||||||
|
|
||||||
print("Server says: %s" % result_text)
|
print("Server says: %s" % result_text)
|
||||||
|
|
||||||
logger.info("finished txt2txt job: %s", output)
|
logger.info("finished txt2txt job: %s", worker.job)
|
||||||
|
|
|
@ -2,16 +2,18 @@ import gc
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import threading
|
import threading
|
||||||
|
from hashlib import sha256
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import environ, path
|
from os import environ, path
|
||||||
from platform import system
|
from platform import system
|
||||||
|
from struct import pack
|
||||||
from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
|
from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
|
|
||||||
from .params import DeviceParams, SizeChart
|
from .params import DeviceParams, Param, SizeChart
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -218,3 +220,34 @@ def load_config_str(raw: str) -> Dict:
|
||||||
return json.loads(raw)
|
return json.loads(raw)
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
return safe_load(raw)
|
return safe_load(raw)
|
||||||
|
|
||||||
|
|
||||||
|
HASH_BUFFER_SIZE = 2**22 # 4MB
|
||||||
|
|
||||||
|
|
||||||
|
def hash_file(name: str):
|
||||||
|
sha = sha256()
|
||||||
|
with open(name, "rb") as f:
|
||||||
|
while True:
|
||||||
|
data = f.read(HASH_BUFFER_SIZE)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
|
||||||
|
sha.update(data)
|
||||||
|
|
||||||
|
return sha.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def hash_value(sha, param: Optional[Param]):
|
||||||
|
if param is None:
|
||||||
|
return
|
||||||
|
elif isinstance(param, bool):
|
||||||
|
sha.update(bytearray(pack("!B", param)))
|
||||||
|
elif isinstance(param, float):
|
||||||
|
sha.update(bytearray(pack("!f", param)))
|
||||||
|
elif isinstance(param, int):
|
||||||
|
sha.update(bytearray(pack("!I", param)))
|
||||||
|
elif isinstance(param, str):
|
||||||
|
sha.update(param.encode("utf-8"))
|
||||||
|
else:
|
||||||
|
logger.warning("cannot hash param: %s, %s", param, type(param))
|
||||||
|
|
|
@ -1,34 +1,61 @@
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict
|
from typing import Any, Callable, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class JobStatus(str, Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
SUCCESS = "success"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class JobType(str, Enum):
|
||||||
|
TXT2TXT = "txt2txt"
|
||||||
|
TXT2IMG = "txt2img"
|
||||||
|
IMG2IMG = "img2img"
|
||||||
|
INPAINT = "inpaint"
|
||||||
|
UPSCALE = "upscale"
|
||||||
|
BLEND = "blend"
|
||||||
|
CHAIN = "chain"
|
||||||
|
|
||||||
|
|
||||||
class ProgressCommand:
|
class ProgressCommand:
|
||||||
device: str
|
device: str
|
||||||
job: str
|
job: str
|
||||||
finished: bool
|
job_type: str
|
||||||
progress: int
|
status: JobStatus
|
||||||
cancelled: bool
|
results: int
|
||||||
failed: bool
|
steps: int
|
||||||
|
stages: int
|
||||||
|
tiles: int
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
job: str,
|
job: str,
|
||||||
|
job_type: str,
|
||||||
device: str,
|
device: str,
|
||||||
finished: bool,
|
status: JobStatus,
|
||||||
progress: int,
|
results: int = 0,
|
||||||
cancelled: bool = False,
|
steps: int = 0,
|
||||||
failed: bool = False,
|
stages: int = 0,
|
||||||
|
tiles: int = 0,
|
||||||
):
|
):
|
||||||
self.job = job
|
self.job = job
|
||||||
|
self.job_type = job_type
|
||||||
self.device = device
|
self.device = device
|
||||||
self.finished = finished
|
self.status = status
|
||||||
self.progress = progress
|
self.results = results
|
||||||
self.cancelled = cancelled
|
self.steps = steps
|
||||||
self.failed = failed
|
self.stages = stages
|
||||||
|
self.tiles = tiles
|
||||||
|
|
||||||
|
|
||||||
class JobCommand:
|
class JobCommand:
|
||||||
device: str
|
device: str
|
||||||
name: str
|
name: str
|
||||||
|
job_type: str
|
||||||
fn: Callable[..., None]
|
fn: Callable[..., None]
|
||||||
args: Any
|
args: Any
|
||||||
kwargs: Dict[str, Any]
|
kwargs: Dict[str, Any]
|
||||||
|
@ -37,12 +64,14 @@ class JobCommand:
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
device: str,
|
device: str,
|
||||||
|
job_type: str,
|
||||||
fn: Callable[..., None],
|
fn: Callable[..., None],
|
||||||
args: Any,
|
args: Any,
|
||||||
kwargs: Dict[str, Any],
|
kwargs: Dict[str, Any],
|
||||||
):
|
):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.job_type = job_type
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
self.args = args
|
self.args = args
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
|
@ -2,21 +2,23 @@ from logging import getLogger
|
||||||
from os import getpid
|
from os import getpid
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from torch.multiprocessing import Queue, Value
|
from torch.multiprocessing import Queue, Value
|
||||||
|
|
||||||
from ..errors import CancelledException
|
from ..errors import CancelledException
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
from .command import JobCommand, ProgressCommand
|
from .command import JobCommand, JobStatus, ProgressCommand
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
ProgressCallback = Callable[[int, int, Any], None]
|
ProgressCallback = Callable[[int, int, np.ndarray], None]
|
||||||
|
|
||||||
|
|
||||||
class WorkerContext:
|
class WorkerContext:
|
||||||
cancel: "Value[bool]"
|
cancel: "Value[bool]"
|
||||||
job: Optional[str]
|
job: Optional[str]
|
||||||
|
job_type: Optional[str]
|
||||||
name: str
|
name: str
|
||||||
pending: "Queue[JobCommand]"
|
pending: "Queue[JobCommand]"
|
||||||
active_pid: "Value[int]"
|
active_pid: "Value[int]"
|
||||||
|
@ -41,6 +43,7 @@ class WorkerContext:
|
||||||
timeout: float,
|
timeout: float,
|
||||||
):
|
):
|
||||||
self.job = None
|
self.job = None
|
||||||
|
self.job_type = None
|
||||||
self.name = name
|
self.name = name
|
||||||
self.device = device
|
self.device = device
|
||||||
self.cancel = cancel
|
self.cancel = cancel
|
||||||
|
@ -54,9 +57,15 @@ class WorkerContext:
|
||||||
self.retries = retries
|
self.retries = retries
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
def start(self, job: str) -> None:
|
def start(self, job: JobCommand) -> None:
|
||||||
self.job = job
|
# set job name and type
|
||||||
|
self.job = job.name
|
||||||
|
self.job_type = job.job_type
|
||||||
|
|
||||||
|
# reset retries
|
||||||
self.retries = self.initial_retries
|
self.retries = self.initial_retries
|
||||||
|
|
||||||
|
# clear flags
|
||||||
self.set_cancel(cancel=False)
|
self.set_cancel(cancel=False)
|
||||||
self.set_idle(idle=False)
|
self.set_idle(idle=False)
|
||||||
|
|
||||||
|
@ -81,7 +90,7 @@ class WorkerContext:
|
||||||
|
|
||||||
def get_progress(self) -> int:
|
def get_progress(self) -> int:
|
||||||
if self.last_progress is not None:
|
if self.last_progress is not None:
|
||||||
return self.last_progress.progress
|
return self.last_progress.steps
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@ -112,13 +121,11 @@ class WorkerContext:
|
||||||
logger.debug("setting progress for job %s to %s", self.job, progress)
|
logger.debug("setting progress for job %s to %s", self.job, progress)
|
||||||
self.last_progress = ProgressCommand(
|
self.last_progress = ProgressCommand(
|
||||||
self.job,
|
self.job,
|
||||||
|
self.job_type,
|
||||||
self.device.device,
|
self.device.device,
|
||||||
False,
|
JobStatus.RUNNING,
|
||||||
progress,
|
steps=progress,
|
||||||
self.is_cancelled(),
|
|
||||||
False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.progress.put(
|
self.progress.put(
|
||||||
self.last_progress,
|
self.last_progress,
|
||||||
block=False,
|
block=False,
|
||||||
|
@ -131,11 +138,10 @@ class WorkerContext:
|
||||||
logger.debug("setting finished for job %s", self.job)
|
logger.debug("setting finished for job %s", self.job)
|
||||||
self.last_progress = ProgressCommand(
|
self.last_progress = ProgressCommand(
|
||||||
self.job,
|
self.job,
|
||||||
|
self.job_type,
|
||||||
self.device.device,
|
self.device.device,
|
||||||
True,
|
JobStatus.SUCCESS, # TODO: FAILED
|
||||||
self.get_progress(),
|
steps=self.get_progress(),
|
||||||
self.is_cancelled(),
|
|
||||||
False,
|
|
||||||
)
|
)
|
||||||
self.progress.put(
|
self.progress.put(
|
||||||
self.last_progress,
|
self.last_progress,
|
||||||
|
@ -150,11 +156,10 @@ class WorkerContext:
|
||||||
try:
|
try:
|
||||||
self.last_progress = ProgressCommand(
|
self.last_progress = ProgressCommand(
|
||||||
self.job,
|
self.job,
|
||||||
|
self.job_type,
|
||||||
self.device.device,
|
self.device.device,
|
||||||
True,
|
JobStatus.FAILED,
|
||||||
self.get_progress(),
|
steps=self.get_progress(),
|
||||||
self.is_cancelled(),
|
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
self.progress.put(
|
self.progress.put(
|
||||||
self.last_progress,
|
self.last_progress,
|
||||||
|
@ -162,25 +167,3 @@ class WorkerContext:
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error setting failure on job %s", self.job)
|
logger.exception("error setting failure on job %s", self.job)
|
||||||
|
|
||||||
|
|
||||||
class JobStatus:
|
|
||||||
name: str
|
|
||||||
device: str
|
|
||||||
progress: int
|
|
||||||
cancelled: bool
|
|
||||||
finished: bool
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
device: DeviceParams,
|
|
||||||
progress: int = 0,
|
|
||||||
cancelled: bool = False,
|
|
||||||
finished: bool = False,
|
|
||||||
) -> None:
|
|
||||||
self.name = name
|
|
||||||
self.device = device.device
|
|
||||||
self.progress = progress
|
|
||||||
self.cancelled = cancelled
|
|
||||||
self.finished = finished
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ from torch.multiprocessing import Process, Queue, Value
|
||||||
|
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from .command import JobCommand, ProgressCommand
|
from .command import JobCommand, JobStatus, ProgressCommand
|
||||||
from .context import WorkerContext
|
from .context import WorkerContext
|
||||||
from .utils import Interval
|
from .utils import Interval
|
||||||
from .worker import worker_main
|
from .worker import worker_main
|
||||||
|
@ -201,6 +201,10 @@ class DevicePoolExecutor:
|
||||||
should be cancelled on the next progress callback.
|
should be cancelled on the next progress callback.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if key in self.cancelled_jobs:
|
||||||
|
logger.debug("cancelling already cancelled job: %s", key)
|
||||||
|
return True
|
||||||
|
|
||||||
for job in self.finished_jobs:
|
for job in self.finished_jobs:
|
||||||
if job.job == key:
|
if job.job == key:
|
||||||
logger.debug("cannot cancel finished job: %s", key)
|
logger.debug("cannot cancel finished job: %s", key)
|
||||||
|
@ -209,6 +213,9 @@ class DevicePoolExecutor:
|
||||||
for job in self.pending_jobs:
|
for job in self.pending_jobs:
|
||||||
if job.name == key:
|
if job.name == key:
|
||||||
self.pending_jobs.remove(job)
|
self.pending_jobs.remove(job)
|
||||||
|
self.cancelled_jobs.append(
|
||||||
|
key
|
||||||
|
) # ensure workers never pick up this job and the status endpoint knows about it later
|
||||||
logger.info("cancelled pending job: %s", key)
|
logger.info("cancelled pending job: %s", key)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -221,28 +228,31 @@ class DevicePoolExecutor:
|
||||||
self.cancelled_jobs.append(key)
|
self.cancelled_jobs.append(key)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def done(self, key: str) -> Tuple[bool, Optional[ProgressCommand]]:
|
def status(self, key: str) -> Tuple[JobStatus, Optional[ProgressCommand]]:
|
||||||
"""
|
"""
|
||||||
Check if a job has been finished and report the last progress update.
|
Check if a job has been finished and report the last progress update.
|
||||||
|
|
||||||
If the job is still pending, the first item will be True and there will be no ProgressCommand.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if key in self.cancelled_jobs:
|
||||||
|
logger.debug("checking status for cancelled job: %s", key)
|
||||||
|
return (JobStatus.CANCELLED, None)
|
||||||
|
|
||||||
if key in self.running_jobs:
|
if key in self.running_jobs:
|
||||||
logger.debug("checking status for running job: %s", key)
|
logger.debug("checking status for running job: %s", key)
|
||||||
return (False, self.running_jobs[key])
|
return (JobStatus.RUNNING, self.running_jobs[key])
|
||||||
|
|
||||||
for job in self.finished_jobs:
|
for job in self.finished_jobs:
|
||||||
if job.job == key:
|
if job.job == key:
|
||||||
logger.debug("checking status for finished job: %s", key)
|
logger.debug("checking status for finished job: %s", key)
|
||||||
return (False, job)
|
return (job.status, job)
|
||||||
|
|
||||||
for job in self.pending_jobs:
|
for job in self.pending_jobs:
|
||||||
if job.name == key:
|
if job.name == key:
|
||||||
logger.debug("checking status for pending job: %s", key)
|
logger.debug("checking status for pending job: %s", key)
|
||||||
return (True, None)
|
return (JobStatus.PENDING, None)
|
||||||
|
|
||||||
logger.trace("checking status for unknown job: %s", key)
|
logger.trace("checking status for unknown job: %s", key)
|
||||||
return (False, None)
|
return (JobStatus.UNKNOWN, None)
|
||||||
|
|
||||||
def join(self):
|
def join(self):
|
||||||
logger.info("stopping worker pool")
|
logger.info("stopping worker pool")
|
||||||
|
@ -383,6 +393,7 @@ class DevicePoolExecutor:
|
||||||
def submit(
|
def submit(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
|
job_type: str,
|
||||||
fn: Callable[..., None],
|
fn: Callable[..., None],
|
||||||
/,
|
/,
|
||||||
*args,
|
*args,
|
||||||
|
@ -399,56 +410,63 @@ class DevicePoolExecutor:
|
||||||
)
|
)
|
||||||
|
|
||||||
# build and queue job
|
# build and queue job
|
||||||
job = JobCommand(key, device, fn, args, kwargs)
|
job = JobCommand(key, device, job_type, fn, args, kwargs)
|
||||||
self.pending_jobs.append(job)
|
self.pending_jobs.append(job)
|
||||||
|
|
||||||
def status(self) -> Dict[str, List[Tuple[str, int, bool, bool, bool, bool]]]:
|
def summary(self) -> Dict[str, List[Tuple[str, int, JobStatus]]]:
|
||||||
"""
|
"""
|
||||||
Returns a tuple of: job/device, progress, progress, finished, cancelled, failed
|
Returns a tuple of: job/device, progress, progress, finished, cancelled, failed
|
||||||
"""
|
"""
|
||||||
return {
|
|
||||||
"cancelled": [],
|
jobs: Tuple[str, int, JobStatus] = []
|
||||||
"finished": [
|
jobs.extend(
|
||||||
|
[
|
||||||
(
|
(
|
||||||
job.job,
|
job,
|
||||||
job.progress,
|
0,
|
||||||
False,
|
JobStatus.CANCELLED,
|
||||||
job.finished,
|
|
||||||
job.cancelled,
|
|
||||||
job.failed,
|
|
||||||
)
|
)
|
||||||
for job in self.finished_jobs
|
for job in self.cancelled_jobs
|
||||||
],
|
]
|
||||||
"pending": [
|
)
|
||||||
|
jobs.extend(
|
||||||
|
[
|
||||||
(
|
(
|
||||||
job.name,
|
job.name,
|
||||||
0,
|
0,
|
||||||
True,
|
JobStatus.PENDING,
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
)
|
)
|
||||||
for job in self.pending_jobs
|
for job in self.pending_jobs
|
||||||
],
|
]
|
||||||
"running": [
|
)
|
||||||
|
jobs.extend(
|
||||||
|
[
|
||||||
(
|
(
|
||||||
name,
|
name,
|
||||||
job.progress,
|
job.steps,
|
||||||
False,
|
job.status,
|
||||||
job.finished,
|
|
||||||
job.cancelled,
|
|
||||||
job.failed,
|
|
||||||
)
|
)
|
||||||
for name, job in self.running_jobs.items()
|
for name, job in self.running_jobs.items()
|
||||||
],
|
]
|
||||||
"total": [
|
)
|
||||||
|
jobs.extend(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
job.job,
|
||||||
|
job.steps,
|
||||||
|
job.status,
|
||||||
|
)
|
||||||
|
for job in self.finished_jobs
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"jobs": jobs,
|
||||||
|
"workers": [
|
||||||
(
|
(
|
||||||
device,
|
device,
|
||||||
total,
|
total,
|
||||||
self.workers[device].is_alive(),
|
self.workers[device].is_alive(),
|
||||||
False,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
)
|
)
|
||||||
for device, total in self.total_jobs.items()
|
for device, total in self.total_jobs.items()
|
||||||
],
|
],
|
||||||
|
@ -476,20 +494,18 @@ class DevicePoolExecutor:
|
||||||
self.cancelled_jobs.remove(progress.job)
|
self.cancelled_jobs.remove(progress.job)
|
||||||
|
|
||||||
def update_job(self, progress: ProgressCommand):
|
def update_job(self, progress: ProgressCommand):
|
||||||
if progress.finished:
|
if progress.status in [JobStatus.SUCCESS, JobStatus.FAILED]:
|
||||||
return self.finish_job(progress)
|
return self.finish_job(progress)
|
||||||
|
|
||||||
# move from pending to running
|
# move from pending to running
|
||||||
logger.debug(
|
logger.debug("progress update for job: %s to %s", progress.job, progress.steps)
|
||||||
"progress update for job: %s to %s", progress.job, progress.progress
|
|
||||||
)
|
|
||||||
self.running_jobs[progress.job] = progress
|
self.running_jobs[progress.job] = progress
|
||||||
self.pending_jobs[:] = [
|
self.pending_jobs[:] = [
|
||||||
job for job in self.pending_jobs if job.name != progress.job
|
job for job in self.pending_jobs if job.name != progress.job
|
||||||
]
|
]
|
||||||
|
|
||||||
# increment job counter if this is the start of a new job
|
# increment job counter if this is the start of a new job
|
||||||
if progress.progress == 0:
|
if progress.steps == 0:
|
||||||
if progress.device in self.total_jobs:
|
if progress.device in self.total_jobs:
|
||||||
self.total_jobs[progress.device] += 1
|
self.total_jobs[progress.device] += 1
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -57,7 +57,7 @@ def worker_main(
|
||||||
logger.info("worker %s got job: %s", worker.device.device, job.name)
|
logger.info("worker %s got job: %s", worker.device.device, job.name)
|
||||||
|
|
||||||
# clear flags and save the job name
|
# clear flags and save the job name
|
||||||
worker.start(job.name)
|
worker.start(job)
|
||||||
logger.info("starting job: %s", job.name)
|
logger.info("starting job: %s", job.name)
|
||||||
|
|
||||||
# reset progress, which does a final check for cancellation
|
# reset progress, which does a final check for cancellation
|
||||||
|
|
|
@ -1,12 +1,6 @@
|
||||||
from diffusers import OnnxStableDiffusionPipeline
|
from diffusers import OnnxStableDiffusionPipeline
|
||||||
from os import path
|
from os import path
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import onnxruntime as ort
|
|
||||||
import torch
|
|
||||||
import time
|
|
||||||
|
|
||||||
cfg = 8
|
cfg = 8
|
||||||
steps = 22
|
steps = 22
|
||||||
height = 512
|
height = 512
|
||||||
|
|
|
@ -22,6 +22,7 @@ from onnx_web.params import (
|
||||||
UpscaleParams,
|
UpscaleParams,
|
||||||
)
|
)
|
||||||
from onnx_web.server.context import ServerContext
|
from onnx_web.server.context import ServerContext
|
||||||
|
from onnx_web.worker.command import JobCommand
|
||||||
from onnx_web.worker.context import WorkerContext
|
from onnx_web.worker.context import WorkerContext
|
||||||
from tests.helpers import (
|
from tests.helpers import (
|
||||||
TEST_MODEL_DIFFUSION_SD15,
|
TEST_MODEL_DIFFUSION_SD15,
|
||||||
|
@ -57,7 +58,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
3,
|
3,
|
||||||
0.1,
|
0.1,
|
||||||
)
|
)
|
||||||
worker.start("test")
|
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
|
||||||
|
|
||||||
run_txt2img_pipeline(
|
run_txt2img_pipeline(
|
||||||
worker,
|
worker,
|
||||||
|
@ -72,7 +73,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
1,
|
1,
|
||||||
),
|
),
|
||||||
Size(256, 256),
|
Size(256, 256),
|
||||||
["test-txt2img-basic.png"],
|
|
||||||
UpscaleParams("test"),
|
UpscaleParams("test"),
|
||||||
HighresParams(False, 1, 0, 0),
|
HighresParams(False, 1, 0, 0),
|
||||||
)
|
)
|
||||||
|
@ -103,7 +103,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
3,
|
3,
|
||||||
0.1,
|
0.1,
|
||||||
)
|
)
|
||||||
worker.start("test")
|
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
|
||||||
|
|
||||||
run_txt2img_pipeline(
|
run_txt2img_pipeline(
|
||||||
worker,
|
worker,
|
||||||
|
@ -119,7 +119,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
batch=2,
|
batch=2,
|
||||||
),
|
),
|
||||||
Size(256, 256),
|
Size(256, 256),
|
||||||
["test-txt2img-batch-0.png", "test-txt2img-batch-1.png"],
|
|
||||||
UpscaleParams("test"),
|
UpscaleParams("test"),
|
||||||
HighresParams(False, 1, 0, 0),
|
HighresParams(False, 1, 0, 0),
|
||||||
)
|
)
|
||||||
|
@ -152,7 +151,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
3,
|
3,
|
||||||
0.1,
|
0.1,
|
||||||
)
|
)
|
||||||
worker.start("test")
|
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
|
||||||
|
|
||||||
run_txt2img_pipeline(
|
run_txt2img_pipeline(
|
||||||
worker,
|
worker,
|
||||||
|
@ -168,7 +167,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
unet_tile=256,
|
unet_tile=256,
|
||||||
),
|
),
|
||||||
Size(256, 256),
|
Size(256, 256),
|
||||||
["test-txt2img-highres.png"],
|
|
||||||
UpscaleParams("test", scale=2, outscale=2),
|
UpscaleParams("test", scale=2, outscale=2),
|
||||||
HighresParams(True, 2, 0, 0),
|
HighresParams(True, 2, 0, 0),
|
||||||
)
|
)
|
||||||
|
@ -198,7 +196,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
3,
|
3,
|
||||||
0.1,
|
0.1,
|
||||||
)
|
)
|
||||||
worker.start("test")
|
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
|
||||||
|
|
||||||
run_txt2img_pipeline(
|
run_txt2img_pipeline(
|
||||||
worker,
|
worker,
|
||||||
|
@ -214,7 +212,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
||||||
batch=2,
|
batch=2,
|
||||||
),
|
),
|
||||||
Size(256, 256),
|
Size(256, 256),
|
||||||
["test-txt2img-highres-batch-0.png", "test-txt2img-highres-batch-1.png"],
|
|
||||||
UpscaleParams("test"),
|
UpscaleParams("test"),
|
||||||
HighresParams(True, 2, 0, 0),
|
HighresParams(True, 2, 0, 0),
|
||||||
)
|
)
|
||||||
|
@ -230,7 +227,7 @@ class TestImg2ImgPipeline(unittest.TestCase):
|
||||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
worker = test_worker()
|
worker = test_worker()
|
||||||
worker.start("test")
|
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
|
||||||
|
|
||||||
source = Image.new("RGB", (64, 64), "black")
|
source = Image.new("RGB", (64, 64), "black")
|
||||||
run_img2img_pipeline(
|
run_img2img_pipeline(
|
||||||
|
@ -245,7 +242,6 @@ class TestImg2ImgPipeline(unittest.TestCase):
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
),
|
),
|
||||||
["test-img2img.png"],
|
|
||||||
UpscaleParams("test"),
|
UpscaleParams("test"),
|
||||||
HighresParams(False, 1, 0, 0),
|
HighresParams(False, 1, 0, 0),
|
||||||
source,
|
source,
|
||||||
|
@ -259,7 +255,7 @@ class TestInpaintPipeline(unittest.TestCase):
|
||||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
|
||||||
def test_basic_white(self):
|
def test_basic_white(self):
|
||||||
worker = test_worker()
|
worker = test_worker()
|
||||||
worker.start("test")
|
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
|
||||||
|
|
||||||
source = Image.new("RGB", (64, 64), "black")
|
source = Image.new("RGB", (64, 64), "black")
|
||||||
mask = Image.new("RGB", (64, 64), "white")
|
mask = Image.new("RGB", (64, 64), "white")
|
||||||
|
@ -277,7 +273,6 @@ class TestInpaintPipeline(unittest.TestCase):
|
||||||
unet_tile=64,
|
unet_tile=64,
|
||||||
),
|
),
|
||||||
Size(*source.size),
|
Size(*source.size),
|
||||||
["test-inpaint-white.png"],
|
|
||||||
UpscaleParams("test"),
|
UpscaleParams("test"),
|
||||||
HighresParams(False, 1, 0, 0),
|
HighresParams(False, 1, 0, 0),
|
||||||
source,
|
source,
|
||||||
|
@ -296,7 +291,7 @@ class TestInpaintPipeline(unittest.TestCase):
|
||||||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
|
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
|
||||||
def test_basic_black(self):
|
def test_basic_black(self):
|
||||||
worker = test_worker()
|
worker = test_worker()
|
||||||
worker.start("test")
|
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
|
||||||
|
|
||||||
source = Image.new("RGB", (64, 64), "black")
|
source = Image.new("RGB", (64, 64), "black")
|
||||||
mask = Image.new("RGB", (64, 64), "black")
|
mask = Image.new("RGB", (64, 64), "black")
|
||||||
|
@ -314,7 +309,6 @@ class TestInpaintPipeline(unittest.TestCase):
|
||||||
unet_tile=64,
|
unet_tile=64,
|
||||||
),
|
),
|
||||||
Size(*source.size),
|
Size(*source.size),
|
||||||
["test-inpaint-black.png"],
|
|
||||||
UpscaleParams("test"),
|
UpscaleParams("test"),
|
||||||
HighresParams(False, 1, 0, 0),
|
HighresParams(False, 1, 0, 0),
|
||||||
source,
|
source,
|
||||||
|
@ -353,7 +347,7 @@ class TestUpscalePipeline(unittest.TestCase):
|
||||||
3,
|
3,
|
||||||
0.1,
|
0.1,
|
||||||
)
|
)
|
||||||
worker.start("test")
|
worker.start(JobCommand("test", "test", "test", run_upscale_pipeline, [], {}))
|
||||||
|
|
||||||
source = Image.new("RGB", (64, 64), "black")
|
source = Image.new("RGB", (64, 64), "black")
|
||||||
run_upscale_pipeline(
|
run_upscale_pipeline(
|
||||||
|
@ -369,7 +363,6 @@ class TestUpscalePipeline(unittest.TestCase):
|
||||||
1,
|
1,
|
||||||
),
|
),
|
||||||
Size(256, 256),
|
Size(256, 256),
|
||||||
["test-upscale.png"],
|
|
||||||
UpscaleParams("test"),
|
UpscaleParams("test"),
|
||||||
HighresParams(False, 1, 0, 0),
|
HighresParams(False, 1, 0, 0),
|
||||||
source,
|
source,
|
||||||
|
@ -399,7 +392,7 @@ class TestBlendPipeline(unittest.TestCase):
|
||||||
3,
|
3,
|
||||||
0.1,
|
0.1,
|
||||||
)
|
)
|
||||||
worker.start("test")
|
worker.start(JobCommand("test", "test", "test", run_blend_pipeline, [], {}))
|
||||||
|
|
||||||
source = Image.new("RGBA", (64, 64), "black")
|
source = Image.new("RGBA", (64, 64), "black")
|
||||||
mask = Image.new("RGBA", (64, 64), "white")
|
mask = Image.new("RGBA", (64, 64), "white")
|
||||||
|
@ -417,7 +410,6 @@ class TestBlendPipeline(unittest.TestCase):
|
||||||
unet_tile=64,
|
unet_tile=64,
|
||||||
),
|
),
|
||||||
Size(64, 64),
|
Size(64, 64),
|
||||||
["test-blend.png"],
|
|
||||||
UpscaleParams("test"),
|
UpscaleParams("test"),
|
||||||
[source, source],
|
[source, source],
|
||||||
mask,
|
mask,
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
||||||
|
|
||||||
from onnx_web.params import DeviceParams
|
from onnx_web.params import DeviceParams
|
||||||
from onnx_web.server.context import ServerContext
|
from onnx_web.server.context import ServerContext
|
||||||
|
from onnx_web.worker.command import JobStatus
|
||||||
from onnx_web.worker.pool import DevicePoolExecutor
|
from onnx_web.worker.pool import DevicePoolExecutor
|
||||||
|
|
||||||
TEST_JOIN_TIMEOUT = 0.2
|
TEST_JOIN_TIMEOUT = 0.2
|
||||||
|
@ -50,11 +51,11 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||||
self.pool.start()
|
self.pool.start()
|
||||||
|
|
||||||
self.pool.submit("test", wait_job, lock=lock)
|
self.pool.submit("test", "test", wait_job, lock=lock)
|
||||||
self.assertEqual(self.pool.done("test"), (True, None))
|
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
|
||||||
|
|
||||||
self.assertTrue(self.pool.cancel("test"))
|
self.assertTrue(self.pool.cancel("test"))
|
||||||
self.assertEqual(self.pool.done("test"), (False, None))
|
self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None))
|
||||||
|
|
||||||
def test_cancel_running(self):
|
def test_cancel_running(self):
|
||||||
pass
|
pass
|
||||||
|
@ -88,12 +89,14 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
self.pool = DevicePoolExecutor(
|
self.pool = DevicePoolExecutor(
|
||||||
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
lock.clear()
|
||||||
self.pool.start(lock)
|
self.pool.start(lock)
|
||||||
self.pool.submit("test", test_job)
|
self.pool.submit("test", "test", test_job)
|
||||||
sleep(5.0)
|
sleep(5.0)
|
||||||
|
|
||||||
pending, _progress = self.pool.done("test")
|
status, _progress = self.pool.status("test")
|
||||||
self.assertFalse(pending)
|
self.assertEqual(status, JobStatus.RUNNING)
|
||||||
|
|
||||||
def test_done_pending(self):
|
def test_done_pending(self):
|
||||||
device = DeviceParams("cpu", "CPUProvider")
|
device = DeviceParams("cpu", "CPUProvider")
|
||||||
|
@ -102,9 +105,9 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||||
self.pool.start(lock)
|
self.pool.start(lock)
|
||||||
|
|
||||||
self.pool.submit("test1", test_job)
|
self.pool.submit("test1", "test", test_job)
|
||||||
self.pool.submit("test2", test_job)
|
self.pool.submit("test2", "test", test_job)
|
||||||
self.assertTrue(self.pool.done("test2"), (True, None))
|
self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None))
|
||||||
|
|
||||||
lock.set()
|
lock.set()
|
||||||
|
|
||||||
|
@ -119,12 +122,12 @@ class TestWorkerPool(unittest.TestCase):
|
||||||
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||||
)
|
)
|
||||||
self.pool.start()
|
self.pool.start()
|
||||||
self.pool.submit("test", wait_job)
|
self.pool.submit("test", "test", wait_job)
|
||||||
self.assertEqual(self.pool.done("test"), (True, None))
|
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
|
||||||
|
|
||||||
sleep(5.0)
|
sleep(5.0)
|
||||||
pending, _progress = self.pool.done("test")
|
status, _progress = self.pool.status("test")
|
||||||
self.assertFalse(pending)
|
self.assertEqual(status, JobStatus.SUCCESS)
|
||||||
|
|
||||||
def test_recycle_live(self):
|
def test_recycle_live(self):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -40,7 +40,7 @@ class WorkerMainTests(unittest.TestCase):
|
||||||
nonlocal status
|
nonlocal status
|
||||||
status = exit_status
|
status = exit_status
|
||||||
|
|
||||||
job = JobCommand("test", "test", main_interrupt, [], {})
|
job = JobCommand("test", "test", "test", main_interrupt, [], {})
|
||||||
cancel = Value("L", False)
|
cancel = Value("L", False)
|
||||||
logs = Queue()
|
logs = Queue()
|
||||||
pending = Queue()
|
pending = Queue()
|
||||||
|
@ -75,7 +75,7 @@ class WorkerMainTests(unittest.TestCase):
|
||||||
nonlocal status
|
nonlocal status
|
||||||
status = exit_status
|
status = exit_status
|
||||||
|
|
||||||
job = JobCommand("test", "test", main_retry, [], {})
|
job = JobCommand("test", "test", "test", main_retry, [], {})
|
||||||
cancel = Value("L", False)
|
cancel = Value("L", False)
|
||||||
logs = Queue()
|
logs = Queue()
|
||||||
pending = Queue()
|
pending = Queue()
|
||||||
|
@ -144,7 +144,7 @@ class WorkerMainTests(unittest.TestCase):
|
||||||
nonlocal status
|
nonlocal status
|
||||||
status = exit_status
|
status = exit_status
|
||||||
|
|
||||||
job = JobCommand("test", "test", main_memory, [], {})
|
job = JobCommand("test", "test", "test", main_memory, [], {})
|
||||||
cancel = Value("L", False)
|
cancel = Value("L", False)
|
||||||
logs = Queue()
|
logs = Queue()
|
||||||
pending = Queue()
|
pending = Queue()
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
onnx-web is designed to simplify the process of running Stable Diffusion and other [ONNX models](https://onnx.ai) so you
|
onnx-web is designed to simplify the process of running Stable Diffusion and other [ONNX models](https://onnx.ai) so you
|
||||||
can focus on making high quality, high resolution art. With the efficiency of hardware acceleration on both AMD and
|
can focus on making high quality, high resolution art. With the efficiency of hardware acceleration on both AMD and
|
||||||
Nvidia GPUs, and offering a reliable CPU software fallback, it offers the full feature set on desktop, laptops, and
|
Nvidia GPUs, and offering a reliable CPU software fallback, it offers the full feature set on desktop, laptops, and
|
||||||
servers with a seamless user experience.
|
multi-GPU servers with a seamless user experience.
|
||||||
|
|
||||||
You can navigate through the user-friendly web UI, hosted on Github Pages and accessible across all major browsers,
|
You can navigate through the user-friendly web UI, hosted on Github Pages and accessible across all major browsers,
|
||||||
including your go-to mobile device. Here, you have the flexibility to choose diffusion models and accelerators for each
|
including your go-to mobile device. Here, you have the flexibility to choose diffusion models and accelerators for each
|
||||||
|
@ -84,18 +84,6 @@ This is an incomplete list of new and interesting features:
|
||||||
- includes both the API and GUI bundle in a single container
|
- includes both the API and GUI bundle in a single container
|
||||||
- runs well on [RunPod](https://www.runpod.io/), [Vast.ai](https://vast.ai/), and other GPU container hosting services
|
- runs well on [RunPod](https://www.runpod.io/), [Vast.ai](https://vast.ai/), and other GPU container hosting services
|
||||||
|
|
||||||
## Contents
|
|
||||||
|
|
||||||
- [onnx-web](#onnx-web)
|
|
||||||
- [Features](#features)
|
|
||||||
- [Contents](#contents)
|
|
||||||
- [Setup](#setup)
|
|
||||||
- [Adding your own models](#adding-your-own-models)
|
|
||||||
- [Usage](#usage)
|
|
||||||
- [Known errors and solutions](#known-errors-and-solutions)
|
|
||||||
- [Running the containers](#running-the-containers)
|
|
||||||
- [Credits](#credits)
|
|
||||||
|
|
||||||
## Setup
|
## Setup
|
||||||
|
|
||||||
There are a few ways to run onnx-web:
|
There are a few ways to run onnx-web:
|
||||||
|
|
|
@ -4,10 +4,7 @@ import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils';
|
||||||
import { ServerParams } from '../config.js';
|
import { ServerParams } from '../config.js';
|
||||||
import {
|
import {
|
||||||
FilterResponse,
|
FilterResponse,
|
||||||
ImageResponse,
|
|
||||||
ImageResponseWithRetry,
|
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ReadyResponse,
|
|
||||||
RetryParams,
|
RetryParams,
|
||||||
WriteExtrasResponse,
|
WriteExtrasResponse,
|
||||||
} from '../types/api.js';
|
} from '../types/api.js';
|
||||||
|
@ -27,6 +24,7 @@ import {
|
||||||
} from '../types/params.js';
|
} from '../types/params.js';
|
||||||
import { range } from '../utils.js';
|
import { range } from '../utils.js';
|
||||||
import { ApiClient } from './base.js';
|
import { ApiClient } from './base.js';
|
||||||
|
import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/api-v2.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fixed precision for integer parameters.
|
* Fixed precision for integer parameters.
|
||||||
|
@ -43,8 +41,9 @@ export const FIXED_INTEGER = 0;
|
||||||
export const FIXED_FLOAT = 2;
|
export const FIXED_FLOAT = 2;
|
||||||
export const STATUS_SUCCESS = 200;
|
export const STATUS_SUCCESS = 200;
|
||||||
|
|
||||||
export function equalResponse(a: ImageResponse, b: ImageResponse): boolean {
|
export function equalResponse(a: JobResponse, b: JobResponse): boolean {
|
||||||
return a.outputs === b.outputs;
|
return a.name === b.name && a.status === b.status && a.type === b.type;
|
||||||
|
// return a.outputs === b.outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -141,8 +140,8 @@ export function appendHighresToURL(url: URL, highres: HighresParams) {
|
||||||
* Make an API client using the given API root and fetch client.
|
* Make an API client using the given API root and fetch client.
|
||||||
*/
|
*/
|
||||||
export function makeClient(root: string, token: Maybe<string> = undefined, f = fetch): ApiClient {
|
export function makeClient(root: string, token: Maybe<string> = undefined, f = fetch): ApiClient {
|
||||||
function parseRequest(url: URL, options: RequestInit): Promise<ImageResponse> {
|
function parseRequest(url: URL, options: RequestInit): Promise<JobResponse> {
|
||||||
return f(url, options).then((res) => parseApiResponse(root, res));
|
return f(url, options).then((res) => parseJobResponse(root, res));
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -218,7 +217,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
const res = await f(path);
|
const res = await f(path);
|
||||||
return await res.json() as Array<string>;
|
return await res.json() as Array<string>;
|
||||||
},
|
},
|
||||||
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
|
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
||||||
const url = makeImageURL(root, 'img2img', params);
|
const url = makeImageURL(root, 'img2img', params);
|
||||||
appendModelToURL(url, model);
|
appendModelToURL(url, model);
|
||||||
|
|
||||||
|
@ -240,12 +239,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
const body = new FormData();
|
const body = new FormData();
|
||||||
body.append('source', params.source, 'source');
|
body.append('source', params.source, 'source');
|
||||||
|
|
||||||
const image = await parseRequest(url, {
|
const job = await parseRequest(url, {
|
||||||
body,
|
body,
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
});
|
});
|
||||||
return {
|
return {
|
||||||
image,
|
job,
|
||||||
retry: {
|
retry: {
|
||||||
type: 'img2img',
|
type: 'img2img',
|
||||||
model,
|
model,
|
||||||
|
@ -254,7 +253,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
|
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
||||||
const url = makeImageURL(root, 'txt2img', params);
|
const url = makeImageURL(root, 'txt2img', params);
|
||||||
appendModelToURL(url, model);
|
appendModelToURL(url, model);
|
||||||
|
|
||||||
|
@ -274,11 +273,11 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
appendHighresToURL(url, highres);
|
appendHighresToURL(url, highres);
|
||||||
}
|
}
|
||||||
|
|
||||||
const image = await parseRequest(url, {
|
const job = await parseRequest(url, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
});
|
});
|
||||||
return {
|
return {
|
||||||
image,
|
job,
|
||||||
retry: {
|
retry: {
|
||||||
type: 'txt2img',
|
type: 'txt2img',
|
||||||
model,
|
model,
|
||||||
|
@ -288,7 +287,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
|
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
||||||
const url = makeImageURL(root, 'inpaint', params);
|
const url = makeImageURL(root, 'inpaint', params);
|
||||||
appendModelToURL(url, model);
|
appendModelToURL(url, model);
|
||||||
|
|
||||||
|
@ -309,12 +308,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
body.append('mask', params.mask, 'mask');
|
body.append('mask', params.mask, 'mask');
|
||||||
body.append('source', params.source, 'source');
|
body.append('source', params.source, 'source');
|
||||||
|
|
||||||
const image = await parseRequest(url, {
|
const job = await parseRequest(url, {
|
||||||
body,
|
body,
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
});
|
});
|
||||||
return {
|
return {
|
||||||
image,
|
job,
|
||||||
retry: {
|
retry: {
|
||||||
type: 'inpaint',
|
type: 'inpaint',
|
||||||
model,
|
model,
|
||||||
|
@ -323,7 +322,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
|
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
||||||
const url = makeImageURL(root, 'inpaint', params);
|
const url = makeImageURL(root, 'inpaint', params);
|
||||||
appendModelToURL(url, model);
|
appendModelToURL(url, model);
|
||||||
|
|
||||||
|
@ -361,12 +360,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
body.append('mask', params.mask, 'mask');
|
body.append('mask', params.mask, 'mask');
|
||||||
body.append('source', params.source, 'source');
|
body.append('source', params.source, 'source');
|
||||||
|
|
||||||
const image = await parseRequest(url, {
|
const job = await parseRequest(url, {
|
||||||
body,
|
body,
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
});
|
});
|
||||||
return {
|
return {
|
||||||
image,
|
job,
|
||||||
retry: {
|
retry: {
|
||||||
type: 'outpaint',
|
type: 'outpaint',
|
||||||
model,
|
model,
|
||||||
|
@ -375,7 +374,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
|
async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
||||||
const url = makeApiUrl(root, 'upscale');
|
const url = makeApiUrl(root, 'upscale');
|
||||||
appendModelToURL(url, model);
|
appendModelToURL(url, model);
|
||||||
|
|
||||||
|
@ -396,12 +395,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
const body = new FormData();
|
const body = new FormData();
|
||||||
body.append('source', params.source, 'source');
|
body.append('source', params.source, 'source');
|
||||||
|
|
||||||
const image = await parseRequest(url, {
|
const job = await parseRequest(url, {
|
||||||
body,
|
body,
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
});
|
});
|
||||||
return {
|
return {
|
||||||
image,
|
job,
|
||||||
retry: {
|
retry: {
|
||||||
type: 'upscale',
|
type: 'upscale',
|
||||||
model,
|
model,
|
||||||
|
@ -410,7 +409,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
async blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
|
async blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<JobResponseWithRetry> {
|
||||||
const url = makeApiUrl(root, 'blend');
|
const url = makeApiUrl(root, 'blend');
|
||||||
appendModelToURL(url, model);
|
appendModelToURL(url, model);
|
||||||
|
|
||||||
|
@ -426,12 +425,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
body.append(name, params.sources[i], name);
|
body.append(name, params.sources[i], name);
|
||||||
}
|
}
|
||||||
|
|
||||||
const image = await parseRequest(url, {
|
const job = await parseRequest(url, {
|
||||||
body,
|
body,
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
});
|
});
|
||||||
return {
|
return {
|
||||||
image,
|
job,
|
||||||
retry: {
|
retry: {
|
||||||
type: 'blend',
|
type: 'blend',
|
||||||
model,
|
model,
|
||||||
|
@ -440,8 +439,8 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
async chain(model: ModelParams, chain: ChainPipeline): Promise<ImageResponse> {
|
async chain(model: ModelParams, chain: ChainPipeline): Promise<JobResponse> {
|
||||||
const url = makeApiUrl(root, 'chain');
|
const url = makeApiUrl(root, 'job');
|
||||||
const body = JSON.stringify({
|
const body = JSON.stringify({
|
||||||
...chain,
|
...chain,
|
||||||
platform: model.platform,
|
platform: model.platform,
|
||||||
|
@ -456,23 +455,23 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
async ready(key: string): Promise<ReadyResponse> {
|
async status(keys: Array<string>): Promise<Array<JobResponse>> {
|
||||||
const path = makeApiUrl(root, 'ready');
|
const path = makeApiUrl(root, 'job', 'status');
|
||||||
path.searchParams.append('output', key);
|
path.searchParams.append('jobs', keys.join(','));
|
||||||
|
|
||||||
const res = await f(path);
|
const res = await f(path);
|
||||||
return await res.json() as ReadyResponse;
|
return await res.json() as Array<JobResponse>;
|
||||||
},
|
},
|
||||||
async cancel(key: string): Promise<boolean> {
|
async cancel(keys: Array<string>): Promise<Array<JobResponse>> {
|
||||||
const path = makeApiUrl(root, 'cancel');
|
const path = makeApiUrl(root, 'job', 'cancel');
|
||||||
path.searchParams.append('output', key);
|
path.searchParams.append('jobs', keys.join(','));
|
||||||
|
|
||||||
const res = await f(path, {
|
const res = await f(path, {
|
||||||
method: 'PUT',
|
method: 'PUT',
|
||||||
});
|
});
|
||||||
return res.status === STATUS_SUCCESS;
|
return await res.json() as Array<JobResponse>;
|
||||||
},
|
},
|
||||||
async retry(retry: RetryParams): Promise<ImageResponseWithRetry> {
|
async retry(retry: RetryParams): Promise<JobResponseWithRetry> {
|
||||||
switch (retry.type) {
|
switch (retry.type) {
|
||||||
case 'blend':
|
case 'blend':
|
||||||
return this.blend(retry.model, retry.params, retry.upscale);
|
return this.blend(retry.model, retry.params, retry.upscale);
|
||||||
|
@ -491,7 +490,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
async restart(): Promise<boolean> {
|
async restart(): Promise<boolean> {
|
||||||
const path = makeApiUrl(root, 'restart');
|
const path = makeApiUrl(root, 'worker', 'restart');
|
||||||
|
|
||||||
if (doesExist(token)) {
|
if (doesExist(token)) {
|
||||||
path.searchParams.append('token', token);
|
path.searchParams.append('token', token);
|
||||||
|
@ -502,8 +501,8 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
});
|
});
|
||||||
return res.status === STATUS_SUCCESS;
|
return res.status === STATUS_SUCCESS;
|
||||||
},
|
},
|
||||||
async status(): Promise<Array<unknown>> {
|
async workers(): Promise<Array<unknown>> {
|
||||||
const path = makeApiUrl(root, 'status');
|
const path = makeApiUrl(root, 'worker', 'status');
|
||||||
|
|
||||||
if (doesExist(token)) {
|
if (doesExist(token)) {
|
||||||
path.searchParams.append('token', token);
|
path.searchParams.append('token', token);
|
||||||
|
@ -512,6 +511,9 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
const res = await f(path);
|
const res = await f(path);
|
||||||
return res.json();
|
return res.json();
|
||||||
},
|
},
|
||||||
|
outputURL(image: SuccessJobResponse, index: number): string {
|
||||||
|
return new URL(joinPath('output', image.outputs[index]), root).toString();
|
||||||
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -521,24 +523,9 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
||||||
* The server sends over the output key, and the client is in the best position to turn
|
* The server sends over the output key, and the client is in the best position to turn
|
||||||
* that into a full URL, since it already knows the root URL of the server.
|
* that into a full URL, since it already knows the root URL of the server.
|
||||||
*/
|
*/
|
||||||
export async function parseApiResponse(root: string, res: Response): Promise<ImageResponse> {
|
export async function parseJobResponse(root: string, res: Response): Promise<JobResponse> {
|
||||||
type LimitedResponse = Omit<ImageResponse, 'outputs'> & { outputs: Array<string> };
|
|
||||||
|
|
||||||
if (res.status === STATUS_SUCCESS) {
|
if (res.status === STATUS_SUCCESS) {
|
||||||
const data = await res.json() as LimitedResponse;
|
return await res.json() as JobResponse;
|
||||||
|
|
||||||
const outputs = data.outputs.map((output) => {
|
|
||||||
const url = new URL(joinPath('output', output), root).toString();
|
|
||||||
return {
|
|
||||||
key: output,
|
|
||||||
url,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
return {
|
|
||||||
...data,
|
|
||||||
outputs,
|
|
||||||
};
|
|
||||||
} else {
|
} else {
|
||||||
throw new Error('request error');
|
throw new Error('request error');
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,19 @@
|
||||||
import { ServerParams } from '../config.js';
|
import { ServerParams } from '../config.js';
|
||||||
import { ExtrasFile } from '../types/model.js';
|
import { ExtrasFile } from '../types/model.js';
|
||||||
import { WriteExtrasResponse, FilterResponse, ModelResponse, ImageResponseWithRetry, ImageResponse, ReadyResponse, RetryParams } from '../types/api.js';
|
import { WriteExtrasResponse, FilterResponse, ModelResponse, RetryParams } from '../types/api.js';
|
||||||
import { ChainPipeline } from '../types/chain.js';
|
import { ChainPipeline } from '../types/chain.js';
|
||||||
import { ModelParams, Txt2ImgParams, UpscaleParams, HighresParams, Img2ImgParams, InpaintParams, OutpaintParams, UpscaleReqParams, BlendParams } from '../types/params.js';
|
import { ModelParams, Txt2ImgParams, UpscaleParams, HighresParams, Img2ImgParams, InpaintParams, OutpaintParams, UpscaleReqParams, BlendParams } from '../types/params.js';
|
||||||
|
import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/api-v2.js';
|
||||||
|
|
||||||
export interface ApiClient {
|
export interface ApiClient {
|
||||||
|
/**
|
||||||
|
* Get the first extras file.
|
||||||
|
*/
|
||||||
extras(): Promise<ExtrasFile>;
|
extras(): Promise<ExtrasFile>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update the first extras file.
|
||||||
|
*/
|
||||||
writeExtras(extras: ExtrasFile): Promise<WriteExtrasResponse>;
|
writeExtras(extras: ExtrasFile): Promise<WriteExtrasResponse>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -51,54 +58,60 @@ export interface ApiClient {
|
||||||
translation: Record<string, string>;
|
translation: Record<string, string>;
|
||||||
}>>;
|
}>>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the available wildcards.
|
||||||
|
*/
|
||||||
wildcards(): Promise<Array<string>>;
|
wildcards(): Promise<Array<string>>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Start a txt2img pipeline.
|
* Start a txt2img pipeline.
|
||||||
*/
|
*/
|
||||||
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
|
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Start an im2img pipeline.
|
* Start an im2img pipeline.
|
||||||
*/
|
*/
|
||||||
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
|
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Start an inpaint pipeline.
|
* Start an inpaint pipeline.
|
||||||
*/
|
*/
|
||||||
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
|
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Start an outpaint pipeline.
|
* Start an outpaint pipeline.
|
||||||
*/
|
*/
|
||||||
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
|
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Start an upscale pipeline.
|
* Start an upscale pipeline.
|
||||||
*/
|
*/
|
||||||
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
|
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Start a blending pipeline.
|
* Start a blending pipeline.
|
||||||
*/
|
*/
|
||||||
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
|
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<JobResponseWithRetry>;
|
||||||
|
|
||||||
chain(model: ModelParams, chain: ChainPipeline): Promise<ImageResponse>;
|
/**
|
||||||
|
* Start a custom chain pipeline.
|
||||||
|
*/
|
||||||
|
chain(model: ModelParams, chain: ChainPipeline): Promise<JobResponse>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check whether job has finished and its output is ready.
|
* Check whether job has finished and its output is ready.
|
||||||
*/
|
*/
|
||||||
ready(key: string): Promise<ReadyResponse>;
|
status(keys: Array<string>): Promise<Array<JobResponse>>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Cancel an existing job.
|
* Cancel an existing job.
|
||||||
*/
|
*/
|
||||||
cancel(key: string): Promise<boolean>;
|
cancel(keys: Array<string>): Promise<Array<JobResponse>>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Retry a previous job using the same parameters.
|
* Retry a previous job using the same parameters.
|
||||||
*/
|
*/
|
||||||
retry(params: RetryParams): Promise<ImageResponseWithRetry>;
|
retry(params: RetryParams): Promise<JobResponseWithRetry>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Restart the image job workers.
|
* Restart the image job workers.
|
||||||
|
@ -108,5 +121,7 @@ export interface ApiClient {
|
||||||
/**
|
/**
|
||||||
* Check the status of the image job workers.
|
* Check the status of the image job workers.
|
||||||
*/
|
*/
|
||||||
status(): Promise<Array<unknown>>;
|
workers(): Promise<Array<unknown>>;
|
||||||
|
|
||||||
|
outputURL(image: SuccessJobResponse, index: number): string;
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@ export const LOCAL_CLIENT = {
|
||||||
async params() {
|
async params() {
|
||||||
throw new NoServerError();
|
throw new NoServerError();
|
||||||
},
|
},
|
||||||
async ready(key) {
|
async status(key) {
|
||||||
throw new NoServerError();
|
throw new NoServerError();
|
||||||
},
|
},
|
||||||
async cancel(key) {
|
async cancel(key) {
|
||||||
|
@ -78,7 +78,10 @@ export const LOCAL_CLIENT = {
|
||||||
async restart() {
|
async restart() {
|
||||||
throw new NoServerError();
|
throw new NoServerError();
|
||||||
},
|
},
|
||||||
async status() {
|
async workers() {
|
||||||
throw new NoServerError();
|
throw new NoServerError();
|
||||||
}
|
},
|
||||||
|
outputURL(image, index) {
|
||||||
|
throw new NoServerError();
|
||||||
|
},
|
||||||
} as ApiClient;
|
} as ApiClient;
|
||||||
|
|
|
@ -97,11 +97,19 @@ export function expandRanges(range: string): Array<string | number> {
|
||||||
export const GRID_TILE_SIZE = 8192;
|
export const GRID_TILE_SIZE = 8192;
|
||||||
|
|
||||||
// eslint-disable-next-line max-params
|
// eslint-disable-next-line max-params
|
||||||
export function makeTxt2ImgGridPipeline(grid: PipelineGrid, model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): ChainPipeline {
|
export function makeTxt2ImgGridPipeline(
|
||||||
|
grid: PipelineGrid,
|
||||||
|
model: ModelParams,
|
||||||
|
params: Txt2ImgParams,
|
||||||
|
upscale?: UpscaleParams,
|
||||||
|
highres?: HighresParams,
|
||||||
|
): ChainPipeline {
|
||||||
const pipeline: ChainPipeline = {
|
const pipeline: ChainPipeline = {
|
||||||
defaults: {
|
defaults: {
|
||||||
...model,
|
...model,
|
||||||
...params,
|
...params,
|
||||||
|
...(upscale || {}),
|
||||||
|
...(highres || {}),
|
||||||
},
|
},
|
||||||
stages: [],
|
stages: [],
|
||||||
};
|
};
|
||||||
|
|
|
@ -10,6 +10,7 @@ import { OnnxState, StateContext } from '../state/full.js';
|
||||||
import { ErrorCard } from './card/ErrorCard.js';
|
import { ErrorCard } from './card/ErrorCard.js';
|
||||||
import { ImageCard } from './card/ImageCard.js';
|
import { ImageCard } from './card/ImageCard.js';
|
||||||
import { LoadingCard } from './card/LoadingCard.js';
|
import { LoadingCard } from './card/LoadingCard.js';
|
||||||
|
import { JobStatus } from '../types/api-v2.js';
|
||||||
|
|
||||||
export function ImageHistory() {
|
export function ImageHistory() {
|
||||||
const store = mustExist(useContext(StateContext));
|
const store = mustExist(useContext(StateContext));
|
||||||
|
@ -25,19 +26,19 @@ export function ImageHistory() {
|
||||||
|
|
||||||
const limited = history.slice(0, limit);
|
const limited = history.slice(0, limit);
|
||||||
for (const item of limited) {
|
for (const item of limited) {
|
||||||
const key = item.image.outputs[0].key;
|
const key = item.image.name;
|
||||||
|
|
||||||
if (doesExist(item.ready) && item.ready.ready) {
|
|
||||||
if (item.ready.cancelled || item.ready.failed) {
|
|
||||||
children.push([key, <ErrorCard key={`history-${key}`} image={item.image} ready={item.ready} retry={item.retry} />]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
switch (item.image.status) {
|
||||||
|
case JobStatus.SUCCESS:
|
||||||
children.push([key, <ImageCard key={`history-${key}`} image={item.image} onDelete={removeHistory} />]);
|
children.push([key, <ImageCard key={`history-${key}`} image={item.image} onDelete={removeHistory} />]);
|
||||||
continue;
|
break;
|
||||||
|
case JobStatus.FAILED:
|
||||||
|
children.push([key, <ErrorCard key={`history-${key}`} image={item.image} retry={item.retry} />]);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
children.push([key, <LoadingCard key={`history-${key}`} image={item.image} />]);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
children.push([key, <LoadingCard key={`history-${key}`} index={0} image={item.image} />]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return <Grid container spacing={2}>{children.map(([key, child]) => <Grid item key={key} xs={6}>{child}</Grid>)}</Grid>;
|
return <Grid container spacing={2}>{children.map(([key, child]) => <Grid item key={key} xs={6}>{child}</Grid>)}</Grid>;
|
||||||
|
|
|
@ -10,16 +10,15 @@ import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
import { ImageResponse, ReadyResponse, RetryParams } from '../../types/api.js';
|
import { FailedJobResponse, RetryParams } from '../../types/api-v2.js';
|
||||||
|
|
||||||
export interface ErrorCardProps {
|
export interface ErrorCardProps {
|
||||||
image: ImageResponse;
|
image: FailedJobResponse;
|
||||||
ready: ReadyResponse;
|
|
||||||
retry: Maybe<RetryParams>;
|
retry: Maybe<RetryParams>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ErrorCard(props: ErrorCardProps) {
|
export function ErrorCard(props: ErrorCardProps) {
|
||||||
const { image, ready, retry: retryParams } = props;
|
const { image, retry: retryParams } = props;
|
||||||
|
|
||||||
const client = mustExist(useContext(ClientContext));
|
const client = mustExist(useContext(ClientContext));
|
||||||
const { params } = mustExist(useContext(ConfigContext));
|
const { params } = mustExist(useContext(ConfigContext));
|
||||||
|
@ -32,8 +31,8 @@ export function ErrorCard(props: ErrorCardProps) {
|
||||||
removeHistory(image);
|
removeHistory(image);
|
||||||
|
|
||||||
if (doesExist(retryParams)) {
|
if (doesExist(retryParams)) {
|
||||||
const { image: nextImage, retry: nextRetry } = await client.retry(retryParams);
|
const { job: nextJob, retry: nextRetry } = await client.retry(retryParams);
|
||||||
pushHistory(nextImage, nextRetry);
|
pushHistory(nextJob, nextRetry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,10 +51,11 @@ export function ErrorCard(props: ErrorCardProps) {
|
||||||
spacing={2}
|
spacing={2}
|
||||||
sx={{ alignItems: 'center' }}
|
sx={{ alignItems: 'center' }}
|
||||||
>
|
>
|
||||||
<Alert severity='error'>{t('loading.progress', {
|
<Alert severity='error'>
|
||||||
current: ready.progress,
|
{t('loading.progress', image.steps)}
|
||||||
total: image.params.steps,
|
<br />
|
||||||
})}</Alert>
|
{image.error}
|
||||||
|
</Alert>
|
||||||
<Stack direction='row' spacing={2}>
|
<Stack direction='row' spacing={2}>
|
||||||
<Tooltip title={t('tooltip.retry')}>
|
<Tooltip title={t('tooltip.retry')}>
|
||||||
<IconButton onClick={() => retry.mutate()}>
|
<IconButton onClick={() => retry.mutate()}>
|
||||||
|
|
|
@ -2,21 +2,22 @@ import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils'
|
||||||
import { ArrowLeft, ArrowRight, Blender, Brush, ContentCopy, Delete, Download, ZoomOutMap } from '@mui/icons-material';
|
import { ArrowLeft, ArrowRight, Blender, Brush, ContentCopy, Delete, Download, ZoomOutMap } from '@mui/icons-material';
|
||||||
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Menu, MenuItem, Paper, Tooltip } from '@mui/material';
|
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Menu, MenuItem, Paper, Tooltip } from '@mui/material';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
import { useContext, useState } from 'react';
|
import { useContext, useMemo, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useHash } from 'react-use/lib/useHash';
|
import { useHash } from 'react-use/lib/useHash';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
import { ImageResponse } from '../../types/api.js';
|
|
||||||
import { range, visibleIndex } from '../../utils.js';
|
import { range, visibleIndex } from '../../utils.js';
|
||||||
import { BLEND_SOURCES } from '../../constants.js';
|
import { BLEND_SOURCES } from '../../constants.js';
|
||||||
|
import { JobResponse, SuccessJobResponse } from '../../types/api-v2.js';
|
||||||
|
import { getApiRoot } from '../../config.js';
|
||||||
|
|
||||||
export interface ImageCardProps {
|
export interface ImageCardProps {
|
||||||
image: ImageResponse;
|
image: SuccessJobResponse;
|
||||||
|
|
||||||
onDelete?: (key: ImageResponse) => void;
|
onDelete?: (key: JobResponse) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function GridItem(props: { xs: number; children: React.ReactNode }) {
|
export function GridItem(props: { xs: number; children: React.ReactNode }) {
|
||||||
|
@ -27,18 +28,19 @@ export function GridItem(props: { xs: number; children: React.ReactNode }) {
|
||||||
|
|
||||||
export function ImageCard(props: ImageCardProps) {
|
export function ImageCard(props: ImageCardProps) {
|
||||||
const { image } = props;
|
const { image } = props;
|
||||||
const { params, outputs, size } = image;
|
const { metadata, outputs } = image;
|
||||||
|
|
||||||
const [_hash, setHash] = useHash();
|
const [_hash, setHash] = useHash();
|
||||||
const [blendAnchor, setBlendAnchor] = useState<Maybe<HTMLElement>>();
|
const [blendAnchor, setBlendAnchor] = useState<Maybe<HTMLElement>>();
|
||||||
const [saveAnchor, setSaveAnchor] = useState<Maybe<HTMLElement>>();
|
const [saveAnchor, setSaveAnchor] = useState<Maybe<HTMLElement>>();
|
||||||
|
|
||||||
|
const client = mustExist(useContext(ClientContext));
|
||||||
const config = mustExist(useContext(ConfigContext));
|
const config = mustExist(useContext(ConfigContext));
|
||||||
const store = mustExist(useContext(StateContext));
|
const store = mustExist(useContext(StateContext));
|
||||||
const { setBlend, setImg2Img, setInpaint, setUpscale } = useStore(store, selectActions, shallow);
|
const { setBlend, setImg2Img, setInpaint, setUpscale } = useStore(store, selectActions, shallow);
|
||||||
|
|
||||||
async function loadSource() {
|
async function loadSource() {
|
||||||
const req = await fetch(outputs[index].url);
|
const req = await fetch(url);
|
||||||
return req.blob();
|
return req.blob();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,12 +86,12 @@ export function ImageCard(props: ImageCardProps) {
|
||||||
}
|
}
|
||||||
|
|
||||||
function downloadImage() {
|
function downloadImage() {
|
||||||
window.open(outputs[index].url, '_blank');
|
window.open(url, '_blank');
|
||||||
close();
|
close();
|
||||||
}
|
}
|
||||||
|
|
||||||
function downloadMetadata() {
|
function downloadMetadata() {
|
||||||
window.open(outputs[index].url + '.json', '_blank');
|
window.open(url + '.json', '_blank');
|
||||||
close();
|
close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,14 +108,16 @@ export function ImageCard(props: ImageCardProps) {
|
||||||
return mustDefault(t(`${key}.${name}`), name);
|
return mustDefault(t(`${key}.${name}`), name);
|
||||||
}
|
}
|
||||||
|
|
||||||
const model = getLabel('model', params.model);
|
const url = useMemo(() => client.outputURL(image, index), [image, index]);
|
||||||
const scheduler = getLabel('scheduler', params.scheduler);
|
|
||||||
|
const model = getLabel('model', metadata[index].model);
|
||||||
|
const scheduler = getLabel('scheduler', metadata[index].scheduler);
|
||||||
|
|
||||||
return <Card sx={{ maxWidth: config.params.width.default }} elevation={2}>
|
return <Card sx={{ maxWidth: config.params.width.default }} elevation={2}>
|
||||||
<CardMedia sx={{ height: config.params.height.default }}
|
<CardMedia sx={{ height: config.params.height.default }}
|
||||||
component='img'
|
component='img'
|
||||||
image={outputs[index].url}
|
image={url}
|
||||||
title={params.prompt}
|
title={metadata[index].params.prompt}
|
||||||
/>
|
/>
|
||||||
<CardContent>
|
<CardContent>
|
||||||
<Box textAlign='center'>
|
<Box textAlign='center'>
|
||||||
|
@ -146,12 +150,12 @@ export function ImageCard(props: ImageCardProps) {
|
||||||
</GridItem>
|
</GridItem>
|
||||||
<GridItem xs={4}>{t('modelType.diffusion', {count: 1})}: {model}</GridItem>
|
<GridItem xs={4}>{t('modelType.diffusion', {count: 1})}: {model}</GridItem>
|
||||||
<GridItem xs={4}>{t('parameter.scheduler')}: {scheduler}</GridItem>
|
<GridItem xs={4}>{t('parameter.scheduler')}: {scheduler}</GridItem>
|
||||||
<GridItem xs={4}>{t('parameter.seed')}: {params.seed}</GridItem>
|
<GridItem xs={4}>{t('parameter.seed')}: {metadata[index].params.seed}</GridItem>
|
||||||
<GridItem xs={4}>{t('parameter.cfg')}: {params.cfg}</GridItem>
|
<GridItem xs={4}>{t('parameter.cfg')}: {metadata[index].params.cfg}</GridItem>
|
||||||
<GridItem xs={4}>{t('parameter.steps')}: {params.steps}</GridItem>
|
<GridItem xs={4}>{t('parameter.steps')}: {metadata[index].params.steps}</GridItem>
|
||||||
<GridItem xs={4}>{t('parameter.size')}: {size.width}x{size.height}</GridItem>
|
<GridItem xs={4}>{t('parameter.size')}: {metadata[index].size.width}x{metadata[index].size.height}</GridItem>
|
||||||
<GridItem xs={12}>
|
<GridItem xs={12}>
|
||||||
<Box textAlign='left'>{params.prompt}</Box>
|
<Box textAlign='left'>{metadata[index].params.prompt}</Box>
|
||||||
</GridItem>
|
</GridItem>
|
||||||
<GridItem xs={2}>
|
<GridItem xs={2}>
|
||||||
<Tooltip title={t('tooltip.save')}>
|
<Tooltip title={t('tooltip.save')}>
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import { doesExist, mustExist } from '@apextoaster/js-utils';
|
import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils';
|
||||||
import { Box, Button, Card, CardContent, CircularProgress, Typography } from '@mui/material';
|
import { Box, Button, Card, CardContent, CircularProgress, Typography } from '@mui/material';
|
||||||
import { Stack } from '@mui/system';
|
import { Stack } from '@mui/system';
|
||||||
import { useMutation, useQuery } from '@tanstack/react-query';
|
import { useMutation, useQuery } from '@tanstack/react-query';
|
||||||
|
@ -10,19 +10,17 @@ import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { POLL_TIME } from '../../config.js';
|
import { POLL_TIME } from '../../config.js';
|
||||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
import { ImageResponse } from '../../types/api.js';
|
import { JobResponse, JobStatus } from '../../types/api-v2.js';
|
||||||
|
|
||||||
const LOADING_PERCENT = 100;
|
const LOADING_PERCENT = 100;
|
||||||
const LOADING_OVERAGE = 99;
|
const LOADING_OVERAGE = 99;
|
||||||
|
|
||||||
export interface LoadingCardProps {
|
export interface LoadingCardProps {
|
||||||
image: ImageResponse;
|
image: JobResponse;
|
||||||
index: number;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function LoadingCard(props: LoadingCardProps) {
|
export function LoadingCard(props: LoadingCardProps) {
|
||||||
const { image, index } = props;
|
const { image } = props;
|
||||||
const { steps } = props.image.params;
|
|
||||||
|
|
||||||
const client = mustExist(useContext(ClientContext));
|
const client = mustExist(useContext(ClientContext));
|
||||||
const { params } = mustExist(useContext(ConfigContext));
|
const { params } = mustExist(useContext(ConfigContext));
|
||||||
|
@ -31,50 +29,22 @@ export function LoadingCard(props: LoadingCardProps) {
|
||||||
const { removeHistory, setReady } = useStore(store, selectActions, shallow);
|
const { removeHistory, setReady } = useStore(store, selectActions, shallow);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const cancel = useMutation(() => client.cancel(image.outputs[index].key));
|
const cancel = useMutation(() => client.cancel([image.name]));
|
||||||
const ready = useQuery(['ready', image.outputs[index].key], () => client.ready(image.outputs[index].key), {
|
const ready = useQuery(['ready', image.name], () => client.status([image.name]), {
|
||||||
// data will always be ready without this, even if the API says its not
|
// data will always be ready without this, even if the API says its not
|
||||||
cacheTime: 0,
|
cacheTime: 0,
|
||||||
refetchInterval: POLL_TIME,
|
refetchInterval: POLL_TIME,
|
||||||
});
|
});
|
||||||
|
|
||||||
function getProgress() {
|
|
||||||
if (doesExist(ready.data)) {
|
|
||||||
return ready.data.progress;
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
function getPercent() {
|
|
||||||
const progress = getProgress();
|
|
||||||
if (progress > steps) {
|
|
||||||
// steps was not complete, show 99% until done
|
|
||||||
return LOADING_OVERAGE;
|
|
||||||
}
|
|
||||||
|
|
||||||
const pct = progress / steps;
|
|
||||||
return Math.ceil(pct * LOADING_PERCENT);
|
|
||||||
}
|
|
||||||
|
|
||||||
function getTotal() {
|
|
||||||
const progress = getProgress();
|
|
||||||
if (progress > steps) {
|
|
||||||
// steps was not complete, show 99% until done
|
|
||||||
return t('loading.unknown');
|
|
||||||
}
|
|
||||||
|
|
||||||
return steps.toFixed(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
function getReady() {
|
function getReady() {
|
||||||
return doesExist(ready.data) && ready.data.ready;
|
return doesExist(ready.data) && ready.data[0].status === JobStatus.SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
function renderProgress() {
|
function renderProgress() {
|
||||||
const progress = getProgress();
|
const progress = getProgress(ready.data);
|
||||||
if (progress > 0 && progress <= steps) {
|
const total = getTotal(ready.data);
|
||||||
return <CircularProgress variant='determinate' value={getPercent()} />;
|
if (progress > 0 && progress <= total) {
|
||||||
|
return <CircularProgress variant='determinate' value={getPercent(progress, total)} />;
|
||||||
} else {
|
} else {
|
||||||
return <CircularProgress />;
|
return <CircularProgress />;
|
||||||
}
|
}
|
||||||
|
@ -88,9 +58,9 @@ export function LoadingCard(props: LoadingCardProps) {
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (ready.status === 'success' && getReady()) {
|
if (ready.status === 'success' && getReady()) {
|
||||||
setReady(props.image, ready.data);
|
setReady(ready.data[0]);
|
||||||
}
|
}
|
||||||
}, [ready.status, getReady(), getProgress()]);
|
}, [ready.status, getReady(), getProgress(ready.data)]);
|
||||||
|
|
||||||
return <Card sx={{ maxWidth: params.width.default }}>
|
return <Card sx={{ maxWidth: params.width.default }}>
|
||||||
<CardContent sx={{ height: params.height.default }}>
|
<CardContent sx={{ height: params.height.default }}>
|
||||||
|
@ -106,10 +76,7 @@ export function LoadingCard(props: LoadingCardProps) {
|
||||||
sx={{ alignItems: 'center' }}
|
sx={{ alignItems: 'center' }}
|
||||||
>
|
>
|
||||||
{renderProgress()}
|
{renderProgress()}
|
||||||
<Typography>{t('loading.progress', {
|
<Typography>{t('loading.progress', selectStatus(ready.data, image))}</Typography>
|
||||||
current: getProgress(),
|
|
||||||
total: getTotal(),
|
|
||||||
})}</Typography>
|
|
||||||
<Button onClick={() => cancel.mutate()}>{t('loading.cancel')}</Button>
|
<Button onClick={() => cancel.mutate()}>{t('loading.cancel')}</Button>
|
||||||
</Stack>
|
</Stack>
|
||||||
</Box>
|
</Box>
|
||||||
|
@ -125,3 +92,45 @@ export function selectActions(state: OnnxState) {
|
||||||
setReady: state.setReady,
|
setReady: state.setReady,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function selectStatus(data: Maybe<Array<JobResponse>>, defaultData: JobResponse) {
|
||||||
|
if (doesExist(data) && data.length > 0) {
|
||||||
|
return {
|
||||||
|
steps: data[0].steps,
|
||||||
|
stages: data[0].stages,
|
||||||
|
tiles: data[0].tiles,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
steps: defaultData.steps,
|
||||||
|
stages: defaultData.stages,
|
||||||
|
tiles: defaultData.tiles,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getPercent(current: number, total: number): number {
|
||||||
|
if (current > total) {
|
||||||
|
// steps was not complete, show 99% until done
|
||||||
|
return LOADING_OVERAGE;
|
||||||
|
}
|
||||||
|
|
||||||
|
const pct = current / total;
|
||||||
|
return Math.ceil(pct * LOADING_PERCENT);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getProgress(data: Maybe<Array<JobResponse>>) {
|
||||||
|
if (doesExist(data)) {
|
||||||
|
return data[0].steps.current;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getTotal(data: Maybe<Array<JobResponse>>) {
|
||||||
|
if (doesExist(data)) {
|
||||||
|
return data[0].steps.total;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
|
@ -20,13 +20,13 @@ import { MaskCanvas } from '../input/MaskCanvas.js';
|
||||||
export function Blend() {
|
export function Blend() {
|
||||||
async function uploadSource() {
|
async function uploadSource() {
|
||||||
const { blend, blendModel, blendUpscale } = store.getState();
|
const { blend, blendModel, blendUpscale } = store.getState();
|
||||||
const { image, retry } = await client.blend(blendModel, {
|
const { job, retry } = await client.blend(blendModel, {
|
||||||
...blend,
|
...blend,
|
||||||
mask: mustExist(blend.mask),
|
mask: mustExist(blend.mask),
|
||||||
sources: mustExist(blend.sources), // TODO: show an error if this doesn't exist
|
sources: mustExist(blend.sources), // TODO: show an error if this doesn't exist
|
||||||
}, blendUpscale);
|
}, blendUpscale);
|
||||||
|
|
||||||
pushHistory(image, retry);
|
pushHistory(job, retry);
|
||||||
}
|
}
|
||||||
|
|
||||||
const client = mustExist(useContext(ClientContext));
|
const client = mustExist(useContext(ClientContext));
|
||||||
|
|
|
@ -27,12 +27,12 @@ export function Img2Img() {
|
||||||
const state = store.getState();
|
const state = store.getState();
|
||||||
const img2img = selectParams(state);
|
const img2img = selectParams(state);
|
||||||
|
|
||||||
const { image, retry } = await client.img2img(model, {
|
const { job, retry } = await client.img2img(model, {
|
||||||
...img2img,
|
...img2img,
|
||||||
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
|
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
|
||||||
}, selectUpscale(state), selectHighres(state));
|
}, selectUpscale(state), selectHighres(state));
|
||||||
|
|
||||||
pushHistory(image, retry);
|
pushHistory(job, retry);
|
||||||
}
|
}
|
||||||
|
|
||||||
const client = mustExist(useContext(ClientContext));
|
const client = mustExist(useContext(ClientContext));
|
||||||
|
|
|
@ -39,22 +39,22 @@ export function Inpaint() {
|
||||||
const inpaint = selectParams(state);
|
const inpaint = selectParams(state);
|
||||||
|
|
||||||
if (outpaint.enabled) {
|
if (outpaint.enabled) {
|
||||||
const { image, retry } = await client.outpaint(model, {
|
const { job, retry } = await client.outpaint(model, {
|
||||||
...inpaint,
|
...inpaint,
|
||||||
...outpaint,
|
...outpaint,
|
||||||
mask: mustExist(mask),
|
mask: mustExist(mask),
|
||||||
source: mustExist(source),
|
source: mustExist(source),
|
||||||
}, selectUpscale(state), selectHighres(state));
|
}, selectUpscale(state), selectHighres(state));
|
||||||
|
|
||||||
pushHistory(image, retry);
|
pushHistory(job, retry);
|
||||||
} else {
|
} else {
|
||||||
const { image, retry } = await client.inpaint(model, {
|
const { job, retry } = await client.inpaint(model, {
|
||||||
...inpaint,
|
...inpaint,
|
||||||
mask: mustExist(mask),
|
mask: mustExist(mask),
|
||||||
source: mustExist(source),
|
source: mustExist(source),
|
||||||
}, selectUpscale(state), selectHighres(state));
|
}, selectUpscale(state), selectHighres(state));
|
||||||
|
|
||||||
pushHistory(image, retry);
|
pushHistory(job, retry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -69,8 +69,8 @@ export function Txt2Img() {
|
||||||
const image = await client.chain(model, chain);
|
const image = await client.chain(model, chain);
|
||||||
pushHistory(image);
|
pushHistory(image);
|
||||||
} else {
|
} else {
|
||||||
const { image, retry } = await client.txt2img(model, params2, upscale, highres);
|
const { job, retry } = await client.txt2img(model, params2, upscale, highres);
|
||||||
pushHistory(image, retry);
|
pushHistory(job, retry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,12 +21,12 @@ import { PromptInput } from '../input/PromptInput.js';
|
||||||
export function Upscale() {
|
export function Upscale() {
|
||||||
async function uploadSource() {
|
async function uploadSource() {
|
||||||
const { upscaleHighres, upscaleUpscale, upscaleModel, upscale } = store.getState();
|
const { upscaleHighres, upscaleUpscale, upscaleModel, upscale } = store.getState();
|
||||||
const { image, retry } = await client.upscale(upscaleModel, {
|
const { job, retry } = await client.upscale(upscaleModel, {
|
||||||
...upscale,
|
...upscale,
|
||||||
source: mustExist(upscale.source), // TODO: show an error if this doesn't exist
|
source: mustExist(upscale.source), // TODO: show an error if this doesn't exist
|
||||||
}, upscaleUpscale, upscaleHighres);
|
}, upscaleUpscale, upscaleHighres);
|
||||||
|
|
||||||
pushHistory(image, retry);
|
pushHistory(job, retry);
|
||||||
}
|
}
|
||||||
|
|
||||||
const client = mustExist(useContext(ClientContext));
|
const client = mustExist(useContext(ClientContext));
|
||||||
|
|
|
@ -2,6 +2,7 @@ import { Maybe } from '@apextoaster/js-utils';
|
||||||
import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js';
|
import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js';
|
||||||
import { Slice } from './types.js';
|
import { Slice } from './types.js';
|
||||||
import { DEFAULT_HISTORY } from '../constants.js';
|
import { DEFAULT_HISTORY } from '../constants.js';
|
||||||
|
import { JobResponse } from '../types/api-v2.js';
|
||||||
|
|
||||||
export interface HistoryItem {
|
export interface HistoryItem {
|
||||||
image: ImageResponse;
|
image: ImageResponse;
|
||||||
|
@ -9,14 +10,19 @@ export interface HistoryItem {
|
||||||
retry: Maybe<RetryParams>;
|
retry: Maybe<RetryParams>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface HistoryItemV2 {
|
||||||
|
image: JobResponse;
|
||||||
|
retry: Maybe<RetryParams>;
|
||||||
|
}
|
||||||
|
|
||||||
export interface HistorySlice {
|
export interface HistorySlice {
|
||||||
history: Array<HistoryItem>;
|
history: Array<HistoryItemV2>;
|
||||||
limit: number;
|
limit: number;
|
||||||
|
|
||||||
pushHistory(image: ImageResponse, retry?: RetryParams): void;
|
pushHistory(image: JobResponse, retry?: RetryParams): void;
|
||||||
removeHistory(image: ImageResponse): void;
|
removeHistory(image: JobResponse): void;
|
||||||
setLimit(limit: number): void;
|
setLimit(limit: number): void;
|
||||||
setReady(image: ImageResponse, ready: ReadyResponse): void;
|
setReady(image: JobResponse): void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function createHistorySlice<TState extends HistorySlice>(): Slice<TState, HistorySlice> {
|
export function createHistorySlice<TState extends HistorySlice>(): Slice<TState, HistorySlice> {
|
||||||
|
@ -39,7 +45,7 @@ export function createHistorySlice<TState extends HistorySlice>(): Slice<TState,
|
||||||
removeHistory(image) {
|
removeHistory(image) {
|
||||||
set((prev) => ({
|
set((prev) => ({
|
||||||
...prev,
|
...prev,
|
||||||
history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key),
|
history: prev.history.filter((it) => it.image.name !== image.name),
|
||||||
}));
|
}));
|
||||||
},
|
},
|
||||||
setLimit(limit) {
|
setLimit(limit) {
|
||||||
|
@ -48,12 +54,12 @@ export function createHistorySlice<TState extends HistorySlice>(): Slice<TState,
|
||||||
limit,
|
limit,
|
||||||
}));
|
}));
|
||||||
},
|
},
|
||||||
setReady(image, ready) {
|
setReady(image) {
|
||||||
set((prev) => {
|
set((prev) => {
|
||||||
const history = [...prev.history];
|
const history = [...prev.history];
|
||||||
const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key);
|
const idx = history.findIndex((it) => it.image.name === image.name);
|
||||||
if (idx >= 0) {
|
if (idx >= 0) {
|
||||||
history[idx].ready = ready;
|
history[idx].image = image;
|
||||||
} else {
|
} else {
|
||||||
// TODO: error
|
// TODO: error
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,7 +67,7 @@ export const I18N_STRINGS_EN = {
|
||||||
},
|
},
|
||||||
loading: {
|
loading: {
|
||||||
cancel: 'Cancel',
|
cancel: 'Cancel',
|
||||||
progress: '{{current}} of {{total}} steps',
|
progress: '{{steps.current}} of {{steps.total}} steps, {{tiles.current}} of {{tiles.total}} tiles, {{stages.current}} of {{stages.total}} stages',
|
||||||
server: 'Connecting to server...',
|
server: 'Connecting to server...',
|
||||||
unknown: 'many',
|
unknown: 'many',
|
||||||
},
|
},
|
||||||
|
|
|
@ -0,0 +1,160 @@
|
||||||
|
import { RetryParams } from './api.js';
|
||||||
|
import { BaseImgParams, HighresParams, Img2ImgParams, InpaintParams, Txt2ImgParams, UpscaleParams } from './params.js';
|
||||||
|
|
||||||
|
export interface Progress {
|
||||||
|
current: number;
|
||||||
|
total: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface Size {
|
||||||
|
width: number;
|
||||||
|
height: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface NetworkMetadata {
|
||||||
|
name: string;
|
||||||
|
hash: string;
|
||||||
|
weight: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ImageMetadata<TParams extends BaseImgParams, TType extends JobType> {
|
||||||
|
input_size: Size;
|
||||||
|
size: Size;
|
||||||
|
outputs: Array<string>;
|
||||||
|
params: TParams;
|
||||||
|
inversions: Array<NetworkMetadata>;
|
||||||
|
loras: Array<NetworkMetadata>;
|
||||||
|
model: string;
|
||||||
|
scheduler: string;
|
||||||
|
border: unknown;
|
||||||
|
highres: HighresParams;
|
||||||
|
upscale: UpscaleParams;
|
||||||
|
type: TType;
|
||||||
|
}
|
||||||
|
|
||||||
|
export enum JobStatus {
|
||||||
|
PENDING = 'pending',
|
||||||
|
RUNNING = 'running',
|
||||||
|
SUCCESS = 'success',
|
||||||
|
FAILED = 'failed',
|
||||||
|
CANCELLED = 'cancelled',
|
||||||
|
UNKNOWN = 'unknown',
|
||||||
|
}
|
||||||
|
|
||||||
|
export enum JobType {
|
||||||
|
TXT2IMG = 'txt2img',
|
||||||
|
IMG2IMG = 'img2img',
|
||||||
|
INPAINT = 'inpaint',
|
||||||
|
UPSCALE = 'upscale',
|
||||||
|
BLEND = 'blend',
|
||||||
|
CHAIN = 'chain',
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface BaseJobResponse {
|
||||||
|
name: string;
|
||||||
|
status: JobStatus;
|
||||||
|
type: JobType;
|
||||||
|
|
||||||
|
stages: Progress;
|
||||||
|
steps: Progress;
|
||||||
|
tiles: Progress;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Pending image job.
|
||||||
|
*/
|
||||||
|
export interface PendingJobResponse extends BaseJobResponse {
|
||||||
|
status: JobStatus.PENDING | JobStatus.RUNNING;
|
||||||
|
queue: Progress;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Failed image job with error information.
|
||||||
|
*/
|
||||||
|
export interface FailedJobResponse extends BaseJobResponse {
|
||||||
|
status: JobStatus.FAILED;
|
||||||
|
error: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Successful txt2img image job with output keys and metadata.
|
||||||
|
*/
|
||||||
|
export interface SuccessTxt2ImgJobResponse extends BaseJobResponse {
|
||||||
|
status: JobStatus.SUCCESS;
|
||||||
|
type: JobType.TXT2IMG;
|
||||||
|
outputs: Array<string>;
|
||||||
|
metadata: Array<ImageMetadata<Txt2ImgParams, JobType.TXT2IMG>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Successful img2img job with output keys and metadata.
|
||||||
|
*/
|
||||||
|
export interface SuccessImg2ImgJobResponse extends BaseJobResponse {
|
||||||
|
status: JobStatus.SUCCESS;
|
||||||
|
type: JobType.IMG2IMG;
|
||||||
|
outputs: Array<string>;
|
||||||
|
metadata: Array<ImageMetadata<Img2ImgParams, JobType.IMG2IMG>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Successful inpaint job with output keys and metadata.
|
||||||
|
*/
|
||||||
|
export interface SuccessInpaintJobResponse extends BaseJobResponse {
|
||||||
|
status: JobStatus.SUCCESS;
|
||||||
|
type: JobType.INPAINT;
|
||||||
|
outputs: Array<string>;
|
||||||
|
metadata: Array<ImageMetadata<InpaintParams, JobType.INPAINT>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Successful upscale job with output keys and metadata.
|
||||||
|
*/
|
||||||
|
export interface SuccessUpscaleJobResponse extends BaseJobResponse {
|
||||||
|
status: JobStatus.SUCCESS;
|
||||||
|
type: JobType.UPSCALE;
|
||||||
|
outputs: Array<string>;
|
||||||
|
metadata: Array<ImageMetadata<BaseImgParams, JobType.UPSCALE>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Successful blend job with output keys and metadata.
|
||||||
|
*/
|
||||||
|
export interface SuccessBlendJobResponse extends BaseJobResponse {
|
||||||
|
status: JobStatus.SUCCESS;
|
||||||
|
type: JobType.BLEND;
|
||||||
|
outputs: Array<string>;
|
||||||
|
metadata: Array<ImageMetadata<BaseImgParams, JobType.BLEND>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Successful chain pipeline job with output keys and metadata.
|
||||||
|
*/
|
||||||
|
export interface SuccessChainJobResponse extends BaseJobResponse {
|
||||||
|
status: JobStatus.SUCCESS;
|
||||||
|
type: JobType.CHAIN;
|
||||||
|
outputs: Array<string>;
|
||||||
|
metadata: Array<ImageMetadata<BaseImgParams, JobType>>; // TODO: could be all kinds
|
||||||
|
}
|
||||||
|
|
||||||
|
export type SuccessJobResponse
|
||||||
|
= SuccessTxt2ImgJobResponse
|
||||||
|
| SuccessImg2ImgJobResponse
|
||||||
|
| SuccessInpaintJobResponse
|
||||||
|
| SuccessUpscaleJobResponse
|
||||||
|
| SuccessBlendJobResponse
|
||||||
|
| SuccessChainJobResponse;
|
||||||
|
|
||||||
|
export type JobResponse = PendingJobResponse | FailedJobResponse | SuccessJobResponse;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Status response from the job endpoint, with parameters to retry the job if it fails.
|
||||||
|
*/
|
||||||
|
export interface JobResponseWithRetry {
|
||||||
|
job: JobResponse;
|
||||||
|
retry: RetryParams;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Re-export `RetryParams` for convenience.
|
||||||
|
*/
|
||||||
|
export { RetryParams };
|
|
@ -14,6 +14,8 @@ import {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Output image data within the response.
|
* Output image data within the response.
|
||||||
|
*
|
||||||
|
* @deprecated
|
||||||
*/
|
*/
|
||||||
export interface ImageOutput {
|
export interface ImageOutput {
|
||||||
key: string;
|
key: string;
|
||||||
|
@ -22,6 +24,8 @@ export interface ImageOutput {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* General response for most image requests.
|
* General response for most image requests.
|
||||||
|
*
|
||||||
|
* @deprecated
|
||||||
*/
|
*/
|
||||||
export interface ImageResponse {
|
export interface ImageResponse {
|
||||||
outputs: Array<ImageOutput>;
|
outputs: Array<ImageOutput>;
|
||||||
|
@ -119,11 +123,19 @@ export type RetryParams = {
|
||||||
upscale?: UpscaleParams;
|
upscale?: UpscaleParams;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Status response from the image endpoint, with parameters to retry the job if it fails.
|
||||||
|
*
|
||||||
|
* @deprecated
|
||||||
|
*/
|
||||||
export interface ImageResponseWithRetry {
|
export interface ImageResponseWithRetry {
|
||||||
image: ImageResponse;
|
image: ImageResponse;
|
||||||
retry: RetryParams;
|
retry: RetryParams;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated
|
||||||
|
*/
|
||||||
export interface ImageMetadata {
|
export interface ImageMetadata {
|
||||||
highres: HighresParams;
|
highres: HighresParams;
|
||||||
outputs: string | Array<string>;
|
outputs: string | Array<string>;
|
||||||
|
|
|
@ -43,6 +43,7 @@
|
||||||
"dtype",
|
"dtype",
|
||||||
"ESRGAN",
|
"ESRGAN",
|
||||||
"Exif",
|
"Exif",
|
||||||
|
"fromarray",
|
||||||
"ftfy",
|
"ftfy",
|
||||||
"gfpgan",
|
"gfpgan",
|
||||||
"Heun",
|
"Heun",
|
||||||
|
@ -115,6 +116,7 @@
|
||||||
"webp",
|
"webp",
|
||||||
"xformers",
|
"xformers",
|
||||||
"zustand"
|
"zustand"
|
||||||
]
|
],
|
||||||
|
"git.ignoreLimitWarning": true
|
||||||
}
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue