From 3679735d86e982c18ae9534118567b2a24990868 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 18 Jan 2023 08:41:02 -0600 Subject: [PATCH] feat: add fill color control to inpaint --- api/onnx_web/image.py | 22 ++++++++-------- api/onnx_web/pipeline.py | 2 ++ api/onnx_web/serve.py | 5 +++- gui/src/client.ts | 1 + gui/src/components/input/ImageInput.tsx | 1 - gui/src/components/tab/Inpaint.tsx | 34 +++++++++++++++---------- gui/src/state.ts | 2 ++ 7 files changed, 41 insertions(+), 26 deletions(-) diff --git a/api/onnx_web/image.py b/api/onnx_web/image.py index de0d3f43..5fec85e9 100644 --- a/api/onnx_web/image.py +++ b/api/onnx_web/image.py @@ -14,7 +14,7 @@ def get_pixel_index(x: int, y: int, width: int) -> int: return (y * width) + x -def mask_filter_none(mask_image: Image, dims: Point, origin: Point, fill='white') -> Image: +def mask_filter_none(mask_image: Image, dims: Point, origin: Point, fill='white', **kw) -> Image: width, height = dims noise = Image.new('RGB', (width, height), fill) @@ -23,7 +23,7 @@ def mask_filter_none(mask_image: Image, dims: Point, origin: Point, fill='white' return noise -def mask_filter_gaussian_multiply(mask_image: Image, dims: Point, origin: Point, rounds=3) -> Image: +def mask_filter_gaussian_multiply(mask_image: Image, dims: Point, origin: Point, rounds=3, **kw) -> Image: ''' Gaussian blur with multiply, source image centered on white canvas. ''' @@ -36,7 +36,7 @@ def mask_filter_gaussian_multiply(mask_image: Image, dims: Point, origin: Point, return noise -def mask_filter_gaussian_screen(mask_image: Image, dims: Point, origin: Point, rounds=3) -> Image: +def mask_filter_gaussian_screen(mask_image: Image, dims: Point, origin: Point, rounds=3, **kw) -> Image: ''' Gaussian blur, source image centered on white canvas. ''' @@ -49,7 +49,7 @@ def mask_filter_gaussian_screen(mask_image: Image, dims: Point, origin: Point, r return noise -def noise_source_fill_edge(source_image: Image, dims: Point, origin: Point, fill='white') -> Image: +def noise_source_fill_edge(source_image: Image, dims: Point, origin: Point, fill='white', **kw) -> Image: ''' Identity transform, source image centered on white canvas. ''' @@ -61,7 +61,7 @@ def noise_source_fill_edge(source_image: Image, dims: Point, origin: Point, fill return noise -def noise_source_fill_mask(source_image: Image, dims: Point, origin: Point, fill='white') -> Image: +def noise_source_fill_mask(source_image: Image, dims: Point, origin: Point, fill='white', **kw) -> Image: ''' Fill the whole canvas, no source or noise. ''' @@ -72,7 +72,7 @@ def noise_source_fill_mask(source_image: Image, dims: Point, origin: Point, fill return noise -def noise_source_gaussian(source_image: Image, dims: Point, origin: Point, rounds=3) -> Image: +def noise_source_gaussian(source_image: Image, dims: Point, origin: Point, rounds=3, **kw) -> Image: ''' Gaussian blur, source image centered on white canvas. ''' @@ -85,7 +85,7 @@ def noise_source_gaussian(source_image: Image, dims: Point, origin: Point, round return noise -def noise_source_uniform(source_image: Image, dims: Point, origin: Point) -> Image: +def noise_source_uniform(source_image: Image, dims: Point, origin: Point, **kw) -> Image: width, height = dims size = width * height @@ -107,7 +107,7 @@ def noise_source_uniform(source_image: Image, dims: Point, origin: Point) -> Ima return noise -def noise_source_normal(source_image: Image, dims: Point, origin: Point) -> Image: +def noise_source_normal(source_image: Image, dims: Point, origin: Point, **kw) -> Image: width, height = dims size = width * height @@ -129,7 +129,7 @@ def noise_source_normal(source_image: Image, dims: Point, origin: Point) -> Imag return noise -def noise_source_histogram(source_image: Image, dims: Point, origin: Point) -> Image: +def noise_source_histogram(source_image: Image, dims: Point, origin: Point, **kw) -> Image: r, g, b = source_image.split() width, height = dims size = width * height @@ -177,8 +177,8 @@ def expand_image( full_source = Image.new('RGB', dims, fill) full_source.paste(source_image, origin) - full_mask = mask_filter(mask_image, dims, origin) - full_noise = noise_source(source_image, dims, origin) + full_mask = mask_filter(mask_image, dims, origin, fill=fill) + full_noise = noise_source(source_image, dims, origin, fill=fill) full_noise = ImageChops.multiply(full_noise, full_mask) full_source = Image.composite(full_noise, full_source, full_mask.convert('L')) diff --git a/api/onnx_web/pipeline.py b/api/onnx_web/pipeline.py index 02d1f0a6..481758f8 100644 --- a/api/onnx_web/pipeline.py +++ b/api/onnx_web/pipeline.py @@ -150,6 +150,7 @@ def run_inpaint_pipeline( noise_source: Any, mask_filter: Any, strength: float, + fill_color: str, ): pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline, params.model, params.provider, params.scheduler) @@ -162,6 +163,7 @@ def run_inpaint_pipeline( source_image, mask_image, expand, + fill=fill_color, noise_source=noise_source, mask_filter=mask_filter) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 7b4ba1a4..e63f7282 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -390,6 +390,7 @@ def inpaint(): expand = border_from_request() upscale = upscale_from_request() + fill_color = request.args.get('fillColor', 'white') mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none') noise_source = get_from_map( request.args, 'noise', noise_sources, 'histogram') @@ -411,6 +412,7 @@ def inpaint(): mask_filter.__name__, noise_source.__name__, strength, + fill_color, ) ) print("inpaint output: %s" % output) @@ -430,7 +432,8 @@ def inpaint(): expand, noise_source, mask_filter, - strength) + strength, + fill_color) return jsonify({ 'output': output, diff --git a/gui/src/client.ts b/gui/src/client.ts index e7c68e1b..596881c0 100644 --- a/gui/src/client.ts +++ b/gui/src/client.ts @@ -48,6 +48,7 @@ export interface InpaintParams extends BaseImgParams { filter: string; noise: string; strength: number; + fillColor: string; } export interface OutpaintPixels { diff --git a/gui/src/components/input/ImageInput.tsx b/gui/src/components/input/ImageInput.tsx index aee071b8..572f1289 100644 --- a/gui/src/components/input/ImageInput.tsx +++ b/gui/src/components/input/ImageInput.tsx @@ -44,7 +44,6 @@ export function ImageInput(props: ImageInputProps) { const { files } = event.target; if (doesExist(files) && files.length > 0) { const file = mustExist(files[0]); - props.onChange(file); } }} diff --git a/gui/src/components/tab/Inpaint.tsx b/gui/src/components/tab/Inpaint.tsx index c06c7ba4..752f9531 100644 --- a/gui/src/components/tab/Inpaint.tsx +++ b/gui/src/components/tab/Inpaint.tsx @@ -53,6 +53,7 @@ export function Inpaint() { } const state = mustExist(useContext(StateContext)); + const fillColor = useStore(state, (s) => s.inpaint.fillColor); const filter = useStore(state, (s) => s.inpaint.filter); const noise = useStore(state, (s) => s.inpaint.noise); const mask = useStore(state, (s) => s.inpaint.mask); @@ -109,19 +110,6 @@ export function Inpaint() { }} /> - { - setInpaint({ - strength: value, - }); - }} - /> - {/* TODO: numeric input for blend strength */} + { + setInpaint({ + strength: value, + }); + }} + /> + + { + setInpaint({ + fillColor: event.target.value, + }); + }} /> + + diff --git a/gui/src/state.ts b/gui/src/state.ts index 663402a9..46c89ac7 100644 --- a/gui/src/state.ts +++ b/gui/src/state.ts @@ -153,6 +153,7 @@ export function createStateSlices(base: ServerParams) { const createInpaintSlice: StateCreator = (set) => ({ inpaint: { ...defaults, + fillColor: '', filter: 'none', mask: null, noise: 'histogram', @@ -171,6 +172,7 @@ export function createStateSlices(base: ServerParams) { set({ inpaint: { ...defaults, + fillColor: '', filter: 'none', mask: null, noise: 'histogram',