feat: implement blend tab and copy buttons (#62)
This commit is contained in:
parent
1de591e15f
commit
7fa1783be4
|
@ -1,6 +1,7 @@
|
|||
from .base import ChainPipeline, PipelineStage, StageCallback, StageParams
|
||||
from .blend_img2img import blend_img2img
|
||||
from .blend_inpaint import blend_inpaint
|
||||
from .blend_mask import blend_mask
|
||||
from .correct_codeformer import correct_codeformer
|
||||
from .correct_gfpgan import correct_gfpgan
|
||||
from .persist_disk import persist_disk
|
||||
|
|
|
@ -27,7 +27,7 @@ def blend_img2img(
|
|||
**kwargs,
|
||||
) -> Image.Image:
|
||||
prompt = prompt or params.prompt
|
||||
logger.info("generating image using img2img, %s steps: %s", params.steps, prompt)
|
||||
logger.info("blending image using img2img, %s steps: %s", params.steps, prompt)
|
||||
|
||||
pipe = load_pipeline(
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
|
|
|
@ -32,7 +32,9 @@ def blend_inpaint(
|
|||
callback: ProgressCallback = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
logger.info("upscaling image by expanding borders", expand)
|
||||
logger.info(
|
||||
"blending image using inpaint, %s steps: %s", params.steps, params.prompt
|
||||
)
|
||||
|
||||
if mask_image is None:
|
||||
# if no mask was provided, keep the full source image
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onnx_web.output import save_image
|
||||
|
||||
from ..device_pool import JobContext, ProgressCallback
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..utils import ServerContext, is_debug
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def blend_mask(
|
||||
_job: JobContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
_params: ImageParams,
|
||||
*,
|
||||
sources: Optional[List[Image.Image]] = None,
|
||||
mask: Optional[Image.Image] = None,
|
||||
_callback: ProgressCallback = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
logger.info("blending image using mask")
|
||||
|
||||
l_mask = Image.new("RGBA", mask.size, color="black")
|
||||
l_mask.alpha_composite(mask)
|
||||
l_mask = l_mask.convert("L")
|
||||
|
||||
if is_debug():
|
||||
save_image(server, "last-mask.png", mask)
|
||||
save_image(server, "last-mask-l.png", l_mask)
|
||||
|
||||
return Image.composite(sources[0], sources[1], l_mask)
|
|
@ -1,11 +1,12 @@
|
|||
from logging import getLogger
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from onnx_web.chain import blend_mask
|
||||
from onnx_web.chain.base import ChainProgress
|
||||
|
||||
from ..chain import upscale_outpaint
|
||||
|
@ -231,3 +232,39 @@ def run_upscale_pipeline(
|
|||
run_gc()
|
||||
|
||||
logger.info("finished upscale job: %s", dest)
|
||||
|
||||
|
||||
def run_blend_pipeline(
|
||||
job: JobContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
output: str,
|
||||
upscale: UpscaleParams,
|
||||
sources: List[Image.Image],
|
||||
mask: Image.Image,
|
||||
) -> None:
|
||||
progress = job.get_progress_callback()
|
||||
stage = StageParams()
|
||||
|
||||
image = blend_mask(
|
||||
job,
|
||||
server,
|
||||
stage,
|
||||
params,
|
||||
sources=sources,
|
||||
mask=mask,
|
||||
callback=progress,
|
||||
)
|
||||
|
||||
image = run_upscale_correction(
|
||||
job, server, stage, params, image, upscale=upscale, callback=progress
|
||||
)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale)
|
||||
|
||||
del image
|
||||
run_gc()
|
||||
|
||||
logger.info("finished blend job: %s", dest)
|
||||
|
|
|
@ -34,6 +34,7 @@ from .chain import (
|
|||
from .device_pool import DevicePoolExecutor
|
||||
from .diffusion.load import pipeline_schedulers
|
||||
from .diffusion.run import (
|
||||
run_blend_pipeline,
|
||||
run_img2img_pipeline,
|
||||
run_inpaint_pipeline,
|
||||
run_txt2img_pipeline,
|
||||
|
@ -735,6 +736,42 @@ def chain():
|
|||
return jsonify(json_params(output, params, size))
|
||||
|
||||
|
||||
@app.route("/api/blend", methods=["POST"])
|
||||
def blend():
|
||||
if "mask" not in request.files:
|
||||
return error_reply("mask image is required")
|
||||
|
||||
mask_file = request.files.get("mask")
|
||||
mask = Image.open(BytesIO(mask_file.read())).convert("RGBA")
|
||||
|
||||
source_file = request.files.get("source:0")
|
||||
source_0 = Image.open(BytesIO(source_file.read())).convert("RGBA")
|
||||
|
||||
source_file = request.files.get("source:1")
|
||||
source_1 = Image.open(BytesIO(source_file.read())).convert("RGBA")
|
||||
|
||||
device, params, size = pipeline_from_request()
|
||||
upscale = upscale_from_request()
|
||||
|
||||
output = make_output_name(context, "upscale", params, size)
|
||||
logger.info("upscale job queued for: %s", output)
|
||||
|
||||
executor.submit(
|
||||
output,
|
||||
run_blend_pipeline,
|
||||
context,
|
||||
params,
|
||||
size,
|
||||
output,
|
||||
upscale,
|
||||
[source_0, source_1],
|
||||
mask,
|
||||
needs_device=device,
|
||||
)
|
||||
|
||||
return jsonify(json_params(output, params, size))
|
||||
|
||||
|
||||
@app.route("/api/cancel", methods=["PUT"])
|
||||
def cancel():
|
||||
output_file = request.args.get("output", None)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import { doesExist } from '@apextoaster/js-utils';
|
||||
|
||||
import { ServerParams } from './config.js';
|
||||
import { range } from './utils.js';
|
||||
|
||||
/**
|
||||
* Shared parameters for anything using models, which is pretty much everything.
|
||||
|
@ -482,7 +483,26 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
});
|
||||
},
|
||||
async blend(model: ModelParams, params: BlendParams, upscale: UpscaleParams): Promise<ImageResponse> {
|
||||
throw new Error('TODO');
|
||||
const url = makeApiUrl(root, 'blend');
|
||||
appendModelToURL(url, model);
|
||||
|
||||
if (doesExist(upscale)) {
|
||||
appendUpscaleToURL(url, upscale);
|
||||
}
|
||||
|
||||
const body = new FormData();
|
||||
body.append('mask', params.mask, 'mask');
|
||||
|
||||
for (const i of range(params.sources.length)) {
|
||||
const name = `source:${i.toFixed(0)}`;
|
||||
body.append(name, params.sources[i], name);
|
||||
}
|
||||
|
||||
// eslint-disable-next-line no-return-await
|
||||
return await throttleRequest(url, {
|
||||
body,
|
||||
method: 'POST',
|
||||
});
|
||||
},
|
||||
async ready(params: ImageResponse): Promise<ReadyResponse> {
|
||||
const path = makeApiUrl(root, 'ready');
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
|
||||
import { Blender, Brush, ContentCopy, CropFree, Delete, Download, ZoomOutMap } from '@mui/icons-material';
|
||||
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Paper, Tooltip } from '@mui/material';
|
||||
import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils';
|
||||
import { Blender, Brush, ContentCopy, Delete, Download, ZoomOutMap } from '@mui/icons-material';
|
||||
import { Box, Card, CardContent, CardMedia, Grid, IconButton, Menu, MenuItem, Paper, Tooltip } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useContext } from 'react';
|
||||
import { useContext, useState } from 'react';
|
||||
import { useHash } from 'react-use/lib/useHash';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { ImageResponse } from '../client.js';
|
||||
import { ConfigContext, StateContext } from '../state.js';
|
||||
import { BLEND_SOURCES, ConfigContext, StateContext } from '../state.js';
|
||||
import { MODEL_LABELS, SCHEDULER_LABELS } from '../strings.js';
|
||||
import { range, visibleIndex } from '../utils.js';
|
||||
|
||||
export interface ImageCardProps {
|
||||
value: ImageResponse;
|
||||
|
@ -27,6 +28,8 @@ export function ImageCard(props: ImageCardProps) {
|
|||
const { params, output, size } = value;
|
||||
|
||||
const [_hash, setHash] = useHash();
|
||||
const [anchor, setAnchor] = useState<Maybe<HTMLElement>>();
|
||||
|
||||
const config = mustExist(useContext(ConfigContext));
|
||||
const state = mustExist(useContext(StateContext));
|
||||
// eslint-disable-next-line @typescript-eslint/unbound-method
|
||||
|
@ -67,11 +70,13 @@ export function ImageCard(props: ImageCardProps) {
|
|||
setHash('upscale');
|
||||
}
|
||||
|
||||
async function copySourceToBlend() {
|
||||
async function copySourceToBlend(idx: number) {
|
||||
const blob = await loadSource();
|
||||
// TODO: push instead
|
||||
const sources = mustDefault(state.getState().blend.sources, []);
|
||||
const newSources = [...sources];
|
||||
newSources[idx] = blob;
|
||||
setBlend({
|
||||
sources: [blob],
|
||||
sources: newSources,
|
||||
});
|
||||
setHash('blend');
|
||||
}
|
||||
|
@ -86,6 +91,10 @@ export function ImageCard(props: ImageCardProps) {
|
|||
window.open(output.url, '_blank');
|
||||
}
|
||||
|
||||
function close() {
|
||||
setAnchor(undefined);
|
||||
}
|
||||
|
||||
const model = mustDefault(MODEL_LABELS[params.model], params.model);
|
||||
const scheduler = mustDefault(SCHEDULER_LABELS[params.scheduler], params.scheduler);
|
||||
|
||||
|
@ -137,10 +146,24 @@ export function ImageCard(props: ImageCardProps) {
|
|||
</GridItem>
|
||||
<GridItem xs={2}>
|
||||
<Tooltip title='Blend'>
|
||||
<IconButton onClick={copySourceToBlend}>
|
||||
<IconButton onClick={(event) => {
|
||||
setAnchor(event.currentTarget);
|
||||
}}>
|
||||
<Blender />
|
||||
</IconButton>
|
||||
</Tooltip>
|
||||
<Menu
|
||||
anchorEl={anchor}
|
||||
open={doesExist(anchor)}
|
||||
onClose={close}
|
||||
>
|
||||
{range(BLEND_SOURCES).map((idx) => <MenuItem key={idx} onClick={() => {
|
||||
copySourceToBlend(idx).catch((err) => {
|
||||
// TODO
|
||||
});
|
||||
close();
|
||||
}}>{visibleIndex(idx)}</MenuItem>)}
|
||||
</Menu>
|
||||
</GridItem>
|
||||
<GridItem xs={2}>
|
||||
<Tooltip title='Delete'>
|
||||
|
|
|
@ -22,7 +22,7 @@ export interface ImageControlProps {
|
|||
}
|
||||
|
||||
/**
|
||||
* doesn't need to use state, the parent component knows which params to pass
|
||||
* Doesn't need to use state directly, the parent component knows which params to pass
|
||||
*/
|
||||
export function ImageControl(props: ImageControlProps) {
|
||||
const { params } = mustExist(useContext(ConfigContext));
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
import { doesExist, Maybe, mustExist } from '@apextoaster/js-utils';
|
||||
import { FormatColorFill, Gradient, InvertColors, Undo } from '@mui/icons-material';
|
||||
import { Button, Stack, Typography } from '@mui/material';
|
||||
import { createLogger } from 'browser-bunyan';
|
||||
import { throttle } from 'lodash';
|
||||
import React, { RefObject, useContext, useEffect, useMemo, useRef } from 'react';
|
||||
import { useStore } from 'zustand';
|
||||
|
||||
import { SAVE_TIME } from '../../config.js';
|
||||
import { ConfigContext, StateContext } from '../../state.js';
|
||||
import { ConfigContext, LoggerContext, StateContext } from '../../state.js';
|
||||
import { imageFromBlob } from '../../utils.js';
|
||||
import { NumericField } from './NumericField';
|
||||
|
||||
|
@ -42,11 +41,10 @@ export interface MaskCanvasProps {
|
|||
onSave: (blob: Blob) => void;
|
||||
}
|
||||
|
||||
const logger = createLogger({ name: 'react', level: 'debug' }); // TODO: hackeroni and cheese
|
||||
|
||||
export function MaskCanvas(props: MaskCanvasProps) {
|
||||
const { source, mask } = props;
|
||||
const { params } = mustExist(useContext(ConfigContext));
|
||||
const logger = mustExist(useContext(LoggerContext));
|
||||
|
||||
function composite() {
|
||||
if (doesExist(viewRef.current)) {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils';
|
||||
import { mustDefault, mustExist } from '@apextoaster/js-utils';
|
||||
import { Box, Button, Stack } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useContext } from 'react';
|
||||
|
@ -6,7 +6,8 @@ import { useMutation, useQueryClient } from 'react-query';
|
|||
import { useStore } from 'zustand';
|
||||
|
||||
import { IMAGE_FILTER } from '../../config.js';
|
||||
import { ClientContext, StateContext } from '../../state.js';
|
||||
import { BLEND_SOURCES, ClientContext, StateContext } from '../../state.js';
|
||||
import { range } from '../../utils.js';
|
||||
import { UpscaleControl } from '../control/UpscaleControl.js';
|
||||
import { ImageInput } from '../input/ImageInput.js';
|
||||
import { MaskCanvas } from '../input/MaskCanvas.js';
|
||||
|
@ -41,22 +42,30 @@ export function Blend() {
|
|||
|
||||
return <Box>
|
||||
<Stack spacing={2}>
|
||||
{range(BLEND_SOURCES).map((idx) =>
|
||||
<ImageInput
|
||||
key={`source-${idx.toFixed(0)}`}
|
||||
filter={IMAGE_FILTER}
|
||||
image={sources[0]}
|
||||
image={sources[idx]}
|
||||
hideSelection={true}
|
||||
label='Source'
|
||||
onChange={(file) => {
|
||||
const newSources = [...sources];
|
||||
newSources[idx] = file;
|
||||
|
||||
setBlend({
|
||||
sources: [file],
|
||||
sources: newSources,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<MaskCanvas
|
||||
source={sources[0]}
|
||||
mask={blend.mask}
|
||||
onSave={() => {
|
||||
// TODO
|
||||
onSave={(mask) => {
|
||||
setBlend({
|
||||
mask,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
<UpscaleControl />
|
||||
|
|
|
@ -6,6 +6,7 @@ import { QueryClient, QueryClientProvider } from 'react-query';
|
|||
import { satisfies } from 'semver';
|
||||
import { createStore } from 'zustand';
|
||||
import { createJSONStorage, persist } from 'zustand/middleware';
|
||||
import { createLogger } from 'browser-bunyan';
|
||||
|
||||
import { makeClient } from './client.js';
|
||||
import { ParamsVersionError } from './components/error/ParamsVersion.js';
|
||||
|
@ -13,7 +14,7 @@ import { ServerParamsError } from './components/error/ServerParams.js';
|
|||
import { OnnxError } from './components/OnnxError.js';
|
||||
import { OnnxWeb } from './components/OnnxWeb.js';
|
||||
import { getApiRoot, loadConfig, mergeConfig, PARAM_VERSION } from './config.js';
|
||||
import { ClientContext, ConfigContext, createStateSlices, OnnxState, STATE_VERSION, StateContext } from './state.js';
|
||||
import { ClientContext, ConfigContext, createStateSlices, OnnxState, STATE_VERSION, StateContext, LoggerContext } from './state.js';
|
||||
|
||||
export const INITIAL_LOAD_TIMEOUT = 5_000;
|
||||
|
||||
|
@ -91,6 +92,12 @@ export async function main() {
|
|||
version: STATE_VERSION,
|
||||
}));
|
||||
|
||||
const logger = createLogger({
|
||||
name: 'onnx-web',
|
||||
system: 'react',
|
||||
level: 'debug',
|
||||
});
|
||||
|
||||
// prep react-query client
|
||||
const query = new QueryClient();
|
||||
|
||||
|
@ -98,9 +105,11 @@ export async function main() {
|
|||
app.render(<QueryClientProvider client={query}>
|
||||
<ClientContext.Provider value={client}>
|
||||
<ConfigContext.Provider value={completeConfig}>
|
||||
<LoggerContext.Provider value={logger}>
|
||||
<StateContext.Provider value={state}>
|
||||
<OnnxWeb />
|
||||
</StateContext.Provider>
|
||||
</LoggerContext.Provider>
|
||||
</ConfigContext.Provider>
|
||||
</ClientContext.Provider>
|
||||
</QueryClientProvider>);
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
/* eslint-disable max-lines */
|
||||
/* eslint-disable no-null/no-null */
|
||||
import { doesExist, Maybe } from '@apextoaster/js-utils';
|
||||
import { Logger } from 'noicejs';
|
||||
import { createContext } from 'react';
|
||||
import { StateCreator, StoreApi } from 'zustand';
|
||||
|
||||
|
@ -145,6 +146,11 @@ export const ClientContext = createContext<Maybe<ApiClient>>(undefined);
|
|||
*/
|
||||
export const ConfigContext = createContext<Maybe<Config<ServerParams>>>(undefined);
|
||||
|
||||
/**
|
||||
* React context binding for bunyan logger.
|
||||
*/
|
||||
export const LoggerContext = createContext<Maybe<Logger>>(undefined);
|
||||
|
||||
/**
|
||||
* React context binding for zustand state store.
|
||||
*/
|
||||
|
@ -160,6 +166,8 @@ export const STATE_KEY = 'onnx-web';
|
|||
*/
|
||||
export const STATE_VERSION = 5;
|
||||
|
||||
export const BLEND_SOURCES = 2;
|
||||
|
||||
/**
|
||||
* Default parameters for the inpaint brush.
|
||||
*
|
||||
|
|
|
@ -10,3 +10,11 @@ export function imageFromBlob(blob: Blob): Promise<HTMLImageElement> {
|
|||
image.src = src;
|
||||
});
|
||||
}
|
||||
|
||||
export function range(max: number): Array<number> {
|
||||
return [...Array(max).keys()];
|
||||
}
|
||||
|
||||
export function visibleIndex(idx: number): string {
|
||||
return (idx + 1).toFixed(0);
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"cSpell.words": [
|
||||
"astype",
|
||||
"basicsr",
|
||||
"Civitai",
|
||||
"ckpt",
|
||||
"codebook",
|
||||
"codeformer",
|
||||
|
|
Loading…
Reference in New Issue