From 626ca18d7f9ad53317b86c1fc034fcc86b348b3d Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 21 Jul 2023 22:11:45 -0500 Subject: [PATCH] include highres and upscale in params loading --- gui/src/client/types.ts | 16 ++++- gui/src/components/Profiles.tsx | 72 ++++++++++++--------- gui/src/components/control/ModelControl.tsx | 3 +- gui/src/components/tab/Img2Img.tsx | 42 +++++++++--- gui/src/components/tab/Inpaint.tsx | 57 ++++++++++------ gui/src/components/tab/Txt2Img.tsx | 36 +++++++++-- gui/src/state.ts | 6 +- gui/src/types.ts | 4 ++ 8 files changed, 164 insertions(+), 72 deletions(-) diff --git a/gui/src/client/types.ts b/gui/src/client/types.ts index ee5907e6..b9d22495 100644 --- a/gui/src/client/types.ts +++ b/gui/src/client/types.ts @@ -60,8 +60,8 @@ export interface BaseImgParams { * Parameters for txt2img requests. */ export interface Txt2ImgParams extends BaseImgParams { - width?: number; - height?: number; + width: number; + height: number; } /** @@ -71,7 +71,7 @@ export interface Img2ImgParams extends BaseImgParams { source: Blob; loopback: number; - sourceFilter?: string; + sourceFilter: string; strength: number; } @@ -267,6 +267,16 @@ export interface ImageResponseWithRetry { retry: RetryParams; } +export interface ImageMetadata { + highres: HighresParams; + outputs: string | Array; + params: Txt2ImgParams | Img2ImgParams | InpaintParams; + upscale: UpscaleParams; + + input_size: ImageSize; + size: ImageSize; +} + export interface ApiClient { extras(): Promise; diff --git a/gui/src/components/Profiles.tsx b/gui/src/components/Profiles.tsx index 0d775e9d..f692ebfe 100644 --- a/gui/src/components/Profiles.tsx +++ b/gui/src/components/Profiles.tsx @@ -1,4 +1,4 @@ -import { InvalidArgumentError, Maybe, doesExist, mustExist } from '@apextoaster/js-utils'; +import { doesExist, InvalidArgumentError, Maybe, mustExist } from '@apextoaster/js-utils'; import { Delete as DeleteIcon, Download, ImageSearch, Save as SaveIcon } from '@mui/icons-material'; import { Autocomplete, @@ -20,19 +20,20 @@ import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; -import { BaseImgParams, HighresParams, Txt2ImgParams, UpscaleParams } from '../client/types.js'; +import { BaseImgParams, HighresParams, ImageMetadata, Txt2ImgParams, UpscaleParams } from '../client/types.js'; import { StateContext } from '../state.js'; +import { DeepPartial } from '../types.js'; -const { useState, Fragment } = React; +const { useState } = React; export interface ProfilesProps { highres: HighresParams; params: BaseImgParams; upscale: UpscaleParams; - setHighres(params: HighresParams): void; - setParams(params: BaseImgParams): void; - setUpscale(params: UpscaleParams): void; + setHighres(params: Partial): void; + setParams(params: Partial): void; + setUpscale(params: Partial): void; } export function Profiles(props: ProfilesProps) { @@ -50,7 +51,7 @@ export function Profiles(props: ProfilesProps) { return option.name} @@ -59,7 +60,7 @@ export function Profiles(props: ProfilesProps) { { + { event.preventDefault(); removeProfile(option.name); }}> @@ -71,7 +72,7 @@ export function Profiles(props: ProfilesProps) { )} renderInput={(params) => ( - + - @@ -100,7 +101,7 @@ export function Profiles(props: ProfilesProps) { {t('profile.saveProfile')} setProfileName(event.target.value)} @@ -118,8 +119,8 @@ export function Profiles(props: ProfilesProps) { saveProfile({ params: props.params, name: profileName, - highResParams: props.highres, - upscaleParams: props.upscale, + highres: props.highres, + upscale: props.upscale, }); setDialogOpen(false); setProfileName(''); @@ -127,7 +128,7 @@ export function Profiles(props: ProfilesProps) { >{t('profile.save')} - ; } -export async function loadParamsFromFile(file: File): Promise> { +export async function loadParamsFromFile(file: File): Promise> { const parts = file.name.toLocaleLowerCase().split('.'); const ext = parts[parts.length - 1]; @@ -182,10 +188,8 @@ export async function loadParamsFromFile(file: File): Promise): void { + const dataStr = 'data:text/json;charset=utf-8,' + encodeURIComponent(JSON.stringify(data)); const elem = document.createElement('a'); elem.setAttribute('href', dataStr); elem.setAttribute('download', 'parameters.json'); @@ -194,7 +198,7 @@ export function downloadParamsAsFile(params: Txt2ImgParams): void { elem.remove(); } -export async function parseImageParams(file: File): Promise> { +export async function parseImageParams(file: File): Promise> { const tags = await ExifReader.load(file); // handle lowercase variation from my earlier mistakes @@ -234,8 +238,8 @@ export function decodeTag(tag: Maybe> { - const data = JSON.parse(json); +export async function parseJSONParams(json: string): Promise> { + const data = JSON.parse(json) as DeepPartial; const params: Partial = { ...data.params, }; @@ -246,7 +250,11 @@ export async function parseJSONParams(json: string): Promise> { +export async function parseAutoComment(comment: string): Promise> { if (isProbablyJSON(comment)) { return parseJSONParams(comment); } @@ -306,5 +314,7 @@ export async function parseAutoComment(comment: string): Promise client.restart()); diff --git a/gui/src/components/tab/Img2Img.tsx b/gui/src/components/tab/Img2Img.tsx index e8e504bb..72b0e244 100644 --- a/gui/src/components/tab/Img2Img.tsx +++ b/gui/src/components/tab/Img2Img.tsx @@ -1,20 +1,21 @@ import { doesExist, mustExist } from '@apextoaster/js-utils'; import { Box, Button, Stack } from '@mui/material'; +import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; import * as React from 'react'; import { useContext } from 'react'; import { useTranslation } from 'react-i18next'; -import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; import { useStore } from 'zustand'; +import { HighresParams, Img2ImgParams, ModelParams, UpscaleParams } from '../../client/types.js'; import { IMAGE_FILTER, STALE_TIME } from '../../config.js'; -import { ClientContext, ConfigContext, StateContext } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; +import { HighresControl } from '../control/HighresControl.js'; import { ImageControl } from '../control/ImageControl.js'; +import { ModelControl } from '../control/ModelControl.js'; import { UpscaleControl } from '../control/UpscaleControl.js'; import { ImageInput } from '../input/ImageInput.js'; import { NumericField } from '../input/NumericField.js'; import { QueryList } from '../input/QueryList.js'; -import { HighresControl } from '../control/HighresControl.js'; -import { ModelControl } from '../control/ModelControl.js'; import { Profiles } from '../Profiles.js'; export function Img2Img() { @@ -43,11 +44,11 @@ export function Img2Img() { }); const state = mustExist(useContext(StateContext)); - const model = useStore(state, (s) => s.img2imgModel); + const model = useStore(state, selectModel); const source = useStore(state, (s) => s.img2img.source); - const img2img = useStore(state, (s) => s.img2img); - const highres = useStore(state, (s) => s.img2imgHighres); - const upscale = useStore(state, (s) => s.img2imgUpscale); + const img2img = useStore(state, selectParams); + const highres = useStore(state, selectHighres); + const upscale = useStore(state, selectUpscale); // eslint-disable-next-line @typescript-eslint/unbound-method const setImg2Img = useStore(state, (s) => s.setImg2Img); // eslint-disable-next-line @typescript-eslint/unbound-method @@ -62,7 +63,14 @@ export function Img2Img() { return - + ; } + +export function selectModel(state: OnnxState): ModelParams { + return state.img2imgModel; +} + +export function selectParams(state: OnnxState): TabState { + return state.img2img; +} + +export function selectHighres(state: OnnxState): HighresParams { + return state.img2imgHighres; +} + +export function selectUpscale(state: OnnxState): UpscaleParams { + return state.img2imgUpscale; +} diff --git a/gui/src/components/tab/Inpaint.tsx b/gui/src/components/tab/Inpaint.tsx index 75a40f63..f578c3ab 100644 --- a/gui/src/components/tab/Inpaint.tsx +++ b/gui/src/components/tab/Inpaint.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next'; import { useStore } from 'zustand'; import { IMAGE_FILTER, STALE_TIME } from '../../config.js'; -import { ClientContext, ConfigContext, StateContext } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; import { HighresControl } from '../control/HighresControl.js'; import { ImageControl } from '../control/ImageControl.js'; import { ModelControl } from '../control/ModelControl.js'; @@ -18,6 +18,7 @@ import { MaskCanvas } from '../input/MaskCanvas.js'; import { NumericField } from '../input/NumericField.js'; import { QueryList } from '../input/QueryList.js'; import { Profiles } from '../Profiles.js'; +import { ModelParams, InpaintParams, HighresParams, UpscaleParams } from '../../client/types.js'; export function Inpaint() { const { params } = mustExist(useContext(ConfigContext)); @@ -35,16 +36,16 @@ export function Inpaint() { const { image, retry } = await client.outpaint(model, { ...inpaint, ...outpaint, - mask: mustExist(mask), - source: mustExist(source), + mask: mustExist(inpaint.mask), + source: mustExist(inpaint.source), }, upscale, highres); pushHistory(image, retry); } else { const { image, retry } = await client.inpaint(model, { ...inpaint, - mask: mustExist(mask), - source: mustExist(source), + mask: mustExist(inpaint.mask), + source: mustExist(inpaint.source), }, upscale, highres); pushHistory(image, retry); @@ -52,7 +53,7 @@ export function Inpaint() { } function preventInpaint(): boolean { - return doesExist(source) === false || doesExist(mask) === false; + return doesExist(inpaint.source) === false || doesExist(inpaint.mask) === false; } function supportsInpaint(): boolean { @@ -60,15 +61,12 @@ export function Inpaint() { } const state = mustExist(useContext(StateContext)); - const mask = useStore(state, (s) => s.inpaint.mask); - const source = useStore(state, (s) => s.inpaint.source); - const inpaint = useStore(state, (s) => s.inpaint); + const inpaint = useStore(state, selectParams); + const highres = useStore(state, selectHighres); + const model = useStore(state, selectModel); + const upscale = useStore(state, selectUpscale); const outpaint = useStore(state, (s) => s.outpaint); - const brush = useStore(state, (s) => s.inpaintBrush); - const highres = useStore(state, (s) => s.inpaintHighres); - const model = useStore(state, (s) => s.inpaintModel); - const upscale = useStore(state, (s) => s.inpaintUpscale); // eslint-disable-next-line @typescript-eslint/unbound-method const setInpaint = useStore(state, (s) => s.setInpaint); @@ -100,12 +98,19 @@ export function Inpaint() { return - + {renderBanner()} { @@ -116,7 +121,7 @@ export function Inpaint() { /> { @@ -127,8 +132,8 @@ export function Inpaint() { /> { setInpaint({ mask: file, @@ -232,3 +237,19 @@ export function Inpaint() { ; } + +export function selectModel(state: OnnxState): ModelParams { + return state.inpaintModel; +} + +export function selectParams(state: OnnxState): TabState { + return state.inpaint; +} + +export function selectHighres(state: OnnxState): HighresParams { + return state.inpaintHighres; +} + +export function selectUpscale(state: OnnxState): UpscaleParams { + return state.inpaintUpscale; +} diff --git a/gui/src/components/tab/Txt2Img.tsx b/gui/src/components/tab/Txt2Img.tsx index b60c71af..3ad2133e 100644 --- a/gui/src/components/tab/Txt2Img.tsx +++ b/gui/src/components/tab/Txt2Img.tsx @@ -6,13 +6,14 @@ import { useTranslation } from 'react-i18next'; import { useMutation, useQueryClient } from '@tanstack/react-query'; import { useStore } from 'zustand'; -import { ClientContext, ConfigContext, StateContext } from '../../state.js'; +import { ClientContext, ConfigContext, OnnxState, StateContext, TabState } from '../../state.js'; import { HighresControl } from '../control/HighresControl.js'; import { ImageControl } from '../control/ImageControl.js'; import { UpscaleControl } from '../control/UpscaleControl.js'; import { NumericField } from '../input/NumericField.js'; import { ModelControl } from '../control/ModelControl.js'; import { Profiles } from '../Profiles.js'; +import { HighresParams, ModelParams, Txt2ImgParams, UpscaleParams } from '../../client/types.js'; export function Txt2Img() { const { params } = mustExist(useContext(ConfigContext)); @@ -30,10 +31,10 @@ export function Txt2Img() { }); const state = mustExist(useContext(StateContext)); - const txt2img = useStore(state, (s) => s.txt2img); - const model = useStore(state, (s) => s.txt2imgModel); - const highres = useStore(state, (s) => s.txt2imgHighres); - const upscale = useStore(state, (s) => s.txt2imgUpscale); + const txt2img = useStore(state, selectParams); + const model = useStore(state, selectModel); + const highres = useStore(state, selectHighres); + const upscale = useStore(state, selectUpscale); // eslint-disable-next-line @typescript-eslint/unbound-method const setParams = useStore(state, (s) => s.setTxt2Img); // eslint-disable-next-line @typescript-eslint/unbound-method @@ -48,7 +49,14 @@ export function Txt2Img() { return - + s.txt2img} onChange={setParams} /> @@ -86,3 +94,19 @@ export function Txt2Img() { ; } + +export function selectModel(state: OnnxState): ModelParams { + return state.txt2imgModel; +} + +export function selectParams(state: OnnxState): TabState { + return state.txt2img; +} + +export function selectHighres(state: OnnxState): HighresParams { + return state.txt2imgHighres; +} + +export function selectUpscale(state: OnnxState): UpscaleParams { + return state.txt2imgUpscale; +} diff --git a/gui/src/state.ts b/gui/src/state.ts index ec64096b..c05eb81c 100644 --- a/gui/src/state.ts +++ b/gui/src/state.ts @@ -43,9 +43,9 @@ interface HistoryItem { interface ProfileItem { name: string; - params: Txt2ImgParams; - highResParams?: Maybe; - upscaleParams?: Maybe; + params: BaseImgParams | Txt2ImgParams; + highres?: Maybe; + upscale?: Maybe; } interface DefaultSlice { diff --git a/gui/src/types.ts b/gui/src/types.ts index 8d1907b6..ccdd1a57 100644 --- a/gui/src/types.ts +++ b/gui/src/types.ts @@ -72,3 +72,7 @@ export interface ExtrasFile { networks: Array; sources: Array; } + +export type DeepPartial = T extends object ? { + [P in keyof T]?: DeepPartial; +} : T;