1
0
Fork 0

break up state slice factories

This commit is contained in:
Sean Sube 2023-12-14 21:30:06 -06:00
parent d86286ab1e
commit 0dfc1b61d2
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
39 changed files with 951 additions and 868 deletions

View File

@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { OnnxState, StateContext } from '../state.js'; import { OnnxState, StateContext } from '../state/full.js';
import { ErrorCard } from './card/ErrorCard.js'; import { ErrorCard } from './card/ErrorCard.js';
import { ImageCard } from './card/ImageCard.js'; import { ImageCard } from './card/ImageCard.js';
import { LoadingCard } from './card/LoadingCard.js'; import { LoadingCard } from './card/LoadingCard.js';

View File

@ -2,7 +2,7 @@ import { Box, Button, Container, Stack, Typography } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { ReactNode } from 'react'; import { ReactNode } from 'react';
import { STATE_KEY } from '../state.js'; import { STATE_KEY } from '../state/full.js';
import { Logo } from './Logo.js'; import { Logo } from './Logo.js';
export interface OnnxErrorProps { export interface OnnxErrorProps {

View File

@ -7,7 +7,7 @@ import { useContext, useMemo } from 'react';
import { useHash } from 'react-use/lib/useHash'; import { useHash } from 'react-use/lib/useHash';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { OnnxState, StateContext } from '../state.js'; import { OnnxState, StateContext } from '../state/full.js';
import { ImageHistory } from './ImageHistory.js'; import { ImageHistory } from './ImageHistory.js';
import { Logo } from './Logo.js'; import { Logo } from './Logo.js';
import { Blend } from './tab/Blend.js'; import { Blend } from './tab/Blend.js';

View File

@ -21,7 +21,7 @@ import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { OnnxState, StateContext } from '../state.js'; import { OnnxState, StateContext } from '../state/full.js';
import { ImageMetadata } from '../types/api.js'; import { ImageMetadata } from '../types/api.js';
import { DeepPartial } from '../types/model.js'; import { DeepPartial } from '../types/model.js';
import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js'; import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js';

View File

@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js'; import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { ImageResponse, ReadyResponse, RetryParams } from '../../types/api.js'; import { ImageResponse, ReadyResponse, RetryParams } from '../../types/api.js';
export interface ErrorCardProps { export interface ErrorCardProps {

View File

@ -8,7 +8,7 @@ import { useHash } from 'react-use/lib/useHash';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { BLEND_SOURCES, ConfigContext, OnnxState, StateContext } from '../../state.js'; import { BLEND_SOURCES, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { ImageResponse } from '../../types/api.js'; import { ImageResponse } from '../../types/api.js';
import { range, visibleIndex } from '../../utils.js'; import { range, visibleIndex } from '../../utils.js';

View File

@ -9,7 +9,7 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { POLL_TIME } from '../../config.js'; import { POLL_TIME } from '../../config.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js'; import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { ImageResponse } from '../../types/api.js'; import { ImageResponse } from '../../types/api.js';
const LOADING_PERCENT = 100; const LOADING_PERCENT = 100;

View File

@ -5,7 +5,7 @@ import { useContext } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { ConfigContext, OnnxState, StateContext } from '../../state.js'; import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { HighresParams } from '../../types/params.js'; import { HighresParams } from '../../types/params.js';
import { NumericField } from '../input/NumericField.js'; import { NumericField } from '../input/NumericField.js';

View File

@ -11,7 +11,7 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { STALE_TIME } from '../../config.js'; import { STALE_TIME } from '../../config.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js'; import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { BaseImgParams } from '../../types/params.js'; import { BaseImgParams } from '../../types/params.js';
import { NumericField } from '../input/NumericField.js'; import { NumericField } from '../input/NumericField.js';
import { PromptInput } from '../input/PromptInput.js'; import { PromptInput } from '../input/PromptInput.js';

View File

@ -6,7 +6,7 @@ import { useContext } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { STALE_TIME } from '../../config.js'; import { STALE_TIME } from '../../config.js';
import { ClientContext } from '../../state.js'; import { ClientContext } from '../../state/full.js';
import { ModelParams } from '../../types/params.js'; import { ModelParams } from '../../types/params.js';
import { QueryList } from '../input/QueryList.js'; import { QueryList } from '../input/QueryList.js';

View File

@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { ConfigContext, OnnxState, StateContext } from '../../state.js'; import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { NumericField } from '../input/NumericField.js'; import { NumericField } from '../input/NumericField.js';
export function OutpaintControl() { export function OutpaintControl() {

View File

@ -5,7 +5,7 @@ import { useContext } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { ConfigContext, OnnxState, StateContext } from '../../state.js'; import { ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { UpscaleParams } from '../../types/params.js'; import { UpscaleParams } from '../../types/params.js';
import { NumericField } from '../input/NumericField.js'; import { NumericField } from '../input/NumericField.js';

View File

@ -5,7 +5,7 @@ import { useContext } from 'react';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { PipelineGrid } from '../../client/utils.js'; import { PipelineGrid } from '../../client/utils.js';
import { OnnxState, StateContext } from '../../state.js'; import { OnnxState, StateContext } from '../../state/full.js';
import { VARIABLE_PARAMETERS } from '../../types/chain.js'; import { VARIABLE_PARAMETERS } from '../../types/chain.js';
export interface VariableControlProps { export interface VariableControlProps {

View File

@ -4,7 +4,7 @@ import * as React from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { OnnxState, StateContext } from '../../state.js'; import { OnnxState, StateContext } from '../../state/full.js';
const { useContext, useState, memo, useMemo } = React; const { useContext, useState, memo, useMemo } = React;

View File

@ -6,7 +6,7 @@ import React, { RefObject, useContext, useEffect, useMemo, useRef } from 'react'
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { SAVE_TIME } from '../../config.js'; import { SAVE_TIME } from '../../config.js';
import { ConfigContext, LoggerContext, StateContext } from '../../state.js'; import { ConfigContext, LoggerContext, StateContext } from '../../state/full.js';
import { BrushParams } from '../../types/params.js'; import { BrushParams } from '../../types/params.js';
import { imageFromBlob } from '../../utils.js'; import { imageFromBlob } from '../../utils.js';
import { NumericField } from './NumericField'; import { NumericField } from './NumericField';

View File

@ -8,7 +8,7 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { STALE_TIME } from '../../config.js'; import { STALE_TIME } from '../../config.js';
import { ClientContext, OnnxState, StateContext } from '../../state.js'; import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
import { QueryMenu } from '../input/QueryMenu.js'; import { QueryMenu } from '../input/QueryMenu.js';
import { ModelResponse } from '../../types/api.js'; import { ModelResponse } from '../../types/api.js';

View File

@ -8,7 +8,8 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { IMAGE_FILTER } from '../../config.js'; import { IMAGE_FILTER } from '../../config.js';
import { BLEND_SOURCES, ClientContext, OnnxState, StateContext, TabState } from '../../state.js'; import { BLEND_SOURCES, ClientContext, OnnxState, StateContext } from '../../state/full.js';
import { TabState } from '../../state/types.js';
import { BlendParams, BrushParams, ModelParams, UpscaleParams } from '../../types/params.js'; import { BlendParams, BrushParams, ModelParams, UpscaleParams } from '../../types/params.js';
import { range } from '../../utils.js'; import { range } from '../../utils.js';
import { UpscaleControl } from '../control/UpscaleControl.js'; import { UpscaleControl } from '../control/UpscaleControl.js';

View File

@ -8,7 +8,8 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { IMAGE_FILTER, STALE_TIME } from '../../config.js'; import { IMAGE_FILTER, STALE_TIME } from '../../config.js';
import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { TabState } from '../../state/types.js';
import { HighresParams, Img2ImgParams, ModelParams, UpscaleParams } from '../../types/params.js'; import { HighresParams, Img2ImgParams, ModelParams, UpscaleParams } from '../../types/params.js';
import { Profiles } from '../Profiles.js'; import { Profiles } from '../Profiles.js';
import { HighresControl } from '../control/HighresControl.js'; import { HighresControl } from '../control/HighresControl.js';

View File

@ -8,7 +8,8 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { IMAGE_FILTER, STALE_TIME } from '../../config.js'; import { IMAGE_FILTER, STALE_TIME } from '../../config.js';
import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { TabState } from '../../state/types.js';
import { BrushParams, HighresParams, InpaintParams, ModelParams, UpscaleParams } from '../../types/params.js'; import { BrushParams, HighresParams, InpaintParams, ModelParams, UpscaleParams } from '../../types/params.js';
import { Profiles } from '../Profiles.js'; import { Profiles } from '../Profiles.js';
import { HighresControl } from '../control/HighresControl.js'; import { HighresControl } from '../control/HighresControl.js';

View File

@ -8,7 +8,7 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { STALE_TIME } from '../../config.js'; import { STALE_TIME } from '../../config.js';
import { ClientContext, OnnxState, StateContext } from '../../state.js'; import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
import { import {
CorrectionModel, CorrectionModel,
DiffusionModel, DiffusionModel,

View File

@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { getApiRoot } from '../../config.js'; import { getApiRoot } from '../../config.js';
import { ConfigContext, StateContext, STATE_KEY } from '../../state.js'; import { ConfigContext, StateContext, STATE_KEY } from '../../state/full.js';
import { getTheme } from '../utils.js'; import { getTheme } from '../utils.js';
import { NumericField } from '../input/NumericField.js'; import { NumericField } from '../input/NumericField.js';

View File

@ -8,7 +8,8 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { PipelineGrid, makeTxt2ImgGridPipeline } from '../../client/utils.js'; import { PipelineGrid, makeTxt2ImgGridPipeline } from '../../client/utils.js';
import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state/full.js';
import { TabState } from '../../state/types.js';
import { HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from '../../types/params.js'; import { HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from '../../types/params.js';
import { Profiles } from '../Profiles.js'; import { Profiles } from '../Profiles.js';
import { HighresControl } from '../control/HighresControl.js'; import { HighresControl } from '../control/HighresControl.js';

View File

@ -8,7 +8,8 @@ import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow'; import { shallow } from 'zustand/shallow';
import { IMAGE_FILTER } from '../../config.js'; import { IMAGE_FILTER } from '../../config.js';
import { ClientContext, OnnxState, StateContext, TabState } from '../../state.js'; import { ClientContext, OnnxState, StateContext } from '../../state/full.js';
import { TabState } from '../../state/types.js';
import { HighresParams, ModelParams, UpscaleParams, UpscaleReqParams } from '../../types/params.js'; import { HighresParams, ModelParams, UpscaleParams, UpscaleReqParams } from '../../types/params.js';
import { Profiles } from '../Profiles.js'; import { Profiles } from '../Profiles.js';
import { HighresControl } from '../control/HighresControl.js'; import { HighresControl } from '../control/HighresControl.js';

View File

@ -1,6 +1,6 @@
import { PaletteMode } from '@mui/material'; import { PaletteMode } from '@mui/material';
import { Theme } from '../state.js'; import { Theme } from '../state/types.js';
import { trimHash } from '../utils.js'; import { trimHash } from '../utils.js';
export const TAB_LABELS = [ export const TAB_LABELS = [

View File

@ -28,7 +28,7 @@ import {
STATE_KEY, STATE_KEY,
STATE_VERSION, STATE_VERSION,
StateContext, StateContext,
} from './state.js'; } from './state/full.js';
import { I18N_STRINGS } from './strings/all.js'; import { I18N_STRINGS } from './strings/all.js';
export const INITIAL_LOAD_TIMEOUT = 5_000; export const INITIAL_LOAD_TIMEOUT = 5_000;

View File

@ -1,817 +0,0 @@
/* eslint-disable camelcase */
/* eslint-disable max-lines */
/* eslint-disable no-null/no-null */
import { Maybe } from '@apextoaster/js-utils';
import { Logger } from 'noicejs';
import { createContext } from 'react';
import { StateCreator, StoreApi } from 'zustand';
import {
ApiClient,
} from './client/base.js';
import { PipelineGrid } from './client/utils.js';
import { Config, ServerParams } from './config.js';
import {
BaseImgParams,
HighresParams,
ModelParams,
UpscaleParams,
} from './types/params.js';
import { DefaultSlice } from './state/default.js';
import { HistorySlice } from './state/history.js';
import { Img2ImgSlice } from './state/img2img.js';
import { InpaintSlice } from './state/inpaint.js';
import { ModelSlice } from './state/models.js';
import { Txt2ImgSlice } from './state/txt2img.js';
import { UpscaleSlice } from './state/upscale.js';
import { ResetSlice } from './state/reset.js';
import { ProfileItem, ProfileSlice } from './state/profile.js';
import { BlendSlice } from './state/blend.js';
import { MISSING_INDEX } from './state/types.js';
/**
* Full merged state including all slices.
*/
export type OnnxState
= DefaultSlice
& HistorySlice
& Img2ImgSlice
& InpaintSlice
& ModelSlice
& Txt2ImgSlice
& UpscaleSlice
& BlendSlice
& ResetSlice
& ProfileSlice;
/**
* Shorthand for state creator to reduce repeated arguments.
*/
export type Slice<T> = StateCreator<OnnxState, [], [], T>;
/**
* React context binding for API client.
*/
export const ClientContext = createContext<Maybe<ApiClient>>(undefined);
/**
* React context binding for merged config, including server parameters.
*/
export const ConfigContext = createContext<Maybe<Config<ServerParams>>>(undefined);
/**
* React context binding for bunyan logger.
*/
export const LoggerContext = createContext<Maybe<Logger>>(undefined);
/**
* React context binding for zustand state store.
*/
export const StateContext = createContext<Maybe<StoreApi<OnnxState>>>(undefined);
/**
* Key for zustand persistence, typically local storage.
*/
export const STATE_KEY = 'onnx-web';
/**
* Current state version for zustand persistence.
*/
export const STATE_VERSION = 7;
export const BLEND_SOURCES = 2;
/**
* Default parameters for the inpaint brush.
*
* Not provided by the server yet.
*/
export const DEFAULT_BRUSH = {
color: 255,
size: 8,
strength: 0.5,
};
/**
* Default parameters for the image history.
*
* Not provided by the server yet.
*/
export const DEFAULT_HISTORY = {
/**
* The number of images to be shown.
*/
limit: 4,
/**
* The number of additional images to be kept in history, so they can scroll
* back into view when you delete one. Does not include deleted images.
*/
scrollback: 2,
};
export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
return {
batch: defaults.batch.default,
cfg: defaults.cfg.default,
eta: defaults.eta.default,
negativePrompt: defaults.negativePrompt.default,
prompt: defaults.prompt.default,
scheduler: defaults.scheduler.default,
steps: defaults.steps.default,
seed: defaults.seed.default,
tiled_vae: defaults.tiled_vae.default,
unet_overlap: defaults.unet_overlap.default,
unet_tile: defaults.unet_tile.default,
vae_overlap: defaults.vae_overlap.default,
vae_tile: defaults.vae_tile.default,
};
}
/**
* Prepare the state slice constructors.
*
* In the default state, image sources should be null and booleans should be false. Everything
* else should be initialized from the default value in the base parameters.
*/
export function createStateSlices(server: ServerParams) {
const defaultParams = baseParamsFromServer(server);
const defaultHighres: HighresParams = {
enabled: false,
highresIterations: server.highresIterations.default,
highresMethod: '',
highresSteps: server.highresSteps.default,
highresScale: server.highresScale.default,
highresStrength: server.highresStrength.default,
};
const defaultModel: ModelParams = {
control: server.control.default,
correction: server.correction.default,
model: server.model.default,
pipeline: server.pipeline.default,
platform: server.platform.default,
upscaling: server.upscaling.default,
};
const defaultUpscale: UpscaleParams = {
denoise: server.denoise.default,
enabled: false,
faces: false,
faceOutscale: server.faceOutscale.default,
faceStrength: server.faceStrength.default,
outscale: server.outscale.default,
scale: server.scale.default,
upscaleOrder: server.upscaleOrder.default,
};
const defaultGrid: PipelineGrid = {
enabled: false,
columns: {
parameter: 'seed',
value: '',
},
rows: {
parameter: 'seed',
value: '',
},
};
const createTxt2ImgSlice: Slice<Txt2ImgSlice> = (set) => ({
txt2img: {
...defaultParams,
width: server.width.default,
height: server.height.default,
},
txt2imgHighres: {
...defaultHighres,
},
txt2imgModel: {
...defaultModel,
},
txt2imgUpscale: {
...defaultUpscale,
},
txt2imgVariable: {
...defaultGrid,
},
setTxt2Img(params) {
set((prev) => ({
txt2img: {
...prev.txt2img,
...params,
},
}));
},
setTxt2ImgHighres(params) {
set((prev) => ({
txt2imgHighres: {
...prev.txt2imgHighres,
...params,
},
}));
},
setTxt2ImgModel(params) {
set((prev) => ({
txt2imgModel: {
...prev.txt2imgModel,
...params,
},
}));
},
setTxt2ImgUpscale(params) {
set((prev) => ({
txt2imgUpscale: {
...prev.txt2imgUpscale,
...params,
},
}));
},
setTxt2ImgVariable(params) {
set((prev) => ({
txt2imgVariable: {
...prev.txt2imgVariable,
...params,
},
}));
},
resetTxt2Img() {
set({
txt2img: {
...defaultParams,
width: server.width.default,
height: server.height.default,
},
});
},
});
const createImg2ImgSlice: Slice<Img2ImgSlice> = (set) => ({
img2img: {
...defaultParams,
loopback: server.loopback.default,
source: null,
sourceFilter: '',
strength: server.strength.default,
},
img2imgHighres: {
...defaultHighres,
},
img2imgModel: {
...defaultModel,
},
img2imgUpscale: {
...defaultUpscale,
},
resetImg2Img() {
set({
img2img: {
...defaultParams,
loopback: server.loopback.default,
source: null,
sourceFilter: '',
strength: server.strength.default,
},
});
},
setImg2Img(params) {
set((prev) => ({
img2img: {
...prev.img2img,
...params,
},
}));
},
setImg2ImgHighres(params) {
set((prev) => ({
img2imgHighres: {
...prev.img2imgHighres,
...params,
},
}));
},
setImg2ImgModel(params) {
set((prev) => ({
img2imgModel: {
...prev.img2imgModel,
...params,
},
}));
},
setImg2ImgUpscale(params) {
set((prev) => ({
img2imgUpscale: {
...prev.img2imgUpscale,
...params,
},
}));
},
});
const createInpaintSlice: Slice<InpaintSlice> = (set) => ({
inpaint: {
...defaultParams,
fillColor: server.fillColor.default,
filter: server.filter.default,
mask: null,
noise: server.noise.default,
source: null,
strength: server.strength.default,
tileOrder: server.tileOrder.default,
},
inpaintBrush: {
...DEFAULT_BRUSH,
},
inpaintHighres: {
...defaultHighres,
},
inpaintModel: {
...defaultModel,
},
inpaintUpscale: {
...defaultUpscale,
},
outpaint: {
enabled: false,
left: server.left.default,
right: server.right.default,
top: server.top.default,
bottom: server.bottom.default,
},
resetInpaint() {
set({
inpaint: {
...defaultParams,
fillColor: server.fillColor.default,
filter: server.filter.default,
mask: null,
noise: server.noise.default,
source: null,
strength: server.strength.default,
tileOrder: server.tileOrder.default,
},
});
},
setInpaint(params) {
set((prev) => ({
inpaint: {
...prev.inpaint,
...params,
},
}));
},
setInpaintBrush(brush) {
set((prev) => ({
inpaintBrush: {
...prev.inpaintBrush,
...brush,
},
}));
},
setInpaintHighres(params) {
set((prev) => ({
inpaintHighres: {
...prev.inpaintHighres,
...params,
},
}));
},
setInpaintModel(params) {
set((prev) => ({
inpaintModel: {
...prev.inpaintModel,
...params,
},
}));
},
setInpaintUpscale(params) {
set((prev) => ({
inpaintUpscale: {
...prev.inpaintUpscale,
...params,
},
}));
},
setOutpaint(pixels) {
set((prev) => ({
outpaint: {
...prev.outpaint,
...pixels,
}
}));
},
});
const createHistorySlice: Slice<HistorySlice> = (set) => ({
history: [],
limit: DEFAULT_HISTORY.limit,
pushHistory(image, retry) {
set((prev) => ({
...prev,
history: [
{
image,
ready: undefined,
retry,
},
...prev.history,
].slice(0, prev.limit + DEFAULT_HISTORY.scrollback),
}));
},
removeHistory(image) {
set((prev) => ({
...prev,
history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key),
}));
},
setLimit(limit) {
set((prev) => ({
...prev,
limit,
}));
},
setReady(image, ready) {
set((prev) => {
const history = [...prev.history];
const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key);
if (idx >= 0) {
history[idx].ready = ready;
} else {
// TODO: error
}
return {
...prev,
history,
};
});
},
});
const createUpscaleSlice: Slice<UpscaleSlice> = (set) => ({
upscale: {
...defaultParams,
source: null,
},
upscaleHighres: {
...defaultHighres,
},
upscaleModel: {
...defaultModel,
},
upscaleUpscale: {
...defaultUpscale,
},
resetUpscale() {
set({
upscale: {
...defaultParams,
source: null,
},
});
},
setUpscale(source) {
set((prev) => ({
upscale: {
...prev.upscale,
...source,
},
}));
},
setUpscaleHighres(params) {
set((prev) => ({
upscaleHighres: {
...prev.upscaleHighres,
...params,
},
}));
},
setUpscaleModel(params) {
set((prev) => ({
upscaleModel: {
...prev.upscaleModel,
...defaultModel,
},
}));
},
setUpscaleUpscale(params) {
set((prev) => ({
upscaleUpscale: {
...prev.upscaleUpscale,
...params,
},
}));
},
});
const createBlendSlice: Slice<BlendSlice> = (set) => ({
blend: {
mask: null,
sources: [],
},
blendBrush: {
...DEFAULT_BRUSH,
},
blendModel: {
...defaultModel,
},
blendUpscale: {
...defaultUpscale,
},
resetBlend() {
set({
blend: {
mask: null,
sources: [],
},
});
},
setBlend(blend) {
set((prev) => ({
blend: {
...prev.blend,
...blend,
},
}));
},
setBlendBrush(brush) {
set((prev) => ({
blendBrush: {
...prev.blendBrush,
...brush,
},
}));
},
setBlendModel(model) {
set((prev) => ({
blendModel: {
...prev.blendModel,
...model,
},
}));
},
setBlendUpscale(params) {
set((prev) => ({
blendUpscale: {
...prev.blendUpscale,
...params,
},
}));
},
});
const createDefaultSlice: Slice<DefaultSlice> = (set) => ({
defaults: {
...defaultParams,
},
theme: '',
setDefaults(params) {
set((prev) => ({
defaults: {
...prev.defaults,
...params,
}
}));
},
setTheme(theme) {
set((prev) => ({
theme,
}));
}
});
const createResetSlice: Slice<ResetSlice> = (set) => ({
resetAll() {
set((prev) => {
const next = { ...prev };
next.resetImg2Img();
next.resetInpaint();
next.resetTxt2Img();
next.resetUpscale();
next.resetBlend();
return next;
});
},
});
const createProfileSlice: Slice<ProfileSlice> = (set) => ({
profiles: [],
saveProfile(profile: ProfileItem) {
set((prev) => {
const profiles = [...prev.profiles];
const idx = profiles.findIndex((it) => it.name === profile.name);
if (idx >= 0) {
profiles[idx] = profile;
} else {
profiles.push(profile);
}
return {
...prev,
profiles,
};
});
},
removeProfile(profileName: string) {
set((prev) => {
const profiles = [...prev.profiles];
const idx = profiles.findIndex((it) => it.name === profileName);
if (idx >= 0) {
profiles.splice(idx, 1);
}
return {
...prev,
profiles,
};
});
}
});
// eslint-disable-next-line sonarjs/cognitive-complexity
const createModelSlice: Slice<ModelSlice> = (set) => ({
extras: {
correction: [],
diffusion: [],
networks: [],
sources: [],
upscaling: [],
},
setExtras(extras) {
set((prev) => ({
...prev,
extras: {
...prev.extras,
...extras,
},
}));
},
setCorrectionModel(model) {
set((prev) => {
const correction = [...prev.extras.correction];
const exists = correction.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
correction.push(model);
} else {
correction[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
correction,
},
};
});
},
setDiffusionModel(model) {
set((prev) => {
const diffusion = [...prev.extras.diffusion];
const exists = diffusion.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
diffusion.push(model);
} else {
diffusion[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
diffusion,
},
};
});
},
setExtraNetwork(model) {
set((prev) => {
const networks = [...prev.extras.networks];
const exists = networks.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
networks.push(model);
} else {
networks[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
networks,
},
};
});
},
setExtraSource(model) {
set((prev) => {
const sources = [...prev.extras.sources];
const exists = sources.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
sources.push(model);
} else {
sources[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
sources,
},
};
});
},
setUpscalingModel(model) {
set((prev) => {
const upscaling = [...prev.extras.upscaling];
const exists = upscaling.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
upscaling.push(model);
} else {
upscaling[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
upscaling,
},
};
});
},
removeCorrectionModel(model) {
set((prev) => {
const correction = prev.extras.correction.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
correction,
},
};
});
},
removeDiffusionModel(model) {
set((prev) => {
const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
diffusion,
},
};
});
},
removeExtraNetwork(model) {
set((prev) => {
const networks = prev.extras.networks.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
networks,
},
};
});
},
removeExtraSource(model) {
set((prev) => {
const sources = prev.extras.sources.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
sources,
},
};
});
},
removeUpscalingModel(model) {
set((prev) => {
const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
upscaling,
},
};
});
},
});
return {
createDefaultSlice,
createHistorySlice,
createImg2ImgSlice,
createInpaintSlice,
createTxt2ImgSlice,
createUpscaleSlice,
createBlendSlice,
createResetSlice,
createModelSlice,
createProfileSlice,
};
}

