1
0
Fork 0

feat(gui): save source and mask images while changing tabs

This commit is contained in:
Sean Sube 2023-01-13 14:39:07 -06:00
parent e872eeacec
commit 4e82241491
6 changed files with 159 additions and 108 deletions

View File

@ -1,32 +1,33 @@
import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils'; import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils';
import { PhotoCamera } from '@mui/icons-material'; import { PhotoCamera } from '@mui/icons-material';
import { Button, Stack } from '@mui/material'; import { Button, Stack } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
const { useState } = React;
export interface ImageInputProps { export interface ImageInputProps {
filter: string; filter: string;
hidden?: boolean; hidden?: boolean;
image?: Maybe<Blob>;
label: string; label: string;
onChange: (file: File) => void; onChange: (file: File) => void;
renderImage?: (image: string | undefined) => React.ReactNode; renderImage?: (image: Maybe<Blob>) => React.ReactNode;
} }
export function ImageInput(props: ImageInputProps) { export function ImageInput(props: ImageInputProps) {
const [image, setImage] = useState<string>();
function renderImage() { function renderImage() {
if (mustDefault(props.hidden, false)) { if (mustDefault(props.hidden, false)) {
return undefined; return undefined;
} }
if (doesExist(props.renderImage)) { if (doesExist(props.renderImage)) {
return props.renderImage(image); return props.renderImage(props.image);
} }
return <img src={image} />; if (doesExist(props.image)) {
return <img src={URL.createObjectURL(props.image)} />;
} else {
return <div>Please select an image.</div>;
}
} }
return <Stack direction='row' spacing={2}> return <Stack direction='row' spacing={2}>
@ -41,11 +42,6 @@ export function ImageInput(props: ImageInputProps) {
if (doesExist(files) && files.length > 0) { if (doesExist(files) && files.length > 0) {
const file = mustExist(files[0]); const file = mustExist(files[0]);
if (doesExist(image)) {
URL.revokeObjectURL(image);
}
setImage(URL.createObjectURL(file));
props.onChange(file); props.onChange(file);
} }
}} }}

View File

