1
0
Fork 0

feat(api): add img2img and inpaint chain stages

This commit is contained in:
Sean Sube 2023-01-28 12:42:02 -06:00
parent 4188b019a1
commit dcbd059082
13 changed files with 209 additions and 48 deletions

View File

@ -4,15 +4,24 @@ from .base import (
StageCallback, StageCallback,
StageParams, StageParams,
) )
from .blend_img2img import (
blend_img2img,
)
from .blend_inpaint import (
blend_inpaint,
)
from .correct_gfpgan import ( from .correct_gfpgan import (
correct_gfpgan, correct_gfpgan,
) )
from .generate_txt2img import (
generate_txt2img,
)
from .persist_disk import ( from .persist_disk import (
persist_disk, persist_disk,
) )
from .persist_s3 import (
persist_s3,
)
from .source_txt2img import (
source_txt2img,
)
from .upscale_outpaint import ( from .upscale_outpaint import (
upscale_outpaint, upscale_outpaint,
) )

View File

@ -7,6 +7,7 @@ from ..params import (
StageParams, StageParams,
) )
from ..utils import ( from ..utils import (
is_debug,
ServerContext, ServerContext,
) )
from .utils import ( from .utils import (
@ -68,10 +69,13 @@ class ChainPipeline:
print('source image larger than tile size of %s, tiling stage' % ( print('source image larger than tile size of %s, tiling stage' % (
stage_params.tile_size)) stage_params.tile_size))
def stage_tile(tile: Image.Image) -> Image.Image: def stage_tile(tile: Image.Image, _dims) -> Image.Image:
tile = stage_pipe(ctx, stage_params, params, tile, tile = stage_pipe(ctx, stage_params, params, tile,
**kwargs) **kwargs)
tile.save(path.join(ctx.output_path, 'last-tile.png'))
if is_debug():
tile.save(path.join(ctx.output_path, 'last-tile.png'))
return tile return tile
image = process_tiles( image = process_tiles(

View File

@ -0,0 +1,48 @@
from diffusers import (
OnnxStableDiffusionImg2ImgPipeline,
)
from PIL import Image
from ..diffusion import (
load_pipeline,
)
from ..params import (
ImageParams,
StageParams,
)
from ..utils import (
ServerContext,
)
import numpy as np
def blend_img2img(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
*,
strength: float,
) -> Image.Image:
print('generating image using img2img', params.prompt)
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
image=source_image,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=strength,
)
output = result.images[0]
print('final output image size', output.size)
return output

View File

@ -0,0 +1,99 @@
from diffusers import (
OnnxStableDiffusionInpaintPipeline,
)
from PIL import Image
from typing import Callable, Tuple
from ..diffusion import (
get_latents_from_seed,
load_pipeline,
)
from ..image import (
expand_image,
mask_filter_none,
noise_source_histogram,
)
from ..params import (
Border,
ImageParams,
Size,
StageParams,
)
from ..utils import (
base_join,
is_debug,
ServerContext,
)
from .utils import (
process_tiles,
)
import numpy as np
def blend_inpaint(
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
*,
expand: Border,
mask_image: Image.Image = None,
fill_color: str = 'white',
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
) -> Image.Image:
print('upscaling image by expanding borders', expand)
if mask_image is None:
# if no mask was provided, keep the full source image
mask_image = Image.new('RGB', source_image.size, 'black')
source_image, mask_image, noise_image, _full_dims = expand_image(
source_image,
mask_image,
expand,
fill=fill_color,
noise_source=noise_source,
mask_filter=mask_filter)
if is_debug():
source_image.save(base_join(ctx.output_path, 'last-source.png'))
mask_image.save(base_join(ctx.output_path, 'last-mask.png'))
noise_image.save(base_join(ctx.output_path, 'last-noise.png'))
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
size = Size(*image.size)
mask = mask_image.crop((left, top, left + tile, top + tile))
if is_debug():
image.save(base_join(ctx.output_path, 'tile-source.png'))
mask.save(base_join(ctx.output_path, 'tile-mask.png'))
# TODO: must use inpainting model here
model = '../models/stable-diffusion-onnx-v1-inpainting'
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
model, params.provider, params.scheduler)
latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)
result = pipe(
params.prompt,
generator=rng,
guidance_scale=params.cfg,
height=size.height,
image=image,
latents=latents,
mask_image=mask,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
width=size.width,
)
return result.images[0]
output = process_tiles(source_image, 512, 1, [outpaint])
print('final output image size', output.size)
return output

View File

@ -57,7 +57,7 @@ def correct_gfpgan(
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
upsampler: Optional[RealESRGANer] = None, upsampler: Optional[RealESRGANer] = None,
) -> Image: ) -> Image.Image:
if upscale.correction_model is None: if upscale.correction_model is None:
print('no face model given, skipping') print('no face model given, skipping')
return image return image

View File

