diff --git a/README.md b/README.md index ed10a950..e36ae18a 100644 --- a/README.md +++ b/README.md @@ -23,8 +23,8 @@ details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md). This is an incomplete list of new and interesting features, with links to the user guide: -- SDXL support -- LCM support +- supports SDXL and SDXL Turbo +- wide variety of schedulers: DDIM, DEIS, DPM SDE, Euler Ancestral, LCM, UniPC, and more - hardware acceleration on both AMD and Nvidia - tested on CUDA, DirectML, and ROCm - [half-precision support for low-memory GPUs](docs/user-guide.md#optimizing-models-for-lower-memory-usage) on both diff --git a/api/launch.bat b/api/launch.bat index 0d5933e5..680de5d9 100644 --- a/api/launch.bat +++ b/api/launch.bat @@ -1,5 +1,7 @@ call onnx_env\Scripts\Activate.bat +echo "This launch.bat script is deprecated in favor of launch.ps1 and will be removed in a future release." + echo "Downloading and converting models to ONNX format..." IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=..\models\extras.json) python -m onnx_web.convert ^ @@ -10,6 +12,12 @@ python -m onnx_web.convert ^ --extras=%ONNX_WEB_EXTRA_MODELS% ^ --token=%HF_TOKEN% %ONNX_WEB_EXTRA_ARGS% +IF NOT EXIST .\gui\index.html ( + echo "Please make sure you have downloaded the web UI files from https://github.com/ssube/onnx-web/tree/gh-pages" + echo "See https://github.com/ssube/onnx-web/blob/main/docs/setup-guide.md#download-the-web-ui-bundle for more information" + pause +) + echo "Launching API server..." waitress-serve ^ --host=0.0.0.0 ^ diff --git a/api/launch.ps1 b/api/launch.ps1 index 19add137..ba898151 100644 --- a/api/launch.ps1 +++ b/api/launch.ps1 @@ -10,6 +10,13 @@ python -m onnx_web.convert ` --extras=$Env:ONNX_WEB_EXTRA_MODELS ` --token=$Env:HF_TOKEN $Env:ONNX_WEB_EXTRA_ARGS +if (!(Test-Path -path .\gui\index.html -PathType Leaf)) { + echo "Downloading latest web UI files from Github..." + Invoke-WebRequest "https://raw.githubusercontent.com/ssube/onnx-web/gh-pages/v0.11.0/index.html" -OutFile .\gui\index.html + Invoke-WebRequest "https://raw.githubusercontent.com/ssube/onnx-web/gh-pages/v0.11.0/config.json" -OutFile .\gui\config.json + Invoke-WebRequest "https://raw.githubusercontent.com/ssube/onnx-web/gh-pages/v0.11.0/bundle/main.js" -OutFile .\gui\bundle\main.js +} + echo "Launching API server..." waitress-serve ` --host=0.0.0.0 ` diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 4486bbf6..326ba162 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -30,17 +30,16 @@ class BlendMaskStage(BaseStage): ) -> StageResult: logger.info("blending image using mask") - # TODO: does this need an alpha channel? mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black") - mult_mask.alpha_composite(stage_mask) + mult_mask = Image.alpha_composite(mult_mask, stage_mask) mult_mask = mult_mask.convert("L") if is_debug(): save_image(server, "last-mask.png", stage_mask) save_image(server, "last-mult-mask.png", mult_mask) - return StageResult( - images=[ + return StageResult.from_images( + [ Image.composite(stage_source, source, mult_mask) for source in sources.as_image() ] diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 5cbe7f07..1a4f0241 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -3,16 +3,17 @@ from argparse import ArgumentParser from logging import getLogger from os import makedirs, path from sys import exit -from typing import Any, Dict, List, Optional, Tuple -from urllib.parse import urlparse +from typing import Any, Dict, List, Union -from huggingface_hub.file_download import hf_hub_download from jsonschema import ValidationError, validate from onnx import load_model, save_model from transformers import CLIPTokenizer from ..constants import ONNX_MODEL, ONNX_WEIGHTS +from ..server.plugin import load_plugins from ..utils import load_config +from .client import add_model_source, fetch_model +from .client.huggingface import HuggingfaceClient from .correction.gfpgan import convert_correction_gfpgan from .diffusion.control import convert_diffusion_control from .diffusion.diffusion import convert_diffusion_diffusers @@ -25,8 +26,7 @@ from .upscaling.swinir import convert_upscaling_swinir from .utils import ( DEFAULT_OPSET, ConversionContext, - download_progress, - remove_prefix, + fix_diffusion_name, source_format, tuple_to_correction, tuple_to_diffusion, @@ -44,32 +44,34 @@ warnings.filterwarnings( ".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*", ) -Models = Dict[str, List[Any]] - logger = getLogger(__name__) +ModelDict = Dict[str, Union[float, int, str]] +Models = Dict[str, List[ModelDict]] -model_sources: Dict[str, Tuple[str, str]] = { - "civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"), +model_converters: Dict[str, Any] = { + "img2img": convert_diffusion_diffusers, + "img2img-sdxl": convert_diffusion_diffusers_xl, + "inpaint": convert_diffusion_diffusers, + "txt2img": convert_diffusion_diffusers, + "txt2img-sdxl": convert_diffusion_diffusers_xl, } -model_source_huggingface = "huggingface://" - # recommended models base_models: Models = { "diffusion": [ # v1.x ( "stable-diffusion-onnx-v1-5", - model_source_huggingface + "runwayml/stable-diffusion-v1-5", + HuggingfaceClient.protocol + "runwayml/stable-diffusion-v1-5", ), ( "stable-diffusion-onnx-v1-inpainting", - model_source_huggingface + "runwayml/stable-diffusion-inpainting", + HuggingfaceClient.protocol + "runwayml/stable-diffusion-inpainting", ), ( "upscaling-stable-diffusion-x4", - model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler", + HuggingfaceClient.protocol + "stabilityai/stable-diffusion-x4-upscaler", True, ), ], @@ -200,69 +202,203 @@ base_models: Models = { } -def fetch_model( - conversion: ConversionContext, - name: str, - source: str, - dest: Optional[str] = None, - format: Optional[str] = None, - hf_hub_fetch: bool = False, - hf_hub_filename: Optional[str] = None, -) -> Tuple[str, bool]: - cache_path = dest or conversion.cache_path - cache_name = path.join(cache_path, name) +def convert_model_source(conversion: ConversionContext, model): + model_format = source_format(model) + name = model["name"] + source = model["source"] - # add an extension if possible, some of the conversion code checks for it - if format is None: - url = urlparse(source) - ext = path.basename(url.path) - _filename, ext = path.splitext(ext) - if ext is not None: - cache_name = cache_name + ext + dest_path = None + if "dest" in model: + dest_path = path.join(conversion.model_path, model["dest"]) + + dest = fetch_model(conversion, name, source, format=model_format, dest=dest_path) + logger.info("finished downloading source: %s -> %s", source, dest) + + +def convert_model_network(conversion: ConversionContext, model): + model_format = source_format(model) + model_type = model["type"] + name = model["name"] + source = model["source"] + + if model_type == "control": + dest = fetch_model( + conversion, + name, + source, + format=model_format, + ) + + convert_diffusion_control( + conversion, + model, + dest, + path.join(conversion.model_path, model_type, name), + ) else: - cache_name = f"{cache_name}.{format}" + model = model.get("model", None) + dest = fetch_model( + conversion, + name, + source, + dest=path.join(conversion.model_path, model_type), + format=model_format, + embeds=(model_type == "inversion" and model == "concept"), + ) - if path.exists(cache_name): - logger.debug("model already exists in cache, skipping fetch") - return cache_name, False + logger.info("finished downloading network: %s -> %s", source, dest) - for proto in model_sources: - api_name, api_root = model_sources.get(proto) - if source.startswith(proto): - api_source = api_root % (remove_prefix(source, proto)) - logger.info( - "downloading model from %s: %s -> %s", api_name, api_source, cache_name + +def convert_model_diffusion(conversion: ConversionContext, model): + # fix up entries with missing prefixes + name = fix_diffusion_name(model["name"]) + if name != model["name"]: + # update the model in-memory if the name changed + model["name"] = name + + model_format = source_format(model) + + pipeline = model.get("pipeline", "txt2img") + converter = model_converters.get(pipeline) + converted, dest = converter( + conversion, + model, + model_format, + ) + + # make sure blending only happens once, not every run + if converted: + # keep track of which models have been blended + blend_models = {} + + inversion_dest = path.join(conversion.model_path, "inversion") + lora_dest = path.join(conversion.model_path, "lora") + + for inversion in model.get("inversions", []): + if "text_encoder" not in blend_models: + blend_models["text_encoder"] = load_model( + path.join( + dest, + "text_encoder", + ONNX_MODEL, + ) + ) + + if "tokenizer" not in blend_models: + blend_models["tokenizer"] = CLIPTokenizer.from_pretrained( + dest, + subfolder="tokenizer", + ) + + inversion_name = inversion["name"] + inversion_source = inversion["source"] + inversion_format = inversion.get("format", None) + inversion_source = fetch_model( + conversion, + inversion_name, + inversion_source, + dest=inversion_dest, ) - return download_progress([(api_source, cache_name)]), False + inversion_token = inversion.get("token", inversion_name) + inversion_weight = inversion.get("weight", 1.0) - if source.startswith(model_source_huggingface): - hub_source = remove_prefix(source, model_source_huggingface) - logger.info("downloading model from Huggingface Hub: %s", hub_source) - # from_pretrained has a bunch of useful logic that snapshot_download by itself down not - if hf_hub_fetch: - return ( - hf_hub_download( - repo_id=hub_source, - filename=hf_hub_filename, - cache_dir=cache_path, - force_filename=f"{name}.bin", - ), - False, + blend_textual_inversions( + conversion, + blend_models["text_encoder"], + blend_models["tokenizer"], + [ + ( + inversion_source, + inversion_weight, + inversion_token, + inversion_format, + ) + ], ) - else: - return hub_source, True - elif source.startswith("https://"): - logger.info("downloading model from: %s", source) - return download_progress([(source, cache_name)]), False - elif source.startswith("http://"): - logger.warning("downloading model from insecure source: %s", source) - return download_progress([(source, cache_name)]), False - elif source.startswith(path.sep) or source.startswith("."): - logger.info("using local model: %s", source) - return source, False + + for lora in model.get("loras", []): + if "text_encoder" not in blend_models: + blend_models["text_encoder"] = load_model( + path.join( + dest, + "text_encoder", + ONNX_MODEL, + ) + ) + + if "unet" not in blend_models: + blend_models["unet"] = load_model(path.join(dest, "unet", ONNX_MODEL)) + + # load models if not loaded yet + lora_name = lora["name"] + lora_source = lora["source"] + lora_source = fetch_model( + conversion, + f"{name}-lora-{lora_name}", + lora_source, + dest=lora_dest, + ) + lora_weight = lora.get("weight", 1.0) + + blend_loras( + conversion, + blend_models["text_encoder"], + [(lora_source, lora_weight)], + "text_encoder", + ) + + blend_loras( + conversion, + blend_models["unet"], + [(lora_source, lora_weight)], + "unet", + ) + + if "tokenizer" in blend_models: + dest_path = path.join(dest, "tokenizer") + logger.debug("saving blended tokenizer to %s", dest_path) + blend_models["tokenizer"].save_pretrained(dest_path) + + for name in ["text_encoder", "unet"]: + if name in blend_models: + dest_path = path.join(dest, name, ONNX_MODEL) + logger.debug("saving blended %s model to %s", name, dest_path) + save_model( + blend_models[name], + dest_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=ONNX_WEIGHTS, + ) + + +def convert_model_upscaling(conversion: ConversionContext, model): + model_format = source_format(model) + name = model["name"] + + source = fetch_model(conversion, name, model["source"], format=model_format) + model_type = model.get("model", "resrgan") + if model_type == "bsrgan": + convert_upscaling_bsrgan(conversion, model, source) + elif model_type == "resrgan": + convert_upscale_resrgan(conversion, model, source) + elif model_type == "swinir": + convert_upscaling_swinir(conversion, model, source) else: - logger.info("unknown model location, using path as provided: %s", source) - return source, False + logger.error("unknown upscaling model type %s for %s", model_type, name) + raise ValueError(name) + + +def convert_model_correction(conversion: ConversionContext, model): + model_format = source_format(model) + name = model["name"] + source = fetch_model(conversion, name, model["source"], format=model_format) + model_type = model.get("model", "gfpgan") + if model_type == "gfpgan": + convert_correction_gfpgan(conversion, model, source) + else: + logger.error("unknown correction model type %s for %s", model_type, name) + raise ValueError(name) def convert_models(conversion: ConversionContext, args, models: Models): @@ -276,69 +412,21 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping source: %s", name) else: - model_format = source_format(model) - source = model["source"] - try: - dest_path = None - if "dest" in model: - dest_path = path.join(conversion.model_path, model["dest"]) - - dest, hf = fetch_model( - conversion, name, source, format=model_format, dest=dest_path - ) - logger.info("finished downloading source: %s -> %s", source, dest) + convert_model_source(conversion, model) except Exception: logger.exception("error fetching source %s", name) model_errors.append(name) if args.networks and "networks" in models: - for network in models.get("networks", []): - name = network["name"] + for model in models.get("networks", []): + name = model["name"] if name in args.skip: logger.info("skipping network: %s", name) else: - network_format = source_format(network) - network_model = network.get("model", None) - network_type = network["type"] - source = network["source"] - try: - if network_type == "control": - dest, hf = fetch_model( - conversion, - name, - source, - format=network_format, - ) - - convert_diffusion_control( - conversion, - network, - dest, - path.join(conversion.model_path, network_type, name), - ) - if network_type == "inversion" and network_model == "concept": - dest, hf = fetch_model( - conversion, - name, - source, - dest=path.join(conversion.model_path, network_type), - format=network_format, - hf_hub_fetch=True, - hf_hub_filename="learned_embeds.bin", - ) - else: - dest, hf = fetch_model( - conversion, - name, - source, - dest=path.join(conversion.model_path, network_type), - format=network_format, - ) - - logger.info("finished downloading network: %s -> %s", source, dest) + convert_model_network(conversion, model) except Exception: logger.exception("error fetching network %s", name) model_errors.append(name) @@ -351,142 +439,8 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping model: %s", name) else: - model_format = source_format(model) - try: - source, hf = fetch_model( - conversion, name, model["source"], format=model_format - ) - - pipeline = model.get("pipeline", "txt2img") - if pipeline.endswith("-sdxl"): - converted, dest = convert_diffusion_diffusers_xl( - conversion, - model, - source, - model_format, - hf=hf, - ) - else: - converted, dest = convert_diffusion_diffusers( - conversion, - model, - source, - model_format, - hf=hf, - ) - - # make sure blending only happens once, not every run - if converted: - # keep track of which models have been blended - blend_models = {} - - inversion_dest = path.join(conversion.model_path, "inversion") - lora_dest = path.join(conversion.model_path, "lora") - - for inversion in model.get("inversions", []): - if "text_encoder" not in blend_models: - blend_models["text_encoder"] = load_model( - path.join( - dest, - "text_encoder", - ONNX_MODEL, - ) - ) - - if "tokenizer" not in blend_models: - blend_models[ - "tokenizer" - ] = CLIPTokenizer.from_pretrained( - dest, - subfolder="tokenizer", - ) - - inversion_name = inversion["name"] - inversion_source = inversion["source"] - inversion_format = inversion.get("format", None) - inversion_source, hf = fetch_model( - conversion, - inversion_name, - inversion_source, - dest=inversion_dest, - ) - inversion_token = inversion.get("token", inversion_name) - inversion_weight = inversion.get("weight", 1.0) - - blend_textual_inversions( - conversion, - blend_models["text_encoder"], - blend_models["tokenizer"], - [ - ( - inversion_source, - inversion_weight, - inversion_token, - inversion_format, - ) - ], - ) - - for lora in model.get("loras", []): - if "text_encoder" not in blend_models: - blend_models["text_encoder"] = load_model( - path.join( - dest, - "text_encoder", - ONNX_MODEL, - ) - ) - - if "unet" not in blend_models: - blend_models["unet"] = load_model( - path.join(dest, "unet", ONNX_MODEL) - ) - - # load models if not loaded yet - lora_name = lora["name"] - lora_source = lora["source"] - lora_source, hf = fetch_model( - conversion, - f"{name}-lora-{lora_name}", - lora_source, - dest=lora_dest, - ) - lora_weight = lora.get("weight", 1.0) - - blend_loras( - conversion, - blend_models["text_encoder"], - [(lora_source, lora_weight)], - "text_encoder", - ) - - blend_loras( - conversion, - blend_models["unet"], - [(lora_source, lora_weight)], - "unet", - ) - - if "tokenizer" in blend_models: - dest_path = path.join(dest, "tokenizer") - logger.debug("saving blended tokenizer to %s", dest_path) - blend_models["tokenizer"].save_pretrained(dest_path) - - for name in ["text_encoder", "unet"]: - if name in blend_models: - dest_path = path.join(dest, name, ONNX_MODEL) - logger.debug( - "saving blended %s model to %s", name, dest_path - ) - save_model( - blend_models[name], - dest_path, - save_as_external_data=True, - all_tensors_to_one_file=True, - location=ONNX_WEIGHTS, - ) - + convert_model_diffusion(conversion, model) except Exception: logger.exception( "error converting diffusion model %s", @@ -502,24 +456,8 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping model: %s", name) else: - model_format = source_format(model) - try: - source, hf = fetch_model( - conversion, name, model["source"], format=model_format - ) - model_type = model.get("model", "resrgan") - if model_type == "bsrgan": - convert_upscaling_bsrgan(conversion, model, source) - elif model_type == "resrgan": - convert_upscale_resrgan(conversion, model, source) - elif model_type == "swinir": - convert_upscaling_swinir(conversion, model, source) - else: - logger.error( - "unknown upscaling model type %s for %s", model_type, name - ) - model_errors.append(name) + convert_model_upscaling(conversion, model) except Exception: logger.exception( "error converting upscaling model %s", @@ -535,19 +473,8 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping model: %s", name) else: - model_format = source_format(model) try: - source, hf = fetch_model( - conversion, name, model["source"], format=model_format - ) - model_type = model.get("model", "gfpgan") - if model_type == "gfpgan": - convert_correction_gfpgan(conversion, model, source) - else: - logger.error( - "unknown correction model type %s for %s", model_type, name - ) - model_errors.append(name) + convert_model_correction(conversion, model) except Exception: logger.exception( "error converting correction model %s", @@ -559,12 +486,26 @@ def convert_models(conversion: ConversionContext, args, models: Models): logger.error("error while converting models: %s", model_errors) +def register_plugins(conversion: ConversionContext): + logger.info("loading conversion plugins") + exports = load_plugins(conversion) + + for proto, client in exports.clients: + try: + add_model_source(proto, client) + except Exception: + logger.exception("error loading client for protocol: %s", proto) + + # TODO: add converters + + def main(args=None) -> int: parser = ArgumentParser( prog="onnx-web model converter", description="convert checkpoint models to ONNX" ) # model groups + parser.add_argument("--base", action="store_true", default=True) parser.add_argument("--networks", action="store_true", default=True) parser.add_argument("--sources", action="store_true", default=True) parser.add_argument("--correction", action="store_true", default=False) @@ -602,16 +543,20 @@ def main(args=None) -> int: server.half = args.half or server.has_optimization("onnx-fp16") server.opset = args.opset server.token = args.token + + register_plugins(server) + logger.info( - "converting models in %s using %s", server.model_path, server.training_device + "converting models into %s using %s", server.model_path, server.training_device ) if not path.exists(server.model_path): logger.info("model path does not existing, creating: %s", server.model_path) makedirs(server.model_path) - logger.info("converting base models") - convert_models(server, args, base_models) + if args.base: + logger.info("converting base models") + convert_models(server, args, base_models) extras = [] extras.extend(server.extra_models) diff --git a/api/onnx_web/convert/client/__init__.py b/api/onnx_web/convert/client/__init__.py new file mode 100644 index 00000000..1d90f04e --- /dev/null +++ b/api/onnx_web/convert/client/__init__.py @@ -0,0 +1,54 @@ +from typing import Callable, Dict, Optional +from logging import getLogger +from os import path + +from ..utils import ConversionContext +from .base import BaseClient +from .civitai import CivitaiClient +from .file import FileClient +from .http import HttpClient +from .huggingface import HuggingfaceClient + +logger = getLogger(__name__) + + +model_sources: Dict[str, Callable[[], BaseClient]] = { + CivitaiClient.protocol: CivitaiClient, + FileClient.protocol: FileClient, + HttpClient.insecure_protocol: HttpClient, + HttpClient.protocol: HttpClient, + HuggingfaceClient.protocol: HuggingfaceClient, +} + + +def add_model_source(proto: str, client: BaseClient): + global model_sources + + if proto in model_sources: + raise ValueError("protocol has already been taken") + + model_sources[proto] = client + + +def fetch_model( + conversion: ConversionContext, + name: str, + source: str, + format: Optional[str] = None, + dest: Optional[str] = None, + **kwargs, +) -> str: + # TODO: switch to urlparse's default scheme + if source.startswith(path.sep) or source.startswith("."): + logger.info("adding file protocol to local path source: %s", source) + source = FileClient.protocol + source + + for proto, client_type in model_sources.items(): + if source.startswith(proto): + client = client_type(conversion) + return client.download( + conversion, name, source, format=format, dest=dest, **kwargs + ) + + logger.warning("unknown model protocol, using path as provided: %s", source) + return source diff --git a/api/onnx_web/convert/client/base.py b/api/onnx_web/convert/client/base.py new file mode 100644 index 00000000..2ec2ae65 --- /dev/null +++ b/api/onnx_web/convert/client/base.py @@ -0,0 +1,16 @@ +from typing import Optional + +from ..utils import ConversionContext + + +class BaseClient: + def download( + self, + conversion: ConversionContext, + name: str, + source: str, + format: Optional[str] = None, + dest: Optional[str] = None, + **kwargs, + ) -> str: + raise NotImplementedError() diff --git a/api/onnx_web/convert/client/civitai.py b/api/onnx_web/convert/client/civitai.py new file mode 100644 index 00000000..b57b6533 --- /dev/null +++ b/api/onnx_web/convert/client/civitai.py @@ -0,0 +1,64 @@ +from logging import getLogger +from typing import Optional + +from ..utils import ( + ConversionContext, + build_cache_paths, + download_progress, + get_first_exists, + remove_prefix, +) +from .base import BaseClient + +logger = getLogger(__name__) + +CIVITAI_ROOT = "https://civitai.com/api/download/models/%s" + + +class CivitaiClient(BaseClient): + name = "civitai" + protocol = "civitai://" + + root: str + token: Optional[str] + + def __init__( + self, + conversion: ConversionContext, + token: Optional[str] = None, + root: str = CIVITAI_ROOT, + ): + self.root = conversion.get_setting("CIVITAI_ROOT", root) + self.token = conversion.get_setting("CIVITAI_TOKEN", token) + + def download( + self, + conversion: ConversionContext, + name: str, + source: str, + format: Optional[str] = None, + dest: Optional[str] = None, + **kwargs, + ) -> str: + cache_paths = build_cache_paths( + conversion, + name, + client=CivitaiClient.name, + format=format, + dest=dest, + ) + cached = get_first_exists(cache_paths) + if cached: + return cached + + source = self.root % (remove_prefix(source, CivitaiClient.protocol)) + logger.info("downloading model from Civitai: %s -> %s", source, cache_paths[0]) + + if self.token: + logger.debug("adding Civitai token authentication") + if "?" in source: + source = f"{source}&token={self.token}" + else: + source = f"{source}?token={self.token}" + + return download_progress(source, cache_paths[0]) diff --git a/api/onnx_web/convert/client/file.py b/api/onnx_web/convert/client/file.py new file mode 100644 index 00000000..8ed1a4d9 --- /dev/null +++ b/api/onnx_web/convert/client/file.py @@ -0,0 +1,32 @@ +from logging import getLogger +from os import path +from typing import Optional +from urllib.parse import urlparse + +from ..utils import ConversionContext +from .base import BaseClient + +logger = getLogger(__name__) + + +class FileClient(BaseClient): + protocol = "file://" + + def __init__(self, _conversion: ConversionContext): + """ + Nothing to initialize for this client. + """ + pass + + def download( + self, + conversion: ConversionContext, + _name: str, + uri: str, + format: Optional[str] = None, + dest: Optional[str] = None, + **kwargs, + ) -> str: + parts = urlparse(uri) + logger.info("loading model from: %s", parts.path) + return path.join(dest or conversion.model_path, parts.path) diff --git a/api/onnx_web/convert/client/http.py b/api/onnx_web/convert/client/http.py new file mode 100644 index 00000000..151ebccd --- /dev/null +++ b/api/onnx_web/convert/client/http.py @@ -0,0 +1,52 @@ +from logging import getLogger +from typing import Dict, Optional + +from ..utils import ( + ConversionContext, + build_cache_paths, + download_progress, + get_first_exists, +) +from .base import BaseClient + +logger = getLogger(__name__) + + +class HttpClient(BaseClient): + name = "http" + protocol = "https://" + insecure_protocol = "http://" + + headers: Dict[str, str] + + def __init__( + self, _conversion: ConversionContext, headers: Optional[Dict[str, str]] = None + ): + self.headers = headers or {} + + def download( + self, + conversion: ConversionContext, + name: str, + source: str, + format: Optional[str] = None, + dest: Optional[str] = None, + **kwargs, + ) -> str: + cache_paths = build_cache_paths( + conversion, + name, + client=HttpClient.name, + format=format, + dest=dest, + ) + cached = get_first_exists(cache_paths) + if cached: + return cached + + if source.startswith(HttpClient.protocol): + logger.info("downloading model from: %s", source) + elif source.startswith(HttpClient.insecure_protocol): + logger.warning("downloading model from insecure source: %s", source) + + return download_progress(source, cache_paths[0]) diff --git a/api/onnx_web/convert/client/huggingface.py b/api/onnx_web/convert/client/huggingface.py new file mode 100644 index 00000000..c35698e7 --- /dev/null +++ b/api/onnx_web/convert/client/huggingface.py @@ -0,0 +1,43 @@ +from logging import getLogger +from typing import Optional + +from huggingface_hub.file_download import hf_hub_download + +from ..utils import ConversionContext, remove_prefix +from .base import BaseClient + +logger = getLogger(__name__) + + +class HuggingfaceClient(BaseClient): + name = "huggingface" + protocol = "huggingface://" + + token: Optional[str] + + def __init__(self, conversion: ConversionContext, token: Optional[str] = None): + self.token = conversion.get_setting("HUGGINGFACE_TOKEN", token) + + def download( + self, + conversion: ConversionContext, + name: str, + source: str, + format: Optional[str] = None, + dest: Optional[str] = None, + embeds: bool = False, + **kwargs, + ) -> str: + source = remove_prefix(source, HuggingfaceClient.protocol) + logger.info("downloading model from Huggingface Hub: %s", source) + + if embeds: + return hf_hub_download( + repo_id=source, + filename="learned_embeds.bin", + cache_dir=dest or conversion.cache_path, + force_filename=f"{name}.bin", + token=self.token, + ) + else: + return source diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index a8ecbbf7..a2e3d807 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -36,6 +36,8 @@ from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ...diffusers.version_safe_diffusers import AttnProcessor from ...models.cnet import UNet2DConditionModel_CNet from ...utils import run_gc +from ..client import fetch_model +from ..client.huggingface import HuggingfaceClient from ..utils import ( RESOLVE_FORMATS, ConversionContext, @@ -43,6 +45,7 @@ from ..utils import ( is_torch_2_0, load_tensor, onnx_export, + remove_prefix, ) from .checkpoint import convert_extract_checkpoint @@ -267,14 +270,13 @@ def collate_cnet(cnet_path): def convert_diffusion_diffusers( conversion: ConversionContext, model: Dict, - source: str, format: Optional[str], - hf: bool = False, ) -> Tuple[bool, str]: """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py """ - name = model.get("name") + name = str(model.get("name")).strip() + source = model.get("source") # optional config = model.get("config", None) @@ -320,9 +322,11 @@ def convert_diffusion_diffusers( logger.info("ONNX model already exists, skipping") return (False, dest_path) + cache_path = fetch_model(conversion, name, source, format=format) + pipe_class = CONVERT_PIPELINES.get(pipe_type) v2, pipe_args = get_model_version( - source, conversion.map_location, size=image_size, version=version + cache_path, conversion.map_location, size=image_size, version=version ) is_inpainting = False @@ -334,57 +338,66 @@ def convert_diffusion_diffusers( pipe_args["from_safetensors"] = True torch_source = None - if path.exists(source) and path.isdir(source): - logger.debug("loading pipeline from diffusers directory: %s", source) - pipeline = pipe_class.from_pretrained( - source, - torch_dtype=dtype, - use_auth_token=conversion.token, - ).to(device) - elif path.exists(source) and path.isfile(source): - if conversion.extract: - logger.debug("extracting SD checkpoint to Torch models: %s", source) - torch_source = convert_extract_checkpoint( - conversion, - source, - f"{name}-torch", - is_inpainting=is_inpainting, - config_file=config, - vae_file=replace_vae, - ) - logger.debug("loading pipeline from extracted checkpoint: %s", torch_source) + if path.exists(cache_path): + if path.isdir(cache_path): + logger.debug("loading pipeline from diffusers directory: %s", source) pipeline = pipe_class.from_pretrained( - torch_source, + cache_path, torch_dtype=dtype, + use_auth_token=conversion.token, ).to(device) - - # VAE replacement already happened during extraction, skip - replace_vae = None else: - logger.debug("loading pipeline from SD checkpoint: %s", source) - pipeline = download_from_original_stable_diffusion_ckpt( - source, - original_config_file=config_path, - pipeline_class=pipe_class, - **pipe_args, - ).to(device, torch_dtype=dtype) - elif hf: - logger.debug("downloading pretrained model from Huggingface hub: %s", source) + if conversion.extract: + logger.debug("extracting SD checkpoint to Torch models: %s", source) + torch_source = convert_extract_checkpoint( + conversion, + source, + f"{name}-torch", + is_inpainting=is_inpainting, + config_file=config, + vae_file=replace_vae, + ) + logger.debug( + "loading pipeline from extracted checkpoint: %s", torch_source + ) + pipeline = pipe_class.from_pretrained( + torch_source, + torch_dtype=dtype, + ).to(device) + + # VAE replacement already happened during extraction, skip + replace_vae = None + else: + logger.debug("loading pipeline from SD checkpoint: %s", source) + pipeline = download_from_original_stable_diffusion_ckpt( + cache_path, + original_config_file=config_path, + pipeline_class=pipe_class, + **pipe_args, + ).to(device, torch_dtype=dtype) + elif source.startswith(HuggingfaceClient.protocol): + hf_path = remove_prefix(source, HuggingfaceClient.protocol) + logger.debug("downloading pretrained model from Huggingface hub: %s", hf_path) pipeline = pipe_class.from_pretrained( - source, + hf_path, torch_dtype=dtype, use_auth_token=conversion.token, ).to(device) else: - logger.warning("pipeline source not found or not recognized: %s", source) - raise ValueError(f"pipeline source not found or not recognized: {source}") + logger.warning( + "pipeline source not found and protocol not recognized: %s", source + ) + raise ValueError( + f"pipeline source not found and protocol not recognized: {source}" + ) if replace_vae is not None: vae_path = path.join(conversion.model_path, replace_vae) - if check_ext(replace_vae, RESOLVE_FORMATS): + vae_file = check_ext(vae_path, RESOLVE_FORMATS) + if vae_file[0]: pipeline.vae = AutoencoderKL.from_single_file(vae_path) else: - pipeline.vae = AutoencoderKL.from_pretrained(vae_path) + pipeline.vae = AutoencoderKL.from_pretrained(replace_vae) if is_torch_2_0: pipeline.unet.set_attn_processor(AttnProcessor()) diff --git a/api/onnx_web/convert/diffusion/diffusion_xl.py b/api/onnx_web/convert/diffusion/diffusion_xl.py index 7a03b700..c09b9440 100644 --- a/api/onnx_web/convert/diffusion/diffusion_xl.py +++ b/api/onnx_web/convert/diffusion/diffusion_xl.py @@ -10,6 +10,7 @@ from onnxruntime.transformers.float16 import convert_float_to_float16 from optimum.exporters.onnx import main_export from ...constants import ONNX_MODEL +from ..client import fetch_model from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext logger = getLogger(__name__) @@ -19,14 +20,13 @@ logger = getLogger(__name__) def convert_diffusion_diffusers_xl( conversion: ConversionContext, model: Dict, - source: str, format: Optional[str], - hf: bool = False, ) -> Tuple[bool, str]: """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py """ - name = model.get("name") + name = str(model.get("name")).strip() + source = model.get("source") replace_vae = model.get("vae", None) device = conversion.training_device @@ -52,24 +52,26 @@ def convert_diffusion_diffusers_xl( return (False, dest_path) + cache_path = fetch_model(conversion, name, model["source"], format=format) # safetensors -> diffusers directory with torch models temp_path = path.join(conversion.cache_path, f"{name}-torch") if format == "safetensors": pipeline = StableDiffusionXLPipeline.from_single_file( - source, use_safetensors=True + cache_path, use_safetensors=True ) else: - pipeline = StableDiffusionXLPipeline.from_pretrained(source) + pipeline = StableDiffusionXLPipeline.from_pretrained(cache_path) if replace_vae is not None: vae_path = path.join(conversion.model_path, replace_vae) - if check_ext(replace_vae, RESOLVE_FORMATS): + vae_file = check_ext(vae_path, RESOLVE_FORMATS) + if vae_file[0]: logger.debug("loading VAE from single tensor file: %s", vae_path) pipeline.vae = AutoencoderKL.from_single_file(vae_path) else: - logger.debug("loading pretrained VAE from path: %s", vae_path) - pipeline.vae = AutoencoderKL.from_pretrained(vae_path) + logger.debug("loading pretrained VAE from path: %s", replace_vae) + pipeline.vae = AutoencoderKL.from_pretrained(replace_vae) if path.exists(temp_path): logger.debug("torch model already exists for %s: %s", source, temp_path) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index ef44ba20..7785007b 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -31,6 +31,15 @@ ModelDict = Dict[str, Union[str, int]] LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]] DEFAULT_OPSET = 14 +DIFFUSION_PREFIX = [ + "diffusion-", + "diffusion/", + "diffusion\\", + "stable-diffusion-", + "upscaling-", # SD upscaling +] +MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"] +RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"] class ConversionContext(ServerContext): @@ -85,38 +94,37 @@ class ConversionContext(ServerContext): return torch.device(self.training_device) -def download_progress(urls: List[Tuple[str, str]]): - for url, dest in urls: - dest_path = Path(dest).expanduser().resolve() - dest_path.parent.mkdir(parents=True, exist_ok=True) - - if dest_path.exists(): - logger.debug("destination already exists: %s", dest_path) - return str(dest_path.absolute()) - - req = requests.get( - url, - stream=True, - allow_redirects=True, - headers={ - "User-Agent": "onnx-web-api", - }, - ) - if req.status_code != 200: - req.raise_for_status() # Only works for 4xx errors, per SO answer - raise RequestException( - "request to %s failed with status code: %s" % (url, req.status_code) - ) - - total = int(req.headers.get("Content-Length", 0)) - desc = "unknown" if total == 0 else "" - req.raw.read = partial(req.raw.read, decode_content=True) - with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data: - with dest_path.open("wb") as f: - shutil.copyfileobj(data, f) +def download_progress(source: str, dest: str): + dest_path = Path(dest).expanduser().resolve() + dest_path.parent.mkdir(parents=True, exist_ok=True) + if dest_path.exists(): + logger.debug("destination already exists: %s", dest_path) return str(dest_path.absolute()) + req = requests.get( + source, + stream=True, + allow_redirects=True, + headers={ + "User-Agent": "onnx-web-api", + }, + ) + if req.status_code != 200: + req.raise_for_status() # Only works for 4xx errors, per SO answer + raise RequestException( + "request to %s failed with status code: %s" % (source, req.status_code) + ) + + total = int(req.headers.get("Content-Length", 0)) + desc = "unknown" if total == 0 else "" + req.raw.read = partial(req.raw.read, decode_content=True) + with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data: + with dest_path.open("wb") as f: + shutil.copyfileobj(data, f) + + return str(dest_path.absolute()) + def tuple_to_source(model: Union[ModelDict, LegacyModel]): if isinstance(model, list) or isinstance(model, tuple): @@ -184,10 +192,6 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]): return model -MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"] -RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"] - - def check_ext(name: str, exts: List[str]) -> Tuple[bool, str]: _name, ext = path.splitext(name) ext = ext.strip(".") @@ -215,6 +219,9 @@ def remove_prefix(name: str, prefix: str) -> str: def load_torch(name: str, map_location=None) -> Optional[Dict]: + """ + TODO: move out of convert + """ try: logger.debug("loading tensor with Torch: %s", name) checkpoint = torch.load(name, map_location=map_location) @@ -228,6 +235,9 @@ def load_torch(name: str, map_location=None) -> Optional[Dict]: def load_tensor(name: str, map_location=None) -> Optional[Dict]: + """ + TODO: move out of convert + """ logger.debug("loading tensor: %s", name) _, extension = path.splitext(name) extension = extension[1:].lower() @@ -285,6 +295,9 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]: def resolve_tensor(name: str) -> Optional[str]: + """ + TODO: move out of convert + """ logger.debug("searching for tensors with known extensions: %s", name) for next_extension in RESOLVE_FORMATS: next_name = f"{name}.{next_extension}" @@ -345,3 +358,52 @@ def onnx_export( all_tensors_to_one_file=True, location=ONNX_WEIGHTS, ) + + +def fix_diffusion_name(name: str): + if not any([name.startswith(prefix) for prefix in DIFFUSION_PREFIX]): + logger.warning( + "diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match", + name, + ) + return f"diffusion-{name}" + + return name + + +def build_cache_paths( + conversion: ConversionContext, + name: str, + client: Optional[str] = None, + dest: Optional[str] = None, + format: Optional[str] = None, +) -> List[str]: + cache_path = dest or conversion.cache_path + + # add an extension if possible, some of the conversion code checks for it + if format is not None: + basename = path.basename(name) + _filename, ext = path.splitext(basename) + if ext is None or ext == "": + name = f"{name}.{format}" + + paths = [ + path.join(cache_path, name), + ] + + if client is not None: + client_path = path.join(cache_path, client) + paths.append(path.join(client_path, name)) + + return paths + + +def get_first_exists( + paths: List[str], +) -> Optional[str]: + for name in paths: + if path.exists(name): + logger.debug("model already exists in cache, skipping fetch: %s", name) + return name + + return None diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 30a87863..b6620e8b 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -30,12 +30,12 @@ from .version_safe_diffusers import ( DDPMScheduler, DEISMultistepScheduler, DPMSolverMultistepScheduler, + DPMSolverSDEScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, IPNDMScheduler, - KarrasVeScheduler, KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, LCMScheduler, @@ -71,6 +71,7 @@ pipeline_schedulers = { "ddpm": DDPMScheduler, "deis-multi": DEISMultistepScheduler, "dpm-multi": DPMSolverMultistepScheduler, + "dpm-sde": DPMSolverSDEScheduler, "dpm-single": DPMSolverSinglestepScheduler, "euler": EulerDiscreteScheduler, "euler-a": EulerAncestralDiscreteScheduler, @@ -78,7 +79,6 @@ pipeline_schedulers = { "ipndm": IPNDMScheduler, "k-dpm-2-a": KDPM2AncestralDiscreteScheduler, "k-dpm-2": KDPM2DiscreteScheduler, - "karras-ve": KarrasVeScheduler, "lcm": LCMScheduler, "lms-discrete": LMSDiscreteScheduler, "pndm": PNDMScheduler, diff --git a/api/onnx_web/diffusers/version_safe_diffusers.py b/api/onnx_web/diffusers/version_safe_diffusers.py index d256d615..8c4da406 100644 --- a/api/onnx_web/diffusers/version_safe_diffusers.py +++ b/api/onnx_web/diffusers/version_safe_diffusers.py @@ -12,6 +12,11 @@ try: except ImportError: from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler +try: + from diffusers import DPMSolverSDEScheduler +except ImportError: + from ..diffusers.stub_scheduler import StubScheduler as DPMSolverSDEScheduler + try: from diffusers import LCMScheduler except ImportError: diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index 034fc3c6..e02cebd5 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -94,45 +94,44 @@ class ServerContext: self.cache = ModelCache(self.cache_limit) @classmethod - def from_environ(cls): - memory_limit = environ.get("ONNX_WEB_MEMORY_LIMIT", None) + def from_environ(cls, env=environ): + memory_limit = env.get("ONNX_WEB_MEMORY_LIMIT", None) if memory_limit is not None: memory_limit = int(memory_limit) return cls( - bundle_path=environ.get( - "ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out") - ), - model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")), - output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")), - params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."), - cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", default="*"), + bundle_path=env.get("ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")), + model_path=env.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")), + output_path=env.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")), + params_path=env.get("ONNX_WEB_PARAMS_PATH", "."), + cors_origin=get_list(env, "ONNX_WEB_CORS_ORIGIN", default="*"), any_platform=get_boolean( - environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM + env, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM ), - block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS"), - default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), - image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT), - cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)), + block_platforms=get_list(env, "ONNX_WEB_BLOCK_PLATFORMS"), + default_platform=env.get("ONNX_WEB_DEFAULT_PLATFORM", None), + image_format=env.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT), + cache_limit=int(env.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)), show_progress=get_boolean( - environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS + env, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS ), - optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS"), - extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS"), - job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), + optimizations=get_list(env, "ONNX_WEB_OPTIMIZATIONS"), + extra_models=get_list(env, "ONNX_WEB_EXTRA_MODELS"), + job_limit=int(env.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), memory_limit=memory_limit, - admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None), - server_version=environ.get( - "ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION - ), + admin_token=env.get("ONNX_WEB_ADMIN_TOKEN", None), + server_version=env.get("ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION), worker_retries=int( - environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES) + env.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES) ), - feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS"), - plugins=get_list(environ, "ONNX_WEB_PLUGINS", ""), - debug=get_boolean(environ, "ONNX_WEB_DEBUG", False), + feature_flags=get_list(env, "ONNX_WEB_FEATURE_FLAGS"), + plugins=get_list(env, "ONNX_WEB_PLUGINS", ""), + debug=get_boolean(env, "ONNX_WEB_DEBUG", False), ) + def get_setting(self, flag: str, default: str) -> Optional[str]: + return environ.get(f"ONNX_WEB_{flag}", default) + def has_feature(self, flag: str) -> bool: return flag in self.feature_flags diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index 6bf1de2d..2dd37157 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union import torch from jsonschema import ValidationError, validate +from ..convert.utils import fix_diffusion_name from ..image import ( # mask filters; noise sources mask_filter_gaussian_multiply, mask_filter_gaussian_screen, @@ -189,6 +190,9 @@ def load_extras(server: ServerContext): for model in data[model_type]: model_name = model["name"] + if model_type == "diffusion": + model_name = fix_diffusion_name(model_name) + if "hash" in model: logger.debug( "collecting hash for model %s from %s", diff --git a/api/onnx_web/server/plugin.py b/api/onnx_web/server/plugin.py index 022047df..2adcabb3 100644 --- a/api/onnx_web/server/plugin.py +++ b/api/onnx_web/server/plugin.py @@ -2,18 +2,24 @@ from importlib import import_module from logging import getLogger from typing import Any, Callable, Dict -from onnx_web.chain.stages import add_stage -from onnx_web.diffusers.load import add_pipeline -from onnx_web.server.context import ServerContext +from ..chain.stages import add_stage +from ..diffusers.load import add_pipeline +from ..server.context import ServerContext logger = getLogger(__name__) class PluginExports: + clients: Dict[str, Any] + converter: Dict[str, Any] pipelines: Dict[str, Any] stages: Dict[str, Any] - def __init__(self, pipelines=None, stages=None) -> None: + def __init__( + self, clients=None, converter=None, pipelines=None, stages=None + ) -> None: + self.clients = clients or {} + self.converter = converter or {} self.pipelines = pipelines or {} self.stages = stages or {} diff --git a/api/params.json b/api/params.json index d9cd6fa0..7b0c746f 100644 --- a/api/params.json +++ b/api/params.json @@ -14,7 +14,7 @@ }, "cfg": { "default": 6, - "min": 1, + "min": 0, "max": 30, "step": 0.1 }, diff --git a/api/requirements/base.txt b/api/requirements/base.txt index 9d2b8ddf..822ef4e7 100644 --- a/api/requirements/base.txt +++ b/api/requirements/base.txt @@ -3,21 +3,21 @@ numpy==1.23.5 protobuf==3.20.3 ### SD packages ### -accelerate==0.22.0 +accelerate==0.25.0 coloredlogs==15.0.1 controlnet_aux==0.0.2 -datasets==2.14.3 -diffusers==0.20.0 +datasets==2.15.0 +diffusers==0.24.0 huggingface-hub==0.16.4 invisible-watermark==0.2.0 mediapipe==0.9.2.1 omegaconf==2.3.0 onnx==1.13.0 # onnxruntime has many platform-specific packages -optimum==1.12.0 -safetensors==0.3.1 -timm==0.6.13 -transformers==4.32.0 +optimum==1.16.0 +safetensors==0.4.1 +timm==0.9.12 +transformers==4.36.1 #### Upscaling and face correction basicsr==1.4.2 @@ -29,10 +29,11 @@ realesrgan==0.3.0 ### Server packages ### arpeggio==2.0.0 boto3==1.26.69 -flask==2.2.2 +flask==3.0.0 flask-cors==3.0.10 jsonschema==4.17.3 piexif==1.1.3 pyyaml==6.0 setproctitle==1.3.2 waitress==2.1.2 +werkzeug==3.0.1 diff --git a/api/tests/convert/test_utils.py b/api/tests/convert/test_utils.py index 4281adbc..34b0bf9b 100644 --- a/api/tests/convert/test_utils.py +++ b/api/tests/convert/test_utils.py @@ -27,7 +27,7 @@ class ConversionContextTests(unittest.TestCase): class DownloadProgressTests(unittest.TestCase): def test_download_example(self): - path = download_progress([("https://example.com", "/tmp/example-dot-com")]) + path = download_progress("https://example.com", "/tmp/example-dot-com") self.assertEqual(path, "/tmp/example-dot-com") diff --git a/api/tests/test_diffusers/test_run.py b/api/tests/test_diffusers/test_run.py index 26578f3e..322712e4 100644 --- a/api/tests/test_diffusers/test_run.py +++ b/api/tests/test_diffusers/test_run.py @@ -414,6 +414,7 @@ class TestBlendPipeline(unittest.TestCase): 3.0, 1, 1, + unet_tile=64, ), Size(64, 64), ["test-blend.png"], diff --git a/docs/getting-started.md b/docs/getting-started.md new file mode 100644 index 00000000..fc596089 --- /dev/null +++ b/docs/getting-started.md @@ -0,0 +1,198 @@ +# Getting Started With onnx-web + +onnx-web is a tool for generating images with Stable Diffusion pipelines, including SDXL. + +## Contents + +- [Getting Started With onnx-web](#getting-started-with-onnx-web) + - [Contents](#contents) + - [Setup](#setup) + - [Windows bundle setup](#windows-bundle-setup) + - [Other setup methods](#other-setup-methods) + - [Running](#running) + - [Running the server](#running-the-server) + - [Running the web UI](#running-the-web-ui) + - [Tabs](#tabs) + - [Txt2img Tab](#txt2img-tab) + - [Img2img Tab](#img2img-tab) + - [Inpaint Tab](#inpaint-tab) + - [Upscale Tab](#upscale-tab) + - [Blend Tab](#blend-tab) + - [Models Tab](#models-tab) + - [Settings Tab](#settings-tab) + - [Image parameters](#image-parameters) + - [Common image parameters](#common-image-parameters) + - [Unique image parameters](#unique-image-parameters) + - [Prompt syntax](#prompt-syntax) + - [LoRAs and embeddings](#loras-and-embeddings) + - [CLIP skip](#clip-skip) + - [Highres](#highres) + - [Highres prompt](#highres-prompt) + - [Highres iterations](#highres-iterations) + - [Profiles](#profiles) + - [Loading from files](#loading-from-files) + - [Saving profiles in the web UI](#saving-profiles-in-the-web-ui) + - [Sharing parameters profiles](#sharing-parameters-profiles) + - [Panorama pipeline](#panorama-pipeline) + - [Region prompts](#region-prompts) + - [Region seeds](#region-seeds) + - [Grid mode](#grid-mode) + - [Grid tokens](#grid-tokens) + - [Memory optimizations](#memory-optimizations) + - [Converting to fp16](#converting-to-fp16) + - [Moving models to the CPU](#moving-models-to-the-cpu) + +## Setup + +### Windows bundle setup + +1. Download +2. Extract +3. Security flags +4. Run + +### Other setup methods + +Link to the other methods. + +## Running + +### Running the server + +Run server or bundle. + +### Running the web UI + +Open web UI. + +Use it from your phone. + +## Tabs + +There are 5 tabs, which do different things. + +### Txt2img Tab + +Words go in, pictures come out. + +### Img2img Tab + +Pictures go in, better pictures come out. + +ControlNet lives here. + +### Inpaint Tab + +Pictures go in, parts of the same picture come out. + +### Upscale Tab + +Just highres and super resolution. + +### Blend Tab + +Use the mask tool to combine two images. + +### Models Tab + +Add and manage models. + +### Settings Tab + +Manage web UI settings. + +Reset buttons. + +## Image parameters + +### Common image parameters + +- Scheduler +- Eta + - for DDIM +- CFG +- Steps +- Seed +- Batch size +- Prompt +- Negative prompt +- Width, height + +### Unique image parameters + +- UNet tile size +- UNet overlap +- Tiled VAE +- VAE tile size +- VAE overlap + +See the complete user guide for details about the highres, upscale, and correction parameters. + +## Prompt syntax + +### LoRAs and embeddings + +`` and ``. + +### CLIP skip + +`` for anime. + +## Highres + +### Highres prompt + +`txt2img prompt || img2img prompt` + +### Highres iterations + +Highres will apply the upscaler and highres prompt (img2img pipeline) for each iteration. + +The final size will be `scale ** iterations`. + +## Profiles + +Saved sets of parameters for later use. + +### Loading from files + +- load from images +- load from JSON + +### Saving profiles in the web UI + +Use the save button. + +### Sharing parameters profiles + +Use the download button. + +Share profiles in the Discord channel. + +## Panorama pipeline + +### Region prompts + +`` + +### Region seeds + +`` + +## Grid mode + +Makes many images. Takes many time. + +### Grid tokens + +`__column__` and `__row__` if you pick token in the menu. + +## Memory optimizations + +### Converting to fp16 + +Enable the option. + +### Moving models to the CPU + +Option for each model. diff --git a/gui/src/components/ImageHistory.tsx b/gui/src/components/ImageHistory.tsx index 0f6b0f2f..20a520ed 100644 --- a/gui/src/components/ImageHistory.tsx +++ b/gui/src/components/ImageHistory.tsx @@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { OnnxState, StateContext } from '../state.js'; +import { OnnxState, StateContext } from '../state/full.js'; import { ErrorCard } from './card/ErrorCard.js'; import { ImageCard } from './card/ImageCard.js'; import { LoadingCard } from './card/LoadingCard.js'; diff --git a/gui/src/components/OnnxError.tsx b/gui/src/components/OnnxError.tsx index da6d7147..05186768 100644 --- a/gui/src/components/OnnxError.tsx +++ b/gui/src/components/OnnxError.tsx @@ -2,7 +2,7 @@ import { Box, Button, Container, Stack, Typography } from '@mui/material'; import * as React from 'react'; import { ReactNode } from 'react'; -import { STATE_KEY } from '../state.js'; +import { STATE_KEY } from '../state/full.js'; import { Logo } from './Logo.js'; export interface OnnxErrorProps { diff --git a/gui/src/components/OnnxWeb.tsx b/gui/src/components/OnnxWeb.tsx index 918ee84f..69c2db20 100644 --- a/gui/src/components/OnnxWeb.tsx +++ b/gui/src/components/OnnxWeb.tsx @@ -7,7 +7,7 @@ import { useContext, useMemo } from 'react'; import { useHash } from 'react-use/lib/useHash'; import { useStore } from 'zustand'; -import { OnnxState, StateContext } from '../state.js'; +import { OnnxState, StateContext } from '../state/full.js'; import { ImageHistory } from './ImageHistory.js'; import { Logo } from './Logo.js'; import { Blend } from './tab/Blend.js'; diff --git a/gui/src/components/Profiles.tsx b/gui/src/components/Profiles.tsx index baf022b3..de8f2263 100644 --- a/gui/src/components/Profiles.tsx +++ b/gui/src/components/Profiles.tsx @@ -21,7 +21,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { OnnxState, StateContext } from '../state.js'; +import { OnnxState, StateContext } from '../state/full.js'; import { ImageMetadata } from '../types/api.js'; import { DeepPartial } from '../types/model.js'; import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js'; diff --git a/gui/src/components/card/ErrorCard.tsx b/gui/src/components/card/ErrorCard.tsx index bb3ac6c9..f7106584 100644 --- a/gui/src/components/card/ErrorCard.tsx +++ b/gui/src/components/card/ErrorCard.tsx @@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { ImageResponse, ReadyResponse, RetryParams } from '../../types/api.js'; export interface ErrorCardProps { diff --git a/gui/src/components/card/ImageCard.tsx b/gui/src/components/card/ImageCard.tsx index 7c35acb2..e614d430 100644 --- a/gui/src/components/card/ImageCard.tsx +++ b/gui/src/components/card/ImageCard.tsx @@ -8,9 +8,10 @@ import { useHash } from 'react-use/lib/useHash'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { BLEND_SOURCES, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { ImageResponse } from '../../types/api.js'; import { range, visibleIndex } from '../../utils.js'; +import { BLEND_SOURCES } from '../../constants.js'; export interface ImageCardProps { image: ImageResponse; diff --git a/gui/src/components/card/LoadingCard.tsx b/gui/src/components/card/LoadingCard.tsx index e0fcdb68..71bfb5f0 100644 --- a/gui/src/components/card/LoadingCard.tsx +++ b/gui/src/components/card/LoadingCard.tsx @@ -9,7 +9,7 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { POLL_TIME } from '../../config.js'; -import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { ImageResponse } from '../../types/api.js'; const LOADING_PERCENT = 100; diff --git a/gui/src/components/control/HighresControl.tsx b/gui/src/components/control/HighresControl.tsx index 91525b21..62fa63fd 100644 --- a/gui/src/components/control/HighresControl.tsx +++ b/gui/src/components/control/HighresControl.tsx @@ -5,7 +5,7 @@ import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; -import { ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { HighresParams } from '../../types/params.js'; import { NumericField } from '../input/NumericField.js'; diff --git a/gui/src/components/control/ImageControl.tsx b/gui/src/components/control/ImageControl.tsx index 8271c700..8877ae50 100644 --- a/gui/src/components/control/ImageControl.tsx +++ b/gui/src/components/control/ImageControl.tsx @@ -11,7 +11,7 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { STALE_TIME } from '../../config.js'; -import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { BaseImgParams } from '../../types/params.js'; import { NumericField } from '../input/NumericField.js'; import { PromptInput } from '../input/PromptInput.js'; diff --git a/gui/src/components/control/ModelControl.tsx b/gui/src/components/control/ModelControl.tsx index ba08998f..54b99846 100644 --- a/gui/src/components/control/ModelControl.tsx +++ b/gui/src/components/control/ModelControl.tsx @@ -6,7 +6,7 @@ import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; import { STALE_TIME } from '../../config.js'; -import { ClientContext } from '../../state.js'; +import { ClientContext } from '../../state/full.js'; import { ModelParams } from '../../types/params.js'; import { QueryList } from '../input/QueryList.js'; diff --git a/gui/src/components/control/OutpaintControl.tsx b/gui/src/components/control/OutpaintControl.tsx index 5b70cc50..69523477 100644 --- a/gui/src/components/control/OutpaintControl.tsx +++ b/gui/src/components/control/OutpaintControl.tsx @@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { NumericField } from '../input/NumericField.js'; export function OutpaintControl() { diff --git a/gui/src/components/control/UpscaleControl.tsx b/gui/src/components/control/UpscaleControl.tsx index a90dd330..5d840a04 100644 --- a/gui/src/components/control/UpscaleControl.tsx +++ b/gui/src/components/control/UpscaleControl.tsx @@ -5,7 +5,7 @@ import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; -import { ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { UpscaleParams } from '../../types/params.js'; import { NumericField } from '../input/NumericField.js'; diff --git a/gui/src/components/control/VariableControl.tsx b/gui/src/components/control/VariableControl.tsx index 32a66660..bf5139e1 100644 --- a/gui/src/components/control/VariableControl.tsx +++ b/gui/src/components/control/VariableControl.tsx @@ -5,7 +5,7 @@ import { useContext } from 'react'; import { useStore } from 'zustand'; import { PipelineGrid } from '../../client/utils.js'; -import { OnnxState, StateContext } from '../../state.js'; +import { OnnxState, StateContext } from '../../state/full.js'; import { VARIABLE_PARAMETERS } from '../../types/chain.js'; export interface VariableControlProps { diff --git a/gui/src/components/input/EditableList.tsx b/gui/src/components/input/EditableList.tsx index 3910e394..a6d45aca 100644 --- a/gui/src/components/input/EditableList.tsx +++ b/gui/src/components/input/EditableList.tsx @@ -4,7 +4,7 @@ import * as React from 'react'; import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; -import { OnnxState, StateContext } from '../../state.js'; +import { OnnxState, StateContext } from '../../state/full.js'; const { useContext, useState, memo, useMemo } = React; diff --git a/gui/src/components/input/MaskCanvas.tsx b/gui/src/components/input/MaskCanvas.tsx index ae7f2723..2e605423 100644 --- a/gui/src/components/input/MaskCanvas.tsx +++ b/gui/src/components/input/MaskCanvas.tsx @@ -6,7 +6,7 @@ import React, { RefObject, useContext, useEffect, useMemo, useRef } from 'react' import { useTranslation } from 'react-i18next'; import { SAVE_TIME } from '../../config.js'; -import { ConfigContext, LoggerContext, StateContext } from '../../state.js'; +import { ConfigContext, LoggerContext, StateContext } from '../../state/full.js'; import { BrushParams } from '../../types/params.js'; import { imageFromBlob } from '../../utils.js'; import { NumericField } from './NumericField'; diff --git a/gui/src/components/input/PromptInput.tsx b/gui/src/components/input/PromptInput.tsx index cdaa6751..bc522569 100644 --- a/gui/src/components/input/PromptInput.tsx +++ b/gui/src/components/input/PromptInput.tsx @@ -8,7 +8,7 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { STALE_TIME } from '../../config.js'; -import { ClientContext, OnnxState, StateContext } from '../../state.js'; +import { ClientContext, OnnxState, StateContext } from '../../state/full.js'; import { QueryMenu } from '../input/QueryMenu.js'; import { ModelResponse } from '../../types/api.js'; diff --git a/gui/src/components/tab/Blend.tsx b/gui/src/components/tab/Blend.tsx index a1a304d1..21fe47b6 100644 --- a/gui/src/components/tab/Blend.tsx +++ b/gui/src/components/tab/Blend.tsx @@ -8,7 +8,9 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { IMAGE_FILTER } from '../../config.js'; -import { BLEND_SOURCES, ClientContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { BLEND_SOURCES } from '../../constants.js'; +import { ClientContext, OnnxState, StateContext } from '../../state/full.js'; +import { TabState } from '../../state/types.js'; import { BlendParams, BrushParams, ModelParams, UpscaleParams } from '../../types/params.js'; import { range } from '../../utils.js'; import { UpscaleControl } from '../control/UpscaleControl.js'; diff --git a/gui/src/components/tab/Img2Img.tsx b/gui/src/components/tab/Img2Img.tsx index cddb8539..26bda5cf 100644 --- a/gui/src/components/tab/Img2Img.tsx +++ b/gui/src/components/tab/Img2Img.tsx @@ -8,7 +8,8 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { IMAGE_FILTER, STALE_TIME } from '../../config.js'; -import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; +import { TabState } from '../../state/types.js'; import { HighresParams, Img2ImgParams, ModelParams, UpscaleParams } from '../../types/params.js'; import { Profiles } from '../Profiles.js'; import { HighresControl } from '../control/HighresControl.js'; diff --git a/gui/src/components/tab/Inpaint.tsx b/gui/src/components/tab/Inpaint.tsx index b783f5b5..ca83835b 100644 --- a/gui/src/components/tab/Inpaint.tsx +++ b/gui/src/components/tab/Inpaint.tsx @@ -8,7 +8,8 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { IMAGE_FILTER, STALE_TIME } from '../../config.js'; -import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; +import { TabState } from '../../state/types.js'; import { BrushParams, HighresParams, InpaintParams, ModelParams, UpscaleParams } from '../../types/params.js'; import { Profiles } from '../Profiles.js'; import { HighresControl } from '../control/HighresControl.js'; diff --git a/gui/src/components/tab/Models.tsx b/gui/src/components/tab/Models.tsx index e634485a..86e96ecd 100644 --- a/gui/src/components/tab/Models.tsx +++ b/gui/src/components/tab/Models.tsx @@ -8,7 +8,7 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { STALE_TIME } from '../../config.js'; -import { ClientContext, OnnxState, StateContext } from '../../state.js'; +import { ClientContext, OnnxState, StateContext } from '../../state/full.js'; import { CorrectionModel, DiffusionModel, diff --git a/gui/src/components/tab/Settings.tsx b/gui/src/components/tab/Settings.tsx index 5f09c1a2..7b25ed3b 100644 --- a/gui/src/components/tab/Settings.tsx +++ b/gui/src/components/tab/Settings.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { getApiRoot } from '../../config.js'; -import { ConfigContext, StateContext, STATE_KEY } from '../../state.js'; +import { ConfigContext, StateContext, STATE_KEY } from '../../state/full.js'; import { getTheme } from '../utils.js'; import { NumericField } from '../input/NumericField.js'; diff --git a/gui/src/components/tab/Txt2Img.tsx b/gui/src/components/tab/Txt2Img.tsx index 921def1f..81286ec2 100644 --- a/gui/src/components/tab/Txt2Img.tsx +++ b/gui/src/components/tab/Txt2Img.tsx @@ -8,7 +8,8 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { PipelineGrid, makeTxt2ImgGridPipeline } from '../../client/utils.js'; -import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; +import { TabState } from '../../state/types.js'; import { HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from '../../types/params.js'; import { Profiles } from '../Profiles.js'; import { HighresControl } from '../control/HighresControl.js'; diff --git a/gui/src/components/tab/Upscale.tsx b/gui/src/components/tab/Upscale.tsx index 06314579..e9a1c482 100644 --- a/gui/src/components/tab/Upscale.tsx +++ b/gui/src/components/tab/Upscale.tsx @@ -8,7 +8,8 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { IMAGE_FILTER } from '../../config.js'; -import { ClientContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { ClientContext, OnnxState, StateContext } from '../../state/full.js'; +import { TabState } from '../../state/types.js'; import { HighresParams, ModelParams, UpscaleParams, UpscaleReqParams } from '../../types/params.js'; import { Profiles } from '../Profiles.js'; import { HighresControl } from '../control/HighresControl.js'; diff --git a/gui/src/components/utils.ts b/gui/src/components/utils.ts index 58e106ec..39ba9bae 100644 --- a/gui/src/components/utils.ts +++ b/gui/src/components/utils.ts @@ -1,6 +1,6 @@ import { PaletteMode } from '@mui/material'; -import { Theme } from '../state.js'; +import { Theme } from '../state/types.js'; import { trimHash } from '../utils.js'; export const TAB_LABELS = [ diff --git a/gui/src/config.json b/gui/src/config.json index 91cccb8e..d2d63098 100644 --- a/gui/src/config.json +++ b/gui/src/config.json @@ -18,7 +18,7 @@ }, "cfg": { "default": 6, - "min": 1, + "min": 0, "max": 30, "step": 0.1 }, diff --git a/gui/src/constants.ts b/gui/src/constants.ts new file mode 100644 index 00000000..eb718447 --- /dev/null +++ b/gui/src/constants.ts @@ -0,0 +1,31 @@ + +export const BLEND_SOURCES = 2; + +/** + * Default parameters for the inpaint brush. + * + * Not provided by the server yet. + */ +export const DEFAULT_BRUSH = { + color: 255, + size: 8, + strength: 0.5, +}; + +/** + * Default parameters for the image history. + * + * Not provided by the server yet. + */ +export const DEFAULT_HISTORY = { + /** + * The number of images to be shown. + */ + limit: 4, + + /** + * The number of additional images to be kept in history, so they can scroll + * back into view when you delete one. Does not include deleted images. + */ + scrollback: 2, +}; diff --git a/gui/src/main.tsx b/gui/src/main.tsx index d0b83a11..81aeecf0 100644 --- a/gui/src/main.tsx +++ b/gui/src/main.tsx @@ -28,8 +28,9 @@ import { STATE_KEY, STATE_VERSION, StateContext, -} from './state.js'; +} from './state/full.js'; import { I18N_STRINGS } from './strings/all.js'; +import { applyStateMigrations, UnknownState } from './state/migration/default.js'; export const INITIAL_LOAD_TIMEOUT = 5_000; @@ -70,6 +71,9 @@ export async function renderApp(config: Config, params: ServerParams, logger: Lo ...createResetSlice(...slice), ...createProfileSlice(...slice), }), { + migrate(persistedState, version) { + return applyStateMigrations(params, persistedState as UnknownState, version, logger); + }, name: STATE_KEY, partialize(s) { return { diff --git a/gui/src/state.ts b/gui/src/state.ts deleted file mode 100644 index ad2a3e3a..00000000 --- a/gui/src/state.ts +++ /dev/null @@ -1,965 +0,0 @@ -/* eslint-disable camelcase */ -/* eslint-disable max-lines */ -/* eslint-disable no-null/no-null */ -import { Maybe } from '@apextoaster/js-utils'; -import { PaletteMode } from '@mui/material'; -import { Logger } from 'noicejs'; -import { createContext } from 'react'; -import { StateCreator, StoreApi } from 'zustand'; - -import { - ApiClient, -} from './client/base.js'; -import { PipelineGrid } from './client/utils.js'; -import { Config, ConfigFiles, ConfigState, ServerParams } from './config.js'; -import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from './types/model.js'; -import { ImageResponse, ReadyResponse, RetryParams } from './types/api.js'; -import { - BaseImgParams, - BlendParams, - BrushParams, - HighresParams, - Img2ImgParams, - InpaintParams, - ModelParams, - OutpaintPixels, - Txt2ImgParams, - UpscaleParams, - UpscaleReqParams, -} from './types/params.js'; - -export const MISSING_INDEX = -1; - -export type Theme = PaletteMode | ''; // tri-state, '' is unset - -/** - * Combine optional files and required ranges. - */ -export type TabState = ConfigFiles> & ConfigState>; - -export interface HistoryItem { - image: ImageResponse; - ready: Maybe; - retry: Maybe; -} - -export interface ProfileItem { - name: string; - params: BaseImgParams | Txt2ImgParams; - highres?: Maybe; - upscale?: Maybe; -} - -interface DefaultSlice { - defaults: TabState; - theme: Theme; - - setDefaults(param: Partial): void; - setTheme(theme: Theme): void; -} - -interface HistorySlice { - history: Array; - limit: number; - - pushHistory(image: ImageResponse, retry?: RetryParams): void; - removeHistory(image: ImageResponse): void; - setLimit(limit: number): void; - setReady(image: ImageResponse, ready: ReadyResponse): void; -} - -interface ModelSlice { - extras: ExtrasFile; - - removeCorrectionModel(model: CorrectionModel): void; - removeDiffusionModel(model: DiffusionModel): void; - removeExtraNetwork(model: ExtraNetwork): void; - removeExtraSource(model: ExtraSource): void; - removeUpscalingModel(model: UpscalingModel): void; - - setExtras(extras: Partial): void; - - setCorrectionModel(model: CorrectionModel): void; - setDiffusionModel(model: DiffusionModel): void; - setExtraNetwork(model: ExtraNetwork): void; - setExtraSource(model: ExtraSource): void; - setUpscalingModel(model: UpscalingModel): void; -} - -// #region tab slices -interface Txt2ImgSlice { - txt2img: TabState; - txt2imgModel: ModelParams; - txt2imgHighres: HighresParams; - txt2imgUpscale: UpscaleParams; - txt2imgVariable: PipelineGrid; - - resetTxt2Img(): void; - - setTxt2Img(params: Partial): void; - setTxt2ImgModel(params: Partial): void; - setTxt2ImgHighres(params: Partial): void; - setTxt2ImgUpscale(params: Partial): void; - setTxt2ImgVariable(params: Partial): void; -} - -interface Img2ImgSlice { - img2img: TabState; - img2imgModel: ModelParams; - img2imgHighres: HighresParams; - img2imgUpscale: UpscaleParams; - - resetImg2Img(): void; - - setImg2Img(params: Partial): void; - setImg2ImgModel(params: Partial): void; - setImg2ImgHighres(params: Partial): void; - setImg2ImgUpscale(params: Partial): void; -} - -interface InpaintSlice { - inpaint: TabState; - inpaintBrush: BrushParams; - inpaintModel: ModelParams; - inpaintHighres: HighresParams; - inpaintUpscale: UpscaleParams; - outpaint: OutpaintPixels; - - resetInpaint(): void; - - setInpaint(params: Partial): void; - setInpaintBrush(brush: Partial): void; - setInpaintModel(params: Partial): void; - setInpaintHighres(params: Partial): void; - setInpaintUpscale(params: Partial): void; - setOutpaint(pixels: Partial): void; -} - -interface UpscaleSlice { - upscale: TabState; - upscaleHighres: HighresParams; - upscaleModel: ModelParams; - upscaleUpscale: UpscaleParams; - - resetUpscale(): void; - - setUpscale(params: Partial): void; - setUpscaleHighres(params: Partial): void; - setUpscaleModel(params: Partial): void; - setUpscaleUpscale(params: Partial): void; -} - -interface BlendSlice { - blend: TabState; - blendBrush: BrushParams; - blendModel: ModelParams; - blendUpscale: UpscaleParams; - - resetBlend(): void; - - setBlend(blend: Partial): void; - setBlendBrush(brush: Partial): void; - setBlendModel(model: Partial): void; - setBlendUpscale(params: Partial): void; -} - -interface ResetSlice { - resetAll(): void; -} - -interface ProfileSlice { - profiles: Array; - - removeProfile(profileName: string): void; - - saveProfile(profile: ProfileItem): void; -} -// #endregion - -/** - * Full merged state including all slices. - */ -export type OnnxState - = DefaultSlice - & HistorySlice - & Img2ImgSlice - & InpaintSlice - & ModelSlice - & Txt2ImgSlice - & UpscaleSlice - & BlendSlice - & ResetSlice - & ModelSlice - & ProfileSlice; - -/** - * Shorthand for state creator to reduce repeated arguments. - */ -export type Slice = StateCreator; - -/** - * React context binding for API client. - */ -export const ClientContext = createContext>(undefined); - -/** - * React context binding for merged config, including server parameters. - */ -export const ConfigContext = createContext>>(undefined); - -/** - * React context binding for bunyan logger. - */ -export const LoggerContext = createContext>(undefined); - -/** - * React context binding for zustand state store. - */ -export const StateContext = createContext>>(undefined); - -/** - * Key for zustand persistence, typically local storage. - */ -export const STATE_KEY = 'onnx-web'; - -/** - * Current state version for zustand persistence. - */ -export const STATE_VERSION = 7; - -export const BLEND_SOURCES = 2; - -/** - * Default parameters for the inpaint brush. - * - * Not provided by the server yet. - */ -export const DEFAULT_BRUSH = { - color: 255, - size: 8, - strength: 0.5, -}; - -/** - * Default parameters for the image history. - * - * Not provided by the server yet. - */ -export const DEFAULT_HISTORY = { - /** - * The number of images to be shown. - */ - limit: 4, - - /** - * The number of additional images to be kept in history, so they can scroll - * back into view when you delete one. Does not include deleted images. - */ - scrollback: 2, -}; - -export function baseParamsFromServer(defaults: ServerParams): Required { - return { - batch: defaults.batch.default, - cfg: defaults.cfg.default, - eta: defaults.eta.default, - negativePrompt: defaults.negativePrompt.default, - prompt: defaults.prompt.default, - scheduler: defaults.scheduler.default, - steps: defaults.steps.default, - seed: defaults.seed.default, - tiled_vae: defaults.tiled_vae.default, - unet_overlap: defaults.unet_overlap.default, - unet_tile: defaults.unet_tile.default, - vae_overlap: defaults.vae_overlap.default, - vae_tile: defaults.vae_tile.default, - }; -} - -/** - * Prepare the state slice constructors. - * - * In the default state, image sources should be null and booleans should be false. Everything - * else should be initialized from the default value in the base parameters. - */ -export function createStateSlices(server: ServerParams) { - const defaultParams = baseParamsFromServer(server); - const defaultHighres: HighresParams = { - enabled: false, - highresIterations: server.highresIterations.default, - highresMethod: '', - highresSteps: server.highresSteps.default, - highresScale: server.highresScale.default, - highresStrength: server.highresStrength.default, - }; - const defaultModel: ModelParams = { - control: server.control.default, - correction: server.correction.default, - model: server.model.default, - pipeline: server.pipeline.default, - platform: server.platform.default, - upscaling: server.upscaling.default, - }; - const defaultUpscale: UpscaleParams = { - denoise: server.denoise.default, - enabled: false, - faces: false, - faceOutscale: server.faceOutscale.default, - faceStrength: server.faceStrength.default, - outscale: server.outscale.default, - scale: server.scale.default, - upscaleOrder: server.upscaleOrder.default, - }; - const defaultGrid: PipelineGrid = { - enabled: false, - columns: { - parameter: 'seed', - value: '', - }, - rows: { - parameter: 'seed', - value: '', - }, - }; - - const createTxt2ImgSlice: Slice = (set) => ({ - txt2img: { - ...defaultParams, - width: server.width.default, - height: server.height.default, - }, - txt2imgHighres: { - ...defaultHighres, - }, - txt2imgModel: { - ...defaultModel, - }, - txt2imgUpscale: { - ...defaultUpscale, - }, - txt2imgVariable: { - ...defaultGrid, - }, - setTxt2Img(params) { - set((prev) => ({ - txt2img: { - ...prev.txt2img, - ...params, - }, - })); - }, - setTxt2ImgHighres(params) { - set((prev) => ({ - txt2imgHighres: { - ...prev.txt2imgHighres, - ...params, - }, - })); - }, - setTxt2ImgModel(params) { - set((prev) => ({ - txt2imgModel: { - ...prev.txt2imgModel, - ...params, - }, - })); - }, - setTxt2ImgUpscale(params) { - set((prev) => ({ - txt2imgUpscale: { - ...prev.txt2imgUpscale, - ...params, - }, - })); - }, - setTxt2ImgVariable(params) { - set((prev) => ({ - txt2imgVariable: { - ...prev.txt2imgVariable, - ...params, - }, - })); - }, - resetTxt2Img() { - set({ - txt2img: { - ...defaultParams, - width: server.width.default, - height: server.height.default, - }, - }); - }, - }); - - const createImg2ImgSlice: Slice = (set) => ({ - img2img: { - ...defaultParams, - loopback: server.loopback.default, - source: null, - sourceFilter: '', - strength: server.strength.default, - }, - img2imgHighres: { - ...defaultHighres, - }, - img2imgModel: { - ...defaultModel, - }, - img2imgUpscale: { - ...defaultUpscale, - }, - resetImg2Img() { - set({ - img2img: { - ...defaultParams, - loopback: server.loopback.default, - source: null, - sourceFilter: '', - strength: server.strength.default, - }, - }); - }, - setImg2Img(params) { - set((prev) => ({ - img2img: { - ...prev.img2img, - ...params, - }, - })); - }, - setImg2ImgHighres(params) { - set((prev) => ({ - img2imgHighres: { - ...prev.img2imgHighres, - ...params, - }, - })); - }, - setImg2ImgModel(params) { - set((prev) => ({ - img2imgModel: { - ...prev.img2imgModel, - ...params, - }, - })); - }, - setImg2ImgUpscale(params) { - set((prev) => ({ - img2imgUpscale: { - ...prev.img2imgUpscale, - ...params, - }, - })); - }, - }); - - const createInpaintSlice: Slice = (set) => ({ - inpaint: { - ...defaultParams, - fillColor: server.fillColor.default, - filter: server.filter.default, - mask: null, - noise: server.noise.default, - source: null, - strength: server.strength.default, - tileOrder: server.tileOrder.default, - }, - inpaintBrush: { - ...DEFAULT_BRUSH, - }, - inpaintHighres: { - ...defaultHighres, - }, - inpaintModel: { - ...defaultModel, - }, - inpaintUpscale: { - ...defaultUpscale, - }, - outpaint: { - enabled: false, - left: server.left.default, - right: server.right.default, - top: server.top.default, - bottom: server.bottom.default, - }, - resetInpaint() { - set({ - inpaint: { - ...defaultParams, - fillColor: server.fillColor.default, - filter: server.filter.default, - mask: null, - noise: server.noise.default, - source: null, - strength: server.strength.default, - tileOrder: server.tileOrder.default, - }, - }); - }, - setInpaint(params) { - set((prev) => ({ - inpaint: { - ...prev.inpaint, - ...params, - }, - })); - }, - setInpaintBrush(brush) { - set((prev) => ({ - inpaintBrush: { - ...prev.inpaintBrush, - ...brush, - }, - })); - }, - setInpaintHighres(params) { - set((prev) => ({ - inpaintHighres: { - ...prev.inpaintHighres, - ...params, - }, - })); - }, - setInpaintModel(params) { - set((prev) => ({ - inpaintModel: { - ...prev.inpaintModel, - ...params, - }, - })); - }, - setInpaintUpscale(params) { - set((prev) => ({ - inpaintUpscale: { - ...prev.inpaintUpscale, - ...params, - }, - })); - }, - setOutpaint(pixels) { - set((prev) => ({ - outpaint: { - ...prev.outpaint, - ...pixels, - } - })); - }, - }); - - const createHistorySlice: Slice = (set) => ({ - history: [], - limit: DEFAULT_HISTORY.limit, - pushHistory(image, retry) { - set((prev) => ({ - ...prev, - history: [ - { - image, - ready: undefined, - retry, - }, - ...prev.history, - ].slice(0, prev.limit + DEFAULT_HISTORY.scrollback), - })); - }, - removeHistory(image) { - set((prev) => ({ - ...prev, - history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key), - })); - }, - setLimit(limit) { - set((prev) => ({ - ...prev, - limit, - })); - }, - setReady(image, ready) { - set((prev) => { - const history = [...prev.history]; - const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key); - if (idx >= 0) { - history[idx].ready = ready; - } else { - // TODO: error - } - - return { - ...prev, - history, - }; - }); - }, - }); - - const createUpscaleSlice: Slice = (set) => ({ - upscale: { - ...defaultParams, - source: null, - }, - upscaleHighres: { - ...defaultHighres, - }, - upscaleModel: { - ...defaultModel, - }, - upscaleUpscale: { - ...defaultUpscale, - }, - resetUpscale() { - set({ - upscale: { - ...defaultParams, - source: null, - }, - }); - }, - setUpscale(source) { - set((prev) => ({ - upscale: { - ...prev.upscale, - ...source, - }, - })); - }, - setUpscaleHighres(params) { - set((prev) => ({ - upscaleHighres: { - ...prev.upscaleHighres, - ...params, - }, - })); - }, - setUpscaleModel(params) { - set((prev) => ({ - upscaleModel: { - ...prev.upscaleModel, - ...defaultModel, - }, - })); - }, - setUpscaleUpscale(params) { - set((prev) => ({ - upscaleUpscale: { - ...prev.upscaleUpscale, - ...params, - }, - })); - }, - }); - - const createBlendSlice: Slice = (set) => ({ - blend: { - mask: null, - sources: [], - }, - blendBrush: { - ...DEFAULT_BRUSH, - }, - blendModel: { - ...defaultModel, - }, - blendUpscale: { - ...defaultUpscale, - }, - resetBlend() { - set({ - blend: { - mask: null, - sources: [], - }, - }); - }, - setBlend(blend) { - set((prev) => ({ - blend: { - ...prev.blend, - ...blend, - }, - })); - }, - setBlendBrush(brush) { - set((prev) => ({ - blendBrush: { - ...prev.blendBrush, - ...brush, - }, - })); - }, - setBlendModel(model) { - set((prev) => ({ - blendModel: { - ...prev.blendModel, - ...model, - }, - })); - }, - setBlendUpscale(params) { - set((prev) => ({ - blendUpscale: { - ...prev.blendUpscale, - ...params, - }, - })); - }, - }); - - const createDefaultSlice: Slice = (set) => ({ - defaults: { - ...defaultParams, - }, - theme: '', - setDefaults(params) { - set((prev) => ({ - defaults: { - ...prev.defaults, - ...params, - } - })); - }, - setTheme(theme) { - set((prev) => ({ - theme, - })); - } - }); - - const createResetSlice: Slice = (set) => ({ - resetAll() { - set((prev) => { - const next = { ...prev }; - next.resetImg2Img(); - next.resetInpaint(); - next.resetTxt2Img(); - next.resetUpscale(); - next.resetBlend(); - return next; - }); - }, - }); - - const createProfileSlice: Slice = (set) => ({ - profiles: [], - saveProfile(profile: ProfileItem) { - set((prev) => { - const profiles = [...prev.profiles]; - const idx = profiles.findIndex((it) => it.name === profile.name); - if (idx >= 0) { - profiles[idx] = profile; - } else { - profiles.push(profile); - } - return { - ...prev, - profiles, - }; - }); - }, - removeProfile(profileName: string) { - set((prev) => { - const profiles = [...prev.profiles]; - const idx = profiles.findIndex((it) => it.name === profileName); - if (idx >= 0) { - profiles.splice(idx, 1); - } - return { - ...prev, - profiles, - }; - }); - } - }); - - // eslint-disable-next-line sonarjs/cognitive-complexity - const createModelSlice: Slice = (set) => ({ - extras: { - correction: [], - diffusion: [], - networks: [], - sources: [], - upscaling: [], - }, - setExtras(extras) { - set((prev) => ({ - ...prev, - extras: { - ...prev.extras, - ...extras, - }, - })); - }, - setCorrectionModel(model) { - set((prev) => { - const correction = [...prev.extras.correction]; - const exists = correction.findIndex((it) => model.name === it.name); - if (exists === MISSING_INDEX) { - correction.push(model); - } else { - correction[exists] = model; - } - - return { - ...prev, - extras: { - ...prev.extras, - correction, - }, - }; - }); - }, - setDiffusionModel(model) { - set((prev) => { - const diffusion = [...prev.extras.diffusion]; - const exists = diffusion.findIndex((it) => model.name === it.name); - if (exists === MISSING_INDEX) { - diffusion.push(model); - } else { - diffusion[exists] = model; - } - - return { - ...prev, - extras: { - ...prev.extras, - diffusion, - }, - }; - }); - }, - setExtraNetwork(model) { - set((prev) => { - const networks = [...prev.extras.networks]; - const exists = networks.findIndex((it) => model.name === it.name); - if (exists === MISSING_INDEX) { - networks.push(model); - } else { - networks[exists] = model; - } - - return { - ...prev, - extras: { - ...prev.extras, - networks, - }, - }; - }); - }, - setExtraSource(model) { - set((prev) => { - const sources = [...prev.extras.sources]; - const exists = sources.findIndex((it) => model.name === it.name); - if (exists === MISSING_INDEX) { - sources.push(model); - } else { - sources[exists] = model; - } - - return { - ...prev, - extras: { - ...prev.extras, - sources, - }, - }; - }); - }, - setUpscalingModel(model) { - set((prev) => { - const upscaling = [...prev.extras.upscaling]; - const exists = upscaling.findIndex((it) => model.name === it.name); - if (exists === MISSING_INDEX) { - upscaling.push(model); - } else { - upscaling[exists] = model; - } - - return { - ...prev, - extras: { - ...prev.extras, - upscaling, - }, - }; - }); - }, - removeCorrectionModel(model) { - set((prev) => { - const correction = prev.extras.correction.filter((it) => model.name !== it.name);; - return { - ...prev, - extras: { - ...prev.extras, - correction, - }, - }; - }); - - }, - removeDiffusionModel(model) { - set((prev) => { - const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);; - return { - ...prev, - extras: { - ...prev.extras, - diffusion, - }, - }; - }); - - }, - removeExtraNetwork(model) { - set((prev) => { - const networks = prev.extras.networks.filter((it) => model.name !== it.name);; - return { - ...prev, - extras: { - ...prev.extras, - networks, - }, - }; - }); - - }, - removeExtraSource(model) { - set((prev) => { - const sources = prev.extras.sources.filter((it) => model.name !== it.name);; - return { - ...prev, - extras: { - ...prev.extras, - sources, - }, - }; - }); - - }, - removeUpscalingModel(model) { - set((prev) => { - const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);; - return { - ...prev, - extras: { - ...prev.extras, - upscaling, - }, - }; - }); - }, - }); - - return { - createDefaultSlice, - createHistorySlice, - createImg2ImgSlice, - createInpaintSlice, - createTxt2ImgSlice, - createUpscaleSlice, - createBlendSlice, - createResetSlice, - createModelSlice, - createProfileSlice, - }; -} diff --git a/gui/src/state/blend.ts b/gui/src/state/blend.ts new file mode 100644 index 00000000..60cd5f88 --- /dev/null +++ b/gui/src/state/blend.ts @@ -0,0 +1,85 @@ +import { DEFAULT_BRUSH } from '../constants.js'; +import { + BlendParams, + BrushParams, + ModelParams, + UpscaleParams, +} from '../types/params.js'; +import { Slice, TabState } from './types.js'; + +export interface BlendSlice { + blend: TabState; + blendBrush: BrushParams; + blendModel: ModelParams; + blendUpscale: UpscaleParams; + + resetBlend(): void; + + setBlend(blend: Partial): void; + setBlendBrush(brush: Partial): void; + setBlendModel(model: Partial): void; + setBlendUpscale(params: Partial): void; +} + +export function createBlendSlice( + defaultModel: ModelParams, + defaultUpscale: UpscaleParams, +): Slice { + return (set) => ({ + blend: { + // eslint-disable-next-line no-null/no-null + mask: null, + sources: [], + }, + blendBrush: { + ...DEFAULT_BRUSH, + }, + blendModel: { + ...defaultModel, + }, + blendUpscale: { + ...defaultUpscale, + }, + resetBlend() { + set((prev) => ({ + blend: { + // eslint-disable-next-line no-null/no-null + mask: null, + sources: [] as Array, + }, + } as Partial)); + }, + setBlend(blend) { + set((prev) => ({ + blend: { + ...prev.blend, + ...blend, + }, + } as Partial)); + }, + setBlendBrush(brush) { + set((prev) => ({ + blendBrush: { + ...prev.blendBrush, + ...brush, + }, + } as Partial)); + }, + setBlendModel(model) { + set((prev) => ({ + blendModel: { + ...prev.blendModel, + ...model, + }, + } as Partial)); + }, + setBlendUpscale(params) { + set((prev) => ({ + blendUpscale: { + ...prev.blendUpscale, + ...params, + }, + } as Partial)); + }, + }); +} diff --git a/gui/src/state/default.ts b/gui/src/state/default.ts new file mode 100644 index 00000000..c578263b --- /dev/null +++ b/gui/src/state/default.ts @@ -0,0 +1,34 @@ +import { + BaseImgParams, +} from '../types/params.js'; +import { Slice, TabState, Theme } from './types.js'; + +export interface DefaultSlice { + defaults: TabState; + theme: Theme; + + setDefaults(param: Partial): void; + setTheme(theme: Theme): void; +} + +export function createDefaultSlice(defaultParams: Required): Slice { + return (set) => ({ + defaults: { + ...defaultParams, + }, + theme: '', + setDefaults(params) { + set((prev) => ({ + defaults: { + ...prev.defaults, + ...params, + } + } as Partial)); + }, + setTheme(theme) { + set((prev) => ({ + theme, + } as Partial)); + } + }); +} diff --git a/gui/src/state/full.ts b/gui/src/state/full.ts new file mode 100644 index 00000000..990161ae --- /dev/null +++ b/gui/src/state/full.ts @@ -0,0 +1,150 @@ +/* eslint-disable camelcase */ +import { Maybe } from '@apextoaster/js-utils'; +import { Logger } from 'noicejs'; +import { createContext } from 'react'; +import { StoreApi } from 'zustand'; + +import { + ApiClient, +} from '../client/base.js'; +import { PipelineGrid } from '../client/utils.js'; +import { Config, ServerParams } from '../config.js'; +import { BlendSlice, createBlendSlice } from './blend.js'; +import { DefaultSlice, createDefaultSlice } from './default.js'; +import { HistorySlice, createHistorySlice } from './history.js'; +import { Img2ImgSlice, createImg2ImgSlice } from './img2img.js'; +import { InpaintSlice, createInpaintSlice } from './inpaint.js'; +import { ModelSlice, createModelSlice } from './model.js'; +import { ProfileSlice, createProfileSlice } from './profile.js'; +import { ResetSlice, createResetSlice } from './reset.js'; +import { Txt2ImgSlice, createTxt2ImgSlice } from './txt2img.js'; +import { UpscaleSlice, createUpscaleSlice } from './upscale.js'; +import { + BaseImgParams, + HighresParams, + ModelParams, + UpscaleParams, +} from '../types/params.js'; + +/** + * Full merged state including all slices. + */ +export type OnnxState + = DefaultSlice + & HistorySlice + & Img2ImgSlice + & InpaintSlice + & ModelSlice + & Txt2ImgSlice + & UpscaleSlice + & BlendSlice + & ResetSlice + & ProfileSlice; + +/** + * React context binding for API client. + */ +export const ClientContext = createContext>(undefined); + +/** + * React context binding for merged config, including server parameters. + */ +export const ConfigContext = createContext>>(undefined); + +/** + * React context binding for bunyan logger. + */ +export const LoggerContext = createContext>(undefined); + +/** + * React context binding for zustand state store. + */ +export const StateContext = createContext>>(undefined); + +/** + * Key for zustand persistence, typically local storage. + */ +export const STATE_KEY = 'onnx-web'; + +/** + * Current state version for zustand persistence. + */ +export const STATE_VERSION = 11; + +export function baseParamsFromServer(defaults: ServerParams): Required { + return { + batch: defaults.batch.default, + cfg: defaults.cfg.default, + eta: defaults.eta.default, + negativePrompt: defaults.negativePrompt.default, + prompt: defaults.prompt.default, + scheduler: defaults.scheduler.default, + steps: defaults.steps.default, + seed: defaults.seed.default, + tiled_vae: defaults.tiled_vae.default, + unet_overlap: defaults.unet_overlap.default, + unet_tile: defaults.unet_tile.default, + vae_overlap: defaults.vae_overlap.default, + vae_tile: defaults.vae_tile.default, + }; +} + +/** + * Prepare the state slice constructors. + * + * In the default state, image sources should be null and booleans should be false. Everything + * else should be initialized from the default value in the base parameters. + */ +export function createStateSlices(server: ServerParams) { + const defaultParams = baseParamsFromServer(server); + const defaultHighres: HighresParams = { + enabled: false, + highresIterations: server.highresIterations.default, + highresMethod: '', + highresSteps: server.highresSteps.default, + highresScale: server.highresScale.default, + highresStrength: server.highresStrength.default, + }; + const defaultModel: ModelParams = { + control: server.control.default, + correction: server.correction.default, + model: server.model.default, + pipeline: server.pipeline.default, + platform: server.platform.default, + upscaling: server.upscaling.default, + }; + const defaultUpscale: UpscaleParams = { + denoise: server.denoise.default, + enabled: false, + faces: false, + faceOutscale: server.faceOutscale.default, + faceStrength: server.faceStrength.default, + outscale: server.outscale.default, + scale: server.scale.default, + upscaleOrder: server.upscaleOrder.default, + }; + const defaultGrid: PipelineGrid = { + enabled: false, + columns: { + parameter: 'seed', + value: '', + }, + rows: { + parameter: 'seed', + value: '', + }, + }; + + return { + createBlendSlice: createBlendSlice(defaultModel, defaultUpscale), + createDefaultSlice: createDefaultSlice(defaultParams), + createHistorySlice: createHistorySlice(), + createImg2ImgSlice: createImg2ImgSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale), + createInpaintSlice: createInpaintSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale), + createModelSlice: createModelSlice(), + createProfileSlice: createProfileSlice(), + createResetSlice: createResetSlice(), + createTxt2ImgSlice: createTxt2ImgSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale, defaultGrid), + createUpscaleSlice: createUpscaleSlice(defaultParams, defaultHighres, defaultModel, defaultUpscale), + }; +} diff --git a/gui/src/state/history.ts b/gui/src/state/history.ts new file mode 100644 index 00000000..de71ef72 --- /dev/null +++ b/gui/src/state/history.ts @@ -0,0 +1,68 @@ +import { Maybe } from '@apextoaster/js-utils'; +import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js'; +import { Slice } from './types.js'; +import { DEFAULT_HISTORY } from '../constants.js'; + +export interface HistoryItem { + image: ImageResponse; + ready: Maybe; + retry: Maybe; +} + +export interface HistorySlice { + history: Array; + limit: number; + + pushHistory(image: ImageResponse, retry?: RetryParams): void; + removeHistory(image: ImageResponse): void; + setLimit(limit: number): void; + setReady(image: ImageResponse, ready: ReadyResponse): void; +} + +export function createHistorySlice(): Slice { + return (set) => ({ + history: [], + limit: DEFAULT_HISTORY.limit, + pushHistory(image, retry) { + set((prev) => ({ + ...prev, + history: [ + { + image, + ready: undefined, + retry, + }, + ...prev.history, + ].slice(0, prev.limit + DEFAULT_HISTORY.scrollback), + })); + }, + removeHistory(image) { + set((prev) => ({ + ...prev, + history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key), + })); + }, + setLimit(limit) { + set((prev) => ({ + ...prev, + limit, + })); + }, + setReady(image, ready) { + set((prev) => { + const history = [...prev.history]; + const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key); + if (idx >= 0) { + history[idx].ready = ready; + } else { + // TODO: error + } + + return { + ...prev, + history, + }; + }); + }, + }); +} diff --git a/gui/src/state/img2img.ts b/gui/src/state/img2img.ts new file mode 100644 index 00000000..b8f986b0 --- /dev/null +++ b/gui/src/state/img2img.ts @@ -0,0 +1,98 @@ + +import { ServerParams } from '../config.js'; +import { + BaseImgParams, + HighresParams, + Img2ImgParams, + ModelParams, + UpscaleParams, +} from '../types/params.js'; +import { Slice, TabState } from './types.js'; + +export interface Img2ImgSlice { + img2img: TabState; + img2imgModel: ModelParams; + img2imgHighres: HighresParams; + img2imgUpscale: UpscaleParams; + + resetImg2Img(): void; + + setImg2Img(params: Partial): void; + setImg2ImgModel(params: Partial): void; + setImg2ImgHighres(params: Partial): void; + setImg2ImgUpscale(params: Partial): void; +} + +// eslint-disable-next-line max-params +export function createImg2ImgSlice( + server: ServerParams, + defaultParams: Required, + defaultHighres: HighresParams, + defaultModel: ModelParams, + defaultUpscale: UpscaleParams +): Slice { + return (set) => ({ + img2img: { + ...defaultParams, + loopback: server.loopback.default, + // eslint-disable-next-line no-null/no-null + source: null, + sourceFilter: '', + strength: server.strength.default, + }, + img2imgHighres: { + ...defaultHighres, + }, + img2imgModel: { + ...defaultModel, + }, + img2imgUpscale: { + ...defaultUpscale, + }, + resetImg2Img() { + set({ + img2img: { + ...defaultParams, + loopback: server.loopback.default, + // eslint-disable-next-line no-null/no-null + source: null, + sourceFilter: '', + strength: server.strength.default, + }, + } as Partial); + }, + setImg2Img(params) { + set((prev) => ({ + img2img: { + ...prev.img2img, + ...params, + }, + } as Partial)); + }, + setImg2ImgHighres(params) { + set((prev) => ({ + img2imgHighres: { + ...prev.img2imgHighres, + ...params, + }, + } as Partial)); + }, + setImg2ImgModel(params) { + set((prev) => ({ + img2imgModel: { + ...prev.img2imgModel, + ...params, + }, + } as Partial)); + }, + setImg2ImgUpscale(params) { + set((prev) => ({ + img2imgUpscale: { + ...prev.img2imgUpscale, + ...params, + }, + } as Partial)); + }, + }); + +} diff --git a/gui/src/state/inpaint.ts b/gui/src/state/inpaint.ts new file mode 100644 index 00000000..7dab5af3 --- /dev/null +++ b/gui/src/state/inpaint.ts @@ -0,0 +1,136 @@ +import { ServerParams } from '../config.js'; +import { DEFAULT_BRUSH } from '../constants.js'; +import { + BaseImgParams, + BrushParams, + HighresParams, + InpaintParams, + ModelParams, + OutpaintPixels, + UpscaleParams, +} from '../types/params.js'; +import { Slice, TabState } from './types.js'; +export interface InpaintSlice { + inpaint: TabState; + inpaintBrush: BrushParams; + inpaintModel: ModelParams; + inpaintHighres: HighresParams; + inpaintUpscale: UpscaleParams; + outpaint: OutpaintPixels; + + resetInpaint(): void; + + setInpaint(params: Partial): void; + setInpaintBrush(brush: Partial): void; + setInpaintModel(params: Partial): void; + setInpaintHighres(params: Partial): void; + setInpaintUpscale(params: Partial): void; + setOutpaint(pixels: Partial): void; +} + +// eslint-disable-next-line max-params +export function createInpaintSlice( + server: ServerParams, + defaultParams: Required, + defaultHighres: HighresParams, + defaultModel: ModelParams, + defaultUpscale: UpscaleParams, +): Slice { + return (set) => ({ + inpaint: { + ...defaultParams, + fillColor: server.fillColor.default, + filter: server.filter.default, + // eslint-disable-next-line no-null/no-null + mask: null, + noise: server.noise.default, + // eslint-disable-next-line no-null/no-null + source: null, + strength: server.strength.default, + tileOrder: server.tileOrder.default, + }, + inpaintBrush: { + ...DEFAULT_BRUSH, + }, + inpaintHighres: { + ...defaultHighres, + }, + inpaintModel: { + ...defaultModel, + }, + inpaintUpscale: { + ...defaultUpscale, + }, + outpaint: { + enabled: false, + left: server.left.default, + right: server.right.default, + top: server.top.default, + bottom: server.bottom.default, + }, + resetInpaint() { + set({ + inpaint: { + ...defaultParams, + fillColor: server.fillColor.default, + filter: server.filter.default, + // eslint-disable-next-line no-null/no-null + mask: null, + noise: server.noise.default, + // eslint-disable-next-line no-null/no-null + source: null, + strength: server.strength.default, + tileOrder: server.tileOrder.default, + }, + } as Partial); + }, + setInpaint(params) { + set((prev) => ({ + inpaint: { + ...prev.inpaint, + ...params, + }, + } as Partial)); + }, + setInpaintBrush(brush) { + set((prev) => ({ + inpaintBrush: { + ...prev.inpaintBrush, + ...brush, + }, + } as Partial)); + }, + setInpaintHighres(params) { + set((prev) => ({ + inpaintHighres: { + ...prev.inpaintHighres, + ...params, + }, + } as Partial)); + }, + setInpaintModel(params) { + set((prev) => ({ + inpaintModel: { + ...prev.inpaintModel, + ...params, + }, + } as Partial)); + }, + setInpaintUpscale(params) { + set((prev) => ({ + inpaintUpscale: { + ...prev.inpaintUpscale, + ...params, + }, + } as Partial)); + }, + setOutpaint(pixels) { + set((prev) => ({ + outpaint: { + ...prev.outpaint, + ...pixels, + } + } as Partial)); + }, + }); +} diff --git a/gui/src/state/migration/default.ts b/gui/src/state/migration/default.ts new file mode 100644 index 00000000..65a4353b --- /dev/null +++ b/gui/src/state/migration/default.ts @@ -0,0 +1,93 @@ +/* eslint-disable camelcase */ +import { Logger } from 'browser-bunyan'; +import { ServerParams } from '../../config.js'; +import { BaseImgParams } from '../../types/params.js'; +import { OnnxState, STATE_VERSION } from '../full.js'; +import { Img2ImgSlice } from '../img2img.js'; +import { InpaintSlice } from '../inpaint.js'; +import { Txt2ImgSlice } from '../txt2img.js'; +import { UpscaleSlice } from '../upscale.js'; + +// #region V7 +export const V7 = 7; + +export type BaseImgParamsV7 = Omit & { + overlap: number; + tile: number; +}; + +export type OnnxStateV7 = Omit & { + img2img: BaseImgParamsV7; + inpaint: BaseImgParamsV7; + txt2img: BaseImgParamsV7; + upscale: BaseImgParamsV7; +}; +// #endregion + +// #region V11 +export const REMOVED_KEYS_V11 = ['tile', 'overlap'] as const; + +export type RemovedKeysV11 = typeof REMOVED_KEYS_V11[number]; + +// TODO: can the compiler calculate this? +export type AddedKeysV11 = 'unet_tile' | 'unet_overlap' | 'vae_tile' | 'vae_overlap'; +// #endregion + +// add versions to this list as they are replaced +export type PreviousState = OnnxStateV7; + +// always the latest version +export type CurrentState = OnnxState; + +// any version of state +export type UnknownState = PreviousState | CurrentState; + +export function applyStateMigrations(params: ServerParams, previousState: UnknownState, version: number, logger: Logger): OnnxState { + logger.info('applying state migrations from version %s to version %s', version, STATE_VERSION); + + if (version <= V7) { + return migrateV7ToV11(params, previousState as PreviousState); + } + + return previousState as CurrentState; +} + +export function migrateV7ToV11(params: ServerParams, previousState: PreviousState): CurrentState { + // add any missing keys + const result: CurrentState = { + ...params, + ...previousState, + img2img: { + ...previousState.img2img, + unet_overlap: params.unet_overlap.default, + unet_tile: params.unet_tile.default, + vae_overlap: params.vae_overlap.default, + vae_tile: params.vae_tile.default, + }, + inpaint: { + ...previousState.inpaint, + unet_overlap: params.unet_overlap.default, + unet_tile: params.unet_tile.default, + vae_overlap: params.vae_overlap.default, + vae_tile: params.vae_tile.default, + }, + txt2img: { + ...previousState.txt2img, + unet_overlap: params.unet_overlap.default, + unet_tile: params.unet_tile.default, + vae_overlap: params.vae_overlap.default, + vae_tile: params.vae_tile.default, + }, + upscale: { + ...previousState.upscale, + unet_overlap: params.unet_overlap.default, + unet_tile: params.unet_tile.default, + vae_overlap: params.vae_overlap.default, + vae_tile: params.vae_tile.default, + }, + }; + + // TODO: remove extra keys + + return result; +} diff --git a/gui/src/state/model.ts b/gui/src/state/model.ts new file mode 100644 index 00000000..3a473182 --- /dev/null +++ b/gui/src/state/model.ts @@ -0,0 +1,202 @@ +import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from '../types/model.js'; +import { MISSING_INDEX, Slice } from './types.js'; + +export interface ModelSlice { + extras: ExtrasFile; + + removeCorrectionModel(model: CorrectionModel): void; + removeDiffusionModel(model: DiffusionModel): void; + removeExtraNetwork(model: ExtraNetwork): void; + removeExtraSource(model: ExtraSource): void; + removeUpscalingModel(model: UpscalingModel): void; + + setExtras(extras: Partial): void; + + setCorrectionModel(model: CorrectionModel): void; + setDiffusionModel(model: DiffusionModel): void; + setExtraNetwork(model: ExtraNetwork): void; + setExtraSource(model: ExtraSource): void; + setUpscalingModel(model: UpscalingModel): void; +} + +// eslint-disable-next-line sonarjs/cognitive-complexity +export function createModelSlice(): Slice { + // eslint-disable-next-line sonarjs/cognitive-complexity + return (set) => ({ + extras: { + correction: [], + diffusion: [], + networks: [], + sources: [], + upscaling: [], + }, + setExtras(extras) { + set((prev) => ({ + ...prev, + extras: { + ...prev.extras, + ...extras, + }, + })); + }, + setCorrectionModel(model) { + set((prev) => { + const correction = [...prev.extras.correction]; + const exists = correction.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + correction.push(model); + } else { + correction[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + correction, + }, + }; + }); + }, + setDiffusionModel(model) { + set((prev) => { + const diffusion = [...prev.extras.diffusion]; + const exists = diffusion.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + diffusion.push(model); + } else { + diffusion[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + diffusion, + }, + }; + }); + }, + setExtraNetwork(model) { + set((prev) => { + const networks = [...prev.extras.networks]; + const exists = networks.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + networks.push(model); + } else { + networks[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + networks, + }, + }; + }); + }, + setExtraSource(model) { + set((prev) => { + const sources = [...prev.extras.sources]; + const exists = sources.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + sources.push(model); + } else { + sources[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + sources, + }, + }; + }); + }, + setUpscalingModel(model) { + set((prev) => { + const upscaling = [...prev.extras.upscaling]; + const exists = upscaling.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + upscaling.push(model); + } else { + upscaling[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + upscaling, + }, + }; + }); + }, + removeCorrectionModel(model) { + set((prev) => { + const correction = prev.extras.correction.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + correction, + }, + }; + }); + + }, + removeDiffusionModel(model) { + set((prev) => { + const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + diffusion, + }, + }; + }); + + }, + removeExtraNetwork(model) { + set((prev) => { + const networks = prev.extras.networks.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + networks, + }, + }; + }); + + }, + removeExtraSource(model) { + set((prev) => { + const sources = prev.extras.sources.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + sources, + }, + }; + }); + + }, + removeUpscalingModel(model) { + set((prev) => { + const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + upscaling, + }, + }; + }); + }, + }); +} diff --git a/gui/src/state/profile.ts b/gui/src/state/profile.ts new file mode 100644 index 00000000..73d52eab --- /dev/null +++ b/gui/src/state/profile.ts @@ -0,0 +1,52 @@ +import { Maybe } from '@apextoaster/js-utils'; +import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js'; +import { Slice } from './types.js'; + +export interface ProfileItem { + name: string; + params: BaseImgParams | Txt2ImgParams; + highres?: Maybe; + upscale?: Maybe; +} + +export interface ProfileSlice { + profiles: Array; + + removeProfile(profileName: string): void; + + saveProfile(profile: ProfileItem): void; +} + +export function createProfileSlice(): Slice { + return (set) => ({ + profiles: [], + saveProfile(profile: ProfileItem) { + set((prev) => { + const profiles = [...prev.profiles]; + const idx = profiles.findIndex((it) => it.name === profile.name); + if (idx >= 0) { + profiles[idx] = profile; + } else { + profiles.push(profile); + } + return { + ...prev, + profiles, + }; + }); + }, + removeProfile(profileName: string) { + set((prev) => { + const profiles = [...prev.profiles]; + const idx = profiles.findIndex((it) => it.name === profileName); + if (idx >= 0) { + profiles.splice(idx, 1); + } + return { + ...prev, + profiles, + }; + }); + } + }); +} diff --git a/gui/src/state/reset.ts b/gui/src/state/reset.ts new file mode 100644 index 00000000..53272e5c --- /dev/null +++ b/gui/src/state/reset.ts @@ -0,0 +1,28 @@ +import { BlendSlice } from './blend.js'; +import { Img2ImgSlice } from './img2img.js'; +import { InpaintSlice } from './inpaint.js'; +import { Txt2ImgSlice } from './txt2img.js'; +import { Slice } from './types.js'; +import { UpscaleSlice } from './upscale.js'; + +export type SlicesWithReset = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & UpscaleSlice & BlendSlice; + +export interface ResetSlice { + resetAll(): void; +} + +export function createResetSlice(): Slice { + return (set) => ({ + resetAll() { + set((prev) => { + const next = { ...prev }; + next.resetImg2Img(); + next.resetInpaint(); + next.resetTxt2Img(); + next.resetUpscale(); + next.resetBlend(); + return next; + }); + }, + }); +} diff --git a/gui/src/state/txt2img.ts b/gui/src/state/txt2img.ts new file mode 100644 index 00000000..8bec3273 --- /dev/null +++ b/gui/src/state/txt2img.ts @@ -0,0 +1,105 @@ +import { PipelineGrid } from '../client/utils.js'; +import { ServerParams } from '../config.js'; +import { + BaseImgParams, + HighresParams, + ModelParams, + Txt2ImgParams, + UpscaleParams, +} from '../types/params.js'; +import { Slice, TabState } from './types.js'; + +export interface Txt2ImgSlice { + txt2img: TabState; + txt2imgModel: ModelParams; + txt2imgHighres: HighresParams; + txt2imgUpscale: UpscaleParams; + txt2imgVariable: PipelineGrid; + + resetTxt2Img(): void; + + setTxt2Img(params: Partial): void; + setTxt2ImgModel(params: Partial): void; + setTxt2ImgHighres(params: Partial): void; + setTxt2ImgUpscale(params: Partial): void; + setTxt2ImgVariable(params: Partial): void; +} + +// eslint-disable-next-line max-params +export function createTxt2ImgSlice( + server: ServerParams, + defaultParams: Required, + defaultHighres: HighresParams, + defaultModel: ModelParams, + defaultUpscale: UpscaleParams, + defaultGrid: PipelineGrid, +): Slice { + return (set) => ({ + txt2img: { + ...defaultParams, + width: server.width.default, + height: server.height.default, + }, + txt2imgHighres: { + ...defaultHighres, + }, + txt2imgModel: { + ...defaultModel, + }, + txt2imgUpscale: { + ...defaultUpscale, + }, + txt2imgVariable: { + ...defaultGrid, + }, + setTxt2Img(params) { + set((prev) => ({ + txt2img: { + ...prev.txt2img, + ...params, + }, + } as Partial)); + }, + setTxt2ImgHighres(params) { + set((prev) => ({ + txt2imgHighres: { + ...prev.txt2imgHighres, + ...params, + }, + } as Partial)); + }, + setTxt2ImgModel(params) { + set((prev) => ({ + txt2imgModel: { + ...prev.txt2imgModel, + ...params, + }, + } as Partial)); + }, + setTxt2ImgUpscale(params) { + set((prev) => ({ + txt2imgUpscale: { + ...prev.txt2imgUpscale, + ...params, + }, + } as Partial)); + }, + setTxt2ImgVariable(params) { + set((prev) => ({ + txt2imgVariable: { + ...prev.txt2imgVariable, + ...params, + }, + } as Partial)); + }, + resetTxt2Img() { + set({ + txt2img: { + ...defaultParams, + width: server.width.default, + height: server.height.default, + }, + } as Partial); + }, + }); +} diff --git a/gui/src/state/types.ts b/gui/src/state/types.ts new file mode 100644 index 00000000..b1e9d3ad --- /dev/null +++ b/gui/src/state/types.ts @@ -0,0 +1,17 @@ +import { PaletteMode } from '@mui/material'; +import { StateCreator } from 'zustand'; +import { ConfigFiles, ConfigState } from '../config.js'; + +export const MISSING_INDEX = -1; + +export type Theme = PaletteMode | ''; // tri-state, '' is unset + +/** + * Combine optional files and required ranges. + */ +export type TabState = ConfigFiles> & ConfigState>; + +/** + * Shorthand for state creator to reduce repeated arguments. + */ +export type Slice = StateCreator; diff --git a/gui/src/state/upscale.ts b/gui/src/state/upscale.ts new file mode 100644 index 00000000..e78d689a --- /dev/null +++ b/gui/src/state/upscale.ts @@ -0,0 +1,87 @@ +import { + BaseImgParams, + HighresParams, + ModelParams, + UpscaleParams, + UpscaleReqParams, +} from '../types/params.js'; +import { Slice, TabState } from './types.js'; + +export interface UpscaleSlice { + upscale: TabState; + upscaleHighres: HighresParams; + upscaleModel: ModelParams; + upscaleUpscale: UpscaleParams; + + resetUpscale(): void; + + setUpscale(params: Partial): void; + setUpscaleHighres(params: Partial): void; + setUpscaleModel(params: Partial): void; + setUpscaleUpscale(params: Partial): void; +} + +export function createUpscaleSlice( + defaultParams: Required, + defaultHighres: HighresParams, + defaultModel: ModelParams, + defaultUpscale: UpscaleParams, +): Slice { + return (set) => ({ + upscale: { + ...defaultParams, + // eslint-disable-next-line no-null/no-null + source: null, + }, + upscaleHighres: { + ...defaultHighres, + }, + upscaleModel: { + ...defaultModel, + }, + upscaleUpscale: { + ...defaultUpscale, + }, + resetUpscale() { + set({ + upscale: { + ...defaultParams, + // eslint-disable-next-line no-null/no-null + source: null, + }, + } as Partial); + }, + setUpscale(source) { + set((prev) => ({ + upscale: { + ...prev.upscale, + ...source, + }, + } as Partial)); + }, + setUpscaleHighres(params) { + set((prev) => ({ + upscaleHighres: { + ...prev.upscaleHighres, + ...params, + }, + } as Partial)); + }, + setUpscaleModel(params) { + set((prev) => ({ + upscaleModel: { + ...prev.upscaleModel, + ...defaultModel, + }, + } as Partial)); + }, + setUpscaleUpscale(params) { + set((prev) => ({ + upscaleUpscale: { + ...prev.upscaleUpscale, + ...params, + }, + } as Partial)); + }, + }); +} diff --git a/gui/src/strings/en.ts b/gui/src/strings/en.ts index c9d901f9..304c3335 100644 --- a/gui/src/strings/en.ts +++ b/gui/src/strings/en.ts @@ -287,6 +287,7 @@ export const I18N_STRINGS_EN = { 'ddpm': 'DDPM', 'deis-multi': 'DEIS Multistep', 'dpm-multi': 'DPM Multistep', + 'dpm-sde': 'DPM SDE (Turbo)', 'dpm-single': 'DPM Singlestep', 'euler': 'Euler', 'euler-a': 'Euler Ancestral',