1
0
Fork 0

continue refactoring to use selectors

This commit is contained in:
Sean Sube 2023-07-22 18:21:54 -05:00
parent 97daf1aa7c
commit 0ba21dfc27
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
20 changed files with 290 additions and 197 deletions

View File

@ -7,7 +7,7 @@ import { useContext, useMemo } from 'react';
import { useHash } from 'react-use/lib/useHash';
import { useStore } from 'zustand';
import { StateContext } from '../state.js';
import { OnnxState, StateContext } from '../state.js';
import { ImageHistory } from './ImageHistory.js';
import { Logo } from './Logo.js';
import { Blend } from './tab/Blend.js';
@ -22,7 +22,8 @@ import { getTab, getTheme, TAB_LABELS } from './utils.js';
export function OnnxWeb() {
/* checks for system light/dark mode preference */
const prefersDarkMode = useMediaQuery('(prefers-color-scheme: dark)');
const stateTheme = useStore(mustExist(useContext(StateContext)), (s) => s.theme);
const store = mustExist(useContext(StateContext));
const stateTheme = useStore(store, selectTheme);
const theme = useMemo(
() => createTheme({
@ -80,3 +81,7 @@ export function OnnxWeb() {
</ThemeProvider>
);
}
export function selectTheme(state: OnnxState) {
return state.theme;
}

View File

@ -37,13 +37,9 @@ export interface ProfilesProps {
}
export function Profiles(props: ProfilesProps) {
const state = mustExist(useContext(StateContext));
const profiles = useStore(state, (s) => s.profiles);
// eslint-disable-next-line @typescript-eslint/unbound-method
const saveProfile = useStore(state, (s) => s.saveProfile);
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeProfile = useStore(state, (s) => s.removeProfile);
const store = mustExist(useContext(StateContext));
const { removeProfile, saveProfile } = useStore(store, selectActions);
const profiles = useStore(store, selectProfiles);
const [dialogOpen, setDialogOpen] = useState(false);
const [profileName, setProfileName] = useState('');
@ -116,12 +112,12 @@ export function Profiles(props: ProfilesProps) {
<Button
variant='contained'
onClick={() => {
const innerState = state.getState();
const state = store.getState();
saveProfile({
params: props.selectParams(innerState),
params: props.selectParams(state),
name: profileName,
highres: props.selectHighres(innerState),
upscale: props.selectUpscale(innerState),
highres: props.selectHighres(state),
upscale: props.selectUpscale(state),
});
setDialogOpen(false);
setProfileName('');
@ -161,11 +157,11 @@ export function Profiles(props: ProfilesProps) {
/>
</Button>
<Button component='label' variant='contained' onClick={() => {
const innerState = state.getState();
const state = store.getState();
downloadParamsAsFile({
params: props.selectParams(innerState),
highres: props.selectHighres(innerState),
upscale: props.selectUpscale(innerState),
params: props.selectParams(state),
highres: props.selectHighres(state),
upscale: props.selectUpscale(state),
});
}}>
<Download />
@ -173,6 +169,19 @@ export function Profiles(props: ProfilesProps) {
</Stack>;
}
export function selectActions(state: OnnxState) {
return {
// eslint-disable-next-line @typescript-eslint/unbound-method
removeProfile: state.removeProfile,
// eslint-disable-next-line @typescript-eslint/unbound-method
saveProfile: state.saveProfile,
};
}
export function selectProfiles(state: OnnxState) {
return state.profiles;
}
export async function loadParamsFromFile(file: File): Promise<DeepPartial<ImageMetadata>> {
const parts = file.name.toLocaleLowerCase().split('.');
const ext = parts[parts.length - 1];

View File

@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { ImageResponse, ReadyResponse, RetryParams } from '../../client/types.js';
import { ClientContext, ConfigContext, StateContext } from '../../state.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js';
export interface ErrorCardProps {
image: ImageResponse;
@ -20,14 +20,11 @@ export interface ErrorCardProps {
export function ErrorCard(props: ErrorCardProps) {
const { image, ready, retry: retryParams } = props;
const client = mustExist(React.useContext(ClientContext));
const client = mustExist(useContext(ClientContext));
const { params } = mustExist(useContext(ConfigContext));
const state = mustExist(useContext(StateContext));
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeHistory = useStore(state, (s) => s.removeHistory);
const { pushHistory, removeHistory } = useStore(state, selectActions);
const { t } = useTranslation();
async function retryImage() {
@ -72,3 +69,12 @@ export function ErrorCard(props: ErrorCardProps) {
</CardContent>
</Card>;
}
export function selectActions(state: OnnxState) {
return {
// eslint-disable-next-line @typescript-eslint/unbound-method
pushHistory: state.pushHistory,
// eslint-disable-next-line @typescript-eslint/unbound-method
removeHistory: state.removeHistory,
};
}

View File

@ -8,7 +8,7 @@ import { useHash } from 'react-use/lib/useHash';
import { useStore } from 'zustand';
import { ImageResponse } from '../../client/types.js';
import { BLEND_SOURCES, ConfigContext, StateContext } from '../../state.js';
import { BLEND_SOURCES, ConfigContext, OnnxState, StateContext } from '../../state.js';
import { range, visibleIndex } from '../../utils.js';
export interface ImageCardProps {
@ -32,15 +32,8 @@ export function ImageCard(props: ImageCardProps) {
const [saveAnchor, setSaveAnchor] = useState<Maybe<HTMLElement>>();
const config = mustExist(useContext(ConfigContext));
const state = mustExist(useContext(StateContext));
// eslint-disable-next-line @typescript-eslint/unbound-method
const setImg2Img = useStore(state, (s) => s.setImg2Img);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setInpaint = useStore(state, (s) => s.setInpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setUpscale = useStore(state, (s) => s.setUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setBlend = useStore(state, (s) => s.setBlend);
const store = mustExist(useContext(StateContext));
const { setBlend, setImg2Img, setInpaint, setUpscale } = useStore(store, selectActions);
async function loadSource() {
const req = await fetch(outputs[index].url);
@ -73,7 +66,7 @@ export function ImageCard(props: ImageCardProps) {
async function copySourceToBlend(idx: number) {
const blob = await loadSource();
const sources = mustDefault(state.getState().blend.sources, []);
const sources = mustDefault(store.getState().blend.sources, []);
const newSources = [...sources];
newSources[idx] = blob;
setBlend({
@ -229,3 +222,16 @@ export function ImageCard(props: ImageCardProps) {
</CardContent>
</Card>;
}
export function selectActions(state: OnnxState) {
return {
// eslint-disable-next-line @typescript-eslint/unbound-method
setBlend: state.setBlend,
// eslint-disable-next-line @typescript-eslint/unbound-method
setImg2Img: state.setImg2Img,
// eslint-disable-next-line @typescript-eslint/unbound-method
setInpaint: state.setInpaint,
// eslint-disable-next-line @typescript-eslint/unbound-method
setUpscale: state.setUpscale,
};
}

View File

@ -9,7 +9,7 @@ import { useStore } from 'zustand';
import { ImageResponse } from '../../client/types.js';
import { POLL_TIME } from '../../config.js';
import { ClientContext, ConfigContext, StateContext } from '../../state.js';
import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../state.js';
const LOADING_PERCENT = 100;
const LOADING_OVERAGE = 99;
@ -23,14 +23,11 @@ export function LoadingCard(props: LoadingCardProps) {
const { image, index } = props;
const { steps } = props.image.params;
const client = mustExist(React.useContext(ClientContext));
const client = mustExist(useContext(ClientContext));
const { params } = mustExist(useContext(ConfigContext));
const state = mustExist(useContext(StateContext));
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeHistory = useStore(state, (s) => s.removeHistory);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setReady = useStore(state, (s) => s.setReady);
const store = mustExist(useContext(StateContext));
const { removeHistory, setReady } = useStore(store, selectActions);
const { t } = useTranslation();
const cancel = useMutation(() => client.cancel(image.outputs[index].key));
@ -118,3 +115,12 @@ export function LoadingCard(props: LoadingCardProps) {
</CardContent>
</Card>;
}
export function selectActions(state: OnnxState) {
return {
// eslint-disable-next-line @typescript-eslint/unbound-method
removeHistory: state.removeHistory,
// eslint-disable-next-line @typescript-eslint/unbound-method
setReady: state.setReady,
};
}

View File

@ -18,8 +18,8 @@ export function HighresControl(props: HighresControlProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const { selectHighres, setHighres } = props;
const state = mustExist(useContext(StateContext));
const highres = useStore(state, selectHighres);
const store = mustExist(useContext(StateContext));
const highres = useStore(store, selectHighres);
const { params } = mustExist(useContext(ConfigContext));
const { t } = useTranslation();

View File

@ -2,10 +2,12 @@ import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
import { Casino } from '@mui/icons-material';
import { Button, Checkbox, FormControlLabel, Stack } from '@mui/material';
import { useQuery } from '@tanstack/react-query';
import { omit } from 'lodash';
import * as React from 'react';
import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { shallow } from 'zustand/shallow';
import { BaseImgParams } from '../../client/types.js';
import { STALE_TIME } from '../../config.js';
@ -14,20 +16,30 @@ import { NumericField } from '../input/NumericField.js';
import { PromptInput } from '../input/PromptInput.js';
import { QueryList } from '../input/QueryList.js';
const { useMemo } = React;
type BaseParamsWithoutPrompt = Omit<BaseImgParams, 'prompt' | 'negativePrompt'>;
export interface ImageControlProps {
onChange(params: BaseImgParams): void;
onChange(params: Partial<BaseImgParams>): void;
selector(state: OnnxState): BaseImgParams;
}
export function omitPrompt(selector: (state: OnnxState) => BaseImgParams): (state: OnnxState) => BaseParamsWithoutPrompt {
return (state) => omit(selector(state), 'prompt', 'negativePrompt');
}
/**
* Doesn't need to use state directly, the parent component knows which params to pass
*/
export function ImageControl(props: ImageControlProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const { onChange, selector } = props;
const selectOmitPrompt = useMemo(() => omitPrompt(selector), [selector]);
const { params } = mustExist(useContext(ConfigContext));
const state = mustExist(useContext(StateContext));
const controlState = useStore(state, selector);
const store = mustExist(useContext(StateContext));
const state = useStore(store, selectOmitPrompt, shallow);
const { t } = useTranslation();
const client = mustExist(useContext(ClientContext));
@ -36,7 +48,7 @@ export function ImageControl(props: ImageControlProps) {
});
// max stride is the lesser of tile size and server's max stride
const maxStride = Math.min(controlState.tiles, params.stride.max);
const maxStride = Math.min(state.tiles, params.stride.max);
return <Stack spacing={2}>
<Stack direction='row' spacing={4}>
@ -47,11 +59,11 @@ export function ImageControl(props: ImageControlProps) {
query={{
result: schedulers,
}}
value={mustDefault(controlState.scheduler, '')}
value={mustDefault(state.scheduler, '')}
onChange={(value) => {
if (doesExist(onChange)) {
onChange({
...controlState,
...state,
scheduler: value,
});
}
@ -63,11 +75,11 @@ export function ImageControl(props: ImageControlProps) {
min={params.eta.min}
max={params.eta.max}
step={params.eta.step}
value={controlState.eta}
value={state.eta}
onChange={(eta) => {
if (doesExist(onChange)) {
onChange({
...controlState,
...state,
eta,
});
}
@ -79,11 +91,11 @@ export function ImageControl(props: ImageControlProps) {
min={params.cfg.min}
max={params.cfg.max}
step={params.cfg.step}
value={controlState.cfg}
value={state.cfg}
onChange={(cfg) => {
if (doesExist(onChange)) {
onChange({
...controlState,
...state,
cfg,
});
}
@ -94,10 +106,10 @@ export function ImageControl(props: ImageControlProps) {
min={params.steps.min}
max={params.steps.max}
step={params.steps.step}
value={controlState.steps}
value={state.steps}
onChange={(steps) => {
onChange({
...controlState,
...state,
steps,
});
}}
@ -107,10 +119,10 @@ export function ImageControl(props: ImageControlProps) {
min={params.seed.min}
max={params.seed.max}
step={params.seed.step}
value={controlState.seed}
value={state.seed}
onChange={(seed) => {
onChange({
...controlState,
...state,
seed,
});
}}
@ -121,7 +133,7 @@ export function ImageControl(props: ImageControlProps) {
onClick={() => {
const seed = Math.floor(Math.random() * params.seed.max);
props.onChange({
...controlState,
...state,
seed,
});
}}
@ -135,10 +147,10 @@ export function ImageControl(props: ImageControlProps) {
min={params.batch.min}
max={params.batch.max}
step={params.batch.step}
value={controlState.batch}
value={state.batch}
onChange={(batch) => {
props.onChange({
...controlState,
...state,
batch,
});
}}
@ -148,10 +160,10 @@ export function ImageControl(props: ImageControlProps) {
min={params.tiles.min}
max={params.tiles.max}
step={params.tiles.step}
value={controlState.tiles}
value={state.tiles}
onChange={(tiles) => {
props.onChange({
...controlState,
...state,
tiles,
});
}}
@ -162,10 +174,10 @@ export function ImageControl(props: ImageControlProps) {
min={params.overlap.min}
max={params.overlap.max}
step={params.overlap.step}
value={controlState.overlap}
value={state.overlap}
onChange={(overlap) => {
props.onChange({
...controlState,
...state,
overlap,
});
}}
@ -175,10 +187,10 @@ export function ImageControl(props: ImageControlProps) {
min={params.stride.min}
max={maxStride}
step={params.stride.step}
value={controlState.stride}
value={state.stride}
onChange={(stride) => {
props.onChange({
...controlState,
...state,
stride,
});
}}
@ -186,23 +198,22 @@ export function ImageControl(props: ImageControlProps) {
<FormControlLabel
label={t('parameter.tiledVAE')}
control={<Checkbox
checked={controlState.tiledVAE}
checked={state.tiledVAE}
value='check'
onChange={(event) => {
props.onChange({
...controlState,
tiledVAE: controlState.tiledVAE === false,
...state,
tiledVAE: state.tiledVAE === false,
});
}}
/>}
/>
</Stack>
<PromptInput
prompt={controlState.prompt}
negativePrompt={controlState.negativePrompt}
selector={selector}
onChange={(value) => {
props.onChange({
...controlState,
...state,
...value,
});
}}

View File

@ -5,15 +5,14 @@ import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { ConfigContext, StateContext } from '../../state.js';
import { ConfigContext, OnnxState, StateContext } from '../../state.js';
import { NumericField } from '../input/NumericField.js';
export function OutpaintControl() {
const { params } = mustExist(useContext(ConfigContext));
const state = mustExist(useContext(StateContext));
const outpaint = useStore(state, (s) => s.outpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setOutpaint = useStore(state, (s) => s.setOutpaint);
const store = mustExist(useContext(StateContext));
const {setOutpaint} = useStore(store, selectActions);
const outpaint = useStore(store, selectOutpaint);
const { t } = useTranslation();
return <Stack direction='row' spacing={4}>
@ -83,3 +82,14 @@ export function OutpaintControl() {
/>
</Stack>;
}
export function selectActions(state: OnnxState) {
return {
// eslint-disable-next-line @typescript-eslint/unbound-method
setOutpaint: state.setOutpaint,
};
}
export function selectOutpaint(state: OnnxState) {
return state.outpaint;
}

View File

@ -18,8 +18,8 @@ export function UpscaleControl(props: UpscaleControlProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method
const { selectUpscale, setUpscale } = props;
const state = mustExist(useContext(StateContext));
const upscale = useStore(state, selectUpscale);
const store = mustExist(useContext(StateContext));
const upscale = useStore(store, selectUpscale);
const { params } = mustExist(useContext(ConfigContext));
const { t } = useTranslation();

View File

@ -25,8 +25,8 @@ export function EditableList<T>(props: EditableListProps<T>) {
const { newItem, removeItem, renderItem, setItem, selector } = props;
const { t } = useTranslation();
const state = mustExist(useContext(StateContext));
const items = useStore(state, selector);
const store = mustExist(useContext(StateContext));
const items = useStore(store, selector);
const [nextLabel, setNextLabel] = useState('');
const [nextSource, setNextSource] = useState('');
const RenderMemo = useMemo(() => memo(renderItem), [renderItem]);

View File

@ -1,23 +1,28 @@
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { mustExist } from '@apextoaster/js-utils';
import { TextField } from '@mui/material';
import { Stack } from '@mui/system';
import { useQuery } from '@tanstack/react-query';
import * as React from 'react';
import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { QueryMenu } from '../input/QueryMenu.js';
import { STALE_TIME } from '../../config.js';
import { ClientContext } from '../../state.js';
import { ClientContext, OnnxState, StateContext } from '../../state.js';
import { QueryMenu } from '../input/QueryMenu.js';
const { useContext } = React;
/**
* @todo replace with a selector
*/
export interface PromptValue {
prompt: string;
negativePrompt?: string;
}
export interface PromptInputProps extends PromptValue {
onChange: (value: PromptValue) => void;
export interface PromptInputProps {
selector(state: OnnxState): PromptValue;
onChange(value: PromptValue): void;
}
export const PROMPT_GROUP = 75;
@ -31,16 +36,20 @@ function splitPrompt(prompt: string): Array<string> {
}
export function PromptInput(props: PromptInputProps) {
const { prompt = '', negativePrompt = '' } = props;
// eslint-disable-next-line @typescript-eslint/unbound-method
const { selector, onChange } = props;
const tokens = splitPrompt(prompt);
const groups = Math.ceil(tokens.length / PROMPT_GROUP);
const store = mustExist(useContext(StateContext));
const { prompt, negativePrompt } = useStore(store, selector);
const client = mustExist(useContext(ClientContext));
const models = useQuery(['models'], async () => client.models(), {
staleTime: STALE_TIME,
});
const tokens = splitPrompt(prompt);
const groups = Math.ceil(tokens.length / PROMPT_GROUP);
const { t } = useTranslation();
const helper = t('input.prompt.tokens', {
groups,
@ -48,7 +57,7 @@ export function PromptInput(props: PromptInputProps) {
});
function addToken(type: string, name: string, weight = 1.0) {
props.onChange({
onChange({
prompt: `<${type}:${name}:1.0> ${prompt}`,
negativePrompt,
});
@ -61,12 +70,10 @@ export function PromptInput(props: PromptInputProps) {
variant='outlined'
value={prompt}
onChange={(event) => {
if (doesExist(props.onChange)) {
props.onChange({
prompt: event.target.value,
negativePrompt,
});
}
}}
/>
<TextField
@ -74,12 +81,10 @@ export function PromptInput(props: PromptInputProps) {
variant='outlined'
value={negativePrompt}
onChange={(event) => {
if (doesExist(props.onChange)) {
props.onChange({
prompt,
negativePrompt: event.target.value,
});
}
}}
/>
<Stack direction='row' spacing={2}>

View File

@ -1,9 +1,9 @@
import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils';
import { Alert, FormControl, FormLabel, InputLabel, LinearProgress, MenuItem, Select, Typography } from '@mui/material';
import { Alert, FormControl, FormLabel, InputLabel, LinearProgress, MenuItem, Select } from '@mui/material';
import { UseQueryResult } from '@tanstack/react-query';
import * as React from 'react';
import { useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { UseQueryResult } from '@tanstack/react-query';
export interface QueryListComplete {
result: UseQueryResult<Array<string>>;

View File

@ -1,9 +1,11 @@
import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils';
import { KeyboardArrowDown } from '@mui/icons-material';
import { Alert, Box, Button, FormControl, FormLabel, LinearProgress, Menu, MenuItem, Typography } from '@mui/material';
import { Alert, Box, Button, FormControl, FormLabel, LinearProgress, Menu, MenuItem } from '@mui/material';
import { UseQueryResult } from '@tanstack/react-query';
import * as React from 'react';
import { useTranslation } from 'react-i18next';
import { UseQueryResult } from '@tanstack/react-query';
const { useState } = React;
export interface QueryMenuComplete {
result: UseQueryResult<Array<string>>;
@ -53,7 +55,7 @@ export function QueryMenu<T>(props: QueryMenuProps<T>) {
const { t } = useTranslation();
const [anchor, setAnchor] = React.useState<Maybe<HTMLElement>>(undefined);
const [anchor, setAnchor] = useState<Maybe<HTMLElement>>(undefined);
function closeMenu() {
setAnchor(undefined);

View File

@ -6,7 +6,7 @@ import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { useStore } from 'zustand';
import { BlendParams, ModelParams, UpscaleParams } from '../../client/types.js';
import { BlendParams, BrushParams, ModelParams, UpscaleParams } from '../../client/types.js';
import { IMAGE_FILTER } from '../../config.js';
import { BLEND_SOURCES, ClientContext, OnnxState, StateContext, TabState } from '../../state.js';
import { range } from '../../utils.js';
@ -16,7 +16,7 @@ import { MaskCanvas } from '../input/MaskCanvas.js';
export function Blend() {
async function uploadSource() {
const { blend, blendModel, blendUpscale } = state.getState();
const { blend, blendModel, blendUpscale } = store.getState();
const { image, retry } = await client.blend(blendModel, {
...blend,
mask: mustExist(blend.mask),
@ -32,17 +32,10 @@ export function Blend() {
onSuccess: () => query.invalidateQueries(['ready']),
});
const state = mustExist(useContext(StateContext));
const brush = useStore(state, (s) => s.blendBrush);
const blend = useStore(state, selectParams);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setBlend = useStore(state, (s) => s.setBlend);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setBrush = useStore(state, (s) => s.setBlendBrush);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setUpscale = useStore(state, (s) => s.setBlendUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);
const store = mustExist(useContext(StateContext));
const { pushHistory, setBlend, setBrush, setUpscale } = useStore(store, selectActions);
const brush = useStore(store, selectBrush);
const blend = useStore(store, selectParams);
const { t } = useTranslation();
const sources = mustDefault(blend.sources, []);
@ -87,6 +80,23 @@ export function Blend() {
</Box>;
}
export function selectActions(state: OnnxState) {
return {
// eslint-disable-next-line @typescript-eslint/unbound-method
pushHistory: state.pushHistory,
// eslint-disable-next-line @typescript-eslint/unbound-method
setBlend: state.setBlend,
// eslint-disable-next-line @typescript-eslint/unbound-method
setBrush: state.setBlendBrush,
// eslint-disable-next-line @typescript-eslint/unbound-method
setUpscale: state.setBlendUpscale,
};
}
export function selectBrush(state: OnnxState): BrushParams {
return state.blendBrush;
}
export function selectModel(state: OnnxState): ModelParams {
return state.blendModel;
}

View File

@ -23,13 +23,13 @@ export function Img2Img() {
const { params } = mustExist(useContext(ConfigContext));
async function uploadSource() {
const innerState = state.getState();
const img2img = selectParams(innerState);
const state = store.getState();
const img2img = selectParams(state);
const { image, retry } = await client.img2img(model, {
...img2img,
source: mustExist(img2img.source), // TODO: show an error if this doesn't exist
}, selectUpscale(innerState), selectHighres(innerState));
}, selectUpscale(state), selectHighres(state));
pushHistory(image, retry);
}
@ -47,10 +47,10 @@ export function Img2Img() {
staleTime: STALE_TIME,
});
const state = mustExist(useContext(StateContext));
const { pushHistory, setHighres, setImg2Img, setModel, setUpscale } = useStore(state, selectActions, shallow);
const { loopback, source, sourceFilter, strength } = useStore(state, selectReactParams, shallow);
const model = useStore(state, selectModel);
const store = mustExist(useContext(StateContext));
const { pushHistory, setHighres, setImg2Img, setModel, setUpscale } = useStore(store, selectActions, shallow);
const { loopback, source, sourceFilter, strength } = useStore(store, selectReactParams, shallow);
const model = useStore(store, selectModel);
const { t } = useTranslation();
@ -75,7 +75,7 @@ export function Img2Img() {
});
}}
/>
<ImageControl selector={(s) => s.img2img} onChange={setImg2Img} />
<ImageControl selector={selectParams} onChange={setImg2Img} />
<Stack direction='row' spacing={2}>
<QueryList
id='control'

View File

@ -33,9 +33,9 @@ export function Inpaint() {
});
async function uploadSource(): Promise<void> {
const innerState = state.getState();
const { outpaint } = innerState;
const inpaint = selectParams(innerState);
const state = store.getState();
const { outpaint } = state;
const inpaint = selectParams(state);
if (outpaint.enabled) {
const { image, retry } = await client.outpaint(model, {
@ -43,7 +43,7 @@ export function Inpaint() {
...outpaint,
mask: mustExist(mask),
source: mustExist(source),
}, selectUpscale(innerState), selectHighres(innerState));
}, selectUpscale(state), selectHighres(state));
pushHistory(image, retry);
} else {
@ -51,7 +51,7 @@ export function Inpaint() {
...inpaint,
mask: mustExist(mask),
source: mustExist(source),
}, selectUpscale(innerState), selectHighres(innerState));
}, selectUpscale(state), selectHighres(state));
pushHistory(image, retry);
}
@ -65,11 +65,11 @@ export function Inpaint() {
return model.model.includes('inpaint');
}
const state = mustExist(useContext(StateContext));
const { pushHistory, setBrush, setHighres, setModel, setInpaint, setUpscale } = useStore(state, selectActions, shallow);
const { source, mask, strength, noise, filter, tileOrder, fillColor } = useStore(state, selectReactParams, shallow);
const model = useStore(state, selectModel);
const brush = useStore(state, selectBrush);
const store = mustExist(useContext(StateContext));
const { pushHistory, setBrush, setHighres, setModel, setInpaint, setUpscale } = useStore(store, selectActions, shallow);
const { source, mask, strength, noise, filter, tileOrder, fillColor } = useStore(store, selectReactParams, shallow);
const model = useStore(store, selectModel);
const brush = useStore(store, selectBrush);
const { t } = useTranslation();
@ -132,7 +132,7 @@ export function Inpaint() {
setBrush={setBrush}
/>
<ImageControl
selector={(s) => s.inpaint}
selector={selectParams}
onChange={(newParams) => {
setInpaint(newParams);
}}

View File

@ -77,44 +77,35 @@ function selectExtraSources(state: OnnxState): Array<ExtraSource> {
}
export function Models() {
const state = mustExist(React.useContext(StateContext));
// eslint-disable-next-line @typescript-eslint/unbound-method
const setExtras = useStore(state, (s) => s.setExtras);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setCorrectionModel = useStore(state, (s) => s.setCorrectionModel);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setDiffusionModel = useStore(state, (s) => s.setDiffusionModel);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setExtraNetwork = useStore(state, (s) => s.setExtraNetwork);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setExtraSource = useStore(state, (s) => s.setExtraSource);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setUpscalingModel = useStore(state, (s) => s.setUpscalingModel);
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeCorrectionModel = useStore(state, (s) => s.removeCorrectionModel);
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeDiffusionModel = useStore(state, (s) => s.removeDiffusionModel);
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeExtraNetwork = useStore(state, (s) => s.removeExtraNetwork);
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeExtraSource = useStore(state, (s) => s.removeExtraSource);
// eslint-disable-next-line @typescript-eslint/unbound-method
const removeUpscalingModel = useStore(state, (s) => s.removeUpscalingModel);
const client = mustExist(useContext(ClientContext));
const store = mustExist(useContext(StateContext));
const {
setCorrectionModel,
setDiffusionModel,
setExtraNetwork,
setExtraSource,
setExtras,
setUpscalingModel,
removeCorrectionModel,
removeDiffusionModel,
removeExtraNetwork,
removeExtraSource,
removeUpscalingModel,
} = useStore(store, selectActions);
const client = mustExist(useContext(ClientContext));
const result = useQuery(['extras'], async () => client.extras(), {
staleTime: STALE_TIME,
});
const query = useQueryClient();
const write = useMutation(writeExtras, {
onSuccess: () => query.invalidateQueries([ 'extras' ]),
onSuccess: () => query.invalidateQueries(['extras']),
});
const { t } = useTranslation();
useEffect(() => {
if (result.status === 'success' && doesExist(result.data)) {
setExtras(mergeModels(state.getState().extras, result.data));
setExtras(mergeModels(store.getState().extras, result.data));
}
}, [result.status]);
@ -127,18 +118,18 @@ export function Models() {
if (result.status === 'loading') {
return <Stack spacing={2} direction='row' sx={{ alignItems: 'center' }}>
<CircularProgress />
</Stack> ;
</Stack>;
}
async function writeExtras() {
const resp = await client.writeExtras(state.getState().extras);
const resp = await client.writeExtras(store.getState().extras);
// TODO: do something with resp
}
return <Stack spacing={2}>
<Accordion>
<AccordionSummary>
{t('modelType.diffusion', {count: 10})}
{t('modelType.diffusion', { count: 10 })}
</AccordionSummary>
<AccordionDetails>
<EditableList<DiffusionModel>
@ -157,7 +148,7 @@ export function Models() {
</Accordion>
<Accordion>
<AccordionSummary>
{t('modelType.correction', {count: 10})}
{t('modelType.correction', { count: 10 })}
</AccordionSummary>
<AccordionDetails>
<EditableList
@ -176,7 +167,7 @@ export function Models() {
</Accordion>
<Accordion>
<AccordionSummary>
{t('modelType.upscaling', {count: 10})}
{t('modelType.upscaling', { count: 10 })}
</AccordionSummary>
<AccordionDetails>
<EditableList
@ -196,7 +187,7 @@ export function Models() {
</Accordion>
<Accordion>
<AccordionSummary>
{t('modelType.network', {count: 10})}
{t('modelType.network', { count: 10 })}
</AccordionSummary>
<AccordionDetails>
<EditableList
@ -217,7 +208,7 @@ export function Models() {
</Accordion>
<Accordion>
<AccordionSummary>
{t('modelType.source', {count: 10})}
{t('modelType.source', { count: 10 })}
</AccordionSummary>
<AccordionDetails>
<EditableList
@ -237,3 +228,30 @@ export function Models() {
<Button color='warning' onClick={() => write.mutate()}>{t('convert')}</Button>
</Stack>;
}
export function selectActions(state: OnnxState) {
return {
// eslint-disable-next-line @typescript-eslint/unbound-method
setExtras: state.setExtras,
// eslint-disable-next-line @typescript-eslint/unbound-method
setCorrectionModel: state.setCorrectionModel,
// eslint-disable-next-line @typescript-eslint/unbound-method
setDiffusionModel: state.setDiffusionModel,
// eslint-disable-next-line @typescript-eslint/unbound-method
setExtraNetwork: state.setExtraNetwork,
// eslint-disable-next-line @typescript-eslint/unbound-method
setExtraSource: state.setExtraSource,
// eslint-disable-next-line @typescript-eslint/unbound-method
setUpscalingModel: state.setUpscalingModel,
// eslint-disable-next-line @typescript-eslint/unbound-method
removeCorrectionModel: state.removeCorrectionModel,
// eslint-disable-next-line @typescript-eslint/unbound-method
removeDiffusionModel: state.removeDiffusionModel,
// eslint-disable-next-line @typescript-eslint/unbound-method
removeExtraNetwork: state.removeExtraNetwork,
// eslint-disable-next-line @typescript-eslint/unbound-method
removeExtraSource: state.removeExtraSource,
// eslint-disable-next-line @typescript-eslint/unbound-method
removeUpscalingModel: state.removeUpscalingModel,
};
}

View File

@ -20,8 +20,8 @@ export function Txt2Img() {
const { params } = mustExist(useContext(ConfigContext));
async function generateImage() {
const innerState = state.getState();
const { image, retry } = await client.txt2img(model, selectParams(innerState), selectUpscale(innerState), selectHighres(innerState));
const state = store.getState();
const { image, retry } = await client.txt2img(model, selectParams(state), selectUpscale(state), selectHighres(state));
pushHistory(image, retry);
}
@ -32,10 +32,10 @@ export function Txt2Img() {
onSuccess: () => query.invalidateQueries([ 'ready' ]),
});
const state = mustExist(useContext(StateContext));
const { pushHistory, setHighres, setModel, setParams, setUpscale } = useStore(state, selectActions, shallow);
const { height, width } = useStore(state, selectReactParams, shallow);
const model = useStore(state, selectModel);
const store = mustExist(useContext(StateContext));
const { pushHistory, setHighres, setModel, setParams, setUpscale } = useStore(store, selectActions, shallow);
const { height, width } = useStore(store, selectReactParams, shallow);
const model = useStore(store, selectModel);
const { t } = useTranslation();

View File

@ -18,7 +18,7 @@ import { Profiles } from '../Profiles.js';
export function Upscale() {
async function uploadSource() {
const { upscaleHighres, upscaleUpscale, upscaleModel, upscale } = state.getState();
const { upscaleHighres, upscaleUpscale, upscaleModel, upscale } = store.getState();
const { image, retry } = await client.upscale(upscaleModel, {
...upscale,
source: mustExist(upscale.source), // TODO: show an error if this doesn't exist
@ -33,19 +33,10 @@ export function Upscale() {
onSuccess: () => query.invalidateQueries([ 'ready' ]),
});
const state = mustExist(useContext(StateContext));
const model = useStore(state, (s) => s.upscaleModel);
const params = useStore(state, (s) => s.upscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setModel = useStore(state, (s) => s.setUpscalingModel);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setHighres = useStore(state, (s) => s.setUpscaleHighres);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setUpscale = useStore(state, (s) => s.setUpscaleUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setParams = useStore(state, (s) => s.setUpscale);
// eslint-disable-next-line @typescript-eslint/unbound-method
const pushHistory = useStore(state, (s) => s.pushHistory);
const store = mustExist(useContext(StateContext));
const { pushHistory, setHighres, setModel, setParams, setUpscale } = useStore(store, selectActions);
const model = useStore(store, selectModel);
const params = useStore(store, selectParams);
const { t } = useTranslation();
return <Box>
@ -70,8 +61,7 @@ export function Upscale() {
}}
/>
<PromptInput
prompt={params.prompt}
negativePrompt={params.negativePrompt}
selector={selectParams}
onChange={(value) => {
setParams(value);
}}
@ -87,6 +77,21 @@ export function Upscale() {
</Box>;
}
export function selectActions(state: OnnxState) {
return {
// eslint-disable-next-line @typescript-eslint/unbound-method
pushHistory: state.pushHistory,
// eslint-disable-next-line @typescript-eslint/unbound-method
setHighres: state.setUpscaleHighres,
// eslint-disable-next-line @typescript-eslint/unbound-method
setModel: state.setUpscaleModel,
// eslint-disable-next-line @typescript-eslint/unbound-method
setParams: state.setUpscale,
// eslint-disable-next-line @typescript-eslint/unbound-method
setUpscale: state.setUpscaleUpscale,
};
}
export function selectModel(state: OnnxState): ModelParams {
return state.upscaleModel;
}

View File

@ -35,13 +35,13 @@ export type Theme = PaletteMode | ''; // tri-state, '' is unset
*/
export type TabState<TabParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>;
interface HistoryItem {
export interface HistoryItem {
image: ImageResponse;
ready: Maybe<ReadyResponse>;
retry: RetryParams;
}
interface ProfileItem {
export interface ProfileItem {
name: string;
params: BaseImgParams | Txt2ImgParams;
highres?: Maybe<HighresParams>;