@ -13,8 +13,8 @@ from ..utils import (
def persist_disk( def persist_disk(
ctx: ServerContext, ctx: ServerContext,
stage: StageParams, _stage: StageParams,
params: ImageParams, _params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
*, *,
output: str, output: str,

View File

@ -2,43 +2,31 @@ from diffusers import (
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
) )
from PIL import Image from PIL import Image
from typing import Callable
from ..diffusion import ( from ..diffusion import (
get_latents_from_seed, get_latents_from_seed,
load_pipeline, load_pipeline,
) )
from ..image import (
expand_image,
mask_filter_none,
noise_source_histogram,
)
from ..params import ( from ..params import (
Border,
ImageParams, ImageParams,
Size, Size,
StageParams, StageParams,
) )
from ..utils import ( from ..utils import (
base_join,
is_debug,
ServerContext, ServerContext,
) )
from .utils import (
process_tiles,
)
import numpy as np import numpy as np
def generate_txt2img( def source_txt2img(
ctx: ServerContext, ctx: ServerContext,
stage: StageParams, stage: StageParams,
params: ImageParams, params: ImageParams,
source_image: Image.Image, source_image: Image.Image,
*, *,
size: Size, size: Size,
) -> Image: ) -> Image.Image:
print('generating image using txt2img', params.prompt) print('generating image using txt2img', params.prompt)
if source_image is not None: if source_image is not None:

View File

@ -75,7 +75,7 @@ def upscale_resrgan(
source_image: Image.Image, source_image: Image.Image,
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
) -> Image: ) -> Image.Image:
print('upscaling image with Real ESRGAN', upscale.scale) print('upscaling image with Real ESRGAN', upscale.scale)
output = np.array(source_image) output = np.array(source_image)

View File

@ -1,13 +1,18 @@
from PIL import Image from PIL import Image
from typing import Callable, List from typing import List, Protocol, Tuple
class TileCallback(Protocol):
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image:
pass
def process_tiles( def process_tiles(
source: Image.Image, source: Image.Image,
tile: int, tile: int,
scale: int, scale: int,
filters: List[Callable], filters: List[TileCallback],
) -> Image: ) -> Image.Image:
width, height = source.size width, height = source.size
image = Image.new('RGB', (width * scale, height * scale)) image = Image.new('RGB', (width * scale, height * scale))

View File

@ -104,7 +104,7 @@ def run_txt2img_pipeline(
size: Size, size: Size,
output: str, output: str,
upscale: UpscaleParams upscale: UpscaleParams
): ) -> None:
pipe = load_pipeline(OnnxStableDiffusionPipeline, pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, params.provider, params.scheduler) params.model, params.provider, params.scheduler)
@ -139,9 +139,9 @@ def run_img2img_pipeline(
params: ImageParams, params: ImageParams,
output: str, output: str,
upscale: UpscaleParams, upscale: UpscaleParams,
source_image: Image, source_image: Image.Image,
strength: float, strength: float,
): ) -> None:
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler) params.model, params.provider, params.scheduler)
@ -176,14 +176,14 @@ def run_inpaint_pipeline(
size: Size, size: Size,
output: str, output: str,
upscale: UpscaleParams, upscale: UpscaleParams,
source_image: Image, source_image: Image.Image,
mask_image: Image, mask_image: Image.Image,
expand: Border, expand: Border,
noise_source: Any, noise_source: Any,
mask_filter: Any, mask_filter: Any,
strength: float, strength: float,
fill_color: str, fill_color: str,
): ) -> None:
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline, pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
params.model, params.provider, params.scheduler) params.model, params.provider, params.scheduler)
@ -241,8 +241,8 @@ def run_upscale_pipeline(
_size: Size, _size: Size,
output: str, output: str,
upscale: UpscaleParams, upscale: UpscaleParams,
source_image: Image source_image: Image.Image,
): ) -> None:
image = run_upscale_correction( image = run_upscale_correction(
ctx, StageParams(), params, source_image, upscale=upscale) ctx, StageParams(), params, source_image, upscale=upscale)

View File