@ -10,7 +10,7 @@ import { ImageControl } from './ImageControl.js';
import { ImageInput } from './ImageInput.js'; import { ImageInput } from './ImageInput.js';
import { NumericField } from './NumericField.js'; import { NumericField } from './NumericField.js';
const { useContext, useState } = React; const { useContext } = React;
export interface Img2ImgProps { export interface Img2ImgProps {
config: ConfigParams; config: ConfigParams;
@ -27,7 +27,7 @@ export function Img2Img(props: Img2ImgProps) {
...params, ...params,
model, model,
platform, platform,
source: mustExist(source), // TODO: show an error if this doesn't exist source: mustExist(params.source), // TODO: show an error if this doesn't exist
}); });
setLoading(output); setLoading(output);
@ -46,11 +46,13 @@ export function Img2Img(props: Img2ImgProps) {
// eslint-disable-next-line @typescript-eslint/unbound-method // eslint-disable-next-line @typescript-eslint/unbound-method
const setLoading = useStore(state, (s) => s.setLoading); const setLoading = useStore(state, (s) => s.setLoading);
const [source, setSource] = useState<File>();
return <Box> return <Box>
<Stack spacing={2}> <Stack spacing={2}>
<ImageInput filter={IMAGE_FILTER} label='Source' onChange={setSource} /> <ImageInput filter={IMAGE_FILTER} image={params.source} label='Source' onChange={(file) => {
setImg2Img({
source: file,
});
}} />
<ImageControl config={config} params={params} onChange={(newParams) => { <ImageControl config={config} params={params} onChange={(newParams) => {
setImg2Img(newParams); setImg2Img(newParams);
}} /> }} />

View File

@ -1,11 +1,13 @@
import { doesExist, mustExist } from '@apextoaster/js-utils'; import { doesExist, mustExist } from '@apextoaster/js-utils';
import { FormatColorFill, Gradient } from '@mui/icons-material'; import { FormatColorFill, Gradient } from '@mui/icons-material';
import { Box, Button, Stack } from '@mui/material'; import { Box, Button, Stack } from '@mui/material';
import { throttle } from 'lodash';
import * as React from 'react'; import * as React from 'react';
import { useCallback } from 'react';
import { useMutation, useQueryClient } from 'react-query'; import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER } from '../config.js'; import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER, SAVE_TIME } from '../config.js';
import { ClientContext, StateContext } from '../state.js'; import { ClientContext, StateContext } from '../state.js';
import { ImageControl } from './ImageControl.js'; import { ImageControl } from './ImageControl.js';
import { ImageInput } from './ImageInput.js'; import { ImageInput } from './ImageInput.js';
@ -67,51 +69,49 @@ export function Inpaint(props: InpaintProps) {
const { config, model, platform } = props; const { config, model, platform } = props;
const client = mustExist(useContext(ClientContext)); const client = mustExist(useContext(ClientContext));
async function uploadSource() { function drawSource(file: Blob): Promise<void> {
const canvas = mustExist(canvasRef.current); const image = new Image();
return new Promise<void>((res, rej) => { return new Promise<void>((res, _rej) => {
canvas.toBlob((blob) => { image.onload = () => {
client.inpaint({ const canvas = mustExist(canvasRef.current);
...params, const ctx = mustExist(canvas.getContext('2d'));
model, ctx.drawImage(image, 0, 0);
platform, URL.revokeObjectURL(src);
mask: mustExist(blob),
source: mustExist(source), // putting a save call here has a tendency to go into an infinite loop
}).then((output) => { res();
setLoading(output); };
res();
}).catch((err) => rej(err)); const src = URL.createObjectURL(file);
}); image.src = src;
}); });
} }
function drawSource(file: File) { function saveMask(): Promise<void> {
const image = new Image(); return new Promise((res, _rej) => {
image.onload = () => { if (doesExist(canvasRef.current)) {
const canvas = mustExist(canvasRef.current); canvasRef.current.toBlob((blob) => {
const ctx = mustExist(canvas.getContext('2d')); setInpaint({
ctx.drawImage(image, 0, 0); mask: mustExist(blob),
URL.revokeObjectURL(src); });
}; res();
});
const src = URL.createObjectURL(file); } else {
image.src = src; res();
}
});
} }
function changeMask(file: File) { async function uploadSource(): Promise<void> {
setMask(file); const output = await client.inpaint({
...params,
model,
platform,
mask: mustExist(params.mask),
source: mustExist(params.source),
});
// always draw the mask to the canvas setLoading(output);
drawSource(file);
}
function changeSource(file: File) {
setSource(file);
// draw the source to the canvas if the mask has not been set
if (doesExist(mask) === false) {
drawSource(file);
}
} }
function floodMask(flooder: (n: number) => number) { function floodMask(flooder: (n: number) => number) {
@ -133,46 +133,11 @@ export function Inpaint(props: InpaintProps) {
} }
ctx.putImageData(image, 0, 0); ctx.putImageData(image, 0, 0);
// eslint-disable-next-line @typescript-eslint/no-floating-promises
save();
} }
const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.inpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setInpaint = useStore(state, (s) => s.setInpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setLoading = useStore(state, (s) => s.setLoading);
const query = useQueryClient();
const upload = useMutation(uploadSource, {
onSuccess: () => query.invalidateQueries({ queryKey: 'ready' }),
});
// eslint-disable-next-line no-null/no-null
const canvasRef = useRef<HTMLCanvasElement>(null);
// painting state
const [clicks, setClicks] = useState<Array<Point>>([]);
const [painting, setPainting] = useState(false);
const [brushColor, setBrushColor] = useState(DEFAULT_BRUSH.color);
const [brushSize, setBrushSize] = useState(DEFAULT_BRUSH.size);
// image state
const [mask, setMask] = useState<File>();
const [source, setSource] = useState<File>();
useEffect(() => {
const canvas = mustExist(canvasRef.current);
const ctx = mustExist(canvas.getContext('2d'));
ctx.fillStyle = grayToRGB(brushColor);
for (const click of clicks) {
ctx.beginPath();
ctx.arc(click.x, click.y, brushSize, 0, FULL_CIRCLE);
ctx.fill();
}
clicks.length = 0;
}, [clicks.length]);
function renderCanvas() { function renderCanvas() {
return <canvas return <canvas
ref={canvasRef} ref={canvasRef}
@ -217,10 +182,74 @@ export function Inpaint(props: InpaintProps) {
/>; />;
} }
const save = useCallback(throttle(saveMask, SAVE_TIME), []);
const state = mustExist(useContext(StateContext));
const params = useStore(state, (s) => s.inpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setInpaint = useStore(state, (s) => s.setInpaint);
// eslint-disable-next-line @typescript-eslint/unbound-method
const setLoading = useStore(state, (s) => s.setLoading);
const query = useQueryClient();
const upload = useMutation(uploadSource, {
onSuccess: () => query.invalidateQueries({ queryKey: 'ready' }),
});
// eslint-disable-next-line no-null/no-null
const canvasRef = useRef<HTMLCanvasElement>(null);
// painting state
const [clicks, setClicks] = useState<Array<Point>>([]);
const [painting, setPainting] = useState(false);
const [brushColor, setBrushColor] = useState(DEFAULT_BRUSH.color);
const [brushSize, setBrushSize] = useState(DEFAULT_BRUSH.size);
useEffect(function changeMask() {
// always draw the new mask to the canvas
if (doesExist(params.mask)) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
drawSource(params.mask);
}
}, [params.mask]);
useEffect(function changeSource() {
// draw the source to the canvas if the mask has not been set
if (doesExist(params.source) && doesExist(params.mask) === false) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
drawSource(params.source);
}
}, [params.source]);
useEffect(() => {
// including clicks.length prevents the initial render from saving a blank canvas
if (doesExist(canvasRef.current) && clicks.length > 0) {
const ctx = mustExist(canvasRef.current.getContext('2d'));
ctx.fillStyle = grayToRGB(brushColor);
for (const click of clicks) {
ctx.beginPath();
ctx.arc(click.x, click.y, brushSize, 0, FULL_CIRCLE);
ctx.fill();
}
clicks.length = 0;
// eslint-disable-next-line @typescript-eslint/no-floating-promises
save();
}
}, [clicks.length]);
return <Box> return <Box>
<Stack spacing={2}> <Stack spacing={2}>
<ImageInput filter={IMAGE_FILTER} label='Source' onChange={changeSource} /> <ImageInput filter={IMAGE_FILTER} image={params.source} label='Source' onChange={(file) => {
<ImageInput filter={IMAGE_FILTER} label='Mask' onChange={changeMask} renderImage={renderCanvas} /> setInpaint({
source: file,
});
}} />
<ImageInput filter={IMAGE_FILTER} image={params.mask} label='Mask' onChange={(file) => {
setInpaint({
mask: file,
});
}} renderImage={renderCanvas} />
<Stack direction='row' spacing={4}> <Stack direction='row' spacing={4}>
<NumericField <NumericField
decimal decimal

View File

@ -1,3 +1,5 @@
import { Maybe } from '@apextoaster/js-utils';
import { Img2ImgParams, STATUS_SUCCESS, Txt2ImgParams } from './api/client.js'; import { Img2ImgParams, STATUS_SUCCESS, Txt2ImgParams } from './api/client.js';
export interface ConfigNumber { export interface ConfigNumber {
@ -12,16 +14,20 @@ export interface ConfigString {
keys: Array<string>; keys: Array<string>;
} }
export type KeyFilter<T extends object> = { export type KeyFilter<T extends object, TValid = number | string> = {
[K in keyof T]: T[K] extends number ? K : T[K] extends string ? K : never; [K in keyof T]: T[K] extends TValid ? K : never;
}[keyof T]; }[keyof T];
export type ConfigFiles<T extends object> = {
[K in KeyFilter<T, Blob | File>]: Maybe<T[K]>;
};
export type ConfigRanges<T extends object> = { export type ConfigRanges<T extends object> = {
[K in KeyFilter<T>]: T[K] extends number ? ConfigNumber : T[K] extends string ? ConfigString : never; [K in KeyFilter<T>]: T[K] extends number ? ConfigNumber : T[K] extends string ? ConfigString : never;
}; };
export type ConfigState<T extends object> = { export type ConfigState<T extends object, TValid = number | string> = {
[K in KeyFilter<T>]: T[K] extends number ? number : T[K] extends string ? string : never; [K in KeyFilter<T, TValid>]: T[K] extends TValid ? T[K] : never;
}; };
export type ConfigParams = ConfigRanges<Required<Img2ImgParams & Txt2ImgParams>>; export type ConfigParams = ConfigRanges<Required<Img2ImgParams & Txt2ImgParams>>;
@ -45,6 +51,7 @@ export const DEFAULT_BRUSH = {
export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png'; export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png';
export const STALE_TIME = 300_000; // 5 minutes export const STALE_TIME = 300_000; // 5 minutes
export const POLL_TIME = 5_000; // 5 seconds export const POLL_TIME = 5_000; // 5 seconds
export const SAVE_TIME = 5_000; // 5 seconds
export async function loadConfig(): Promise<Config> { export async function loadConfig(): Promise<Config> {
const configPath = new URL('./config.json', window.origin); const configPath = new URL('./config.json', window.origin);

View File

@ -12,8 +12,6 @@ import { OnnxWeb } from './components/OnnxWeb.js';
import { loadConfig } from './config.js'; import { loadConfig } from './config.js';
import { ClientContext, createStateSlices, OnnxState, StateContext } from './state.js'; import { ClientContext, createStateSlices, OnnxState, StateContext } from './state.js';
const { createContext } = React;
export async function main() { export async function main() {
// load config from GUI server // load config from GUI server
const config = await loadConfig(); const config = await loadConfig();
@ -41,6 +39,20 @@ export async function main() {
...createDefaultSlice(...slice), ...createDefaultSlice(...slice),
}), { }), {
name: 'onnx-web', name: 'onnx-web',
partialize(s) {
return {
...s,
img2img: {
...s.img2img,
source: undefined,
},
inpaint: {
...s.inpaint,
mask: undefined,
source: undefined,
},
};
},
storage: createJSONStorage(() => localStorage), storage: createJSONStorage(() => localStorage),
version: 3, version: 3,
})); }));

