1
0
Fork 0

feat(gui): move model controls into each tab

This commit is contained in:
Sean Sube 2023-07-21 17:38:01 -05:00
parent 27a21dfa62
commit f14f197264
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
17 changed files with 686 additions and 539 deletions

View File

@ -140,9 +140,7 @@ export interface UpscaleParams {
/**
* Parameters for upscale requests.
*/
export interface UpscaleReqParams {
prompt: string;
negativePrompt?: string;
export interface UpscaleReqParams extends BaseImgParams {
source: Blob;
}

View File

@ -8,7 +8,6 @@ import { useHash } from 'react-use/lib/useHash';
import { useStore } from 'zustand';
import { StateContext } from '../state.js';
import { ModelControl } from './control/ModelControl.js';
import { ImageHistory } from './ImageHistory.js';
import { Logo } from './Logo.js';
import { Blend } from './tab/Blend.js';
@ -43,9 +42,6 @@ export function OnnxWeb() {
<Box sx={{ my: 4 }}>
<Logo />
</Box>
<Box sx={{ mx: 4, my: 4 }}>
<ModelControl />
</Box>
<TabContext value={getTab(hash)}>
<Box sx={{ borderBottom: 1, borderColor: 'divider' }}>
<TabList onChange={(_e, idx) => {

View File

@ -20,29 +20,35 @@ import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { BaseImgParams, Txt2ImgParams } from '../client/types.js';
import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../client/types.js';
import { StateContext } from '../state.js';
const { useState, Fragment } = React;
export interface ProfilesProps {
highres: HighresParams;
params: BaseImgParams;
setParams: ((params: BaseImgParams) => void) | undefined;
upscale: UpscaleParams;
setHighres(params: HighresParams): void;
setParams(params: BaseImgParams): void;
setUpscale(params: UpscaleParams): void;
}
export function Profiles(props: ProfilesProps) {
const state = mustExist(useContext(StateContext));
const profiles = useStore(state, (s) => s.profiles);
// eslint-disable-next-line @typescript-eslint/unbound-method
const saveProfile = useStore(state, (s) => s.saveProfile);
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeProfile = useStore(state, (s) => s.removeProfile);
const profiles = useStore(state, (s) => s.profiles);
const highres = useStore(state, (s) => s.highres);
const upscale = useStore(state, (s) => s.upscale);
const [dialogOpen, setDialogOpen] = React.useState(false);
const [profileName, setProfileName] = React.useState('');
const [dialogOpen, setDialogOpen] = useState(false);
const [profileName, setProfileName] = useState('');
const { t } = useTranslation();
return <>
return <Stack direction='row' spacing={2}>
<Autocomplete
id="profile-select"
options={profiles}
@ -77,36 +83,10 @@ export function Profiles(props: ProfilesProps) {
<Button type="button" variant="contained" onClick={() => setDialogOpen(true)}>
<SaveIcon />
</Button>
<Button component='label' variant="contained">
<ImageSearch />
<input
hidden
accept={'.json,.jpg,.jpeg,.png,.txt,.webp'}
type='file'
onChange={(event) => {
const { files } = event.target;
if (doesExist(files) && files.length > 0) {
const file = mustExist(files[0]);
// eslint-disable-next-line @typescript-eslint/no-floating-promises
loadParamsFromFile(file).then((newParams) => {
if (doesExist(props.setParams) && doesExist(newParams)) {
props.setParams({
...props.params,
...newParams,
});
}
});
}
}}
onClick={(event) => {
event.currentTarget.value = '';
}}
/>
</Button>
</Stack>
)}
onChange={(event, value) => {
if (doesExist(value) && doesExist(props.setParams)) {
if (doesExist(value)) {
props.setParams({
...value.params
});
@ -138,8 +118,8 @@ export function Profiles(props: ProfilesProps) {
saveProfile({
params: props.params,
name: profileName,
highResParams: highres,
upscaleParams: upscale,
highResParams: props.highres,
upscaleParams: props.upscale,
});
setDialogOpen(false);
setProfileName('');
@ -147,7 +127,33 @@ export function Profiles(props: ProfilesProps) {
>{t('profile.save')}</Button>
</DialogActions>
</Dialog>
</>;
<Button component='label' variant="contained">
<ImageSearch />
<input
hidden
accept={'.json,.jpg,.jpeg,.png,.txt,.webp'}
type='file'
onChange={(event) => {
const { files } = event.target;
if (doesExist(files) && files.length > 0) {
const file = mustExist(files[0]);
// eslint-disable-next-line @typescript-eslint/no-floating-promises
loadParamsFromFile(file).then((newParams) => {
if (doesExist(newParams)) {
props.setParams({
...props.params,
...newParams,
});
}
});
}
}}
onClick={(event) => {
event.currentTarget.value = '';
}}
/>
</Button>
</Stack>;
}
export async function loadParamsFromFile(file: File): Promise<Partial<Txt2ImgParams>> {

View File

@ -38,7 +38,7 @@ export function ImageCard(props: ImageCardProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const setInpaint = useStore(state, (s) => s.setInpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setUpscale = useStore(state, (s) => s.setUpscaleTab);
const setUpscale = useStore(state, (s) => s.setUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setBlend = useStore(state, (s) => s.setBlend);

View File

@ -3,17 +3,21 @@ import { Checkbox, FormControl, FormControlLabel, InputLabel, MenuItem, Select,
import * as React from 'react';
import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { ConfigContext, StateContext } from '../../state.js';
import { HighresParams } from '../../client/types.js';
import { ConfigContext } from '../../state.js';
import { NumericField } from '../input/NumericField.js';
export function HighresControl() {
const { params } = mustExist(useContext(ConfigContext));
const state = mustExist(useContext(StateContext));
const highres = useStore(state, (s) => s.highres);
export interface HighresControlProps {
highres: HighresParams;
setHighres(params: Partial<HighresParams>): void;
}
export function HighresControl(props: HighresControlProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const setHighres = useStore(state, (s) => s.setHighres);
const { highres, setHighres } = props;
const { params } = mustExist(useContext(ConfigContext));
const { t } = useTranslation();
return <Stack direction='row' spacing={4}>
@ -22,7 +26,7 @@ export function HighresControl() {
control={<Checkbox
checked={highres.enabled}
value='check'
onChange={(event) => {
onChange={(_event) => {
setHighres({
enabled: highres.enabled === false,
});

View File

@ -13,21 +13,21 @@ import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../sta
import { NumericField } from '../input/NumericField.js';
import { PromptInput } from '../input/PromptInput.js';
import { QueryList } from '../input/QueryList.js';
import { Profiles } from '../Profiles.js';
export interface ImageControlProps {
selector: (state: OnnxState) => BaseImgParams;
onChange?: (params: BaseImgParams) => void;
onChange(params: BaseImgParams): void;
selector(state: OnnxState): BaseImgParams;
}
/**
* Doesn't need to use state directly, the parent component knows which params to pass
*/
export function ImageControl(props: ImageControlProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const { onChange, selector } = props;
const { params } = mustExist(useContext(ConfigContext));
const state = mustExist(useContext(StateContext));
const controlState = useStore(state, props.selector);
const controlState = useStore(state, selector);
const { t } = useTranslation();
const client = mustExist(useContext(ClientContext));
@ -40,7 +40,6 @@ export function ImageControl(props: ImageControlProps) {
return <Stack spacing={2}>
<Stack direction='row' spacing={4}>
<Profiles params={controlState} setParams={props.onChange} />
<QueryList
id='schedulers'
labelKey='scheduler'
@ -50,8 +49,8 @@ export function ImageControl(props: ImageControlProps) {
}}
value={mustDefault(controlState.scheduler, '')}
onChange={(value) => {
if (doesExist(props.onChange)) {
props.onChange({
if (doesExist(onChange)) {
onChange({
...controlState,
scheduler: value,
});
@ -66,8 +65,8 @@ export function ImageControl(props: ImageControlProps) {
step={params.eta.step}
value={controlState.eta}
onChange={(eta) => {
if (doesExist(props.onChange)) {
props.onChange({
if (doesExist(onChange)) {
onChange({
...controlState,
eta,
});
@ -82,8 +81,8 @@ export function ImageControl(props: ImageControlProps) {
step={params.cfg.step}
value={controlState.cfg}
onChange={(cfg) => {
if (doesExist(props.onChange)) {
props.onChange({
if (doesExist(onChange)) {
onChange({
...controlState,
cfg,
});
@ -97,12 +96,10 @@ export function ImageControl(props: ImageControlProps) {
step={params.steps.step}
value={controlState.steps}
onChange={(steps) => {
if (doesExist(props.onChange)) {
props.onChange({
onChange({
...controlState,
steps,
});
}
}}
/>
<NumericField
@ -112,12 +109,10 @@ export function ImageControl(props: ImageControlProps) {
step={params.seed.step}
value={controlState.seed}
onChange={(seed) => {
if (doesExist(props.onChange)) {
props.onChange({
onChange({
...controlState,
seed,
});
}
}}
/>
<Button
@ -125,12 +120,10 @@ export function ImageControl(props: ImageControlProps) {
startIcon={<Casino />}
onClick={() => {
const seed = Math.floor(Math.random() * params.seed.max);
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
seed,
});
}
}}
>
{t('parameter.newSeed')}
@ -144,12 +137,10 @@ export function ImageControl(props: ImageControlProps) {
step={params.batch.step}
value={controlState.batch}
onChange={(batch) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
batch,
});
}
}}
/>
<NumericField
@ -159,12 +150,10 @@ export function ImageControl(props: ImageControlProps) {
step={params.tiles.step}
value={controlState.tiles}
onChange={(tiles) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
tiles,
});
}
}}
/>
<NumericField
@ -175,12 +164,10 @@ export function ImageControl(props: ImageControlProps) {
step={params.overlap.step}
value={controlState.overlap}
onChange={(overlap) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
overlap,
});
}
}}
/>
<NumericField
@ -190,12 +177,10 @@ export function ImageControl(props: ImageControlProps) {
step={params.stride.step}
value={controlState.stride}
onChange={(stride) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
stride,
});
}
}}
/>
<FormControlLabel
@ -204,12 +189,10 @@ export function ImageControl(props: ImageControlProps) {
checked={controlState.tiledVAE}
value='check'
onChange={(event) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
tiledVAE: controlState.tiledVAE === false,
});
}
}}
/>}
/>
@ -218,12 +201,10 @@ export function ImageControl(props: ImageControlProps) {
prompt={controlState.prompt}
negativePrompt={controlState.negativePrompt}
onChange={(value) => {
if (doesExist(props.onChange)) {
props.onChange({
...controlState,
...value,
});
}
}}
/>
</Stack>;

View File

@ -4,25 +4,25 @@ import { useMutation, useQuery } from '@tanstack/react-query';
import * as React from 'react';
import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useHash } from 'react-use/lib/useHash';
import { useStore } from 'zustand';
import { ModelParams } from '../../client/types.js';
import { STALE_TIME } from '../../config.js';
import { ClientContext, StateContext } from '../../state.js';
import { QueryList } from '../input/QueryList.js';
import { QueryMenu } from '../input/QueryMenu.js';
import { getTab } from '../utils.js';
export function ModelControl() {
export interface ModelControlProps {
model: ModelParams;
setModel(params: Partial<ModelParams>): void;
}
export function ModelControl(props: ModelControlProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const { model, setModel } = props;
const client = mustExist(useContext(ClientContext));
const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.model);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setModel = useStore(state, (s) => s.setModel);
const { t } = useTranslation();
const [hash, _setHash] = useHash();
const restart = useMutation(['restart'], async () => client.restart());
const models = useQuery(['models'], async () => client.models(), {
staleTime: STALE_TIME,
@ -34,32 +34,7 @@ export function ModelControl() {
staleTime: STALE_TIME,
});
function addToken(type: string, name: string, weight = 1.0) {
const tab = getTab(hash);
const current = state.getState();
switch (tab) {
case 'txt2img': {
const { prompt } = current.txt2img;
current.setTxt2Img({
prompt: `<${type}:${name}:1.0> ${prompt}`,
});
break;
}
case 'img2img': {
const { prompt } = current.img2img;
current.setImg2Img({
prompt: `<${type}:${name}:1.0> ${prompt}`,
});
break;
}
default:
// not supported yet
}
}
return <Stack direction='column' spacing={2}>
<Stack direction='row' spacing={2}>
return <Stack direction='row' spacing={2}>
<QueryList
id='platforms'
labelKey='platform'
@ -67,7 +42,7 @@ export function ModelControl() {
query={{
result: platforms,
}}
value={params.platform}
value={model.platform}
onChange={(platform) => {
setModel({
platform,
@ -81,7 +56,7 @@ export function ModelControl() {
query={{
result: pipelines,
}}
value={params.pipeline}
value={model.pipeline}
onChange={(pipeline) => {
setModel({
pipeline,
@ -91,27 +66,27 @@ export function ModelControl() {
<QueryList
id='diffusion'
labelKey='model'
name={t('modelType.diffusion', {count: 1})}
name={t('modelType.diffusion', { count: 1 })}
query={{
result: models,
selector: (result) => result.diffusion,
}}
value={params.model}
onChange={(model) => {
value={model.model}
onChange={(newModel) => {
setModel({
model,
model: newModel,
});
}}
/>
<QueryList
id='upscaling'
labelKey='model'
name={t('modelType.upscaling', {count: 1})}
name={t('modelType.upscaling', { count: 1 })}
query={{
result: models,
selector: (result) => result.upscaling,
}}
value={params.upscaling}
value={model.upscaling}
onChange={(upscaling) => {
setModel({
upscaling,
@ -121,48 +96,21 @@ export function ModelControl() {
<QueryList
id='correction'
labelKey='model'
name={t('modelType.correction', {count: 1})}
name={t('modelType.correction', { count: 1 })}
query={{
result: models,
selector: (result) => result.correction,
}}
value={params.correction}
value={model.correction}
onChange={(correction) => {
setModel({
correction,
});
}}
/>
</Stack>
<Stack direction='row' spacing={2}>
<QueryMenu
id='inversion'
labelKey='model.inversion'
name={t('modelType.inversion')}
query={{
result: models,
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
}}
onSelect={(name) => {
addToken('inversion', name);
}}
/>
<QueryMenu
id='lora'
labelKey='model.lora'
name={t('modelType.lora')}
query={{
result: models,
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
}}
onSelect={(name) => {
addToken('lora', name);
}}
/>
<Button
variant='outlined'
onClick={() => restart.mutate()}
>{t('admin.restart')}</Button>
</Stack>
</Stack>;
}

View File

@ -3,17 +3,21 @@ import { Checkbox, FormControl, FormControlLabel, InputLabel, MenuItem, Select,
import * as React from 'react';
import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { ConfigContext, StateContext } from '../../state.js';
import { UpscaleParams } from '../../client/types.js';
import { ConfigContext } from '../../state.js';
import { NumericField } from '../input/NumericField.js';
export function UpscaleControl() {
const { params } = mustExist(useContext(ConfigContext));
const state = mustExist(useContext(StateContext));
const upscale = useStore(state, (s) => s.upscale);
export interface UpscaleControlProps {
upscale: UpscaleParams;
setUpscale(params: Partial<UpscaleParams>): void;
}
export function UpscaleControl(props: UpscaleControlProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const setUpscale = useStore(state, (s) => s.setUpscale);
const { upscale, setUpscale } = props;
const { params } = mustExist(useContext(ConfigContext));
const { t } = useTranslation();
return <Stack direction='row' spacing={4}>
@ -22,7 +26,7 @@ export function UpscaleControl() {
control={<Checkbox
checked={upscale.enabled}
value='check'
onChange={(event) => {
onChange={(_event) => {
setUpscale({
enabled: upscale.enabled === false,
});

View File

@ -4,8 +4,8 @@ import { Button, Stack, Typography } from '@mui/material';
import { throttle } from 'lodash';
import React, { RefObject, useContext, useEffect, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { BrushParams } from '../../client/types.js';
import { SAVE_TIME } from '../../config.js';
import { ConfigContext, LoggerContext, StateContext } from '../../state.js';
import { imageFromBlob } from '../../utils.js';
@ -36,14 +36,17 @@ export interface Point {
}
export interface MaskCanvasProps {
brush: BrushParams;
source?: Maybe<Blob>;
mask?: Maybe<Blob>;
onSave: (blob: Blob) => void;
onSave(blob: Blob): void;
setBrush(brush: Partial<BrushParams>): void;
}
export function MaskCanvas(props: MaskCanvasProps) {
const { source, mask } = props;
// eslint-disable-next-line @typescript-eslint/unbound-method
const { source, mask, brush, setBrush } = props;
const { params } = mustExist(useContext(ConfigContext));
const logger = mustExist(useContext(LoggerContext));
@ -202,9 +205,6 @@ export function MaskCanvas(props: MaskCanvasProps) {
});
const state = mustExist(useContext(StateContext));
const brush = useStore(state, (s) => s.brush);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setBrush = useStore(state, (s) => s.setBrush);
const { t } = useTranslation();
useEffect(() => {

View File

@ -1,16 +1,23 @@
import { doesExist, Maybe } from '@apextoaster/js-utils';
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { TextField } from '@mui/material';
import { Stack } from '@mui/system';
import { useQuery } from '@tanstack/react-query';
import * as React from 'react';
import { useTranslation } from 'react-i18next';
import { QueryMenu } from '../input/QueryMenu.js';
import { STALE_TIME } from '../../config.js';
import { ClientContext } from '../../state.js';
const { useContext } = React;
export interface PromptValue {
prompt: string;
negativePrompt?: string;
}
export interface PromptInputProps extends PromptValue {
onChange?: Maybe<(value: PromptValue) => void>;
onChange: (value: PromptValue) => void;
}
export const PROMPT_GROUP = 75;
@ -29,12 +36,24 @@ export function PromptInput(props: PromptInputProps) {
const tokens = splitPrompt(prompt);
const groups = Math.ceil(tokens.length / PROMPT_GROUP);
const client = mustExist(useContext(ClientContext));
const models = useQuery(['models'], async () => client.models(), {
staleTime: STALE_TIME,
});
const { t } = useTranslation();
const helper = t('input.prompt.tokens', {
groups,
tokens: tokens.length,
});
function addToken(type: string, name: string, weight = 1.0) {
props.onChange({
prompt: `<${type}:${name}:1.0> ${prompt}`,
negativePrompt,
});
}
return <Stack spacing={2}>
<TextField
label={t('parameter.prompt')}
@ -63,5 +82,31 @@ export function PromptInput(props: PromptInputProps) {
}
}}
/>
<Stack direction='row' spacing={2}>
<QueryMenu
id='inversion'
labelKey='model.inversion'
name={t('modelType.inversion')}
query={{
result: models,
selector: (result) => result.networks.filter((network) => network.type === 'inversion').map((network) => network.name),
}}
onSelect={(name) => {
addToken('inversion', name);
}}
/>
<QueryMenu
id='lora'
labelKey='model.lora'
name={t('modelType.lora')}
query={{
result: models,
selector: (result) => result.networks.filter((network) => network.type === 'lora').map((network) => network.name),
}}
onSelect={(name) => {
addToken('lora', name);
}}
/>
</Stack>
</Stack>;
}

View File

@ -15,12 +15,12 @@ import { MaskCanvas } from '../input/MaskCanvas.js';
export function Blend() {
async function uploadSource() {
const { model, blend, upscale } = state.getState();
const { image, retry } = await client.blend(model, {
const { blend, blendModel, blendUpscale } = state.getState();
const { image, retry } = await client.blend(blendModel, {
...blend,
mask: mustExist(blend.mask),
sources: mustExist(blend.sources), // TODO: show an error if this doesn't exist
}, upscale);
}, blendUpscale);
pushHistory(image, retry);
}
@ -32,10 +32,16 @@ export function Blend() {
});
const state = mustExist(useContext(StateContext));
const brush = useStore(state, (s) => s.blendBrush);
const blend = useStore(state, (s) => s.blend);
const upscale = useStore(state, (s) => s.blendUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setBlend = useStore(state, (s) => s.setBlend);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setBrush = useStore(state, (s) => s.setBlendBrush);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setUpscale = useStore(state, (s) => s.setBlendUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);
const { t } = useTranslation();
@ -61,6 +67,7 @@ export function Blend() {
/>
)}
<MaskCanvas
brush={brush}
source={sources[0]}
mask={blend.mask}
onSave={(mask) => {
@ -68,8 +75,9 @@ export function Blend() {
mask,
});
}}
setBrush={setBrush}
/>
<UpscaleControl />
<UpscaleControl upscale={upscale} setUpscale={setUpscale} />
<Button
disabled={sources.length < 2}
variant='contained'

View File

@ -14,12 +14,13 @@ import { ImageInput } from '../input/ImageInput.js';
import { NumericField } from '../input/NumericField.js';
import { QueryList } from '../input/QueryList.js';
import { HighresControl } from '../control/HighresControl.js';
import { ModelControl } from '../control/ModelControl.js';
import { Profiles } from '../Profiles.js';
export function Img2Img() {
const { params } = mustExist(useContext(ConfigContext));
async function uploadSource() {
const { model, img2img, upscale, highres } = state.getState();
const { image, retry } = await client.img2img(model, {
...img2img,
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
@ -42,21 +43,27 @@ export function Img2Img() {
});
const state = mustExist(useContext(StateContext));
const control = useStore(state, (s) => s.model.control);
const model = useStore(state, (s) => s.img2imgModel);
const source = useStore(state, (s) => s.img2img.source);
const sourceFilter = useStore(state, (s) => s.img2img.sourceFilter);
const strength = useStore(state, (s) => s.img2img.strength);
const loopback = useStore(state, (s) => s.img2img.loopback);
const img2img = useStore(state, (s) => s.img2img);
const highres = useStore(state, (s) => s.img2imgHighres);
const upscale = useStore(state, (s) => s.img2imgUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setImg2Img = useStore(state, (s) => s.setImg2Img);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setModel = useStore(state, (s) => s.setModel);
const setHighres = useStore(state, (s) => s.setImg2ImgHighres);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setUpscale = useStore(state, (s) => s.setImg2ImgUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setModel = useStore(state, (s) => s.setImg2ImgModel);
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);
const { t } = useTranslation();
return <Box>
<Stack spacing={2}>
<Profiles params={img2img} setParams={setImg2Img} highres={highres} setHighres={setHighres} upscale={upscale} setUpscale={setUpscale} />
<ModelControl model={model} setModel={setModel} />
<ImageInput
filter={IMAGE_FILTER}
image={source}
@ -77,7 +84,7 @@ export function Img2Img() {
result: models,
selector: (result) => result.networks.filter((network) => network.type === 'control').map((network) => network.name),
}}
value={control}
value={model.control}
onChange={(newControl) => {
setModel({
control: newControl,
@ -93,7 +100,7 @@ export function Img2Img() {
selector: (f) => f.source,
}}
showNone
value={sourceFilter}
value={img2img.sourceFilter}
onChange={(newFilter) => {
setImg2Img({
sourceFilter: newFilter,
@ -106,7 +113,7 @@ export function Img2Img() {
min={params.strength.min}
max={params.strength.max}
step={params.strength.step}
value={strength}
value={img2img.strength}
onChange={(value) => {
setImg2Img({
strength: value,
@ -118,7 +125,7 @@ export function Img2Img() {
min={params.loopback.min}
max={params.loopback.max}
step={params.loopback.step}
value={loopback}
value={img2img.loopback}
onChange={(value) => {
setImg2Img({
loopback: value,
@ -126,8 +133,8 @@ export function Img2Img() {
}}
/>
</Stack>
<HighresControl />
<UpscaleControl />
<HighresControl highres={highres} setHighres={setHighres} />
<UpscaleControl upscale={upscale} setUpscale={setUpscale} />
<Button
disabled={doesExist(source) === false}
variant='contained'

View File

@ -1,21 +1,23 @@
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { Alert, Box, Button, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material';
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
import * as React from 'react';
import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
import { useStore } from 'zustand';
import { IMAGE_FILTER, STALE_TIME } from '../../config.js';
import { ClientContext, ConfigContext, StateContext } from '../../state.js';
import { HighresControl } from '../control/HighresControl.js';
import { ImageControl } from '../control/ImageControl.js';
import { ModelControl } from '../control/ModelControl.js';
import { OutpaintControl } from '../control/OutpaintControl.js';
import { UpscaleControl } from '../control/UpscaleControl.js';
import { ImageInput } from '../input/ImageInput.js';
import { MaskCanvas } from '../input/MaskCanvas.js';
import { NumericField } from '../input/NumericField.js';
import { QueryList } from '../input/QueryList.js';
import { HighresControl } from '../control/HighresControl.js';
import { Profiles } from '../Profiles.js';
export function Inpaint() {
const { params } = mustExist(useContext(ConfigContext));
@ -29,9 +31,6 @@ export function Inpaint() {
});
async function uploadSource(): Promise<void> {
// these are not watched by the component, only sent by the mutation
const { model, inpaint, outpaint, upscale, highres } = state.getState();
if (outpaint.enabled) {
const { image, retry } = await client.outpaint(model, {
...inpaint,
@ -57,21 +56,31 @@ export function Inpaint() {
}
function supportsInpaint(): boolean {
return diffusionModel.includes('inpaint');
return model.model.includes('inpaint');
}
const state = mustExist(useContext(StateContext));
const fillColor = useStore(state, (s) => s.inpaint.fillColor);
const filter = useStore(state, (s) => s.inpaint.filter);
const noise = useStore(state, (s) => s.inpaint.noise);
const mask = useStore(state, (s) => s.inpaint.mask);
const source = useStore(state, (s) => s.inpaint.source);
const strength = useStore(state, (s) => s.inpaint.strength);
const tileOrder = useStore(state, (s) => s.inpaint.tileOrder);
const diffusionModel = useStore(state, (s) => s.model.model);
const inpaint = useStore(state, (s) => s.inpaint);
const outpaint = useStore(state, (s) => s.outpaint);
const brush = useStore(state, (s) => s.inpaintBrush);
const highres = useStore(state, (s) => s.inpaintHighres);
const model = useStore(state, (s) => s.inpaintModel);
const upscale = useStore(state, (s) => s.inpaintUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setInpaint = useStore(state, (s) => s.setInpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setBrush = useStore(state, (s) => s.setInpaintBrush);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setModel = useStore(state, (s) => s.setInpaintModel);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setHighres = useStore(state, (s) => s.setInpaintHighres);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setUpscale = useStore(state, (s) => s.setInpaintUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);
const { t } = useTranslation();
@ -91,6 +100,8 @@ export function Inpaint() {
return <Box>
<Stack spacing={2}>
<Profiles params={inpaint} setParams={setInpaint} highres={highres} setHighres={setHighres} upscale={upscale} setUpscale={setUpscale} />
<ModelControl model={model} setModel={setModel} />
{renderBanner()}
<ImageInput
filter={IMAGE_FILTER}
@ -115,6 +126,7 @@ export function Inpaint() {
}}
/>
<MaskCanvas
brush={brush}
source={source}
mask={mask}
onSave={(file) => {
@ -122,6 +134,7 @@ export function Inpaint() {
mask: file,
});
}}
setBrush={setBrush}
/>
<ImageControl
selector={(s) => s.inpaint}
@ -134,7 +147,7 @@ export function Inpaint() {
min={params.strength.min}
max={params.strength.max}
step={params.strength.step}
value={strength}
value={inpaint.strength}
onChange={(value) => {
setInpaint({
strength: value,
@ -150,7 +163,7 @@ export function Inpaint() {
result: filters,
selector: (f) => f.mask,
}}
value={filter}
value={inpaint.filter}
onChange={(newFilter) => {
setInpaint({
filter: newFilter,
@ -164,7 +177,7 @@ export function Inpaint() {
query={{
result: noises,
}}
value={noise}
value={inpaint.noise}
onChange={(newNoise) => {
setInpaint({
noise: newNoise,
@ -176,7 +189,7 @@ export function Inpaint() {
<Select
labelId={'outpaint-tiling'}
label={t('parameter.tileOrder')}
value={tileOrder}
value={inpaint.tileOrder}
onChange={(e) => {
setInpaint({
tileOrder: e.target.value,
@ -194,7 +207,7 @@ export function Inpaint() {
sx={{ mx: 1 }}
control={
<input
defaultValue={fillColor}
defaultValue={inpaint.fillColor}
name='fill-color'
type='color'
onBlur={(event) => {
@ -208,8 +221,8 @@ export function Inpaint() {
</Stack>
</Stack>
<OutpaintControl />
<HighresControl />
<UpscaleControl />
<HighresControl highres={highres} setHighres={setHighres} />
<UpscaleControl upscale={upscale} setUpscale={setUpscale} />
<Button
disabled={preventInpaint()}
variant='contained'

View File

@ -11,12 +11,13 @@ import { HighresControl } from '../control/HighresControl.js';
import { ImageControl } from '../control/ImageControl.js';
import { UpscaleControl } from '../control/UpscaleControl.js';
import { NumericField } from '../input/NumericField.js';
import { ModelControl } from '../control/ModelControl.js';
import { Profiles } from '../Profiles.js';
export function Txt2Img() {
const { params } = mustExist(useContext(ConfigContext));
async function generateImage() {
const { model, txt2img, upscale, highres } = state.getState();
const { image, retry } = await client.txt2img(model, txt2img, upscale, highres);
pushHistory(image, retry);
@ -29,26 +30,36 @@ export function Txt2Img() {
});
const state = mustExist(useContext(StateContext));
const height = useStore(state, (s) => s.txt2img.height);
const width = useStore(state, (s) => s.txt2img.width);
const txt2img = useStore(state, (s) => s.txt2img);
const model = useStore(state, (s) => s.txt2imgModel);
const highres = useStore(state, (s) => s.txt2imgHighres);
const upscale = useStore(state, (s) => s.txt2imgUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setTxt2Img = useStore(state, (s) => s.setTxt2Img);
const setParams = useStore(state, (s) => s.setTxt2Img);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setHighres = useStore(state, (s) => s.setTxt2ImgHighres);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setUpscale = useStore(state, (s) => s.setTxt2ImgUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setModel = useStore(state, (s) => s.setTxt2ImgModel);
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);
const { t } = useTranslation();
return <Box>
<Stack spacing={2}>
<ImageControl selector={(s) => s.txt2img} onChange={setTxt2Img} />
<Profiles params={txt2img} setParams={setParams} highres={highres} setHighres={setHighres} upscale={upscale} setUpscale={setUpscale} />
<ModelControl model={model} setModel={setModel} />
<ImageControl selector={(s) => s.txt2img} onChange={setParams} />
<Stack direction='row' spacing={4}>
<NumericField
label={t('parameter.width')}
min={params.width.min}
max={params.width.max}
step={params.width.step}
value={width}
value={txt2img.width}
onChange={(value) => {
setTxt2Img({
setParams({
width: value,
});
}}
@ -58,16 +69,16 @@ export function Txt2Img() {
min={params.height.min}
max={params.height.max}
step={params.height.step}
value={height}
value={txt2img.height}
onChange={(value) => {
setTxt2Img({
setParams({
height: value,
});
}}
/>
</Stack>
<HighresControl />
<UpscaleControl />
<HighresControl highres={highres} setHighres={setHighres} />
<UpscaleControl upscale={upscale} setUpscale={setUpscale} />
<Button
variant='contained'
onClick={() => generate.mutate()}

View File

@ -1,25 +1,27 @@
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { Box, Button, Stack } from '@mui/material';
import { useMutation, useQueryClient } from '@tanstack/react-query';
import * as React from 'react';
import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useMutation, useQueryClient } from '@tanstack/react-query';
import { useStore } from 'zustand';
import { IMAGE_FILTER } from '../../config.js';
import { ClientContext, StateContext } from '../../state.js';
import { HighresControl } from '../control/HighresControl.js';
import { ModelControl } from '../control/ModelControl.js';
import { UpscaleControl } from '../control/UpscaleControl.js';
import { ImageInput } from '../input/ImageInput.js';
import { PromptInput } from '../input/PromptInput.js';
import { HighresControl } from '../control/HighresControl.js';
import { Profiles } from '../Profiles.js';
export function Upscale() {
async function uploadSource() {
const { highres, model, upscale } = state.getState();
const { image, retry } = await client.upscale(model, {
...params,
source: mustExist(params.source), // TODO: show an error if this doesn't exist
}, upscale, highres);
const { upscaleHighres, upscaleUpscale, upscaleModel, upscale } = state.getState();
const { image, retry } = await client.upscale(upscaleModel, {
...upscale,
source: mustExist(upscale.source), // TODO: show an error if this doesn't exist
}, upscaleUpscale, upscaleHighres);
pushHistory(image, retry);
}
@ -31,21 +33,32 @@ export function Upscale() {
});
const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.upscaleTab);
const highres = useStore(state, (s) => s.upscaleHighres);
const model = useStore(state, (s) => s.upscaleModel);
const params = useStore(state, (s) => s.upscale);
const upscale = useStore(state, (s) => s.upscaleUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setSource = useStore(state, (s) => s.setUpscaleTab);
const setModel = useStore(state, (s) => s.setUpscalingModel);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setHighres = useStore(state, (s) => s.setUpscaleHighres);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setUpscale = useStore(state, (s) => s.setUpscaleUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setParams = useStore(state, (s) => s.setUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);
const { t } = useTranslation();
return <Box>
<Stack spacing={2}>
<Profiles params={params} setParams={setParams} highres={highres} setHighres={setHighres} upscale={upscale} setUpscale={setUpscale} />
<ModelControl model={model} setModel={setModel} />
<ImageInput
filter={IMAGE_FILTER}
image={params.source}
label={t('input.image.source')}
onChange={(file) => {
setSource({
setParams({
source: file,
});
}}
@ -54,11 +67,11 @@ export function Upscale() {
prompt={params.prompt}
negativePrompt={params.negativePrompt}
onChange={(value) => {
setSource(value);
setParams(value);
}}
/>
<HighresControl />
<UpscaleControl />
<HighresControl highres={highres} setHighres={setHighres} />
<UpscaleControl upscale={upscale} setUpscale={setUpscale} />
<Button
disabled={doesExist(params.source) === false}
variant='contained'

View File

@ -47,35 +47,27 @@ export async function renderApp(config: Config, params: ServerParams, logger: Lo
// prep zustand with a slice for each tab, using local storage
const {
createBrushSlice,
createDefaultSlice,
createHistorySlice,
createImg2ImgSlice,
createInpaintSlice,
createModelSlice,
createOutpaintSlice,
createTxt2ImgSlice,
createUpscaleSlice,
createHighresSlice,
createBlendSlice,
createResetSlice,
createExtraSlice,
createProfileSlice,
} = createStateSlices(params);
const state = createStore<OnnxState, [['zustand/persist', OnnxState]]>(persist((...slice) => ({
...createBrushSlice(...slice),
...createDefaultSlice(...slice),
...createHistorySlice(...slice),
...createImg2ImgSlice(...slice),
...createInpaintSlice(...slice),
...createModelSlice(...slice),
...createTxt2ImgSlice(...slice),
...createOutpaintSlice(...slice),
...createUpscaleSlice(...slice),
...createHighresSlice(...slice),
...createBlendSlice(...slice),
...createResetSlice(...slice),
...createExtraSlice(...slice),
...createProfileSlice(...slice),
}), {
name: STATE_KEY,
@ -91,8 +83,8 @@ export async function renderApp(config: Config, params: ServerParams, logger: Lo
mask: undefined,
source: undefined,
},
upscaleTab: {
...s.upscaleTab,
upscale: {
...s.upscale,
source: undefined,
},
blend: {

View File

@ -48,12 +48,6 @@ interface ProfileItem {
upscaleParams?: Maybe<UpscaleParams>;
}
interface BrushSlice {
brush: BrushParams;
setBrush(brush: Partial<BrushParams>): void;
}
interface DefaultSlice {
defaults: TabState<BaseImgParams>;
theme: Theme;
@ -62,24 +56,6 @@ interface DefaultSlice {
setTheme(theme: Theme): void;
}
interface ExtraSlice {
extras: ExtrasFile;
setExtras(extras: Partial<ExtrasFile>): void;
setCorrectionModel(model: CorrectionModel): void;
setDiffusionModel(model: DiffusionModel): void;
setExtraNetwork(model: ExtraNetwork): void;
setExtraSource(model: ExtraSource): void;
setUpscalingModel(model: UpscalingModel): void;
removeCorrectionModel(model: CorrectionModel): void;
removeDiffusionModel(model: DiffusionModel): void;
removeExtraNetwork(model: ExtraNetwork): void;
removeExtraSource(model: ExtraSource): void;
removeUpscalingModel(model: UpscalingModel): void;
}
interface HistorySlice {
history: Array<HistoryItem>;
limit: number;
@ -91,60 +67,96 @@ interface HistorySlice {
}
interface ModelSlice {
model: ModelParams;
extras: ExtrasFile;
setModel(model: Partial<ModelParams>): void;
removeCorrectionModel(model: CorrectionModel): void;
removeDiffusionModel(model: DiffusionModel): void;
removeExtraNetwork(model: ExtraNetwork): void;
removeExtraSource(model: ExtraSource): void;
removeUpscalingModel(model: UpscalingModel): void;
setExtras(extras: Partial<ExtrasFile>): void;
setCorrectionModel(model: CorrectionModel): void;
setDiffusionModel(model: DiffusionModel): void;
setExtraNetwork(model: ExtraNetwork): void;
setExtraSource(model: ExtraSource): void;
setUpscalingModel(model: UpscalingModel): void;
}
// #region tab slices
interface Txt2ImgSlice {
txt2img: TabState<Txt2ImgParams>;
txt2imgModel: ModelParams;
txt2imgHighres: HighresParams;
txt2imgUpscale: UpscaleParams;
resetTxt2Img(): void;
setTxt2Img(params: Partial<Txt2ImgParams>): void;
resetTxt2Img(): void;
setTxt2ImgModel(params: Partial<ModelParams>): void;
setTxt2ImgHighres(params: Partial<HighresParams>): void;
setTxt2ImgUpscale(params: Partial<UpscaleParams>): void;
}
interface Img2ImgSlice {
img2img: TabState<Img2ImgParams>;
img2imgModel: ModelParams;
img2imgHighres: HighresParams;
img2imgUpscale: UpscaleParams;
resetImg2Img(): void;
setImg2Img(params: Partial<Img2ImgParams>): void;
resetImg2Img(): void;
setImg2ImgModel(params: Partial<ModelParams>): void;
setImg2ImgHighres(params: Partial<HighresParams>): void;
setImg2ImgUpscale(params: Partial<UpscaleParams>): void;
}
interface InpaintSlice {
inpaint: TabState<InpaintParams>;
setInpaint(params: Partial<InpaintParams>): void;
resetInpaint(): void;
}
interface OutpaintSlice {
inpaintBrush: BrushParams;
inpaintModel: ModelParams;
inpaintHighres: HighresParams;
inpaintUpscale: UpscaleParams;
outpaint: OutpaintPixels;
resetInpaint(): void;
setInpaint(params: Partial<InpaintParams>): void;
setInpaintBrush(brush: Partial<BrushParams>): void;
setInpaintModel(params: Partial<ModelParams>): void;
setInpaintHighres(params: Partial<HighresParams>): void;
setInpaintUpscale(params: Partial<UpscaleParams>): void;
setOutpaint(pixels: Partial<OutpaintPixels>): void;
}
interface HighresSlice {
highres: HighresParams;
setHighres(params: Partial<HighresParams>): void;
resetHighres(): void;
}
interface UpscaleSlice {
upscale: UpscaleParams;
upscaleTab: TabState<UpscaleReqParams>;
upscale: TabState<UpscaleReqParams>;
upscaleHighres: HighresParams;
upscaleModel: ModelParams;
upscaleUpscale: UpscaleParams;
setUpscale(upscale: Partial<UpscaleParams>): void;
setUpscaleTab(params: Partial<UpscaleReqParams>): void;
resetUpscaleTab(): void;
resetUpscale(): void;
setUpscale(params: Partial<UpscaleReqParams>): void;
setUpscaleHighres(params: Partial<HighresParams>): void;
setUpscaleModel(params: Partial<ModelParams>): void;
setUpscaleUpscale(params: Partial<UpscaleParams>): void;
}
interface BlendSlice {
blend: TabState<BlendParams>;
blendBrush: BrushParams;
blendModel: ModelParams;
blendUpscale: UpscaleParams;
resetBlend(): void;
setBlend(blend: Partial<BlendParams>): void;
resetBlend(): void;
setBlendBrush(brush: Partial<BrushParams>): void;
setBlendModel(model: Partial<ModelParams>): void;
setBlendUpscale(params: Partial<UpscaleParams>): void;
}
interface ResetSlice {
@ -154,8 +166,9 @@ interface ResetSlice {
interface ProfileSlice {
profiles: Array<ProfileItem>;
saveProfile(profile: ProfileItem): void;
removeProfile(profileName: string): void;
saveProfile(profile: ProfileItem): void;
}
// #endregion
@ -163,19 +176,16 @@ interface ProfileSlice {
* Full merged state including all slices.
*/
export type OnnxState
= BrushSlice
& DefaultSlice
= DefaultSlice
& HistorySlice
& Img2ImgSlice
& InpaintSlice
& ModelSlice
& OutpaintSlice
& Txt2ImgSlice
& HighresSlice
& UpscaleSlice
& BlendSlice
& ResetSlice
& ExtraSlice
& ModelSlice
& ProfileSlice;
/**
@ -268,14 +278,49 @@ export function baseParamsFromServer(defaults: ServerParams): Required<BaseImgPa
* else should be initialized from the default value in the base parameters.
*/
export function createStateSlices(server: ServerParams) {
const base = baseParamsFromServer(server);
const defaultParams = baseParamsFromServer(server);
const defaultHighres: HighresParams = {
enabled: false,
highresIterations: server.highresIterations.default,
highresMethod: '',
highresSteps: server.highresSteps.default,
highresScale: server.highresScale.default,
highresStrength: server.highresStrength.default,
};
const defaultModel: ModelParams = {
control: server.control.default,
correction: server.correction.default,
model: server.model.default,
pipeline: server.pipeline.default,
platform: server.platform.default,
upscaling: server.upscaling.default,
};
const defaultUpscale: UpscaleParams = {
denoise: server.denoise.default,
enabled: false,
faces: false,
faceOutscale: server.faceOutscale.default,
faceStrength: server.faceStrength.default,
outscale: server.outscale.default,
scale: server.scale.default,
upscaleOrder: server.upscaleOrder.default,
};
const createTxt2ImgSlice: Slice<Txt2ImgSlice> = (set) => ({
txt2img: {
...base,
...defaultParams,
width: server.width.default,
height: server.height.default,
},
txt2imgHighres: {
...defaultHighres,
},
txt2imgModel: {
...defaultModel,
},
txt2imgUpscale: {
...defaultUpscale,
},
setTxt2Img(params) {
set((prev) => ({
txt2img: {
@ -284,10 +329,34 @@ export function createStateSlices(server: ServerParams) {
},
}));
},
setTxt2ImgHighres(params) {
set((prev) => ({
txt2imgHighres: {
...prev.txt2imgHighres,
...params,
},
}));
},
setTxt2ImgModel(params) {
set((prev) => ({
txt2imgModel: {
...prev.txt2imgModel,
...params,
},
}));
},
setTxt2ImgUpscale(params) {
set((prev) => ({
txt2imgUpscale: {
...prev.txt2imgUpscale,
...params,
},
}));
},
resetTxt2Img() {
set({
txt2img: {
...base,
...defaultParams,
width: server.width.default,
height: server.height.default,
},
@ -297,12 +366,32 @@ export function createStateSlices(server: ServerParams) {
const createImg2ImgSlice: Slice<Img2ImgSlice> = (set) => ({
img2img: {
...base,
...defaultParams,
loopback: server.loopback.default,
source: null,
sourceFilter: '',
strength: server.strength.default,
},
img2imgHighres: {
...defaultHighres,
},
img2imgModel: {
...defaultModel,
},
img2imgUpscale: {
...defaultUpscale,
},
resetImg2Img() {
set({
img2img: {
...defaultParams,
loopback: server.loopback.default,
source: null,
sourceFilter: '',
strength: server.strength.default,
},
});
},
setImg2Img(params) {
set((prev) => ({
img2img: {
@ -311,22 +400,35 @@ export function createStateSlices(server: ServerParams) {
},
}));
},
resetImg2Img() {
set({
img2img: {
...base,
loopback: server.loopback.default,
source: null,
sourceFilter: '',
strength: server.strength.default,
setImg2ImgHighres(params) {
set((prev) => ({
img2imgHighres: {
...prev.img2imgHighres,
...params,
},
});
}));
},
setImg2ImgModel(params) {
set((prev) => ({
img2imgModel: {
...prev.img2imgModel,
...params,
},
}));
},
setImg2ImgUpscale(params) {
set((prev) => ({
img2imgUpscale: {
...prev.img2imgUpscale,
...params,
},
}));
},
});
const createInpaintSlice: Slice<InpaintSlice> = (set) => ({
inpaint: {
...base,
...defaultParams,
fillColor: server.fillColor.default,
filter: server.filter.default,
mask: null,
@ -335,6 +437,39 @@ export function createStateSlices(server: ServerParams) {
strength: server.strength.default,
tileOrder: server.tileOrder.default,
},
inpaintBrush: {
...DEFAULT_BRUSH,
},
inpaintHighres: {
...defaultHighres,
},
inpaintModel: {
...defaultModel,
},
inpaintUpscale: {
...defaultUpscale,
},
outpaint: {
enabled: false,
left: server.left.default,
right: server.right.default,
top: server.top.default,
bottom: server.bottom.default,
},
resetInpaint() {
set({
inpaint: {
...defaultParams,
fillColor: server.fillColor.default,
filter: server.filter.default,
mask: null,
noise: server.noise.default,
source: null,
strength: server.strength.default,
tileOrder: server.tileOrder.default,
},
});
},
setInpaint(params) {
set((prev) => ({
inpaint: {
@ -343,19 +478,45 @@ export function createStateSlices(server: ServerParams) {
},
}));
},
resetInpaint() {
set({
inpaint: {
...base,
fillColor: server.fillColor.default,
filter: server.filter.default,
mask: null,
noise: server.noise.default,
source: null,
strength: server.strength.default,
tileOrder: server.tileOrder.default,
setInpaintBrush(brush) {
set((prev) => ({
inpaintBrush: {
...prev.inpaintBrush,
...brush,
},
});
}));
},
setInpaintHighres(params) {
set((prev) => ({
inpaintHighres: {
...prev.inpaintHighres,
...params,
},
}));
},
setInpaintModel(params) {
set((prev) => ({
inpaintModel: {
...prev.inpaintModel,
...params,
},
}));
},
setInpaintUpscale(params) {
set((prev) => ({
inpaintUpscale: {
...prev.inpaintUpscale,
...params,
},
}));
},
setOutpaint(pixels) {
set((prev) => ({
outpaint: {
...prev.outpaint,
...pixels,
}
}));
},
});
@ -405,109 +566,59 @@ export function createStateSlices(server: ServerParams) {
},
});
const createOutpaintSlice: Slice<OutpaintSlice> = (set) => ({
outpaint: {
enabled: false,
left: server.left.default,
right: server.right.default,
top: server.top.default,
bottom: server.bottom.default,
},
setOutpaint(pixels) {
set((prev) => ({
outpaint: {
...prev.outpaint,
...pixels,
}
}));
},
});
const createBrushSlice: Slice<BrushSlice> = (set) => ({
brush: {
...DEFAULT_BRUSH,
},
setBrush(brush) {
set((prev) => ({
brush: {
...prev.brush,
...brush,
},
}));
},
});
const createUpscaleSlice: Slice<UpscaleSlice> = (set) => ({
upscale: {
denoise: server.denoise.default,
enabled: false,
faces: false,
faceOutscale: server.faceOutscale.default,
faceStrength: server.faceStrength.default,
outscale: server.outscale.default,
scale: server.scale.default,
upscaleOrder: server.upscaleOrder.default,
},
upscaleTab: {
negativePrompt: server.negativePrompt.default,
prompt: server.prompt.default,
...defaultParams,
source: null,
},
setUpscale(upscale) {
upscaleHighres: {
...defaultHighres,
},
upscaleModel: {
...defaultModel,
},
upscaleUpscale: {
...defaultUpscale,
},
resetUpscale() {
set({
upscale: {
...defaultParams,
source: null,
},
});
},
setUpscale(source) {
set((prev) => ({
upscale: {
...prev.upscale,
...upscale,
},
}));
},
setUpscaleTab(source) {
set((prev) => ({
upscaleTab: {
...prev.upscaleTab,
...source,
},
}));
},
resetUpscaleTab() {
set({
upscaleTab: {
negativePrompt: server.negativePrompt.default,
prompt: server.prompt.default,
source: null,
},
});
},
});
const createHighresSlice: Slice<HighresSlice> = (set) => ({
highres: {
enabled: false,
highresIterations: server.highresIterations.default,
highresMethod: '',
highresSteps: server.highresSteps.default,
highresScale: server.highresScale.default,
highresStrength: server.highresStrength.default,
},
setHighres(params) {
setUpscaleHighres(params) {
set((prev) => ({
highres: {
...prev.highres,
upscaleHighres: {
...prev.upscaleHighres,
...params,
},
}));
},
resetHighres() {
set({
highres: {
enabled: false,
highresIterations: server.highresIterations.default,
highresMethod: '',
highresSteps: server.highresSteps.default,
highresScale: server.highresScale.default,
highresStrength: server.highresStrength.default,
setUpscaleModel(params) {
set((prev) => ({
upscaleModel: {
...prev.upscaleModel,
...defaultModel,
},
});
}));
},
setUpscaleUpscale(params) {
set((prev) => ({
upscaleUpscale: {
...prev.upscaleUpscale,
...params,
},
}));
},
});
@ -516,13 +627,14 @@ export function createStateSlices(server: ServerParams) {
mask: null,
sources: [],
},
setBlend(blend) {
set((prev) => ({
blend: {
...prev.blend,
...blend,
blendBrush: {
...DEFAULT_BRUSH,
},
}));
blendModel: {
...defaultModel,
},
blendUpscale: {
...defaultUpscale,
},
resetBlend() {
set({
@ -532,11 +644,43 @@ export function createStateSlices(server: ServerParams) {
},
});
},
setBlend(blend) {
set((prev) => ({
blend: {
...prev.blend,
...blend,
},
}));
},
setBlendBrush(brush) {
set((prev) => ({
blendBrush: {
...prev.blendBrush,
...brush,
},
}));
},
setBlendModel(model) {
set((prev) => ({
blendModel: {
...prev.blendModel,
...model,
},
}));
},
setBlendUpscale(params) {
set((prev) => ({
blendUpscale: {
...prev.blendUpscale,
...params,
},
}));
},
});
const createDefaultSlice: Slice<DefaultSlice> = (set) => ({
defaults: {
...base,
...defaultParams,
},
theme: '',
setDefaults(params) {
@ -554,25 +698,6 @@ export function createStateSlices(server: ServerParams) {
}
});
const createModelSlice: Slice<ModelSlice> = (set) => ({
model: {
control: server.control.default,
correction: server.correction.default,
model: server.model.default,
pipeline: server.pipeline.default,
platform: server.platform.default,
upscaling: server.upscaling.default,
},
setModel(params) {
set((prev) => ({
model: {
...prev.model,
...params,
}
}));
},
});
const createResetSlice: Slice<ResetSlice> = (set) => ({
resetAll() {
set((prev) => {
@ -580,7 +705,7 @@ export function createStateSlices(server: ServerParams) {
next.resetImg2Img();
next.resetInpaint();
next.resetTxt2Img();
next.resetUpscaleTab();
next.resetUpscale();
next.resetBlend();
return next;
});
@ -620,7 +745,7 @@ export function createStateSlices(server: ServerParams) {
});
// eslint-disable-next-line sonarjs/cognitive-complexity
const createExtraSlice: Slice<ExtraSlice> = (set) => ({
const createModelSlice: Slice<ModelSlice> = (set) => ({
extras: {
correction: [],
diffusion: [],
@ -799,19 +924,15 @@ export function createStateSlices(server: ServerParams) {
});
return {
createBrushSlice,
createDefaultSlice,
createHistorySlice,
createImg2ImgSlice,
createInpaintSlice,
createModelSlice,
createOutpaintSlice,
createTxt2ImgSlice,
createUpscaleSlice,
createHighresSlice,
createBlendSlice,
createResetSlice,
createExtraSlice,
createModelSlice,
createProfileSlice,
};
}