Merge branch 'main' of https://github.com/ssube/onnx-web into feat/dynamic-wildcards
This commit is contained in:
commit
591a63edc8
|
@ -23,8 +23,8 @@ details](https://github.com/ssube/onnx-web/blob/main/docs/user-guide.md).
|
|||
|
||||
This is an incomplete list of new and interesting features, with links to the user guide:
|
||||
|
||||
- SDXL support
|
||||
- LCM support
|
||||
- supports SDXL and SDXL Turbo
|
||||
- wide variety of schedulers: DDIM, DEIS, DPM SDE, Euler Ancestral, LCM, UniPC, and more
|
||||
- hardware acceleration on both AMD and Nvidia
|
||||
- tested on CUDA, DirectML, and ROCm
|
||||
- [half-precision support for low-memory GPUs](docs/user-guide.md#optimizing-models-for-lower-memory-usage) on both
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
call onnx_env\Scripts\Activate.bat
|
||||
|
||||
echo "This launch.bat script is deprecated in favor of launch.ps1 and will be removed in a future release."
|
||||
|
||||
echo "Downloading and converting models to ONNX format..."
|
||||
IF "%ONNX_WEB_EXTRA_MODELS%"=="" (set ONNX_WEB_EXTRA_MODELS=..\models\extras.json)
|
||||
python -m onnx_web.convert ^
|
||||
|
@ -10,6 +12,12 @@ python -m onnx_web.convert ^
|
|||
--extras=%ONNX_WEB_EXTRA_MODELS% ^
|
||||
--token=%HF_TOKEN% %ONNX_WEB_EXTRA_ARGS%
|
||||
|
||||
IF NOT EXIST .\gui\index.html (
|
||||
echo "Please make sure you have downloaded the web UI files from https://github.com/ssube/onnx-web/tree/gh-pages"
|
||||
echo "See https://github.com/ssube/onnx-web/blob/main/docs/setup-guide.md#download-the-web-ui-bundle for more information"
|
||||
pause
|
||||
)
|
||||
|
||||
echo "Launching API server..."
|
||||
waitress-serve ^
|
||||
--host=0.0.0.0 ^
|
||||
|
|
|
@ -10,6 +10,13 @@ python -m onnx_web.convert `
|
|||
--extras=$Env:ONNX_WEB_EXTRA_MODELS `
|
||||
--token=$Env:HF_TOKEN $Env:ONNX_WEB_EXTRA_ARGS
|
||||
|
||||
if (!(Test-Path -path .\gui\index.html -PathType Leaf)) {
|
||||
echo "Downloading latest web UI files from Github..."
|
||||
Invoke-WebRequest "https://raw.githubusercontent.com/ssube/onnx-web/gh-pages/v0.11.0/index.html" -OutFile .\gui\index.html
|
||||
Invoke-WebRequest "https://raw.githubusercontent.com/ssube/onnx-web/gh-pages/v0.11.0/config.json" -OutFile .\gui\config.json
|
||||
Invoke-WebRequest "https://raw.githubusercontent.com/ssube/onnx-web/gh-pages/v0.11.0/bundle/main.js" -OutFile .\gui\bundle\main.js
|
||||
}
|
||||
|
||||
echo "Launching API server..."
|
||||
waitress-serve `
|
||||
--host=0.0.0.0 `
|
||||
|
|
|
@ -30,17 +30,16 @@ class BlendMaskStage(BaseStage):
|
|||
) -> StageResult:
|
||||
logger.info("blending image using mask")
|
||||
|
||||
# TODO: does this need an alpha channel?
|
||||
mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black")
|
||||
mult_mask.alpha_composite(stage_mask)
|
||||
mult_mask = Image.alpha_composite(mult_mask, stage_mask)
|
||||
mult_mask = mult_mask.convert("L")
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "last-mask.png", stage_mask)
|
||||
save_image(server, "last-mult-mask.png", mult_mask)
|
||||
|
||||
return StageResult(
|
||||
images=[
|
||||
return StageResult.from_images(
|
||||
[
|
||||
Image.composite(stage_source, source, mult_mask)
|
||||
for source in sources.as_image()
|
||||
]
|
||||
|
|
|
@ -3,16 +3,17 @@ from argparse import ArgumentParser
|
|||
from logging import getLogger
|
||||
from os import makedirs, path
|
||||
from sys import exit
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
from jsonschema import ValidationError, validate
|
||||
from onnx import load_model, save_model
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||
from ..server.plugin import load_plugins
|
||||
from ..utils import load_config
|
||||
from .client import add_model_source, fetch_model
|
||||
from .client.huggingface import HuggingfaceClient
|
||||
from .correction.gfpgan import convert_correction_gfpgan
|
||||
from .diffusion.control import convert_diffusion_control
|
||||
from .diffusion.diffusion import convert_diffusion_diffusers
|
||||
|
@ -25,8 +26,7 @@ from .upscaling.swinir import convert_upscaling_swinir
|
|||
from .utils import (
|
||||
DEFAULT_OPSET,
|
||||
ConversionContext,
|
||||
download_progress,
|
||||
remove_prefix,
|
||||
fix_diffusion_name,
|
||||
source_format,
|
||||
tuple_to_correction,
|
||||
tuple_to_diffusion,
|
||||
|
@ -44,32 +44,34 @@ warnings.filterwarnings(
|
|||
".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*",
|
||||
)
|
||||
|
||||
Models = Dict[str, List[Any]]
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
ModelDict = Dict[str, Union[float, int, str]]
|
||||
Models = Dict[str, List[ModelDict]]
|
||||
|
||||
model_sources: Dict[str, Tuple[str, str]] = {
|
||||
"civitai://": ("Civitai", "https://civitai.com/api/download/models/%s"),
|
||||
model_converters: Dict[str, Any] = {
|
||||
"img2img": convert_diffusion_diffusers,
|
||||
"img2img-sdxl": convert_diffusion_diffusers_xl,
|
||||
"inpaint": convert_diffusion_diffusers,
|
||||
"txt2img": convert_diffusion_diffusers,
|
||||
"txt2img-sdxl": convert_diffusion_diffusers_xl,
|
||||
}
|
||||
|
||||
model_source_huggingface = "huggingface://"
|
||||
|
||||
# recommended models
|
||||
base_models: Models = {
|
||||
"diffusion": [
|
||||
# v1.x
|
||||
(
|
||||
"stable-diffusion-onnx-v1-5",
|
||||
model_source_huggingface + "runwayml/stable-diffusion-v1-5",
|
||||
HuggingfaceClient.protocol + "runwayml/stable-diffusion-v1-5",
|
||||
),
|
||||
(
|
||||
"stable-diffusion-onnx-v1-inpainting",
|
||||
model_source_huggingface + "runwayml/stable-diffusion-inpainting",
|
||||
HuggingfaceClient.protocol + "runwayml/stable-diffusion-inpainting",
|
||||
),
|
||||
(
|
||||
"upscaling-stable-diffusion-x4",
|
||||
model_source_huggingface + "stabilityai/stable-diffusion-x4-upscaler",
|
||||
HuggingfaceClient.protocol + "stabilityai/stable-diffusion-x4-upscaler",
|
||||
True,
|
||||
),
|
||||
],
|
||||
|
@ -200,180 +202,68 @@ base_models: Models = {
|
|||
}
|
||||
|
||||
|
||||
def fetch_model(
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
source: str,
|
||||
dest: Optional[str] = None,
|
||||
format: Optional[str] = None,
|
||||
hf_hub_fetch: bool = False,
|
||||
hf_hub_filename: Optional[str] = None,
|
||||
) -> Tuple[str, bool]:
|
||||
cache_path = dest or conversion.cache_path
|
||||
cache_name = path.join(cache_path, name)
|
||||
|
||||
# add an extension if possible, some of the conversion code checks for it
|
||||
if format is None:
|
||||
url = urlparse(source)
|
||||
ext = path.basename(url.path)
|
||||
_filename, ext = path.splitext(ext)
|
||||
if ext is not None:
|
||||
cache_name = cache_name + ext
|
||||
else:
|
||||
cache_name = f"{cache_name}.{format}"
|
||||
|
||||
if path.exists(cache_name):
|
||||
logger.debug("model already exists in cache, skipping fetch")
|
||||
return cache_name, False
|
||||
|
||||
for proto in model_sources:
|
||||
api_name, api_root = model_sources.get(proto)
|
||||
if source.startswith(proto):
|
||||
api_source = api_root % (remove_prefix(source, proto))
|
||||
logger.info(
|
||||
"downloading model from %s: %s -> %s", api_name, api_source, cache_name
|
||||
)
|
||||
return download_progress([(api_source, cache_name)]), False
|
||||
|
||||
if source.startswith(model_source_huggingface):
|
||||
hub_source = remove_prefix(source, model_source_huggingface)
|
||||
logger.info("downloading model from Huggingface Hub: %s", hub_source)
|
||||
# from_pretrained has a bunch of useful logic that snapshot_download by itself down not
|
||||
if hf_hub_fetch:
|
||||
return (
|
||||
hf_hub_download(
|
||||
repo_id=hub_source,
|
||||
filename=hf_hub_filename,
|
||||
cache_dir=cache_path,
|
||||
force_filename=f"{name}.bin",
|
||||
),
|
||||
False,
|
||||
)
|
||||
else:
|
||||
return hub_source, True
|
||||
elif source.startswith("https://"):
|
||||
logger.info("downloading model from: %s", source)
|
||||
return download_progress([(source, cache_name)]), False
|
||||
elif source.startswith("http://"):
|
||||
logger.warning("downloading model from insecure source: %s", source)
|
||||
return download_progress([(source, cache_name)]), False
|
||||
elif source.startswith(path.sep) or source.startswith("."):
|
||||
logger.info("using local model: %s", source)
|
||||
return source, False
|
||||
else:
|
||||
logger.info("unknown model location, using path as provided: %s", source)
|
||||
return source, False
|
||||
|
||||
|
||||
def convert_models(conversion: ConversionContext, args, models: Models):
|
||||
model_errors = []
|
||||
|
||||
if args.sources and "sources" in models:
|
||||
for model in models.get("sources", []):
|
||||
model = tuple_to_source(model)
|
||||
name = model.get("name")
|
||||
|
||||
if name in args.skip:
|
||||
logger.info("skipping source: %s", name)
|
||||
else:
|
||||
def convert_model_source(conversion: ConversionContext, model):
|
||||
model_format = source_format(model)
|
||||
name = model["name"]
|
||||
source = model["source"]
|
||||
|
||||
try:
|
||||
dest_path = None
|
||||
if "dest" in model:
|
||||
dest_path = path.join(conversion.model_path, model["dest"])
|
||||
|
||||
dest, hf = fetch_model(
|
||||
conversion, name, source, format=model_format, dest=dest_path
|
||||
)
|
||||
dest = fetch_model(conversion, name, source, format=model_format, dest=dest_path)
|
||||
logger.info("finished downloading source: %s -> %s", source, dest)
|
||||
except Exception:
|
||||
logger.exception("error fetching source %s", name)
|
||||
model_errors.append(name)
|
||||
|
||||
if args.networks and "networks" in models:
|
||||
for network in models.get("networks", []):
|
||||
name = network["name"]
|
||||
|
||||
if name in args.skip:
|
||||
logger.info("skipping network: %s", name)
|
||||
else:
|
||||
network_format = source_format(network)
|
||||
network_model = network.get("model", None)
|
||||
network_type = network["type"]
|
||||
source = network["source"]
|
||||
def convert_model_network(conversion: ConversionContext, model):
|
||||
model_format = source_format(model)
|
||||
model_type = model["type"]
|
||||
name = model["name"]
|
||||
source = model["source"]
|
||||
|
||||
try:
|
||||
if network_type == "control":
|
||||
dest, hf = fetch_model(
|
||||
if model_type == "control":
|
||||
dest = fetch_model(
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
format=network_format,
|
||||
format=model_format,
|
||||
)
|
||||
|
||||
convert_diffusion_control(
|
||||
conversion,
|
||||
network,
|
||||
model,
|
||||
dest,
|
||||
path.join(conversion.model_path, network_type, name),
|
||||
)
|
||||
if network_type == "inversion" and network_model == "concept":
|
||||
dest, hf = fetch_model(
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
dest=path.join(conversion.model_path, network_type),
|
||||
format=network_format,
|
||||
hf_hub_fetch=True,
|
||||
hf_hub_filename="learned_embeds.bin",
|
||||
path.join(conversion.model_path, model_type, name),
|
||||
)
|
||||
else:
|
||||
dest, hf = fetch_model(
|
||||
model = model.get("model", None)
|
||||
dest = fetch_model(
|
||||
conversion,
|
||||
name,
|
||||
source,
|
||||
dest=path.join(conversion.model_path, network_type),
|
||||
format=network_format,
|
||||
dest=path.join(conversion.model_path, model_type),
|
||||
format=model_format,
|
||||
embeds=(model_type == "inversion" and model == "concept"),
|
||||
)
|
||||
|
||||
logger.info("finished downloading network: %s -> %s", source, dest)
|
||||
except Exception:
|
||||
logger.exception("error fetching network %s", name)
|
||||
model_errors.append(name)
|
||||
|
||||
if args.diffusion and "diffusion" in models:
|
||||
for model in models.get("diffusion", []):
|
||||
model = tuple_to_diffusion(model)
|
||||
name = model.get("name")
|
||||
|
||||
if name in args.skip:
|
||||
logger.info("skipping model: %s", name)
|
||||
else:
|
||||
def convert_model_diffusion(conversion: ConversionContext, model):
|
||||
# fix up entries with missing prefixes
|
||||
name = fix_diffusion_name(model["name"])
|
||||
if name != model["name"]:
|
||||
# update the model in-memory if the name changed
|
||||
model["name"] = name
|
||||
|
||||
model_format = source_format(model)
|
||||
|
||||
try:
|
||||
source, hf = fetch_model(
|
||||
conversion, name, model["source"], format=model_format
|
||||
)
|
||||
|
||||
pipeline = model.get("pipeline", "txt2img")
|
||||
if pipeline.endswith("-sdxl"):
|
||||
converted, dest = convert_diffusion_diffusers_xl(
|
||||
converter = model_converters.get(pipeline)
|
||||
converted, dest = converter(
|
||||
conversion,
|
||||
model,
|
||||
source,
|
||||
model_format,
|
||||
hf=hf,
|
||||
)
|
||||
else:
|
||||
converted, dest = convert_diffusion_diffusers(
|
||||
conversion,
|
||||
model,
|
||||
source,
|
||||
model_format,
|
||||
hf=hf,
|
||||
)
|
||||
|
||||
# make sure blending only happens once, not every run
|
||||
|
@ -395,9 +285,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
)
|
||||
|
||||
if "tokenizer" not in blend_models:
|
||||
blend_models[
|
||||
"tokenizer"
|
||||
] = CLIPTokenizer.from_pretrained(
|
||||
blend_models["tokenizer"] = CLIPTokenizer.from_pretrained(
|
||||
dest,
|
||||
subfolder="tokenizer",
|
||||
)
|
||||
|
@ -405,7 +293,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
inversion_name = inversion["name"]
|
||||
inversion_source = inversion["source"]
|
||||
inversion_format = inversion.get("format", None)
|
||||
inversion_source, hf = fetch_model(
|
||||
inversion_source = fetch_model(
|
||||
conversion,
|
||||
inversion_name,
|
||||
inversion_source,
|
||||
|
@ -439,14 +327,12 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
)
|
||||
|
||||
if "unet" not in blend_models:
|
||||
blend_models["unet"] = load_model(
|
||||
path.join(dest, "unet", ONNX_MODEL)
|
||||
)
|
||||
blend_models["unet"] = load_model(path.join(dest, "unet", ONNX_MODEL))
|
||||
|
||||
# load models if not loaded yet
|
||||
lora_name = lora["name"]
|
||||
lora_source = lora["source"]
|
||||
lora_source, hf = fetch_model(
|
||||
lora_source = fetch_model(
|
||||
conversion,
|
||||
f"{name}-lora-{lora_name}",
|
||||
lora_source,
|
||||
|
@ -476,9 +362,7 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
for name in ["text_encoder", "unet"]:
|
||||
if name in blend_models:
|
||||
dest_path = path.join(dest, name, ONNX_MODEL)
|
||||
logger.debug(
|
||||
"saving blended %s model to %s", name, dest_path
|
||||
)
|
||||
logger.debug("saving blended %s model to %s", name, dest_path)
|
||||
save_model(
|
||||
blend_models[name],
|
||||
dest_path,
|
||||
|
@ -487,6 +371,76 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
location=ONNX_WEIGHTS,
|
||||
)
|
||||
|
||||
|
||||
def convert_model_upscaling(conversion: ConversionContext, model):
|
||||
model_format = source_format(model)
|
||||
name = model["name"]
|
||||
|
||||
source = fetch_model(conversion, name, model["source"], format=model_format)
|
||||
model_type = model.get("model", "resrgan")
|
||||
if model_type == "bsrgan":
|
||||
convert_upscaling_bsrgan(conversion, model, source)
|
||||
elif model_type == "resrgan":
|
||||
convert_upscale_resrgan(conversion, model, source)
|
||||
elif model_type == "swinir":
|
||||
convert_upscaling_swinir(conversion, model, source)
|
||||
else:
|
||||
logger.error("unknown upscaling model type %s for %s", model_type, name)
|
||||
raise ValueError(name)
|
||||
|
||||
|
||||
def convert_model_correction(conversion: ConversionContext, model):
|
||||
model_format = source_format(model)
|
||||
name = model["name"]
|
||||
source = fetch_model(conversion, name, model["source"], format=model_format)
|
||||
model_type = model.get("model", "gfpgan")
|
||||
if model_type == "gfpgan":
|
||||
convert_correction_gfpgan(conversion, model, source)
|
||||
else:
|
||||
logger.error("unknown correction model type %s for %s", model_type, name)
|
||||
raise ValueError(name)
|
||||
|
||||
|
||||
def convert_models(conversion: ConversionContext, args, models: Models):
|
||||
model_errors = []
|
||||
|
||||
if args.sources and "sources" in models:
|
||||
for model in models.get("sources", []):
|
||||
model = tuple_to_source(model)
|
||||
name = model.get("name")
|
||||
|
||||
if name in args.skip:
|
||||
logger.info("skipping source: %s", name)
|
||||
else:
|
||||
try:
|
||||
convert_model_source(conversion, model)
|
||||
except Exception:
|
||||
logger.exception("error fetching source %s", name)
|
||||
model_errors.append(name)
|
||||
|
||||
if args.networks and "networks" in models:
|
||||
for model in models.get("networks", []):
|
||||
name = model["name"]
|
||||
|
||||
if name in args.skip:
|
||||
logger.info("skipping network: %s", name)
|
||||
else:
|
||||
try:
|
||||
convert_model_network(conversion, model)
|
||||
except Exception:
|
||||
logger.exception("error fetching network %s", name)
|
||||
model_errors.append(name)
|
||||
|
||||
if args.diffusion and "diffusion" in models:
|
||||
for model in models.get("diffusion", []):
|
||||
model = tuple_to_diffusion(model)
|
||||
name = model.get("name")
|
||||
|
||||
if name in args.skip:
|
||||
logger.info("skipping model: %s", name)
|
||||
else:
|
||||
try:
|
||||
convert_model_diffusion(conversion, model)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error converting diffusion model %s",
|
||||
|
@ -502,24 +456,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
if name in args.skip:
|
||||
logger.info("skipping model: %s", name)
|
||||
else:
|
||||
model_format = source_format(model)
|
||||
|
||||
try:
|
||||
source, hf = fetch_model(
|
||||
conversion, name, model["source"], format=model_format
|
||||
)
|
||||
model_type = model.get("model", "resrgan")
|
||||
if model_type == "bsrgan":
|
||||
convert_upscaling_bsrgan(conversion, model, source)
|
||||
elif model_type == "resrgan":
|
||||
convert_upscale_resrgan(conversion, model, source)
|
||||
elif model_type == "swinir":
|
||||
convert_upscaling_swinir(conversion, model, source)
|
||||
else:
|
||||
logger.error(
|
||||
"unknown upscaling model type %s for %s", model_type, name
|
||||
)
|
||||
model_errors.append(name)
|
||||
convert_model_upscaling(conversion, model)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error converting upscaling model %s",
|
||||
|
@ -535,19 +473,8 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
if name in args.skip:
|
||||
logger.info("skipping model: %s", name)
|
||||
else:
|
||||
model_format = source_format(model)
|
||||
try:
|
||||
source, hf = fetch_model(
|
||||
conversion, name, model["source"], format=model_format
|
||||
)
|
||||
model_type = model.get("model", "gfpgan")
|
||||
if model_type == "gfpgan":
|
||||
convert_correction_gfpgan(conversion, model, source)
|
||||
else:
|
||||
logger.error(
|
||||
"unknown correction model type %s for %s", model_type, name
|
||||
)
|
||||
model_errors.append(name)
|
||||
convert_model_correction(conversion, model)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"error converting correction model %s",
|
||||
|
@ -559,12 +486,26 @@ def convert_models(conversion: ConversionContext, args, models: Models):
|
|||
logger.error("error while converting models: %s", model_errors)
|
||||
|
||||
|
||||
def register_plugins(conversion: ConversionContext):
|
||||
logger.info("loading conversion plugins")
|
||||
exports = load_plugins(conversion)
|
||||
|
||||
for proto, client in exports.clients:
|
||||
try:
|
||||
add_model_source(proto, client)
|
||||
except Exception:
|
||||
logger.exception("error loading client for protocol: %s", proto)
|
||||
|
||||
# TODO: add converters
|
||||
|
||||
|
||||
def main(args=None) -> int:
|
||||
parser = ArgumentParser(
|
||||
prog="onnx-web model converter", description="convert checkpoint models to ONNX"
|
||||
)
|
||||
|
||||
# model groups
|
||||
parser.add_argument("--base", action="store_true", default=True)
|
||||
parser.add_argument("--networks", action="store_true", default=True)
|
||||
parser.add_argument("--sources", action="store_true", default=True)
|
||||
parser.add_argument("--correction", action="store_true", default=False)
|
||||
|
@ -602,14 +543,18 @@ def main(args=None) -> int:
|
|||
server.half = args.half or server.has_optimization("onnx-fp16")
|
||||
server.opset = args.opset
|
||||
server.token = args.token
|
||||
|
||||
register_plugins(server)
|
||||
|
||||
logger.info(
|
||||
"converting models in %s using %s", server.model_path, server.training_device
|
||||
"converting models into %s using %s", server.model_path, server.training_device
|
||||
)
|
||||
|
||||
if not path.exists(server.model_path):
|
||||
logger.info("model path does not existing, creating: %s", server.model_path)
|
||||
makedirs(server.model_path)
|
||||
|
||||
if args.base:
|
||||
logger.info("converting base models")
|
||||
convert_models(server, args, base_models)
|
||||
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
from typing import Callable, Dict, Optional
|
||||
from logging import getLogger
|
||||
from os import path
|
||||
|
||||
from ..utils import ConversionContext
|
||||
from .base import BaseClient
|
||||
from .civitai import CivitaiClient
|
||||
from .file import FileClient
|
||||
from .http import HttpClient
|
||||
from .huggingface import HuggingfaceClient
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
model_sources: Dict[str, Callable[[], BaseClient]] = {
|
||||
CivitaiClient.protocol: CivitaiClient,
|
||||
FileClient.protocol: FileClient,
|
||||
HttpClient.insecure_protocol: HttpClient,
|
||||
HttpClient.protocol: HttpClient,
|
||||
HuggingfaceClient.protocol: HuggingfaceClient,
|
||||
}
|
||||
|
||||
|
||||
def add_model_source(proto: str, client: BaseClient):
|
||||
global model_sources
|
||||
|
||||
if proto in model_sources:
|
||||
raise ValueError("protocol has already been taken")
|
||||
|
||||
model_sources[proto] = client
|
||||
|
||||
|
||||
def fetch_model(
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
source: str,
|
||||
format: Optional[str] = None,
|
||||
dest: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
# TODO: switch to urlparse's default scheme
|
||||
if source.startswith(path.sep) or source.startswith("."):
|
||||
logger.info("adding file protocol to local path source: %s", source)
|
||||
source = FileClient.protocol + source
|
||||
|
||||
for proto, client_type in model_sources.items():
|
||||
if source.startswith(proto):
|
||||
client = client_type(conversion)
|
||||
return client.download(
|
||||
conversion, name, source, format=format, dest=dest, **kwargs
|
||||
)
|
||||
|
||||
logger.warning("unknown model protocol, using path as provided: %s", source)
|
||||
return source
|
|
@ -0,0 +1,16 @@
|
|||
from typing import Optional
|
||||
|
||||
from ..utils import ConversionContext
|
||||
|
||||
|
||||
class BaseClient:
|
||||
def download(
|
||||
self,
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
source: str,
|
||||
format: Optional[str] = None,
|
||||
dest: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
raise NotImplementedError()
|
|
@ -0,0 +1,64 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
from ..utils import (
|
||||
ConversionContext,
|
||||
build_cache_paths,
|
||||
download_progress,
|
||||
get_first_exists,
|
||||
remove_prefix,
|
||||
)
|
||||
from .base import BaseClient
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
CIVITAI_ROOT = "https://civitai.com/api/download/models/%s"
|
||||
|
||||
|
||||
class CivitaiClient(BaseClient):
|
||||
name = "civitai"
|
||||
protocol = "civitai://"
|
||||
|
||||
root: str
|
||||
token: Optional[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversion: ConversionContext,
|
||||
token: Optional[str] = None,
|
||||
root: str = CIVITAI_ROOT,
|
||||
):
|
||||
self.root = conversion.get_setting("CIVITAI_ROOT", root)
|
||||
self.token = conversion.get_setting("CIVITAI_TOKEN", token)
|
||||
|
||||
def download(
|
||||
self,
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
source: str,
|
||||
format: Optional[str] = None,
|
||||
dest: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
cache_paths = build_cache_paths(
|
||||
conversion,
|
||||
name,
|
||||
client=CivitaiClient.name,
|
||||
format=format,
|
||||
dest=dest,
|
||||
)
|
||||
cached = get_first_exists(cache_paths)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
source = self.root % (remove_prefix(source, CivitaiClient.protocol))
|
||||
logger.info("downloading model from Civitai: %s -> %s", source, cache_paths[0])
|
||||
|
||||
if self.token:
|
||||
logger.debug("adding Civitai token authentication")
|
||||
if "?" in source:
|
||||
source = f"{source}&token={self.token}"
|
||||
else:
|
||||
source = f"{source}?token={self.token}"
|
||||
|
||||
return download_progress(source, cache_paths[0])
|
|
@ -0,0 +1,32 @@
|
|||
from logging import getLogger
|
||||
from os import path
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ..utils import ConversionContext
|
||||
from .base import BaseClient
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class FileClient(BaseClient):
|
||||
protocol = "file://"
|
||||
|
||||
def __init__(self, _conversion: ConversionContext):
|
||||
"""
|
||||
Nothing to initialize for this client.
|
||||
"""
|
||||
pass
|
||||
|
||||
def download(
|
||||
self,
|
||||
conversion: ConversionContext,
|
||||
_name: str,
|
||||
uri: str,
|
||||
format: Optional[str] = None,
|
||||
dest: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
parts = urlparse(uri)
|
||||
logger.info("loading model from: %s", parts.path)
|
||||
return path.join(dest or conversion.model_path, parts.path)
|
|
@ -0,0 +1,52 @@
|
|||
from logging import getLogger
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ..utils import (
|
||||
ConversionContext,
|
||||
build_cache_paths,
|
||||
download_progress,
|
||||
get_first_exists,
|
||||
)
|
||||
from .base import BaseClient
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class HttpClient(BaseClient):
|
||||
name = "http"
|
||||
protocol = "https://"
|
||||
insecure_protocol = "http://"
|
||||
|
||||
headers: Dict[str, str]
|
||||
|
||||
def __init__(
|
||||
self, _conversion: ConversionContext, headers: Optional[Dict[str, str]] = None
|
||||
):
|
||||
self.headers = headers or {}
|
||||
|
||||
def download(
|
||||
self,
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
source: str,
|
||||
format: Optional[str] = None,
|
||||
dest: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
cache_paths = build_cache_paths(
|
||||
conversion,
|
||||
name,
|
||||
client=HttpClient.name,
|
||||
format=format,
|
||||
dest=dest,
|
||||
)
|
||||
cached = get_first_exists(cache_paths)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
if source.startswith(HttpClient.protocol):
|
||||
logger.info("downloading model from: %s", source)
|
||||
elif source.startswith(HttpClient.insecure_protocol):
|
||||
logger.warning("downloading model from insecure source: %s", source)
|
||||
|
||||
return download_progress(source, cache_paths[0])
|
|
@ -0,0 +1,43 @@
|
|||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
from huggingface_hub.file_download import hf_hub_download
|
||||
|
||||
from ..utils import ConversionContext, remove_prefix
|
||||
from .base import BaseClient
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class HuggingfaceClient(BaseClient):
|
||||
name = "huggingface"
|
||||
protocol = "huggingface://"
|
||||
|
||||
token: Optional[str]
|
||||
|
||||
def __init__(self, conversion: ConversionContext, token: Optional[str] = None):
|
||||
self.token = conversion.get_setting("HUGGINGFACE_TOKEN", token)
|
||||
|
||||
def download(
|
||||
self,
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
source: str,
|
||||
format: Optional[str] = None,
|
||||
dest: Optional[str] = None,
|
||||
embeds: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
source = remove_prefix(source, HuggingfaceClient.protocol)
|
||||
logger.info("downloading model from Huggingface Hub: %s", source)
|
||||
|
||||
if embeds:
|
||||
return hf_hub_download(
|
||||
repo_id=source,
|
||||
filename="learned_embeds.bin",
|
||||
cache_dir=dest or conversion.cache_path,
|
||||
force_filename=f"{name}.bin",
|
||||
token=self.token,
|
||||
)
|
||||
else:
|
||||
return source
|
|
@ -36,6 +36,8 @@ from ...diffusers.pipelines.upscale import OnnxStableDiffusionUpscalePipeline
|
|||
from ...diffusers.version_safe_diffusers import AttnProcessor
|
||||
from ...models.cnet import UNet2DConditionModel_CNet
|
||||
from ...utils import run_gc
|
||||
from ..client import fetch_model
|
||||
from ..client.huggingface import HuggingfaceClient
|
||||
from ..utils import (
|
||||
RESOLVE_FORMATS,
|
||||
ConversionContext,
|
||||
|
@ -43,6 +45,7 @@ from ..utils import (
|
|||
is_torch_2_0,
|
||||
load_tensor,
|
||||
onnx_export,
|
||||
remove_prefix,
|
||||
)
|
||||
from .checkpoint import convert_extract_checkpoint
|
||||
|
||||
|
@ -267,14 +270,13 @@ def collate_cnet(cnet_path):
|
|||
def convert_diffusion_diffusers(
|
||||
conversion: ConversionContext,
|
||||
model: Dict,
|
||||
source: str,
|
||||
format: Optional[str],
|
||||
hf: bool = False,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
"""
|
||||
name = model.get("name")
|
||||
name = str(model.get("name")).strip()
|
||||
source = model.get("source")
|
||||
|
||||
# optional
|
||||
config = model.get("config", None)
|
||||
|
@ -320,9 +322,11 @@ def convert_diffusion_diffusers(
|
|||
logger.info("ONNX model already exists, skipping")
|
||||
return (False, dest_path)
|
||||
|
||||
cache_path = fetch_model(conversion, name, source, format=format)
|
||||
|
||||
pipe_class = CONVERT_PIPELINES.get(pipe_type)
|
||||
v2, pipe_args = get_model_version(
|
||||
source, conversion.map_location, size=image_size, version=version
|
||||
cache_path, conversion.map_location, size=image_size, version=version
|
||||
)
|
||||
|
||||
is_inpainting = False
|
||||
|
@ -334,14 +338,15 @@ def convert_diffusion_diffusers(
|
|||
pipe_args["from_safetensors"] = True
|
||||
|
||||
torch_source = None
|
||||
if path.exists(source) and path.isdir(source):
|
||||
if path.exists(cache_path):
|
||||
if path.isdir(cache_path):
|
||||
logger.debug("loading pipeline from diffusers directory: %s", source)
|
||||
pipeline = pipe_class.from_pretrained(
|
||||
source,
|
||||
cache_path,
|
||||
torch_dtype=dtype,
|
||||
use_auth_token=conversion.token,
|
||||
).to(device)
|
||||
elif path.exists(source) and path.isfile(source):
|
||||
else:
|
||||
if conversion.extract:
|
||||
logger.debug("extracting SD checkpoint to Torch models: %s", source)
|
||||
torch_source = convert_extract_checkpoint(
|
||||
|
@ -352,7 +357,9 @@ def convert_diffusion_diffusers(
|
|||
config_file=config,
|
||||
vae_file=replace_vae,
|
||||
)
|
||||
logger.debug("loading pipeline from extracted checkpoint: %s", torch_source)
|
||||
logger.debug(
|
||||
"loading pipeline from extracted checkpoint: %s", torch_source
|
||||
)
|
||||
pipeline = pipe_class.from_pretrained(
|
||||
torch_source,
|
||||
torch_dtype=dtype,
|
||||
|
@ -363,28 +370,34 @@ def convert_diffusion_diffusers(
|
|||
else:
|
||||
logger.debug("loading pipeline from SD checkpoint: %s", source)
|
||||
pipeline = download_from_original_stable_diffusion_ckpt(
|
||||
source,
|
||||
cache_path,
|
||||
original_config_file=config_path,
|
||||
pipeline_class=pipe_class,
|
||||
**pipe_args,
|
||||
).to(device, torch_dtype=dtype)
|
||||
elif hf:
|
||||
logger.debug("downloading pretrained model from Huggingface hub: %s", source)
|
||||
elif source.startswith(HuggingfaceClient.protocol):
|
||||
hf_path = remove_prefix(source, HuggingfaceClient.protocol)
|
||||
logger.debug("downloading pretrained model from Huggingface hub: %s", hf_path)
|
||||
pipeline = pipe_class.from_pretrained(
|
||||
source,
|
||||
hf_path,
|
||||
torch_dtype=dtype,
|
||||
use_auth_token=conversion.token,
|
||||
).to(device)
|
||||
else:
|
||||
logger.warning("pipeline source not found or not recognized: %s", source)
|
||||
raise ValueError(f"pipeline source not found or not recognized: {source}")
|
||||
logger.warning(
|
||||
"pipeline source not found and protocol not recognized: %s", source
|
||||
)
|
||||
raise ValueError(
|
||||
f"pipeline source not found and protocol not recognized: {source}"
|
||||
)
|
||||
|
||||
if replace_vae is not None:
|
||||
vae_path = path.join(conversion.model_path, replace_vae)
|
||||
if check_ext(replace_vae, RESOLVE_FORMATS):
|
||||
vae_file = check_ext(vae_path, RESOLVE_FORMATS)
|
||||
if vae_file[0]:
|
||||
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
|
||||
else:
|
||||
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
|
||||
pipeline.vae = AutoencoderKL.from_pretrained(replace_vae)
|
||||
|
||||
if is_torch_2_0:
|
||||
pipeline.unet.set_attn_processor(AttnProcessor())
|
||||
|
|
|
@ -10,6 +10,7 @@ from onnxruntime.transformers.float16 import convert_float_to_float16
|
|||
from optimum.exporters.onnx import main_export
|
||||
|
||||
from ...constants import ONNX_MODEL
|
||||
from ..client import fetch_model
|
||||
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -19,14 +20,13 @@ logger = getLogger(__name__)
|
|||
def convert_diffusion_diffusers_xl(
|
||||
conversion: ConversionContext,
|
||||
model: Dict,
|
||||
source: str,
|
||||
format: Optional[str],
|
||||
hf: bool = False,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
"""
|
||||
name = model.get("name")
|
||||
name = str(model.get("name")).strip()
|
||||
source = model.get("source")
|
||||
replace_vae = model.get("vae", None)
|
||||
|
||||
device = conversion.training_device
|
||||
|
@ -52,24 +52,26 @@ def convert_diffusion_diffusers_xl(
|
|||
|
||||
return (False, dest_path)
|
||||
|
||||
cache_path = fetch_model(conversion, name, model["source"], format=format)
|
||||
# safetensors -> diffusers directory with torch models
|
||||
temp_path = path.join(conversion.cache_path, f"{name}-torch")
|
||||
|
||||
if format == "safetensors":
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(
|
||||
source, use_safetensors=True
|
||||
cache_path, use_safetensors=True
|
||||
)
|
||||
else:
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(source)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(cache_path)
|
||||
|
||||
if replace_vae is not None:
|
||||
vae_path = path.join(conversion.model_path, replace_vae)
|
||||
if check_ext(replace_vae, RESOLVE_FORMATS):
|
||||
vae_file = check_ext(vae_path, RESOLVE_FORMATS)
|
||||
if vae_file[0]:
|
||||
logger.debug("loading VAE from single tensor file: %s", vae_path)
|
||||
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
|
||||
else:
|
||||
logger.debug("loading pretrained VAE from path: %s", vae_path)
|
||||
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
|
||||
logger.debug("loading pretrained VAE from path: %s", replace_vae)
|
||||
pipeline.vae = AutoencoderKL.from_pretrained(replace_vae)
|
||||
|
||||
if path.exists(temp_path):
|
||||
logger.debug("torch model already exists for %s: %s", source, temp_path)
|
||||
|
|
|
@ -31,6 +31,15 @@ ModelDict = Dict[str, Union[str, int]]
|
|||
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]
|
||||
|
||||
DEFAULT_OPSET = 14
|
||||
DIFFUSION_PREFIX = [
|
||||
"diffusion-",
|
||||
"diffusion/",
|
||||
"diffusion\\",
|
||||
"stable-diffusion-",
|
||||
"upscaling-", # SD upscaling
|
||||
]
|
||||
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
|
||||
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"]
|
||||
|
||||
|
||||
class ConversionContext(ServerContext):
|
||||
|
@ -85,8 +94,7 @@ class ConversionContext(ServerContext):
|
|||
return torch.device(self.training_device)
|
||||
|
||||
|
||||
def download_progress(urls: List[Tuple[str, str]]):
|
||||
for url, dest in urls:
|
||||
def download_progress(source: str, dest: str):
|
||||
dest_path = Path(dest).expanduser().resolve()
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
@ -95,7 +103,7 @@ def download_progress(urls: List[Tuple[str, str]]):
|
|||
return str(dest_path.absolute())
|
||||
|
||||
req = requests.get(
|
||||
url,
|
||||
source,
|
||||
stream=True,
|
||||
allow_redirects=True,
|
||||
headers={
|
||||
|
@ -105,7 +113,7 @@ def download_progress(urls: List[Tuple[str, str]]):
|
|||
if req.status_code != 200:
|
||||
req.raise_for_status() # Only works for 4xx errors, per SO answer
|
||||
raise RequestException(
|
||||
"request to %s failed with status code: %s" % (url, req.status_code)
|
||||
"request to %s failed with status code: %s" % (source, req.status_code)
|
||||
)
|
||||
|
||||
total = int(req.headers.get("Content-Length", 0))
|
||||
|
@ -184,10 +192,6 @@ def tuple_to_upscaling(model: Union[ModelDict, LegacyModel]):
|
|||
return model
|
||||
|
||||
|
||||
MODEL_FORMATS = ["onnx", "pth", "ckpt", "safetensors"]
|
||||
RESOLVE_FORMATS = ["safetensors", "ckpt", "pt", "pth", "bin"]
|
||||
|
||||
|
||||
def check_ext(name: str, exts: List[str]) -> Tuple[bool, str]:
|
||||
_name, ext = path.splitext(name)
|
||||
ext = ext.strip(".")
|
||||
|
@ -215,6 +219,9 @@ def remove_prefix(name: str, prefix: str) -> str:
|
|||
|
||||
|
||||
def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
||||
"""
|
||||
TODO: move out of convert
|
||||
"""
|
||||
try:
|
||||
logger.debug("loading tensor with Torch: %s", name)
|
||||
checkpoint = torch.load(name, map_location=map_location)
|
||||
|
@ -228,6 +235,9 @@ def load_torch(name: str, map_location=None) -> Optional[Dict]:
|
|||
|
||||
|
||||
def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
||||
"""
|
||||
TODO: move out of convert
|
||||
"""
|
||||
logger.debug("loading tensor: %s", name)
|
||||
_, extension = path.splitext(name)
|
||||
extension = extension[1:].lower()
|
||||
|
@ -285,6 +295,9 @@ def load_tensor(name: str, map_location=None) -> Optional[Dict]:
|
|||
|
||||
|
||||
def resolve_tensor(name: str) -> Optional[str]:
|
||||
"""
|
||||
TODO: move out of convert
|
||||
"""
|
||||
logger.debug("searching for tensors with known extensions: %s", name)
|
||||
for next_extension in RESOLVE_FORMATS:
|
||||
next_name = f"{name}.{next_extension}"
|
||||
|
@ -345,3 +358,52 @@ def onnx_export(
|
|||
all_tensors_to_one_file=True,
|
||||
location=ONNX_WEIGHTS,
|
||||
)
|
||||
|
||||
|
||||
def fix_diffusion_name(name: str):
|
||||
if not any([name.startswith(prefix) for prefix in DIFFUSION_PREFIX]):
|
||||
logger.warning(
|
||||
"diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match",
|
||||
name,
|
||||
)
|
||||
return f"diffusion-{name}"
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def build_cache_paths(
|
||||
conversion: ConversionContext,
|
||||
name: str,
|
||||
client: Optional[str] = None,
|
||||
dest: Optional[str] = None,
|
||||
format: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
cache_path = dest or conversion.cache_path
|
||||
|
||||
# add an extension if possible, some of the conversion code checks for it
|
||||
if format is not None:
|
||||
basename = path.basename(name)
|
||||
_filename, ext = path.splitext(basename)
|
||||
if ext is None or ext == "":
|
||||
name = f"{name}.{format}"
|
||||
|
||||
paths = [
|
||||
path.join(cache_path, name),
|
||||
]
|
||||
|
||||
if client is not None:
|
||||
client_path = path.join(cache_path, client)
|
||||
paths.append(path.join(client_path, name))
|
||||
|
||||
return paths
|
||||
|
||||
|
||||
def get_first_exists(
|
||||
paths: List[str],
|
||||
) -> Optional[str]:
|
||||
for name in paths:
|
||||
if path.exists(name):
|
||||
logger.debug("model already exists in cache, skipping fetch: %s", name)
|
||||
return name
|
||||
|
||||
return None
|
||||
|
|
|
@ -30,12 +30,12 @@ from .version_safe_diffusers import (
|
|||
DDPMScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DPMSolverSDEScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
IPNDMScheduler,
|
||||
KarrasVeScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
LCMScheduler,
|
||||
|
@ -71,6 +71,7 @@ pipeline_schedulers = {
|
|||
"ddpm": DDPMScheduler,
|
||||
"deis-multi": DEISMultistepScheduler,
|
||||
"dpm-multi": DPMSolverMultistepScheduler,
|
||||
"dpm-sde": DPMSolverSDEScheduler,
|
||||
"dpm-single": DPMSolverSinglestepScheduler,
|
||||
"euler": EulerDiscreteScheduler,
|
||||
"euler-a": EulerAncestralDiscreteScheduler,
|
||||
|
@ -78,7 +79,6 @@ pipeline_schedulers = {
|
|||
"ipndm": IPNDMScheduler,
|
||||
"k-dpm-2-a": KDPM2AncestralDiscreteScheduler,
|
||||
"k-dpm-2": KDPM2DiscreteScheduler,
|
||||
"karras-ve": KarrasVeScheduler,
|
||||
"lcm": LCMScheduler,
|
||||
"lms-discrete": LMSDiscreteScheduler,
|
||||
"pndm": PNDMScheduler,
|
||||
|
|
|
@ -12,6 +12,11 @@ try:
|
|||
except ImportError:
|
||||
from ..diffusers.stub_scheduler import StubScheduler as DEISMultistepScheduler
|
||||
|
||||
try:
|
||||
from diffusers import DPMSolverSDEScheduler
|
||||
except ImportError:
|
||||
from ..diffusers.stub_scheduler import StubScheduler as DPMSolverSDEScheduler
|
||||
|
||||
try:
|
||||
from diffusers import LCMScheduler
|
||||
except ImportError:
|
||||
|
|
|
@ -94,45 +94,44 @@ class ServerContext:
|
|||
self.cache = ModelCache(self.cache_limit)
|
||||
|
||||
@classmethod
|
||||
def from_environ(cls):
|
||||
memory_limit = environ.get("ONNX_WEB_MEMORY_LIMIT", None)
|
||||
def from_environ(cls, env=environ):
|
||||
memory_limit = env.get("ONNX_WEB_MEMORY_LIMIT", None)
|
||||
if memory_limit is not None:
|
||||
memory_limit = int(memory_limit)
|
||||
|
||||
return cls(
|
||||
bundle_path=environ.get(
|
||||
"ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")
|
||||
),
|
||||
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
|
||||
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
|
||||
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
|
||||
cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", default="*"),
|
||||
bundle_path=env.get("ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")),
|
||||
model_path=env.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
|
||||
output_path=env.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
|
||||
params_path=env.get("ONNX_WEB_PARAMS_PATH", "."),
|
||||
cors_origin=get_list(env, "ONNX_WEB_CORS_ORIGIN", default="*"),
|
||||
any_platform=get_boolean(
|
||||
environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
|
||||
env, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM
|
||||
),
|
||||
block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS"),
|
||||
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT),
|
||||
cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
|
||||
block_platforms=get_list(env, "ONNX_WEB_BLOCK_PLATFORMS"),
|
||||
default_platform=env.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||
image_format=env.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT),
|
||||
cache_limit=int(env.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)),
|
||||
show_progress=get_boolean(
|
||||
environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
|
||||
env, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS
|
||||
),
|
||||
optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS"),
|
||||
extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS"),
|
||||
job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
||||
optimizations=get_list(env, "ONNX_WEB_OPTIMIZATIONS"),
|
||||
extra_models=get_list(env, "ONNX_WEB_EXTRA_MODELS"),
|
||||
job_limit=int(env.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)),
|
||||
memory_limit=memory_limit,
|
||||
admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None),
|
||||
server_version=environ.get(
|
||||
"ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION
|
||||
),
|
||||
admin_token=env.get("ONNX_WEB_ADMIN_TOKEN", None),
|
||||
server_version=env.get("ONNX_WEB_SERVER_VERSION", DEFAULT_SERVER_VERSION),
|
||||
worker_retries=int(
|
||||
environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
|
||||
env.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES)
|
||||
),
|
||||
feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS"),
|
||||
plugins=get_list(environ, "ONNX_WEB_PLUGINS", ""),
|
||||
debug=get_boolean(environ, "ONNX_WEB_DEBUG", False),
|
||||
feature_flags=get_list(env, "ONNX_WEB_FEATURE_FLAGS"),
|
||||
plugins=get_list(env, "ONNX_WEB_PLUGINS", ""),
|
||||
debug=get_boolean(env, "ONNX_WEB_DEBUG", False),
|
||||
)
|
||||
|
||||
def get_setting(self, flag: str, default: str) -> Optional[str]:
|
||||
return environ.get(f"ONNX_WEB_{flag}", default)
|
||||
|
||||
def has_feature(self, flag: str) -> bool:
|
||||
return flag in self.feature_flags
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||
import torch
|
||||
from jsonschema import ValidationError, validate
|
||||
|
||||
from ..convert.utils import fix_diffusion_name
|
||||
from ..image import ( # mask filters; noise sources
|
||||
mask_filter_gaussian_multiply,
|
||||
mask_filter_gaussian_screen,
|
||||
|
@ -189,6 +190,9 @@ def load_extras(server: ServerContext):
|
|||
for model in data[model_type]:
|
||||
model_name = model["name"]
|
||||
|
||||
if model_type == "diffusion":
|
||||
model_name = fix_diffusion_name(model_name)
|
||||
|
||||
if "hash" in model:
|
||||
logger.debug(
|
||||
"collecting hash for model %s from %s",
|
||||
|
|
|
@ -2,18 +2,24 @@ from importlib import import_module
|
|||
from logging import getLogger
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from onnx_web.chain.stages import add_stage
|
||||
from onnx_web.diffusers.load import add_pipeline
|
||||
from onnx_web.server.context import ServerContext
|
||||
from ..chain.stages import add_stage
|
||||
from ..diffusers.load import add_pipeline
|
||||
from ..server.context import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class PluginExports:
|
||||
clients: Dict[str, Any]
|
||||
converter: Dict[str, Any]
|
||||
pipelines: Dict[str, Any]
|
||||
stages: Dict[str, Any]
|
||||
|
||||
def __init__(self, pipelines=None, stages=None) -> None:
|
||||
def __init__(
|
||||
self, clients=None, converter=None, pipelines=None, stages=None
|
||||
) -> None:
|
||||
self.clients = clients or {}
|
||||
self.converter = converter or {}
|
||||
self.pipelines = pipelines or {}
|
||||
self.stages = stages or {}
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
},
|
||||
"cfg": {
|
||||
"default": 6,
|
||||
"min": 1,
|
||||
"min": 0,
|
||||
"max": 30,
|
||||
"step": 0.1
|
||||
},
|
||||
|
|
|
@ -3,21 +3,21 @@ numpy==1.23.5
|
|||
protobuf==3.20.3
|
||||
|
||||
### SD packages ###
|
||||
accelerate==0.22.0
|
||||
accelerate==0.25.0
|
||||
coloredlogs==15.0.1
|
||||
controlnet_aux==0.0.2
|
||||
datasets==2.14.3
|
||||
diffusers==0.20.0
|
||||
datasets==2.15.0
|
||||
diffusers==0.24.0
|
||||
huggingface-hub==0.16.4
|
||||
invisible-watermark==0.2.0
|
||||
mediapipe==0.9.2.1
|
||||
omegaconf==2.3.0
|
||||
onnx==1.13.0
|
||||
# onnxruntime has many platform-specific packages
|
||||
optimum==1.12.0
|
||||
safetensors==0.3.1
|
||||
timm==0.6.13
|
||||
transformers==4.32.0
|
||||
optimum==1.16.0
|
||||
safetensors==0.4.1
|
||||
timm==0.9.12
|
||||
transformers==4.36.1
|
||||
|
||||
#### Upscaling and face correction
|
||||
basicsr==1.4.2
|
||||
|
@ -29,10 +29,11 @@ realesrgan==0.3.0
|
|||
### Server packages ###
|
||||
arpeggio==2.0.0
|
||||
boto3==1.26.69
|
||||
flask==2.2.2
|
||||
flask==3.0.0
|
||||
flask-cors==3.0.10
|
||||
jsonschema==4.17.3
|
||||
piexif==1.1.3
|
||||
pyyaml==6.0
|
||||
setproctitle==1.3.2
|
||||
waitress==2.1.2
|
||||
werkzeug==3.0.1
|
||||
|
|
|
@ -27,7 +27,7 @@ class ConversionContextTests(unittest.TestCase):
|
|||
|
||||
class DownloadProgressTests(unittest.TestCase):
|
||||
def test_download_example(self):
|
||||
path = download_progress([("https://example.com", "/tmp/example-dot-com")])
|
||||
path = download_progress("https://example.com", "/tmp/example-dot-com")
|
||||
self.assertEqual(path, "/tmp/example-dot-com")
|
||||
|
||||
|
||||
|
|
|
@ -414,6 +414,7 @@ class TestBlendPipeline(unittest.TestCase):
|
|||
3.0,
|
||||
1,
|
||||
1,
|
||||
unet_tile=64,
|
||||
),
|
||||
Size(64, 64),
|
||||
["test-blend.png"],
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
# Getting Started With onnx-web
|
||||
|
||||
onnx-web is a tool for generating images with Stable Diffusion pipelines, including SDXL.
|
||||
|
||||
## Contents
|
||||
|
||||
- [Getting Started With onnx-web](#getting-started-with-onnx-web)
|
||||
- [Contents](#contents)
|
||||
- [Setup](#setup)
|
||||
- [Windows bundle setup](#windows-bundle-setup)
|
||||
- [Other setup methods](#other-setup-methods)
|
||||
- [Running](#running)
|
||||
- [Running the server](#running-the-server)
|
||||
- [Running the web UI](#running-the-web-ui)
|
||||
- [Tabs](#tabs)
|
||||
- [Txt2img Tab](#txt2img-tab)
|
||||
- [Img2img Tab](#img2img-tab)
|
||||
- [Inpaint Tab](#inpaint-tab)
|
||||
- [Upscale Tab](#upscale-tab)
|
||||
- [Blend Tab](#blend-tab)
|
||||
- [Models Tab](#models-tab)
|
||||
- [Settings Tab](#settings-tab)
|
||||
- [Image parameters](#image-parameters)
|
||||
- [Common image parameters](#common-image-parameters)
|
||||
- [Unique image parameters](#unique-image-parameters)
|
||||
- [Prompt syntax](#prompt-syntax)
|
||||
- [LoRAs and embeddings](#loras-and-embeddings)
|
||||
- [CLIP skip](#clip-skip)
|
||||
- [Highres](#highres)
|
||||
- [Highres prompt](#highres-prompt)
|
||||
- [Highres iterations](#highres-iterations)
|
||||
- [Profiles](#profiles)
|
||||
- [Loading from files](#loading-from-files)
|
||||
- [Saving profiles in the web UI](#saving-profiles-in-the-web-ui)
|
||||
- [Sharing parameters profiles](#sharing-parameters-profiles)
|
||||
- [Panorama pipeline](#panorama-pipeline)
|
||||
- [Region prompts](#region-prompts)
|
||||
- [Region seeds](#region-seeds)
|
||||
- [Grid mode](#grid-mode)
|
||||
- [Grid tokens](#grid-tokens)
|
||||
- [Memory optimizations](#memory-optimizations)
|
||||
- [Converting to fp16](#converting-to-fp16)
|
||||
- [Moving models to the CPU](#moving-models-to-the-cpu)
|
||||
|
||||
## Setup
|
||||
|
||||
### Windows bundle setup
|
||||
|
||||
1. Download
|
||||
2. Extract
|
||||
3. Security flags
|
||||
4. Run
|
||||
|
||||
### Other setup methods
|
||||
|
||||
Link to the other methods.
|
||||
|
||||
## Running
|
||||
|
||||
### Running the server
|
||||
|
||||
Run server or bundle.
|
||||
|
||||
### Running the web UI
|
||||
|
||||
Open web UI.
|
||||
|
||||
Use it from your phone.
|
||||
|
||||
## Tabs
|
||||
|
||||
There are 5 tabs, which do different things.
|
||||
|
||||
### Txt2img Tab
|
||||
|
||||
Words go in, pictures come out.
|
||||
|
||||
### Img2img Tab
|
||||
|
||||
Pictures go in, better pictures come out.
|
||||
|
||||
ControlNet lives here.
|
||||
|
||||
### Inpaint Tab
|
||||
|
||||
Pictures go in, parts of the same picture come out.
|
||||
|
||||
### Upscale Tab
|
||||
|
||||
Just highres and super resolution.
|
||||
|
||||
### Blend Tab
|
||||
|
||||
Use the mask tool to combine two images.
|
||||
|
||||
### Models Tab
|
||||
|
||||
Add and manage models.
|
||||
|
||||
### Settings Tab
|
||||
|
||||
Manage web UI settings.
|
||||
|
||||
Reset buttons.
|
||||
|
||||
## Image parameters
|
||||
|
||||
### Common image parameters
|
||||
|
||||
- Scheduler
|
||||
- Eta
|
||||
- for DDIM
|
||||
- CFG
|
||||
- Steps
|
||||
- Seed
|
||||
- Batch size
|
||||
- Prompt
|
||||
- Negative prompt
|
||||
- Width, height
|
||||
|
||||
### Unique image parameters
|
||||
|
||||
- UNet tile size
|
||||
- UNet overlap
|
||||
- Tiled VAE
|
||||
- VAE tile size
|
||||
- VAE overlap
|
||||
|
||||
See the complete user guide for details about the highres, upscale, and correction parameters.
|
||||
|
||||
## Prompt syntax
|
||||
|
||||
### LoRAs and embeddings
|
||||
|
||||
`<lora:filename:1.0>` and `<inversion:filename:1.0>`.
|
||||
|
||||
### CLIP skip
|
||||
|
||||
`<clip:skip:2>` for anime.
|
||||
|
||||
## Highres
|
||||
|
||||
### Highres prompt
|
||||
|
||||
`txt2img prompt || img2img prompt`
|
||||
|
||||
### Highres iterations
|
||||
|
||||
Highres will apply the upscaler and highres prompt (img2img pipeline) for each iteration.
|
||||
|
||||
The final size will be `scale ** iterations`.
|
||||
|
||||
## Profiles
|
||||
|
||||
Saved sets of parameters for later use.
|
||||
|
||||
### Loading from files
|
||||
|
||||
- load from images
|
||||
- load from JSON
|
||||
|
||||
### Saving profiles in the web UI
|
||||
|
||||
Use the save button.
|
||||
|
||||
### Sharing parameters profiles
|
||||
|
||||
Use the download button.
|
||||
|
||||
Share profiles in the Discord channel.
|
||||
|
||||
## Panorama pipeline
|
||||
|
||||
### Region prompts
|
||||
|
||||
`<region:X:Y:W:H:S:F_TLBR:prompt+>`
|
||||
|
||||
### Region seeds
|
||||
|
||||
`<reseed:X:Y:W:H:?:F_TLBR:seed>`
|
||||
|
||||
## Grid mode
|
||||
|
||||
Makes many images. Takes many time.
|
||||
|
||||
### Grid tokens
|
||||
|
||||
`__column__` and `__row__` if you pick token in the menu.
|
||||
|
||||
## Memory optimizations
|
||||
|
||||
### Converting to fp16
|
||||
|
||||
Enable the option.
|
||||
|
||||
### Moving models to the CPU
|
||||
|
||||
Option for each model.
|
|
@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next';
|
|||
import { useStore } from 'zustand';
|
||||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { OnnxState, StateContext } from '../state.js';
|
||||
import { OnnxState, StateContext } from '../state/full.js';
|
||||
import { ErrorCard } from './card/ErrorCard.js';
|
||||
import { ImageCard } from './card/ImageCard.js';
|
||||
import { LoadingCard } from './card/LoadingCard.js';
|
||||
|
|
|
@ -2,7 +2,7 @@ import { Box, Button, Container, Stack, Typography } from '@mui/material';
|
|||
import * as React from 'react';
|
||||
import { ReactNode } from 'react';
|
||||
|
||||
import { STATE_KEY } from '../state.js';
|
||||
import { STATE_KEY } from '../state/full.js';
|
||||
import { Logo } from './Logo.js';
|
||||
|
||||
export interface OnnxErrorProps {
|
||||
|
|
|
@ -7,7 +7,7 @@ import { useContext, useMemo } from 'react';
|
|||
import { useHash } from 'react-use/lib/useHash';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { OnnxState, StateContext } from '../state.js';
|
||||
import { OnnxState, StateContext } from '../state/full.js';
|
||||
import { ImageHistory } from './ImageHistory.js';
|
||||
import { Logo } from './Logo.js';
|
||||
import { Blend } from './tab/Blend.js';
|
||||
|
|
|
@ -21,7 +21,7 @@ import { useTranslation } from 'react-i18next';
|
|||
import { useStore } from 'zustand';
|
||||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { OnnxState, StateContext } from '../state.js';
|
||||
import { OnnxState, StateContext } from '../state/full.js';
|
||||
import { ImageMetadata } from '../types/api.js';
|
||||
import { DeepPartial } from '../types/model.js';
|
||||
import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js';
|
||||
|
|
|
@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next';
|
|||
import { useStore } from 'zustand';
|
||||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js';
|
||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { ImageResponse, ReadyResponse, RetryParams } from '../../types/api.js';
|
||||
|
||||
export interface ErrorCardProps {
|
||||
|
|
|
@ -8,9 +8,10 @@ import { useHash } from 'react-use/lib/useHash';
|
|||
import { useStore } from 'zustand';
|
||||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { BLEND_SOURCES, ConfigContext, OnnxState, StateContext } from '../../state.js';
|
||||
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { ImageResponse } from '../../types/api.js';
|
||||
import { range, visibleIndex } from '../../utils.js';
|
||||
import { BLEND_SOURCES } from '../../constants.js';
|
||||
|
||||
export interface ImageCardProps {
|
||||
image: ImageResponse;
|
||||
|
|
|
@ -9,7 +9,7 @@ import { useStore } from 'zustand';
|
|||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { POLL_TIME } from '../../config.js';
|
||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js';
|
||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { ImageResponse } from '../../types/api.js';
|
||||
|
||||
const LOADING_PERCENT = 100;
|
||||
|
|
|
@ -5,7 +5,7 @@ import { useContext } from 'react';
|
|||
import { useTranslation } from 'react-i18next';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ConfigContext, OnnxState, StateContext } from '../../state.js';
|
||||
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { HighresParams } from '../../types/params.js';
|
||||
import { NumericField } from '../input/NumericField.js';
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ import { useStore } from 'zustand';
|
|||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { STALE_TIME } from '../../config.js';
|
||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js';
|
||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { BaseImgParams } from '../../types/params.js';
|
||||
import { NumericField } from '../input/NumericField.js';
|
||||
import { PromptInput } from '../input/PromptInput.js';
|
||||
|
|
|
@ -6,7 +6,7 @@ import { useContext } from 'react';
|
|||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { STALE_TIME } from '../../config.js';
|
||||
import { ClientContext } from '../../state.js';
|
||||
import { ClientContext } from '../../state/full.js';
|
||||
import { ModelParams } from '../../types/params.js';
|
||||
import { QueryList } from '../input/QueryList.js';
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next';
|
|||
import { useStore } from 'zustand';
|
||||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { ConfigContext, OnnxState, StateContext } from '../../state.js';
|
||||
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { NumericField } from '../input/NumericField.js';
|
||||
|
||||
export function OutpaintControl() {
|
||||
|
|
|
@ -5,7 +5,7 @@ import { useContext } from 'react';
|
|||
import { useTranslation } from 'react-i18next';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ConfigContext, OnnxState, StateContext } from '../../state.js';
|
||||
import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { UpscaleParams } from '../../types/params.js';
|
||||
import { NumericField } from '../input/NumericField.js';
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import { useContext } from 'react';
|
|||
import { useStore } from 'zustand';
|
||||
|
||||
import { PipelineGrid } from '../../client/utils.js';
|
||||
import { OnnxState, StateContext } from '../../state.js';
|
||||
import { OnnxState, StateContext } from '../../state/full.js';
|
||||
import { VARIABLE_PARAMETERS } from '../../types/chain.js';
|
||||
|
||||
export interface VariableControlProps {
|
||||
|
|
|
@ -4,7 +4,7 @@ import * as React from 'react';
|
|||
import { useTranslation } from 'react-i18next';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { OnnxState, StateContext } from '../../state.js';
|
||||
import { OnnxState, StateContext } from '../../state/full.js';
|
||||
|
||||
const { useContext, useState, memo, useMemo } = React;
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import React, { RefObject, useContext, useEffect, useMemo, useRef } from 'react'
|
|||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { SAVE_TIME } from '../../config.js';
|
||||
import { ConfigContext, LoggerContext, StateContext } from '../../state.js';
|
||||
import { ConfigContext, LoggerContext, StateContext } from '../../state/full.js';
|
||||
import { BrushParams } from '../../types/params.js';
|
||||
import { imageFromBlob } from '../../utils.js';
|
||||
import { NumericField } from './NumericField';
|
||||
|
|
|
@ -8,7 +8,7 @@ import { useStore } from 'zustand';
|
|||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { STALE_TIME } from '../../config.js';
|
||||
import { ClientContext, OnnxState, StateContext } from '../../state.js';
|
||||
import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { QueryMenu } from '../input/QueryMenu.js';
|
||||
import { ModelResponse } from '../../types/api.js';
|
||||
|
||||
|
|
|
@ -8,7 +8,9 @@ import { useStore } from 'zustand';
|
|||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { IMAGE_FILTER } from '../../config.js';
|
||||
import { BLEND_SOURCES, ClientContext, OnnxState, StateContext, TabState } from '../../state.js';
|
||||
import { BLEND_SOURCES } from '../../constants.js';
|
||||
import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { TabState } from '../../state/types.js';
|
||||
import { BlendParams, BrushParams, ModelParams, UpscaleParams } from '../../types/params.js';
|
||||
import { range } from '../../utils.js';
|
||||
import { UpscaleControl } from '../control/UpscaleControl.js';
|
||||
|
|
|
@ -8,7 +8,8 @@ import { useStore } from 'zustand';
|
|||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { IMAGE_FILTER, STALE_TIME } from '../../config.js';
|
||||
import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js';
|
||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { TabState } from '../../state/types.js';
|
||||
import { HighresParams, Img2ImgParams, ModelParams, UpscaleParams } from '../../types/params.js';
|
||||
import { Profiles } from '../Profiles.js';
|
||||
import { HighresControl } from '../control/HighresControl.js';
|
||||
|
|
|
@ -8,7 +8,8 @@ import { useStore } from 'zustand';
|
|||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { IMAGE_FILTER, STALE_TIME } from '../../config.js';
|
||||
import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js';
|
||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { TabState } from '../../state/types.js';
|
||||
import { BrushParams, HighresParams, InpaintParams, ModelParams, UpscaleParams } from '../../types/params.js';
|
||||
import { Profiles } from '../Profiles.js';
|
||||
import { HighresControl } from '../control/HighresControl.js';
|
||||
|
|
|
@ -8,7 +8,7 @@ import { useStore } from 'zustand';
|
|||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { STALE_TIME } from '../../config.js';
|
||||
import { ClientContext, OnnxState, StateContext } from '../../state.js';
|
||||
import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import {
|
||||
CorrectionModel,
|
||||
DiffusionModel,
|
||||
|
|
|
@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next';
|
|||
import { useStore } from 'zustand';
|
||||
|
||||
import { getApiRoot } from '../../config.js';
|
||||
import { ConfigContext, StateContext, STATE_KEY } from '../../state.js';
|
||||
import { ConfigContext, StateContext, STATE_KEY } from '../../state/full.js';
|
||||
import { getTheme } from '../utils.js';
|
||||
import { NumericField } from '../input/NumericField.js';
|
||||
|
||||
|
|
|
@ -8,7 +8,8 @@ import { useStore } from 'zustand';
|
|||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { PipelineGrid, makeTxt2ImgGridPipeline } from '../../client/utils.js';
|
||||
import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js';
|
||||
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { TabState } from '../../state/types.js';
|
||||
import { HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from '../../types/params.js';
|
||||
import { Profiles } from '../Profiles.js';
|
||||
import { HighresControl } from '../control/HighresControl.js';
|
||||
|
|
|
@ -8,7 +8,8 @@ import { useStore } from 'zustand';
|
|||
import { shallow } from 'zustand/shallow';
|
||||
|
||||
import { IMAGE_FILTER } from '../../config.js';
|
||||
import { ClientContext, OnnxState, StateContext, TabState } from '../../state.js';
|
||||
import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
|
||||
import { TabState } from '../../state/types.js';
|
||||
import { HighresParams, ModelParams, UpscaleParams, UpscaleReqParams } from '../../types/params.js';
|
||||
import { Profiles } from '../Profiles.js';
|
||||
import { HighresControl } from '../control/HighresControl.js';
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import { PaletteMode } from '@mui/material';
|
||||
|
||||
import { Theme } from '../state.js';
|
||||
import { Theme } from '../state/types.js';
|
||||
import { trimHash } from '../utils.js';
|
||||
|
||||
export const TAB_LABELS = [
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
},
|
||||
"cfg": {
|
||||
"default": 6,
|
||||
"min": 1,
|
||||
"min": 0,
|
||||
"max": 30,
|
||||
"step": 0.1
|
||||
},
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
|
||||
export const BLEND_SOURCES = 2;
|
||||
|
||||
/**
|
||||
* Default parameters for the inpaint brush.
|
||||
*
|
||||
* Not provided by the server yet.
|
||||
*/
|
||||
export const DEFAULT_BRUSH = {
|
||||
color: 255,
|
||||
size: 8,
|
||||
strength: 0.5,
|
||||
};
|
||||
|
||||
/**
|
||||
* Default parameters for the image history.
|
||||
*
|
||||
* Not provided by the server yet.
|
||||
*/
|
||||
export const DEFAULT_HISTORY = {
|
||||
/**
|
||||
* The number of images to be shown.
|
||||
*/
|
||||
limit: 4,
|
||||
|
||||
/**
|
||||
* The number of additional images to be kept in history, so they can scroll
|
||||
* back into view when you delete one. Does not include deleted images.
|
||||
*/
|
||||
scrollback: 2,
|
||||
};
|
|
@ -28,8 +28,9 @@ import {
|
|||
STATE_KEY,
|
||||
STATE_VERSION,
|
||||
StateContext,
|
||||
} from './state.js';
|
||||
} from './state/full.js';
|
||||
import { I18N_STRINGS } from './strings/all.js';
|
||||
import { applyStateMigrations, UnknownState } from './state/migration/default.js';
|
||||
|
||||
export const INITIAL_LOAD_TIMEOUT = 5_000;
|
||||
|
||||
|
@ -70,6 +71,9 @@ export async function renderApp(config: Config, params: ServerParams, logger: Lo
|
|||
...createResetSlice(...slice),
|
||||
...createProfileSlice(...slice),
|
||||
}), {
|
||||
migrate(persistedState, version) {
|
||||
return applyStateMigrations(params, persistedState as UnknownState, version, logger);
|
||||
},
|
||||
name: STATE_KEY,
|
||||
partialize(s) {
|
||||
return {
|
||||
|
|
965
gui/src/state.ts
965
gui/src/state.ts
|
@ -1,965 +0,0 @@
|
|||
/* eslint-disable camelcase */
|
||||
/* eslint-disable max-lines */
|
||||
/* eslint-disable no-null/no-null */
|
||||
import { Maybe } from '@apextoaster/js-utils';
|
||||
import { PaletteMode } from '@mui/material';
|
||||
import { Logger } from 'noicejs';
|
||||
import { createContext } from 'react';
|
||||
import { StateCreator, StoreApi } from 'zustand';
|
||||
|
||||
import {
|
||||
ApiClient,
|
||||
} from './client/base.js';
|
||||
import { PipelineGrid } from './client/utils.js';
|
||||
import { Config, ConfigFiles, ConfigState, ServerParams } from './config.js';
|
||||
import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from './types/model.js';
|
||||
import { ImageResponse, ReadyResponse, RetryParams } from './types/api.js';
|
||||
import {
|
||||
BaseImgParams,
|
||||
BlendParams,
|
||||
BrushParams,
|
||||
HighresParams,
|
||||
Img2ImgParams,
|
||||
InpaintParams,
|
||||
ModelParams,
|
||||
OutpaintPixels,
|
||||
Txt2ImgParams,
|
||||
UpscaleParams,
|
||||
UpscaleReqParams,
|
||||
} from './types/params.js';
|
||||
|
||||
export const MISSING_INDEX = -1;
|
||||
|
||||
export type Theme = PaletteMode | ''; // tri-state, '' is unset
|
||||
|
||||
/**
|
||||
* Combine optional files and required ranges.
|
||||
*/
|
||||
export type TabState<TabParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>;
|
||||
|
||||
export interface HistoryItem {
|
||||
image: ImageResponse;
|
||||
ready: Maybe<ReadyResponse>;
|
||||
retry: Maybe<RetryParams>;
|
||||
}
|
||||
|
||||
export interface ProfileItem {
|
||||
name: string;
|
||||
params: BaseImgParams | Txt2ImgParams;
|
||||
highres?: Maybe<HighresParams>;
|
||||
upscale?: Maybe<UpscaleParams>;
|
||||
}
|
||||
|
||||
interface DefaultSlice {
|
||||
defaults: TabState<BaseImgParams>;
|
||||
theme: Theme;
|
||||
|
||||
setDefaults(param: Partial<BaseImgParams>): void;
|
||||
setTheme(theme: Theme): void;
|
||||
}
|
||||
|
||||
interface HistorySlice {
|
||||
history: Array<HistoryItem>;
|
||||
limit: number;
|
||||
|
||||
pushHistory(image: ImageResponse, retry?: RetryParams): void;
|
||||
removeHistory(image: ImageResponse): void;
|
||||
setLimit(limit: number): void;
|
||||
setReady(image: ImageResponse, ready: ReadyResponse): void;
|
||||
}
|
||||
|
||||
interface ModelSlice {
|
||||
extras: ExtrasFile;
|
||||
|
||||
removeCorrectionModel(model: CorrectionModel): void;
|
||||
removeDiffusionModel(model: DiffusionModel): void;
|
||||
removeExtraNetwork(model: ExtraNetwork): void;
|
||||
removeExtraSource(model: ExtraSource): void;
|
||||
removeUpscalingModel(model: UpscalingModel): void;
|
||||
|
||||
setExtras(extras: Partial<ExtrasFile>): void;
|
||||
|
||||
setCorrectionModel(model: CorrectionModel): void;
|
||||
setDiffusionModel(model: DiffusionModel): void;
|
||||
setExtraNetwork(model: ExtraNetwork): void;
|
||||
setExtraSource(model: ExtraSource): void;
|
||||
setUpscalingModel(model: UpscalingModel): void;
|
||||
}
|
||||
|
||||
// #region tab slices
|
||||
interface Txt2ImgSlice {
|
||||
txt2img: TabState<Txt2ImgParams>;
|
||||
txt2imgModel: ModelParams;
|
||||
txt2imgHighres: HighresParams;
|
||||
txt2imgUpscale: UpscaleParams;
|
||||
txt2imgVariable: PipelineGrid;
|
||||
|
||||
resetTxt2Img(): void;
|
||||
|
||||
setTxt2Img(params: Partial<Txt2ImgParams>): void;
|
||||
setTxt2ImgModel(params: Partial<ModelParams>): void;
|
||||
setTxt2ImgHighres(params: Partial<HighresParams>): void;
|
||||
setTxt2ImgUpscale(params: Partial<UpscaleParams>): void;
|
||||
setTxt2ImgVariable(params: Partial<PipelineGrid>): void;
|
||||
}
|
||||
|
||||
interface Img2ImgSlice {
|
||||
img2img: TabState<Img2ImgParams>;
|
||||
img2imgModel: ModelParams;
|
||||
img2imgHighres: HighresParams;
|
||||
img2imgUpscale: UpscaleParams;
|
||||
|
||||
resetImg2Img(): void;
|
||||
|
||||
setImg2Img(params: Partial<Img2ImgParams>): void;
|
||||
setImg2ImgModel(params: Partial<ModelParams>): void;
|
||||
setImg2ImgHighres(params: Partial<HighresParams>): void;
|
||||
setImg2ImgUpscale(params: Partial<UpscaleParams>): void;
|
||||
}
|
||||
|
||||
interface InpaintSlice {
|
||||
inpaint: TabState<InpaintParams>;
|
||||
inpaintBrush: BrushParams;
|
||||
inpaintModel: ModelParams;
|
||||
inpaintHighres: HighresParams;
|
||||
inpaintUpscale: UpscaleParams;
|
||||
outpaint: OutpaintPixels;
|
||||
|
||||
resetInpaint(): void;
|
||||
|
||||
setInpaint(params: Partial<InpaintParams>): void;
|
||||
setInpaintBrush(brush: Partial<BrushParams>): void;
|
||||
setInpaintModel(params: Partial<ModelParams>): void;
|
||||
setInpaintHighres(params: Partial<HighresParams>): void;
|
||||
setInpaintUpscale(params: Partial<UpscaleParams>): void;
|
||||
setOutpaint(pixels: Partial<OutpaintPixels>): void;
|
||||
}
|
||||
|
||||
interface UpscaleSlice {
|
||||
upscale: TabState<UpscaleReqParams>;
|
||||
upscaleHighres: HighresParams;
|
||||
upscaleModel: ModelParams;
|
||||
upscaleUpscale: UpscaleParams;
|
||||
|
||||
resetUpscale(): void;
|
||||
|
||||
setUpscale(params: Partial<UpscaleReqParams>): void;
|
||||
setUpscaleHighres(params: Partial<HighresParams>): void;
|
||||
setUpscaleModel(params: Partial<ModelParams>): void;
|
||||
setUpscaleUpscale(params: Partial<UpscaleParams>): void;
|
||||
}
|
||||
|
||||
interface BlendSlice {
|
||||
blend: TabState<BlendParams>;
|
||||
blendBrush: BrushParams;
|
||||
blendModel: ModelParams;
|
||||
blendUpscale: UpscaleParams;
|
||||
|
||||
resetBlend(): void;
|
||||
|
||||
setBlend(blend: Partial<BlendParams>): void;
|
||||
setBlendBrush(brush: Partial<BrushParams>): void;
|
||||
setBlendModel(model: Partial<ModelParams>): void;
|
||||
setBlendUpscale(params: Partial<UpscaleParams>): void;
|
||||
}
|
||||
|
||||
interface ResetSlice {
|
||||
resetAll(): void;
|
||||
}
|
||||
|
||||
interface ProfileSlice {
|
||||
profiles: Array<ProfileItem>;
|
||||
|
||||
removeProfile(profileName: string): void;
|
||||
|
||||
saveProfile(profile: ProfileItem): void;
|
||||
}
|
||||
// #endregion
|
||||
|
||||
/**
|
||||
* Full merged state including all slices.
|
||||
*/
|
||||
export type OnnxState
|
||||
= DefaultSlice
|
||||
& HistorySlice
|
||||
& Img2ImgSlice
|
||||
& InpaintSlice
|
||||
& ModelSlice
|
||||
& Txt2ImgSlice
|
||||
& UpscaleSlice
|
||||
& BlendSlice
|
||||
& ResetSlice
|
||||
& ModelSlice
|
||||
& ProfileSlice;
|
||||
|
||||
/**
|
||||
* Shorthand for state creator to reduce repeated arguments.
|
||||
*/
|
||||
export type Slice<T> = StateCreator<OnnxState, [], [], T>;
|
||||
|
||||
/**
|
||||
* React context binding for API client.
|
||||
*/
|
||||
export const ClientContext = createContext<Maybe<ApiClient>>(undefined);
|
||||
|
||||
/**
|
||||
* React context binding for merged config, including server parameters.
|
||||
*/
|
||||
export const ConfigContext = createContext<Maybe<Config<ServerParams>>>(undefined);
|
||||
|
||||
/**
|
||||
* React context binding for bunyan logger.
|
||||
*/
|
||||
export const LoggerContext = createContext<Maybe<Logger>>(undefined);
|
||||
|
||||
/**
|
||||
* React context binding for zustand state store.
|
||||
*/
|
||||
export const StateContext = createContext<Maybe<StoreApi<OnnxState>>>(undefined);
|
||||
|
||||
/**
|
||||
* Key for zustand persistence, typically local storage.
|
||||
*/
|
||||
export const STATE_KEY = 'onnx-web';
|
||||
|
||||
/**
|
||||
* Current state version for zustand persistence.
|
||||
*/
|
||||
export const STATE_VERSION = 7;
|
||||
|
||||
export const BLEND_SOURCES = 2;
|
||||
|
||||
/**
|
||||
* Default parameters for the inpaint brush.
|
||||
*
|
||||
* Not provided by the server yet.
|
||||
*/
|
||||
export const DEFAULT_BRUSH = {
|
||||
color: 255,
|
||||
size: 8,
|
||||
strength: 0.5,
|
||||
};
|
||||
|
||||
/**
|
||||
* Default parameters for the image history.
|
||||
*
|
||||
* Not provided by the server yet.
|
||||
*/
|
||||
export const DEFAULT_HISTORY = {
|
||||
/**
|
||||
* The number of images to be shown.
|
||||
*/
|
||||
limit: 4,
|
||||
|
||||
/**
|
||||
* The number of additional images to be kept in history, so they can scroll
|
||||
* back into view when you delete one. Does not include deleted images.
|
||||
*/
|
||||
scrollback: 2,
|
||||
};
|
||||
|
||||
export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
|
||||
return {
|
||||
batch: defaults.batch.default,
|
||||
cfg: defaults.cfg.default,
|
||||
eta: defaults.eta.default,
|
||||
negativePrompt: defaults.negativePrompt.default,
|
||||
prompt: defaults.prompt.default,
|
||||
scheduler: defaults.scheduler.default,
|
||||
steps: defaults.steps.default,
|
||||
seed: defaults.seed.default,
|
||||
tiled_vae: defaults.tiled_vae.default,
|
||||
unet_overlap: defaults.unet_overlap.default,
|
||||
unet_tile: defaults.unet_tile.default,
|
||||
vae_overlap: defaults.vae_overlap.default,
|
||||
vae_tile: defaults.vae_tile.default,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare the state slice constructors.
|
||||
*
|
||||
* In the default state, image sources should be null and booleans should be false. Everything
|
||||
* else should be initialized from the default value in the base parameters.
|
||||
*/
|
||||
export function createStateSlices(server: ServerParams) {
|
||||
const defaultParams = baseParamsFromServer(server);
|
||||
const defaultHighres: HighresParams = {
|
||||
enabled: false,
|
||||
highresIterations: server.highresIterations.default,
|
||||
highresMethod: '',
|
||||
highresSteps: server.highresSteps.default,
|
||||
highresScale: server.highresScale.default,
|
||||
highresStrength: server.highresStrength.default,
|
||||
};
|
||||
const defaultModel: ModelParams = {
|
||||
control: server.control.default,
|
||||
correction: server.correction.default,
|
||||
model: server.model.default,
|
||||
pipeline: server.pipeline.default,
|
||||
platform: server.platform.default,
|
||||
upscaling: server.upscaling.default,
|
||||
};
|
||||
const defaultUpscale: UpscaleParams = {
|
||||
denoise: server.denoise.default,
|
||||
enabled: false,
|
||||
faces: false,
|
||||
faceOutscale: server.faceOutscale.default,
|
||||
faceStrength: server.faceStrength.default,
|
||||
outscale: server.outscale.default,
|
||||
scale: server.scale.default,
|
||||
upscaleOrder: server.upscaleOrder.default,
|
||||
};
|
||||
const defaultGrid: PipelineGrid = {
|
||||
enabled: false,
|
||||
columns: {
|
||||
parameter: 'seed',
|
||||
value: '',
|
||||
},
|
||||
rows: {
|
||||
parameter: 'seed',
|
||||
value: '',
|
||||
},
|
||||
};
|
||||
|
||||
const createTxt2ImgSlice: Slice<Txt2ImgSlice> = (set) => ({
|
||||
txt2img: {
|
||||
...defaultParams,
|
||||
width: server.width.default,
|
||||
height: server.height.default,
|
||||
},
|
||||
txt2imgHighres: {
|
||||
...defaultHighres,
|
||||
},
|
||||
txt2imgModel: {
|
||||
...defaultModel,
|
||||
},
|
||||
txt2imgUpscale: {
|
||||
...defaultUpscale,
|
||||
},
|
||||
txt2imgVariable: {
|
||||
...defaultGrid,
|
||||
},
|
||||
setTxt2Img(params) {
|
||||
set((prev) => ({
|
||||
txt2img: {
|
||||
...prev.txt2img,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setTxt2ImgHighres(params) {
|
||||
set((prev) => ({
|
||||
txt2imgHighres: {
|
||||
...prev.txt2imgHighres,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setTxt2ImgModel(params) {
|
||||
set((prev) => ({
|
||||
txt2imgModel: {
|
||||
...prev.txt2imgModel,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setTxt2ImgUpscale(params) {
|
||||
set((prev) => ({
|
||||
txt2imgUpscale: {
|
||||
...prev.txt2imgUpscale,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setTxt2ImgVariable(params) {
|
||||
set((prev) => ({
|
||||
txt2imgVariable: {
|
||||
...prev.txt2imgVariable,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
resetTxt2Img() {
|
||||
set({
|
||||
txt2img: {
|
||||
...defaultParams,
|
||||
width: server.width.default,
|
||||
height: server.height.default,
|
||||
},
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const createImg2ImgSlice: Slice<Img2ImgSlice> = (set) => ({
|
||||
img2img: {
|
||||
...defaultParams,
|
||||
loopback: server.loopback.default,
|
||||
source: null,
|
||||
sourceFilter: '',
|
||||
strength: server.strength.default,
|
||||
},
|
||||
img2imgHighres: {
|
||||
...defaultHighres,
|
||||
},
|
||||
img2imgModel: {
|
||||
...defaultModel,
|
||||
},
|
||||
img2imgUpscale: {
|
||||
...defaultUpscale,
|
||||
},
|
||||
resetImg2Img() {
|
||||
set({
|
||||
img2img: {
|
||||
...defaultParams,
|
||||
loopback: server.loopback.default,
|
||||
source: null,
|
||||
sourceFilter: '',
|
||||
strength: server.strength.default,
|
||||
},
|
||||
});
|
||||
},
|
||||
setImg2Img(params) {
|
||||
set((prev) => ({
|
||||
img2img: {
|
||||
...prev.img2img,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setImg2ImgHighres(params) {
|
||||
set((prev) => ({
|
||||
img2imgHighres: {
|
||||
...prev.img2imgHighres,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setImg2ImgModel(params) {
|
||||
set((prev) => ({
|
||||
img2imgModel: {
|
||||
...prev.img2imgModel,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setImg2ImgUpscale(params) {
|
||||
set((prev) => ({
|
||||
img2imgUpscale: {
|
||||
...prev.img2imgUpscale,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
});
|
||||
|
||||
const createInpaintSlice: Slice<InpaintSlice> = (set) => ({
|
||||
inpaint: {
|
||||
...defaultParams,
|
||||
fillColor: server.fillColor.default,
|
||||
filter: server.filter.default,
|
||||
mask: null,
|
||||
noise: server.noise.default,
|
||||
source: null,
|
||||
strength: server.strength.default,
|
||||
tileOrder: server.tileOrder.default,
|
||||
},
|
||||
inpaintBrush: {
|
||||
...DEFAULT_BRUSH,
|
||||
},
|
||||
inpaintHighres: {
|
||||
...defaultHighres,
|
||||
},
|
||||
inpaintModel: {
|
||||
...defaultModel,
|
||||
},
|
||||
inpaintUpscale: {
|
||||
...defaultUpscale,
|
||||
},
|
||||
outpaint: {
|
||||
enabled: false,
|
||||
left: server.left.default,
|
||||
right: server.right.default,
|
||||
top: server.top.default,
|
||||
bottom: server.bottom.default,
|
||||
},
|
||||
resetInpaint() {
|
||||
set({
|
||||
inpaint: {
|
||||
...defaultParams,
|
||||
fillColor: server.fillColor.default,
|
||||
filter: server.filter.default,
|
||||
mask: null,
|
||||
noise: server.noise.default,
|
||||
source: null,
|
||||
strength: server.strength.default,
|
||||
tileOrder: server.tileOrder.default,
|
||||
},
|
||||
});
|
||||
},
|
||||
setInpaint(params) {
|
||||
set((prev) => ({
|
||||
inpaint: {
|
||||
...prev.inpaint,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setInpaintBrush(brush) {
|
||||
set((prev) => ({
|
||||
inpaintBrush: {
|
||||
...prev.inpaintBrush,
|
||||
...brush,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setInpaintHighres(params) {
|
||||
set((prev) => ({
|
||||
inpaintHighres: {
|
||||
...prev.inpaintHighres,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setInpaintModel(params) {
|
||||
set((prev) => ({
|
||||
inpaintModel: {
|
||||
...prev.inpaintModel,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setInpaintUpscale(params) {
|
||||
set((prev) => ({
|
||||
inpaintUpscale: {
|
||||
...prev.inpaintUpscale,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setOutpaint(pixels) {
|
||||
set((prev) => ({
|
||||
outpaint: {
|
||||
...prev.outpaint,
|
||||
...pixels,
|
||||
}
|
||||
}));
|
||||
},
|
||||
});
|
||||
|
||||
const createHistorySlice: Slice<HistorySlice> = (set) => ({
|
||||
history: [],
|
||||
limit: DEFAULT_HISTORY.limit,
|
||||
pushHistory(image, retry) {
|
||||
set((prev) => ({
|
||||
...prev,
|
||||
history: [
|
||||
{
|
||||
image,
|
||||
ready: undefined,
|
||||
retry,
|
||||
},
|
||||
...prev.history,
|
||||
].slice(0, prev.limit + DEFAULT_HISTORY.scrollback),
|
||||
}));
|
||||
},
|
||||
removeHistory(image) {
|
||||
set((prev) => ({
|
||||
...prev,
|
||||
history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key),
|
||||
}));
|
||||
},
|
||||
setLimit(limit) {
|
||||
set((prev) => ({
|
||||
...prev,
|
||||
limit,
|
||||
}));
|
||||
},
|
||||
setReady(image, ready) {
|
||||
set((prev) => {
|
||||
const history = [...prev.history];
|
||||
const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key);
|
||||
if (idx >= 0) {
|
||||
history[idx].ready = ready;
|
||||
} else {
|
||||
// TODO: error
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
history,
|
||||
};
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const createUpscaleSlice: Slice<UpscaleSlice> = (set) => ({
|
||||
upscale: {
|
||||
...defaultParams,
|
||||
source: null,
|
||||
},
|
||||
upscaleHighres: {
|
||||
...defaultHighres,
|
||||
},
|
||||
upscaleModel: {
|
||||
...defaultModel,
|
||||
},
|
||||
upscaleUpscale: {
|
||||
...defaultUpscale,
|
||||
},
|
||||
resetUpscale() {
|
||||
set({
|
||||
upscale: {
|
||||
...defaultParams,
|
||||
source: null,
|
||||
},
|
||||
});
|
||||
},
|
||||
setUpscale(source) {
|
||||
set((prev) => ({
|
||||
upscale: {
|
||||
...prev.upscale,
|
||||
...source,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setUpscaleHighres(params) {
|
||||
set((prev) => ({
|
||||
upscaleHighres: {
|
||||
...prev.upscaleHighres,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setUpscaleModel(params) {
|
||||
set((prev) => ({
|
||||
upscaleModel: {
|
||||
...prev.upscaleModel,
|
||||
...defaultModel,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setUpscaleUpscale(params) {
|
||||
set((prev) => ({
|
||||
upscaleUpscale: {
|
||||
...prev.upscaleUpscale,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
});
|
||||
|
||||
const createBlendSlice: Slice<BlendSlice> = (set) => ({
|
||||
blend: {
|
||||
mask: null,
|
||||
sources: [],
|
||||
},
|
||||
blendBrush: {
|
||||
...DEFAULT_BRUSH,
|
||||
},
|
||||
blendModel: {
|
||||
...defaultModel,
|
||||
},
|
||||
blendUpscale: {
|
||||
...defaultUpscale,
|
||||
},
|
||||
resetBlend() {
|
||||
set({
|
||||
blend: {
|
||||
mask: null,
|
||||
sources: [],
|
||||
},
|
||||
});
|
||||
},
|
||||
setBlend(blend) {
|
||||
set((prev) => ({
|
||||
blend: {
|
||||
...prev.blend,
|
||||
...blend,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setBlendBrush(brush) {
|
||||
set((prev) => ({
|
||||
blendBrush: {
|
||||
...prev.blendBrush,
|
||||
...brush,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setBlendModel(model) {
|
||||
set((prev) => ({
|
||||
blendModel: {
|
||||
...prev.blendModel,
|
||||
...model,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setBlendUpscale(params) {
|
||||
set((prev) => ({
|
||||
blendUpscale: {
|
||||
...prev.blendUpscale,
|
||||
...params,
|
||||
},
|
||||
}));
|
||||
},
|
||||
});
|
||||
|
||||
const createDefaultSlice: Slice<DefaultSlice> = (set) => ({
|
||||
defaults: {
|
||||
...defaultParams,
|
||||
},
|
||||
theme: '',
|
||||
setDefaults(params) {
|
||||
set((prev) => ({
|
||||
defaults: {
|
||||
...prev.defaults,
|
||||
...params,
|
||||
}
|
||||
}));
|
||||
},
|
||||
setTheme(theme) {
|
||||
set((prev) => ({
|
||||
theme,
|
||||
}));
|
||||
}
|
||||
});
|
||||
|
||||
const createResetSlice: Slice<ResetSlice> = (set) => ({
|
||||
resetAll() {
|
||||
set((prev) => {
|
||||
const next = { ...prev };
|
||||
next.resetImg2Img();
|
||||
next.resetInpaint();
|
||||
next.resetTxt2Img();
|
||||
next.resetUpscale();
|
||||
next.resetBlend();
|
||||
return next;
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const createProfileSlice: Slice<ProfileSlice> = (set) => ({
|
||||
profiles: [],
|
||||
saveProfile(profile: ProfileItem) {
|
||||
set((prev) => {
|
||||
const profiles = [...prev.profiles];
|
||||
const idx = profiles.findIndex((it) => it.name === profile.name);
|
||||
if (idx >= 0) {
|
||||
profiles[idx] = profile;
|
||||
} else {
|
||||
profiles.push(profile);
|
||||
}
|
||||
return {
|
||||
...prev,
|
||||
profiles,
|
||||
};
|
||||
});
|
||||
},
|
||||
removeProfile(profileName: string) {
|
||||
set((prev) => {
|
||||
const profiles = [...prev.profiles];
|
||||
const idx = profiles.findIndex((it) => it.name === profileName);
|
||||
if (idx >= 0) {
|
||||
profiles.splice(idx, 1);
|
||||
}
|
||||
return {
|
||||
...prev,
|
||||
profiles,
|
||||
};
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// eslint-disable-next-line sonarjs/cognitive-complexity
|
||||
const createModelSlice: Slice<ModelSlice> = (set) => ({
|
||||
extras: {
|
||||
correction: [],
|
||||
diffusion: [],
|
||||
networks: [],
|
||||
sources: [],
|
||||
upscaling: [],
|
||||
},
|
||||
setExtras(extras) {
|
||||
set((prev) => ({
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
...extras,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setCorrectionModel(model) {
|
||||
set((prev) => {
|
||||
const correction = [...prev.extras.correction];
|
||||
const exists = correction.findIndex((it) => model.name === it.name);
|
||||
if (exists === MISSING_INDEX) {
|
||||
correction.push(model);
|
||||
} else {
|
||||
correction[exists] = model;
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
correction,
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
setDiffusionModel(model) {
|
||||
set((prev) => {
|
||||
const diffusion = [...prev.extras.diffusion];
|
||||
const exists = diffusion.findIndex((it) => model.name === it.name);
|
||||
if (exists === MISSING_INDEX) {
|
||||
diffusion.push(model);
|
||||
} else {
|
||||
diffusion[exists] = model;
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
diffusion,
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
setExtraNetwork(model) {
|
||||
set((prev) => {
|
||||
const networks = [...prev.extras.networks];
|
||||
const exists = networks.findIndex((it) => model.name === it.name);
|
||||
if (exists === MISSING_INDEX) {
|
||||
networks.push(model);
|
||||
} else {
|
||||
networks[exists] = model;
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
networks,
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
setExtraSource(model) {
|
||||
set((prev) => {
|
||||
const sources = [...prev.extras.sources];
|
||||
const exists = sources.findIndex((it) => model.name === it.name);
|
||||
if (exists === MISSING_INDEX) {
|
||||
sources.push(model);
|
||||
} else {
|
||||
sources[exists] = model;
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
sources,
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
setUpscalingModel(model) {
|
||||
set((prev) => {
|
||||
const upscaling = [...prev.extras.upscaling];
|
||||
const exists = upscaling.findIndex((it) => model.name === it.name);
|
||||
if (exists === MISSING_INDEX) {
|
||||
upscaling.push(model);
|
||||
} else {
|
||||
upscaling[exists] = model;
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
upscaling,
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
removeCorrectionModel(model) {
|
||||
set((prev) => {
|
||||
const correction = prev.extras.correction.filter((it) => model.name !== it.name);;
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
correction,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
},
|
||||
removeDiffusionModel(model) {
|
||||
set((prev) => {
|
||||
const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);;
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
diffusion,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
},
|
||||
removeExtraNetwork(model) {
|
||||
set((prev) => {
|
||||
const networks = prev.extras.networks.filter((it) => model.name !== it.name);;
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
networks,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
},
|
||||
removeExtraSource(model) {
|
||||
set((prev) => {
|
||||
const sources = prev.extras.sources.filter((it) => model.name !== it.name);;
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
sources,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
},
|
||||
removeUpscalingModel(model) {
|
||||
set((prev) => {
|
||||
const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);;
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
upscaling,
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
createDefaultSlice,
|
||||
createHistorySlice,
|
||||
createImg2ImgSlice,
|
||||
createInpaintSlice,
|
||||
createTxt2ImgSlice,
|
||||
createUpscaleSlice,
|
||||
createBlendSlice,
|
||||
createResetSlice,
|
||||
createModelSlice,
|
||||
createProfileSlice,
|
||||
};
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
import { DEFAULT_BRUSH } from '../constants.js';
|
||||
import {
|
||||
BlendParams,
|
||||
BrushParams,
|
||||
ModelParams,
|
||||
UpscaleParams,
|
||||
} from '../types/params.js';
|
||||
import { Slice, TabState } from './types.js';
|
||||
|
||||
export interface BlendSlice {
|
||||
blend: TabState<BlendParams>;
|
||||
blendBrush: BrushParams;
|
||||
blendModel: ModelParams;
|
||||
blendUpscale: UpscaleParams;
|
||||
|
||||
resetBlend(): void;
|
||||
|
||||
setBlend(blend: Partial<BlendParams>): void;
|
||||
setBlendBrush(brush: Partial<BrushParams>): void;
|
||||
setBlendModel(model: Partial<ModelParams>): void;
|
||||
setBlendUpscale(params: Partial<UpscaleParams>): void;
|
||||
}
|
||||
|
||||
export function createBlendSlice<TState extends BlendSlice>(
|
||||
defaultModel: ModelParams,
|
||||
defaultUpscale: UpscaleParams,
|
||||
): Slice<TState, BlendSlice> {
|
||||
return (set) => ({
|
||||
blend: {
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
mask: null,
|
||||
sources: [],
|
||||
},
|
||||
blendBrush: {
|
||||
...DEFAULT_BRUSH,
|
||||
},
|
||||
blendModel: {
|
||||
...defaultModel,
|
||||
},
|
||||
blendUpscale: {
|
||||
...defaultUpscale,
|
||||
},
|
||||
resetBlend() {
|
||||
set((prev) => ({
|
||||
blend: {
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
mask: null,
|
||||
sources: [] as Array<Blob>,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setBlend(blend) {
|
||||
set((prev) => ({
|
||||
blend: {
|
||||
...prev.blend,
|
||||
...blend,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setBlendBrush(brush) {
|
||||
set((prev) => ({
|
||||
blendBrush: {
|
||||
...prev.blendBrush,
|
||||
...brush,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setBlendModel(model) {
|
||||
set((prev) => ({
|
||||
blendModel: {
|
||||
...prev.blendModel,
|
||||
...model,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setBlendUpscale(params) {
|
||||
set((prev) => ({
|
||||
blendUpscale: {
|
||||
...prev.blendUpscale,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
});
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
import {
|
||||
BaseImgParams,
|
||||
} from '../types/params.js';
|
||||
import { Slice, TabState, Theme } from './types.js';
|
||||
|
||||
export interface DefaultSlice {
|
||||
defaults: TabState<BaseImgParams>;
|
||||
theme: Theme;
|
||||
|
||||
setDefaults(param: Partial<BaseImgParams>): void;
|
||||
setTheme(theme: Theme): void;
|
||||
}
|
||||
|
||||
export function createDefaultSlice<TState extends DefaultSlice>(defaultParams: Required<BaseImgParams>): Slice<TState, DefaultSlice> {
|
||||
return (set) => ({
|
||||
defaults: {
|
||||
...defaultParams,
|
||||
},
|
||||
theme: '',
|
||||
setDefaults(params) {
|
||||
set((prev) => ({
|
||||
defaults: {
|
||||
...prev.defaults,
|
||||
...params,
|
||||
}
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setTheme(theme) {
|
||||
set((prev) => ({
|
||||
theme,
|
||||
} as Partial<TState>));
|
||||
}
|
||||
});
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
/* eslint-disable camelcase */
|
||||
import { Maybe } from '@apextoaster/js-utils';
|
||||
import { Logger } from 'noicejs';
|
||||
import { createContext } from 'react';
|
||||
import { StoreApi } from 'zustand';
|
||||
|
||||
import {
|
||||
ApiClient,
|
||||
} from '../client/base.js';
|
||||
import { PipelineGrid } from '../client/utils.js';
|
||||
import { Config, ServerParams } from '../config.js';
|
||||
import { BlendSlice, createBlendSlice } from './blend.js';
|
||||
import { DefaultSlice, createDefaultSlice } from './default.js';
|
||||
import { HistorySlice, createHistorySlice } from './history.js';
|
||||
import { Img2ImgSlice, createImg2ImgSlice } from './img2img.js';
|
||||
import { InpaintSlice, createInpaintSlice } from './inpaint.js';
|
||||
import { ModelSlice, createModelSlice } from './model.js';
|
||||
import { ProfileSlice, createProfileSlice } from './profile.js';
|
||||
import { ResetSlice, createResetSlice } from './reset.js';
|
||||
import { Txt2ImgSlice, createTxt2ImgSlice } from './txt2img.js';
|
||||
import { UpscaleSlice, createUpscaleSlice } from './upscale.js';
|
||||
import {
|
||||
BaseImgParams,
|
||||
HighresParams,
|
||||
ModelParams,
|
||||
UpscaleParams,
|
||||
} from '../types/params.js';
|
||||
|
||||
/**
|
||||
* Full merged state including all slices.
|
||||
*/
|
||||
export type OnnxState
|
||||
= DefaultSlice
|
||||
& HistorySlice
|
||||
& Img2ImgSlice
|
||||
& InpaintSlice
|
||||
& ModelSlice
|
||||
& Txt2ImgSlice
|
||||
& UpscaleSlice
|
||||
& BlendSlice
|
||||
& ResetSlice
|
||||
& ProfileSlice;
|
||||
|
||||
/**
|
||||
* React context binding for API client.
|
||||
*/
|
||||
export const ClientContext = createContext<Maybe<ApiClient>>(undefined);
|
||||
|
||||
/**
|
||||
* React context binding for merged config, including server parameters.
|
||||
*/
|
||||
export const ConfigContext = createContext<Maybe<Config<ServerParams>>>(undefined);
|
||||
|
||||
/**
|
||||
* React context binding for bunyan logger.
|
||||
*/
|
||||
export const LoggerContext = createContext<Maybe<Logger>>(undefined);
|
||||
|
||||
/**
|
||||
* React context binding for zustand state store.
|
||||
*/
|
||||
export const StateContext = createContext<Maybe<StoreApi<OnnxState>>>(undefined);
|
||||
|
||||
/**
|
||||
* Key for zustand persistence, typically local storage.
|
||||
*/
|
||||
export const STATE_KEY = 'onnx-web';
|
||||
|
||||
/**
|
||||
* Current state version for zustand persistence.
|
||||
*/
|
||||
export const STATE_VERSION = 11;
|
||||
|
||||
export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
|
||||
return {
|
||||
batch: defaults.batch.default,
|
||||
cfg: defaults.cfg.default,
|
||||
eta: defaults.eta.default,
|
||||
negativePrompt: defaults.negativePrompt.default,
|
||||
prompt: defaults.prompt.default,
|
||||
scheduler: defaults.scheduler.default,
|
||||
steps: defaults.steps.default,
|
||||
seed: defaults.seed.default,
|
||||
tiled_vae: defaults.tiled_vae.default,
|
||||
unet_overlap: defaults.unet_overlap.default,
|
||||
unet_tile: defaults.unet_tile.default,
|
||||
vae_overlap: defaults.vae_overlap.default,
|
||||
vae_tile: defaults.vae_tile.default,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare the state slice constructors.
|
||||
*
|
||||
* In the default state, image sources should be null and booleans should be false. Everything
|
||||
* else should be initialized from the default value in the base parameters.
|
||||
*/
|
||||
export function createStateSlices(server: ServerParams) {
|
||||
const defaultParams = baseParamsFromServer(server);
|
||||
const defaultHighres: HighresParams = {
|
||||
enabled: false,
|
||||
highresIterations: server.highresIterations.default,
|
||||
highresMethod: '',
|
||||
highresSteps: server.highresSteps.default,
|
||||
highresScale: server.highresScale.default,
|
||||
highresStrength: server.highresStrength.default,
|
||||
};
|
||||
const defaultModel: ModelParams = {
|
||||
control: server.control.default,
|
||||
correction: server.correction.default,
|
||||
model: server.model.default,
|
||||
pipeline: server.pipeline.default,
|
||||
platform: server.platform.default,
|
||||
upscaling: server.upscaling.default,
|
||||
};
|
||||
const defaultUpscale: UpscaleParams = {
|
||||
denoise: server.denoise.default,
|
||||
enabled: false,
|
||||
faces: false,
|
||||
faceOutscale: server.faceOutscale.default,
|
||||
faceStrength: server.faceStrength.default,
|
||||
outscale: server.outscale.default,
|
||||
scale: server.scale.default,
|
||||
upscaleOrder: server.upscaleOrder.default,
|
||||
};
|
||||
const defaultGrid: PipelineGrid = {
|
||||
enabled: false,
|
||||
columns: {
|
||||
parameter: 'seed',
|
||||
value: '',
|
||||
},
|
||||
rows: {
|
||||
parameter: 'seed',
|
||||
value: '',
|
||||
},
|
||||
};
|
||||
|
||||
return {
|
||||
createBlendSlice: createBlendSlice(defaultModel, defaultUpscale),
|
||||
createDefaultSlice: createDefaultSlice(defaultParams),
|
||||
createHistorySlice: createHistorySlice(),
|
||||
createImg2ImgSlice: createImg2ImgSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale),
|
||||
createInpaintSlice: createInpaintSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale),
|
||||
createModelSlice: createModelSlice(),
|
||||
createProfileSlice: createProfileSlice(),
|
||||
createResetSlice: createResetSlice(),
|
||||
createTxt2ImgSlice: createTxt2ImgSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale, defaultGrid),
|
||||
createUpscaleSlice: createUpscaleSlice(defaultParams, defaultHighres, defaultModel, defaultUpscale),
|
||||
};
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
import { Maybe } from '@apextoaster/js-utils';
|
||||
import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js';
|
||||
import { Slice } from './types.js';
|
||||
import { DEFAULT_HISTORY } from '../constants.js';
|
||||
|
||||
export interface HistoryItem {
|
||||
image: ImageResponse;
|
||||
ready: Maybe<ReadyResponse>;
|
||||
retry: Maybe<RetryParams>;
|
||||
}
|
||||
|
||||
export interface HistorySlice {
|
||||
history: Array<HistoryItem>;
|
||||
limit: number;
|
||||
|
||||
pushHistory(image: ImageResponse, retry?: RetryParams): void;
|
||||
removeHistory(image: ImageResponse): void;
|
||||
setLimit(limit: number): void;
|
||||
setReady(image: ImageResponse, ready: ReadyResponse): void;
|
||||
}
|
||||
|
||||
export function createHistorySlice<TState extends HistorySlice>(): Slice<TState, HistorySlice> {
|
||||
return (set) => ({
|
||||
history: [],
|
||||
limit: DEFAULT_HISTORY.limit,
|
||||
pushHistory(image, retry) {
|
||||
set((prev) => ({
|
||||
...prev,
|
||||
history: [
|
||||
{
|
||||
image,
|
||||
ready: undefined,
|
||||
retry,
|
||||
},
|
||||
...prev.history,
|
||||
].slice(0, prev.limit + DEFAULT_HISTORY.scrollback),
|
||||
}));
|
||||
},
|
||||
removeHistory(image) {
|
||||
set((prev) => ({
|
||||
...prev,
|
||||
history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key),
|
||||
}));
|
||||
},
|
||||
setLimit(limit) {
|
||||
set((prev) => ({
|
||||
...prev,
|
||||
limit,
|
||||
}));
|
||||
},
|
||||
setReady(image, ready) {
|
||||
set((prev) => {
|
||||
const history = [...prev.history];
|
||||
const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key);
|
||||
if (idx >= 0) {
|
||||
history[idx].ready = ready;
|
||||
} else {
|
||||
// TODO: error
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
history,
|
||||
};
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
|
||||
import { ServerParams } from '../config.js';
|
||||
import {
|
||||
BaseImgParams,
|
||||
HighresParams,
|
||||
Img2ImgParams,
|
||||
ModelParams,
|
||||
UpscaleParams,
|
||||
} from '../types/params.js';
|
||||
import { Slice, TabState } from './types.js';
|
||||
|
||||
export interface Img2ImgSlice {
|
||||
img2img: TabState<Img2ImgParams>;
|
||||
img2imgModel: ModelParams;
|
||||
img2imgHighres: HighresParams;
|
||||
img2imgUpscale: UpscaleParams;
|
||||
|
||||
resetImg2Img(): void;
|
||||
|
||||
setImg2Img(params: Partial<Img2ImgParams>): void;
|
||||
setImg2ImgModel(params: Partial<ModelParams>): void;
|
||||
setImg2ImgHighres(params: Partial<HighresParams>): void;
|
||||
setImg2ImgUpscale(params: Partial<UpscaleParams>): void;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line max-params
|
||||
export function createImg2ImgSlice<TState extends Img2ImgSlice>(
|
||||
server: ServerParams,
|
||||
defaultParams: Required<BaseImgParams>,
|
||||
defaultHighres: HighresParams,
|
||||
defaultModel: ModelParams,
|
||||
defaultUpscale: UpscaleParams
|
||||
): Slice<TState, Img2ImgSlice> {
|
||||
return (set) => ({
|
||||
img2img: {
|
||||
...defaultParams,
|
||||
loopback: server.loopback.default,
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
source: null,
|
||||
sourceFilter: '',
|
||||
strength: server.strength.default,
|
||||
},
|
||||
img2imgHighres: {
|
||||
...defaultHighres,
|
||||
},
|
||||
img2imgModel: {
|
||||
...defaultModel,
|
||||
},
|
||||
img2imgUpscale: {
|
||||
...defaultUpscale,
|
||||
},
|
||||
resetImg2Img() {
|
||||
set({
|
||||
img2img: {
|
||||
...defaultParams,
|
||||
loopback: server.loopback.default,
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
source: null,
|
||||
sourceFilter: '',
|
||||
strength: server.strength.default,
|
||||
},
|
||||
} as Partial<TState>);
|
||||
},
|
||||
setImg2Img(params) {
|
||||
set((prev) => ({
|
||||
img2img: {
|
||||
...prev.img2img,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setImg2ImgHighres(params) {
|
||||
set((prev) => ({
|
||||
img2imgHighres: {
|
||||
...prev.img2imgHighres,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setImg2ImgModel(params) {
|
||||
set((prev) => ({
|
||||
img2imgModel: {
|
||||
...prev.img2imgModel,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setImg2ImgUpscale(params) {
|
||||
set((prev) => ({
|
||||
img2imgUpscale: {
|
||||
...prev.img2imgUpscale,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
});
|
||||
|
||||
}
|
|
@ -0,0 +1,136 @@
|
|||
import { ServerParams } from '../config.js';
|
||||
import { DEFAULT_BRUSH } from '../constants.js';
|
||||
import {
|
||||
BaseImgParams,
|
||||
BrushParams,
|
||||
HighresParams,
|
||||
InpaintParams,
|
||||
ModelParams,
|
||||
OutpaintPixels,
|
||||
UpscaleParams,
|
||||
} from '../types/params.js';
|
||||
import { Slice, TabState } from './types.js';
|
||||
export interface InpaintSlice {
|
||||
inpaint: TabState<InpaintParams>;
|
||||
inpaintBrush: BrushParams;
|
||||
inpaintModel: ModelParams;
|
||||
inpaintHighres: HighresParams;
|
||||
inpaintUpscale: UpscaleParams;
|
||||
outpaint: OutpaintPixels;
|
||||
|
||||
resetInpaint(): void;
|
||||
|
||||
setInpaint(params: Partial<InpaintParams>): void;
|
||||
setInpaintBrush(brush: Partial<BrushParams>): void;
|
||||
setInpaintModel(params: Partial<ModelParams>): void;
|
||||
setInpaintHighres(params: Partial<HighresParams>): void;
|
||||
setInpaintUpscale(params: Partial<UpscaleParams>): void;
|
||||
setOutpaint(pixels: Partial<OutpaintPixels>): void;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line max-params
|
||||
export function createInpaintSlice<TState extends InpaintSlice>(
|
||||
server: ServerParams,
|
||||
defaultParams: Required<BaseImgParams>,
|
||||
defaultHighres: HighresParams,
|
||||
defaultModel: ModelParams,
|
||||
defaultUpscale: UpscaleParams,
|
||||
): Slice<TState, InpaintSlice> {
|
||||
return (set) => ({
|
||||
inpaint: {
|
||||
...defaultParams,
|
||||
fillColor: server.fillColor.default,
|
||||
filter: server.filter.default,
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
mask: null,
|
||||
noise: server.noise.default,
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
source: null,
|
||||
strength: server.strength.default,
|
||||
tileOrder: server.tileOrder.default,
|
||||
},
|
||||
inpaintBrush: {
|
||||
...DEFAULT_BRUSH,
|
||||
},
|
||||
inpaintHighres: {
|
||||
...defaultHighres,
|
||||
},
|
||||
inpaintModel: {
|
||||
...defaultModel,
|
||||
},
|
||||
inpaintUpscale: {
|
||||
...defaultUpscale,
|
||||
},
|
||||
outpaint: {
|
||||
enabled: false,
|
||||
left: server.left.default,
|
||||
right: server.right.default,
|
||||
top: server.top.default,
|
||||
bottom: server.bottom.default,
|
||||
},
|
||||
resetInpaint() {
|
||||
set({
|
||||
inpaint: {
|
||||
...defaultParams,
|
||||
fillColor: server.fillColor.default,
|
||||
filter: server.filter.default,
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
mask: null,
|
||||
noise: server.noise.default,
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
source: null,
|
||||
strength: server.strength.default,
|
||||
tileOrder: server.tileOrder.default,
|
||||
},
|
||||
} as Partial<TState>);
|
||||
},
|
||||
setInpaint(params) {
|
||||
set((prev) => ({
|
||||
inpaint: {
|
||||
...prev.inpaint,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setInpaintBrush(brush) {
|
||||
set((prev) => ({
|
||||
inpaintBrush: {
|
||||
...prev.inpaintBrush,
|
||||
...brush,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setInpaintHighres(params) {
|
||||
set((prev) => ({
|
||||
inpaintHighres: {
|
||||
...prev.inpaintHighres,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setInpaintModel(params) {
|
||||
set((prev) => ({
|
||||
inpaintModel: {
|
||||
...prev.inpaintModel,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setInpaintUpscale(params) {
|
||||
set((prev) => ({
|
||||
inpaintUpscale: {
|
||||
...prev.inpaintUpscale,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setOutpaint(pixels) {
|
||||
set((prev) => ({
|
||||
outpaint: {
|
||||
...prev.outpaint,
|
||||
...pixels,
|
||||
}
|
||||
} as Partial<TState>));
|
||||
},
|
||||
});
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
/* eslint-disable camelcase */
|
||||
import { Logger } from 'browser-bunyan';
|
||||
import { ServerParams } from '../../config.js';
|
||||
import { BaseImgParams } from '../../types/params.js';
|
||||
import { OnnxState, STATE_VERSION } from '../full.js';
|
||||
import { Img2ImgSlice } from '../img2img.js';
|
||||
import { InpaintSlice } from '../inpaint.js';
|
||||
import { Txt2ImgSlice } from '../txt2img.js';
|
||||
import { UpscaleSlice } from '../upscale.js';
|
||||
|
||||
// #region V7
|
||||
export const V7 = 7;
|
||||
|
||||
export type BaseImgParamsV7<T extends BaseImgParams> = Omit<T, AddedKeysV11> & {
|
||||
overlap: number;
|
||||
tile: number;
|
||||
};
|
||||
|
||||
export type OnnxStateV7 = Omit<OnnxState, 'img2img' | 'txt2img'> & {
|
||||
img2img: BaseImgParamsV7<Img2ImgSlice['img2img']>;
|
||||
inpaint: BaseImgParamsV7<InpaintSlice['inpaint']>;
|
||||
txt2img: BaseImgParamsV7<Txt2ImgSlice['txt2img']>;
|
||||
upscale: BaseImgParamsV7<UpscaleSlice['upscale']>;
|
||||
};
|
||||
// #endregion
|
||||
|
||||
// #region V11
|
||||
export const REMOVED_KEYS_V11 = ['tile', 'overlap'] as const;
|
||||
|
||||
export type RemovedKeysV11 = typeof REMOVED_KEYS_V11[number];
|
||||
|
||||
// TODO: can the compiler calculate this?
|
||||
export type AddedKeysV11 = 'unet_tile' | 'unet_overlap' | 'vae_tile' | 'vae_overlap';
|
||||
// #endregion
|
||||
|
||||
// add versions to this list as they are replaced
|
||||
export type PreviousState = OnnxStateV7;
|
||||
|
||||
// always the latest version
|
||||
export type CurrentState = OnnxState;
|
||||
|
||||
// any version of state
|
||||
export type UnknownState = PreviousState | CurrentState;
|
||||
|
||||
export function applyStateMigrations(params: ServerParams, previousState: UnknownState, version: number, logger: Logger): OnnxState {
|
||||
logger.info('applying state migrations from version %s to version %s', version, STATE_VERSION);
|
||||
|
||||
if (version <= V7) {
|
||||
return migrateV7ToV11(params, previousState as PreviousState);
|
||||
}
|
||||
|
||||
return previousState as CurrentState;
|
||||
}
|
||||
|
||||
export function migrateV7ToV11(params: ServerParams, previousState: PreviousState): CurrentState {
|
||||
// add any missing keys
|
||||
const result: CurrentState = {
|
||||
...params,
|
||||
...previousState,
|
||||
img2img: {
|
||||
...previousState.img2img,
|
||||
unet_overlap: params.unet_overlap.default,
|
||||
unet_tile: params.unet_tile.default,
|
||||
vae_overlap: params.vae_overlap.default,
|
||||
vae_tile: params.vae_tile.default,
|
||||
},
|
||||
inpaint: {
|
||||
...previousState.inpaint,
|
||||
unet_overlap: params.unet_overlap.default,
|
||||
unet_tile: params.unet_tile.default,
|
||||
vae_overlap: params.vae_overlap.default,
|
||||
vae_tile: params.vae_tile.default,
|
||||
},
|
||||
txt2img: {
|
||||
...previousState.txt2img,
|
||||
unet_overlap: params.unet_overlap.default,
|
||||
unet_tile: params.unet_tile.default,
|
||||
vae_overlap: params.vae_overlap.default,
|
||||
vae_tile: params.vae_tile.default,
|
||||
},
|
||||
upscale: {
|
||||
...previousState.upscale,
|
||||
unet_overlap: params.unet_overlap.default,
|
||||
unet_tile: params.unet_tile.default,
|
||||
vae_overlap: params.vae_overlap.default,
|
||||
vae_tile: params.vae_tile.default,
|
||||
},
|
||||
};
|
||||
|
||||
// TODO: remove extra keys
|
||||
|
||||
return result;
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from '../types/model.js';
|
||||
import { MISSING_INDEX, Slice } from './types.js';
|
||||
|
||||
export interface ModelSlice {
|
||||
extras: ExtrasFile;
|
||||
|
||||
removeCorrectionModel(model: CorrectionModel): void;
|
||||
removeDiffusionModel(model: DiffusionModel): void;
|
||||
removeExtraNetwork(model: ExtraNetwork): void;
|
||||
removeExtraSource(model: ExtraSource): void;
|
||||
removeUpscalingModel(model: UpscalingModel): void;
|
||||
|
||||
setExtras(extras: Partial<ExtrasFile>): void;
|
||||
|
||||
setCorrectionModel(model: CorrectionModel): void;
|
||||
setDiffusionModel(model: DiffusionModel): void;
|
||||
setExtraNetwork(model: ExtraNetwork): void;
|
||||
setExtraSource(model: ExtraSource): void;
|
||||
setUpscalingModel(model: UpscalingModel): void;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line sonarjs/cognitive-complexity
|
||||
export function createModelSlice<TState extends ModelSlice>(): Slice<TState, ModelSlice> {
|
||||
// eslint-disable-next-line sonarjs/cognitive-complexity
|
||||
return (set) => ({
|
||||
extras: {
|
||||
correction: [],
|
||||
diffusion: [],
|
||||
networks: [],
|
||||
sources: [],
|
||||
upscaling: [],
|
||||
},
|
||||
setExtras(extras) {
|
||||
set((prev) => ({
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
...extras,
|
||||
},
|
||||
}));
|
||||
},
|
||||
setCorrectionModel(model) {
|
||||
set((prev) => {
|
||||
const correction = [...prev.extras.correction];
|
||||
const exists = correction.findIndex((it) => model.name === it.name);
|
||||
if (exists === MISSING_INDEX) {
|
||||
correction.push(model);
|
||||
} else {
|
||||
correction[exists] = model;
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
correction,
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
setDiffusionModel(model) {
|
||||
set((prev) => {
|
||||
const diffusion = [...prev.extras.diffusion];
|
||||
const exists = diffusion.findIndex((it) => model.name === it.name);
|
||||
if (exists === MISSING_INDEX) {
|
||||
diffusion.push(model);
|
||||
} else {
|
||||
diffusion[exists] = model;
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
diffusion,
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
setExtraNetwork(model) {
|
||||
set((prev) => {
|
||||
const networks = [...prev.extras.networks];
|
||||
const exists = networks.findIndex((it) => model.name === it.name);
|
||||
if (exists === MISSING_INDEX) {
|
||||
networks.push(model);
|
||||
} else {
|
||||
networks[exists] = model;
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
networks,
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
setExtraSource(model) {
|
||||
set((prev) => {
|
||||
const sources = [...prev.extras.sources];
|
||||
const exists = sources.findIndex((it) => model.name === it.name);
|
||||
if (exists === MISSING_INDEX) {
|
||||
sources.push(model);
|
||||
} else {
|
||||
sources[exists] = model;
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
sources,
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
setUpscalingModel(model) {
|
||||
set((prev) => {
|
||||
const upscaling = [...prev.extras.upscaling];
|
||||
const exists = upscaling.findIndex((it) => model.name === it.name);
|
||||
if (exists === MISSING_INDEX) {
|
||||
upscaling.push(model);
|
||||
} else {
|
||||
upscaling[exists] = model;
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
upscaling,
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
removeCorrectionModel(model) {
|
||||
set((prev) => {
|
||||
const correction = prev.extras.correction.filter((it) => model.name !== it.name);;
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
correction,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
},
|
||||
removeDiffusionModel(model) {
|
||||
set((prev) => {
|
||||
const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);;
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
diffusion,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
},
|
||||
removeExtraNetwork(model) {
|
||||
set((prev) => {
|
||||
const networks = prev.extras.networks.filter((it) => model.name !== it.name);;
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
networks,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
},
|
||||
removeExtraSource(model) {
|
||||
set((prev) => {
|
||||
const sources = prev.extras.sources.filter((it) => model.name !== it.name);;
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
sources,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
},
|
||||
removeUpscalingModel(model) {
|
||||
set((prev) => {
|
||||
const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);;
|
||||
return {
|
||||
...prev,
|
||||
extras: {
|
||||
...prev.extras,
|
||||
upscaling,
|
||||
},
|
||||
};
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
import { Maybe } from '@apextoaster/js-utils';
|
||||
import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js';
|
||||
import { Slice } from './types.js';
|
||||
|
||||
export interface ProfileItem {
|
||||
name: string;
|
||||
params: BaseImgParams | Txt2ImgParams;
|
||||
highres?: Maybe<HighresParams>;
|
||||
upscale?: Maybe<UpscaleParams>;
|
||||
}
|
||||
|
||||
export interface ProfileSlice {
|
||||
profiles: Array<ProfileItem>;
|
||||
|
||||
removeProfile(profileName: string): void;
|
||||
|
||||
saveProfile(profile: ProfileItem): void;
|
||||
}
|
||||
|
||||
export function createProfileSlice<TState extends ProfileSlice>(): Slice<TState, ProfileSlice> {
|
||||
return (set) => ({
|
||||
profiles: [],
|
||||
saveProfile(profile: ProfileItem) {
|
||||
set((prev) => {
|
||||
const profiles = [...prev.profiles];
|
||||
const idx = profiles.findIndex((it) => it.name === profile.name);
|
||||
if (idx >= 0) {
|
||||
profiles[idx] = profile;
|
||||
} else {
|
||||
profiles.push(profile);
|
||||
}
|
||||
return {
|
||||
...prev,
|
||||
profiles,
|
||||
};
|
||||
});
|
||||
},
|
||||
removeProfile(profileName: string) {
|
||||
set((prev) => {
|
||||
const profiles = [...prev.profiles];
|
||||
const idx = profiles.findIndex((it) => it.name === profileName);
|
||||
if (idx >= 0) {
|
||||
profiles.splice(idx, 1);
|
||||
}
|
||||
return {
|
||||
...prev,
|
||||
profiles,
|
||||
};
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
import { BlendSlice } from './blend.js';
|
||||
import { Img2ImgSlice } from './img2img.js';
|
||||
import { InpaintSlice } from './inpaint.js';
|
||||
import { Txt2ImgSlice } from './txt2img.js';
|
||||
import { Slice } from './types.js';
|
||||
import { UpscaleSlice } from './upscale.js';
|
||||
|
||||
export type SlicesWithReset = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & UpscaleSlice & BlendSlice;
|
||||
|
||||
export interface ResetSlice {
|
||||
resetAll(): void;
|
||||
}
|
||||
|
||||
export function createResetSlice<TState extends ResetSlice & SlicesWithReset>(): Slice<TState, ResetSlice> {
|
||||
return (set) => ({
|
||||
resetAll() {
|
||||
set((prev) => {
|
||||
const next = { ...prev };
|
||||
next.resetImg2Img();
|
||||
next.resetInpaint();
|
||||
next.resetTxt2Img();
|
||||
next.resetUpscale();
|
||||
next.resetBlend();
|
||||
return next;
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
import { PipelineGrid } from '../client/utils.js';
|
||||
import { ServerParams } from '../config.js';
|
||||
import {
|
||||
BaseImgParams,
|
||||
HighresParams,
|
||||
ModelParams,
|
||||
Txt2ImgParams,
|
||||
UpscaleParams,
|
||||
} from '../types/params.js';
|
||||
import { Slice, TabState } from './types.js';
|
||||
|
||||
export interface Txt2ImgSlice {
|
||||
txt2img: TabState<Txt2ImgParams>;
|
||||
txt2imgModel: ModelParams;
|
||||
txt2imgHighres: HighresParams;
|
||||
txt2imgUpscale: UpscaleParams;
|
||||
txt2imgVariable: PipelineGrid;
|
||||
|
||||
resetTxt2Img(): void;
|
||||
|
||||
setTxt2Img(params: Partial<Txt2ImgParams>): void;
|
||||
setTxt2ImgModel(params: Partial<ModelParams>): void;
|
||||
setTxt2ImgHighres(params: Partial<HighresParams>): void;
|
||||
setTxt2ImgUpscale(params: Partial<UpscaleParams>): void;
|
||||
setTxt2ImgVariable(params: Partial<PipelineGrid>): void;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line max-params
|
||||
export function createTxt2ImgSlice<TState extends Txt2ImgSlice>(
|
||||
server: ServerParams,
|
||||
defaultParams: Required<BaseImgParams>,
|
||||
defaultHighres: HighresParams,
|
||||
defaultModel: ModelParams,
|
||||
defaultUpscale: UpscaleParams,
|
||||
defaultGrid: PipelineGrid,
|
||||
): Slice<TState, Txt2ImgSlice> {
|
||||
return (set) => ({
|
||||
txt2img: {
|
||||
...defaultParams,
|
||||
width: server.width.default,
|
||||
height: server.height.default,
|
||||
},
|
||||
txt2imgHighres: {
|
||||
...defaultHighres,
|
||||
},
|
||||
txt2imgModel: {
|
||||
...defaultModel,
|
||||
},
|
||||
txt2imgUpscale: {
|
||||
...defaultUpscale,
|
||||
},
|
||||
txt2imgVariable: {
|
||||
...defaultGrid,
|
||||
},
|
||||
setTxt2Img(params) {
|
||||
set((prev) => ({
|
||||
txt2img: {
|
||||
...prev.txt2img,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setTxt2ImgHighres(params) {
|
||||
set((prev) => ({
|
||||
txt2imgHighres: {
|
||||
...prev.txt2imgHighres,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setTxt2ImgModel(params) {
|
||||
set((prev) => ({
|
||||
txt2imgModel: {
|
||||
...prev.txt2imgModel,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setTxt2ImgUpscale(params) {
|
||||
set((prev) => ({
|
||||
txt2imgUpscale: {
|
||||
...prev.txt2imgUpscale,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setTxt2ImgVariable(params) {
|
||||
set((prev) => ({
|
||||
txt2imgVariable: {
|
||||
...prev.txt2imgVariable,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
resetTxt2Img() {
|
||||
set({
|
||||
txt2img: {
|
||||
...defaultParams,
|
||||
width: server.width.default,
|
||||
height: server.height.default,
|
||||
},
|
||||
} as Partial<TState>);
|
||||
},
|
||||
});
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
import { PaletteMode } from '@mui/material';
|
||||
import { StateCreator } from 'zustand';
|
||||
import { ConfigFiles, ConfigState } from '../config.js';
|
||||
|
||||
export const MISSING_INDEX = -1;
|
||||
|
||||
export type Theme = PaletteMode | ''; // tri-state, '' is unset
|
||||
|
||||
/**
|
||||
* Combine optional files and required ranges.
|
||||
*/
|
||||
export type TabState<TabParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>;
|
||||
|
||||
/**
|
||||
* Shorthand for state creator to reduce repeated arguments.
|
||||
*/
|
||||
export type Slice<TState, TValue> = StateCreator<TState, [], [], TValue>;
|
|
@ -0,0 +1,87 @@
|
|||
import {
|
||||
BaseImgParams,
|
||||
HighresParams,
|
||||
ModelParams,
|
||||
UpscaleParams,
|
||||
UpscaleReqParams,
|
||||
} from '../types/params.js';
|
||||
import { Slice, TabState } from './types.js';
|
||||
|
||||
export interface UpscaleSlice {
|
||||
upscale: TabState<UpscaleReqParams>;
|
||||
upscaleHighres: HighresParams;
|
||||
upscaleModel: ModelParams;
|
||||
upscaleUpscale: UpscaleParams;
|
||||
|
||||
resetUpscale(): void;
|
||||
|
||||
setUpscale(params: Partial<UpscaleReqParams>): void;
|
||||
setUpscaleHighres(params: Partial<HighresParams>): void;
|
||||
setUpscaleModel(params: Partial<ModelParams>): void;
|
||||
setUpscaleUpscale(params: Partial<UpscaleParams>): void;
|
||||
}
|
||||
|
||||
export function createUpscaleSlice<TState extends UpscaleSlice>(
|
||||
defaultParams: Required<BaseImgParams>,
|
||||
defaultHighres: HighresParams,
|
||||
defaultModel: ModelParams,
|
||||
defaultUpscale: UpscaleParams,
|
||||
): Slice<TState, UpscaleSlice> {
|
||||
return (set) => ({
|
||||
upscale: {
|
||||
...defaultParams,
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
source: null,
|
||||
},
|
||||
upscaleHighres: {
|
||||
...defaultHighres,
|
||||
},
|
||||
upscaleModel: {
|
||||
...defaultModel,
|
||||
},
|
||||
upscaleUpscale: {
|
||||
...defaultUpscale,
|
||||
},
|
||||
resetUpscale() {
|
||||
set({
|
||||
upscale: {
|
||||
...defaultParams,
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
source: null,
|
||||
},
|
||||
} as Partial<TState>);
|
||||
},
|
||||
setUpscale(source) {
|
||||
set((prev) => ({
|
||||
upscale: {
|
||||
...prev.upscale,
|
||||
...source,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setUpscaleHighres(params) {
|
||||
set((prev) => ({
|
||||
upscaleHighres: {
|
||||
...prev.upscaleHighres,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setUpscaleModel(params) {
|
||||
set((prev) => ({
|
||||
upscaleModel: {
|
||||
...prev.upscaleModel,
|
||||
...defaultModel,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
setUpscaleUpscale(params) {
|
||||
set((prev) => ({
|
||||
upscaleUpscale: {
|
||||
...prev.upscaleUpscale,
|
||||
...params,
|
||||
},
|
||||
} as Partial<TState>));
|
||||
},
|
||||
});
|
||||
}
|
|
@ -287,6 +287,7 @@ export const I18N_STRINGS_EN = {
|
|||
'ddpm': 'DDPM',
|
||||
'deis-multi': 'DEIS Multistep',
|
||||
'dpm-multi': 'DPM Multistep',
|
||||
'dpm-sde': 'DPM SDE (Turbo)',
|
||||
'dpm-single': 'DPM Singlestep',
|
||||
'euler': 'Euler',
|
||||
'euler-a': 'Euler Ancestral',
|
||||
|
|
Loading…
Reference in New Issue