View File

@ -1,3 +1,4 @@
/* eslint-disable no-null/no-null */
import { Maybe } from '@apextoaster/js-utils'; import { Maybe } from '@apextoaster/js-utils';
import { createContext } from 'react'; import { createContext } from 'react';
import { StateCreator, StoreApi } from 'zustand'; import { StateCreator, StoreApi } from 'zustand';
@ -11,9 +12,9 @@ import {
paramsFromConfig, paramsFromConfig,
Txt2ImgParams, Txt2ImgParams,
} from './api/client.js'; } from './api/client.js';
import { ConfigParams, ConfigState } from './config.js'; import { ConfigFiles, ConfigParams, ConfigState } from './config.js';
type TabState<TabParams extends BaseImgParams> = ConfigState<Required<TabParams>>; type TabState<TabParams extends BaseImgParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>;
interface Txt2ImgSlice { interface Txt2ImgSlice {
txt2img: TabState<Txt2ImgParams>; txt2img: TabState<Txt2ImgParams>;
@ -86,6 +87,7 @@ export function createStateSlices(base: ConfigParams) {
const createImg2ImgSlice: StateCreator<OnnxState, [], [], Img2ImgSlice> = (set) => ({ const createImg2ImgSlice: StateCreator<OnnxState, [], [], Img2ImgSlice> = (set) => ({
img2img: { img2img: {
...defaults, ...defaults,
source: null,
strength: base.strength.default, strength: base.strength.default,
}, },
setImg2Img(params) { setImg2Img(params) {
@ -100,6 +102,7 @@ export function createStateSlices(base: ConfigParams) {
set({ set({
img2img: { img2img: {
...defaults, ...defaults,
source: null,
strength: base.strength.default, strength: base.strength.default,
}, },
}); });
@ -109,6 +112,8 @@ export function createStateSlices(base: ConfigParams) {
const createInpaintSlice: StateCreator<OnnxState, [], [], InpaintSlice> = (set) => ({ const createInpaintSlice: StateCreator<OnnxState, [], [], InpaintSlice> = (set) => ({
inpaint: { inpaint: {
...defaults, ...defaults,
mask: null,
source: null,
}, },
setInpaint(params) { setInpaint(params) {
set((prev) => ({ set((prev) => ({
@ -122,6 +127,8 @@ export function createStateSlices(base: ConfigParams) {
set({ set({
inpaint: { inpaint: {
...defaults, ...defaults,
mask: null,
source: null,
}, },
}); });
}, },
@ -130,7 +137,6 @@ export function createStateSlices(base: ConfigParams) {
const createHistorySlice: StateCreator<OnnxState, [], [], HistorySlice> = (set) => ({ const createHistorySlice: StateCreator<OnnxState, [], [], HistorySlice> = (set) => ({
history: [], history: [],
limit: 4, limit: 4,
// eslint-disable-next-line no-null/no-null
loading: null, loading: null,
pushHistory(image) { pushHistory(image) {
set((prev) => ({ set((prev) => ({
@ -139,7 +145,6 @@ export function createStateSlices(base: ConfigParams) {
image, image,
...prev.history, ...prev.history,
], ],
// eslint-disable-next-line no-null/no-null
loading: null, loading: null,
})); }));
}, },