feat(gui): save source and mask images while changing tabs
This commit is contained in:
parent
e872eeacec
commit
4e82241491
|
@ -1,32 +1,33 @@
|
|||
import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
|
||||
import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils';
|
||||
import { PhotoCamera } from '@mui/icons-material';
|
||||
import { Button, Stack } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
|
||||
const { useState } = React;
|
||||
|
||||
export interface ImageInputProps {
|
||||
filter: string;
|
||||
hidden?: boolean;
|
||||
image?: Maybe<Blob>;
|
||||
label: string;
|
||||
|
||||
onChange: (file: File) => void;
|
||||
renderImage?: (image: string | undefined) => React.ReactNode;
|
||||
renderImage?: (image: Maybe<Blob>) => React.ReactNode;
|
||||
}
|
||||
|
||||
export function ImageInput(props: ImageInputProps) {
|
||||
const [image, setImage] = useState<string>();
|
||||
|
||||
function renderImage() {
|
||||
if (mustDefault(props.hidden, false)) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (doesExist(props.renderImage)) {
|
||||
return props.renderImage(image);
|
||||
return props.renderImage(props.image);
|
||||
}
|
||||
|
||||
return <img src={image} />;
|
||||
if (doesExist(props.image)) {
|
||||
return <img src={URL.createObjectURL(props.image)} />;
|
||||
} else {
|
||||
return <div>Please select an image.</div>;
|
||||
}
|
||||
}
|
||||
|
||||
return <Stack direction='row' spacing={2}>
|
||||
|
@ -41,11 +42,6 @@ export function ImageInput(props: ImageInputProps) {
|
|||
if (doesExist(files) && files.length > 0) {
|
||||
const file = mustExist(files[0]);
|
||||
|
||||
if (doesExist(image)) {
|
||||
URL.revokeObjectURL(image);
|
||||
}
|
||||
|
||||
setImage(URL.createObjectURL(file));
|
||||
props.onChange(file);
|
||||
}
|
||||
}}
|
||||
|
|
|
@ -10,7 +10,7 @@ import { ImageControl } from './ImageControl.js';
|
|||
import { ImageInput } from './ImageInput.js';
|
||||
import { NumericField } from './NumericField.js';
|
||||
|
||||
const { useContext, useState } = React;
|
||||
const { useContext } = React;
|
||||
|
||||
export interface Img2ImgProps {
|
||||
config: ConfigParams;
|
||||
|
@ -27,7 +27,7 @@ export function Img2Img(props: Img2ImgProps) {
|
|||
...params,
|
||||
model,
|
||||
platform,
|
||||
source: mustExist(source), // TODO: show an error if this doesn't exist
|
||||
source: mustExist(params.source), // TODO: show an error if this doesn't exist
|
||||
});
|
||||
|
||||
setLoading(output);
|
||||
|
@ -46,11 +46,13 @@ export function Img2Img(props: Img2ImgProps) {
|
|||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const setLoading = useStore(state, (s) => s.setLoading);
|
||||
|
||||
const [source, setSource] = useState<File>();
|
||||
|
||||
return <Box>
|
||||
<Stack spacing={2}>
|
||||
<ImageInput filter={IMAGE_FILTER} label='Source' onChange={setSource} />
|
||||
<ImageInput filter={IMAGE_FILTER} image={params.source} label='Source' onChange={(file) => {
|
||||
setImg2Img({
|
||||
source: file,
|
||||
});
|
||||
}} />
|
||||
<ImageControl config={config} params={params} onChange={(newParams) => {
|
||||
setImg2Img(newParams);
|
||||
}} />
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import { doesExist, mustExist } from '@apextoaster/js-utils';
|
||||
import { FormatColorFill, Gradient } from '@mui/icons-material';
|
||||
import { Box, Button, Stack } from '@mui/material';
|
||||
import { throttle } from 'lodash';
|
||||
import * as React from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import { useMutation, useQueryClient } from 'react-query';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER } from '../config.js';
|
||||
import { ConfigParams, DEFAULT_BRUSH, IMAGE_FILTER, SAVE_TIME } from '../config.js';
|
||||
import { ClientContext, StateContext } from '../state.js';
|
||||
import { ImageControl } from './ImageControl.js';
|
||||
import { ImageInput } from './ImageInput.js';
|
||||
|
@ -67,51 +69,49 @@ export function Inpaint(props: InpaintProps) {
|
|||
const { config, model, platform } = props;
|
||||
const client = mustExist(useContext(ClientContext));
|
||||
|
||||
async function uploadSource() {
|
||||
const canvas = mustExist(canvasRef.current);
|
||||
return new Promise<void>((res, rej) => {
|
||||
canvas.toBlob((blob) => {
|
||||
client.inpaint({
|
||||
...params,
|
||||
model,
|
||||
platform,
|
||||
mask: mustExist(blob),
|
||||
source: mustExist(source),
|
||||
}).then((output) => {
|
||||
setLoading(output);
|
||||
res();
|
||||
}).catch((err) => rej(err));
|
||||
});
|
||||
function drawSource(file: Blob): Promise<void> {
|
||||
const image = new Image();
|
||||
return new Promise<void>((res, _rej) => {
|
||||
image.onload = () => {
|
||||
const canvas = mustExist(canvasRef.current);
|
||||
const ctx = mustExist(canvas.getContext('2d'));
|
||||
ctx.drawImage(image, 0, 0);
|
||||
URL.revokeObjectURL(src);
|
||||
|
||||
// putting a save call here has a tendency to go into an infinite loop
|
||||
res();
|
||||
};
|
||||
|
||||
const src = URL.createObjectURL(file);
|
||||
image.src = src;
|
||||
});
|
||||
}
|
||||
|
||||
function drawSource(file: File) {
|
||||
const image = new Image();
|
||||
image.onload = () => {
|
||||
const canvas = mustExist(canvasRef.current);
|
||||
const ctx = mustExist(canvas.getContext('2d'));
|
||||
ctx.drawImage(image, 0, 0);
|
||||
URL.revokeObjectURL(src);
|
||||
};
|
||||
|
||||
const src = URL.createObjectURL(file);
|
||||
image.src = src;
|
||||
function saveMask(): Promise<void> {
|
||||
return new Promise((res, _rej) => {
|
||||
if (doesExist(canvasRef.current)) {
|
||||
canvasRef.current.toBlob((blob) => {
|
||||
setInpaint({
|
||||
mask: mustExist(blob),
|
||||
});
|
||||
res();
|
||||
});
|
||||
} else {
|
||||
res();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function changeMask(file: File) {
|
||||
setMask(file);
|
||||
async function uploadSource(): Promise<void> {
|
||||
const output = await client.inpaint({
|
||||
...params,
|
||||
model,
|
||||
platform,
|
||||
mask: mustExist(params.mask),
|
||||
source: mustExist(params.source),
|
||||
});
|
||||
|
||||
// always draw the mask to the canvas
|
||||
drawSource(file);
|
||||
}
|
||||
|
||||
function changeSource(file: File) {
|
||||
setSource(file);
|
||||
|
||||
// draw the source to the canvas if the mask has not been set
|
||||
if (doesExist(mask) === false) {
|
||||
drawSource(file);
|
||||
}
|
||||
setLoading(output);
|
||||
}
|
||||
|
||||
function floodMask(flooder: (n: number) => number) {
|
||||
|
@ -133,46 +133,11 @@ export function Inpaint(props: InpaintProps) {
|
|||
}
|
||||
|
||||
ctx.putImageData(image, 0, 0);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
save();
|
||||
}
|
||||
|
||||
const state = mustExist(useContext(StateContext));
|
||||
const params = useStore(state, (s) => s.inpaint);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const setInpaint = useStore(state, (s) => s.setInpaint);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const setLoading = useStore(state, (s) => s.setLoading);
|
||||
|
||||
const query = useQueryClient();
|
||||
const upload = useMutation(uploadSource, {
|
||||
onSuccess: () => query.invalidateQueries({ queryKey: 'ready' }),
|
||||
});
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
const canvasRef = useRef<HTMLCanvasElement>(null);
|
||||
|
||||
// painting state
|
||||
const [clicks, setClicks] = useState<Array<Point>>([]);
|
||||
const [painting, setPainting] = useState(false);
|
||||
const [brushColor, setBrushColor] = useState(DEFAULT_BRUSH.color);
|
||||
const [brushSize, setBrushSize] = useState(DEFAULT_BRUSH.size);
|
||||
|
||||
// image state
|
||||
const [mask, setMask] = useState<File>();
|
||||
const [source, setSource] = useState<File>();
|
||||
|
||||
useEffect(() => {
|
||||
const canvas = mustExist(canvasRef.current);
|
||||
const ctx = mustExist(canvas.getContext('2d'));
|
||||
ctx.fillStyle = grayToRGB(brushColor);
|
||||
|
||||
for (const click of clicks) {
|
||||
ctx.beginPath();
|
||||
ctx.arc(click.x, click.y, brushSize, 0, FULL_CIRCLE);
|
||||
ctx.fill();
|
||||
}
|
||||
|
||||
clicks.length = 0;
|
||||
}, [clicks.length]);
|
||||
|
||||
function renderCanvas() {
|
||||
return <canvas
|
||||
ref={canvasRef}
|
||||
|
@ -217,10 +182,74 @@ export function Inpaint(props: InpaintProps) {
|
|||
/>;
|
||||
}
|
||||
|
||||
const save = useCallback(throttle(saveMask, SAVE_TIME), []);
|
||||
const state = mustExist(useContext(StateContext));
|
||||
const params = useStore(state, (s) => s.inpaint);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const setInpaint = useStore(state, (s) => s.setInpaint);
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
const setLoading = useStore(state, (s) => s.setLoading);
|
||||
|
||||
const query = useQueryClient();
|
||||
const upload = useMutation(uploadSource, {
|
||||
onSuccess: () => query.invalidateQueries({ queryKey: 'ready' }),
|
||||
});
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
const canvasRef = useRef<HTMLCanvasElement>(null);
|
||||
|
||||
// painting state
|
||||
const [clicks, setClicks] = useState<Array<Point>>([]);
|
||||
const [painting, setPainting] = useState(false);
|
||||
const [brushColor, setBrushColor] = useState(DEFAULT_BRUSH.color);
|
||||
const [brushSize, setBrushSize] = useState(DEFAULT_BRUSH.size);
|
||||
|
||||
useEffect(function changeMask() {
|
||||
// always draw the new mask to the canvas
|
||||
if (doesExist(params.mask)) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
drawSource(params.mask);
|
||||
}
|
||||
}, [params.mask]);
|
||||
|
||||
useEffect(function changeSource() {
|
||||
// draw the source to the canvas if the mask has not been set
|
||||
if (doesExist(params.source) && doesExist(params.mask) === false) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
drawSource(params.source);
|
||||
}
|
||||
}, [params.source]);
|
||||
|
||||
useEffect(() => {
|
||||
// including clicks.length prevents the initial render from saving a blank canvas
|
||||
if (doesExist(canvasRef.current) && clicks.length > 0) {
|
||||
const ctx = mustExist(canvasRef.current.getContext('2d'));
|
||||
ctx.fillStyle = grayToRGB(brushColor);
|
||||
|
||||
for (const click of clicks) {
|
||||
ctx.beginPath();
|
||||
ctx.arc(click.x, click.y, brushSize, 0, FULL_CIRCLE);
|
||||
ctx.fill();
|
||||
}
|
||||
|
||||
clicks.length = 0;
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
save();
|
||||
}
|
||||
}, [clicks.length]);
|
||||
|
||||
return <Box>
|
||||
<Stack spacing={2}>
|
||||
<ImageInput filter={IMAGE_FILTER} label='Source' onChange={changeSource} />
|
||||
<ImageInput filter={IMAGE_FILTER} label='Mask' onChange={changeMask} renderImage={renderCanvas} />
|
||||
<ImageInput filter={IMAGE_FILTER} image={params.source} label='Source' onChange={(file) => {
|
||||
setInpaint({
|
||||
source: file,
|
||||
});
|
||||
}} />
|
||||
<ImageInput filter={IMAGE_FILTER} image={params.mask} label='Mask' onChange={(file) => {
|
||||
setInpaint({
|
||||
mask: file,
|
||||
});
|
||||
}} renderImage={renderCanvas} />
|
||||
<Stack direction='row' spacing={4}>
|
||||
<NumericField
|
||||
decimal
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import { Maybe } from '@apextoaster/js-utils';
|
||||
|
||||
import { Img2ImgParams, STATUS_SUCCESS, Txt2ImgParams } from './api/client.js';
|
||||
|
||||
export interface ConfigNumber {
|
||||
|
@ -12,16 +14,20 @@ export interface ConfigString {
|
|||
keys: Array<string>;
|
||||
}
|
||||
|
||||
export type KeyFilter<T extends object> = {
|
||||
[K in keyof T]: T[K] extends number ? K : T[K] extends string ? K : never;
|
||||
export type KeyFilter<T extends object, TValid = number | string> = {
|
||||
[K in keyof T]: T[K] extends TValid ? K : never;
|
||||
}[keyof T];
|
||||
|
||||
export type ConfigFiles<T extends object> = {
|
||||
[K in KeyFilter<T, Blob | File>]: Maybe<T[K]>;
|
||||
};
|
||||
|
||||
export type ConfigRanges<T extends object> = {
|
||||
[K in KeyFilter<T>]: T[K] extends number ? ConfigNumber : T[K] extends string ? ConfigString : never;
|
||||
};
|
||||
|
||||
export type ConfigState<T extends object> = {
|
||||
[K in KeyFilter<T>]: T[K] extends number ? number : T[K] extends string ? string : never;
|
||||
export type ConfigState<T extends object, TValid = number | string> = {
|
||||
[K in KeyFilter<T, TValid>]: T[K] extends TValid ? T[K] : never;
|
||||
};
|
||||
|
||||
export type ConfigParams = ConfigRanges<Required<Img2ImgParams & Txt2ImgParams>>;
|
||||
|
@ -45,6 +51,7 @@ export const DEFAULT_BRUSH = {
|
|||
export const IMAGE_FILTER = '.bmp, .jpg, .jpeg, .png';
|
||||
export const STALE_TIME = 300_000; // 5 minutes
|
||||
export const POLL_TIME = 5_000; // 5 seconds
|
||||
export const SAVE_TIME = 5_000; // 5 seconds
|
||||
|
||||
export async function loadConfig(): Promise<Config> {
|
||||
const configPath = new URL('./config.json', window.origin);
|
||||
|
|
|
@ -12,8 +12,6 @@ import { OnnxWeb } from './components/OnnxWeb.js';
|
|||
import { loadConfig } from './config.js';
|
||||
import { ClientContext, createStateSlices, OnnxState, StateContext } from './state.js';
|
||||
|
||||
const { createContext } = React;
|
||||
|
||||
export async function main() {
|
||||
// load config from GUI server
|
||||
const config = await loadConfig();
|
||||
|
@ -41,6 +39,20 @@ export async function main() {
|
|||
...createDefaultSlice(...slice),
|
||||
}), {
|
||||
name: 'onnx-web',
|
||||
partialize(s) {
|
||||
return {
|
||||
...s,
|
||||
img2img: {
|
||||
...s.img2img,
|
||||
source: undefined,
|
||||
},
|
||||
inpaint: {
|
||||
...s.inpaint,
|
||||
mask: undefined,
|
||||
source: undefined,
|
||||
},
|
||||
};
|
||||
},
|
||||
storage: createJSONStorage(() => localStorage),
|
||||
version: 3,
|
||||
}));
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
/* eslint-disable no-null/no-null */
|
||||
import { Maybe } from '@apextoaster/js-utils';
|
||||
import { createContext } from 'react';
|
||||
import { StateCreator, StoreApi } from 'zustand';
|
||||
|
@ -11,9 +12,9 @@ import {
|
|||
paramsFromConfig,
|
||||
Txt2ImgParams,
|
||||
} from './api/client.js';
|
||||
import { ConfigParams, ConfigState } from './config.js';
|
||||
import { ConfigFiles, ConfigParams, ConfigState } from './config.js';
|
||||
|
||||
type TabState<TabParams extends BaseImgParams> = ConfigState<Required<TabParams>>;
|
||||
type TabState<TabParams extends BaseImgParams> = ConfigFiles<Required<TabParams>> & ConfigState<Required<TabParams>>;
|
||||
|
||||
interface Txt2ImgSlice {
|
||||
txt2img: TabState<Txt2ImgParams>;
|
||||
|
@ -86,6 +87,7 @@ export function createStateSlices(base: ConfigParams) {
|
|||
const createImg2ImgSlice: StateCreator<OnnxState, [], [], Img2ImgSlice> = (set) => ({
|
||||
img2img: {
|
||||
...defaults,
|
||||
source: null,
|
||||
strength: base.strength.default,
|
||||
},
|
||||
setImg2Img(params) {
|
||||
|
@ -100,6 +102,7 @@ export function createStateSlices(base: ConfigParams) {
|
|||
set({
|
||||
img2img: {
|
||||
...defaults,
|
||||
source: null,
|
||||
strength: base.strength.default,
|
||||
},
|
||||
});
|
||||
|
@ -109,6 +112,8 @@ export function createStateSlices(base: ConfigParams) {
|
|||
const createInpaintSlice: StateCreator<OnnxState, [], [], InpaintSlice> = (set) => ({
|
||||
inpaint: {
|
||||
...defaults,
|
||||
mask: null,
|
||||
source: null,
|
||||
},
|
||||
setInpaint(params) {
|
||||
set((prev) => ({
|
||||
|
@ -122,6 +127,8 @@ export function createStateSlices(base: ConfigParams) {
|
|||
set({
|
||||
inpaint: {
|
||||
...defaults,
|
||||
mask: null,
|
||||
source: null,
|
||||
},
|
||||
});
|
||||
},
|
||||
|
@ -130,7 +137,6 @@ export function createStateSlices(base: ConfigParams) {
|
|||
const createHistorySlice: StateCreator<OnnxState, [], [], HistorySlice> = (set) => ({
|
||||
history: [],
|
||||
limit: 4,
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
loading: null,
|
||||
pushHistory(image) {
|
||||
set((prev) => ({
|
||||
|
@ -139,7 +145,6 @@ export function createStateSlices(base: ConfigParams) {
|
|||
image,
|
||||
...prev.history,
|
||||
],
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
loading: null,
|
||||
}));
|
||||
},
|
||||
|
|
Loading…
Reference in New Issue