1
0
Fork 0

add basic variables to txt2img tab

This commit is contained in:
Sean Sube 2023-09-10 20:59:33 -05:00
parent 1fb965633e
commit 9d4272eb09
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 239 additions and 9 deletions

View File

@ -7,6 +7,7 @@ import {
ApiClient,
BaseImgParams,
BlendParams,
ChainPipeline,
FilterResponse,
HighresParams,
ImageResponse,
@ -430,6 +431,16 @@ export function makeClient(root: string, token: Maybe<string> = undefined, f = f
}
};
},
async chain(chain: ChainPipeline): Promise<ImageResponse> {
const url = makeApiUrl(root, 'chain');
const body = JSON.stringify(chain);
// eslint-disable-next-line no-return-await
return await parseRequest(url, {
body,
method: 'POST',
});
},
async ready(key: string): Promise<ReadyResponse> {
const path = makeApiUrl(root, 'ready');
path.searchParams.append('output', key);

View File

@ -39,6 +39,9 @@ export const LOCAL_CLIENT = {
async outpaint(model, params, upscale) {
throw new NoServerError();
},
async chain(chain) {
throw new NoServerError();
},
async noises() {
throw new NoServerError();
},

View File

@ -162,6 +162,22 @@ export interface HighresParams {
highresStrength: number;
}
export interface Txt2ImgStage {
name: string;
type: 'source-txt2img';
params: Txt2ImgParams;
}
export interface Img2ImgStage {
name: string;
type: 'blend-img2img';
params: Img2ImgParams;
}
export interface ChainPipeline {
stages: Array<Txt2ImgStage | Img2ImgStage>;
}
/**
* Output image data within the response.
*/
@ -354,6 +370,8 @@ export interface ApiClient {
*/
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
chain(chain: ChainPipeline): Promise<ImageResponse>;
/**
* Check whether job has finished and its output is ready.
*/

42
gui/src/client/utils.ts Normal file
View File

@ -0,0 +1,42 @@
import { ChainPipeline, HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from './types.js';
export interface PipelineVariable {
parameter: 'prompt' | 'cfg' | 'seed' | 'steps';
input: string;
values: Array<string>;
}
export interface PipelineGrid {
enabled: boolean;
columns: PipelineVariable;
rows: PipelineVariable;
}
// eslint-disable-next-line max-params
export function buildPipelineForTxt2ImgGrid(grid: PipelineGrid, model: ModelParams, params: Txt2ImgParams, upscale?: UpscaleParams, highres?: HighresParams): ChainPipeline {
const pipeline: ChainPipeline = {
stages: [],
};
let i = 0;
for (const column of grid.columns.values) {
for (const row of grid.rows.values) {
pipeline.stages.push({
name: `cell-${i}`,
type: 'source-txt2img',
params: {
...params,
[grid.columns.parameter]: column,
[grid.rows.parameter]: row,
},
});
i += 1;
}
}
// TODO: add final grid stage
return pipeline;
}

View File

@ -1,4 +1,4 @@
import { mustExist } from '@apextoaster/js-utils';
import { Maybe, doesExist, mustExist } from '@apextoaster/js-utils';
import { Delete, Replay } from '@mui/icons-material';
import { Alert, Box, Card, CardContent, IconButton, Tooltip } from '@mui/material';
import { Stack } from '@mui/system';
@ -15,7 +15,7 @@ import { ClientContext, ConfigContext, OnnxState, StateContext } from '../../sta
export interface ErrorCardProps {
image: ImageResponse;
ready: ReadyResponse;
retry: RetryParams;
retry: Maybe<RetryParams>;
}
export function ErrorCard(props: ErrorCardProps) {
@ -30,8 +30,11 @@ export function ErrorCard(props: ErrorCardProps) {
async function retryImage() {
removeHistory(image);
const { image: nextImage, retry: nextRetry } = await client.retry(retryParams);
pushHistory(nextImage, nextRetry);
if (doesExist(retryParams)) {
const { image: nextImage, retry: nextRetry } = await client.retry(retryParams);
pushHistory(nextImage, nextRetry);
}
}
const retry = useMutation(retryImage);

View File

@ -0,0 +1,107 @@
import { doesExist, mustExist } from '@apextoaster/js-utils';
import { Checkbox, FormControl, InputLabel, MenuItem, Select, Stack, TextField } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useStore } from 'zustand';
import { PipelineGrid } from '../../client/utils.js';
import { OnnxState, StateContext } from '../../state.js';
export interface VariableControlProps {
selectGrid: (state: OnnxState) => PipelineGrid;
setGrid: (grid: Partial<PipelineGrid>) => void;
}
export type VariableKey = 'prompt' | 'steps' | 'seed';
export function VariableControl(props: VariableControlProps) {
const store = mustExist(useContext(StateContext));
const grid = useStore(store, props.selectGrid);
return <Stack direction='column' spacing={2}>
<Stack direction='row' spacing={2}>
<InputLabel>Grid Mode</InputLabel>
<Checkbox checked={grid.enabled} onChange={() => props.setGrid({
enabled: grid.enabled === false,
})} />
</Stack>
<Stack direction='row' spacing={2}>
<FormControl>
<InputLabel id='TODO'>Columns</InputLabel>
<Select onChange={(event) => props.setGrid({
columns: {
parameter: event.target.value as VariableKey,
input: '',
values: [],
},
})} value={grid.columns.parameter}>
<MenuItem key='prompt' value='prompt'>Prompt</MenuItem>
<MenuItem key='seed' value='seed'>Seed</MenuItem>
<MenuItem key='steps' value='steps'>Steps</MenuItem>
</Select>
</FormControl>
<TextField label={grid.columns.parameter} value={grid.columns.input} onChange={(event) => props.setGrid({
columns: {
parameter: grid.columns.parameter,
input: event.target.value,
values: rangeSplit(grid.columns.parameter, event.target.value),
},
})} />
</Stack>
<Stack direction='row' spacing={2}>
<FormControl>
<InputLabel id='TODO'>Rows</InputLabel>
<Select onChange={(event) => props.setGrid({
rows: {
parameter: event.target.value as VariableKey,
input: '',
values: [],
}
})} value={grid.rows.parameter}>
<MenuItem key='prompt' value='prompt'>Prompt</MenuItem>
<MenuItem key='seed' value='seed'>Seed</MenuItem>
<MenuItem key='steps' value='steps'>Steps</MenuItem>
</Select>
</FormControl>
<TextField label={grid.rows.parameter} value={grid.rows.input} onChange={(event) => props.setGrid({
rows: {
parameter: grid.rows.parameter,
input: event.target.value,
values: rangeSplit(grid.rows.parameter, event.target.value),
}
})} />
</Stack>
</Stack>;
}
export function rangeSplit(parameter: string, value: string): Array<string> {
// string values
if (parameter === 'prompt') {
return value.split('\n');
}
return value.split(',').map((it) => it.trim()).flatMap((it) => expandRanges(it));
}
export const EXPR_STRICT_NUMBER = /^[0-9]+$/;
export const EXPR_NUMBER_RANGE = /^([0-9]+)-([0-9]+)$/;
export function expandRanges(range: string): Array<string> {
if (EXPR_STRICT_NUMBER.test(range)) {
// entirely numeric, return without parsing
return [range];
}
if (EXPR_NUMBER_RANGE.test(range)) {
const match = EXPR_NUMBER_RANGE.exec(range);
if (doesExist(match)) {
const [_full, startStr, endStr] = Array.from(match);
const start = parseInt(startStr, 10);
const end = parseInt(endStr, 10);
return new Array(end - start).fill(0).map((_value, idx) => (idx + start).toFixed(0));
}
}
return [];
}

View File

@ -15,15 +15,27 @@ import { ModelControl } from '../control/ModelControl.js';
import { UpscaleControl } from '../control/UpscaleControl.js';
import { NumericField } from '../input/NumericField.js';
import { Profiles } from '../Profiles.js';
import { VariableControl } from '../control/VariableControl.js';
import { PipelineGrid, buildPipelineForTxt2ImgGrid } from '../../client/utils.js';
export function Txt2Img() {
const { params } = mustExist(useContext(ConfigContext));
async function generateImage() {
const state = store.getState();
const { image, retry } = await client.txt2img(model, selectParams(state), selectUpscale(state), selectHighres(state));
const grid = selectVariable(state);
const params2 = selectParams(state);
const upscale = selectUpscale(state);
const highres = selectHighres(state);
pushHistory(image, retry);
if (grid.enabled) {
const chain = buildPipelineForTxt2ImgGrid(grid, model, params2, upscale, highres);
const image = await client.chain(chain);
pushHistory(image);
} else {
const { image, retry } = await client.txt2img(model, params2, upscale, highres);
pushHistory(image, retry);
}
}
const client = mustExist(useContext(ClientContext));
@ -33,7 +45,7 @@ export function Txt2Img() {
});
const store = mustExist(useContext(StateContext));
const { pushHistory, setHighres, setModel, setParams, setUpscale } = useStore(store, selectActions, shallow);
const { pushHistory, setHighres, setModel, setParams, setUpscale, setVariable } = useStore(store, selectActions, shallow);
const { height, width } = useStore(store, selectReactParams, shallow);
const model = useStore(store, selectModel);
@ -79,6 +91,7 @@ export function Txt2Img() {
</Stack>
<HighresControl selectHighres={selectHighres} setHighres={setHighres} />
<UpscaleControl selectUpscale={selectUpscale} setUpscale={setUpscale} />
<VariableControl selectGrid={selectVariable} setGrid={setVariable} />
<Button
variant='contained'
onClick={() => generate.mutate()}
@ -99,6 +112,8 @@ export function selectActions(state: OnnxState) {
setParams: state.setTxt2Img,
// eslint-disable-next-line @typescript-eslint/unbound-method
setUpscale: state.setTxt2ImgUpscale,
// eslint-disable-next-line @typescript-eslint/unbound-method
setVariable: state.setTxt2ImgVariable,
};
}
@ -124,3 +139,7 @@ export function selectHighres(state: OnnxState): HighresParams {
export function selectUpscale(state: OnnxState): UpscaleParams {
return state.txt2imgUpscale;
}
export function selectVariable(state: OnnxState): PipelineGrid {
return state.txt2imgVariable;
}

View File

@ -25,6 +25,7 @@ import {
} from './client/types.js';
import { Config, ConfigFiles, ConfigState, ServerParams } from './config.js';
import { CorrectionModel, DiffusionModel, ExtraNetwork, ExtraSource, ExtrasFile, UpscalingModel } from './types.js';
import { PipelineGrid } from './client/utils.js';
export const MISSING_INDEX = -1;
@ -38,7 +39,7 @@ export type TabState<TabParams> = ConfigFiles<Required<TabParams>> & ConfigState
export interface HistoryItem {
image: ImageResponse;
ready: Maybe<ReadyResponse>;
retry: RetryParams;
retry: Maybe<RetryParams>;
}
export interface ProfileItem {
@ -60,7 +61,7 @@ interface HistorySlice {
history: Array<HistoryItem>;
limit: number;
pushHistory(image: ImageResponse, retry: RetryParams): void;
pushHistory(image: ImageResponse, retry?: RetryParams): void;
removeHistory(image: ImageResponse): void;
setLimit(limit: number): void;
setReady(image: ImageResponse, ready: ReadyResponse): void;
@ -90,6 +91,7 @@ interface Txt2ImgSlice {
txt2imgModel: ModelParams;
txt2imgHighres: HighresParams;
txt2imgUpscale: UpscaleParams;
txt2imgVariable: PipelineGrid;
resetTxt2Img(): void;
@ -97,6 +99,7 @@ interface Txt2ImgSlice {
setTxt2ImgModel(params: Partial<ModelParams>): void;
setTxt2ImgHighres(params: Partial<HighresParams>): void;
setTxt2ImgUpscale(params: Partial<UpscaleParams>): void;
setTxt2ImgVariable(params: Partial<PipelineGrid>): void;
}
interface Img2ImgSlice {
@ -305,6 +308,19 @@ export function createStateSlices(server: ServerParams) {
scale: server.scale.default,
upscaleOrder: server.upscaleOrder.default,
};
const defaultGrid: PipelineGrid = {
enabled: false,
columns: {
input: '',
parameter: 'seed',
values: [],
},
rows: {
input: '',
parameter: 'seed',
values: [],
},
};
const createTxt2ImgSlice: Slice<Txt2ImgSlice> = (set) => ({
txt2img: {
@ -321,6 +337,9 @@ export function createStateSlices(server: ServerParams) {
txt2imgUpscale: {
...defaultUpscale,
},
txt2imgVariable: {
...defaultGrid,
},
setTxt2Img(params) {
set((prev) => ({
txt2img: {
@ -353,6 +372,14 @@ export function createStateSlices(server: ServerParams) {
},
}));
},
setTxt2ImgVariable(params) {
set((prev) => ({
txt2imgVariable: {
...prev.txt2imgVariable,
...params,
},
}));
},
resetTxt2Img() {
set({
txt2img: {