1
0
Fork 0

Merge branch 'main' of https://github.com/ssube/onnx-web into feat/dynamic-wildcards

This commit is contained in:
BZLibby 2023-12-16 13:25:38 -06:00
commit 591a63edc8
66 changed files with 2153 additions and 1409 deletions

View File

@ -23,8 +23,8 @@ details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md).
This is an incomplete list of new and interesting features, with links to the user guide:
- SDXL support
- LCM support
- supports SDXL and SDXL Turbo
- wide variety of schedulers: DDIM, DEIS, DPM SDE, Euler Ancestral, LCM, UniPC, and more
- hardware acceleration on both AMD and Nvidia
- tested on CUDA, DirectML, and ROCm
- [half-precision support for low-memory GPUs](docs/user-guide.md#optimizing-models-for-lower-memory-usage) on both

View File

@ -1,5 +1,7 @@
call onnx_env\Scripts\Activate.bat
echo "This launch.bat script is deprecated in favor of launch.ps1 and will be removed in a future release."
echo "Downloading and converting models to ONNX format..."
IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=..\models\extras.json)
python -m onnx_web.convert ^
@ -10,6 +12,12 @@ python -m onnx_web.convert ^
--extras=%ONNX_WEB_EXTRA_MODELS% ^
--token=%HF_TOKEN% %ONNX_WEB_EXTRA_ARGS%
IF NOT EXIST .\gui\index.html (
echo "Please make sure you have downloaded the web UI files from https://github.com/ssube/onnx-web/tree/gh-pages"
echo "See https://github.com/ssube/onnx-web/blob/main/docs/setup-guide.md#download-the-web-ui-bundle for more information"
pause
)
echo "Launching API server..."
waitress-serve ^
--host=0.0.0.0 ^

View File

@ -10,6 +10,13 @@ python -m onnx_web.convert `
--extras=$Env:ONNX_WEB_EXTRA_MODELS `
--token=$Env:HF_TOKEN $Env:ONNX_WEB_EXTRA_ARGS
if (!(Test-Path -path .\gui\index.html -PathType Leaf)) {
echo "Downloading latest web UI files from Github..."
Invoke-WebRequest "https://raw.githubusercontent.com/ssube/onnx-web/gh-pages/v0.11.0/index.html" -OutFile .\gui\index.html
Invoke-WebRequest "https://raw.githubusercontent.com/ssube/onnx-web/gh-pages/v0.11.0/config.json" -OutFile .\gui\config.json
Invoke-WebRequest "https://raw.githubusercontent.com/ssube/onnx-web/gh-pages/v0.11.0/bundle/main.js" -OutFile .\gui\bundle\main.js
}
echo "Launching API server..."
waitress-serve `
--host=0.0.0.0 `

View File

