1
0
Fork 0

feat: implement blend tab and copy buttons (#62)

This commit is contained in:
Sean Sube 2023-02-13 17:34:42 -06:00
parent 1de591e15f
commit 7fa1783be4
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
15 changed files with 226 additions and 37 deletions

View File

@ -1,6 +1,7 @@
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
from .blend_img2img import blend_img2img
from .blend_inpaint import blend_inpaint
from .blend_mask import blend_mask
from .correct_codeformer import correct_codeformer
from .correct_gfpgan import correct_gfpgan
from .persist_disk import persist_disk

View File

@ -27,7 +27,7 @@ def blend_img2img(
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
logger.info("generating image using img2img, %s steps: %s", params.steps, prompt)
logger.info("blending image using img2img, %s steps: %s", params.steps, prompt)
pipe = load_pipeline(
OnnxStableDiffusionImg2ImgPipeline,

View File

@ -32,7 +32,9 @@ def blend_inpaint(
callback: ProgressCallback = None,
**kwargs,
) -> Image.Image:
logger.info("upscaling image by expanding borders", expand)
logger.info(
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
)
if mask_image is None:
# if no mask was provided, keep the full source image

View File

@ -0,0 +1,36 @@
from logging import getLogger
from typing import List, Optional
from PIL import Image
from onnx_web.output import save_image
from ..device_pool import JobContext, ProgressCallback
from ..params import ImageParams, StageParams
from ..utils import ServerContext, is_debug
logger = getLogger(__name__)
def blend_mask(
_job: JobContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
*,
sources: Optional[List[Image.Image]] = None,
mask: Optional[Image.Image] = None,
_callback: ProgressCallback = None,
**kwargs,
) -> Image.Image:
logger.info("blending image using mask")
l_mask = Image.new("RGBA", mask.size, color="black")
l_mask.alpha_composite(mask)
l_mask = l_mask.convert("L")
if is_debug():
save_image(server, "last-mask.png", mask)
save_image(server, "last-mask-l.png", l_mask)
return Image.composite(sources[0], sources[1], l_mask)

View File

@ -1,11 +1,12 @@
from logging import getLogger
from typing import Any
from typing import Any, List
import numpy as np
import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
from PIL import Image, ImageChops
from onnx_web.chain import blend_mask
from onnx_web.chain.base import ChainProgress
from ..chain import upscale_outpaint
@ -231,3 +232,39 @@ def run_upscale_pipeline(
run_gc()
logger.info("finished upscale job: %s", dest)
def run_blend_pipeline(
job: JobContext,
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
upscale: UpscaleParams,
sources: List[Image.Image],
mask: Image.Image,
) -> None:
progress = job.get_progress_callback()
stage = StageParams()
image = blend_mask(
job,
server,
stage,
params,
sources=sources,
mask=mask,
callback=progress,
)
image = run_upscale_correction(
job, server, stage, params, image, upscale=upscale, callback=progress
)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
del image
run_gc()
logger.info("finished blend job: %s", dest)

View File

@ -34,6 +34,7 @@ from .chain import (
from .device_pool import DevicePoolExecutor
from .diffusion.load import pipeline_schedulers
from .diffusion.run import (
run_blend_pipeline,
run_img2img_pipeline,
run_inpaint_pipeline,
run_txt2img_pipeline,
@ -735,6 +736,42 @@ def chain():
return jsonify(json_params(output, params, size))
@app.route("/api/blend", methods=["POST"])
def blend():
if "mask" not in request.files:
return error_reply("mask image is required")
mask_file = request.files.get("mask")
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
source_file = request.files.get("source:0")
source_0 = Image.open(BytesIO(source_file.read())).convert("RGBA")
source_file = request.files.get("source:1")
source_1 = Image.open(BytesIO(source_file.read())).convert("RGBA")
device, params, size = pipeline_from_request()
upscale = upscale_from_request()
output = make_output_name(context, "upscale", params, size)
logger.info("upscale job queued for: %s", output)
executor.submit(
output,
run_blend_pipeline,
context,
params,
size,
output,
upscale,
[source_0, source_1],
mask,
needs_device=device,
)
return jsonify(json_params(output, params, size))
@app.route("/api/cancel", methods=["PUT"])
def cancel():
output_file = request.args.get("output", None)

View File

@ -2,6 +2,7 @@
import { doesExist } from '@apextoaster/js-utils';
import { ServerParams } from './config.js';
import { range } from './utils.js';
/**
* Shared parameters for anything using models, which is pretty much everything.
@ -482,7 +483,26 @@ export function makeClient(root: string, f = fetch): ApiClient {
});
},
async blend(model: ModelParams, params: BlendParams, upscale: UpscaleParams): Promise<ImageResponse> {
throw new Error('TODO');
const url = makeApiUrl(root, 'blend');
appendModelToURL(url, model);
if (doesExist(upscale)) {
appendUpscaleToURL(url, upscale);
}
const body = new FormData();
body.append('mask', params.mask, 'mask');
for (const i of range(params.sources.length)) {
const name = `source:${i.toFixed(0)}`;
body.append(name, params.sources[i], name);
}
// eslint-disable-next-line no-return-await
return await throttleRequest(url, {
body,
method: 'POST',
});
},
async ready(params: ImageResponse): Promise<ReadyResponse> {
const path = makeApiUrl(root, 'ready');

View File

@ -1,14 +1,15 @@
import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
import { Blender, Brush, ContentCopy, CropFree, Delete, Download, ZoomOutMap } from '@mui/icons-material';
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Paper, Tooltip } from '@mui/material';
import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils';
import { Blender, Brush, ContentCopy, Delete, Download, ZoomOutMap } from '@mui/icons-material';
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Menu, MenuItem, Paper, Tooltip } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
import { useContext, useState } from 'react';
import { useHash } from 'react-use/lib/useHash';
import { useStore } from 'zustand';
import { ImageResponse } from '../client.js';
import { ConfigContext, StateContext } from '../state.js';
import { BLEND_SOURCES, ConfigContext, StateContext } from '../state.js';
import { MODEL_LABELS, SCHEDULER_LABELS } from '../strings.js';
import { range, visibleIndex } from '../utils.js';
export interface ImageCardProps {
value: ImageResponse;
@ -27,6 +28,8 @@ export function ImageCard(props: ImageCardProps) {
const { params, output, size } = value;
const [_hash, setHash] = useHash();
const [anchor, setAnchor] = useState<Maybe<HTMLElement>>();
const config = mustExist(useContext(ConfigContext));
const state = mustExist(useContext(StateContext));
// eslint-disable-next-line @typescript-eslint/unbound-method
@ -67,11 +70,13 @@ export function ImageCard(props: ImageCardProps) {
setHash('upscale');
}
async function copySourceToBlend() {
async function copySourceToBlend(idx: number) {
const blob = await loadSource();
// TODO: push instead
const sources = mustDefault(state.getState().blend.sources, []);
const newSources = [...sources];
newSources[idx] = blob;
setBlend({
sources: [blob],
sources: newSources,
});
setHash('blend');
}
@ -86,6 +91,10 @@ export function ImageCard(props: ImageCardProps) {
window.open(output.url, '_blank');
}
function close() {
setAnchor(undefined);
}
const model = mustDefault(MODEL_LABELS[params.model], params.model);
const scheduler = mustDefault(SCHEDULER_LABELS[params.scheduler], params.scheduler);
@ -137,10 +146,24 @@ export function ImageCard(props: ImageCardProps) {
</GridItem>
<GridItem xs={2}>
<Tooltip title='Blend'>
<IconButton onClick={copySourceToBlend}>
<IconButton onClick={(event) => {
setAnchor(event.currentTarget);
}}>
<Blender />
</IconButton>
</Tooltip>
<Menu
anchorEl={anchor}
open={doesExist(anchor)}
onClose={close}
>
{range(BLEND_SOURCES).map((idx) => <MenuItem key={idx} onClick={() => {
copySourceToBlend(idx).catch((err) => {
// TODO
});
close();
}}>{visibleIndex(idx)}</MenuItem>)}
</Menu>
</GridItem>
<GridItem xs={2}>
<Tooltip title='Delete'>

View File

@ -22,7 +22,7 @@ export interface ImageControlProps {
}
/**
* doesn't need to use state, the parent component knows which params to pass
* Doesn't need to use state directly, the parent component knows which params to pass
*/
export function ImageControl(props: ImageControlProps) {
const { params } = mustExist(useContext(ConfigContext));

View File

@ -1,13 +1,12 @@
import { doesExist, Maybe, mustExist } from '@apextoaster/js-utils';
import { FormatColorFill, Gradient, InvertColors, Undo } from '@mui/icons-material';
import { Button, Stack, Typography } from '@mui/material';
import { createLogger } from 'browser-bunyan';
import { throttle } from 'lodash';
import React, { RefObject, useContext, useEffect, useMemo, useRef } from 'react';
import { useStore } from 'zustand';
import { SAVE_TIME } from '../../config.js';
import { ConfigContext, StateContext } from '../../state.js';
import { ConfigContext, LoggerContext, StateContext } from '../../state.js';
import { imageFromBlob } from '../../utils.js';
import { NumericField } from './NumericField';
@ -42,11 +41,10 @@ export interface MaskCanvasProps {
onSave: (blob: Blob) => void;
}
const logger = createLogger({ name: 'react', level: 'debug' }); // TODO: hackeroni and cheese
export function MaskCanvas(props: MaskCanvasProps) {
const { source, mask } = props;
const { params } = mustExist(useContext(ConfigContext));
const logger = mustExist(useContext(LoggerContext));
function composite() {
if (doesExist(viewRef.current)) {

View File

@ -1,4 +1,4 @@
import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
import { mustDefault, mustExist } from '@apextoaster/js-utils';
import { Box, Button, Stack } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
@ -6,7 +6,8 @@ import { useMutation, useQueryClient } from 'react-query';
import { useStore } from 'zustand';
import { IMAGE_FILTER } from '../../config.js';
import { ClientContext, StateContext } from '../../state.js';
import { BLEND_SOURCES, ClientContext, StateContext } from '../../state.js';
import { range } from '../../utils.js';
import { UpscaleControl } from '../control/UpscaleControl.js';
import { ImageInput } from '../input/ImageInput.js';
import { MaskCanvas } from '../input/MaskCanvas.js';
@ -41,22 +42,30 @@ export function Blend() {
return <Box>
<Stack spacing={2}>
<ImageInput
filter={IMAGE_FILTER}
image={sources[0]}
hideSelection={true}
label='Source'
onChange={(file) => {
setBlend({
sources: [file],
});
}}
/>
{range(BLEND_SOURCES).map((idx) =>
<ImageInput
key={`source-${idx.toFixed(0)}`}
filter={IMAGE_FILTER}
image={sources[idx]}
hideSelection={true}
label='Source'
onChange={(file) => {
const newSources = [...sources];
newSources[idx] = file;
setBlend({
sources: newSources,
});
}}
/>
)}
<MaskCanvas
source={sources[0]}
mask={blend.mask}
onSave={() => {
// TODO
onSave={(mask) => {
setBlend({
mask,
});
}}
/>
<UpscaleControl />

View File

@ -6,6 +6,7 @@ import { QueryClient, QueryClientProvider } from 'react-query';
import { satisfies } from 'semver';
import { createStore } from 'zustand';
import { createJSONStorage, persist } from 'zustand/middleware';
import { createLogger } from 'browser-bunyan';
import { makeClient } from './client.js';
import { ParamsVersionError } from './components/error/ParamsVersion.js';
@ -13,7 +14,7 @@ import { ServerParamsError } from './components/error/ServerParams.js';
import { OnnxError } from './components/OnnxError.js';
import { OnnxWeb } from './components/OnnxWeb.js';
import { getApiRoot, loadConfig, mergeConfig, PARAM_VERSION } from './config.js';
import { ClientContext, ConfigContext, createStateSlices, OnnxState, STATE_VERSION, StateContext } from './state.js';
import { ClientContext, ConfigContext, createStateSlices, OnnxState, STATE_VERSION, StateContext, LoggerContext } from './state.js';
export const INITIAL_LOAD_TIMEOUT = 5_000;
@ -91,6 +92,12 @@ export async function main() {
version: STATE_VERSION,
}));
const logger = createLogger({
name: 'onnx-web',
system: 'react',
level: 'debug',
});
// prep react-query client
const query = new QueryClient();
@ -98,9 +105,11 @@ export async function main() {
app.render(<QueryClientProvider client={query}>
<ClientContext.Provider value={client}>
<ConfigContext.Provider value={completeConfig}>
<StateContext.Provider value={state}>
<OnnxWeb />
</StateContext.Provider>
<LoggerContext.Provider value={logger}>
<StateContext.Provider value={state}>
<OnnxWeb />
</StateContext.Provider>
</LoggerContext.Provider>
</ConfigContext.Provider>
</ClientContext.Provider>
</QueryClientProvider>);

View File

@ -1,6 +1,7 @@
/* eslint-disable max-lines */
/* eslint-disable no-null/no-null */
import { doesExist, Maybe } from '@apextoaster/js-utils';
import { Logger } from 'noicejs';
import { createContext } from 'react';
import { StateCreator, StoreApi } from 'zustand';
@ -145,6 +146,11 @@ export const ClientContext = createContext<Maybe<ApiClient>>(undefined);
*/
export const ConfigContext = createContext<Maybe<Config<ServerParams>>>(undefined);
/**
* React context binding for bunyan logger.
*/
export const LoggerContext = createContext<Maybe<Logger>>(undefined);
/**
* React context binding for zustand state store.
*/
@ -160,6 +166,8 @@ export const STATE_KEY = 'onnx-web';
*/
export const STATE_VERSION = 5;
export const BLEND_SOURCES = 2;
/**
* Default parameters for the inpaint brush.
*

View File

@ -10,3 +10,11 @@ export function imageFromBlob(blob: Blob): Promise<HTMLImageElement> {
image.src = src;
});
}
export function range(max: number): Array<number> {
return [...Array(max).keys()];
}
export function visibleIndex(idx: number): string {
return (idx + 1).toFixed(0);
}

View File

@ -15,6 +15,7 @@
"cSpell.words": [
"astype",
"basicsr",
"Civitai",
"ckpt",
"codebook",
"codeformer",