@ -1,6 +1,5 @@
from numpy import random from numpy import random
from PIL import Image, ImageChops, ImageFilter from PIL import Image, ImageChops, ImageFilter
from typing import Callable, List
import numpy as np import numpy as np
@ -14,7 +13,7 @@ def get_pixel_index(x: int, y: int, width: int) -> int:
return (y * width) + x return (y * width) + x
def mask_filter_none(mask_image: Image, dims: Point, origin: Point, fill='white', **kw) -> Image: def mask_filter_none(mask_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
width, height = dims width, height = dims
noise = Image.new('RGB', (width, height), fill) noise = Image.new('RGB', (width, height), fill)
@ -23,7 +22,7 @@ def mask_filter_none(mask_image: Image, dims: Point, origin: Point, fill='white'
return noise return noise
def mask_filter_gaussian_multiply(mask_image: Image, dims: Point, origin: Point, rounds=3, **kw) -> Image: def mask_filter_gaussian_multiply(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
''' '''
Gaussian blur with multiply, source image centered on white canvas. Gaussian blur with multiply, source image centered on white canvas.
''' '''
@ -36,7 +35,7 @@ def mask_filter_gaussian_multiply(mask_image: Image, dims: Point, origin: Point,
return noise return noise
def mask_filter_gaussian_screen(mask_image: Image, dims: Point, origin: Point, rounds=3, **kw) -> Image: def mask_filter_gaussian_screen(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
''' '''
Gaussian blur, source image centered on white canvas. Gaussian blur, source image centered on white canvas.
''' '''
@ -49,7 +48,7 @@ def mask_filter_gaussian_screen(mask_image: Image, dims: Point, origin: Point, r
return noise return noise
def noise_source_fill_edge(source_image: Image, dims: Point, origin: Point, fill='white', **kw) -> Image: def noise_source_fill_edge(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
''' '''
Identity transform, source image centered on white canvas. Identity transform, source image centered on white canvas.
''' '''
@ -61,7 +60,7 @@ def noise_source_fill_edge(source_image: Image, dims: Point, origin: Point, fill
return noise return noise
def noise_source_fill_mask(source_image: Image, dims: Point, origin: Point, fill='white', **kw) -> Image: def noise_source_fill_mask(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image:
''' '''
Fill the whole canvas, no source or noise. Fill the whole canvas, no source or noise.
''' '''
@ -72,7 +71,7 @@ def noise_source_fill_mask(source_image: Image, dims: Point, origin: Point, fill
return noise return noise
def noise_source_gaussian(source_image: Image, dims: Point, origin: Point, rounds=3, **kw) -> Image: def noise_source_gaussian(source_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image:
''' '''
Gaussian blur, source image centered on white canvas. Gaussian blur, source image centered on white canvas.
''' '''
@ -85,7 +84,7 @@ def noise_source_gaussian(source_image: Image, dims: Point, origin: Point, round
return noise return noise
def noise_source_uniform(source_image: Image, dims: Point, origin: Point, **kw) -> Image: def noise_source_uniform(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
width, height = dims width, height = dims
size = width * height size = width * height
@ -107,7 +106,7 @@ def noise_source_uniform(source_image: Image, dims: Point, origin: Point, **kw)
return noise return noise
def noise_source_normal(source_image: Image, dims: Point, origin: Point, **kw) -> Image: def noise_source_normal(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
width, height = dims width, height = dims
size = width * height size = width * height
@ -129,7 +128,7 @@ def noise_source_normal(source_image: Image, dims: Point, origin: Point, **kw) -
return noise return noise
def noise_source_histogram(source_image: Image, dims: Point, origin: Point, **kw) -> Image: def noise_source_histogram(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image:
r, g, b = source_image.split() r, g, b = source_image.split()
width, height = dims width, height = dims
size = width * height size = width * height
@ -161,8 +160,8 @@ def noise_source_histogram(source_image: Image, dims: Point, origin: Point, **kw
# very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232 # very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232
def expand_image( def expand_image(
source_image: Image, source_image: Image.Image,
mask_image: Image, mask_image: Image.Image,
expand: Border, expand: Border,
fill='white', fill='white',
noise_source=noise_source_histogram, noise_source=noise_source_histogram,

View File

@ -24,8 +24,9 @@ from typing import Tuple
from .chain import ( from .chain import (
correct_gfpgan, correct_gfpgan,
generate_txt2img, source_txt2img,
persist_disk, persist_disk,
persist_s3,
upscale_outpaint, upscale_outpaint,
upscale_resrgan, upscale_resrgan,
upscale_stable_diffusion, upscale_stable_diffusion,
@ -546,7 +547,7 @@ def chain():
# parse body as json, list of stages # parse body as json, list of stages
example = ChainPipeline(stages=[ example = ChainPipeline(stages=[
(generate_txt2img, StageParams(), { (source_txt2img, StageParams(), {
'size': size, 'size': size,
}), }),
(upscale_outpaint, StageParams(), { (upscale_outpaint, StageParams(), {
@ -561,6 +562,10 @@ def chain():
(persist_disk, StageParams(tile_size=8192), { (persist_disk, StageParams(tile_size=8192), {
'output': output, 'output': output,
}), }),
(persist_s3, StageParams(tile_size=8192), {
'bucket': 'storage-stable-diffusion',
'output': output,
}),
]) ])
# build and run chain pipeline # build and run chain pipeline

View File

@ -24,6 +24,10 @@ def run_upscale_correction(
*, *,
upscale: UpscaleParams, upscale: UpscaleParams,
) -> Image.Image: ) -> Image.Image:
'''
This is a convenience method for a chain pipeline that will run upscaling and
correction, based on the `upscale` params.
'''
print('running upscale pipeline') print('running upscale pipeline')
chain = ChainPipeline() chain = ChainPipeline()