diff --git a/api/onnx_web/chain/__init__.py b/api/onnx_web/chain/__init__.py index 0312d85e..5aa56c56 100644 --- a/api/onnx_web/chain/__init__.py +++ b/api/onnx_web/chain/__init__.py @@ -1,6 +1,7 @@ from .base import ChainPipeline, PipelineStage, StageCallback, StageParams from .blend_img2img import blend_img2img from .blend_inpaint import blend_inpaint +from .blend_mask import blend_mask from .correct_codeformer import correct_codeformer from .correct_gfpgan import correct_gfpgan from .persist_disk import persist_disk diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 3257ffc5..b101125c 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -27,7 +27,7 @@ def blend_img2img( **kwargs, ) -> Image.Image: prompt = prompt or params.prompt - logger.info("generating image using img2img, %s steps: %s", params.steps, prompt) + logger.info("blending image using img2img, %s steps: %s", params.steps, prompt) pipe = load_pipeline( OnnxStableDiffusionImg2ImgPipeline, diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index b828fac9..114f095a 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -32,7 +32,9 @@ def blend_inpaint( callback: ProgressCallback = None, **kwargs, ) -> Image.Image: - logger.info("upscaling image by expanding borders", expand) + logger.info( + "blending image using inpaint, %s steps: %s", params.steps, params.prompt + ) if mask_image is None: # if no mask was provided, keep the full source image diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py new file mode 100644 index 00000000..b8f0d390 --- /dev/null +++ b/api/onnx_web/chain/blend_mask.py @@ -0,0 +1,36 @@ +from logging import getLogger +from typing import List, Optional + +from PIL import Image + +from onnx_web.output import save_image + +from ..device_pool import JobContext, ProgressCallback +from ..params import ImageParams, StageParams +from ..utils import ServerContext, is_debug + +logger = getLogger(__name__) + + +def blend_mask( + _job: JobContext, + server: ServerContext, + _stage: StageParams, + _params: ImageParams, + *, + sources: Optional[List[Image.Image]] = None, + mask: Optional[Image.Image] = None, + _callback: ProgressCallback = None, + **kwargs, +) -> Image.Image: + logger.info("blending image using mask") + + l_mask = Image.new("RGBA", mask.size, color="black") + l_mask.alpha_composite(mask) + l_mask = l_mask.convert("L") + + if is_debug(): + save_image(server, "last-mask.png", mask) + save_image(server, "last-mask-l.png", l_mask) + + return Image.composite(sources[0], sources[1], l_mask) diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 2b776b9c..e86ec9c4 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -1,11 +1,12 @@ from logging import getLogger -from typing import Any +from typing import Any, List import numpy as np import torch from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline from PIL import Image, ImageChops +from onnx_web.chain import blend_mask from onnx_web.chain.base import ChainProgress from ..chain import upscale_outpaint @@ -231,3 +232,39 @@ def run_upscale_pipeline( run_gc() logger.info("finished upscale job: %s", dest) + + +def run_blend_pipeline( + job: JobContext, + server: ServerContext, + params: ImageParams, + size: Size, + output: str, + upscale: UpscaleParams, + sources: List[Image.Image], + mask: Image.Image, +) -> None: + progress = job.get_progress_callback() + stage = StageParams() + + image = blend_mask( + job, + server, + stage, + params, + sources=sources, + mask=mask, + callback=progress, + ) + + image = run_upscale_correction( + job, server, stage, params, image, upscale=upscale, callback=progress + ) + + dest = save_image(server, output, image) + save_params(server, output, params, size, upscale=upscale) + + del image + run_gc() + + logger.info("finished blend job: %s", dest) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index c92b7629..12e9046d 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -34,6 +34,7 @@ from .chain import ( from .device_pool import DevicePoolExecutor from .diffusion.load import pipeline_schedulers from .diffusion.run import ( + run_blend_pipeline, run_img2img_pipeline, run_inpaint_pipeline, run_txt2img_pipeline, @@ -735,6 +736,42 @@ def chain(): return jsonify(json_params(output, params, size)) +@app.route("/api/blend", methods=["POST"]) +def blend(): + if "mask" not in request.files: + return error_reply("mask image is required") + + mask_file = request.files.get("mask") + mask = Image.open(BytesIO(mask_file.read())).convert("RGBA") + + source_file = request.files.get("source:0") + source_0 = Image.open(BytesIO(source_file.read())).convert("RGBA") + + source_file = request.files.get("source:1") + source_1 = Image.open(BytesIO(source_file.read())).convert("RGBA") + + device, params, size = pipeline_from_request() + upscale = upscale_from_request() + + output = make_output_name(context, "upscale", params, size) + logger.info("upscale job queued for: %s", output) + + executor.submit( + output, + run_blend_pipeline, + context, + params, + size, + output, + upscale, + [source_0, source_1], + mask, + needs_device=device, + ) + + return jsonify(json_params(output, params, size)) + + @app.route("/api/cancel", methods=["PUT"]) def cancel(): output_file = request.args.get("output", None) diff --git a/gui/src/client.ts b/gui/src/client.ts index a8198c94..a0a69645 100644 --- a/gui/src/client.ts +++ b/gui/src/client.ts @@ -2,6 +2,7 @@ import { doesExist } from '@apextoaster/js-utils'; import { ServerParams } from './config.js'; +import { range } from './utils.js'; /** * Shared parameters for anything using models, which is pretty much everything. @@ -482,7 +483,26 @@ export function makeClient(root: string, f = fetch): ApiClient { }); }, async blend(model: ModelParams, params: BlendParams, upscale: UpscaleParams): Promise { - throw new Error('TODO'); + const url = makeApiUrl(root, 'blend'); + appendModelToURL(url, model); + + if (doesExist(upscale)) { + appendUpscaleToURL(url, upscale); + } + + const body = new FormData(); + body.append('mask', params.mask, 'mask'); + + for (const i of range(params.sources.length)) { + const name = `source:${i.toFixed(0)}`; + body.append(name, params.sources[i], name); + } + + // eslint-disable-next-line no-return-await + return await throttleRequest(url, { + body, + method: 'POST', + }); }, async ready(params: ImageResponse): Promise { const path = makeApiUrl(root, 'ready'); diff --git a/gui/src/components/ImageCard.tsx b/gui/src/components/ImageCard.tsx index ec9feaa9..7bbcfe4e 100644 --- a/gui/src/components/ImageCard.tsx +++ b/gui/src/components/ImageCard.tsx @@ -1,14 +1,15 @@ -import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils'; -import { Blender, Brush, ContentCopy, CropFree, Delete, Download, ZoomOutMap } from '@mui/icons-material'; -import { Box, Card, CardContent, CardMedia, Grid, IconButton, Paper, Tooltip } from '@mui/material'; +import { doesExist, Maybe, mustDefault, mustExist } from '@apextoaster/js-utils'; +import { Blender, Brush, ContentCopy, Delete, Download, ZoomOutMap } from '@mui/icons-material'; +import { Box, Card, CardContent, CardMedia, Grid, IconButton, Menu, MenuItem, Paper, Tooltip } from '@mui/material'; import * as React from 'react'; -import { useContext } from 'react'; +import { useContext, useState } from 'react'; import { useHash } from 'react-use/lib/useHash'; import { useStore } from 'zustand'; import { ImageResponse } from '../client.js'; -import { ConfigContext, StateContext } from '../state.js'; +import { BLEND_SOURCES, ConfigContext, StateContext } from '../state.js'; import { MODEL_LABELS, SCHEDULER_LABELS } from '../strings.js'; +import { range, visibleIndex } from '../utils.js'; export interface ImageCardProps { value: ImageResponse; @@ -27,6 +28,8 @@ export function ImageCard(props: ImageCardProps) { const { params, output, size } = value; const [_hash, setHash] = useHash(); + const [anchor, setAnchor] = useState>(); + const config = mustExist(useContext(ConfigContext)); const state = mustExist(useContext(StateContext)); // eslint-disable-next-line @typescript-eslint/unbound-method @@ -67,11 +70,13 @@ export function ImageCard(props: ImageCardProps) { setHash('upscale'); } - async function copySourceToBlend() { + async function copySourceToBlend(idx: number) { const blob = await loadSource(); - // TODO: push instead + const sources = mustDefault(state.getState().blend.sources, []); + const newSources = [...sources]; + newSources[idx] = blob; setBlend({ - sources: [blob], + sources: newSources, }); setHash('blend'); } @@ -86,6 +91,10 @@ export function ImageCard(props: ImageCardProps) { window.open(output.url, '_blank'); } + function close() { + setAnchor(undefined); + } + const model = mustDefault(MODEL_LABELS[params.model], params.model); const scheduler = mustDefault(SCHEDULER_LABELS[params.scheduler], params.scheduler); @@ -137,10 +146,24 @@ export function ImageCard(props: ImageCardProps) { - + { + setAnchor(event.currentTarget); + }}> + + {range(BLEND_SOURCES).map((idx) => { + copySourceToBlend(idx).catch((err) => { + // TODO + }); + close(); + }}>{visibleIndex(idx)})} + diff --git a/gui/src/components/control/ImageControl.tsx b/gui/src/components/control/ImageControl.tsx index 6420cfba..a01a16f8 100644 --- a/gui/src/components/control/ImageControl.tsx +++ b/gui/src/components/control/ImageControl.tsx @@ -22,7 +22,7 @@ export interface ImageControlProps { } /** - * doesn't need to use state, the parent component knows which params to pass + * Doesn't need to use state directly, the parent component knows which params to pass */ export function ImageControl(props: ImageControlProps) { const { params } = mustExist(useContext(ConfigContext)); diff --git a/gui/src/components/input/MaskCanvas.tsx b/gui/src/components/input/MaskCanvas.tsx index 5f3aeb19..381a1759 100644 --- a/gui/src/components/input/MaskCanvas.tsx +++ b/gui/src/components/input/MaskCanvas.tsx @@ -1,13 +1,12 @@ import { doesExist, Maybe, mustExist } from '@apextoaster/js-utils'; import { FormatColorFill, Gradient, InvertColors, Undo } from '@mui/icons-material'; import { Button, Stack, Typography } from '@mui/material'; -import { createLogger } from 'browser-bunyan'; import { throttle } from 'lodash'; import React, { RefObject, useContext, useEffect, useMemo, useRef } from 'react'; import { useStore } from 'zustand'; import { SAVE_TIME } from '../../config.js'; -import { ConfigContext, StateContext } from '../../state.js'; +import { ConfigContext, LoggerContext, StateContext } from '../../state.js'; import { imageFromBlob } from '../../utils.js'; import { NumericField } from './NumericField'; @@ -42,11 +41,10 @@ export interface MaskCanvasProps { onSave: (blob: Blob) => void; } -const logger = createLogger({ name: 'react', level: 'debug' }); // TODO: hackeroni and cheese - export function MaskCanvas(props: MaskCanvasProps) { const { source, mask } = props; const { params } = mustExist(useContext(ConfigContext)); + const logger = mustExist(useContext(LoggerContext)); function composite() { if (doesExist(viewRef.current)) { diff --git a/gui/src/components/tab/Blend.tsx b/gui/src/components/tab/Blend.tsx index 51660eb2..2460e400 100644 --- a/gui/src/components/tab/Blend.tsx +++ b/gui/src/components/tab/Blend.tsx @@ -1,4 +1,4 @@ -import { doesExist, mustDefault, mustExist } from '@apextoaster/js-utils'; +import { mustDefault, mustExist } from '@apextoaster/js-utils'; import { Box, Button, Stack } from '@mui/material'; import * as React from 'react'; import { useContext } from 'react'; @@ -6,7 +6,8 @@ import { useMutation, useQueryClient } from 'react-query'; import { useStore } from 'zustand'; import { IMAGE_FILTER } from '../../config.js'; -import { ClientContext, StateContext } from '../../state.js'; +import { BLEND_SOURCES, ClientContext, StateContext } from '../../state.js'; +import { range } from '../../utils.js'; import { UpscaleControl } from '../control/UpscaleControl.js'; import { ImageInput } from '../input/ImageInput.js'; import { MaskCanvas } from '../input/MaskCanvas.js'; @@ -41,22 +42,30 @@ export function Blend() { return - { - setBlend({ - sources: [file], - }); - }} - /> + {range(BLEND_SOURCES).map((idx) => + { + const newSources = [...sources]; + newSources[idx] = file; + + setBlend({ + sources: newSources, + }); + }} + /> + )} { - // TODO + onSave={(mask) => { + setBlend({ + mask, + }); }} /> diff --git a/gui/src/main.tsx b/gui/src/main.tsx index f673dc60..a7004578 100644 --- a/gui/src/main.tsx +++ b/gui/src/main.tsx @@ -6,6 +6,7 @@ import { QueryClient, QueryClientProvider } from 'react-query'; import { satisfies } from 'semver'; import { createStore } from 'zustand'; import { createJSONStorage, persist } from 'zustand/middleware'; +import { createLogger } from 'browser-bunyan'; import { makeClient } from './client.js'; import { ParamsVersionError } from './components/error/ParamsVersion.js'; @@ -13,7 +14,7 @@ import { ServerParamsError } from './components/error/ServerParams.js'; import { OnnxError } from './components/OnnxError.js'; import { OnnxWeb } from './components/OnnxWeb.js'; import { getApiRoot, loadConfig, mergeConfig, PARAM_VERSION } from './config.js'; -import { ClientContext, ConfigContext, createStateSlices, OnnxState, STATE_VERSION, StateContext } from './state.js'; +import { ClientContext, ConfigContext, createStateSlices, OnnxState, STATE_VERSION, StateContext, LoggerContext } from './state.js'; export const INITIAL_LOAD_TIMEOUT = 5_000; @@ -91,6 +92,12 @@ export async function main() { version: STATE_VERSION, })); + const logger = createLogger({ + name: 'onnx-web', + system: 'react', + level: 'debug', + }); + // prep react-query client const query = new QueryClient(); @@ -98,9 +105,11 @@ export async function main() { app.render( - - - + + + + + ); diff --git a/gui/src/state.ts b/gui/src/state.ts index 854bf964..535fbc55 100644 --- a/gui/src/state.ts +++ b/gui/src/state.ts @@ -1,6 +1,7 @@ /* eslint-disable max-lines */ /* eslint-disable no-null/no-null */ import { doesExist, Maybe } from '@apextoaster/js-utils'; +import { Logger } from 'noicejs'; import { createContext } from 'react'; import { StateCreator, StoreApi } from 'zustand'; @@ -145,6 +146,11 @@ export const ClientContext = createContext>(undefined); */ export const ConfigContext = createContext>>(undefined); +/** + * React context binding for bunyan logger. + */ +export const LoggerContext = createContext>(undefined); + /** * React context binding for zustand state store. */ @@ -160,6 +166,8 @@ export const STATE_KEY = 'onnx-web'; */ export const STATE_VERSION = 5; +export const BLEND_SOURCES = 2; + /** * Default parameters for the inpaint brush. * diff --git a/gui/src/utils.ts b/gui/src/utils.ts index 6d564cc7..a209721c 100644 --- a/gui/src/utils.ts +++ b/gui/src/utils.ts @@ -10,3 +10,11 @@ export function imageFromBlob(blob: Blob): Promise { image.src = src; }); } + +export function range(max: number): Array { + return [...Array(max).keys()]; +} + +export function visibleIndex(idx: number): string { + return (idx + 1).toFixed(0); +} diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace index ebd137ae..8f441e9b 100644 --- a/onnx-web.code-workspace +++ b/onnx-web.code-workspace @@ -15,6 +15,7 @@ "cSpell.words": [ "astype", "basicsr", + "Civitai", "ckpt", "codebook", "codeformer",