View File

@ -4,7 +4,7 @@ import {
ModelParams, ModelParams,
UpscaleParams, UpscaleParams,
} from '../types/params.js'; } from '../types/params.js';
import { TabState } from './types.js'; import { DEFAULT_BRUSH, Slice, TabState } from './types.js';
export interface BlendSlice { export interface BlendSlice {
blend: TabState<BlendParams>; blend: TabState<BlendParams>;
@ -19,3 +19,66 @@ export interface BlendSlice {
setBlendModel(model: Partial<ModelParams>): void; setBlendModel(model: Partial<ModelParams>): void;
setBlendUpscale(params: Partial<UpscaleParams>): void; setBlendUpscale(params: Partial<UpscaleParams>): void;
} }
export function createBlendSlice<TState extends BlendSlice>(
defaultModel: ModelParams,
defaultUpscale: UpscaleParams,
): Slice<TState, BlendSlice> {
return (set) => ({
blend: {
// eslint-disable-next-line no-null/no-null
mask: null,
sources: [],
},
blendBrush: {
...DEFAULT_BRUSH,
},
blendModel: {
...defaultModel,
},
blendUpscale: {
...defaultUpscale,
},
resetBlend() {
set((prev) => ({
blend: {
// eslint-disable-next-line no-null/no-null
mask: null,
sources: [] as Array<Blob>,
},
} as Partial<TState>));
},
setBlend(blend) {
set((prev) => ({
blend: {
...prev.blend,
...blend,
},
} as Partial<TState>));
},
setBlendBrush(brush) {
set((prev) => ({
blendBrush: {
...prev.blendBrush,
...brush,
},
} as Partial<TState>));
},
setBlendModel(model) {
set((prev) => ({
blendModel: {
...prev.blendModel,
...model,
},
} as Partial<TState>));
},
setBlendUpscale(params) {
set((prev) => ({
blendUpscale: {
...prev.blendUpscale,
...params,
},
} as Partial<TState>));
},
});
}

View File

@ -1,7 +1,7 @@
import { import {
BaseImgParams, BaseImgParams,
} from '../types/params.js'; } from '../types/params.js';
import { TabState, Theme } from './types.js'; import { Slice, TabState, Theme } from './types.js';
export interface DefaultSlice { export interface DefaultSlice {
defaults: TabState<BaseImgParams>; defaults: TabState<BaseImgParams>;
@ -10,3 +10,25 @@ export interface DefaultSlice {
setDefaults(param: Partial<BaseImgParams>): void; setDefaults(param: Partial<BaseImgParams>): void;
setTheme(theme: Theme): void; setTheme(theme: Theme): void;
} }
export function createDefaultSlice<TState extends DefaultSlice>(defaultParams: Required<BaseImgParams>): Slice<TState, DefaultSlice> {
return (set) => ({
defaults: {
...defaultParams,
},
theme: '',
setDefaults(params) {
set((prev) => ({
defaults: {
...prev.defaults,
...params,
}
} as Partial<TState>));
},
setTheme(theme) {
set((prev) => ({
theme,
} as Partial<TState>));
}
});
}

152
gui/src/state/full.ts Normal file
View File

@ -0,0 +1,152 @@
/* eslint-disable camelcase */
import { Maybe } from '@apextoaster/js-utils';
import { Logger } from 'noicejs';
import { createContext } from 'react';
import { StoreApi } from 'zustand';
import {
ApiClient,
} from '../client/base.js';
import { PipelineGrid } from '../client/utils.js';
import { Config, ServerParams } from '../config.js';
import { BlendSlice, createBlendSlice } from './blend.js';
import { DefaultSlice, createDefaultSlice } from './default.js';
import { HistorySlice, createHistorySlice } from './history.js';
import { Img2ImgSlice, createImg2ImgSlice } from './img2img.js';
import { InpaintSlice, createInpaintSlice } from './inpaint.js';
import { ModelSlice, createModelSlice } from './model.js';
import { ProfileSlice, createProfileSlice } from './profile.js';
import { ResetSlice, createResetSlice } from './reset.js';
import { Txt2ImgSlice, createTxt2ImgSlice } from './txt2img.js';
import { UpscaleSlice, createUpscaleSlice } from './upscale.js';
import {
BaseImgParams,
HighresParams,
ModelParams,
UpscaleParams,
} from '../types/params.js';
/**
* Full merged state including all slices.
*/
export type OnnxState
= DefaultSlice
& HistorySlice
& Img2ImgSlice
& InpaintSlice
& ModelSlice
& Txt2ImgSlice
& UpscaleSlice
& BlendSlice
& ResetSlice
& ProfileSlice;
/**
* React context binding for API client.
*/
export const ClientContext = createContext<Maybe<ApiClient>>(undefined);
/**
* React context binding for merged config, including server parameters.
*/
export const ConfigContext = createContext<Maybe<Config<ServerParams>>>(undefined);
/**
* React context binding for bunyan logger.
*/
export const LoggerContext = createContext<Maybe<Logger>>(undefined);
/**
* React context binding for zustand state store.
*/
export const StateContext = createContext<Maybe<StoreApi<OnnxState>>>(undefined);
/**
* Key for zustand persistence, typically local storage.
*/
export const STATE_KEY = 'onnx-web';
/**
* Current state version for zustand persistence.
*/
export const STATE_VERSION = 7;
export const BLEND_SOURCES = 2;
export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgParams> {
return {
batch: defaults.batch.default,
cfg: defaults.cfg.default,
eta: defaults.eta.default,
negativePrompt: defaults.negativePrompt.default,
prompt: defaults.prompt.default,
scheduler: defaults.scheduler.default,
steps: defaults.steps.default,
seed: defaults.seed.default,
tiled_vae: defaults.tiled_vae.default,
unet_overlap: defaults.unet_overlap.default,
unet_tile: defaults.unet_tile.default,
vae_overlap: defaults.vae_overlap.default,
vae_tile: defaults.vae_tile.default,
};
}
/**
* Prepare the state slice constructors.
*
* In the default state, image sources should be null and booleans should be false. Everything
* else should be initialized from the default value in the base parameters.
*/
export function createStateSlices(server: ServerParams) {
const defaultParams = baseParamsFromServer(server);
const defaultHighres: HighresParams = {
enabled: false,
highresIterations: server.highresIterations.default,
highresMethod: '',
highresSteps: server.highresSteps.default,
highresScale: server.highresScale.default,
highresStrength: server.highresStrength.default,
};
const defaultModel: ModelParams = {
control: server.control.default,
correction: server.correction.default,
model: server.model.default,
pipeline: server.pipeline.default,
platform: server.platform.default,
upscaling: server.upscaling.default,
};
const defaultUpscale: UpscaleParams = {
denoise: server.denoise.default,
enabled: false,
faces: false,
faceOutscale: server.faceOutscale.default,
faceStrength: server.faceStrength.default,
outscale: server.outscale.default,
scale: server.scale.default,
upscaleOrder: server.upscaleOrder.default,
};
const defaultGrid: PipelineGrid = {
enabled: false,
columns: {
parameter: 'seed',
value: '',
},
rows: {
parameter: 'seed',
value: '',
},
};
return {
createBlendSlice: createBlendSlice(defaultModel, defaultUpscale),
createDefaultSlice: createDefaultSlice(defaultParams),
createHistorySlice: createHistorySlice(),
createImg2ImgSlice: createImg2ImgSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale),
createInpaintSlice: createInpaintSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale),
createModelSlice: createModelSlice(),
createProfileSlice: createProfileSlice(),
createResetSlice: createResetSlice(),
createTxt2ImgSlice: createTxt2ImgSlice(server, defaultParams, defaultHighres, defaultModel, defaultUpscale, defaultGrid),
createUpscaleSlice: createUpscaleSlice(defaultParams, defaultHighres, defaultModel, defaultUpscale),
};
}

