Merge branch 'main' of https://github.com/ssube/onnx-web into feat/dynamic-wildcards
This commit is contained in:
commit
591a63edc8
|
@ -23,8 +23,8 @@ details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md).
|
||||||
|
|
||||||
This is an incomplete list of new and interesting features, with links to the user guide:
|
This is an incomplete list of new and interesting features, with links to the user guide:
|
||||||
|
|
||||||
- SDXL support
|
- supports SDXL and SDXL Turbo
|
||||||
- LCM support
|
- wide variety of schedulers: DDIM, DEIS, DPM SDE, Euler Ancestral, LCM, UniPC, and more
|
||||||
- hardware acceleration on both AMD and Nvidia
|
- hardware acceleration on both AMD and Nvidia
|
||||||
- tested on CUDA, DirectML, and ROCm
|
- tested on CUDA, DirectML, and ROCm
|
||||||
- [half-precision support for low-memory GPUs](docs/user-guide.md#optimizing-models-for-lower-memory-usage) on both
|
- [half-precision support for low-memory GPUs](docs/user-guide.md#optimizing-models-for-lower-memory-usage) on both
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
call onnx_env\Scripts\Activate.bat
|
call onnx_env\Scripts\Activate.bat
|
||||||
|
|
||||||
|
echo "This launch.bat script is deprecated in favor of launch.ps1 and will be removed in a future release."
|
||||||
|
|
||||||
echo "Downloading and converting models to ONNX format..."
|
echo "Downloading and converting models to ONNX format..."
|
||||||
IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=..\models\extras.json)
|
IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=..\models\extras.json)
|
||||||
python -m onnx_web.convert ^
|
python -m onnx_web.convert ^
|
||||||
|
@ -10,6 +12,12 @@ python -m onnx_web.convert ^
|
||||||
--extras=%ONNX_WEB_EXTRA_MODELS% ^
|
--extras=%ONNX_WEB_EXTRA_MODELS% ^
|
||||||
--token=%HF_TOKEN% %ONNX_WEB_EXTRA_ARGS%
|
--token=%HF_TOKEN% %ONNX_WEB_EXTRA_ARGS%
|
||||||
|
|
||||||
|
IF NOT EXIST .\gui\index.html (
|
||||||
|
echo "Please make sure you have downloaded the web UI files from https://github.com/ssube/onnx-web/tree/gh-pages"
|
||||||
|
echo "See https://github.com/ssube/onnx-web/blob/main/docs/setup-guide.md#download-the-web-ui-bundle for more information"
|
||||||
|
pause
|
||||||
|
)
|
||||||
|
|
||||||
echo "Launching API server..."
|
echo "Launching API server..."
|
||||||
waitress-serve ^
|
waitress-serve ^
|
||||||
--host=0.0.0.0 ^
|
--host=0.0.0.0 ^
|
||||||
|
|
|
@ -10,6 +10,13 @@ python -m onnx_web.convert `
|
||||||
--extras=$Env:ONNX_WEB_EXTRA_MODELS `
|
--extras=$Env:ONNX_WEB_EXTRA_MODELS `
|
||||||
--token=$Env:HF_TOKEN $Env:ONNX_WEB_EXTRA_ARGS
|
--token=$Env:HF_TOKEN $Env:ONNX_WEB_EXTRA_ARGS
|
||||||
|
|
||||||
|
if (!(Test-Path -path .\gui\index.html -PathType Leaf)) {
|
||||||
|
echo "Downloading latest web UI files from Github..."
|
||||||
|
Invoke-WebRequest "https://raw.githubusercontent.com/ssube/onnx-web/gh-pages/v0.11.0/index.html" -OutFile .\gui\index.html
|
||||||
|
Invoke-WebRequest "https://raw.githubusercontent.com/ssube/onnx-web/gh-pages/v0.11.0/config.json" -OutFile .\gui\config.json
|
||||||
|
Invoke-WebRequest "https://raw.githubusercontent.com/ssube/onnx-web/gh-pages/v0.11.0/bundle/main.js" -OutFile .\gui\bundle\main.js
|
||||||
|
}
|
||||||
|
|
||||||
echo "Launching API server..."
|
echo "Launching API server..."
|
||||||
waitress-serve `
|
waitress-serve `
|
||||||
--host=0.0.0.0 `
|
--host=0.0.0.0 `
|
||||||
|
|
|
@ -30,17 +30,16 @@ class BlendMaskStage(BaseStage):
|
||||||
) -> StageResult:
|
) -> StageResult:
|
||||||
logger.info("blending image using mask")
|
logger.info("blending image using mask")
|
||||||
|
|
||||||
# TODO: does this need an alpha channel?
|
|
||||||
mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black")
|
mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black")
|
||||||
mult_mask.alpha_composite(stage_mask)
|
mult_mask = Image.alpha_composite(mult_mask, stage_mask)
|
||||||
mult_mask = mult_mask.convert("L")
|
mult_mask = mult_mask.convert("L")
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(server, "last-mask.png", stage_mask)
|
save_image(server, "last-mask.png", stage_mask)
|
||||||
save_image(server, "last-mult-mask.png", mult_mask)
|
save_image(server, "last-mult-mask.png", mult_mask)
|
||||||
|
|
||||||
return StageResult(
|
return StageResult.from_images(
|
||||||
images=[
|
[
|
||||||
Image.composite(stage_source, source, mult_mask)
|
Image.composite(stage_source, source, mult_mask)
|
||||||
for source in sources.as_image()
|
for source in sources.as_image()
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,16 +3,17 @@ from argparse import ArgumentParser
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import makedirs, path
|
from os import makedirs, path
|
||||||
from sys import exit
|
from sys import exit
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Union
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from huggingface_hub.file_download import hf_hub_download
|
|
||||||
from jsonschema import ValidationError, validate
|
from jsonschema import ValidationError, validate
|
||||||
from onnx import load_model, save_model
|
from onnx import load_model, save_model
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
|
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||||
|
from ..server.plugin import load_plugins
|
||||||
from ..utils import load_config
|
from ..utils import load_config
|
||||||
|
from .client import add_model_source, fetch_model
|
||||||
|
from .client.huggingface import HuggingfaceClient
|
||||||
from .correction.gfpgan import convert_correction_gfpgan
|
from .correction.gfpgan import convert_correction_gfpgan
|
||||||
from .diffusion.control import convert_diffusion_control
|
from .diffusion.control import convert_diffusion_control
|
||||||
from .diffusion.diffusion import convert_diffusion_diffusers
|
from .diffusion.diffusion import convert_diffusion_diffusers
|
||||||
|
@ -25,8 +26,7 @@ from .upscaling.swinir import convert_upscaling_swinir
|
||||||
from .utils import (
|
from .utils import (
|
||||||
DEFAULT_OPSET,
|
DEFAULT_OPSET,
|
||||||
ConversionContext,
|
ConversionContext,
|
||||||
download_progress,
|
fix_diffusion_name,
|
||||||
remove_prefix,
|
|
||||||
source_format,
|
source_format,
|
||||||
tuple_to_correction,
|
tuple_to_correction,
|
||||||
tuple_to_diffusion,
|
tuple_to_diffusion,
|
||||||
|
@ -44,32 +44,34 @@ warnings.filterwarnings(
|
||||||
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
|
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
|
||||||
)
|
)
|
||||||
|
|
||||||
Models = Dict[str, List[Any]]
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
ModelDict = Dict[str, Union[float, int, str]]
|
||||||
|
Models = Dict[str, List[ModelDict]]
|
||||||
|
|
||||||
model_sources: Dict[str, Tuple[str, str]] = {
|
model_converters: Dict[str, Any] = {
|
||||||
"civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"),
|
"img2img": convert_diffusion_diffusers,
|
||||||
|
"img2img-sdxl": convert_diffusion_diffusers_xl,
|
||||||
|
"inpaint": convert_diffusion_diffusers,
|
||||||
|
"txt2img": convert_diffusion_diffusers,
|
||||||
|
"txt2img-sdxl": convert_diffusion_diffusers_xl,
|
||||||
}
|
}
|
||||||
|
|
||||||
model_source_huggingface = "huggingface://"
|
|
||||||
|
|
||||||
# recommended models
|
# recommended models
|
||||||
base_models: Models = {
|
base_models: Models = {
|
||||||
"diffusion": [
|
"diffusion": [
|
||||||
# v1.x
|
# v1.x
|
||||||
(
|
(
|
||||||
"stable-diffusion-onnx-v1-5",
|
"stable-diffusion-onnx-v1-5",
|
||||||
model_source_huggingface + "runwayml/stable-diffusion-v1-5",
|
HuggingfaceClient.protocol + "runwayml/stable-diffusion-v1-5",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"stable-diffusion-onnx-v1-inpainting",
|
"stable-diffusion-onnx-v1-inpainting",
|
||||||
model_source_huggingface + "runwayml/stable-diffusion-inpainting",
|
HuggingfaceClient.protocol + "runwayml/stable-diffusion-inpainting",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"upscaling-stable-diffusion-x4",
|
"upscaling-stable-diffusion-x4",
|
||||||
model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler",
|
HuggingfaceClient.protocol + "stabilityai/stable-diffusion-x4-upscaler",
|
||||||
True,
|
True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
@ -200,180 +202,68 @@ base_models: Models = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def fetch_model(
|
def convert_model_source(conversion: ConversionContext, model):
|
||||||
conversion: ConversionContext,
|
|
||||||
name: str,
|
|
||||||
source: str,
|
|
||||||
dest: Optional[str] = None,
|
|
||||||
format: Optional[str] = None,
|
|
||||||
hf_hub_fetch: bool = False,
|
|
||||||
hf_hub_filename: Optional[str] = None,
|
|
||||||
) -> Tuple[str, bool]:
|
|
||||||
cache_path = dest or conversion.cache_path
|
|
||||||
cache_name = path.join(cache_path, name)
|
|
||||||
|
|
||||||
# add an extension if possible, some of the conversion code checks for it
|
|
||||||
if format is None:
|
|
||||||
url = urlparse(source)
|
|
||||||
ext = path.basename(url.path)
|
|
||||||
_filename, ext = path.splitext(ext)
|
|
||||||
if ext is not None:
|
|
||||||
cache_name = cache_name + ext
|
|
||||||
else:
|
|
||||||
cache_name = f"{cache_name}.{format}"
|
|
||||||
|
|
||||||
if path.exists(cache_name):
|
|
||||||
logger.debug("model already exists in cache, skipping fetch")
|
|
||||||
return cache_name, False
|
|
||||||
|
|
||||||
for proto in model_sources:
|
|
||||||
api_name, api_root = model_sources.get(proto)
|
|
||||||
if source.startswith(proto):
|
|
||||||
api_source = api_root % (remove_prefix(source, proto))
|
|
||||||
logger.info(
|
|
||||||
"downloading model from %s: %s -> %s", api_name, api_source, cache_name
|
|
||||||
)
|
|
||||||
return download_progress([(api_source, cache_name)]), False
|
|
||||||
|
|
||||||
if source.startswith(model_source_huggingface):
|
|
||||||
hub_source = remove_prefix(source, model_source_huggingface)
|
|
||||||
logger.info("downloading model from Huggingface Hub: %s", hub_source)
|
|
||||||
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
|
|
||||||
if hf_hub_fetch:
|
|
||||||
return (
|
|
||||||
hf_hub_download(
|
|
||||||
repo_id=hub_source,
|
|
||||||
filename=hf_hub_filename,
|
|
||||||
cache_dir=cache_path,
|
|
||||||
force_filename=f"{name}.bin",
|
|
||||||
),
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return hub_source, True
|
|
||||||
elif source.startswith("https://"):
|
|
||||||
logger.info("downloading model from: %s", source)
|
|
||||||
return download_progress([(source, cache_name)]), False
|
|
||||||
elif source.startswith("http://"):
|
|
||||||
logger.warning("downloading model from insecure source: %s", source)
|
|
||||||
return download_progress([(source, cache_name)]), False
|
|
||||||
elif source.startswith(path.sep) or source.startswith("."):
|
|
||||||
logger.info("using local model: %s", source)
|
|
||||||
return source, False
|
|
||||||
else:
|
|
||||||
logger.info("unknown model location, using path as provided: %s", source)
|
|
||||||
return source, False
|
|
||||||
|
|
||||||
|
|
||||||
def convert_models(conversion: ConversionContext, args, models: Models):
|
|
||||||
model_errors = []
|
|
||||||
|
|
||||||
if args.sources and "sources" in models:
|
|
||||||
for model in models.get("sources", []):
|
|
||||||
model = tuple_to_source(model)
|
|
||||||
name = model.get("name")
|
|
||||||
|
|
||||||
if name in args.skip:
|
|
||||||
logger.info("skipping source: %s", name)
|
|
||||||
else:
|
|
||||||
model_format = source_format(model)
|
model_format = source_format(model)
|
||||||
|
name = model["name"]
|
||||||
source = model["source"]
|
source = model["source"]
|
||||||
|
|
||||||
try:
|
|
||||||
dest_path = None
|
dest_path = None
|
||||||
if "dest" in model:
|
if "dest" in model:
|
||||||
dest_path = path.join(conversion.model_path, model["dest"])
|
dest_path = path.join(conversion.model_path, model["dest"])
|
||||||
|
|
||||||
dest, hf = fetch_model(
|
dest = fetch_model(conversion, name, source, format=model_format, dest=dest_path)
|
||||||
conversion, name, source, format=model_format, dest=dest_path
|
|
||||||
)
|
|
||||||
logger.info("finished downloading source: %s -> %s", source, dest)
|
logger.info("finished downloading source: %s -> %s", source, dest)
|
||||||
except Exception:
|
|
||||||
logger.exception("error fetching source %s", name)
|
|
||||||
model_errors.append(name)
|
|
||||||
|
|
||||||
if args.networks and "networks" in models:
|
|
||||||
for network in models.get("networks", []):
|
|
||||||
name = network["name"]
|
|
||||||
|
|
||||||
if name in args.skip:
|
def convert_model_network(conversion: ConversionContext, model):
|
||||||
logger.info("skipping network: %s", name)
|
model_format = source_format(model)
|
||||||
else:
|
model_type = model["type"]
|
||||||
network_format = source_format(network)
|
name = model["name"]
|
||||||
network_model = network.get("model", None)
|
source = model["source"]
|
||||||
network_type = network["type"]
|
|
||||||
source = network["source"]
|
|
||||||
|
|
||||||
try:
|
if model_type == "control":
|
||||||
if network_type == "control":
|
dest = fetch_model(
|
||||||
dest, hf = fetch_model(
|
|
||||||
conversion,
|
conversion,
|
||||||
name,
|
name,
|
||||||
source,
|
source,
|
||||||
format=network_format,
|
format=model_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
convert_diffusion_control(
|
convert_diffusion_control(
|
||||||
conversion,
|
conversion,
|
||||||
network,
|
model,
|
||||||
dest,
|
dest,
|
||||||
path.join(conversion.model_path, network_type, name),
|
path.join(conversion.model_path, model_type, name),
|
||||||
)
|
|
||||||
if network_type == "inversion" and network_model == "concept":
|
|
||||||
dest, hf = fetch_model(
|
|
||||||
conversion,
|
|
||||||
name,
|
|
||||||
source,
|
|
||||||
dest=path.join(conversion.model_path, network_type),
|
|
||||||
format=network_format,
|
|
||||||
hf_hub_fetch=True,
|
|
||||||
hf_hub_filename="learned_embeds.bin",
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dest, hf = fetch_model(
|
model = model.get("model", None)
|
||||||
|
dest = fetch_model(
|
||||||
conversion,
|
conversion,
|
||||||
name,
|
name,
|
||||||
source,
|
source,
|
||||||
dest=path.join(conversion.model_path, network_type),
|
dest=path.join(conversion.model_path, model_type),
|
||||||
format=network_format,
|
format=model_format,
|
||||||
|
embeds=(model_type == "inversion" and model == "concept"),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("finished downloading network: %s -> %s", source, dest)
|
logger.info("finished downloading network: %s -> %s", source, dest)
|
||||||
except Exception:
|
|
||||||
logger.exception("error fetching network %s", name)
|
|
||||||
model_errors.append(name)
|
|
||||||
|
|
||||||
if args.diffusion and "diffusion" in models:
|
|
||||||
for model in models.get("diffusion", []):
|
|
||||||
model = tuple_to_diffusion(model)
|
|
||||||
name = model.get("name")
|
|
||||||
|
|
||||||
if name in args.skip:
|
def convert_model_diffusion(conversion: ConversionContext, model):
|
||||||
logger.info("skipping model: %s", name)
|
# fix up entries with missing prefixes
|
||||||
else:
|
name = fix_diffusion_name(model["name"])
|
||||||
|
if name != model["name"]:
|
||||||
|
# update the model in-memory if the name changed
|
||||||
|
model["name"] = name
|
||||||
|
|
||||||
model_format = source_format(model)
|
model_format = source_format(model)
|
||||||
|
|
||||||
try:
|
|
||||||
source, hf = fetch_model(
|
|
||||||
conversion, name, model["source"], format=model_format
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline = model.get("pipeline", "txt2img")
|
pipeline = model.get("pipeline", "txt2img")
|
||||||
if pipeline.endswith("-sdxl"):
|
converter = model_converters.get(pipeline)
|
||||||
converted, dest = convert_diffusion_diffusers_xl(
|
converted, dest = converter(
|
||||||
conversion,
|
conversion,
|
||||||
model,
|
model,
|
||||||
source,
|
|
||||||
model_format,
|
model_format,
|
||||||
hf=hf,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
converted, dest = convert_diffusion_diffusers(
|
|
||||||
conversion,
|
|
||||||
model,
|
|
||||||
source,
|
|
||||||
model_format,
|
|
||||||
hf=hf,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure blending only happens once, not every run
|
# make sure blending only happens once, not every run
|
||||||
|
@ -395,9 +285,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
)
|
)
|
||||||
|
|
||||||
if "tokenizer" not in blend_models:
|
if "tokenizer" not in blend_models:
|
||||||
blend_models[
|
blend_models["tokenizer"] = CLIPTokenizer.from_pretrained(
|
||||||
"tokenizer"
|
|
||||||
] = CLIPTokenizer.from_pretrained(
|
|
||||||
dest,
|
dest,
|
||||||
subfolder="tokenizer",
|
subfolder="tokenizer",
|
||||||
)
|
)
|
||||||
|
@ -405,7 +293,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
inversion_name = inversion["name"]
|
inversion_name = inversion["name"]
|
||||||
inversion_source = inversion["source"]
|
inversion_source = inversion["source"]
|
||||||
inversion_format = inversion.get("format", None)
|
inversion_format = inversion.get("format", None)
|
||||||
inversion_source, hf = fetch_model(
|
inversion_source = fetch_model(
|
||||||
conversion,
|
conversion,
|
||||||
inversion_name,
|
inversion_name,
|
||||||
inversion_source,
|
inversion_source,
|
||||||
|
@ -439,14 +327,12 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
)
|
)
|
||||||
|
|
||||||
if "unet" not in blend_models:
|
if "unet" not in blend_models:
|
||||||
blend_models["unet"] = load_model(
|
blend_models["unet"] = load_model(path.join(dest, "unet", ONNX_MODEL))
|
||||||
path.join(dest, "unet", ONNX_MODEL)
|
|
||||||
)
|
|
||||||
|
|
||||||
# load models if not loaded yet
|
# load models if not loaded yet
|
||||||
lora_name = lora["name"]
|
lora_name = lora["name"]
|
||||||
lora_source = lora["source"]
|
lora_source = lora["source"]
|
||||||
lora_source, hf = fetch_model(
|
lora_source = fetch_model(
|
||||||
conversion,
|
conversion,
|
||||||
f"{name}-lora-{lora_name}",
|
f"{name}-lora-{lora_name}",
|
||||||
lora_source,
|
lora_source,
|
||||||
|
@ -476,9 +362,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
for name in ["text_encoder", "unet"]:
|
for name in ["text_encoder", "unet"]:
|
||||||
if name in blend_models:
|
if name in blend_models:
|
||||||
dest_path = path.join(dest, name, ONNX_MODEL)
|
dest_path = path.join(dest, name, ONNX_MODEL)
|
||||||
logger.debug(
|
logger.debug("saving blended %s model to %s", name, dest_path)
|
||||||
"saving blended %s model to %s", name, dest_path
|
|
||||||
)
|
|
||||||
save_model(
|
save_model(
|
||||||
blend_models[name],
|
blend_models[name],
|
||||||
dest_path,
|
dest_path,
|
||||||
|
@ -487,6 +371,76 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
location=ONNX_WEIGHTS,
|
location=ONNX_WEIGHTS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_model_upscaling(conversion: ConversionContext, model):
|
||||||
|
model_format = source_format(model)
|
||||||
|
name = model["name"]
|
||||||
|
|
||||||
|
source = fetch_model(conversion, name, model["source"], format=model_format)
|
||||||
|
model_type = model.get("model", "resrgan")
|
||||||
|
if model_type == "bsrgan":
|
||||||
|
convert_upscaling_bsrgan(conversion, model, source)
|
||||||
|
elif model_type == "resrgan":
|
||||||
|
convert_upscale_resrgan(conversion, model, source)
|
||||||
|
elif model_type == "swinir":
|
||||||
|
convert_upscaling_swinir(conversion, model, source)
|
||||||
|
else:
|
||||||
|
logger.error("unknown upscaling model type %s for %s", model_type, name)
|
||||||
|
raise ValueError(name)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_model_correction(conversion: ConversionContext, model):
|
||||||
|
model_format = source_format(model)
|
||||||
|
name = model["name"]
|
||||||
|
source = fetch_model(conversion, name, model["source"], format=model_format)
|
||||||
|
model_type = model.get("model", "gfpgan")
|
||||||
|
if model_type == "gfpgan":
|
||||||
|
convert_correction_gfpgan(conversion, model, source)
|
||||||
|
else:
|
||||||
|
logger.error("unknown correction model type %s for %s", model_type, name)
|
||||||
|
raise ValueError(name)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
|
model_errors = []
|
||||||
|
|
||||||
|
if args.sources and "sources" in models:
|
||||||
|
for model in models.get("sources", []):
|
||||||
|
model = tuple_to_source(model)
|
||||||
|
name = model.get("name")
|
||||||
|
|
||||||
|
if name in args.skip:
|
||||||
|
logger.info("skipping source: %s", name)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
convert_model_source(conversion, model)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("error fetching source %s", name)
|
||||||
|
model_errors.append(name)
|
||||||
|
|
||||||
|
if args.networks and "networks" in models:
|
||||||
|
for model in models.get("networks", []):
|
||||||
|
name = model["name"]
|
||||||
|
|
||||||
|
if name in args.skip:
|
||||||
|
logger.info("skipping network: %s", name)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
convert_model_network(conversion, model)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("error fetching network %s", name)
|
||||||
|
model_errors.append(name)
|
||||||
|
|
||||||
|
if args.diffusion and "diffusion" in models:
|
||||||
|
for model in models.get("diffusion", []):
|
||||||
|
model = tuple_to_diffusion(model)
|
||||||
|
name = model.get("name")
|
||||||
|
|
||||||
|
if name in args.skip:
|
||||||
|
logger.info("skipping model: %s", name)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
convert_model_diffusion(conversion, model)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"error converting diffusion model %s",
|
"error converting diffusion model %s",
|
||||||
|
@ -502,24 +456,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
if name in args.skip:
|
if name in args.skip:
|
||||||
logger.info("skipping model: %s", name)
|
logger.info("skipping model: %s", name)
|
||||||
else:
|
else:
|
||||||
model_format = source_format(model)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
source, hf = fetch_model(
|
convert_model_upscaling(conversion, model)
|
||||||
conversion, name, model["source"], format=model_format
|
|
||||||
)
|
|
||||||
model_type = model.get("model", "resrgan")
|
|
||||||
if model_type == "bsrgan":
|
|
||||||
convert_upscaling_bsrgan(conversion, model, source)
|
|
||||||
elif model_type == "resrgan":
|
|
||||||
convert_upscale_resrgan(conversion, model, source)
|
|
||||||
elif model_type == "swinir":
|
|
||||||
convert_upscaling_swinir(conversion, model, source)
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
"unknown upscaling model type %s for %s", model_type, name
|
|
||||||
)
|
|
||||||
model_errors.append(name)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"error converting upscaling model %s",
|
"error converting upscaling model %s",
|
||||||
|
@ -535,19 +473,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
if name in args.skip:
|
if name in args.skip:
|
||||||
logger.info("skipping model: %s", name)
|
logger.info("skipping model: %s", name)
|
||||||
else:
|
else:
|
||||||
model_format = source_format(model)
|
|
||||||
try:
|
try:
|
||||||
source, hf = fetch_model(
|
convert_model_correction(conversion, model)
|
||||||
conversion, name, model["source"], format=model_format
|
|
||||||
)
|
|
||||||
model_type = model.get("model", "gfpgan")
|
|
||||||
if model_type == "gfpgan":
|
|
||||||
convert_correction_gfpgan(conversion, model, source)
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
"unknown correction model type %s for %s", model_type, name
|
|
||||||
)
|
|
||||||
model_errors.append(name)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"error converting correction model %s",
|
"error converting correction model %s",
|
||||||
|
@ -559,12 +486,26 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
||||||
logger.error("error while converting models: %s", model_errors)
|
logger.error("error while converting models: %s", model_errors)
|
||||||
|
|
||||||
|
|
||||||
|
def register_plugins(conversion: ConversionContext):
|
||||||
|
logger.info("loading conversion plugins")
|
||||||
|
exports = load_plugins(conversion)
|
||||||
|
|
||||||
|
for proto, client in exports.clients:
|
||||||
|
try:
|
||||||
|
add_model_source(proto, client)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("error loading client for protocol: %s", proto)
|
||||||
|
|
||||||
|
# TODO: add converters
|
||||||
|
|
||||||
|
|
||||||
def main(args=None) -> int:
|
def main(args=None) -> int:
|
||||||
parser = ArgumentParser(
|
parser = ArgumentParser(
|
||||||
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
|
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
|
||||||
)
|
)
|
||||||
|
|
||||||
# model groups
|
# model groups
|
||||||
|
parser.add_argument("--base", action="store_true", default=True)
|
||||||
parser.add_argument("--networks", action="store_true", default=True)
|
parser.add_argument("--networks", action="store_true", default=True)
|
||||||
parser.add_argument("--sources", action="store_true", default=True)
|
parser.add_argument("--sources", action="store_true", default=True)
|
||||||
parser.add_argument("--correction", action="store_true", default=False)
|
parser.add_argument("--correction", action="store_true", default=False)
|
||||||
|
@ -602,14 +543,18 @@ def main(args=None) -> int:
|
||||||
server.half = args.half or server.has_optimization("onnx-fp16")
|
server.half = args.half or server.has_optimization("onnx-fp16")
|
||||||
server.opset = args.opset
|
server.opset = args.opset
|
||||||
server.token = args.token
|
server.token = args.token
|
||||||
|
|
||||||
|
register_plugins(server)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"converting models in %s using %s", server.model_path, server.training_device
|
"converting models into %s using %s", server.model_path, server.training_device
|
||||||
)
|
)
|
||||||
|
|
||||||
if not path.exists(server.model_path):
|
if not path.exists(server.model_path):
|
||||||
logger.info("model path does not existing, creating: %s", server.model_path)
|
logger.info("model path does not existing, creating: %s", server.model_path)
|
||||||
makedirs(server.model_path)
|
makedirs(server.model_path)
|
||||||
|
|
||||||
|
if args.base:
|
||||||
logger.info("converting base models")
|
logger.info("converting base models")
|
||||||
convert_models(server, args, base_models)
|
convert_models(server, args, base_models)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
from typing import Callable, Dict, Optional
|
||||||
|
from logging import getLogger
|
||||||
|
from os import path
|
||||||
|
|
||||||
|
from ..utils import ConversionContext
|
||||||
|
from .base import BaseClient
|
||||||
|
from .civitai import CivitaiClient
|
||||||
|
from .file import FileClient
|
||||||
|
from .http import HttpClient
|
||||||
|
from .huggingface import HuggingfaceClient
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
model_sources: Dict[str, Callable[[], BaseClient]] = {
|
||||||
|
CivitaiClient.protocol: CivitaiClient,
|
||||||
|
FileClient.protocol: FileClient,
|
||||||
|
HttpClient.insecure_protocol: HttpClient,
|
||||||
|
HttpClient.protocol: HttpClient,
|
||||||
|
HuggingfaceClient.protocol: HuggingfaceClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def add_model_source(proto: str, client: BaseClient):
|
||||||
|
global model_sources
|
||||||
|
|
||||||
|
if proto in model_sources:
|
||||||
|
raise ValueError("protocol has already been taken")
|
||||||
|
|
||||||
|
model_sources[proto] = client
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_model(
|
||||||
|
conversion: ConversionContext,
|
||||||
|
name: str,
|
||||||
|
source: str,
|
||||||
|
format: Optional[str] = None,
|
||||||
|
dest: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
# TODO: switch to urlparse's default scheme
|
||||||
|
if source.startswith(path.sep) or source.startswith("."):
|
||||||
|
logger.info("adding file protocol to local path source: %s", source)
|
||||||
|
source = FileClient.protocol + source
|
||||||
|
|
||||||
|
for proto, client_type in model_sources.items():
|
||||||
|
if source.startswith(proto):
|
||||||
|
client = client_type(conversion)
|
||||||
|
return client.download(
|
||||||
|
conversion, name, source, format=format, dest=dest, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning("unknown model protocol, using path as provided: %s", source)
|
||||||
|
return source
|
|
@ -0,0 +1,16 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..utils import ConversionContext
|
||||||
|
|
||||||
|
|
||||||
|
class BaseClient:
|
||||||
|
def download(
|
||||||
|
self,
|
||||||
|
conversion: ConversionContext,
|
||||||
|
name: str,
|
||||||
|
source: str,
|
||||||
|
format: Optional[str] = None,
|
||||||
|
dest: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
raise NotImplementedError()
|
|
@ -0,0 +1,64 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..utils import (
|
||||||
|
ConversionContext,
|
||||||
|
build_cache_paths,
|
||||||
|
download_progress,
|
||||||
|
get_first_exists,
|
||||||
|
remove_prefix,
|
||||||
|
)
|
||||||
|
from .base import BaseClient
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
CIVITAI_ROOT = "https://civitai.com/api/download/models/%s"
|
||||||
|
|
||||||
|
|
||||||
|
class CivitaiClient(BaseClient):
|
||||||
|
name = "civitai"
|
||||||
|
protocol = "civitai://"
|
||||||
|
|
||||||
|
root: str
|
||||||
|
token: Optional[str]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
conversion: ConversionContext,
|
||||||
|
token: Optional[str] = None,
|
||||||
|
root: str = CIVITAI_ROOT,
|
||||||
|
):
|
||||||
|
self.root = conversion.get_setting("CIVITAI_ROOT", root)
|
||||||
|
self.token = conversion.get_setting("CIVITAI_TOKEN", token)
|
||||||
|
|
||||||
|
def download(
|
||||||
|
self,
|
||||||
|
conversion: ConversionContext,
|
||||||
|
name: str,
|
||||||
|
source: str,
|
||||||
|
format: Optional[str] = None,
|
||||||
|
dest: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
cache_paths = build_cache_paths(
|
||||||
|
conversion,
|
||||||
|
name,
|
||||||
|
client=CivitaiClient.name,
|
||||||
|
format=format,
|
||||||
|
dest=dest,
|
||||||
|
)
|
||||||
|
cached = get_first_exists(cache_paths)
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
source = self.root % (remove_prefix(source, CivitaiClient.protocol))
|
||||||
|
logger.info("downloading model from Civitai: %s -> %s", source, cache_paths[0])
|
||||||
|
|
||||||
|
if self.token:
|
||||||
|
logger.debug("adding Civitai token authentication")
|
||||||
|
if "?" in source:
|
||||||
|
source = f"{source}&token={self.token}"
|
||||||
|
else:
|
||||||
|
source = f"{source}?token={self.token}"
|
||||||
|
|
||||||
|
return download_progress(source, cache_paths[0])
|
|
@ -0,0 +1,32 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from os import path
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from ..utils import ConversionContext
|
||||||
|
from .base import BaseClient
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FileClient(BaseClient):
|
||||||
|
protocol = "file://"
|
||||||
|
|
||||||
|
def __init__(self, _conversion: ConversionContext):
|
||||||
|
"""
|
||||||
|
Nothing to initialize for this client.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def download(
|
||||||
|
self,
|
||||||
|
conversion: ConversionContext,
|
||||||
|
_name: str,
|
||||||
|
uri: str,
|
||||||
|
format: Optional[str] = None,
|
||||||
|
dest: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
parts = urlparse(uri)
|
||||||
|
logger.info("loading model from: %s", parts.path)
|
||||||
|
return path.join(dest or conversion.model_path, parts.path)
|
|
@ -0,0 +1,52 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from ..utils import (
|
||||||
|
ConversionContext,
|
||||||
|
build_cache_paths,
|
||||||
|
download_progress,
|
||||||
|
get_first_exists,
|
||||||
|
)
|
||||||
|
from .base import BaseClient
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HttpClient(BaseClient):
|
||||||
|
name = "http"
|
||||||
|
protocol = "https://"
|
||||||
|
insecure_protocol = "http://"
|
||||||
|
|
||||||
|
headers: Dict[str, str]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, _conversion: ConversionContext, headers: Optional[Dict[str, str]] = None
|
||||||
|
):
|
||||||
|
self.headers = headers or {}
|
||||||
|
|
||||||
|
def download(
|
||||||
|
self,
|
||||||
|
conversion: ConversionContext,
|
||||||
|
name: str,
|
||||||
|
source: str,
|
||||||
|
format: Optional[str] = None,
|
||||||
|
dest: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
cache_paths = build_cache_paths(
|
||||||
|
conversion,
|
||||||
|
name,
|
||||||
|
client=HttpClient.name,
|
||||||
|
format=format,
|
||||||
|
dest=dest,
|
||||||
|
)
|
||||||
|
cached = get_first_exists(cache_paths)
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
if source.startswith(HttpClient.protocol):
|
||||||
|
logger.info("downloading model from: %s", source)
|
||||||
|
elif source.startswith(HttpClient.insecure_protocol):
|
||||||
|
logger.warning("downloading model from insecure source: %s", source)
|
||||||
|
|
||||||
|
return download_progress(source, cache_paths[0])
|
|
@ -0,0 +1,43 @@
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from huggingface_hub.file_download import hf_hub_download
|
||||||
|
|
||||||
|
from ..utils import ConversionContext, remove_prefix
|
||||||
|
from .base import BaseClient
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingfaceClient(BaseClient):
|
||||||
|
name = "huggingface"
|
||||||
|
protocol = "huggingface://"
|
||||||
|
|
||||||
|
token: Optional[str]
|
||||||
|
|
||||||
|
def __init__(self, conversion: ConversionContext, token: Optional[str] = None):
|
||||||
|
self.token = conversion.get_setting("HUGGINGFACE_TOKEN", token)
|
||||||
|
|
||||||
|
def download(
|
||||||
|
self,
|
||||||
|
conversion: ConversionContext,
|
||||||
|
name: str,
|
||||||
|
source: str,
|
||||||
|
format: Optional[str] = None,
|
||||||
|
dest: Optional[str] = None,
|
||||||
|
embeds: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
source = remove_prefix(source, HuggingfaceClient.protocol)
|
||||||
|
logger.info("downloading model from Huggingface Hub: %s", source)
|
||||||
|
|
||||||
|
if embeds:
|
||||||
|
return hf_hub_download(
|
||||||
|
repo_id=source,
|
||||||
|
filename="learned_embeds.bin",
|
||||||
|
cache_dir=dest or conversion.cache_path,
|
||||||
|
force_filename=f"{name}.bin",
|
||||||
|
token=self.token,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return source
|
|
@ -36,6 +36,8 @@ from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
||||||
from ...diffusers.version_safe_diffusers import AttnProcessor
|
from ...diffusers.version_safe_diffusers import AttnProcessor
|
||||||
from ...models.cnet import UNet2DConditionModel_CNet
|
from ...models.cnet import UNet2DConditionModel_CNet
|
||||||
from ...utils import run_gc
|
from ...utils import run_gc
|
||||||
|
from ..client import fetch_model
|
||||||
|
from ..client.huggingface import HuggingfaceClient
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
RESOLVE_FORMATS,
|
RESOLVE_FORMATS,
|
||||||
ConversionContext,
|
ConversionContext,
|
||||||
|
@ -43,6 +45,7 @@ from ..utils import (
|
||||||
is_torch_2_0,
|
is_torch_2_0,
|
||||||
load_tensor,
|
load_tensor,
|
||||||
onnx_export,
|
onnx_export,
|
||||||
|
remove_prefix,
|
||||||
)
|
)
|
||||||
from .checkpoint import convert_extract_checkpoint
|
from .checkpoint import convert_extract_checkpoint
|
||||||
|
|
||||||
|
@ -267,14 +270,13 @@ def collate_cnet(cnet_path):
|
||||||
def convert_diffusion_diffusers(
|
def convert_diffusion_diffusers(
|
||||||
conversion: ConversionContext,
|
conversion: ConversionContext,
|
||||||
model: Dict,
|
model: Dict,
|
||||||
source: str,
|
|
||||||
format: Optional[str],
|
format: Optional[str],
|
||||||
hf: bool = False,
|
|
||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
"""
|
"""
|
||||||
name = model.get("name")
|
name = str(model.get("name")).strip()
|
||||||
|
source = model.get("source")
|
||||||
|
|
||||||
# optional
|
# optional
|
||||||
config = model.get("config", None)
|
config = model.get("config", None)
|
||||||
|
@ -320,9 +322,11 @@ def convert_diffusion_diffusers(
|
||||||
logger.info("ONNX model already exists, skipping")
|
logger.info("ONNX model already exists, skipping")
|
||||||
return (False, dest_path)
|
return (False, dest_path)
|
||||||
|
|
||||||
|
cache_path = fetch_model(conversion, name, source, format=format)
|
||||||
|
|
||||||
pipe_class = CONVERT_PIPELINES.get(pipe_type)
|
pipe_class = CONVERT_PIPELINES.get(pipe_type)
|
||||||
v2, pipe_args = get_model_version(
|
v2, pipe_args = get_model_version(
|
||||||
source, conversion.map_location, size=image_size, version=version
|
cache_path, conversion.map_location, size=image_size, version=version
|
||||||
)
|
)
|
||||||
|
|
||||||
is_inpainting = False
|
is_inpainting = False
|
||||||
|
@ -334,14 +338,15 @@ def convert_diffusion_diffusers(
|
||||||
pipe_args["from_safetensors"] = True
|
pipe_args["from_safetensors"] = True
|
||||||
|
|
||||||
torch_source = None
|
torch_source = None
|
||||||
if path.exists(source) and path.isdir(source):
|
if path.exists(cache_path):
|
||||||
|
if path.isdir(cache_path):
|
||||||
logger.debug("loading pipeline from diffusers directory: %s", source)
|
logger.debug("loading pipeline from diffusers directory: %s", source)
|
||||||
pipeline = pipe_class.from_pretrained(
|
pipeline = pipe_class.from_pretrained(
|
||||||
source,
|
cache_path,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
use_auth_token=conversion.token,
|
use_auth_token=conversion.token,
|
||||||
).to(device)
|
).to(device)
|
||||||
elif path.exists(source) and path.isfile(source):
|
else:
|
||||||
if conversion.extract:
|
if conversion.extract:
|
||||||
logger.debug("extracting SD checkpoint to Torch models: %s", source)
|
logger.debug("extracting SD checkpoint to Torch models: %s", source)
|
||||||
torch_source = convert_extract_checkpoint(
|
torch_source = convert_extract_checkpoint(
|
||||||
|
@ -352,7 +357,9 @@ def convert_diffusion_diffusers(
|
||||||
config_file=config,
|
config_file=config,
|
||||||
vae_file=replace_vae,
|
vae_file=replace_vae,
|
||||||
)
|
)
|
||||||
logger.debug("loading pipeline from extracted checkpoint: %s", torch_source)
|
logger.debug(
|
||||||
|
"loading pipeline from extracted checkpoint: %s", torch_source
|
||||||
|
)
|
||||||
pipeline = pipe_class.from_pretrained(
|
pipeline = pipe_class.from_pretrained(
|
||||||
torch_source,
|
torch_source,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
|
@ -363,28 +370,34 @@ def convert_diffusion_diffusers(
|
||||||
else:
|
else:
|
||||||
logger.debug("loading pipeline from SD checkpoint: %s", source)
|
logger.debug("loading pipeline from SD checkpoint: %s", source)
|
||||||
pipeline = download_from_original_stable_diffusion_ckpt(
|
pipeline = download_from_original_stable_diffusion_ckpt(
|
||||||
source,
|
cache_path,
|
||||||
original_config_file=config_path,
|
original_config_file=config_path,
|
||||||
pipeline_class=pipe_class,
|
pipeline_class=pipe_class,
|
||||||
**pipe_args,
|
**pipe_args,
|
||||||
).to(device, torch_dtype=dtype)
|
).to(device, torch_dtype=dtype)
|
||||||
elif hf:
|
elif source.startswith(HuggingfaceClient.protocol):
|
||||||
logger.debug("downloading pretrained model from Huggingface hub: %s", source)
|
hf_path = remove_prefix(source, HuggingfaceClient.protocol)
|
||||||
|
logger.debug("downloading pretrained model from Huggingface hub: %s", hf_path)
|
||||||
pipeline = pipe_class.from_pretrained(
|
pipeline = pipe_class.from_pretrained(
|
||||||
source,
|
hf_path,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
use_auth_token=conversion.token,
|
use_auth_token=conversion.token,
|
||||||
).to(device)
|
).to(device)
|
||||||
else:
|
else:
|
||||||
logger.warning("pipeline source not found or not recognized: %s", source)
|
logger.warning(
|
||||||
raise ValueError(f"pipeline source not found or not recognized: {source}")
|
"pipeline source not found and protocol not recognized: %s", source
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"pipeline source not found and protocol not recognized: {source}"
|
||||||
|
)
|
||||||
|
|
||||||
if replace_vae is not None:
|
if replace_vae is not None:
|
||||||
vae_path = path.join(conversion.model_path, replace_vae)
|
vae_path = path.join(conversion.model_path, replace_vae)
|
||||||
if check_ext(replace_vae, RESOLVE_FORMATS):
|
vae_file = check_ext(vae_path, RESOLVE_FORMATS)
|
||||||
|
if vae_file[0]:
|
||||||
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
|
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
|
||||||
else:
|
else:
|
||||||
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
|
pipeline.vae = AutoencoderKL.from_pretrained(replace_vae)
|
||||||
|
|
||||||
if is_torch_2_0:
|
if is_torch_2_0:
|
||||||
pipeline.unet.set_attn_processor(AttnProcessor())
|
pipeline.unet.set_attn_processor(AttnProcessor())
|
||||||
|
|
|
@ -10,6 +10,7 @@ from onnxruntime.transformers.float16 import convert_float_to_float16
|
||||||
from optimum.exporters.onnx import main_export
|
from optimum.exporters.onnx import main_export
|
||||||
|
|
||||||
from ...constants import ONNX_MODEL
|
from ...constants import ONNX_MODEL
|
||||||
|
from ..client import fetch_model
|
||||||
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
|
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -19,14 +20,13 @@ logger = getLogger(__name__)
|
||||||
def convert_diffusion_diffusers_xl(
|
def convert_diffusion_diffusers_xl(
|
||||||
conversion: ConversionContext,
|
conversion: ConversionContext,
|
||||||
model: Dict,
|
model: Dict,
|
||||||
source: str,
|
|
||||||
format: Optional[str],
|
format: Optional[str],
|
||||||
hf: bool = False,
|
|
||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
"""
|
"""
|
||||||
name = model.get("name")
|
name = str(model.get("name")).strip()
|
||||||
|
source = model.get("source")
|
||||||
replace_vae = model.get("vae", None)
|
replace_vae = model.get("vae", None)
|
||||||
|
|
||||||
device = conversion.training_device
|
device = conversion.training_device
|
||||||
|
@ -52,24 +52,26 @@ def convert_diffusion_diffusers_xl(
|
||||||
|
|
||||||
return (False, dest_path)
|
return (False, dest_path)
|
||||||
|
|
||||||
|
cache_path = fetch_model(conversion, name, model["source"], format=format)
|
||||||
# safetensors -> diffusers directory with torch models
|
# safetensors -> diffusers directory with torch models
|
||||||
temp_path = path.join(conversion.cache_path, f"{name}-torch")
|
temp_path = path.join(conversion.cache_path, f"{name}-torch")
|
||||||
|
|
||||||
if format == "safetensors":
|
if format == "safetensors":
|
||||||
pipeline = StableDiffusionXLPipeline.from_single_file(
|
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||||
source, use_safetensors=True
|
cache_path, use_safetensors=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
pipeline = StableDiffusionXLPipeline.from_pretrained(source)
|
pipeline = StableDiffusionXLPipeline.from_pretrained(cache_path)
|
||||||
|
|
||||||
if replace_vae is not None:
|
if replace_vae is not None:
|
||||||
vae_path = path.join(conversion.model_path, replace_vae)
|
vae_path = path.join(conversion.model_path, replace_vae)
|
||||||
if check_ext(replace_vae, RESOLVE_FORMATS):
|
vae_file = check_ext(vae_path, RESOLVE_FORMATS)
|
||||||
|
if vae_file[0]:
|
||||||
logger.debug("loading VAE from single tensor file: %s", vae_path)
|
logger.debug("loading VAE from single tensor file: %s", vae_path)
|
||||||
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
|
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
|
||||||
else:
|
else:
|
||||||
logger.debug("loading pretrained VAE from path: %s", vae_path)
|
logger.debug("loading pretrained VAE from path: %s", replace_vae)
|
||||||
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
|
pipeline.vae = AutoencoderKL.from_pretrained(replace_vae)
|
||||||
|
|
||||||
if path.exists(temp_path):
|
if path.exists(temp_path):
|
||||||
logger.debug("torch model already exists for %s: %s", source, temp_path)
|
logger.debug("torch model already exists for %s: %s", source, temp_path)
|
||||||
|
|
|
@ -31,6 +31,15 @@ ModelDict = Dict[str, Union[str, int]]
|
||||||
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
|
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
|
||||||
|
|
||||||
DEFAULT_OPSET = 14
|
DEFAULT_OPSET = 14
|
||||||
|
DIFFUSION_PREFIX = [
|
||||||
|
"diffusion-",
|
||||||
|
"diffusion/",
|
||||||
|
"diffusion\\",
|
||||||
|
"stable-diffusion-",
|
||||||
|
"upscaling-", # SD upscaling
|
||||||
|
]
|
||||||
|
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
|
||||||
|
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"]
|
||||||
|
|
||||||
|
|
||||||
class ConversionContext(ServerContext):
|
class ConversionContext(ServerContext):
|
||||||
|
@ -85,8 +94,7 @@ class ConversionContext(ServerContext):
|
||||||
return torch.device(self.training_device)
|
return torch.device(self.training_device)
|
||||||
|
|
||||||
|
|
||||||
def download_progress(urls: List[Tuple[str, str]]):
|
def download_progress(source: str, dest: str):
|
||||||
for url, dest in urls:
|
|
||||||
dest_path = Path(dest).expanduser().resolve()
|
dest_path = Path(dest).expanduser().resolve()
|
||||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@ -95,7 +103,7 @@ def download_progress(urls: List[Tuple[str, str]]):
|
||||||
return str(dest_path.absolute())
|
return str(dest_path.absolute())
|
||||||
|
|
||||||
req = requests.get(
|
req = requests.get(
|
||||||
url,
|
source,
|
||||||
stream=True,
|
stream=True,
|
||||||
allow_redirects=True,
|
allow_redirects=True,
|
||||||
headers={
|
headers={
|
||||||
|
@ -105,7 +113,7 @@ def download_progress(urls: List[Tuple[str, str]]):
|
||||||
if req.status_code != 200:
|
if req.status_code != 200:
|
||||||
req.raise_for_status() # Only works for 4xx errors, per SO answer
|
req.raise_for_status() # Only works for 4xx errors, per SO answer
|
||||||
raise RequestException(
|
raise RequestException(
|
||||||
"request to %s failed with status code: %s" % (url, req.status_code)
|
"request to %s failed with status code: %s" % (source, req.status_code)
|
||||||
)
|
)
|
||||||
|
|
||||||
total = int(req.headers.get("Content-Length", 0))
|
total = int(req.headers.get("Content-Length", 0))
|
||||||
|
@ -184,10 +192,6 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
|
|
||||||
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"]
|
|
||||||
|
|
||||||
|
|
||||||
def check_ext(name: str, exts: List[str]) -> Tuple[bool, str]:
|
def check_ext(name: str, exts: List[str]) -> Tuple[bool, str]:
|
||||||
_name, ext = path.splitext(name)
|
_name, ext = path.splitext(name)
|
||||||
ext = ext.strip(".")
|
ext = ext.strip(".")
|
||||||
|
@ -215,6 +219,9 @@ def remove_prefix(name: str, prefix: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
||||||
|
"""
|
||||||
|
TODO: move out of convert
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
logger.debug("loading tensor with Torch: %s", name)
|
logger.debug("loading tensor with Torch: %s", name)
|
||||||
checkpoint = torch.load(name, map_location=map_location)
|
checkpoint = torch.load(name, map_location=map_location)
|
||||||
|
@ -228,6 +235,9 @@ def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
||||||
|
|
||||||
|
|
||||||
def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
||||||
|
"""
|
||||||
|
TODO: move out of convert
|
||||||
|
"""
|
||||||
logger.debug("loading tensor: %s", name)
|
logger.debug("loading tensor: %s", name)
|
||||||
_, extension = path.splitext(name)
|
_, extension = path.splitext(name)
|
||||||
extension = extension[1:].lower()
|
extension = extension[1:].lower()
|
||||||
|
@ -285,6 +295,9 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
||||||
|
|
||||||
|
|
||||||
def resolve_tensor(name: str) -> Optional[str]:
|
def resolve_tensor(name: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
TODO: move out of convert
|
||||||
|
"""
|
||||||
logger.debug("searching for tensors with known extensions: %s", name)
|
logger.debug("searching for tensors with known extensions: %s", name)
|
||||||
for next_extension in RESOLVE_FORMATS:
|
for next_extension in RESOLVE_FORMATS:
|
||||||
next_name = f"{name}.{next_extension}"
|
next_name = f"{name}.{next_extension}"
|
||||||
|
@ -345,3 +358,52 @@ def onnx_export(
|
||||||
all_tensors_to_one_file=True,
|
all_tensors_to_one_file=True,
|
||||||
location=ONNX_WEIGHTS,
|
location=ONNX_WEIGHTS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_diffusion_name(name: str):
|
||||||
|
if not any([name.startswith(prefix) for prefix in DIFFUSION_PREFIX]):
|
||||||
|
logger.warning(
|
||||||
|
"diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match",
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
return f"diffusion-{name}"
|
||||||
|
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def build_cache_paths(
|
||||||
|
conversion: ConversionContext,
|
||||||
|
name: str,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
dest: Optional[str] = None,
|
||||||
|
format: Optional[str] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
cache_path = dest or conversion.cache_path
|
||||||
|
|
||||||
|
# add an extension if possible, some of the conversion code checks for it
|
||||||
|
if format is not None:
|
||||||
|
basename = path.basename(name)
|
||||||
|
_filename, ext = path.splitext(basename)
|
||||||
|
if ext is None or ext == "":
|
||||||
|
name = f"{name}.{format}"
|
||||||
|
|
||||||
|
paths = [
|
||||||
|
path.join(cache_path, name),
|
||||||
|
]
|
||||||
|
|
||||||
|
if client is not None:
|
||||||
|
client_path = path.join(cache_path, client)
|
||||||
|
paths.append(path.join(client_path, name))
|
||||||
|
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
|
def get_first_exists(
|
||||||
|
paths: List[str],
|
||||||
|
) -> Optional[str]:
|
||||||
|
for name in paths:
|
||||||
|
if path.exists(name):
|
||||||
|
logger.debug("model already exists in cache, skipping fetch: %s", name)
|
||||||
|
return name
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
|
@ -30,12 +30,12 @@ from .version_safe_diffusers import (
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
DEISMultistepScheduler,
|
DEISMultistepScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
|
DPMSolverSDEScheduler,
|
||||||
DPMSolverSinglestepScheduler,
|
DPMSolverSinglestepScheduler,
|
||||||
EulerAncestralDiscreteScheduler,
|
EulerAncestralDiscreteScheduler,
|
||||||
EulerDiscreteScheduler,
|
EulerDiscreteScheduler,
|
||||||
HeunDiscreteScheduler,
|
HeunDiscreteScheduler,
|
||||||
IPNDMScheduler,
|
IPNDMScheduler,
|
||||||
KarrasVeScheduler,
|
|
||||||
KDPM2AncestralDiscreteScheduler,
|
KDPM2AncestralDiscreteScheduler,
|
||||||
KDPM2DiscreteScheduler,
|
KDPM2DiscreteScheduler,
|
||||||
LCMScheduler,
|
LCMScheduler,
|
||||||
|
@ -71,6 +71,7 @@ pipeline_schedulers = {
|
||||||
"ddpm": DDPMScheduler,
|
"ddpm": DDPMScheduler,
|
||||||
"deis-multi": DEISMultistepScheduler,
|
"deis-multi": DEISMultistepScheduler,
|
||||||
"dpm-multi": DPMSolverMultistepScheduler,
|
"dpm-multi": DPMSolverMultistepScheduler,
|
||||||
|
"dpm-sde": DPMSolverSDEScheduler,
|
||||||
"dpm-single": DPMSolverSinglestepScheduler,
|
"dpm-single": DPMSolverSinglestepScheduler,
|
||||||
"euler": EulerDiscreteScheduler,
|
"euler": EulerDiscreteScheduler,
|
||||||
"euler-a": EulerAncestralDiscreteScheduler,
|
"euler-a": EulerAncestralDiscreteScheduler,
|
||||||
|
@ -78,7 +79,6 @@ pipeline_schedulers = {
|
||||||
"ipndm": IPNDMScheduler,
|
"ipndm": IPNDMScheduler,
|
||||||
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
|
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
|
||||||
"k-dpm-2": KDPM2DiscreteScheduler,
|
"k-dpm-2": KDPM2DiscreteScheduler,
|
||||||
"karras-ve": KarrasVeScheduler,
|
|
||||||
"lcm": LCMScheduler,
|
"lcm": LCMScheduler,
|
||||||
"lms-discrete": LMSDiscreteScheduler,
|
"lms-discrete": LMSDiscreteScheduler,
|
||||||
"pndm": PNDMScheduler,
|
"pndm": PNDMScheduler,
|
||||||
|
|
|
@ -12,6 +12,11 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler
|
from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler
|
||||||
|
|
||||||
|
try:
|
||||||
|
from diffusers import DPMSolverSDEScheduler
|
||||||
|
except ImportError:
|
||||||
|
from ..diffusers.stub_scheduler import StubScheduler as DPMSolverSDEScheduler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from diffusers import LCMScheduler
|
from diffusers import LCMScheduler
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
|
@ -94,45 +94,44 @@ class ServerContext:
|
||||||
self.cache = ModelCache(self.cache_limit)
|
self.cache = ModelCache(self.cache_limit)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_environ(cls):
|
def from_environ(cls, env=environ):
|
||||||
memory_limit = environ.get("ONNX_WEB_MEMORY_LIMIT", None)
|
memory_limit = env.get("ONNX_WEB_MEMORY_LIMIT", None)
|
||||||
if memory_limit is not None:
|
if memory_limit is not None:
|
||||||
memory_limit = int(memory_limit)
|
memory_limit = int(memory_limit)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
bundle_path=environ.get(
|
bundle_path=env.get("ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")),
|
||||||
"ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")
|
model_path=env.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
|
||||||
),
|
output_path=env.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
|
||||||
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
|
params_path=env.get("ONNX_WEB_PARAMS_PATH", "."),
|
||||||
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
|
cors_origin=get_list(env, "ONNX_WEB_CORS_ORIGIN", default="*"),
|
||||||
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
|
|
||||||
cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", default="*"),
|
|
||||||
any_platform=get_boolean(
|
any_platform=get_boolean(
|
||||||
environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
|
env, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
|
||||||
),
|
),
|
||||||
block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS"),
|
block_platforms=get_list(env, "ONNX_WEB_BLOCK_PLATFORMS"),
|
||||||
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
default_platform=env.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT),
|
image_format=env.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT),
|
||||||
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
|
cache_limit=int(env.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
|
||||||
show_progress=get_boolean(
|
show_progress=get_boolean(
|
||||||
environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
|
env, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
|
||||||
),
|
),
|
||||||
optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS"),
|
optimizations=get_list(env, "ONNX_WEB_OPTIMIZATIONS"),
|
||||||
extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS"),
|
extra_models=get_list(env, "ONNX_WEB_EXTRA_MODELS"),
|
||||||
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
job_limit=int(env.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
||||||
memory_limit=memory_limit,
|
memory_limit=memory_limit,
|
||||||
admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None),
|
admin_token=env.get("ONNX_WEB_ADMIN_TOKEN", None),
|
||||||
server_version=environ.get(
|
server_version=env.get("ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION),
|
||||||
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
|
|
||||||
),
|
|
||||||
worker_retries=int(
|
worker_retries=int(
|
||||||
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
|
env.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
|
||||||
),
|
),
|
||||||
feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS"),
|
feature_flags=get_list(env, "ONNX_WEB_FEATURE_FLAGS"),
|
||||||
plugins=get_list(environ, "ONNX_WEB_PLUGINS", ""),
|
plugins=get_list(env, "ONNX_WEB_PLUGINS", ""),
|
||||||
debug=get_boolean(environ, "ONNX_WEB_DEBUG", False),
|
debug=get_boolean(env, "ONNX_WEB_DEBUG", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_setting(self, flag: str, default: str) -> Optional[str]:
|
||||||
|
return environ.get(f"ONNX_WEB_{flag}", default)
|
||||||
|
|
||||||
def has_feature(self, flag: str) -> bool:
|
def has_feature(self, flag: str) -> bool:
|
||||||
return flag in self.feature_flags
|
return flag in self.feature_flags
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
import torch
|
import torch
|
||||||
from jsonschema import ValidationError, validate
|
from jsonschema import ValidationError, validate
|
||||||
|
|
||||||
|
from ..convert.utils import fix_diffusion_name
|
||||||
from ..image import ( # mask filters; noise sources
|
from ..image import ( # mask filters; noise sources
|
||||||
mask_filter_gaussian_multiply,
|
mask_filter_gaussian_multiply,
|
||||||
mask_filter_gaussian_screen,
|
mask_filter_gaussian_screen,
|
||||||
|
@ -189,6 +190,9 @@ def load_extras(server: ServerContext):
|
||||||
for model in data[model_type]:
|
for model in data[model_type]:
|
||||||
model_name = model["name"]
|
model_name = model["name"]
|
||||||
|
|
||||||
|
if model_type == "diffusion":
|
||||||
|
model_name = fix_diffusion_name(model_name)
|
||||||
|
|
||||||
if "hash" in model:
|
if "hash" in model:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"collecting hash for model %s from %s",
|
"collecting hash for model %s from %s",
|
||||||
|
|
|
@ -2,18 +2,24 @@ from importlib import import_module
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from typing import Any, Callable, Dict
|
from typing import Any, Callable, Dict
|
||||||
|
|
||||||
from onnx_web.chain.stages import add_stage
|
from ..chain.stages import add_stage
|
||||||
from onnx_web.diffusers.load import add_pipeline
|
from ..diffusers.load import add_pipeline
|
||||||
from onnx_web.server.context import ServerContext
|
from ..server.context import ServerContext
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PluginExports:
|
class PluginExports:
|
||||||
|
clients: Dict[str, Any]
|
||||||
|
converter: Dict[str, Any]
|
||||||
pipelines: Dict[str, Any]
|
pipelines: Dict[str, Any]
|
||||||
stages: Dict[str, Any]
|
stages: Dict[str, Any]
|
||||||
|
|
||||||
def __init__(self, pipelines=None, stages=None) -> None:
|
def __init__(
|
||||||
|
self, clients=None, converter=None, pipelines=None, stages=None
|
||||||
|
) -> None:
|
||||||
|
self.clients = clients or {}
|
||||||
|
self.converter = converter or {}
|
||||||
self.pipelines = pipelines or {}
|
self.pipelines = pipelines or {}
|
||||||
self.stages = stages or {}
|
self.stages = stages or {}
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
},
|
},
|
||||||
"cfg": {
|
"cfg": {
|
||||||
"default": 6,
|
"default": 6,
|
||||||
"min": 1,
|
"min": 0,
|
||||||
"max": 30,
|
"max": 30,
|
||||||
"step": 0.1
|
"step": 0.1
|
||||||
},
|
},
|
||||||
|
|
|
@ -3,21 +3,21 @@ numpy==1.23.5
|
||||||
protobuf==3.20.3
|
protobuf==3.20.3
|
||||||
|
|
||||||
### SD packages ###
|
### SD packages ###
|
||||||
accelerate==0.22.0
|
accelerate==0.25.0
|
||||||
coloredlogs==15.0.1
|
coloredlogs==15.0.1
|
||||||
controlnet_aux==0.0.2
|
controlnet_aux==0.0.2
|
||||||
datasets==2.14.3
|
datasets==2.15.0
|
||||||
diffusers==0.20.0
|
diffusers==0.24.0
|
||||||
huggingface-hub==0.16.4
|
huggingface-hub==0.16.4
|
||||||
invisible-watermark==0.2.0
|
invisible-watermark==0.2.0
|
||||||
mediapipe==0.9.2.1
|
mediapipe==0.9.2.1
|
||||||
omegaconf==2.3.0
|
omegaconf==2.3.0
|
||||||
onnx==1.13.0
|
onnx==1.13.0
|
||||||
# onnxruntime has many platform-specific packages
|
# onnxruntime has many platform-specific packages
|
||||||
optimum==1.12.0
|
optimum==1.16.0
|
||||||
safetensors==0.3.1
|
safetensors==0.4.1
|
||||||
timm==0.6.13
|
timm==0.9.12
|
||||||
transformers==4.32.0
|
transformers==4.36.1
|
||||||
|
|
||||||
#### Upscaling and face correction
|
#### Upscaling and face correction
|
||||||
basicsr==1.4.2
|
basicsr==1.4.2
|
||||||
|
@ -29,10 +29,11 @@ realesrgan==0.3.0
|
||||||
### Server packages ###
|
### Server packages ###
|
||||||
arpeggio==2.0.0
|
arpeggio==2.0.0
|
||||||
boto3==1.26.69
|
boto3==1.26.69
|
||||||
flask==2.2.2
|
flask==3.0.0
|
||||||
flask-cors==3.0.10
|
flask-cors==3.0.10
|
||||||
jsonschema==4.17.3
|
jsonschema==4.17.3
|
||||||
piexif==1.1.3
|
piexif==1.1.3
|
||||||
pyyaml==6.0
|
pyyaml==6.0
|
||||||
setproctitle==1.3.2
|
setproctitle==1.3.2
|
||||||
waitress==2.1.2
|
waitress==2.1.2
|
||||||
|
werkzeug==3.0.1
|
||||||
|
|
|
@ -27,7 +27,7 @@ class ConversionContextTests(unittest.TestCase):
|
||||||
|
|
||||||
class DownloadProgressTests(unittest.TestCase):
|
class DownloadProgressTests(unittest.TestCase):
|
||||||
def test_download_example(self):
|
def test_download_example(self):
|
||||||
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
|
path = download_progress("https://example.com", "/tmp/example-dot-com")
|
||||||
self.assertEqual(path, "/tmp/example-dot-com")
|
self.assertEqual(path, "/tmp/example-dot-com")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -414,6 +414,7 @@ class TestBlendPipeline(unittest.TestCase):
|
||||||
3.0,
|
3.0,
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
|
unet_tile=64,
|
||||||
),
|
),
|
||||||
Size(64, 64),
|
Size(64, 64),
|
||||||
["test-blend.png"],
|
["test-blend.png"],
|
||||||
|
|
|
@ -0,0 +1,198 @@
|
||||||
|
# Getting Started With onnx-web
|
||||||
|
|
||||||
|
onnx-web is a tool for generating images with Stable Diffusion pipelines, including SDXL.
|
||||||
|
|
||||||
|
## Contents
|
||||||
|
|
||||||
|
- [Getting Started With onnx-web](#getting-started-with-onnx-web)
|
||||||
|
- [Contents](#contents)
|
||||||
|
- [Setup](#setup)
|
||||||
|
- [Windows bundle setup](#windows-bundle-setup)
|
||||||
|
- [Other setup methods](#other-setup-methods)
|
||||||
|
- [Running](#running)
|
||||||
|
- [Running the server](#running-the-server)
|
||||||
|
- [Running the web UI](#running-the-web-ui)
|
||||||
|
- [Tabs](#tabs)
|
||||||
|
- [Txt2img Tab](#txt2img-tab)
|
||||||
|
- [Img2img Tab](#img2img-tab)
|
||||||
|
- [Inpaint Tab](#inpaint-tab)
|
||||||
|
- [Upscale Tab](#upscale-tab)
|
||||||
|
- [Blend Tab](#blend-tab)
|
||||||
|
- [Models Tab](#models-tab)
|
||||||
|
- [Settings Tab](#settings-tab)
|
||||||
|
- [Image parameters](#image-parameters)
|
||||||
|
- [Common image parameters](#common-image-parameters)
|
||||||
|
- [Unique image parameters](#unique-image-parameters)
|
||||||
|
- [Prompt syntax](#prompt-syntax)
|
||||||
|
- [LoRAs and embeddings](#loras-and-embeddings)
|
||||||
|
- [CLIP skip](#clip-skip)
|
||||||
|
- [Highres](#highres)
|
||||||
|
- [Highres prompt](#highres-prompt)
|
||||||
|
- [Highres iterations](#highres-iterations)
|
||||||
|
- [Profiles](#profiles)
|
||||||
|
- [Loading from files](#loading-from-files)
|
||||||
|
- [Saving profiles in the web UI](#saving-profiles-in-the-web-ui)
|
||||||
|
- [Sharing parameters profiles](#sharing-parameters-profiles)
|
||||||
|
- [Panorama pipeline](#panorama-pipeline)
|
||||||
|
- [Region prompts](#region-prompts)
|
||||||
|
- [Region seeds](#region-seeds)
|
||||||
|
- [Grid mode](#grid-mode)
|
||||||
|
- [Grid tokens](#grid-tokens)
|
||||||
|
- [Memory optimizations](#memory-optimizations)
|
||||||
|
- [Converting to fp16](#converting-to-fp16)
|
||||||
|
- [Moving models to the CPU](#moving-models-to-the-cpu)
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
### Windows bundle setup
|
||||||
|
|
||||||
|
1. Download
|
||||||
|
2. Extract
|
||||||
|
3. Security flags
|
||||||
|
4. Run
|
||||||
|
|
||||||
|
### Other setup methods
|
||||||
|
|
||||||
|
Link to the other methods.
|
||||||
|
|
||||||
|
## Running
|
||||||
|
|
||||||
|
### Running the server
|
||||||
|
|
||||||
|
Run server or bundle.
|
||||||
|
|
||||||
|
### Running the web UI
|
||||||
|
|
||||||
|
Open web UI.
|
||||||
|
|
||||||
|
Use it from your phone.
|
||||||
|
|
||||||
|
## Tabs
|
||||||
|
|
||||||
|
There are 5 tabs, which do different things.
|
||||||
|
|
||||||
|
### Txt2img Tab
|
||||||
|
|
||||||
|
Words go in, pictures come out.
|
||||||
|
|
||||||
|
### Img2img Tab
|
||||||
|
|
||||||
|
Pictures go in, better pictures come out.
|
||||||
|
|
||||||
|
ControlNet lives here.
|
||||||
|
|
||||||
|
### Inpaint Tab
|
||||||
|
|
||||||
|
Pictures go in, parts of the same picture come out.
|
||||||
|
|
||||||
|
### Upscale Tab
|
||||||
|
|
||||||
|
Just highres and super resolution.
|
||||||
|
|
||||||
|
### Blend Tab
|
||||||
|
|
||||||
|
Use the mask tool to combine two images.
|
||||||
|
|
||||||
|
### Models Tab
|
||||||
|
|
||||||
|
Add and manage models.
|
||||||
|
|
||||||
|
### Settings Tab
|
||||||
|
|
||||||
|
Manage web UI settings.
|
||||||
|
|
||||||
|
Reset buttons.
|
||||||
|
|
||||||
|
## Image parameters
|
||||||
|
|
||||||
|
### Common image parameters
|
||||||
|
|
||||||
|
- Scheduler
|
||||||
|
- Eta
|
||||||
|
- for DDIM
|
||||||
|
- CFG
|
||||||
|
- Steps
|
||||||
|
- Seed
|
||||||
|
- Batch size
|
||||||
|
- Prompt
|
||||||
|
- Negative prompt
|
||||||
|
- Width, height
|
||||||
|
|
||||||
|
### Unique image parameters
|
||||||
|
|
||||||
|
- UNet tile size
|
||||||
|
- UNet overlap
|
||||||
|
- Tiled VAE
|
||||||
|
- VAE tile size
|
||||||
|
- VAE overlap
|
||||||
|
|
||||||
|
See the complete user guide for details about the highres, upscale, and correction parameters.
|
||||||
|
|
||||||
|
## Prompt syntax
|
||||||
|
|
||||||
|
### LoRAs and embeddings
|
||||||
|
|
||||||
|
`<lora:filename:1.0>` and `<inversion:filename:1.0>`.
|
||||||
|
|
||||||
|
### CLIP skip
|
||||||
|
|
||||||
|
`<clip:skip:2>` for anime.
|
||||||
|
|
||||||
|
## Highres
|
||||||
|
|
||||||
|
### Highres prompt
|
||||||
|
|
||||||
|
`txt2img prompt || img2img prompt`
|
||||||
|
|
||||||
|
### Highres iterations
|
||||||
|
|
||||||
|
Highres will apply the upscaler and highres prompt (img2img pipeline) for each iteration.
|
||||||
|
|
||||||
|
The final size will be `scale ** iterations`.
|
||||||
|
|
||||||
|
## Profiles
|
||||||
|
|
||||||
|
Saved sets of parameters for later use.
|
||||||
|
|
||||||
|
### Loading from files
|
||||||
|
|
||||||
|
- load from images
|
||||||
|
- load from JSON
|
||||||
|
|
||||||
|
### Saving profiles in the web UI
|
||||||
|
|
||||||
|
Use the save button.
|
||||||
|
|
||||||
|
### Sharing parameters profiles
|
||||||
|
|
||||||
|
Use the download button.
|
||||||
|
|
||||||
|
Share profiles in the Discord channel.
|
||||||
|
|
||||||
|
## Panorama pipeline
|
||||||
|
|
||||||
|
### Region prompts
|
||||||
|
|
||||||
|
`<region:X:Y:W:H:S:F_TLBR:prompt+>`
|
||||||
|
|
||||||
|
### Region seeds
|
||||||
|
|
||||||
|
`<reseed:X:Y:W:H:?:F_TLBR:seed>`
|
||||||
|
|
||||||
|
## Grid mode
|
||||||
|
|
||||||
|
Makes many images. Takes many time.
|
||||||
|
|
||||||
|
### Grid tokens
|
||||||
|
|
||||||
|
`__column__` and `__row__` if you pick token in the menu.
|
||||||
|
|
||||||
|
## Memory optimizations
|
||||||
|
|
||||||
|
### Converting to fp16
|
||||||
|
|
||||||
|
Enable the option.
|
||||||
|
|
||||||
|
### Moving models to the CPU
|
||||||
|
|
||||||
|
Option for each model.
|
|
@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { OnnxState, StateContext } from '../state.js';
|
import { OnnxState, StateContext } from '../state/full.js';
|
||||||
import { ErrorCard } from './card/ErrorCard.js';
|
import { ErrorCard } from './card/ErrorCard.js';
|
||||||
import { ImageCard } from './card/ImageCard.js';
|
import { ImageCard } from './card/ImageCard.js';
|
||||||
import { LoadingCard } from './card/LoadingCard.js';
|
import { LoadingCard } from './card/LoadingCard.js';
|
||||||
|
|
|
@ -2,7 +2,7 @@ import { Box, Button, Container, Stack, Typography } from '@mui/material';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
import { ReactNode } from 'react';
|
import { ReactNode } from 'react';
|
||||||
|
|
||||||
import { STATE_KEY } from '../state.js';
|
import { STATE_KEY } from '../state/full.js';
|
||||||
import { Logo } from './Logo.js';
|
import { Logo } from './Logo.js';
|
||||||
|
|
||||||
export interface OnnxErrorProps {
|
export interface OnnxErrorProps {
|
||||||
|
|
|
@ -7,7 +7,7 @@ import { useContext, useMemo } from 'react';
|
||||||
import { useHash } from 'react-use/lib/useHash';
|
import { useHash } from 'react-use/lib/useHash';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { OnnxState, StateContext } from '../state.js';
|
import { OnnxState, StateContext } from '../state/full.js';
|
||||||
import { ImageHistory } from './ImageHistory.js';
|
import { ImageHistory } from './ImageHistory.js';
|
||||||
import { Logo } from './Logo.js';
|
import { Logo } from './Logo.js';
|
||||||
import { Blend } from './tab/Blend.js';
|
import { Blend } from './tab/Blend.js';
|
||||||
|
|
|
@ -21,7 +21,7 @@ import { useTranslation } from 'react-i18next';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { OnnxState, StateContext } from '../state.js';
|
import { OnnxState, StateContext } from '../state/full.js';
|
||||||
import { ImageMetadata } from '../types/api.js';
|
import { ImageMetadata } from '../types/api.js';
|
||||||
import { DeepPartial } from '../types/model.js';
|
import { DeepPartial } from '../types/model.js';
|
||||||
import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js';
|
import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js';
|
||||||
|
|
|
@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js';
|
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
import { ImageResponse, ReadyResponse, RetryParams } from '../../types/api.js';
|
import { ImageResponse, ReadyResponse, RetryParams } from '../../types/api.js';
|
||||||
|
|
||||||
export interface ErrorCardProps {
|
export interface ErrorCardProps {
|
||||||
|
|
|
@ -8,9 +8,10 @@ import { useHash } from 'react-use/lib/useHash';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { BLEND_SOURCES, ConfigContext, OnnxState, StateContext } from '../../state.js';
|
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
import { ImageResponse } from '../../types/api.js';
|
import { ImageResponse } from '../../types/api.js';
|
||||||
import { range, visibleIndex } from '../../utils.js';
|
import { range, visibleIndex } from '../../utils.js';
|
||||||
|
import { BLEND_SOURCES } from '../../constants.js';
|
||||||
|
|
||||||
export interface ImageCardProps {
|
export interface ImageCardProps {
|
||||||
image: ImageResponse;
|
image: ImageResponse;
|
||||||
|
|
|
@ -9,7 +9,7 @@ import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { POLL_TIME } from '../../config.js';
|
import { POLL_TIME } from '../../config.js';
|
||||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js';
|
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
import { ImageResponse } from '../../types/api.js';
|
import { ImageResponse } from '../../types/api.js';
|
||||||
|
|
||||||
const LOADING_PERCENT = 100;
|
const LOADING_PERCENT = 100;
|
||||||
|
|
|
@ -5,7 +5,7 @@ import { useContext } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { ConfigContext, OnnxState, StateContext } from '../../state.js';
|
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
import { HighresParams } from '../../types/params.js';
|
import { HighresParams } from '../../types/params.js';
|
||||||
import { NumericField } from '../input/NumericField.js';
|
import { NumericField } from '../input/NumericField.js';
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { STALE_TIME } from '../../config.js';
|
import { STALE_TIME } from '../../config.js';
|
||||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js';
|
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
import { BaseImgParams } from '../../types/params.js';
|
import { BaseImgParams } from '../../types/params.js';
|
||||||
import { NumericField } from '../input/NumericField.js';
|
import { NumericField } from '../input/NumericField.js';
|
||||||
import { PromptInput } from '../input/PromptInput.js';
|
import { PromptInput } from '../input/PromptInput.js';
|
||||||
|
|
|
@ -6,7 +6,7 @@ import { useContext } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { STALE_TIME } from '../../config.js';
|
import { STALE_TIME } from '../../config.js';
|
||||||
import { ClientContext } from '../../state.js';
|
import { ClientContext } from '../../state/full.js';
|
||||||
import { ModelParams } from '../../types/params.js';
|
import { ModelParams } from '../../types/params.js';
|
||||||
import { QueryList } from '../input/QueryList.js';
|
import { QueryList } from '../input/QueryList.js';
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { ConfigContext, OnnxState, StateContext } from '../../state.js';
|
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
import { NumericField } from '../input/NumericField.js';
|
import { NumericField } from '../input/NumericField.js';
|
||||||
|
|
||||||
export function OutpaintControl() {
|
export function OutpaintControl() {
|
||||||
|
|
|
@ -5,7 +5,7 @@ import { useContext } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { ConfigContext, OnnxState, StateContext } from '../../state.js';
|
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
import { UpscaleParams } from '../../types/params.js';
|
import { UpscaleParams } from '../../types/params.js';
|
||||||
import { NumericField } from '../input/NumericField.js';
|
import { NumericField } from '../input/NumericField.js';
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import { useContext } from 'react';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { PipelineGrid } from '../../client/utils.js';
|
import { PipelineGrid } from '../../client/utils.js';
|
||||||
import { OnnxState, StateContext } from '../../state.js';
|
import { OnnxState, StateContext } from '../../state/full.js';
|
||||||
import { VARIABLE_PARAMETERS } from '../../types/chain.js';
|
import { VARIABLE_PARAMETERS } from '../../types/chain.js';
|
||||||
|
|
||||||
export interface VariableControlProps {
|
export interface VariableControlProps {
|
||||||
|
|
|
@ -4,7 +4,7 @@ import * as React from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { OnnxState, StateContext } from '../../state.js';
|
import { OnnxState, StateContext } from '../../state/full.js';
|
||||||
|
|
||||||
const { useContext, useState, memo, useMemo } = React;
|
const { useContext, useState, memo, useMemo } = React;
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import React, { RefObject, useContext, useEffect, useMemo, useRef } from 'react'
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { SAVE_TIME } from '../../config.js';
|
import { SAVE_TIME } from '../../config.js';
|
||||||
import { ConfigContext, LoggerContext, StateContext } from '../../state.js';
|
import { ConfigContext, LoggerContext, StateContext } from '../../state/full.js';
|
||||||
import { BrushParams } from '../../types/params.js';
|
import { BrushParams } from '../../types/params.js';
|
||||||
import { imageFromBlob } from '../../utils.js';
|
import { imageFromBlob } from '../../utils.js';
|
||||||
import { NumericField } from './NumericField';
|
import { NumericField } from './NumericField';
|
||||||
|
|
|
@ -8,7 +8,7 @@ import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { STALE_TIME } from '../../config.js';
|
import { STALE_TIME } from '../../config.js';
|
||||||
import { ClientContext, OnnxState, StateContext } from '../../state.js';
|
import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
import { QueryMenu } from '../input/QueryMenu.js';
|
import { QueryMenu } from '../input/QueryMenu.js';
|
||||||
import { ModelResponse } from '../../types/api.js';
|
import { ModelResponse } from '../../types/api.js';
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,9 @@ import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { IMAGE_FILTER } from '../../config.js';
|
import { IMAGE_FILTER } from '../../config.js';
|
||||||
import { BLEND_SOURCES, ClientContext, OnnxState, StateContext, TabState } from '../../state.js';
|
import { BLEND_SOURCES } from '../../constants.js';
|
||||||
|
import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
|
import { TabState } from '../../state/types.js';
|
||||||
import { BlendParams, BrushParams, ModelParams, UpscaleParams } from '../../types/params.js';
|
import { BlendParams, BrushParams, ModelParams, UpscaleParams } from '../../types/params.js';
|
||||||
import { range } from '../../utils.js';
|
import { range } from '../../utils.js';
|
||||||
import { UpscaleControl } from '../control/UpscaleControl.js';
|
import { UpscaleControl } from '../control/UpscaleControl.js';
|
||||||
|
|
|
@ -8,7 +8,8 @@ import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { IMAGE_FILTER, STALE_TIME } from '../../config.js';
|
import { IMAGE_FILTER, STALE_TIME } from '../../config.js';
|
||||||
import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js';
|
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
|
import { TabState } from '../../state/types.js';
|
||||||
import { HighresParams, Img2ImgParams, ModelParams, UpscaleParams } from '../../types/params.js';
|
import { HighresParams, Img2ImgParams, ModelParams, UpscaleParams } from '../../types/params.js';
|
||||||
import { Profiles } from '../Profiles.js';
|
import { Profiles } from '../Profiles.js';
|
||||||
import { HighresControl } from '../control/HighresControl.js';
|
import { HighresControl } from '../control/HighresControl.js';
|
||||||
|
|
|
@ -8,7 +8,8 @@ import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { IMAGE_FILTER, STALE_TIME } from '../../config.js';
|
import { IMAGE_FILTER, STALE_TIME } from '../../config.js';
|
||||||
import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js';
|
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
|
import { TabState } from '../../state/types.js';
|
||||||
import { BrushParams, HighresParams, InpaintParams, ModelParams, UpscaleParams } from '../../types/params.js';
|
import { BrushParams, HighresParams, InpaintParams, ModelParams, UpscaleParams } from '../../types/params.js';
|
||||||
import { Profiles } from '../Profiles.js';
|
import { Profiles } from '../Profiles.js';
|
||||||
import { HighresControl } from '../control/HighresControl.js';
|
import { HighresControl } from '../control/HighresControl.js';
|
||||||
|
|
|
@ -8,7 +8,7 @@ import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { STALE_TIME } from '../../config.js';
|
import { STALE_TIME } from '../../config.js';
|
||||||
import { ClientContext, OnnxState, StateContext } from '../../state.js';
|
import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
import {
|
import {
|
||||||
CorrectionModel,
|
CorrectionModel,
|
||||||
DiffusionModel,
|
DiffusionModel,
|
||||||
|
|
|
@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next';
|
||||||
import { useStore } from 'zustand';
|
import { useStore } from 'zustand';
|
||||||
|
|
||||||
import { getApiRoot } from '../../config.js';
|
import { getApiRoot } from '../../config.js';
|
||||||
import { ConfigContext, StateContext, STATE_KEY } from '../../state.js';
|
import { ConfigContext, StateContext, STATE_KEY } from '../../state/full.js';
|
||||||
import { getTheme } from '../utils.js';
|
import { getTheme } from '../utils.js';
|
||||||
import { NumericField } from '../input/NumericField.js';
|
import { NumericField } from '../input/NumericField.js';
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,8 @@ import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { PipelineGrid, makeTxt2ImgGridPipeline } from '../../client/utils.js';
|
import { PipelineGrid, makeTxt2ImgGridPipeline } from '../../client/utils.js';
|
||||||
import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js';
|
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
|
import { TabState } from '../../state/types.js';
|
||||||
import { HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from '../../types/params.js';
|
import { HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from '../../types/params.js';
|
||||||
import { Profiles } from '../Profiles.js';
|
import { Profiles } from '../Profiles.js';
|
||||||
import { HighresControl } from '../control/HighresControl.js';
|
import { HighresControl } from '../control/HighresControl.js';
|
||||||
|
|
|
@ -8,7 +8,8 @@ import { useStore } from 'zustand';
|
||||||
import { shallow } from 'zustand/shallow';
|
import { shallow } from 'zustand/shallow';
|
||||||
|
|
||||||
import { IMAGE_FILTER } from '../../config.js';
|
import { IMAGE_FILTER } from '../../config.js';
|
||||||
import { ClientContext, OnnxState, StateContext, TabState } from '../../state.js';
|
import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
|
||||||
|
import { TabState } from '../../state/types.js';
|
||||||
import { HighresParams, ModelParams, UpscaleParams, UpscaleReqParams } from '../../types/params.js';
|
import { HighresParams, ModelParams, UpscaleParams, UpscaleReqParams } from '../../types/params.js';
|
||||||
import { Profiles } from '../Profiles.js';
|
import { Profiles } from '../Profiles.js';
|
||||||
import { HighresControl } from '../control/HighresControl.js';
|
import { HighresControl } from '../control/HighresControl.js';
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import { PaletteMode } from '@mui/material';
|
import { PaletteMode } from '@mui/material';
|
||||||
|
|
||||||
import { Theme } from '../state.js';
|
import { Theme } from '../state/types.js';
|
||||||
import { trimHash } from '../utils.js';
|
import { trimHash } from '../utils.js';
|
||||||
|
|
||||||
export const TAB_LABELS = [
|
export const TAB_LABELS = [
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
},
|
},
|
||||||
"cfg": {
|
"cfg": {
|
||||||
"default": 6,
|
"default": 6,
|
||||||
"min": 1,
|
"min": 0,
|
||||||
"max": 30,
|
"max": 30,
|
||||||
"step": 0.1
|
"step": 0.1
|
||||||
},
|
},
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
|
||||||
|
export const BLEND_SOURCES = 2;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default parameters for the inpaint brush.
|
||||||
|
*
|
||||||
|
* Not provided by the server yet.
|
||||||
|
*/
|
||||||
|
export const DEFAULT_BRUSH = {
|
||||||
|
color: 255,
|
||||||
|
size: 8,
|
||||||
|
strength: 0.5,
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default parameters for the image history.
|
||||||
|
*
|
||||||
|
* Not provided by the server yet.
|
||||||
|
*/
|
||||||
|
export const DEFAULT_HISTORY = {
|
||||||
|
/**
|
||||||
|
* The number of images to be shown.
|
||||||
|
*/
|
||||||
|
limit: 4,
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of additional images to be kept in history, so they can scroll
|
||||||
|
* back into view when you delete one. Does not include deleted images.
|
||||||
|
*/
|
||||||
|
scrollback: 2,
|
||||||
|
};
|
|
@ -28,8 +28,9 @@ import {
|
||||||
STATE_KEY,
|
STATE_KEY,
|
||||||
STATE_VERSION,
|
STATE_VERSION,
|
||||||
StateContext,
|
StateContext,
|
||||||
} from './state.js';
|
} from './state/full.js';
|
||||||
import { I18N_STRINGS } from './strings/all.js';
|
import { I18N_STRINGS } from './strings/all.js';
|
||||||
|
import { applyStateMigrations, UnknownState } from './state/migration/default.js';
|
||||||
|
|
||||||
export const INITIAL_LOAD_TIMEOUT = 5_000;
|
export const INITIAL_LOAD_TIMEOUT = 5_000;
|
||||||
|
|
||||||
|
@ -70,6 +71,9 @@ export async function renderApp(config: Config, params: ServerParams, logger: Lo
|
||||||
...createResetSlice(...slice),
|
...createResetSlice(...slice),
|
||||||
...createProfileSlice(...slice),
|
...createProfileSlice(...slice),
|
||||||
}), {
|
}), {
|
||||||
|
migrate(persistedState, version) {
|
||||||
|
return applyStateMigrations(params, persistedState as UnknownState, version, logger);
|
||||||
|
},
|
||||||
name: STATE_KEY,
|
name: STATE_KEY,
|
||||||
partialize(s) {
|
partialize(s) {
|
||||||
return {
|
return {
|
||||||
|
|
965
gui/src/state.ts
965
gui/src/state.ts
|
@ -1,965 +0,0 @@
|
||||||
/* eslint-disable camelcase */
|
|
||||||
/* eslint-disable max-lines */
|
|
||||||
/* eslint-disable no-null/no-null */
|
|
||||||
import { Maybe } from '@apextoaster/js-utils';
|
|
||||||
import { PaletteMode } from '@mui/material';
|
|
||||||
import { Logger } from 'noicejs';
|
|
||||||
import { createContext } from 'react';
|
|
||||||
import { StateCreator, StoreApi } from 'zustand';
|
|
||||||
|
|
||||||
import {
|
|
||||||
ApiClient,
|
|
||||||
} from './client/base.js';
|
|
||||||
import { PipelineGrid } from './client/utils.js';
|
|
||||||
import { Config, ConfigFiles, ConfigState, ServerParams } from './config.js';
|
|
||||||
import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from './types/model.js';
|
|
||||||
import { ImageResponse, ReadyResponse, RetryParams } from './types/api.js';
|
|
||||||
import {
|
|
||||||
BaseImgParams,
|
|
||||||
BlendParams,
|
|
||||||
BrushParams,
|
|
||||||
HighresParams,
|
|
||||||
Img2ImgParams,
|
|
||||||
InpaintParams,
|
|
||||||
ModelParams,
|
|
||||||
OutpaintPixels,
|
|
||||||
Txt2ImgParams,
|
|
||||||
UpscaleParams,
|
|
||||||
UpscaleReqParams,
|
|
||||||
} from './types/params.js';
|
|
||||||
|
|
||||||
export const MISSING_INDEX = -1;
|
|
||||||
|
|
||||||
export type Theme = PaletteMode | ''; // tri-state, '' is unset
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Combine optional files and required ranges.
|
|
||||||
*/
|
|
||||||
export type TabState<TabParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>;
|
|
||||||
|
|
||||||
export interface HistoryItem {
|
|
||||||
image: ImageResponse;
|
|
||||||
ready: Maybe<ReadyResponse>;
|
|
||||||
retry: Maybe<RetryParams>;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface ProfileItem {
|
|
||||||
name: string;
|
|
||||||
params: BaseImgParams | Txt2ImgParams;
|
|
||||||
highres?: Maybe<HighresParams>;
|
|
||||||
upscale?: Maybe<UpscaleParams>;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface DefaultSlice {
|
|
||||||
defaults: TabState<BaseImgParams>;
|
|
||||||
theme: Theme;
|
|
||||||
|
|
||||||
setDefaults(param: Partial<BaseImgParams>): void;
|
|
||||||
setTheme(theme: Theme): void;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface HistorySlice {
|
|
||||||
history: Array<HistoryItem>;
|
|
||||||
limit: number;
|
|
||||||
|
|
||||||
pushHistory(image: ImageResponse, retry?: RetryParams): void;
|
|
||||||
removeHistory(image: ImageResponse): void;
|
|
||||||
setLimit(limit: number): void;
|
|
||||||
setReady(image: ImageResponse, ready: ReadyResponse): void;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ModelSlice {
|
|
||||||
extras: ExtrasFile;
|
|
||||||
|
|
||||||
removeCorrectionModel(model: CorrectionModel): void;
|
|
||||||
removeDiffusionModel(model: DiffusionModel): void;
|
|
||||||
removeExtraNetwork(model: ExtraNetwork): void;
|
|
||||||
removeExtraSource(model: ExtraSource): void;
|
|
||||||
removeUpscalingModel(model: UpscalingModel): void;
|
|
||||||
|
|
||||||
setExtras(extras: Partial<ExtrasFile>): void;
|
|
||||||
|
|
||||||
setCorrectionModel(model: CorrectionModel): void;
|
|
||||||
setDiffusionModel(model: DiffusionModel): void;
|
|
||||||
setExtraNetwork(model: ExtraNetwork): void;
|
|
||||||
setExtraSource(model: ExtraSource): void;
|
|
||||||
setUpscalingModel(model: UpscalingModel): void;
|
|
||||||
}
|
|
||||||
|
|
||||||
// #region tab slices
|
|
||||||
interface Txt2ImgSlice {
|
|
||||||
txt2img: TabState<Txt2ImgParams>;
|
|
||||||
txt2imgModel: ModelParams;
|
|
||||||
txt2imgHighres: HighresParams;
|
|
||||||
txt2imgUpscale: UpscaleParams;
|
|
||||||
txt2imgVariable: PipelineGrid;
|
|
||||||
|
|
||||||
resetTxt2Img(): void;
|
|
||||||
|
|
||||||
setTxt2Img(params: Partial<Txt2ImgParams>): void;
|
|
||||||
setTxt2ImgModel(params: Partial<ModelParams>): void;
|
|
||||||
setTxt2ImgHighres(params: Partial<HighresParams>): void;
|
|
||||||
setTxt2ImgUpscale(params: Partial<UpscaleParams>): void;
|
|
||||||
setTxt2ImgVariable(params: Partial<PipelineGrid>): void;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface Img2ImgSlice {
|
|
||||||
img2img: TabState<Img2ImgParams>;
|
|
||||||
img2imgModel: ModelParams;
|
|
||||||
img2imgHighres: HighresParams;
|
|
||||||
img2imgUpscale: UpscaleParams;
|
|
||||||
|
|
||||||
resetImg2Img(): void;
|
|
||||||
|
|
||||||
setImg2Img(params: Partial<Img2ImgParams>): void;
|
|
||||||
setImg2ImgModel(params: Partial<ModelParams>): void;
|
|
||||||
setImg2ImgHighres(params: Partial<HighresParams>): void;
|
|
||||||
setImg2ImgUpscale(params: Partial<UpscaleParams>): void;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface InpaintSlice {
|
|
||||||
inpaint: TabState<InpaintParams>;
|
|
||||||
inpaintBrush: BrushParams;
|
|
||||||
inpaintModel: ModelParams;
|
|
||||||
inpaintHighres: HighresParams;
|
|
||||||
inpaintUpscale: UpscaleParams;
|
|
||||||
outpaint: OutpaintPixels;
|
|
||||||
|
|
||||||
resetInpaint(): void;
|
|
||||||
|
|
||||||
setInpaint(params: Partial<InpaintParams>): void;
|
|
||||||
setInpaintBrush(brush: Partial<BrushParams>): void;
|
|
||||||
setInpaintModel(params: Partial<ModelParams>): void;
|
|
||||||
setInpaintHighres(params: Partial<HighresParams>): void;
|
|
||||||
setInpaintUpscale(params: Partial<UpscaleParams>): void;
|
|
||||||
setOutpaint(pixels: Partial<OutpaintPixels>): void;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface UpscaleSlice {
|
|
||||||
upscale: TabState<UpscaleReqParams>;
|
|
||||||
upscaleHighres: HighresParams;
|
|
||||||
upscaleModel: ModelParams;
|
|
||||||
upscaleUpscale: UpscaleParams;
|
|
||||||
|
|
||||||
resetUpscale(): void;
|
|
||||||
|
|
||||||
setUpscale(params: Partial<UpscaleReqParams>): void;
|
|
||||||
setUpscaleHighres(params: Partial<HighresParams>): void;
|
|
||||||
setUpscaleModel(params: Partial<ModelParams>): void;
|
|
||||||
setUpscaleUpscale(params: Partial<UpscaleParams>): void;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface BlendSlice {
|
|
||||||
blend: TabState<BlendParams>;
|
|
||||||
blendBrush: BrushParams;
|
|
||||||
blendModel: ModelParams;
|
|
||||||
blendUpscale: UpscaleParams;
|
|
||||||
|
|
||||||
resetBlend(): void;
|
|
||||||
|
|
||||||
setBlend(blend: Partial<BlendParams>): void;
|
|
||||||
setBlendBrush(brush: Partial<BrushParams>): void;
|
|
||||||
setBlendModel(model: Partial<ModelParams>): void;
|
|
||||||
setBlendUpscale(params: Partial<UpscaleParams>): void;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ResetSlice {
|
|
||||||
resetAll(): void;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ProfileSlice {
|
|
||||||
profiles: Array<ProfileItem>;
|
|
||||||
|
|
||||||
removeProfile(profileName: string): void;
|
|
||||||
|
|
||||||
saveProfile(profile: ProfileItem): void;
|
|
||||||
}
|
|
||||||
// #endregion
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Full merged state including all slices.
|
|
||||||
*/
|
|
||||||
export type OnnxState
|
|
||||||
= DefaultSlice
|
|
||||||
& HistorySlice
|
|
||||||
& Img2ImgSlice
|
|
||||||
& InpaintSlice
|
|
||||||
& ModelSlice
|
|
||||||
& Txt2ImgSlice
|
|
||||||
& UpscaleSlice
|
|
||||||
& BlendSlice
|
|
||||||
& ResetSlice
|
|
||||||
& ModelSlice
|
|
||||||
& ProfileSlice;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Shorthand for state creator to reduce repeated arguments.
|
|
||||||
*/
|
|
||||||
export type Slice<T> = StateCreator<OnnxState, [], [], T>;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* React context binding for API client.
|
|
||||||
*/
|
|
||||||
export const ClientContext = createContext<Maybe<ApiClient>>(undefined);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* React context binding for merged config, including server parameters.
|
|
||||||
*/
|
|
||||||
export const ConfigContext = createContext<Maybe<Config<ServerParams>>>(undefined);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* React context binding for bunyan logger.
|
|
||||||
*/
|
|
||||||
export const LoggerContext = createContext<Maybe<Logger>>(undefined);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* React context binding for zustand state store.
|
|
||||||
*/
|
|
||||||
export const StateContext = createContext<Maybe<StoreApi<OnnxState>>>(undefined);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Key for zustand persistence, typically local storage.
|
|
||||||
*/
|
|
||||||
export const STATE_KEY = 'onnx-web';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Current state version for zustand persistence.
|
|
||||||
*/
|
|
||||||
export const STATE_VERSION = 7;
|
|
||||||
|
|
||||||
export const BLEND_SOURCES = 2;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Default parameters for the inpaint brush.
|
|
||||||
*
|
|
||||||
* Not provided by the server yet.
|
|
||||||
*/
|
|
||||||
export const DEFAULT_BRUSH = {
|
|
||||||
color: 255,
|
|
||||||
size: 8,
|
|
||||||
strength: 0.5,
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Default parameters for the image history.
|
|
||||||
*
|
|
||||||
* Not provided by the server yet.
|
|
||||||
*/
|
|
||||||
export const DEFAULT_HISTORY = {
|
|
||||||
/**
|
|
||||||
* The number of images to be shown.
|
|
||||||
*/
|
|
||||||
limit: 4,
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The number of additional images to be kept in history, so they can scroll
|
|
||||||
* back into view when you delete one. Does not include deleted images.
|
|
||||||
*/
|
|
||||||
scrollback: 2,
|
|
||||||
};
|
|
||||||
|
|
||||||
export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
|
|
||||||
return {
|
|
||||||
batch: defaults.batch.default,
|
|
||||||
cfg: defaults.cfg.default,
|
|
||||||
eta: defaults.eta.default,
|
|
||||||
negativePrompt: defaults.negativePrompt.default,
|
|
||||||
prompt: defaults.prompt.default,
|
|
||||||
scheduler: defaults.scheduler.default,
|
|
||||||
steps: defaults.steps.default,
|
|
||||||
seed: defaults.seed.default,
|
|
||||||
tiled_vae: defaults.tiled_vae.default,
|
|
||||||
unet_overlap: defaults.unet_overlap.default,
|
|
||||||
unet_tile: defaults.unet_tile.default,
|
|
||||||
vae_overlap: defaults.vae_overlap.default,
|
|
||||||
vae_tile: defaults.vae_tile.default,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Prepare the state slice constructors.
|
|
||||||
*
|
|
||||||
* In the default state, image sources should be null and booleans should be false. Everything
|
|
||||||
* else should be initialized from the default value in the base parameters.
|
|
||||||
*/
|
|
||||||
export function createStateSlices(server: ServerParams) {
|
|
||||||
const defaultParams = baseParamsFromServer(server);
|
|
||||||
const defaultHighres: HighresParams = {
|
|
||||||
enabled: false,
|
|
||||||
highresIterations: server.highresIterations.default,
|
|
||||||
highresMethod: '',
|
|
||||||
highresSteps: server.highresSteps.default,
|
|
||||||
highresScale: server.highresScale.default,
|
|
||||||
highresStrength: server.highresStrength.default,
|
|
||||||
};
|
|
||||||
const defaultModel: ModelParams = {
|
|
||||||
control: server.control.default,
|
|
||||||
correction: server.correction.default,
|
|
||||||
model: server.model.default,
|
|
||||||
pipeline: server.pipeline.default,
|
|
||||||
platform: server.platform.default,
|
|
||||||
upscaling: server.upscaling.default,
|
|
||||||
};
|
|
||||||
const defaultUpscale: UpscaleParams = {
|
|
||||||
denoise: server.denoise.default,
|
|
||||||
enabled: false,
|
|
||||||
faces: false,
|
|
||||||
faceOutscale: server.faceOutscale.default,
|
|
||||||
faceStrength: server.faceStrength.default,
|
|
||||||
outscale: server.outscale.default,
|
|
||||||
scale: server.scale.default,
|
|
||||||
upscaleOrder: server.upscaleOrder.default,
|
|
||||||
};
|
|
||||||
const defaultGrid: PipelineGrid = {
|
|
||||||
enabled: false,
|
|
||||||
columns: {
|
|
||||||
parameter: 'seed',
|
|
||||||
value: '',
|
|
||||||
},
|
|
||||||
rows: {
|
|
||||||
parameter: 'seed',
|
|
||||||
value: '',
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
const createTxt2ImgSlice: Slice<Txt2ImgSlice> = (set) => ({
|
|
||||||
txt2img: {
|
|
||||||
...defaultParams,
|
|
||||||
width: server.width.default,
|
|
||||||
height: server.height.default,
|
|
||||||
},
|
|
||||||
txt2imgHighres: {
|
|
||||||
...defaultHighres,
|
|
||||||
},
|
|
||||||
txt2imgModel: {
|
|
||||||
...defaultModel,
|
|
||||||
},
|
|
||||||
txt2imgUpscale: {
|
|
||||||
...defaultUpscale,
|
|
||||||
},
|
|
||||||
txt2imgVariable: {
|
|
||||||
...defaultGrid,
|
|
||||||
},
|
|
||||||
setTxt2Img(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
txt2img: {
|
|
||||||
...prev.txt2img,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setTxt2ImgHighres(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
txt2imgHighres: {
|
|
||||||
...prev.txt2imgHighres,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setTxt2ImgModel(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
txt2imgModel: {
|
|
||||||
...prev.txt2imgModel,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setTxt2ImgUpscale(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
txt2imgUpscale: {
|
|
||||||
...prev.txt2imgUpscale,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setTxt2ImgVariable(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
txt2imgVariable: {
|
|
||||||
...prev.txt2imgVariable,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
resetTxt2Img() {
|
|
||||||
set({
|
|
||||||
txt2img: {
|
|
||||||
...defaultParams,
|
|
||||||
width: server.width.default,
|
|
||||||
height: server.height.default,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const createImg2ImgSlice: Slice<Img2ImgSlice> = (set) => ({
|
|
||||||
img2img: {
|
|
||||||
...defaultParams,
|
|
||||||
loopback: server.loopback.default,
|
|
||||||
source: null,
|
|
||||||
sourceFilter: '',
|
|
||||||
strength: server.strength.default,
|
|
||||||
},
|
|
||||||
img2imgHighres: {
|
|
||||||
...defaultHighres,
|
|
||||||
},
|
|
||||||
img2imgModel: {
|
|
||||||
...defaultModel,
|
|
||||||
},
|
|
||||||
img2imgUpscale: {
|
|
||||||
...defaultUpscale,
|
|
||||||
},
|
|
||||||
resetImg2Img() {
|
|
||||||
set({
|
|
||||||
img2img: {
|
|
||||||
...defaultParams,
|
|
||||||
loopback: server.loopback.default,
|
|
||||||
source: null,
|
|
||||||
sourceFilter: '',
|
|
||||||
strength: server.strength.default,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
},
|
|
||||||
setImg2Img(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
img2img: {
|
|
||||||
...prev.img2img,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setImg2ImgHighres(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
img2imgHighres: {
|
|
||||||
...prev.img2imgHighres,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setImg2ImgModel(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
img2imgModel: {
|
|
||||||
...prev.img2imgModel,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setImg2ImgUpscale(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
img2imgUpscale: {
|
|
||||||
...prev.img2imgUpscale,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const createInpaintSlice: Slice<InpaintSlice> = (set) => ({
|
|
||||||
inpaint: {
|
|
||||||
...defaultParams,
|
|
||||||
fillColor: server.fillColor.default,
|
|
||||||
filter: server.filter.default,
|
|
||||||
mask: null,
|
|
||||||
noise: server.noise.default,
|
|
||||||
source: null,
|
|
||||||
strength: server.strength.default,
|
|
||||||
tileOrder: server.tileOrder.default,
|
|
||||||
},
|
|
||||||
inpaintBrush: {
|
|
||||||
...DEFAULT_BRUSH,
|
|
||||||
},
|
|
||||||
inpaintHighres: {
|
|
||||||
...defaultHighres,
|
|
||||||
},
|
|
||||||
inpaintModel: {
|
|
||||||
...defaultModel,
|
|
||||||
},
|
|
||||||
inpaintUpscale: {
|
|
||||||
...defaultUpscale,
|
|
||||||
},
|
|
||||||
outpaint: {
|
|
||||||
enabled: false,
|
|
||||||
left: server.left.default,
|
|
||||||
right: server.right.default,
|
|
||||||
top: server.top.default,
|
|
||||||
bottom: server.bottom.default,
|
|
||||||
},
|
|
||||||
resetInpaint() {
|
|
||||||
set({
|
|
||||||
inpaint: {
|
|
||||||
...defaultParams,
|
|
||||||
fillColor: server.fillColor.default,
|
|
||||||
filter: server.filter.default,
|
|
||||||
mask: null,
|
|
||||||
noise: server.noise.default,
|
|
||||||
source: null,
|
|
||||||
strength: server.strength.default,
|
|
||||||
tileOrder: server.tileOrder.default,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
},
|
|
||||||
setInpaint(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
inpaint: {
|
|
||||||
...prev.inpaint,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setInpaintBrush(brush) {
|
|
||||||
set((prev) => ({
|
|
||||||
inpaintBrush: {
|
|
||||||
...prev.inpaintBrush,
|
|
||||||
...brush,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setInpaintHighres(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
inpaintHighres: {
|
|
||||||
...prev.inpaintHighres,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setInpaintModel(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
inpaintModel: {
|
|
||||||
...prev.inpaintModel,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setInpaintUpscale(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
inpaintUpscale: {
|
|
||||||
...prev.inpaintUpscale,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setOutpaint(pixels) {
|
|
||||||
set((prev) => ({
|
|
||||||
outpaint: {
|
|
||||||
...prev.outpaint,
|
|
||||||
...pixels,
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const createHistorySlice: Slice<HistorySlice> = (set) => ({
|
|
||||||
history: [],
|
|
||||||
limit: DEFAULT_HISTORY.limit,
|
|
||||||
pushHistory(image, retry) {
|
|
||||||
set((prev) => ({
|
|
||||||
...prev,
|
|
||||||
history: [
|
|
||||||
{
|
|
||||||
image,
|
|
||||||
ready: undefined,
|
|
||||||
retry,
|
|
||||||
},
|
|
||||||
...prev.history,
|
|
||||||
].slice(0, prev.limit + DEFAULT_HISTORY.scrollback),
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
removeHistory(image) {
|
|
||||||
set((prev) => ({
|
|
||||||
...prev,
|
|
||||||
history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key),
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setLimit(limit) {
|
|
||||||
set((prev) => ({
|
|
||||||
...prev,
|
|
||||||
limit,
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setReady(image, ready) {
|
|
||||||
set((prev) => {
|
|
||||||
const history = [...prev.history];
|
|
||||||
const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key);
|
|
||||||
if (idx >= 0) {
|
|
||||||
history[idx].ready = ready;
|
|
||||||
} else {
|
|
||||||
// TODO: error
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
...prev,
|
|
||||||
history,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const createUpscaleSlice: Slice<UpscaleSlice> = (set) => ({
|
|
||||||
upscale: {
|
|
||||||
...defaultParams,
|
|
||||||
source: null,
|
|
||||||
},
|
|
||||||
upscaleHighres: {
|
|
||||||
...defaultHighres,
|
|
||||||
},
|
|
||||||
upscaleModel: {
|
|
||||||
...defaultModel,
|
|
||||||
},
|
|
||||||
upscaleUpscale: {
|
|
||||||
...defaultUpscale,
|
|
||||||
},
|
|
||||||
resetUpscale() {
|
|
||||||
set({
|
|
||||||
upscale: {
|
|
||||||
...defaultParams,
|
|
||||||
source: null,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
},
|
|
||||||
setUpscale(source) {
|
|
||||||
set((prev) => ({
|
|
||||||
upscale: {
|
|
||||||
...prev.upscale,
|
|
||||||
...source,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setUpscaleHighres(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
upscaleHighres: {
|
|
||||||
...prev.upscaleHighres,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setUpscaleModel(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
upscaleModel: {
|
|
||||||
...prev.upscaleModel,
|
|
||||||
...defaultModel,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setUpscaleUpscale(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
upscaleUpscale: {
|
|
||||||
...prev.upscaleUpscale,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const createBlendSlice: Slice<BlendSlice> = (set) => ({
|
|
||||||
blend: {
|
|
||||||
mask: null,
|
|
||||||
sources: [],
|
|
||||||
},
|
|
||||||
blendBrush: {
|
|
||||||
...DEFAULT_BRUSH,
|
|
||||||
},
|
|
||||||
blendModel: {
|
|
||||||
...defaultModel,
|
|
||||||
},
|
|
||||||
blendUpscale: {
|
|
||||||
...defaultUpscale,
|
|
||||||
},
|
|
||||||
resetBlend() {
|
|
||||||
set({
|
|
||||||
blend: {
|
|
||||||
mask: null,
|
|
||||||
sources: [],
|
|
||||||
},
|
|
||||||
});
|
|
||||||
},
|
|
||||||
setBlend(blend) {
|
|
||||||
set((prev) => ({
|
|
||||||
blend: {
|
|
||||||
...prev.blend,
|
|
||||||
...blend,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setBlendBrush(brush) {
|
|
||||||
set((prev) => ({
|
|
||||||
blendBrush: {
|
|
||||||
...prev.blendBrush,
|
|
||||||
...brush,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setBlendModel(model) {
|
|
||||||
set((prev) => ({
|
|
||||||
blendModel: {
|
|
||||||
...prev.blendModel,
|
|
||||||
...model,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setBlendUpscale(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
blendUpscale: {
|
|
||||||
...prev.blendUpscale,
|
|
||||||
...params,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const createDefaultSlice: Slice<DefaultSlice> = (set) => ({
|
|
||||||
defaults: {
|
|
||||||
...defaultParams,
|
|
||||||
},
|
|
||||||
theme: '',
|
|
||||||
setDefaults(params) {
|
|
||||||
set((prev) => ({
|
|
||||||
defaults: {
|
|
||||||
...prev.defaults,
|
|
||||||
...params,
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setTheme(theme) {
|
|
||||||
set((prev) => ({
|
|
||||||
theme,
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
const createResetSlice: Slice<ResetSlice> = (set) => ({
|
|
||||||
resetAll() {
|
|
||||||
set((prev) => {
|
|
||||||
const next = { ...prev };
|
|
||||||
next.resetImg2Img();
|
|
||||||
next.resetInpaint();
|
|
||||||
next.resetTxt2Img();
|
|
||||||
next.resetUpscale();
|
|
||||||
next.resetBlend();
|
|
||||||
return next;
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const createProfileSlice: Slice<ProfileSlice> = (set) => ({
|
|
||||||
profiles: [],
|
|
||||||
saveProfile(profile: ProfileItem) {
|
|
||||||
set((prev) => {
|
|
||||||
const profiles = [...prev.profiles];
|
|
||||||
const idx = profiles.findIndex((it) => it.name === profile.name);
|
|
||||||
if (idx >= 0) {
|
|
||||||
profiles[idx] = profile;
|
|
||||||
} else {
|
|
||||||
profiles.push(profile);
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
...prev,
|
|
||||||
profiles,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
},
|
|
||||||
removeProfile(profileName: string) {
|
|
||||||
set((prev) => {
|
|
||||||
const profiles = [...prev.profiles];
|
|
||||||
const idx = profiles.findIndex((it) => it.name === profileName);
|
|
||||||
if (idx >= 0) {
|
|
||||||
profiles.splice(idx, 1);
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
...prev,
|
|
||||||
profiles,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// eslint-disable-next-line sonarjs/cognitive-complexity
|
|
||||||
const createModelSlice: Slice<ModelSlice> = (set) => ({
|
|
||||||
extras: {
|
|
||||||
correction: [],
|
|
||||||
diffusion: [],
|
|
||||||
networks: [],
|
|
||||||
sources: [],
|
|
||||||
upscaling: [],
|
|
||||||
},
|
|
||||||
setExtras(extras) {
|
|
||||||
set((prev) => ({
|
|
||||||
...prev,
|
|
||||||
extras: {
|
|
||||||
...prev.extras,
|
|
||||||
...extras,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
},
|
|
||||||
setCorrectionModel(model) {
|
|
||||||
set((prev) => {
|
|
||||||
const correction = [...prev.extras.correction];
|
|
||||||
const exists = correction.findIndex((it) => model.name === it.name);
|
|
||||||
if (exists === MISSING_INDEX) {
|
|
||||||
correction.push(model);
|
|
||||||
} else {
|
|
||||||
correction[exists] = model;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
...prev,
|
|
||||||
extras: {
|
|
||||||
...prev.extras,
|
|
||||||
correction,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
});
|
|
||||||
},
|
|
||||||
setDiffusionModel(model) {
|
|
||||||
set((prev) => {
|
|
||||||
const diffusion = [...prev.extras.diffusion];
|
|
||||||
const exists = diffusion.findIndex((it) => model.name === it.name);
|
|
||||||
if (exists === MISSING_INDEX) {
|
|
||||||
diffusion.push(model);
|
|
||||||
} else {
|
|
||||||
diffusion[exists] = model;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
...prev,
|
|
||||||
extras: {
|
|
||||||
...prev.extras,
|
|
||||||
diffusion,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
});
|
|
||||||
},
|
|
||||||
setExtraNetwork(model) {
|
|
||||||
set((prev) => {
|
|
||||||
const networks = [...prev.extras.networks];
|
|
||||||
const exists = networks.findIndex((it) => model.name === it.name);
|
|
||||||
if (exists === MISSING_INDEX) {
|
|
||||||
networks.push(model);
|
|
||||||
} else {
|
|
||||||
networks[exists] = model;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
...prev,
|
|
||||||
extras: {
|
|
||||||
...prev.extras,
|
|
||||||
networks,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
});
|
|
||||||
},
|
|
||||||
setExtraSource(model) {
|
|
||||||
set((prev) => {
|
|
||||||
const sources = [...prev.extras.sources];
|
|
||||||
const exists = sources.findIndex((it) => model.name === it.name);
|
|
||||||
if (exists === MISSING_INDEX) {
|
|
||||||
sources.push(model);
|
|
||||||
} else {
|
|
||||||
sources[exists] = model;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
...prev,
|
|
||||||
extras: {
|
|
||||||
...prev.extras,
|
|
||||||
sources,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
});
|
|
||||||
},
|
|
||||||
setUpscalingModel(model) {
|
|
||||||
set((prev) => {
|
|
||||||
const upscaling = [...prev.extras.upscaling];
|
|
||||||
const exists = upscaling.findIndex((it) => model.name === it.name);
|
|
||||||
if (exists === MISSING_INDEX) {
|
|
||||||
upscaling.push(model);
|
|
||||||
} else {
|
|
||||||
upscaling[exists] = model;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
...prev,
|
|
||||||
extras: {
|
|
||||||
...prev.extras,
|
|
||||||
upscaling,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
});
|
|
||||||
},
|
|
||||||
removeCorrectionModel(model) {
|
|
||||||
set((prev) => {
|
|
||||||
const correction = prev.extras.correction.filter((it) => model.name !== it.name);;
|
|
||||||
return {
|
|
||||||
...prev,
|
|
||||||
extras: {
|
|
||||||
...prev.extras,
|
|
||||||
correction,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
},
|
|
||||||
removeDiffusionModel(model) {
|
|
||||||
set((prev) => {
|
|
||||||
const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);;
|
|
||||||
return {
|
|
||||||
...prev,
|
|
||||||
extras: {
|
|
||||||
...prev.extras,
|
|
||||||
diffusion,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
},
|
|
||||||
removeExtraNetwork(model) {
|
|
||||||
set((prev) => {
|
|
||||||
const networks = prev.extras.networks.filter((it) => model.name !== it.name);;
|
|
||||||
return {
|
|
||||||
...prev,
|
|
||||||
extras: {
|
|
||||||
...prev.extras,
|
|
||||||
networks,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
},
|
|
||||||
removeExtraSource(model) {
|
|
||||||
set((prev) => {
|
|
||||||
const sources = prev.extras.sources.filter((it) => model.name !== it.name);;
|
|
||||||
return {
|
|
||||||
...prev,
|
|
||||||
extras: {
|
|
||||||
...prev.extras,
|
|
||||||
sources,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
},
|
|
||||||
removeUpscalingModel(model) {
|
|
||||||
set((prev) => {
|
|
||||||
const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);;
|
|
||||||
return {
|
|
||||||
...prev,
|
|
||||||
extras: {
|
|
||||||
...prev.extras,
|
|
||||||
upscaling,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
return {
|
|
||||||
createDefaultSlice,
|
|
||||||
createHistorySlice,
|
|
||||||
createImg2ImgSlice,
|
|
||||||
createInpaintSlice,
|
|
||||||
createTxt2ImgSlice,
|
|
||||||
createUpscaleSlice,
|
|
||||||
createBlendSlice,
|
|
||||||
createResetSlice,
|
|
||||||
createModelSlice,
|
|
||||||
createProfileSlice,
|
|
||||||
};
|
|
||||||
}
|
|
|
@ -0,0 +1,85 @@
|
||||||
|
import { DEFAULT_BRUSH } from '../constants.js';
|
||||||
|
import {
|
||||||
|
BlendParams,
|
||||||
|
BrushParams,
|
||||||
|
ModelParams,
|
||||||
|
UpscaleParams,
|
||||||
|
} from '../types/params.js';
|
||||||
|
import { Slice, TabState } from './types.js';
|
||||||
|
|
||||||
|
export interface BlendSlice {
|
||||||
|
blend: TabState<BlendParams>;
|
||||||
|
blendBrush: BrushParams;
|
||||||
|
blendModel: ModelParams;
|
||||||
|
blendUpscale: UpscaleParams;
|
||||||
|
|
||||||
|
resetBlend(): void;
|
||||||
|
|
||||||
|
setBlend(blend: Partial<BlendParams>): void;
|
||||||
|
setBlendBrush(brush: Partial<BrushParams>): void;
|
||||||
|
setBlendModel(model: Partial<ModelParams>): void;
|
||||||
|
setBlendUpscale(params: Partial<UpscaleParams>): void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createBlendSlice<TState extends BlendSlice>(
|
||||||
|
defaultModel: ModelParams,
|
||||||
|
defaultUpscale: UpscaleParams,
|
||||||
|
): Slice<TState, BlendSlice> {
|
||||||
|
return (set) => ({
|
||||||
|
blend: {
|
||||||
|
// eslint-disable-next-line no-null/no-null
|
||||||
|
mask: null,
|
||||||
|
sources: [],
|
||||||
|
},
|
||||||
|
blendBrush: {
|
||||||
|
...DEFAULT_BRUSH,
|
||||||
|
},
|
||||||
|
blendModel: {
|
||||||
|
...defaultModel,
|
||||||
|
},
|
||||||
|
blendUpscale: {
|
||||||
|
...defaultUpscale,
|
||||||
|
},
|
||||||
|
resetBlend() {
|
||||||
|
set((prev) => ({
|
||||||
|
blend: {
|
||||||
|
// eslint-disable-next-line no-null/no-null
|
||||||
|
mask: null,
|
||||||
|
sources: [] as Array<Blob>,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setBlend(blend) {
|
||||||
|
set((prev) => ({
|
||||||
|
blend: {
|
||||||
|
...prev.blend,
|
||||||
|
...blend,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setBlendBrush(brush) {
|
||||||
|
set((prev) => ({
|
||||||
|
blendBrush: {
|
||||||
|
...prev.blendBrush,
|
||||||
|
...brush,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setBlendModel(model) {
|
||||||
|
set((prev) => ({
|
||||||
|
blendModel: {
|
||||||
|
...prev.blendModel,
|
||||||
|
...model,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setBlendUpscale(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
blendUpscale: {
|
||||||
|
...prev.blendUpscale,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
|
@ -0,0 +1,34 @@
|
||||||
|
import {
|
||||||
|
BaseImgParams,
|
||||||
|
} from '../types/params.js';
|
||||||
|
import { Slice, TabState, Theme } from './types.js';
|
||||||
|
|
||||||
|
export interface DefaultSlice {
|
||||||
|
defaults: TabState<BaseImgParams>;
|
||||||
|
theme: Theme;
|
||||||
|
|
||||||
|
setDefaults(param: Partial<BaseImgParams>): void;
|
||||||
|
setTheme(theme: Theme): void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createDefaultSlice<TState extends DefaultSlice>(defaultParams: Required<BaseImgParams>): Slice<TState, DefaultSlice> {
|
||||||
|
return (set) => ({
|
||||||
|
defaults: {
|
||||||
|
...defaultParams,
|
||||||
|
},
|
||||||
|
theme: '',
|
||||||
|
setDefaults(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
defaults: {
|
||||||
|
...prev.defaults,
|
||||||
|
...params,
|
||||||
|
}
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setTheme(theme) {
|
||||||
|
set((prev) => ({
|
||||||
|
theme,
|
||||||
|
} as Partial<TState>));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
|
@ -0,0 +1,150 @@
|
||||||
|
/* eslint-disable camelcase */
|
||||||
|
import { Maybe } from '@apextoaster/js-utils';
|
||||||
|
import { Logger } from 'noicejs';
|
||||||
|
import { createContext } from 'react';
|
||||||
|
import { StoreApi } from 'zustand';
|
||||||
|
|
||||||
|
import {
|
||||||
|
ApiClient,
|
||||||
|
} from '../client/base.js';
|
||||||
|
import { PipelineGrid } from '../client/utils.js';
|
||||||
|
import { Config, ServerParams } from '../config.js';
|
||||||
|
import { BlendSlice, createBlendSlice } from './blend.js';
|
||||||
|
import { DefaultSlice, createDefaultSlice } from './default.js';
|
||||||
|
import { HistorySlice, createHistorySlice } from './history.js';
|
||||||
|
import { Img2ImgSlice, createImg2ImgSlice } from './img2img.js';
|
||||||
|
import { InpaintSlice, createInpaintSlice } from './inpaint.js';
|
||||||
|
import { ModelSlice, createModelSlice } from './model.js';
|
||||||
|
import { ProfileSlice, createProfileSlice } from './profile.js';
|
||||||
|
import { ResetSlice, createResetSlice } from './reset.js';
|
||||||
|
import { Txt2ImgSlice, createTxt2ImgSlice } from './txt2img.js';
|
||||||
|
import { UpscaleSlice, createUpscaleSlice } from './upscale.js';
|
||||||
|
import {
|
||||||
|
BaseImgParams,
|
||||||
|
HighresParams,
|
||||||
|
ModelParams,
|
||||||
|
UpscaleParams,
|
||||||
|
} from '../types/params.js';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Full merged state including all slices.
|
||||||
|
*/
|
||||||
|
export type OnnxState
|
||||||
|
= DefaultSlice
|
||||||
|
& HistorySlice
|
||||||
|
& Img2ImgSlice
|
||||||
|
& InpaintSlice
|
||||||
|
& ModelSlice
|
||||||
|
& Txt2ImgSlice
|
||||||
|
& UpscaleSlice
|
||||||
|
& BlendSlice
|
||||||
|
& ResetSlice
|
||||||
|
& ProfileSlice;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* React context binding for API client.
|
||||||
|
*/
|
||||||
|
export const ClientContext = createContext<Maybe<ApiClient>>(undefined);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* React context binding for merged config, including server parameters.
|
||||||
|
*/
|
||||||
|
export const ConfigContext = createContext<Maybe<Config<ServerParams>>>(undefined);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* React context binding for bunyan logger.
|
||||||
|
*/
|
||||||
|
export const LoggerContext = createContext<Maybe<Logger>>(undefined);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* React context binding for zustand state store.
|
||||||
|
*/
|
||||||
|
export const StateContext = createContext<Maybe<StoreApi<OnnxState>>>(undefined);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Key for zustand persistence, typically local storage.
|
||||||
|
*/
|
||||||
|
export const STATE_KEY = 'onnx-web';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Current state version for zustand persistence.
|
||||||
|
*/
|
||||||
|
export const STATE_VERSION = 11;
|
||||||
|
|
||||||
|
export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
|
||||||
|
return {
|
||||||
|
batch: defaults.batch.default,
|
||||||
|
cfg: defaults.cfg.default,
|
||||||
|
eta: defaults.eta.default,
|
||||||
|
negativePrompt: defaults.negativePrompt.default,
|
||||||
|
prompt: defaults.prompt.default,
|
||||||
|
scheduler: defaults.scheduler.default,
|
||||||
|
steps: defaults.steps.default,
|
||||||
|
seed: defaults.seed.default,
|
||||||
|
tiled_vae: defaults.tiled_vae.default,
|
||||||
|
unet_overlap: defaults.unet_overlap.default,
|
||||||
|
unet_tile: defaults.unet_tile.default,
|
||||||
|
vae_overlap: defaults.vae_overlap.default,
|
||||||
|
vae_tile: defaults.vae_tile.default,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Prepare the state slice constructors.
|
||||||
|
*
|
||||||
|
* In the default state, image sources should be null and booleans should be false. Everything
|
||||||
|
* else should be initialized from the default value in the base parameters.
|
||||||
|
*/
|
||||||
|
export function createStateSlices(server: ServerParams) {
|
||||||
|
const defaultParams = baseParamsFromServer(server);
|
||||||
|
const defaultHighres: HighresParams = {
|
||||||
|
enabled: false,
|
||||||
|
highresIterations: server.highresIterations.default,
|
||||||
|
highresMethod: '',
|
||||||
|
highresSteps: server.highresSteps.default,
|
||||||
|
highresScale: server.highresScale.default,
|
||||||
|
highresStrength: server.highresStrength.default,
|
||||||
|
};
|
||||||
|
const defaultModel: ModelParams = {
|
||||||
|
control: server.control.default,
|
||||||
|
correction: server.correction.default,
|
||||||
|
model: server.model.default,
|
||||||
|
pipeline: server.pipeline.default,
|
||||||
|
platform: server.platform.default,
|
||||||
|
upscaling: server.upscaling.default,
|
||||||
|
};
|
||||||
|
const defaultUpscale: UpscaleParams = {
|
||||||
|
denoise: server.denoise.default,
|
||||||
|
enabled: false,
|
||||||
|
faces: false,
|
||||||
|
faceOutscale: server.faceOutscale.default,
|
||||||
|
faceStrength: server.faceStrength.default,
|
||||||
|
outscale: server.outscale.default,
|
||||||
|
scale: server.scale.default,
|
||||||
|
upscaleOrder: server.upscaleOrder.default,
|
||||||
|
};
|
||||||
|
const defaultGrid: PipelineGrid = {
|
||||||
|
enabled: false,
|
||||||
|
columns: {
|
||||||
|
parameter: 'seed',
|
||||||
|
value: '',
|
||||||
|
},
|
||||||
|
rows: {
|
||||||
|
parameter: 'seed',
|
||||||
|
value: '',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
return {
|
||||||
|
createBlendSlice: createBlendSlice(defaultModel, defaultUpscale),
|
||||||
|
createDefaultSlice: createDefaultSlice(defaultParams),
|
||||||
|
createHistorySlice: createHistorySlice(),
|
||||||
|
createImg2ImgSlice: createImg2ImgSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale),
|
||||||
|
createInpaintSlice: createInpaintSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale),
|
||||||
|
createModelSlice: createModelSlice(),
|
||||||
|
createProfileSlice: createProfileSlice(),
|
||||||
|
createResetSlice: createResetSlice(),
|
||||||
|
createTxt2ImgSlice: createTxt2ImgSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale, defaultGrid),
|
||||||
|
createUpscaleSlice: createUpscaleSlice(defaultParams, defaultHighres, defaultModel, defaultUpscale),
|
||||||
|
};
|
||||||
|
}
|
|
@ -0,0 +1,68 @@
|
||||||
|
import { Maybe } from '@apextoaster/js-utils';
|
||||||
|
import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js';
|
||||||
|
import { Slice } from './types.js';
|
||||||
|
import { DEFAULT_HISTORY } from '../constants.js';
|
||||||
|
|
||||||
|
export interface HistoryItem {
|
||||||
|
image: ImageResponse;
|
||||||
|
ready: Maybe<ReadyResponse>;
|
||||||
|
retry: Maybe<RetryParams>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface HistorySlice {
|
||||||
|
history: Array<HistoryItem>;
|
||||||
|
limit: number;
|
||||||
|
|
||||||
|
pushHistory(image: ImageResponse, retry?: RetryParams): void;
|
||||||
|
removeHistory(image: ImageResponse): void;
|
||||||
|
setLimit(limit: number): void;
|
||||||
|
setReady(image: ImageResponse, ready: ReadyResponse): void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createHistorySlice<TState extends HistorySlice>(): Slice<TState, HistorySlice> {
|
||||||
|
return (set) => ({
|
||||||
|
history: [],
|
||||||
|
limit: DEFAULT_HISTORY.limit,
|
||||||
|
pushHistory(image, retry) {
|
||||||
|
set((prev) => ({
|
||||||
|
...prev,
|
||||||
|
history: [
|
||||||
|
{
|
||||||
|
image,
|
||||||
|
ready: undefined,
|
||||||
|
retry,
|
||||||
|
},
|
||||||
|
...prev.history,
|
||||||
|
].slice(0, prev.limit + DEFAULT_HISTORY.scrollback),
|
||||||
|
}));
|
||||||
|
},
|
||||||
|
removeHistory(image) {
|
||||||
|
set((prev) => ({
|
||||||
|
...prev,
|
||||||
|
history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key),
|
||||||
|
}));
|
||||||
|
},
|
||||||
|
setLimit(limit) {
|
||||||
|
set((prev) => ({
|
||||||
|
...prev,
|
||||||
|
limit,
|
||||||
|
}));
|
||||||
|
},
|
||||||
|
setReady(image, ready) {
|
||||||
|
set((prev) => {
|
||||||
|
const history = [...prev.history];
|
||||||
|
const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key);
|
||||||
|
if (idx >= 0) {
|
||||||
|
history[idx].ready = ready;
|
||||||
|
} else {
|
||||||
|
// TODO: error
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
history,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
|
@ -0,0 +1,98 @@
|
||||||
|
|
||||||
|
import { ServerParams } from '../config.js';
|
||||||
|
import {
|
||||||
|
BaseImgParams,
|
||||||
|
HighresParams,
|
||||||
|
Img2ImgParams,
|
||||||
|
ModelParams,
|
||||||
|
UpscaleParams,
|
||||||
|
} from '../types/params.js';
|
||||||
|
import { Slice, TabState } from './types.js';
|
||||||
|
|
||||||
|
export interface Img2ImgSlice {
|
||||||
|
img2img: TabState<Img2ImgParams>;
|
||||||
|
img2imgModel: ModelParams;
|
||||||
|
img2imgHighres: HighresParams;
|
||||||
|
img2imgUpscale: UpscaleParams;
|
||||||
|
|
||||||
|
resetImg2Img(): void;
|
||||||
|
|
||||||
|
setImg2Img(params: Partial<Img2ImgParams>): void;
|
||||||
|
setImg2ImgModel(params: Partial<ModelParams>): void;
|
||||||
|
setImg2ImgHighres(params: Partial<HighresParams>): void;
|
||||||
|
setImg2ImgUpscale(params: Partial<UpscaleParams>): void;
|
||||||
|
}
|
||||||
|
|
||||||
|
// eslint-disable-next-line max-params
|
||||||
|
export function createImg2ImgSlice<TState extends Img2ImgSlice>(
|
||||||
|
server: ServerParams,
|
||||||
|
defaultParams: Required<BaseImgParams>,
|
||||||
|
defaultHighres: HighresParams,
|
||||||
|
defaultModel: ModelParams,
|
||||||
|
defaultUpscale: UpscaleParams
|
||||||
|
): Slice<TState, Img2ImgSlice> {
|
||||||
|
return (set) => ({
|
||||||
|
img2img: {
|
||||||
|
...defaultParams,
|
||||||
|
loopback: server.loopback.default,
|
||||||
|
// eslint-disable-next-line no-null/no-null
|
||||||
|
source: null,
|
||||||
|
sourceFilter: '',
|
||||||
|
strength: server.strength.default,
|
||||||
|
},
|
||||||
|
img2imgHighres: {
|
||||||
|
...defaultHighres,
|
||||||
|
},
|
||||||
|
img2imgModel: {
|
||||||
|
...defaultModel,
|
||||||
|
},
|
||||||
|
img2imgUpscale: {
|
||||||
|
...defaultUpscale,
|
||||||
|
},
|
||||||
|
resetImg2Img() {
|
||||||
|
set({
|
||||||
|
img2img: {
|
||||||
|
...defaultParams,
|
||||||
|
loopback: server.loopback.default,
|
||||||
|
// eslint-disable-next-line no-null/no-null
|
||||||
|
source: null,
|
||||||
|
sourceFilter: '',
|
||||||
|
strength: server.strength.default,
|
||||||
|
},
|
||||||
|
} as Partial<TState>);
|
||||||
|
},
|
||||||
|
setImg2Img(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
img2img: {
|
||||||
|
...prev.img2img,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setImg2ImgHighres(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
img2imgHighres: {
|
||||||
|
...prev.img2imgHighres,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setImg2ImgModel(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
img2imgModel: {
|
||||||
|
...prev.img2imgModel,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setImg2ImgUpscale(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
img2imgUpscale: {
|
||||||
|
...prev.img2imgUpscale,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,136 @@
|
||||||
|
import { ServerParams } from '../config.js';
|
||||||
|
import { DEFAULT_BRUSH } from '../constants.js';
|
||||||
|
import {
|
||||||
|
BaseImgParams,
|
||||||
|
BrushParams,
|
||||||
|
HighresParams,
|
||||||
|
InpaintParams,
|
||||||
|
ModelParams,
|
||||||
|
OutpaintPixels,
|
||||||
|
UpscaleParams,
|
||||||
|
} from '../types/params.js';
|
||||||
|
import { Slice, TabState } from './types.js';
|
||||||
|
export interface InpaintSlice {
|
||||||
|
inpaint: TabState<InpaintParams>;
|
||||||
|
inpaintBrush: BrushParams;
|
||||||
|
inpaintModel: ModelParams;
|
||||||
|
inpaintHighres: HighresParams;
|
||||||
|
inpaintUpscale: UpscaleParams;
|
||||||
|
outpaint: OutpaintPixels;
|
||||||
|
|
||||||
|
resetInpaint(): void;
|
||||||
|
|
||||||
|
setInpaint(params: Partial<InpaintParams>): void;
|
||||||
|
setInpaintBrush(brush: Partial<BrushParams>): void;
|
||||||
|
setInpaintModel(params: Partial<ModelParams>): void;
|
||||||
|
setInpaintHighres(params: Partial<HighresParams>): void;
|
||||||
|
setInpaintUpscale(params: Partial<UpscaleParams>): void;
|
||||||
|
setOutpaint(pixels: Partial<OutpaintPixels>): void;
|
||||||
|
}
|
||||||
|
|
||||||
|
// eslint-disable-next-line max-params
|
||||||
|
export function createInpaintSlice<TState extends InpaintSlice>(
|
||||||
|
server: ServerParams,
|
||||||
|
defaultParams: Required<BaseImgParams>,
|
||||||
|
defaultHighres: HighresParams,
|
||||||
|
defaultModel: ModelParams,
|
||||||
|
defaultUpscale: UpscaleParams,
|
||||||
|
): Slice<TState, InpaintSlice> {
|
||||||
|
return (set) => ({
|
||||||
|
inpaint: {
|
||||||
|
...defaultParams,
|
||||||
|
fillColor: server.fillColor.default,
|
||||||
|
filter: server.filter.default,
|
||||||
|
// eslint-disable-next-line no-null/no-null
|
||||||
|
mask: null,
|
||||||
|
noise: server.noise.default,
|
||||||
|
// eslint-disable-next-line no-null/no-null
|
||||||
|
source: null,
|
||||||
|
strength: server.strength.default,
|
||||||
|
tileOrder: server.tileOrder.default,
|
||||||
|
},
|
||||||
|
inpaintBrush: {
|
||||||
|
...DEFAULT_BRUSH,
|
||||||
|
},
|
||||||
|
inpaintHighres: {
|
||||||
|
...defaultHighres,
|
||||||
|
},
|
||||||
|
inpaintModel: {
|
||||||
|
...defaultModel,
|
||||||
|
},
|
||||||
|
inpaintUpscale: {
|
||||||
|
...defaultUpscale,
|
||||||
|
},
|
||||||
|
outpaint: {
|
||||||
|
enabled: false,
|
||||||
|
left: server.left.default,
|
||||||
|
right: server.right.default,
|
||||||
|
top: server.top.default,
|
||||||
|
bottom: server.bottom.default,
|
||||||
|
},
|
||||||
|
resetInpaint() {
|
||||||
|
set({
|
||||||
|
inpaint: {
|
||||||
|
...defaultParams,
|
||||||
|
fillColor: server.fillColor.default,
|
||||||
|
filter: server.filter.default,
|
||||||
|
// eslint-disable-next-line no-null/no-null
|
||||||
|
mask: null,
|
||||||
|
noise: server.noise.default,
|
||||||
|
// eslint-disable-next-line no-null/no-null
|
||||||
|
source: null,
|
||||||
|
strength: server.strength.default,
|
||||||
|
tileOrder: server.tileOrder.default,
|
||||||
|
},
|
||||||
|
} as Partial<TState>);
|
||||||
|
},
|
||||||
|
setInpaint(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
inpaint: {
|
||||||
|
...prev.inpaint,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setInpaintBrush(brush) {
|
||||||
|
set((prev) => ({
|
||||||
|
inpaintBrush: {
|
||||||
|
...prev.inpaintBrush,
|
||||||
|
...brush,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setInpaintHighres(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
inpaintHighres: {
|
||||||
|
...prev.inpaintHighres,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setInpaintModel(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
inpaintModel: {
|
||||||
|
...prev.inpaintModel,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setInpaintUpscale(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
inpaintUpscale: {
|
||||||
|
...prev.inpaintUpscale,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setOutpaint(pixels) {
|
||||||
|
set((prev) => ({
|
||||||
|
outpaint: {
|
||||||
|
...prev.outpaint,
|
||||||
|
...pixels,
|
||||||
|
}
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
|
@ -0,0 +1,93 @@
|
||||||
|
/* eslint-disable camelcase */
|
||||||
|
import { Logger } from 'browser-bunyan';
|
||||||
|
import { ServerParams } from '../../config.js';
|
||||||
|
import { BaseImgParams } from '../../types/params.js';
|
||||||
|
import { OnnxState, STATE_VERSION } from '../full.js';
|
||||||
|
import { Img2ImgSlice } from '../img2img.js';
|
||||||
|
import { InpaintSlice } from '../inpaint.js';
|
||||||
|
import { Txt2ImgSlice } from '../txt2img.js';
|
||||||
|
import { UpscaleSlice } from '../upscale.js';
|
||||||
|
|
||||||
|
// #region V7
|
||||||
|
export const V7 = 7;
|
||||||
|
|
||||||
|
export type BaseImgParamsV7<T extends BaseImgParams> = Omit<T, AddedKeysV11> & {
|
||||||
|
overlap: number;
|
||||||
|
tile: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type OnnxStateV7 = Omit<OnnxState, 'img2img' | 'txt2img'> & {
|
||||||
|
img2img: BaseImgParamsV7<Img2ImgSlice['img2img']>;
|
||||||
|
inpaint: BaseImgParamsV7<InpaintSlice['inpaint']>;
|
||||||
|
txt2img: BaseImgParamsV7<Txt2ImgSlice['txt2img']>;
|
||||||
|
upscale: BaseImgParamsV7<UpscaleSlice['upscale']>;
|
||||||
|
};
|
||||||
|
// #endregion
|
||||||
|
|
||||||
|
// #region V11
|
||||||
|
export const REMOVED_KEYS_V11 = ['tile', 'overlap'] as const;
|
||||||
|
|
||||||
|
export type RemovedKeysV11 = typeof REMOVED_KEYS_V11[number];
|
||||||
|
|
||||||
|
// TODO: can the compiler calculate this?
|
||||||
|
export type AddedKeysV11 = 'unet_tile' | 'unet_overlap' | 'vae_tile' | 'vae_overlap';
|
||||||
|
// #endregion
|
||||||
|
|
||||||
|
// add versions to this list as they are replaced
|
||||||
|
export type PreviousState = OnnxStateV7;
|
||||||
|
|
||||||
|
// always the latest version
|
||||||
|
export type CurrentState = OnnxState;
|
||||||
|
|
||||||
|
// any version of state
|
||||||
|
export type UnknownState = PreviousState | CurrentState;
|
||||||
|
|
||||||
|
export function applyStateMigrations(params: ServerParams, previousState: UnknownState, version: number, logger: Logger): OnnxState {
|
||||||
|
logger.info('applying state migrations from version %s to version %s', version, STATE_VERSION);
|
||||||
|
|
||||||
|
if (version <= V7) {
|
||||||
|
return migrateV7ToV11(params, previousState as PreviousState);
|
||||||
|
}
|
||||||
|
|
||||||
|
return previousState as CurrentState;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function migrateV7ToV11(params: ServerParams, previousState: PreviousState): CurrentState {
|
||||||
|
// add any missing keys
|
||||||
|
const result: CurrentState = {
|
||||||
|
...params,
|
||||||
|
...previousState,
|
||||||
|
img2img: {
|
||||||
|
...previousState.img2img,
|
||||||
|
unet_overlap: params.unet_overlap.default,
|
||||||
|
unet_tile: params.unet_tile.default,
|
||||||
|
vae_overlap: params.vae_overlap.default,
|
||||||
|
vae_tile: params.vae_tile.default,
|
||||||
|
},
|
||||||
|
inpaint: {
|
||||||
|
...previousState.inpaint,
|
||||||
|
unet_overlap: params.unet_overlap.default,
|
||||||
|
unet_tile: params.unet_tile.default,
|
||||||
|
vae_overlap: params.vae_overlap.default,
|
||||||
|
vae_tile: params.vae_tile.default,
|
||||||
|
},
|
||||||
|
txt2img: {
|
||||||
|
...previousState.txt2img,
|
||||||
|
unet_overlap: params.unet_overlap.default,
|
||||||
|
unet_tile: params.unet_tile.default,
|
||||||
|
vae_overlap: params.vae_overlap.default,
|
||||||
|
vae_tile: params.vae_tile.default,
|
||||||
|
},
|
||||||
|
upscale: {
|
||||||
|
...previousState.upscale,
|
||||||
|
unet_overlap: params.unet_overlap.default,
|
||||||
|
unet_tile: params.unet_tile.default,
|
||||||
|
vae_overlap: params.vae_overlap.default,
|
||||||
|
vae_tile: params.vae_tile.default,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: remove extra keys
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
|
@ -0,0 +1,202 @@
|
||||||
|
import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from '../types/model.js';
|
||||||
|
import { MISSING_INDEX, Slice } from './types.js';
|
||||||
|
|
||||||
|
export interface ModelSlice {
|
||||||
|
extras: ExtrasFile;
|
||||||
|
|
||||||
|
removeCorrectionModel(model: CorrectionModel): void;
|
||||||
|
removeDiffusionModel(model: DiffusionModel): void;
|
||||||
|
removeExtraNetwork(model: ExtraNetwork): void;
|
||||||
|
removeExtraSource(model: ExtraSource): void;
|
||||||
|
removeUpscalingModel(model: UpscalingModel): void;
|
||||||
|
|
||||||
|
setExtras(extras: Partial<ExtrasFile>): void;
|
||||||
|
|
||||||
|
setCorrectionModel(model: CorrectionModel): void;
|
||||||
|
setDiffusionModel(model: DiffusionModel): void;
|
||||||
|
setExtraNetwork(model: ExtraNetwork): void;
|
||||||
|
setExtraSource(model: ExtraSource): void;
|
||||||
|
setUpscalingModel(model: UpscalingModel): void;
|
||||||
|
}
|
||||||
|
|
||||||
|
// eslint-disable-next-line sonarjs/cognitive-complexity
|
||||||
|
export function createModelSlice<TState extends ModelSlice>(): Slice<TState, ModelSlice> {
|
||||||
|
// eslint-disable-next-line sonarjs/cognitive-complexity
|
||||||
|
return (set) => ({
|
||||||
|
extras: {
|
||||||
|
correction: [],
|
||||||
|
diffusion: [],
|
||||||
|
networks: [],
|
||||||
|
sources: [],
|
||||||
|
upscaling: [],
|
||||||
|
},
|
||||||
|
setExtras(extras) {
|
||||||
|
set((prev) => ({
|
||||||
|
...prev,
|
||||||
|
extras: {
|
||||||
|
...prev.extras,
|
||||||
|
...extras,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
},
|
||||||
|
setCorrectionModel(model) {
|
||||||
|
set((prev) => {
|
||||||
|
const correction = [...prev.extras.correction];
|
||||||
|
const exists = correction.findIndex((it) => model.name === it.name);
|
||||||
|
if (exists === MISSING_INDEX) {
|
||||||
|
correction.push(model);
|
||||||
|
} else {
|
||||||
|
correction[exists] = model;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
extras: {
|
||||||
|
...prev.extras,
|
||||||
|
correction,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
},
|
||||||
|
setDiffusionModel(model) {
|
||||||
|
set((prev) => {
|
||||||
|
const diffusion = [...prev.extras.diffusion];
|
||||||
|
const exists = diffusion.findIndex((it) => model.name === it.name);
|
||||||
|
if (exists === MISSING_INDEX) {
|
||||||
|
diffusion.push(model);
|
||||||
|
} else {
|
||||||
|
diffusion[exists] = model;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
extras: {
|
||||||
|
...prev.extras,
|
||||||
|
diffusion,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
},
|
||||||
|
setExtraNetwork(model) {
|
||||||
|
set((prev) => {
|
||||||
|
const networks = [...prev.extras.networks];
|
||||||
|
const exists = networks.findIndex((it) => model.name === it.name);
|
||||||
|
if (exists === MISSING_INDEX) {
|
||||||
|
networks.push(model);
|
||||||
|
} else {
|
||||||
|
networks[exists] = model;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
extras: {
|
||||||
|
...prev.extras,
|
||||||
|
networks,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
},
|
||||||
|
setExtraSource(model) {
|
||||||
|
set((prev) => {
|
||||||
|
const sources = [...prev.extras.sources];
|
||||||
|
const exists = sources.findIndex((it) => model.name === it.name);
|
||||||
|
if (exists === MISSING_INDEX) {
|
||||||
|
sources.push(model);
|
||||||
|
} else {
|
||||||
|
sources[exists] = model;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
extras: {
|
||||||
|
...prev.extras,
|
||||||
|
sources,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
},
|
||||||
|
setUpscalingModel(model) {
|
||||||
|
set((prev) => {
|
||||||
|
const upscaling = [...prev.extras.upscaling];
|
||||||
|
const exists = upscaling.findIndex((it) => model.name === it.name);
|
||||||
|
if (exists === MISSING_INDEX) {
|
||||||
|
upscaling.push(model);
|
||||||
|
} else {
|
||||||
|
upscaling[exists] = model;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
extras: {
|
||||||
|
...prev.extras,
|
||||||
|
upscaling,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
},
|
||||||
|
removeCorrectionModel(model) {
|
||||||
|
set((prev) => {
|
||||||
|
const correction = prev.extras.correction.filter((it) => model.name !== it.name);;
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
extras: {
|
||||||
|
...prev.extras,
|
||||||
|
correction,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
},
|
||||||
|
removeDiffusionModel(model) {
|
||||||
|
set((prev) => {
|
||||||
|
const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);;
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
extras: {
|
||||||
|
...prev.extras,
|
||||||
|
diffusion,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
},
|
||||||
|
removeExtraNetwork(model) {
|
||||||
|
set((prev) => {
|
||||||
|
const networks = prev.extras.networks.filter((it) => model.name !== it.name);;
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
extras: {
|
||||||
|
...prev.extras,
|
||||||
|
networks,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
},
|
||||||
|
removeExtraSource(model) {
|
||||||
|
set((prev) => {
|
||||||
|
const sources = prev.extras.sources.filter((it) => model.name !== it.name);;
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
extras: {
|
||||||
|
...prev.extras,
|
||||||
|
sources,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
},
|
||||||
|
removeUpscalingModel(model) {
|
||||||
|
set((prev) => {
|
||||||
|
const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);;
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
extras: {
|
||||||
|
...prev.extras,
|
||||||
|
upscaling,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
import { Maybe } from '@apextoaster/js-utils';
|
||||||
|
import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js';
|
||||||
|
import { Slice } from './types.js';
|
||||||
|
|
||||||
|
export interface ProfileItem {
|
||||||
|
name: string;
|
||||||
|
params: BaseImgParams | Txt2ImgParams;
|
||||||
|
highres?: Maybe<HighresParams>;
|
||||||
|
upscale?: Maybe<UpscaleParams>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ProfileSlice {
|
||||||
|
profiles: Array<ProfileItem>;
|
||||||
|
|
||||||
|
removeProfile(profileName: string): void;
|
||||||
|
|
||||||
|
saveProfile(profile: ProfileItem): void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createProfileSlice<TState extends ProfileSlice>(): Slice<TState, ProfileSlice> {
|
||||||
|
return (set) => ({
|
||||||
|
profiles: [],
|
||||||
|
saveProfile(profile: ProfileItem) {
|
||||||
|
set((prev) => {
|
||||||
|
const profiles = [...prev.profiles];
|
||||||
|
const idx = profiles.findIndex((it) => it.name === profile.name);
|
||||||
|
if (idx >= 0) {
|
||||||
|
profiles[idx] = profile;
|
||||||
|
} else {
|
||||||
|
profiles.push(profile);
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
profiles,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
},
|
||||||
|
removeProfile(profileName: string) {
|
||||||
|
set((prev) => {
|
||||||
|
const profiles = [...prev.profiles];
|
||||||
|
const idx = profiles.findIndex((it) => it.name === profileName);
|
||||||
|
if (idx >= 0) {
|
||||||
|
profiles.splice(idx, 1);
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
profiles,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
import { BlendSlice } from './blend.js';
|
||||||
|
import { Img2ImgSlice } from './img2img.js';
|
||||||
|
import { InpaintSlice } from './inpaint.js';
|
||||||
|
import { Txt2ImgSlice } from './txt2img.js';
|
||||||
|
import { Slice } from './types.js';
|
||||||
|
import { UpscaleSlice } from './upscale.js';
|
||||||
|
|
||||||
|
export type SlicesWithReset = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & UpscaleSlice & BlendSlice;
|
||||||
|
|
||||||
|
export interface ResetSlice {
|
||||||
|
resetAll(): void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createResetSlice<TState extends ResetSlice & SlicesWithReset>(): Slice<TState, ResetSlice> {
|
||||||
|
return (set) => ({
|
||||||
|
resetAll() {
|
||||||
|
set((prev) => {
|
||||||
|
const next = { ...prev };
|
||||||
|
next.resetImg2Img();
|
||||||
|
next.resetInpaint();
|
||||||
|
next.resetTxt2Img();
|
||||||
|
next.resetUpscale();
|
||||||
|
next.resetBlend();
|
||||||
|
return next;
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
|
@ -0,0 +1,105 @@
|
||||||
|
import { PipelineGrid } from '../client/utils.js';
|
||||||
|
import { ServerParams } from '../config.js';
|
||||||
|
import {
|
||||||
|
BaseImgParams,
|
||||||
|
HighresParams,
|
||||||
|
ModelParams,
|
||||||
|
Txt2ImgParams,
|
||||||
|
UpscaleParams,
|
||||||
|
} from '../types/params.js';
|
||||||
|
import { Slice, TabState } from './types.js';
|
||||||
|
|
||||||
|
export interface Txt2ImgSlice {
|
||||||
|
txt2img: TabState<Txt2ImgParams>;
|
||||||
|
txt2imgModel: ModelParams;
|
||||||
|
txt2imgHighres: HighresParams;
|
||||||
|
txt2imgUpscale: UpscaleParams;
|
||||||
|
txt2imgVariable: PipelineGrid;
|
||||||
|
|
||||||
|
resetTxt2Img(): void;
|
||||||
|
|
||||||
|
setTxt2Img(params: Partial<Txt2ImgParams>): void;
|
||||||
|
setTxt2ImgModel(params: Partial<ModelParams>): void;
|
||||||
|
setTxt2ImgHighres(params: Partial<HighresParams>): void;
|
||||||
|
setTxt2ImgUpscale(params: Partial<UpscaleParams>): void;
|
||||||
|
setTxt2ImgVariable(params: Partial<PipelineGrid>): void;
|
||||||
|
}
|
||||||
|
|
||||||
|
// eslint-disable-next-line max-params
|
||||||
|
export function createTxt2ImgSlice<TState extends Txt2ImgSlice>(
|
||||||
|
server: ServerParams,
|
||||||
|
defaultParams: Required<BaseImgParams>,
|
||||||
|
defaultHighres: HighresParams,
|
||||||
|
defaultModel: ModelParams,
|
||||||
|
defaultUpscale: UpscaleParams,
|
||||||
|
defaultGrid: PipelineGrid,
|
||||||
|
): Slice<TState, Txt2ImgSlice> {
|
||||||
|
return (set) => ({
|
||||||
|
txt2img: {
|
||||||
|
...defaultParams,
|
||||||
|
width: server.width.default,
|
||||||
|
height: server.height.default,
|
||||||
|
},
|
||||||
|
txt2imgHighres: {
|
||||||
|
...defaultHighres,
|
||||||
|
},
|
||||||
|
txt2imgModel: {
|
||||||
|
...defaultModel,
|
||||||
|
},
|
||||||
|
txt2imgUpscale: {
|
||||||
|
...defaultUpscale,
|
||||||
|
},
|
||||||
|
txt2imgVariable: {
|
||||||
|
...defaultGrid,
|
||||||
|
},
|
||||||
|
setTxt2Img(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
txt2img: {
|
||||||
|
...prev.txt2img,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setTxt2ImgHighres(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
txt2imgHighres: {
|
||||||
|
...prev.txt2imgHighres,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setTxt2ImgModel(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
txt2imgModel: {
|
||||||
|
...prev.txt2imgModel,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setTxt2ImgUpscale(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
txt2imgUpscale: {
|
||||||
|
...prev.txt2imgUpscale,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setTxt2ImgVariable(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
txt2imgVariable: {
|
||||||
|
...prev.txt2imgVariable,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
resetTxt2Img() {
|
||||||
|
set({
|
||||||
|
txt2img: {
|
||||||
|
...defaultParams,
|
||||||
|
width: server.width.default,
|
||||||
|
height: server.height.default,
|
||||||
|
},
|
||||||
|
} as Partial<TState>);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
|
@ -0,0 +1,17 @@
|
||||||
|
import { PaletteMode } from '@mui/material';
|
||||||
|
import { StateCreator } from 'zustand';
|
||||||
|
import { ConfigFiles, ConfigState } from '../config.js';
|
||||||
|
|
||||||
|
export const MISSING_INDEX = -1;
|
||||||
|
|
||||||
|
export type Theme = PaletteMode | ''; // tri-state, '' is unset
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Combine optional files and required ranges.
|
||||||
|
*/
|
||||||
|
export type TabState<TabParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Shorthand for state creator to reduce repeated arguments.
|
||||||
|
*/
|
||||||
|
export type Slice<TState, TValue> = StateCreator<TState, [], [], TValue>;
|
|
@ -0,0 +1,87 @@
|
||||||
|
import {
|
||||||
|
BaseImgParams,
|
||||||
|
HighresParams,
|
||||||
|
ModelParams,
|
||||||
|
UpscaleParams,
|
||||||
|
UpscaleReqParams,
|
||||||
|
} from '../types/params.js';
|
||||||
|
import { Slice, TabState } from './types.js';
|
||||||
|
|
||||||
|
export interface UpscaleSlice {
|
||||||
|
upscale: TabState<UpscaleReqParams>;
|
||||||
|
upscaleHighres: HighresParams;
|
||||||
|
upscaleModel: ModelParams;
|
||||||
|
upscaleUpscale: UpscaleParams;
|
||||||
|
|
||||||
|
resetUpscale(): void;
|
||||||
|
|
||||||
|
setUpscale(params: Partial<UpscaleReqParams>): void;
|
||||||
|
setUpscaleHighres(params: Partial<HighresParams>): void;
|
||||||
|
setUpscaleModel(params: Partial<ModelParams>): void;
|
||||||
|
setUpscaleUpscale(params: Partial<UpscaleParams>): void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createUpscaleSlice<TState extends UpscaleSlice>(
|
||||||
|
defaultParams: Required<BaseImgParams>,
|
||||||
|
defaultHighres: HighresParams,
|
||||||
|
defaultModel: ModelParams,
|
||||||
|
defaultUpscale: UpscaleParams,
|
||||||
|
): Slice<TState, UpscaleSlice> {
|
||||||
|
return (set) => ({
|
||||||
|
upscale: {
|
||||||
|
...defaultParams,
|
||||||
|
// eslint-disable-next-line no-null/no-null
|
||||||
|
source: null,
|
||||||
|
},
|
||||||
|
upscaleHighres: {
|
||||||
|
...defaultHighres,
|
||||||
|
},
|
||||||
|
upscaleModel: {
|
||||||
|
...defaultModel,
|
||||||
|
},
|
||||||
|
upscaleUpscale: {
|
||||||
|
...defaultUpscale,
|
||||||
|
},
|
||||||
|
resetUpscale() {
|
||||||
|
set({
|
||||||
|
upscale: {
|
||||||
|
...defaultParams,
|
||||||
|
// eslint-disable-next-line no-null/no-null
|
||||||
|
source: null,
|
||||||
|
},
|
||||||
|
} as Partial<TState>);
|
||||||
|
},
|
||||||
|
setUpscale(source) {
|
||||||
|
set((prev) => ({
|
||||||
|
upscale: {
|
||||||
|
...prev.upscale,
|
||||||
|
...source,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setUpscaleHighres(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
upscaleHighres: {
|
||||||
|
...prev.upscaleHighres,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setUpscaleModel(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
upscaleModel: {
|
||||||
|
...prev.upscaleModel,
|
||||||
|
...defaultModel,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
setUpscaleUpscale(params) {
|
||||||
|
set((prev) => ({
|
||||||
|
upscaleUpscale: {
|
||||||
|
...prev.upscaleUpscale,
|
||||||
|
...params,
|
||||||
|
},
|
||||||
|
} as Partial<TState>));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
|
@ -287,6 +287,7 @@ export const I18N_STRINGS_EN = {
|
||||||
'ddpm': 'DDPM',
|
'ddpm': 'DDPM',
|
||||||
'deis-multi': 'DEIS Multistep',
|
'deis-multi': 'DEIS Multistep',
|
||||||
'dpm-multi': 'DPM Multistep',
|
'dpm-multi': 'DPM Multistep',
|
||||||
|
'dpm-sde': 'DPM SDE (Turbo)',
|
||||||
'dpm-single': 'DPM Singlestep',
|
'dpm-single': 'DPM Singlestep',
|
||||||
'euler': 'Euler',
|
'euler': 'Euler',
|
||||||
'euler-a': 'Euler Ancestral',
|
'euler-a': 'Euler Ancestral',
|
||||||
|
|
Loading…
Reference in New Issue