@ -30,17 +30,16 @@ class BlendMaskStage(BaseStage):
) -> StageResult:
logger.info("blending image using mask")
# TODO: does this need an alpha channel?
mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black")
mult_mask.alpha_composite(stage_mask)
mult_mask = Image.alpha_composite(mult_mask, stage_mask)
mult_mask = mult_mask.convert("L")
if is_debug():
save_image(server, "last-mask.png", stage_mask)
save_image(server, "last-mult-mask.png", mult_mask)
return StageResult(
images=[
return StageResult.from_images(
[
Image.composite(stage_source, source, mult_mask)
for source in sources.as_image()
]

View File

@ -3,16 +3,17 @@ from argparse import ArgumentParser
from logging import getLogger
from os import makedirs, path
from sys import exit
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
from typing import Any, Dict, List, Union
from huggingface_hub.file_download import hf_hub_download
from jsonschema import ValidationError, validate
from onnx import load_model, save_model
from transformers import CLIPTokenizer
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
from ..server.plugin import load_plugins
from ..utils import load_config
from .client import add_model_source, fetch_model
from .client.huggingface import HuggingfaceClient
from .correction.gfpgan import convert_correction_gfpgan
from .diffusion.control import convert_diffusion_control
from .diffusion.diffusion import convert_diffusion_diffusers
@ -25,8 +26,7 @@ from .upscaling.swinir import convert_upscaling_swinir
from .utils import (
DEFAULT_OPSET,
ConversionContext,
download_progress,
remove_prefix,
fix_diffusion_name,
source_format,
tuple_to_correction,
tuple_to_diffusion,
@ -44,32 +44,34 @@ warnings.filterwarnings(
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
)
Models = Dict[str, List[Any]]
logger = getLogger(__name__)
ModelDict = Dict[str, Union[float, int, str]]
Models = Dict[str, List[ModelDict]]
model_sources: Dict[str, Tuple[str, str]] = {
"civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"),
model_converters: Dict[str, Any] = {
"img2img": convert_diffusion_diffusers,
"img2img-sdxl": convert_diffusion_diffusers_xl,
"inpaint": convert_diffusion_diffusers,
"txt2img": convert_diffusion_diffusers,
"txt2img-sdxl": convert_diffusion_diffusers_xl,
}
model_source_huggingface = "huggingface://"
# recommended models
base_models: Models = {
"diffusion": [
# v1.x
(
"stable-diffusion-onnx-v1-5",
model_source_huggingface + "runwayml/stable-diffusion-v1-5",
HuggingfaceClient.protocol + "runwayml/stable-diffusion-v1-5",
),
(
"stable-diffusion-onnx-v1-inpainting",
model_source_huggingface + "runwayml/stable-diffusion-inpainting",
HuggingfaceClient.protocol + "runwayml/stable-diffusion-inpainting",
),
(
"upscaling-stable-diffusion-x4",
model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler",
HuggingfaceClient.protocol + "stabilityai/stable-diffusion-x4-upscaler",
True,
),
],
@ -200,69 +202,203 @@ base_models: Models = {
}
def fetch_model(
conversion: ConversionContext,
name: str,
source: str,
dest: Optional[str] = None,
format: Optional[str] = None,
hf_hub_fetch: bool = False,
hf_hub_filename: Optional[str] = None,
) -> Tuple[str, bool]:
cache_path = dest or conversion.cache_path
cache_name = path.join(cache_path, name)
def convert_model_source(conversion: ConversionContext, model):
model_format = source_format(model)
name = model["name"]
source = model["source"]
# add an extension if possible, some of the conversion code checks for it
if format is None:
url = urlparse(source)
ext = path.basename(url.path)
_filename, ext = path.splitext(ext)
if ext is not None:
cache_name = cache_name + ext
dest_path = None
if "dest" in model:
dest_path = path.join(conversion.model_path, model["dest"])
dest = fetch_model(conversion, name, source, format=model_format, dest=dest_path)
logger.info("finished downloading source: %s -> %s", source, dest)
def convert_model_network(conversion: ConversionContext, model):
model_format = source_format(model)
model_type = model["type"]
name = model["name"]
source = model["source"]
if model_type == "control":
dest = fetch_model(
conversion,
name,
source,
format=model_format,
)
convert_diffusion_control(
conversion,
model,
dest,
path.join(conversion.model_path, model_type, name),
)
else:
cache_name = f"{cache_name}.{format}"
model = model.get("model", None)
dest = fetch_model(
conversion,
name,
source,
dest=path.join(conversion.model_path, model_type),
format=model_format,
embeds=(model_type == "inversion" and model == "concept"),
)
if path.exists(cache_name):
logger.debug("model already exists in cache, skipping fetch")
return cache_name, False
logger.info("finished downloading network: %s -> %s", source, dest)
for proto in model_sources:
api_name, api_root = model_sources.get(proto)
if source.startswith(proto):
api_source = api_root % (remove_prefix(source, proto))
logger.info(
"downloading model from %s: %s -> %s", api_name, api_source, cache_name
def convert_model_diffusion(conversion: ConversionContext, model):
# fix up entries with missing prefixes
name = fix_diffusion_name(model["name"])
if name != model["name"]:
# update the model in-memory if the name changed
model["name"] = name
model_format = source_format(model)
pipeline = model.get("pipeline", "txt2img")
converter = model_converters.get(pipeline)
converted, dest = converter(
conversion,
model,
model_format,
)
# make sure blending only happens once, not every run
if converted:
# keep track of which models have been blended
blend_models = {}
inversion_dest = path.join(conversion.model_path, "inversion")
lora_dest = path.join(conversion.model_path, "lora")
for inversion in model.get("inversions", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(
path.join(
dest,
"text_encoder",
ONNX_MODEL,
)
)
if "tokenizer" not in blend_models:
blend_models["tokenizer"] = CLIPTokenizer.from_pretrained(
dest,
subfolder="tokenizer",
)
inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_format = inversion.get("format", None)
inversion_source = fetch_model(
conversion,
inversion_name,
inversion_source,
dest=inversion_dest,
)
return download_progress([(api_source, cache_name)]), False
inversion_token = inversion.get("token", inversion_name)
inversion_weight = inversion.get("weight", 1.0)
if source.startswith(model_source_huggingface):
hub_source = remove_prefix(source, model_source_huggingface)
logger.info("downloading model from Huggingface Hub: %s", hub_source)
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
if hf_hub_fetch:
return (
hf_hub_download(
repo_id=hub_source,
filename=hf_hub_filename,
cache_dir=cache_path,
force_filename=f"{name}.bin",
),
False,
blend_textual_inversions(
conversion,
blend_models["text_encoder"],
blend_models["tokenizer"],
[
(
inversion_source,
inversion_weight,
inversion_token,
inversion_format,
)
],
)
else:
return hub_source, True
elif source.startswith("https://"):
logger.info("downloading model from: %s", source)
return download_progress([(source, cache_name)]), False
elif source.startswith("http://"):
logger.warning("downloading model from insecure source: %s", source)
return download_progress([(source, cache_name)]), False
elif source.startswith(path.sep) or source.startswith("."):
logger.info("using local model: %s", source)
return source, False
for lora in model.get("loras", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(
path.join(
dest,
"text_encoder",
ONNX_MODEL,
)
)
if "unet" not in blend_models:
blend_models["unet"] = load_model(path.join(dest, "unet", ONNX_MODEL))
# load models if not loaded yet
lora_name = lora["name"]
lora_source = lora["source"]
lora_source = fetch_model(
conversion,
f"{name}-lora-{lora_name}",
lora_source,
dest=lora_dest,
)
lora_weight = lora.get("weight", 1.0)
blend_loras(
conversion,
blend_models["text_encoder"],
[(lora_source, lora_weight)],
"text_encoder",
)
blend_loras(
conversion,
blend_models["unet"],
[(lora_source, lora_weight)],
"unet",
)
if "tokenizer" in blend_models:
dest_path = path.join(dest, "tokenizer")
logger.debug("saving blended tokenizer to %s", dest_path)
blend_models["tokenizer"].save_pretrained(dest_path)
for name in ["text_encoder", "unet"]:
if name in blend_models:
dest_path = path.join(dest, name, ONNX_MODEL)
logger.debug("saving blended %s model to %s", name, dest_path)
save_model(
blend_models[name],
dest_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
)
def convert_model_upscaling(conversion: ConversionContext, model):
model_format = source_format(model)
name = model["name"]
source = fetch_model(conversion, name, model["source"], format=model_format)
model_type = model.get("model", "resrgan")
if model_type == "bsrgan":
convert_upscaling_bsrgan(conversion, model, source)
elif model_type == "resrgan":
convert_upscale_resrgan(conversion, model, source)
elif model_type == "swinir":
convert_upscaling_swinir(conversion, model, source)
else:
logger.info("unknown model location, using path as provided: %s", source)
return source, False
logger.error("unknown upscaling model type %s for %s", model_type, name)
raise ValueError(name)
def convert_model_correction(conversion: ConversionContext, model):
model_format = source_format(model)
name = model["name"]
source = fetch_model(conversion, name, model["source"], format=model_format)
model_type = model.get("model", "gfpgan")
if model_type == "gfpgan":
convert_correction_gfpgan(conversion, model, source)
else:
logger.error("unknown correction model type %s for %s", model_type, name)
raise ValueError(name)
def convert_models(conversion: ConversionContext, args, models: Models):
@ -276,69 +412,21 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if name in args.skip:
logger.info("skipping source: %s", name)
else:
model_format = source_format(model)
source = model["source"]
try:
dest_path = None
if "dest" in model:
dest_path = path.join(conversion.model_path, model["dest"])
dest, hf = fetch_model(
conversion, name, source, format=model_format, dest=dest_path
)
logger.info("finished downloading source: %s -> %s", source, dest)
convert_model_source(conversion, model)
except Exception:
logger.exception("error fetching source %s", name)
model_errors.append(name)
if args.networks and "networks" in models:
for network in models.get("networks", []):
name = network["name"]
for model in models.get("networks", []):
name = model["name"]
if name in args.skip:
logger.info("skipping network: %s", name)
else:
network_format = source_format(network)
network_model = network.get("model", None)
network_type = network["type"]
source = network["source"]
try:
if network_type == "control":
dest, hf = fetch_model(
conversion,
name,
source,
format=network_format,
)
convert_diffusion_control(
conversion,
network,
dest,
path.join(conversion.model_path, network_type, name),
)
if network_type == "inversion" and network_model == "concept":
dest, hf = fetch_model(
conversion,
name,
source,
dest=path.join(conversion.model_path, network_type),
format=network_format,
hf_hub_fetch=True,
hf_hub_filename="learned_embeds.bin",
)
else:
dest, hf = fetch_model(
conversion,
name,
source,
dest=path.join(conversion.model_path, network_type),
format=network_format,
)
logger.info("finished downloading network: %s -> %s", source, dest)
convert_model_network(conversion, model)
except Exception:
logger.exception("error fetching network %s", name)
model_errors.append(name)
@ -351,142 +439,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if name in args.skip:
logger.info("skipping model: %s", name)
else:
model_format = source_format(model)
try:
source, hf = fetch_model(
conversion, name, model["source"], format=model_format
)
pipeline = model.get("pipeline", "txt2img")
if pipeline.endswith("-sdxl"):
converted, dest = convert_diffusion_diffusers_xl(
conversion,
model,
source,
model_format,
hf=hf,
)
else:
converted, dest = convert_diffusion_diffusers(
conversion,
model,
source,
model_format,
hf=hf,
)
# make sure blending only happens once, not every run
if converted:
# keep track of which models have been blended
blend_models = {}
inversion_dest = path.join(conversion.model_path, "inversion")
lora_dest = path.join(conversion.model_path, "lora")
for inversion in model.get("inversions", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(
path.join(
dest,
"text_encoder",
ONNX_MODEL,
)
)
if "tokenizer" not in blend_models:
blend_models[
"tokenizer"
] = CLIPTokenizer.from_pretrained(
dest,
subfolder="tokenizer",
)
inversion_name = inversion["name"]
inversion_source = inversion["source"]
inversion_format = inversion.get("format", None)
inversion_source, hf = fetch_model(
conversion,
inversion_name,
inversion_source,
dest=inversion_dest,
)
inversion_token = inversion.get("token", inversion_name)
inversion_weight = inversion.get("weight", 1.0)
blend_textual_inversions(
conversion,
blend_models["text_encoder"],
blend_models["tokenizer"],
[
(
inversion_source,
inversion_weight,
inversion_token,
inversion_format,
)
],
)
for lora in model.get("loras", []):
if "text_encoder" not in blend_models:
blend_models["text_encoder"] = load_model(
path.join(
dest,
"text_encoder",
ONNX_MODEL,
)
)
if "unet" not in blend_models:
blend_models["unet"] = load_model(
path.join(dest, "unet", ONNX_MODEL)
)
# load models if not loaded yet
lora_name = lora["name"]
lora_source = lora["source"]
lora_source, hf = fetch_model(
conversion,
f"{name}-lora-{lora_name}",
lora_source,
dest=lora_dest,
)
lora_weight = lora.get("weight", 1.0)
blend_loras(
conversion,
blend_models["text_encoder"],
[(lora_source, lora_weight)],
"text_encoder",
)
blend_loras(
conversion,
blend_models["unet"],
[(lora_source, lora_weight)],
"unet",
)
if "tokenizer" in blend_models:
dest_path = path.join(dest, "tokenizer")
logger.debug("saving blended tokenizer to %s", dest_path)
blend_models["tokenizer"].save_pretrained(dest_path)
for name in ["text_encoder", "unet"]:
if name in blend_models:
dest_path = path.join(dest, name, ONNX_MODEL)
logger.debug(
"saving blended %s model to %s", name, dest_path
)
save_model(
blend_models[name],
dest_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
)
convert_model_diffusion(conversion, model)
except Exception:
logger.exception(
"error converting diffusion model %s",
@ -502,24 +456,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if name in args.skip:
logger.info("skipping model: %s", name)
else:
model_format = source_format(model)
try:
source, hf = fetch_model(
conversion, name, model["source"], format=model_format
)
model_type = model.get("model", "resrgan")
if model_type == "bsrgan":
convert_upscaling_bsrgan(conversion, model, source)
elif model_type == "resrgan":
convert_upscale_resrgan(conversion, model, source)
elif model_type == "swinir":
convert_upscaling_swinir(conversion, model, source)
else:
logger.error(
"unknown upscaling model type %s for %s", model_type, name
)
model_errors.append(name)
convert_model_upscaling(conversion, model)
except Exception:
logger.exception(
"error converting upscaling model %s",
@ -535,19 +473,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
if name in args.skip:
logger.info("skipping model: %s", name)
else:
model_format = source_format(model)
try:
source, hf = fetch_model(
conversion, name, model["source"], format=model_format
)
model_type = model.get("model", "gfpgan")
if model_type == "gfpgan":
convert_correction_gfpgan(conversion, model, source)
else:
logger.error(
"unknown correction model type %s for %s", model_type, name
)
model_errors.append(name)
convert_model_correction(conversion, model)
except Exception:
logger.exception(
"error converting correction model %s",
@ -559,12 +486,26 @@ def convert_models(conversion: ConversionContext, args, models: Models):
logger.error("error while converting models: %s", model_errors)
def register_plugins(conversion: ConversionContext):
logger.info("loading conversion plugins")
exports = load_plugins(conversion)
for proto, client in exports.clients:
try:
add_model_source(proto, client)
except Exception:
logger.exception("error loading client for protocol: %s", proto)
# TODO: add converters
def main(args=None) -> int:
parser = ArgumentParser(
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
)
# model groups
parser.add_argument("--base", action="store_true", default=True)
parser.add_argument("--networks", action="store_true", default=True)
parser.add_argument("--sources", action="store_true", default=True)
parser.add_argument("--correction", action="store_true", default=False)
@ -602,16 +543,20 @@ def main(args=None) -> int:
server.half = args.half or server.has_optimization("onnx-fp16")
server.opset = args.opset
server.token = args.token
register_plugins(server)
logger.info(
"converting models in %s using %s", server.model_path, server.training_device
"converting models into %s using %s", server.model_path, server.training_device
)
if not path.exists(server.model_path):
logger.info("model path does not existing, creating: %s", server.model_path)
makedirs(server.model_path)
logger.info("converting base models")
convert_models(server, args, base_models)
if args.base:
logger.info("converting base models")
convert_models(server, args, base_models)
extras = []
extras.extend(server.extra_models)

View File

@ -0,0 +1,54 @@
from typing import Callable, Dict, Optional
from logging import getLogger
from os import path
from ..utils import ConversionContext
from .base import BaseClient
from .civitai import CivitaiClient
from .file import FileClient
from .http import HttpClient
from .huggingface import HuggingfaceClient
logger = getLogger(__name__)
model_sources: Dict[str, Callable[[], BaseClient]] = {
CivitaiClient.protocol: CivitaiClient,
FileClient.protocol: FileClient,
HttpClient.insecure_protocol: HttpClient,
HttpClient.protocol: HttpClient,
HuggingfaceClient.protocol: HuggingfaceClient,
}
def add_model_source(proto: str, client: BaseClient):
global model_sources
if proto in model_sources:
raise ValueError("protocol has already been taken")
model_sources[proto] = client
def fetch_model(
conversion: ConversionContext,
name: str,
source: str,
format: Optional[str] = None,
dest: Optional[str] = None,
**kwargs,
) -> str:
# TODO: switch to urlparse's default scheme
if source.startswith(path.sep) or source.startswith("."):
logger.info("adding file protocol to local path source: %s", source)
source = FileClient.protocol + source
for proto, client_type in model_sources.items():
if source.startswith(proto):
client = client_type(conversion)
return client.download(
conversion, name, source, format=format, dest=dest, **kwargs
)
logger.warning("unknown model protocol, using path as provided: %s", source)
return source

View File

@ -0,0 +1,16 @@
from typing import Optional
from ..utils import ConversionContext
class BaseClient:
def download(
self,
conversion: ConversionContext,
name: str,
source: str,
format: Optional[str] = None,
dest: Optional[str] = None,
**kwargs,
) -> str:
raise NotImplementedError()

View File

@ -0,0 +1,64 @@
from logging import getLogger
from typing import Optional
from ..utils import (
ConversionContext,
build_cache_paths,
download_progress,
get_first_exists,
remove_prefix,
)
from .base import BaseClient
logger = getLogger(__name__)
CIVITAI_ROOT = "https://civitai.com/api/download/models/%s"
class CivitaiClient(BaseClient):
name = "civitai"
protocol = "civitai://"
root: str
token: Optional[str]
def __init__(
self,
conversion: ConversionContext,
token: Optional[str] = None,
root: str = CIVITAI_ROOT,
):
self.root = conversion.get_setting("CIVITAI_ROOT", root)
self.token = conversion.get_setting("CIVITAI_TOKEN", token)
def download(
self,
conversion: ConversionContext,
name: str,
source: str,
format: Optional[str] = None,
dest: Optional[str] = None,
**kwargs,
) -> str:
cache_paths = build_cache_paths(
conversion,
name,
client=CivitaiClient.name,
format=format,
dest=dest,
)
cached = get_first_exists(cache_paths)
if cached:
return cached
source = self.root % (remove_prefix(source, CivitaiClient.protocol))
logger.info("downloading model from Civitai: %s -> %s", source, cache_paths[0])
if self.token:
logger.debug("adding Civitai token authentication")
if "?" in source:
source = f"{source}&token={self.token}"
else:
source = f"{source}?token={self.token}"
return download_progress(source, cache_paths[0])

View File

@ -0,0 +1,32 @@
from logging import getLogger
from os import path
from typing import Optional
from urllib.parse import urlparse
from ..utils import ConversionContext
from .base import BaseClient
logger = getLogger(__name__)
class FileClient(BaseClient):
protocol = "file://"
def __init__(self, _conversion: ConversionContext):
"""
Nothing to initialize for this client.
"""
pass
def download(
self,
conversion: ConversionContext,
_name: str,
uri: str,
format: Optional[str] = None,
dest: Optional[str] = None,
**kwargs,
) -> str:
parts = urlparse(uri)
logger.info("loading model from: %s", parts.path)
return path.join(dest or conversion.model_path, parts.path)

View File

@ -0,0 +1,52 @@
from logging import getLogger
from typing import Dict, Optional
from ..utils import (
ConversionContext,
build_cache_paths,
download_progress,
get_first_exists,
)
from .base import BaseClient
logger = getLogger(__name__)
class HttpClient(BaseClient):
name = "http"
protocol = "https://"
insecure_protocol = "http://"
headers: Dict[str, str]
def __init__(
self, _conversion: ConversionContext, headers: Optional[Dict[str, str]] = None
):
self.headers = headers or {}
def download(
self,
conversion: ConversionContext,
name: str,
source: str,
format: Optional[str] = None,
dest: Optional[str] = None,
**kwargs,
) -> str:
cache_paths = build_cache_paths(
conversion,
name,
client=HttpClient.name,
format=format,
dest=dest,
)
cached = get_first_exists(cache_paths)
if cached:
return cached
if source.startswith(HttpClient.protocol):
logger.info("downloading model from: %s", source)
elif source.startswith(HttpClient.insecure_protocol):
logger.warning("downloading model from insecure source: %s", source)
return download_progress(source, cache_paths[0])

View File

@ -0,0 +1,43 @@
from logging import getLogger
from typing import Optional
from huggingface_hub.file_download import hf_hub_download
from ..utils import ConversionContext, remove_prefix
from .base import BaseClient
logger = getLogger(__name__)
class HuggingfaceClient(BaseClient):
name = "huggingface"
protocol = "huggingface://"
token: Optional[str]
def __init__(self, conversion: ConversionContext, token: Optional[str] = None):
self.token = conversion.get_setting("HUGGINGFACE_TOKEN", token)
def download(
self,
conversion: ConversionContext,
name: str,
source: str,
format: Optional[str] = None,
dest: Optional[str] = None,
embeds: bool = False,
**kwargs,
) -> str:
source = remove_prefix(source, HuggingfaceClient.protocol)
logger.info("downloading model from Huggingface Hub: %s", source)
if embeds:
return hf_hub_download(
repo_id=source,
filename="learned_embeds.bin",
cache_dir=dest or conversion.cache_path,
force_filename=f"{name}.bin",
token=self.token,
)
else:
return source

View File

@ -36,6 +36,8 @@ from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
from ...diffusers.version_safe_diffusers import AttnProcessor
from ...models.cnet import UNet2DConditionModel_CNet
from ...utils import run_gc
from ..client import fetch_model
from ..client.huggingface import HuggingfaceClient
from ..utils import (
RESOLVE_FORMATS,
ConversionContext,
@ -43,6 +45,7 @@ from ..utils import (
is_torch_2_0,
load_tensor,
onnx_export,
remove_prefix,
)
from .checkpoint import convert_extract_checkpoint
@ -267,14 +270,13 @@ def collate_cnet(cnet_path):
def convert_diffusion_diffusers(
conversion: ConversionContext,
model: Dict,
source: str,
format: Optional[str],
hf: bool = False,
) -> Tuple[bool, str]:
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
name = model.get("name")
name = str(model.get("name")).strip()
source = model.get("source")
# optional
config = model.get("config", None)
@ -320,9 +322,11 @@ def convert_diffusion_diffusers(
logger.info("ONNX model already exists, skipping")
return (False, dest_path)
cache_path = fetch_model(conversion, name, source, format=format)
pipe_class = CONVERT_PIPELINES.get(pipe_type)
v2, pipe_args = get_model_version(
source, conversion.map_location, size=image_size, version=version
cache_path, conversion.map_location, size=image_size, version=version
)
is_inpainting = False
@ -334,57 +338,66 @@ def convert_diffusion_diffusers(
pipe_args["from_safetensors"] = True
torch_source = None
if path.exists(source) and path.isdir(source):
logger.debug("loading pipeline from diffusers directory: %s", source)
pipeline = pipe_class.from_pretrained(
source,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
elif path.exists(source) and path.isfile(source):
if conversion.extract:
logger.debug("extracting SD checkpoint to Torch models: %s", source)
torch_source = convert_extract_checkpoint(
conversion,
source,
f"{name}-torch",
is_inpainting=is_inpainting,
config_file=config,
vae_file=replace_vae,
)
logger.debug("loading pipeline from extracted checkpoint: %s", torch_source)
if path.exists(cache_path):
if path.isdir(cache_path):
logger.debug("loading pipeline from diffusers directory: %s", source)
pipeline = pipe_class.from_pretrained(
torch_source,
cache_path,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
# VAE replacement already happened during extraction, skip
replace_vae = None
else:
logger.debug("loading pipeline from SD checkpoint: %s", source)
pipeline = download_from_original_stable_diffusion_ckpt(
source,
original_config_file=config_path,
pipeline_class=pipe_class,
**pipe_args,
).to(device, torch_dtype=dtype)
elif hf:
logger.debug("downloading pretrained model from Huggingface hub: %s", source)
if conversion.extract:
logger.debug("extracting SD checkpoint to Torch models: %s", source)
torch_source = convert_extract_checkpoint(
conversion,
source,
f"{name}-torch",
is_inpainting=is_inpainting,
config_file=config,
vae_file=replace_vae,
)
logger.debug(
"loading pipeline from extracted checkpoint: %s", torch_source
)
pipeline = pipe_class.from_pretrained(
torch_source,
torch_dtype=dtype,
).to(device)
# VAE replacement already happened during extraction, skip
replace_vae = None
else:
logger.debug("loading pipeline from SD checkpoint: %s", source)
pipeline = download_from_original_stable_diffusion_ckpt(
cache_path,
original_config_file=config_path,
pipeline_class=pipe_class,
**pipe_args,
).to(device, torch_dtype=dtype)
elif source.startswith(HuggingfaceClient.protocol):
hf_path = remove_prefix(source, HuggingfaceClient.protocol)
logger.debug("downloading pretrained model from Huggingface hub: %s", hf_path)
pipeline = pipe_class.from_pretrained(
source,
hf_path,
torch_dtype=dtype,
use_auth_token=conversion.token,
).to(device)
else:
logger.warning("pipeline source not found or not recognized: %s", source)
raise ValueError(f"pipeline source not found or not recognized: {source}")
logger.warning(
"pipeline source not found and protocol not recognized: %s", source
)
raise ValueError(
f"pipeline source not found and protocol not recognized: {source}"
)
if replace_vae is not None:
vae_path = path.join(conversion.model_path, replace_vae)
if check_ext(replace_vae, RESOLVE_FORMATS):
vae_file = check_ext(vae_path, RESOLVE_FORMATS)
if vae_file[0]:
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
else:
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
pipeline.vae = AutoencoderKL.from_pretrained(replace_vae)
if is_torch_2_0:
pipeline.unet.set_attn_processor(AttnProcessor())

View File

@ -10,6 +10,7 @@ from onnxruntime.transformers.float16 import convert_float_to_float16
from optimum.exporters.onnx import main_export
from ...constants import ONNX_MODEL
from ..client import fetch_model
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
logger = getLogger(__name__)
@ -19,14 +20,13 @@ logger = getLogger(__name__)
def convert_diffusion_diffusers_xl(
conversion: ConversionContext,
model: Dict,
source: str,
format: Optional[str],
hf: bool = False,
) -> Tuple[bool, str]:
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
name = model.get("name")
name = str(model.get("name")).strip()
source = model.get("source")
replace_vae = model.get("vae", None)
device = conversion.training_device
@ -52,24 +52,26 @@ def convert_diffusion_diffusers_xl(
return (False, dest_path)
cache_path = fetch_model(conversion, name, model["source"], format=format)
# safetensors -> diffusers directory with torch models
temp_path = path.join(conversion.cache_path, f"{name}-torch")
if format == "safetensors":
pipeline = StableDiffusionXLPipeline.from_single_file(
source, use_safetensors=True
cache_path, use_safetensors=True
)
else:
pipeline = StableDiffusionXLPipeline.from_pretrained(source)
pipeline = StableDiffusionXLPipeline.from_pretrained(cache_path)
if replace_vae is not None:
vae_path = path.join(conversion.model_path, replace_vae)
if check_ext(replace_vae, RESOLVE_FORMATS):
vae_file = check_ext(vae_path, RESOLVE_FORMATS)
if vae_file[0]:
logger.debug("loading VAE from single tensor file: %s", vae_path)
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
else:
logger.debug("loading pretrained VAE from path: %s", vae_path)
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
logger.debug("loading pretrained VAE from path: %s", replace_vae)
pipeline.vae = AutoencoderKL.from_pretrained(replace_vae)
if path.exists(temp_path):
logger.debug("torch model already exists for %s: %s", source, temp_path)

View File

@ -31,6 +31,15 @@ ModelDict = Dict[str, Union[str, int]]
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
DEFAULT_OPSET = 14
DIFFUSION_PREFIX = [
"diffusion-",
"diffusion/",
"diffusion\\",
"stable-diffusion-",
"upscaling-", # SD upscaling
]
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"]
class ConversionContext(ServerContext):
@ -85,38 +94,37 @@ class ConversionContext(ServerContext):
return torch.device(self.training_device)
def download_progress(urls: List[Tuple[str, str]]):
for url, dest in urls:
dest_path = Path(dest).expanduser().resolve()
dest_path.parent.mkdir(parents=True, exist_ok=True)
if dest_path.exists():
logger.debug("destination already exists: %s", dest_path)
return str(dest_path.absolute())
req = requests.get(
url,
stream=True,
allow_redirects=True,
headers={
"User-Agent": "onnx-web-api",
},
)
if req.status_code != 200:
req.raise_for_status() # Only works for 4xx errors, per SO answer
raise RequestException(
"request to %s failed with status code: %s" % (url, req.status_code)
)
total = int(req.headers.get("Content-Length", 0))
desc = "unknown" if total == 0 else ""
req.raw.read = partial(req.raw.read, decode_content=True)
with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data:
with dest_path.open("wb") as f:
shutil.copyfileobj(data, f)
def download_progress(source: str, dest: str):
dest_path = Path(dest).expanduser().resolve()
dest_path.parent.mkdir(parents=True, exist_ok=True)
if dest_path.exists():
logger.debug("destination already exists: %s", dest_path)
return str(dest_path.absolute())
req = requests.get(
source,
stream=True,
allow_redirects=True,
headers={
"User-Agent": "onnx-web-api",
},
)
if req.status_code != 200:
req.raise_for_status() # Only works for 4xx errors, per SO answer
raise RequestException(
"request to %s failed with status code: %s" % (source, req.status_code)
)
total = int(req.headers.get("Content-Length", 0))
desc = "unknown" if total == 0 else ""
req.raw.read = partial(req.raw.read, decode_content=True)
with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data:
with dest_path.open("wb") as f:
shutil.copyfileobj(data, f)
return str(dest_path.absolute())
def tuple_to_source(model: Union[ModelDict, LegacyModel]):
if isinstance(model, list) or isinstance(model, tuple):
@ -184,10 +192,6 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
return model
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"]
def check_ext(name: str, exts: List[str]) -> Tuple[bool, str]:
_name, ext = path.splitext(name)
ext = ext.strip(".")
@ -215,6 +219,9 @@ def remove_prefix(name: str, prefix: str) -> str:
def load_torch(name: str, map_location=None) -> Optional[Dict]:
"""
TODO: move out of convert
"""
try:
logger.debug("loading tensor with Torch: %s", name)
checkpoint = torch.load(name, map_location=map_location)
@ -228,6 +235,9 @@ def load_torch(name: str, map_location=None) -> Optional[Dict]:
def load_tensor(name: str, map_location=None) -> Optional[Dict]:
"""
TODO: move out of convert
"""
logger.debug("loading tensor: %s", name)
_, extension = path.splitext(name)
extension = extension[1:].lower()
@ -285,6 +295,9 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
def resolve_tensor(name: str) -> Optional[str]:
"""
TODO: move out of convert
"""
logger.debug("searching for tensors with known extensions: %s", name)
for next_extension in RESOLVE_FORMATS:
next_name = f"{name}.{next_extension}"
@ -345,3 +358,52 @@ def onnx_export(
all_tensors_to_one_file=True,
location=ONNX_WEIGHTS,
)
def fix_diffusion_name(name: str):
if not any([name.startswith(prefix) for prefix in DIFFUSION_PREFIX]):
logger.warning(
"diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match",
name,
)
return f"diffusion-{name}"
return name
def build_cache_paths(
conversion: ConversionContext,
name: str,
client: Optional[str] = None,
dest: Optional[str] = None,
format: Optional[str] = None,
) -> List[str]:
cache_path = dest or conversion.cache_path
# add an extension if possible, some of the conversion code checks for it
if format is not None:
basename = path.basename(name)
_filename, ext = path.splitext(basename)
if ext is None or ext == "":
name = f"{name}.{format}"
paths = [
path.join(cache_path, name),
]
if client is not None:
client_path = path.join(cache_path, client)
paths.append(path.join(client_path, name))
return paths
def get_first_exists(
paths: List[str],
) -> Optional[str]:
for name in paths:
if path.exists(name):
logger.debug("model already exists in cache, skipping fetch: %s", name)
return name
return None

View File

@ -30,12 +30,12 @@ from .version_safe_diffusers import (
DDPMScheduler,
DEISMultistepScheduler,
DPMSolverMultistepScheduler,
DPMSolverSDEScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
IPNDMScheduler,
KarrasVeScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
LCMScheduler,
@ -71,6 +71,7 @@ pipeline_schedulers = {
"ddpm": DDPMScheduler,
"deis-multi": DEISMultistepScheduler,
"dpm-multi": DPMSolverMultistepScheduler,
"dpm-sde": DPMSolverSDEScheduler,
"dpm-single": DPMSolverSinglestepScheduler,
"euler": EulerDiscreteScheduler,
"euler-a": EulerAncestralDiscreteScheduler,
@ -78,7 +79,6 @@ pipeline_schedulers = {
"ipndm": IPNDMScheduler,
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
"k-dpm-2": KDPM2DiscreteScheduler,
"karras-ve": KarrasVeScheduler,
"lcm": LCMScheduler,
"lms-discrete": LMSDiscreteScheduler,
"pndm": PNDMScheduler,

View File

@ -12,6 +12,11 @@ try:
except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler
try:
from diffusers import DPMSolverSDEScheduler
except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as DPMSolverSDEScheduler
try:
from diffusers import LCMScheduler
except ImportError:

View File

@ -94,45 +94,44 @@ class ServerContext:
self.cache = ModelCache(self.cache_limit)
@classmethod
def from_environ(cls):
memory_limit = environ.get("ONNX_WEB_MEMORY_LIMIT", None)
def from_environ(cls, env=environ):
memory_limit = env.get("ONNX_WEB_MEMORY_LIMIT", None)
if memory_limit is not None:
memory_limit = int(memory_limit)
return cls(
bundle_path=environ.get(
"ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")
),
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", default="*"),
bundle_path=env.get("ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")),
model_path=env.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
output_path=env.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
params_path=env.get("ONNX_WEB_PARAMS_PATH", "."),
cors_origin=get_list(env, "ONNX_WEB_CORS_ORIGIN", default="*"),
any_platform=get_boolean(
environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
env, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
),
block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS"),
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT),
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
block_platforms=get_list(env, "ONNX_WEB_BLOCK_PLATFORMS"),
default_platform=env.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=env.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT),
cache_limit=int(env.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
show_progress=get_boolean(
environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
env, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
),
optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS"),
extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS"),
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
optimizations=get_list(env, "ONNX_WEB_OPTIMIZATIONS"),
extra_models=get_list(env, "ONNX_WEB_EXTRA_MODELS"),
job_limit=int(env.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
memory_limit=memory_limit,
admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None),
server_version=environ.get(
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
),
admin_token=env.get("ONNX_WEB_ADMIN_TOKEN", None),
server_version=env.get("ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION),
worker_retries=int(
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
env.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
),
feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS"),
plugins=get_list(environ, "ONNX_WEB_PLUGINS", ""),
debug=get_boolean(environ, "ONNX_WEB_DEBUG", False),
feature_flags=get_list(env, "ONNX_WEB_FEATURE_FLAGS"),
plugins=get_list(env, "ONNX_WEB_PLUGINS", ""),
debug=get_boolean(env, "ONNX_WEB_DEBUG", False),
)
def get_setting(self, flag: str, default: str) -> Optional[str]:
return environ.get(f"ONNX_WEB_{flag}", default)
def has_feature(self, flag: str) -> bool:
return flag in self.feature_flags

View File

@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union
import torch
from jsonschema import ValidationError, validate
from ..convert.utils import fix_diffusion_name
from ..image import ( # mask filters; noise sources
mask_filter_gaussian_multiply,
mask_filter_gaussian_screen,
@ -189,6 +190,9 @@ def load_extras(server: ServerContext):
for model in data[model_type]:
model_name = model["name"]
if model_type == "diffusion":
model_name = fix_diffusion_name(model_name)
if "hash" in model:
logger.debug(
"collecting hash for model %s from %s",

View File

@ -2,18 +2,24 @@ from importlib import import_module
from logging import getLogger
from typing import Any, Callable, Dict
from onnx_web.chain.stages import add_stage
from onnx_web.diffusers.load import add_pipeline
from onnx_web.server.context import ServerContext
from ..chain.stages import add_stage
from ..diffusers.load import add_pipeline
from ..server.context import ServerContext
logger = getLogger(__name__)
class PluginExports:
clients: Dict[str, Any]
converter: Dict[str, Any]
pipelines: Dict[str, Any]
stages: Dict[str, Any]
def __init__(self, pipelines=None, stages=None) -> None:
def __init__(
self, clients=None, converter=None, pipelines=None, stages=None
) -> None:
self.clients = clients or {}
self.converter = converter or {}
self.pipelines = pipelines or {}
self.stages = stages or {}

View File

@ -14,7 +14,7 @@
},
"cfg": {
"default": 6,
"min": 1,
"min": 0,
"max": 30,
"step": 0.1
},

View File

@ -3,21 +3,21 @@ numpy==1.23.5
protobuf==3.20.3
### SD packages ###
accelerate==0.22.0
accelerate==0.25.0
coloredlogs==15.0.1
controlnet_aux==0.0.2
datasets==2.14.3
diffusers==0.20.0
datasets==2.15.0
diffusers==0.24.0
huggingface-hub==0.16.4
invisible-watermark==0.2.0
mediapipe==0.9.2.1
omegaconf==2.3.0
onnx==1.13.0
# onnxruntime has many platform-specific packages
optimum==1.12.0
safetensors==0.3.1
timm==0.6.13
transformers==4.32.0
optimum==1.16.0
safetensors==0.4.1
timm==0.9.12
transformers==4.36.1
#### Upscaling and face correction
basicsr==1.4.2
@ -29,10 +29,11 @@ realesrgan==0.3.0
### Server packages ###
arpeggio==2.0.0
boto3==1.26.69
flask==2.2.2
flask==3.0.0
flask-cors==3.0.10
jsonschema==4.17.3
piexif==1.1.3
pyyaml==6.0
setproctitle==1.3.2
waitress==2.1.2
werkzeug==3.0.1

View File

@ -27,7 +27,7 @@ class ConversionContextTests(unittest.TestCase):
class DownloadProgressTests(unittest.TestCase):
def test_download_example(self):
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
path = download_progress("https://example.com", "/tmp/example-dot-com")
self.assertEqual(path, "/tmp/example-dot-com")

View File

@ -414,6 +414,7 @@ class TestBlendPipeline(unittest.TestCase):
3.0,
1,
1,
unet_tile=64,
),
Size(64, 64),
["test-blend.png"],

198
docs/getting-started.md Normal file
View File

@ -0,0 +1,198 @@
# Getting Started With onnx-web
onnx-web is a tool for generating images with Stable Diffusion pipelines, including SDXL.
## Contents
- [Getting Started With onnx-web](#getting-started-with-onnx-web)
- [Contents](#contents)
- [Setup](#setup)
- [Windows bundle setup](#windows-bundle-setup)
- [Other setup methods](#other-setup-methods)
- [Running](#running)
- [Running the server](#running-the-server)
- [Running the web UI](#running-the-web-ui)
- [Tabs](#tabs)
- [Txt2img Tab](#txt2img-tab)
- [Img2img Tab](#img2img-tab)
- [Inpaint Tab](#inpaint-tab)
- [Upscale Tab](#upscale-tab)
- [Blend Tab](#blend-tab)
- [Models Tab](#models-tab)
- [Settings Tab](#settings-tab)
- [Image parameters](#image-parameters)
- [Common image parameters](#common-image-parameters)
- [Unique image parameters](#unique-image-parameters)
- [Prompt syntax](#prompt-syntax)
- [LoRAs and embeddings](#loras-and-embeddings)
- [CLIP skip](#clip-skip)
- [Highres](#highres)
- [Highres prompt](#highres-prompt)
- [Highres iterations](#highres-iterations)
- [Profiles](#profiles)
- [Loading from files](#loading-from-files)
- [Saving profiles in the web UI](#saving-profiles-in-the-web-ui)
- [Sharing parameters profiles](#sharing-parameters-profiles)
- [Panorama pipeline](#panorama-pipeline)
- [Region prompts](#region-prompts)
- [Region seeds](#region-seeds)
- [Grid mode](#grid-mode)
- [Grid tokens](#grid-tokens)
- [Memory optimizations](#memory-optimizations)
- [Converting to fp16](#converting-to-fp16)
- [Moving models to the CPU](#moving-models-to-the-cpu)
## Setup
### Windows bundle setup
1. Download
2. Extract
3. Security flags
4. Run
### Other setup methods
Link to the other methods.
## Running
### Running the server
Run server or bundle.
### Running the web UI
Open web UI.
Use it from your phone.
## Tabs
There are 5 tabs, which do different things.
### Txt2img Tab
Words go in, pictures come out.
### Img2img Tab
Pictures go in, better pictures come out.
ControlNet lives here.
### Inpaint Tab
Pictures go in, parts of the same picture come out.
### Upscale Tab
Just highres and super resolution.
### Blend Tab
Use the mask tool to combine two images.
### Models Tab
Add and manage models.
### Settings Tab
Manage web UI settings.
Reset buttons.
## Image parameters
### Common image parameters
- Scheduler
- Eta
- for DDIM
- CFG
- Steps
- Seed
- Batch size
- Prompt
- Negative prompt
- Width, height
### Unique image parameters
- UNet tile size
- UNet overlap
- Tiled VAE
- VAE tile size
- VAE overlap
See the complete user guide for details about the highres, upscale, and correction parameters.
## Prompt syntax
### LoRAs and embeddings
`<lora:filename:1.0>` and `<inversion:filename:1.0>`.
### CLIP skip
`<clip:skip:2>` for anime.
## Highres
### Highres prompt
`txt2img prompt || img2img prompt`
### Highres iterations
Highres will apply the upscaler and highres prompt (img2img pipeline) for each iteration.
The final size will be `scale ** iterations`.
## Profiles
Saved sets of parameters for later use.
### Loading from files
- load from images
- load from JSON
### Saving profiles in the web UI
Use the save button.
### Sharing parameters profiles
Use the download button.
Share profiles in the Discord channel.
## Panorama pipeline
### Region prompts
`<region:X:Y:W:H:S:F_TLBR:prompt+>`
### Region seeds
`<reseed:X:Y:W:H:?:F_TLBR:seed>`
## Grid mode
Makes many images. Takes many time.
### Grid tokens
`__column__` and `__row__` if you pick token in the menu.
## Memory optimizations
### Converting to fp16
Enable the option.
### Moving models to the CPU
Option for each model.

View File

@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { OnnxState, StateContext } from '../state.js';
import { OnnxState, StateContext } from '../state/full.js';
import { ErrorCard } from './card/ErrorCard.js';
import { ImageCard } from './card/ImageCard.js';
import { LoadingCard } from './card/LoadingCard.js';

View File

@ -2,7 +2,7 @@ import { Box, Button, Container, Stack, Typography } from '@mui/material';
import * as React from 'react';
import { ReactNode } from 'react';
import { STATE_KEY } from '../state.js';
import { STATE_KEY } from '../state/full.js';
import { Logo } from './Logo.js';
export interface OnnxErrorProps {

View File

@ -7,7 +7,7 @@ import { useContext, useMemo } from 'react';
import { useHash } from 'react-use/lib/useHash';
import { useStore } from 'zustand';
import { OnnxState, StateContext } from '../state.js';
import { OnnxState, StateContext } from '../state/full.js';
import { ImageHistory } from './ImageHistory.js';
import { Logo } from './Logo.js';
import { Blend } from './tab/Blend.js';

View File

@ -21,7 +21,7 @@ import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { OnnxState, StateContext } from '../state.js';
import { OnnxState, StateContext } from '../state/full.js';
import { ImageMetadata } from '../types/api.js';
import { DeepPartial } from '../types/model.js';
import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js';

View File

@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { ImageResponse, ReadyResponse, RetryParams } from '../../types/api.js';
export interface ErrorCardProps {

View File

@ -8,9 +8,10 @@ import { useHash } from 'react-use/lib/useHash';
import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { BLEND_SOURCES, ConfigContext, OnnxState, StateContext } from '../../state.js';
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { ImageResponse } from '../../types/api.js';
import { range, visibleIndex } from '../../utils.js';
import { BLEND_SOURCES } from '../../constants.js';
export interface ImageCardProps {
image: ImageResponse;

View File

@ -9,7 +9,7 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { POLL_TIME } from '../../config.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { ImageResponse } from '../../types/api.js';
const LOADING_PERCENT = 100;

View File

@ -5,7 +5,7 @@ import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { ConfigContext, OnnxState, StateContext } from '../../state.js';
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { HighresParams } from '../../types/params.js';
import { NumericField } from '../input/NumericField.js';

View File

@ -11,7 +11,7 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { STALE_TIME } from '../../config.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { BaseImgParams } from '../../types/params.js';
import { NumericField } from '../input/NumericField.js';
import { PromptInput } from '../input/PromptInput.js';

View File

@ -6,7 +6,7 @@ import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { STALE_TIME } from '../../config.js';
import { ClientContext } from '../../state.js';
import { ClientContext } from '../../state/full.js';
import { ModelParams } from '../../types/params.js';
import { QueryList } from '../input/QueryList.js';

View File

@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { ConfigContext, OnnxState, StateContext } from '../../state.js';
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { NumericField } from '../input/NumericField.js';
export function OutpaintControl() {

View File

@ -5,7 +5,7 @@ import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { ConfigContext, OnnxState, StateContext } from '../../state.js';
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { UpscaleParams } from '../../types/params.js';
import { NumericField } from '../input/NumericField.js';

View File

@ -5,7 +5,7 @@ import { useContext } from 'react';
import { useStore } from 'zustand';
import { PipelineGrid } from '../../client/utils.js';
import { OnnxState, StateContext } from '../../state.js';
import { OnnxState, StateContext } from '../../state/full.js';
import { VARIABLE_PARAMETERS } from '../../types/chain.js';
export interface VariableControlProps {

View File

@ -4,7 +4,7 @@ import * as React from 'react';
import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { OnnxState, StateContext } from '../../state.js';
import { OnnxState, StateContext } from '../../state/full.js';
const { useContext, useState, memo, useMemo } = React;

View File

@ -6,7 +6,7 @@ import React, { RefObject, useContext, useEffect, useMemo, useRef } from 'react'
import { useTranslation } from 'react-i18next';
import { SAVE_TIME } from '../../config.js';
import { ConfigContext, LoggerContext, StateContext } from '../../state.js';
import { ConfigContext, LoggerContext, StateContext } from '../../state/full.js';
import { BrushParams } from '../../types/params.js';
import { imageFromBlob } from '../../utils.js';
import { NumericField } from './NumericField';

View File

@ -8,7 +8,7 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { STALE_TIME } from '../../config.js';
import { ClientContext, OnnxState, StateContext } from '../../state.js';
import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
import { QueryMenu } from '../input/QueryMenu.js';
import { ModelResponse } from '../../types/api.js';

View File

@ -8,7 +8,9 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { IMAGE_FILTER } from '../../config.js';
import { BLEND_SOURCES, ClientContext, OnnxState, StateContext, TabState } from '../../state.js';
import { BLEND_SOURCES } from '../../constants.js';
import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
import { TabState } from '../../state/types.js';
import { BlendParams, BrushParams, ModelParams, UpscaleParams } from '../../types/params.js';
import { range } from '../../utils.js';
import { UpscaleControl } from '../control/UpscaleControl.js';

View File

@ -8,7 +8,8 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { IMAGE_FILTER, STALE_TIME } from '../../config.js';
import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { TabState } from '../../state/types.js';
import { HighresParams, Img2ImgParams, ModelParams, UpscaleParams } from '../../types/params.js';
import { Profiles } from '../Profiles.js';
import { HighresControl } from '../control/HighresControl.js';

View File

@ -8,7 +8,8 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { IMAGE_FILTER, STALE_TIME } from '../../config.js';
import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { TabState } from '../../state/types.js';
import { BrushParams, HighresParams, InpaintParams, ModelParams, UpscaleParams } from '../../types/params.js';
import { Profiles } from '../Profiles.js';
import { HighresControl } from '../control/HighresControl.js';

View File

@ -8,7 +8,7 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { STALE_TIME } from '../../config.js';
import { ClientContext, OnnxState, StateContext } from '../../state.js';
import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
import {
CorrectionModel,
DiffusionModel,

View File

@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { getApiRoot } from '../../config.js';
import { ConfigContext, StateContext, STATE_KEY } from '../../state.js';
import { ConfigContext, StateContext, STATE_KEY } from '../../state/full.js';
import { getTheme } from '../utils.js';
import { NumericField } from '../input/NumericField.js';

View File

@ -8,7 +8,8 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { PipelineGrid, makeTxt2ImgGridPipeline } from '../../client/utils.js';
import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { TabState } from '../../state/types.js';
import { HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from '../../types/params.js';
import { Profiles } from '../Profiles.js';
import { HighresControl } from '../control/HighresControl.js';

View File

@ -8,7 +8,8 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { IMAGE_FILTER } from '../../config.js';
import { ClientContext, OnnxState, StateContext, TabState } from '../../state.js';
import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
import { TabState } from '../../state/types.js';
import { HighresParams, ModelParams, UpscaleParams, UpscaleReqParams } from '../../types/params.js';
import { Profiles } from '../Profiles.js';
import { HighresControl } from '../control/HighresControl.js';

View File

@ -1,6 +1,6 @@
import { PaletteMode } from '@mui/material';
import { Theme } from '../state.js';
import { Theme } from '../state/types.js';
import { trimHash } from '../utils.js';
export const TAB_LABELS = [

View File

@ -18,7 +18,7 @@
},
"cfg": {
"default": 6,
"min": 1,
"min": 0,
"max": 30,
"step": 0.1
},

31
gui/src/constants.ts Normal file
View File

@ -0,0 +1,31 @@
export const BLEND_SOURCES = 2;
/**
* Default parameters for the inpaint brush.
*
* Not provided by the server yet.
*/
export const DEFAULT_BRUSH = {
color: 255,
size: 8,
strength: 0.5,
};
/**
* Default parameters for the image history.
*
* Not provided by the server yet.
*/
export const DEFAULT_HISTORY = {
/**
* The number of images to be shown.
*/
limit: 4,
/**
* The number of additional images to be kept in history, so they can scroll
* back into view when you delete one. Does not include deleted images.
*/
scrollback: 2,
};

View File

@ -28,8 +28,9 @@ import {
STATE_KEY,
STATE_VERSION,
StateContext,
} from './state.js';
} from './state/full.js';
import { I18N_STRINGS } from './strings/all.js';
import { applyStateMigrations, UnknownState } from './state/migration/default.js';
export const INITIAL_LOAD_TIMEOUT = 5_000;
@ -70,6 +71,9 @@ export async function renderApp(config: Config, params: ServerParams, logger: Lo
...createResetSlice(...slice),
...createProfileSlice(...slice),
}), {
migrate(persistedState, version) {
return applyStateMigrations(params, persistedState as UnknownState, version, logger);
},
name: STATE_KEY,
partialize(s) {
return {

View File

@ -1,965 +0,0 @@
/* eslint-disable camelcase */
/* eslint-disable max-lines */
/* eslint-disable no-null/no-null */
import { Maybe } from '@apextoaster/js-utils';
import { PaletteMode } from '@mui/material';
import { Logger } from 'noicejs';
import { createContext } from 'react';
import { StateCreator, StoreApi } from 'zustand';
import {
ApiClient,
} from './client/base.js';
import { PipelineGrid } from './client/utils.js';
import { Config, ConfigFiles, ConfigState, ServerParams } from './config.js';
import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from './types/model.js';
import { ImageResponse, ReadyResponse, RetryParams } from './types/api.js';
import {
BaseImgParams,
BlendParams,
BrushParams,
HighresParams,
Img2ImgParams,
InpaintParams,
ModelParams,
OutpaintPixels,
Txt2ImgParams,
UpscaleParams,
UpscaleReqParams,
} from './types/params.js';
export const MISSING_INDEX = -1;
export type Theme = PaletteMode | ''; // tri-state, '' is unset
/**
* Combine optional files and required ranges.
*/
export type TabState<TabParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>;
export interface HistoryItem {
image: ImageResponse;
ready: Maybe<ReadyResponse>;
retry: Maybe<RetryParams>;
}
export interface ProfileItem {
name: string;
params: BaseImgParams | Txt2ImgParams;
highres?: Maybe<HighresParams>;
upscale?: Maybe<UpscaleParams>;
}
interface DefaultSlice {
defaults: TabState<BaseImgParams>;
theme: Theme;
setDefaults(param: Partial<BaseImgParams>): void;
setTheme(theme: Theme): void;
}
interface HistorySlice {
history: Array<HistoryItem>;
limit: number;
pushHistory(image: ImageResponse, retry?: RetryParams): void;
removeHistory(image: ImageResponse): void;
setLimit(limit: number): void;
setReady(image: ImageResponse, ready: ReadyResponse): void;
}
interface ModelSlice {
extras: ExtrasFile;
removeCorrectionModel(model: CorrectionModel): void;
removeDiffusionModel(model: DiffusionModel): void;
removeExtraNetwork(model: ExtraNetwork): void;
removeExtraSource(model: ExtraSource): void;
removeUpscalingModel(model: UpscalingModel): void;
setExtras(extras: Partial<ExtrasFile>): void;
setCorrectionModel(model: CorrectionModel): void;
setDiffusionModel(model: DiffusionModel): void;
setExtraNetwork(model: ExtraNetwork): void;
setExtraSource(model: ExtraSource): void;
setUpscalingModel(model: UpscalingModel): void;
}
// #region tab slices
interface Txt2ImgSlice {
txt2img: TabState<Txt2ImgParams>;
txt2imgModel: ModelParams;
txt2imgHighres: HighresParams;
txt2imgUpscale: UpscaleParams;
txt2imgVariable: PipelineGrid;
resetTxt2Img(): void;
setTxt2Img(params: Partial<Txt2ImgParams>): void;
setTxt2ImgModel(params: Partial<ModelParams>): void;
setTxt2ImgHighres(params: Partial<HighresParams>): void;
setTxt2ImgUpscale(params: Partial<UpscaleParams>): void;
setTxt2ImgVariable(params: Partial<PipelineGrid>): void;
}
interface Img2ImgSlice {
img2img: TabState<Img2ImgParams>;
img2imgModel: ModelParams;
img2imgHighres: HighresParams;
img2imgUpscale: UpscaleParams;
resetImg2Img(): void;
setImg2Img(params: Partial<Img2ImgParams>): void;
setImg2ImgModel(params: Partial<ModelParams>): void;
setImg2ImgHighres(params: Partial<HighresParams>): void;
setImg2ImgUpscale(params: Partial<UpscaleParams>): void;
}
interface InpaintSlice {
inpaint: TabState<InpaintParams>;
inpaintBrush: BrushParams;
inpaintModel: ModelParams;
inpaintHighres: HighresParams;
inpaintUpscale: UpscaleParams;
outpaint: OutpaintPixels;
resetInpaint(): void;
setInpaint(params: Partial<InpaintParams>): void;
setInpaintBrush(brush: Partial<BrushParams>): void;
setInpaintModel(params: Partial<ModelParams>): void;
setInpaintHighres(params: Partial<HighresParams>): void;
setInpaintUpscale(params: Partial<UpscaleParams>): void;
setOutpaint(pixels: Partial<OutpaintPixels>): void;
}
interface UpscaleSlice {
upscale: TabState<UpscaleReqParams>;
upscaleHighres: HighresParams;
upscaleModel: ModelParams;
upscaleUpscale: UpscaleParams;
resetUpscale(): void;
setUpscale(params: Partial<UpscaleReqParams>): void;
setUpscaleHighres(params: Partial<HighresParams>): void;
setUpscaleModel(params: Partial<ModelParams>): void;
setUpscaleUpscale(params: Partial<UpscaleParams>): void;
}
interface BlendSlice {
blend: TabState<BlendParams>;
blendBrush: BrushParams;
blendModel: ModelParams;
blendUpscale: UpscaleParams;
resetBlend(): void;
setBlend(blend: Partial<BlendParams>): void;
setBlendBrush(brush: Partial<BrushParams>): void;
setBlendModel(model: Partial<ModelParams>): void;
setBlendUpscale(params: Partial<UpscaleParams>): void;
}
interface ResetSlice {
resetAll(): void;
}
interface ProfileSlice {
profiles: Array<ProfileItem>;
removeProfile(profileName: string): void;
saveProfile(profile: ProfileItem): void;
}
// #endregion
/**
* Full merged state including all slices.
*/
export type OnnxState
= DefaultSlice
& HistorySlice
& Img2ImgSlice
& InpaintSlice
& ModelSlice
& Txt2ImgSlice
& UpscaleSlice
& BlendSlice
& ResetSlice
& ModelSlice
& ProfileSlice;
/**
* Shorthand for state creator to reduce repeated arguments.
*/
export type Slice<T> = StateCreator<OnnxState, [], [], T>;
/**
* React context binding for API client.
*/
export const ClientContext = createContext<Maybe<ApiClient>>(undefined);
/**
* React context binding for merged config, including server parameters.
*/
export const ConfigContext = createContext<Maybe<Config<ServerParams>>>(undefined);
/**
* React context binding for bunyan logger.
*/
export const LoggerContext = createContext<Maybe<Logger>>(undefined);
/**
* React context binding for zustand state store.
*/
export const StateContext = createContext<Maybe<StoreApi<OnnxState>>>(undefined);
/**
* Key for zustand persistence, typically local storage.
*/
export const STATE_KEY = 'onnx-web';
/**
* Current state version for zustand persistence.
*/
export const STATE_VERSION = 7;
export const BLEND_SOURCES = 2;
/**
* Default parameters for the inpaint brush.
*
* Not provided by the server yet.
*/
export const DEFAULT_BRUSH = {
color: 255,
size: 8,
strength: 0.5,
};
/**
* Default parameters for the image history.
*
* Not provided by the server yet.
*/
export const DEFAULT_HISTORY = {
/**
* The number of images to be shown.
*/
limit: 4,
/**
* The number of additional images to be kept in history, so they can scroll
* back into view when you delete one. Does not include deleted images.
*/
scrollback: 2,
};
export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
return {
batch: defaults.batch.default,
cfg: defaults.cfg.default,
eta: defaults.eta.default,
negativePrompt: defaults.negativePrompt.default,
prompt: defaults.prompt.default,
scheduler: defaults.scheduler.default,
steps: defaults.steps.default,
seed: defaults.seed.default,
tiled_vae: defaults.tiled_vae.default,
unet_overlap: defaults.unet_overlap.default,
unet_tile: defaults.unet_tile.default,
vae_overlap: defaults.vae_overlap.default,
vae_tile: defaults.vae_tile.default,
};
}
/**
* Prepare the state slice constructors.
*
* In the default state, image sources should be null and booleans should be false. Everything
* else should be initialized from the default value in the base parameters.
*/
export function createStateSlices(server: ServerParams) {
const defaultParams = baseParamsFromServer(server);
const defaultHighres: HighresParams = {
enabled: false,
highresIterations: server.highresIterations.default,
highresMethod: '',
highresSteps: server.highresSteps.default,
highresScale: server.highresScale.default,
highresStrength: server.highresStrength.default,
};
const defaultModel: ModelParams = {
control: server.control.default,
correction: server.correction.default,
model: server.model.default,
pipeline: server.pipeline.default,
platform: server.platform.default,
upscaling: server.upscaling.default,
};
const defaultUpscale: UpscaleParams = {
denoise: server.denoise.default,
enabled: false,
faces: false,
faceOutscale: server.faceOutscale.default,
faceStrength: server.faceStrength.default,
outscale: server.outscale.default,
scale: server.scale.default,
upscaleOrder: server.upscaleOrder.default,
};
const defaultGrid: PipelineGrid = {
enabled: false,
columns: {
parameter: 'seed',
value: '',
},
rows: {
parameter: 'seed',
value: '',
},
};
const createTxt2ImgSlice: Slice<Txt2ImgSlice> = (set) => ({
txt2img: {
...defaultParams,
width: server.width.default,
height: server.height.default,
},
txt2imgHighres: {
...defaultHighres,
},
txt2imgModel: {
...defaultModel,
},
txt2imgUpscale: {
...defaultUpscale,
},
txt2imgVariable: {
...defaultGrid,
},
setTxt2Img(params) {
set((prev) => ({
txt2img: {
...prev.txt2img,
...params,
},
}));
},
setTxt2ImgHighres(params) {
set((prev) => ({
txt2imgHighres: {
...prev.txt2imgHighres,
...params,
},
}));
},
setTxt2ImgModel(params) {
set((prev) => ({
txt2imgModel: {
...prev.txt2imgModel,
...params,
},
}));
},
setTxt2ImgUpscale(params) {
set((prev) => ({
txt2imgUpscale: {
...prev.txt2imgUpscale,
...params,
},
}));
},
setTxt2ImgVariable(params) {
set((prev) => ({
txt2imgVariable: {
...prev.txt2imgVariable,
...params,
},
}));
},
resetTxt2Img() {
set({
txt2img: {
...defaultParams,
width: server.width.default,
height: server.height.default,
},
});
},
});
const createImg2ImgSlice: Slice<Img2ImgSlice> = (set) => ({
img2img: {
...defaultParams,
loopback: server.loopback.default,
source: null,
sourceFilter: '',
strength: server.strength.default,
},
img2imgHighres: {
...defaultHighres,
},
img2imgModel: {
...defaultModel,
},
img2imgUpscale: {
...defaultUpscale,
},
resetImg2Img() {
set({
img2img: {
...defaultParams,
loopback: server.loopback.default,
source: null,
sourceFilter: '',
strength: server.strength.default,
},
});
},
setImg2Img(params) {
set((prev) => ({
img2img: {
...prev.img2img,
...params,
},
}));
},
setImg2ImgHighres(params) {
set((prev) => ({
img2imgHighres: {
...prev.img2imgHighres,
...params,
},
}));
},
setImg2ImgModel(params) {
set((prev) => ({
img2imgModel: {
...prev.img2imgModel,
...params,
},
}));
},
setImg2ImgUpscale(params) {
set((prev) => ({
img2imgUpscale: {
...prev.img2imgUpscale,
...params,
},
}));
},
});
const createInpaintSlice: Slice<InpaintSlice> = (set) => ({
inpaint: {
...defaultParams,
fillColor: server.fillColor.default,
filter: server.filter.default,
mask: null,
noise: server.noise.default,
source: null,
strength: server.strength.default,
tileOrder: server.tileOrder.default,
},
inpaintBrush: {
...DEFAULT_BRUSH,
},
inpaintHighres: {
...defaultHighres,
},
inpaintModel: {
...defaultModel,
},
inpaintUpscale: {
...defaultUpscale,
},
outpaint: {
enabled: false,
left: server.left.default,
right: server.right.default,
top: server.top.default,
bottom: server.bottom.default,
},
resetInpaint() {
set({
inpaint: {
...defaultParams,
fillColor: server.fillColor.default,
filter: server.filter.default,
mask: null,
noise: server.noise.default,
source: null,
strength: server.strength.default,
tileOrder: server.tileOrder.default,
},
});
},
setInpaint(params) {
set((prev) => ({
inpaint: {
...prev.inpaint,
...params,
},
}));
},
setInpaintBrush(brush) {
set((prev) => ({
inpaintBrush: {
...prev.inpaintBrush,
...brush,
},
}));
},
setInpaintHighres(params) {
set((prev) => ({
inpaintHighres: {
...prev.inpaintHighres,
...params,
},
}));
},
setInpaintModel(params) {
set((prev) => ({
inpaintModel: {
...prev.inpaintModel,
...params,
},
}));
},
setInpaintUpscale(params) {
set((prev) => ({
inpaintUpscale: {
...prev.inpaintUpscale,
...params,
},
}));
},
setOutpaint(pixels) {
set((prev) => ({
outpaint: {
...prev.outpaint,
...pixels,
}
}));
},
});
const createHistorySlice: Slice<HistorySlice> = (set) => ({
history: [],
limit: DEFAULT_HISTORY.limit,
pushHistory(image, retry) {
set((prev) => ({
...prev,
history: [
{
image,
ready: undefined,
retry,
},
...prev.history,
].slice(0, prev.limit + DEFAULT_HISTORY.scrollback),
}));
},
removeHistory(image) {
set((prev) => ({
...prev,
history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key),
}));
},
setLimit(limit) {
set((prev) => ({
...prev,
limit,
}));
},
setReady(image, ready) {
set((prev) => {
const history = [...prev.history];
const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key);
if (idx >= 0) {
history[idx].ready = ready;
} else {
// TODO: error
}
return {
...prev,
history,
};
});
},
});
const createUpscaleSlice: Slice<UpscaleSlice> = (set) => ({
upscale: {
...defaultParams,
source: null,
},
upscaleHighres: {
...defaultHighres,
},
upscaleModel: {
...defaultModel,
},
upscaleUpscale: {
...defaultUpscale,
},
resetUpscale() {
set({
upscale: {
...defaultParams,
source: null,
},
});
},
setUpscale(source) {
set((prev) => ({
upscale: {
...prev.upscale,
...source,
},
}));
},
setUpscaleHighres(params) {
set((prev) => ({
upscaleHighres: {
...prev.upscaleHighres,
...params,
},
}));
},
setUpscaleModel(params) {
set((prev) => ({
upscaleModel: {
...prev.upscaleModel,
...defaultModel,
},
}));
},
setUpscaleUpscale(params) {
set((prev) => ({
upscaleUpscale: {
...prev.upscaleUpscale,
...params,
},
}));
},
});
const createBlendSlice: Slice<BlendSlice> = (set) => ({
blend: {
mask: null,
sources: [],
},
blendBrush: {
...DEFAULT_BRUSH,
},
blendModel: {
...defaultModel,
},
blendUpscale: {
...defaultUpscale,
},
resetBlend() {
set({
blend: {
mask: null,
sources: [],
},
});
},
setBlend(blend) {
set((prev) => ({
blend: {
...prev.blend,
...blend,
},
}));
},
setBlendBrush(brush) {
set((prev) => ({
blendBrush: {
...prev.blendBrush,
...brush,
},
}));
},
setBlendModel(model) {
set((prev) => ({
blendModel: {
...prev.blendModel,
...model,
},
}));
},
setBlendUpscale(params) {
set((prev) => ({
blendUpscale: {
...prev.blendUpscale,
...params,
},
}));
},
});
const createDefaultSlice: Slice<DefaultSlice> = (set) => ({
defaults: {
...defaultParams,
},
theme: '',
setDefaults(params) {
set((prev) => ({
defaults: {
...prev.defaults,
...params,
}
}));
},
setTheme(theme) {
set((prev) => ({
theme,
}));
}
});
const createResetSlice: Slice<ResetSlice> = (set) => ({
resetAll() {
set((prev) => {
const next = { ...prev };
next.resetImg2Img();
next.resetInpaint();
next.resetTxt2Img();
next.resetUpscale();
next.resetBlend();
return next;
});
},
});
const createProfileSlice: Slice<ProfileSlice> = (set) => ({
profiles: [],
saveProfile(profile: ProfileItem) {
set((prev) => {
const profiles = [...prev.profiles];
const idx = profiles.findIndex((it) => it.name === profile.name);
if (idx >= 0) {
profiles[idx] = profile;
} else {
profiles.push(profile);
}
return {
...prev,
profiles,
};
});
},
removeProfile(profileName: string) {
set((prev) => {
const profiles = [...prev.profiles];
const idx = profiles.findIndex((it) => it.name === profileName);
if (idx >= 0) {
profiles.splice(idx, 1);
}
return {
...prev,
profiles,
};
});
}
});
// eslint-disable-next-line sonarjs/cognitive-complexity
const createModelSlice: Slice<ModelSlice> = (set) => ({
extras: {
correction: [],
diffusion: [],
networks: [],
sources: [],
upscaling: [],
},
setExtras(extras) {
set((prev) => ({
...prev,
extras: {
...prev.extras,
...extras,
},
}));
},
setCorrectionModel(model) {
set((prev) => {
const correction = [...prev.extras.correction];
const exists = correction.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
correction.push(model);
} else {
correction[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
correction,
},
};
});
},
setDiffusionModel(model) {
set((prev) => {
const diffusion = [...prev.extras.diffusion];
const exists = diffusion.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
diffusion.push(model);
} else {
diffusion[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
diffusion,
},
};
});
},
setExtraNetwork(model) {
set((prev) => {
const networks = [...prev.extras.networks];
const exists = networks.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
networks.push(model);
} else {
networks[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
networks,
},
};
});
},
setExtraSource(model) {
set((prev) => {
const sources = [...prev.extras.sources];
const exists = sources.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
sources.push(model);
} else {
sources[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
sources,
},
};
});
},
setUpscalingModel(model) {
set((prev) => {
const upscaling = [...prev.extras.upscaling];
const exists = upscaling.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
upscaling.push(model);
} else {
upscaling[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
upscaling,
},
};
});
},
removeCorrectionModel(model) {
set((prev) => {
const correction = prev.extras.correction.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
correction,
},
};
});
},
removeDiffusionModel(model) {
set((prev) => {
const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
diffusion,
},
};
});
},
removeExtraNetwork(model) {
set((prev) => {
const networks = prev.extras.networks.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
networks,
},
};
});
},
removeExtraSource(model) {
set((prev) => {
const sources = prev.extras.sources.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
sources,
},
};
});
},
removeUpscalingModel(model) {
set((prev) => {
const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
upscaling,
},
};
});
},
});
return {
createDefaultSlice,
createHistorySlice,
createImg2ImgSlice,
createInpaintSlice,
createTxt2ImgSlice,
createUpscaleSlice,
createBlendSlice,
createResetSlice,
createModelSlice,
createProfileSlice,
};
}

85
gui/src/state/blend.ts Normal file
View File

@ -0,0 +1,85 @@
import { DEFAULT_BRUSH } from '../constants.js';
import {
BlendParams,
BrushParams,
ModelParams,
UpscaleParams,
} from '../types/params.js';
import { Slice, TabState } from './types.js';
export interface BlendSlice {
blend: TabState<BlendParams>;
blendBrush: BrushParams;
blendModel: ModelParams;
blendUpscale: UpscaleParams;
resetBlend(): void;
setBlend(blend: Partial<BlendParams>): void;
setBlendBrush(brush: Partial<BrushParams>): void;
setBlendModel(model: Partial<ModelParams>): void;
setBlendUpscale(params: Partial<UpscaleParams>): void;
}
export function createBlendSlice<TState extends BlendSlice>(
defaultModel: ModelParams,
defaultUpscale: UpscaleParams,
): Slice<TState, BlendSlice> {
return (set) => ({
blend: {
// eslint-disable-next-line no-null/no-null
mask: null,
sources: [],
},
blendBrush: {
...DEFAULT_BRUSH,
},
blendModel: {
...defaultModel,
},
blendUpscale: {
...defaultUpscale,
},
resetBlend() {
set((prev) => ({
blend: {
// eslint-disable-next-line no-null/no-null
mask: null,
sources: [] as Array<Blob>,
},
} as Partial<TState>));
},
setBlend(blend) {
set((prev) => ({
blend: {
...prev.blend,
...blend,
},
} as Partial<TState>));
},
setBlendBrush(brush) {
set((prev) => ({
blendBrush: {
...prev.blendBrush,
...brush,
},
} as Partial<TState>));
},
setBlendModel(model) {
set((prev) => ({
blendModel: {
...prev.blendModel,
...model,
},
} as Partial<TState>));
},
setBlendUpscale(params) {
set((prev) => ({
blendUpscale: {
...prev.blendUpscale,
...params,
},
} as Partial<TState>));
},
});
}

34
gui/src/state/default.ts Normal file
View File

@ -0,0 +1,34 @@
import {
BaseImgParams,
} from '../types/params.js';
import { Slice, TabState, Theme } from './types.js';
export interface DefaultSlice {
defaults: TabState<BaseImgParams>;
theme: Theme;
setDefaults(param: Partial<BaseImgParams>): void;
setTheme(theme: Theme): void;
}
export function createDefaultSlice<TState extends DefaultSlice>(defaultParams: Required<BaseImgParams>): Slice<TState, DefaultSlice> {
return (set) => ({
defaults: {
...defaultParams,
},
theme: '',
setDefaults(params) {
set((prev) => ({
defaults: {
...prev.defaults,
...params,
}
} as Partial<TState>));
},
setTheme(theme) {
set((prev) => ({
theme,
} as Partial<TState>));
}
});
}

150
gui/src/state/full.ts Normal file
View File

@ -0,0 +1,150 @@
/* eslint-disable camelcase */
import { Maybe } from '@apextoaster/js-utils';
import { Logger } from 'noicejs';
import { createContext } from 'react';
import { StoreApi } from 'zustand';
import {
ApiClient,
} from '../client/base.js';
import { PipelineGrid } from '../client/utils.js';
import { Config, ServerParams } from '../config.js';
import { BlendSlice, createBlendSlice } from './blend.js';
import { DefaultSlice, createDefaultSlice } from './default.js';
import { HistorySlice, createHistorySlice } from './history.js';
import { Img2ImgSlice, createImg2ImgSlice } from './img2img.js';
import { InpaintSlice, createInpaintSlice } from './inpaint.js';
import { ModelSlice, createModelSlice } from './model.js';
import { ProfileSlice, createProfileSlice } from './profile.js';
import { ResetSlice, createResetSlice } from './reset.js';
import { Txt2ImgSlice, createTxt2ImgSlice } from './txt2img.js';
import { UpscaleSlice, createUpscaleSlice } from './upscale.js';
import {
BaseImgParams,
HighresParams,
ModelParams,
UpscaleParams,
} from '../types/params.js';
/**
* Full merged state including all slices.
*/
export type OnnxState
= DefaultSlice
& HistorySlice
& Img2ImgSlice
& InpaintSlice
& ModelSlice
& Txt2ImgSlice
& UpscaleSlice
& BlendSlice
& ResetSlice
& ProfileSlice;
/**
* React context binding for API client.
*/
export const ClientContext = createContext<Maybe<ApiClient>>(undefined);
/**
* React context binding for merged config, including server parameters.
*/
export const ConfigContext = createContext<Maybe<Config<ServerParams>>>(undefined);
/**
* React context binding for bunyan logger.
*/
export const LoggerContext = createContext<Maybe<Logger>>(undefined);
/**
* React context binding for zustand state store.
*/
export const StateContext = createContext<Maybe<StoreApi<OnnxState>>>(undefined);
/**
* Key for zustand persistence, typically local storage.
*/
export const STATE_KEY = 'onnx-web';
/**
* Current state version for zustand persistence.
*/
export const STATE_VERSION = 11;
export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
return {
batch: defaults.batch.default,
cfg: defaults.cfg.default,
eta: defaults.eta.default,
negativePrompt: defaults.negativePrompt.default,
prompt: defaults.prompt.default,
scheduler: defaults.scheduler.default,
steps: defaults.steps.default,
seed: defaults.seed.default,
tiled_vae: defaults.tiled_vae.default,
unet_overlap: defaults.unet_overlap.default,
unet_tile: defaults.unet_tile.default,
vae_overlap: defaults.vae_overlap.default,
vae_tile: defaults.vae_tile.default,
};
}
/**
* Prepare the state slice constructors.
*
* In the default state, image sources should be null and booleans should be false. Everything
* else should be initialized from the default value in the base parameters.
*/
export function createStateSlices(server: ServerParams) {
const defaultParams = baseParamsFromServer(server);
const defaultHighres: HighresParams = {
enabled: false,
highresIterations: server.highresIterations.default,
highresMethod: '',
highresSteps: server.highresSteps.default,
highresScale: server.highresScale.default,
highresStrength: server.highresStrength.default,
};
const defaultModel: ModelParams = {
control: server.control.default,
correction: server.correction.default,
model: server.model.default,
pipeline: server.pipeline.default,
platform: server.platform.default,
upscaling: server.upscaling.default,
};
const defaultUpscale: UpscaleParams = {
denoise: server.denoise.default,
enabled: false,
faces: false,
faceOutscale: server.faceOutscale.default,
faceStrength: server.faceStrength.default,
outscale: server.outscale.default,
scale: server.scale.default,
upscaleOrder: server.upscaleOrder.default,
};
const defaultGrid: PipelineGrid = {
enabled: false,
columns: {
parameter: 'seed',
value: '',
},
rows: {
parameter: 'seed',
value: '',
},
};
return {
createBlendSlice: createBlendSlice(defaultModel, defaultUpscale),
createDefaultSlice: createDefaultSlice(defaultParams),
createHistorySlice: createHistorySlice(),
createImg2ImgSlice: createImg2ImgSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale),
createInpaintSlice: createInpaintSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale),
createModelSlice: createModelSlice(),
createProfileSlice: createProfileSlice(),
createResetSlice: createResetSlice(),
createTxt2ImgSlice: createTxt2ImgSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale, defaultGrid),
createUpscaleSlice: createUpscaleSlice(defaultParams, defaultHighres, defaultModel, defaultUpscale),
};
}

68
gui/src/state/history.ts Normal file
View File

@ -0,0 +1,68 @@
import { Maybe } from '@apextoaster/js-utils';
import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js';
import { Slice } from './types.js';
import { DEFAULT_HISTORY } from '../constants.js';
export interface HistoryItem {
image: ImageResponse;
ready: Maybe<ReadyResponse>;
retry: Maybe<RetryParams>;
}
export interface HistorySlice {
history: Array<HistoryItem>;
limit: number;
pushHistory(image: ImageResponse, retry?: RetryParams): void;
removeHistory(image: ImageResponse): void;
setLimit(limit: number): void;
setReady(image: ImageResponse, ready: ReadyResponse): void;
}
export function createHistorySlice<TState extends HistorySlice>(): Slice<TState, HistorySlice> {
return (set) => ({
history: [],
limit: DEFAULT_HISTORY.limit,
pushHistory(image, retry) {
set((prev) => ({
...prev,
history: [
{
image,
ready: undefined,
retry,
},
...prev.history,
].slice(0, prev.limit + DEFAULT_HISTORY.scrollback),
}));
},
removeHistory(image) {
set((prev) => ({
...prev,
history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key),
}));
},
setLimit(limit) {
set((prev) => ({
...prev,
limit,
}));
},
setReady(image, ready) {
set((prev) => {
const history = [...prev.history];
const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key);
if (idx >= 0) {
history[idx].ready = ready;
} else {
// TODO: error
}
return {
...prev,
history,
};
});
},
});
}

98
gui/src/state/img2img.ts Normal file
View File

@ -0,0 +1,98 @@
import { ServerParams } from '../config.js';
import {
BaseImgParams,
HighresParams,
Img2ImgParams,
ModelParams,
UpscaleParams,
} from '../types/params.js';
import { Slice, TabState } from './types.js';
export interface Img2ImgSlice {
img2img: TabState<Img2ImgParams>;
img2imgModel: ModelParams;
img2imgHighres: HighresParams;
img2imgUpscale: UpscaleParams;
resetImg2Img(): void;
setImg2Img(params: Partial<Img2ImgParams>): void;
setImg2ImgModel(params: Partial<ModelParams>): void;
setImg2ImgHighres(params: Partial<HighresParams>): void;
setImg2ImgUpscale(params: Partial<UpscaleParams>): void;
}
// eslint-disable-next-line max-params
export function createImg2ImgSlice<TState extends Img2ImgSlice>(
server: ServerParams,
defaultParams: Required<BaseImgParams>,
defaultHighres: HighresParams,
defaultModel: ModelParams,
defaultUpscale: UpscaleParams
): Slice<TState, Img2ImgSlice> {
return (set) => ({
img2img: {
...defaultParams,
loopback: server.loopback.default,
// eslint-disable-next-line no-null/no-null
source: null,
sourceFilter: '',
strength: server.strength.default,
},
img2imgHighres: {
...defaultHighres,
},
img2imgModel: {
...defaultModel,
},
img2imgUpscale: {
...defaultUpscale,
},
resetImg2Img() {
set({
img2img: {
...defaultParams,
loopback: server.loopback.default,
// eslint-disable-next-line no-null/no-null
source: null,
sourceFilter: '',
strength: server.strength.default,
},
} as Partial<TState>);
},
setImg2Img(params) {
set((prev) => ({
img2img: {
...prev.img2img,
...params,
},
} as Partial<TState>));
},
setImg2ImgHighres(params) {
set((prev) => ({
img2imgHighres: {
...prev.img2imgHighres,
...params,
},
} as Partial<TState>));
},
setImg2ImgModel(params) {
set((prev) => ({
img2imgModel: {
...prev.img2imgModel,
...params,
},
} as Partial<TState>));
},
setImg2ImgUpscale(params) {
set((prev) => ({
img2imgUpscale: {
...prev.img2imgUpscale,
...params,
},
} as Partial<TState>));
},
});
}

136
gui/src/state/inpaint.ts Normal file
View File

@ -0,0 +1,136 @@
import { ServerParams } from '../config.js';
import { DEFAULT_BRUSH } from '../constants.js';
import {
BaseImgParams,
BrushParams,
HighresParams,
InpaintParams,
ModelParams,
OutpaintPixels,
UpscaleParams,
} from '../types/params.js';
import { Slice, TabState } from './types.js';
export interface InpaintSlice {
inpaint: TabState<InpaintParams>;
inpaintBrush: BrushParams;
inpaintModel: ModelParams;
inpaintHighres: HighresParams;
inpaintUpscale: UpscaleParams;
outpaint: OutpaintPixels;
resetInpaint(): void;
setInpaint(params: Partial<InpaintParams>): void;
setInpaintBrush(brush: Partial<BrushParams>): void;
setInpaintModel(params: Partial<ModelParams>): void;
setInpaintHighres(params: Partial<HighresParams>): void;
setInpaintUpscale(params: Partial<UpscaleParams>): void;
setOutpaint(pixels: Partial<OutpaintPixels>): void;
}
// eslint-disable-next-line max-params
export function createInpaintSlice<TState extends InpaintSlice>(
server: ServerParams,
defaultParams: Required<BaseImgParams>,
defaultHighres: HighresParams,
defaultModel: ModelParams,
defaultUpscale: UpscaleParams,
): Slice<TState, InpaintSlice> {
return (set) => ({
inpaint: {
...defaultParams,
fillColor: server.fillColor.default,
filter: server.filter.default,
// eslint-disable-next-line no-null/no-null
mask: null,
noise: server.noise.default,
// eslint-disable-next-line no-null/no-null
source: null,
strength: server.strength.default,
tileOrder: server.tileOrder.default,
},
inpaintBrush: {
...DEFAULT_BRUSH,
},
inpaintHighres: {
...defaultHighres,
},
inpaintModel: {
...defaultModel,
},
inpaintUpscale: {
...defaultUpscale,
},
outpaint: {
enabled: false,
left: server.left.default,
right: server.right.default,
top: server.top.default,
bottom: server.bottom.default,
},
resetInpaint() {
set({
inpaint: {
...defaultParams,
fillColor: server.fillColor.default,
filter: server.filter.default,
// eslint-disable-next-line no-null/no-null
mask: null,
noise: server.noise.default,
// eslint-disable-next-line no-null/no-null
source: null,
strength: server.strength.default,
tileOrder: server.tileOrder.default,
},
} as Partial<TState>);
},
setInpaint(params) {
set((prev) => ({
inpaint: {
...prev.inpaint,
...params,
},
} as Partial<TState>));
},
setInpaintBrush(brush) {
set((prev) => ({
inpaintBrush: {
...prev.inpaintBrush,
...brush,
},
} as Partial<TState>));
},
setInpaintHighres(params) {
set((prev) => ({
inpaintHighres: {
...prev.inpaintHighres,
...params,
},
} as Partial<TState>));
},
setInpaintModel(params) {
set((prev) => ({
inpaintModel: {
...prev.inpaintModel,
...params,
},
} as Partial<TState>));
},
setInpaintUpscale(params) {
set((prev) => ({
inpaintUpscale: {
...prev.inpaintUpscale,
...params,
},
} as Partial<TState>));
},
setOutpaint(pixels) {
set((prev) => ({
outpaint: {
...prev.outpaint,
...pixels,
}
} as Partial<TState>));
},
});
}

View File

@ -0,0 +1,93 @@
/* eslint-disable camelcase */
import { Logger } from 'browser-bunyan';
import { ServerParams } from '../../config.js';
import { BaseImgParams } from '../../types/params.js';
import { OnnxState, STATE_VERSION } from '../full.js';
import { Img2ImgSlice } from '../img2img.js';
import { InpaintSlice } from '../inpaint.js';
import { Txt2ImgSlice } from '../txt2img.js';
import { UpscaleSlice } from '../upscale.js';
// #region V7
export const V7 = 7;
export type BaseImgParamsV7<T extends BaseImgParams> = Omit<T, AddedKeysV11> & {
overlap: number;
tile: number;
};
export type OnnxStateV7 = Omit<OnnxState, 'img2img' | 'txt2img'> & {
img2img: BaseImgParamsV7<Img2ImgSlice['img2img']>;
inpaint: BaseImgParamsV7<InpaintSlice['inpaint']>;
txt2img: BaseImgParamsV7<Txt2ImgSlice['txt2img']>;
upscale: BaseImgParamsV7<UpscaleSlice['upscale']>;
};
// #endregion
// #region V11
export const REMOVED_KEYS_V11 = ['tile', 'overlap'] as const;
export type RemovedKeysV11 = typeof REMOVED_KEYS_V11[number];
// TODO: can the compiler calculate this?
export type AddedKeysV11 = 'unet_tile' | 'unet_overlap' | 'vae_tile' | 'vae_overlap';
// #endregion
// add versions to this list as they are replaced
export type PreviousState = OnnxStateV7;
// always the latest version
export type CurrentState = OnnxState;
// any version of state
export type UnknownState = PreviousState | CurrentState;
export function applyStateMigrations(params: ServerParams, previousState: UnknownState, version: number, logger: Logger): OnnxState {
logger.info('applying state migrations from version %s to version %s', version, STATE_VERSION);
if (version <= V7) {
return migrateV7ToV11(params, previousState as PreviousState);
}
return previousState as CurrentState;
}
export function migrateV7ToV11(params: ServerParams, previousState: PreviousState): CurrentState {
// add any missing keys
const result: CurrentState = {
...params,
...previousState,
img2img: {
...previousState.img2img,
unet_overlap: params.unet_overlap.default,
unet_tile: params.unet_tile.default,
vae_overlap: params.vae_overlap.default,
vae_tile: params.vae_tile.default,
},
inpaint: {
...previousState.inpaint,
unet_overlap: params.unet_overlap.default,
unet_tile: params.unet_tile.default,
vae_overlap: params.vae_overlap.default,
vae_tile: params.vae_tile.default,
},
txt2img: {
...previousState.txt2img,
unet_overlap: params.unet_overlap.default,
unet_tile: params.unet_tile.default,
vae_overlap: params.vae_overlap.default,
vae_tile: params.vae_tile.default,
},
upscale: {
...previousState.upscale,
unet_overlap: params.unet_overlap.default,
unet_tile: params.unet_tile.default,
vae_overlap: params.vae_overlap.default,
vae_tile: params.vae_tile.default,
},
};
// TODO: remove extra keys
return result;
}

202
gui/src/state/model.ts Normal file
View File

@ -0,0 +1,202 @@
import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from '../types/model.js';
import { MISSING_INDEX, Slice } from './types.js';
export interface ModelSlice {
extras: ExtrasFile;
removeCorrectionModel(model: CorrectionModel): void;
removeDiffusionModel(model: DiffusionModel): void;
removeExtraNetwork(model: ExtraNetwork): void;
removeExtraSource(model: ExtraSource): void;
removeUpscalingModel(model: UpscalingModel): void;
setExtras(extras: Partial<ExtrasFile>): void;
setCorrectionModel(model: CorrectionModel): void;
setDiffusionModel(model: DiffusionModel): void;
setExtraNetwork(model: ExtraNetwork): void;
setExtraSource(model: ExtraSource): void;
setUpscalingModel(model: UpscalingModel): void;
}
// eslint-disable-next-line sonarjs/cognitive-complexity
export function createModelSlice<TState extends ModelSlice>(): Slice<TState, ModelSlice> {
// eslint-disable-next-line sonarjs/cognitive-complexity
return (set) => ({
extras: {
correction: [],
diffusion: [],
networks: [],
sources: [],
upscaling: [],
},
setExtras(extras) {
set((prev) => ({
...prev,
extras: {
...prev.extras,
...extras,
},
}));
},
setCorrectionModel(model) {
set((prev) => {
const correction = [...prev.extras.correction];
const exists = correction.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
correction.push(model);
} else {
correction[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
correction,
},
};
});
},
setDiffusionModel(model) {
set((prev) => {
const diffusion = [...prev.extras.diffusion];
const exists = diffusion.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
diffusion.push(model);
} else {
diffusion[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
diffusion,
},
};
});
},
setExtraNetwork(model) {
set((prev) => {
const networks = [...prev.extras.networks];
const exists = networks.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
networks.push(model);
} else {
networks[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
networks,
},
};
});
},
setExtraSource(model) {
set((prev) => {
const sources = [...prev.extras.sources];
const exists = sources.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
sources.push(model);
} else {
sources[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
sources,
},
};
});
},
setUpscalingModel(model) {
set((prev) => {
const upscaling = [...prev.extras.upscaling];
const exists = upscaling.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
upscaling.push(model);
} else {
upscaling[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
upscaling,
},
};
});
},
removeCorrectionModel(model) {
set((prev) => {
const correction = prev.extras.correction.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
correction,
},
};
});
},
removeDiffusionModel(model) {
set((prev) => {
const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
diffusion,
},
};
});
},
removeExtraNetwork(model) {
set((prev) => {
const networks = prev.extras.networks.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
networks,
},
};
});
},
removeExtraSource(model) {
set((prev) => {
const sources = prev.extras.sources.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
sources,
},
};
});
},
removeUpscalingModel(model) {
set((prev) => {
const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
upscaling,
},
};
});
},
});
}

52
gui/src/state/profile.ts Normal file
View File

@ -0,0 +1,52 @@
import { Maybe } from '@apextoaster/js-utils';
import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js';
import { Slice } from './types.js';
export interface ProfileItem {
name: string;
params: BaseImgParams | Txt2ImgParams;
highres?: Maybe<HighresParams>;
upscale?: Maybe<UpscaleParams>;
}
export interface ProfileSlice {
profiles: Array<ProfileItem>;
removeProfile(profileName: string): void;
saveProfile(profile: ProfileItem): void;
}
export function createProfileSlice<TState extends ProfileSlice>(): Slice<TState, ProfileSlice> {
return (set) => ({
profiles: [],
saveProfile(profile: ProfileItem) {
set((prev) => {
const profiles = [...prev.profiles];
const idx = profiles.findIndex((it) => it.name === profile.name);
if (idx >= 0) {
profiles[idx] = profile;
} else {
profiles.push(profile);
}
return {
...prev,
profiles,
};
});
},
removeProfile(profileName: string) {
set((prev) => {
const profiles = [...prev.profiles];
const idx = profiles.findIndex((it) => it.name === profileName);
if (idx >= 0) {
profiles.splice(idx, 1);
}
return {
...prev,
profiles,
};
});
}
});
}

28
gui/src/state/reset.ts Normal file
View File

@ -0,0 +1,28 @@
import { BlendSlice } from './blend.js';
import { Img2ImgSlice } from './img2img.js';
import { InpaintSlice } from './inpaint.js';
import { Txt2ImgSlice } from './txt2img.js';
import { Slice } from './types.js';
import { UpscaleSlice } from './upscale.js';
export type SlicesWithReset = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & UpscaleSlice & BlendSlice;
export interface ResetSlice {
resetAll(): void;
}
export function createResetSlice<TState extends ResetSlice & SlicesWithReset>(): Slice<TState, ResetSlice> {
return (set) => ({
resetAll() {
set((prev) => {
const next = { ...prev };
next.resetImg2Img();
next.resetInpaint();
next.resetTxt2Img();
next.resetUpscale();
next.resetBlend();
return next;
});
},
});
}

105
gui/src/state/txt2img.ts Normal file
View File

@ -0,0 +1,105 @@
import { PipelineGrid } from '../client/utils.js';
import { ServerParams } from '../config.js';
import {
BaseImgParams,
HighresParams,
ModelParams,
Txt2ImgParams,
UpscaleParams,
} from '../types/params.js';
import { Slice, TabState } from './types.js';
export interface Txt2ImgSlice {
txt2img: TabState<Txt2ImgParams>;
txt2imgModel: ModelParams;
txt2imgHighres: HighresParams;
txt2imgUpscale: UpscaleParams;
txt2imgVariable: PipelineGrid;
resetTxt2Img(): void;
setTxt2Img(params: Partial<Txt2ImgParams>): void;
setTxt2ImgModel(params: Partial<ModelParams>): void;
setTxt2ImgHighres(params: Partial<HighresParams>): void;
setTxt2ImgUpscale(params: Partial<UpscaleParams>): void;
setTxt2ImgVariable(params: Partial<PipelineGrid>): void;
}
// eslint-disable-next-line max-params
export function createTxt2ImgSlice<TState extends Txt2ImgSlice>(
server: ServerParams,
defaultParams: Required<BaseImgParams>,
defaultHighres: HighresParams,
defaultModel: ModelParams,
defaultUpscale: UpscaleParams,
defaultGrid: PipelineGrid,
): Slice<TState, Txt2ImgSlice> {
return (set) => ({
txt2img: {
...defaultParams,
width: server.width.default,
height: server.height.default,
},
txt2imgHighres: {
...defaultHighres,
},
txt2imgModel: {
...defaultModel,
},
txt2imgUpscale: {
...defaultUpscale,
},
txt2imgVariable: {
...defaultGrid,
},
setTxt2Img(params) {
set((prev) => ({
txt2img: {
...prev.txt2img,
...params,
},
} as Partial<TState>));
},
setTxt2ImgHighres(params) {
set((prev) => ({
txt2imgHighres: {
...prev.txt2imgHighres,
...params,
},
} as Partial<TState>));
},
setTxt2ImgModel(params) {
set((prev) => ({
txt2imgModel: {
...prev.txt2imgModel,
...params,
},
} as Partial<TState>));
},
setTxt2ImgUpscale(params) {
set((prev) => ({
txt2imgUpscale: {
...prev.txt2imgUpscale,
...params,
},
} as Partial<TState>));
},
setTxt2ImgVariable(params) {
set((prev) => ({
txt2imgVariable: {
...prev.txt2imgVariable,
...params,
},
} as Partial<TState>));
},
resetTxt2Img() {
set({
txt2img: {
...defaultParams,
width: server.width.default,
height: server.height.default,
},
} as Partial<TState>);
},
});
}

17
gui/src/state/types.ts Normal file
View File

@ -0,0 +1,17 @@
import { PaletteMode } from '@mui/material';
import { StateCreator } from 'zustand';
import { ConfigFiles, ConfigState } from '../config.js';
export const MISSING_INDEX = -1;
export type Theme = PaletteMode | ''; // tri-state, '' is unset
/**
* Combine optional files and required ranges.
*/
export type TabState<TabParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>;
/**
* Shorthand for state creator to reduce repeated arguments.
*/
export type Slice<TState, TValue> = StateCreator<TState, [], [], TValue>;

87
gui/src/state/upscale.ts Normal file
View File

@ -0,0 +1,87 @@
import {
BaseImgParams,
HighresParams,
ModelParams,
UpscaleParams,
UpscaleReqParams,
} from '../types/params.js';
import { Slice, TabState } from './types.js';
export interface UpscaleSlice {
upscale: TabState<UpscaleReqParams>;
upscaleHighres: HighresParams;
upscaleModel: ModelParams;
upscaleUpscale: UpscaleParams;
resetUpscale(): void;
setUpscale(params: Partial<UpscaleReqParams>): void;
setUpscaleHighres(params: Partial<HighresParams>): void;
setUpscaleModel(params: Partial<ModelParams>): void;
setUpscaleUpscale(params: Partial<UpscaleParams>): void;
}
export function createUpscaleSlice<TState extends UpscaleSlice>(
defaultParams: Required<BaseImgParams>,
defaultHighres: HighresParams,
defaultModel: ModelParams,
defaultUpscale: UpscaleParams,
): Slice<TState, UpscaleSlice> {
return (set) => ({
upscale: {
...defaultParams,
// eslint-disable-next-line no-null/no-null
source: null,
},
upscaleHighres: {
...defaultHighres,
},
upscaleModel: {
...defaultModel,
},
upscaleUpscale: {
...defaultUpscale,
},
resetUpscale() {
set({
upscale: {
...defaultParams,
// eslint-disable-next-line no-null/no-null
source: null,
},
} as Partial<TState>);
},
setUpscale(source) {
set((prev) => ({
upscale: {
...prev.upscale,
...source,
},
} as Partial<TState>));
},
setUpscaleHighres(params) {
set((prev) => ({
upscaleHighres: {
...prev.upscaleHighres,
...params,
},
} as Partial<TState>));
},
setUpscaleModel(params) {
set((prev) => ({
upscaleModel: {
...prev.upscaleModel,
...defaultModel,
},
} as Partial<TState>));
},
setUpscaleUpscale(params) {
set((prev) => ({
upscaleUpscale: {
...prev.upscaleUpscale,
...params,
},
} as Partial<TState>));
},
});
}

View File

@ -287,6 +287,7 @@ export const I18N_STRINGS_EN = {
'ddpm': 'DDPM',
'deis-multi': 'DEIS Multistep',
'dpm-multi': 'DPM Multistep',
'dpm-sde': 'DPM SDE (Turbo)',
'dpm-single': 'DPM Singlestep',
'euler': 'Euler',
'euler-a': 'Euler Ancestral',