View File

@ -1,5 +1,6 @@
import { Maybe } from '@apextoaster/js-utils'; import { Maybe } from '@apextoaster/js-utils';
import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js'; import { ImageResponse, ReadyResponse, RetryParams } from '../types/api.js';
import { DEFAULT_HISTORY, Slice } from './types.js';
export interface HistoryItem { export interface HistoryItem {
image: ImageResponse; image: ImageResponse;
@ -16,3 +17,51 @@ export interface HistorySlice {
setLimit(limit: number): void; setLimit(limit: number): void;
setReady(image: ImageResponse, ready: ReadyResponse): void; setReady(image: ImageResponse, ready: ReadyResponse): void;
} }
export function createHistorySlice<TState extends HistorySlice>(): Slice<TState, HistorySlice> {
return (set) => ({
history: [],
limit: DEFAULT_HISTORY.limit,
pushHistory(image, retry) {
set((prev) => ({
...prev,
history: [
{
image,
ready: undefined,
retry,
},
...prev.history,
].slice(0, prev.limit + DEFAULT_HISTORY.scrollback),
}));
},
removeHistory(image) {
set((prev) => ({
...prev,
history: prev.history.filter((it) => it.image.outputs[0].key !== image.outputs[0].key),
}));
},
setLimit(limit) {
set((prev) => ({
...prev,
limit,
}));
},
setReady(image, ready) {
set((prev) => {
const history = [...prev.history];
const idx = history.findIndex((it) => it.image.outputs[0].key === image.outputs[0].key);
if (idx >= 0) {
history[idx].ready = ready;
} else {
// TODO: error
}
return {
...prev,
history,
};
});
},
});
}

View File

@ -1,11 +1,13 @@
import { ServerParams } from '../config.js';
import { import {
BaseImgParams,
HighresParams, HighresParams,
Img2ImgParams, Img2ImgParams,
ModelParams, ModelParams,
UpscaleParams, UpscaleParams,
} from '../types/params.js'; } from '../types/params.js';
import { TabState } from './types.js'; import { Slice, TabState } from './types.js';
export interface Img2ImgSlice { export interface Img2ImgSlice {
img2img: TabState<Img2ImgParams>; img2img: TabState<Img2ImgParams>;
@ -20,3 +22,77 @@ export interface Img2ImgSlice {
setImg2ImgHighres(params: Partial<HighresParams>): void; setImg2ImgHighres(params: Partial<HighresParams>): void;
setImg2ImgUpscale(params: Partial<UpscaleParams>): void; setImg2ImgUpscale(params: Partial<UpscaleParams>): void;
} }
// eslint-disable-next-line max-params
export function createImg2ImgSlice<TState extends Img2ImgSlice>(
server: ServerParams,
defaultParams: Required<BaseImgParams>,
defaultHighres: HighresParams,
defaultModel: ModelParams,
defaultUpscale: UpscaleParams
): Slice<TState, Img2ImgSlice> {
return (set) => ({
img2img: {
...defaultParams,
loopback: server.loopback.default,
// eslint-disable-next-line no-null/no-null
source: null,
sourceFilter: '',
strength: server.strength.default,
},
img2imgHighres: {
...defaultHighres,
},
img2imgModel: {
...defaultModel,
},
img2imgUpscale: {
...defaultUpscale,
},
resetImg2Img() {
set({
img2img: {
...defaultParams,
loopback: server.loopback.default,
// eslint-disable-next-line no-null/no-null
source: null,
sourceFilter: '',
strength: server.strength.default,
},
} as Partial<TState>);
},
setImg2Img(params) {
set((prev) => ({
img2img: {
...prev.img2img,
...params,
},
} as Partial<TState>));
},
setImg2ImgHighres(params) {
set((prev) => ({
img2imgHighres: {
...prev.img2imgHighres,
...params,
},
} as Partial<TState>));
},
setImg2ImgModel(params) {
set((prev) => ({
img2imgModel: {
...prev.img2imgModel,
...params,
},
} as Partial<TState>));
},
setImg2ImgUpscale(params) {
set((prev) => ({
img2imgUpscale: {
...prev.img2imgUpscale,
...params,
},
} as Partial<TState>));
},
});
}

View File

@ -1,4 +1,6 @@
import { ServerParams } from '../config.js';
import { import {
BaseImgParams,
BrushParams, BrushParams,
HighresParams, HighresParams,
InpaintParams, InpaintParams,
@ -6,8 +8,7 @@ import {
OutpaintPixels, OutpaintPixels,
UpscaleParams, UpscaleParams,
} from '../types/params.js'; } from '../types/params.js';
import { TabState } from './types.js'; import { DEFAULT_BRUSH, Slice, TabState } from './types.js';
export interface InpaintSlice { export interface InpaintSlice {
inpaint: TabState<InpaintParams>; inpaint: TabState<InpaintParams>;
inpaintBrush: BrushParams; inpaintBrush: BrushParams;
@ -25,3 +26,110 @@ export interface InpaintSlice {
setInpaintUpscale(params: Partial<UpscaleParams>): void; setInpaintUpscale(params: Partial<UpscaleParams>): void;
setOutpaint(pixels: Partial<OutpaintPixels>): void; setOutpaint(pixels: Partial<OutpaintPixels>): void;
} }
// eslint-disable-next-line max-params
export function createInpaintSlice<TState extends InpaintSlice>(
server: ServerParams,
defaultParams: Required<BaseImgParams>,
defaultHighres: HighresParams,
defaultModel: ModelParams,
defaultUpscale: UpscaleParams,
): Slice<TState, InpaintSlice> {
return (set) => ({
inpaint: {
...defaultParams,
fillColor: server.fillColor.default,
filter: server.filter.default,
// eslint-disable-next-line no-null/no-null
mask: null,
noise: server.noise.default,
// eslint-disable-next-line no-null/no-null
source: null,
strength: server.strength.default,
tileOrder: server.tileOrder.default,
},
inpaintBrush: {
...DEFAULT_BRUSH,
},
inpaintHighres: {
...defaultHighres,
},
inpaintModel: {
...defaultModel,
},
inpaintUpscale: {
...defaultUpscale,
},
outpaint: {
enabled: false,
left: server.left.default,
right: server.right.default,
top: server.top.default,
bottom: server.bottom.default,
},
resetInpaint() {
set({
inpaint: {
...defaultParams,
fillColor: server.fillColor.default,
filter: server.filter.default,
// eslint-disable-next-line no-null/no-null
mask: null,
noise: server.noise.default,
// eslint-disable-next-line no-null/no-null
source: null,
strength: server.strength.default,
tileOrder: server.tileOrder.default,
},
} as Partial<TState>);
},
setInpaint(params) {
set((prev) => ({
inpaint: {
...prev.inpaint,
...params,
},
} as Partial<TState>));
},
setInpaintBrush(brush) {
set((prev) => ({
inpaintBrush: {
...prev.inpaintBrush,
...brush,
},
} as Partial<TState>));
},
setInpaintHighres(params) {
set((prev) => ({
inpaintHighres: {
...prev.inpaintHighres,
...params,
},
} as Partial<TState>));
},
setInpaintModel(params) {
set((prev) => ({
inpaintModel: {
...prev.inpaintModel,
...params,
},
} as Partial<TState>));
},
setInpaintUpscale(params) {
set((prev) => ({
inpaintUpscale: {
...prev.inpaintUpscale,
...params,
},
} as Partial<TState>));
},
setOutpaint(pixels) {
set((prev) => ({
outpaint: {
...prev.outpaint,
...pixels,
}
} as Partial<TState>));
},
});
}

202
gui/src/state/model.ts Normal file
View File

@ -0,0 +1,202 @@
import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from '../types/model.js';
import { MISSING_INDEX, Slice } from './types.js';
export interface ModelSlice {
extras: ExtrasFile;
removeCorrectionModel(model: CorrectionModel): void;
removeDiffusionModel(model: DiffusionModel): void;
removeExtraNetwork(model: ExtraNetwork): void;
removeExtraSource(model: ExtraSource): void;
removeUpscalingModel(model: UpscalingModel): void;
setExtras(extras: Partial<ExtrasFile>): void;
setCorrectionModel(model: CorrectionModel): void;
setDiffusionModel(model: DiffusionModel): void;
setExtraNetwork(model: ExtraNetwork): void;
setExtraSource(model: ExtraSource): void;
setUpscalingModel(model: UpscalingModel): void;
}
// eslint-disable-next-line sonarjs/cognitive-complexity
export function createModelSlice<TState extends ModelSlice>(): Slice<TState, ModelSlice> {
// eslint-disable-next-line sonarjs/cognitive-complexity
return (set) => ({
extras: {
correction: [],
diffusion: [],
networks: [],
sources: [],
upscaling: [],
},
setExtras(extras) {
set((prev) => ({
...prev,
extras: {
...prev.extras,
...extras,
},
}));
},
setCorrectionModel(model) {
set((prev) => {
const correction = [...prev.extras.correction];
const exists = correction.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
correction.push(model);
} else {
correction[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
correction,
},
};
});
},
setDiffusionModel(model) {
set((prev) => {
const diffusion = [...prev.extras.diffusion];
const exists = diffusion.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
diffusion.push(model);
} else {
diffusion[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
diffusion,
},
};
});
},
setExtraNetwork(model) {
set((prev) => {
const networks = [...prev.extras.networks];
const exists = networks.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
networks.push(model);
} else {
networks[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
networks,
},
};
});
},
setExtraSource(model) {
set((prev) => {
const sources = [...prev.extras.sources];
const exists = sources.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
sources.push(model);
} else {
sources[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
sources,
},
};
});
},
setUpscalingModel(model) {
set((prev) => {
const upscaling = [...prev.extras.upscaling];
const exists = upscaling.findIndex((it) => model.name === it.name);
if (exists === MISSING_INDEX) {
upscaling.push(model);
} else {
upscaling[exists] = model;
}
return {
...prev,
extras: {
...prev.extras,
upscaling,
},
};
});
},
removeCorrectionModel(model) {
set((prev) => {
const correction = prev.extras.correction.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
correction,
},
};
});
},
removeDiffusionModel(model) {
set((prev) => {
const diffusion = prev.extras.diffusion.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
diffusion,
},
};
});
},
removeExtraNetwork(model) {
set((prev) => {
const networks = prev.extras.networks.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
networks,
},
};
});
},
removeExtraSource(model) {
set((prev) => {
const sources = prev.extras.sources.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
sources,
},
};
});
},
removeUpscalingModel(model) {
set((prev) => {
const upscaling = prev.extras.upscaling.filter((it) => model.name !== it.name);;
return {
...prev,
extras: {
...prev.extras,
upscaling,
},
};
});
},
});
}

