diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 6bed1e46..1271caab 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -7,6 +7,7 @@ import { ApiClient, BaseImgParams, BlendParams, + ChainPipeline, FilterResponse, HighresParams, ImageResponse, @@ -430,6 +431,16 @@ export function makeClient(root: string, token: Maybe = undefined, f = f } }; }, + async chain(chain: ChainPipeline): Promise { + const url = makeApiUrl(root, 'chain'); + const body = JSON.stringify(chain); + + // eslint-disable-next-line no-return-await + return await parseRequest(url, { + body, + method: 'POST', + }); + }, async ready(key: string): Promise { const path = makeApiUrl(root, 'ready'); path.searchParams.append('output', key); diff --git a/gui/src/client/local.ts b/gui/src/client/local.ts index 97f785a8..273e5168 100644 --- a/gui/src/client/local.ts +++ b/gui/src/client/local.ts @@ -39,6 +39,9 @@ export const LOCAL_CLIENT = { async outpaint(model, params, upscale) { throw new NoServerError(); }, + async chain(chain) { + throw new NoServerError(); + }, async noises() { throw new NoServerError(); }, diff --git a/gui/src/client/types.ts b/gui/src/client/types.ts index b9d22495..d1912d7b 100644 --- a/gui/src/client/types.ts +++ b/gui/src/client/types.ts @@ -162,6 +162,22 @@ export interface HighresParams { highresStrength: number; } +export interface Txt2ImgStage { + name: string; + type: 'source-txt2img'; + params: Txt2ImgParams; +} + +export interface Img2ImgStage { + name: string; + type: 'blend-img2img'; + params: Img2ImgParams; +} + +export interface ChainPipeline { + stages: Array; +} + /** * Output image data within the response. */ @@ -354,6 +370,8 @@ export interface ApiClient { */ blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise; + chain(chain: ChainPipeline): Promise; + /** * Check whether job has finished and its output is ready. */ diff --git a/gui/src/client/utils.ts b/gui/src/client/utils.ts new file mode 100644 index 00000000..d1d69141 --- /dev/null +++ b/gui/src/client/utils.ts @@ -0,0 +1,42 @@ +import { ChainPipeline, HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from './types.js'; + +export interface PipelineVariable { + parameter: 'prompt' | 'cfg' | 'seed' | 'steps'; + input: string; + values: Array; +} + +export interface PipelineGrid { + enabled: boolean; + columns: PipelineVariable; + rows: PipelineVariable; +} + +// eslint-disable-next-line max-params +export function buildPipelineForTxt2ImgGrid(grid: PipelineGrid, model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): ChainPipeline { + const pipeline: ChainPipeline = { + stages: [], + }; + + let i = 0; + + for (const column of grid.columns.values) { + for (const row of grid.rows.values) { + pipeline.stages.push({ + name: `cell-${i}`, + type: 'source-txt2img', + params: { + ...params, + [grid.columns.parameter]: column, + [grid.rows.parameter]: row, + }, + }); + + i += 1; + } + } + + // TODO: add final grid stage + + return pipeline; +} diff --git a/gui/src/components/card/ErrorCard.tsx b/gui/src/components/card/ErrorCard.tsx index fe683e67..9bb6e455 100644 --- a/gui/src/components/card/ErrorCard.tsx +++ b/gui/src/components/card/ErrorCard.tsx @@ -1,4 +1,4 @@ -import { mustExist } from '@apextoaster/js-utils'; +import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils'; import { Delete, Replay } from '@mui/icons-material'; import { Alert, Box, Card, CardContent, IconButton, Tooltip } from '@mui/material'; import { Stack } from '@mui/system'; @@ -15,7 +15,7 @@ import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../sta export interface ErrorCardProps { image: ImageResponse; ready: ReadyResponse; - retry: RetryParams; + retry: Maybe; } export function ErrorCard(props: ErrorCardProps) { @@ -30,8 +30,11 @@ export function ErrorCard(props: ErrorCardProps) { async function retryImage() { removeHistory(image); - const { image: nextImage, retry: nextRetry } = await client.retry(retryParams); - pushHistory(nextImage, nextRetry); + + if (doesExist(retryParams)) { + const { image: nextImage, retry: nextRetry } = await client.retry(retryParams); + pushHistory(nextImage, nextRetry); + } } const retry = useMutation(retryImage); diff --git a/gui/src/components/control/VariableControl.tsx b/gui/src/components/control/VariableControl.tsx new file mode 100644 index 00000000..cd159954 --- /dev/null +++ b/gui/src/components/control/VariableControl.tsx @@ -0,0 +1,107 @@ +import { doesExist, mustExist } from '@apextoaster/js-utils'; +import { Checkbox, FormControl, InputLabel, MenuItem, Select, Stack, TextField } from '@mui/material'; +import * as React from 'react'; +import { useContext } from 'react'; +import { useStore } from 'zustand'; + +import { PipelineGrid } from '../../client/utils.js'; +import { OnnxState, StateContext } from '../../state.js'; + +export interface VariableControlProps { + selectGrid: (state: OnnxState) => PipelineGrid; + setGrid: (grid: Partial) => void; +} + +export type VariableKey = 'prompt' | 'steps' | 'seed'; + +export function VariableControl(props: VariableControlProps) { + const store = mustExist(useContext(StateContext)); + const grid = useStore(store, props.selectGrid); + + return + + Grid Mode + props.setGrid({ + enabled: grid.enabled === false, + })} /> + + + + Columns + + + props.setGrid({ + columns: { + parameter: grid.columns.parameter, + input: event.target.value, + values: rangeSplit(grid.columns.parameter, event.target.value), + }, + })} /> + + + + Rows + + + props.setGrid({ + rows: { + parameter: grid.rows.parameter, + input: event.target.value, + values: rangeSplit(grid.rows.parameter, event.target.value), + } + })} /> + + ; +} + +export function rangeSplit(parameter: string, value: string): Array { + // string values + if (parameter === 'prompt') { + return value.split('\n'); + } + + return value.split(',').map((it) => it.trim()).flatMap((it) => expandRanges(it)); +} + +export const EXPR_STRICT_NUMBER = /^[0-9]+$/; +export const EXPR_NUMBER_RANGE = /^([0-9]+)-([0-9]+)$/; + +export function expandRanges(range: string): Array { + if (EXPR_STRICT_NUMBER.test(range)) { + // entirely numeric, return without parsing + return [range]; + } + + if (EXPR_NUMBER_RANGE.test(range)) { + const match = EXPR_NUMBER_RANGE.exec(range); + if (doesExist(match)) { + const [_full, startStr, endStr] = Array.from(match); + const start = parseInt(startStr, 10); + const end = parseInt(endStr, 10); + + return new Array(end - start).fill(0).map((_value, idx) => (idx + start).toFixed(0)); + } + } + + return []; +} diff --git a/gui/src/components/tab/Txt2Img.tsx b/gui/src/components/tab/Txt2Img.tsx index 76e669ce..d5347926 100644 --- a/gui/src/components/tab/Txt2Img.tsx +++ b/gui/src/components/tab/Txt2Img.tsx @@ -15,15 +15,27 @@ import { ModelControl } from '../control/ModelControl.js'; import { UpscaleControl } from '../control/UpscaleControl.js'; import { NumericField } from '../input/NumericField.js'; import { Profiles } from '../Profiles.js'; +import { VariableControl } from '../control/VariableControl.js'; +import { PipelineGrid, buildPipelineForTxt2ImgGrid } from '../../client/utils.js'; export function Txt2Img() { const { params } = mustExist(useContext(ConfigContext)); async function generateImage() { const state = store.getState(); - const { image, retry } = await client.txt2img(model, selectParams(state), selectUpscale(state), selectHighres(state)); + const grid = selectVariable(state); + const params2 = selectParams(state); + const upscale = selectUpscale(state); + const highres = selectHighres(state); - pushHistory(image, retry); + if (grid.enabled) { + const chain = buildPipelineForTxt2ImgGrid(grid, model, params2, upscale, highres); + const image = await client.chain(chain); + pushHistory(image); + } else { + const { image, retry } = await client.txt2img(model, params2, upscale, highres); + pushHistory(image, retry); + } } const client = mustExist(useContext(ClientContext)); @@ -33,7 +45,7 @@ export function Txt2Img() { }); const store = mustExist(useContext(StateContext)); - const { pushHistory, setHighres, setModel, setParams, setUpscale } = useStore(store, selectActions, shallow); + const { pushHistory, setHighres, setModel, setParams, setUpscale, setVariable } = useStore(store, selectActions, shallow); const { height, width } = useStore(store, selectReactParams, shallow); const model = useStore(store, selectModel); @@ -79,6 +91,7 @@ export function Txt2Img() { +