diff --git a/gui/src/components/ImageControl.tsx b/gui/src/components/ImageControl.tsx index cd098673..4506518b 100644 --- a/gui/src/components/ImageControl.tsx +++ b/gui/src/components/ImageControl.tsx @@ -3,10 +3,11 @@ import { Casino } from '@mui/icons-material'; import { Button, Stack, TextField } from '@mui/material'; import * as React from 'react'; import { useQuery } from 'react-query'; +import { useStore } from 'zustand'; import { BaseImgParams } from '../client.js'; import { ConfigParams, STALE_TIME } from '../config.js'; -import { ClientContext } from '../state.js'; +import { ClientContext, OnnxState, StateContext } from '../state.js'; import { SCHEDULER_LABELS } from '../strings.js'; import { NumericField } from './NumericField.js'; import { QueryList } from './QueryList.js'; @@ -15,7 +16,8 @@ const { useContext } = React; export interface ImageControlProps { config: ConfigParams; - params: BaseImgParams; + + selector: (state: OnnxState) => BaseImgParams; onChange?: (params: BaseImgParams) => void; } @@ -24,7 +26,10 @@ export interface ImageControlProps { * doesn't need to use state, the parent component knows which params to pass */ export function ImageControl(props: ImageControlProps) { - const { config, params } = props; + const { config } = props; + + const state = mustExist(useContext(StateContext)); + const params = useStore(state, props.selector); const client = mustExist(useContext(ClientContext)); const schedulers = useQuery('schedulers', async () => client.schedulers(), { diff --git a/gui/src/components/Img2Img.tsx b/gui/src/components/Img2Img.tsx index 7105373f..aa67f5e4 100644 --- a/gui/src/components/Img2Img.tsx +++ b/gui/src/components/Img2Img.tsx @@ -24,13 +24,13 @@ export function Img2Img(props: Img2ImgProps) { const { config, model, platform } = props; async function uploadSource() { - const upscale = state.getState().upscale; + const { img2img, upscale } = state.getState(); const output = await client.img2img({ - ...params, + ...img2img, model, platform, - source: mustExist(params.source), // TODO: show an error if this doesn't exist + source: mustExist(img2img.source), // TODO: show an error if this doesn't exist }, upscale); setLoading(output); @@ -43,7 +43,8 @@ export function Img2Img(props: Img2ImgProps) { }); const state = mustExist(useContext(StateContext)); - const params = useStore(state, (s) => s.img2img); + const source = useStore(state, (s) => s.img2img.source); + const strength = useStore(state, (s) => s.img2img.strength); // eslint-disable-next-line @typescript-eslint/unbound-method const setImg2Img = useStore(state, (s) => s.setImg2Img); // eslint-disable-next-line @typescript-eslint/unbound-method @@ -51,21 +52,19 @@ export function Img2Img(props: Img2ImgProps) { return - { + { setImg2Img({ source: file, }); }} /> - { - setImg2Img(newParams); - }} /> + s.img2img} onChange={setImg2Img} /> { setImg2Img({ strength: value, diff --git a/gui/src/components/Inpaint.tsx b/gui/src/components/Inpaint.tsx index 4fa7872f..3c0966fc 100644 --- a/gui/src/components/Inpaint.tsx +++ b/gui/src/components/Inpaint.tsx @@ -35,27 +35,26 @@ export function Inpaint(props: InpaintProps) { async function uploadSource(): Promise { // these are not watched by the component, only sent by the mutation - const outpaint = state.getState().outpaint; - const upscale = state.getState().upscale; + const { inpaint, outpaint, upscale } = state.getState(); if (outpaint.enabled) { const output = await client.outpaint({ - ...params, + ...inpaint, ...outpaint, model, platform, - mask: mustExist(params.mask), - source: mustExist(params.source), + mask: mustExist(mask), + source: mustExist(source), }, upscale); setLoading(output); } else { const output = await client.inpaint({ - ...params, + ...inpaint, model, platform, - mask: mustExist(params.mask), - source: mustExist(params.source), + mask: mustExist(mask), + source: mustExist(source), }, upscale); setLoading(output); @@ -63,7 +62,10 @@ export function Inpaint(props: InpaintProps) { } const state = mustExist(useContext(StateContext)); - const params = useStore(state, (s) => s.inpaint); + const filter = useStore(state, (s) => s.inpaint.filter); + const noise = useStore(state, (s) => s.inpaint.noise); + const mask = useStore(state, (s) => s.inpaint.mask); + const source = useStore(state, (s) => s.inpaint.source); // eslint-disable-next-line @typescript-eslint/unbound-method const setInpaint = useStore(state, (s) => s.setInpaint); // eslint-disable-next-line @typescript-eslint/unbound-method @@ -78,7 +80,7 @@ export function Inpaint(props: InpaintProps) { { setInpaint({ @@ -88,7 +90,7 @@ export function Inpaint(props: InpaintProps) { /> { setInpaint({ @@ -98,11 +100,11 @@ export function Inpaint(props: InpaintProps) { renderImage={(image) => { + onSave={(file) => { setInpaint({ - mask, + mask: file, }); }} /> @@ -110,7 +112,7 @@ export function Inpaint(props: InpaintProps) { /> s.inpaint} onChange={(newParams) => { setInpaint(newParams); }} @@ -121,10 +123,10 @@ export function Inpaint(props: InpaintProps) { labels={MASK_LABELS} name='Mask Filter' result={masks} - value={params.filter} - onChange={(filter) => { + value={filter} + onChange={(newFilter) => { setInpaint({ - filter, + filter: newFilter, }); }} /> @@ -133,10 +135,10 @@ export function Inpaint(props: InpaintProps) { labels={NOISE_LABELS} name='Noise Source' result={noises} - value={params.noise} - onChange={(noise) => { + value={noise} + onChange={(newNoise) => { setInpaint({ - noise, + noise: newNoise, }); }} /> diff --git a/gui/src/components/Txt2Img.tsx b/gui/src/components/Txt2Img.tsx index 2bb32b59..1a85f9d7 100644 --- a/gui/src/components/Txt2Img.tsx +++ b/gui/src/components/Txt2Img.tsx @@ -23,9 +23,9 @@ export function Txt2Img(props: Txt2ImgProps) { const { config, model, platform } = props; async function generateImage() { - const upscale = state.getState().upscale; + const { txt2img, upscale } = state.getState(); const output = await client.txt2img({ - ...params, + ...txt2img, model, platform, }, upscale); @@ -40,7 +40,8 @@ export function Txt2Img(props: Txt2ImgProps) { }); const state = mustExist(useContext(StateContext)); - const params = useStore(state, (s) => s.txt2img); + const height = useStore(state, (s) => s.txt2img.height); + const width = useStore(state, (s) => s.txt2img.width); // eslint-disable-next-line @typescript-eslint/unbound-method const setTxt2Img = useStore(state, (s) => s.setTxt2Img); // eslint-disable-next-line @typescript-eslint/unbound-method @@ -48,16 +49,14 @@ export function Txt2Img(props: Txt2ImgProps) { return - { - setTxt2Img(newParams); - }} /> + s.txt2img} onChange={setTxt2Img} /> { setTxt2Img({ width: value, @@ -69,7 +68,7 @@ export function Txt2Img(props: Txt2ImgProps) { min={config.height.min} max={config.height.max} step={config.height.step} - value={params.height} + value={height} onChange={(value) => { setTxt2Img({ height: value,