View File

@ -1,19 +0,0 @@
import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from '../types/model.js';
export interface ModelSlice {
extras: ExtrasFile;
removeCorrectionModel(model: CorrectionModel): void;
removeDiffusionModel(model: DiffusionModel): void;
removeExtraNetwork(model: ExtraNetwork): void;
removeExtraSource(model: ExtraSource): void;
removeUpscalingModel(model: UpscalingModel): void;
setExtras(extras: Partial<ExtrasFile>): void;
setCorrectionModel(model: CorrectionModel): void;
setDiffusionModel(model: DiffusionModel): void;
setExtraNetwork(model: ExtraNetwork): void;
setExtraSource(model: ExtraSource): void;
setUpscalingModel(model: UpscalingModel): void;
}

View File

@ -1,5 +1,6 @@
import { Maybe } from '@apextoaster/js-utils'; import { Maybe } from '@apextoaster/js-utils';
import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js'; import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../types/params.js';
import { Slice } from './types.js';
export interface ProfileItem { export interface ProfileItem {
name: string; name: string;
@ -15,3 +16,37 @@ export interface ProfileSlice {
saveProfile(profile: ProfileItem): void; saveProfile(profile: ProfileItem): void;
} }
export function createProfileSlice<TState extends ProfileSlice>(): Slice<TState, ProfileSlice> {
return (set) => ({
profiles: [],
saveProfile(profile: ProfileItem) {
set((prev) => {
const profiles = [...prev.profiles];
const idx = profiles.findIndex((it) => it.name === profile.name);
if (idx >= 0) {
profiles[idx] = profile;
} else {
profiles.push(profile);
}
return {
...prev,
profiles,
};
});
},
removeProfile(profileName: string) {
set((prev) => {
const profiles = [...prev.profiles];
const idx = profiles.findIndex((it) => it.name === profileName);
if (idx >= 0) {
profiles.splice(idx, 1);
}
return {
...prev,
profiles,
};
});
}
});
}

View File

@ -1,3 +1,28 @@
import { BlendSlice } from './blend.js';
import { Img2ImgSlice } from './img2img.js';
import { InpaintSlice } from './inpaint.js';
import { Txt2ImgSlice } from './txt2img.js';
import { Slice } from './types.js';
import { UpscaleSlice } from './upscale.js';
export type SlicesWithReset = Txt2ImgSlice & Img2ImgSlice & InpaintSlice & UpscaleSlice & BlendSlice;
export interface ResetSlice { export interface ResetSlice {
resetAll(): void; resetAll(): void;
} }
export function createResetSlice<TState extends ResetSlice & SlicesWithReset>(): Slice<TState, ResetSlice> {
return (set) => ({
resetAll() {
set((prev) => {
const next = { ...prev };
next.resetImg2Img();
next.resetInpaint();
next.resetTxt2Img();
next.resetUpscale();
next.resetBlend();
return next;
});
},
});
}

View File

@ -1,11 +1,13 @@
import { PipelineGrid } from '../client/utils.js'; import { PipelineGrid } from '../client/utils.js';
import { ServerParams } from '../config.js';
import { import {
BaseImgParams,
HighresParams, HighresParams,
ModelParams, ModelParams,
Txt2ImgParams, Txt2ImgParams,
UpscaleParams, UpscaleParams,
} from '../types/params.js'; } from '../types/params.js';
import { TabState } from './types.js'; import { Slice, TabState } from './types.js';
export interface Txt2ImgSlice { export interface Txt2ImgSlice {
txt2img: TabState<Txt2ImgParams>; txt2img: TabState<Txt2ImgParams>;
@ -22,3 +24,82 @@ export interface Txt2ImgSlice {
setTxt2ImgUpscale(params: Partial<UpscaleParams>): void; setTxt2ImgUpscale(params: Partial<UpscaleParams>): void;
setTxt2ImgVariable(params: Partial<PipelineGrid>): void; setTxt2ImgVariable(params: Partial<PipelineGrid>): void;
} }
// eslint-disable-next-line max-params
export function createTxt2ImgSlice<TState extends Txt2ImgSlice>(
server: ServerParams,
defaultParams: Required<BaseImgParams>,
defaultHighres: HighresParams,
defaultModel: ModelParams,
defaultUpscale: UpscaleParams,
defaultGrid: PipelineGrid,
): Slice<TState, Txt2ImgSlice> {
return (set) => ({
txt2img: {
...defaultParams,
width: server.width.default,
height: server.height.default,
},
txt2imgHighres: {
...defaultHighres,
},
txt2imgModel: {
...defaultModel,
},
txt2imgUpscale: {
...defaultUpscale,
},
txt2imgVariable: {
...defaultGrid,
},
setTxt2Img(params) {
set((prev) => ({
txt2img: {
...prev.txt2img,
...params,
},
} as Partial<TState>));
},
setTxt2ImgHighres(params) {
set((prev) => ({
txt2imgHighres: {
...prev.txt2imgHighres,
...params,
},
} as Partial<TState>));
},
setTxt2ImgModel(params) {
set((prev) => ({
txt2imgModel: {
...prev.txt2imgModel,
...params,
},
} as Partial<TState>));
},
setTxt2ImgUpscale(params) {
set((prev) => ({
txt2imgUpscale: {
...prev.txt2imgUpscale,
...params,
},
} as Partial<TState>));
},
setTxt2ImgVariable(params) {
set((prev) => ({
txt2imgVariable: {
...prev.txt2imgVariable,
...params,
},
} as Partial<TState>));
},
resetTxt2Img() {
set({
txt2img: {
...defaultParams,
width: server.width.default,
height: server.height.default,
},
} as Partial<TState>);
},
});
}

View File

@ -1,4 +1,5 @@
import { PaletteMode } from '@mui/material'; import { PaletteMode } from '@mui/material';
import { StateCreator } from 'zustand';
import { ConfigFiles, ConfigState } from '../config.js'; import { ConfigFiles, ConfigState } from '../config.js';
export const MISSING_INDEX = -1; export const MISSING_INDEX = -1;
@ -9,3 +10,37 @@ export type Theme = PaletteMode | ''; // tri-state, '' is unset
* Combine optional files and required ranges. * Combine optional files and required ranges.
*/ */
export type TabState<TabParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>; export type TabState<TabParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>;
/**
* Shorthand for state creator to reduce repeated arguments.
*/
export type Slice<TState, TValue> = StateCreator<TState, [], [], TValue>;
/**
* Default parameters for the inpaint brush.
*
* Not provided by the server yet.
*/
export const DEFAULT_BRUSH = {
color: 255,
size: 8,
strength: 0.5,
};
/**
* Default parameters for the image history.
*
* Not provided by the server yet.
*/
export const DEFAULT_HISTORY = {
/**
* The number of images to be shown.
*/
limit: 4,
/**
* The number of additional images to be kept in history, so they can scroll
* back into view when you delete one. Does not include deleted images.
*/
scrollback: 2,
};

View File

@ -1,10 +1,11 @@
import { import {
BaseImgParams,
HighresParams, HighresParams,
ModelParams, ModelParams,
UpscaleParams, UpscaleParams,
UpscaleReqParams, UpscaleReqParams,
} from '../types/params.js'; } from '../types/params.js';
import { TabState } from './types.js'; import { Slice, TabState } from './types.js';
export interface UpscaleSlice { export interface UpscaleSlice {
upscale: TabState<UpscaleReqParams>; upscale: TabState<UpscaleReqParams>;
@ -19,3 +20,68 @@ export interface UpscaleSlice {
setUpscaleModel(params: Partial<ModelParams>): void; setUpscaleModel(params: Partial<ModelParams>): void;
setUpscaleUpscale(params: Partial<UpscaleParams>): void; setUpscaleUpscale(params: Partial<UpscaleParams>): void;
} }
export function createUpscaleSlice<TState extends UpscaleSlice>(
defaultParams: Required<BaseImgParams>,
defaultHighres: HighresParams,
defaultModel: ModelParams,
defaultUpscale: UpscaleParams,
): Slice<TState, UpscaleSlice> {
return (set) => ({
upscale: {
...defaultParams,
// eslint-disable-next-line no-null/no-null
source: null,
},
upscaleHighres: {
...defaultHighres,
},
upscaleModel: {
...defaultModel,
},
upscaleUpscale: {
...defaultUpscale,
},
resetUpscale() {
set({
upscale: {
...defaultParams,
// eslint-disable-next-line no-null/no-null
source: null,
},
} as Partial<TState>);
},
setUpscale(source) {
set((prev) => ({
upscale: {
...prev.upscale,
...source,
},
} as Partial<TState>));
},
setUpscaleHighres(params) {
set((prev) => ({
upscaleHighres: {
...prev.upscaleHighres,
...params,
},
} as Partial<TState>));
},
setUpscaleModel(params) {
set((prev) => ({
upscaleModel: {
...prev.upscaleModel,
...defaultModel,
},
} as Partial<TState>));
},
setUpscaleUpscale(params) {
set((prev) => ({
upscaleUpscale: {
...prev.upscaleUpscale,
...params,
},
} as Partial<TState>));
},
});
}