feat: add batch endpoints for cancel and status, update responses
This commit is contained in:
parent
19c91f70f5
commit
44a8d61082
|
@ -36,15 +36,14 @@ class CorrectCodeformerStage(BaseStage):
|
|||
# https://pypi.org/project/codeformer-perceptor/
|
||||
|
||||
# import must be within the load function for patches to take effect
|
||||
# TODO: rewrite and remove
|
||||
from codeformer.basicsr.archs.codeformer_arch import CodeFormer
|
||||
from codeformer.basicsr.utils import img2tensor, tensor2img
|
||||
from codeformer.basicsr.utils.registry import ARCH_REGISTRY
|
||||
from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
|
||||
upscale = upscale.with_args(**kwargs)
|
||||
device = worker.get_device()
|
||||
|
||||
net = ARCH_REGISTRY.get("CodeFormer")(
|
||||
net = CodeFormer(
|
||||
dim_embd=512,
|
||||
codebook_size=1024,
|
||||
n_head=8,
|
||||
|
|
|
@ -1,10 +1,28 @@
|
|||
from typing import Any, List, Optional, Tuple
|
||||
from json import dumps
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..output import json_params
|
||||
from ..convert.utils import resolve_tensor
|
||||
from ..params import Border, HighresParams, ImageParams, Size, UpscaleParams
|
||||
from ..server.load import get_extra_hashes
|
||||
from ..utils import hash_file
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class NetworkMetadata:
|
||||
name: str
|
||||
hash: str
|
||||
weight: float
|
||||
|
||||
def __init__(self, name: str, hash: str, weight: float) -> None:
|
||||
self.name = name
|
||||
self.hash = hash
|
||||
self.weight = weight
|
||||
|
||||
|
||||
class ImageMetadata:
|
||||
|
@ -13,8 +31,9 @@ class ImageMetadata:
|
|||
params: ImageParams
|
||||
size: Size
|
||||
upscale: UpscaleParams
|
||||
inversions: Optional[List[Tuple[str, float]]]
|
||||
loras: Optional[List[Tuple[str, float]]]
|
||||
inversions: Optional[List[NetworkMetadata]]
|
||||
loras: Optional[List[NetworkMetadata]]
|
||||
models: Optional[List[NetworkMetadata]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -23,8 +42,9 @@ class ImageMetadata:
|
|||
upscale: Optional[UpscaleParams] = None,
|
||||
border: Optional[Border] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
inversions: Optional[List[Tuple[str, float]]] = None,
|
||||
loras: Optional[List[Tuple[str, float]]] = None,
|
||||
inversions: Optional[List[NetworkMetadata]] = None,
|
||||
loras: Optional[List[NetworkMetadata]] = None,
|
||||
models: Optional[List[NetworkMetadata]] = None,
|
||||
) -> None:
|
||||
self.params = params
|
||||
self.size = size
|
||||
|
@ -33,19 +53,108 @@ class ImageMetadata:
|
|||
self.highres = highres
|
||||
self.inversions = inversions
|
||||
self.loras = loras
|
||||
self.models = models
|
||||
|
||||
def to_auto1111(self, server, outputs) -> str:
|
||||
model_name = path.basename(path.normpath(self.params.model))
|
||||
logger.debug("getting model hash for %s", model_name)
|
||||
|
||||
model_hash = get_extra_hashes().get(model_name, None)
|
||||
if model_hash is None:
|
||||
model_hash_path = path.join(self.params.model, "hash.txt")
|
||||
if path.exists(model_hash_path):
|
||||
with open(model_hash_path, "r") as f:
|
||||
model_hash = f.readline().rstrip(",. \n\t\r")
|
||||
|
||||
model_hash = model_hash or "unknown"
|
||||
hash_map = {
|
||||
model_name: model_hash,
|
||||
}
|
||||
|
||||
inversion_hashes = ""
|
||||
if self.inversions is not None:
|
||||
inversion_pairs = [
|
||||
(
|
||||
name,
|
||||
hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "inversion", name))
|
||||
).upper(),
|
||||
)
|
||||
for name, _weight in self.inversions
|
||||
]
|
||||
inversion_hashes = ",".join(
|
||||
[f"{name}: {hash}" for name, hash in inversion_pairs]
|
||||
)
|
||||
hash_map.update(dict(inversion_pairs))
|
||||
|
||||
lora_hashes = ""
|
||||
if self.loras is not None:
|
||||
lora_pairs = [
|
||||
(
|
||||
name,
|
||||
hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "lora", name))
|
||||
).upper(),
|
||||
)
|
||||
for name, _weight in self.loras
|
||||
]
|
||||
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
|
||||
hash_map.update(dict(lora_pairs))
|
||||
|
||||
return (
|
||||
f"{self.params.prompt or ''}\nNegative prompt: {self.params.negative_prompt or ''}\n"
|
||||
f"Steps: {self.params.steps}, Sampler: {self.params.scheduler}, CFG scale: {self.params.cfg}, "
|
||||
f"Seed: {self.params.seed}, Size: {self.size.width}x{self.size.height}, "
|
||||
f"Model hash: {model_hash}, Model: {model_name}, "
|
||||
f"Tool: onnx-web, Version: {server.server_version}, "
|
||||
f'Inversion hashes: "{inversion_hashes}", '
|
||||
f'Lora hashes: "{lora_hashes}", '
|
||||
f"Hashes: {dumps(hash_map)}"
|
||||
)
|
||||
|
||||
def tojson(self, server, outputs):
|
||||
return json_params(
|
||||
server,
|
||||
outputs,
|
||||
self.params,
|
||||
self.size,
|
||||
upscale=self.upscale,
|
||||
border=self.border,
|
||||
highres=self.highres,
|
||||
inversions=self.inversions,
|
||||
loras=self.loras,
|
||||
)
|
||||
json = {
|
||||
"input_size": self.size.tojson(),
|
||||
"outputs": outputs,
|
||||
"params": self.params.tojson(),
|
||||
"inversions": {},
|
||||
"loras": {},
|
||||
}
|
||||
|
||||
json["params"]["model"] = path.basename(self.params.model)
|
||||
json["params"]["scheduler"] = self.params.scheduler # TODO: why tho?
|
||||
|
||||
# calculate final output size
|
||||
output_size = self.size
|
||||
if self.border is not None:
|
||||
json["border"] = self.border.tojson()
|
||||
output_size = output_size.add_border(self.border)
|
||||
|
||||
if self.highres is not None:
|
||||
json["highres"] = self.highres.tojson()
|
||||
output_size = self.highres.resize(output_size)
|
||||
|
||||
if self.upscale is not None:
|
||||
json["upscale"] = self.upscale.tojson()
|
||||
output_size = self.upscale.resize(output_size)
|
||||
|
||||
json["size"] = output_size.tojson()
|
||||
|
||||
if self.inversions is not None:
|
||||
for name, weight in self.inversions:
|
||||
hash = hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "inversion", name))
|
||||
).upper()
|
||||
json["inversions"][name] = {"weight": weight, "hash": hash}
|
||||
|
||||
if self.loras is not None:
|
||||
for name, weight in self.loras:
|
||||
hash = hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "lora", name))
|
||||
).upper()
|
||||
json["loras"][name] = {"weight": weight, "hash": hash}
|
||||
|
||||
return json
|
||||
|
||||
|
||||
class StageResult:
|
||||
|
@ -86,6 +195,7 @@ class StageResult:
|
|||
self.arrays = arrays
|
||||
self.images = images
|
||||
self.source = source
|
||||
self.metadata = []
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self.arrays is not None:
|
||||
|
@ -117,7 +227,7 @@ class StageResult:
|
|||
elif self.images is not None:
|
||||
self.images.append(Image.fromarray(np.uint8(array), shape_mode(array)))
|
||||
else:
|
||||
raise ValueError("invalid stage result")
|
||||
self.arrays = [array]
|
||||
|
||||
if metadata is not None:
|
||||
self.metadata.append(metadata)
|
||||
|
@ -130,13 +240,45 @@ class StageResult:
|
|||
elif self.arrays is not None:
|
||||
self.arrays.append(np.array(image))
|
||||
else:
|
||||
raise ValueError("invalid stage result")
|
||||
self.images = [image]
|
||||
|
||||
if metadata is not None:
|
||||
self.metadata.append(metadata)
|
||||
else:
|
||||
self.metadata.append(ImageMetadata())
|
||||
|
||||
def insert_array(
|
||||
self, index: int, array: np.ndarray, metadata: Optional[ImageMetadata]
|
||||
):
|
||||
if self.arrays is not None:
|
||||
self.arrays.insert(index, array)
|
||||
elif self.images is not None:
|
||||
self.images.insert(
|
||||
index, Image.fromarray(np.uint8(array), shape_mode(array))
|
||||
)
|
||||
else:
|
||||
self.arrays = [array]
|
||||
|
||||
if metadata is not None:
|
||||
self.metadata.insert(index, metadata)
|
||||
else:
|
||||
self.metadata.insert(index, ImageMetadata())
|
||||
|
||||
def insert_image(
|
||||
self, index: int, image: Image.Image, metadata: Optional[ImageMetadata]
|
||||
):
|
||||
if self.images is not None:
|
||||
self.images.insert(index, image)
|
||||
elif self.arrays is not None:
|
||||
self.arrays.insert(index, np.array(image))
|
||||
else:
|
||||
self.images = [image]
|
||||
|
||||
if metadata is not None:
|
||||
self.metadata.insert(index, metadata)
|
||||
else:
|
||||
self.metadata.insert(index, ImageMetadata())
|
||||
|
||||
|
||||
def shape_mode(arr: np.ndarray) -> str:
|
||||
if len(arr.shape) != 3:
|
||||
|
|
|
@ -16,7 +16,7 @@ from ..chain.highres import stage_highres
|
|||
from ..chain.result import StageResult
|
||||
from ..chain.upscale import split_upscale, stage_upscale_correction
|
||||
from ..image import expand_image
|
||||
from ..output import save_image
|
||||
from ..output import save_image, save_result
|
||||
from ..params import (
|
||||
Border,
|
||||
HighresParams,
|
||||
|
@ -29,7 +29,7 @@ from ..server import ServerContext
|
|||
from ..server.load import get_source_filters
|
||||
from ..utils import is_debug, run_gc, show_system_toast
|
||||
from ..worker import WorkerContext
|
||||
from .utils import get_latents_from_seed, parse_prompt
|
||||
from .utils import get_latents_from_seed
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -57,7 +57,6 @@ def run_txt2img_pipeline(
|
|||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
outputs: List[str],
|
||||
upscale: UpscaleParams,
|
||||
highres: HighresParams,
|
||||
) -> None:
|
||||
|
@ -114,50 +113,34 @@ def run_txt2img_pipeline(
|
|||
# run and save
|
||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain.run(
|
||||
images = chain(
|
||||
worker, server, params, StageResult.empty(), callback=progress, latents=latents
|
||||
)
|
||||
|
||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||
|
||||
# add a thumbnail, if requested
|
||||
cover = images[0]
|
||||
cover = images.as_image()[0]
|
||||
if params.thumbnail and (
|
||||
cover.width > server.thumbnail_size or cover.height > server.thumbnail_size
|
||||
):
|
||||
thumbnail = cover.copy()
|
||||
thumbnail.thumbnail((server.thumbnail_size, server.thumbnail_size))
|
||||
|
||||
images.insert(0, thumbnail)
|
||||
outputs.insert(0, f"{worker.name}-thumb.{server.image_format}")
|
||||
images.insert_image(0, thumbnail)
|
||||
|
||||
for image, output in zip(images, outputs):
|
||||
logger.trace("saving output image %s: %s", output, image.size)
|
||||
dest = save_image(
|
||||
server,
|
||||
output,
|
||||
image,
|
||||
params,
|
||||
size,
|
||||
upscale=upscale,
|
||||
highres=highres,
|
||||
inversions=inversions,
|
||||
loras=loras,
|
||||
)
|
||||
save_result(server, images, worker.job)
|
||||
|
||||
# clean up
|
||||
run_gc([worker.get_device()])
|
||||
|
||||
# notify the user
|
||||
show_system_toast(f"finished txt2img job: {dest}")
|
||||
logger.info("finished txt2img job: %s", dest)
|
||||
show_system_toast(f"finished txt2img job: {worker.job}")
|
||||
logger.info("finished txt2img job: %s", worker.job)
|
||||
|
||||
|
||||
def run_img2img_pipeline(
|
||||
worker: WorkerContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
outputs: List[str],
|
||||
upscale: UpscaleParams,
|
||||
highres: HighresParams,
|
||||
source: Image.Image,
|
||||
|
@ -228,36 +211,21 @@ def run_img2img_pipeline(
|
|||
|
||||
# run and append the filtered source
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain.run(
|
||||
images = chain(
|
||||
worker, server, params, StageResult(images=[source]), callback=progress
|
||||
)
|
||||
|
||||
if source_filter is not None and source_filter != "none":
|
||||
images.append(source)
|
||||
images.push_image(source)
|
||||
|
||||
# save with metadata
|
||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||
size = Size(*source.size)
|
||||
|
||||
for image, output in zip(images, outputs):
|
||||
dest = save_image(
|
||||
server,
|
||||
output,
|
||||
image,
|
||||
params,
|
||||
size,
|
||||
upscale=upscale,
|
||||
highres=highres,
|
||||
inversions=inversions,
|
||||
loras=loras,
|
||||
)
|
||||
save_result(server, images, worker.job)
|
||||
|
||||
# clean up
|
||||
run_gc([worker.get_device()])
|
||||
|
||||
# notify the user
|
||||
show_system_toast(f"finished img2img job: {dest}")
|
||||
logger.info("finished img2img job: %s", dest)
|
||||
show_system_toast(f"finished img2img job: {worker.job}")
|
||||
logger.info("finished img2img job: %s", worker.job)
|
||||
|
||||
|
||||
def run_inpaint_pipeline(
|
||||
|
@ -265,7 +233,6 @@ def run_inpaint_pipeline(
|
|||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
outputs: List[str],
|
||||
upscale: UpscaleParams,
|
||||
highres: HighresParams,
|
||||
source: Image.Image,
|
||||
|
@ -290,7 +257,7 @@ def run_inpaint_pipeline(
|
|||
mask = ImageOps.contain(mask, (mask_max, mask_max))
|
||||
mask = mask.crop((0, 0, source.width, source.height))
|
||||
|
||||
source, mask, noise, full_size = expand_image(
|
||||
source, mask, noise, _full_size = expand_image(
|
||||
source,
|
||||
mask,
|
||||
border,
|
||||
|
@ -414,7 +381,7 @@ def run_inpaint_pipeline(
|
|||
# run and save
|
||||
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain.run(
|
||||
images = chain(
|
||||
worker,
|
||||
server,
|
||||
params,
|
||||
|
@ -423,33 +390,28 @@ def run_inpaint_pipeline(
|
|||
latents=latents,
|
||||
)
|
||||
|
||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||
for image, output in zip(images, outputs):
|
||||
for i, image, metadata in enumerate(zip(images.as_image(), images.metadata)):
|
||||
if full_res_inpaint:
|
||||
if is_debug():
|
||||
save_image(server, "adjusted-output.png", image)
|
||||
|
||||
mini_image = ImageOps.contain(image, (adj_mask_size, adj_mask_size))
|
||||
image = original_source
|
||||
image.paste(mini_image, box=adj_mask_border)
|
||||
dest = save_image(
|
||||
|
||||
save_image(
|
||||
server,
|
||||
output,
|
||||
f"{worker.job}_{i}.{server.image_format}",
|
||||
image,
|
||||
params,
|
||||
size,
|
||||
upscale=upscale,
|
||||
border=border,
|
||||
inversions=inversions,
|
||||
loras=loras,
|
||||
metadata,
|
||||
)
|
||||
|
||||
# clean up
|
||||
del image
|
||||
run_gc([worker.get_device()])
|
||||
|
||||
# notify the user
|
||||
show_system_toast(f"finished inpaint job: {dest}")
|
||||
logger.info("finished inpaint job: %s", dest)
|
||||
show_system_toast(f"finished inpaint job: {worker.job}")
|
||||
logger.info("finished inpaint job: %s", worker.job)
|
||||
|
||||
|
||||
def run_upscale_pipeline(
|
||||
|
@ -457,7 +419,6 @@ def run_upscale_pipeline(
|
|||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
outputs: List[str],
|
||||
upscale: UpscaleParams,
|
||||
highres: HighresParams,
|
||||
source: Image.Image,
|
||||
|
@ -497,30 +458,18 @@ def run_upscale_pipeline(
|
|||
|
||||
# run and save
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain.run(
|
||||
images = chain(
|
||||
worker, server, params, StageResult(images=[source]), callback=progress
|
||||
)
|
||||
|
||||
_pairs, loras, inversions, _rest = parse_prompt(params)
|
||||
for image, output in zip(images, outputs):
|
||||
dest = save_image(
|
||||
server,
|
||||
output,
|
||||
image,
|
||||
params,
|
||||
size,
|
||||
upscale=upscale,
|
||||
inversions=inversions,
|
||||
loras=loras,
|
||||
)
|
||||
save_result(server, images, worker.job)
|
||||
|
||||
# clean up
|
||||
del image
|
||||
run_gc([worker.get_device()])
|
||||
|
||||
# notify the user
|
||||
show_system_toast(f"finished upscale job: {dest}")
|
||||
logger.info("finished upscale job: %s", dest)
|
||||
show_system_toast(f"finished upscale job: {worker.job}")
|
||||
logger.info("finished upscale job: %s", worker.job)
|
||||
|
||||
|
||||
def run_blend_pipeline(
|
||||
|
@ -528,7 +477,6 @@ def run_blend_pipeline(
|
|||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
outputs: List[str],
|
||||
upscale: UpscaleParams,
|
||||
# highres: HighresParams,
|
||||
sources: List[Image.Image],
|
||||
|
@ -559,17 +507,15 @@ def run_blend_pipeline(
|
|||
|
||||
# run and save
|
||||
progress = worker.get_progress_callback()
|
||||
images = chain.run(
|
||||
images = chain(
|
||||
worker, server, params, StageResult(images=sources), callback=progress
|
||||
)
|
||||
|
||||
for image, output in zip(images, outputs):
|
||||
dest = save_image(server, output, image, params, size, upscale=upscale)
|
||||
save_result(server, images, worker.job)
|
||||
|
||||
# clean up
|
||||
del image
|
||||
run_gc([worker.get_device()])
|
||||
|
||||
# notify the user
|
||||
show_system_toast(f"finished blend job: {dest}")
|
||||
logger.info("finished blend job: %s", dest)
|
||||
show_system_toast(f"finished blend job: {worker.job}")
|
||||
logger.info("finished blend job: %s", worker.job)
|
||||
|
|
|
@ -1,173 +1,20 @@
|
|||
from hashlib import sha256
|
||||
from json import dumps
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from struct import pack
|
||||
from time import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
from piexif import ExifIFD, ImageIFD, dump
|
||||
from piexif.helper import UserComment
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from .convert.utils import resolve_tensor
|
||||
from .params import Border, HighresParams, ImageParams, Param, Size, UpscaleParams
|
||||
from .chain.result import ImageMetadata, StageResult
|
||||
from .params import ImageParams, Param, Size
|
||||
from .server import ServerContext
|
||||
from .server.load import get_extra_hashes
|
||||
from .utils import base_join
|
||||
from .utils import base_join, hash_value
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
HASH_BUFFER_SIZE = 2**22 # 4MB
|
||||
|
||||
|
||||
def hash_file(name: str):
|
||||
sha = sha256()
|
||||
with open(name, "rb") as f:
|
||||
while True:
|
||||
data = f.read(HASH_BUFFER_SIZE)
|
||||
if not data:
|
||||
break
|
||||
|
||||
sha.update(data)
|
||||
|
||||
return sha.hexdigest()
|
||||
|
||||
|
||||
def hash_value(sha, param: Optional[Param]):
|
||||
if param is None:
|
||||
return
|
||||
elif isinstance(param, bool):
|
||||
sha.update(bytearray(pack("!B", param)))
|
||||
elif isinstance(param, float):
|
||||
sha.update(bytearray(pack("!f", param)))
|
||||
elif isinstance(param, int):
|
||||
sha.update(bytearray(pack("!I", param)))
|
||||
elif isinstance(param, str):
|
||||
sha.update(param.encode("utf-8"))
|
||||
else:
|
||||
logger.warning("cannot hash param: %s, %s", param, type(param))
|
||||
|
||||
|
||||
def json_params(
|
||||
server: ServerContext,
|
||||
outputs: List[str],
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
upscale: Optional[UpscaleParams] = None,
|
||||
border: Optional[Border] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
inversions: Optional[List[Tuple[str, float]]] = None,
|
||||
loras: Optional[List[Tuple[str, float]]] = None,
|
||||
parent: Optional[Dict] = None,
|
||||
) -> Any:
|
||||
json = {
|
||||
"input_size": size.tojson(),
|
||||
"outputs": outputs,
|
||||
"params": params.tojson(),
|
||||
"inversions": {},
|
||||
"loras": {},
|
||||
}
|
||||
|
||||
json["params"]["model"] = path.basename(params.model)
|
||||
json["params"]["scheduler"] = params.scheduler
|
||||
|
||||
# calculate final output size
|
||||
output_size = size
|
||||
if border is not None:
|
||||
json["border"] = border.tojson()
|
||||
output_size = output_size.add_border(border)
|
||||
|
||||
if highres is not None:
|
||||
json["highres"] = highres.tojson()
|
||||
output_size = highres.resize(output_size)
|
||||
|
||||
if upscale is not None:
|
||||
json["upscale"] = upscale.tojson()
|
||||
output_size = upscale.resize(output_size)
|
||||
|
||||
json["size"] = output_size.tojson()
|
||||
|
||||
if inversions is not None:
|
||||
for name, weight in inversions:
|
||||
hash = hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "inversion", name))
|
||||
).upper()
|
||||
json["inversions"][name] = {"weight": weight, "hash": hash}
|
||||
|
||||
if loras is not None:
|
||||
for name, weight in loras:
|
||||
hash = hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "lora", name))
|
||||
).upper()
|
||||
json["loras"][name] = {"weight": weight, "hash": hash}
|
||||
|
||||
return json
|
||||
|
||||
|
||||
def str_params(
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
inversions: List[Tuple[str, float]] = None,
|
||||
loras: List[Tuple[str, float]] = None,
|
||||
) -> str:
|
||||
model_name = path.basename(path.normpath(params.model))
|
||||
logger.debug("getting model hash for %s", model_name)
|
||||
|
||||
model_hash = get_extra_hashes().get(model_name, None)
|
||||
if model_hash is None:
|
||||
model_hash_path = path.join(params.model, "hash.txt")
|
||||
if path.exists(model_hash_path):
|
||||
with open(model_hash_path, "r") as f:
|
||||
model_hash = f.readline().rstrip(",. \n\t\r")
|
||||
|
||||
model_hash = model_hash or "unknown"
|
||||
hash_map = {
|
||||
model_name: model_hash,
|
||||
}
|
||||
|
||||
inversion_hashes = ""
|
||||
if inversions is not None:
|
||||
inversion_pairs = [
|
||||
(
|
||||
name,
|
||||
hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "inversion", name))
|
||||
).upper(),
|
||||
)
|
||||
for name, _weight in inversions
|
||||
]
|
||||
inversion_hashes = ",".join(
|
||||
[f"{name}: {hash}" for name, hash in inversion_pairs]
|
||||
)
|
||||
hash_map.update(dict(inversion_pairs))
|
||||
|
||||
lora_hashes = ""
|
||||
if loras is not None:
|
||||
lora_pairs = [
|
||||
(
|
||||
name,
|
||||
hash_file(
|
||||
resolve_tensor(path.join(server.model_path, "lora", name))
|
||||
).upper(),
|
||||
)
|
||||
for name, _weight in loras
|
||||
]
|
||||
lora_hashes = ",".join([f"{name}: {hash}" for name, hash in lora_pairs])
|
||||
hash_map.update(dict(lora_pairs))
|
||||
|
||||
return (
|
||||
f"{params.prompt or ''}\nNegative prompt: {params.negative_prompt or ''}\n"
|
||||
f"Steps: {params.steps}, Sampler: {params.scheduler}, CFG scale: {params.cfg}, "
|
||||
f"Seed: {params.seed}, Size: {size.width}x{size.height}, "
|
||||
f"Model hash: {model_hash}, Model: {model_name}, "
|
||||
f"Tool: onnx-web, Version: {server.server_version}, "
|
||||
f'Inversion hashes: "{inversion_hashes}", '
|
||||
f'Lora hashes: "{lora_hashes}", '
|
||||
f"Hashes: {dumps(hash_map)}"
|
||||
)
|
||||
|
||||
|
||||
def make_output_name(
|
||||
server: ServerContext,
|
||||
|
@ -179,6 +26,19 @@ def make_output_name(
|
|||
offset: int = 0,
|
||||
) -> List[str]:
|
||||
count = count or params.batch
|
||||
job_name = make_job_name(mode, params, size, extras)
|
||||
|
||||
return [
|
||||
f"{job_name}_{i}.{server.image_format}" for i in range(offset, count + offset)
|
||||
]
|
||||
|
||||
|
||||
def make_job_name(
|
||||
mode: str,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
extras: Optional[List[Optional[Param]]] = None,
|
||||
) -> str:
|
||||
now = int(time())
|
||||
sha = sha256()
|
||||
|
||||
|
@ -200,49 +60,49 @@ def make_output_name(
|
|||
for param in extras:
|
||||
hash_value(sha, param)
|
||||
|
||||
return [
|
||||
f"{mode}_{params.seed}_{sha.hexdigest()}_{now}_{i}.{server.image_format}"
|
||||
for i in range(offset, count + offset)
|
||||
]
|
||||
return f"{mode}_{params.seed}_{sha.hexdigest()}_{now}"
|
||||
|
||||
|
||||
def save_result(
|
||||
server: ServerContext,
|
||||
result: StageResult,
|
||||
base_name: str,
|
||||
) -> List[str]:
|
||||
results = []
|
||||
for i, image, metadata in enumerate(zip(result.as_image(), result.metadata)):
|
||||
results.append(
|
||||
save_image(
|
||||
server,
|
||||
base_name + f"_{i}.{server.image_format}",
|
||||
image,
|
||||
metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def save_image(
|
||||
server: ServerContext,
|
||||
output: str,
|
||||
image: Image.Image,
|
||||
params: Optional[ImageParams] = None,
|
||||
size: Optional[Size] = None,
|
||||
upscale: Optional[UpscaleParams] = None,
|
||||
border: Optional[Border] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
inversions: List[Tuple[str, float]] = None,
|
||||
loras: List[Tuple[str, float]] = None,
|
||||
metadata: ImageMetadata,
|
||||
) -> str:
|
||||
path = base_join(server.output_path, output)
|
||||
|
||||
if server.image_format == "png":
|
||||
exif = PngImagePlugin.PngInfo()
|
||||
|
||||
if params is not None:
|
||||
if metadata is not None:
|
||||
exif.add_text("make", "onnx-web")
|
||||
exif.add_text(
|
||||
"maker note",
|
||||
dumps(
|
||||
json_params(
|
||||
server,
|
||||
[output],
|
||||
params,
|
||||
size,
|
||||
upscale=upscale,
|
||||
border=border,
|
||||
highres=highres,
|
||||
)
|
||||
),
|
||||
dumps(metadata.tojson(server, [output])),
|
||||
)
|
||||
exif.add_text("model", server.server_version)
|
||||
exif.add_text(
|
||||
"parameters",
|
||||
str_params(server, params, size, inversions=inversions, loras=loras),
|
||||
metadata.to_auto1111(server, [output]),
|
||||
)
|
||||
|
||||
image.save(path, format=server.image_format, pnginfo=exif)
|
||||
|
@ -251,23 +111,11 @@ def save_image(
|
|||
{
|
||||
"0th": {
|
||||
ExifIFD.MakerNote: UserComment.dump(
|
||||
dumps(
|
||||
json_params(
|
||||
server,
|
||||
[output],
|
||||
params,
|
||||
size,
|
||||
upscale=upscale,
|
||||
border=border,
|
||||
highres=highres,
|
||||
)
|
||||
),
|
||||
dumps(metadata.tojson(server, [output])),
|
||||
encoding="unicode",
|
||||
),
|
||||
ExifIFD.UserComment: UserComment.dump(
|
||||
str_params(
|
||||
server, params, size, inversions=inversions, loras=loras
|
||||
),
|
||||
metadata.to_auto1111(server, [output]),
|
||||
encoding="unicode",
|
||||
),
|
||||
ImageIFD.Make: "onnx-web",
|
||||
|
@ -277,34 +125,23 @@ def save_image(
|
|||
)
|
||||
image.save(path, format=server.image_format, exif=exif)
|
||||
|
||||
if params is not None:
|
||||
save_params(
|
||||
if metadata is not None:
|
||||
save_metadata(
|
||||
server,
|
||||
output,
|
||||
params,
|
||||
size,
|
||||
upscale=upscale,
|
||||
border=border,
|
||||
highres=highres,
|
||||
)
|
||||
|
||||
logger.debug("saved output image to: %s", path)
|
||||
return path
|
||||
|
||||
|
||||
def save_params(
|
||||
def save_metadata(
|
||||
server: ServerContext,
|
||||
output: str,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
upscale: Optional[UpscaleParams] = None,
|
||||
border: Optional[Border] = None,
|
||||
highres: Optional[HighresParams] = None,
|
||||
metadata: ImageMetadata,
|
||||
) -> str:
|
||||
path = base_join(server.output_path, f"{output}.json")
|
||||
json = json_params(
|
||||
server, output, params, size, upscale=upscale, border=border, highres=highres
|
||||
)
|
||||
json = metadata.tojson(server, [output])
|
||||
with open(path, "w") as f:
|
||||
f.write(dumps(json))
|
||||
logger.debug("saved image params to: %s", path)
|
||||
|
|
|
@ -13,6 +13,24 @@ Param = Union[str, int, float]
|
|||
Point = Tuple[int, int]
|
||||
|
||||
|
||||
class Progress:
|
||||
current: int
|
||||
total: int
|
||||
|
||||
def __init__(self, current: int, total: int) -> None:
|
||||
self.current = current
|
||||
self.total = total
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "%s/%s" % (self.current, self.total)
|
||||
|
||||
def tojson(self):
|
||||
return {
|
||||
"current": self.current,
|
||||
"total": self.total,
|
||||
}
|
||||
|
||||
|
||||
class SizeChart(IntEnum):
|
||||
micro = 64
|
||||
mini = 128 # small tile for very expensive models
|
||||
|
|
|
@ -26,14 +26,14 @@ def restart_workers(server: ServerContext, pool: DevicePoolExecutor):
|
|||
pool.recycle(recycle_all=True)
|
||||
logger.info("restarted worker pool")
|
||||
|
||||
return jsonify(pool.status())
|
||||
return jsonify(pool.summary())
|
||||
|
||||
|
||||
def worker_status(server: ServerContext, pool: DevicePoolExecutor):
|
||||
if not check_admin(server):
|
||||
return make_response(jsonify({})), 401
|
||||
|
||||
return jsonify(pool.status())
|
||||
return jsonify(pool.summary())
|
||||
|
||||
|
||||
def get_extra_models(server: ServerContext):
|
||||
|
@ -102,8 +102,8 @@ def register_admin_routes(app: Flask, server: ServerContext, pool: DevicePoolExe
|
|||
app.route("/api/extras", methods=["PUT"])(
|
||||
wrap_route(update_extra_models, server)
|
||||
),
|
||||
app.route("/api/restart", methods=["POST"])(
|
||||
app.route("/api/worker/restart", methods=["POST"])(
|
||||
wrap_route(restart_workers, server, pool=pool)
|
||||
),
|
||||
app.route("/api/status")(wrap_route(worker_status, server, pool=pool)),
|
||||
app.route("/api/worker/status")(wrap_route(worker_status, server, pool=pool)),
|
||||
]
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
from io import BytesIO
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from flask import Flask, jsonify, make_response, request, url_for
|
||||
from jsonschema import validate
|
||||
from PIL import Image
|
||||
|
||||
from ..chain import CHAIN_STAGES, ChainPipeline
|
||||
from ..chain.result import StageResult
|
||||
from ..chain.result import ImageMetadata, StageResult
|
||||
from ..diffusers.load import get_available_pipelines, get_pipeline_schedulers
|
||||
from ..diffusers.run import (
|
||||
run_blend_pipeline,
|
||||
|
@ -18,8 +18,8 @@ from ..diffusers.run import (
|
|||
run_upscale_pipeline,
|
||||
)
|
||||
from ..diffusers.utils import replace_wildcards
|
||||
from ..output import json_params, make_output_name
|
||||
from ..params import Size, StageParams, TileOrder
|
||||
from ..output import make_job_name
|
||||
from ..params import Progress, Size, StageParams, TileOrder
|
||||
from ..transformers.run import run_txt2txt_pipeline
|
||||
from ..utils import (
|
||||
base_join,
|
||||
|
@ -34,6 +34,7 @@ from ..utils import (
|
|||
load_config_str,
|
||||
sanitize_name,
|
||||
)
|
||||
from ..worker.command import JobType
|
||||
from ..worker.pool import DevicePoolExecutor
|
||||
from .context import ServerContext
|
||||
from .load import (
|
||||
|
@ -92,6 +93,64 @@ def error_reply(err: str):
|
|||
return response
|
||||
|
||||
|
||||
def job_reply(name: str):
|
||||
return jsonify(
|
||||
{
|
||||
"name": name,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def image_reply(
|
||||
name: str,
|
||||
status: str,
|
||||
job_type: str,
|
||||
stages: Progress = None,
|
||||
steps: Progress = None,
|
||||
tiles: Progress = None,
|
||||
outputs: List[str] = None,
|
||||
metadata: List[ImageMetadata] = None,
|
||||
):
|
||||
if stages is None:
|
||||
stages = Progress()
|
||||
|
||||
if steps is None:
|
||||
steps = Progress()
|
||||
|
||||
if tiles is None:
|
||||
tiles = Progress()
|
||||
|
||||
data = {
|
||||
"name": name,
|
||||
"status": status,
|
||||
"type": job_type,
|
||||
"stages": stages.tojson(),
|
||||
"steps": steps.tojson(),
|
||||
"tiles": tiles.tojson(),
|
||||
}
|
||||
|
||||
if len(metadata) != len(outputs):
|
||||
logger.error("metadata and outputs must be the same length")
|
||||
return error_reply("metadata and outputs must be the same length")
|
||||
|
||||
if outputs is not None:
|
||||
data["outputs"] = outputs
|
||||
|
||||
if metadata is not None:
|
||||
data["metadata"] = metadata
|
||||
|
||||
return jsonify(data)
|
||||
|
||||
|
||||
def multi_image_reply(results: Dict[str, Any]):
|
||||
# TODO: not that
|
||||
return jsonify(
|
||||
{
|
||||
"results": results,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def url_from_rule(rule) -> str:
|
||||
options = {}
|
||||
for arg in rule.arguments:
|
||||
|
@ -197,17 +256,15 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
)
|
||||
output_count += 1
|
||||
|
||||
output = make_output_name(
|
||||
job_name = make_job_name(
|
||||
server, "img2img", params, size, extras=[strength], count=output_count
|
||||
)
|
||||
|
||||
job_name = output[0]
|
||||
pool.submit(
|
||||
job_name,
|
||||
JobType.IMG2IMG,
|
||||
run_img2img_pipeline,
|
||||
server,
|
||||
params,
|
||||
output,
|
||||
upscale,
|
||||
highres,
|
||||
source,
|
||||
|
@ -218,9 +275,7 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("img2img job queued for: %s", job_name)
|
||||
|
||||
return jsonify(
|
||||
json_params(server, output, params, size, upscale=upscale, highres=highres)
|
||||
)
|
||||
return job_reply(job_name)
|
||||
|
||||
|
||||
def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -230,16 +285,15 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
output = make_output_name(server, "txt2img", params, size, count=params.batch)
|
||||
job_name = make_job_name(server, "txt2img", params, size, count=params.batch)
|
||||
|
||||
job_name = output[0]
|
||||
pool.submit(
|
||||
job_name,
|
||||
JobType.TXT2IMG,
|
||||
run_txt2img_pipeline,
|
||||
server,
|
||||
params,
|
||||
size,
|
||||
output,
|
||||
upscale,
|
||||
highres,
|
||||
needs_device=device,
|
||||
|
@ -247,9 +301,7 @@ def txt2img(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("txt2img job queued for: %s", job_name)
|
||||
|
||||
return jsonify(
|
||||
json_params(server, output, params, size, upscale=upscale, highres=highres)
|
||||
)
|
||||
return job_reply(job_name)
|
||||
|
||||
|
||||
def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -295,7 +347,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
output = make_output_name(
|
||||
job_name = make_job_name(
|
||||
server,
|
||||
"inpaint",
|
||||
params,
|
||||
|
@ -312,14 +364,13 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
|||
],
|
||||
)
|
||||
|
||||
job_name = output[0]
|
||||
pool.submit(
|
||||
job_name,
|
||||
JobType.INPAINT,
|
||||
run_inpaint_pipeline,
|
||||
server,
|
||||
params,
|
||||
size,
|
||||
output,
|
||||
upscale,
|
||||
highres,
|
||||
source,
|
||||
|
@ -336,17 +387,7 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("inpaint job queued for: %s", job_name)
|
||||
|
||||
return jsonify(
|
||||
json_params(
|
||||
server,
|
||||
output,
|
||||
params,
|
||||
size,
|
||||
upscale=upscale,
|
||||
border=expand,
|
||||
highres=highres,
|
||||
)
|
||||
)
|
||||
return job_reply(job_name)
|
||||
|
||||
|
||||
def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -362,16 +403,14 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
replace_wildcards(params, get_wildcard_data())
|
||||
|
||||
output = make_output_name(server, "upscale", params, size)
|
||||
|
||||
job_name = output[0]
|
||||
job_name = make_job_name(server, "upscale", params, size)
|
||||
pool.submit(
|
||||
job_name,
|
||||
JobType.UPSCALE,
|
||||
run_upscale_pipeline,
|
||||
server,
|
||||
params,
|
||||
size,
|
||||
output,
|
||||
upscale,
|
||||
highres,
|
||||
source,
|
||||
|
@ -380,9 +419,7 @@ def upscale(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("upscale job queued for: %s", job_name)
|
||||
|
||||
return jsonify(
|
||||
json_params(server, output, params, size, upscale=upscale, highres=highres)
|
||||
)
|
||||
return job_reply(job_name)
|
||||
|
||||
|
||||
# keys that are specially parsed by params and should not show up in with_args
|
||||
|
@ -478,25 +515,21 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
||||
|
||||
output = make_output_name(
|
||||
server, "chain", base_params, base_size, count=pipeline.outputs(base_params, 0)
|
||||
)
|
||||
job_name = output[0]
|
||||
job_name = make_job_name(server, "chain", base_params, base_size)
|
||||
|
||||
# build and run chain pipeline
|
||||
pool.submit(
|
||||
job_name,
|
||||
JobType.CHAIN,
|
||||
pipeline,
|
||||
server,
|
||||
base_params,
|
||||
StageResult.empty(),
|
||||
output=output,
|
||||
size=base_size,
|
||||
needs_device=device,
|
||||
)
|
||||
|
||||
step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size))
|
||||
return jsonify(json_params(server, output, step_params, base_size))
|
||||
return job_reply(job_name)
|
||||
|
||||
|
||||
def blend(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -520,15 +553,14 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
|||
device, params, size = pipeline_from_request(server)
|
||||
upscale = build_upscale()
|
||||
|
||||
output = make_output_name(server, "upscale", params, size)
|
||||
job_name = output[0]
|
||||
job_name = make_job_name(server, "blend", params, size)
|
||||
pool.submit(
|
||||
job_name,
|
||||
JobType.BLEND,
|
||||
run_blend_pipeline,
|
||||
server,
|
||||
params,
|
||||
size,
|
||||
output,
|
||||
upscale,
|
||||
# TODO: highres
|
||||
sources,
|
||||
|
@ -538,27 +570,26 @@ def blend(server: ServerContext, pool: DevicePoolExecutor):
|
|||
|
||||
logger.info("upscale job queued for: %s", job_name)
|
||||
|
||||
return jsonify(json_params(server, output, params, size, upscale=upscale))
|
||||
return job_reply(job_name)
|
||||
|
||||
|
||||
def txt2txt(server: ServerContext, pool: DevicePoolExecutor):
|
||||
device, params, size = pipeline_from_request(server)
|
||||
|
||||
output = make_output_name(server, "txt2txt", params, size)
|
||||
job_name = output[0]
|
||||
job_name = make_job_name(server, "txt2txt", params, size)
|
||||
logger.info("upscale job queued for: %s", job_name)
|
||||
|
||||
pool.submit(
|
||||
job_name,
|
||||
JobType.TXT2TXT,
|
||||
run_txt2txt_pipeline,
|
||||
server,
|
||||
params,
|
||||
size,
|
||||
output,
|
||||
needs_device=device,
|
||||
)
|
||||
|
||||
return jsonify(json_params(server, output, params, size))
|
||||
return job_reply(job_name)
|
||||
|
||||
|
||||
def cancel(server: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -601,9 +632,64 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
|
|||
)
|
||||
|
||||
|
||||
def job_cancel(server: ServerContext, pool: DevicePoolExecutor):
|
||||
legacy_job_name = request.args.get("job", None)
|
||||
job_list = request.args.get("jobs", "").split(",")
|
||||
|
||||
if legacy_job_name is not None:
|
||||
job_list.append(legacy_job_name)
|
||||
|
||||
if len(job_list) == 0:
|
||||
return error_reply("at least one job name is required")
|
||||
|
||||
results = {}
|
||||
for job_name in job_list:
|
||||
job_name = sanitize_name(job_name)
|
||||
cancelled = pool.cancel(job_name)
|
||||
results[job_name] = cancelled
|
||||
|
||||
return multi_image_reply(results)
|
||||
|
||||
|
||||
def job_status(server: ServerContext, pool: DevicePoolExecutor):
|
||||
legacy_job_name = request.args.get("job", None)
|
||||
job_list = request.args.get("jobs", "").split(",")
|
||||
|
||||
if legacy_job_name is not None:
|
||||
job_list.append(legacy_job_name)
|
||||
|
||||
if len(job_list) == 0:
|
||||
return error_reply("at least one job name is required")
|
||||
|
||||
for job_name in job_list:
|
||||
job_name = sanitize_name(job_name)
|
||||
status, progress = pool.status(job_name)
|
||||
|
||||
# TODO: accumulate results
|
||||
if progress is not None:
|
||||
# TODO: add output paths based on progress.results counter
|
||||
return image_reply(
|
||||
job_name,
|
||||
status,
|
||||
"TODO",
|
||||
stages=Progress(progress.stages, 0),
|
||||
steps=Progress(progress.steps, 0),
|
||||
tiles=Progress(progress.tiles, 0),
|
||||
)
|
||||
|
||||
return image_reply(job_name, status, "TODO")
|
||||
|
||||
|
||||
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
|
||||
return [
|
||||
app.route("/api")(wrap_route(introspect, server, app=app)),
|
||||
# job routes
|
||||
app.route("/api/job", methods=["POST"])(wrap_route(chain, server, pool=pool)),
|
||||
app.route("/api/job/cancel", methods=["PUT"])(
|
||||
wrap_route(job_cancel, server, pool=pool)
|
||||
),
|
||||
app.route("/api/job/status")(wrap_route(job_status, server, pool=pool)),
|
||||
# settings routes
|
||||
app.route("/api/settings/filters")(wrap_route(list_filters, server)),
|
||||
app.route("/api/settings/masks")(wrap_route(list_mask_filters, server)),
|
||||
app.route("/api/settings/models")(wrap_route(list_models, server)),
|
||||
|
@ -614,6 +700,7 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu
|
|||
app.route("/api/settings/schedulers")(wrap_route(list_schedulers, server)),
|
||||
app.route("/api/settings/strings")(wrap_route(list_extra_strings, server)),
|
||||
app.route("/api/settings/wildcards")(wrap_route(list_wildcards, server)),
|
||||
# legacy job routes
|
||||
app.route("/api/img2img", methods=["POST"])(
|
||||
wrap_route(img2img, server, pool=pool)
|
||||
),
|
||||
|
@ -631,6 +718,7 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu
|
|||
),
|
||||
app.route("/api/chain", methods=["POST"])(wrap_route(chain, server, pool=pool)),
|
||||
app.route("/api/blend", methods=["POST"])(wrap_route(blend, server, pool=pool)),
|
||||
# deprecated routes
|
||||
app.route("/api/cancel", methods=["PUT"])(
|
||||
wrap_route(cancel, server, pool=pool)
|
||||
),
|
||||
|
|
|
@ -12,7 +12,6 @@ def run_txt2txt_pipeline(
|
|||
_server: ServerContext,
|
||||
params: ImageParams,
|
||||
_size: Size,
|
||||
output: str,
|
||||
) -> None:
|
||||
from transformers import AutoTokenizer, GPTJForCausalLM
|
||||
|
||||
|
@ -38,4 +37,4 @@ def run_txt2txt_pipeline(
|
|||
|
||||
print("Server says: %s" % result_text)
|
||||
|
||||
logger.info("finished txt2txt job: %s", output)
|
||||
logger.info("finished txt2txt job: %s", worker.job)
|
||||
|
|
|
@ -2,16 +2,18 @@ import gc
|
|||
import importlib
|
||||
import json
|
||||
import threading
|
||||
from hashlib import sha256
|
||||
from json import JSONDecodeError
|
||||
from logging import getLogger
|
||||
from os import environ, path
|
||||
from platform import system
|
||||
from struct import pack
|
||||
from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from yaml import safe_load
|
||||
|
||||
from .params import DeviceParams, SizeChart
|
||||
from .params import DeviceParams, Param, SizeChart
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -218,3 +220,34 @@ def load_config_str(raw: str) -> Dict:
|
|||
return json.loads(raw)
|
||||
except JSONDecodeError:
|
||||
return safe_load(raw)
|
||||
|
||||
|
||||
HASH_BUFFER_SIZE = 2**22 # 4MB
|
||||
|
||||
|
||||
def hash_file(name: str):
|
||||
sha = sha256()
|
||||
with open(name, "rb") as f:
|
||||
while True:
|
||||
data = f.read(HASH_BUFFER_SIZE)
|
||||
if not data:
|
||||
break
|
||||
|
||||
sha.update(data)
|
||||
|
||||
return sha.hexdigest()
|
||||
|
||||
|
||||
def hash_value(sha, param: Optional[Param]):
|
||||
if param is None:
|
||||
return
|
||||
elif isinstance(param, bool):
|
||||
sha.update(bytearray(pack("!B", param)))
|
||||
elif isinstance(param, float):
|
||||
sha.update(bytearray(pack("!f", param)))
|
||||
elif isinstance(param, int):
|
||||
sha.update(bytearray(pack("!I", param)))
|
||||
elif isinstance(param, str):
|
||||
sha.update(param.encode("utf-8"))
|
||||
else:
|
||||
logger.warning("cannot hash param: %s, %s", param, type(param))
|
||||
|
|
|
@ -1,34 +1,61 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class JobType(str, Enum):
|
||||
TXT2TXT = "txt2txt"
|
||||
TXT2IMG = "txt2img"
|
||||
IMG2IMG = "img2img"
|
||||
INPAINT = "inpaint"
|
||||
UPSCALE = "upscale"
|
||||
BLEND = "blend"
|
||||
CHAIN = "chain"
|
||||
|
||||
|
||||
class ProgressCommand:
|
||||
device: str
|
||||
job: str
|
||||
finished: bool
|
||||
progress: int
|
||||
cancelled: bool
|
||||
failed: bool
|
||||
job_type: str
|
||||
status: JobStatus
|
||||
results: int
|
||||
steps: int
|
||||
stages: int
|
||||
tiles: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
job: str,
|
||||
job_type: str,
|
||||
device: str,
|
||||
finished: bool,
|
||||
progress: int,
|
||||
cancelled: bool = False,
|
||||
failed: bool = False,
|
||||
status: JobStatus,
|
||||
results: int = 0,
|
||||
steps: int = 0,
|
||||
stages: int = 0,
|
||||
tiles: int = 0,
|
||||
):
|
||||
self.job = job
|
||||
self.job_type = job_type
|
||||
self.device = device
|
||||
self.finished = finished
|
||||
self.progress = progress
|
||||
self.cancelled = cancelled
|
||||
self.failed = failed
|
||||
self.status = status
|
||||
self.results = results
|
||||
self.steps = steps
|
||||
self.stages = stages
|
||||
self.tiles = tiles
|
||||
|
||||
|
||||
class JobCommand:
|
||||
device: str
|
||||
name: str
|
||||
job_type: str
|
||||
fn: Callable[..., None]
|
||||
args: Any
|
||||
kwargs: Dict[str, Any]
|
||||
|
@ -37,12 +64,14 @@ class JobCommand:
|
|||
self,
|
||||
name: str,
|
||||
device: str,
|
||||
job_type: str,
|
||||
fn: Callable[..., None],
|
||||
args: Any,
|
||||
kwargs: Dict[str, Any],
|
||||
):
|
||||
self.device = device
|
||||
self.name = name
|
||||
self.job_type = job_type
|
||||
self.fn = fn
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
|
|
@ -2,21 +2,23 @@ from logging import getLogger
|
|||
from os import getpid
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
from torch.multiprocessing import Queue, Value
|
||||
|
||||
from ..errors import CancelledException
|
||||
from ..params import DeviceParams
|
||||
from .command import JobCommand, ProgressCommand
|
||||
from .command import JobCommand, JobStatus, ProgressCommand
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
ProgressCallback = Callable[[int, int, Any], None]
|
||||
ProgressCallback = Callable[[int, int, np.ndarray], None]
|
||||
|
||||
|
||||
class WorkerContext:
|
||||
cancel: "Value[bool]"
|
||||
job: Optional[str]
|
||||
job_type: Optional[str]
|
||||
name: str
|
||||
pending: "Queue[JobCommand]"
|
||||
active_pid: "Value[int]"
|
||||
|
@ -41,6 +43,7 @@ class WorkerContext:
|
|||
timeout: float,
|
||||
):
|
||||
self.job = None
|
||||
self.job_type = None
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.cancel = cancel
|
||||
|
@ -54,9 +57,15 @@ class WorkerContext:
|
|||
self.retries = retries
|
||||
self.timeout = timeout
|
||||
|
||||
def start(self, job: str) -> None:
|
||||
self.job = job
|
||||
def start(self, job: JobCommand) -> None:
|
||||
# set job name and type
|
||||
self.job = job.name
|
||||
self.job_type = job.job_type
|
||||
|
||||
# reset retries
|
||||
self.retries = self.initial_retries
|
||||
|
||||
# clear flags
|
||||
self.set_cancel(cancel=False)
|
||||
self.set_idle(idle=False)
|
||||
|
||||
|
@ -81,7 +90,7 @@ class WorkerContext:
|
|||
|
||||
def get_progress(self) -> int:
|
||||
if self.last_progress is not None:
|
||||
return self.last_progress.progress
|
||||
return self.last_progress.steps
|
||||
|
||||
return 0
|
||||
|
||||
|
@ -112,13 +121,11 @@ class WorkerContext:
|
|||
logger.debug("setting progress for job %s to %s", self.job, progress)
|
||||
self.last_progress = ProgressCommand(
|
||||
self.job,
|
||||
self.job_type,
|
||||
self.device.device,
|
||||
False,
|
||||
progress,
|
||||
self.is_cancelled(),
|
||||
False,
|
||||
JobStatus.RUNNING,
|
||||
steps=progress,
|
||||
)
|
||||
|
||||
self.progress.put(
|
||||
self.last_progress,
|
||||
block=False,
|
||||
|
@ -131,11 +138,10 @@ class WorkerContext:
|
|||
logger.debug("setting finished for job %s", self.job)
|
||||
self.last_progress = ProgressCommand(
|
||||
self.job,
|
||||
self.job_type,
|
||||
self.device.device,
|
||||
True,
|
||||
self.get_progress(),
|
||||
self.is_cancelled(),
|
||||
False,
|
||||
JobStatus.SUCCESS, # TODO: FAILED
|
||||
steps=self.get_progress(),
|
||||
)
|
||||
self.progress.put(
|
||||
self.last_progress,
|
||||
|
@ -150,11 +156,10 @@ class WorkerContext:
|
|||
try:
|
||||
self.last_progress = ProgressCommand(
|
||||
self.job,
|
||||
self.job_type,
|
||||
self.device.device,
|
||||
True,
|
||||
self.get_progress(),
|
||||
self.is_cancelled(),
|
||||
True,
|
||||
JobStatus.FAILED,
|
||||
steps=self.get_progress(),
|
||||
)
|
||||
self.progress.put(
|
||||
self.last_progress,
|
||||
|
@ -162,25 +167,3 @@ class WorkerContext:
|
|||
)
|
||||
except Exception:
|
||||
logger.exception("error setting failure on job %s", self.job)
|
||||
|
||||
|
||||
class JobStatus:
|
||||
name: str
|
||||
device: str
|
||||
progress: int
|
||||
cancelled: bool
|
||||
finished: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
device: DeviceParams,
|
||||
progress: int = 0,
|
||||
cancelled: bool = False,
|
||||
finished: bool = False,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.device = device.device
|
||||
self.progress = progress
|
||||
self.cancelled = cancelled
|
||||
self.finished = finished
|
||||
|
|
|
@ -8,7 +8,7 @@ from torch.multiprocessing import Process, Queue, Value
|
|||
|
||||
from ..params import DeviceParams
|
||||
from ..server import ServerContext
|
||||
from .command import JobCommand, ProgressCommand
|
||||
from .command import JobCommand, JobStatus, ProgressCommand
|
||||
from .context import WorkerContext
|
||||
from .utils import Interval
|
||||
from .worker import worker_main
|
||||
|
@ -201,6 +201,10 @@ class DevicePoolExecutor:
|
|||
should be cancelled on the next progress callback.
|
||||
"""
|
||||
|
||||
if key in self.cancelled_jobs:
|
||||
logger.debug("cancelling already cancelled job: %s", key)
|
||||
return True
|
||||
|
||||
for job in self.finished_jobs:
|
||||
if job.job == key:
|
||||
logger.debug("cannot cancel finished job: %s", key)
|
||||
|
@ -209,6 +213,9 @@ class DevicePoolExecutor:
|
|||
for job in self.pending_jobs:
|
||||
if job.name == key:
|
||||
self.pending_jobs.remove(job)
|
||||
self.cancelled_jobs.append(
|
||||
key
|
||||
) # ensure workers never pick up this job and the status endpoint knows about it later
|
||||
logger.info("cancelled pending job: %s", key)
|
||||
return True
|
||||
|
||||
|
@ -221,28 +228,31 @@ class DevicePoolExecutor:
|
|||
self.cancelled_jobs.append(key)
|
||||
return True
|
||||
|
||||
def done(self, key: str) -> Tuple[bool, Optional[ProgressCommand]]:
|
||||
def status(self, key: str) -> Tuple[JobStatus, Optional[ProgressCommand]]:
|
||||
"""
|
||||
Check if a job has been finished and report the last progress update.
|
||||
|
||||
If the job is still pending, the first item will be True and there will be no ProgressCommand.
|
||||
"""
|
||||
|
||||
if key in self.cancelled_jobs:
|
||||
logger.debug("checking status for cancelled job: %s", key)
|
||||
return (JobStatus.CANCELLED, None)
|
||||
|
||||
if key in self.running_jobs:
|
||||
logger.debug("checking status for running job: %s", key)
|
||||
return (False, self.running_jobs[key])
|
||||
return (JobStatus.RUNNING, self.running_jobs[key])
|
||||
|
||||
for job in self.finished_jobs:
|
||||
if job.job == key:
|
||||
logger.debug("checking status for finished job: %s", key)
|
||||
return (False, job)
|
||||
return (job.status, job)
|
||||
|
||||
for job in self.pending_jobs:
|
||||
if job.name == key:
|
||||
logger.debug("checking status for pending job: %s", key)
|
||||
return (True, None)
|
||||
return (JobStatus.PENDING, None)
|
||||
|
||||
logger.trace("checking status for unknown job: %s", key)
|
||||
return (False, None)
|
||||
return (JobStatus.UNKNOWN, None)
|
||||
|
||||
def join(self):
|
||||
logger.info("stopping worker pool")
|
||||
|
@ -383,6 +393,7 @@ class DevicePoolExecutor:
|
|||
def submit(
|
||||
self,
|
||||
key: str,
|
||||
job_type: str,
|
||||
fn: Callable[..., None],
|
||||
/,
|
||||
*args,
|
||||
|
@ -399,56 +410,63 @@ class DevicePoolExecutor:
|
|||
)
|
||||
|
||||
# build and queue job
|
||||
job = JobCommand(key, device, fn, args, kwargs)
|
||||
job = JobCommand(key, device, job_type, fn, args, kwargs)
|
||||
self.pending_jobs.append(job)
|
||||
|
||||
def status(self) -> Dict[str, List[Tuple[str, int, bool, bool, bool, bool]]]:
|
||||
def summary(self) -> Dict[str, List[Tuple[str, int, JobStatus]]]:
|
||||
"""
|
||||
Returns a tuple of: job/device, progress, progress, finished, cancelled, failed
|
||||
"""
|
||||
return {
|
||||
"cancelled": [],
|
||||
"finished": [
|
||||
|
||||
jobs: Tuple[str, int, JobStatus] = []
|
||||
jobs.extend(
|
||||
[
|
||||
(
|
||||
job.job,
|
||||
job.progress,
|
||||
False,
|
||||
job.finished,
|
||||
job.cancelled,
|
||||
job.failed,
|
||||
job,
|
||||
0,
|
||||
JobStatus.CANCELLED,
|
||||
)
|
||||
for job in self.finished_jobs
|
||||
],
|
||||
"pending": [
|
||||
for job in self.cancelled_jobs
|
||||
]
|
||||
)
|
||||
jobs.extend(
|
||||
[
|
||||
(
|
||||
job.name,
|
||||
0,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
JobStatus.PENDING,
|
||||
)
|
||||
for job in self.pending_jobs
|
||||
],
|
||||
"running": [
|
||||
]
|
||||
)
|
||||
jobs.extend(
|
||||
[
|
||||
(
|
||||
name,
|
||||
job.progress,
|
||||
False,
|
||||
job.finished,
|
||||
job.cancelled,
|
||||
job.failed,
|
||||
job.steps,
|
||||
job.status,
|
||||
)
|
||||
for name, job in self.running_jobs.items()
|
||||
],
|
||||
"total": [
|
||||
]
|
||||
)
|
||||
jobs.extend(
|
||||
[
|
||||
(
|
||||
job.job,
|
||||
job.steps,
|
||||
job.status,
|
||||
)
|
||||
for job in self.finished_jobs
|
||||
]
|
||||
)
|
||||
|
||||
return {
|
||||
"jobs": jobs,
|
||||
"workers": [
|
||||
(
|
||||
device,
|
||||
total,
|
||||
self.workers[device].is_alive(),
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
for device, total in self.total_jobs.items()
|
||||
],
|
||||
|
@ -476,20 +494,18 @@ class DevicePoolExecutor:
|
|||
self.cancelled_jobs.remove(progress.job)
|
||||
|
||||
def update_job(self, progress: ProgressCommand):
|
||||
if progress.finished:
|
||||
if progress.status in [JobStatus.SUCCESS, JobStatus.FAILED]:
|
||||
return self.finish_job(progress)
|
||||
|
||||
# move from pending to running
|
||||
logger.debug(
|
||||
"progress update for job: %s to %s", progress.job, progress.progress
|
||||
)
|
||||
logger.debug("progress update for job: %s to %s", progress.job, progress.steps)
|
||||
self.running_jobs[progress.job] = progress
|
||||
self.pending_jobs[:] = [
|
||||
job for job in self.pending_jobs if job.name != progress.job
|
||||
]
|
||||
|
||||
# increment job counter if this is the start of a new job
|
||||
if progress.progress == 0:
|
||||
if progress.steps == 0:
|
||||
if progress.device in self.total_jobs:
|
||||
self.total_jobs[progress.device] += 1
|
||||
else:
|
||||
|
|
|
@ -57,7 +57,7 @@ def worker_main(
|
|||
logger.info("worker %s got job: %s", worker.device.device, job.name)
|
||||
|
||||
# clear flags and save the job name
|
||||
worker.start(job.name)
|
||||
worker.start(job)
|
||||
logger.info("starting job: %s", job.name)
|
||||
|
||||
# reset progress, which does a final check for cancellation
|
||||
|
|
|
@ -1,12 +1,6 @@
|
|||
from diffusers import OnnxStableDiffusionPipeline
|
||||
from os import path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
import time
|
||||
|
||||
cfg = 8
|
||||
steps = 22
|
||||
height = 512
|
||||
|
|
|
@ -22,6 +22,7 @@ from onnx_web.params import (
|
|||
UpscaleParams,
|
||||
)
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.command import JobCommand
|
||||
from onnx_web.worker.context import WorkerContext
|
||||
from tests.helpers import (
|
||||
TEST_MODEL_DIFFUSION_SD15,
|
||||
|
@ -57,7 +58,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
3,
|
||||
0.1,
|
||||
)
|
||||
worker.start("test")
|
||||
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
|
||||
|
||||
run_txt2img_pipeline(
|
||||
worker,
|
||||
|
@ -72,7 +73,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
1,
|
||||
),
|
||||
Size(256, 256),
|
||||
["test-txt2img-basic.png"],
|
||||
UpscaleParams("test"),
|
||||
HighresParams(False, 1, 0, 0),
|
||||
)
|
||||
|
@ -103,7 +103,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
3,
|
||||
0.1,
|
||||
)
|
||||
worker.start("test")
|
||||
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
|
||||
|
||||
run_txt2img_pipeline(
|
||||
worker,
|
||||
|
@ -119,7 +119,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
batch=2,
|
||||
),
|
||||
Size(256, 256),
|
||||
["test-txt2img-batch-0.png", "test-txt2img-batch-1.png"],
|
||||
UpscaleParams("test"),
|
||||
HighresParams(False, 1, 0, 0),
|
||||
)
|
||||
|
@ -152,7 +151,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
3,
|
||||
0.1,
|
||||
)
|
||||
worker.start("test")
|
||||
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
|
||||
|
||||
run_txt2img_pipeline(
|
||||
worker,
|
||||
|
@ -168,7 +167,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
unet_tile=256,
|
||||
),
|
||||
Size(256, 256),
|
||||
["test-txt2img-highres.png"],
|
||||
UpscaleParams("test", scale=2, outscale=2),
|
||||
HighresParams(True, 2, 0, 0),
|
||||
)
|
||||
|
@ -198,7 +196,7 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
3,
|
||||
0.1,
|
||||
)
|
||||
worker.start("test")
|
||||
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
|
||||
|
||||
run_txt2img_pipeline(
|
||||
worker,
|
||||
|
@ -214,7 +212,6 @@ class TestTxt2ImgPipeline(unittest.TestCase):
|
|||
batch=2,
|
||||
),
|
||||
Size(256, 256),
|
||||
["test-txt2img-highres-batch-0.png", "test-txt2img-highres-batch-1.png"],
|
||||
UpscaleParams("test"),
|
||||
HighresParams(True, 2, 0, 0),
|
||||
)
|
||||
|
@ -230,7 +227,7 @@ class TestImg2ImgPipeline(unittest.TestCase):
|
|||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15])
|
||||
def test_basic(self):
|
||||
worker = test_worker()
|
||||
worker.start("test")
|
||||
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
|
||||
|
||||
source = Image.new("RGB", (64, 64), "black")
|
||||
run_img2img_pipeline(
|
||||
|
@ -245,7 +242,6 @@ class TestImg2ImgPipeline(unittest.TestCase):
|
|||
1,
|
||||
1,
|
||||
),
|
||||
["test-img2img.png"],
|
||||
UpscaleParams("test"),
|
||||
HighresParams(False, 1, 0, 0),
|
||||
source,
|
||||
|
@ -259,7 +255,7 @@ class TestInpaintPipeline(unittest.TestCase):
|
|||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
|
||||
def test_basic_white(self):
|
||||
worker = test_worker()
|
||||
worker.start("test")
|
||||
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
|
||||
|
||||
source = Image.new("RGB", (64, 64), "black")
|
||||
mask = Image.new("RGB", (64, 64), "white")
|
||||
|
@ -277,7 +273,6 @@ class TestInpaintPipeline(unittest.TestCase):
|
|||
unet_tile=64,
|
||||
),
|
||||
Size(*source.size),
|
||||
["test-inpaint-white.png"],
|
||||
UpscaleParams("test"),
|
||||
HighresParams(False, 1, 0, 0),
|
||||
source,
|
||||
|
@ -296,7 +291,7 @@ class TestInpaintPipeline(unittest.TestCase):
|
|||
@test_needs_models([TEST_MODEL_DIFFUSION_SD15_INPAINT])
|
||||
def test_basic_black(self):
|
||||
worker = test_worker()
|
||||
worker.start("test")
|
||||
worker.start(JobCommand("test", "test", "test", run_txt2img_pipeline, [], {}))
|
||||
|
||||
source = Image.new("RGB", (64, 64), "black")
|
||||
mask = Image.new("RGB", (64, 64), "black")
|
||||
|
@ -314,7 +309,6 @@ class TestInpaintPipeline(unittest.TestCase):
|
|||
unet_tile=64,
|
||||
),
|
||||
Size(*source.size),
|
||||
["test-inpaint-black.png"],
|
||||
UpscaleParams("test"),
|
||||
HighresParams(False, 1, 0, 0),
|
||||
source,
|
||||
|
@ -353,7 +347,7 @@ class TestUpscalePipeline(unittest.TestCase):
|
|||
3,
|
||||
0.1,
|
||||
)
|
||||
worker.start("test")
|
||||
worker.start(JobCommand("test", "test", "test", run_upscale_pipeline, [], {}))
|
||||
|
||||
source = Image.new("RGB", (64, 64), "black")
|
||||
run_upscale_pipeline(
|
||||
|
@ -369,7 +363,6 @@ class TestUpscalePipeline(unittest.TestCase):
|
|||
1,
|
||||
),
|
||||
Size(256, 256),
|
||||
["test-upscale.png"],
|
||||
UpscaleParams("test"),
|
||||
HighresParams(False, 1, 0, 0),
|
||||
source,
|
||||
|
@ -399,7 +392,7 @@ class TestBlendPipeline(unittest.TestCase):
|
|||
3,
|
||||
0.1,
|
||||
)
|
||||
worker.start("test")
|
||||
worker.start(JobCommand("test", "test", "test", run_blend_pipeline, [], {}))
|
||||
|
||||
source = Image.new("RGBA", (64, 64), "black")
|
||||
mask = Image.new("RGBA", (64, 64), "white")
|
||||
|
@ -417,7 +410,6 @@ class TestBlendPipeline(unittest.TestCase):
|
|||
unet_tile=64,
|
||||
),
|
||||
Size(64, 64),
|
||||
["test-blend.png"],
|
||||
UpscaleParams("test"),
|
||||
[source, source],
|
||||
mask,
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
|||
|
||||
from onnx_web.params import DeviceParams
|
||||
from onnx_web.server.context import ServerContext
|
||||
from onnx_web.worker.command import JobStatus
|
||||
from onnx_web.worker.pool import DevicePoolExecutor
|
||||
|
||||
TEST_JOIN_TIMEOUT = 0.2
|
||||
|
@ -50,11 +51,11 @@ class TestWorkerPool(unittest.TestCase):
|
|||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||
self.pool.start()
|
||||
|
||||
self.pool.submit("test", wait_job, lock=lock)
|
||||
self.assertEqual(self.pool.done("test"), (True, None))
|
||||
self.pool.submit("test", "test", wait_job, lock=lock)
|
||||
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
|
||||
|
||||
self.assertTrue(self.pool.cancel("test"))
|
||||
self.assertEqual(self.pool.done("test"), (False, None))
|
||||
self.assertEqual(self.pool.status("test"), (JobStatus.CANCELLED, None))
|
||||
|
||||
def test_cancel_running(self):
|
||||
pass
|
||||
|
@ -88,12 +89,14 @@ class TestWorkerPool(unittest.TestCase):
|
|||
self.pool = DevicePoolExecutor(
|
||||
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||
)
|
||||
|
||||
lock.clear()
|
||||
self.pool.start(lock)
|
||||
self.pool.submit("test", test_job)
|
||||
self.pool.submit("test", "test", test_job)
|
||||
sleep(5.0)
|
||||
|
||||
pending, _progress = self.pool.done("test")
|
||||
self.assertFalse(pending)
|
||||
status, _progress = self.pool.status("test")
|
||||
self.assertEqual(status, JobStatus.RUNNING)
|
||||
|
||||
def test_done_pending(self):
|
||||
device = DeviceParams("cpu", "CPUProvider")
|
||||
|
@ -102,9 +105,9 @@ class TestWorkerPool(unittest.TestCase):
|
|||
self.pool = DevicePoolExecutor(server, [device], join_timeout=TEST_JOIN_TIMEOUT)
|
||||
self.pool.start(lock)
|
||||
|
||||
self.pool.submit("test1", test_job)
|
||||
self.pool.submit("test2", test_job)
|
||||
self.assertTrue(self.pool.done("test2"), (True, None))
|
||||
self.pool.submit("test1", "test", test_job)
|
||||
self.pool.submit("test2", "test", test_job)
|
||||
self.assertEqual(self.pool.status("test2"), (JobStatus.PENDING, None))
|
||||
|
||||
lock.set()
|
||||
|
||||
|
@ -119,12 +122,12 @@ class TestWorkerPool(unittest.TestCase):
|
|||
server, [device], join_timeout=TEST_JOIN_TIMEOUT, progress_interval=0.1
|
||||
)
|
||||
self.pool.start()
|
||||
self.pool.submit("test", wait_job)
|
||||
self.assertEqual(self.pool.done("test"), (True, None))
|
||||
self.pool.submit("test", "test", wait_job)
|
||||
self.assertEqual(self.pool.status("test"), (JobStatus.PENDING, None))
|
||||
|
||||
sleep(5.0)
|
||||
pending, _progress = self.pool.done("test")
|
||||
self.assertFalse(pending)
|
||||
status, _progress = self.pool.status("test")
|
||||
self.assertEqual(status, JobStatus.SUCCESS)
|
||||
|
||||
def test_recycle_live(self):
|
||||
pass
|
||||
|
|
|
@ -40,7 +40,7 @@ class WorkerMainTests(unittest.TestCase):
|
|||
nonlocal status
|
||||
status = exit_status
|
||||
|
||||
job = JobCommand("test", "test", main_interrupt, [], {})
|
||||
job = JobCommand("test", "test", "test", main_interrupt, [], {})
|
||||
cancel = Value("L", False)
|
||||
logs = Queue()
|
||||
pending = Queue()
|
||||
|
@ -75,7 +75,7 @@ class WorkerMainTests(unittest.TestCase):
|
|||
nonlocal status
|
||||
status = exit_status
|
||||
|
||||
job = JobCommand("test", "test", main_retry, [], {})
|
||||
job = JobCommand("test", "test", "test", main_retry, [], {})
|
||||
cancel = Value("L", False)
|
||||
logs = Queue()
|
||||
pending = Queue()
|
||||
|
@ -144,7 +144,7 @@ class WorkerMainTests(unittest.TestCase):
|
|||
nonlocal status
|
||||
status = exit_status
|
||||
|
||||
job = JobCommand("test", "test", main_memory, [], {})
|
||||
job = JobCommand("test", "test", "test", main_memory, [], {})
|
||||
cancel = Value("L", False)
|
||||
logs = Queue()
|
||||
pending = Queue()
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
onnx-web is designed to simplify the process of running Stable Diffusion and other [ONNX models](https://onnx.ai) so you
|
||||
can focus on making high quality, high resolution art. With the efficiency of hardware acceleration on both AMD and
|
||||
Nvidia GPUs, and offering a reliable CPU software fallback, it offers the full feature set on desktop, laptops, and
|
||||
servers with a seamless user experience.
|
||||
multi-GPU servers with a seamless user experience.
|
||||
|
||||
You can navigate through the user-friendly web UI, hosted on Github Pages and accessible across all major browsers,
|
||||
including your go-to mobile device. Here, you have the flexibility to choose diffusion models and accelerators for each
|
||||
|
@ -84,18 +84,6 @@ This is an incomplete list of new and interesting features:
|
|||
- includes both the API and GUI bundle in a single container
|
||||
- runs well on [RunPod](https://www.runpod.io/), [Vast.ai](https://vast.ai/), and other GPU container hosting services
|
||||
|
||||
## Contents
|
||||
|
||||
- [onnx-web](#onnx-web)
|
||||
- [Features](#features)
|
||||
- [Contents](#contents)
|
||||
- [Setup](#setup)
|
||||
- [Adding your own models](#adding-your-own-models)
|
||||
- [Usage](#usage)
|
||||
- [Known errors and solutions](#known-errors-and-solutions)
|
||||
- [Running the containers](#running-the-containers)
|
||||
- [Credits](#credits)
|
||||
|
||||
## Setup
|
||||
|
||||
There are a few ways to run onnx-web:
|
||||
|
|
|
@ -4,10 +4,7 @@ import { doesExist, InvalidArgumentError, Maybe } from '@apextoaster/js-utils';
|
|||
import { ServerParams } from '../config.js';
|
||||
import {
|
||||
FilterResponse,
|
||||
ImageResponse,
|
||||
ImageResponseWithRetry,
|
||||
ModelResponse,
|
||||
ReadyResponse,
|
||||
RetryParams,
|
||||
WriteExtrasResponse,
|
||||
} from '../types/api.js';
|
||||
|
@ -27,6 +24,7 @@ import {
|
|||
} from '../types/params.js';
|
||||
import { range } from '../utils.js';
|
||||
import { ApiClient } from './base.js';
|
||||
import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/api-v2.js';
|
||||
|
||||
/**
|
||||
* Fixed precision for integer parameters.
|
||||
|
@ -43,8 +41,9 @@ export const FIXED_INTEGER = 0;
|
|||
export const FIXED_FLOAT = 2;
|
||||
export const STATUS_SUCCESS = 200;
|
||||
|
||||
export function equalResponse(a: ImageResponse, b: ImageResponse): boolean {
|
||||
return a.outputs === b.outputs;
|
||||
export function equalResponse(a: JobResponse, b: JobResponse): boolean {
|
||||
return a.name === b.name && a.status === b.status && a.type === b.type;
|
||||
// return a.outputs === b.outputs;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -141,8 +140,8 @@ export function appendHighresToURL(url: URL, highres: HighresParams) {
|
|||
* Make an API client using the given API root and fetch client.
|
||||
*/
|
||||
export function makeClient(root: string, token: Maybe<string> = undefined, f = fetch): ApiClient {
|
||||
function parseRequest(url: URL, options: RequestInit): Promise<ImageResponse> {
|
||||
return f(url, options).then((res) => parseApiResponse(root, res));
|
||||
function parseRequest(url: URL, options: RequestInit): Promise<JobResponse> {
|
||||
return f(url, options).then((res) => parseJobResponse(root, res));
|
||||
}
|
||||
|
||||
return {
|
||||
|
@ -218,7 +217,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
const res = await f(path);
|
||||
return await res.json() as Array<string>;
|
||||
},
|
||||
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
|
||||
async img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
||||
const url = makeImageURL(root, 'img2img', params);
|
||||
appendModelToURL(url, model);
|
||||
|
||||
|
@ -240,12 +239,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
const body = new FormData();
|
||||
body.append('source', params.source, 'source');
|
||||
|
||||
const image = await parseRequest(url, {
|
||||
const job = await parseRequest(url, {
|
||||
body,
|
||||
method: 'POST',
|
||||
});
|
||||
return {
|
||||
image,
|
||||
job,
|
||||
retry: {
|
||||
type: 'img2img',
|
||||
model,
|
||||
|
@ -254,7 +253,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
},
|
||||
};
|
||||
},
|
||||
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
|
||||
async txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
||||
const url = makeImageURL(root, 'txt2img', params);
|
||||
appendModelToURL(url, model);
|
||||
|
||||
|
@ -274,11 +273,11 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
appendHighresToURL(url, highres);
|
||||
}
|
||||
|
||||
const image = await parseRequest(url, {
|
||||
const job = await parseRequest(url, {
|
||||
method: 'POST',
|
||||
});
|
||||
return {
|
||||
image,
|
||||
job,
|
||||
retry: {
|
||||
type: 'txt2img',
|
||||
model,
|
||||
|
@ -288,7 +287,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
},
|
||||
};
|
||||
},
|
||||
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
|
||||
async inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
||||
const url = makeImageURL(root, 'inpaint', params);
|
||||
appendModelToURL(url, model);
|
||||
|
||||
|
@ -309,12 +308,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
body.append('mask', params.mask, 'mask');
|
||||
body.append('source', params.source, 'source');
|
||||
|
||||
const image = await parseRequest(url, {
|
||||
const job = await parseRequest(url, {
|
||||
body,
|
||||
method: 'POST',
|
||||
});
|
||||
return {
|
||||
image,
|
||||
job,
|
||||
retry: {
|
||||
type: 'inpaint',
|
||||
model,
|
||||
|
@ -323,7 +322,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
},
|
||||
};
|
||||
},
|
||||
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
|
||||
async outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
||||
const url = makeImageURL(root, 'inpaint', params);
|
||||
appendModelToURL(url, model);
|
||||
|
||||
|
@ -361,12 +360,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
body.append('mask', params.mask, 'mask');
|
||||
body.append('source', params.source, 'source');
|
||||
|
||||
const image = await parseRequest(url, {
|
||||
const job = await parseRequest(url, {
|
||||
body,
|
||||
method: 'POST',
|
||||
});
|
||||
return {
|
||||
image,
|
||||
job,
|
||||
retry: {
|
||||
type: 'outpaint',
|
||||
model,
|
||||
|
@ -375,7 +374,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
},
|
||||
};
|
||||
},
|
||||
async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry> {
|
||||
async upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry> {
|
||||
const url = makeApiUrl(root, 'upscale');
|
||||
appendModelToURL(url, model);
|
||||
|
||||
|
@ -396,12 +395,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
const body = new FormData();
|
||||
body.append('source', params.source, 'source');
|
||||
|
||||
const image = await parseRequest(url, {
|
||||
const job = await parseRequest(url, {
|
||||
body,
|
||||
method: 'POST',
|
||||
});
|
||||
return {
|
||||
image,
|
||||
job,
|
||||
retry: {
|
||||
type: 'upscale',
|
||||
model,
|
||||
|
@ -410,7 +409,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
},
|
||||
};
|
||||
},
|
||||
async blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry> {
|
||||
async blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<JobResponseWithRetry> {
|
||||
const url = makeApiUrl(root, 'blend');
|
||||
appendModelToURL(url, model);
|
||||
|
||||
|
@ -426,12 +425,12 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
body.append(name, params.sources[i], name);
|
||||
}
|
||||
|
||||
const image = await parseRequest(url, {
|
||||
const job = await parseRequest(url, {
|
||||
body,
|
||||
method: 'POST',
|
||||
});
|
||||
return {
|
||||
image,
|
||||
job,
|
||||
retry: {
|
||||
type: 'blend',
|
||||
model,
|
||||
|
@ -440,8 +439,8 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
}
|
||||
};
|
||||
},
|
||||
async chain(model: ModelParams, chain: ChainPipeline): Promise<ImageResponse> {
|
||||
const url = makeApiUrl(root, 'chain');
|
||||
async chain(model: ModelParams, chain: ChainPipeline): Promise<JobResponse> {
|
||||
const url = makeApiUrl(root, 'job');
|
||||
const body = JSON.stringify({
|
||||
...chain,
|
||||
platform: model.platform,
|
||||
|
@ -456,23 +455,23 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
method: 'POST',
|
||||
});
|
||||
},
|
||||
async ready(key: string): Promise<ReadyResponse> {
|
||||
const path = makeApiUrl(root, 'ready');
|
||||
path.searchParams.append('output', key);
|
||||
async status(keys: Array<string>): Promise<Array<JobResponse>> {
|
||||
const path = makeApiUrl(root, 'job', 'status');
|
||||
path.searchParams.append('jobs', keys.join(','));
|
||||
|
||||
const res = await f(path);
|
||||
return await res.json() as ReadyResponse;
|
||||
return await res.json() as Array<JobResponse>;
|
||||
},
|
||||
async cancel(key: string): Promise<boolean> {
|
||||
const path = makeApiUrl(root, 'cancel');
|
||||
path.searchParams.append('output', key);
|
||||
async cancel(keys: Array<string>): Promise<Array<JobResponse>> {
|
||||
const path = makeApiUrl(root, 'job', 'cancel');
|
||||
path.searchParams.append('jobs', keys.join(','));
|
||||
|
||||
const res = await f(path, {
|
||||
method: 'PUT',
|
||||
});
|
||||
return res.status === STATUS_SUCCESS;
|
||||
return await res.json() as Array<JobResponse>;
|
||||
},
|
||||
async retry(retry: RetryParams): Promise<ImageResponseWithRetry> {
|
||||
async retry(retry: RetryParams): Promise<JobResponseWithRetry> {
|
||||
switch (retry.type) {
|
||||
case 'blend':
|
||||
return this.blend(retry.model, retry.params, retry.upscale);
|
||||
|
@ -491,7 +490,7 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
}
|
||||
},
|
||||
async restart(): Promise<boolean> {
|
||||
const path = makeApiUrl(root, 'restart');
|
||||
const path = makeApiUrl(root, 'worker', 'restart');
|
||||
|
||||
if (doesExist(token)) {
|
||||
path.searchParams.append('token', token);
|
||||
|
@ -502,8 +501,8 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
});
|
||||
return res.status === STATUS_SUCCESS;
|
||||
},
|
||||
async status(): Promise<Array<unknown>> {
|
||||
const path = makeApiUrl(root, 'status');
|
||||
async workers(): Promise<Array<unknown>> {
|
||||
const path = makeApiUrl(root, 'worker', 'status');
|
||||
|
||||
if (doesExist(token)) {
|
||||
path.searchParams.append('token', token);
|
||||
|
@ -512,6 +511,9 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
const res = await f(path);
|
||||
return res.json();
|
||||
},
|
||||
outputURL(image: SuccessJobResponse, index: number): string {
|
||||
return new URL(joinPath('output', image.outputs[index]), root).toString();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -521,24 +523,9 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
|
|||
* The server sends over the output key, and the client is in the best position to turn
|
||||
* that into a full URL, since it already knows the root URL of the server.
|
||||
*/
|
||||
export async function parseApiResponse(root: string, res: Response): Promise<ImageResponse> {
|
||||
type LimitedResponse = Omit<ImageResponse, 'outputs'> & { outputs: Array<string> };
|
||||
|
||||
export async function parseJobResponse(root: string, res: Response): Promise<JobResponse> {
|
||||
if (res.status === STATUS_SUCCESS) {
|
||||
const data = await res.json() as LimitedResponse;
|
||||
|
||||
const outputs = data.outputs.map((output) => {
|
||||
const url = new URL(joinPath('output', output), root).toString();
|
||||
return {
|
||||
key: output,
|
||||
url,
|
||||
};
|
||||
});
|
||||
|
||||
return {
|
||||
...data,
|
||||
outputs,
|
||||
};
|
||||
return await res.json() as JobResponse;
|
||||
} else {
|
||||
throw new Error('request error');
|
||||
}
|
||||
|
|
|
@ -1,12 +1,19 @@
|
|||
import { ServerParams } from '../config.js';
|
||||
import { ExtrasFile } from '../types/model.js';
|
||||
import { WriteExtrasResponse, FilterResponse, ModelResponse, ImageResponseWithRetry, ImageResponse, ReadyResponse, RetryParams } from '../types/api.js';
|
||||
import { WriteExtrasResponse, FilterResponse, ModelResponse, RetryParams } from '../types/api.js';
|
||||
import { ChainPipeline } from '../types/chain.js';
|
||||
import { ModelParams, Txt2ImgParams, UpscaleParams, HighresParams, Img2ImgParams, InpaintParams, OutpaintParams, UpscaleReqParams, BlendParams } from '../types/params.js';
|
||||
import { JobResponse, JobResponseWithRetry, SuccessJobResponse } from '../types/api-v2.js';
|
||||
|
||||
export interface ApiClient {
|
||||
/**
|
||||
* Get the first extras file.
|
||||
*/
|
||||
extras(): Promise<ExtrasFile>;
|
||||
|
||||
/**
|
||||
* Update the first extras file.
|
||||
*/
|
||||
writeExtras(extras: ExtrasFile): Promise<WriteExtrasResponse>;
|
||||
|
||||
/**
|
||||
|
@ -51,54 +58,60 @@ export interface ApiClient {
|
|||
translation: Record<string, string>;
|
||||
}>>;
|
||||
|
||||
/**
|
||||
* Get the available wildcards.
|
||||
*/
|
||||
wildcards(): Promise<Array<string>>;
|
||||
|
||||
/**
|
||||
* Start a txt2img pipeline.
|
||||
*/
|
||||
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
|
||||
txt2img(model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
||||
|
||||
/**
|
||||
* Start an im2img pipeline.
|
||||
*/
|
||||
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
|
||||
img2img(model: ModelParams, params: Img2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
||||
|
||||
/**
|
||||
* Start an inpaint pipeline.
|
||||
*/
|
||||
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
|
||||
inpaint(model: ModelParams, params: InpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
||||
|
||||
/**
|
||||
* Start an outpaint pipeline.
|
||||
*/
|
||||
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
|
||||
outpaint(model: ModelParams, params: OutpaintParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
||||
|
||||
/**
|
||||
* Start an upscale pipeline.
|
||||
*/
|
||||
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<ImageResponseWithRetry>;
|
||||
upscale(model: ModelParams, params: UpscaleReqParams, upscale?: UpscaleParams, highres?: HighresParams): Promise<JobResponseWithRetry>;
|
||||
|
||||
/**
|
||||
* Start a blending pipeline.
|
||||
*/
|
||||
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
|
||||
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<JobResponseWithRetry>;
|
||||
|
||||
chain(model: ModelParams, chain: ChainPipeline): Promise<ImageResponse>;
|
||||
/**
|
||||
* Start a custom chain pipeline.
|
||||
*/
|
||||
chain(model: ModelParams, chain: ChainPipeline): Promise<JobResponse>;
|
||||
|
||||
/**
|
||||
* Check whether job has finished and its output is ready.
|
||||
*/
|
||||
ready(key: string): Promise<ReadyResponse>;
|
||||
status(keys: Array<string>): Promise<Array<JobResponse>>;
|
||||
|
||||
/**
|
||||
* Cancel an existing job.
|
||||
*/
|
||||
cancel(key: string): Promise<boolean>;
|
||||
cancel(keys: Array<string>): Promise<Array<JobResponse>>;
|
||||
|
||||
/**
|
||||
* Retry a previous job using the same parameters.
|
||||
*/
|
||||
retry(params: RetryParams): Promise<ImageResponseWithRetry>;
|
||||
retry(params: RetryParams): Promise<JobResponseWithRetry>;
|
||||
|
||||
/**
|
||||
* Restart the image job workers.
|
||||
|
@ -108,5 +121,7 @@ export interface ApiClient {
|
|||
/**
|
||||
* Check the status of the image job workers.
|
||||
*/
|
||||
status(): Promise<Array<unknown>>;
|
||||
workers(): Promise<Array<unknown>>;
|
||||
|
||||
outputURL(image: SuccessJobResponse, index: number): string;
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ export const LOCAL_CLIENT = {
|
|||
async params() {
|
||||
throw new NoServerError();
|
||||
},
|
||||
async ready(key) {
|
||||
async status(key) {
|
||||
throw new NoServerError();
|
||||
},
|
||||
async cancel(key) {
|
||||
|
@ -78,7 +78,10 @@ export const LOCAL_CLIENT = {
|
|||
async restart() {
|
||||
throw new NoServerError();
|
||||
},
|
||||
async status() {
|
||||
async workers() {
|
||||
throw new NoServerError();
|
||||
}
|
||||
},
|
||||
outputURL(image, index) {
|
||||
throw new NoServerError();
|
||||
},
|
||||
} as ApiClient;
|
||||
|
|
|
@ -97,11 +97,19 @@ export function expandRanges(range: string): Array<string | number> {
|
|||
export const GRID_TILE_SIZE = 8192;
|
||||
|
||||
// eslint-disable-next-line max-params
|
||||
export function makeTxt2ImgGridPipeline(grid: PipelineGrid, model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): ChainPipeline {
|
||||
export function makeTxt2ImgGridPipeline(
|
||||
grid: PipelineGrid,
|
||||
model: ModelParams,
|
||||
params: Txt2ImgParams,
|
||||
upscale?: UpscaleParams,
|
||||
highres?: HighresParams,
|
||||
): ChainPipeline {
|
||||
const pipeline: ChainPipeline = {
|
||||
defaults: {
|
||||
...model,
|
||||
...params,
|
||||
...(upscale || {}),
|
||||
...(highres || {}),
|
||||
},
|
||||
stages: [],
|
||||
};
|
||||
|
|
|
@ -10,6 +10,7 @@ import { OnnxState, StateContext } from '../state/full.js';
|
|||
import { ErrorCard } from './card/ErrorCard.js';
|
||||
import { ImageCard } from './card/ImageCard.js';
|
||||
import { LoadingCard } from './card/LoadingCard.js';
|
||||
import { JobStatus } from '../types/api-v2.js';
|
||||
|
||||
export function ImageHistory() {
|
||||
const store = mustExist(useContext(StateContext));
|
||||
|
@ -25,19 +26,19 @@ export function ImageHistory() {
|
|||
|
||||
const limited = history.slice(0, limit);
|
||||
for (const item of limited) {
|
||||
const key = item.image.outputs[0].key;
|
||||
const key = item.image.name;
|
||||
|
||||
if (doesExist(item.ready) && item.ready.ready) {
|
||||
if (item.ready.cancelled || item.ready.failed) {
|
||||
children.push([key, <ErrorCard key={`history-${key}`} image={item.image} ready={item.ready} retry={item.retry} />]);
|
||||
continue;
|
||||
}
|
||||
|
||||
children.push([key, <ImageCard key={`history-${key}`} image={item.image} onDelete={removeHistory} />]);
|
||||
continue;
|
||||
switch (item.image.status) {
|
||||
case JobStatus.SUCCESS:
|
||||
children.push([key, <ImageCard key={`history-${key}`} image={item.image} onDelete={removeHistory} />]);
|
||||
break;
|
||||
case JobStatus.FAILED:
|
||||
children.push([key, <ErrorCard key={`history-${key}`} image={item.image} retry={item.retry} />]);
|
||||
break;
|
||||
default:
|
||||
children.push([key, <LoadingCard key={`history-${key}`} image={item.image} />]);
|
||||
break;
|
||||
}
|
||||
|
||||
children.push([key, <LoadingCard key={`history-${key}`} index={0} image={item.image} />]);
|
||||
}
|
||||
|
||||
return <Grid container spacing={2}>{children.map(([key, child]) => <Grid item key={key} xs={6}>{child}</Grid>)}</Grid>;
|
||||
|
|
|
@ -10,16 +10,15 @@ import { useStore } from 'zustand';
|
|||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { ImageResponse, ReadyResponse, RetryParams } from '../../types/api.js';
|
||||
import { FailedJobResponse, RetryParams } from '../../types/api-v2.js';
|
||||
|
||||
export interface ErrorCardProps {
|
||||
image: ImageResponse;
|
||||
ready: ReadyResponse;
|
||||
image: FailedJobResponse;
|
||||
retry: Maybe<RetryParams>;
|
||||
}
|
||||
|
||||
export function ErrorCard(props: ErrorCardProps) {
|
||||
const { image, ready, retry: retryParams } = props;
|
||||
const { image, retry: retryParams } = props;
|
||||
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
const { params } = mustExist(useContext(ConfigContext));
|
||||
|
@ -32,8 +31,8 @@ export function ErrorCard(props: ErrorCardProps) {
|
|||
removeHistory(image);
|
||||
|
||||
if (doesExist(retryParams)) {
|
||||
const { image: nextImage, retry: nextRetry } = await client.retry(retryParams);
|
||||
pushHistory(nextImage, nextRetry);
|
||||
const { job: nextJob, retry: nextRetry } = await client.retry(retryParams);
|
||||
pushHistory(nextJob, nextRetry);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -52,10 +51,11 @@ export function ErrorCard(props: ErrorCardProps) {
|
|||
spacing={2}
|
||||
sx={{ alignItems: 'center' }}
|
||||
>
|
||||
<Alert severity='error'>{t('loading.progress', {
|
||||
current: ready.progress,
|
||||
total: image.params.steps,
|
||||
})}</Alert>
|
||||
<Alert severity='error'>
|
||||
{t('loading.progress', image.steps)}
|
||||
<br />
|
||||
{image.error}
|
||||
</Alert>
|
||||
<Stack direction='row' spacing={2}>
|
||||
<Tooltip title={t('tooltip.retry')}>
|
||||
<IconButton onClick={() => retry.mutate()}>
|
||||
|
|
|
@ -2,21 +2,22 @@ import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils'
|
|||
import { ArrowLeft, ArrowRight, Blender, Brush, ContentCopy, Delete, Download, ZoomOutMap } from '@mui/icons-material';
|
||||
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Menu, MenuItem, Paper, Tooltip } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useContext, useState } from 'react';
|
||||
import { useContext, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useHash } from 'react-use/lib/useHash';
|
||||
import { useStore } from 'zustand';
|
||||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { ImageResponse } from '../../types/api.js';
|
||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { range, visibleIndex } from '../../utils.js';
|
||||
import { BLEND_SOURCES } from '../../constants.js';
|
||||
import { JobResponse, SuccessJobResponse } from '../../types/api-v2.js';
|
||||
import { getApiRoot } from '../../config.js';
|
||||
|
||||
export interface ImageCardProps {
|
||||
image: ImageResponse;
|
||||
image: SuccessJobResponse;
|
||||
|
||||
onDelete?: (key: ImageResponse) => void;
|
||||
onDelete?: (key: JobResponse) => void;
|
||||
}
|
||||
|
||||
export function GridItem(props: { xs: number; children: React.ReactNode }) {
|
||||
|
@ -27,18 +28,19 @@ export function GridItem(props: { xs: number; children: React.ReactNode }) {
|
|||
|
||||
export function ImageCard(props: ImageCardProps) {
|
||||
const { image } = props;
|
||||
const { params, outputs, size } = image;
|
||||
const { metadata, outputs } = image;
|
||||
|
||||
const [_hash, setHash] = useHash();
|
||||
const [blendAnchor, setBlendAnchor] = useState<Maybe<HTMLElement>>();
|
||||
const [saveAnchor, setSaveAnchor] = useState<Maybe<HTMLElement>>();
|
||||
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
const config = mustExist(useContext(ConfigContext));
|
||||
const store = mustExist(useContext(StateContext));
|
||||
const { setBlend, setImg2Img, setInpaint, setUpscale } = useStore(store, selectActions, shallow);
|
||||
|
||||
async function loadSource() {
|
||||
const req = await fetch(outputs[index].url);
|
||||
const req = await fetch(url);
|
||||
return req.blob();
|
||||
}
|
||||
|
||||
|
@ -84,12 +86,12 @@ export function ImageCard(props: ImageCardProps) {
|
|||
}
|
||||
|
||||
function downloadImage() {
|
||||
window.open(outputs[index].url, '_blank');
|
||||
window.open(url, '_blank');
|
||||
close();
|
||||
}
|
||||
|
||||
function downloadMetadata() {
|
||||
window.open(outputs[index].url + '.json', '_blank');
|
||||
window.open(url + '.json', '_blank');
|
||||
close();
|
||||
}
|
||||
|
||||
|
@ -106,14 +108,16 @@ export function ImageCard(props: ImageCardProps) {
|
|||
return mustDefault(t(`${key}.${name}`), name);
|
||||
}
|
||||
|
||||
const model = getLabel('model', params.model);
|
||||
const scheduler = getLabel('scheduler', params.scheduler);
|
||||
const url = useMemo(() => client.outputURL(image, index), [image, index]);
|
||||
|
||||
const model = getLabel('model', metadata[index].model);
|
||||
const scheduler = getLabel('scheduler', metadata[index].scheduler);
|
||||
|
||||
return <Card sx={{ maxWidth: config.params.width.default }} elevation={2}>
|
||||
<CardMedia sx={{ height: config.params.height.default }}
|
||||
component='img'
|
||||
image={outputs[index].url}
|
||||
title={params.prompt}
|
||||
image={url}
|
||||
title={metadata[index].params.prompt}
|
||||
/>
|
||||
<CardContent>
|
||||
<Box textAlign='center'>
|
||||
|
@ -146,12 +150,12 @@ export function ImageCard(props: ImageCardProps) {
|
|||
</GridItem>
|
||||
<GridItem xs={4}>{t('modelType.diffusion', {count: 1})}: {model}</GridItem>
|
||||
<GridItem xs={4}>{t('parameter.scheduler')}: {scheduler}</GridItem>
|
||||
<GridItem xs={4}>{t('parameter.seed')}: {params.seed}</GridItem>
|
||||
<GridItem xs={4}>{t('parameter.cfg')}: {params.cfg}</GridItem>
|
||||
<GridItem xs={4}>{t('parameter.steps')}: {params.steps}</GridItem>
|
||||
<GridItem xs={4}>{t('parameter.size')}: {size.width}x{size.height}</GridItem>
|
||||
<GridItem xs={4}>{t('parameter.seed')}: {metadata[index].params.seed}</GridItem>
|
||||
<GridItem xs={4}>{t('parameter.cfg')}: {metadata[index].params.cfg}</GridItem>
|
||||
<GridItem xs={4}>{t('parameter.steps')}: {metadata[index].params.steps}</GridItem>
|
||||
<GridItem xs={4}>{t('parameter.size')}: {metadata[index].size.width}x{metadata[index].size.height}</GridItem>
|
||||
<GridItem xs={12}>
|
||||
<Box textAlign='left'>{params.prompt}</Box>
|
||||
<Box textAlign='left'>{metadata[index].params.prompt}</Box>
|
||||
</GridItem>
|
||||
<GridItem xs={2}>
|
||||
<Tooltip title={t('tooltip.save')}>
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { doesExist, mustExist } from '@apextoaster/js-utils';
|
||||
import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils';
|
||||
import { Box, Button, Card, CardContent, CircularProgress, Typography } from '@mui/material';
|
||||
import { Stack } from '@mui/system';
|
||||
import { useMutation, useQuery } from '@tanstack/react-query';
|
||||
|
@ -10,19 +10,17 @@ import { shallow } from 'zustand/shallow';
|
|||
|
||||
import { POLL_TIME } from '../../config.js';
|
||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { ImageResponse } from '../../types/api.js';
|
||||
import { JobResponse, JobStatus } from '../../types/api-v2.js';
|
||||
|
||||
const LOADING_PERCENT = 100;
|
||||
const LOADING_OVERAGE = 99;
|
||||
|
||||
export interface LoadingCardProps {
|
||||
image: ImageResponse;
|
||||
index: number;
|
||||
image: JobResponse;
|
||||
}
|
||||
|
||||
export function LoadingCard(props: LoadingCardProps) {
|
||||
const { image, index } = props;
|
||||
const { steps } = props.image.params;
|
||||
const { image } = props;
|
||||
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
const { params } = mustExist(useContext(ConfigContext));
|
||||
|
@ -31,50 +29,22 @@ export function LoadingCard(props: LoadingCardProps) {
|
|||
const { removeHistory, setReady } = useStore(store, selectActions, shallow);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const cancel = useMutation(() => client.cancel(image.outputs[index].key));
|
||||
const ready = useQuery(['ready', image.outputs[index].key], () => client.ready(image.outputs[index].key), {
|
||||
const cancel = useMutation(() => client.cancel([image.name]));
|
||||
const ready = useQuery(['ready', image.name], () => client.status([image.name]), {
|
||||
// data will always be ready without this, even if the API says its not
|
||||
cacheTime: 0,
|
||||
refetchInterval: POLL_TIME,
|
||||
});
|
||||
|
||||
function getProgress() {
|
||||
if (doesExist(ready.data)) {
|
||||
return ready.data.progress;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
function getPercent() {
|
||||
const progress = getProgress();
|
||||
if (progress > steps) {
|
||||
// steps was not complete, show 99% until done
|
||||
return LOADING_OVERAGE;
|
||||
}
|
||||
|
||||
const pct = progress / steps;
|
||||
return Math.ceil(pct * LOADING_PERCENT);
|
||||
}
|
||||
|
||||
function getTotal() {
|
||||
const progress = getProgress();
|
||||
if (progress > steps) {
|
||||
// steps was not complete, show 99% until done
|
||||
return t('loading.unknown');
|
||||
}
|
||||
|
||||
return steps.toFixed(0);
|
||||
}
|
||||
|
||||
function getReady() {
|
||||
return doesExist(ready.data) && ready.data.ready;
|
||||
return doesExist(ready.data) && ready.data[0].status === JobStatus.SUCCESS;
|
||||
}
|
||||
|
||||
function renderProgress() {
|
||||
const progress = getProgress();
|
||||
if (progress > 0 && progress <= steps) {
|
||||
return <CircularProgress variant='determinate' value={getPercent()} />;
|
||||
const progress = getProgress(ready.data);
|
||||
const total = getTotal(ready.data);
|
||||
if (progress > 0 && progress <= total) {
|
||||
return <CircularProgress variant='determinate' value={getPercent(progress, total)} />;
|
||||
} else {
|
||||
return <CircularProgress />;
|
||||
}
|
||||
|
@ -88,9 +58,9 @@ export function LoadingCard(props: LoadingCardProps) {
|
|||
|
||||
useEffect(() => {
|
||||
if (ready.status === 'success' && getReady()) {
|
||||
setReady(props.image, ready.data);
|
||||
setReady(ready.data[0]);
|
||||
}
|
||||
}, [ready.status, getReady(), getProgress()]);
|
||||
}, [ready.status, getReady(), getProgress(ready.data)]);
|
||||
|
||||
return <Card sx={{ maxWidth: params.width.default }}>
|
||||
<CardContent sx={{ height: params.height.default }}>
|
||||
|
@ -106,10 +76,7 @@ export function LoadingCard(props: LoadingCardProps) {
|
|||
sx={{ alignItems: 'center' }}
|
||||
>
|
||||
{renderProgress()}
|
||||
<Typography>{t('loading.progress', {
|
||||
current: getProgress(),
|
||||
total: getTotal(),
|
||||
})}</Typography>
|
||||
<Typography>{t('loading.progress', selectStatus(ready.data, image))}</Typography>
|
||||
<Button onClick={() => cancel.mutate()}>{t('loading.cancel')}</Button>
|
||||
</Stack>
|
||||
</Box>
|
||||
|
@ -125,3 +92,45 @@ export function selectActions(state: OnnxState) {
|
|||
setReady: state.setReady,
|
||||
};
|
||||
}
|
||||
|
||||
export function selectStatus(data: Maybe<Array<JobResponse>>, defaultData: JobResponse) {
|
||||
if (doesExist(data) && data.length > 0) {
|
||||
return {
|
||||
steps: data[0].steps,
|
||||
stages: data[0].stages,
|
||||
tiles: data[0].tiles,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
steps: defaultData.steps,
|
||||
stages: defaultData.stages,
|
||||
tiles: defaultData.tiles,
|
||||
};
|
||||
}
|
||||
|
||||
export function getPercent(current: number, total: number): number {
|
||||
if (current > total) {
|
||||
// steps was not complete, show 99% until done
|
||||
return LOADING_OVERAGE;
|
||||
}
|
||||
|
||||
const pct = current / total;
|
||||
return Math.ceil(pct * LOADING_PERCENT);
|
||||
}
|
||||
|
||||
export function getProgress(data: Maybe<Array<JobResponse>>) {
|
||||
if (doesExist(data)) {
|
||||
return data[0].steps.current;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
export function getTotal(data: Maybe<Array<JobResponse>>) {
|
||||
if (doesExist(data)) {
|
||||
return data[0].steps.total;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -20,13 +20,13 @@ import { MaskCanvas } from '../input/MaskCanvas.js';
|
|||
export function Blend() {
|
||||
async function uploadSource() {
|
||||
const { blend, blendModel, blendUpscale } = store.getState();
|
||||
const { image, retry } = await client.blend(blendModel, {
|
||||
const { job, retry } = await client.blend(blendModel, {
|
||||
...blend,
|
||||
mask: mustExist(blend.mask),
|
||||
sources: mustExist(blend.sources), // TODO: show an error if this doesn't exist
|
||||
}, blendUpscale);
|
||||
|
||||
pushHistory(image, retry);
|
||||
pushHistory(job, retry);
|
||||
}
|
||||
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
|
|
|
@ -27,12 +27,12 @@ export function Img2Img() {
|
|||
const state = store.getState();
|
||||
const img2img = selectParams(state);
|
||||
|
||||
const { image, retry } = await client.img2img(model, {
|
||||
const { job, retry } = await client.img2img(model, {
|
||||
...img2img,
|
||||
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
|
||||
}, selectUpscale(state), selectHighres(state));
|
||||
|
||||
pushHistory(image, retry);
|
||||
pushHistory(job, retry);
|
||||
}
|
||||
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
|
|
|
@ -39,22 +39,22 @@ export function Inpaint() {
|
|||
const inpaint = selectParams(state);
|
||||
|
||||
if (outpaint.enabled) {
|
||||
const { image, retry } = await client.outpaint(model, {
|
||||
const { job, retry } = await client.outpaint(model, {
|
||||
...inpaint,
|
||||
...outpaint,
|
||||
mask: mustExist(mask),
|
||||
source: mustExist(source),
|
||||
}, selectUpscale(state), selectHighres(state));
|
||||
|
||||
pushHistory(image, retry);
|
||||
pushHistory(job, retry);
|
||||
} else {
|
||||
const { image, retry } = await client.inpaint(model, {
|
||||
const { job, retry } = await client.inpaint(model, {
|
||||
...inpaint,
|
||||
mask: mustExist(mask),
|
||||
source: mustExist(source),
|
||||
}, selectUpscale(state), selectHighres(state));
|
||||
|
||||
pushHistory(image, retry);
|
||||
pushHistory(job, retry);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -69,8 +69,8 @@ export function Txt2Img() {
|
|||
const image = await client.chain(model, chain);
|
||||
pushHistory(image);
|
||||
} else {
|
||||
const { image, retry } = await client.txt2img(model, params2, upscale, highres);
|
||||
pushHistory(image, retry);
|
||||
const { job, retry } = await client.txt2img(model, params2, upscale, highres);
|
||||
pushHistory(job, retry);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -21,12 +21,12 @@ import { PromptInput } from '../input/PromptInput.js';
|
|||
export function Upscale() {
|
||||
async function uploadSource() {
|
||||
const { upscaleHighres, upscaleUpscale, upscaleModel, upscale } = store.getState();
|
||||
const { image, retry } = await client.upscale(upscaleModel, {
|
||||
const { job, retry } = await client.upscale(upscaleModel, {
|
||||
...upscale,
|
||||
source: mustExist(upscale.source), // TODO: show an error if this doesn't exist
|
||||
}, upscaleUpscale, upscaleHighres);
|
||||
|
||||
pushHistory(image, retry);
|
||||
pushHistory(job, retry);
|
||||
}
|
||||
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
|
|
|
@ -2,6 +2,7 @@ import { Maybe } from '@apextoaster/js-utils';
|
|||
import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js';
|
||||
import { Slice } from './types.js';
|
||||
import { DEFAULT_HISTORY } from '../constants.js';
|
||||
import { JobResponse } from '../types/api-v2.js';
|
||||
|
||||
export interface HistoryItem {
|
||||
image: ImageResponse;
|
||||
|
@ -9,14 +10,19 @@ export interface HistoryItem {
|
|||
retry: Maybe<RetryParams>;
|
||||
}
|
||||
|
||||
export interface HistoryItemV2 {
|
||||
image: JobResponse;
|
||||
retry: Maybe<RetryParams>;
|
||||
}
|
||||
|
||||
export interface HistorySlice {
|
||||
history: Array<HistoryItem>;
|
||||
history: Array<HistoryItemV2>;
|
||||
limit: number;
|
||||
|
||||
pushHistory(image: ImageResponse, retry?: RetryParams): void;
|
||||
removeHistory(image: ImageResponse): void;
|
||||
pushHistory(image: JobResponse, retry?: RetryParams): void;
|
||||
removeHistory(image: JobResponse): void;
|
||||
setLimit(limit: number): void;
|
||||
setReady(image: ImageResponse, ready: ReadyResponse): void;
|
||||
setReady(image: JobResponse): void;
|
||||
}
|
||||
|
||||
export function createHistorySlice<TState extends HistorySlice>(): Slice<TState, HistorySlice> {
|
||||
|
@ -39,7 +45,7 @@ export function createHistorySlice<TState extends HistorySlice>(): Slice<TState,
|
|||
removeHistory(image) {
|
||||
set((prev) => ({
|
||||
...prev,
|
||||
history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key),
|
||||
history: prev.history.filter((it) => it.image.name !== image.name),
|
||||
}));
|
||||
},
|
||||
setLimit(limit) {
|
||||
|
@ -48,12 +54,12 @@ export function createHistorySlice<TState extends HistorySlice>(): Slice<TState,
|
|||
limit,
|
||||
}));
|
||||
},
|
||||
setReady(image, ready) {
|
||||
setReady(image) {
|
||||
set((prev) => {
|
||||
const history = [...prev.history];
|
||||
const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key);
|
||||
const idx = history.findIndex((it) => it.image.name === image.name);
|
||||
if (idx >= 0) {
|
||||
history[idx].ready = ready;
|
||||
history[idx].image = image;
|
||||
} else {
|
||||
// TODO: error
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ export const I18N_STRINGS_EN = {
|
|||
},
|
||||
loading: {
|
||||
cancel: 'Cancel',
|
||||
progress: '{{current}} of {{total}} steps',
|
||||
progress: '{{steps.current}} of {{steps.total}} steps, {{tiles.current}} of {{tiles.total}} tiles, {{stages.current}} of {{stages.total}} stages',
|
||||
server: 'Connecting to server...',
|
||||
unknown: 'many',
|
||||
},
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
import { RetryParams } from './api.js';
|
||||
import { BaseImgParams, HighresParams, Img2ImgParams, InpaintParams, Txt2ImgParams, UpscaleParams } from './params.js';
|
||||
|
||||
export interface Progress {
|
||||
current: number;
|
||||
total: number;
|
||||
}
|
||||
|
||||
export interface Size {
|
||||
width: number;
|
||||
height: number;
|
||||
}
|
||||
|
||||
export interface NetworkMetadata {
|
||||
name: string;
|
||||
hash: string;
|
||||
weight: number;
|
||||
}
|
||||
|
||||
export interface ImageMetadata<TParams extends BaseImgParams, TType extends JobType> {
|
||||
input_size: Size;
|
||||
size: Size;
|
||||
outputs: Array<string>;
|
||||
params: TParams;
|
||||
inversions: Array<NetworkMetadata>;
|
||||
loras: Array<NetworkMetadata>;
|
||||
model: string;
|
||||
scheduler: string;
|
||||
border: unknown;
|
||||
highres: HighresParams;
|
||||
upscale: UpscaleParams;
|
||||
type: TType;
|
||||
}
|
||||
|
||||
export enum JobStatus {
|
||||
PENDING = 'pending',
|
||||
RUNNING = 'running',
|
||||
SUCCESS = 'success',
|
||||
FAILED = 'failed',
|
||||
CANCELLED = 'cancelled',
|
||||
UNKNOWN = 'unknown',
|
||||
}
|
||||
|
||||
export enum JobType {
|
||||
TXT2IMG = 'txt2img',
|
||||
IMG2IMG = 'img2img',
|
||||
INPAINT = 'inpaint',
|
||||
UPSCALE = 'upscale',
|
||||
BLEND = 'blend',
|
||||
CHAIN = 'chain',
|
||||
}
|
||||
|
||||
export interface BaseJobResponse {
|
||||
name: string;
|
||||
status: JobStatus;
|
||||
type: JobType;
|
||||
|
||||
stages: Progress;
|
||||
steps: Progress;
|
||||
tiles: Progress;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pending image job.
|
||||
*/
|
||||
export interface PendingJobResponse extends BaseJobResponse {
|
||||
status: JobStatus.PENDING | JobStatus.RUNNING;
|
||||
queue: Progress;
|
||||
}
|
||||
|
||||
/**
|
||||
* Failed image job with error information.
|
||||
*/
|
||||
export interface FailedJobResponse extends BaseJobResponse {
|
||||
status: JobStatus.FAILED;
|
||||
error: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Successful txt2img image job with output keys and metadata.
|
||||
*/
|
||||
export interface SuccessTxt2ImgJobResponse extends BaseJobResponse {
|
||||
status: JobStatus.SUCCESS;
|
||||
type: JobType.TXT2IMG;
|
||||
outputs: Array<string>;
|
||||
metadata: Array<ImageMetadata<Txt2ImgParams, JobType.TXT2IMG>>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Successful img2img job with output keys and metadata.
|
||||
*/
|
||||
export interface SuccessImg2ImgJobResponse extends BaseJobResponse {
|
||||
status: JobStatus.SUCCESS;
|
||||
type: JobType.IMG2IMG;
|
||||
outputs: Array<string>;
|
||||
metadata: Array<ImageMetadata<Img2ImgParams, JobType.IMG2IMG>>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Successful inpaint job with output keys and metadata.
|
||||
*/
|
||||
export interface SuccessInpaintJobResponse extends BaseJobResponse {
|
||||
status: JobStatus.SUCCESS;
|
||||
type: JobType.INPAINT;
|
||||
outputs: Array<string>;
|
||||
metadata: Array<ImageMetadata<InpaintParams, JobType.INPAINT>>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Successful upscale job with output keys and metadata.
|
||||
*/
|
||||
export interface SuccessUpscaleJobResponse extends BaseJobResponse {
|
||||
status: JobStatus.SUCCESS;
|
||||
type: JobType.UPSCALE;
|
||||
outputs: Array<string>;
|
||||
metadata: Array<ImageMetadata<BaseImgParams, JobType.UPSCALE>>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Successful blend job with output keys and metadata.
|
||||
*/
|
||||
export interface SuccessBlendJobResponse extends BaseJobResponse {
|
||||
status: JobStatus.SUCCESS;
|
||||
type: JobType.BLEND;
|
||||
outputs: Array<string>;
|
||||
metadata: Array<ImageMetadata<BaseImgParams, JobType.BLEND>>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Successful chain pipeline job with output keys and metadata.
|
||||
*/
|
||||
export interface SuccessChainJobResponse extends BaseJobResponse {
|
||||
status: JobStatus.SUCCESS;
|
||||
type: JobType.CHAIN;
|
||||
outputs: Array<string>;
|
||||
metadata: Array<ImageMetadata<BaseImgParams, JobType>>; // TODO: could be all kinds
|
||||
}
|
||||
|
||||
export type SuccessJobResponse
|
||||
= SuccessTxt2ImgJobResponse
|
||||
| SuccessImg2ImgJobResponse
|
||||
| SuccessInpaintJobResponse
|
||||
| SuccessUpscaleJobResponse
|
||||
| SuccessBlendJobResponse
|
||||
| SuccessChainJobResponse;
|
||||
|
||||
export type JobResponse = PendingJobResponse | FailedJobResponse | SuccessJobResponse;
|
||||
|
||||
/**
|
||||
* Status response from the job endpoint, with parameters to retry the job if it fails.
|
||||
*/
|
||||
export interface JobResponseWithRetry {
|
||||
job: JobResponse;
|
||||
retry: RetryParams;
|
||||
}
|
||||
|
||||
/**
|
||||
* Re-export `RetryParams` for convenience.
|
||||
*/
|
||||
export { RetryParams };
|
|
@ -14,6 +14,8 @@ import {
|
|||
|
||||
/**
|
||||
* Output image data within the response.
|
||||
*
|
||||
* @deprecated
|
||||
*/
|
||||
export interface ImageOutput {
|
||||
key: string;
|
||||
|
@ -22,6 +24,8 @@ export interface ImageOutput {
|
|||
|
||||
/**
|
||||
* General response for most image requests.
|
||||
*
|
||||
* @deprecated
|
||||
*/
|
||||
export interface ImageResponse {
|
||||
outputs: Array<ImageOutput>;
|
||||
|
@ -119,11 +123,19 @@ export type RetryParams = {
|
|||
upscale?: UpscaleParams;
|
||||
};
|
||||
|
||||
/**
|
||||
* Status response from the image endpoint, with parameters to retry the job if it fails.
|
||||
*
|
||||
* @deprecated
|
||||
*/
|
||||
export interface ImageResponseWithRetry {
|
||||
image: ImageResponse;
|
||||
retry: RetryParams;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated
|
||||
*/
|
||||
export interface ImageMetadata {
|
||||
highres: HighresParams;
|
||||
outputs: string | Array<string>;
|
||||
|
|
|
@ -43,6 +43,7 @@
|
|||
"dtype",
|
||||
"ESRGAN",
|
||||
"Exif",
|
||||
"fromarray",
|
||||
"ftfy",
|
||||
"gfpgan",
|
||||
"Heun",
|
||||
|
@ -115,6 +116,7 @@
|
|||
"webp",
|
||||
"xformers",
|
||||
"zustand"
|
||||
]
|
||||
],
|
||||
"git.ignoreLimitWarning": true
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue