From 46d9fc0dd40a2b1c657dd7733b30ffa11101bb01 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 8 Dec 2023 18:49:18 -0600 Subject: [PATCH 01/42] fix(api): make sure diffusion models have a valid prefix --- api/onnx_web/convert/__main__.py | 20 +++++++++++++++++++ .../convert/diffusion/diffusion_xl.py | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 5cbe7f07..fb302ae9 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -265,6 +265,20 @@ def fetch_model( return source, False +DIFFUSION_PREFIX = ["diffusion-", "diffusion/", "diffusion\\", "stable-diffusion-"] + + +def fix_diffusion_name(name: str): + if not any([name.startswith(prefix) for prefix in DIFFUSION_PREFIX]): + logger.warning( + "diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match", + name, + ) + return f"diffusion-{name}" + + return name + + def convert_models(conversion: ConversionContext, args, models: Models): model_errors = [] @@ -351,6 +365,12 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping model: %s", name) else: + # fix up entries with missing prefixes + name = fix_diffusion_name(name) + if name != model["name"]: + # update the model in-memory if the name changed + model["name"] = name + model_format = source_format(model) try: diff --git a/api/onnx_web/convert/diffusion/diffusion_xl.py b/api/onnx_web/convert/diffusion/diffusion_xl.py index 7a03b700..18fa8493 100644 --- a/api/onnx_web/convert/diffusion/diffusion_xl.py +++ b/api/onnx_web/convert/diffusion/diffusion_xl.py @@ -64,7 +64,7 @@ def convert_diffusion_diffusers_xl( if replace_vae is not None: vae_path = path.join(conversion.model_path, replace_vae) - if check_ext(replace_vae, RESOLVE_FORMATS): + if check_ext(vae_path, RESOLVE_FORMATS): logger.debug("loading VAE from single tensor file: %s", vae_path) pipeline.vae = AutoencoderKL.from_single_file(vae_path) else: From 293a1bb18462e8ad475fa2d3f80210aaa5dd1540 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 8 Dec 2023 22:26:19 -0600 Subject: [PATCH 02/42] fix(api): include SD upscaling in diffusion prefixes --- api/onnx_web/convert/__main__.py | 15 +-------------- api/onnx_web/convert/utils.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index fb302ae9..262bbf44 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -26,6 +26,7 @@ from .utils import ( DEFAULT_OPSET, ConversionContext, download_progress, + fix_diffusion_name, remove_prefix, source_format, tuple_to_correction, @@ -265,20 +266,6 @@ def fetch_model( return source, False -DIFFUSION_PREFIX = ["diffusion-", "diffusion/", "diffusion\\", "stable-diffusion-"] - - -def fix_diffusion_name(name: str): - if not any([name.startswith(prefix) for prefix in DIFFUSION_PREFIX]): - logger.warning( - "diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match", - name, - ) - return f"diffusion-{name}" - - return name - - def convert_models(conversion: ConversionContext, args, models: Models): model_errors = [] diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index ef44ba20..e56e05ae 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -345,3 +345,17 @@ def onnx_export( all_tensors_to_one_file=True, location=ONNX_WEIGHTS, ) + + +DIFFUSION_PREFIX = ["diffusion-", "diffusion/", "diffusion\\", "stable-diffusion-", "upscaling-"] + + +def fix_diffusion_name(name: str): + if not any([name.startswith(prefix) for prefix in DIFFUSION_PREFIX]): + logger.warning( + "diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match", + name, + ) + return f"diffusion-{name}" + + return name From 4fd50984f00647605e1bcf1dfc860a513ad307c4 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 8 Dec 2023 22:55:51 -0600 Subject: [PATCH 03/42] fix(api): correct VAE extension check during conversion --- api/onnx_web/convert/diffusion/diffusion.py | 3 ++- api/onnx_web/convert/diffusion/diffusion_xl.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index a8ecbbf7..ade270c0 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -381,7 +381,8 @@ def convert_diffusion_diffusers( if replace_vae is not None: vae_path = path.join(conversion.model_path, replace_vae) - if check_ext(replace_vae, RESOLVE_FORMATS): + vae_file = check_ext(replace_vae, RESOLVE_FORMATS) + if vae_file[0]: pipeline.vae = AutoencoderKL.from_single_file(vae_path) else: pipeline.vae = AutoencoderKL.from_pretrained(vae_path) diff --git a/api/onnx_web/convert/diffusion/diffusion_xl.py b/api/onnx_web/convert/diffusion/diffusion_xl.py index 18fa8493..26f50bd8 100644 --- a/api/onnx_web/convert/diffusion/diffusion_xl.py +++ b/api/onnx_web/convert/diffusion/diffusion_xl.py @@ -64,7 +64,8 @@ def convert_diffusion_diffusers_xl( if replace_vae is not None: vae_path = path.join(conversion.model_path, replace_vae) - if check_ext(vae_path, RESOLVE_FORMATS): + vae_file = check_ext(vae_path, RESOLVE_FORMATS) + if vae_file[0]: logger.debug("loading VAE from single tensor file: %s", vae_path) pipeline.vae = AutoencoderKL.from_single_file(vae_path) else: From e7da2cf8a6d0a3d6d455ac4248453a55aa9d7034 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 8 Dec 2023 23:19:52 -0600 Subject: [PATCH 04/42] fix(api): load pretrained VAE from original path --- api/onnx_web/convert/diffusion/diffusion.py | 4 ++-- api/onnx_web/convert/diffusion/diffusion_xl.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index ade270c0..dbb50610 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -381,11 +381,11 @@ def convert_diffusion_diffusers( if replace_vae is not None: vae_path = path.join(conversion.model_path, replace_vae) - vae_file = check_ext(replace_vae, RESOLVE_FORMATS) + vae_file = check_ext(vae_path, RESOLVE_FORMATS) if vae_file[0]: pipeline.vae = AutoencoderKL.from_single_file(vae_path) else: - pipeline.vae = AutoencoderKL.from_pretrained(vae_path) + pipeline.vae = AutoencoderKL.from_pretrained(replace_vae) if is_torch_2_0: pipeline.unet.set_attn_processor(AttnProcessor()) diff --git a/api/onnx_web/convert/diffusion/diffusion_xl.py b/api/onnx_web/convert/diffusion/diffusion_xl.py index 26f50bd8..8370d302 100644 --- a/api/onnx_web/convert/diffusion/diffusion_xl.py +++ b/api/onnx_web/convert/diffusion/diffusion_xl.py @@ -69,8 +69,8 @@ def convert_diffusion_diffusers_xl( logger.debug("loading VAE from single tensor file: %s", vae_path) pipeline.vae = AutoencoderKL.from_single_file(vae_path) else: - logger.debug("loading pretrained VAE from path: %s", vae_path) - pipeline.vae = AutoencoderKL.from_pretrained(vae_path) + logger.debug("loading pretrained VAE from path: %s", replace_vae) + pipeline.vae = AutoencoderKL.from_pretrained(replace_vae) if path.exists(temp_path): logger.debug("torch model already exists for %s: %s", source, temp_path) From cc8e564d26a2678b5feffbb72ca0e3e04d17f34f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 9 Dec 2023 18:04:34 -0600 Subject: [PATCH 05/42] split up download clients, add to plugin exports --- api/onnx_web/convert/__main__.py | 557 ++++++++++----------- api/onnx_web/convert/client/__init__.py | 0 api/onnx_web/convert/client/base.py | 13 + api/onnx_web/convert/client/civitai.py | 45 ++ api/onnx_web/convert/client/file.py | 20 + api/onnx_web/convert/client/http.py | 39 ++ api/onnx_web/convert/client/huggingface.py | 61 +++ api/onnx_web/convert/utils.py | 125 +++-- api/onnx_web/server/plugin.py | 14 +- api/tests/convert/test_utils.py | 2 +- 10 files changed, 550 insertions(+), 326 deletions(-) create mode 100644 api/onnx_web/convert/client/__init__.py create mode 100644 api/onnx_web/convert/client/base.py create mode 100644 api/onnx_web/convert/client/civitai.py create mode 100644 api/onnx_web/convert/client/file.py create mode 100644 api/onnx_web/convert/client/http.py create mode 100644 api/onnx_web/convert/client/huggingface.py diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 262bbf44..3bb66b3c 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -3,16 +3,21 @@ from argparse import ArgumentParser from logging import getLogger from os import makedirs, path from sys import exit -from typing import Any, Dict, List, Optional, Tuple -from urllib.parse import urlparse +from typing import Any, Dict, List, Union -from huggingface_hub.file_download import hf_hub_download from jsonschema import ValidationError, validate from onnx import load_model, save_model from transformers import CLIPTokenizer +from onnx_web.server.plugin import load_plugins + from ..constants import ONNX_MODEL, ONNX_WEIGHTS from ..utils import load_config +from .client.base import BaseClient +from .client.civitai import CivitaiClient +from .client.file import FileClient +from .client.http import HttpClient +from .client.huggingface import HuggingfaceClient from .correction.gfpgan import convert_correction_gfpgan from .diffusion.control import convert_diffusion_control from .diffusion.diffusion import convert_diffusion_diffusers @@ -25,9 +30,8 @@ from .upscaling.swinir import convert_upscaling_swinir from .utils import ( DEFAULT_OPSET, ConversionContext, - download_progress, + fetch_model, fix_diffusion_name, - remove_prefix, source_format, tuple_to_correction, tuple_to_diffusion, @@ -45,16 +49,27 @@ warnings.filterwarnings( ".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*", ) -Models = Dict[str, List[Any]] - logger = getLogger(__name__) +ModelDict = Dict[str, Union[float, int, str]] +Models = Dict[str, List[ModelDict]] -model_sources: Dict[str, Tuple[str, str]] = { - "civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"), + +model_sources: Dict[str, BaseClient] = { + CivitaiClient.protocol: CivitaiClient, + FileClient.protocol: FileClient, + HttpClient.insecure_protocol: HttpClient, + HttpClient.protocol: HttpClient, + HuggingfaceClient.protocol: HuggingfaceClient, } -model_source_huggingface = "huggingface://" +model_converters: Dict[str, Any] = { + "img2img": convert_diffusion_diffusers, + "img2img-sdxl": convert_diffusion_diffusers_xl, + "inpaint": convert_diffusion_diffusers, + "txt2img": convert_diffusion_diffusers, + "txt2img-sdxl": convert_diffusion_diffusers_xl, +} # recommended models base_models: Models = { @@ -62,15 +77,15 @@ base_models: Models = { # v1.x ( "stable-diffusion-onnx-v1-5", - model_source_huggingface + "runwayml/stable-diffusion-v1-5", + HuggingfaceClient.protocol + "runwayml/stable-diffusion-v1-5", ), ( "stable-diffusion-onnx-v1-inpainting", - model_source_huggingface + "runwayml/stable-diffusion-inpainting", + HuggingfaceClient.protocol + "runwayml/stable-diffusion-inpainting", ), ( "upscaling-stable-diffusion-x4", - model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler", + HuggingfaceClient.protocol + "stabilityai/stable-diffusion-x4-upscaler", True, ), ], @@ -201,69 +216,226 @@ base_models: Models = { } -def fetch_model( - conversion: ConversionContext, - name: str, - source: str, - dest: Optional[str] = None, - format: Optional[str] = None, - hf_hub_fetch: bool = False, - hf_hub_filename: Optional[str] = None, -) -> Tuple[str, bool]: - cache_path = dest or conversion.cache_path - cache_name = path.join(cache_path, name) +def add_model_source(proto: str, client: BaseClient): + global model_sources - # add an extension if possible, some of the conversion code checks for it - if format is None: - url = urlparse(source) - ext = path.basename(url.path) - _filename, ext = path.splitext(ext) - if ext is not None: - cache_name = cache_name + ext + if proto in model_sources: + raise ValueError("protocol has already been taken") + + model_sources[proto] = client + + +def convert_source_model(conversion: ConversionContext, model): + model_format = source_format(model) + name = model["name"] + source = model["source"] + + dest_path = None + if "dest" in model: + dest_path = path.join(conversion.model_path, model["dest"]) + + dest, hf = fetch_model( + conversion, name, source, format=model_format, dest=dest_path + ) + logger.info("finished downloading source: %s -> %s", source, dest) + + +def convert_network_model(conversion: ConversionContext, network): + network_format = source_format(network) + network_model = network.get("model", None) + name = network["name"] + network_type = network["type"] + source = network["source"] + + if network_type == "control": + dest, hf = fetch_model( + conversion, + name, + source, + format=network_format, + ) + + convert_diffusion_control( + conversion, + network, + dest, + path.join(conversion.model_path, network_type, name), + ) + if network_type == "inversion" and network_model == "concept": + dest, hf = fetch_model( + conversion, + name, + source, + dest=path.join(conversion.model_path, network_type), + format=network_format, + hf_hub_fetch=True, + hf_hub_filename="learned_embeds.bin", + ) else: - cache_name = f"{cache_name}.{format}" + dest, hf = fetch_model( + conversion, + name, + source, + dest=path.join(conversion.model_path, network_type), + format=network_format, + ) - if path.exists(cache_name): - logger.debug("model already exists in cache, skipping fetch") - return cache_name, False + logger.info("finished downloading network: %s -> %s", source, dest) - for proto in model_sources: - api_name, api_root = model_sources.get(proto) - if source.startswith(proto): - api_source = api_root % (remove_prefix(source, proto)) - logger.info( - "downloading model from %s: %s -> %s", api_name, api_source, cache_name + +def convert_diffusion_model(conversion: ConversionContext, model): + # fix up entries with missing prefixes + name = fix_diffusion_name(model["name"]) + if name != model["name"]: + # update the model in-memory if the name changed + model["name"] = name + + model_format = source_format(model) + source, hf = fetch_model(conversion, name, model["source"], format=model_format) + + pipeline = model.get("pipeline", "txt2img") + converter = model_converters.get(pipeline) + converted, dest = converter( + conversion, + model, + source, + model_format, + hf=hf, + ) + + # make sure blending only happens once, not every run + if converted: + # keep track of which models have been blended + blend_models = {} + + inversion_dest = path.join(conversion.model_path, "inversion") + lora_dest = path.join(conversion.model_path, "lora") + + for inversion in model.get("inversions", []): + if "text_encoder" not in blend_models: + blend_models["text_encoder"] = load_model( + path.join( + dest, + "text_encoder", + ONNX_MODEL, + ) + ) + + if "tokenizer" not in blend_models: + blend_models["tokenizer"] = CLIPTokenizer.from_pretrained( + dest, + subfolder="tokenizer", + ) + + inversion_name = inversion["name"] + inversion_source = inversion["source"] + inversion_format = inversion.get("format", None) + inversion_source, hf = fetch_model( + conversion, + inversion_name, + inversion_source, + dest=inversion_dest, ) - return download_progress([(api_source, cache_name)]), False + inversion_token = inversion.get("token", inversion_name) + inversion_weight = inversion.get("weight", 1.0) - if source.startswith(model_source_huggingface): - hub_source = remove_prefix(source, model_source_huggingface) - logger.info("downloading model from Huggingface Hub: %s", hub_source) - # from_pretrained has a bunch of useful logic that snapshot_download by itself down not - if hf_hub_fetch: - return ( - hf_hub_download( - repo_id=hub_source, - filename=hf_hub_filename, - cache_dir=cache_path, - force_filename=f"{name}.bin", - ), - False, + blend_textual_inversions( + conversion, + blend_models["text_encoder"], + blend_models["tokenizer"], + [ + ( + inversion_source, + inversion_weight, + inversion_token, + inversion_format, + ) + ], ) - else: - return hub_source, True - elif source.startswith("https://"): - logger.info("downloading model from: %s", source) - return download_progress([(source, cache_name)]), False - elif source.startswith("http://"): - logger.warning("downloading model from insecure source: %s", source) - return download_progress([(source, cache_name)]), False - elif source.startswith(path.sep) or source.startswith("."): - logger.info("using local model: %s", source) - return source, False + + for lora in model.get("loras", []): + if "text_encoder" not in blend_models: + blend_models["text_encoder"] = load_model( + path.join( + dest, + "text_encoder", + ONNX_MODEL, + ) + ) + + if "unet" not in blend_models: + blend_models["unet"] = load_model(path.join(dest, "unet", ONNX_MODEL)) + + # load models if not loaded yet + lora_name = lora["name"] + lora_source = lora["source"] + lora_source, hf = fetch_model( + conversion, + f"{name}-lora-{lora_name}", + lora_source, + dest=lora_dest, + ) + lora_weight = lora.get("weight", 1.0) + + blend_loras( + conversion, + blend_models["text_encoder"], + [(lora_source, lora_weight)], + "text_encoder", + ) + + blend_loras( + conversion, + blend_models["unet"], + [(lora_source, lora_weight)], + "unet", + ) + + if "tokenizer" in blend_models: + dest_path = path.join(dest, "tokenizer") + logger.debug("saving blended tokenizer to %s", dest_path) + blend_models["tokenizer"].save_pretrained(dest_path) + + for name in ["text_encoder", "unet"]: + if name in blend_models: + dest_path = path.join(dest, name, ONNX_MODEL) + logger.debug("saving blended %s model to %s", name, dest_path) + save_model( + blend_models[name], + dest_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=ONNX_WEIGHTS, + ) + + +def convert_upscaling_model(conversion: ConversionContext, model): + model_format = source_format(model) + name = model["name"] + + source, hf = fetch_model(conversion, name, model["source"], format=model_format) + model_type = model.get("model", "resrgan") + if model_type == "bsrgan": + convert_upscaling_bsrgan(conversion, model, source) + elif model_type == "resrgan": + convert_upscale_resrgan(conversion, model, source) + elif model_type == "swinir": + convert_upscaling_swinir(conversion, model, source) else: - logger.info("unknown model location, using path as provided: %s", source) - return source, False + logger.error("unknown upscaling model type %s for %s", model_type, name) + model_errors.append(name) + + +def convert_correction_model(conversion: ConversionContext, model): + model_format = source_format(model) + name = model["name"] + source, hf = fetch_model(conversion, name, model["source"], format=model_format) + model_type = model.get("model", "gfpgan") + if model_type == "gfpgan": + convert_correction_gfpgan(conversion, model, source) + else: + logger.error("unknown correction model type %s for %s", model_type, name) + model_errors.append(name) def convert_models(conversion: ConversionContext, args, models: Models): @@ -277,18 +449,8 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping source: %s", name) else: - model_format = source_format(model) - source = model["source"] - try: - dest_path = None - if "dest" in model: - dest_path = path.join(conversion.model_path, model["dest"]) - - dest, hf = fetch_model( - conversion, name, source, format=model_format, dest=dest_path - ) - logger.info("finished downloading source: %s -> %s", source, dest) + convert_source_model(model) except Exception: logger.exception("error fetching source %s", name) model_errors.append(name) @@ -300,46 +462,8 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping network: %s", name) else: - network_format = source_format(network) - network_model = network.get("model", None) - network_type = network["type"] - source = network["source"] - try: - if network_type == "control": - dest, hf = fetch_model( - conversion, - name, - source, - format=network_format, - ) - - convert_diffusion_control( - conversion, - network, - dest, - path.join(conversion.model_path, network_type, name), - ) - if network_type == "inversion" and network_model == "concept": - dest, hf = fetch_model( - conversion, - name, - source, - dest=path.join(conversion.model_path, network_type), - format=network_format, - hf_hub_fetch=True, - hf_hub_filename="learned_embeds.bin", - ) - else: - dest, hf = fetch_model( - conversion, - name, - source, - dest=path.join(conversion.model_path, network_type), - format=network_format, - ) - - logger.info("finished downloading network: %s -> %s", source, dest) + convert_network_model(conversion, model) except Exception: logger.exception("error fetching network %s", name) model_errors.append(name) @@ -352,148 +476,8 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping model: %s", name) else: - # fix up entries with missing prefixes - name = fix_diffusion_name(name) - if name != model["name"]: - # update the model in-memory if the name changed - model["name"] = name - - model_format = source_format(model) - try: - source, hf = fetch_model( - conversion, name, model["source"], format=model_format - ) - - pipeline = model.get("pipeline", "txt2img") - if pipeline.endswith("-sdxl"): - converted, dest = convert_diffusion_diffusers_xl( - conversion, - model, - source, - model_format, - hf=hf, - ) - else: - converted, dest = convert_diffusion_diffusers( - conversion, - model, - source, - model_format, - hf=hf, - ) - - # make sure blending only happens once, not every run - if converted: - # keep track of which models have been blended - blend_models = {} - - inversion_dest = path.join(conversion.model_path, "inversion") - lora_dest = path.join(conversion.model_path, "lora") - - for inversion in model.get("inversions", []): - if "text_encoder" not in blend_models: - blend_models["text_encoder"] = load_model( - path.join( - dest, - "text_encoder", - ONNX_MODEL, - ) - ) - - if "tokenizer" not in blend_models: - blend_models[ - "tokenizer" - ] = CLIPTokenizer.from_pretrained( - dest, - subfolder="tokenizer", - ) - - inversion_name = inversion["name"] - inversion_source = inversion["source"] - inversion_format = inversion.get("format", None) - inversion_source, hf = fetch_model( - conversion, - inversion_name, - inversion_source, - dest=inversion_dest, - ) - inversion_token = inversion.get("token", inversion_name) - inversion_weight = inversion.get("weight", 1.0) - - blend_textual_inversions( - conversion, - blend_models["text_encoder"], - blend_models["tokenizer"], - [ - ( - inversion_source, - inversion_weight, - inversion_token, - inversion_format, - ) - ], - ) - - for lora in model.get("loras", []): - if "text_encoder" not in blend_models: - blend_models["text_encoder"] = load_model( - path.join( - dest, - "text_encoder", - ONNX_MODEL, - ) - ) - - if "unet" not in blend_models: - blend_models["unet"] = load_model( - path.join(dest, "unet", ONNX_MODEL) - ) - - # load models if not loaded yet - lora_name = lora["name"] - lora_source = lora["source"] - lora_source, hf = fetch_model( - conversion, - f"{name}-lora-{lora_name}", - lora_source, - dest=lora_dest, - ) - lora_weight = lora.get("weight", 1.0) - - blend_loras( - conversion, - blend_models["text_encoder"], - [(lora_source, lora_weight)], - "text_encoder", - ) - - blend_loras( - conversion, - blend_models["unet"], - [(lora_source, lora_weight)], - "unet", - ) - - if "tokenizer" in blend_models: - dest_path = path.join(dest, "tokenizer") - logger.debug("saving blended tokenizer to %s", dest_path) - blend_models["tokenizer"].save_pretrained(dest_path) - - for name in ["text_encoder", "unet"]: - if name in blend_models: - dest_path = path.join(dest, name, ONNX_MODEL) - logger.debug( - "saving blended %s model to %s", name, dest_path - ) - save_model( - blend_models[name], - dest_path, - save_as_external_data=True, - all_tensors_to_one_file=True, - location=ONNX_WEIGHTS, - ) - + convert_diffusion_model(conversion, model) except Exception: logger.exception( "error converting diffusion model %s", @@ -509,24 +493,8 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping model: %s", name) else: - model_format = source_format(model) - try: - source, hf = fetch_model( - conversion, name, model["source"], format=model_format - ) - model_type = model.get("model", "resrgan") - if model_type == "bsrgan": - convert_upscaling_bsrgan(conversion, model, source) - elif model_type == "resrgan": - convert_upscale_resrgan(conversion, model, source) - elif model_type == "swinir": - convert_upscaling_swinir(conversion, model, source) - else: - logger.error( - "unknown upscaling model type %s for %s", model_type, name - ) - model_errors.append(name) + convert_upscaling_model(conversion, model) except Exception: logger.exception( "error converting upscaling model %s", @@ -542,19 +510,8 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping model: %s", name) else: - model_format = source_format(model) try: - source, hf = fetch_model( - conversion, name, model["source"], format=model_format - ) - model_type = model.get("model", "gfpgan") - if model_type == "gfpgan": - convert_correction_gfpgan(conversion, model, source) - else: - logger.error( - "unknown correction model type %s for %s", model_type, name - ) - model_errors.append(name) + convert_correction_model(conversion, model) except Exception: logger.exception( "error converting correction model %s", @@ -566,12 +523,26 @@ def convert_models(conversion: ConversionContext, args, models: Models): logger.error("error while converting models: %s", model_errors) +def register_plugins(conversion: ConversionContext): + logger.info("loading conversion plugins") + exports = load_plugins(conversion) + + for proto, client in exports.clients: + try: + add_model_source(proto, client) + except Exception: + logger.exception("error loading client for protocol: %s", proto) + + # TODO: add converters + + def main(args=None) -> int: parser = ArgumentParser( prog="onnx-web model converter", description="convert checkpoint models to ONNX" ) # model groups + parser.add_argument("--base", action="store_true", default=True) parser.add_argument("--networks", action="store_true", default=True) parser.add_argument("--sources", action="store_true", default=True) parser.add_argument("--correction", action="store_true", default=False) @@ -609,16 +580,20 @@ def main(args=None) -> int: server.half = args.half or server.has_optimization("onnx-fp16") server.opset = args.opset server.token = args.token + + register_plugins(server) + logger.info( - "converting models in %s using %s", server.model_path, server.training_device + "converting models into %s using %s", server.model_path, server.training_device ) if not path.exists(server.model_path): logger.info("model path does not existing, creating: %s", server.model_path) makedirs(server.model_path) - logger.info("converting base models") - convert_models(server, args, base_models) + if args.base: + logger.info("converting base models") + convert_models(server, args, base_models) extras = [] extras.extend(server.extra_models) diff --git a/api/onnx_web/convert/client/__init__.py b/api/onnx_web/convert/client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/onnx_web/convert/client/base.py b/api/onnx_web/convert/client/base.py new file mode 100644 index 00000000..c38ec082 --- /dev/null +++ b/api/onnx_web/convert/client/base.py @@ -0,0 +1,13 @@ +from typing import Optional +from ..utils import ConversionContext + + +class BaseClient: + def download( + self, + conversion: ConversionContext, + name: str, + source: str, + format: Optional[str], + ) -> str: + raise NotImplementedError() diff --git a/api/onnx_web/convert/client/civitai.py b/api/onnx_web/convert/client/civitai.py new file mode 100644 index 00000000..1bc8c4a9 --- /dev/null +++ b/api/onnx_web/convert/client/civitai.py @@ -0,0 +1,45 @@ +from ..utils import ( + ConversionContext, + build_cache_paths, + download_progress, + get_first_exists, + remove_prefix, +) +from .base import BaseClient +from typing import Optional +from logging import getLogger + +logger = getLogger(__name__) + +CIVITAI_ROOT = "https://civitai.com/api/download/models/%s" + + +class CivitaiClient(BaseClient): + protocol = "civitai://" + root: str + token: Optional[str] + + def __init__(self, token: Optional[str] = None, root=CIVITAI_ROOT): + self.root = root + self.token = token + + def download( + self, + conversion: ConversionContext, + name: str, + source: str, + format: Optional[str], + ) -> str: + """ + TODO: download with auth token + """ + cache_paths = build_cache_paths( + conversion, name, client=CivitaiClient.name, format=format + ) + cached = get_first_exists(cache_paths) + if cached: + return cached + + source = self.root % (remove_prefix(source, CivitaiClient.protocol)) + logger.info("downloading model from Civitai: %s -> %s", source, cache_paths[0]) + return download_progress(source, cache_paths[0]) diff --git a/api/onnx_web/convert/client/file.py b/api/onnx_web/convert/client/file.py new file mode 100644 index 00000000..8c44c0d0 --- /dev/null +++ b/api/onnx_web/convert/client/file.py @@ -0,0 +1,20 @@ +from .base import BaseClient +from logging import getLogger +from os import path +from urllib.parse import urlparse + +logger = getLogger(__name__) + + +class FileClient(BaseClient): + protocol = "file://" + + root: str + + def __init__(self, root: str): + self.root = root + + def download(self, uri: str) -> str: + parts = urlparse(uri) + logger.info("loading model from: %s", parts.path) + return path.join(self.root, parts.path) diff --git a/api/onnx_web/convert/client/http.py b/api/onnx_web/convert/client/http.py new file mode 100644 index 00000000..5259402f --- /dev/null +++ b/api/onnx_web/convert/client/http.py @@ -0,0 +1,39 @@ +from ..utils import ( + ConversionContext, + build_cache_paths, + download_progress, + get_first_exists, + remove_prefix, +) +from .base import BaseClient +from typing import Dict, Optional +from logging import getLogger + +logger = getLogger(__name__) + + +class HttpClient(BaseClient): + name = "http" + protocol = "https://" + insecure_protocol = "http://" + + headers: Dict[str, str] + + def __init__(self, headers: Optional[Dict[str, str]] = None): + self.headers = headers or {} + + def download(self, conversion: ConversionContext, name: str, uri: str) -> str: + cache_paths = build_cache_paths( + conversion, name, client=HttpClient.name, format=format + ) + cached = get_first_exists(cache_paths) + if cached: + return cached + + if uri.startswith(HttpClient.protocol): + source = remove_prefix(uri, HttpClient.protocol) + logger.info("downloading model from: %s", source) + elif uri.startswith(HttpClient.insecure_protocol): + logger.warning("downloading model from insecure source: %s", source) + + return download_progress(source, cache_paths[0]) diff --git a/api/onnx_web/convert/client/huggingface.py b/api/onnx_web/convert/client/huggingface.py new file mode 100644 index 00000000..61796e99 --- /dev/null +++ b/api/onnx_web/convert/client/huggingface.py @@ -0,0 +1,61 @@ +from ..utils import ( + ConversionContext, + build_cache_paths, + get_first_exists, + remove_prefix, +) +from .base import BaseClient +from typing import Optional, Any +from logging import getLogger +from huggingface_hub.file_download import hf_hub_download + +logger = getLogger(__name__) + + +class HuggingfaceClient(BaseClient): + name = "huggingface" + protocol = "huggingface://" + + download: Any + token: Optional[str] + + def __init__(self, token: Optional[str] = None, download=hf_hub_download): + self.download = download + self.token = token + + def download( + self, + conversion: ConversionContext, + name: str, + source: str, + format: Optional[str], + ) -> str: + """ + TODO: download with auth + """ + hf_hub_fetch = TODO + hf_hub_filename = TODO + + cache_paths = build_cache_paths( + conversion, name, client=HuggingfaceClient.name, format=format + ) + cached = get_first_exists(cache_paths) + if cached: + return cached + + source = remove_prefix(source, HuggingfaceClient.protocol) + logger.info("downloading model from Huggingface Hub: %s", source) + + # from_pretrained has a bunch of useful logic that snapshot_download by itself down not + if hf_hub_fetch: + return ( + hf_hub_download( + repo_id=source, + filename=hf_hub_filename, + cache_dir=cache_paths[0], + force_filename=f"{name}.bin", + ), + False, + ) + else: + return source diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index e56e05ae..19e6a309 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -15,6 +15,8 @@ from onnxruntime.transformers.float16 import convert_float_to_float16 from packaging import version from torch.onnx import export +from onnx_web.convert.client.file import FileClient + from ..constants import ONNX_WEIGHTS from ..errors import RequestException from ..server import ServerContext @@ -85,38 +87,37 @@ class ConversionContext(ServerContext): return torch.device(self.training_device) -def download_progress(urls: List[Tuple[str, str]]): - for url, dest in urls: - dest_path = Path(dest).expanduser().resolve() - dest_path.parent.mkdir(parents=True, exist_ok=True) - - if dest_path.exists(): - logger.debug("destination already exists: %s", dest_path) - return str(dest_path.absolute()) - - req = requests.get( - url, - stream=True, - allow_redirects=True, - headers={ - "User-Agent": "onnx-web-api", - }, - ) - if req.status_code != 200: - req.raise_for_status() # Only works for 4xx errors, per SO answer - raise RequestException( - "request to %s failed with status code: %s" % (url, req.status_code) - ) - - total = int(req.headers.get("Content-Length", 0)) - desc = "unknown" if total == 0 else "" - req.raw.read = partial(req.raw.read, decode_content=True) - with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data: - with dest_path.open("wb") as f: - shutil.copyfileobj(data, f) +def download_progress(source: str, dest: str): + dest_path = Path(dest).expanduser().resolve() + dest_path.parent.mkdir(parents=True, exist_ok=True) + if dest_path.exists(): + logger.debug("destination already exists: %s", dest_path) return str(dest_path.absolute()) + req = requests.get( + source, + stream=True, + allow_redirects=True, + headers={ + "User-Agent": "onnx-web-api", + }, + ) + if req.status_code != 200: + req.raise_for_status() # Only works for 4xx errors, per SO answer + raise RequestException( + "request to %s failed with status code: %s" % (source, req.status_code) + ) + + total = int(req.headers.get("Content-Length", 0)) + desc = "unknown" if total == 0 else "" + req.raw.read = partial(req.raw.read, decode_content=True) + with tqdm.wrapattr(req.raw, "read", total=total, desc=desc) as data: + with dest_path.open("wb") as f: + shutil.copyfileobj(data, f) + + return str(dest_path.absolute()) + def tuple_to_source(model: Union[ModelDict, LegacyModel]): if isinstance(model, list) or isinstance(model, tuple): @@ -347,7 +348,13 @@ def onnx_export( ) -DIFFUSION_PREFIX = ["diffusion-", "diffusion/", "diffusion\\", "stable-diffusion-", "upscaling-"] +DIFFUSION_PREFIX = [ + "diffusion-", + "diffusion/", + "diffusion\\", + "stable-diffusion-", + "upscaling-", +] def fix_diffusion_name(name: str): @@ -359,3 +366,61 @@ def fix_diffusion_name(name: str): return f"diffusion-{name}" return name + + +def fetch_model( + conversion: ConversionContext, + name: str, + source: str, + format: Optional[str], +) -> Tuple[str, bool]: + # TODO: switch to urlparse's default scheme + if source.startswith(path.sep) or source.startswith("."): + logger.info("adding file protocol to local path source: %s", source) + source = FileClient.protocol + source + + for proto, client_type in model_sources.items(): + if source.startswith(proto): + client = client_type() + return client.download(proto) + + logger.warning("unknown model protocol, using path as provided: %s", source) + return source, False + + +def build_cache_paths( + conversion: ConversionContext, + name: str, + client: Optional[str] = None, + dest: Optional[str] = None, + format: Optional[str] = None, +) -> List[str]: + cache_path = dest or conversion.cache_path + + # add an extension if possible, some of the conversion code checks for it + if format is not None: + basename = path.basename(name) + _filename, ext = path.splitext(basename) + if ext is None: + name = f"{name}.{format}" + + paths = [ + path.join(cache_path, name), + ] + + if client is not None: + client_path = path.join(cache_path, client) + paths.append(path.join(client_path, name)) + + return paths + + +def get_first_exists( + paths: List[str], +) -> Optional[str]: + for name in paths: + if path.exists(name): + logger.debug("model already exists in cache, skipping fetch: %s", name) + return name + + return None diff --git a/api/onnx_web/server/plugin.py b/api/onnx_web/server/plugin.py index 022047df..2adcabb3 100644 --- a/api/onnx_web/server/plugin.py +++ b/api/onnx_web/server/plugin.py @@ -2,18 +2,24 @@ from importlib import import_module from logging import getLogger from typing import Any, Callable, Dict -from onnx_web.chain.stages import add_stage -from onnx_web.diffusers.load import add_pipeline -from onnx_web.server.context import ServerContext +from ..chain.stages import add_stage +from ..diffusers.load import add_pipeline +from ..server.context import ServerContext logger = getLogger(__name__) class PluginExports: + clients: Dict[str, Any] + converter: Dict[str, Any] pipelines: Dict[str, Any] stages: Dict[str, Any] - def __init__(self, pipelines=None, stages=None) -> None: + def __init__( + self, clients=None, converter=None, pipelines=None, stages=None + ) -> None: + self.clients = clients or {} + self.converter = converter or {} self.pipelines = pipelines or {} self.stages = stages or {} diff --git a/api/tests/convert/test_utils.py b/api/tests/convert/test_utils.py index 4281adbc..34b0bf9b 100644 --- a/api/tests/convert/test_utils.py +++ b/api/tests/convert/test_utils.py @@ -27,7 +27,7 @@ class ConversionContextTests(unittest.TestCase): class DownloadProgressTests(unittest.TestCase): def test_download_example(self): - path = download_progress([("https://example.com", "/tmp/example-dot-com")]) + path = download_progress("https://example.com", "/tmp/example-dot-com") self.assertEqual(path, "/tmp/example-dot-com") From 7496613f4e70efb540711e53df85e5d61dbacef1 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 9 Dec 2023 18:46:47 -0600 Subject: [PATCH 06/42] fix fetch calls --- api/onnx_web/convert/__main__.py | 67 +++++----------------- api/onnx_web/convert/client/__init__.py | 51 ++++++++++++++++ api/onnx_web/convert/client/base.py | 4 +- api/onnx_web/convert/client/civitai.py | 14 +++-- api/onnx_web/convert/client/file.py | 14 ++++- api/onnx_web/convert/client/http.py | 20 +++++-- api/onnx_web/convert/client/huggingface.py | 25 +++++--- api/onnx_web/convert/utils.py | 53 ++++++----------- api/tests/test_diffusers/test_run.py | 1 + 9 files changed, 143 insertions(+), 106 deletions(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 3bb66b3c..e05a46b9 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -9,14 +9,10 @@ from jsonschema import ValidationError, validate from onnx import load_model, save_model from transformers import CLIPTokenizer -from onnx_web.server.plugin import load_plugins - from ..constants import ONNX_MODEL, ONNX_WEIGHTS +from ..server.plugin import load_plugins from ..utils import load_config -from .client.base import BaseClient -from .client.civitai import CivitaiClient -from .client.file import FileClient -from .client.http import HttpClient +from .client import add_model_source, fetch_model from .client.huggingface import HuggingfaceClient from .correction.gfpgan import convert_correction_gfpgan from .diffusion.control import convert_diffusion_control @@ -30,7 +26,6 @@ from .upscaling.swinir import convert_upscaling_swinir from .utils import ( DEFAULT_OPSET, ConversionContext, - fetch_model, fix_diffusion_name, source_format, tuple_to_correction, @@ -54,15 +49,6 @@ logger = getLogger(__name__) ModelDict = Dict[str, Union[float, int, str]] Models = Dict[str, List[ModelDict]] - -model_sources: Dict[str, BaseClient] = { - CivitaiClient.protocol: CivitaiClient, - FileClient.protocol: FileClient, - HttpClient.insecure_protocol: HttpClient, - HttpClient.protocol: HttpClient, - HuggingfaceClient.protocol: HuggingfaceClient, -} - model_converters: Dict[str, Any] = { "img2img": convert_diffusion_diffusers, "img2img-sdxl": convert_diffusion_diffusers_xl, @@ -216,15 +202,6 @@ base_models: Models = { } -def add_model_source(proto: str, client: BaseClient): - global model_sources - - if proto in model_sources: - raise ValueError("protocol has already been taken") - - model_sources[proto] = client - - def convert_source_model(conversion: ConversionContext, model): model_format = source_format(model) name = model["name"] @@ -234,21 +211,18 @@ def convert_source_model(conversion: ConversionContext, model): if "dest" in model: dest_path = path.join(conversion.model_path, model["dest"]) - dest, hf = fetch_model( - conversion, name, source, format=model_format, dest=dest_path - ) + dest = fetch_model(conversion, name, source, format=model_format, dest=dest_path) logger.info("finished downloading source: %s -> %s", source, dest) def convert_network_model(conversion: ConversionContext, network): network_format = source_format(network) - network_model = network.get("model", None) name = network["name"] network_type = network["type"] source = network["source"] if network_type == "control": - dest, hf = fetch_model( + dest = fetch_model( conversion, name, source, @@ -261,18 +235,8 @@ def convert_network_model(conversion: ConversionContext, network): dest, path.join(conversion.model_path, network_type, name), ) - if network_type == "inversion" and network_model == "concept": - dest, hf = fetch_model( - conversion, - name, - source, - dest=path.join(conversion.model_path, network_type), - format=network_format, - hf_hub_fetch=True, - hf_hub_filename="learned_embeds.bin", - ) else: - dest, hf = fetch_model( + dest = fetch_model( conversion, name, source, @@ -290,17 +254,16 @@ def convert_diffusion_model(conversion: ConversionContext, model): # update the model in-memory if the name changed model["name"] = name - model_format = source_format(model) - source, hf = fetch_model(conversion, name, model["source"], format=model_format) + format = source_format(model) + dest = fetch_model(conversion, name, model["source"], format=format) pipeline = model.get("pipeline", "txt2img") converter = model_converters.get(pipeline) converted, dest = converter( conversion, model, - source, - model_format, - hf=hf, + dest, + format, ) # make sure blending only happens once, not every run @@ -330,7 +293,7 @@ def convert_diffusion_model(conversion: ConversionContext, model): inversion_name = inversion["name"] inversion_source = inversion["source"] inversion_format = inversion.get("format", None) - inversion_source, hf = fetch_model( + inversion_source = fetch_model( conversion, inversion_name, inversion_source, @@ -369,7 +332,7 @@ def convert_diffusion_model(conversion: ConversionContext, model): # load models if not loaded yet lora_name = lora["name"] lora_source = lora["source"] - lora_source, hf = fetch_model( + lora_source = fetch_model( conversion, f"{name}-lora-{lora_name}", lora_source, @@ -413,7 +376,7 @@ def convert_upscaling_model(conversion: ConversionContext, model): model_format = source_format(model) name = model["name"] - source, hf = fetch_model(conversion, name, model["source"], format=model_format) + source = fetch_model(conversion, name, model["source"], format=model_format) model_type = model.get("model", "resrgan") if model_type == "bsrgan": convert_upscaling_bsrgan(conversion, model, source) @@ -423,19 +386,19 @@ def convert_upscaling_model(conversion: ConversionContext, model): convert_upscaling_swinir(conversion, model, source) else: logger.error("unknown upscaling model type %s for %s", model_type, name) - model_errors.append(name) + raise ValueError(name) def convert_correction_model(conversion: ConversionContext, model): model_format = source_format(model) name = model["name"] - source, hf = fetch_model(conversion, name, model["source"], format=model_format) + source = fetch_model(conversion, name, model["source"], format=model_format) model_type = model.get("model", "gfpgan") if model_type == "gfpgan": convert_correction_gfpgan(conversion, model, source) else: logger.error("unknown correction model type %s for %s", model_type, name) - model_errors.append(name) + raise ValueError(name) def convert_models(conversion: ConversionContext, args, models: Models): diff --git a/api/onnx_web/convert/client/__init__.py b/api/onnx_web/convert/client/__init__.py index e69de29b..166905e6 100644 --- a/api/onnx_web/convert/client/__init__.py +++ b/api/onnx_web/convert/client/__init__.py @@ -0,0 +1,51 @@ +from .base import BaseClient +from .civitai import CivitaiClient +from .file import FileClient +from .http import HttpClient +from .huggingface import HuggingfaceClient +from ..utils import ConversionContext +from typing import Dict, Optional +from logging import getLogger +from os import path + +logger = getLogger(__name__) + + +model_sources: Dict[str, BaseClient] = { + CivitaiClient.protocol: CivitaiClient, + FileClient.protocol: FileClient, + HttpClient.insecure_protocol: HttpClient, + HttpClient.protocol: HttpClient, + HuggingfaceClient.protocol: HuggingfaceClient, +} + + +def add_model_source(proto: str, client: BaseClient): + global model_sources + + if proto in model_sources: + raise ValueError("protocol has already been taken") + + model_sources[proto] = client + + +def fetch_model( + conversion: ConversionContext, + name: str, + source: str, + format: Optional[str] = None, + dest: Optional[str] = None, +) -> str: + # TODO: switch to urlparse's default scheme + if source.startswith(path.sep) or source.startswith("."): + logger.info("adding file protocol to local path source: %s", source) + source = FileClient.protocol + source + + for proto, client_type in model_sources.items(): + if source.startswith(proto): + # TODO: fix type of client_type + client: BaseClient = client_type() + return client.download(conversion, name, source, format=format, dest=dest) + + logger.warning("unknown model protocol, using path as provided: %s", source) + return source diff --git a/api/onnx_web/convert/client/base.py b/api/onnx_web/convert/client/base.py index c38ec082..0e91dab7 100644 --- a/api/onnx_web/convert/client/base.py +++ b/api/onnx_web/convert/client/base.py @@ -1,4 +1,5 @@ from typing import Optional + from ..utils import ConversionContext @@ -8,6 +9,7 @@ class BaseClient: conversion: ConversionContext, name: str, source: str, - format: Optional[str], + format: Optional[str] = None, + dest: Optional[str] = None, ) -> str: raise NotImplementedError() diff --git a/api/onnx_web/convert/client/civitai.py b/api/onnx_web/convert/client/civitai.py index 1bc8c4a9..e56b49e5 100644 --- a/api/onnx_web/convert/client/civitai.py +++ b/api/onnx_web/convert/client/civitai.py @@ -1,3 +1,6 @@ +from logging import getLogger +from typing import Optional + from ..utils import ( ConversionContext, build_cache_paths, @@ -6,8 +9,6 @@ from ..utils import ( remove_prefix, ) from .base import BaseClient -from typing import Optional -from logging import getLogger logger = getLogger(__name__) @@ -28,13 +29,18 @@ class CivitaiClient(BaseClient): conversion: ConversionContext, name: str, source: str, - format: Optional[str], + format: Optional[str] = None, + dest: Optional[str] = None, ) -> str: """ TODO: download with auth token """ cache_paths = build_cache_paths( - conversion, name, client=CivitaiClient.name, format=format + conversion, + name, + client=CivitaiClient.name, + format=format, + dest=dest, ) cached = get_first_exists(cache_paths) if cached: diff --git a/api/onnx_web/convert/client/file.py b/api/onnx_web/convert/client/file.py index 8c44c0d0..142e503e 100644 --- a/api/onnx_web/convert/client/file.py +++ b/api/onnx_web/convert/client/file.py @@ -1,8 +1,11 @@ -from .base import BaseClient from logging import getLogger from os import path +from typing import Optional from urllib.parse import urlparse +from ..utils import ConversionContext +from .base import BaseClient + logger = getLogger(__name__) @@ -14,7 +17,14 @@ class FileClient(BaseClient): def __init__(self, root: str): self.root = root - def download(self, uri: str) -> str: + def download( + self, + _conversion: ConversionContext, + _name: str, + uri: str, + format: Optional[str] = None, + dest: Optional[str] = None, + ) -> str: parts = urlparse(uri) logger.info("loading model from: %s", parts.path) return path.join(self.root, parts.path) diff --git a/api/onnx_web/convert/client/http.py b/api/onnx_web/convert/client/http.py index 5259402f..d994dabe 100644 --- a/api/onnx_web/convert/client/http.py +++ b/api/onnx_web/convert/client/http.py @@ -1,3 +1,6 @@ +from logging import getLogger +from typing import Dict, Optional + from ..utils import ( ConversionContext, build_cache_paths, @@ -6,8 +9,6 @@ from ..utils import ( remove_prefix, ) from .base import BaseClient -from typing import Dict, Optional -from logging import getLogger logger = getLogger(__name__) @@ -22,9 +23,20 @@ class HttpClient(BaseClient): def __init__(self, headers: Optional[Dict[str, str]] = None): self.headers = headers or {} - def download(self, conversion: ConversionContext, name: str, uri: str) -> str: + def download( + self, + conversion: ConversionContext, + name: str, + uri: str, + format: Optional[str] = None, + dest: Optional[str] = None, + ) -> str: cache_paths = build_cache_paths( - conversion, name, client=HttpClient.name, format=format + conversion, + name, + client=HttpClient.name, + format=format, + dest=dest, ) cached = get_first_exists(cache_paths) if cached: diff --git a/api/onnx_web/convert/client/huggingface.py b/api/onnx_web/convert/client/huggingface.py index 61796e99..6d1064fd 100644 --- a/api/onnx_web/convert/client/huggingface.py +++ b/api/onnx_web/convert/client/huggingface.py @@ -1,3 +1,8 @@ +from logging import getLogger +from typing import Any, Optional + +from huggingface_hub.file_download import hf_hub_download + from ..utils import ( ConversionContext, build_cache_paths, @@ -5,9 +10,6 @@ from ..utils import ( remove_prefix, ) from .base import BaseClient -from typing import Optional, Any -from logging import getLogger -from huggingface_hub.file_download import hf_hub_download logger = getLogger(__name__) @@ -28,16 +30,23 @@ class HuggingfaceClient(BaseClient): conversion: ConversionContext, name: str, source: str, - format: Optional[str], + format: Optional[str] = None, + dest: Optional[str] = None, ) -> str: """ TODO: download with auth + TODO: set fetch and filename + if network_type == "inversion" and network_model == "concept": """ - hf_hub_fetch = TODO - hf_hub_filename = TODO + hf_hub_fetch = True + hf_hub_filename = "learned_embeds.bin" cache_paths = build_cache_paths( - conversion, name, client=HuggingfaceClient.name, format=format + conversion, + name, + client=HuggingfaceClient.name, + format=format, + dest=dest, ) cached = get_first_exists(cache_paths) if cached: @@ -46,7 +55,6 @@ class HuggingfaceClient(BaseClient): source = remove_prefix(source, HuggingfaceClient.protocol) logger.info("downloading model from Huggingface Hub: %s", source) - # from_pretrained has a bunch of useful logic that snapshot_download by itself down not if hf_hub_fetch: return ( hf_hub_download( @@ -58,4 +66,5 @@ class HuggingfaceClient(BaseClient): False, ) else: + # TODO: download pretrained because load doesn't call from_pretrained anymore return source diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 19e6a309..fd02f671 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -15,8 +15,6 @@ from onnxruntime.transformers.float16 import convert_float_to_float16 from packaging import version from torch.onnx import export -from onnx_web.convert.client.file import FileClient - from ..constants import ONNX_WEIGHTS from ..errors import RequestException from ..server import ServerContext @@ -33,6 +31,15 @@ ModelDict = Dict[str, Union[str, int]] LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]] DEFAULT_OPSET = 14 +DIFFUSION_PREFIX = [ + "diffusion-", + "diffusion/", + "diffusion\\", + "stable-diffusion-", + "upscaling-", # SD upscaling +] +MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"] +RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"] class ConversionContext(ServerContext): @@ -185,10 +192,6 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]): return model -MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"] -RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"] - - def check_ext(name: str, exts: List[str]) -> Tuple[bool, str]: _name, ext = path.splitext(name) ext = ext.strip(".") @@ -216,6 +219,9 @@ def remove_prefix(name: str, prefix: str) -> str: def load_torch(name: str, map_location=None) -> Optional[Dict]: + """ + TODO: move out of convert + """ try: logger.debug("loading tensor with Torch: %s", name) checkpoint = torch.load(name, map_location=map_location) @@ -229,6 +235,9 @@ def load_torch(name: str, map_location=None) -> Optional[Dict]: def load_tensor(name: str, map_location=None) -> Optional[Dict]: + """ + TODO: move out of convert + """ logger.debug("loading tensor: %s", name) _, extension = path.splitext(name) extension = extension[1:].lower() @@ -286,6 +295,9 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]: def resolve_tensor(name: str) -> Optional[str]: + """ + TODO: move out of convert + """ logger.debug("searching for tensors with known extensions: %s", name) for next_extension in RESOLVE_FORMATS: next_name = f"{name}.{next_extension}" @@ -348,15 +360,6 @@ def onnx_export( ) -DIFFUSION_PREFIX = [ - "diffusion-", - "diffusion/", - "diffusion\\", - "stable-diffusion-", - "upscaling-", -] - - def fix_diffusion_name(name: str): if not any([name.startswith(prefix) for prefix in DIFFUSION_PREFIX]): logger.warning( @@ -368,26 +371,6 @@ def fix_diffusion_name(name: str): return name -def fetch_model( - conversion: ConversionContext, - name: str, - source: str, - format: Optional[str], -) -> Tuple[str, bool]: - # TODO: switch to urlparse's default scheme - if source.startswith(path.sep) or source.startswith("."): - logger.info("adding file protocol to local path source: %s", source) - source = FileClient.protocol + source - - for proto, client_type in model_sources.items(): - if source.startswith(proto): - client = client_type() - return client.download(proto) - - logger.warning("unknown model protocol, using path as provided: %s", source) - return source, False - - def build_cache_paths( conversion: ConversionContext, name: str, diff --git a/api/tests/test_diffusers/test_run.py b/api/tests/test_diffusers/test_run.py index 26578f3e..322712e4 100644 --- a/api/tests/test_diffusers/test_run.py +++ b/api/tests/test_diffusers/test_run.py @@ -414,6 +414,7 @@ class TestBlendPipeline(unittest.TestCase): 3.0, 1, 1, + unet_tile=64, ), Size(64, 64), ["test-blend.png"], From 20b719fdff31436b41220e962ca52144f21f4d01 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 9 Dec 2023 19:15:28 -0600 Subject: [PATCH 07/42] special downloading for embeds --- api/onnx_web/convert/__main__.py | 28 ++++++----- api/onnx_web/convert/client/base.py | 1 + api/onnx_web/convert/client/civitai.py | 1 + api/onnx_web/convert/client/file.py | 1 + api/onnx_web/convert/client/http.py | 2 + api/onnx_web/convert/client/huggingface.py | 57 +++++++--------------- 6 files changed, 37 insertions(+), 53 deletions(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index e05a46b9..618ea6a5 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -202,7 +202,7 @@ base_models: Models = { } -def convert_source_model(conversion: ConversionContext, model): +def convert_model_source(conversion: ConversionContext, model): model_format = source_format(model) name = model["name"] source = model["source"] @@ -215,9 +215,10 @@ def convert_source_model(conversion: ConversionContext, model): logger.info("finished downloading source: %s -> %s", source, dest) -def convert_network_model(conversion: ConversionContext, network): - network_format = source_format(network) +def convert_model_network(conversion: ConversionContext, network): + format = source_format(network) name = network["name"] + model = network["model"] network_type = network["type"] source = network["source"] @@ -226,7 +227,7 @@ def convert_network_model(conversion: ConversionContext, network): conversion, name, source, - format=network_format, + format=format, ) convert_diffusion_control( @@ -241,13 +242,14 @@ def convert_network_model(conversion: ConversionContext, network): name, source, dest=path.join(conversion.model_path, network_type), - format=network_format, + format=format, + embeds=(network_type == "inversion" and model == "concept"), ) logger.info("finished downloading network: %s -> %s", source, dest) -def convert_diffusion_model(conversion: ConversionContext, model): +def convert_model_diffusion(conversion: ConversionContext, model): # fix up entries with missing prefixes name = fix_diffusion_name(model["name"]) if name != model["name"]: @@ -372,7 +374,7 @@ def convert_diffusion_model(conversion: ConversionContext, model): ) -def convert_upscaling_model(conversion: ConversionContext, model): +def convert_model_upscaling(conversion: ConversionContext, model): model_format = source_format(model) name = model["name"] @@ -389,7 +391,7 @@ def convert_upscaling_model(conversion: ConversionContext, model): raise ValueError(name) -def convert_correction_model(conversion: ConversionContext, model): +def convert_model_correction(conversion: ConversionContext, model): model_format = source_format(model) name = model["name"] source = fetch_model(conversion, name, model["source"], format=model_format) @@ -413,7 +415,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): logger.info("skipping source: %s", name) else: try: - convert_source_model(model) + convert_model_source(model) except Exception: logger.exception("error fetching source %s", name) model_errors.append(name) @@ -426,7 +428,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): logger.info("skipping network: %s", name) else: try: - convert_network_model(conversion, model) + convert_model_network(conversion, model) except Exception: logger.exception("error fetching network %s", name) model_errors.append(name) @@ -440,7 +442,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): logger.info("skipping model: %s", name) else: try: - convert_diffusion_model(conversion, model) + convert_model_diffusion(conversion, model) except Exception: logger.exception( "error converting diffusion model %s", @@ -457,7 +459,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): logger.info("skipping model: %s", name) else: try: - convert_upscaling_model(conversion, model) + convert_model_upscaling(conversion, model) except Exception: logger.exception( "error converting upscaling model %s", @@ -474,7 +476,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): logger.info("skipping model: %s", name) else: try: - convert_correction_model(conversion, model) + convert_model_correction(conversion, model) except Exception: logger.exception( "error converting correction model %s", diff --git a/api/onnx_web/convert/client/base.py b/api/onnx_web/convert/client/base.py index 0e91dab7..2ec2ae65 100644 --- a/api/onnx_web/convert/client/base.py +++ b/api/onnx_web/convert/client/base.py @@ -11,5 +11,6 @@ class BaseClient: source: str, format: Optional[str] = None, dest: Optional[str] = None, + **kwargs, ) -> str: raise NotImplementedError() diff --git a/api/onnx_web/convert/client/civitai.py b/api/onnx_web/convert/client/civitai.py index e56b49e5..706e5c34 100644 --- a/api/onnx_web/convert/client/civitai.py +++ b/api/onnx_web/convert/client/civitai.py @@ -31,6 +31,7 @@ class CivitaiClient(BaseClient): source: str, format: Optional[str] = None, dest: Optional[str] = None, + **kwargs, ) -> str: """ TODO: download with auth token diff --git a/api/onnx_web/convert/client/file.py b/api/onnx_web/convert/client/file.py index 142e503e..5619b35b 100644 --- a/api/onnx_web/convert/client/file.py +++ b/api/onnx_web/convert/client/file.py @@ -24,6 +24,7 @@ class FileClient(BaseClient): uri: str, format: Optional[str] = None, dest: Optional[str] = None, + **kwargs, ) -> str: parts = urlparse(uri) logger.info("loading model from: %s", parts.path) diff --git a/api/onnx_web/convert/client/http.py b/api/onnx_web/convert/client/http.py index d994dabe..9d9f840c 100644 --- a/api/onnx_web/convert/client/http.py +++ b/api/onnx_web/convert/client/http.py @@ -30,6 +30,7 @@ class HttpClient(BaseClient): uri: str, format: Optional[str] = None, dest: Optional[str] = None, + **kwargs, ) -> str: cache_paths = build_cache_paths( conversion, @@ -46,6 +47,7 @@ class HttpClient(BaseClient): source = remove_prefix(uri, HttpClient.protocol) logger.info("downloading model from: %s", source) elif uri.startswith(HttpClient.insecure_protocol): + source = remove_prefix(uri, HttpClient.insecure_protocol) logger.warning("downloading model from insecure source: %s", source) return download_progress(source, cache_paths[0]) diff --git a/api/onnx_web/convert/client/huggingface.py b/api/onnx_web/convert/client/huggingface.py index 6d1064fd..6db97928 100644 --- a/api/onnx_web/convert/client/huggingface.py +++ b/api/onnx_web/convert/client/huggingface.py @@ -1,14 +1,10 @@ from logging import getLogger -from typing import Any, Optional +from typing import Optional +from huggingface_hub import snapshot_download from huggingface_hub.file_download import hf_hub_download -from ..utils import ( - ConversionContext, - build_cache_paths, - get_first_exists, - remove_prefix, -) +from ..utils import ConversionContext, remove_prefix from .base import BaseClient logger = getLogger(__name__) @@ -18,11 +14,9 @@ class HuggingfaceClient(BaseClient): name = "huggingface" protocol = "huggingface://" - download: Any token: Optional[str] - def __init__(self, token: Optional[str] = None, download=hf_hub_download): - self.download = download + def __init__(self, token: Optional[str] = None): self.token = token def download( @@ -32,39 +26,22 @@ class HuggingfaceClient(BaseClient): source: str, format: Optional[str] = None, dest: Optional[str] = None, + embeds: bool = False, + **kwargs, ) -> str: - """ - TODO: download with auth - TODO: set fetch and filename - if network_type == "inversion" and network_model == "concept": - """ - hf_hub_fetch = True - hf_hub_filename = "learned_embeds.bin" - - cache_paths = build_cache_paths( - conversion, - name, - client=HuggingfaceClient.name, - format=format, - dest=dest, - ) - cached = get_first_exists(cache_paths) - if cached: - return cached - source = remove_prefix(source, HuggingfaceClient.protocol) logger.info("downloading model from Huggingface Hub: %s", source) - if hf_hub_fetch: - return ( - hf_hub_download( - repo_id=source, - filename=hf_hub_filename, - cache_dir=cache_paths[0], - force_filename=f"{name}.bin", - ), - False, + if embeds: + return hf_hub_download( + repo_id=source, + filename="learned_embeds.bin", + cache_dir=conversion.cache_path, + force_filename=f"{name}.bin", + token=self.token, ) else: - # TODO: download pretrained because load doesn't call from_pretrained anymore - return source + return snapshot_download( + repo_id=source, + token=self.token, + ) From 4c5bb906e89d6e47478529fc86203dd76074cadf Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 9 Dec 2023 20:36:04 -0600 Subject: [PATCH 08/42] pass conversion context to generic sources --- api/onnx_web/convert/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 618ea6a5..f3c02953 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -415,7 +415,7 @@ def convert_models(conversion: ConversionContext, args, models: Models): logger.info("skipping source: %s", name) else: try: - convert_model_source(model) + convert_model_source(conversion, model) except Exception: logger.exception("error fetching source %s", name) model_errors.append(name) From 8bb76f1fee1afb841525de85c7897db30c1b3413 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 9 Dec 2023 20:50:41 -0600 Subject: [PATCH 09/42] keep protocol when downloading from http sources --- api/onnx_web/convert/client/http.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/api/onnx_web/convert/client/http.py b/api/onnx_web/convert/client/http.py index 9d9f840c..90a5e659 100644 --- a/api/onnx_web/convert/client/http.py +++ b/api/onnx_web/convert/client/http.py @@ -27,7 +27,7 @@ class HttpClient(BaseClient): self, conversion: ConversionContext, name: str, - uri: str, + source: str, format: Optional[str] = None, dest: Optional[str] = None, **kwargs, @@ -43,11 +43,9 @@ class HttpClient(BaseClient): if cached: return cached - if uri.startswith(HttpClient.protocol): - source = remove_prefix(uri, HttpClient.protocol) + if source.startswith(HttpClient.protocol): logger.info("downloading model from: %s", source) - elif uri.startswith(HttpClient.insecure_protocol): - source = remove_prefix(uri, HttpClient.insecure_protocol) + elif source.startswith(HttpClient.insecure_protocol): logger.warning("downloading model from insecure source: %s", source) return download_progress(source, cache_paths[0]) From 6cdba4cebbd0f21449283607b05b64114e8ad68f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 9 Dec 2023 21:05:16 -0600 Subject: [PATCH 10/42] handle extensions correctly --- api/onnx_web/convert/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index fd02f671..015b13cd 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -384,7 +384,7 @@ def build_cache_paths( if format is not None: basename = path.basename(name) _filename, ext = path.splitext(basename) - if ext is None: + if ext is None or ext == '': name = f"{name}.{format}" paths = [ From ebb5a586ce83103ac5e8c55def2143603f576619 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 9 Dec 2023 22:53:29 -0600 Subject: [PATCH 11/42] get file root path from context, avoid downloading entire HF repos --- api/onnx_web/convert/client/file.py | 9 ++------- api/onnx_web/convert/client/huggingface.py | 5 +---- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/api/onnx_web/convert/client/file.py b/api/onnx_web/convert/client/file.py index 5619b35b..7457e28b 100644 --- a/api/onnx_web/convert/client/file.py +++ b/api/onnx_web/convert/client/file.py @@ -12,14 +12,9 @@ logger = getLogger(__name__) class FileClient(BaseClient): protocol = "file://" - root: str - - def __init__(self, root: str): - self.root = root - def download( self, - _conversion: ConversionContext, + conversion: ConversionContext, _name: str, uri: str, format: Optional[str] = None, @@ -28,4 +23,4 @@ class FileClient(BaseClient): ) -> str: parts = urlparse(uri) logger.info("loading model from: %s", parts.path) - return path.join(self.root, parts.path) + return path.join(conversion.model_path, parts.path) diff --git a/api/onnx_web/convert/client/huggingface.py b/api/onnx_web/convert/client/huggingface.py index 6db97928..60f9e08f 100644 --- a/api/onnx_web/convert/client/huggingface.py +++ b/api/onnx_web/convert/client/huggingface.py @@ -41,7 +41,4 @@ class HuggingfaceClient(BaseClient): token=self.token, ) else: - return snapshot_download( - repo_id=source, - token=self.token, - ) + return source From e052578a20d24537293f7375545d7564f7fda112 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 9 Dec 2023 23:03:41 -0600 Subject: [PATCH 12/42] lint, give civitai client a name --- api/onnx_web/convert/client/civitai.py | 2 ++ api/onnx_web/convert/client/http.py | 1 - api/onnx_web/convert/client/huggingface.py | 1 - api/onnx_web/convert/utils.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) diff --git a/api/onnx_web/convert/client/civitai.py b/api/onnx_web/convert/client/civitai.py index 706e5c34..1ae5ecfd 100644 --- a/api/onnx_web/convert/client/civitai.py +++ b/api/onnx_web/convert/client/civitai.py @@ -16,7 +16,9 @@ CIVITAI_ROOT = "https://civitai.com/api/download/models/%s" class CivitaiClient(BaseClient): + name = "civitai" protocol = "civitai://" + root: str token: Optional[str] diff --git a/api/onnx_web/convert/client/http.py b/api/onnx_web/convert/client/http.py index 90a5e659..591e3a6c 100644 --- a/api/onnx_web/convert/client/http.py +++ b/api/onnx_web/convert/client/http.py @@ -6,7 +6,6 @@ from ..utils import ( build_cache_paths, download_progress, get_first_exists, - remove_prefix, ) from .base import BaseClient diff --git a/api/onnx_web/convert/client/huggingface.py b/api/onnx_web/convert/client/huggingface.py index 60f9e08f..68a6a853 100644 --- a/api/onnx_web/convert/client/huggingface.py +++ b/api/onnx_web/convert/client/huggingface.py @@ -1,7 +1,6 @@ from logging import getLogger from typing import Optional -from huggingface_hub import snapshot_download from huggingface_hub.file_download import hf_hub_download from ..utils import ConversionContext, remove_prefix diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 015b13cd..7785007b 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -384,7 +384,7 @@ def build_cache_paths( if format is not None: basename = path.basename(name) _filename, ext = path.splitext(basename) - if ext is None or ext == '': + if ext is None or ext == "": name = f"{name}.{format}" paths = [ From 52e78748c874ed11c27010b2c3f2248a062fae48 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 9 Dec 2023 23:05:41 -0600 Subject: [PATCH 13/42] keep model field optional --- api/onnx_web/convert/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index f3c02953..eefa6052 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -218,7 +218,6 @@ def convert_model_source(conversion: ConversionContext, model): def convert_model_network(conversion: ConversionContext, network): format = source_format(network) name = network["name"] - model = network["model"] network_type = network["type"] source = network["source"] @@ -237,6 +236,7 @@ def convert_model_network(conversion: ConversionContext, network): path.join(conversion.model_path, network_type, name), ) else: + model = network.get("model", None) dest = fetch_model( conversion, name, From 35c973e55f1736363415f2e9826118f6d1bbe4da Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 9 Dec 2023 23:07:56 -0600 Subject: [PATCH 14/42] pass correct metadata to network converter --- api/onnx_web/convert/__main__.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index eefa6052..bc9b1c75 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -215,11 +215,11 @@ def convert_model_source(conversion: ConversionContext, model): logger.info("finished downloading source: %s -> %s", source, dest) -def convert_model_network(conversion: ConversionContext, network): - format = source_format(network) - name = network["name"] - network_type = network["type"] - source = network["source"] +def convert_model_network(conversion: ConversionContext, model): + format = source_format(model) + name = model["name"] + network_type = model["type"] + source = model["source"] if network_type == "control": dest = fetch_model( @@ -231,12 +231,12 @@ def convert_model_network(conversion: ConversionContext, network): convert_diffusion_control( conversion, - network, + model, dest, path.join(conversion.model_path, network_type, name), ) else: - model = network.get("model", None) + model = model.get("model", None) dest = fetch_model( conversion, name, @@ -421,8 +421,8 @@ def convert_models(conversion: ConversionContext, args, models: Models): model_errors.append(name) if args.networks and "networks" in models: - for network in models.get("networks", []): - name = network["name"] + for model in models.get("networks", []): + name = model["name"] if name in args.skip: logger.info("skipping network: %s", name) From 22e05979169a1a52a4593b809d5755739763e526 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 9 Dec 2023 23:09:00 -0600 Subject: [PATCH 15/42] pass kwargs on to client --- api/onnx_web/convert/client/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/convert/client/__init__.py b/api/onnx_web/convert/client/__init__.py index 166905e6..23a71177 100644 --- a/api/onnx_web/convert/client/__init__.py +++ b/api/onnx_web/convert/client/__init__.py @@ -35,6 +35,7 @@ def fetch_model( source: str, format: Optional[str] = None, dest: Optional[str] = None, + **kwargs, ) -> str: # TODO: switch to urlparse's default scheme if source.startswith(path.sep) or source.startswith("."): @@ -45,7 +46,7 @@ def fetch_model( if source.startswith(proto): # TODO: fix type of client_type client: BaseClient = client_type() - return client.download(conversion, name, source, format=format, dest=dest) + return client.download(conversion, name, source, format=format, dest=dest, **kwargs) logger.warning("unknown model protocol, using path as provided: %s", source) return source From 419b2811ef5cbcfe2b960e44171d2b6cd24cd575 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 9 Dec 2023 23:40:27 -0600 Subject: [PATCH 16/42] feat: update min CFG for SDXL turbo --- README.md | 4 ++-- api/params.json | 2 +- gui/src/config.json | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ed10a950..14bc1351 100644 --- a/README.md +++ b/README.md @@ -23,8 +23,8 @@ details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md). This is an incomplete list of new and interesting features, with links to the user guide: -- SDXL support -- LCM support +- supports SDXL and SDXL Turbo support +- supports most schedulers: DDIM, DEIS, Euler, LCM, UniPC, and more - hardware acceleration on both AMD and Nvidia - tested on CUDA, DirectML, and ROCm - [half-precision support for low-memory GPUs](docs/user-guide.md#optimizing-models-for-lower-memory-usage) on both diff --git a/api/params.json b/api/params.json index d9cd6fa0..7b0c746f 100644 --- a/api/params.json +++ b/api/params.json @@ -14,7 +14,7 @@ }, "cfg": { "default": 6, - "min": 1, + "min": 0, "max": 30, "step": 0.1 }, diff --git a/gui/src/config.json b/gui/src/config.json index 91cccb8e..d2d63098 100644 --- a/gui/src/config.json +++ b/gui/src/config.json @@ -18,7 +18,7 @@ }, "cfg": { "default": 6, - "min": 1, + "min": 0, "max": 30, "step": 0.1 }, From 50db19922abfd88294bcbeb7159f19dcc93012fe Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 9 Dec 2023 23:40:47 -0600 Subject: [PATCH 17/42] fix type of client instance --- api/onnx_web/convert/client/__init__.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/api/onnx_web/convert/client/__init__.py b/api/onnx_web/convert/client/__init__.py index 23a71177..6047b2b3 100644 --- a/api/onnx_web/convert/client/__init__.py +++ b/api/onnx_web/convert/client/__init__.py @@ -1,17 +1,18 @@ +from typing import Callable, Dict, Optional +from logging import getLogger +from os import path + +from ..utils import ConversionContext from .base import BaseClient from .civitai import CivitaiClient from .file import FileClient from .http import HttpClient from .huggingface import HuggingfaceClient -from ..utils import ConversionContext -from typing import Dict, Optional -from logging import getLogger -from os import path logger = getLogger(__name__) -model_sources: Dict[str, BaseClient] = { +model_sources: Dict[str, Callable[[], BaseClient]] = { CivitaiClient.protocol: CivitaiClient, FileClient.protocol: FileClient, HttpClient.insecure_protocol: HttpClient, @@ -44,9 +45,10 @@ def fetch_model( for proto, client_type in model_sources.items(): if source.startswith(proto): - # TODO: fix type of client_type - client: BaseClient = client_type() - return client.download(conversion, name, source, format=format, dest=dest, **kwargs) + client = client_type() + return client.download( + conversion, name, source, format=format, dest=dest, **kwargs + ) logger.warning("unknown model protocol, using path as provided: %s", source) return source From 2fc5ec930c30021dfdbe5b9e5b8b40da9346ae6f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 9 Dec 2023 23:51:09 -0600 Subject: [PATCH 18/42] feat(api): add support for DPM SDE scheduler --- api/onnx_web/diffusers/load.py | 2 ++ api/onnx_web/diffusers/version_safe_diffusers.py | 5 +++++ gui/src/strings/en.ts | 1 + 3 files changed, 8 insertions(+) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 30a87863..2aaf2154 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -30,6 +30,7 @@ from .version_safe_diffusers import ( DDPMScheduler, DEISMultistepScheduler, DPMSolverMultistepScheduler, + DPMSolverSDEScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, @@ -71,6 +72,7 @@ pipeline_schedulers = { "ddpm": DDPMScheduler, "deis-multi": DEISMultistepScheduler, "dpm-multi": DPMSolverMultistepScheduler, + "dpm-sde": DPMSolverSDEScheduler, "dpm-single": DPMSolverSinglestepScheduler, "euler": EulerDiscreteScheduler, "euler-a": EulerAncestralDiscreteScheduler, diff --git a/api/onnx_web/diffusers/version_safe_diffusers.py b/api/onnx_web/diffusers/version_safe_diffusers.py index d256d615..3d51ea4b 100644 --- a/api/onnx_web/diffusers/version_safe_diffusers.py +++ b/api/onnx_web/diffusers/version_safe_diffusers.py @@ -12,6 +12,11 @@ try: except ImportError: from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler +try: + from diffusers import DPMSolverSDEScheduler +except: + from ..diffusers.stub_scheduler import StubScheduler as DPMSolverSDEScheduler + try: from diffusers import LCMScheduler except ImportError: diff --git a/gui/src/strings/en.ts b/gui/src/strings/en.ts index c9d901f9..304c3335 100644 --- a/gui/src/strings/en.ts +++ b/gui/src/strings/en.ts @@ -287,6 +287,7 @@ export const I18N_STRINGS_EN = { 'ddpm': 'DDPM', 'deis-multi': 'DEIS Multistep', 'dpm-multi': 'DPM Multistep', + 'dpm-sde': 'DPM SDE (Turbo)', 'dpm-single': 'DPM Singlestep', 'euler': 'Euler', 'euler-a': 'Euler Ancestral', From d9c5c1bd45c5a2250e025afdd1bba5c955b41b42 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 9 Dec 2023 23:51:43 -0600 Subject: [PATCH 19/42] lint --- api/onnx_web/diffusers/version_safe_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/onnx_web/diffusers/version_safe_diffusers.py b/api/onnx_web/diffusers/version_safe_diffusers.py index 3d51ea4b..8c4da406 100644 --- a/api/onnx_web/diffusers/version_safe_diffusers.py +++ b/api/onnx_web/diffusers/version_safe_diffusers.py @@ -14,7 +14,7 @@ except ImportError: try: from diffusers import DPMSolverSDEScheduler -except: +except ImportError: from ..diffusers.stub_scheduler import StubScheduler as DPMSolverSDEScheduler try: From b22e54eb27beb16590a3422a42dd95a5a282f55b Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Dec 2023 00:10:34 -0600 Subject: [PATCH 20/42] add DPM SDE to readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 14bc1351..d4ca0164 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md). This is an incomplete list of new and interesting features, with links to the user guide: - supports SDXL and SDXL Turbo support -- supports most schedulers: DDIM, DEIS, Euler, LCM, UniPC, and more +- supports most schedulers: DDIM, DEIS, DPM SDE, Euler Ancestral, LCM, UniPC, and more - hardware acceleration on both AMD and Nvidia - tested on CUDA, DirectML, and ROCm - [half-precision support for low-memory GPUs](docs/user-guide.md#optimizing-models-for-lower-memory-usage) on both From 0f330652b523a3853e562a1bcea632dada15de27 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Dec 2023 00:19:20 -0600 Subject: [PATCH 21/42] phrasing and grammar --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d4ca0164..e36ae18a 100644 --- a/README.md +++ b/README.md @@ -23,8 +23,8 @@ details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md). This is an incomplete list of new and interesting features, with links to the user guide: -- supports SDXL and SDXL Turbo support -- supports most schedulers: DDIM, DEIS, DPM SDE, Euler Ancestral, LCM, UniPC, and more +- supports SDXL and SDXL Turbo +- wide variety of schedulers: DDIM, DEIS, DPM SDE, Euler Ancestral, LCM, UniPC, and more - hardware acceleration on both AMD and Nvidia - tested on CUDA, DirectML, and ROCm - [half-precision support for low-memory GPUs](docs/user-guide.md#optimizing-models-for-lower-memory-usage) on both From e6978243cc182f3ac858402731c7d0f682041e60 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Dec 2023 10:57:25 -0600 Subject: [PATCH 22/42] fix download location for networks --- api/onnx_web/convert/client/huggingface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/onnx_web/convert/client/huggingface.py b/api/onnx_web/convert/client/huggingface.py index 68a6a853..02044e50 100644 --- a/api/onnx_web/convert/client/huggingface.py +++ b/api/onnx_web/convert/client/huggingface.py @@ -35,7 +35,7 @@ class HuggingfaceClient(BaseClient): return hf_hub_download( repo_id=source, filename="learned_embeds.bin", - cache_dir=conversion.cache_path, + cache_dir=dest or conversion.cache_path, force_filename=f"{name}.bin", token=self.token, ) From 4edb32aaac1a4881d962d2caf1e90b947558f58f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Dec 2023 11:08:02 -0600 Subject: [PATCH 23/42] feat(api): add env vars for Civitai and Huggingface auth tokens --- api/onnx_web/convert/client/civitai.py | 19 +++++--- api/onnx_web/convert/client/file.py | 3 ++ api/onnx_web/convert/client/http.py | 4 +- api/onnx_web/convert/client/huggingface.py | 4 +- api/onnx_web/server/context.py | 56 +++++++++++----------- 5 files changed, 50 insertions(+), 36 deletions(-) diff --git a/api/onnx_web/convert/client/civitai.py b/api/onnx_web/convert/client/civitai.py index 1ae5ecfd..ebccb223 100644 --- a/api/onnx_web/convert/client/civitai.py +++ b/api/onnx_web/convert/client/civitai.py @@ -22,9 +22,14 @@ class CivitaiClient(BaseClient): root: str token: Optional[str] - def __init__(self, token: Optional[str] = None, root=CIVITAI_ROOT): - self.root = root - self.token = token + def __init__( + self, + conversion: ConversionContext, + token: Optional[str] = None, + root=CIVITAI_ROOT, + ): + self.root = conversion.get_setting("CIVITAI_ROOT", root) + self.token = conversion.get_setting("CIVITAI_TOKEN", token) def download( self, @@ -35,9 +40,6 @@ class CivitaiClient(BaseClient): dest: Optional[str] = None, **kwargs, ) -> str: - """ - TODO: download with auth token - """ cache_paths = build_cache_paths( conversion, name, @@ -51,4 +53,9 @@ class CivitaiClient(BaseClient): source = self.root % (remove_prefix(source, CivitaiClient.protocol)) logger.info("downloading model from Civitai: %s -> %s", source, cache_paths[0]) + + if self.token: + logger.debug("adding Civitai token authentication") + source = f"{source}?token={self.token}" + return download_progress(source, cache_paths[0]) diff --git a/api/onnx_web/convert/client/file.py b/api/onnx_web/convert/client/file.py index 7457e28b..0ae41b46 100644 --- a/api/onnx_web/convert/client/file.py +++ b/api/onnx_web/convert/client/file.py @@ -12,6 +12,9 @@ logger = getLogger(__name__) class FileClient(BaseClient): protocol = "file://" + def __init__(self, _conversion: ConversionContext): + pass + def download( self, conversion: ConversionContext, diff --git a/api/onnx_web/convert/client/http.py b/api/onnx_web/convert/client/http.py index 591e3a6c..151ebccd 100644 --- a/api/onnx_web/convert/client/http.py +++ b/api/onnx_web/convert/client/http.py @@ -19,7 +19,9 @@ class HttpClient(BaseClient): headers: Dict[str, str] - def __init__(self, headers: Optional[Dict[str, str]] = None): + def __init__( + self, _conversion: ConversionContext, headers: Optional[Dict[str, str]] = None + ): self.headers = headers or {} def download( diff --git a/api/onnx_web/convert/client/huggingface.py b/api/onnx_web/convert/client/huggingface.py index 02044e50..c35698e7 100644 --- a/api/onnx_web/convert/client/huggingface.py +++ b/api/onnx_web/convert/client/huggingface.py @@ -15,8 +15,8 @@ class HuggingfaceClient(BaseClient): token: Optional[str] - def __init__(self, token: Optional[str] = None): - self.token = token + def __init__(self, conversion: ConversionContext, token: Optional[str] = None): + self.token = conversion.get_setting("HUGGINGFACE_TOKEN", token) def download( self, diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index 034fc3c6..66a8946a 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -1,7 +1,7 @@ from logging import getLogger from os import environ, path from secrets import token_urlsafe -from typing import List, Optional +from typing import Dict, List, Optional import torch @@ -42,6 +42,7 @@ class ServerContext: feature_flags: List[str] plugins: List[str] debug: bool + env: Dict[str, str] def __init__( self, @@ -67,6 +68,7 @@ class ServerContext: feature_flags: Optional[List[str]] = None, plugins: Optional[List[str]] = None, debug: bool = False, + env: Dict[str, str] = environ, ) -> None: self.bundle_path = bundle_path self.model_path = model_path @@ -90,49 +92,49 @@ class ServerContext: self.feature_flags = feature_flags or [] self.plugins = plugins or [] self.debug = debug + self.env = env self.cache = ModelCache(self.cache_limit) @classmethod - def from_environ(cls): - memory_limit = environ.get("ONNX_WEB_MEMORY_LIMIT", None) + def from_environ(cls, env=environ): + memory_limit = env.get("ONNX_WEB_MEMORY_LIMIT", None) if memory_limit is not None: memory_limit = int(memory_limit) return cls( - bundle_path=environ.get( - "ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out") - ), - model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")), - output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")), - params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."), - cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", default="*"), + bundle_path=env.get("ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")), + model_path=env.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")), + output_path=env.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")), + params_path=env.get("ONNX_WEB_PARAMS_PATH", "."), + cors_origin=get_list(env, "ONNX_WEB_CORS_ORIGIN", default="*"), any_platform=get_boolean( - environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM + env, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM ), - block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS"), - default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), - image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT), - cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)), + block_platforms=get_list(env, "ONNX_WEB_BLOCK_PLATFORMS"), + default_platform=env.get("ONNX_WEB_DEFAULT_PLATFORM", None), + image_format=env.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT), + cache_limit=int(env.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)), show_progress=get_boolean( - environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS + env, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS ), - optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS"), - extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS"), - job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), + optimizations=get_list(env, "ONNX_WEB_OPTIMIZATIONS"), + extra_models=get_list(env, "ONNX_WEB_EXTRA_MODELS"), + job_limit=int(env.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), memory_limit=memory_limit, - admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None), - server_version=environ.get( - "ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION - ), + admin_token=env.get("ONNX_WEB_ADMIN_TOKEN", None), + server_version=env.get("ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION), worker_retries=int( - environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES) + env.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES) ), - feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS"), - plugins=get_list(environ, "ONNX_WEB_PLUGINS", ""), - debug=get_boolean(environ, "ONNX_WEB_DEBUG", False), + feature_flags=get_list(env, "ONNX_WEB_FEATURE_FLAGS"), + plugins=get_list(env, "ONNX_WEB_PLUGINS", ""), + debug=get_boolean(env, "ONNX_WEB_DEBUG", False), ) + def get_setting(self, flag: str, default: str) -> Optional[str]: + return self.env.get(f"ONNX_WEB_{flag}", default) + def has_feature(self, flag: str) -> bool: return flag in self.feature_flags From 3b32cd4ac316351f308b390795280fc7f1940dbd Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Dec 2023 11:45:52 -0600 Subject: [PATCH 24/42] sonar lint --- api/onnx_web/convert/__main__.py | 22 +++++++++++----------- api/onnx_web/convert/client/file.py | 3 +++ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index bc9b1c75..34be1420 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -216,24 +216,24 @@ def convert_model_source(conversion: ConversionContext, model): def convert_model_network(conversion: ConversionContext, model): - format = source_format(model) + model_format = source_format(model) + model_type = model["type"] name = model["name"] - network_type = model["type"] source = model["source"] - if network_type == "control": + if model_type == "control": dest = fetch_model( conversion, name, source, - format=format, + format=model_format, ) convert_diffusion_control( conversion, model, dest, - path.join(conversion.model_path, network_type, name), + path.join(conversion.model_path, model_type, name), ) else: model = model.get("model", None) @@ -241,9 +241,9 @@ def convert_model_network(conversion: ConversionContext, model): conversion, name, source, - dest=path.join(conversion.model_path, network_type), - format=format, - embeds=(network_type == "inversion" and model == "concept"), + dest=path.join(conversion.model_path, model_type), + format=model_format, + embeds=(model_type == "inversion" and model == "concept"), ) logger.info("finished downloading network: %s -> %s", source, dest) @@ -256,8 +256,8 @@ def convert_model_diffusion(conversion: ConversionContext, model): # update the model in-memory if the name changed model["name"] = name - format = source_format(model) - dest = fetch_model(conversion, name, model["source"], format=format) + model_format = source_format(model) + dest = fetch_model(conversion, name, model["source"], format=model_format) pipeline = model.get("pipeline", "txt2img") converter = model_converters.get(pipeline) @@ -265,7 +265,7 @@ def convert_model_diffusion(conversion: ConversionContext, model): conversion, model, dest, - format, + model_format, ) # make sure blending only happens once, not every run diff --git a/api/onnx_web/convert/client/file.py b/api/onnx_web/convert/client/file.py index 0ae41b46..63077774 100644 --- a/api/onnx_web/convert/client/file.py +++ b/api/onnx_web/convert/client/file.py @@ -13,6 +13,9 @@ class FileClient(BaseClient): protocol = "file://" def __init__(self, _conversion: ConversionContext): + """ + Nothing to initialize for this client. + """ pass def download( From 9b883de1cb1d8f2a0e9f78b28d9bb4f14961812f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Dec 2023 12:16:01 -0600 Subject: [PATCH 25/42] pass context to client ctor --- api/onnx_web/convert/client/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/onnx_web/convert/client/__init__.py b/api/onnx_web/convert/client/__init__.py index 6047b2b3..1d90f04e 100644 --- a/api/onnx_web/convert/client/__init__.py +++ b/api/onnx_web/convert/client/__init__.py @@ -45,7 +45,7 @@ def fetch_model( for proto, client_type in model_sources.items(): if source.startswith(proto): - client = client_type() + client = client_type(conversion) return client.download( conversion, name, source, format=format, dest=dest, **kwargs ) From e200fe9186e666a201bdc79296e9b08ded3cc8ac Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Dec 2023 13:07:02 -0600 Subject: [PATCH 26/42] avoid pickling environ --- api/onnx_web/server/context.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index 66a8946a..e02cebd5 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -1,7 +1,7 @@ from logging import getLogger from os import environ, path from secrets import token_urlsafe -from typing import Dict, List, Optional +from typing import List, Optional import torch @@ -42,7 +42,6 @@ class ServerContext: feature_flags: List[str] plugins: List[str] debug: bool - env: Dict[str, str] def __init__( self, @@ -68,7 +67,6 @@ class ServerContext: feature_flags: Optional[List[str]] = None, plugins: Optional[List[str]] = None, debug: bool = False, - env: Dict[str, str] = environ, ) -> None: self.bundle_path = bundle_path self.model_path = model_path @@ -92,7 +90,6 @@ class ServerContext: self.feature_flags = feature_flags or [] self.plugins = plugins or [] self.debug = debug - self.env = env self.cache = ModelCache(self.cache_limit) @@ -133,7 +130,7 @@ class ServerContext: ) def get_setting(self, flag: str, default: str) -> Optional[str]: - return self.env.get(f"ONNX_WEB_{flag}", default) + return environ.get(f"ONNX_WEB_{flag}", default) def has_feature(self, flag: str) -> bool: return flag in self.feature_flags From e91e08484bf333d5400ab4fcfe916f829039a8a1 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Dec 2023 13:17:37 -0600 Subject: [PATCH 27/42] handle query params better in civitai client --- api/onnx_web/convert/client/civitai.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/convert/client/civitai.py b/api/onnx_web/convert/client/civitai.py index ebccb223..6bbb1be5 100644 --- a/api/onnx_web/convert/client/civitai.py +++ b/api/onnx_web/convert/client/civitai.py @@ -56,6 +56,9 @@ class CivitaiClient(BaseClient): if self.token: logger.debug("adding Civitai token authentication") - source = f"{source}?token={self.token}" + if "?" in source: + source = f"{source}&token={self.token}" + else: + source = f"{source}?token={self.token}" return download_progress(source, cache_paths[0]) From c9b1df9fdd0517b33230cac9b0e271ef66da3404 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Dec 2023 13:30:27 -0600 Subject: [PATCH 28/42] use dest in file client if provided --- api/onnx_web/convert/client/civitai.py | 2 +- api/onnx_web/convert/client/file.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/onnx_web/convert/client/civitai.py b/api/onnx_web/convert/client/civitai.py index 6bbb1be5..b57b6533 100644 --- a/api/onnx_web/convert/client/civitai.py +++ b/api/onnx_web/convert/client/civitai.py @@ -26,7 +26,7 @@ class CivitaiClient(BaseClient): self, conversion: ConversionContext, token: Optional[str] = None, - root=CIVITAI_ROOT, + root: str = CIVITAI_ROOT, ): self.root = conversion.get_setting("CIVITAI_ROOT", root) self.token = conversion.get_setting("CIVITAI_TOKEN", token) diff --git a/api/onnx_web/convert/client/file.py b/api/onnx_web/convert/client/file.py index 63077774..8ed1a4d9 100644 --- a/api/onnx_web/convert/client/file.py +++ b/api/onnx_web/convert/client/file.py @@ -29,4 +29,4 @@ class FileClient(BaseClient): ) -> str: parts = urlparse(uri) logger.info("loading model from: %s", parts.path) - return path.join(conversion.model_path, parts.path) + return path.join(dest or conversion.model_path, parts.path) From 9c1fcd16fa31318b85aa5e42175c93e4421d1f89 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Dec 2023 13:52:52 -0600 Subject: [PATCH 29/42] fix(api): only fetch diffusion models if they have not already been converted (#398) --- api/onnx_web/convert/__main__.py | 2 - api/onnx_web/convert/diffusion/diffusion.py | 88 +++++++++++-------- .../convert/diffusion/diffusion_xl.py | 9 +- 3 files changed, 55 insertions(+), 44 deletions(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 34be1420..1a4f0241 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -257,14 +257,12 @@ def convert_model_diffusion(conversion: ConversionContext, model): model["name"] = name model_format = source_format(model) - dest = fetch_model(conversion, name, model["source"], format=model_format) pipeline = model.get("pipeline", "txt2img") converter = model_converters.get(pipeline) converted, dest = converter( conversion, model, - dest, model_format, ) diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index dbb50610..aea560d9 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -36,6 +36,8 @@ from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline from ...diffusers.version_safe_diffusers import AttnProcessor from ...models.cnet import UNet2DConditionModel_CNet from ...utils import run_gc +from ..client import fetch_model +from ..client.huggingface import HuggingfaceClient from ..utils import ( RESOLVE_FORMATS, ConversionContext, @@ -43,6 +45,7 @@ from ..utils import ( is_torch_2_0, load_tensor, onnx_export, + remove_prefix, ) from .checkpoint import convert_extract_checkpoint @@ -267,14 +270,13 @@ def collate_cnet(cnet_path): def convert_diffusion_diffusers( conversion: ConversionContext, model: Dict, - source: str, format: Optional[str], - hf: bool = False, ) -> Tuple[bool, str]: """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py """ name = model.get("name") + source = model.get("source") # optional config = model.get("config", None) @@ -320,9 +322,11 @@ def convert_diffusion_diffusers( logger.info("ONNX model already exists, skipping") return (False, dest_path) + cache_path = fetch_model(conversion, name, source, format=format) + pipe_class = CONVERT_PIPELINES.get(pipe_type) v2, pipe_args = get_model_version( - source, conversion.map_location, size=image_size, version=version + cache_path, conversion.map_location, size=image_size, version=version ) is_inpainting = False @@ -334,50 +338,58 @@ def convert_diffusion_diffusers( pipe_args["from_safetensors"] = True torch_source = None - if path.exists(source) and path.isdir(source): - logger.debug("loading pipeline from diffusers directory: %s", source) - pipeline = pipe_class.from_pretrained( - source, - torch_dtype=dtype, - use_auth_token=conversion.token, - ).to(device) - elif path.exists(source) and path.isfile(source): - if conversion.extract: - logger.debug("extracting SD checkpoint to Torch models: %s", source) - torch_source = convert_extract_checkpoint( - conversion, - source, - f"{name}-torch", - is_inpainting=is_inpainting, - config_file=config, - vae_file=replace_vae, - ) - logger.debug("loading pipeline from extracted checkpoint: %s", torch_source) + if path.exists(cache_path): + if path.isdir(cache_path): + logger.debug("loading pipeline from diffusers directory: %s", source) pipeline = pipe_class.from_pretrained( - torch_source, + cache_path, torch_dtype=dtype, + use_auth_token=conversion.token, ).to(device) + elif path.isfile(source): + if conversion.extract: + logger.debug("extracting SD checkpoint to Torch models: %s", source) + torch_source = convert_extract_checkpoint( + conversion, + source, + f"{name}-torch", + is_inpainting=is_inpainting, + config_file=config, + vae_file=replace_vae, + ) + logger.debug( + "loading pipeline from extracted checkpoint: %s", torch_source + ) + pipeline = pipe_class.from_pretrained( + torch_source, + torch_dtype=dtype, + ).to(device) - # VAE replacement already happened during extraction, skip - replace_vae = None - else: - logger.debug("loading pipeline from SD checkpoint: %s", source) - pipeline = download_from_original_stable_diffusion_ckpt( - source, - original_config_file=config_path, - pipeline_class=pipe_class, - **pipe_args, - ).to(device, torch_dtype=dtype) - elif hf: - logger.debug("downloading pretrained model from Huggingface hub: %s", source) + # VAE replacement already happened during extraction, skip + replace_vae = None + else: + logger.debug("loading pipeline from SD checkpoint: %s", source) + pipeline = download_from_original_stable_diffusion_ckpt( + source, + original_config_file=config_path, + pipeline_class=pipe_class, + **pipe_args, + ).to(device, torch_dtype=dtype) + elif source.startswith(HuggingfaceClient.protocol): + hf_path = remove_prefix(source, HuggingfaceClient.protocol) + logger.debug("downloading pretrained model from Huggingface hub: %s", hf_path) pipeline = pipe_class.from_pretrained( - source, + hf_path, torch_dtype=dtype, use_auth_token=conversion.token, ).to(device) else: - logger.warning("pipeline source not found or not recognized: %s", source) - raise ValueError(f"pipeline source not found or not recognized: {source}") + logger.warning( + "pipeline source not found and protocol not recognized: %s", source + ) + raise ValueError( + f"pipeline source not found and protocol not recognized: {source}" + ) if replace_vae is not None: vae_path = path.join(conversion.model_path, replace_vae) diff --git a/api/onnx_web/convert/diffusion/diffusion_xl.py b/api/onnx_web/convert/diffusion/diffusion_xl.py index 8370d302..d9319596 100644 --- a/api/onnx_web/convert/diffusion/diffusion_xl.py +++ b/api/onnx_web/convert/diffusion/diffusion_xl.py @@ -10,6 +10,7 @@ from onnxruntime.transformers.float16 import convert_float_to_float16 from optimum.exporters.onnx import main_export from ...constants import ONNX_MODEL +from ..client import fetch_model from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext logger = getLogger(__name__) @@ -19,14 +20,13 @@ logger = getLogger(__name__) def convert_diffusion_diffusers_xl( conversion: ConversionContext, model: Dict, - source: str, format: Optional[str], - hf: bool = False, ) -> Tuple[bool, str]: """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py """ name = model.get("name") + source = model.get("source") replace_vae = model.get("vae", None) device = conversion.training_device @@ -52,15 +52,16 @@ def convert_diffusion_diffusers_xl( return (False, dest_path) + cache_path = fetch_model(conversion, name, model["source"], format=format) # safetensors -> diffusers directory with torch models temp_path = path.join(conversion.cache_path, f"{name}-torch") if format == "safetensors": pipeline = StableDiffusionXLPipeline.from_single_file( - source, use_safetensors=True + cache_path, use_safetensors=True ) else: - pipeline = StableDiffusionXLPipeline.from_pretrained(source) + pipeline = StableDiffusionXLPipeline.from_pretrained(cache_path) if replace_vae is not None: vae_path = path.join(conversion.model_path, replace_vae) From 4da4cd95a55bd52f1b733d087b251156aafb4000 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Dec 2023 13:59:47 -0600 Subject: [PATCH 30/42] fix(api): trim whitespace from model names because it breaks things (#376) --- api/onnx_web/convert/diffusion/diffusion.py | 2 +- api/onnx_web/convert/diffusion/diffusion_xl.py | 2 +- api/onnx_web/server/load.py | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index aea560d9..296f7351 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -275,7 +275,7 @@ def convert_diffusion_diffusers( """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py """ - name = model.get("name") + name = str(model.get("name")).strip() source = model.get("source") # optional diff --git a/api/onnx_web/convert/diffusion/diffusion_xl.py b/api/onnx_web/convert/diffusion/diffusion_xl.py index d9319596..c09b9440 100644 --- a/api/onnx_web/convert/diffusion/diffusion_xl.py +++ b/api/onnx_web/convert/diffusion/diffusion_xl.py @@ -25,7 +25,7 @@ def convert_diffusion_diffusers_xl( """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py """ - name = model.get("name") + name = str(model.get("name")).strip() source = model.get("source") replace_vae = model.get("vae", None) diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index 6bf1de2d..2dd37157 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union import torch from jsonschema import ValidationError, validate +from ..convert.utils import fix_diffusion_name from ..image import ( # mask filters; noise sources mask_filter_gaussian_multiply, mask_filter_gaussian_screen, @@ -189,6 +190,9 @@ def load_extras(server: ServerContext): for model in data[model_type]: model_name = model["name"] + if model_type == "diffusion": + model_name = fix_diffusion_name(model_name) + if "hash" in model: logger.debug( "collecting hash for model %s from %s", From 0155236744ae8bb2d036f9a90987fe1e1043f871 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Dec 2023 14:10:09 -0600 Subject: [PATCH 31/42] fix(api): update flask and pin werkzeug (#414) --- api/requirements/base.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/requirements/base.txt b/api/requirements/base.txt index 9d2b8ddf..f7bb85aa 100644 --- a/api/requirements/base.txt +++ b/api/requirements/base.txt @@ -29,10 +29,11 @@ realesrgan==0.3.0 ### Server packages ### arpeggio==2.0.0 boto3==1.26.69 -flask==2.2.2 +flask==3.0.0 flask-cors==3.0.10 jsonschema==4.17.3 piexif==1.1.3 pyyaml==6.0 setproctitle==1.3.2 waitress==2.1.2 +werkzeug==3.0.1 From 2a641b111ec827477bf811134246163979fb6d7e Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 10 Dec 2023 17:03:54 -0600 Subject: [PATCH 32/42] fix(api): check for web UI files in Windows launch scripts --- api/launch.bat | 8 ++++++++ api/launch.ps1 | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/api/launch.bat b/api/launch.bat index 0d5933e5..680de5d9 100644 --- a/api/launch.bat +++ b/api/launch.bat @@ -1,5 +1,7 @@ call onnx_env\Scripts\Activate.bat +echo "This launch.bat script is deprecated in favor of launch.ps1 and will be removed in a future release." + echo "Downloading and converting models to ONNX format..." IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=..\models\extras.json) python -m onnx_web.convert ^ @@ -10,6 +12,12 @@ python -m onnx_web.convert ^ --extras=%ONNX_WEB_EXTRA_MODELS% ^ --token=%HF_TOKEN% %ONNX_WEB_EXTRA_ARGS% +IF NOT EXIST .\gui\index.html ( + echo "Please make sure you have downloaded the web UI files from https://github.com/ssube/onnx-web/tree/gh-pages" + echo "See https://github.com/ssube/onnx-web/blob/main/docs/setup-guide.md#download-the-web-ui-bundle for more information" + pause +) + echo "Launching API server..." waitress-serve ^ --host=0.0.0.0 ^ diff --git a/api/launch.ps1 b/api/launch.ps1 index 19add137..ba898151 100644 --- a/api/launch.ps1 +++ b/api/launch.ps1 @@ -10,6 +10,13 @@ python -m onnx_web.convert ` --extras=$Env:ONNX_WEB_EXTRA_MODELS ` --token=$Env:HF_TOKEN $Env:ONNX_WEB_EXTRA_ARGS +if (!(Test-Path -path .\gui\index.html -PathType Leaf)) { + echo "Downloading latest web UI files from Github..." + Invoke-WebRequest "https://raw.githubusercontent.com/ssube/onnx-web/gh-pages/v0.11.0/index.html" -OutFile .\gui\index.html + Invoke-WebRequest "https://raw.githubusercontent.com/ssube/onnx-web/gh-pages/v0.11.0/config.json" -OutFile .\gui\config.json + Invoke-WebRequest "https://raw.githubusercontent.com/ssube/onnx-web/gh-pages/v0.11.0/bundle/main.js" -OutFile .\gui\bundle\main.js +} + echo "Launching API server..." waitress-serve ` --host=0.0.0.0 ` From 95d8f4a5986f27255828481ff0b7c5f485584ce3 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 14 Dec 2023 20:04:08 -0600 Subject: [PATCH 33/42] fix(api): make sure all file types are covered (#432) --- api/onnx_web/convert/diffusion/diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index 296f7351..da16a5e9 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -346,7 +346,7 @@ def convert_diffusion_diffusers( torch_dtype=dtype, use_auth_token=conversion.token, ).to(device) - elif path.isfile(source): + else: if conversion.extract: logger.debug("extracting SD checkpoint to Torch models: %s", source) torch_source = convert_extract_checkpoint( From 6c4f4f334f4705ff8d91b0ead2b7d0caea6ce53a Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 14 Dec 2023 20:12:39 -0600 Subject: [PATCH 34/42] always use alpha in blend stage --- api/onnx_web/chain/blend_mask.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 4486bbf6..326ba162 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -30,17 +30,16 @@ class BlendMaskStage(BaseStage): ) -> StageResult: logger.info("blending image using mask") - # TODO: does this need an alpha channel? mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black") - mult_mask.alpha_composite(stage_mask) + mult_mask = Image.alpha_composite(mult_mask, stage_mask) mult_mask = mult_mask.convert("L") if is_debug(): save_image(server, "last-mask.png", stage_mask) save_image(server, "last-mult-mask.png", mult_mask) - return StageResult( - images=[ + return StageResult.from_images( + [ Image.composite(stage_source, source, mult_mask) for source in sources.as_image() ] From c6de25682dbe9cdafffa5aed3e47510c84a6ce9b Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 14 Dec 2023 20:13:58 -0600 Subject: [PATCH 35/42] feat(api): upgrade to latest diffusers, optimum, transformers (#433) --- api/requirements/base.txt | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/api/requirements/base.txt b/api/requirements/base.txt index f7bb85aa..822ef4e7 100644 --- a/api/requirements/base.txt +++ b/api/requirements/base.txt @@ -3,21 +3,21 @@ numpy==1.23.5 protobuf==3.20.3 ### SD packages ### -accelerate==0.22.0 +accelerate==0.25.0 coloredlogs==15.0.1 controlnet_aux==0.0.2 -datasets==2.14.3 -diffusers==0.20.0 +datasets==2.15.0 +diffusers==0.24.0 huggingface-hub==0.16.4 invisible-watermark==0.2.0 mediapipe==0.9.2.1 omegaconf==2.3.0 onnx==1.13.0 # onnxruntime has many platform-specific packages -optimum==1.12.0 -safetensors==0.3.1 -timm==0.6.13 -transformers==4.32.0 +optimum==1.16.0 +safetensors==0.4.1 +timm==0.9.12 +transformers==4.36.1 #### Upscaling and face correction basicsr==1.4.2 From d86286ab1e956935f200f2ab65bb82811a7ab7bf Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 14 Dec 2023 20:46:43 -0600 Subject: [PATCH 36/42] lint(gui): split up state types --- gui/src/state.ts | 172 +++------------------------------------ gui/src/state/blend.ts | 21 +++++ gui/src/state/default.ts | 12 +++ gui/src/state/history.ts | 18 ++++ gui/src/state/img2img.ts | 22 +++++ gui/src/state/inpaint.ts | 27 ++++++ gui/src/state/models.ts | 19 +++++ gui/src/state/profile.ts | 17 ++++ gui/src/state/reset.ts | 3 + gui/src/state/txt2img.ts | 24 ++++++ gui/src/state/types.ts | 11 +++ gui/src/state/upscale.ts | 21 +++++ 12 files changed, 207 insertions(+), 160 deletions(-) create mode 100644 gui/src/state/blend.ts create mode 100644 gui/src/state/default.ts create mode 100644 gui/src/state/history.ts create mode 100644 gui/src/state/img2img.ts create mode 100644 gui/src/state/inpaint.ts create mode 100644 gui/src/state/models.ts create mode 100644 gui/src/state/profile.ts create mode 100644 gui/src/state/reset.ts create mode 100644 gui/src/state/txt2img.ts create mode 100644 gui/src/state/types.ts create mode 100644 gui/src/state/upscale.ts diff --git a/gui/src/state.ts b/gui/src/state.ts index ad2a3e3a..fd9a21a0 100644 --- a/gui/src/state.ts +++ b/gui/src/state.ts @@ -2,7 +2,6 @@ /* eslint-disable max-lines */ /* eslint-disable no-null/no-null */ import { Maybe } from '@apextoaster/js-utils'; -import { PaletteMode } from '@mui/material'; import { Logger } from 'noicejs'; import { createContext } from 'react'; import { StateCreator, StoreApi } from 'zustand'; @@ -11,170 +10,24 @@ import { ApiClient, } from './client/base.js'; import { PipelineGrid } from './client/utils.js'; -import { Config, ConfigFiles, ConfigState, ServerParams } from './config.js'; -import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from './types/model.js'; -import { ImageResponse, ReadyResponse, RetryParams } from './types/api.js'; +import { Config, ServerParams } from './config.js'; import { BaseImgParams, - BlendParams, - BrushParams, HighresParams, - Img2ImgParams, - InpaintParams, ModelParams, - OutpaintPixels, - Txt2ImgParams, UpscaleParams, - UpscaleReqParams, } from './types/params.js'; - -export const MISSING_INDEX = -1; - -export type Theme = PaletteMode | ''; // tri-state, '' is unset - -/** - * Combine optional files and required ranges. - */ -export type TabState = ConfigFiles> & ConfigState>; - -export interface HistoryItem { - image: ImageResponse; - ready: Maybe; - retry: Maybe; -} - -export interface ProfileItem { - name: string; - params: BaseImgParams | Txt2ImgParams; - highres?: Maybe; - upscale?: Maybe; -} - -interface DefaultSlice { - defaults: TabState; - theme: Theme; - - setDefaults(param: Partial): void; - setTheme(theme: Theme): void; -} - -interface HistorySlice { - history: Array; - limit: number; - - pushHistory(image: ImageResponse, retry?: RetryParams): void; - removeHistory(image: ImageResponse): void; - setLimit(limit: number): void; - setReady(image: ImageResponse, ready: ReadyResponse): void; -} - -interface ModelSlice { - extras: ExtrasFile; - - removeCorrectionModel(model: CorrectionModel): void; - removeDiffusionModel(model: DiffusionModel): void; - removeExtraNetwork(model: ExtraNetwork): void; - removeExtraSource(model: ExtraSource): void; - removeUpscalingModel(model: UpscalingModel): void; - - setExtras(extras: Partial): void; - - setCorrectionModel(model: CorrectionModel): void; - setDiffusionModel(model: DiffusionModel): void; - setExtraNetwork(model: ExtraNetwork): void; - setExtraSource(model: ExtraSource): void; - setUpscalingModel(model: UpscalingModel): void; -} - -// #region tab slices -interface Txt2ImgSlice { - txt2img: TabState; - txt2imgModel: ModelParams; - txt2imgHighres: HighresParams; - txt2imgUpscale: UpscaleParams; - txt2imgVariable: PipelineGrid; - - resetTxt2Img(): void; - - setTxt2Img(params: Partial): void; - setTxt2ImgModel(params: Partial): void; - setTxt2ImgHighres(params: Partial): void; - setTxt2ImgUpscale(params: Partial): void; - setTxt2ImgVariable(params: Partial): void; -} - -interface Img2ImgSlice { - img2img: TabState; - img2imgModel: ModelParams; - img2imgHighres: HighresParams; - img2imgUpscale: UpscaleParams; - - resetImg2Img(): void; - - setImg2Img(params: Partial): void; - setImg2ImgModel(params: Partial): void; - setImg2ImgHighres(params: Partial): void; - setImg2ImgUpscale(params: Partial): void; -} - -interface InpaintSlice { - inpaint: TabState; - inpaintBrush: BrushParams; - inpaintModel: ModelParams; - inpaintHighres: HighresParams; - inpaintUpscale: UpscaleParams; - outpaint: OutpaintPixels; - - resetInpaint(): void; - - setInpaint(params: Partial): void; - setInpaintBrush(brush: Partial): void; - setInpaintModel(params: Partial): void; - setInpaintHighres(params: Partial): void; - setInpaintUpscale(params: Partial): void; - setOutpaint(pixels: Partial): void; -} - -interface UpscaleSlice { - upscale: TabState; - upscaleHighres: HighresParams; - upscaleModel: ModelParams; - upscaleUpscale: UpscaleParams; - - resetUpscale(): void; - - setUpscale(params: Partial): void; - setUpscaleHighres(params: Partial): void; - setUpscaleModel(params: Partial): void; - setUpscaleUpscale(params: Partial): void; -} - -interface BlendSlice { - blend: TabState; - blendBrush: BrushParams; - blendModel: ModelParams; - blendUpscale: UpscaleParams; - - resetBlend(): void; - - setBlend(blend: Partial): void; - setBlendBrush(brush: Partial): void; - setBlendModel(model: Partial): void; - setBlendUpscale(params: Partial): void; -} - -interface ResetSlice { - resetAll(): void; -} - -interface ProfileSlice { - profiles: Array; - - removeProfile(profileName: string): void; - - saveProfile(profile: ProfileItem): void; -} -// #endregion +import { DefaultSlice } from './state/default.js'; +import { HistorySlice } from './state/history.js'; +import { Img2ImgSlice } from './state/img2img.js'; +import { InpaintSlice } from './state/inpaint.js'; +import { ModelSlice } from './state/models.js'; +import { Txt2ImgSlice } from './state/txt2img.js'; +import { UpscaleSlice } from './state/upscale.js'; +import { ResetSlice } from './state/reset.js'; +import { ProfileItem, ProfileSlice } from './state/profile.js'; +import { BlendSlice } from './state/blend.js'; +import { MISSING_INDEX } from './state/types.js'; /** * Full merged state including all slices. @@ -189,7 +42,6 @@ export type OnnxState & UpscaleSlice & BlendSlice & ResetSlice - & ModelSlice & ProfileSlice; /** diff --git a/gui/src/state/blend.ts b/gui/src/state/blend.ts new file mode 100644 index 00000000..42d292f3 --- /dev/null +++ b/gui/src/state/blend.ts @@ -0,0 +1,21 @@ +import { + BlendParams, + BrushParams, + ModelParams, + UpscaleParams, +} from '../types/params.js'; +import { TabState } from './types.js'; + +export interface BlendSlice { + blend: TabState; + blendBrush: BrushParams; + blendModel: ModelParams; + blendUpscale: UpscaleParams; + + resetBlend(): void; + + setBlend(blend: Partial): void; + setBlendBrush(brush: Partial): void; + setBlendModel(model: Partial): void; + setBlendUpscale(params: Partial): void; +} diff --git a/gui/src/state/default.ts b/gui/src/state/default.ts new file mode 100644 index 00000000..b43b7ceb --- /dev/null +++ b/gui/src/state/default.ts @@ -0,0 +1,12 @@ +import { + BaseImgParams, +} from '../types/params.js'; +import { TabState, Theme } from './types.js'; + +export interface DefaultSlice { + defaults: TabState; + theme: Theme; + + setDefaults(param: Partial): void; + setTheme(theme: Theme): void; +} diff --git a/gui/src/state/history.ts b/gui/src/state/history.ts new file mode 100644 index 00000000..44f7b2f3 --- /dev/null +++ b/gui/src/state/history.ts @@ -0,0 +1,18 @@ +import { Maybe } from '@apextoaster/js-utils'; +import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js'; + +export interface HistoryItem { + image: ImageResponse; + ready: Maybe; + retry: Maybe; +} + +export interface HistorySlice { + history: Array; + limit: number; + + pushHistory(image: ImageResponse, retry?: RetryParams): void; + removeHistory(image: ImageResponse): void; + setLimit(limit: number): void; + setReady(image: ImageResponse, ready: ReadyResponse): void; +} diff --git a/gui/src/state/img2img.ts b/gui/src/state/img2img.ts new file mode 100644 index 00000000..cbe204d1 --- /dev/null +++ b/gui/src/state/img2img.ts @@ -0,0 +1,22 @@ + +import { + HighresParams, + Img2ImgParams, + ModelParams, + UpscaleParams, +} from '../types/params.js'; +import { TabState } from './types.js'; + +export interface Img2ImgSlice { + img2img: TabState; + img2imgModel: ModelParams; + img2imgHighres: HighresParams; + img2imgUpscale: UpscaleParams; + + resetImg2Img(): void; + + setImg2Img(params: Partial): void; + setImg2ImgModel(params: Partial): void; + setImg2ImgHighres(params: Partial): void; + setImg2ImgUpscale(params: Partial): void; +} diff --git a/gui/src/state/inpaint.ts b/gui/src/state/inpaint.ts new file mode 100644 index 00000000..12756ac2 --- /dev/null +++ b/gui/src/state/inpaint.ts @@ -0,0 +1,27 @@ +import { + BrushParams, + HighresParams, + InpaintParams, + ModelParams, + OutpaintPixels, + UpscaleParams, +} from '../types/params.js'; +import { TabState } from './types.js'; + +export interface InpaintSlice { + inpaint: TabState; + inpaintBrush: BrushParams; + inpaintModel: ModelParams; + inpaintHighres: HighresParams; + inpaintUpscale: UpscaleParams; + outpaint: OutpaintPixels; + + resetInpaint(): void; + + setInpaint(params: Partial): void; + setInpaintBrush(brush: Partial): void; + setInpaintModel(params: Partial): void; + setInpaintHighres(params: Partial): void; + setInpaintUpscale(params: Partial): void; + setOutpaint(pixels: Partial): void; +} diff --git a/gui/src/state/models.ts b/gui/src/state/models.ts new file mode 100644 index 00000000..e0faa27d --- /dev/null +++ b/gui/src/state/models.ts @@ -0,0 +1,19 @@ +import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from '../types/model.js'; + +export interface ModelSlice { + extras: ExtrasFile; + + removeCorrectionModel(model: CorrectionModel): void; + removeDiffusionModel(model: DiffusionModel): void; + removeExtraNetwork(model: ExtraNetwork): void; + removeExtraSource(model: ExtraSource): void; + removeUpscalingModel(model: UpscalingModel): void; + + setExtras(extras: Partial): void; + + setCorrectionModel(model: CorrectionModel): void; + setDiffusionModel(model: DiffusionModel): void; + setExtraNetwork(model: ExtraNetwork): void; + setExtraSource(model: ExtraSource): void; + setUpscalingModel(model: UpscalingModel): void; +} diff --git a/gui/src/state/profile.ts b/gui/src/state/profile.ts new file mode 100644 index 00000000..7ffdfac9 --- /dev/null +++ b/gui/src/state/profile.ts @@ -0,0 +1,17 @@ +import { Maybe } from '@apextoaster/js-utils'; +import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js'; + +export interface ProfileItem { + name: string; + params: BaseImgParams | Txt2ImgParams; + highres?: Maybe; + upscale?: Maybe; +} + +export interface ProfileSlice { + profiles: Array; + + removeProfile(profileName: string): void; + + saveProfile(profile: ProfileItem): void; +} diff --git a/gui/src/state/reset.ts b/gui/src/state/reset.ts new file mode 100644 index 00000000..66b545a5 --- /dev/null +++ b/gui/src/state/reset.ts @@ -0,0 +1,3 @@ +export interface ResetSlice { + resetAll(): void; +} diff --git a/gui/src/state/txt2img.ts b/gui/src/state/txt2img.ts new file mode 100644 index 00000000..d8cd95eb --- /dev/null +++ b/gui/src/state/txt2img.ts @@ -0,0 +1,24 @@ +import { PipelineGrid } from '../client/utils.js'; +import { + HighresParams, + ModelParams, + Txt2ImgParams, + UpscaleParams, +} from '../types/params.js'; +import { TabState } from './types.js'; + +export interface Txt2ImgSlice { + txt2img: TabState; + txt2imgModel: ModelParams; + txt2imgHighres: HighresParams; + txt2imgUpscale: UpscaleParams; + txt2imgVariable: PipelineGrid; + + resetTxt2Img(): void; + + setTxt2Img(params: Partial): void; + setTxt2ImgModel(params: Partial): void; + setTxt2ImgHighres(params: Partial): void; + setTxt2ImgUpscale(params: Partial): void; + setTxt2ImgVariable(params: Partial): void; +} diff --git a/gui/src/state/types.ts b/gui/src/state/types.ts new file mode 100644 index 00000000..98843c86 --- /dev/null +++ b/gui/src/state/types.ts @@ -0,0 +1,11 @@ +import { PaletteMode } from '@mui/material'; +import { ConfigFiles, ConfigState } from '../config.js'; + +export const MISSING_INDEX = -1; + +export type Theme = PaletteMode | ''; // tri-state, '' is unset + +/** + * Combine optional files and required ranges. + */ +export type TabState = ConfigFiles> & ConfigState>; diff --git a/gui/src/state/upscale.ts b/gui/src/state/upscale.ts new file mode 100644 index 00000000..af0a344a --- /dev/null +++ b/gui/src/state/upscale.ts @@ -0,0 +1,21 @@ +import { + HighresParams, + ModelParams, + UpscaleParams, + UpscaleReqParams, +} from '../types/params.js'; +import { TabState } from './types.js'; + +export interface UpscaleSlice { + upscale: TabState; + upscaleHighres: HighresParams; + upscaleModel: ModelParams; + upscaleUpscale: UpscaleParams; + + resetUpscale(): void; + + setUpscale(params: Partial): void; + setUpscaleHighres(params: Partial): void; + setUpscaleModel(params: Partial): void; + setUpscaleUpscale(params: Partial): void; +} From 0dfc1b61d20eb3323c045878773b3953726a7d2f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 14 Dec 2023 21:30:06 -0600 Subject: [PATCH 37/42] break up state slice factories --- gui/src/components/ImageHistory.tsx | 2 +- gui/src/components/OnnxError.tsx | 2 +- gui/src/components/OnnxWeb.tsx | 2 +- gui/src/components/Profiles.tsx | 2 +- gui/src/components/card/ErrorCard.tsx | 2 +- gui/src/components/card/ImageCard.tsx | 2 +- gui/src/components/card/LoadingCard.tsx | 2 +- gui/src/components/control/HighresControl.tsx | 2 +- gui/src/components/control/ImageControl.tsx | 2 +- gui/src/components/control/ModelControl.tsx | 2 +- .../components/control/OutpaintControl.tsx | 2 +- gui/src/components/control/UpscaleControl.tsx | 2 +- .../components/control/VariableControl.tsx | 2 +- gui/src/components/input/EditableList.tsx | 2 +- gui/src/components/input/MaskCanvas.tsx | 2 +- gui/src/components/input/PromptInput.tsx | 2 +- gui/src/components/tab/Blend.tsx | 3 +- gui/src/components/tab/Img2Img.tsx | 3 +- gui/src/components/tab/Inpaint.tsx | 3 +- gui/src/components/tab/Models.tsx | 2 +- gui/src/components/tab/Settings.tsx | 2 +- gui/src/components/tab/Txt2Img.tsx | 3 +- gui/src/components/tab/Upscale.tsx | 3 +- gui/src/components/utils.ts | 2 +- gui/src/main.tsx | 2 +- gui/src/state.ts | 817 ------------------ gui/src/state/blend.ts | 65 +- gui/src/state/default.ts | 24 +- gui/src/state/full.ts | 152 ++++ gui/src/state/history.ts | 49 ++ gui/src/state/img2img.ts | 78 +- gui/src/state/inpaint.ts | 112 ++- gui/src/state/model.ts | 202 +++++ gui/src/state/models.ts | 19 - gui/src/state/profile.ts | 35 + gui/src/state/reset.ts | 25 + gui/src/state/txt2img.ts | 83 +- gui/src/state/types.ts | 35 + gui/src/state/upscale.ts | 68 +- 39 files changed, 951 insertions(+), 868 deletions(-) delete mode 100644 gui/src/state.ts create mode 100644 gui/src/state/full.ts create mode 100644 gui/src/state/model.ts delete mode 100644 gui/src/state/models.ts diff --git a/gui/src/components/ImageHistory.tsx b/gui/src/components/ImageHistory.tsx index 0f6b0f2f..20a520ed 100644 --- a/gui/src/components/ImageHistory.tsx +++ b/gui/src/components/ImageHistory.tsx @@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { OnnxState, StateContext } from '../state.js'; +import { OnnxState, StateContext } from '../state/full.js'; import { ErrorCard } from './card/ErrorCard.js'; import { ImageCard } from './card/ImageCard.js'; import { LoadingCard } from './card/LoadingCard.js'; diff --git a/gui/src/components/OnnxError.tsx b/gui/src/components/OnnxError.tsx index da6d7147..05186768 100644 --- a/gui/src/components/OnnxError.tsx +++ b/gui/src/components/OnnxError.tsx @@ -2,7 +2,7 @@ import { Box, Button, Container, Stack, Typography } from '@mui/material'; import * as React from 'react'; import { ReactNode } from 'react'; -import { STATE_KEY } from '../state.js'; +import { STATE_KEY } from '../state/full.js'; import { Logo } from './Logo.js'; export interface OnnxErrorProps { diff --git a/gui/src/components/OnnxWeb.tsx b/gui/src/components/OnnxWeb.tsx index 918ee84f..69c2db20 100644 --- a/gui/src/components/OnnxWeb.tsx +++ b/gui/src/components/OnnxWeb.tsx @@ -7,7 +7,7 @@ import { useContext, useMemo } from 'react'; import { useHash } from 'react-use/lib/useHash'; import { useStore } from 'zustand'; -import { OnnxState, StateContext } from '../state.js'; +import { OnnxState, StateContext } from '../state/full.js'; import { ImageHistory } from './ImageHistory.js'; import { Logo } from './Logo.js'; import { Blend } from './tab/Blend.js'; diff --git a/gui/src/components/Profiles.tsx b/gui/src/components/Profiles.tsx index baf022b3..de8f2263 100644 --- a/gui/src/components/Profiles.tsx +++ b/gui/src/components/Profiles.tsx @@ -21,7 +21,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { OnnxState, StateContext } from '../state.js'; +import { OnnxState, StateContext } from '../state/full.js'; import { ImageMetadata } from '../types/api.js'; import { DeepPartial } from '../types/model.js'; import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js'; diff --git a/gui/src/components/card/ErrorCard.tsx b/gui/src/components/card/ErrorCard.tsx index bb3ac6c9..f7106584 100644 --- a/gui/src/components/card/ErrorCard.tsx +++ b/gui/src/components/card/ErrorCard.tsx @@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { ImageResponse, ReadyResponse, RetryParams } from '../../types/api.js'; export interface ErrorCardProps { diff --git a/gui/src/components/card/ImageCard.tsx b/gui/src/components/card/ImageCard.tsx index 7c35acb2..44f5dd91 100644 --- a/gui/src/components/card/ImageCard.tsx +++ b/gui/src/components/card/ImageCard.tsx @@ -8,7 +8,7 @@ import { useHash } from 'react-use/lib/useHash'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { BLEND_SOURCES, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { BLEND_SOURCES, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { ImageResponse } from '../../types/api.js'; import { range, visibleIndex } from '../../utils.js'; diff --git a/gui/src/components/card/LoadingCard.tsx b/gui/src/components/card/LoadingCard.tsx index e0fcdb68..71bfb5f0 100644 --- a/gui/src/components/card/LoadingCard.tsx +++ b/gui/src/components/card/LoadingCard.tsx @@ -9,7 +9,7 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { POLL_TIME } from '../../config.js'; -import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { ImageResponse } from '../../types/api.js'; const LOADING_PERCENT = 100; diff --git a/gui/src/components/control/HighresControl.tsx b/gui/src/components/control/HighresControl.tsx index 91525b21..62fa63fd 100644 --- a/gui/src/components/control/HighresControl.tsx +++ b/gui/src/components/control/HighresControl.tsx @@ -5,7 +5,7 @@ import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; -import { ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { HighresParams } from '../../types/params.js'; import { NumericField } from '../input/NumericField.js'; diff --git a/gui/src/components/control/ImageControl.tsx b/gui/src/components/control/ImageControl.tsx index 8271c700..8877ae50 100644 --- a/gui/src/components/control/ImageControl.tsx +++ b/gui/src/components/control/ImageControl.tsx @@ -11,7 +11,7 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { STALE_TIME } from '../../config.js'; -import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { BaseImgParams } from '../../types/params.js'; import { NumericField } from '../input/NumericField.js'; import { PromptInput } from '../input/PromptInput.js'; diff --git a/gui/src/components/control/ModelControl.tsx b/gui/src/components/control/ModelControl.tsx index ba08998f..54b99846 100644 --- a/gui/src/components/control/ModelControl.tsx +++ b/gui/src/components/control/ModelControl.tsx @@ -6,7 +6,7 @@ import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; import { STALE_TIME } from '../../config.js'; -import { ClientContext } from '../../state.js'; +import { ClientContext } from '../../state/full.js'; import { ModelParams } from '../../types/params.js'; import { QueryList } from '../input/QueryList.js'; diff --git a/gui/src/components/control/OutpaintControl.tsx b/gui/src/components/control/OutpaintControl.tsx index 5b70cc50..69523477 100644 --- a/gui/src/components/control/OutpaintControl.tsx +++ b/gui/src/components/control/OutpaintControl.tsx @@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { NumericField } from '../input/NumericField.js'; export function OutpaintControl() { diff --git a/gui/src/components/control/UpscaleControl.tsx b/gui/src/components/control/UpscaleControl.tsx index a90dd330..5d840a04 100644 --- a/gui/src/components/control/UpscaleControl.tsx +++ b/gui/src/components/control/UpscaleControl.tsx @@ -5,7 +5,7 @@ import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; -import { ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { UpscaleParams } from '../../types/params.js'; import { NumericField } from '../input/NumericField.js'; diff --git a/gui/src/components/control/VariableControl.tsx b/gui/src/components/control/VariableControl.tsx index 32a66660..bf5139e1 100644 --- a/gui/src/components/control/VariableControl.tsx +++ b/gui/src/components/control/VariableControl.tsx @@ -5,7 +5,7 @@ import { useContext } from 'react'; import { useStore } from 'zustand'; import { PipelineGrid } from '../../client/utils.js'; -import { OnnxState, StateContext } from '../../state.js'; +import { OnnxState, StateContext } from '../../state/full.js'; import { VARIABLE_PARAMETERS } from '../../types/chain.js'; export interface VariableControlProps { diff --git a/gui/src/components/input/EditableList.tsx b/gui/src/components/input/EditableList.tsx index 3910e394..a6d45aca 100644 --- a/gui/src/components/input/EditableList.tsx +++ b/gui/src/components/input/EditableList.tsx @@ -4,7 +4,7 @@ import * as React from 'react'; import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; -import { OnnxState, StateContext } from '../../state.js'; +import { OnnxState, StateContext } from '../../state/full.js'; const { useContext, useState, memo, useMemo } = React; diff --git a/gui/src/components/input/MaskCanvas.tsx b/gui/src/components/input/MaskCanvas.tsx index ae7f2723..2e605423 100644 --- a/gui/src/components/input/MaskCanvas.tsx +++ b/gui/src/components/input/MaskCanvas.tsx @@ -6,7 +6,7 @@ import React, { RefObject, useContext, useEffect, useMemo, useRef } from 'react' import { useTranslation } from 'react-i18next'; import { SAVE_TIME } from '../../config.js'; -import { ConfigContext, LoggerContext, StateContext } from '../../state.js'; +import { ConfigContext, LoggerContext, StateContext } from '../../state/full.js'; import { BrushParams } from '../../types/params.js'; import { imageFromBlob } from '../../utils.js'; import { NumericField } from './NumericField'; diff --git a/gui/src/components/input/PromptInput.tsx b/gui/src/components/input/PromptInput.tsx index cdaa6751..bc522569 100644 --- a/gui/src/components/input/PromptInput.tsx +++ b/gui/src/components/input/PromptInput.tsx @@ -8,7 +8,7 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { STALE_TIME } from '../../config.js'; -import { ClientContext, OnnxState, StateContext } from '../../state.js'; +import { ClientContext, OnnxState, StateContext } from '../../state/full.js'; import { QueryMenu } from '../input/QueryMenu.js'; import { ModelResponse } from '../../types/api.js'; diff --git a/gui/src/components/tab/Blend.tsx b/gui/src/components/tab/Blend.tsx index a1a304d1..3cd70449 100644 --- a/gui/src/components/tab/Blend.tsx +++ b/gui/src/components/tab/Blend.tsx @@ -8,7 +8,8 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { IMAGE_FILTER } from '../../config.js'; -import { BLEND_SOURCES, ClientContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { BLEND_SOURCES, ClientContext, OnnxState, StateContext } from '../../state/full.js'; +import { TabState } from '../../state/types.js'; import { BlendParams, BrushParams, ModelParams, UpscaleParams } from '../../types/params.js'; import { range } from '../../utils.js'; import { UpscaleControl } from '../control/UpscaleControl.js'; diff --git a/gui/src/components/tab/Img2Img.tsx b/gui/src/components/tab/Img2Img.tsx index cddb8539..26bda5cf 100644 --- a/gui/src/components/tab/Img2Img.tsx +++ b/gui/src/components/tab/Img2Img.tsx @@ -8,7 +8,8 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { IMAGE_FILTER, STALE_TIME } from '../../config.js'; -import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; +import { TabState } from '../../state/types.js'; import { HighresParams, Img2ImgParams, ModelParams, UpscaleParams } from '../../types/params.js'; import { Profiles } from '../Profiles.js'; import { HighresControl } from '../control/HighresControl.js'; diff --git a/gui/src/components/tab/Inpaint.tsx b/gui/src/components/tab/Inpaint.tsx index b783f5b5..ca83835b 100644 --- a/gui/src/components/tab/Inpaint.tsx +++ b/gui/src/components/tab/Inpaint.tsx @@ -8,7 +8,8 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { IMAGE_FILTER, STALE_TIME } from '../../config.js'; -import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; +import { TabState } from '../../state/types.js'; import { BrushParams, HighresParams, InpaintParams, ModelParams, UpscaleParams } from '../../types/params.js'; import { Profiles } from '../Profiles.js'; import { HighresControl } from '../control/HighresControl.js'; diff --git a/gui/src/components/tab/Models.tsx b/gui/src/components/tab/Models.tsx index e634485a..86e96ecd 100644 --- a/gui/src/components/tab/Models.tsx +++ b/gui/src/components/tab/Models.tsx @@ -8,7 +8,7 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { STALE_TIME } from '../../config.js'; -import { ClientContext, OnnxState, StateContext } from '../../state.js'; +import { ClientContext, OnnxState, StateContext } from '../../state/full.js'; import { CorrectionModel, DiffusionModel, diff --git a/gui/src/components/tab/Settings.tsx b/gui/src/components/tab/Settings.tsx index 5f09c1a2..7b25ed3b 100644 --- a/gui/src/components/tab/Settings.tsx +++ b/gui/src/components/tab/Settings.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { getApiRoot } from '../../config.js'; -import { ConfigContext, StateContext, STATE_KEY } from '../../state.js'; +import { ConfigContext, StateContext, STATE_KEY } from '../../state/full.js'; import { getTheme } from '../utils.js'; import { NumericField } from '../input/NumericField.js'; diff --git a/gui/src/components/tab/Txt2Img.tsx b/gui/src/components/tab/Txt2Img.tsx index 921def1f..81286ec2 100644 --- a/gui/src/components/tab/Txt2Img.tsx +++ b/gui/src/components/tab/Txt2Img.tsx @@ -8,7 +8,8 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { PipelineGrid, makeTxt2ImgGridPipeline } from '../../client/utils.js'; -import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; +import { TabState } from '../../state/types.js'; import { HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from '../../types/params.js'; import { Profiles } from '../Profiles.js'; import { HighresControl } from '../control/HighresControl.js'; diff --git a/gui/src/components/tab/Upscale.tsx b/gui/src/components/tab/Upscale.tsx index 06314579..e9a1c482 100644 --- a/gui/src/components/tab/Upscale.tsx +++ b/gui/src/components/tab/Upscale.tsx @@ -8,7 +8,8 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { IMAGE_FILTER } from '../../config.js'; -import { ClientContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { ClientContext, OnnxState, StateContext } from '../../state/full.js'; +import { TabState } from '../../state/types.js'; import { HighresParams, ModelParams, UpscaleParams, UpscaleReqParams } from '../../types/params.js'; import { Profiles } from '../Profiles.js'; import { HighresControl } from '../control/HighresControl.js'; diff --git a/gui/src/components/utils.ts b/gui/src/components/utils.ts index 58e106ec..39ba9bae 100644 --- a/gui/src/components/utils.ts +++ b/gui/src/components/utils.ts @@ -1,6 +1,6 @@ import { PaletteMode } from '@mui/material'; -import { Theme } from '../state.js'; +import { Theme } from '../state/types.js'; import { trimHash } from '../utils.js'; export const TAB_LABELS = [ diff --git a/gui/src/main.tsx b/gui/src/main.tsx index d0b83a11..73b43f8f 100644 --- a/gui/src/main.tsx +++ b/gui/src/main.tsx @@ -28,7 +28,7 @@ import { STATE_KEY, STATE_VERSION, StateContext, -} from './state.js'; +} from './state/full.js'; import { I18N_STRINGS } from './strings/all.js'; export const INITIAL_LOAD_TIMEOUT = 5_000; diff --git a/gui/src/state.ts b/gui/src/state.ts deleted file mode 100644 index fd9a21a0..00000000 --- a/gui/src/state.ts +++ /dev/null @@ -1,817 +0,0 @@ -/* eslint-disable camelcase */ -/* eslint-disable max-lines */ -/* eslint-disable no-null/no-null */ -import { Maybe } from '@apextoaster/js-utils'; -import { Logger } from 'noicejs'; -import { createContext } from 'react'; -import { StateCreator, StoreApi } from 'zustand'; - -import { - ApiClient, -} from './client/base.js'; -import { PipelineGrid } from './client/utils.js'; -import { Config, ServerParams } from './config.js'; -import { - BaseImgParams, - HighresParams, - ModelParams, - UpscaleParams, -} from './types/params.js'; -import { DefaultSlice } from './state/default.js'; -import { HistorySlice } from './state/history.js'; -import { Img2ImgSlice } from './state/img2img.js'; -import { InpaintSlice } from './state/inpaint.js'; -import { ModelSlice } from './state/models.js'; -import { Txt2ImgSlice } from './state/txt2img.js'; -import { UpscaleSlice } from './state/upscale.js'; -import { ResetSlice } from './state/reset.js'; -import { ProfileItem, ProfileSlice } from './state/profile.js'; -import { BlendSlice } from './state/blend.js'; -import { MISSING_INDEX } from './state/types.js'; - -/** - * Full merged state including all slices. - */ -export type OnnxState - = DefaultSlice - & HistorySlice - & Img2ImgSlice - & InpaintSlice - & ModelSlice - & Txt2ImgSlice - & UpscaleSlice - & BlendSlice - & ResetSlice - & ProfileSlice; - -/** - * Shorthand for state creator to reduce repeated arguments. - */ -export type Slice = StateCreator; - -/** - * React context binding for API client. - */ -export const ClientContext = createContext>(undefined); - -/** - * React context binding for merged config, including server parameters. - */ -export const ConfigContext = createContext>>(undefined); - -/** - * React context binding for bunyan logger. - */ -export const LoggerContext = createContext>(undefined); - -/** - * React context binding for zustand state store. - */ -export const StateContext = createContext>>(undefined); - -/** - * Key for zustand persistence, typically local storage. - */ -export const STATE_KEY = 'onnx-web'; - -/** - * Current state version for zustand persistence. - */ -export const STATE_VERSION = 7; - -export const BLEND_SOURCES = 2; - -/** - * Default parameters for the inpaint brush. - * - * Not provided by the server yet. - */ -export const DEFAULT_BRUSH = { - color: 255, - size: 8, - strength: 0.5, -}; - -/** - * Default parameters for the image history. - * - * Not provided by the server yet. - */ -export const DEFAULT_HISTORY = { - /** - * The number of images to be shown. - */ - limit: 4, - - /** - * The number of additional images to be kept in history, so they can scroll - * back into view when you delete one. Does not include deleted images. - */ - scrollback: 2, -}; - -export function baseParamsFromServer(defaults: ServerParams): Required { - return { - batch: defaults.batch.default, - cfg: defaults.cfg.default, - eta: defaults.eta.default, - negativePrompt: defaults.negativePrompt.default, - prompt: defaults.prompt.default, - scheduler: defaults.scheduler.default, - steps: defaults.steps.default, - seed: defaults.seed.default, - tiled_vae: defaults.tiled_vae.default, - unet_overlap: defaults.unet_overlap.default, - unet_tile: defaults.unet_tile.default, - vae_overlap: defaults.vae_overlap.default, - vae_tile: defaults.vae_tile.default, - }; -} - -/** - * Prepare the state slice constructors. - * - * In the default state, image sources should be null and booleans should be false. Everything - * else should be initialized from the default value in the base parameters. - */ -export function createStateSlices(server: ServerParams) { - const defaultParams = baseParamsFromServer(server); - const defaultHighres: HighresParams = { - enabled: false, - highresIterations: server.highresIterations.default, - highresMethod: '', - highresSteps: server.highresSteps.default, - highresScale: server.highresScale.default, - highresStrength: server.highresStrength.default, - }; - const defaultModel: ModelParams = { - control: server.control.default, - correction: server.correction.default, - model: server.model.default, - pipeline: server.pipeline.default, - platform: server.platform.default, - upscaling: server.upscaling.default, - }; - const defaultUpscale: UpscaleParams = { - denoise: server.denoise.default, - enabled: false, - faces: false, - faceOutscale: server.faceOutscale.default, - faceStrength: server.faceStrength.default, - outscale: server.outscale.default, - scale: server.scale.default, - upscaleOrder: server.upscaleOrder.default, - }; - const defaultGrid: PipelineGrid = { - enabled: false, - columns: { - parameter: 'seed', - value: '', - }, - rows: { - parameter: 'seed', - value: '', - }, - }; - - const createTxt2ImgSlice: Slice = (set) => ({ - txt2img: { - ...defaultParams, - width: server.width.default, - height: server.height.default, - }, - txt2imgHighres: { - ...defaultHighres, - }, - txt2imgModel: { - ...defaultModel, - }, - txt2imgUpscale: { - ...defaultUpscale, - }, - txt2imgVariable: { - ...defaultGrid, - }, - setTxt2Img(params) { - set((prev) => ({ - txt2img: { - ...prev.txt2img, - ...params, - }, - })); - }, - setTxt2ImgHighres(params) { - set((prev) => ({ - txt2imgHighres: { - ...prev.txt2imgHighres, - ...params, - }, - })); - }, - setTxt2ImgModel(params) { - set((prev) => ({ - txt2imgModel: { - ...prev.txt2imgModel, - ...params, - }, - })); - }, - setTxt2ImgUpscale(params) { - set((prev) => ({ - txt2imgUpscale: { - ...prev.txt2imgUpscale, - ...params, - }, - })); - }, - setTxt2ImgVariable(params) { - set((prev) => ({ - txt2imgVariable: { - ...prev.txt2imgVariable, - ...params, - }, - })); - }, - resetTxt2Img() { - set({ - txt2img: { - ...defaultParams, - width: server.width.default, - height: server.height.default, - }, - }); - }, - }); - - const createImg2ImgSlice: Slice = (set) => ({ - img2img: { - ...defaultParams, - loopback: server.loopback.default, - source: null, - sourceFilter: '', - strength: server.strength.default, - }, - img2imgHighres: { - ...defaultHighres, - }, - img2imgModel: { - ...defaultModel, - }, - img2imgUpscale: { - ...defaultUpscale, - }, - resetImg2Img() { - set({ - img2img: { - ...defaultParams, - loopback: server.loopback.default, - source: null, - sourceFilter: '', - strength: server.strength.default, - }, - }); - }, - setImg2Img(params) { - set((prev) => ({ - img2img: { - ...prev.img2img, - ...params, - }, - })); - }, - setImg2ImgHighres(params) { - set((prev) => ({ - img2imgHighres: { - ...prev.img2imgHighres, - ...params, - }, - })); - }, - setImg2ImgModel(params) { - set((prev) => ({ - img2imgModel: { - ...prev.img2imgModel, - ...params, - }, - })); - }, - setImg2ImgUpscale(params) { - set((prev) => ({ - img2imgUpscale: { - ...prev.img2imgUpscale, - ...params, - }, - })); - }, - }); - - const createInpaintSlice: Slice = (set) => ({ - inpaint: { - ...defaultParams, - fillColor: server.fillColor.default, - filter: server.filter.default, - mask: null, - noise: server.noise.default, - source: null, - strength: server.strength.default, - tileOrder: server.tileOrder.default, - }, - inpaintBrush: { - ...DEFAULT_BRUSH, - }, - inpaintHighres: { - ...defaultHighres, - }, - inpaintModel: { - ...defaultModel, - }, - inpaintUpscale: { - ...defaultUpscale, - }, - outpaint: { - enabled: false, - left: server.left.default, - right: server.right.default, - top: server.top.default, - bottom: server.bottom.default, - }, - resetInpaint() { - set({ - inpaint: { - ...defaultParams, - fillColor: server.fillColor.default, - filter: server.filter.default, - mask: null, - noise: server.noise.default, - source: null, - strength: server.strength.default, - tileOrder: server.tileOrder.default, - }, - }); - }, - setInpaint(params) { - set((prev) => ({ - inpaint: { - ...prev.inpaint, - ...params, - }, - })); - }, - setInpaintBrush(brush) { - set((prev) => ({ - inpaintBrush: { - ...prev.inpaintBrush, - ...brush, - }, - })); - }, - setInpaintHighres(params) { - set((prev) => ({ - inpaintHighres: { - ...prev.inpaintHighres, - ...params, - }, - })); - }, - setInpaintModel(params) { - set((prev) => ({ - inpaintModel: { - ...prev.inpaintModel, - ...params, - }, - })); - }, - setInpaintUpscale(params) { - set((prev) => ({ - inpaintUpscale: { - ...prev.inpaintUpscale, - ...params, - }, - })); - }, - setOutpaint(pixels) { - set((prev) => ({ - outpaint: { - ...prev.outpaint, - ...pixels, - } - })); - }, - }); - - const createHistorySlice: Slice = (set) => ({ - history: [], - limit: DEFAULT_HISTORY.limit, - pushHistory(image, retry) { - set((prev) => ({ - ...prev, - history: [ - { - image, - ready: undefined, - retry, - }, - ...prev.history, - ].slice(0, prev.limit + DEFAULT_HISTORY.scrollback), - })); - }, - removeHistory(image) { - set((prev) => ({ - ...prev, - history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key), - })); - }, - setLimit(limit) { - set((prev) => ({ - ...prev, - limit, - })); - }, - setReady(image, ready) { - set((prev) => { - const history = [...prev.history]; - const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key); - if (idx >= 0) { - history[idx].ready = ready; - } else { - // TODO: error - } - - return { - ...prev, - history, - }; - }); - }, - }); - - const createUpscaleSlice: Slice = (set) => ({ - upscale: { - ...defaultParams, - source: null, - }, - upscaleHighres: { - ...defaultHighres, - }, - upscaleModel: { - ...defaultModel, - }, - upscaleUpscale: { - ...defaultUpscale, - }, - resetUpscale() { - set({ - upscale: { - ...defaultParams, - source: null, - }, - }); - }, - setUpscale(source) { - set((prev) => ({ - upscale: { - ...prev.upscale, - ...source, - }, - })); - }, - setUpscaleHighres(params) { - set((prev) => ({ - upscaleHighres: { - ...prev.upscaleHighres, - ...params, - }, - })); - }, - setUpscaleModel(params) { - set((prev) => ({ - upscaleModel: { - ...prev.upscaleModel, - ...defaultModel, - }, - })); - }, - setUpscaleUpscale(params) { - set((prev) => ({ - upscaleUpscale: { - ...prev.upscaleUpscale, - ...params, - }, - })); - }, - }); - - const createBlendSlice: Slice = (set) => ({ - blend: { - mask: null, - sources: [], - }, - blendBrush: { - ...DEFAULT_BRUSH, - }, - blendModel: { - ...defaultModel, - }, - blendUpscale: { - ...defaultUpscale, - }, - resetBlend() { - set({ - blend: { - mask: null, - sources: [], - }, - }); - }, - setBlend(blend) { - set((prev) => ({ - blend: { - ...prev.blend, - ...blend, - }, - })); - }, - setBlendBrush(brush) { - set((prev) => ({ - blendBrush: { - ...prev.blendBrush, - ...brush, - }, - })); - }, - setBlendModel(model) { - set((prev) => ({ - blendModel: { - ...prev.blendModel, - ...model, - }, - })); - }, - setBlendUpscale(params) { - set((prev) => ({ - blendUpscale: { - ...prev.blendUpscale, - ...params, - }, - })); - }, - }); - - const createDefaultSlice: Slice = (set) => ({ - defaults: { - ...defaultParams, - }, - theme: '', - setDefaults(params) { - set((prev) => ({ - defaults: { - ...prev.defaults, - ...params, - } - })); - }, - setTheme(theme) { - set((prev) => ({ - theme, - })); - } - }); - - const createResetSlice: Slice = (set) => ({ - resetAll() { - set((prev) => { - const next = { ...prev }; - next.resetImg2Img(); - next.resetInpaint(); - next.resetTxt2Img(); - next.resetUpscale(); - next.resetBlend(); - return next; - }); - }, - }); - - const createProfileSlice: Slice = (set) => ({ - profiles: [], - saveProfile(profile: ProfileItem) { - set((prev) => { - const profiles = [...prev.profiles]; - const idx = profiles.findIndex((it) => it.name === profile.name); - if (idx >= 0) { - profiles[idx] = profile; - } else { - profiles.push(profile); - } - return { - ...prev, - profiles, - }; - }); - }, - removeProfile(profileName: string) { - set((prev) => { - const profiles = [...prev.profiles]; - const idx = profiles.findIndex((it) => it.name === profileName); - if (idx >= 0) { - profiles.splice(idx, 1); - } - return { - ...prev, - profiles, - }; - }); - } - }); - - // eslint-disable-next-line sonarjs/cognitive-complexity - const createModelSlice: Slice = (set) => ({ - extras: { - correction: [], - diffusion: [], - networks: [], - sources: [], - upscaling: [], - }, - setExtras(extras) { - set((prev) => ({ - ...prev, - extras: { - ...prev.extras, - ...extras, - }, - })); - }, - setCorrectionModel(model) { - set((prev) => { - const correction = [...prev.extras.correction]; - const exists = correction.findIndex((it) => model.name === it.name); - if (exists === MISSING_INDEX) { - correction.push(model); - } else { - correction[exists] = model; - } - - return { - ...prev, - extras: { - ...prev.extras, - correction, - }, - }; - }); - }, - setDiffusionModel(model) { - set((prev) => { - const diffusion = [...prev.extras.diffusion]; - const exists = diffusion.findIndex((it) => model.name === it.name); - if (exists === MISSING_INDEX) { - diffusion.push(model); - } else { - diffusion[exists] = model; - } - - return { - ...prev, - extras: { - ...prev.extras, - diffusion, - }, - }; - }); - }, - setExtraNetwork(model) { - set((prev) => { - const networks = [...prev.extras.networks]; - const exists = networks.findIndex((it) => model.name === it.name); - if (exists === MISSING_INDEX) { - networks.push(model); - } else { - networks[exists] = model; - } - - return { - ...prev, - extras: { - ...prev.extras, - networks, - }, - }; - }); - }, - setExtraSource(model) { - set((prev) => { - const sources = [...prev.extras.sources]; - const exists = sources.findIndex((it) => model.name === it.name); - if (exists === MISSING_INDEX) { - sources.push(model); - } else { - sources[exists] = model; - } - - return { - ...prev, - extras: { - ...prev.extras, - sources, - }, - }; - }); - }, - setUpscalingModel(model) { - set((prev) => { - const upscaling = [...prev.extras.upscaling]; - const exists = upscaling.findIndex((it) => model.name === it.name); - if (exists === MISSING_INDEX) { - upscaling.push(model); - } else { - upscaling[exists] = model; - } - - return { - ...prev, - extras: { - ...prev.extras, - upscaling, - }, - }; - }); - }, - removeCorrectionModel(model) { - set((prev) => { - const correction = prev.extras.correction.filter((it) => model.name !== it.name);; - return { - ...prev, - extras: { - ...prev.extras, - correction, - }, - }; - }); - - }, - removeDiffusionModel(model) { - set((prev) => { - const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);; - return { - ...prev, - extras: { - ...prev.extras, - diffusion, - }, - }; - }); - - }, - removeExtraNetwork(model) { - set((prev) => { - const networks = prev.extras.networks.filter((it) => model.name !== it.name);; - return { - ...prev, - extras: { - ...prev.extras, - networks, - }, - }; - }); - - }, - removeExtraSource(model) { - set((prev) => { - const sources = prev.extras.sources.filter((it) => model.name !== it.name);; - return { - ...prev, - extras: { - ...prev.extras, - sources, - }, - }; - }); - - }, - removeUpscalingModel(model) { - set((prev) => { - const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);; - return { - ...prev, - extras: { - ...prev.extras, - upscaling, - }, - }; - }); - }, - }); - - return { - createDefaultSlice, - createHistorySlice, - createImg2ImgSlice, - createInpaintSlice, - createTxt2ImgSlice, - createUpscaleSlice, - createBlendSlice, - createResetSlice, - createModelSlice, - createProfileSlice, - }; -} diff --git a/gui/src/state/blend.ts b/gui/src/state/blend.ts index 42d292f3..e433a07e 100644 --- a/gui/src/state/blend.ts +++ b/gui/src/state/blend.ts @@ -4,7 +4,7 @@ import { ModelParams, UpscaleParams, } from '../types/params.js'; -import { TabState } from './types.js'; +import { DEFAULT_BRUSH, Slice, TabState } from './types.js'; export interface BlendSlice { blend: TabState; @@ -19,3 +19,66 @@ export interface BlendSlice { setBlendModel(model: Partial): void; setBlendUpscale(params: Partial): void; } + +export function createBlendSlice( + defaultModel: ModelParams, + defaultUpscale: UpscaleParams, +): Slice { + return (set) => ({ + blend: { + // eslint-disable-next-line no-null/no-null + mask: null, + sources: [], + }, + blendBrush: { + ...DEFAULT_BRUSH, + }, + blendModel: { + ...defaultModel, + }, + blendUpscale: { + ...defaultUpscale, + }, + resetBlend() { + set((prev) => ({ + blend: { + // eslint-disable-next-line no-null/no-null + mask: null, + sources: [] as Array, + }, + } as Partial)); + }, + setBlend(blend) { + set((prev) => ({ + blend: { + ...prev.blend, + ...blend, + }, + } as Partial)); + }, + setBlendBrush(brush) { + set((prev) => ({ + blendBrush: { + ...prev.blendBrush, + ...brush, + }, + } as Partial)); + }, + setBlendModel(model) { + set((prev) => ({ + blendModel: { + ...prev.blendModel, + ...model, + }, + } as Partial)); + }, + setBlendUpscale(params) { + set((prev) => ({ + blendUpscale: { + ...prev.blendUpscale, + ...params, + }, + } as Partial)); + }, + }); +} diff --git a/gui/src/state/default.ts b/gui/src/state/default.ts index b43b7ceb..c578263b 100644 --- a/gui/src/state/default.ts +++ b/gui/src/state/default.ts @@ -1,7 +1,7 @@ import { BaseImgParams, } from '../types/params.js'; -import { TabState, Theme } from './types.js'; +import { Slice, TabState, Theme } from './types.js'; export interface DefaultSlice { defaults: TabState; @@ -10,3 +10,25 @@ export interface DefaultSlice { setDefaults(param: Partial): void; setTheme(theme: Theme): void; } + +export function createDefaultSlice(defaultParams: Required): Slice { + return (set) => ({ + defaults: { + ...defaultParams, + }, + theme: '', + setDefaults(params) { + set((prev) => ({ + defaults: { + ...prev.defaults, + ...params, + } + } as Partial)); + }, + setTheme(theme) { + set((prev) => ({ + theme, + } as Partial)); + } + }); +} diff --git a/gui/src/state/full.ts b/gui/src/state/full.ts new file mode 100644 index 00000000..e46c328a --- /dev/null +++ b/gui/src/state/full.ts @@ -0,0 +1,152 @@ +/* eslint-disable camelcase */ +import { Maybe } from '@apextoaster/js-utils'; +import { Logger } from 'noicejs'; +import { createContext } from 'react'; +import { StoreApi } from 'zustand'; + +import { + ApiClient, +} from '../client/base.js'; +import { PipelineGrid } from '../client/utils.js'; +import { Config, ServerParams } from '../config.js'; +import { BlendSlice, createBlendSlice } from './blend.js'; +import { DefaultSlice, createDefaultSlice } from './default.js'; +import { HistorySlice, createHistorySlice } from './history.js'; +import { Img2ImgSlice, createImg2ImgSlice } from './img2img.js'; +import { InpaintSlice, createInpaintSlice } from './inpaint.js'; +import { ModelSlice, createModelSlice } from './model.js'; +import { ProfileSlice, createProfileSlice } from './profile.js'; +import { ResetSlice, createResetSlice } from './reset.js'; +import { Txt2ImgSlice, createTxt2ImgSlice } from './txt2img.js'; +import { UpscaleSlice, createUpscaleSlice } from './upscale.js'; +import { + BaseImgParams, + HighresParams, + ModelParams, + UpscaleParams, +} from '../types/params.js'; + +/** + * Full merged state including all slices. + */ +export type OnnxState + = DefaultSlice + & HistorySlice + & Img2ImgSlice + & InpaintSlice + & ModelSlice + & Txt2ImgSlice + & UpscaleSlice + & BlendSlice + & ResetSlice + & ProfileSlice; + +/** + * React context binding for API client. + */ +export const ClientContext = createContext>(undefined); + +/** + * React context binding for merged config, including server parameters. + */ +export const ConfigContext = createContext>>(undefined); + +/** + * React context binding for bunyan logger. + */ +export const LoggerContext = createContext>(undefined); + +/** + * React context binding for zustand state store. + */ +export const StateContext = createContext>>(undefined); + +/** + * Key for zustand persistence, typically local storage. + */ +export const STATE_KEY = 'onnx-web'; + +/** + * Current state version for zustand persistence. + */ +export const STATE_VERSION = 7; + +export const BLEND_SOURCES = 2; + +export function baseParamsFromServer(defaults: ServerParams): Required { + return { + batch: defaults.batch.default, + cfg: defaults.cfg.default, + eta: defaults.eta.default, + negativePrompt: defaults.negativePrompt.default, + prompt: defaults.prompt.default, + scheduler: defaults.scheduler.default, + steps: defaults.steps.default, + seed: defaults.seed.default, + tiled_vae: defaults.tiled_vae.default, + unet_overlap: defaults.unet_overlap.default, + unet_tile: defaults.unet_tile.default, + vae_overlap: defaults.vae_overlap.default, + vae_tile: defaults.vae_tile.default, + }; +} + +/** + * Prepare the state slice constructors. + * + * In the default state, image sources should be null and booleans should be false. Everything + * else should be initialized from the default value in the base parameters. + */ +export function createStateSlices(server: ServerParams) { + const defaultParams = baseParamsFromServer(server); + const defaultHighres: HighresParams = { + enabled: false, + highresIterations: server.highresIterations.default, + highresMethod: '', + highresSteps: server.highresSteps.default, + highresScale: server.highresScale.default, + highresStrength: server.highresStrength.default, + }; + const defaultModel: ModelParams = { + control: server.control.default, + correction: server.correction.default, + model: server.model.default, + pipeline: server.pipeline.default, + platform: server.platform.default, + upscaling: server.upscaling.default, + }; + const defaultUpscale: UpscaleParams = { + denoise: server.denoise.default, + enabled: false, + faces: false, + faceOutscale: server.faceOutscale.default, + faceStrength: server.faceStrength.default, + outscale: server.outscale.default, + scale: server.scale.default, + upscaleOrder: server.upscaleOrder.default, + }; + const defaultGrid: PipelineGrid = { + enabled: false, + columns: { + parameter: 'seed', + value: '', + }, + rows: { + parameter: 'seed', + value: '', + }, + }; + + return { + createBlendSlice: createBlendSlice(defaultModel, defaultUpscale), + createDefaultSlice: createDefaultSlice(defaultParams), + createHistorySlice: createHistorySlice(), + createImg2ImgSlice: createImg2ImgSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale), + createInpaintSlice: createInpaintSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale), + createModelSlice: createModelSlice(), + createProfileSlice: createProfileSlice(), + createResetSlice: createResetSlice(), + createTxt2ImgSlice: createTxt2ImgSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale, defaultGrid), + createUpscaleSlice: createUpscaleSlice(defaultParams, defaultHighres, defaultModel, defaultUpscale), + }; +} diff --git a/gui/src/state/history.ts b/gui/src/state/history.ts index 44f7b2f3..4eb58271 100644 --- a/gui/src/state/history.ts +++ b/gui/src/state/history.ts @@ -1,5 +1,6 @@ import { Maybe } from '@apextoaster/js-utils'; import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js'; +import { DEFAULT_HISTORY, Slice } from './types.js'; export interface HistoryItem { image: ImageResponse; @@ -16,3 +17,51 @@ export interface HistorySlice { setLimit(limit: number): void; setReady(image: ImageResponse, ready: ReadyResponse): void; } + +export function createHistorySlice(): Slice { + return (set) => ({ + history: [], + limit: DEFAULT_HISTORY.limit, + pushHistory(image, retry) { + set((prev) => ({ + ...prev, + history: [ + { + image, + ready: undefined, + retry, + }, + ...prev.history, + ].slice(0, prev.limit + DEFAULT_HISTORY.scrollback), + })); + }, + removeHistory(image) { + set((prev) => ({ + ...prev, + history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key), + })); + }, + setLimit(limit) { + set((prev) => ({ + ...prev, + limit, + })); + }, + setReady(image, ready) { + set((prev) => { + const history = [...prev.history]; + const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key); + if (idx >= 0) { + history[idx].ready = ready; + } else { + // TODO: error + } + + return { + ...prev, + history, + }; + }); + }, + }); +} diff --git a/gui/src/state/img2img.ts b/gui/src/state/img2img.ts index cbe204d1..b8f986b0 100644 --- a/gui/src/state/img2img.ts +++ b/gui/src/state/img2img.ts @@ -1,11 +1,13 @@ +import { ServerParams } from '../config.js'; import { + BaseImgParams, HighresParams, Img2ImgParams, ModelParams, UpscaleParams, } from '../types/params.js'; -import { TabState } from './types.js'; +import { Slice, TabState } from './types.js'; export interface Img2ImgSlice { img2img: TabState; @@ -20,3 +22,77 @@ export interface Img2ImgSlice { setImg2ImgHighres(params: Partial): void; setImg2ImgUpscale(params: Partial): void; } + +// eslint-disable-next-line max-params +export function createImg2ImgSlice( + server: ServerParams, + defaultParams: Required, + defaultHighres: HighresParams, + defaultModel: ModelParams, + defaultUpscale: UpscaleParams +): Slice { + return (set) => ({ + img2img: { + ...defaultParams, + loopback: server.loopback.default, + // eslint-disable-next-line no-null/no-null + source: null, + sourceFilter: '', + strength: server.strength.default, + }, + img2imgHighres: { + ...defaultHighres, + }, + img2imgModel: { + ...defaultModel, + }, + img2imgUpscale: { + ...defaultUpscale, + }, + resetImg2Img() { + set({ + img2img: { + ...defaultParams, + loopback: server.loopback.default, + // eslint-disable-next-line no-null/no-null + source: null, + sourceFilter: '', + strength: server.strength.default, + }, + } as Partial); + }, + setImg2Img(params) { + set((prev) => ({ + img2img: { + ...prev.img2img, + ...params, + }, + } as Partial)); + }, + setImg2ImgHighres(params) { + set((prev) => ({ + img2imgHighres: { + ...prev.img2imgHighres, + ...params, + }, + } as Partial)); + }, + setImg2ImgModel(params) { + set((prev) => ({ + img2imgModel: { + ...prev.img2imgModel, + ...params, + }, + } as Partial)); + }, + setImg2ImgUpscale(params) { + set((prev) => ({ + img2imgUpscale: { + ...prev.img2imgUpscale, + ...params, + }, + } as Partial)); + }, + }); + +} diff --git a/gui/src/state/inpaint.ts b/gui/src/state/inpaint.ts index 12756ac2..3eac9113 100644 --- a/gui/src/state/inpaint.ts +++ b/gui/src/state/inpaint.ts @@ -1,4 +1,6 @@ +import { ServerParams } from '../config.js'; import { + BaseImgParams, BrushParams, HighresParams, InpaintParams, @@ -6,8 +8,7 @@ import { OutpaintPixels, UpscaleParams, } from '../types/params.js'; -import { TabState } from './types.js'; - +import { DEFAULT_BRUSH, Slice, TabState } from './types.js'; export interface InpaintSlice { inpaint: TabState; inpaintBrush: BrushParams; @@ -25,3 +26,110 @@ export interface InpaintSlice { setInpaintUpscale(params: Partial): void; setOutpaint(pixels: Partial): void; } + +// eslint-disable-next-line max-params +export function createInpaintSlice( + server: ServerParams, + defaultParams: Required, + defaultHighres: HighresParams, + defaultModel: ModelParams, + defaultUpscale: UpscaleParams, +): Slice { + return (set) => ({ + inpaint: { + ...defaultParams, + fillColor: server.fillColor.default, + filter: server.filter.default, + // eslint-disable-next-line no-null/no-null + mask: null, + noise: server.noise.default, + // eslint-disable-next-line no-null/no-null + source: null, + strength: server.strength.default, + tileOrder: server.tileOrder.default, + }, + inpaintBrush: { + ...DEFAULT_BRUSH, + }, + inpaintHighres: { + ...defaultHighres, + }, + inpaintModel: { + ...defaultModel, + }, + inpaintUpscale: { + ...defaultUpscale, + }, + outpaint: { + enabled: false, + left: server.left.default, + right: server.right.default, + top: server.top.default, + bottom: server.bottom.default, + }, + resetInpaint() { + set({ + inpaint: { + ...defaultParams, + fillColor: server.fillColor.default, + filter: server.filter.default, + // eslint-disable-next-line no-null/no-null + mask: null, + noise: server.noise.default, + // eslint-disable-next-line no-null/no-null + source: null, + strength: server.strength.default, + tileOrder: server.tileOrder.default, + }, + } as Partial); + }, + setInpaint(params) { + set((prev) => ({ + inpaint: { + ...prev.inpaint, + ...params, + }, + } as Partial)); + }, + setInpaintBrush(brush) { + set((prev) => ({ + inpaintBrush: { + ...prev.inpaintBrush, + ...brush, + }, + } as Partial)); + }, + setInpaintHighres(params) { + set((prev) => ({ + inpaintHighres: { + ...prev.inpaintHighres, + ...params, + }, + } as Partial)); + }, + setInpaintModel(params) { + set((prev) => ({ + inpaintModel: { + ...prev.inpaintModel, + ...params, + }, + } as Partial)); + }, + setInpaintUpscale(params) { + set((prev) => ({ + inpaintUpscale: { + ...prev.inpaintUpscale, + ...params, + }, + } as Partial)); + }, + setOutpaint(pixels) { + set((prev) => ({ + outpaint: { + ...prev.outpaint, + ...pixels, + } + } as Partial)); + }, + }); +} diff --git a/gui/src/state/model.ts b/gui/src/state/model.ts new file mode 100644 index 00000000..3a473182 --- /dev/null +++ b/gui/src/state/model.ts @@ -0,0 +1,202 @@ +import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from '../types/model.js'; +import { MISSING_INDEX, Slice } from './types.js'; + +export interface ModelSlice { + extras: ExtrasFile; + + removeCorrectionModel(model: CorrectionModel): void; + removeDiffusionModel(model: DiffusionModel): void; + removeExtraNetwork(model: ExtraNetwork): void; + removeExtraSource(model: ExtraSource): void; + removeUpscalingModel(model: UpscalingModel): void; + + setExtras(extras: Partial): void; + + setCorrectionModel(model: CorrectionModel): void; + setDiffusionModel(model: DiffusionModel): void; + setExtraNetwork(model: ExtraNetwork): void; + setExtraSource(model: ExtraSource): void; + setUpscalingModel(model: UpscalingModel): void; +} + +// eslint-disable-next-line sonarjs/cognitive-complexity +export function createModelSlice(): Slice { + // eslint-disable-next-line sonarjs/cognitive-complexity + return (set) => ({ + extras: { + correction: [], + diffusion: [], + networks: [], + sources: [], + upscaling: [], + }, + setExtras(extras) { + set((prev) => ({ + ...prev, + extras: { + ...prev.extras, + ...extras, + }, + })); + }, + setCorrectionModel(model) { + set((prev) => { + const correction = [...prev.extras.correction]; + const exists = correction.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + correction.push(model); + } else { + correction[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + correction, + }, + }; + }); + }, + setDiffusionModel(model) { + set((prev) => { + const diffusion = [...prev.extras.diffusion]; + const exists = diffusion.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + diffusion.push(model); + } else { + diffusion[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + diffusion, + }, + }; + }); + }, + setExtraNetwork(model) { + set((prev) => { + const networks = [...prev.extras.networks]; + const exists = networks.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + networks.push(model); + } else { + networks[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + networks, + }, + }; + }); + }, + setExtraSource(model) { + set((prev) => { + const sources = [...prev.extras.sources]; + const exists = sources.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + sources.push(model); + } else { + sources[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + sources, + }, + }; + }); + }, + setUpscalingModel(model) { + set((prev) => { + const upscaling = [...prev.extras.upscaling]; + const exists = upscaling.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + upscaling.push(model); + } else { + upscaling[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + upscaling, + }, + }; + }); + }, + removeCorrectionModel(model) { + set((prev) => { + const correction = prev.extras.correction.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + correction, + }, + }; + }); + + }, + removeDiffusionModel(model) { + set((prev) => { + const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + diffusion, + }, + }; + }); + + }, + removeExtraNetwork(model) { + set((prev) => { + const networks = prev.extras.networks.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + networks, + }, + }; + }); + + }, + removeExtraSource(model) { + set((prev) => { + const sources = prev.extras.sources.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + sources, + }, + }; + }); + + }, + removeUpscalingModel(model) { + set((prev) => { + const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + upscaling, + }, + }; + }); + }, + }); +} diff --git a/gui/src/state/models.ts b/gui/src/state/models.ts deleted file mode 100644 index e0faa27d..00000000 --- a/gui/src/state/models.ts +++ /dev/null @@ -1,19 +0,0 @@ -import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from '../types/model.js'; - -export interface ModelSlice { - extras: ExtrasFile; - - removeCorrectionModel(model: CorrectionModel): void; - removeDiffusionModel(model: DiffusionModel): void; - removeExtraNetwork(model: ExtraNetwork): void; - removeExtraSource(model: ExtraSource): void; - removeUpscalingModel(model: UpscalingModel): void; - - setExtras(extras: Partial): void; - - setCorrectionModel(model: CorrectionModel): void; - setDiffusionModel(model: DiffusionModel): void; - setExtraNetwork(model: ExtraNetwork): void; - setExtraSource(model: ExtraSource): void; - setUpscalingModel(model: UpscalingModel): void; -} diff --git a/gui/src/state/profile.ts b/gui/src/state/profile.ts index 7ffdfac9..73d52eab 100644 --- a/gui/src/state/profile.ts +++ b/gui/src/state/profile.ts @@ -1,5 +1,6 @@ import { Maybe } from '@apextoaster/js-utils'; import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js'; +import { Slice } from './types.js'; export interface ProfileItem { name: string; @@ -15,3 +16,37 @@ export interface ProfileSlice { saveProfile(profile: ProfileItem): void; } + +export function createProfileSlice(): Slice { + return (set) => ({ + profiles: [], + saveProfile(profile: ProfileItem) { + set((prev) => { + const profiles = [...prev.profiles]; + const idx = profiles.findIndex((it) => it.name === profile.name); + if (idx >= 0) { + profiles[idx] = profile; + } else { + profiles.push(profile); + } + return { + ...prev, + profiles, + }; + }); + }, + removeProfile(profileName: string) { + set((prev) => { + const profiles = [...prev.profiles]; + const idx = profiles.findIndex((it) => it.name === profileName); + if (idx >= 0) { + profiles.splice(idx, 1); + } + return { + ...prev, + profiles, + }; + }); + } + }); +} diff --git a/gui/src/state/reset.ts b/gui/src/state/reset.ts index 66b545a5..53272e5c 100644 --- a/gui/src/state/reset.ts +++ b/gui/src/state/reset.ts @@ -1,3 +1,28 @@ +import { BlendSlice } from './blend.js'; +import { Img2ImgSlice } from './img2img.js'; +import { InpaintSlice } from './inpaint.js'; +import { Txt2ImgSlice } from './txt2img.js'; +import { Slice } from './types.js'; +import { UpscaleSlice } from './upscale.js'; + +export type SlicesWithReset = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & UpscaleSlice & BlendSlice; + export interface ResetSlice { resetAll(): void; } + +export function createResetSlice(): Slice { + return (set) => ({ + resetAll() { + set((prev) => { + const next = { ...prev }; + next.resetImg2Img(); + next.resetInpaint(); + next.resetTxt2Img(); + next.resetUpscale(); + next.resetBlend(); + return next; + }); + }, + }); +} diff --git a/gui/src/state/txt2img.ts b/gui/src/state/txt2img.ts index d8cd95eb..8bec3273 100644 --- a/gui/src/state/txt2img.ts +++ b/gui/src/state/txt2img.ts @@ -1,11 +1,13 @@ import { PipelineGrid } from '../client/utils.js'; +import { ServerParams } from '../config.js'; import { + BaseImgParams, HighresParams, ModelParams, Txt2ImgParams, UpscaleParams, } from '../types/params.js'; -import { TabState } from './types.js'; +import { Slice, TabState } from './types.js'; export interface Txt2ImgSlice { txt2img: TabState; @@ -22,3 +24,82 @@ export interface Txt2ImgSlice { setTxt2ImgUpscale(params: Partial): void; setTxt2ImgVariable(params: Partial): void; } + +// eslint-disable-next-line max-params +export function createTxt2ImgSlice( + server: ServerParams, + defaultParams: Required, + defaultHighres: HighresParams, + defaultModel: ModelParams, + defaultUpscale: UpscaleParams, + defaultGrid: PipelineGrid, +): Slice { + return (set) => ({ + txt2img: { + ...defaultParams, + width: server.width.default, + height: server.height.default, + }, + txt2imgHighres: { + ...defaultHighres, + }, + txt2imgModel: { + ...defaultModel, + }, + txt2imgUpscale: { + ...defaultUpscale, + }, + txt2imgVariable: { + ...defaultGrid, + }, + setTxt2Img(params) { + set((prev) => ({ + txt2img: { + ...prev.txt2img, + ...params, + }, + } as Partial)); + }, + setTxt2ImgHighres(params) { + set((prev) => ({ + txt2imgHighres: { + ...prev.txt2imgHighres, + ...params, + }, + } as Partial)); + }, + setTxt2ImgModel(params) { + set((prev) => ({ + txt2imgModel: { + ...prev.txt2imgModel, + ...params, + }, + } as Partial)); + }, + setTxt2ImgUpscale(params) { + set((prev) => ({ + txt2imgUpscale: { + ...prev.txt2imgUpscale, + ...params, + }, + } as Partial)); + }, + setTxt2ImgVariable(params) { + set((prev) => ({ + txt2imgVariable: { + ...prev.txt2imgVariable, + ...params, + }, + } as Partial)); + }, + resetTxt2Img() { + set({ + txt2img: { + ...defaultParams, + width: server.width.default, + height: server.height.default, + }, + } as Partial); + }, + }); +} diff --git a/gui/src/state/types.ts b/gui/src/state/types.ts index 98843c86..3b2144bd 100644 --- a/gui/src/state/types.ts +++ b/gui/src/state/types.ts @@ -1,4 +1,5 @@ import { PaletteMode } from '@mui/material'; +import { StateCreator } from 'zustand'; import { ConfigFiles, ConfigState } from '../config.js'; export const MISSING_INDEX = -1; @@ -9,3 +10,37 @@ export type Theme = PaletteMode | ''; // tri-state, '' is unset * Combine optional files and required ranges. */ export type TabState = ConfigFiles> & ConfigState>; + +/** + * Shorthand for state creator to reduce repeated arguments. + */ +export type Slice = StateCreator; + +/** + * Default parameters for the inpaint brush. + * + * Not provided by the server yet. + */ +export const DEFAULT_BRUSH = { + color: 255, + size: 8, + strength: 0.5, +}; + +/** + * Default parameters for the image history. + * + * Not provided by the server yet. + */ +export const DEFAULT_HISTORY = { + /** + * The number of images to be shown. + */ + limit: 4, + + /** + * The number of additional images to be kept in history, so they can scroll + * back into view when you delete one. Does not include deleted images. + */ + scrollback: 2, +}; diff --git a/gui/src/state/upscale.ts b/gui/src/state/upscale.ts index af0a344a..e78d689a 100644 --- a/gui/src/state/upscale.ts +++ b/gui/src/state/upscale.ts @@ -1,10 +1,11 @@ import { + BaseImgParams, HighresParams, ModelParams, UpscaleParams, UpscaleReqParams, } from '../types/params.js'; -import { TabState } from './types.js'; +import { Slice, TabState } from './types.js'; export interface UpscaleSlice { upscale: TabState; @@ -19,3 +20,68 @@ export interface UpscaleSlice { setUpscaleModel(params: Partial): void; setUpscaleUpscale(params: Partial): void; } + +export function createUpscaleSlice( + defaultParams: Required, + defaultHighres: HighresParams, + defaultModel: ModelParams, + defaultUpscale: UpscaleParams, +): Slice { + return (set) => ({ + upscale: { + ...defaultParams, + // eslint-disable-next-line no-null/no-null + source: null, + }, + upscaleHighres: { + ...defaultHighres, + }, + upscaleModel: { + ...defaultModel, + }, + upscaleUpscale: { + ...defaultUpscale, + }, + resetUpscale() { + set({ + upscale: { + ...defaultParams, + // eslint-disable-next-line no-null/no-null + source: null, + }, + } as Partial); + }, + setUpscale(source) { + set((prev) => ({ + upscale: { + ...prev.upscale, + ...source, + }, + } as Partial)); + }, + setUpscaleHighres(params) { + set((prev) => ({ + upscaleHighres: { + ...prev.upscaleHighres, + ...params, + }, + } as Partial)); + }, + setUpscaleModel(params) { + set((prev) => ({ + upscaleModel: { + ...prev.upscaleModel, + ...defaultModel, + }, + } as Partial)); + }, + setUpscaleUpscale(params) { + set((prev) => ({ + upscaleUpscale: { + ...prev.upscaleUpscale, + ...params, + }, + } as Partial)); + }, + }); +} From 5680dd704ef8a9a309aba328ff9c9d56903f3db5 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 14 Dec 2023 22:24:09 -0600 Subject: [PATCH 38/42] fix(gui): add state migrations for new unet/vae params (#427) --- gui/src/components/card/ImageCard.tsx | 3 +- gui/src/components/tab/Blend.tsx | 3 +- gui/src/constants.ts | 31 ++++++++++ gui/src/main.tsx | 4 ++ gui/src/state/blend.ts | 3 +- gui/src/state/full.ts | 4 +- gui/src/state/history.ts | 3 +- gui/src/state/inpaint.ts | 3 +- gui/src/state/migration/default.ts | 82 +++++++++++++++++++++++++++ gui/src/state/types.ts | 29 ---------- 10 files changed, 128 insertions(+), 37 deletions(-) create mode 100644 gui/src/constants.ts create mode 100644 gui/src/state/migration/default.ts diff --git a/gui/src/components/card/ImageCard.tsx b/gui/src/components/card/ImageCard.tsx index 44f5dd91..e614d430 100644 --- a/gui/src/components/card/ImageCard.tsx +++ b/gui/src/components/card/ImageCard.tsx @@ -8,9 +8,10 @@ import { useHash } from 'react-use/lib/useHash'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { BLEND_SOURCES, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; +import { ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { ImageResponse } from '../../types/api.js'; import { range, visibleIndex } from '../../utils.js'; +import { BLEND_SOURCES } from '../../constants.js'; export interface ImageCardProps { image: ImageResponse; diff --git a/gui/src/components/tab/Blend.tsx b/gui/src/components/tab/Blend.tsx index 3cd70449..21fe47b6 100644 --- a/gui/src/components/tab/Blend.tsx +++ b/gui/src/components/tab/Blend.tsx @@ -8,7 +8,8 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { IMAGE_FILTER } from '../../config.js'; -import { BLEND_SOURCES, ClientContext, OnnxState, StateContext } from '../../state/full.js'; +import { BLEND_SOURCES } from '../../constants.js'; +import { ClientContext, OnnxState, StateContext } from '../../state/full.js'; import { TabState } from '../../state/types.js'; import { BlendParams, BrushParams, ModelParams, UpscaleParams } from '../../types/params.js'; import { range } from '../../utils.js'; diff --git a/gui/src/constants.ts b/gui/src/constants.ts new file mode 100644 index 00000000..eb718447 --- /dev/null +++ b/gui/src/constants.ts @@ -0,0 +1,31 @@ + +export const BLEND_SOURCES = 2; + +/** + * Default parameters for the inpaint brush. + * + * Not provided by the server yet. + */ +export const DEFAULT_BRUSH = { + color: 255, + size: 8, + strength: 0.5, +}; + +/** + * Default parameters for the image history. + * + * Not provided by the server yet. + */ +export const DEFAULT_HISTORY = { + /** + * The number of images to be shown. + */ + limit: 4, + + /** + * The number of additional images to be kept in history, so they can scroll + * back into view when you delete one. Does not include deleted images. + */ + scrollback: 2, +}; diff --git a/gui/src/main.tsx b/gui/src/main.tsx index 73b43f8f..edfc2150 100644 --- a/gui/src/main.tsx +++ b/gui/src/main.tsx @@ -30,6 +30,7 @@ import { StateContext, } from './state/full.js'; import { I18N_STRINGS } from './strings/all.js'; +import { applyStateMigrations, UnknownState } from './state/migration/default.js'; export const INITIAL_LOAD_TIMEOUT = 5_000; @@ -70,6 +71,9 @@ export async function renderApp(config: Config, params: ServerParams, logger: Lo ...createResetSlice(...slice), ...createProfileSlice(...slice), }), { + migrate(persistedState, version) { + return applyStateMigrations(params, persistedState as UnknownState, version); + }, name: STATE_KEY, partialize(s) { return { diff --git a/gui/src/state/blend.ts b/gui/src/state/blend.ts index e433a07e..60cd5f88 100644 --- a/gui/src/state/blend.ts +++ b/gui/src/state/blend.ts @@ -1,10 +1,11 @@ +import { DEFAULT_BRUSH } from '../constants.js'; import { BlendParams, BrushParams, ModelParams, UpscaleParams, } from '../types/params.js'; -import { DEFAULT_BRUSH, Slice, TabState } from './types.js'; +import { Slice, TabState } from './types.js'; export interface BlendSlice { blend: TabState; diff --git a/gui/src/state/full.ts b/gui/src/state/full.ts index e46c328a..990161ae 100644 --- a/gui/src/state/full.ts +++ b/gui/src/state/full.ts @@ -69,9 +69,7 @@ export const STATE_KEY = 'onnx-web'; /** * Current state version for zustand persistence. */ -export const STATE_VERSION = 7; - -export const BLEND_SOURCES = 2; +export const STATE_VERSION = 11; export function baseParamsFromServer(defaults: ServerParams): Required { return { diff --git a/gui/src/state/history.ts b/gui/src/state/history.ts index 4eb58271..de71ef72 100644 --- a/gui/src/state/history.ts +++ b/gui/src/state/history.ts @@ -1,6 +1,7 @@ import { Maybe } from '@apextoaster/js-utils'; import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js'; -import { DEFAULT_HISTORY, Slice } from './types.js'; +import { Slice } from './types.js'; +import { DEFAULT_HISTORY } from '../constants.js'; export interface HistoryItem { image: ImageResponse; diff --git a/gui/src/state/inpaint.ts b/gui/src/state/inpaint.ts index 3eac9113..7dab5af3 100644 --- a/gui/src/state/inpaint.ts +++ b/gui/src/state/inpaint.ts @@ -1,4 +1,5 @@ import { ServerParams } from '../config.js'; +import { DEFAULT_BRUSH } from '../constants.js'; import { BaseImgParams, BrushParams, @@ -8,7 +9,7 @@ import { OutpaintPixels, UpscaleParams, } from '../types/params.js'; -import { DEFAULT_BRUSH, Slice, TabState } from './types.js'; +import { Slice, TabState } from './types.js'; export interface InpaintSlice { inpaint: TabState; inpaintBrush: BrushParams; diff --git a/gui/src/state/migration/default.ts b/gui/src/state/migration/default.ts new file mode 100644 index 00000000..425a18fb --- /dev/null +++ b/gui/src/state/migration/default.ts @@ -0,0 +1,82 @@ +/* eslint-disable camelcase */ +import { ServerParams } from '../../config.js'; +import { BaseImgParams } from '../../types/params.js'; +import { OnnxState, STATE_VERSION } from '../full.js'; +import { Img2ImgSlice } from '../img2img.js'; +import { InpaintSlice } from '../inpaint.js'; +import { Txt2ImgSlice } from '../txt2img.js'; +import { UpscaleSlice } from '../upscale.js'; + +export const REMOVE_KEYS = ['tile', 'overlap'] as const; + +export type RemovedKeys = typeof REMOVE_KEYS[number]; + +// TODO: can the compiler calculate this? +export type AddedKeysV11 = 'unet_tile' | 'unet_overlap' | 'vae_tile' | 'vae_overlap'; + +export type BaseImgParamsV7 = Omit & { + overlap: number; + tile: number; +}; + +export type OnnxStateV7 = Omit & { + img2img: BaseImgParamsV7; + inpaint: BaseImgParamsV7; + txt2img: BaseImgParamsV7; + upscale: BaseImgParamsV7; +}; + +export type PreviousState = OnnxStateV7; +export type CurrentState = OnnxState; +export type UnknownState = PreviousState | CurrentState; + +export function applyStateMigrations(params: ServerParams, previousState: UnknownState, version: number): OnnxState { + // eslint-disable-next-line no-console + console.log('applying migrations from %s to %s', version, STATE_VERSION); + + if (version < STATE_VERSION) { + return migrateDefaults(params, previousState as PreviousState); + } + + return previousState as CurrentState; +} + +export function migrateDefaults(params: ServerParams, previousState: PreviousState): CurrentState { + // add any missing keys + const result: CurrentState = { + ...params, + ...previousState, + img2img: { + ...previousState.img2img, + unet_overlap: params.unet_overlap.default, + unet_tile: params.unet_tile.default, + vae_overlap: params.vae_overlap.default, + vae_tile: params.vae_tile.default, + }, + inpaint: { + ...previousState.inpaint, + unet_overlap: params.unet_overlap.default, + unet_tile: params.unet_tile.default, + vae_overlap: params.vae_overlap.default, + vae_tile: params.vae_tile.default, + }, + txt2img: { + ...previousState.txt2img, + unet_overlap: params.unet_overlap.default, + unet_tile: params.unet_tile.default, + vae_overlap: params.vae_overlap.default, + vae_tile: params.vae_tile.default, + }, + upscale: { + ...previousState.upscale, + unet_overlap: params.unet_overlap.default, + unet_tile: params.unet_tile.default, + vae_overlap: params.vae_overlap.default, + vae_tile: params.vae_tile.default, + }, + }; + + // TODO: remove extra keys + + return result; +} diff --git a/gui/src/state/types.ts b/gui/src/state/types.ts index 3b2144bd..b1e9d3ad 100644 --- a/gui/src/state/types.ts +++ b/gui/src/state/types.ts @@ -15,32 +15,3 @@ export type TabState = ConfigFiles> & ConfigState * Shorthand for state creator to reduce repeated arguments. */ export type Slice = StateCreator; - -/** - * Default parameters for the inpaint brush. - * - * Not provided by the server yet. - */ -export const DEFAULT_BRUSH = { - color: 255, - size: 8, - strength: 0.5, -}; - -/** - * Default parameters for the image history. - * - * Not provided by the server yet. - */ -export const DEFAULT_HISTORY = { - /** - * The number of images to be shown. - */ - limit: 4, - - /** - * The number of additional images to be kept in history, so they can scroll - * back into view when you delete one. Does not include deleted images. - */ - scrollback: 2, -}; From eefe9fe22ec76bbc88e37970a012d753021bb76f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 14 Dec 2023 22:31:47 -0600 Subject: [PATCH 39/42] pass logger to migrations, include version numbers in migration name --- gui/src/main.tsx | 2 +- gui/src/state/migration/default.ts | 35 ++++++++++++++++++++---------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/gui/src/main.tsx b/gui/src/main.tsx index edfc2150..81aeecf0 100644 --- a/gui/src/main.tsx +++ b/gui/src/main.tsx @@ -72,7 +72,7 @@ export async function renderApp(config: Config, params: ServerParams, logger: Lo ...createProfileSlice(...slice), }), { migrate(persistedState, version) { - return applyStateMigrations(params, persistedState as UnknownState, version); + return applyStateMigrations(params, persistedState as UnknownState, version, logger); }, name: STATE_KEY, partialize(s) { diff --git a/gui/src/state/migration/default.ts b/gui/src/state/migration/default.ts index 425a18fb..65a4353b 100644 --- a/gui/src/state/migration/default.ts +++ b/gui/src/state/migration/default.ts @@ -1,4 +1,5 @@ /* eslint-disable camelcase */ +import { Logger } from 'browser-bunyan'; import { ServerParams } from '../../config.js'; import { BaseImgParams } from '../../types/params.js'; import { OnnxState, STATE_VERSION } from '../full.js'; @@ -7,12 +8,8 @@ import { InpaintSlice } from '../inpaint.js'; import { Txt2ImgSlice } from '../txt2img.js'; import { UpscaleSlice } from '../upscale.js'; -export const REMOVE_KEYS = ['tile', 'overlap'] as const; - -export type RemovedKeys = typeof REMOVE_KEYS[number]; - -// TODO: can the compiler calculate this? -export type AddedKeysV11 = 'unet_tile' | 'unet_overlap' | 'vae_tile' | 'vae_overlap'; +// #region V7 +export const V7 = 7; export type BaseImgParamsV7 = Omit & { overlap: number; @@ -25,23 +22,37 @@ export type OnnxStateV7 = Omit & { txt2img: BaseImgParamsV7; upscale: BaseImgParamsV7; }; +// #endregion +// #region V11 +export const REMOVED_KEYS_V11 = ['tile', 'overlap'] as const; + +export type RemovedKeysV11 = typeof REMOVED_KEYS_V11[number]; + +// TODO: can the compiler calculate this? +export type AddedKeysV11 = 'unet_tile' | 'unet_overlap' | 'vae_tile' | 'vae_overlap'; +// #endregion + +// add versions to this list as they are replaced export type PreviousState = OnnxStateV7; + +// always the latest version export type CurrentState = OnnxState; + +// any version of state export type UnknownState = PreviousState | CurrentState; -export function applyStateMigrations(params: ServerParams, previousState: UnknownState, version: number): OnnxState { - // eslint-disable-next-line no-console - console.log('applying migrations from %s to %s', version, STATE_VERSION); +export function applyStateMigrations(params: ServerParams, previousState: UnknownState, version: number, logger: Logger): OnnxState { + logger.info('applying state migrations from version %s to version %s', version, STATE_VERSION); - if (version < STATE_VERSION) { - return migrateDefaults(params, previousState as PreviousState); + if (version <= V7) { + return migrateV7ToV11(params, previousState as PreviousState); } return previousState as CurrentState; } -export function migrateDefaults(params: ServerParams, previousState: PreviousState): CurrentState { +export function migrateV7ToV11(params: ServerParams, previousState: PreviousState): CurrentState { // add any missing keys const result: CurrentState = { ...params, From 695eeaf30331ad4f24401e5873c7e861844b3b79 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 14 Dec 2023 22:36:52 -0600 Subject: [PATCH 40/42] feat(api): remove deprecated Karras Ve scheduler (#189) --- api/onnx_web/diffusers/load.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 2aaf2154..b6620e8b 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -36,7 +36,6 @@ from .version_safe_diffusers import ( EulerDiscreteScheduler, HeunDiscreteScheduler, IPNDMScheduler, - KarrasVeScheduler, KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, LCMScheduler, @@ -80,7 +79,6 @@ pipeline_schedulers = { "ipndm": IPNDMScheduler, "k-dpm-2-a": KDPM2AncestralDiscreteScheduler, "k-dpm-2": KDPM2DiscreteScheduler, - "karras-ve": KarrasVeScheduler, "lcm": LCMScheduler, "lms-discrete": LMSDiscreteScheduler, "pndm": PNDMScheduler, From 5b382afc6cae543a1cd85b216f01b8452732bcdd Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 15 Dec 2023 08:41:02 -0600 Subject: [PATCH 41/42] outline getting started guide --- docs/getting-started.md | 198 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 docs/getting-started.md diff --git a/docs/getting-started.md b/docs/getting-started.md new file mode 100644 index 00000000..fc596089 --- /dev/null +++ b/docs/getting-started.md @@ -0,0 +1,198 @@ +# Getting Started With onnx-web + +onnx-web is a tool for generating images with Stable Diffusion pipelines, including SDXL. + +## Contents + +- [Getting Started With onnx-web](#getting-started-with-onnx-web) + - [Contents](#contents) + - [Setup](#setup) + - [Windows bundle setup](#windows-bundle-setup) + - [Other setup methods](#other-setup-methods) + - [Running](#running) + - [Running the server](#running-the-server) + - [Running the web UI](#running-the-web-ui) + - [Tabs](#tabs) + - [Txt2img Tab](#txt2img-tab) + - [Img2img Tab](#img2img-tab) + - [Inpaint Tab](#inpaint-tab) + - [Upscale Tab](#upscale-tab) + - [Blend Tab](#blend-tab) + - [Models Tab](#models-tab) + - [Settings Tab](#settings-tab) + - [Image parameters](#image-parameters) + - [Common image parameters](#common-image-parameters) + - [Unique image parameters](#unique-image-parameters) + - [Prompt syntax](#prompt-syntax) + - [LoRAs and embeddings](#loras-and-embeddings) + - [CLIP skip](#clip-skip) + - [Highres](#highres) + - [Highres prompt](#highres-prompt) + - [Highres iterations](#highres-iterations) + - [Profiles](#profiles) + - [Loading from files](#loading-from-files) + - [Saving profiles in the web UI](#saving-profiles-in-the-web-ui) + - [Sharing parameters profiles](#sharing-parameters-profiles) + - [Panorama pipeline](#panorama-pipeline) + - [Region prompts](#region-prompts) + - [Region seeds](#region-seeds) + - [Grid mode](#grid-mode) + - [Grid tokens](#grid-tokens) + - [Memory optimizations](#memory-optimizations) + - [Converting to fp16](#converting-to-fp16) + - [Moving models to the CPU](#moving-models-to-the-cpu) + +## Setup + +### Windows bundle setup + +1. Download +2. Extract +3. Security flags +4. Run + +### Other setup methods + +Link to the other methods. + +## Running + +### Running the server + +Run server or bundle. + +### Running the web UI + +Open web UI. + +Use it from your phone. + +## Tabs + +There are 5 tabs, which do different things. + +### Txt2img Tab + +Words go in, pictures come out. + +### Img2img Tab + +Pictures go in, better pictures come out. + +ControlNet lives here. + +### Inpaint Tab + +Pictures go in, parts of the same picture come out. + +### Upscale Tab + +Just highres and super resolution. + +### Blend Tab + +Use the mask tool to combine two images. + +### Models Tab + +Add and manage models. + +### Settings Tab + +Manage web UI settings. + +Reset buttons. + +## Image parameters + +### Common image parameters + +- Scheduler +- Eta + - for DDIM +- CFG +- Steps +- Seed +- Batch size +- Prompt +- Negative prompt +- Width, height + +### Unique image parameters + +- UNet tile size +- UNet overlap +- Tiled VAE +- VAE tile size +- VAE overlap + +See the complete user guide for details about the highres, upscale, and correction parameters. + +## Prompt syntax + +### LoRAs and embeddings + +`` and ``. + +### CLIP skip + +`` for anime. + +## Highres + +### Highres prompt + +`txt2img prompt || img2img prompt` + +### Highres iterations + +Highres will apply the upscaler and highres prompt (img2img pipeline) for each iteration. + +The final size will be `scale ** iterations`. + +## Profiles + +Saved sets of parameters for later use. + +### Loading from files + +- load from images +- load from JSON + +### Saving profiles in the web UI + +Use the save button. + +### Sharing parameters profiles + +Use the download button. + +Share profiles in the Discord channel. + +## Panorama pipeline + +### Region prompts + +`` + +### Region seeds + +`` + +## Grid mode + +Makes many images. Takes many time. + +### Grid tokens + +`__column__` and `__row__` if you pick token in the menu. + +## Memory optimizations + +### Converting to fp16 + +Enable the option. + +### Moving models to the CPU + +Option for each model. From 133c4a20bdaa08c89ea7078fa7f973b623cc938c Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 16 Dec 2023 13:17:24 -0600 Subject: [PATCH 42/42] fix(api): use correct post-fetch path when converting from checkpoints (#432) --- api/onnx_web/convert/diffusion/diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index da16a5e9..a2e3d807 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -370,7 +370,7 @@ def convert_diffusion_diffusers( else: logger.debug("loading pipeline from SD checkpoint: %s", source) pipeline = download_from_original_stable_diffusion_ckpt( - source, + cache_path, original_config_file=config_path, pipeline_class=pipe_class, **pipe_args,