From 0dfc1b61d20eb3323c045878773b3953726a7d2f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 14 Dec 2023 21:30:06 -0600 Subject: [PATCH] break up state slice factories --- gui/src/components/ImageHistory.tsx | 2 +- gui/src/components/OnnxError.tsx | 2 +- gui/src/components/OnnxWeb.tsx | 2 +- gui/src/components/Profiles.tsx | 2 +- gui/src/components/card/ErrorCard.tsx | 2 +- gui/src/components/card/ImageCard.tsx | 2 +- gui/src/components/card/LoadingCard.tsx | 2 +- gui/src/components/control/HighresControl.tsx | 2 +- gui/src/components/control/ImageControl.tsx | 2 +- gui/src/components/control/ModelControl.tsx | 2 +- .../components/control/OutpaintControl.tsx | 2 +- gui/src/components/control/UpscaleControl.tsx | 2 +- .../components/control/VariableControl.tsx | 2 +- gui/src/components/input/EditableList.tsx | 2 +- gui/src/components/input/MaskCanvas.tsx | 2 +- gui/src/components/input/PromptInput.tsx | 2 +- gui/src/components/tab/Blend.tsx | 3 +- gui/src/components/tab/Img2Img.tsx | 3 +- gui/src/components/tab/Inpaint.tsx | 3 +- gui/src/components/tab/Models.tsx | 2 +- gui/src/components/tab/Settings.tsx | 2 +- gui/src/components/tab/Txt2Img.tsx | 3 +- gui/src/components/tab/Upscale.tsx | 3 +- gui/src/components/utils.ts | 2 +- gui/src/main.tsx | 2 +- gui/src/state.ts | 817 ------------------ gui/src/state/blend.ts | 65 +- gui/src/state/default.ts | 24 +- gui/src/state/full.ts | 152 ++++ gui/src/state/history.ts | 49 ++ gui/src/state/img2img.ts | 78 +- gui/src/state/inpaint.ts | 112 ++- gui/src/state/model.ts | 202 +++++ gui/src/state/models.ts | 19 - gui/src/state/profile.ts | 35 + gui/src/state/reset.ts | 25 + gui/src/state/txt2img.ts | 83 +- gui/src/state/types.ts | 35 + gui/src/state/upscale.ts | 68 +- 39 files changed, 951 insertions(+), 868 deletions(-) delete mode 100644 gui/src/state.ts create mode 100644 gui/src/state/full.ts create mode 100644 gui/src/state/model.ts delete mode 100644 gui/src/state/models.ts diff --git a/gui/src/components/ImageHistory.tsx b/gui/src/components/ImageHistory.tsx index 0f6b0f2f..20a520ed 100644 --- a/gui/src/components/ImageHistory.tsx +++ b/gui/src/components/ImageHistory.tsx @@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { OnnxState, StateContext } from '../state.js'; +import { OnnxState, StateContext } from '../state/full.js'; import { ErrorCard } from './card/ErrorCard.js'; import { ImageCard } from './card/ImageCard.js'; import { LoadingCard } from './card/LoadingCard.js'; diff --git a/gui/src/components/OnnxError.tsx b/gui/src/components/OnnxError.tsx index da6d7147..05186768 100644 --- a/gui/src/components/OnnxError.tsx +++ b/gui/src/components/OnnxError.tsx @@ -2,7 +2,7 @@ import { Box, Button, Container, Stack, Typography } from '@mui/material'; import * as React from 'react'; import { ReactNode } from 'react'; -import { STATE_KEY } from '../state.js'; +import { STATE_KEY } from '../state/full.js'; import { Logo } from './Logo.js'; export interface OnnxErrorProps { diff --git a/gui/src/components/OnnxWeb.tsx b/gui/src/components/OnnxWeb.tsx index 918ee84f..69c2db20 100644 --- a/gui/src/components/OnnxWeb.tsx +++ b/gui/src/components/OnnxWeb.tsx @@ -7,7 +7,7 @@ import { useContext, useMemo } from 'react'; import { useHash } from 'react-use/lib/useHash'; import { useStore } from 'zustand'; -import { OnnxState, StateContext } from '../state.js'; +import { OnnxState, StateContext } from '../state/full.js'; import { ImageHistory } from './ImageHistory.js'; import { Logo } from './Logo.js'; import { Blend } from './tab/Blend.js'; diff --git a/gui/src/components/Profiles.tsx b/gui/src/components/Profiles.tsx index baf022b3..de8f2263 100644 --- a/gui/src/components/Profiles.tsx +++ b/gui/src/components/Profiles.tsx @@ -21,7 +21,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { OnnxState, StateContext } from '../state.js'; +import { OnnxState, StateContext } from '../state/full.js'; import { ImageMetadata } from '../types/api.js'; import { DeepPartial } from '../types/model.js'; import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js'; diff --git a/gui/src/components/card/ErrorCard.tsx b/gui/src/components/card/ErrorCard.tsx index bb3ac6c9..f7106584 100644 --- a/gui/src/components/card/ErrorCard.tsx +++ b/gui/src/components/card/ErrorCard.tsx @@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { ImageResponse, ReadyResponse, RetryParams } from '../../types/api.js'; export interface ErrorCardProps { diff --git a/gui/src/components/card/ImageCard.tsx b/gui/src/components/card/ImageCard.tsx index 7c35acb2..44f5dd91 100644 --- a/gui/src/components/card/ImageCard.tsx +++ b/gui/src/components/card/ImageCard.tsx @@ -8,7 +8,7 @@ import { useHash } from 'react-use/lib/useHash'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { BLEND_SOURCES, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { BLEND_SOURCES, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { ImageResponse } from '../../types/api.js'; import { range, visibleIndex } from '../../utils.js'; diff --git a/gui/src/components/card/LoadingCard.tsx b/gui/src/components/card/LoadingCard.tsx index e0fcdb68..71bfb5f0 100644 --- a/gui/src/components/card/LoadingCard.tsx +++ b/gui/src/components/card/LoadingCard.tsx @@ -9,7 +9,7 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { POLL_TIME } from '../../config.js'; -import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { ImageResponse } from '../../types/api.js'; const LOADING_PERCENT = 100; diff --git a/gui/src/components/control/HighresControl.tsx b/gui/src/components/control/HighresControl.tsx index 91525b21..62fa63fd 100644 --- a/gui/src/components/control/HighresControl.tsx +++ b/gui/src/components/control/HighresControl.tsx @@ -5,7 +5,7 @@ import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; -import { ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { HighresParams } from '../../types/params.js'; import { NumericField } from '../input/NumericField.js'; diff --git a/gui/src/components/control/ImageControl.tsx b/gui/src/components/control/ImageControl.tsx index 8271c700..8877ae50 100644 --- a/gui/src/components/control/ImageControl.tsx +++ b/gui/src/components/control/ImageControl.tsx @@ -11,7 +11,7 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { STALE_TIME } from '../../config.js'; -import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { BaseImgParams } from '../../types/params.js'; import { NumericField } from '../input/NumericField.js'; import { PromptInput } from '../input/PromptInput.js'; diff --git a/gui/src/components/control/ModelControl.tsx b/gui/src/components/control/ModelControl.tsx index ba08998f..54b99846 100644 --- a/gui/src/components/control/ModelControl.tsx +++ b/gui/src/components/control/ModelControl.tsx @@ -6,7 +6,7 @@ import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; import { STALE_TIME } from '../../config.js'; -import { ClientContext } from '../../state.js'; +import { ClientContext } from '../../state/full.js'; import { ModelParams } from '../../types/params.js'; import { QueryList } from '../input/QueryList.js'; diff --git a/gui/src/components/control/OutpaintControl.tsx b/gui/src/components/control/OutpaintControl.tsx index 5b70cc50..69523477 100644 --- a/gui/src/components/control/OutpaintControl.tsx +++ b/gui/src/components/control/OutpaintControl.tsx @@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; -import { ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { NumericField } from '../input/NumericField.js'; export function OutpaintControl() { diff --git a/gui/src/components/control/UpscaleControl.tsx b/gui/src/components/control/UpscaleControl.tsx index a90dd330..5d840a04 100644 --- a/gui/src/components/control/UpscaleControl.tsx +++ b/gui/src/components/control/UpscaleControl.tsx @@ -5,7 +5,7 @@ import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; -import { ConfigContext, OnnxState, StateContext } from '../../state.js'; +import { ConfigContext, OnnxState, StateContext } from '../../state/full.js'; import { UpscaleParams } from '../../types/params.js'; import { NumericField } from '../input/NumericField.js'; diff --git a/gui/src/components/control/VariableControl.tsx b/gui/src/components/control/VariableControl.tsx index 32a66660..bf5139e1 100644 --- a/gui/src/components/control/VariableControl.tsx +++ b/gui/src/components/control/VariableControl.tsx @@ -5,7 +5,7 @@ import { useContext } from 'react'; import { useStore } from 'zustand'; import { PipelineGrid } from '../../client/utils.js'; -import { OnnxState, StateContext } from '../../state.js'; +import { OnnxState, StateContext } from '../../state/full.js'; import { VARIABLE_PARAMETERS } from '../../types/chain.js'; export interface VariableControlProps { diff --git a/gui/src/components/input/EditableList.tsx b/gui/src/components/input/EditableList.tsx index 3910e394..a6d45aca 100644 --- a/gui/src/components/input/EditableList.tsx +++ b/gui/src/components/input/EditableList.tsx @@ -4,7 +4,7 @@ import * as React from 'react'; import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; -import { OnnxState, StateContext } from '../../state.js'; +import { OnnxState, StateContext } from '../../state/full.js'; const { useContext, useState, memo, useMemo } = React; diff --git a/gui/src/components/input/MaskCanvas.tsx b/gui/src/components/input/MaskCanvas.tsx index ae7f2723..2e605423 100644 --- a/gui/src/components/input/MaskCanvas.tsx +++ b/gui/src/components/input/MaskCanvas.tsx @@ -6,7 +6,7 @@ import React, { RefObject, useContext, useEffect, useMemo, useRef } from 'react' import { useTranslation } from 'react-i18next'; import { SAVE_TIME } from '../../config.js'; -import { ConfigContext, LoggerContext, StateContext } from '../../state.js'; +import { ConfigContext, LoggerContext, StateContext } from '../../state/full.js'; import { BrushParams } from '../../types/params.js'; import { imageFromBlob } from '../../utils.js'; import { NumericField } from './NumericField'; diff --git a/gui/src/components/input/PromptInput.tsx b/gui/src/components/input/PromptInput.tsx index cdaa6751..bc522569 100644 --- a/gui/src/components/input/PromptInput.tsx +++ b/gui/src/components/input/PromptInput.tsx @@ -8,7 +8,7 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { STALE_TIME } from '../../config.js'; -import { ClientContext, OnnxState, StateContext } from '../../state.js'; +import { ClientContext, OnnxState, StateContext } from '../../state/full.js'; import { QueryMenu } from '../input/QueryMenu.js'; import { ModelResponse } from '../../types/api.js'; diff --git a/gui/src/components/tab/Blend.tsx b/gui/src/components/tab/Blend.tsx index a1a304d1..3cd70449 100644 --- a/gui/src/components/tab/Blend.tsx +++ b/gui/src/components/tab/Blend.tsx @@ -8,7 +8,8 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { IMAGE_FILTER } from '../../config.js'; -import { BLEND_SOURCES, ClientContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { BLEND_SOURCES, ClientContext, OnnxState, StateContext } from '../../state/full.js'; +import { TabState } from '../../state/types.js'; import { BlendParams, BrushParams, ModelParams, UpscaleParams } from '../../types/params.js'; import { range } from '../../utils.js'; import { UpscaleControl } from '../control/UpscaleControl.js'; diff --git a/gui/src/components/tab/Img2Img.tsx b/gui/src/components/tab/Img2Img.tsx index cddb8539..26bda5cf 100644 --- a/gui/src/components/tab/Img2Img.tsx +++ b/gui/src/components/tab/Img2Img.tsx @@ -8,7 +8,8 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { IMAGE_FILTER, STALE_TIME } from '../../config.js'; -import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; +import { TabState } from '../../state/types.js'; import { HighresParams, Img2ImgParams, ModelParams, UpscaleParams } from '../../types/params.js'; import { Profiles } from '../Profiles.js'; import { HighresControl } from '../control/HighresControl.js'; diff --git a/gui/src/components/tab/Inpaint.tsx b/gui/src/components/tab/Inpaint.tsx index b783f5b5..ca83835b 100644 --- a/gui/src/components/tab/Inpaint.tsx +++ b/gui/src/components/tab/Inpaint.tsx @@ -8,7 +8,8 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { IMAGE_FILTER, STALE_TIME } from '../../config.js'; -import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; +import { TabState } from '../../state/types.js'; import { BrushParams, HighresParams, InpaintParams, ModelParams, UpscaleParams } from '../../types/params.js'; import { Profiles } from '../Profiles.js'; import { HighresControl } from '../control/HighresControl.js'; diff --git a/gui/src/components/tab/Models.tsx b/gui/src/components/tab/Models.tsx index e634485a..86e96ecd 100644 --- a/gui/src/components/tab/Models.tsx +++ b/gui/src/components/tab/Models.tsx @@ -8,7 +8,7 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { STALE_TIME } from '../../config.js'; -import { ClientContext, OnnxState, StateContext } from '../../state.js'; +import { ClientContext, OnnxState, StateContext } from '../../state/full.js'; import { CorrectionModel, DiffusionModel, diff --git a/gui/src/components/tab/Settings.tsx b/gui/src/components/tab/Settings.tsx index 5f09c1a2..7b25ed3b 100644 --- a/gui/src/components/tab/Settings.tsx +++ b/gui/src/components/tab/Settings.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { getApiRoot } from '../../config.js'; -import { ConfigContext, StateContext, STATE_KEY } from '../../state.js'; +import { ConfigContext, StateContext, STATE_KEY } from '../../state/full.js'; import { getTheme } from '../utils.js'; import { NumericField } from '../input/NumericField.js'; diff --git a/gui/src/components/tab/Txt2Img.tsx b/gui/src/components/tab/Txt2Img.tsx index 921def1f..81286ec2 100644 --- a/gui/src/components/tab/Txt2Img.tsx +++ b/gui/src/components/tab/Txt2Img.tsx @@ -8,7 +8,8 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { PipelineGrid, makeTxt2ImgGridPipeline } from '../../client/utils.js'; -import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js'; +import { TabState } from '../../state/types.js'; import { HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from '../../types/params.js'; import { Profiles } from '../Profiles.js'; import { HighresControl } from '../control/HighresControl.js'; diff --git a/gui/src/components/tab/Upscale.tsx b/gui/src/components/tab/Upscale.tsx index 06314579..e9a1c482 100644 --- a/gui/src/components/tab/Upscale.tsx +++ b/gui/src/components/tab/Upscale.tsx @@ -8,7 +8,8 @@ import { useStore } from 'zustand'; import { shallow } from 'zustand/shallow'; import { IMAGE_FILTER } from '../../config.js'; -import { ClientContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { ClientContext, OnnxState, StateContext } from '../../state/full.js'; +import { TabState } from '../../state/types.js'; import { HighresParams, ModelParams, UpscaleParams, UpscaleReqParams } from '../../types/params.js'; import { Profiles } from '../Profiles.js'; import { HighresControl } from '../control/HighresControl.js'; diff --git a/gui/src/components/utils.ts b/gui/src/components/utils.ts index 58e106ec..39ba9bae 100644 --- a/gui/src/components/utils.ts +++ b/gui/src/components/utils.ts @@ -1,6 +1,6 @@ import { PaletteMode } from '@mui/material'; -import { Theme } from '../state.js'; +import { Theme } from '../state/types.js'; import { trimHash } from '../utils.js'; export const TAB_LABELS = [ diff --git a/gui/src/main.tsx b/gui/src/main.tsx index d0b83a11..73b43f8f 100644 --- a/gui/src/main.tsx +++ b/gui/src/main.tsx @@ -28,7 +28,7 @@ import { STATE_KEY, STATE_VERSION, StateContext, -} from './state.js'; +} from './state/full.js'; import { I18N_STRINGS } from './strings/all.js'; export const INITIAL_LOAD_TIMEOUT = 5_000; diff --git a/gui/src/state.ts b/gui/src/state.ts deleted file mode 100644 index fd9a21a0..00000000 --- a/gui/src/state.ts +++ /dev/null @@ -1,817 +0,0 @@ -/* eslint-disable camelcase */ -/* eslint-disable max-lines */ -/* eslint-disable no-null/no-null */ -import { Maybe } from '@apextoaster/js-utils'; -import { Logger } from 'noicejs'; -import { createContext } from 'react'; -import { StateCreator, StoreApi } from 'zustand'; - -import { - ApiClient, -} from './client/base.js'; -import { PipelineGrid } from './client/utils.js'; -import { Config, ServerParams } from './config.js'; -import { - BaseImgParams, - HighresParams, - ModelParams, - UpscaleParams, -} from './types/params.js'; -import { DefaultSlice } from './state/default.js'; -import { HistorySlice } from './state/history.js'; -import { Img2ImgSlice } from './state/img2img.js'; -import { InpaintSlice } from './state/inpaint.js'; -import { ModelSlice } from './state/models.js'; -import { Txt2ImgSlice } from './state/txt2img.js'; -import { UpscaleSlice } from './state/upscale.js'; -import { ResetSlice } from './state/reset.js'; -import { ProfileItem, ProfileSlice } from './state/profile.js'; -import { BlendSlice } from './state/blend.js'; -import { MISSING_INDEX } from './state/types.js'; - -/** - * Full merged state including all slices. - */ -export type OnnxState - = DefaultSlice - & HistorySlice - & Img2ImgSlice - & InpaintSlice - & ModelSlice - & Txt2ImgSlice - & UpscaleSlice - & BlendSlice - & ResetSlice - & ProfileSlice; - -/** - * Shorthand for state creator to reduce repeated arguments. - */ -export type Slice = StateCreator; - -/** - * React context binding for API client. - */ -export const ClientContext = createContext>(undefined); - -/** - * React context binding for merged config, including server parameters. - */ -export const ConfigContext = createContext>>(undefined); - -/** - * React context binding for bunyan logger. - */ -export const LoggerContext = createContext>(undefined); - -/** - * React context binding for zustand state store. - */ -export const StateContext = createContext>>(undefined); - -/** - * Key for zustand persistence, typically local storage. - */ -export const STATE_KEY = 'onnx-web'; - -/** - * Current state version for zustand persistence. - */ -export const STATE_VERSION = 7; - -export const BLEND_SOURCES = 2; - -/** - * Default parameters for the inpaint brush. - * - * Not provided by the server yet. - */ -export const DEFAULT_BRUSH = { - color: 255, - size: 8, - strength: 0.5, -}; - -/** - * Default parameters for the image history. - * - * Not provided by the server yet. - */ -export const DEFAULT_HISTORY = { - /** - * The number of images to be shown. - */ - limit: 4, - - /** - * The number of additional images to be kept in history, so they can scroll - * back into view when you delete one. Does not include deleted images. - */ - scrollback: 2, -}; - -export function baseParamsFromServer(defaults: ServerParams): Required { - return { - batch: defaults.batch.default, - cfg: defaults.cfg.default, - eta: defaults.eta.default, - negativePrompt: defaults.negativePrompt.default, - prompt: defaults.prompt.default, - scheduler: defaults.scheduler.default, - steps: defaults.steps.default, - seed: defaults.seed.default, - tiled_vae: defaults.tiled_vae.default, - unet_overlap: defaults.unet_overlap.default, - unet_tile: defaults.unet_tile.default, - vae_overlap: defaults.vae_overlap.default, - vae_tile: defaults.vae_tile.default, - }; -} - -/** - * Prepare the state slice constructors. - * - * In the default state, image sources should be null and booleans should be false. Everything - * else should be initialized from the default value in the base parameters. - */ -export function createStateSlices(server: ServerParams) { - const defaultParams = baseParamsFromServer(server); - const defaultHighres: HighresParams = { - enabled: false, - highresIterations: server.highresIterations.default, - highresMethod: '', - highresSteps: server.highresSteps.default, - highresScale: server.highresScale.default, - highresStrength: server.highresStrength.default, - }; - const defaultModel: ModelParams = { - control: server.control.default, - correction: server.correction.default, - model: server.model.default, - pipeline: server.pipeline.default, - platform: server.platform.default, - upscaling: server.upscaling.default, - }; - const defaultUpscale: UpscaleParams = { - denoise: server.denoise.default, - enabled: false, - faces: false, - faceOutscale: server.faceOutscale.default, - faceStrength: server.faceStrength.default, - outscale: server.outscale.default, - scale: server.scale.default, - upscaleOrder: server.upscaleOrder.default, - }; - const defaultGrid: PipelineGrid = { - enabled: false, - columns: { - parameter: 'seed', - value: '', - }, - rows: { - parameter: 'seed', - value: '', - }, - }; - - const createTxt2ImgSlice: Slice = (set) => ({ - txt2img: { - ...defaultParams, - width: server.width.default, - height: server.height.default, - }, - txt2imgHighres: { - ...defaultHighres, - }, - txt2imgModel: { - ...defaultModel, - }, - txt2imgUpscale: { - ...defaultUpscale, - }, - txt2imgVariable: { - ...defaultGrid, - }, - setTxt2Img(params) { - set((prev) => ({ - txt2img: { - ...prev.txt2img, - ...params, - }, - })); - }, - setTxt2ImgHighres(params) { - set((prev) => ({ - txt2imgHighres: { - ...prev.txt2imgHighres, - ...params, - }, - })); - }, - setTxt2ImgModel(params) { - set((prev) => ({ - txt2imgModel: { - ...prev.txt2imgModel, - ...params, - }, - })); - }, - setTxt2ImgUpscale(params) { - set((prev) => ({ - txt2imgUpscale: { - ...prev.txt2imgUpscale, - ...params, - }, - })); - }, - setTxt2ImgVariable(params) { - set((prev) => ({ - txt2imgVariable: { - ...prev.txt2imgVariable, - ...params, - }, - })); - }, - resetTxt2Img() { - set({ - txt2img: { - ...defaultParams, - width: server.width.default, - height: server.height.default, - }, - }); - }, - }); - - const createImg2ImgSlice: Slice = (set) => ({ - img2img: { - ...defaultParams, - loopback: server.loopback.default, - source: null, - sourceFilter: '', - strength: server.strength.default, - }, - img2imgHighres: { - ...defaultHighres, - }, - img2imgModel: { - ...defaultModel, - }, - img2imgUpscale: { - ...defaultUpscale, - }, - resetImg2Img() { - set({ - img2img: { - ...defaultParams, - loopback: server.loopback.default, - source: null, - sourceFilter: '', - strength: server.strength.default, - }, - }); - }, - setImg2Img(params) { - set((prev) => ({ - img2img: { - ...prev.img2img, - ...params, - }, - })); - }, - setImg2ImgHighres(params) { - set((prev) => ({ - img2imgHighres: { - ...prev.img2imgHighres, - ...params, - }, - })); - }, - setImg2ImgModel(params) { - set((prev) => ({ - img2imgModel: { - ...prev.img2imgModel, - ...params, - }, - })); - }, - setImg2ImgUpscale(params) { - set((prev) => ({ - img2imgUpscale: { - ...prev.img2imgUpscale, - ...params, - }, - })); - }, - }); - - const createInpaintSlice: Slice = (set) => ({ - inpaint: { - ...defaultParams, - fillColor: server.fillColor.default, - filter: server.filter.default, - mask: null, - noise: server.noise.default, - source: null, - strength: server.strength.default, - tileOrder: server.tileOrder.default, - }, - inpaintBrush: { - ...DEFAULT_BRUSH, - }, - inpaintHighres: { - ...defaultHighres, - }, - inpaintModel: { - ...defaultModel, - }, - inpaintUpscale: { - ...defaultUpscale, - }, - outpaint: { - enabled: false, - left: server.left.default, - right: server.right.default, - top: server.top.default, - bottom: server.bottom.default, - }, - resetInpaint() { - set({ - inpaint: { - ...defaultParams, - fillColor: server.fillColor.default, - filter: server.filter.default, - mask: null, - noise: server.noise.default, - source: null, - strength: server.strength.default, - tileOrder: server.tileOrder.default, - }, - }); - }, - setInpaint(params) { - set((prev) => ({ - inpaint: { - ...prev.inpaint, - ...params, - }, - })); - }, - setInpaintBrush(brush) { - set((prev) => ({ - inpaintBrush: { - ...prev.inpaintBrush, - ...brush, - }, - })); - }, - setInpaintHighres(params) { - set((prev) => ({ - inpaintHighres: { - ...prev.inpaintHighres, - ...params, - }, - })); - }, - setInpaintModel(params) { - set((prev) => ({ - inpaintModel: { - ...prev.inpaintModel, - ...params, - }, - })); - }, - setInpaintUpscale(params) { - set((prev) => ({ - inpaintUpscale: { - ...prev.inpaintUpscale, - ...params, - }, - })); - }, - setOutpaint(pixels) { - set((prev) => ({ - outpaint: { - ...prev.outpaint, - ...pixels, - } - })); - }, - }); - - const createHistorySlice: Slice = (set) => ({ - history: [], - limit: DEFAULT_HISTORY.limit, - pushHistory(image, retry) { - set((prev) => ({ - ...prev, - history: [ - { - image, - ready: undefined, - retry, - }, - ...prev.history, - ].slice(0, prev.limit + DEFAULT_HISTORY.scrollback), - })); - }, - removeHistory(image) { - set((prev) => ({ - ...prev, - history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key), - })); - }, - setLimit(limit) { - set((prev) => ({ - ...prev, - limit, - })); - }, - setReady(image, ready) { - set((prev) => { - const history = [...prev.history]; - const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key); - if (idx >= 0) { - history[idx].ready = ready; - } else { - // TODO: error - } - - return { - ...prev, - history, - }; - }); - }, - }); - - const createUpscaleSlice: Slice = (set) => ({ - upscale: { - ...defaultParams, - source: null, - }, - upscaleHighres: { - ...defaultHighres, - }, - upscaleModel: { - ...defaultModel, - }, - upscaleUpscale: { - ...defaultUpscale, - }, - resetUpscale() { - set({ - upscale: { - ...defaultParams, - source: null, - }, - }); - }, - setUpscale(source) { - set((prev) => ({ - upscale: { - ...prev.upscale, - ...source, - }, - })); - }, - setUpscaleHighres(params) { - set((prev) => ({ - upscaleHighres: { - ...prev.upscaleHighres, - ...params, - }, - })); - }, - setUpscaleModel(params) { - set((prev) => ({ - upscaleModel: { - ...prev.upscaleModel, - ...defaultModel, - }, - })); - }, - setUpscaleUpscale(params) { - set((prev) => ({ - upscaleUpscale: { - ...prev.upscaleUpscale, - ...params, - }, - })); - }, - }); - - const createBlendSlice: Slice = (set) => ({ - blend: { - mask: null, - sources: [], - }, - blendBrush: { - ...DEFAULT_BRUSH, - }, - blendModel: { - ...defaultModel, - }, - blendUpscale: { - ...defaultUpscale, - }, - resetBlend() { - set({ - blend: { - mask: null, - sources: [], - }, - }); - }, - setBlend(blend) { - set((prev) => ({ - blend: { - ...prev.blend, - ...blend, - }, - })); - }, - setBlendBrush(brush) { - set((prev) => ({ - blendBrush: { - ...prev.blendBrush, - ...brush, - }, - })); - }, - setBlendModel(model) { - set((prev) => ({ - blendModel: { - ...prev.blendModel, - ...model, - }, - })); - }, - setBlendUpscale(params) { - set((prev) => ({ - blendUpscale: { - ...prev.blendUpscale, - ...params, - }, - })); - }, - }); - - const createDefaultSlice: Slice = (set) => ({ - defaults: { - ...defaultParams, - }, - theme: '', - setDefaults(params) { - set((prev) => ({ - defaults: { - ...prev.defaults, - ...params, - } - })); - }, - setTheme(theme) { - set((prev) => ({ - theme, - })); - } - }); - - const createResetSlice: Slice = (set) => ({ - resetAll() { - set((prev) => { - const next = { ...prev }; - next.resetImg2Img(); - next.resetInpaint(); - next.resetTxt2Img(); - next.resetUpscale(); - next.resetBlend(); - return next; - }); - }, - }); - - const createProfileSlice: Slice = (set) => ({ - profiles: [], - saveProfile(profile: ProfileItem) { - set((prev) => { - const profiles = [...prev.profiles]; - const idx = profiles.findIndex((it) => it.name === profile.name); - if (idx >= 0) { - profiles[idx] = profile; - } else { - profiles.push(profile); - } - return { - ...prev, - profiles, - }; - }); - }, - removeProfile(profileName: string) { - set((prev) => { - const profiles = [...prev.profiles]; - const idx = profiles.findIndex((it) => it.name === profileName); - if (idx >= 0) { - profiles.splice(idx, 1); - } - return { - ...prev, - profiles, - }; - }); - } - }); - - // eslint-disable-next-line sonarjs/cognitive-complexity - const createModelSlice: Slice = (set) => ({ - extras: { - correction: [], - diffusion: [], - networks: [], - sources: [], - upscaling: [], - }, - setExtras(extras) { - set((prev) => ({ - ...prev, - extras: { - ...prev.extras, - ...extras, - }, - })); - }, - setCorrectionModel(model) { - set((prev) => { - const correction = [...prev.extras.correction]; - const exists = correction.findIndex((it) => model.name === it.name); - if (exists === MISSING_INDEX) { - correction.push(model); - } else { - correction[exists] = model; - } - - return { - ...prev, - extras: { - ...prev.extras, - correction, - }, - }; - }); - }, - setDiffusionModel(model) { - set((prev) => { - const diffusion = [...prev.extras.diffusion]; - const exists = diffusion.findIndex((it) => model.name === it.name); - if (exists === MISSING_INDEX) { - diffusion.push(model); - } else { - diffusion[exists] = model; - } - - return { - ...prev, - extras: { - ...prev.extras, - diffusion, - }, - }; - }); - }, - setExtraNetwork(model) { - set((prev) => { - const networks = [...prev.extras.networks]; - const exists = networks.findIndex((it) => model.name === it.name); - if (exists === MISSING_INDEX) { - networks.push(model); - } else { - networks[exists] = model; - } - - return { - ...prev, - extras: { - ...prev.extras, - networks, - }, - }; - }); - }, - setExtraSource(model) { - set((prev) => { - const sources = [...prev.extras.sources]; - const exists = sources.findIndex((it) => model.name === it.name); - if (exists === MISSING_INDEX) { - sources.push(model); - } else { - sources[exists] = model; - } - - return { - ...prev, - extras: { - ...prev.extras, - sources, - }, - }; - }); - }, - setUpscalingModel(model) { - set((prev) => { - const upscaling = [...prev.extras.upscaling]; - const exists = upscaling.findIndex((it) => model.name === it.name); - if (exists === MISSING_INDEX) { - upscaling.push(model); - } else { - upscaling[exists] = model; - } - - return { - ...prev, - extras: { - ...prev.extras, - upscaling, - }, - }; - }); - }, - removeCorrectionModel(model) { - set((prev) => { - const correction = prev.extras.correction.filter((it) => model.name !== it.name);; - return { - ...prev, - extras: { - ...prev.extras, - correction, - }, - }; - }); - - }, - removeDiffusionModel(model) { - set((prev) => { - const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);; - return { - ...prev, - extras: { - ...prev.extras, - diffusion, - }, - }; - }); - - }, - removeExtraNetwork(model) { - set((prev) => { - const networks = prev.extras.networks.filter((it) => model.name !== it.name);; - return { - ...prev, - extras: { - ...prev.extras, - networks, - }, - }; - }); - - }, - removeExtraSource(model) { - set((prev) => { - const sources = prev.extras.sources.filter((it) => model.name !== it.name);; - return { - ...prev, - extras: { - ...prev.extras, - sources, - }, - }; - }); - - }, - removeUpscalingModel(model) { - set((prev) => { - const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);; - return { - ...prev, - extras: { - ...prev.extras, - upscaling, - }, - }; - }); - }, - }); - - return { - createDefaultSlice, - createHistorySlice, - createImg2ImgSlice, - createInpaintSlice, - createTxt2ImgSlice, - createUpscaleSlice, - createBlendSlice, - createResetSlice, - createModelSlice, - createProfileSlice, - }; -} diff --git a/gui/src/state/blend.ts b/gui/src/state/blend.ts index 42d292f3..e433a07e 100644 --- a/gui/src/state/blend.ts +++ b/gui/src/state/blend.ts @@ -4,7 +4,7 @@ import { ModelParams, UpscaleParams, } from '../types/params.js'; -import { TabState } from './types.js'; +import { DEFAULT_BRUSH, Slice, TabState } from './types.js'; export interface BlendSlice { blend: TabState; @@ -19,3 +19,66 @@ export interface BlendSlice { setBlendModel(model: Partial): void; setBlendUpscale(params: Partial): void; } + +export function createBlendSlice( + defaultModel: ModelParams, + defaultUpscale: UpscaleParams, +): Slice { + return (set) => ({ + blend: { + // eslint-disable-next-line no-null/no-null + mask: null, + sources: [], + }, + blendBrush: { + ...DEFAULT_BRUSH, + }, + blendModel: { + ...defaultModel, + }, + blendUpscale: { + ...defaultUpscale, + }, + resetBlend() { + set((prev) => ({ + blend: { + // eslint-disable-next-line no-null/no-null + mask: null, + sources: [] as Array, + }, + } as Partial)); + }, + setBlend(blend) { + set((prev) => ({ + blend: { + ...prev.blend, + ...blend, + }, + } as Partial)); + }, + setBlendBrush(brush) { + set((prev) => ({ + blendBrush: { + ...prev.blendBrush, + ...brush, + }, + } as Partial)); + }, + setBlendModel(model) { + set((prev) => ({ + blendModel: { + ...prev.blendModel, + ...model, + }, + } as Partial)); + }, + setBlendUpscale(params) { + set((prev) => ({ + blendUpscale: { + ...prev.blendUpscale, + ...params, + }, + } as Partial)); + }, + }); +} diff --git a/gui/src/state/default.ts b/gui/src/state/default.ts index b43b7ceb..c578263b 100644 --- a/gui/src/state/default.ts +++ b/gui/src/state/default.ts @@ -1,7 +1,7 @@ import { BaseImgParams, } from '../types/params.js'; -import { TabState, Theme } from './types.js'; +import { Slice, TabState, Theme } from './types.js'; export interface DefaultSlice { defaults: TabState; @@ -10,3 +10,25 @@ export interface DefaultSlice { setDefaults(param: Partial): void; setTheme(theme: Theme): void; } + +export function createDefaultSlice(defaultParams: Required): Slice { + return (set) => ({ + defaults: { + ...defaultParams, + }, + theme: '', + setDefaults(params) { + set((prev) => ({ + defaults: { + ...prev.defaults, + ...params, + } + } as Partial)); + }, + setTheme(theme) { + set((prev) => ({ + theme, + } as Partial)); + } + }); +} diff --git a/gui/src/state/full.ts b/gui/src/state/full.ts new file mode 100644 index 00000000..e46c328a --- /dev/null +++ b/gui/src/state/full.ts @@ -0,0 +1,152 @@ +/* eslint-disable camelcase */ +import { Maybe } from '@apextoaster/js-utils'; +import { Logger } from 'noicejs'; +import { createContext } from 'react'; +import { StoreApi } from 'zustand'; + +import { + ApiClient, +} from '../client/base.js'; +import { PipelineGrid } from '../client/utils.js'; +import { Config, ServerParams } from '../config.js'; +import { BlendSlice, createBlendSlice } from './blend.js'; +import { DefaultSlice, createDefaultSlice } from './default.js'; +import { HistorySlice, createHistorySlice } from './history.js'; +import { Img2ImgSlice, createImg2ImgSlice } from './img2img.js'; +import { InpaintSlice, createInpaintSlice } from './inpaint.js'; +import { ModelSlice, createModelSlice } from './model.js'; +import { ProfileSlice, createProfileSlice } from './profile.js'; +import { ResetSlice, createResetSlice } from './reset.js'; +import { Txt2ImgSlice, createTxt2ImgSlice } from './txt2img.js'; +import { UpscaleSlice, createUpscaleSlice } from './upscale.js'; +import { + BaseImgParams, + HighresParams, + ModelParams, + UpscaleParams, +} from '../types/params.js'; + +/** + * Full merged state including all slices. + */ +export type OnnxState + = DefaultSlice + & HistorySlice + & Img2ImgSlice + & InpaintSlice + & ModelSlice + & Txt2ImgSlice + & UpscaleSlice + & BlendSlice + & ResetSlice + & ProfileSlice; + +/** + * React context binding for API client. + */ +export const ClientContext = createContext>(undefined); + +/** + * React context binding for merged config, including server parameters. + */ +export const ConfigContext = createContext>>(undefined); + +/** + * React context binding for bunyan logger. + */ +export const LoggerContext = createContext>(undefined); + +/** + * React context binding for zustand state store. + */ +export const StateContext = createContext>>(undefined); + +/** + * Key for zustand persistence, typically local storage. + */ +export const STATE_KEY = 'onnx-web'; + +/** + * Current state version for zustand persistence. + */ +export const STATE_VERSION = 7; + +export const BLEND_SOURCES = 2; + +export function baseParamsFromServer(defaults: ServerParams): Required { + return { + batch: defaults.batch.default, + cfg: defaults.cfg.default, + eta: defaults.eta.default, + negativePrompt: defaults.negativePrompt.default, + prompt: defaults.prompt.default, + scheduler: defaults.scheduler.default, + steps: defaults.steps.default, + seed: defaults.seed.default, + tiled_vae: defaults.tiled_vae.default, + unet_overlap: defaults.unet_overlap.default, + unet_tile: defaults.unet_tile.default, + vae_overlap: defaults.vae_overlap.default, + vae_tile: defaults.vae_tile.default, + }; +} + +/** + * Prepare the state slice constructors. + * + * In the default state, image sources should be null and booleans should be false. Everything + * else should be initialized from the default value in the base parameters. + */ +export function createStateSlices(server: ServerParams) { + const defaultParams = baseParamsFromServer(server); + const defaultHighres: HighresParams = { + enabled: false, + highresIterations: server.highresIterations.default, + highresMethod: '', + highresSteps: server.highresSteps.default, + highresScale: server.highresScale.default, + highresStrength: server.highresStrength.default, + }; + const defaultModel: ModelParams = { + control: server.control.default, + correction: server.correction.default, + model: server.model.default, + pipeline: server.pipeline.default, + platform: server.platform.default, + upscaling: server.upscaling.default, + }; + const defaultUpscale: UpscaleParams = { + denoise: server.denoise.default, + enabled: false, + faces: false, + faceOutscale: server.faceOutscale.default, + faceStrength: server.faceStrength.default, + outscale: server.outscale.default, + scale: server.scale.default, + upscaleOrder: server.upscaleOrder.default, + }; + const defaultGrid: PipelineGrid = { + enabled: false, + columns: { + parameter: 'seed', + value: '', + }, + rows: { + parameter: 'seed', + value: '', + }, + }; + + return { + createBlendSlice: createBlendSlice(defaultModel, defaultUpscale), + createDefaultSlice: createDefaultSlice(defaultParams), + createHistorySlice: createHistorySlice(), + createImg2ImgSlice: createImg2ImgSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale), + createInpaintSlice: createInpaintSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale), + createModelSlice: createModelSlice(), + createProfileSlice: createProfileSlice(), + createResetSlice: createResetSlice(), + createTxt2ImgSlice: createTxt2ImgSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale, defaultGrid), + createUpscaleSlice: createUpscaleSlice(defaultParams, defaultHighres, defaultModel, defaultUpscale), + }; +} diff --git a/gui/src/state/history.ts b/gui/src/state/history.ts index 44f7b2f3..4eb58271 100644 --- a/gui/src/state/history.ts +++ b/gui/src/state/history.ts @@ -1,5 +1,6 @@ import { Maybe } from '@apextoaster/js-utils'; import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js'; +import { DEFAULT_HISTORY, Slice } from './types.js'; export interface HistoryItem { image: ImageResponse; @@ -16,3 +17,51 @@ export interface HistorySlice { setLimit(limit: number): void; setReady(image: ImageResponse, ready: ReadyResponse): void; } + +export function createHistorySlice(): Slice { + return (set) => ({ + history: [], + limit: DEFAULT_HISTORY.limit, + pushHistory(image, retry) { + set((prev) => ({ + ...prev, + history: [ + { + image, + ready: undefined, + retry, + }, + ...prev.history, + ].slice(0, prev.limit + DEFAULT_HISTORY.scrollback), + })); + }, + removeHistory(image) { + set((prev) => ({ + ...prev, + history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key), + })); + }, + setLimit(limit) { + set((prev) => ({ + ...prev, + limit, + })); + }, + setReady(image, ready) { + set((prev) => { + const history = [...prev.history]; + const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key); + if (idx >= 0) { + history[idx].ready = ready; + } else { + // TODO: error + } + + return { + ...prev, + history, + }; + }); + }, + }); +} diff --git a/gui/src/state/img2img.ts b/gui/src/state/img2img.ts index cbe204d1..b8f986b0 100644 --- a/gui/src/state/img2img.ts +++ b/gui/src/state/img2img.ts @@ -1,11 +1,13 @@ +import { ServerParams } from '../config.js'; import { + BaseImgParams, HighresParams, Img2ImgParams, ModelParams, UpscaleParams, } from '../types/params.js'; -import { TabState } from './types.js'; +import { Slice, TabState } from './types.js'; export interface Img2ImgSlice { img2img: TabState; @@ -20,3 +22,77 @@ export interface Img2ImgSlice { setImg2ImgHighres(params: Partial): void; setImg2ImgUpscale(params: Partial): void; } + +// eslint-disable-next-line max-params +export function createImg2ImgSlice( + server: ServerParams, + defaultParams: Required, + defaultHighres: HighresParams, + defaultModel: ModelParams, + defaultUpscale: UpscaleParams +): Slice { + return (set) => ({ + img2img: { + ...defaultParams, + loopback: server.loopback.default, + // eslint-disable-next-line no-null/no-null + source: null, + sourceFilter: '', + strength: server.strength.default, + }, + img2imgHighres: { + ...defaultHighres, + }, + img2imgModel: { + ...defaultModel, + }, + img2imgUpscale: { + ...defaultUpscale, + }, + resetImg2Img() { + set({ + img2img: { + ...defaultParams, + loopback: server.loopback.default, + // eslint-disable-next-line no-null/no-null + source: null, + sourceFilter: '', + strength: server.strength.default, + }, + } as Partial); + }, + setImg2Img(params) { + set((prev) => ({ + img2img: { + ...prev.img2img, + ...params, + }, + } as Partial)); + }, + setImg2ImgHighres(params) { + set((prev) => ({ + img2imgHighres: { + ...prev.img2imgHighres, + ...params, + }, + } as Partial)); + }, + setImg2ImgModel(params) { + set((prev) => ({ + img2imgModel: { + ...prev.img2imgModel, + ...params, + }, + } as Partial)); + }, + setImg2ImgUpscale(params) { + set((prev) => ({ + img2imgUpscale: { + ...prev.img2imgUpscale, + ...params, + }, + } as Partial)); + }, + }); + +} diff --git a/gui/src/state/inpaint.ts b/gui/src/state/inpaint.ts index 12756ac2..3eac9113 100644 --- a/gui/src/state/inpaint.ts +++ b/gui/src/state/inpaint.ts @@ -1,4 +1,6 @@ +import { ServerParams } from '../config.js'; import { + BaseImgParams, BrushParams, HighresParams, InpaintParams, @@ -6,8 +8,7 @@ import { OutpaintPixels, UpscaleParams, } from '../types/params.js'; -import { TabState } from './types.js'; - +import { DEFAULT_BRUSH, Slice, TabState } from './types.js'; export interface InpaintSlice { inpaint: TabState; inpaintBrush: BrushParams; @@ -25,3 +26,110 @@ export interface InpaintSlice { setInpaintUpscale(params: Partial): void; setOutpaint(pixels: Partial): void; } + +// eslint-disable-next-line max-params +export function createInpaintSlice( + server: ServerParams, + defaultParams: Required, + defaultHighres: HighresParams, + defaultModel: ModelParams, + defaultUpscale: UpscaleParams, +): Slice { + return (set) => ({ + inpaint: { + ...defaultParams, + fillColor: server.fillColor.default, + filter: server.filter.default, + // eslint-disable-next-line no-null/no-null + mask: null, + noise: server.noise.default, + // eslint-disable-next-line no-null/no-null + source: null, + strength: server.strength.default, + tileOrder: server.tileOrder.default, + }, + inpaintBrush: { + ...DEFAULT_BRUSH, + }, + inpaintHighres: { + ...defaultHighres, + }, + inpaintModel: { + ...defaultModel, + }, + inpaintUpscale: { + ...defaultUpscale, + }, + outpaint: { + enabled: false, + left: server.left.default, + right: server.right.default, + top: server.top.default, + bottom: server.bottom.default, + }, + resetInpaint() { + set({ + inpaint: { + ...defaultParams, + fillColor: server.fillColor.default, + filter: server.filter.default, + // eslint-disable-next-line no-null/no-null + mask: null, + noise: server.noise.default, + // eslint-disable-next-line no-null/no-null + source: null, + strength: server.strength.default, + tileOrder: server.tileOrder.default, + }, + } as Partial); + }, + setInpaint(params) { + set((prev) => ({ + inpaint: { + ...prev.inpaint, + ...params, + }, + } as Partial)); + }, + setInpaintBrush(brush) { + set((prev) => ({ + inpaintBrush: { + ...prev.inpaintBrush, + ...brush, + }, + } as Partial)); + }, + setInpaintHighres(params) { + set((prev) => ({ + inpaintHighres: { + ...prev.inpaintHighres, + ...params, + }, + } as Partial)); + }, + setInpaintModel(params) { + set((prev) => ({ + inpaintModel: { + ...prev.inpaintModel, + ...params, + }, + } as Partial)); + }, + setInpaintUpscale(params) { + set((prev) => ({ + inpaintUpscale: { + ...prev.inpaintUpscale, + ...params, + }, + } as Partial)); + }, + setOutpaint(pixels) { + set((prev) => ({ + outpaint: { + ...prev.outpaint, + ...pixels, + } + } as Partial)); + }, + }); +} diff --git a/gui/src/state/model.ts b/gui/src/state/model.ts new file mode 100644 index 00000000..3a473182 --- /dev/null +++ b/gui/src/state/model.ts @@ -0,0 +1,202 @@ +import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from '../types/model.js'; +import { MISSING_INDEX, Slice } from './types.js'; + +export interface ModelSlice { + extras: ExtrasFile; + + removeCorrectionModel(model: CorrectionModel): void; + removeDiffusionModel(model: DiffusionModel): void; + removeExtraNetwork(model: ExtraNetwork): void; + removeExtraSource(model: ExtraSource): void; + removeUpscalingModel(model: UpscalingModel): void; + + setExtras(extras: Partial): void; + + setCorrectionModel(model: CorrectionModel): void; + setDiffusionModel(model: DiffusionModel): void; + setExtraNetwork(model: ExtraNetwork): void; + setExtraSource(model: ExtraSource): void; + setUpscalingModel(model: UpscalingModel): void; +} + +// eslint-disable-next-line sonarjs/cognitive-complexity +export function createModelSlice(): Slice { + // eslint-disable-next-line sonarjs/cognitive-complexity + return (set) => ({ + extras: { + correction: [], + diffusion: [], + networks: [], + sources: [], + upscaling: [], + }, + setExtras(extras) { + set((prev) => ({ + ...prev, + extras: { + ...prev.extras, + ...extras, + }, + })); + }, + setCorrectionModel(model) { + set((prev) => { + const correction = [...prev.extras.correction]; + const exists = correction.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + correction.push(model); + } else { + correction[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + correction, + }, + }; + }); + }, + setDiffusionModel(model) { + set((prev) => { + const diffusion = [...prev.extras.diffusion]; + const exists = diffusion.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + diffusion.push(model); + } else { + diffusion[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + diffusion, + }, + }; + }); + }, + setExtraNetwork(model) { + set((prev) => { + const networks = [...prev.extras.networks]; + const exists = networks.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + networks.push(model); + } else { + networks[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + networks, + }, + }; + }); + }, + setExtraSource(model) { + set((prev) => { + const sources = [...prev.extras.sources]; + const exists = sources.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + sources.push(model); + } else { + sources[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + sources, + }, + }; + }); + }, + setUpscalingModel(model) { + set((prev) => { + const upscaling = [...prev.extras.upscaling]; + const exists = upscaling.findIndex((it) => model.name === it.name); + if (exists === MISSING_INDEX) { + upscaling.push(model); + } else { + upscaling[exists] = model; + } + + return { + ...prev, + extras: { + ...prev.extras, + upscaling, + }, + }; + }); + }, + removeCorrectionModel(model) { + set((prev) => { + const correction = prev.extras.correction.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + correction, + }, + }; + }); + + }, + removeDiffusionModel(model) { + set((prev) => { + const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + diffusion, + }, + }; + }); + + }, + removeExtraNetwork(model) { + set((prev) => { + const networks = prev.extras.networks.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + networks, + }, + }; + }); + + }, + removeExtraSource(model) { + set((prev) => { + const sources = prev.extras.sources.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + sources, + }, + }; + }); + + }, + removeUpscalingModel(model) { + set((prev) => { + const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);; + return { + ...prev, + extras: { + ...prev.extras, + upscaling, + }, + }; + }); + }, + }); +} diff --git a/gui/src/state/models.ts b/gui/src/state/models.ts deleted file mode 100644 index e0faa27d..00000000 --- a/gui/src/state/models.ts +++ /dev/null @@ -1,19 +0,0 @@ -import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from '../types/model.js'; - -export interface ModelSlice { - extras: ExtrasFile; - - removeCorrectionModel(model: CorrectionModel): void; - removeDiffusionModel(model: DiffusionModel): void; - removeExtraNetwork(model: ExtraNetwork): void; - removeExtraSource(model: ExtraSource): void; - removeUpscalingModel(model: UpscalingModel): void; - - setExtras(extras: Partial): void; - - setCorrectionModel(model: CorrectionModel): void; - setDiffusionModel(model: DiffusionModel): void; - setExtraNetwork(model: ExtraNetwork): void; - setExtraSource(model: ExtraSource): void; - setUpscalingModel(model: UpscalingModel): void; -} diff --git a/gui/src/state/profile.ts b/gui/src/state/profile.ts index 7ffdfac9..73d52eab 100644 --- a/gui/src/state/profile.ts +++ b/gui/src/state/profile.ts @@ -1,5 +1,6 @@ import { Maybe } from '@apextoaster/js-utils'; import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js'; +import { Slice } from './types.js'; export interface ProfileItem { name: string; @@ -15,3 +16,37 @@ export interface ProfileSlice { saveProfile(profile: ProfileItem): void; } + +export function createProfileSlice(): Slice { + return (set) => ({ + profiles: [], + saveProfile(profile: ProfileItem) { + set((prev) => { + const profiles = [...prev.profiles]; + const idx = profiles.findIndex((it) => it.name === profile.name); + if (idx >= 0) { + profiles[idx] = profile; + } else { + profiles.push(profile); + } + return { + ...prev, + profiles, + }; + }); + }, + removeProfile(profileName: string) { + set((prev) => { + const profiles = [...prev.profiles]; + const idx = profiles.findIndex((it) => it.name === profileName); + if (idx >= 0) { + profiles.splice(idx, 1); + } + return { + ...prev, + profiles, + }; + }); + } + }); +} diff --git a/gui/src/state/reset.ts b/gui/src/state/reset.ts index 66b545a5..53272e5c 100644 --- a/gui/src/state/reset.ts +++ b/gui/src/state/reset.ts @@ -1,3 +1,28 @@ +import { BlendSlice } from './blend.js'; +import { Img2ImgSlice } from './img2img.js'; +import { InpaintSlice } from './inpaint.js'; +import { Txt2ImgSlice } from './txt2img.js'; +import { Slice } from './types.js'; +import { UpscaleSlice } from './upscale.js'; + +export type SlicesWithReset = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & UpscaleSlice & BlendSlice; + export interface ResetSlice { resetAll(): void; } + +export function createResetSlice(): Slice { + return (set) => ({ + resetAll() { + set((prev) => { + const next = { ...prev }; + next.resetImg2Img(); + next.resetInpaint(); + next.resetTxt2Img(); + next.resetUpscale(); + next.resetBlend(); + return next; + }); + }, + }); +} diff --git a/gui/src/state/txt2img.ts b/gui/src/state/txt2img.ts index d8cd95eb..8bec3273 100644 --- a/gui/src/state/txt2img.ts +++ b/gui/src/state/txt2img.ts @@ -1,11 +1,13 @@ import { PipelineGrid } from '../client/utils.js'; +import { ServerParams } from '../config.js'; import { + BaseImgParams, HighresParams, ModelParams, Txt2ImgParams, UpscaleParams, } from '../types/params.js'; -import { TabState } from './types.js'; +import { Slice, TabState } from './types.js'; export interface Txt2ImgSlice { txt2img: TabState; @@ -22,3 +24,82 @@ export interface Txt2ImgSlice { setTxt2ImgUpscale(params: Partial): void; setTxt2ImgVariable(params: Partial): void; } + +// eslint-disable-next-line max-params +export function createTxt2ImgSlice( + server: ServerParams, + defaultParams: Required, + defaultHighres: HighresParams, + defaultModel: ModelParams, + defaultUpscale: UpscaleParams, + defaultGrid: PipelineGrid, +): Slice { + return (set) => ({ + txt2img: { + ...defaultParams, + width: server.width.default, + height: server.height.default, + }, + txt2imgHighres: { + ...defaultHighres, + }, + txt2imgModel: { + ...defaultModel, + }, + txt2imgUpscale: { + ...defaultUpscale, + }, + txt2imgVariable: { + ...defaultGrid, + }, + setTxt2Img(params) { + set((prev) => ({ + txt2img: { + ...prev.txt2img, + ...params, + }, + } as Partial)); + }, + setTxt2ImgHighres(params) { + set((prev) => ({ + txt2imgHighres: { + ...prev.txt2imgHighres, + ...params, + }, + } as Partial)); + }, + setTxt2ImgModel(params) { + set((prev) => ({ + txt2imgModel: { + ...prev.txt2imgModel, + ...params, + }, + } as Partial)); + }, + setTxt2ImgUpscale(params) { + set((prev) => ({ + txt2imgUpscale: { + ...prev.txt2imgUpscale, + ...params, + }, + } as Partial)); + }, + setTxt2ImgVariable(params) { + set((prev) => ({ + txt2imgVariable: { + ...prev.txt2imgVariable, + ...params, + }, + } as Partial)); + }, + resetTxt2Img() { + set({ + txt2img: { + ...defaultParams, + width: server.width.default, + height: server.height.default, + }, + } as Partial); + }, + }); +} diff --git a/gui/src/state/types.ts b/gui/src/state/types.ts index 98843c86..3b2144bd 100644 --- a/gui/src/state/types.ts +++ b/gui/src/state/types.ts @@ -1,4 +1,5 @@ import { PaletteMode } from '@mui/material'; +import { StateCreator } from 'zustand'; import { ConfigFiles, ConfigState } from '../config.js'; export const MISSING_INDEX = -1; @@ -9,3 +10,37 @@ export type Theme = PaletteMode | ''; // tri-state, '' is unset * Combine optional files and required ranges. */ export type TabState = ConfigFiles> & ConfigState>; + +/** + * Shorthand for state creator to reduce repeated arguments. + */ +export type Slice = StateCreator; + +/** + * Default parameters for the inpaint brush. + * + * Not provided by the server yet. + */ +export const DEFAULT_BRUSH = { + color: 255, + size: 8, + strength: 0.5, +}; + +/** + * Default parameters for the image history. + * + * Not provided by the server yet. + */ +export const DEFAULT_HISTORY = { + /** + * The number of images to be shown. + */ + limit: 4, + + /** + * The number of additional images to be kept in history, so they can scroll + * back into view when you delete one. Does not include deleted images. + */ + scrollback: 2, +}; diff --git a/gui/src/state/upscale.ts b/gui/src/state/upscale.ts index af0a344a..e78d689a 100644 --- a/gui/src/state/upscale.ts +++ b/gui/src/state/upscale.ts @@ -1,10 +1,11 @@ import { + BaseImgParams, HighresParams, ModelParams, UpscaleParams, UpscaleReqParams, } from '../types/params.js'; -import { TabState } from './types.js'; +import { Slice, TabState } from './types.js'; export interface UpscaleSlice { upscale: TabState; @@ -19,3 +20,68 @@ export interface UpscaleSlice { setUpscaleModel(params: Partial): void; setUpscaleUpscale(params: Partial): void; } + +export function createUpscaleSlice( + defaultParams: Required, + defaultHighres: HighresParams, + defaultModel: ModelParams, + defaultUpscale: UpscaleParams, +): Slice { + return (set) => ({ + upscale: { + ...defaultParams, + // eslint-disable-next-line no-null/no-null + source: null, + }, + upscaleHighres: { + ...defaultHighres, + }, + upscaleModel: { + ...defaultModel, + }, + upscaleUpscale: { + ...defaultUpscale, + }, + resetUpscale() { + set({ + upscale: { + ...defaultParams, + // eslint-disable-next-line no-null/no-null + source: null, + }, + } as Partial); + }, + setUpscale(source) { + set((prev) => ({ + upscale: { + ...prev.upscale, + ...source, + }, + } as Partial)); + }, + setUpscaleHighres(params) { + set((prev) => ({ + upscaleHighres: { + ...prev.upscaleHighres, + ...params, + }, + } as Partial)); + }, + setUpscaleModel(params) { + set((prev) => ({ + upscaleModel: { + ...prev.upscaleModel, + ...defaultModel, + }, + } as Partial)); + }, + setUpscaleUpscale(params) { + set((prev) => ({ + upscaleUpscale: { + ...prev.upscaleUpscale, + ...params, + }, + } as Partial)); + }, + }); +}